about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-01-01 13:13:38 +0000
committerbors <bors@rust-lang.org>2024-01-01 13:13:38 +0000
commita8d935eedc80df8b453d90539cbe78b7e2c75e3c (patch)
tree7c60c182a332bee45a5a1d58fcfb13a414829892
parentaef441ab3a7bebf8207499becf4d0e258bcd2533 (diff)
parent3c8dd9e89ef8295241187ddc092d9b675b7db48d (diff)
downloadrust-a8d935eedc80df8b453d90539cbe78b7e2c75e3c.tar.gz
rust-a8d935eedc80df8b453d90539cbe78b7e2c75e3c.zip
Auto merge of #16226 - Veykril:lsp-server, r=Veykril
internal: Expose whether a channel has been dropped in lsp-server errors

Not the best way to expose this, but this should allow us to give somewhat better errors when the initialization request is malformed, as currently that just results in a channel disconnected error instead of the deserialization error. cc https://github.com/rust-lang/rust-analyzer/issues/15859
-rw-r--r--Cargo.lock14
-rw-r--r--Cargo.toml2
-rw-r--r--crates/rust-analyzer/src/bin/main.rs17
-rw-r--r--lib/lsp-server/Cargo.toml4
-rw-r--r--lib/lsp-server/examples/goto_def.rs10
-rw-r--r--lib/lsp-server/src/error.rs17
-rw-r--r--lib/lsp-server/src/lib.rs49
-rw-r--r--lib/lsp-server/src/msg.rs4
-rw-r--r--lib/lsp-server/src/stdio.rs3
9 files changed, 79 insertions, 41 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 7310ecc8582..c7d110eafb6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -945,24 +945,24 @@ checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
 
 [[package]]
 name = "lsp-server"
-version = "0.7.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b52dccdf3302eefab8c8a1273047f0a3c3dca4b527c8458d00c09484c8371928"
+version = "0.7.6"
 dependencies = [
  "crossbeam-channel",
+ "ctrlc",
  "log",
+ "lsp-types",
  "serde",
  "serde_json",
 ]
 
 [[package]]
 name = "lsp-server"
-version = "0.7.5"
+version = "0.7.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "248f65b78f6db5d8e1b1604b4098a28b43d21a8eb1deeca22b1c421b276c7095"
 dependencies = [
  "crossbeam-channel",
- "ctrlc",
  "log",
- "lsp-types",
  "serde",
  "serde_json",
 ]
@@ -1526,7 +1526,7 @@ dependencies = [
  "ide-ssr",
  "itertools",
  "load-cargo",
- "lsp-server 0.7.4",
+ "lsp-server 0.7.6 (registry+https://github.com/rust-lang/crates.io-index)",
  "lsp-types",
  "mbe",
  "mimalloc",
diff --git a/Cargo.toml b/Cargo.toml
index d4cff420bcb..e82a14d16e1 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -88,7 +88,7 @@ test-utils = { path = "./crates/test-utils" }
 # In-tree crates that are published separately and follow semver. See lib/README.md
 line-index = { version = "0.1.1" }
 la-arena = { version = "0.3.1" }
-lsp-server = { version = "0.7.4" }
+lsp-server = { version = "0.7.6" }
 
 # non-local crates
 anyhow = "1.0.75"
diff --git a/crates/rust-analyzer/src/bin/main.rs b/crates/rust-analyzer/src/bin/main.rs
index 8472e49de98..6f40a4c88ed 100644
--- a/crates/rust-analyzer/src/bin/main.rs
+++ b/crates/rust-analyzer/src/bin/main.rs
@@ -172,7 +172,15 @@ fn run_server() -> anyhow::Result<()> {
 
     let (connection, io_threads) = Connection::stdio();
 
-    let (initialize_id, initialize_params) = connection.initialize_start()?;
+    let (initialize_id, initialize_params) = match connection.initialize_start() {
+        Ok(it) => it,
+        Err(e) => {
+            if e.channel_is_disconnected() {
+                io_threads.join()?;
+            }
+            return Err(e.into());
+        }
+    };
     tracing::info!("InitializeParams: {}", initialize_params);
     let lsp_types::InitializeParams {
         root_uri,
@@ -240,7 +248,12 @@ fn run_server() -> anyhow::Result<()> {
 
     let initialize_result = serde_json::to_value(initialize_result).unwrap();
 
-    connection.initialize_finish(initialize_id, initialize_result)?;
+    if let Err(e) = connection.initialize_finish(initialize_id, initialize_result) {
+        if e.channel_is_disconnected() {
+            io_threads.join()?;
+        }
+        return Err(e.into());
+    }
 
     if !config.has_linked_projects() && config.detached_files().is_empty() {
         config.rediscover_workspaces();
diff --git a/lib/lsp-server/Cargo.toml b/lib/lsp-server/Cargo.toml
index e802bf185b3..116b376b0b0 100644
--- a/lib/lsp-server/Cargo.toml
+++ b/lib/lsp-server/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "lsp-server"
-version = "0.7.5"
+version = "0.7.6"
 description = "Generic LSP server scaffold."
 license = "MIT OR Apache-2.0"
 repository = "https://github.com/rust-lang/rust-analyzer/tree/master/lib/lsp-server"
@@ -10,7 +10,7 @@ edition = "2021"
 log = "0.4.17"
 serde_json = "1.0.108"
 serde = { version = "1.0.192", features = ["derive"] }
-crossbeam-channel = "0.5.6"
+crossbeam-channel = "0.5.8"
 
 [dev-dependencies]
 lsp-types = "=0.95"
diff --git a/lib/lsp-server/examples/goto_def.rs b/lib/lsp-server/examples/goto_def.rs
index 2f270afbbf1..71f66254069 100644
--- a/lib/lsp-server/examples/goto_def.rs
+++ b/lib/lsp-server/examples/goto_def.rs
@@ -64,7 +64,15 @@ fn main() -> Result<(), Box<dyn Error + Sync + Send>> {
         ..Default::default()
     })
     .unwrap();
-    let initialization_params = connection.initialize(server_capabilities)?;
+    let initialization_params = match connection.initialize(server_capabilities) {
+        Ok(it) => it,
+        Err(e) => {
+            if e.channel_is_disconnected() {
+                io_threads.join()?;
+            }
+            return Err(e.into());
+        }
+    };
     main_loop(connection, initialization_params)?;
     io_threads.join()?;
 
diff --git a/lib/lsp-server/src/error.rs b/lib/lsp-server/src/error.rs
index 755b3fd9596..ebdd153b5b3 100644
--- a/lib/lsp-server/src/error.rs
+++ b/lib/lsp-server/src/error.rs
@@ -3,7 +3,22 @@ use std::fmt;
 use crate::{Notification, Request};
 
 #[derive(Debug, Clone, PartialEq)]
-pub struct ProtocolError(pub(crate) String);
+pub struct ProtocolError(String, bool);
+
+impl ProtocolError {
+    pub(crate) fn new(msg: impl Into<String>) -> Self {
+        ProtocolError(msg.into(), false)
+    }
+
+    pub(crate) fn disconnected() -> ProtocolError {
+        ProtocolError("disconnected channel".into(), true)
+    }
+
+    /// Whether this error occured due to a disconnected channel.
+    pub fn channel_is_disconnected(&self) -> bool {
+        self.1
+    }
+}
 
 impl std::error::Error for ProtocolError {}
 
diff --git a/lib/lsp-server/src/lib.rs b/lib/lsp-server/src/lib.rs
index 2797a6b60de..6b732d47029 100644
--- a/lib/lsp-server/src/lib.rs
+++ b/lib/lsp-server/src/lib.rs
@@ -17,7 +17,7 @@ use std::{
     net::{TcpListener, TcpStream, ToSocketAddrs},
 };
 
-use crossbeam_channel::{Receiver, RecvTimeoutError, Sender};
+use crossbeam_channel::{Receiver, RecvError, RecvTimeoutError, Sender};
 
 pub use crate::{
     error::{ExtractError, ProtocolError},
@@ -158,11 +158,7 @@ impl Connection {
                 Err(RecvTimeoutError::Timeout) => {
                     continue;
                 }
-                Err(e) => {
-                    return Err(ProtocolError(format!(
-                        "expected initialize request, got error: {e}"
-                    )))
-                }
+                Err(RecvTimeoutError::Disconnected) => return Err(ProtocolError::disconnected()),
             };
 
             match msg {
@@ -181,12 +177,14 @@ impl Connection {
                     continue;
                 }
                 msg => {
-                    return Err(ProtocolError(format!("expected initialize request, got {msg:?}")));
+                    return Err(ProtocolError::new(format!(
+                        "expected initialize request, got {msg:?}"
+                    )));
                 }
             };
         }
 
-        return Err(ProtocolError(String::from(
+        return Err(ProtocolError::new(String::from(
             "Initialization has been aborted during initialization",
         )));
     }
@@ -201,12 +199,10 @@ impl Connection {
         self.sender.send(resp.into()).unwrap();
         match &self.receiver.recv() {
             Ok(Message::Notification(n)) if n.is_initialized() => Ok(()),
-            Ok(msg) => {
-                Err(ProtocolError(format!(r#"expected initialized notification, got: {msg:?}"#)))
-            }
-            Err(e) => {
-                Err(ProtocolError(format!("expected initialized notification, got error: {e}",)))
-            }
+            Ok(msg) => Err(ProtocolError::new(format!(
+                r#"expected initialized notification, got: {msg:?}"#
+            ))),
+            Err(RecvError) => Err(ProtocolError::disconnected()),
         }
     }
 
@@ -231,10 +227,8 @@ impl Connection {
                 Err(RecvTimeoutError::Timeout) => {
                     continue;
                 }
-                Err(e) => {
-                    return Err(ProtocolError(format!(
-                        "expected initialized notification, got error: {e}",
-                    )));
+                Err(RecvTimeoutError::Disconnected) => {
+                    return Err(ProtocolError::disconnected());
                 }
             };
 
@@ -243,14 +237,14 @@ impl Connection {
                     return Ok(());
                 }
                 msg => {
-                    return Err(ProtocolError(format!(
+                    return Err(ProtocolError::new(format!(
                         r#"expected initialized notification, got: {msg:?}"#
                     )));
                 }
             }
         }
 
-        return Err(ProtocolError(String::from(
+        return Err(ProtocolError::new(String::from(
             "Initialization has been aborted during initialization",
         )));
     }
@@ -359,9 +353,18 @@ impl Connection {
         match &self.receiver.recv_timeout(std::time::Duration::from_secs(30)) {
             Ok(Message::Notification(n)) if n.is_exit() => (),
             Ok(msg) => {
-                return Err(ProtocolError(format!("unexpected message during shutdown: {msg:?}")))
+                return Err(ProtocolError::new(format!(
+                    "unexpected message during shutdown: {msg:?}"
+                )))
+            }
+            Err(RecvTimeoutError::Timeout) => {
+                return Err(ProtocolError::new(format!("timed out waiting for exit notification")))
+            }
+            Err(RecvTimeoutError::Disconnected) => {
+                return Err(ProtocolError::new(format!(
+                    "channel disconnected waiting for exit notification"
+                )))
             }
-            Err(e) => return Err(ProtocolError(format!("unexpected error during shutdown: {e}"))),
         }
         Ok(true)
     }
@@ -426,7 +429,7 @@ mod tests {
 
         initialize_start_test(TestCase {
             test_messages: vec![notification_msg.clone()],
-            expected_resp: Err(ProtocolError(format!(
+            expected_resp: Err(ProtocolError::new(format!(
                 "expected initialize request, got {:?}",
                 notification_msg
             ))),
diff --git a/lib/lsp-server/src/msg.rs b/lib/lsp-server/src/msg.rs
index 730ad51f424..ba318dd1690 100644
--- a/lib/lsp-server/src/msg.rs
+++ b/lib/lsp-server/src/msg.rs
@@ -264,12 +264,12 @@ fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> {
         let mut parts = buf.splitn(2, ": ");
         let header_name = parts.next().unwrap();
         let header_value =
-            parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?;
+            parts.next().ok_or_else(|| invalid_data(format!("malformed header: {:?}", buf)))?;
         if header_name.eq_ignore_ascii_case("Content-Length") {
             size = Some(header_value.parse::<usize>().map_err(invalid_data)?);
         }
     }
-    let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?;
+    let size: usize = size.ok_or_else(|| invalid_data("no Content-Length".to_string()))?;
     let mut buf = buf.into_bytes();
     buf.resize(size, 0);
     inp.read_exact(&mut buf)?;
diff --git a/lib/lsp-server/src/stdio.rs b/lib/lsp-server/src/stdio.rs
index e487b9b4622..cea199d0293 100644
--- a/lib/lsp-server/src/stdio.rs
+++ b/lib/lsp-server/src/stdio.rs
@@ -15,8 +15,7 @@ pub(crate) fn stdio_transport() -> (Sender<Message>, Receiver<Message>, IoThread
     let writer = thread::spawn(move || {
         let stdout = stdout();
         let mut stdout = stdout.lock();
-        writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))?;
-        Ok(())
+        writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))
     });
     let (reader_sender, reader_receiver) = bounded::<Message>(0);
     let reader = thread::spawn(move || {