about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-11-22 08:48:17 +0000
committerbors <bors@rust-lang.org>2023-11-22 08:48:17 +0000
commit7ceefc7ee981f7dd9de4cfdd070696e48b4ab43e (patch)
tree1e1de9dc7bcfedb197ad96dbd3484785cf3deaa3
parent45136511a5b4f6b216aba371f541fb2251868c2b (diff)
parent81c2d3552e6d377bf0332523c1278b7e5723c27c (diff)
downloadrust-7ceefc7ee981f7dd9de4cfdd070696e48b4ab43e.tar.gz
rust-7ceefc7ee981f7dd9de4cfdd070696e48b4ab43e.zip
Auto merge of #15894 - schrieveslaach:cancelable-initialization, r=Veykril
Cancelable Initialization

This commit provides additional initialization methods to Connection in order to support CTRL + C sigterm handling.

In the process of adding LSP to Nushell (see https://github.com/nushell/nushell/pull/10941) this gap has been identified.
-rw-r--r--Cargo.lock32
-rw-r--r--lib/lsp-server/Cargo.toml1
-rw-r--r--lib/lsp-server/src/lib.rs169
3 files changed, 188 insertions, 14 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 701e36d74a6..11283034ea2 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -216,7 +216,7 @@ version = "2.1.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "5080df6b0f0ecb76cab30808f00d937ba725cebe266a3da8cd89dff92f2a9916"
 dependencies = [
- "nix",
+ "nix 0.26.2",
  "winapi",
 ]
 
@@ -290,6 +290,16 @@ dependencies = [
 ]
 
 [[package]]
+name = "ctrlc"
+version = "3.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "82e95fbd621905b854affdc67943b043a0fbb6ed7385fd5a25650d19a8a6cfdf"
+dependencies = [
+ "nix 0.27.1",
+ "windows-sys 0.48.0",
+]
+
+[[package]]
 name = "dashmap"
 version = "5.4.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -961,6 +971,7 @@ name = "lsp-server"
 version = "0.7.4"
 dependencies = [
  "crossbeam-channel",
+ "ctrlc",
  "log",
  "lsp-types",
  "serde",
@@ -1101,6 +1112,17 @@ dependencies = [
 ]
 
 [[package]]
+name = "nix"
+version = "0.27.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
+dependencies = [
+ "bitflags 2.4.1",
+ "cfg-if",
+ "libc",
+]
+
+[[package]]
 name = "nohash-hasher"
 version = "0.2.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1701,18 +1723,18 @@ dependencies = [
 
 [[package]]
 name = "serde"
-version = "1.0.192"
+version = "1.0.193"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001"
+checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
 dependencies = [
  "serde_derive",
 ]
 
 [[package]]
 name = "serde_derive"
-version = "1.0.192"
+version = "1.0.193"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1"
+checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
 dependencies = [
  "proc-macro2",
  "quote",
diff --git a/lib/lsp-server/Cargo.toml b/lib/lsp-server/Cargo.toml
index 8d00813b0d7..be1573913ff 100644
--- a/lib/lsp-server/Cargo.toml
+++ b/lib/lsp-server/Cargo.toml
@@ -14,3 +14,4 @@ crossbeam-channel = "0.5.6"
 
 [dev-dependencies]
 lsp-types = "=0.94"
+ctrlc = "3.4.1"
diff --git a/lib/lsp-server/src/lib.rs b/lib/lsp-server/src/lib.rs
index affab60a227..b190c0af73d 100644
--- a/lib/lsp-server/src/lib.rs
+++ b/lib/lsp-server/src/lib.rs
@@ -17,7 +17,7 @@ use std::{
     net::{TcpListener, TcpStream, ToSocketAddrs},
 };
 
-use crossbeam_channel::{Receiver, Sender};
+use crossbeam_channel::{Receiver, RecvTimeoutError, Sender};
 
 pub use crate::{
     error::{ExtractError, ProtocolError},
@@ -113,11 +113,62 @@ impl Connection {
     /// }
     /// ```
     pub fn initialize_start(&self) -> Result<(RequestId, serde_json::Value), ProtocolError> {
-        loop {
-            break match self.receiver.recv() {
-                Ok(Message::Request(req)) if req.is_initialize() => Ok((req.id, req.params)),
+        self.initialize_start_while(|| true)
+    }
+
+    /// Starts the initialization process by waiting for an initialize as described in
+    /// [`Self::initialize_start`] as long as `running` returns
+    /// `true` while the return value can be changed through a sig handler such as `CTRL + C`.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// use std::sync::atomic::{AtomicBool, Ordering};
+    /// use std::sync::Arc;
+    /// # use std::error::Error;
+    /// # use lsp_types::{ClientCapabilities, InitializeParams, ServerCapabilities};
+    /// # use lsp_server::{Connection, Message, Request, RequestId, Response};
+    /// # fn main() -> Result<(), Box<dyn Error + Sync + Send>> {
+    /// let running = Arc::new(AtomicBool::new(true));
+    /// # running.store(true, Ordering::SeqCst);
+    /// let r = running.clone();
+    ///
+    /// ctrlc::set_handler(move || {
+    ///     r.store(false, Ordering::SeqCst);
+    /// }).expect("Error setting Ctrl-C handler");
+    ///
+    /// let (connection, io_threads) = Connection::stdio();
+    ///
+    /// let res = connection.initialize_start_while(|| running.load(Ordering::SeqCst));
+    /// # assert!(res.is_err());
+    ///
+    /// # Ok(())
+    /// # }
+    /// ```
+    pub fn initialize_start_while<C>(
+        &self,
+        running: C,
+    ) -> Result<(RequestId, serde_json::Value), ProtocolError>
+    where
+        C: Fn() -> bool,
+    {
+        while running() {
+            let msg = match self.receiver.recv_timeout(std::time::Duration::from_secs(1)) {
+                Ok(msg) => msg,
+                Err(RecvTimeoutError::Timeout) => {
+                    continue;
+                }
+                Err(e) => {
+                    return Err(ProtocolError(format!(
+                        "expected initialize request, got error: {e}"
+                    )))
+                }
+            };
+
+            match msg {
+                Message::Request(req) if req.is_initialize() => return Ok((req.id, req.params)),
                 // Respond to non-initialize requests with ServerNotInitialized
-                Ok(Message::Request(req)) => {
+                Message::Request(req) => {
                     let resp = Response::new_err(
                         req.id.clone(),
                         ErrorCode::ServerNotInitialized as i32,
@@ -126,15 +177,18 @@ impl Connection {
                     self.sender.send(resp.into()).unwrap();
                     continue;
                 }
-                Ok(Message::Notification(n)) if !n.is_exit() => {
+                Message::Notification(n) if !n.is_exit() => {
                     continue;
                 }
-                Ok(msg) => Err(ProtocolError(format!("expected initialize request, got {msg:?}"))),
-                Err(e) => {
-                    Err(ProtocolError(format!("expected initialize request, got error: {e}")))
+                msg => {
+                    return Err(ProtocolError(format!("expected initialize request, got {msg:?}")));
                 }
             };
         }
+
+        return Err(ProtocolError(String::from(
+            "Initialization has been aborted during initialization",
+        )));
     }
 
     /// Finishes the initialization process by sending an `InitializeResult` to the client
@@ -156,6 +210,51 @@ impl Connection {
         }
     }
 
+    /// Finishes the initialization process as described in [`Self::initialize_finish`] as
+    /// long as `running` returns `true` while the return value can be changed through a sig
+    /// handler such as `CTRL + C`.
+    pub fn initialize_finish_while<C>(
+        &self,
+        initialize_id: RequestId,
+        initialize_result: serde_json::Value,
+        running: C,
+    ) -> Result<(), ProtocolError>
+    where
+        C: Fn() -> bool,
+    {
+        let resp = Response::new_ok(initialize_id, initialize_result);
+        self.sender.send(resp.into()).unwrap();
+
+        while running() {
+            let msg = match self.receiver.recv_timeout(std::time::Duration::from_secs(1)) {
+                Ok(msg) => msg,
+                Err(RecvTimeoutError::Timeout) => {
+                    continue;
+                }
+                Err(e) => {
+                    return Err(ProtocolError(format!(
+                        "expected initialized notification, got error: {e}",
+                    )));
+                }
+            };
+
+            match msg {
+                Message::Notification(n) if n.is_initialized() => {
+                    return Ok(());
+                }
+                msg => {
+                    return Err(ProtocolError(format!(
+                        r#"expected initialized notification, got: {msg:?}"#
+                    )));
+                }
+            }
+        }
+
+        return Err(ProtocolError(String::from(
+            "Initialization has been aborted during initialization",
+        )));
+    }
+
     /// Initialize the connection. Sends the server capabilities
     /// to the client and returns the serialized client capabilities
     /// on success. If more fine-grained initialization is required use
@@ -198,6 +297,58 @@ impl Connection {
         Ok(params)
     }
 
+    /// Initialize the connection as described in [`Self::initialize`] as long as `running` returns
+    /// `true` while the return value can be changed through a sig handler such as `CTRL + C`.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// use std::sync::atomic::{AtomicBool, Ordering};
+    /// use std::sync::Arc;
+    /// # use std::error::Error;
+    /// # use lsp_types::ServerCapabilities;
+    /// # use lsp_server::{Connection, Message, Request, RequestId, Response};
+    ///
+    /// # fn main() -> Result<(), Box<dyn Error + Sync + Send>> {
+    /// let running = Arc::new(AtomicBool::new(true));
+    /// # running.store(true, Ordering::SeqCst);
+    /// let r = running.clone();
+    ///
+    /// ctrlc::set_handler(move || {
+    ///     r.store(false, Ordering::SeqCst);
+    /// }).expect("Error setting Ctrl-C handler");
+    ///
+    /// let (connection, io_threads) = Connection::stdio();
+    ///
+    /// let server_capabilities = serde_json::to_value(&ServerCapabilities::default()).unwrap();
+    /// let initialization_params = connection.initialize_while(
+    ///     server_capabilities,
+    ///     || running.load(Ordering::SeqCst)
+    /// );
+    ///
+    /// # assert!(initialization_params.is_err());
+    /// # Ok(())
+    /// # }
+    /// ```
+    pub fn initialize_while<C>(
+        &self,
+        server_capabilities: serde_json::Value,
+        running: C,
+    ) -> Result<serde_json::Value, ProtocolError>
+    where
+        C: Fn() -> bool,
+    {
+        let (id, params) = self.initialize_start_while(&running)?;
+
+        let initialize_data = serde_json::json!({
+            "capabilities": server_capabilities,
+        });
+
+        self.initialize_finish_while(id, initialize_data, running)?;
+
+        Ok(params)
+    }
+
     /// If `req` is `Shutdown`, respond to it and return `true`, otherwise return `false`
     pub fn handle_shutdown(&self, req: &Request) -> Result<bool, ProtocolError> {
         if !req.is_shutdown() {