diff --git a/Cargo.lock b/Cargo.lock index 1682b80a8c..f68136d978 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -139,8 +139,6 @@ dependencies = [ [[package]] name = "agent-client-protocol" version = "0.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4255a06cc2414033d1fe4baf1968bcc8f16d7e5814f272b97779b5806d129142" dependencies = [ "schemars", "serde", diff --git a/Cargo.toml b/Cargo.toml index 81da82cbb7..d733f2242e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.13" +agent-client-protocol = {path="../agent-client-protocol"} aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 841d320796..3b9f0842bd 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -958,10 +958,6 @@ impl AcpThread { cx.notify(); } - pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future> { - self.connection.authenticate(cx) - } - #[cfg(any(test, feature = "test-support"))] pub fn send_raw( &mut self, diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 97161a19c0..2e7deaf7df 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,6 +1,6 @@ -use std::{path::Path, rc::Rc}; +use std::{cell::Ref, path::Path, rc::Rc}; -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp}; use anyhow::Result; use gpui::{AsyncApp, Entity, Task}; use project::Project; @@ -16,7 +16,9 @@ pub trait AgentConnection { cx: &mut AsyncApp, ) -> Task>>; - fn authenticate(&self, cx: &mut App) -> Task>; + fn state(&self) -> Ref<'_, acp::AgentState>; + + fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task>; diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs index d7ef1b73da..4d06f81d06 100644 --- a/crates/acp_thread/src/old_acp_support.rs +++ b/crates/acp_thread/src/old_acp_support.rs @@ -5,7 +5,13 @@ use anyhow::{Context as _, Result}; use futures::channel::oneshot; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; -use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc}; +use std::{ + cell::{Ref, RefCell}, + error::Error, + fmt, + path::Path, + rc::Rc, +}; use ui::App; use crate::{AcpThread, AgentConnection}; @@ -364,6 +370,7 @@ pub struct OldAcpAgentConnection { pub name: &'static str, pub connection: acp_old::AgentConnection, pub child_status: Task>, + pub agent_state: Rc>, } impl AgentConnection for OldAcpAgentConnection { @@ -397,7 +404,11 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn authenticate(&self, cx: &mut App) -> Task> { + fn state(&self) -> Ref<'_, acp::AgentState> { + self.agent_state.borrow() + } + + fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { let task = self .connection .request_any(acp_old::AuthenticateParams.into_any()); diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs index 9139d62c38..96067fe520 100644 --- a/crates/agent_servers/src/acp_connection.rs +++ b/crates/agent_servers/src/acp_connection.rs @@ -7,10 +7,10 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use futures::channel::{mpsc, oneshot}; use project::Project; use smol::stream::StreamExt as _; -use std::cell::RefCell; +use std::cell::{Ref, RefCell}; use std::rc::Rc; use std::{path::Path, sync::Arc}; -use util::ResultExt; +use util::{ResultExt, TryFutureExt}; use anyhow::{Context, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; @@ -20,10 +20,12 @@ use crate::{AgentServerCommand, mcp_server}; use acp_thread::{AcpThread, AgentConnection}; pub struct AcpConnection { + agent_state: Rc>, server_name: &'static str, client: Arc, sessions: Rc>>, - _notification_handler_task: Task<()>, + _agent_state_task: Task<()>, + _session_update_task: Task<()>, } impl AcpConnection { @@ -43,29 +45,55 @@ impl AcpConnection { .into(); ContextServer::start(client.clone(), cx).await?; - let (notification_tx, mut notification_rx) = mpsc::unbounded(); - client - .client() - .context("Failed to subscribe")? - .on_notification(acp::AGENT_METHODS.session_update, { - move |notification, _cx| { - let notification_tx = notification_tx.clone(); - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); + let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default()); + let mcp_client = client.client().context("Failed to subscribe")?; - if let Some(notification) = - serde_json::from_value::(notification).log_err() - { - notification_tx.unbounded_send(notification).ok(); - } + mcp_client.on_notification(acp::AGENT_METHODS.agent_state, { + move |notification, _cx| { + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(state) = + serde_json::from_value::(notification).log_err() + { + state_tx.send(state).log_err(); } - }); + } + }); + + let (notification_tx, mut notification_rx) = mpsc::unbounded(); + mcp_client.on_notification(acp::AGENT_METHODS.session_update, { + move |notification, _cx| { + let notification_tx = notification_tx.clone(); + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(notification) = + serde_json::from_value::(notification).log_err() + { + notification_tx.unbounded_send(notification).ok(); + } + } + }); let sessions = Rc::new(RefCell::new(HashMap::default())); + let initial_state = state_rx.recv().await?; + let agent_state = Rc::new(RefCell::new(initial_state)); - let notification_handler_task = cx.spawn({ + let agent_state_task = cx.foreground_executor().spawn({ + let agent_state = agent_state.clone(); + async move { + while let Some(state) = state_rx.recv().log_err().await { + agent_state.replace(state); + } + } + }); + + let session_update_handler_task = cx.spawn({ let sessions = sessions.clone(); async move |cx| { while let Some(notification) = notification_rx.next().await { @@ -78,7 +106,9 @@ impl AcpConnection { server_name, client, sessions, - _notification_handler_task: notification_handler_task, + agent_state, + _agent_state_task: agent_state_task, + _session_update_task: session_update_handler_task, }) } @@ -185,8 +215,30 @@ impl AgentConnection for AcpConnection { }) } - fn authenticate(&self, _cx: &mut App) -> Task> { - Task::ready(Err(anyhow!("Authentication not supported"))) + fn state(&self) -> Ref<'_, acp::AgentState> { + self.agent_state.borrow() + } + + fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let client = self.client.client(); + cx.foreground_executor().spawn(async move { + let params = acp::AuthenticateArguments { method_id }; + + let response = client + .context("MCP server is not initialized yet")? + .request::(context_server::types::CallToolParams { + name: acp::AGENT_METHODS.authenticate.into(), + arguments: Some(serde_json::to_value(params)?), + meta: None, + }) + .await?; + + if response.is_error.unwrap_or_default() { + Err(anyhow!(response.text_contents())) + } else { + Ok(()) + } + }) } fn prompt( diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 590da69cd8..0f49403a0b 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,7 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; -use std::cell::RefCell; +use std::cell::{Ref, RefCell}; use std::fmt::Display; use std::path::Path; use std::rc::Rc; @@ -58,6 +58,7 @@ impl AgentServer for ClaudeCode { _cx: &mut App, ) -> Task>> { let connection = ClaudeAgentConnection { + agent_state: Default::default(), sessions: Default::default(), }; @@ -66,6 +67,7 @@ impl AgentServer for ClaudeCode { } struct ClaudeAgentConnection { + agent_state: Rc>, sessions: Rc>>, } @@ -183,7 +185,11 @@ impl AgentConnection for ClaudeAgentConnection { }) } - fn authenticate(&self, _cx: &mut App) -> Task> { + fn state(&self) -> Ref<'_, acp::AgentState> { + self.agent_state.borrow() + } + + fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { Task::ready(Err(anyhow!("Authentication not supported"))) } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index e46e1ae3ab..824748a0aa 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -216,6 +216,15 @@ impl AcpThreadView { } }; + if connection.state().needs_authentication { + this.update(cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { connection }; + cx.notify(); + }) + .ok(); + return; + } + let result = match connection .clone() .new_thread(project.clone(), &root_dir, cx) @@ -223,6 +232,7 @@ impl AcpThreadView { { Err(e) => { let mut cx = cx.clone(); + // todo! remove duplication if e.downcast_ref::().is_some() { this.update(&mut cx, |this, cx| { this.thread_state = ThreadState::Unauthenticated { connection }; @@ -640,13 +650,18 @@ impl AcpThreadView { Some(entry.diffs().map(|diff| diff.multibuffer.clone())) } - fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { + fn authenticate( + &mut self, + method: acp::AuthMethodId, + window: &mut Window, + cx: &mut Context, + ) { let ThreadState::Unauthenticated { ref connection } = self.thread_state else { return; }; self.last_error.take(); - let authenticate = connection.authenticate(cx); + let authenticate = connection.authenticate(method, cx); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); let agent = self.agent.clone(); @@ -2197,22 +2212,26 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::open_agent_diff)) .child(match &self.thread_state { - ThreadState::Unauthenticated { .. } => { - v_flex() - .p_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .child( - h_flex().mt_1p5().justify_center().child( - Button::new("sign-in", format!("Sign in to {}", self.agent.name())) - .on_click(cx.listener(|this, _, window, cx| { - this.authenticate(window, cx) - })), - ), - ) - } + ThreadState::Unauthenticated { connection } => v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_pending_auth_state()) + .child(h_flex().mt_1p5().justify_center().children( + connection.state().auth_methods.iter().map(|method| { + Button::new( + SharedString::from(method.id.0.clone()), + method.label.clone(), + ) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) + }) + }), + )), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2()