diff --git a/Cargo.lock b/Cargo.lock index 64470b5abe..6daa01dca3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,6 +19,7 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "language_model", "markdown", "project", "serde", @@ -137,15 +138,57 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b" +version = "0.0.14" dependencies = [ + "anyhow", + "futures 0.3.31", + "log", + "parking_lot", "schemars", "serde", "serde_json", ] +[[package]] +name = "agent2" +version = "0.1.0" +dependencies = [ + "acp_thread", + "agent-client-protocol", + "agent_servers", + "anyhow", + "assistant_tool", + "assistant_tools", + "chrono", + "client", + "cloud_llm_client", + "collections", + "ctor", + "env_logger 0.11.8", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "handlebars 4.5.0", + "language_model", + "language_models", + "log", + "parking_lot", + "project", + "reqwest_client", + "rust-embed", + "schemars", + "serde", + "serde_json", + "settings", + "smol", + "thiserror 2.0.12", + "ui", + "util", + "uuid", + "worktree", +] + [[package]] name = "agent_servers" version = "0.1.0" @@ -209,6 +252,7 @@ dependencies = [ "acp_thread", "agent", "agent-client-protocol", + "agent2", "agent_servers", "agent_settings", "ai_onboarding", @@ -9570,9 +9614,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -11286,9 +11330,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -11296,9 +11340,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", diff --git a/Cargo.toml b/Cargo.toml index 5b97596d0c..d01ae9f683 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "crates/acp_thread", "crates/activity_indicator", "crates/agent", + "crates/agent2", "crates/agent_servers", "crates/agent_settings", "crates/agent_ui", @@ -228,6 +229,7 @@ edition = "2024" acp_thread = { path = "crates/acp_thread" } agent = { path = "crates/agent" } +agent2 = { path = "crates/agent2" } activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } agent_settings = { path = "crates/agent_settings" } @@ -421,7 +423,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.11" +agent-client-protocol = {path="../agent-client-protocol"} aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 011f26f364..308756b038 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -26,6 +26,7 @@ futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true +language_model.workspace = true markdown.workspace = true project.workspace = true serde.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 7a10f3bd72..c42c155e89 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -391,7 +391,7 @@ impl ToolCallContent { cx: &mut App, ) -> Self { match content { - acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock { + acp::ToolCallContent::Content { content } => Self::ContentBlock { content: ContentBlock::new(content, &language_registry, cx), }, acp::ToolCallContent::Diff { diff } => Self::Diff { @@ -619,6 +619,7 @@ impl Error for LoadError {} impl AcpThread { pub fn new( + title: impl Into, connection: Rc, project: Entity, session_id: acp::SessionId, @@ -631,7 +632,7 @@ impl AcpThread { shared_buffers: Default::default(), entries: Default::default(), plan: Default::default(), - title: connection.name().into(), + title: title.into(), project, send_task: None, connection, @@ -655,6 +656,10 @@ impl AcpThread { &self.entries } + pub fn session_id(&self) -> &acp::SessionId { + &self.session_id + } + pub fn status(&self) -> ThreadStatus { if self.send_task.is_some() { if self.waiting_for_tool_confirmation() { @@ -697,14 +702,14 @@ impl AcpThread { cx: &mut Context, ) -> Result<()> { match update { - acp::SessionUpdate::UserMessage(content_block) => { - self.push_user_content_block(content_block, cx); + acp::SessionUpdate::UserMessageChunk { content } => { + self.push_user_content_block(content, cx); } - acp::SessionUpdate::AgentMessageChunk(content_block) => { - self.push_assistant_content_block(content_block, false, cx); + acp::SessionUpdate::AgentMessageChunk { content } => { + self.push_assistant_content_block(content, false, cx); } - acp::SessionUpdate::AgentThoughtChunk(content_block) => { - self.push_assistant_content_block(content_block, true, cx); + acp::SessionUpdate::AgentThoughtChunk { content } => { + self.push_assistant_content_block(content, true, cx); } acp::SessionUpdate::ToolCall(tool_call) => { self.upsert_tool_call(tool_call, cx); @@ -973,10 +978,6 @@ impl AcpThread { cx.notify(); } - pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future> { - self.connection.authenticate(cx) - } - #[cfg(any(test, feature = "test-support"))] pub fn send_raw( &mut self, @@ -1018,7 +1019,7 @@ impl AcpThread { let result = this .update(cx, |this, cx| { this.connection.prompt( - acp::PromptArguments { + acp::PromptRequest { prompt: message, session_id: this.session_id.clone(), }, @@ -1620,9 +1621,15 @@ mod tests { connection, child_status: io_task, current_thread: thread_rc, + auth_methods: [acp::AuthMethod { + id: acp::AuthMethodId("acp-old-no-id".into()), + label: "Log in".into(), + description: None, + }], }; AcpThread::new( + "test", Rc::new(connection), project, acp::SessionId("test".into()), diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 5b25b71863..9659433edb 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,16 +1,62 @@ -use std::{path::Path, rc::Rc}; +use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp}; use anyhow::Result; use gpui::{AsyncApp, Entity, Task}; +use language_model::LanguageModel; use project::Project; use ui::App; use crate::AcpThread; -pub trait AgentConnection { - fn name(&self) -> &'static str; +/// Trait for agents that support listing, selecting, and querying language models. +/// +/// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. +pub trait ModelSelector: 'static { + /// Lists all available language models for this agent. + /// + /// # Parameters + /// - `cx`: The GPUI app context for async operations and global access. + /// + /// # Returns + /// A task resolving to the list of models or an error (e.g., if no models are configured). + fn list_models(&self, cx: &mut AsyncApp) -> Task>>>; + /// Selects a model for a specific session (thread). + /// + /// This sets the default model for future interactions in the session. + /// If the session doesn't exist or the model is invalid, it returns an error. + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to apply the model to. + /// - `model`: The model to select (should be one from [list_models]). + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to `Ok(())` on success or an error. + fn select_model( + &self, + session_id: acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task>; + + /// Retrieves the currently selected model for a specific session (thread). + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to query. + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to the selected model (always set) or an error (e.g., session not found). + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>>; +} + +pub trait AgentConnection { fn new_thread( self: Rc, project: Entity, @@ -18,9 +64,29 @@ pub trait AgentConnection { cx: &mut AsyncApp, ) -> Task>>; - fn authenticate(&self, cx: &mut App) -> Task>; + fn auth_methods(&self) -> &[acp::AuthMethod]; - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task>; + fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; + + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task>; fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + + /// Returns this agent as an [Rc] if the model selection capability is supported. + /// + /// If the agent does not support model selection, returns [None]. + /// This allows sharing the selector in UI components. + fn model_selector(&self) -> Option> { + None // Default impl for agents that don't support it + } +} + +#[derive(Debug)] +pub struct AuthRequired; + +impl Error for AuthRequired {} +impl fmt::Display for AuthRequired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthRequired") + } } diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs index 571023239f..adb27e21c4 100644 --- a/crates/acp_thread/src/old_acp_support.rs +++ b/crates/acp_thread/src/old_acp_support.rs @@ -5,11 +5,11 @@ use anyhow::{Context as _, Result}; use futures::channel::oneshot; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; -use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc}; +use std::{cell::RefCell, path::Path, rc::Rc}; use ui::App; use util::ResultExt as _; -use crate::{AcpThread, AgentConnection}; +use crate::{AcpThread, AgentConnection, AuthRequired}; #[derive(Clone)] pub struct OldAcpClientDelegate { @@ -351,28 +351,15 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu } } -#[derive(Debug)] -pub struct Unauthenticated; - -impl Error for Unauthenticated {} -impl fmt::Display for Unauthenticated { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Unauthenticated") - } -} - pub struct OldAcpAgentConnection { pub name: &'static str, pub connection: acp_old::AgentConnection, pub child_status: Task>, pub current_thread: Rc>>, + pub auth_methods: [acp::AuthMethod; 1], } impl AgentConnection for OldAcpAgentConnection { - fn name(&self) -> &'static str { - self.name - } - fn new_thread( self: Rc, project: Entity, @@ -391,13 +378,13 @@ impl AgentConnection for OldAcpAgentConnection { let result = acp_old::InitializeParams::response_from_any(result)?; if !result.is_authenticated { - anyhow::bail!(Unauthenticated) + anyhow::bail!(AuthRequired) } cx.update(|cx| { let thread = cx.new(|cx| { let session_id = acp::SessionId("acp-old-no-id".into()); - AcpThread::new(self.clone(), project, session_id, cx) + AcpThread::new("Gemini", self.clone(), project, session_id, cx) }); current_thread.replace(thread.downgrade()); thread @@ -405,7 +392,11 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn authenticate(&self, cx: &mut App) -> Task> { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { let task = self .connection .request_any(acp_old::AuthenticateParams.into_any()); @@ -415,7 +406,7 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let chunks = params .prompt .into_iter() diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml new file mode 100644 index 0000000000..b2e83e2252 --- /dev/null +++ b/crates/agent2/Cargo.toml @@ -0,0 +1,56 @@ +[package] +name = "agent2" +version = "0.1.0" +edition = "2021" +license = "GPL-3.0-or-later" +publish = false + +[lib] +path = "src/agent2.rs" + +[lints] +workspace = true + +[dependencies] +acp_thread.workspace = true +agent-client-protocol.workspace = true +agent_servers.workspace = true +anyhow.workspace = true +assistant_tool.workspace = true +assistant_tools.workspace = true +chrono.workspace = true +cloud_llm_client.workspace = true +collections.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +handlebars = { workspace = true, features = ["rust-embed"] } +language_model.workspace = true +language_models.workspace = true +log.workspace = true +parking_lot.workspace = true +project.workspace = true +rust-embed.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +thiserror.workspace = true +ui.workspace = true +util.workspace = true +uuid.workspace = true +worktree.workspace = true + +[dev-dependencies] +ctor.workspace = true +client = { workspace = true, "features" = ["test-support"] } +env_logger.workspace = true +fs = { workspace = true, "features" = ["test-support"] } +gpui = { workspace = true, "features" = ["test-support"] } +gpui_tokio.workspace = true +language_model = { workspace = true, "features" = ["test-support"] } +project = { workspace = true, "features" = ["test-support"] } +reqwest_client.workspace = true +settings = { workspace = true, "features" = ["test-support"] } +worktree = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/LICENSE-GPL b/crates/agent2/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/agent2/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs new file mode 100644 index 0000000000..0491f7c81a --- /dev/null +++ b/crates/agent2/src/agent.rs @@ -0,0 +1,377 @@ +use acp_thread::ModelSelector; +use agent_client_protocol as acp; +use anyhow::{anyhow, Result}; +use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use language_model::{LanguageModel, LanguageModelRegistry}; +use project::Project; +use std::collections::HashMap; +use std::path::Path; +use std::rc::Rc; +use std::sync::Arc; + +use crate::{templates::Templates, Thread}; + +/// Holds both the internal Thread and the AcpThread for a session +#[derive(Clone)] +struct Session { + /// The internal thread that processes messages + thread: Entity, + /// The ACP thread that handles protocol communication + acp_thread: Entity, +} + +pub struct NativeAgent { + /// Session ID -> Session mapping + sessions: HashMap, + /// Shared templates for all threads + templates: Arc, +} + +impl NativeAgent { + pub fn new(templates: Arc) -> Self { + log::info!("Creating new NativeAgent"); + Self { + sessions: HashMap::new(), + templates, + } + } +} + +/// Wrapper struct that implements the AgentConnection trait +#[derive(Clone)] +pub struct NativeAgentConnection(pub Entity); + +impl ModelSelector for NativeAgentConnection { + fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { + log::debug!("NativeAgentConnection::list_models called"); + cx.spawn(async move |cx| { + cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let models = registry.available_models(cx).collect::>(); + log::info!("Found {} available models", models.len()); + if models.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(models) + } + })? + }) + } + + fn select_model( + &self, + session_id: acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task> { + log::info!( + "Setting model for session {}: {:?}", + session_id, + model.name() + ); + let agent = self.0.clone(); + + cx.spawn(async move |cx| { + agent.update(cx, |agent, cx| { + if let Some(session) = agent.sessions.get(&session_id) { + session.thread.update(cx, |thread, _cx| { + thread.selected_model = model; + }); + Ok(()) + } else { + Err(anyhow!("Session not found")) + } + })? + }) + } + + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + let session_id = session_id.clone(); + cx.spawn(async move |cx| { + let session = agent + .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())? + .ok_or_else(|| anyhow::anyhow!("Session not found"))?; + let selected = session + .thread + .read_with(cx, |thread, _| thread.selected_model.clone())?; + Ok(selected) + }) + } +} + +impl acp_thread::AgentConnection for NativeAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + log::info!("Creating new thread for project at: {:?}", cwd); + + cx.spawn(async move |cx| { + log::debug!("Starting thread creation in async context"); + // Create Thread + let (session_id, thread) = agent.update( + cx, + |agent, cx: &mut gpui::Context| -> Result<_> { + // Fetch default model from registry settings + let registry = LanguageModelRegistry::read_global(cx); + + // Log available models for debugging + let available_count = registry.available_models(cx).count(); + log::debug!("Total available models: {}", available_count); + + let default_model = registry + .default_model() + .map(|configured| { + log::info!( + "Using configured default model: {:?} from provider: {:?}", + configured.model.name(), + configured.provider.name() + ); + configured.model + }) + .ok_or_else(|| { + log::warn!("No default model configured in settings"); + anyhow!("No default model configured. Please configure a default model in settings.") + })?; + + let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model)); + + // Generate session ID + let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + log::info!("Created session with ID: {}", session_id); + Ok((session_id, thread)) + }, + )??; + + // Create AcpThread + let acp_thread = cx.update(|cx| { + cx.new(|cx| { + acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx) + }) + })?; + + // Store the session + agent.update(cx, |agent, _cx| { + agent.sessions.insert( + session_id, + Session { + thread, + acp_thread: acp_thread.clone(), + }, + ); + })?; + + Ok(acp_thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] // No auth for in-process + } + + fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } + + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + let session_id = params.session_id.clone(); + let agent = self.0.clone(); + log::info!("Received prompt request for session: {}", session_id); + log::debug!("Prompt blocks count: {}", params.prompt.len()); + + cx.spawn(async move |cx| { + // Get session + let session = agent + .read_with(cx, |agent, _| { + agent.sessions.get(&session_id).map(|s| Session { + thread: s.thread.clone(), + acp_thread: s.acp_thread.clone(), + }) + })? + .ok_or_else(|| { + log::error!("Session not found: {}", session_id); + anyhow::anyhow!("Session not found") + })?; + log::debug!("Found session for: {}", session_id); + + // Convert prompt to message + let message = convert_prompt_to_message(params.prompt); + log::info!("Converted prompt to message: {} chars", message.len()); + log::debug!("Message content: {}", message); + + // Get model using the ModelSelector capability (always available for agent2) + // Get the selected model from the thread directly + let model = session + .thread + .read_with(cx, |thread, _| thread.selected_model.clone())?; + + // Send to thread + log::info!("Sending message to thread with model: {:?}", model.name()); + let response_stream = session + .thread + .update(cx, |thread, cx| thread.send(model, message, cx))?; + + // Handle response stream and forward to session.acp_thread + let acp_thread = session.acp_thread.clone(); + cx.spawn(async move |cx| { + use futures::StreamExt; + use language_model::LanguageModelCompletionEvent; + + let mut response_stream = response_stream; + + while let Some(result) = response_stream.next().await { + match result { + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + + match event { + LanguageModelCompletionEvent::Text(text) => { + // Send text chunk as agent message + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text( + acp::TextContent { + text: text.into(), + annotations: None, + }, + ), + }, + cx, + ) + })??; + } + LanguageModelCompletionEvent::ToolUse(tool_use) => { + // Convert LanguageModelToolUse to ACP ToolCall + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + label: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(tool_use.input), + }), + cx, + ) + })??; + } + LanguageModelCompletionEvent::StartMessage { .. } => { + log::debug!("Started new assistant message"); + } + LanguageModelCompletionEvent::UsageUpdate(usage) => { + log::debug!("Token usage update: {:?}", usage); + } + LanguageModelCompletionEvent::Thinking { text, .. } => { + // Send thinking text as agent thought chunk + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: acp::ContentBlock::Text( + acp::TextContent { + text: text.into(), + annotations: None, + }, + ), + }, + cx, + ) + })??; + } + LanguageModelCompletionEvent::StatusUpdate(status) => { + log::trace!("Status update: {:?}", status); + } + LanguageModelCompletionEvent::Stop(stop_reason) => { + log::debug!("Assistant message complete: {:?}", stop_reason); + } + LanguageModelCompletionEvent::RedactedThinking { .. } => { + log::trace!("Redacted thinking event"); + } + LanguageModelCompletionEvent::ToolUseJsonParseError { + id, + tool_name, + raw_input, + json_parse_error, + } => { + log::error!( + "Tool use JSON parse error for tool '{}' (id: {}): {} - input: {}", + tool_name, + id, + json_parse_error, + raw_input + ); + } + } + } + Err(e) => { + log::error!("Error in model response stream: {:?}", e); + // TODO: Consider sending an error message to the UI + break; + } + } + } + + log::info!("Response stream completed"); + anyhow::Ok(()) + }) + .detach(); + + log::info!("Successfully sent prompt to thread and started response handler"); + Ok(()) + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + log::info!("Cancelling session: {}", session_id); + self.0.update(cx, |agent, _cx| { + agent.sessions.remove(session_id); + }); + } +} + +/// Convert ACP content blocks to a message string +fn convert_prompt_to_message(blocks: Vec) -> String { + log::debug!("Converting {} content blocks to message", blocks.len()); + let mut message = String::new(); + + for block in blocks { + match block { + acp::ContentBlock::Text(text) => { + log::trace!("Processing text block: {} chars", text.text.len()); + message.push_str(&text.text); + } + acp::ContentBlock::ResourceLink(link) => { + log::trace!("Processing resource link: {}", link.uri); + message.push_str(&format!(" @{} ", link.uri)); + } + acp::ContentBlock::Image(_) => { + log::trace!("Processing image block"); + message.push_str(" [image] "); + } + acp::ContentBlock::Audio(_) => { + log::trace!("Processing audio block"); + message.push_str(" [audio] "); + } + acp::ContentBlock::Resource(resource) => { + log::trace!("Processing resource block: {:?}", resource.resource); + message.push_str(&format!(" [resource: {:?}] ", resource.resource)); + } + } + } + + message +} diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs new file mode 100644 index 0000000000..aa665fe313 --- /dev/null +++ b/crates/agent2/src/agent2.rs @@ -0,0 +1,13 @@ +mod agent; +mod native_agent_server; +mod prompts; +mod templates; +mod thread; +mod tools; + +#[cfg(test)] +mod tests; + +pub use agent::*; +pub use native_agent_server::NativeAgentServer; +pub use thread::*; diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs new file mode 100644 index 0000000000..aafe70a8a2 --- /dev/null +++ b/crates/agent2/src/native_agent_server.rs @@ -0,0 +1,58 @@ +use std::path::Path; +use std::rc::Rc; + +use agent_servers::AgentServer; +use anyhow::Result; +use gpui::{App, AppContext, Entity, Task}; +use project::Project; + +use crate::{templates::Templates, NativeAgent, NativeAgentConnection}; + +#[derive(Clone)] +pub struct NativeAgentServer; + +impl AgentServer for NativeAgentServer { + fn name(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_headline(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_message(&self) -> &'static str { + "How can I help you today?" + } + + fn logo(&self) -> ui::IconName { + // Using the ZedAssistant icon as it's the native built-in agent + ui::IconName::ZedAssistant + } + + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + cx: &mut App, + ) -> Task>> { + log::info!( + "NativeAgentServer::connect called for path: {:?}", + _root_dir + ); + cx.spawn(async move |cx| { + log::debug!("Creating templates for native agent"); + // Create templates (you might want to load these from files or resources) + let templates = Templates::new(); + + // Create the native agent + log::debug!("Creating native agent entity"); + let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?; + + // Create the connection wrapper + let connection = NativeAgentConnection(agent); + log::info!("NativeAgentServer connection established successfully"); + + Ok(Rc::new(connection) as Rc) + }) + } +} diff --git a/crates/agent2/src/prompts.rs b/crates/agent2/src/prompts.rs new file mode 100644 index 0000000000..015f56f4db --- /dev/null +++ b/crates/agent2/src/prompts.rs @@ -0,0 +1,30 @@ +use crate::{ + templates::{BaseTemplate, Template, Templates, WorktreeData}, + thread::Prompt, +}; +use anyhow::Result; +use gpui::{App, Entity}; +use project::Project; + +#[allow(dead_code)] +struct BasePrompt { + project: Entity, +} + +impl Prompt for BasePrompt { + fn render(&self, templates: &Templates, cx: &App) -> Result { + BaseTemplate { + os: std::env::consts::OS.to_string(), + shell: util::get_system_shell(), + worktrees: self + .project + .read(cx) + .worktrees(cx) + .map(|worktree| WorktreeData { + root_name: worktree.read(cx).root_name().to_string(), + }) + .collect(), + } + .render(templates) + } +} diff --git a/crates/agent2/src/templates.rs b/crates/agent2/src/templates.rs new file mode 100644 index 0000000000..04569369be --- /dev/null +++ b/crates/agent2/src/templates.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use anyhow::Result; +use handlebars::Handlebars; +use rust_embed::RustEmbed; +use serde::Serialize; + +#[derive(RustEmbed)] +#[folder = "src/templates"] +#[include = "*.hbs"] +struct Assets; + +pub struct Templates(Handlebars<'static>); + +impl Templates { + pub fn new() -> Arc { + let mut handlebars = Handlebars::new(); + handlebars.register_embed_templates::().unwrap(); + Arc::new(Self(handlebars)) + } +} + +pub trait Template: Sized { + const TEMPLATE_NAME: &'static str; + + fn render(&self, templates: &Templates) -> Result + where + Self: Serialize + Sized, + { + Ok(templates.0.render(Self::TEMPLATE_NAME, self)?) + } +} + +#[derive(Serialize)] +pub struct BaseTemplate { + pub os: String, + pub shell: String, + pub worktrees: Vec, +} + +impl Template for BaseTemplate { + const TEMPLATE_NAME: &'static str = "base.hbs"; +} + +#[derive(Serialize)] +pub struct WorktreeData { + pub root_name: String, +} + +#[derive(Serialize)] +pub struct GlobTemplate { + pub project_roots: String, +} + +impl Template for GlobTemplate { + const TEMPLATE_NAME: &'static str = "glob.hbs"; +} diff --git a/crates/agent2/src/templates/base.hbs b/crates/agent2/src/templates/base.hbs new file mode 100644 index 0000000000..7eef231e32 --- /dev/null +++ b/crates/agent2/src/templates/base.hbs @@ -0,0 +1,56 @@ +You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. + +## Communication + +1. Be conversational but professional. +2. Refer to the USER in the second person and yourself in the first person. +3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. +4. NEVER lie or make things up. +5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing. + +## Tool Use + +1. Make sure to adhere to the tools schema. +2. Provide every required argument. +3. DO NOT use tools to access items that are already available in the context section. +4. Use only the tools that are currently available. +5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off. + +## Searching and Reading + +If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions. + +If appropriate, use tool calls to explore the current project, which contains the following root directories: + +{{#each worktrees}} +- `{{root_name}}` +{{/each}} + +- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above. +- When looking for symbols in the project, prefer the `grep` tool. +- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project. +- Bias towards not asking the user for help if you can find the answer yourself. + +## Fixing Diagnostics + +1. Make 1-2 attempts at fixing diagnostics, then defer to the user. +2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem. + +## Debugging + +When debugging, only make code changes if you are certain that you can solve the problem. +Otherwise, follow debugging best practices: +1. Address the root cause instead of the symptoms. +2. Add descriptive logging statements and error messages to track variable and code state. +3. Add test functions and statements to isolate the problem. + +## Calling External APIs + +1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission. +2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data. +3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed) + +## System Information + +Operating System: {{os}} +Default Shell: {{shell}} diff --git a/crates/agent2/src/templates/glob.hbs b/crates/agent2/src/templates/glob.hbs new file mode 100644 index 0000000000..3bf992b093 --- /dev/null +++ b/crates/agent2/src/templates/glob.hbs @@ -0,0 +1,8 @@ +Find paths on disk with glob patterns. + +Assume that all glob patterns are matched in a project directory with the following entries. + +{{project_roots}} + +When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be +sure to anchor them with one of the directories listed above. diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs new file mode 100644 index 0000000000..ced3c239f0 --- /dev/null +++ b/crates/agent2/src/tests/mod.rs @@ -0,0 +1,378 @@ +use super::*; +use crate::templates::Templates; +use acp_thread::AgentConnection as _; +use agent_client_protocol as acp; +use client::{Client, UserStore}; +use gpui::{AppContext, Entity, TestAppContext}; +use language_model::{ + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelRegistry, MessageContent, StopReason, +}; +use project::Project; +use reqwest_client::ReqwestClient; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use smol::stream::StreamExt; +use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; + +mod test_tools; +use test_tools::*; + +#[gpui::test] +async fn test_echo(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx).await; + + let events = thread + .update(cx, |thread, cx| { + thread.send(model.clone(), "Testing: Reply with 'Hello'", cx) + }) + .collect() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().content, + vec![MessageContent::Text("Hello".to_string())] + ); + }); + assert_eq!(stop_events(events), vec![StopReason::EndTurn]); +} + +#[gpui::test] +async fn test_basic_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send( + model.clone(), + "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", + cx, + ) + }) + .collect() + .await; + assert_eq!( + stop_events(events), + vec![StopReason::ToolUse, StopReason::EndTurn] + ); + + // Test a tool calls that's likely to complete *after* streaming stops. + let events = thread + .update(cx, |thread, cx| { + thread.remove_tool(&AgentTool::name(&EchoTool)); + thread.add_tool(DelayTool); + thread.send( + model.clone(), + "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", + cx, + ) + }) + .collect() + .await; + assert_eq!( + stop_events(events), + vec![StopReason::ToolUse, StopReason::EndTurn] + ); + thread.update(cx, |thread, _cx| { + assert!(thread + .messages() + .last() + .unwrap() + .content + .iter() + .any(|content| { + if let MessageContent::Text(text) = content { + text.contains("Ding") + } else { + false + } + })); + }); +} + +#[gpui::test] +async fn test_streaming_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(WordListTool); + thread.send(model.clone(), "Test the word_list tool.", cx) + }); + + let mut saw_partial_tool_use = false; + while let Some(event) = events.next().await { + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event { + thread.update(cx, |thread, _cx| { + // Look for a tool use in the thread's last message + let last_content = thread.messages().last().unwrap().content.last().unwrap(); + if let MessageContent::ToolUse(last_tool_use) = last_content { + assert_eq!(last_tool_use.name.as_ref(), "word_list"); + if tool_use_event.is_input_complete { + last_tool_use + .input + .get("a") + .expect("'a' has streamed because input is now complete"); + last_tool_use + .input + .get("g") + .expect("'g' has streamed because input is now complete"); + } else { + if !last_tool_use.is_input_complete + && last_tool_use.input.get("g").is_none() + { + saw_partial_tool_use = true; + } + } + } else { + panic!("last content should be a tool use"); + } + }); + } + } + + assert!( + saw_partial_tool_use, + "should see at least one partially streamed tool use in the history" + ); +} + +#[gpui::test] +async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx).await; + + // Test concurrent tool calls with different delay times + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(DelayTool); + thread.send( + model.clone(), + "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", + cx, + ) + }) + .collect() + .await; + + let stop_reasons = stop_events(events); + if stop_reasons.len() == 2 { + assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]); + } else if stop_reasons.len() == 3 { + assert_eq!( + stop_reasons, + vec![ + StopReason::ToolUse, + StopReason::ToolUse, + StopReason::EndTurn + ] + ); + } else { + panic!("Expected either 1 or 2 tool uses followed by end turn"); + } + + thread.update(cx, |thread, _cx| { + let last_message = thread.messages().last().unwrap(); + let text = last_message + .content + .iter() + .filter_map(|content| { + if let MessageContent::Text(text) = content { + Some(text.as_str()) + } else { + None + } + }) + .collect::(); + + assert!(text.contains("Ding")); + }); +} + +#[gpui::test] +async fn test_agent_connection(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + cx.update(settings::init); + let templates = Templates::new(); + + // Initialize language model system with test provider + cx.update(|cx| { + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + // Initialize project settings + Project::init_settings(cx); + + // Use test registry with fake provider + LanguageModelRegistry::test(cx); + }); + + // Create agent and connection + let agent = cx.new(|_| NativeAgent::new(templates.clone())); + let connection = NativeAgentConnection(agent.clone()); + + // Test model_selector returns Some + let selector_opt = connection.model_selector(); + assert!( + selector_opt.is_some(), + "agent2 should always support ModelSelector" + ); + let selector = selector_opt.unwrap(); + + // Test list_models + let listed_models = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.list_models(&mut async_cx) + }) + .await + .expect("list_models should succeed"); + assert!(!listed_models.is_empty(), "should have at least one model"); + assert_eq!(listed_models[0].id().0, "fake"); + + // Create a project for new_thread + let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); + let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + + // Create a thread using new_thread + let cwd = Path::new("/test"); + let connection_rc = Rc::new(connection.clone()); + let acp_thread = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + connection_rc.new_thread(project, cwd, &mut async_cx) + }) + .await + .expect("new_thread should succeed"); + + // Get the session_id from the AcpThread + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + + // Test selected_model returns the default + let selected = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&session_id, &mut async_cx) + }) + .await + .expect("selected_model should succeed"); + assert_eq!(selected.id().0, "fake", "should return default model"); + + // The thread was created via prompt with the default model + // We can verify it through selected_model + + // Test prompt uses the selected model + let prompt_request = acp::PromptRequest { + session_id: session_id.clone(), + prompt: vec![acp::ContentBlock::Text(acp::TextContent { + text: "Test prompt".into(), + annotations: None, + })], + }; + + cx.update(|cx| connection.prompt(prompt_request, cx)) + .await + .expect("prompt should succeed"); + + // The prompt was sent successfully + + // Test cancel + cx.update(|cx| connection.cancel(&session_id, cx)); + + // After cancel, selected_model should fail + let result = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&session_id, &mut async_cx) + }) + .await; + assert!(result.is_err(), "selected_model should fail after cancel"); + + // Test error case: invalid session + let invalid_session = acp::SessionId("invalid".into()); + let result = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&invalid_session, &mut async_cx) + }) + .await; + assert!(result.is_err(), "should fail for invalid session"); + if let Err(e) = result { + assert!( + e.to_string().contains("Session not found"), + "should have correct error message" + ); + } +} + +/// Filters out the stop events for asserting against in tests +fn stop_events( + result_events: Vec>, +) -> Vec { + result_events + .into_iter() + .filter_map(|event| match event.unwrap() { + LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason), + _ => None, + }) + .collect() +} + +struct ThreadTest { + model: Arc, + thread: Entity, +} + +async fn setup(cx: &mut TestAppContext) -> ThreadTest { + cx.executor().allow_parking(); + cx.update(settings::init); + let templates = Templates::new(); + + let model = cx + .update(|cx| { + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + let models = LanguageModelRegistry::read_global(cx); + let model = models + .available_models(cx) + .find(|model| model.id().0 == "claude-3-7-sonnet-latest") + .unwrap(); + + let provider = models.provider(&model.provider_id()).unwrap(); + let authenticated = provider.authenticate(cx); + + cx.spawn(async move |_cx| { + authenticated.await.unwrap(); + model + }) + }) + .await; + + let thread = cx.new(|_| Thread::new(templates, model.clone())); + + ThreadTest { model, thread } +} + +#[cfg(test)] +#[ctor::ctor] +fn init_logger() { + if std::env::var("RUST_LOG").is_ok() { + env_logger::init(); + } +} diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs new file mode 100644 index 0000000000..43d0414499 --- /dev/null +++ b/crates/agent2/src/tests/test_tools.rs @@ -0,0 +1,85 @@ +use super::*; +use anyhow::Result; +use gpui::{App, SharedString, Task}; + +/// A tool that echoes its input +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct EchoToolInput { + /// The text to echo. + text: String, +} + +pub struct EchoTool; + +impl AgentTool for EchoTool { + type Input = EchoToolInput; + + fn name(&self) -> SharedString { + "echo".into() + } + + fn run(self: Arc, input: Self::Input, _cx: &mut App) -> Task> { + Task::ready(Ok(input.text)) + } +} + +/// A tool that waits for a specified delay +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct DelayToolInput { + /// The delay in milliseconds. + ms: u64, +} + +pub struct DelayTool; + +impl AgentTool for DelayTool { + type Input = DelayToolInput; + + fn name(&self) -> SharedString { + "delay".into() + } + + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> + where + Self: Sized, + { + cx.foreground_executor().spawn(async move { + smol::Timer::after(Duration::from_millis(input.ms)).await; + Ok("Ding".to_string()) + }) + } +} + +/// A tool that takes an object with map from letters to random words starting with that letter. +/// All fiealds are required! Pass a word for every letter! +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct WordListInput { + /// Provide a random word that starts with A. + a: Option, + /// Provide a random word that starts with B. + b: Option, + /// Provide a random word that starts with C. + c: Option, + /// Provide a random word that starts with D. + d: Option, + /// Provide a random word that starts with E. + e: Option, + /// Provide a random word that starts with F. + f: Option, + /// Provide a random word that starts with G. + g: Option, +} + +pub struct WordListTool; + +impl AgentTool for WordListTool { + type Input = WordListInput; + + fn name(&self) -> SharedString { + "word_list".into() + } + + fn run(self: Arc, _input: Self::Input, _cx: &mut App) -> Task> { + Task::ready(Ok("ok".to_string())) + } +} diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs new file mode 100644 index 0000000000..4bc2932aeb --- /dev/null +++ b/crates/agent2/src/thread.rs @@ -0,0 +1,493 @@ +use crate::templates::Templates; +use anyhow::{anyhow, Result}; +use cloud_llm_client::{CompletionIntent, CompletionMode}; +use futures::{channel::mpsc, future}; +use gpui::{App, Context, SharedString, Task}; +use language_model::{ + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, + LanguageModelToolUse, MessageContent, Role, StopReason, +}; +use log; +use schemars::{JsonSchema, Schema}; +use serde::Deserialize; +use smol::stream::StreamExt; +use std::{collections::BTreeMap, sync::Arc}; +use util::ResultExt; + +#[derive(Debug)] +pub struct AgentMessage { + pub role: Role, + pub content: Vec, +} + +pub type AgentResponseEvent = LanguageModelCompletionEvent; + +pub trait Prompt { + fn render(&self, prompts: &Templates, cx: &App) -> Result; +} + +pub struct Thread { + messages: Vec, + completion_mode: CompletionMode, + /// Holds the task that handles agent interaction until the end of the turn. + /// Survives across multiple requests as the model performs tool calls and + /// we run tools, report their results. + running_turn: Option>, + system_prompts: Vec>, + tools: BTreeMap>, + templates: Arc, + pub selected_model: Arc, + // action_log: Entity, +} + +impl Thread { + pub fn new(templates: Arc, default_model: Arc) -> Self { + Self { + messages: Vec::new(), + completion_mode: CompletionMode::Normal, + system_prompts: Vec::new(), + running_turn: None, + tools: BTreeMap::default(), + templates, + selected_model: default_model, + } + } + + pub fn set_mode(&mut self, mode: CompletionMode) { + self.completion_mode = mode; + } + + pub fn messages(&self) -> &[AgentMessage] { + &self.messages + } + + pub fn add_tool(&mut self, tool: impl AgentTool) { + self.tools.insert(tool.name(), tool.erase()); + } + + pub fn remove_tool(&mut self, name: &str) -> bool { + self.tools.remove(name).is_some() + } + + /// Sending a message results in the model streaming a response, which could include tool calls. + /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. + /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. + pub fn send( + &mut self, + model: Arc, + content: impl Into, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let content = content.into(); + log::info!("Thread::send called with model: {:?}", model.name()); + log::debug!("Thread::send content: {:?}", content); + + cx.notify(); + let (events_tx, events_rx) = + mpsc::unbounded::>(); + + let system_message = self.build_system_message(cx); + log::debug!( + "System messages count: {}", + if system_message.is_some() { 1 } else { 0 } + ); + self.messages.extend(system_message); + + self.messages.push(AgentMessage { + role: Role::User, + content: vec![content], + }); + log::info!("Total messages in thread: {}", self.messages.len()); + self.running_turn = Some(cx.spawn(async move |thread, cx| { + log::info!("Starting agent turn execution"); + let turn_result = async { + // Perform one request, then keep looping if the model makes tool calls. + let mut completion_intent = CompletionIntent::UserPrompt; + loop { + log::debug!( + "Building completion request with intent: {:?}", + completion_intent + ); + let request = thread.update(cx, |thread, cx| { + thread.build_completion_request(completion_intent, cx) + })?; + + // println!( + // "request: {}", + // serde_json::to_string_pretty(&request).unwrap() + // ); + + // Stream events, appending to messages and collecting up tool uses. + log::info!("Calling model.stream_completion"); + let mut events = model.stream_completion(request, cx).await?; + log::debug!("Stream completion started successfully"); + let mut tool_uses = Vec::new(); + while let Some(event) = events.next().await { + match event { + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + thread + .update(cx, |thread, cx| { + tool_uses.extend(thread.handle_streamed_completion_event( + event, + events_tx.clone(), + cx, + )); + }) + .ok(); + } + Err(error) => { + log::error!("Error in completion stream: {:?}", error); + events_tx.unbounded_send(Err(error)).ok(); + break; + } + } + } + + // If there are no tool uses, the turn is done. + if tool_uses.is_empty() { + log::info!("No tool uses found, completing turn"); + break; + } + log::info!("Found {} tool uses to execute", tool_uses.len()); + + // If there are tool uses, wait for their results to be + // computed, then send them together in a single message on + // the next loop iteration. + let tool_results = future::join_all(tool_uses).await; + log::debug!("Tool execution completed, {} results", tool_results.len()); + thread + .update(cx, |thread, _cx| { + thread.messages.push(AgentMessage { + role: Role::User, + content: tool_results + .into_iter() + .map(MessageContent::ToolResult) + .collect(), + }); + }) + .ok(); + completion_intent = CompletionIntent::ToolResults; + } + + Ok(()) + } + .await; + + if let Err(error) = turn_result { + log::error!("Turn execution failed: {:?}", error); + events_tx.unbounded_send(Err(error)).ok(); + } else { + log::info!("Turn execution completed successfully"); + } + })); + events_rx + } + + pub fn build_system_message(&mut self, cx: &App) -> Option { + log::debug!("Building system message"); + let mut system_message = AgentMessage { + role: Role::System, + content: Vec::new(), + }; + + for prompt in &self.system_prompts { + if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() { + system_message + .content + .push(MessageContent::Text(rendered_prompt)); + } + } + + let result = (!system_message.content.is_empty()).then_some(system_message); + log::debug!("System message built: {}", result.is_some()); + result + } + + /// A helper method that's called on every streamed completion event. + /// Returns an optional tool result task, which the main agentic loop in + /// send will send back to the model when it resolves. + fn handle_streamed_completion_event( + &mut self, + event: LanguageModelCompletionEvent, + events_tx: mpsc::UnboundedSender>, + cx: &mut Context, + ) -> Option> { + log::trace!("Handling streamed completion event: {:?}", event); + use LanguageModelCompletionEvent::*; + events_tx.unbounded_send(Ok(event.clone())).ok(); + + match event { + Text(new_text) => self.handle_text_event(new_text, cx), + Thinking { + text: _text, + signature: _signature, + } => { + todo!() + } + ToolUse(tool_use) => { + return self.handle_tool_use_event(tool_use, cx); + } + StartMessage { .. } => { + self.messages.push(AgentMessage { + role: Role::Assistant, + content: Vec::new(), + }); + } + UsageUpdate(_) => {} + Stop(stop_reason) => self.handle_stop_event(stop_reason), + StatusUpdate(_completion_request_status) => {} + RedactedThinking { data: _data } => todo!(), + ToolUseJsonParseError { + id: _id, + tool_name: _tool_name, + raw_input: _raw_input, + json_parse_error: _json_parse_error, + } => todo!(), + } + + None + } + + fn handle_stop_event(&mut self, stop_reason: StopReason) { + match stop_reason { + StopReason::EndTurn | StopReason::ToolUse => {} + StopReason::MaxTokens => todo!(), + StopReason::Refusal => todo!(), + } + } + + fn handle_text_event(&mut self, new_text: String, cx: &mut Context) { + let last_message = self.last_assistant_message(); + if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { + text.push_str(&new_text); + } else { + last_message.content.push(MessageContent::Text(new_text)); + } + + cx.notify(); + } + + fn handle_tool_use_event( + &mut self, + tool_use: LanguageModelToolUse, + cx: &mut Context, + ) -> Option> { + cx.notify(); + + let last_message = self.last_assistant_message(); + + // Ensure the last message ends in the current tool use + let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { + if let MessageContent::ToolUse(last_tool_use) = content { + if last_tool_use.id == tool_use.id { + *last_tool_use = tool_use.clone(); + false + } else { + true + } + } else { + true + } + }); + if push_new_tool_use { + last_message + .content + .push(MessageContent::ToolUse(tool_use.clone())); + } + + if !tool_use.is_input_complete { + return None; + } + + if let Some(tool) = self.tools.get(tool_use.name.as_ref()) { + let pending_tool_result = tool.clone().run(tool_use.input, cx); + + Some(cx.foreground_executor().spawn(async move { + match pending_tool_result.await { + Ok(tool_output) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: false, + content: LanguageModelToolResultContent::Text(Arc::from(tool_output)), + output: None, + }, + Err(error) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), + output: None, + }, + } + })) + } else { + Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(format!( + "No tool named {} exists", + tool_use.name + ))), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })) + } + } + + /// Guarantees the last message is from the assistant and returns a mutable reference. + fn last_assistant_message(&mut self) -> &mut AgentMessage { + if self + .messages + .last() + .map_or(true, |m| m.role != Role::Assistant) + { + self.messages.push(AgentMessage { + role: Role::Assistant, + content: Vec::new(), + }); + } + self.messages.last_mut().unwrap() + } + + fn build_completion_request( + &self, + completion_intent: CompletionIntent, + cx: &mut App, + ) -> LanguageModelRequest { + log::debug!("Building completion request"); + log::debug!("Completion intent: {:?}", completion_intent); + log::debug!("Completion mode: {:?}", self.completion_mode); + + let messages = self.build_request_messages(); + log::info!("Request will include {} messages", messages.len()); + + let tools: Vec = self + .tools + .values() + .filter_map(|tool| { + let tool_name = tool.name().to_string(); + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name, + description: tool.description(cx).to_string(), + input_schema: tool + .input_schema(LanguageModelToolSchemaFormat::JsonSchema) + .log_err()?, + }) + }) + .collect(); + + log::info!("Request includes {} tools", tools.len()); + + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: Some(completion_intent), + mode: Some(self.completion_mode), + messages, + tools, + tool_choice: None, + stop: Vec::new(), + temperature: None, + thinking_allowed: false, + }; + + log::debug!("Completion request built successfully"); + request + } + + fn build_request_messages(&self) -> Vec { + log::trace!( + "Building request messages from {} thread messages", + self.messages.len() + ); + let messages = self + .messages + .iter() + .map(|message| { + log::trace!( + " - {} message with {} content items", + match message.role { + Role::System => "System", + Role::User => "User", + Role::Assistant => "Assistant", + }, + message.content.len() + ); + LanguageModelRequestMessage { + role: message.role, + content: message.content.clone(), + cache: false, + } + }) + .collect(); + messages + } +} + +pub trait AgentTool +where + Self: 'static + Sized, +{ + type Input: for<'de> Deserialize<'de> + JsonSchema; + + fn name(&self) -> SharedString; + fn description(&self, _cx: &mut App) -> SharedString { + let schema = schemars::schema_for!(Self::Input); + SharedString::new( + schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap_or_default(), + ) + } + + /// Returns the JSON schema that describes the tool's input. + fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema { + schemars::schema_for!(Self::Input) + } + + /// Runs the tool with the provided input. + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task>; + + fn erase(self) -> Arc { + Arc::new(Erased(Arc::new(self))) + } +} + +pub struct Erased(T); + +pub trait AnyAgentTool { + fn name(&self) -> SharedString; + fn description(&self, cx: &mut App) -> SharedString; + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task>; +} + +impl AnyAgentTool for Erased> +where + T: AgentTool, +{ + fn name(&self) -> SharedString { + self.0.name() + } + + fn description(&self, cx: &mut App) -> SharedString { + self.0.description(cx) + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + Ok(serde_json::to_value(self.0.input_schema(format))?) + } + + fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task> { + let parsed_input: Result = serde_json::from_value(input).map_err(Into::into); + match parsed_input { + Ok(input) => self.0.clone().run(input, cx), + Err(error) => Task::ready(Err(anyhow!(error))), + } + } +} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs new file mode 100644 index 0000000000..cf3162abfa --- /dev/null +++ b/crates/agent2/src/tools.rs @@ -0,0 +1 @@ +mod glob; diff --git a/crates/agent2/src/tools/glob.rs b/crates/agent2/src/tools/glob.rs new file mode 100644 index 0000000000..9434311aaf --- /dev/null +++ b/crates/agent2/src/tools/glob.rs @@ -0,0 +1,76 @@ +use anyhow::{anyhow, Result}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::Deserialize; +use std::{path::PathBuf, sync::Arc}; +use util::paths::PathMatcher; +use worktree::Snapshot as WorktreeSnapshot; + +use crate::{ + templates::{GlobTemplate, Template, Templates}, + thread::AgentTool, +}; + +// Description is dynamic, see `fn description` below +#[derive(Deserialize, JsonSchema)] +struct GlobInput { + /// A POSIX glob pattern + glob: SharedString, +} + +struct GlobTool { + project: Entity, + templates: Arc, +} + +impl AgentTool for GlobTool { + type Input = GlobInput; + + fn name(&self) -> SharedString { + "glob".into() + } + + fn description(&self, cx: &mut App) -> SharedString { + let project_roots = self + .project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).root_name().into()) + .collect::>() + .join("\n"); + + GlobTemplate { project_roots } + .render(&self.templates) + .expect("template failed to render") + .into() + } + + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> { + let path_matcher = match PathMatcher::new([&input.glob]) { + Ok(matcher) => matcher, + Err(error) => return Task::ready(Err(anyhow!(error))), + }; + + let snapshots: Vec = self + .project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect(); + + cx.background_spawn(async move { + let paths = snapshots.iter().flat_map(|snapshot| { + let root_name = PathBuf::from(snapshot.root_name()); + snapshot + .entries(false, 0) + .map(move |entry| root_name.join(&entry.path)) + .filter(|path| path_matcher.is_match(&path)) + }); + let output = paths + .map(|path| format!("{}\n", path.display())) + .collect::(); + Ok(output) + }) + } +} diff --git a/crates/agent_servers/acp b/crates/agent_servers/acp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs new file mode 100644 index 0000000000..0ced22fc65 --- /dev/null +++ b/crates/agent_servers/src/acp_connection.rs @@ -0,0 +1,245 @@ +use agent_client_protocol::{self as acp, Agent as _}; +use collections::HashMap; +use futures::channel::oneshot; +use project::Project; +use std::cell::RefCell; +use std::path::Path; +use std::rc::Rc; + +use anyhow::{Context as _, Result}; +use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; + +use crate::AgentServerCommand; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; + +pub struct AcpConnection { + server_name: &'static str, + connection: Rc, + sessions: Rc>>, + auth_methods: Vec, + _io_task: Task>, +} + +pub struct AcpSession { + thread: WeakEntity, +} + +impl AcpConnection { + pub async fn stdio( + server_name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, + ) -> Result { + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter().map(|arg| arg.as_str())) + .envs(command.env.iter().flatten()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdout = child.stdout.take().expect("Failed to take stdout"); + let stdin = child.stdin.take().expect("Failed to take stdin"); + + let sessions = Rc::new(RefCell::new(HashMap::default())); + + let client = ClientDelegate { + sessions: sessions.clone(), + cx: cx.clone(), + }; + let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, { + let foreground_executor = cx.foreground_executor().clone(); + move |fut| { + foreground_executor.spawn(fut).detach(); + } + }); + + let io_task = cx.background_spawn(io_task); + + let response = connection + .initialize(acp::InitializeRequest { + protocol_version: acp::VERSION, + client_capabilities: acp::ClientCapabilities { + fs: acp::FileSystemCapability { + read_text_file: true, + write_text_file: true, + }, + }, + }) + .await?; + + // todo! check version + + Ok(Self { + auth_methods: response.auth_methods, + connection: connection.into(), + server_name, + sessions, + _io_task: io_task, + }) + } +} + +impl AgentConnection for AcpConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let conn = self.connection.clone(); + let sessions = self.sessions.clone(); + let cwd = cwd.to_path_buf(); + cx.spawn(async move |cx| { + let response = conn + .new_session(acp::NewSessionRequest { + // todo! Zed MCP server? + mcp_servers: vec![], + cwd, + }) + .await?; + + let Some(session_id) = response.session_id else { + anyhow::bail!(AuthRequired); + }; + + let thread = cx.new(|cx| { + AcpThread::new( + self.server_name, + self.clone(), + project, + session_id.clone(), + cx, + ) + })?; + + let session = AcpSession { + thread: thread.downgrade(), + }; + sessions.borrow_mut().insert(session_id, session); + + Ok(thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor().spawn(async move { + let result = conn + .authenticate(acp::AuthenticateRequest { + method_id: method_id.clone(), + }) + .await?; + + Ok(result) + }) + } + + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor() + .spawn(async move { Ok(conn.prompt(params).await?) }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let conn = self.connection.clone(); + let params = acp::CancelledNotification { + session_id: session_id.clone(), + }; + cx.foreground_executor() + .spawn(async move { conn.cancelled(params).await }) + .detach(); + } +} + +struct ClientDelegate { + sessions: Rc>>, + cx: AsyncApp, +} + +impl acp::Client for ClientDelegate { + async fn request_permission( + &self, + arguments: acp::RequestPermissionRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let result = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx) + })? + .await; + + let outcome = match result { + Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, + Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled, + }; + + Ok(acp::RequestPermissionResponse { outcome }) + } + + async fn write_text_file( + &self, + arguments: acp::WriteTextFileRequest, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + self.sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.write_text_file(arguments.path, arguments.content, cx) + })? + .await?; + + Ok(()) + } + + async fn read_text_file( + &self, + arguments: acp::ReadTextFileRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let content = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx) + })? + .await?; + + Ok(acp::ReadTextFileResponse { content }) + } + + async fn session_notification( + &self, + notification: acp::SessionNotification, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let sessions = self.sessions.borrow(); + let session = sessions + .get(¬ification.session_id) + .context("Failed to get session")?; + + session.thread.update(cx, |thread, cx| { + thread.handle_session_update(notification.update, cx) + })??; + + Ok(()) + } +} diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 212bb74d8a..13bad53cd9 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,14 +1,12 @@ +mod acp_connection; mod claude; -mod codex; mod gemini; -mod mcp_server; mod settings; #[cfg(test)] mod e2e_tests; pub use claude::*; -pub use codex::*; pub use gemini::*; pub use settings::*; @@ -38,7 +36,6 @@ pub trait AgentServer: Send { fn connect( &self, - // these will go away when old_acp is fully removed root_dir: &Path, project: &Entity, cx: &mut App, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 6565786204..9040b83085 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -70,10 +70,6 @@ struct ClaudeAgentConnection { } impl AgentConnection for ClaudeAgentConnection { - fn name(&self) -> &'static str { - ClaudeCode.name() - } - fn new_thread( self: Rc, project: Entity, @@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection { } }); - let thread = - cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?; + let thread = cx.new(|cx| { + AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) + })?; thread_tx.send(thread.downgrade())?; @@ -186,11 +183,15 @@ impl AgentConnection for ClaudeAgentConnection { }) } - fn authenticate(&self, _cx: &mut App) -> Task> { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { Task::ready(Err(anyhow!("Authentication not supported"))) } - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let sessions = self.sessions.borrow(); let Some(session) = sessions.get(¶ms.session_id) else { return Task::ready(Err(anyhow!( diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs deleted file mode 100644 index 712c333221..0000000000 --- a/crates/agent_servers/src/codex.rs +++ /dev/null @@ -1,319 +0,0 @@ -use agent_client_protocol as acp; -use anyhow::anyhow; -use collections::HashMap; -use context_server::listener::McpServerTool; -use context_server::types::requests; -use context_server::{ContextServer, ContextServerCommand, ContextServerId}; -use futures::channel::{mpsc, oneshot}; -use project::Project; -use settings::SettingsStore; -use smol::stream::StreamExt as _; -use std::cell::RefCell; -use std::rc::Rc; -use std::{path::Path, sync::Arc}; -use util::ResultExt; - -use anyhow::{Context, Result}; -use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; - -use crate::mcp_server::ZedMcpServer; -use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server}; -use acp_thread::{AcpThread, AgentConnection}; - -#[derive(Clone)] -pub struct Codex; - -impl AgentServer for Codex { - fn name(&self) -> &'static str { - "Codex" - } - - fn empty_state_headline(&self) -> &'static str { - "Welcome to Codex" - } - - fn empty_state_message(&self) -> &'static str { - "What can I help with?" - } - - fn logo(&self) -> ui::IconName { - ui::IconName::AiOpenAi - } - - fn connect( - &self, - _root_dir: &Path, - project: &Entity, - cx: &mut App, - ) -> Task>> { - let project = project.clone(); - let working_directory = project.read(cx).active_project_directory(cx); - cx.spawn(async move |cx| { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).codex.clone() - })?; - - let Some(command) = - AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await - else { - anyhow::bail!("Failed to find codex binary"); - }; - - let client: Arc = ContextServer::stdio( - ContextServerId("codex-mcp-server".into()), - ContextServerCommand { - path: command.path, - args: command.args, - env: command.env, - }, - working_directory, - ) - .into(); - ContextServer::start(client.clone(), cx).await?; - - let (notification_tx, mut notification_rx) = mpsc::unbounded(); - client - .client() - .context("Failed to subscribe")? - .on_notification(acp::SESSION_UPDATE_METHOD_NAME, { - move |notification, _cx| { - let notification_tx = notification_tx.clone(); - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(notification) = - serde_json::from_value::(notification) - .log_err() - { - notification_tx.unbounded_send(notification).ok(); - } - } - }); - - let sessions = Rc::new(RefCell::new(HashMap::default())); - - let notification_handler_task = cx.spawn({ - let sessions = sessions.clone(); - async move |cx| { - while let Some(notification) = notification_rx.next().await { - CodexConnection::handle_session_notification( - notification, - sessions.clone(), - cx, - ) - } - } - }); - - let connection = CodexConnection { - client, - sessions, - _notification_handler_task: notification_handler_task, - }; - Ok(Rc::new(connection) as _) - }) - } -} - -struct CodexConnection { - client: Arc, - sessions: Rc>>, - _notification_handler_task: Task<()>, -} - -struct CodexSession { - thread: WeakEntity, - cancel_tx: Option>, - _mcp_server: ZedMcpServer, -} - -impl AgentConnection for CodexConnection { - fn name(&self) -> &'static str { - "Codex" - } - - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let client = self.client.client(); - let sessions = self.sessions.clone(); - let cwd = cwd.to_path_buf(); - cx.spawn(async move |cx| { - let client = client.context("MCP server is not initialized yet")?; - let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); - - let mcp_server = ZedMcpServer::new(thread_rx, cx).await?; - - let response = client - .request::(context_server::types::CallToolParams { - name: acp::NEW_SESSION_TOOL_NAME.into(), - arguments: Some(serde_json::to_value(acp::NewSessionArguments { - mcp_servers: [( - mcp_server::SERVER_NAME.to_string(), - mcp_server.server_config()?, - )] - .into(), - client_tools: acp::ClientTools { - request_permission: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::RequestPermissionTool::NAME.into(), - }), - read_text_file: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::ReadTextFileTool::NAME.into(), - }), - write_text_file: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::WriteTextFileTool::NAME.into(), - }), - }, - cwd, - })?), - meta: None, - }) - .await?; - - if response.is_error.unwrap_or_default() { - return Err(anyhow!(response.text_contents())); - } - - let result = serde_json::from_value::( - response.structured_content.context("Empty response")?, - )?; - - let thread = - cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?; - - thread_tx.send(thread.downgrade())?; - - let session = CodexSession { - thread: thread.downgrade(), - cancel_tx: None, - _mcp_server: mcp_server, - }; - sessions.borrow_mut().insert(result.session_id, session); - - Ok(thread) - }) - } - - fn authenticate(&self, _cx: &mut App) -> Task> { - Task::ready(Err(anyhow!("Authentication not supported"))) - } - - fn prompt( - &self, - params: agent_client_protocol::PromptArguments, - cx: &mut App, - ) -> Task> { - let client = self.client.client(); - let sessions = self.sessions.clone(); - - cx.foreground_executor().spawn(async move { - let client = client.context("MCP server is not initialized yet")?; - - let (new_cancel_tx, cancel_rx) = oneshot::channel(); - { - let mut sessions = sessions.borrow_mut(); - let session = sessions - .get_mut(¶ms.session_id) - .context("Session not found")?; - session.cancel_tx.replace(new_cancel_tx); - } - - let result = client - .request_with::( - context_server::types::CallToolParams { - name: acp::PROMPT_TOOL_NAME.into(), - arguments: Some(serde_json::to_value(params)?), - meta: None, - }, - Some(cancel_rx), - None, - ) - .await; - - if let Err(err) = &result - && err.is::() - { - return Ok(()); - } - - let response = result?; - - if response.is_error.unwrap_or_default() { - return Err(anyhow!(response.text_contents())); - } - - Ok(()) - }) - } - - fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) { - let mut sessions = self.sessions.borrow_mut(); - - if let Some(cancel_tx) = sessions - .get_mut(session_id) - .and_then(|session| session.cancel_tx.take()) - { - cancel_tx.send(()).ok(); - } - } -} - -impl CodexConnection { - pub fn handle_session_notification( - notification: acp::SessionNotification, - threads: Rc>>, - cx: &mut AsyncApp, - ) { - let threads = threads.borrow(); - let Some(thread) = threads - .get(¬ification.session_id) - .and_then(|session| session.thread.upgrade()) - else { - log::error!( - "Thread not found for session ID: {}", - notification.session_id - ); - return; - }; - - thread - .update(cx, |thread, cx| { - thread.handle_session_update(notification.update, cx) - }) - .log_err(); - } -} - -impl Drop for CodexConnection { - fn drop(&mut self) { - self.client.stop().log_err(); - } -} - -#[cfg(test)] -pub(crate) mod tests { - use super::*; - use crate::AgentServerCommand; - use std::path::Path; - - crate::common_e2e_tests!(Codex, allow_option_id = "approve"); - - pub fn local_command() -> AgentServerCommand { - let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) - .join("../../../codex/codex-rs/target/debug/codex"); - - AgentServerCommand { - path: cli_path, - args: vec![], - env: None, - } - } -} diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index e9c72eabc9..16bf1e6b47 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -375,9 +375,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc { gemini: Some(AgentServerSettings { command: crate::gemini::tests::local_command(), }), - codex: Some(AgentServerSettings { - command: crate::codex::tests::local_command(), - }), }, cx, ); diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index a97ff3f462..372ce76aa9 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,14 +1,10 @@ -use anyhow::anyhow; -use std::cell::RefCell; use std::path::Path; use std::rc::Rc; -use util::ResultExt as _; -use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; -use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; -use agentic_coding_protocol as acp_old; -use anyhow::{Context as _, Result}; -use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use crate::{AgentServer, AgentServerCommand, acp_connection::AcpConnection}; +use acp_thread::AgentConnection; +use anyhow::Result; +use gpui::{Entity, Task}; use project::Project; use settings::SettingsStore; use ui::App; @@ -43,146 +39,27 @@ impl AgentServer for Gemini { project: &Entity, cx: &mut App, ) -> Task>> { - let root_dir = root_dir.to_path_buf(); let project = project.clone(); - let this = self.clone(); - let name = self.name(); - + let server_name = self.name(); + let root_dir = root_dir.to_path_buf(); cx.spawn(async move |cx| { - let command = this.command(&project, cx).await?; + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() + })?; - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; + let Some(command) = + AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await + else { + anyhow::bail!("Failed to find gemini binary"); + }; + // todo! check supported version - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - let foreground_executor = cx.foreground_executor().clone(); - - let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); - - let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(AgentServerVersion::Unsupported { - error_message, - upgrade_message, - upgrade_command, - }) = this.version(&command).await.log_err() - { - Err(anyhow!(LoadError::Unsupported { - error_message, - upgrade_message, - upgrade_command - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } - } - }; - drop(io_task); - result - }); - - let connection: Rc = Rc::new(OldAcpAgentConnection { - name, - connection, - child_status, - current_thread: thread_rc, - }); - - Ok(connection) + let conn = AcpConnection::stdio(server_name, command, &root_dir, cx).await?; + Ok(Rc::new(conn) as _) }) } } -impl Gemini { - async fn command( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).gemini.clone() - })?; - - if let Some(command) = - AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await - { - return Ok(command); - }; - - let (fs, node_runtime) = project.update(cx, |project, _| { - (project.fs().clone(), project.node_runtime().cloned()) - })?; - let node_runtime = node_runtime.context("gemini not found on path")?; - - let directory = ::paths::agent_servers_dir().join("gemini"); - fs.create_dir(&directory).await?; - node_runtime - .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) - .await?; - let path = directory.join("node_modules/.bin/gemini"); - - Ok(AgentServerCommand { - path, - args: vec![ACP_ARG.into()], - env: None, - }) - } - - async fn version(&self, command: &AgentServerCommand) -> Result { - let version_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--version") - .kill_on_drop(true) - .output(); - - let help_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--help") - .kill_on_drop(true) - .output(); - - let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; - - let current_version = String::from_utf8(version_output?.stdout)?; - let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); - - if supported { - Ok(AgentServerVersion::Supported) - } else { - Ok(AgentServerVersion::Unsupported { - error_message: format!( - "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", - current_version - ).into(), - upgrade_message: "Upgrade Gemini to Latest".into(), - upgrade_command: "npm install -g @google/gemini-cli@latest".into(), - }) - } - } -} - #[cfg(test)] pub(crate) mod tests { use super::*; @@ -199,7 +76,7 @@ pub(crate) mod tests { AgentServerCommand { path: "node".into(), - args: vec![cli_path, ACP_ARG.into()], + args: vec![cli_path], env: None, } } diff --git a/crates/agent_servers/src/mcp_server.rs b/crates/agent_servers/src/mcp_server.rs deleted file mode 100644 index 055b89dfe2..0000000000 --- a/crates/agent_servers/src/mcp_server.rs +++ /dev/null @@ -1,207 +0,0 @@ -use acp_thread::AcpThread; -use agent_client_protocol as acp; -use anyhow::Result; -use context_server::listener::{McpServerTool, ToolResponse}; -use context_server::types::{ - Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, - ToolsCapabilities, requests, -}; -use futures::channel::oneshot; -use gpui::{App, AsyncApp, Task, WeakEntity}; -use indoc::indoc; - -pub struct ZedMcpServer { - server: context_server::listener::McpServer, -} - -pub const SERVER_NAME: &str = "zed"; - -impl ZedMcpServer { - pub async fn new( - thread_rx: watch::Receiver>, - cx: &AsyncApp, - ) -> Result { - let mut mcp_server = context_server::listener::McpServer::new(cx).await?; - mcp_server.handle_request::(Self::handle_initialize); - - mcp_server.add_tool(RequestPermissionTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(ReadTextFileTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(WriteTextFileTool { - thread_rx: thread_rx.clone(), - }); - - Ok(Self { server: mcp_server }) - } - - pub fn server_config(&self) -> Result { - #[cfg(not(test))] - let zed_path = anyhow::Context::context( - std::env::current_exe(), - "finding current executable path for use in mcp_server", - )?; - - #[cfg(test)] - let zed_path = crate::e2e_tests::get_zed_path(); - - Ok(acp::McpServerConfig { - command: zed_path, - args: vec![ - "--nc".into(), - self.server.socket_path().display().to_string(), - ], - env: None, - }) - } - - fn handle_initialize(_: InitializeParams, cx: &App) -> Task> { - cx.foreground_executor().spawn(async move { - Ok(InitializeResponse { - protocol_version: ProtocolVersion("2025-06-18".into()), - capabilities: ServerCapabilities { - experimental: None, - logging: None, - completions: None, - prompts: None, - resources: None, - tools: Some(ToolsCapabilities { - list_changed: Some(false), - }), - }, - server_info: Implementation { - name: SERVER_NAME.into(), - version: "0.1.0".into(), - }, - meta: None, - }) - }) - } -} - -// Tools - -#[derive(Clone)] -pub struct RequestPermissionTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for RequestPermissionTool { - type Input = acp::RequestPermissionArguments; - type Output = acp::RequestPermissionOutput; - - const NAME: &'static str = "Confirmation"; - - fn description(&self) -> &'static str { - indoc! {" - Request permission for tool calls. - - This tool is meant to be called programmatically by the agent loop, not the LLM. - "} - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let result = thread - .update(cx, |thread, cx| { - thread.request_tool_call_permission(input.tool_call, input.options, cx) - })? - .await; - - let outcome = match result { - Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id }, - Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled, - }; - - Ok(ToolResponse { - content: vec![], - structured_content: acp::RequestPermissionOutput { outcome }, - }) - } -} - -#[derive(Clone)] -pub struct ReadTextFileTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for ReadTextFileTool { - type Input = acp::ReadTextFileArguments; - type Output = acp::ReadTextFileOutput; - - const NAME: &'static str = "Read"; - - fn description(&self) -> &'static str { - "Reads the content of the given file in the project including unsaved changes." - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(input.path, input.line, input.limit, false, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![], - structured_content: acp::ReadTextFileOutput { content }, - }) - } -} - -#[derive(Clone)] -pub struct WriteTextFileTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for WriteTextFileTool { - type Input = acp::WriteTextFileArguments; - type Output = (); - - const NAME: &'static str = "Write"; - - fn description(&self) -> &'static str { - "Write to a file replacing its contents" - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - thread - .update(cx, |thread, cx| { - thread.write_text_file(input.path, input.content, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![], - structured_content: (), - }) - } -} diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs index aeb34a5e61..645674b5f1 100644 --- a/crates/agent_servers/src/settings.rs +++ b/crates/agent_servers/src/settings.rs @@ -13,7 +13,6 @@ pub fn init(cx: &mut App) { pub struct AllAgentServersSettings { pub gemini: Option, pub claude: Option, - pub codex: Option, } #[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] @@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings { fn load(sources: SettingsSources, _: &mut App) -> Result { let mut settings = AllAgentServersSettings::default(); - for AllAgentServersSettings { - gemini, - claude, - codex, - } in sources.defaults_and_customizations() - { + for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() { if gemini.is_some() { settings.gemini = gemini.clone(); } if claude.is_some() { settings.claude = claude.clone(); } - if codex.is_some() { - settings.codex = codex.clone(); - } } Ok(settings) diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 95fd2b1757..c145df0eae 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -19,6 +19,7 @@ test-support = ["gpui/test-support", "language/test-support"] acp_thread.workspace = true agent-client-protocol.workspace = true agent.workspace = true +agent2.workspace = true agent_servers.workspace = true agent_settings.workspace = true ai_onboarding.workspace = true diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 8820e4a73d..26166b6960 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -232,7 +232,8 @@ impl AcpThreadView { { Err(e) => { let mut cx = cx.clone(); - if e.downcast_ref::().is_some() { + // todo! remove duplication + if e.downcast_ref::().is_some() { this.update(&mut cx, |this, cx| { this.thread_state = ThreadState::Unauthenticated { connection }; cx.notify(); @@ -675,13 +676,18 @@ impl AcpThreadView { Some(entry.diffs().map(|diff| diff.multibuffer.clone())) } - fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { + fn authenticate( + &mut self, + method: acp::AuthMethodId, + window: &mut Window, + cx: &mut Context, + ) { let ThreadState::Unauthenticated { ref connection } = self.thread_state else { return; }; self.last_error.take(); - let authenticate = connection.authenticate(cx); + let authenticate = connection.authenticate(method, cx); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); let agent = self.agent.clone(); @@ -2380,22 +2386,26 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::open_agent_diff)) .child(match &self.thread_state { - ThreadState::Unauthenticated { .. } => { - v_flex() - .p_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .child( - h_flex().mt_1p5().justify_center().child( - Button::new("sign-in", format!("Sign in to {}", self.agent.name())) - .on_click(cx.listener(|this, _, window, cx| { - this.authenticate(window, cx) - })), - ), - ) - } + ThreadState::Unauthenticated { connection } => v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_pending_auth_state()) + .child(h_flex().mt_1p5().justify_center().children( + connection.auth_methods().into_iter().map(|method| { + Button::new( + SharedString::from(method.id.0.clone()), + method.label.clone(), + ) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) + }) + }), + )), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2() @@ -2834,10 +2844,6 @@ mod tests { } impl AgentConnection for StubAgentConnection { - fn name(&self) -> &'static str { - "StubAgentConnection" - } - fn new_thread( self: Rc, project: Entity, @@ -2853,17 +2859,27 @@ mod tests { .into(), ); let thread = cx - .new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx)) + .new(|cx| { + AcpThread::new("New Thread", self.clone(), project, session_id.clone(), cx) + }) .unwrap(); self.sessions.lock().insert(session_id, thread.downgrade()); Task::ready(Ok(thread)) } - fn authenticate(&self, _cx: &mut App) -> Task> { + fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] { + todo!() + } + + fn authenticate( + &self, + _method: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { unimplemented!() } - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let sessions = self.sessions.lock(); let thread = sessions.get(¶ms.session_id).unwrap(); let mut tasks = vec![]; @@ -2910,10 +2926,6 @@ mod tests { struct SaboteurAgentConnection; impl AgentConnection for SaboteurAgentConnection { - fn name(&self) -> &'static str { - "SaboteurAgentConnection" - } - fn new_thread( self: Rc, project: Entity, @@ -2921,15 +2933,23 @@ mod tests { cx: &mut gpui::AsyncApp, ) -> Task>> { Task::ready(Ok(cx - .new(|cx| AcpThread::new(self, project, SessionId("test".into()), cx)) + .new(|cx| AcpThread::new("New Thread", self, project, SessionId("test".into()), cx)) .unwrap())) } - fn authenticate(&self, _cx: &mut App) -> Task> { + fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] { + todo!() + } + + fn authenticate( + &self, + _method: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { unimplemented!() } - fn prompt(&self, _params: acp::PromptArguments, _cx: &mut App) -> Task> { + fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task> { Task::ready(Err(anyhow::anyhow!("Error prompting"))) } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index fcb8dfbac2..5813b48f6e 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1954,54 +1954,54 @@ impl AgentPanel { this } }) - .when(cx.has_flag::(), |this| { - this.separator() - .header("External Agents") - .item( - ContextMenuEntry::new("New Gemini Thread") - .icon(IconName::AiGemini) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some(crate::ExternalAgent::Gemini), - } - .boxed_clone(), - cx, - ); - }), - ) - .item( - ContextMenuEntry::new("New Claude Code Thread") - .icon(IconName::AiClaude) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::ClaudeCode, - ), - } - .boxed_clone(), - cx, - ); - }), - ) - .item( - ContextMenuEntry::new("New Codex Thread") - .icon(IconName::AiOpenAi) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some(crate::ExternalAgent::Codex), - } - .boxed_clone(), - cx, - ); - }), - ) - }); + // Temporarily removed feature flag check for testing + // .when(cx.has_flag::(), |this| { + // this + .separator() + .header("External Agents") + .item( + ContextMenuEntry::new("New Gemini Thread") + .icon(IconName::AiGemini) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::Gemini), + } + .boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Claude Code Thread") + .icon(IconName::AiClaude) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::ClaudeCode), + } + .boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Native Agent Thread") + .icon(IconName::ZedAssistant) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::NativeAgent), + } + .boxed_clone(), + cx, + ); + }), + ); + // }); menu })) } @@ -2608,82 +2608,87 @@ impl AgentPanel { ), ), ) - .when(cx.has_flag::(), |this| { - this.child( - h_flex() - .w_full() - .gap_2() - .child( - NewThreadButton::new( - "new-gemini-thread-btn", - "New Gemini Thread", - IconName::AiGemini, - ) - // .keybinding(KeyBinding::for_action_in( - // &OpenHistory, - // &self.focus_handle(cx), - // window, - // cx, - // )) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::Gemini, - ), - }), - cx, - ) - }, - ), + // Temporarily removed feature flag check for testing + // .when(cx.has_flag::(), |this| { + // this + .child( + h_flex() + .w_full() + .gap_2() + .child( + NewThreadButton::new( + "new-gemini-thread-btn", + "New Gemini Thread", + IconName::AiGemini, ) - .child( - NewThreadButton::new( - "new-claude-thread-btn", - "New Claude Code Thread", - IconName::AiClaude, - ) - // .keybinding(KeyBinding::for_action_in( - // &OpenHistory, - // &self.focus_handle(cx), - // window, - // cx, - // )) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::ClaudeCode, - ), - }), - cx, - ) - }, - ), - ) - .child( - NewThreadButton::new( - "new-codex-thread-btn", - "New Codex Thread", - IconName::AiOpenAi, - ) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::Codex, - ), - }), - cx, - ) - }, - ), + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some(crate::ExternalAgent::Gemini), + }), + cx, + ) + }, ), - ) - }), + ) + .child( + NewThreadButton::new( + "new-claude-thread-btn", + "New Claude Code Thread", + IconName::AiClaude, + ) + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::ClaudeCode, + ), + }), + cx, + ) + }, + ), + ) + .child( + NewThreadButton::new( + "new-native-agent-thread-btn", + "New Native Agent Thread", + IconName::ZedAssistant, + ) + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::NativeAgent, + ), + }), + cx, + ) + }, + ), + ), + ), // }) ) .when_some(configuration_error.as_ref(), |this, err| { this.child(self.render_configuration_error(err, &focus_handle, window, cx)) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 0800031abe..adbecb75cb 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -150,7 +150,7 @@ enum ExternalAgent { #[default] Gemini, ClaudeCode, - Codex, + NativeAgent, } impl ExternalAgent { @@ -158,7 +158,7 @@ impl ExternalAgent { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), - ExternalAgent::Codex => Rc::new(agent_servers::Codex), + ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer), } } } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 1eb29bbbf9..65283afa87 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -441,14 +441,12 @@ impl Client { Ok(()) } - #[allow(unused)] - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { - self.notification_handlers - .lock() - .insert(method, Box::new(f)); + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { + self.notification_handlers.lock().insert(method, f); } } diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index e76e7972f7..34fa29678d 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -95,8 +95,28 @@ impl ContextServer { self.client.read().clone() } - pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { - let client = match &self.configuration { + pub async fn start(&self, cx: &AsyncApp) -> Result<()> { + self.initialize(self.new_client(cx)?).await + } + + /// Starts the context server, making sure handlers are registered before initialization happens + pub async fn start_with_handlers( + &self, + notification_handlers: Vec<( + &'static str, + Box, + )>, + cx: &AsyncApp, + ) -> Result<()> { + let client = self.new_client(cx)?; + for (method, handler) in notification_handlers { + client.on_notification(method, handler); + } + self.initialize(client).await + } + + fn new_client(&self, cx: &AsyncApp) -> Result { + Ok(match &self.configuration { ContextServerTransport::Stdio(command, working_directory) => Client::stdio( client::ContextServerId(self.id.0.clone()), client::ModelContextServerBinary { @@ -113,8 +133,7 @@ impl ContextServer { transport.clone(), cx.clone(), )?, - }; - self.initialize(client).await + }) } async fn initialize(&self, client: Client) -> Result<()> { diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs index 34e3a9a78c..0e85fb2129 100644 --- a/crates/context_server/src/listener.rs +++ b/crates/context_server/src/listener.rs @@ -83,14 +83,18 @@ impl McpServer { } pub fn add_tool(&mut self, tool: T) { - let output_schema = schemars::schema_for!(T::Output); - let unit_schema = schemars::schema_for!(()); + let mut settings = schemars::generate::SchemaSettings::draft07(); + settings.inline_subschemas = true; + let mut generator = settings.into_generator(); + + let output_schema = generator.root_schema_for::(); + let unit_schema = generator.root_schema_for::(); let registered_tool = RegisteredTool { tool: Tool { name: T::NAME.into(), description: Some(tool.description().into()), - input_schema: schemars::schema_for!(T::Input).into(), + input_schema: generator.root_schema_for::().into(), output_schema: if output_schema == unit_schema { None } else { diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 9ccbc8a553..5355f20f62 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -115,10 +115,11 @@ impl InitializedContextServerProtocol { self.inner.notify(T::METHOD, params) } - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { self.inner.on_notification(method, f); } }