Auth WIP
This commit is contained in:
parent
254c6be42b
commit
6656403ce8
8 changed files with 140 additions and 56 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -139,8 +139,6 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "agent-client-protocol"
|
||||
version = "0.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4255a06cc2414033d1fe4baf1968bcc8f16d7e5814f272b97779b5806d129142"
|
||||
dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
|
|
|
@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
|||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = "0.0.13"
|
||||
agent-client-protocol = {path="../agent-client-protocol"}
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
|
|
|
@ -958,10 +958,6 @@ impl AcpThread {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.connection.authenticate(cx)
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn send_raw(
|
||||
&mut self,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::{path::Path, rc::Rc};
|
||||
use std::{cell::Ref, path::Path, rc::Rc};
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use project::Project;
|
||||
|
@ -16,7 +16,9 @@ pub trait AgentConnection {
|
|||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
|
||||
fn state(&self) -> Ref<'_, acp::AgentState>;
|
||||
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
|
|
|
@ -5,7 +5,13 @@ use anyhow::{Context as _, Result};
|
|||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use project::Project;
|
||||
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
|
||||
use std::{
|
||||
cell::{Ref, RefCell},
|
||||
error::Error,
|
||||
fmt,
|
||||
path::Path,
|
||||
rc::Rc,
|
||||
};
|
||||
use ui::App;
|
||||
|
||||
use crate::{AcpThread, AgentConnection};
|
||||
|
@ -364,6 +370,7 @@ pub struct OldAcpAgentConnection {
|
|||
pub name: &'static str,
|
||||
pub connection: acp_old::AgentConnection,
|
||||
pub child_status: Task<Result<()>>,
|
||||
pub agent_state: Rc<RefCell<acp::AgentState>>,
|
||||
}
|
||||
|
||||
impl AgentConnection for OldAcpAgentConnection {
|
||||
|
@ -397,7 +404,11 @@ impl AgentConnection for OldAcpAgentConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
fn state(&self) -> Ref<'_, acp::AgentState> {
|
||||
self.agent_state.borrow()
|
||||
}
|
||||
|
||||
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
let task = self
|
||||
.connection
|
||||
.request_any(acp_old::AuthenticateParams.into_any());
|
||||
|
|
|
@ -7,10 +7,10 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
|||
use futures::channel::{mpsc, oneshot};
|
||||
use project::Project;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::cell::RefCell;
|
||||
use std::cell::{Ref, RefCell};
|
||||
use std::rc::Rc;
|
||||
use std::{path::Path, sync::Arc};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
|
@ -20,10 +20,12 @@ use crate::{AgentServerCommand, mcp_server};
|
|||
use acp_thread::{AcpThread, AgentConnection};
|
||||
|
||||
pub struct AcpConnection {
|
||||
agent_state: Rc<RefCell<acp::AgentState>>,
|
||||
server_name: &'static str,
|
||||
client: Arc<context_server::ContextServer>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
_notification_handler_task: Task<()>,
|
||||
_agent_state_task: Task<()>,
|
||||
_session_update_task: Task<()>,
|
||||
}
|
||||
|
||||
impl AcpConnection {
|
||||
|
@ -43,29 +45,55 @@ impl AcpConnection {
|
|||
.into();
|
||||
ContextServer::start(client.clone(), cx).await?;
|
||||
|
||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||
client
|
||||
.client()
|
||||
.context("Failed to subscribe")?
|
||||
.on_notification(acp::AGENT_METHODS.session_update, {
|
||||
move |notification, _cx| {
|
||||
let notification_tx = notification_tx.clone();
|
||||
log::trace!(
|
||||
"ACP Notification: {}",
|
||||
serde_json::to_string_pretty(¬ification).unwrap()
|
||||
);
|
||||
let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
|
||||
let mcp_client = client.client().context("Failed to subscribe")?;
|
||||
|
||||
if let Some(notification) =
|
||||
serde_json::from_value::<acp::SessionNotification>(notification).log_err()
|
||||
{
|
||||
notification_tx.unbounded_send(notification).ok();
|
||||
}
|
||||
mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
|
||||
move |notification, _cx| {
|
||||
log::trace!(
|
||||
"ACP Notification: {}",
|
||||
serde_json::to_string_pretty(¬ification).unwrap()
|
||||
);
|
||||
|
||||
if let Some(state) =
|
||||
serde_json::from_value::<acp::AgentState>(notification).log_err()
|
||||
{
|
||||
state_tx.send(state).log_err();
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||
mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
|
||||
move |notification, _cx| {
|
||||
let notification_tx = notification_tx.clone();
|
||||
log::trace!(
|
||||
"ACP Notification: {}",
|
||||
serde_json::to_string_pretty(¬ification).unwrap()
|
||||
);
|
||||
|
||||
if let Some(notification) =
|
||||
serde_json::from_value::<acp::SessionNotification>(notification).log_err()
|
||||
{
|
||||
notification_tx.unbounded_send(notification).ok();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
let initial_state = state_rx.recv().await?;
|
||||
let agent_state = Rc::new(RefCell::new(initial_state));
|
||||
|
||||
let notification_handler_task = cx.spawn({
|
||||
let agent_state_task = cx.foreground_executor().spawn({
|
||||
let agent_state = agent_state.clone();
|
||||
async move {
|
||||
while let Some(state) = state_rx.recv().log_err().await {
|
||||
agent_state.replace(state);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let session_update_handler_task = cx.spawn({
|
||||
let sessions = sessions.clone();
|
||||
async move |cx| {
|
||||
while let Some(notification) = notification_rx.next().await {
|
||||
|
@ -78,7 +106,9 @@ impl AcpConnection {
|
|||
server_name,
|
||||
client,
|
||||
sessions,
|
||||
_notification_handler_task: notification_handler_task,
|
||||
agent_state,
|
||||
_agent_state_task: agent_state_task,
|
||||
_session_update_task: session_update_handler_task,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -185,8 +215,30 @@ impl AgentConnection for AcpConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
fn state(&self) -> Ref<'_, acp::AgentState> {
|
||||
self.agent_state.borrow()
|
||||
}
|
||||
|
||||
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
let client = self.client.client();
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let params = acp::AuthenticateArguments { method_id };
|
||||
|
||||
let response = client
|
||||
.context("MCP server is not initialized yet")?
|
||||
.request::<requests::CallTool>(context_server::types::CallToolParams {
|
||||
name: acp::AGENT_METHODS.authenticate.into(),
|
||||
arguments: Some(serde_json::to_value(params)?),
|
||||
meta: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if response.is_error.unwrap_or_default() {
|
||||
Err(anyhow!(response.text_contents()))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
|
|
|
@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
|
|||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::process::Child;
|
||||
use std::cell::RefCell;
|
||||
use std::cell::{Ref, RefCell};
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
@ -58,6 +58,7 @@ impl AgentServer for ClaudeCode {
|
|||
_cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let connection = ClaudeAgentConnection {
|
||||
agent_state: Default::default(),
|
||||
sessions: Default::default(),
|
||||
};
|
||||
|
||||
|
@ -66,6 +67,7 @@ impl AgentServer for ClaudeCode {
|
|||
}
|
||||
|
||||
struct ClaudeAgentConnection {
|
||||
agent_state: Rc<RefCell<acp::AgentState>>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
|
||||
}
|
||||
|
||||
|
@ -183,7 +185,11 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
fn state(&self) -> Ref<'_, acp::AgentState> {
|
||||
self.agent_state.borrow()
|
||||
}
|
||||
|
||||
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
|
||||
|
|
|
@ -216,6 +216,15 @@ impl AcpThreadView {
|
|||
}
|
||||
};
|
||||
|
||||
if connection.state().needs_authentication {
|
||||
this.update(cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
return;
|
||||
}
|
||||
|
||||
let result = match connection
|
||||
.clone()
|
||||
.new_thread(project.clone(), &root_dir, cx)
|
||||
|
@ -223,6 +232,7 @@ impl AcpThreadView {
|
|||
{
|
||||
Err(e) => {
|
||||
let mut cx = cx.clone();
|
||||
// todo! remove duplication
|
||||
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
|
@ -640,13 +650,18 @@ impl AcpThreadView {
|
|||
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
fn authenticate(
|
||||
&mut self,
|
||||
method: acp::AuthMethodId,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.last_error.take();
|
||||
let authenticate = connection.authenticate(cx);
|
||||
let authenticate = connection.authenticate(method, cx);
|
||||
self.auth_task = Some(cx.spawn_in(window, {
|
||||
let project = self.project.clone();
|
||||
let agent = self.agent.clone();
|
||||
|
@ -2197,22 +2212,26 @@ impl Render for AcpThreadView {
|
|||
.on_action(cx.listener(Self::next_history_message))
|
||||
.on_action(cx.listener(Self::open_agent_diff))
|
||||
.child(match &self.thread_state {
|
||||
ThreadState::Unauthenticated { .. } => {
|
||||
v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(
|
||||
h_flex().mt_1p5().justify_center().child(
|
||||
Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.authenticate(window, cx)
|
||||
})),
|
||||
),
|
||||
)
|
||||
}
|
||||
ThreadState::Unauthenticated { connection } => v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(h_flex().mt_1p5().justify_center().children(
|
||||
connection.state().auth_methods.iter().map(|method| {
|
||||
Button::new(
|
||||
SharedString::from(method.id.0.clone()),
|
||||
method.label.clone(),
|
||||
)
|
||||
.on_click({
|
||||
let method_id = method.id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.authenticate(method_id.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
}),
|
||||
)),
|
||||
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
||||
ThreadState::LoadError(e) => v_flex()
|
||||
.p_2()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue