This commit is contained in:
Agus Zubiaga 2025-07-30 13:30:50 -03:00
parent 738296345e
commit 27708143ec
6 changed files with 48 additions and 88 deletions

View file

@ -1595,7 +1595,6 @@ mod tests {
connection,
child_status: io_task,
current_thread: thread_rc,
agent_state: Default::default(),
};
AcpThread::new(

View file

@ -1,4 +1,4 @@
use std::{cell::Ref, path::Path, rc::Rc};
use std::{error::Error, fmt, path::Path, rc::Rc};
use agent_client_protocol::{self as acp};
use anyhow::Result;
@ -16,7 +16,7 @@ pub trait AgentConnection {
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>;
fn state(&self) -> Ref<'_, acp::AgentState>;
fn auth_methods(&self) -> Vec<acp::AuthMethod>;
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
@ -24,3 +24,13 @@ pub trait AgentConnection {
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
}
#[derive(Debug)]
pub struct AuthRequired;
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired")
}
}

View file

@ -5,17 +5,11 @@ use anyhow::{Context as _, Result};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
use std::{
cell::{Ref, RefCell},
error::Error,
fmt,
path::Path,
rc::Rc,
};
use std::{cell::RefCell, path::Path, rc::Rc};
use ui::App;
use util::ResultExt as _;
use crate::{AcpThread, AgentConnection};
use crate::{AcpThread, AgentConnection, AuthRequired};
#[derive(Clone)]
pub struct OldAcpClientDelegate {
@ -357,21 +351,10 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
}
}
#[derive(Debug)]
pub struct Unauthenticated;
impl Error for Unauthenticated {}
impl fmt::Display for Unauthenticated {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unauthenticated")
}
}
pub struct OldAcpAgentConnection {
pub name: &'static str,
pub connection: acp_old::AgentConnection,
pub child_status: Task<Result<()>>,
pub agent_state: Rc<RefCell<acp::AgentState>>,
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
}
@ -394,7 +377,7 @@ impl AgentConnection for OldAcpAgentConnection {
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
anyhow::bail!(Unauthenticated)
anyhow::bail!(AuthRequired)
}
cx.update(|cx| {
@ -408,8 +391,12 @@ impl AgentConnection for OldAcpAgentConnection {
})
}
fn state(&self) -> Ref<'_, acp::AgentState> {
self.agent_state.borrow()
fn auth_methods(&self) -> Vec<acp::AuthMethod> {
vec![acp::AuthMethod {
id: acp::AuthMethodId("acp-old-no-id".into()),
label: "Log in".into(),
description: None,
}]
}
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {

View file

@ -7,24 +7,23 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use project::Project;
use smol::stream::StreamExt as _;
use std::cell::{Ref, RefCell};
use std::cell::RefCell;
use std::rc::Rc;
use std::{path::Path, sync::Arc};
use util::{ResultExt, TryFutureExt};
use util::ResultExt;
use anyhow::{Context, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::mcp_server::ZedMcpServer;
use crate::{AgentServerCommand, mcp_server};
use acp_thread::{AcpThread, AgentConnection};
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
pub struct AcpConnection {
agent_state: Rc<RefCell<acp::AgentState>>,
auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
server_name: &'static str,
client: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
_agent_state_task: Task<()>,
_session_update_task: Task<()>,
}
@ -47,24 +46,8 @@ impl AcpConnection {
.into();
ContextServer::start(client.clone(), cx).await?;
let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
let mcp_client = client.client().context("Failed to subscribe")?;
mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
move |notification, _cx| {
log::trace!(
"ACP Notification: {}",
serde_json::to_string_pretty(&notification).unwrap()
);
if let Some(state) =
serde_json::from_value::<acp::AgentState>(notification).log_err()
{
state_tx.send(state).log_err();
}
}
});
let (notification_tx, mut notification_rx) = mpsc::unbounded();
mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
move |notification, _cx| {
@ -83,17 +66,6 @@ impl AcpConnection {
});
let sessions = Rc::new(RefCell::new(HashMap::default()));
let initial_state = state_rx.recv().await?;
let agent_state = Rc::new(RefCell::new(initial_state));
let agent_state_task = cx.foreground_executor().spawn({
let agent_state = agent_state.clone();
async move {
while let Some(state) = state_rx.recv().log_err().await {
agent_state.replace(state);
}
}
});
let session_update_handler_task = cx.spawn({
let sessions = sessions.clone();
@ -105,11 +77,10 @@ impl AcpConnection {
});
Ok(Self {
auth_methods: Default::default(),
server_name,
client,
sessions,
agent_state,
_agent_state_task: agent_state_task,
_session_update_task: session_update_handler_task,
})
}
@ -154,6 +125,7 @@ impl AgentConnection for AcpConnection {
) -> Task<Result<Entity<AcpThread>>> {
let client = self.client.client();
let sessions = self.sessions.clone();
let auth_methods = self.auth_methods.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let client = client.context("MCP server is not initialized yet")?;
@ -194,12 +166,18 @@ impl AgentConnection for AcpConnection {
response.structured_content.context("Empty response")?,
)?;
auth_methods.replace(result.auth_methods);
let Some(session_id) = result.session_id else {
anyhow::bail!(AuthRequired);
};
let thread = cx.new(|cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
result.session_id.clone(),
session_id.clone(),
cx,
)
})?;
@ -211,14 +189,14 @@ impl AgentConnection for AcpConnection {
cancel_tx: None,
_mcp_server: mcp_server,
};
sessions.borrow_mut().insert(result.session_id, session);
sessions.borrow_mut().insert(session_id, session);
Ok(thread)
})
}
fn state(&self) -> Ref<'_, acp::AgentState> {
self.agent_state.borrow()
fn auth_methods(&self) -> Vec<agent_client_protocol::AuthMethod> {
self.auth_methods.borrow().clone()
}
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {

View file

@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
use std::cell::{Ref, RefCell};
use std::cell::RefCell;
use std::fmt::Display;
use std::path::Path;
use std::rc::Rc;
@ -58,7 +58,6 @@ impl AgentServer for ClaudeCode {
_cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let connection = ClaudeAgentConnection {
agent_state: Default::default(),
sessions: Default::default(),
};
@ -67,7 +66,6 @@ impl AgentServer for ClaudeCode {
}
struct ClaudeAgentConnection {
agent_state: Rc<RefCell<acp::AgentState>>,
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
}
@ -185,8 +183,8 @@ impl AgentConnection for ClaudeAgentConnection {
})
}
fn state(&self) -> Ref<'_, acp::AgentState> {
self.agent_state.borrow()
fn auth_methods(&self) -> Vec<acp::AuthMethod> {
vec![]
}
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {

View file

@ -216,15 +216,6 @@ impl AcpThreadView {
}
};
if connection.state().needs_authentication {
this.update(cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify();
})
.ok();
return;
}
let result = match connection
.clone()
.new_thread(project.clone(), &root_dir, cx)
@ -233,7 +224,7 @@ impl AcpThreadView {
Err(e) => {
let mut cx = cx.clone();
// todo! remove duplication
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify();
@ -2219,17 +2210,14 @@ impl Render for AcpThreadView {
.justify_center()
.child(self.render_pending_auth_state())
.child(h_flex().mt_1p5().justify_center().children(
connection.state().auth_methods.iter().map(|method| {
Button::new(
SharedString::from(method.id.0.clone()),
method.label.clone(),
)
.on_click({
let method_id = method.id.clone();
cx.listener(move |this, _, window, cx| {
this.authenticate(method_id.clone(), window, cx)
connection.auth_methods().into_iter().map(|method| {
Button::new(SharedString::from(method.id.0.clone()), method.label)
.on_click({
let method_id = method.id.clone();
cx.listener(move |this, _, window, cx| {
this.authenticate(method_id.clone(), window, cx)
})
})
})
}),
)),
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),