Skip to content

Commit da769cb

Browse files
udpate database via messages
1 parent c22d977 commit da769cb

File tree

3 files changed

+99
-76
lines changed

3 files changed

+99
-76
lines changed

crates/pg_lsp/src/client.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ impl LspClient {
5050
Ok(())
5151
}
5252

53+
/// This will ignore any errors that occur while sending the notification.
54+
pub fn send_info_notification(&self, message: &str) {
55+
let _ = self.send_notification::<ShowMessage>(ShowMessageParams {
56+
message: message.into(),
57+
typ: MessageType::INFO,
58+
});
59+
}
60+
5361
pub fn send_request<R>(&self, params: R::Params) -> Result<R::Result>
5462
where
5563
R: lsp_types::request::Request,

crates/pg_lsp/src/db_connection.rs

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use pg_schema_cache::SchemaCache;
22
use sqlx::{postgres::PgListener, PgPool};
3+
use tokio::task::JoinHandle;
34

45
#[derive(Debug)]
56
pub(crate) struct DbConnection {
67
pub pool: PgPool,
78
connection_string: String,
9+
schema_update_handle: Option<JoinHandle<()>>,
810
}
911

1012
impl DbConnection {
@@ -13,49 +15,52 @@ impl DbConnection {
1315
Ok(Self {
1416
pool,
1517
connection_string: connection_string,
18+
schema_update_handle: None,
1619
})
1720
}
1821

19-
pub(crate) async fn refresh_db_connection(
20-
self,
21-
connection_string: Option<String>,
22-
) -> anyhow::Result<Self> {
23-
if connection_string.is_none()
24-
|| connection_string.as_ref() == Some(&self.connection_string)
25-
{
26-
return Ok(self);
27-
}
22+
pub(crate) fn connected_to(&self, connection_string: &str) -> bool {
23+
connection_string == self.connection_string
24+
}
2825

26+
pub(crate) async fn close(self) {
27+
if self.schema_update_handle.is_some() {
28+
self.schema_update_handle.unwrap().abort();
29+
}
2930
self.pool.close().await;
30-
31-
let conn = DbConnection::new(connection_string.unwrap()).await?;
32-
33-
Ok(conn)
3431
}
3532

36-
pub(crate) async fn start_listening<F>(&self, on_schema_update: F) -> anyhow::Result<()>
33+
pub(crate) async fn listen_for_schema_updates<F>(
34+
&mut self,
35+
on_schema_update: F,
36+
) -> anyhow::Result<()>
3737
where
38-
F: Fn() -> () + Send + 'static,
38+
F: Fn(SchemaCache) -> () + Send + 'static,
3939
{
4040
let mut listener = PgListener::connect_with(&self.pool).await?;
4141
listener.listen_all(["postgres_lsp", "pgrst"]).await?;
4242

43-
loop {
44-
match listener.recv().await {
45-
Ok(notification) => {
46-
if notification.payload().to_string() == "reload schema" {
47-
on_schema_update();
43+
let pool = self.pool.clone();
44+
45+
let handle: JoinHandle<()> = tokio::spawn(async move {
46+
loop {
47+
match listener.recv().await {
48+
Ok(not) => {
49+
if not.payload().to_string() == "reload schema" {
50+
let schema_cache = SchemaCache::load(&pool).await;
51+
on_schema_update(schema_cache);
52+
};
53+
}
54+
Err(why) => {
55+
eprintln!("Error receiving notification: {:?}", why);
56+
break;
4857
}
49-
}
50-
Err(e) => {
51-
eprintln!("Listener error: {}", e);
52-
return Err(e.into());
5358
}
5459
}
55-
}
56-
}
60+
});
61+
62+
self.schema_update_handle = Some(handle);
5763

58-
pub(crate) async fn get_schema_cache(&self) -> SchemaCache {
59-
SchemaCache::load(&self.pool).await
64+
Ok(())
6065
}
6166
}

crates/pg_lsp/src/server.rs

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,14 @@ use crate::{
4040
};
4141

4242
use self::{debouncer::EventDebouncer, options::Options};
43-
use sqlx::{
44-
postgres::{PgListener, PgPool},
45-
Executor,
46-
};
43+
use sqlx::{postgres::PgPool, Executor};
4744

4845
#[derive(Debug)]
4946
enum InternalMessage {
5047
PublishDiagnostics(lsp_types::Url),
5148
SetOptions(Options),
52-
RefreshSchemaCache,
5349
SetSchemaCache(SchemaCache),
50+
SetDatabaseConnection(DbConnection),
5451
}
5552

5653
/// `lsp-servers` `Connection` type uses a crossbeam channel, which is not compatible with tokio's async runtime.
@@ -210,29 +207,54 @@ impl Server {
210207
});
211208
}
212209

213-
async fn update_options(&mut self, options: Options) -> anyhow::Result<()> {
214-
if options.db_connection_string.is_none() {
210+
fn update_db_connection(&self, options: Options) -> anyhow::Result<()> {
211+
if options.db_connection_string.is_none()
212+
|| self
213+
.db_conn
214+
.as_ref()
215+
.is_some_and(|c| c.connected_to(options.db_connection_string.as_ref().unwrap()))
216+
{
215217
return Ok(());
216218
}
217219

218-
let new_conn = if self.db_conn.is_none() {
219-
DbConnection::new(options.db_connection_string.clone().unwrap()).await?
220-
} else {
221-
let current_conn = self.db_conn.take().unwrap();
222-
current_conn
223-
.refresh_db_connection(options.db_connection_string.clone())
224-
.await?
225-
};
220+
let connection_string = options.db_connection_string.unwrap();
226221

227222
let internal_tx = self.internal_tx.clone();
223+
let client = self.client.clone();
228224
self.spawn_with_cancel(async move {
229-
new_conn.start_listening(move || {
225+
match DbConnection::new(connection_string.into()).await {
226+
Ok(conn) => {
227+
internal_tx
228+
.send(InternalMessage::SetDatabaseConnection(conn))
229+
.unwrap();
230+
}
231+
Err(why) => {
232+
client.send_info_notification(&format!("Unable to update database connection: {}", why));
233+
234+
}
235+
}
236+
});
237+
238+
Ok(())
239+
}
240+
241+
async fn listen_for_schema_updates(&mut self) -> anyhow::Result<()> {
242+
if self.db_conn.is_none() {
243+
eprintln!("Error trying to listen for schema updates: No database connection");
244+
return Ok(());
245+
}
246+
247+
let internal_tx = self.internal_tx.clone();
248+
self.db_conn
249+
.as_mut()
250+
.unwrap()
251+
.listen_for_schema_updates(move |schema_cache| {
230252
internal_tx
231-
.send(InternalMessage::RefreshSchemaCache)
253+
.send(InternalMessage::SetSchemaCache(schema_cache))
232254
.unwrap();
233-
// TODO: handle result
234-
}).await.unwrap()
235-
});
255+
// TODO: handle result
256+
})
257+
.await?;
236258

237259
Ok(())
238260
}
@@ -692,26 +714,6 @@ impl Server {
692714
});
693715
}
694716

695-
async fn refresh_schema_cache(&self) {
696-
if self.db_conn.is_none() {
697-
return;
698-
}
699-
700-
let tx = self.internal_tx.clone();
701-
let conn = self.db_conn.as_ref().unwrap().pool.clone();
702-
let client = self.client.clone();
703-
704-
client
705-
.send_notification::<ShowMessage>(ShowMessageParams {
706-
typ: lsp_types::MessageType::INFO,
707-
message: "Refreshing schema cache...".to_string(),
708-
})
709-
.unwrap();
710-
let schema_cache = SchemaCache::load(&conn).await;
711-
tx.send(InternalMessage::SetSchemaCache(schema_cache))
712-
.unwrap();
713-
}
714-
715717
fn did_change_configuration(
716718
&mut self,
717719
params: DidChangeConfigurationParams,
@@ -720,7 +722,7 @@ impl Server {
720722
self.pull_options();
721723
} else {
722724
let options = self.client.parse_options(params.settings)?;
723-
self.update_options(options);
725+
self.update_db_connection(options);
724726
}
725727

726728
Ok(())
@@ -744,14 +746,14 @@ impl Server {
744746
msg = self.client_rx.recv() => {
745747
match msg {
746748
None => panic!("The LSP's client closed, but not via an 'exit' method. This should never happen."),
747-
Some(m) => self.handle_message(m)
749+
Some(m) => self.handle_message(m).await
748750
}
749751
},
750752
}?;
751753
}
752754
}
753755

754-
fn handle_message(&mut self, msg: Message) -> anyhow::Result<()> {
756+
async fn handle_message(&mut self, msg: Message) -> anyhow::Result<()> {
755757
match msg {
756758
Message::Request(request) => {
757759
if let Some(response) = dispatch::RequestDispatcher::new(request)
@@ -768,7 +770,8 @@ impl Server {
768770
Message::Notification(notification) => {
769771
dispatch::NotificationDispatcher::new(notification)
770772
.on::<DidChangeConfiguration, _>(|params| {
771-
self.did_change_configuration(params)
773+
self.did_change_configuration(params);
774+
Ok(())
772775
})?
773776
.on::<DidCloseTextDocument, _>(|params| self.did_close(params))?
774777
.on::<DidOpenTextDocument, _>(|params| self.did_open(params))?
@@ -788,17 +791,24 @@ impl Server {
788791
async fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> {
789792
match msg {
790793
InternalMessage::SetSchemaCache(c) => {
794+
self.client
795+
.send_info_notification("Refreshing Schema Cache...");
791796
self.ide.set_schema_cache(c);
797+
self.client.send_info_notification("Updated Schema Cache.");
792798
self.compute_now();
793799
}
794-
InternalMessage::RefreshSchemaCache => {
795-
self.refresh_schema_cache().await;
796-
}
797800
InternalMessage::PublishDiagnostics(uri) => {
798801
self.publish_diagnostics(uri)?;
799802
}
800803
InternalMessage::SetOptions(options) => {
801-
self.update_options(options);
804+
self.update_db_connection(options);
805+
}
806+
InternalMessage::SetDatabaseConnection(conn) => {
807+
let current = self.db_conn.replace(conn);
808+
if current.is_some() {
809+
current.unwrap().close().await
810+
}
811+
self.listen_for_schema_updates();
802812
}
803813
}
804814

0 commit comments

Comments
 (0)