about summary refs log tree commit diff
path: root/lib
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2024-01-01 14:06:56 +0100
committerLukas Wirth <lukastw97@gmail.com>2024-01-01 14:10:46 +0100
commit3c8dd9e89ef8295241187ddc092d9b675b7db48d (patch)
tree7c60c182a332bee45a5a1d58fcfb13a414829892 /lib
parent06be1b1f34ec358594e6d6493876eb7af22c8b47 (diff)
downloadrust-3c8dd9e89ef8295241187ddc092d9b675b7db48d.tar.gz
rust-3c8dd9e89ef8295241187ddc092d9b675b7db48d.zip
Expose whether a channel has been dropped in lsp-server errors
Diffstat (limited to 'lib')
-rw-r--r--lib/lsp-server/Cargo.toml4
-rw-r--r--lib/lsp-server/examples/goto_def.rs10
-rw-r--r--lib/lsp-server/src/error.rs17
-rw-r--r--lib/lsp-server/src/lib.rs49
-rw-r--r--lib/lsp-server/src/msg.rs4
-rw-r--r--lib/lsp-server/src/stdio.rs3
6 files changed, 56 insertions, 31 deletions
diff --git a/lib/lsp-server/Cargo.toml b/lib/lsp-server/Cargo.toml
index e802bf185b3..116b376b0b0 100644
--- a/lib/lsp-server/Cargo.toml
+++ b/lib/lsp-server/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "lsp-server"
-version = "0.7.5"
+version = "0.7.6"
 description = "Generic LSP server scaffold."
 license = "MIT OR Apache-2.0"
 repository = "https://github.com/rust-lang/rust-analyzer/tree/master/lib/lsp-server"
@@ -10,7 +10,7 @@ edition = "2021"
 log = "0.4.17"
 serde_json = "1.0.108"
 serde = { version = "1.0.192", features = ["derive"] }
-crossbeam-channel = "0.5.6"
+crossbeam-channel = "0.5.8"
 
 [dev-dependencies]
 lsp-types = "=0.95"
diff --git a/lib/lsp-server/examples/goto_def.rs b/lib/lsp-server/examples/goto_def.rs
index 2f270afbbf1..71f66254069 100644
--- a/lib/lsp-server/examples/goto_def.rs
+++ b/lib/lsp-server/examples/goto_def.rs
@@ -64,7 +64,15 @@ fn main() -> Result<(), Box<dyn Error + Sync + Send>> {
         ..Default::default()
     })
     .unwrap();
-    let initialization_params = connection.initialize(server_capabilities)?;
+    let initialization_params = match connection.initialize(server_capabilities) {
+        Ok(it) => it,
+        Err(e) => {
+            if e.channel_is_disconnected() {
+                io_threads.join()?;
+            }
+            return Err(e.into());
+        }
+    };
     main_loop(connection, initialization_params)?;
     io_threads.join()?;
 
diff --git a/lib/lsp-server/src/error.rs b/lib/lsp-server/src/error.rs
index 755b3fd9596..ebdd153b5b3 100644
--- a/lib/lsp-server/src/error.rs
+++ b/lib/lsp-server/src/error.rs
@@ -3,7 +3,22 @@ use std::fmt;
 use crate::{Notification, Request};
 
 #[derive(Debug, Clone, PartialEq)]
-pub struct ProtocolError(pub(crate) String);
+pub struct ProtocolError(String, bool);
+
+impl ProtocolError {
+    pub(crate) fn new(msg: impl Into<String>) -> Self {
+        ProtocolError(msg.into(), false)
+    }
+
+    pub(crate) fn disconnected() -> ProtocolError {
+        ProtocolError("disconnected channel".into(), true)
+    }
+
+    /// Whether this error occured due to a disconnected channel.
+    pub fn channel_is_disconnected(&self) -> bool {
+        self.1
+    }
+}
 
 impl std::error::Error for ProtocolError {}
 
diff --git a/lib/lsp-server/src/lib.rs b/lib/lsp-server/src/lib.rs
index 2797a6b60de..6b732d47029 100644
--- a/lib/lsp-server/src/lib.rs
+++ b/lib/lsp-server/src/lib.rs
@@ -17,7 +17,7 @@ use std::{
     net::{TcpListener, TcpStream, ToSocketAddrs},
 };
 
-use crossbeam_channel::{Receiver, RecvTimeoutError, Sender};
+use crossbeam_channel::{Receiver, RecvError, RecvTimeoutError, Sender};
 
 pub use crate::{
     error::{ExtractError, ProtocolError},
@@ -158,11 +158,7 @@ impl Connection {
                 Err(RecvTimeoutError::Timeout) => {
                     continue;
                 }
-                Err(e) => {
-                    return Err(ProtocolError(format!(
-                        "expected initialize request, got error: {e}"
-                    )))
-                }
+                Err(RecvTimeoutError::Disconnected) => return Err(ProtocolError::disconnected()),
             };
 
             match msg {
@@ -181,12 +177,14 @@ impl Connection {
                     continue;
                 }
                 msg => {
-                    return Err(ProtocolError(format!("expected initialize request, got {msg:?}")));
+                    return Err(ProtocolError::new(format!(
+                        "expected initialize request, got {msg:?}"
+                    )));
                 }
             };
         }
 
-        return Err(ProtocolError(String::from(
+        return Err(ProtocolError::new(String::from(
             "Initialization has been aborted during initialization",
         )));
     }
@@ -201,12 +199,10 @@ impl Connection {
         self.sender.send(resp.into()).unwrap();
         match &self.receiver.recv() {
             Ok(Message::Notification(n)) if n.is_initialized() => Ok(()),
-            Ok(msg) => {
-                Err(ProtocolError(format!(r#"expected initialized notification, got: {msg:?}"#)))
-            }
-            Err(e) => {
-                Err(ProtocolError(format!("expected initialized notification, got error: {e}",)))
-            }
+            Ok(msg) => Err(ProtocolError::new(format!(
+                r#"expected initialized notification, got: {msg:?}"#
+            ))),
+            Err(RecvError) => Err(ProtocolError::disconnected()),
         }
     }
 
@@ -231,10 +227,8 @@ impl Connection {
                 Err(RecvTimeoutError::Timeout) => {
                     continue;
                 }
-                Err(e) => {
-                    return Err(ProtocolError(format!(
-                        "expected initialized notification, got error: {e}",
-                    )));
+                Err(RecvTimeoutError::Disconnected) => {
+                    return Err(ProtocolError::disconnected());
                 }
             };
 
@@ -243,14 +237,14 @@ impl Connection {
                     return Ok(());
                 }
                 msg => {
-                    return Err(ProtocolError(format!(
+                    return Err(ProtocolError::new(format!(
                         r#"expected initialized notification, got: {msg:?}"#
                     )));
                 }
             }
         }
 
-        return Err(ProtocolError(String::from(
+        return Err(ProtocolError::new(String::from(
             "Initialization has been aborted during initialization",
         )));
     }
@@ -359,9 +353,18 @@ impl Connection {
         match &self.receiver.recv_timeout(std::time::Duration::from_secs(30)) {
             Ok(Message::Notification(n)) if n.is_exit() => (),
             Ok(msg) => {
-                return Err(ProtocolError(format!("unexpected message during shutdown: {msg:?}")))
+                return Err(ProtocolError::new(format!(
+                    "unexpected message during shutdown: {msg:?}"
+                )))
+            }
+            Err(RecvTimeoutError::Timeout) => {
+                return Err(ProtocolError::new(format!("timed out waiting for exit notification")))
+            }
+            Err(RecvTimeoutError::Disconnected) => {
+                return Err(ProtocolError::new(format!(
+                    "channel disconnected waiting for exit notification"
+                )))
             }
-            Err(e) => return Err(ProtocolError(format!("unexpected error during shutdown: {e}"))),
         }
         Ok(true)
     }
@@ -426,7 +429,7 @@ mod tests {
 
         initialize_start_test(TestCase {
             test_messages: vec![notification_msg.clone()],
-            expected_resp: Err(ProtocolError(format!(
+            expected_resp: Err(ProtocolError::new(format!(
                 "expected initialize request, got {:?}",
                 notification_msg
             ))),
diff --git a/lib/lsp-server/src/msg.rs b/lib/lsp-server/src/msg.rs
index 730ad51f424..ba318dd1690 100644
--- a/lib/lsp-server/src/msg.rs
+++ b/lib/lsp-server/src/msg.rs
@@ -264,12 +264,12 @@ fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> {
         let mut parts = buf.splitn(2, ": ");
         let header_name = parts.next().unwrap();
         let header_value =
-            parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?;
+            parts.next().ok_or_else(|| invalid_data(format!("malformed header: {:?}", buf)))?;
         if header_name.eq_ignore_ascii_case("Content-Length") {
             size = Some(header_value.parse::<usize>().map_err(invalid_data)?);
         }
     }
-    let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?;
+    let size: usize = size.ok_or_else(|| invalid_data("no Content-Length".to_string()))?;
     let mut buf = buf.into_bytes();
     buf.resize(size, 0);
     inp.read_exact(&mut buf)?;
diff --git a/lib/lsp-server/src/stdio.rs b/lib/lsp-server/src/stdio.rs
index e487b9b4622..cea199d0293 100644
--- a/lib/lsp-server/src/stdio.rs
+++ b/lib/lsp-server/src/stdio.rs
@@ -15,8 +15,7 @@ pub(crate) fn stdio_transport() -> (Sender<Message>, Receiver<Message>, IoThread
     let writer = thread::spawn(move || {
         let stdout = stdout();
         let mut stdout = stdout.lock();
-        writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))?;
-        Ok(())
+        writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))
     });
     let (reader_sender, reader_receiver) = bounded::<Message>(0);
     let reader = thread::spawn(move || {