Refactor handling of ContextServer notifications
The notification handler registration is now more explicit, with handlers set up before server initialization to avoid potential race conditions.
This commit is contained in:
parent
f028ca4d1a
commit
81c111510f
4 changed files with 79 additions and 56 deletions
|
@ -22,7 +22,7 @@ use acp_thread::{AcpThread, AgentConnection};
|
||||||
pub struct AcpConnection {
|
pub struct AcpConnection {
|
||||||
agent_state: Rc<RefCell<acp::AgentState>>,
|
agent_state: Rc<RefCell<acp::AgentState>>,
|
||||||
server_name: &'static str,
|
server_name: &'static str,
|
||||||
client: Arc<context_server::ContextServer>,
|
context_server: Arc<context_server::ContextServer>,
|
||||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||||
_agent_state_task: Task<()>,
|
_agent_state_task: Task<()>,
|
||||||
_session_update_task: Task<()>,
|
_session_update_task: Task<()>,
|
||||||
|
@ -35,7 +35,7 @@ impl AcpConnection {
|
||||||
working_directory: Option<Arc<Path>>,
|
working_directory: Option<Arc<Path>>,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let client: Arc<ContextServer> = ContextServer::stdio(
|
let context_server: Arc<ContextServer> = ContextServer::stdio(
|
||||||
ContextServerId(format!("{}-mcp-server", server_name).into()),
|
ContextServerId(format!("{}-mcp-server", server_name).into()),
|
||||||
ContextServerCommand {
|
ContextServerCommand {
|
||||||
path: command.path,
|
path: command.path,
|
||||||
|
@ -45,42 +45,9 @@ impl AcpConnection {
|
||||||
working_directory,
|
working_directory,
|
||||||
)
|
)
|
||||||
.into();
|
.into();
|
||||||
ContextServer::start(client.clone(), cx).await?;
|
|
||||||
|
|
||||||
let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
|
let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
|
||||||
let mcp_client = client.client().context("Failed to subscribe")?;
|
|
||||||
|
|
||||||
mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
|
|
||||||
move |notification, _cx| {
|
|
||||||
log::trace!(
|
|
||||||
"ACP Notification: {}",
|
|
||||||
serde_json::to_string_pretty(¬ification).unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(state) =
|
|
||||||
serde_json::from_value::<acp::AgentState>(notification).log_err()
|
|
||||||
{
|
|
||||||
state_tx.send(state).log_err();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||||
mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
|
|
||||||
move |notification, _cx| {
|
|
||||||
let notification_tx = notification_tx.clone();
|
|
||||||
log::trace!(
|
|
||||||
"ACP Notification: {}",
|
|
||||||
serde_json::to_string_pretty(¬ification).unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(notification) =
|
|
||||||
serde_json::from_value::<acp::SessionNotification>(notification).log_err()
|
|
||||||
{
|
|
||||||
notification_tx.unbounded_send(notification).ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||||
let initial_state = state_rx.recv().await?;
|
let initial_state = state_rx.recv().await?;
|
||||||
|
@ -104,9 +71,47 @@ impl AcpConnection {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
context_server
|
||||||
|
.start_with_handlers(
|
||||||
|
vec![
|
||||||
|
(acp::AGENT_METHODS.agent_state, {
|
||||||
|
Box::new(move |notification, _cx| {
|
||||||
|
log::trace!(
|
||||||
|
"ACP Notification: {}",
|
||||||
|
serde_json::to_string_pretty(¬ification).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(state) =
|
||||||
|
serde_json::from_value::<acp::AgentState>(notification).log_err()
|
||||||
|
{
|
||||||
|
state_tx.send(state).log_err();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
(acp::AGENT_METHODS.session_update, {
|
||||||
|
Box::new(move |notification, _cx| {
|
||||||
|
let notification_tx = notification_tx.clone();
|
||||||
|
log::trace!(
|
||||||
|
"ACP Notification: {}",
|
||||||
|
serde_json::to_string_pretty(¬ification).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(notification) =
|
||||||
|
serde_json::from_value::<acp::SessionNotification>(notification)
|
||||||
|
.log_err()
|
||||||
|
{
|
||||||
|
notification_tx.unbounded_send(notification).ok();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
server_name,
|
server_name,
|
||||||
client,
|
context_server,
|
||||||
sessions,
|
sessions,
|
||||||
agent_state,
|
agent_state,
|
||||||
_agent_state_task: agent_state_task,
|
_agent_state_task: agent_state_task,
|
||||||
|
@ -152,7 +157,7 @@ impl AgentConnection for AcpConnection {
|
||||||
cwd: &Path,
|
cwd: &Path,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Task<Result<Entity<AcpThread>>> {
|
) -> Task<Result<Entity<AcpThread>>> {
|
||||||
let client = self.client.client();
|
let client = self.context_server.client();
|
||||||
let sessions = self.sessions.clone();
|
let sessions = self.sessions.clone();
|
||||||
let cwd = cwd.to_path_buf();
|
let cwd = cwd.to_path_buf();
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
|
@ -222,7 +227,7 @@ impl AgentConnection for AcpConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||||
let client = self.client.client();
|
let client = self.context_server.client();
|
||||||
cx.foreground_executor().spawn(async move {
|
cx.foreground_executor().spawn(async move {
|
||||||
let params = acp::AuthenticateArguments { method_id };
|
let params = acp::AuthenticateArguments { method_id };
|
||||||
|
|
||||||
|
@ -248,7 +253,7 @@ impl AgentConnection for AcpConnection {
|
||||||
params: agent_client_protocol::PromptArguments,
|
params: agent_client_protocol::PromptArguments,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
let client = self.client.client();
|
let client = self.context_server.client();
|
||||||
let sessions = self.sessions.clone();
|
let sessions = self.sessions.clone();
|
||||||
|
|
||||||
cx.foreground_executor().spawn(async move {
|
cx.foreground_executor().spawn(async move {
|
||||||
|
@ -305,6 +310,6 @@ impl AgentConnection for AcpConnection {
|
||||||
|
|
||||||
impl Drop for AcpConnection {
|
impl Drop for AcpConnection {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.client.stop().log_err();
|
self.context_server.stop().log_err();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -441,14 +441,12 @@ impl Client {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
pub fn on_notification(
|
||||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
&self,
|
||||||
where
|
method: &'static str,
|
||||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||||
{
|
) {
|
||||||
self.notification_handlers
|
self.notification_handlers.lock().insert(method, f);
|
||||||
.lock()
|
|
||||||
.insert(method, Box::new(f));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -95,8 +95,28 @@ impl ContextServer {
|
||||||
self.client.read().clone()
|
self.client.read().clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
|
||||||
let client = match &self.configuration {
|
self.initialize(self.new_client(cx)?).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts the context server, making sure handlers are registered before initialization happens
|
||||||
|
pub async fn start_with_handlers(
|
||||||
|
&self,
|
||||||
|
notification_handlers: Vec<(
|
||||||
|
&'static str,
|
||||||
|
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
|
||||||
|
)>,
|
||||||
|
cx: &AsyncApp,
|
||||||
|
) -> Result<()> {
|
||||||
|
let client = self.new_client(cx)?;
|
||||||
|
for (method, handler) in notification_handlers {
|
||||||
|
client.on_notification(method, handler);
|
||||||
|
}
|
||||||
|
self.initialize(client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
|
||||||
|
Ok(match &self.configuration {
|
||||||
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
|
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
|
||||||
client::ContextServerId(self.id.0.clone()),
|
client::ContextServerId(self.id.0.clone()),
|
||||||
client::ModelContextServerBinary {
|
client::ModelContextServerBinary {
|
||||||
|
@ -113,8 +133,7 @@ impl ContextServer {
|
||||||
transport.clone(),
|
transport.clone(),
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)?,
|
)?,
|
||||||
};
|
})
|
||||||
self.initialize(client).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn initialize(&self, client: Client) -> Result<()> {
|
async fn initialize(&self, client: Client) -> Result<()> {
|
||||||
|
|
|
@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
|
||||||
self.inner.notify(T::METHOD, params)
|
self.inner.notify(T::METHOD, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
pub fn on_notification(
|
||||||
where
|
&self,
|
||||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
method: &'static str,
|
||||||
{
|
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||||
|
) {
|
||||||
self.inner.on_notification(method, f);
|
self.inner.on_notification(method, f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue