Auth WIP
This commit is contained in:
parent
254c6be42b
commit
6656403ce8
8 changed files with 140 additions and 56 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -139,8 +139,6 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "agent-client-protocol"
|
name = "agent-client-protocol"
|
||||||
version = "0.0.13"
|
version = "0.0.13"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "4255a06cc2414033d1fe4baf1968bcc8f16d7e5814f272b97779b5806d129142"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
|
|
|
@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||||
#
|
#
|
||||||
|
|
||||||
agentic-coding-protocol = "0.0.10"
|
agentic-coding-protocol = "0.0.10"
|
||||||
agent-client-protocol = "0.0.13"
|
agent-client-protocol = {path="../agent-client-protocol"}
|
||||||
aho-corasick = "1.1"
|
aho-corasick = "1.1"
|
||||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||||
any_vec = "0.14"
|
any_vec = "0.14"
|
||||||
|
|
|
@ -958,10 +958,6 @@ impl AcpThread {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
|
|
||||||
self.connection.authenticate(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
pub fn send_raw(
|
pub fn send_raw(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::{path::Path, rc::Rc};
|
use std::{cell::Ref, path::Path, rc::Rc};
|
||||||
|
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol::{self as acp};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use gpui::{AsyncApp, Entity, Task};
|
use gpui::{AsyncApp, Entity, Task};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
@ -16,7 +16,9 @@ pub trait AgentConnection {
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Task<Result<Entity<AcpThread>>>;
|
) -> Task<Result<Entity<AcpThread>>>;
|
||||||
|
|
||||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
|
fn state(&self) -> Ref<'_, acp::AgentState>;
|
||||||
|
|
||||||
|
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||||
|
|
||||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
|
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,13 @@ use anyhow::{Context as _, Result};
|
||||||
use futures::channel::oneshot;
|
use futures::channel::oneshot;
|
||||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
|
use std::{
|
||||||
|
cell::{Ref, RefCell},
|
||||||
|
error::Error,
|
||||||
|
fmt,
|
||||||
|
path::Path,
|
||||||
|
rc::Rc,
|
||||||
|
};
|
||||||
use ui::App;
|
use ui::App;
|
||||||
|
|
||||||
use crate::{AcpThread, AgentConnection};
|
use crate::{AcpThread, AgentConnection};
|
||||||
|
@ -364,6 +370,7 @@ pub struct OldAcpAgentConnection {
|
||||||
pub name: &'static str,
|
pub name: &'static str,
|
||||||
pub connection: acp_old::AgentConnection,
|
pub connection: acp_old::AgentConnection,
|
||||||
pub child_status: Task<Result<()>>,
|
pub child_status: Task<Result<()>>,
|
||||||
|
pub agent_state: Rc<RefCell<acp::AgentState>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentConnection for OldAcpAgentConnection {
|
impl AgentConnection for OldAcpAgentConnection {
|
||||||
|
@ -397,7 +404,11 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
|
fn state(&self) -> Ref<'_, acp::AgentState> {
|
||||||
|
self.agent_state.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||||
let task = self
|
let task = self
|
||||||
.connection
|
.connection
|
||||||
.request_any(acp_old::AuthenticateParams.into_any());
|
.request_any(acp_old::AuthenticateParams.into_any());
|
||||||
|
|
|
@ -7,10 +7,10 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||||
use futures::channel::{mpsc, oneshot};
|
use futures::channel::{mpsc, oneshot};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use smol::stream::StreamExt as _;
|
use smol::stream::StreamExt as _;
|
||||||
use std::cell::RefCell;
|
use std::cell::{Ref, RefCell};
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::{path::Path, sync::Arc};
|
use std::{path::Path, sync::Arc};
|
||||||
use util::ResultExt;
|
use util::{ResultExt, TryFutureExt};
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||||
|
@ -20,10 +20,12 @@ use crate::{AgentServerCommand, mcp_server};
|
||||||
use acp_thread::{AcpThread, AgentConnection};
|
use acp_thread::{AcpThread, AgentConnection};
|
||||||
|
|
||||||
pub struct AcpConnection {
|
pub struct AcpConnection {
|
||||||
|
agent_state: Rc<RefCell<acp::AgentState>>,
|
||||||
server_name: &'static str,
|
server_name: &'static str,
|
||||||
client: Arc<context_server::ContextServer>,
|
client: Arc<context_server::ContextServer>,
|
||||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||||
_notification_handler_task: Task<()>,
|
_agent_state_task: Task<()>,
|
||||||
|
_session_update_task: Task<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AcpConnection {
|
impl AcpConnection {
|
||||||
|
@ -43,29 +45,55 @@ impl AcpConnection {
|
||||||
.into();
|
.into();
|
||||||
ContextServer::start(client.clone(), cx).await?;
|
ContextServer::start(client.clone(), cx).await?;
|
||||||
|
|
||||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
|
||||||
client
|
let mcp_client = client.client().context("Failed to subscribe")?;
|
||||||
.client()
|
|
||||||
.context("Failed to subscribe")?
|
|
||||||
.on_notification(acp::AGENT_METHODS.session_update, {
|
|
||||||
move |notification, _cx| {
|
|
||||||
let notification_tx = notification_tx.clone();
|
|
||||||
log::trace!(
|
|
||||||
"ACP Notification: {}",
|
|
||||||
serde_json::to_string_pretty(¬ification).unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(notification) =
|
mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
|
||||||
serde_json::from_value::<acp::SessionNotification>(notification).log_err()
|
move |notification, _cx| {
|
||||||
{
|
log::trace!(
|
||||||
notification_tx.unbounded_send(notification).ok();
|
"ACP Notification: {}",
|
||||||
}
|
serde_json::to_string_pretty(¬ification).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(state) =
|
||||||
|
serde_json::from_value::<acp::AgentState>(notification).log_err()
|
||||||
|
{
|
||||||
|
state_tx.send(state).log_err();
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||||
|
mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
|
||||||
|
move |notification, _cx| {
|
||||||
|
let notification_tx = notification_tx.clone();
|
||||||
|
log::trace!(
|
||||||
|
"ACP Notification: {}",
|
||||||
|
serde_json::to_string_pretty(¬ification).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(notification) =
|
||||||
|
serde_json::from_value::<acp::SessionNotification>(notification).log_err()
|
||||||
|
{
|
||||||
|
notification_tx.unbounded_send(notification).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||||
|
let initial_state = state_rx.recv().await?;
|
||||||
|
let agent_state = Rc::new(RefCell::new(initial_state));
|
||||||
|
|
||||||
let notification_handler_task = cx.spawn({
|
let agent_state_task = cx.foreground_executor().spawn({
|
||||||
|
let agent_state = agent_state.clone();
|
||||||
|
async move {
|
||||||
|
while let Some(state) = state_rx.recv().log_err().await {
|
||||||
|
agent_state.replace(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let session_update_handler_task = cx.spawn({
|
||||||
let sessions = sessions.clone();
|
let sessions = sessions.clone();
|
||||||
async move |cx| {
|
async move |cx| {
|
||||||
while let Some(notification) = notification_rx.next().await {
|
while let Some(notification) = notification_rx.next().await {
|
||||||
|
@ -78,7 +106,9 @@ impl AcpConnection {
|
||||||
server_name,
|
server_name,
|
||||||
client,
|
client,
|
||||||
sessions,
|
sessions,
|
||||||
_notification_handler_task: notification_handler_task,
|
agent_state,
|
||||||
|
_agent_state_task: agent_state_task,
|
||||||
|
_session_update_task: session_update_handler_task,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,8 +215,30 @@ impl AgentConnection for AcpConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
fn state(&self) -> Ref<'_, acp::AgentState> {
|
||||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
self.agent_state.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||||
|
let client = self.client.client();
|
||||||
|
cx.foreground_executor().spawn(async move {
|
||||||
|
let params = acp::AuthenticateArguments { method_id };
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.context("MCP server is not initialized yet")?
|
||||||
|
.request::<requests::CallTool>(context_server::types::CallToolParams {
|
||||||
|
name: acp::AGENT_METHODS.authenticate.into(),
|
||||||
|
arguments: Some(serde_json::to_value(params)?),
|
||||||
|
meta: None,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if response.is_error.unwrap_or_default() {
|
||||||
|
Err(anyhow!(response.text_contents()))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
|
|
|
@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::process::Child;
|
use smol::process::Child;
|
||||||
use std::cell::RefCell;
|
use std::cell::{Ref, RefCell};
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
@ -58,6 +58,7 @@ impl AgentServer for ClaudeCode {
|
||||||
_cx: &mut App,
|
_cx: &mut App,
|
||||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||||
let connection = ClaudeAgentConnection {
|
let connection = ClaudeAgentConnection {
|
||||||
|
agent_state: Default::default(),
|
||||||
sessions: Default::default(),
|
sessions: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -66,6 +67,7 @@ impl AgentServer for ClaudeCode {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ClaudeAgentConnection {
|
struct ClaudeAgentConnection {
|
||||||
|
agent_state: Rc<RefCell<acp::AgentState>>,
|
||||||
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
|
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,7 +185,11 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
fn state(&self) -> Ref<'_, acp::AgentState> {
|
||||||
|
self.agent_state.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -216,6 +216,15 @@ impl AcpThreadView {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if connection.state().needs_authentication {
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let result = match connection
|
let result = match connection
|
||||||
.clone()
|
.clone()
|
||||||
.new_thread(project.clone(), &root_dir, cx)
|
.new_thread(project.clone(), &root_dir, cx)
|
||||||
|
@ -223,6 +232,7 @@ impl AcpThreadView {
|
||||||
{
|
{
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let mut cx = cx.clone();
|
let mut cx = cx.clone();
|
||||||
|
// todo! remove duplication
|
||||||
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
|
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||||
|
@ -640,13 +650,18 @@ impl AcpThreadView {
|
||||||
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
fn authenticate(
|
||||||
|
&mut self,
|
||||||
|
method: acp::AuthMethodId,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
self.last_error.take();
|
self.last_error.take();
|
||||||
let authenticate = connection.authenticate(cx);
|
let authenticate = connection.authenticate(method, cx);
|
||||||
self.auth_task = Some(cx.spawn_in(window, {
|
self.auth_task = Some(cx.spawn_in(window, {
|
||||||
let project = self.project.clone();
|
let project = self.project.clone();
|
||||||
let agent = self.agent.clone();
|
let agent = self.agent.clone();
|
||||||
|
@ -2197,22 +2212,26 @@ impl Render for AcpThreadView {
|
||||||
.on_action(cx.listener(Self::next_history_message))
|
.on_action(cx.listener(Self::next_history_message))
|
||||||
.on_action(cx.listener(Self::open_agent_diff))
|
.on_action(cx.listener(Self::open_agent_diff))
|
||||||
.child(match &self.thread_state {
|
.child(match &self.thread_state {
|
||||||
ThreadState::Unauthenticated { .. } => {
|
ThreadState::Unauthenticated { connection } => v_flex()
|
||||||
v_flex()
|
.p_2()
|
||||||
.p_2()
|
.flex_1()
|
||||||
.flex_1()
|
.items_center()
|
||||||
.items_center()
|
.justify_center()
|
||||||
.justify_center()
|
.child(self.render_pending_auth_state())
|
||||||
.child(self.render_pending_auth_state())
|
.child(h_flex().mt_1p5().justify_center().children(
|
||||||
.child(
|
connection.state().auth_methods.iter().map(|method| {
|
||||||
h_flex().mt_1p5().justify_center().child(
|
Button::new(
|
||||||
Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
|
SharedString::from(method.id.0.clone()),
|
||||||
.on_click(cx.listener(|this, _, window, cx| {
|
method.label.clone(),
|
||||||
this.authenticate(window, cx)
|
)
|
||||||
})),
|
.on_click({
|
||||||
),
|
let method_id = method.id.clone();
|
||||||
)
|
cx.listener(move |this, _, window, cx| {
|
||||||
}
|
this.authenticate(method_id.clone(), window, cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
)),
|
||||||
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
||||||
ThreadState::LoadError(e) => v_flex()
|
ThreadState::LoadError(e) => v_flex()
|
||||||
.p_2()
|
.p_2()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue