This commit is contained in:
Agus Zubiaga 2025-07-30 13:30:50 -03:00
parent 738296345e
commit 27708143ec
6 changed files with 48 additions and 88 deletions

View file

@ -1595,7 +1595,6 @@ mod tests {
connection, connection,
child_status: io_task, child_status: io_task,
current_thread: thread_rc, current_thread: thread_rc,
agent_state: Default::default(),
}; };
AcpThread::new( AcpThread::new(

View file

@ -1,4 +1,4 @@
use std::{cell::Ref, path::Path, rc::Rc}; use std::{error::Error, fmt, path::Path, rc::Rc};
use agent_client_protocol::{self as acp}; use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;
@ -16,7 +16,7 @@ pub trait AgentConnection {
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>; ) -> Task<Result<Entity<AcpThread>>>;
fn state(&self) -> Ref<'_, acp::AgentState>; fn auth_methods(&self) -> Vec<acp::AuthMethod>;
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>; fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
@ -24,3 +24,13 @@ pub trait AgentConnection {
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
} }
#[derive(Debug)]
pub struct AuthRequired;
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired")
}
}

View file

@ -5,17 +5,11 @@ use anyhow::{Context as _, Result};
use futures::channel::oneshot; use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project; use project::Project;
use std::{ use std::{cell::RefCell, path::Path, rc::Rc};
cell::{Ref, RefCell},
error::Error,
fmt,
path::Path,
rc::Rc,
};
use ui::App; use ui::App;
use util::ResultExt as _; use util::ResultExt as _;
use crate::{AcpThread, AgentConnection}; use crate::{AcpThread, AgentConnection, AuthRequired};
#[derive(Clone)] #[derive(Clone)]
pub struct OldAcpClientDelegate { pub struct OldAcpClientDelegate {
@ -357,21 +351,10 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
} }
} }
#[derive(Debug)]
pub struct Unauthenticated;
impl Error for Unauthenticated {}
impl fmt::Display for Unauthenticated {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unauthenticated")
}
}
pub struct OldAcpAgentConnection { pub struct OldAcpAgentConnection {
pub name: &'static str, pub name: &'static str,
pub connection: acp_old::AgentConnection, pub connection: acp_old::AgentConnection,
pub child_status: Task<Result<()>>, pub child_status: Task<Result<()>>,
pub agent_state: Rc<RefCell<acp::AgentState>>,
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>, pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
} }
@ -394,7 +377,7 @@ impl AgentConnection for OldAcpAgentConnection {
let result = acp_old::InitializeParams::response_from_any(result)?; let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated { if !result.is_authenticated {
anyhow::bail!(Unauthenticated) anyhow::bail!(AuthRequired)
} }
cx.update(|cx| { cx.update(|cx| {
@ -408,8 +391,12 @@ impl AgentConnection for OldAcpAgentConnection {
}) })
} }
fn state(&self) -> Ref<'_, acp::AgentState> { fn auth_methods(&self) -> Vec<acp::AuthMethod> {
self.agent_state.borrow() vec![acp::AuthMethod {
id: acp::AuthMethodId("acp-old-no-id".into()),
label: "Log in".into(),
description: None,
}]
} }
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {

View file

@ -7,24 +7,23 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot}; use futures::channel::{mpsc, oneshot};
use project::Project; use project::Project;
use smol::stream::StreamExt as _; use smol::stream::StreamExt as _;
use std::cell::{Ref, RefCell}; use std::cell::RefCell;
use std::rc::Rc; use std::rc::Rc;
use std::{path::Path, sync::Arc}; use std::{path::Path, sync::Arc};
use util::{ResultExt, TryFutureExt}; use util::ResultExt;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::mcp_server::ZedMcpServer; use crate::mcp_server::ZedMcpServer;
use crate::{AgentServerCommand, mcp_server}; use crate::{AgentServerCommand, mcp_server};
use acp_thread::{AcpThread, AgentConnection}; use acp_thread::{AcpThread, AgentConnection, AuthRequired};
pub struct AcpConnection { pub struct AcpConnection {
agent_state: Rc<RefCell<acp::AgentState>>, auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
server_name: &'static str, server_name: &'static str,
client: Arc<context_server::ContextServer>, client: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>, sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
_agent_state_task: Task<()>,
_session_update_task: Task<()>, _session_update_task: Task<()>,
} }
@ -47,24 +46,8 @@ impl AcpConnection {
.into(); .into();
ContextServer::start(client.clone(), cx).await?; ContextServer::start(client.clone(), cx).await?;
let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
let mcp_client = client.client().context("Failed to subscribe")?; let mcp_client = client.client().context("Failed to subscribe")?;
mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
move |notification, _cx| {
log::trace!(
"ACP Notification: {}",
serde_json::to_string_pretty(&notification).unwrap()
);
if let Some(state) =
serde_json::from_value::<acp::AgentState>(notification).log_err()
{
state_tx.send(state).log_err();
}
}
});
let (notification_tx, mut notification_rx) = mpsc::unbounded(); let (notification_tx, mut notification_rx) = mpsc::unbounded();
mcp_client.on_notification(acp::AGENT_METHODS.session_update, { mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
move |notification, _cx| { move |notification, _cx| {
@ -83,17 +66,6 @@ impl AcpConnection {
}); });
let sessions = Rc::new(RefCell::new(HashMap::default())); let sessions = Rc::new(RefCell::new(HashMap::default()));
let initial_state = state_rx.recv().await?;
let agent_state = Rc::new(RefCell::new(initial_state));
let agent_state_task = cx.foreground_executor().spawn({
let agent_state = agent_state.clone();
async move {
while let Some(state) = state_rx.recv().log_err().await {
agent_state.replace(state);
}
}
});
let session_update_handler_task = cx.spawn({ let session_update_handler_task = cx.spawn({
let sessions = sessions.clone(); let sessions = sessions.clone();
@ -105,11 +77,10 @@ impl AcpConnection {
}); });
Ok(Self { Ok(Self {
auth_methods: Default::default(),
server_name, server_name,
client, client,
sessions, sessions,
agent_state,
_agent_state_task: agent_state_task,
_session_update_task: session_update_handler_task, _session_update_task: session_update_handler_task,
}) })
} }
@ -154,6 +125,7 @@ impl AgentConnection for AcpConnection {
) -> Task<Result<Entity<AcpThread>>> { ) -> Task<Result<Entity<AcpThread>>> {
let client = self.client.client(); let client = self.client.client();
let sessions = self.sessions.clone(); let sessions = self.sessions.clone();
let auth_methods = self.auth_methods.clone();
let cwd = cwd.to_path_buf(); let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let client = client.context("MCP server is not initialized yet")?; let client = client.context("MCP server is not initialized yet")?;
@ -194,12 +166,18 @@ impl AgentConnection for AcpConnection {
response.structured_content.context("Empty response")?, response.structured_content.context("Empty response")?,
)?; )?;
auth_methods.replace(result.auth_methods);
let Some(session_id) = result.session_id else {
anyhow::bail!(AuthRequired);
};
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
AcpThread::new( AcpThread::new(
self.server_name, self.server_name,
self.clone(), self.clone(),
project, project,
result.session_id.clone(), session_id.clone(),
cx, cx,
) )
})?; })?;
@ -211,14 +189,14 @@ impl AgentConnection for AcpConnection {
cancel_tx: None, cancel_tx: None,
_mcp_server: mcp_server, _mcp_server: mcp_server,
}; };
sessions.borrow_mut().insert(result.session_id, session); sessions.borrow_mut().insert(session_id, session);
Ok(thread) Ok(thread)
}) })
} }
fn state(&self) -> Ref<'_, acp::AgentState> { fn auth_methods(&self) -> Vec<agent_client_protocol::AuthMethod> {
self.agent_state.borrow() self.auth_methods.borrow().clone()
} }
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {

View file

@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use smol::process::Child; use smol::process::Child;
use std::cell::{Ref, RefCell}; use std::cell::RefCell;
use std::fmt::Display; use std::fmt::Display;
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
@ -58,7 +58,6 @@ impl AgentServer for ClaudeCode {
_cx: &mut App, _cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> { ) -> Task<Result<Rc<dyn AgentConnection>>> {
let connection = ClaudeAgentConnection { let connection = ClaudeAgentConnection {
agent_state: Default::default(),
sessions: Default::default(), sessions: Default::default(),
}; };
@ -67,7 +66,6 @@ impl AgentServer for ClaudeCode {
} }
struct ClaudeAgentConnection { struct ClaudeAgentConnection {
agent_state: Rc<RefCell<acp::AgentState>>,
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>, sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
} }
@ -185,8 +183,8 @@ impl AgentConnection for ClaudeAgentConnection {
}) })
} }
fn state(&self) -> Ref<'_, acp::AgentState> { fn auth_methods(&self) -> Vec<acp::AuthMethod> {
self.agent_state.borrow() vec![]
} }
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {

View file

@ -216,15 +216,6 @@ impl AcpThreadView {
} }
}; };
if connection.state().needs_authentication {
this.update(cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify();
})
.ok();
return;
}
let result = match connection let result = match connection
.clone() .clone()
.new_thread(project.clone(), &root_dir, cx) .new_thread(project.clone(), &root_dir, cx)
@ -233,7 +224,7 @@ impl AcpThreadView {
Err(e) => { Err(e) => {
let mut cx = cx.clone(); let mut cx = cx.clone();
// todo! remove duplication // todo! remove duplication
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() { if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection }; this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify(); cx.notify();
@ -2219,17 +2210,14 @@ impl Render for AcpThreadView {
.justify_center() .justify_center()
.child(self.render_pending_auth_state()) .child(self.render_pending_auth_state())
.child(h_flex().mt_1p5().justify_center().children( .child(h_flex().mt_1p5().justify_center().children(
connection.state().auth_methods.iter().map(|method| { connection.auth_methods().into_iter().map(|method| {
Button::new( Button::new(SharedString::from(method.id.0.clone()), method.label)
SharedString::from(method.id.0.clone()), .on_click({
method.label.clone(), let method_id = method.id.clone();
) cx.listener(move |this, _, window, cx| {
.on_click({ this.authenticate(method_id.clone(), window, cx)
let method_id = method.id.clone(); })
cx.listener(move |this, _, window, cx| {
this.authenticate(method_id.clone(), window, cx)
}) })
})
}), }),
)), )),
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),