From 27708143ecac70b963bbe70153d426fe5f061277 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Wed, 30 Jul 2025 13:30:50 -0300 Subject: [PATCH] Fix auth --- crates/acp_thread/src/acp_thread.rs | 1 - crates/acp_thread/src/connection.rs | 14 +++++- crates/acp_thread/src/old_acp_support.rs | 31 ++++--------- crates/agent_servers/src/acp_connection.rs | 54 +++++++--------------- crates/agent_servers/src/claude.rs | 8 ++-- crates/agent_ui/src/acp/thread_view.rs | 28 ++++------- 6 files changed, 48 insertions(+), 88 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 8aa07da330..bc2f8c756a 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1595,7 +1595,6 @@ mod tests { connection, child_status: io_task, current_thread: thread_rc, - agent_state: Default::default(), }; AcpThread::new( diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 2e7deaf7df..11f1fcc94c 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,4 +1,4 @@ -use std::{cell::Ref, path::Path, rc::Rc}; +use std::{error::Error, fmt, path::Path, rc::Rc}; use agent_client_protocol::{self as acp}; use anyhow::Result; @@ -16,7 +16,7 @@ pub trait AgentConnection { cx: &mut AsyncApp, ) -> Task>>; - fn state(&self) -> Ref<'_, acp::AgentState>; + fn auth_methods(&self) -> Vec; fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; @@ -24,3 +24,13 @@ pub trait AgentConnection { fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); } + +#[derive(Debug)] +pub struct AuthRequired; + +impl Error for AuthRequired {} +impl fmt::Display for AuthRequired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthRequired") + } +} diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs index 718ad0da03..88313e0fd5 100644 --- a/crates/acp_thread/src/old_acp_support.rs +++ b/crates/acp_thread/src/old_acp_support.rs @@ -5,17 +5,11 @@ use anyhow::{Context as _, Result}; use futures::channel::oneshot; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; -use std::{ - cell::{Ref, RefCell}, - error::Error, - fmt, - path::Path, - rc::Rc, -}; +use std::{cell::RefCell, path::Path, rc::Rc}; use ui::App; use util::ResultExt as _; -use crate::{AcpThread, AgentConnection}; +use crate::{AcpThread, AgentConnection, AuthRequired}; #[derive(Clone)] pub struct OldAcpClientDelegate { @@ -357,21 +351,10 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu } } -#[derive(Debug)] -pub struct Unauthenticated; - -impl Error for Unauthenticated {} -impl fmt::Display for Unauthenticated { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Unauthenticated") - } -} - pub struct OldAcpAgentConnection { pub name: &'static str, pub connection: acp_old::AgentConnection, pub child_status: Task>, - pub agent_state: Rc>, pub current_thread: Rc>>, } @@ -394,7 +377,7 @@ impl AgentConnection for OldAcpAgentConnection { let result = acp_old::InitializeParams::response_from_any(result)?; if !result.is_authenticated { - anyhow::bail!(Unauthenticated) + anyhow::bail!(AuthRequired) } cx.update(|cx| { @@ -408,8 +391,12 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn state(&self) -> Ref<'_, acp::AgentState> { - self.agent_state.borrow() + fn auth_methods(&self) -> Vec { + vec![acp::AuthMethod { + id: acp::AuthMethodId("acp-old-no-id".into()), + label: "Log in".into(), + description: None, + }] } fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs index 95c09e2c52..c19a145196 100644 --- a/crates/agent_servers/src/acp_connection.rs +++ b/crates/agent_servers/src/acp_connection.rs @@ -7,24 +7,23 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use futures::channel::{mpsc, oneshot}; use project::Project; use smol::stream::StreamExt as _; -use std::cell::{Ref, RefCell}; +use std::cell::RefCell; use std::rc::Rc; use std::{path::Path, sync::Arc}; -use util::{ResultExt, TryFutureExt}; +use util::ResultExt; use anyhow::{Context, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use crate::mcp_server::ZedMcpServer; use crate::{AgentServerCommand, mcp_server}; -use acp_thread::{AcpThread, AgentConnection}; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; pub struct AcpConnection { - agent_state: Rc>, + auth_methods: Rc>>, server_name: &'static str, client: Arc, sessions: Rc>>, - _agent_state_task: Task<()>, _session_update_task: Task<()>, } @@ -47,24 +46,8 @@ impl AcpConnection { .into(); ContextServer::start(client.clone(), cx).await?; - let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default()); let mcp_client = client.client().context("Failed to subscribe")?; - mcp_client.on_notification(acp::AGENT_METHODS.agent_state, { - move |notification, _cx| { - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(state) = - serde_json::from_value::(notification).log_err() - { - state_tx.send(state).log_err(); - } - } - }); - let (notification_tx, mut notification_rx) = mpsc::unbounded(); mcp_client.on_notification(acp::AGENT_METHODS.session_update, { move |notification, _cx| { @@ -83,17 +66,6 @@ impl AcpConnection { }); let sessions = Rc::new(RefCell::new(HashMap::default())); - let initial_state = state_rx.recv().await?; - let agent_state = Rc::new(RefCell::new(initial_state)); - - let agent_state_task = cx.foreground_executor().spawn({ - let agent_state = agent_state.clone(); - async move { - while let Some(state) = state_rx.recv().log_err().await { - agent_state.replace(state); - } - } - }); let session_update_handler_task = cx.spawn({ let sessions = sessions.clone(); @@ -105,11 +77,10 @@ impl AcpConnection { }); Ok(Self { + auth_methods: Default::default(), server_name, client, sessions, - agent_state, - _agent_state_task: agent_state_task, _session_update_task: session_update_handler_task, }) } @@ -154,6 +125,7 @@ impl AgentConnection for AcpConnection { ) -> Task>> { let client = self.client.client(); let sessions = self.sessions.clone(); + let auth_methods = self.auth_methods.clone(); let cwd = cwd.to_path_buf(); cx.spawn(async move |cx| { let client = client.context("MCP server is not initialized yet")?; @@ -194,12 +166,18 @@ impl AgentConnection for AcpConnection { response.structured_content.context("Empty response")?, )?; + auth_methods.replace(result.auth_methods); + + let Some(session_id) = result.session_id else { + anyhow::bail!(AuthRequired); + }; + let thread = cx.new(|cx| { AcpThread::new( self.server_name, self.clone(), project, - result.session_id.clone(), + session_id.clone(), cx, ) })?; @@ -211,14 +189,14 @@ impl AgentConnection for AcpConnection { cancel_tx: None, _mcp_server: mcp_server, }; - sessions.borrow_mut().insert(result.session_id, session); + sessions.borrow_mut().insert(session_id, session); Ok(thread) }) } - fn state(&self) -> Ref<'_, acp::AgentState> { - self.agent_state.borrow() + fn auth_methods(&self) -> Vec { + self.auth_methods.borrow().clone() } fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 0f49403a0b..736fdd2726 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,7 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; -use std::cell::{Ref, RefCell}; +use std::cell::RefCell; use std::fmt::Display; use std::path::Path; use std::rc::Rc; @@ -58,7 +58,6 @@ impl AgentServer for ClaudeCode { _cx: &mut App, ) -> Task>> { let connection = ClaudeAgentConnection { - agent_state: Default::default(), sessions: Default::default(), }; @@ -67,7 +66,6 @@ impl AgentServer for ClaudeCode { } struct ClaudeAgentConnection { - agent_state: Rc>, sessions: Rc>>, } @@ -185,8 +183,8 @@ impl AgentConnection for ClaudeAgentConnection { }) } - fn state(&self) -> Ref<'_, acp::AgentState> { - self.agent_state.borrow() + fn auth_methods(&self) -> Vec { + vec![] } fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 824748a0aa..6d7684bbfc 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -216,15 +216,6 @@ impl AcpThreadView { } }; - if connection.state().needs_authentication { - this.update(cx, |this, cx| { - this.thread_state = ThreadState::Unauthenticated { connection }; - cx.notify(); - }) - .ok(); - return; - } - let result = match connection .clone() .new_thread(project.clone(), &root_dir, cx) @@ -233,7 +224,7 @@ impl AcpThreadView { Err(e) => { let mut cx = cx.clone(); // todo! remove duplication - if e.downcast_ref::().is_some() { + if e.downcast_ref::().is_some() { this.update(&mut cx, |this, cx| { this.thread_state = ThreadState::Unauthenticated { connection }; cx.notify(); @@ -2219,17 +2210,14 @@ impl Render for AcpThreadView { .justify_center() .child(self.render_pending_auth_state()) .child(h_flex().mt_1p5().justify_center().children( - connection.state().auth_methods.iter().map(|method| { - Button::new( - SharedString::from(method.id.0.clone()), - method.label.clone(), - ) - .on_click({ - let method_id = method.id.clone(); - cx.listener(move |this, _, window, cx| { - this.authenticate(method_id.clone(), window, cx) + connection.auth_methods().into_iter().map(|method| { + Button::new(SharedString::from(method.id.0.clone()), method.label) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) }) - }) }), )), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),