Fix auth
This commit is contained in:
parent
738296345e
commit
27708143ec
6 changed files with 48 additions and 88 deletions
|
@ -1595,7 +1595,6 @@ mod tests {
|
|||
connection,
|
||||
child_status: io_task,
|
||||
current_thread: thread_rc,
|
||||
agent_state: Default::default(),
|
||||
};
|
||||
|
||||
AcpThread::new(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::{cell::Ref, path::Path, rc::Rc};
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc};
|
||||
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
|
@ -16,7 +16,7 @@ pub trait AgentConnection {
|
|||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn state(&self) -> Ref<'_, acp::AgentState>;
|
||||
fn auth_methods(&self) -> Vec<acp::AuthMethod>;
|
||||
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
|
@ -24,3 +24,13 @@ pub trait AgentConnection {
|
|||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,17 +5,11 @@ use anyhow::{Context as _, Result};
|
|||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use project::Project;
|
||||
use std::{
|
||||
cell::{Ref, RefCell},
|
||||
error::Error,
|
||||
fmt,
|
||||
path::Path,
|
||||
rc::Rc,
|
||||
};
|
||||
use std::{cell::RefCell, path::Path, rc::Rc};
|
||||
use ui::App;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{AcpThread, AgentConnection};
|
||||
use crate::{AcpThread, AgentConnection, AuthRequired};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OldAcpClientDelegate {
|
||||
|
@ -357,21 +351,10 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Unauthenticated;
|
||||
|
||||
impl Error for Unauthenticated {}
|
||||
impl fmt::Display for Unauthenticated {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Unauthenticated")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OldAcpAgentConnection {
|
||||
pub name: &'static str,
|
||||
pub connection: acp_old::AgentConnection,
|
||||
pub child_status: Task<Result<()>>,
|
||||
pub agent_state: Rc<RefCell<acp::AgentState>>,
|
||||
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
||||
}
|
||||
|
||||
|
@ -394,7 +377,7 @@ impl AgentConnection for OldAcpAgentConnection {
|
|||
let result = acp_old::InitializeParams::response_from_any(result)?;
|
||||
|
||||
if !result.is_authenticated {
|
||||
anyhow::bail!(Unauthenticated)
|
||||
anyhow::bail!(AuthRequired)
|
||||
}
|
||||
|
||||
cx.update(|cx| {
|
||||
|
@ -408,8 +391,12 @@ impl AgentConnection for OldAcpAgentConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn state(&self) -> Ref<'_, acp::AgentState> {
|
||||
self.agent_state.borrow()
|
||||
fn auth_methods(&self) -> Vec<acp::AuthMethod> {
|
||||
vec![acp::AuthMethod {
|
||||
id: acp::AuthMethodId("acp-old-no-id".into()),
|
||||
label: "Log in".into(),
|
||||
description: None,
|
||||
}]
|
||||
}
|
||||
|
||||
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
|
|
|
@ -7,24 +7,23 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
|||
use futures::channel::{mpsc, oneshot};
|
||||
use project::Project;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::cell::{Ref, RefCell};
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use std::{path::Path, sync::Arc};
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
use util::ResultExt;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
|
||||
use crate::mcp_server::ZedMcpServer;
|
||||
use crate::{AgentServerCommand, mcp_server};
|
||||
use acp_thread::{AcpThread, AgentConnection};
|
||||
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
|
||||
|
||||
pub struct AcpConnection {
|
||||
agent_state: Rc<RefCell<acp::AgentState>>,
|
||||
auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
|
||||
server_name: &'static str,
|
||||
client: Arc<context_server::ContextServer>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
_agent_state_task: Task<()>,
|
||||
_session_update_task: Task<()>,
|
||||
}
|
||||
|
||||
|
@ -47,24 +46,8 @@ impl AcpConnection {
|
|||
.into();
|
||||
ContextServer::start(client.clone(), cx).await?;
|
||||
|
||||
let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
|
||||
let mcp_client = client.client().context("Failed to subscribe")?;
|
||||
|
||||
mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
|
||||
move |notification, _cx| {
|
||||
log::trace!(
|
||||
"ACP Notification: {}",
|
||||
serde_json::to_string_pretty(¬ification).unwrap()
|
||||
);
|
||||
|
||||
if let Some(state) =
|
||||
serde_json::from_value::<acp::AgentState>(notification).log_err()
|
||||
{
|
||||
state_tx.send(state).log_err();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||
mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
|
||||
move |notification, _cx| {
|
||||
|
@ -83,17 +66,6 @@ impl AcpConnection {
|
|||
});
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
let initial_state = state_rx.recv().await?;
|
||||
let agent_state = Rc::new(RefCell::new(initial_state));
|
||||
|
||||
let agent_state_task = cx.foreground_executor().spawn({
|
||||
let agent_state = agent_state.clone();
|
||||
async move {
|
||||
while let Some(state) = state_rx.recv().log_err().await {
|
||||
agent_state.replace(state);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let session_update_handler_task = cx.spawn({
|
||||
let sessions = sessions.clone();
|
||||
|
@ -105,11 +77,10 @@ impl AcpConnection {
|
|||
});
|
||||
|
||||
Ok(Self {
|
||||
auth_methods: Default::default(),
|
||||
server_name,
|
||||
client,
|
||||
sessions,
|
||||
agent_state,
|
||||
_agent_state_task: agent_state_task,
|
||||
_session_update_task: session_update_handler_task,
|
||||
})
|
||||
}
|
||||
|
@ -154,6 +125,7 @@ impl AgentConnection for AcpConnection {
|
|||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let client = self.client.client();
|
||||
let sessions = self.sessions.clone();
|
||||
let auth_methods = self.auth_methods.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let client = client.context("MCP server is not initialized yet")?;
|
||||
|
@ -194,12 +166,18 @@ impl AgentConnection for AcpConnection {
|
|||
response.structured_content.context("Empty response")?,
|
||||
)?;
|
||||
|
||||
auth_methods.replace(result.auth_methods);
|
||||
|
||||
let Some(session_id) = result.session_id else {
|
||||
anyhow::bail!(AuthRequired);
|
||||
};
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new(
|
||||
self.server_name,
|
||||
self.clone(),
|
||||
project,
|
||||
result.session_id.clone(),
|
||||
session_id.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
@ -211,14 +189,14 @@ impl AgentConnection for AcpConnection {
|
|||
cancel_tx: None,
|
||||
_mcp_server: mcp_server,
|
||||
};
|
||||
sessions.borrow_mut().insert(result.session_id, session);
|
||||
sessions.borrow_mut().insert(session_id, session);
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn state(&self) -> Ref<'_, acp::AgentState> {
|
||||
self.agent_state.borrow()
|
||||
fn auth_methods(&self) -> Vec<agent_client_protocol::AuthMethod> {
|
||||
self.auth_methods.borrow().clone()
|
||||
}
|
||||
|
||||
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
|
|
|
@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
|
|||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::process::Child;
|
||||
use std::cell::{Ref, RefCell};
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
@ -58,7 +58,6 @@ impl AgentServer for ClaudeCode {
|
|||
_cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let connection = ClaudeAgentConnection {
|
||||
agent_state: Default::default(),
|
||||
sessions: Default::default(),
|
||||
};
|
||||
|
||||
|
@ -67,7 +66,6 @@ impl AgentServer for ClaudeCode {
|
|||
}
|
||||
|
||||
struct ClaudeAgentConnection {
|
||||
agent_state: Rc<RefCell<acp::AgentState>>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
|
||||
}
|
||||
|
||||
|
@ -185,8 +183,8 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn state(&self) -> Ref<'_, acp::AgentState> {
|
||||
self.agent_state.borrow()
|
||||
fn auth_methods(&self) -> Vec<acp::AuthMethod> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||
|
|
|
@ -216,15 +216,6 @@ impl AcpThreadView {
|
|||
}
|
||||
};
|
||||
|
||||
if connection.state().needs_authentication {
|
||||
this.update(cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
return;
|
||||
}
|
||||
|
||||
let result = match connection
|
||||
.clone()
|
||||
.new_thread(project.clone(), &root_dir, cx)
|
||||
|
@ -233,7 +224,7 @@ impl AcpThreadView {
|
|||
Err(e) => {
|
||||
let mut cx = cx.clone();
|
||||
// todo! remove duplication
|
||||
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
|
||||
if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
cx.notify();
|
||||
|
@ -2219,17 +2210,14 @@ impl Render for AcpThreadView {
|
|||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(h_flex().mt_1p5().justify_center().children(
|
||||
connection.state().auth_methods.iter().map(|method| {
|
||||
Button::new(
|
||||
SharedString::from(method.id.0.clone()),
|
||||
method.label.clone(),
|
||||
)
|
||||
.on_click({
|
||||
let method_id = method.id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.authenticate(method_id.clone(), window, cx)
|
||||
connection.auth_methods().into_iter().map(|method| {
|
||||
Button::new(SharedString::from(method.id.0.clone()), method.label)
|
||||
.on_click({
|
||||
let method_id = method.id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.authenticate(method_id.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
})
|
||||
}),
|
||||
)),
|
||||
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue