Refactor to use new ACP crate (#35043)
This will prepare us for running the protocol over MCP Release Notes: - N/A --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com> Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com> Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
45ddf32a1d
commit
2d0f10c48a
21 changed files with 1830 additions and 1748 deletions
|
@ -1,6 +1,6 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt, StreamExt, channel::oneshot, select};
|
||||
use futures::{FutureExt, StreamExt, channel::oneshot, future, select};
|
||||
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
|
||||
use parking_lot::Mutex;
|
||||
use postage::barrier;
|
||||
|
@ -10,15 +10,19 @@ use smol::channel;
|
|||
use std::{
|
||||
fmt,
|
||||
path::PathBuf,
|
||||
pin::pin,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicI32, Ordering::SeqCst},
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use util::TryFutureExt;
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
|
||||
use crate::transport::{StdioTransport, Transport};
|
||||
use crate::{
|
||||
transport::{StdioTransport, Transport},
|
||||
types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
|
||||
};
|
||||
|
||||
const JSON_RPC_VERSION: &str = "2.0";
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
@ -32,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603;
|
|||
|
||||
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
|
||||
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
|
||||
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
|
@ -78,6 +83,15 @@ pub struct Request<'a, T> {
|
|||
pub params: T,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AnyRequest<'a> {
|
||||
pub jsonrpc: &'a str,
|
||||
pub id: RequestId,
|
||||
pub method: &'a str,
|
||||
#[serde(skip_serializing_if = "is_null_value")]
|
||||
pub params: Option<&'a RawValue>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AnyResponse<'a> {
|
||||
jsonrpc: &'a str,
|
||||
|
@ -176,15 +190,23 @@ impl Client {
|
|||
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
|
||||
let response_handlers =
|
||||
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
|
||||
let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
|
||||
|
||||
let receive_input_task = cx.spawn({
|
||||
let notification_handlers = notification_handlers.clone();
|
||||
let response_handlers = response_handlers.clone();
|
||||
let request_handlers = request_handlers.clone();
|
||||
let transport = transport.clone();
|
||||
async move |cx| {
|
||||
Self::handle_input(transport, notification_handlers, response_handlers, cx)
|
||||
.log_err()
|
||||
.await
|
||||
Self::handle_input(
|
||||
transport,
|
||||
notification_handlers,
|
||||
request_handlers,
|
||||
response_handlers,
|
||||
cx,
|
||||
)
|
||||
.log_err()
|
||||
.await
|
||||
}
|
||||
});
|
||||
let receive_err_task = cx.spawn({
|
||||
|
@ -230,13 +252,24 @@ impl Client {
|
|||
async fn handle_input(
|
||||
transport: Arc<dyn Transport>,
|
||||
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
|
||||
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
|
||||
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut receiver = transport.receive();
|
||||
|
||||
while let Some(message) = receiver.next().await {
|
||||
if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
|
||||
log::trace!("recv: {}", &message);
|
||||
if let Ok(request) = serde_json::from_str::<AnyRequest>(&message) {
|
||||
let mut request_handlers = request_handlers.lock();
|
||||
if let Some(handler) = request_handlers.get_mut(request.method) {
|
||||
handler(
|
||||
request.id,
|
||||
request.params.unwrap_or(RawValue::NULL),
|
||||
cx.clone(),
|
||||
);
|
||||
}
|
||||
} else if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
|
||||
if let Some(handlers) = response_handlers.lock().as_mut() {
|
||||
if let Some(handler) = handlers.remove(&response.id) {
|
||||
handler(Ok(message.to_string()));
|
||||
|
@ -247,6 +280,8 @@ impl Client {
|
|||
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
|
||||
handler(notification.params.unwrap_or(Value::Null), cx.clone());
|
||||
}
|
||||
} else {
|
||||
log::error!("Unhandled JSON from context_server: {}", message);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -294,6 +329,24 @@ impl Client {
|
|||
&self,
|
||||
method: &str,
|
||||
params: impl Serialize,
|
||||
) -> Result<T> {
|
||||
self.request_impl(method, params, None).await
|
||||
}
|
||||
|
||||
pub async fn cancellable_request<T: DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: impl Serialize,
|
||||
cancel_rx: oneshot::Receiver<()>,
|
||||
) -> Result<T> {
|
||||
self.request_impl(method, params, Some(cancel_rx)).await
|
||||
}
|
||||
|
||||
pub async fn request_impl<T: DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: impl Serialize,
|
||||
cancel_rx: Option<oneshot::Receiver<()>>,
|
||||
) -> Result<T> {
|
||||
let id = self.next_id.fetch_add(1, SeqCst);
|
||||
let request = serde_json::to_string(&Request {
|
||||
|
@ -330,6 +383,16 @@ impl Client {
|
|||
send?;
|
||||
|
||||
let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
|
||||
let mut cancel_fut = pin!(
|
||||
match cancel_rx {
|
||||
Some(rx) => future::Either::Left(async {
|
||||
rx.await.log_err();
|
||||
}),
|
||||
None => future::Either::Right(future::pending()),
|
||||
}
|
||||
.fuse()
|
||||
);
|
||||
|
||||
select! {
|
||||
response = rx.fuse() => {
|
||||
let elapsed = started.elapsed();
|
||||
|
@ -348,6 +411,16 @@ impl Client {
|
|||
Err(_) => anyhow::bail!("cancelled")
|
||||
}
|
||||
}
|
||||
_ = cancel_fut => {
|
||||
self.notify(
|
||||
Cancelled::METHOD,
|
||||
ClientNotification::Cancelled(CancelledParams {
|
||||
request_id: RequestId::Int(id),
|
||||
reason: None
|
||||
})
|
||||
).log_err();
|
||||
anyhow::bail!("Request cancelled")
|
||||
}
|
||||
_ = timeout => {
|
||||
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
|
||||
anyhow::bail!("Context server request timeout");
|
||||
|
|
|
@ -6,6 +6,9 @@
|
|||
//! of messages.
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::AsyncApp;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::client::Client;
|
||||
use crate::types::{self, Notification, Request};
|
||||
|
@ -95,7 +98,24 @@ impl InitializedContextServerProtocol {
|
|||
self.inner.request(T::METHOD, params).await
|
||||
}
|
||||
|
||||
pub async fn cancellable_request<T: Request>(
|
||||
&self,
|
||||
params: T::Params,
|
||||
cancel_rx: oneshot::Receiver<()>,
|
||||
) -> Result<T::Response> {
|
||||
self.inner
|
||||
.cancellable_request(T::METHOD, params, cancel_rx)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
|
||||
self.inner.notify(T::METHOD, params)
|
||||
}
|
||||
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
{
|
||||
self.inner.on_notification(method, f);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ use serde::de::DeserializeOwned;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
use crate::client::RequestId;
|
||||
|
||||
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
|
||||
pub const VERSION_2024_11_05: &str = "2024-11-05";
|
||||
|
||||
|
@ -100,6 +102,7 @@ pub mod notifications {
|
|||
notification!("notifications/initialized", Initialized, ());
|
||||
notification!("notifications/progress", Progress, ProgressParams);
|
||||
notification!("notifications/message", Message, MessageParams);
|
||||
notification!("notifications/cancelled", Cancelled, CancelledParams);
|
||||
notification!(
|
||||
"notifications/resources/updated",
|
||||
ResourcesUpdated,
|
||||
|
@ -617,11 +620,14 @@ pub enum ClientNotification {
|
|||
Initialized,
|
||||
Progress(ProgressParams),
|
||||
RootsListChanged,
|
||||
Cancelled {
|
||||
request_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reason: Option<String>,
|
||||
},
|
||||
Cancelled(CancelledParams),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CancelledParams {
|
||||
pub request_id: RequestId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue