Compare commits

...
Sign in to create a new pull request.

30 commits

Author SHA1 Message Date
Nathan Sobo
f81993574e Connect Native Agent responses to UI display
User-visible improvements:
- Native Agent now shows AI responses in the chat interface
- Uses configured default model from settings instead of random selection
- Streams responses in real-time as the model generates them

Technical changes:
- Implemented response stream forwarding from Thread to AcpThread
- Created Session struct to manage Thread and AcpThread together
- Added proper SessionUpdate handling for text chunks and tool calls
- Fixed model selection to use LanguageModelRegistry's default
- Added comprehensive logging for debugging model interactions
- Removed unused cwd parameter - native agent captures context differently than external agents
2025-08-02 09:58:33 -06:00
Nathan Sobo
bc1f861d3f Add Native Agent to UI and implement NativeAgentServer
- Created NativeAgentServer that implements AgentServer trait
- Added NativeAgent to ExternalAgent enum
- Added Native Agent option to both the + menu and empty state view
- Added necessary dependencies (agent_servers, ui) to agent2 crate
- Added agent2 dependency to agent_ui crate
- Temporarily removed feature flag check for testing
2025-08-02 09:17:26 -06:00
Nathan Sobo
4f2d6a9ea9 Rename Agent to NativeAgent and AgentConnection to NativeAgentConnection
- Renamed Agent struct to NativeAgent to better reflect its native implementation
- Renamed AgentConnection to NativeAgentConnection for consistency
- Updated all references and implementations
- Bumped agent-client-protocol version to 0.0.14
2025-08-02 08:59:16 -06:00
Nathan Sobo
604a88f6e3 Add comprehensive test for AgentConnection with ModelSelector
- Add public session_id() method to AcpThread to enable testing
- Fix ModelSelector methods to use async move closures properly to avoid borrow conflicts
- Add test_agent_connection that verifies:
  - Model selector is available for agent2
  - Can list available models
  - Can create threads with default model
  - Can query selected model for a session
  - Can send prompts using the selected model
  - Can cancel sessions
  - Handles errors for invalid sessions
- Remove unnecessary mut keywords from async closures
2025-08-02 08:45:51 -06:00
Nathan Sobo
a4fe8c6972 Add ModelSelector capability to AgentConnection
- Add ModelSelector trait to acp_thread crate with list_models, select_model, and selected_model methods
- Extend AgentConnection trait with optional model_selector() method returning Option<Rc<dyn ModelSelector>>
- Implement ModelSelector for agent2's AgentConnection using LanguageModelRegistry
- Make selected_model field mandatory on Thread struct
- Update Thread::new to require a default_model parameter
- Update agent2 to fetch default model from registry when creating threads
- Fix prompt method to use the thread's selected model directly
- All methods use &mut AsyncApp for async-friendly operations
2025-08-02 08:28:37 -06:00
Nathan Sobo
5d621bef78 WIP 2025-08-02 00:15:22 -06:00
Nathan Sobo
27877325bc Implement agent-client-protocol Agent trait for agent2
Added Agent struct that implements the acp::Agent trait with:
- Complete: initialize (protocol handshake) and authenticate (stub auth)
- Partial: new_session (creates ID but needs GPUI context for Thread)
- Partial: cancelled (removes session but needs GPUI cleanup)
- Stub: load_session and prompt (need GPUI context integration)

The implementation uses RefCell for session management since trait methods
take &self, and Cell for simple authentication state. Templates are Arc'd
for potential Send requirements.

Next steps:
- Integrate GPUI context for Thread creation/management
- Implement content type conversions between acp and agent2
- Add proper session persistence for load_session
- Stream responses back through the protocol
2025-08-01 23:22:21 -06:00
Nathan Sobo
9e1c7fdfea Rename agent to thread in tests to avoid confusion
Preparing for the introduction of a new Agent type that will implement
the agent-client-protocol Agent trait. The existing Thread type represents
individual conversation sessions, while Agent will manage multiple sessions.
2025-08-01 23:14:21 -06:00
Nathan Sobo
387ee1be8d Clean up warnings in agent2
- Remove underscore prefix from BasePrompt struct name
- Remove unused imports and variables in tests
- Fix unused parameter warning in async closure
- Rename AgentToolErased to AnyAgentTool for clarity
2025-08-01 22:54:34 -06:00
Nathan Sobo
afc8cf6098 Refactor agent2 tests structure
- Move tests from thread/tests.rs to tests/mod.rs
- Move test_tools from thread/tests/test_tools.rs to tests/test_tools.rs
- Update imports and fix compilation errors in tests
- Fix private field access by using public messages() method
- Add necessary imports for test modules
2025-08-01 22:48:35 -06:00
Nathan Sobo
84d6a0fae9 Fix agent2 compilation errors and warnings
- Add cloud_llm_client dependency for CompletionIntent and CompletionMode
- Fix LanguageModelRequest initialization with missing thinking_allowed field
- Update StartMessage handling to use Assistant role
- Fix MessageContent conversions to use enum variants directly
- Fix input_schema implementation to use schemars directly
- Suppress unused variable and dead code warnings
2025-08-01 22:39:08 -06:00
Nathan Sobo
afb5c4147a WIP: Add agent2 crate from test-driven-agent branch 2025-08-01 22:32:36 -06:00
Nathan Sobo
8890f590b1 Fix some breakages against agent-client-protocol/main 2025-08-01 22:25:04 -06:00
Agus Zubiaga
76f0f9163d Merge branch 'main' into new-acp 2025-08-01 18:17:48 -03:00
Agus Zubiaga
8acb58b6e5 Fix types 2025-08-01 18:15:26 -03:00
Agus Zubiaga
8563ed2252 Compiling 2025-08-01 17:07:50 -03:00
Agus Zubiaga
02d3043ec5 Rename arg to experimental-mcp 2025-07-30 14:28:01 -03:00
Agus Zubiaga
30739041a4 Merge branch 'mcp-acp-gemini' of github.com:zed-industries/zed into mcp-acp-gemini 2025-07-30 13:33:03 -03:00
Agus Zubiaga
27708143ec Fix auth 2025-07-30 13:30:50 -03:00
Agus Zubiaga
738296345e Inline tool schemas 2025-07-30 11:46:11 -03:00
Ben Brandt
81c111510f
Refactor handling of ContextServer notifications
The notification handler registration is now more explicit, with
handlers set up before server initialization to avoid potential race
conditions.
2025-07-30 15:48:40 +02:00
Ben Brandt
f028ca4d1a
Merge branch 'main' into mcp-acp-gemini 2025-07-30 12:24:01 +02:00
Agus Zubiaga
6656403ce8 Auth WIP 2025-07-29 21:15:00 -03:00
Ben Brandt
254c6be42b
Fix broken test 2025-07-29 10:12:57 +02:00
Ben Brandt
745e4b5f1e
Merge branch 'main' into mcp-acp-gemini 2025-07-29 10:10:28 +02:00
Agus Zubiaga
912ab505b2 Connect to gemini over MCP 2025-07-28 20:04:32 -03:00
Agus Zubiaga
b48faddaf4 Restore gemini change 2025-07-28 18:45:05 -03:00
Agus Zubiaga
477731d77d Merge branch 'main' into mcp-acp-gemini 2025-07-28 18:43:25 -03:00
Agus Zubiaga
ced3d09f10 Extract acp_connection 2025-07-28 18:43:01 -03:00
Agus Zubiaga
0f395df9a8 Update to new schema 2025-07-28 18:02:21 -03:00
37 changed files with 2347 additions and 917 deletions

62
Cargo.lock generated
View file

@ -19,6 +19,7 @@ dependencies = [
"indoc", "indoc",
"itertools 0.14.0", "itertools 0.14.0",
"language", "language",
"language_model",
"markdown", "markdown",
"project", "project",
"serde", "serde",
@ -137,15 +138,57 @@ dependencies = [
[[package]] [[package]]
name = "agent-client-protocol" name = "agent-client-protocol"
version = "0.0.11" version = "0.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b"
dependencies = [ dependencies = [
"anyhow",
"futures 0.3.31",
"log",
"parking_lot",
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
] ]
[[package]]
name = "agent2"
version = "0.1.0"
dependencies = [
"acp_thread",
"agent-client-protocol",
"agent_servers",
"anyhow",
"assistant_tool",
"assistant_tools",
"chrono",
"client",
"cloud_llm_client",
"collections",
"ctor",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"language_model",
"language_models",
"log",
"parking_lot",
"project",
"reqwest_client",
"rust-embed",
"schemars",
"serde",
"serde_json",
"settings",
"smol",
"thiserror 2.0.12",
"ui",
"util",
"uuid",
"worktree",
]
[[package]] [[package]]
name = "agent_servers" name = "agent_servers"
version = "0.1.0" version = "0.1.0"
@ -209,6 +252,7 @@ dependencies = [
"acp_thread", "acp_thread",
"agent", "agent",
"agent-client-protocol", "agent-client-protocol",
"agent2",
"agent_servers", "agent_servers",
"agent_settings", "agent_settings",
"ai_onboarding", "ai_onboarding",
@ -9570,9 +9614,9 @@ dependencies = [
[[package]] [[package]]
name = "lock_api" name = "lock_api"
version = "0.4.12" version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"scopeguard", "scopeguard",
@ -11286,9 +11330,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.3" version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13"
dependencies = [ dependencies = [
"lock_api", "lock_api",
"parking_lot_core", "parking_lot_core",
@ -11296,9 +11340,9 @@ dependencies = [
[[package]] [[package]]
name = "parking_lot_core" name = "parking_lot_core"
version = "0.9.10" version = "0.9.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",

View file

@ -4,6 +4,7 @@ members = [
"crates/acp_thread", "crates/acp_thread",
"crates/activity_indicator", "crates/activity_indicator",
"crates/agent", "crates/agent",
"crates/agent2",
"crates/agent_servers", "crates/agent_servers",
"crates/agent_settings", "crates/agent_settings",
"crates/agent_ui", "crates/agent_ui",
@ -228,6 +229,7 @@ edition = "2024"
acp_thread = { path = "crates/acp_thread" } acp_thread = { path = "crates/acp_thread" }
agent = { path = "crates/agent" } agent = { path = "crates/agent" }
agent2 = { path = "crates/agent2" }
activity_indicator = { path = "crates/activity_indicator" } activity_indicator = { path = "crates/activity_indicator" }
agent_ui = { path = "crates/agent_ui" } agent_ui = { path = "crates/agent_ui" }
agent_settings = { path = "crates/agent_settings" } agent_settings = { path = "crates/agent_settings" }
@ -421,7 +423,7 @@ zlog_settings = { path = "crates/zlog_settings" }
# #
agentic-coding-protocol = "0.0.10" agentic-coding-protocol = "0.0.10"
agent-client-protocol = "0.0.11" agent-client-protocol = {path="../agent-client-protocol"}
aho-corasick = "1.1" aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14" any_vec = "0.14"

View file

@ -26,6 +26,7 @@ futures.workspace = true
gpui.workspace = true gpui.workspace = true
itertools.workspace = true itertools.workspace = true
language.workspace = true language.workspace = true
language_model.workspace = true
markdown.workspace = true markdown.workspace = true
project.workspace = true project.workspace = true
serde.workspace = true serde.workspace = true

View file

@ -391,7 +391,7 @@ impl ToolCallContent {
cx: &mut App, cx: &mut App,
) -> Self { ) -> Self {
match content { match content {
acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock { acp::ToolCallContent::Content { content } => Self::ContentBlock {
content: ContentBlock::new(content, &language_registry, cx), content: ContentBlock::new(content, &language_registry, cx),
}, },
acp::ToolCallContent::Diff { diff } => Self::Diff { acp::ToolCallContent::Diff { diff } => Self::Diff {
@ -619,6 +619,7 @@ impl Error for LoadError {}
impl AcpThread { impl AcpThread {
pub fn new( pub fn new(
title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>, connection: Rc<dyn AgentConnection>,
project: Entity<Project>, project: Entity<Project>,
session_id: acp::SessionId, session_id: acp::SessionId,
@ -631,7 +632,7 @@ impl AcpThread {
shared_buffers: Default::default(), shared_buffers: Default::default(),
entries: Default::default(), entries: Default::default(),
plan: Default::default(), plan: Default::default(),
title: connection.name().into(), title: title.into(),
project, project,
send_task: None, send_task: None,
connection, connection,
@ -655,6 +656,10 @@ impl AcpThread {
&self.entries &self.entries
} }
pub fn session_id(&self) -> &acp::SessionId {
&self.session_id
}
pub fn status(&self) -> ThreadStatus { pub fn status(&self) -> ThreadStatus {
if self.send_task.is_some() { if self.send_task.is_some() {
if self.waiting_for_tool_confirmation() { if self.waiting_for_tool_confirmation() {
@ -697,14 +702,14 @@ impl AcpThread {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Result<()> { ) -> Result<()> {
match update { match update {
acp::SessionUpdate::UserMessage(content_block) => { acp::SessionUpdate::UserMessageChunk { content } => {
self.push_user_content_block(content_block, cx); self.push_user_content_block(content, cx);
} }
acp::SessionUpdate::AgentMessageChunk(content_block) => { acp::SessionUpdate::AgentMessageChunk { content } => {
self.push_assistant_content_block(content_block, false, cx); self.push_assistant_content_block(content, false, cx);
} }
acp::SessionUpdate::AgentThoughtChunk(content_block) => { acp::SessionUpdate::AgentThoughtChunk { content } => {
self.push_assistant_content_block(content_block, true, cx); self.push_assistant_content_block(content, true, cx);
} }
acp::SessionUpdate::ToolCall(tool_call) => { acp::SessionUpdate::ToolCall(tool_call) => {
self.upsert_tool_call(tool_call, cx); self.upsert_tool_call(tool_call, cx);
@ -973,10 +978,6 @@ impl AcpThread {
cx.notify(); cx.notify();
} }
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
self.connection.authenticate(cx)
}
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub fn send_raw( pub fn send_raw(
&mut self, &mut self,
@ -1018,7 +1019,7 @@ impl AcpThread {
let result = this let result = this
.update(cx, |this, cx| { .update(cx, |this, cx| {
this.connection.prompt( this.connection.prompt(
acp::PromptArguments { acp::PromptRequest {
prompt: message, prompt: message,
session_id: this.session_id.clone(), session_id: this.session_id.clone(),
}, },
@ -1620,9 +1621,15 @@ mod tests {
connection, connection,
child_status: io_task, child_status: io_task,
current_thread: thread_rc, current_thread: thread_rc,
auth_methods: [acp::AuthMethod {
id: acp::AuthMethodId("acp-old-no-id".into()),
label: "Log in".into(),
description: None,
}],
}; };
AcpThread::new( AcpThread::new(
"test",
Rc::new(connection), Rc::new(connection),
project, project,
acp::SessionId("test".into()), acp::SessionId("test".into()),

View file

@ -1,16 +1,62 @@
use std::{path::Path, rc::Rc}; use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use agent_client_protocol as acp; use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;
use gpui::{AsyncApp, Entity, Task}; use gpui::{AsyncApp, Entity, Task};
use language_model::LanguageModel;
use project::Project; use project::Project;
use ui::App; use ui::App;
use crate::AcpThread; use crate::AcpThread;
pub trait AgentConnection { /// Trait for agents that support listing, selecting, and querying language models.
fn name(&self) -> &'static str; ///
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
pub trait ModelSelector: 'static {
/// Lists all available language models for this agent.
///
/// # Parameters
/// - `cx`: The GPUI app context for async operations and global access.
///
/// # Returns
/// A task resolving to the list of models or an error (e.g., if no models are configured).
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
/// Selects a model for a specific session (thread).
///
/// This sets the default model for future interactions in the session.
/// If the session doesn't exist or the model is invalid, it returns an error.
///
/// # Parameters
/// - `session_id`: The ID of the session (thread) to apply the model to.
/// - `model`: The model to select (should be one from [list_models]).
/// - `cx`: The GPUI app context.
///
/// # Returns
/// A task resolving to `Ok(())` on success or an error.
fn select_model(
&self,
session_id: acp::SessionId,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
) -> Task<Result<()>>;
/// Retrieves the currently selected model for a specific session (thread).
///
/// # Parameters
/// - `session_id`: The ID of the session (thread) to query.
/// - `cx`: The GPUI app context.
///
/// # Returns
/// A task resolving to the selected model (always set) or an error (e.g., session not found).
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>>;
}
pub trait AgentConnection {
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
@ -18,9 +64,29 @@ pub trait AgentConnection {
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>; ) -> Task<Result<Entity<AcpThread>>>;
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>; fn auth_methods(&self) -> &[acp::AuthMethod];
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>; fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
///
/// If the agent does not support model selection, returns [None].
/// This allows sharing the selector in UI components.
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
None // Default impl for agents that don't support it
}
}
#[derive(Debug)]
pub struct AuthRequired;
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired")
}
} }

View file

@ -5,11 +5,11 @@ use anyhow::{Context as _, Result};
use futures::channel::oneshot; use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project; use project::Project;
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc}; use std::{cell::RefCell, path::Path, rc::Rc};
use ui::App; use ui::App;
use util::ResultExt as _; use util::ResultExt as _;
use crate::{AcpThread, AgentConnection}; use crate::{AcpThread, AgentConnection, AuthRequired};
#[derive(Clone)] #[derive(Clone)]
pub struct OldAcpClientDelegate { pub struct OldAcpClientDelegate {
@ -351,28 +351,15 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
} }
} }
#[derive(Debug)]
pub struct Unauthenticated;
impl Error for Unauthenticated {}
impl fmt::Display for Unauthenticated {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unauthenticated")
}
}
pub struct OldAcpAgentConnection { pub struct OldAcpAgentConnection {
pub name: &'static str, pub name: &'static str,
pub connection: acp_old::AgentConnection, pub connection: acp_old::AgentConnection,
pub child_status: Task<Result<()>>, pub child_status: Task<Result<()>>,
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>, pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
pub auth_methods: [acp::AuthMethod; 1],
} }
impl AgentConnection for OldAcpAgentConnection { impl AgentConnection for OldAcpAgentConnection {
fn name(&self) -> &'static str {
self.name
}
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
@ -391,13 +378,13 @@ impl AgentConnection for OldAcpAgentConnection {
let result = acp_old::InitializeParams::response_from_any(result)?; let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated { if !result.is_authenticated {
anyhow::bail!(Unauthenticated) anyhow::bail!(AuthRequired)
} }
cx.update(|cx| { cx.update(|cx| {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into()); let session_id = acp::SessionId("acp-old-no-id".into());
AcpThread::new(self.clone(), project, session_id, cx) AcpThread::new("Gemini", self.clone(), project, session_id, cx)
}); });
current_thread.replace(thread.downgrade()); current_thread.replace(thread.downgrade());
thread thread
@ -405,7 +392,11 @@ impl AgentConnection for OldAcpAgentConnection {
}) })
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn auth_methods(&self) -> &[acp::AuthMethod] {
&self.auth_methods
}
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let task = self let task = self
.connection .connection
.request_any(acp_old::AuthenticateParams.into_any()); .request_any(acp_old::AuthenticateParams.into_any());
@ -415,7 +406,7 @@ impl AgentConnection for OldAcpAgentConnection {
}) })
} }
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> { fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
let chunks = params let chunks = params
.prompt .prompt
.into_iter() .into_iter()

56
crates/agent2/Cargo.toml Normal file
View file

@ -0,0 +1,56 @@
[package]
name = "agent2"
version = "0.1.0"
edition = "2021"
license = "GPL-3.0-or-later"
publish = false
[lib]
path = "src/agent2.rs"
[lints]
workspace = true
[dependencies]
acp_thread.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
chrono.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
language_model.workspace = true
language_models.workspace = true
log.workspace = true
parking_lot.workspace = true
project.workspace = true
rust-embed.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
thiserror.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
worktree.workspace = true
[dev-dependencies]
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language_model = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }
worktree = { workspace = true, "features" = ["test-support"] }

1
crates/agent2/LICENSE-GPL Symbolic link
View file

@ -0,0 +1 @@
../../LICENSE-GPL

377
crates/agent2/src/agent.rs Normal file
View file

@ -0,0 +1,377 @@
use acp_thread::ModelSelector;
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::{LanguageModel, LanguageModelRegistry};
use project::Project;
use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use crate::{templates::Templates, Thread};
/// Holds both the internal Thread and the AcpThread for a session
#[derive(Clone)]
struct Session {
/// The internal thread that processes messages
thread: Entity<Thread>,
/// The ACP thread that handles protocol communication
acp_thread: Entity<acp_thread::AcpThread>,
}
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
/// Shared templates for all threads
templates: Arc<Templates>,
}
impl NativeAgent {
pub fn new(templates: Arc<Templates>) -> Self {
log::info!("Creating new NativeAgent");
Self {
sessions: HashMap::new(),
templates,
}
}
}
/// Wrapper struct that implements the AgentConnection trait
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
impl ModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
log::debug!("NativeAgentConnection::list_models called");
cx.spawn(async move |cx| {
cx.update(|cx| {
let registry = LanguageModelRegistry::read_global(cx);
let models = registry.available_models(cx).collect::<Vec<_>>();
log::info!("Found {} available models", models.len());
if models.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(models)
}
})?
})
}
fn select_model(
&self,
session_id: acp::SessionId,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
) -> Task<Result<()>> {
log::info!(
"Setting model for session {}: {:?}",
session_id,
model.name()
);
let agent = self.0.clone();
cx.spawn(async move |cx| {
agent.update(cx, |agent, cx| {
if let Some(session) = agent.sessions.get(&session_id) {
session.thread.update(cx, |thread, _cx| {
thread.selected_model = model;
});
Ok(())
} else {
Err(anyhow!("Session not found"))
}
})?
})
}
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>> {
let agent = self.0.clone();
let session_id = session_id.clone();
cx.spawn(async move |cx| {
let session = agent
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
let selected = session
.thread
.read_with(cx, |thread, _| thread.selected_model.clone())?;
Ok(selected)
})
}
}
impl acp_thread::AgentConnection for NativeAgentConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let agent = self.0.clone();
log::info!("Creating new thread for project at: {:?}", cwd);
cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context");
// Create Thread
let (session_id, thread) = agent.update(
cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx);
// Log available models for debugging
let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count);
let default_model = registry
.default_model()
.map(|configured| {
log::info!(
"Using configured default model: {:?} from provider: {:?}",
configured.model.name(),
configured.provider.name()
);
configured.model
})
.ok_or_else(|| {
log::warn!("No default model configured in settings");
anyhow!("No default model configured. Please configure a default model in settings.")
})?;
let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
// Generate session ID
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
Ok((session_id, thread))
},
)??;
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
})
})?;
// Store the session
agent.update(cx, |agent, _cx| {
agent.sessions.insert(
session_id,
Session {
thread,
acp_thread: acp_thread.clone(),
},
);
})?;
Ok(acp_thread)
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&[] // No auth for in-process
}
fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
let session_id = params.session_id.clone();
let agent = self.0.clone();
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
cx.spawn(async move |cx| {
// Get session
let session = agent
.read_with(cx, |agent, _| {
agent.sessions.get(&session_id).map(|s| Session {
thread: s.thread.clone(),
acp_thread: s.acp_thread.clone(),
})
})?
.ok_or_else(|| {
log::error!("Session not found: {}", session_id);
anyhow::anyhow!("Session not found")
})?;
log::debug!("Found session for: {}", session_id);
// Convert prompt to message
let message = convert_prompt_to_message(params.prompt);
log::info!("Converted prompt to message: {} chars", message.len());
log::debug!("Message content: {}", message);
// Get model using the ModelSelector capability (always available for agent2)
// Get the selected model from the thread directly
let model = session
.thread
.read_with(cx, |thread, _| thread.selected_model.clone())?;
// Send to thread
log::info!("Sending message to thread with model: {:?}", model.name());
let response_stream = session
.thread
.update(cx, |thread, cx| thread.send(model, message, cx))?;
// Handle response stream and forward to session.acp_thread
let acp_thread = session.acp_thread.clone();
cx.spawn(async move |cx| {
use futures::StreamExt;
use language_model::LanguageModelCompletionEvent;
let mut response_stream = response_stream;
while let Some(result) = response_stream.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
LanguageModelCompletionEvent::Text(text) => {
// Send text chunk as agent message
acp_thread.update(cx, |thread, cx| {
thread.handle_session_update(
acp::SessionUpdate::AgentMessageChunk {
content: acp::ContentBlock::Text(
acp::TextContent {
text: text.into(),
annotations: None,
},
),
},
cx,
)
})??;
}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
// Convert LanguageModelToolUse to ACP ToolCall
acp_thread.update(cx, |thread, cx| {
thread.handle_session_update(
acp::SessionUpdate::ToolCall(acp::ToolCall {
id: acp::ToolCallId(tool_use.id.to_string().into()),
label: tool_use.name.to_string(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(tool_use.input),
}),
cx,
)
})??;
}
LanguageModelCompletionEvent::StartMessage { .. } => {
log::debug!("Started new assistant message");
}
LanguageModelCompletionEvent::UsageUpdate(usage) => {
log::debug!("Token usage update: {:?}", usage);
}
LanguageModelCompletionEvent::Thinking { text, .. } => {
// Send thinking text as agent thought chunk
acp_thread.update(cx, |thread, cx| {
thread.handle_session_update(
acp::SessionUpdate::AgentThoughtChunk {
content: acp::ContentBlock::Text(
acp::TextContent {
text: text.into(),
annotations: None,
},
),
},
cx,
)
})??;
}
LanguageModelCompletionEvent::StatusUpdate(status) => {
log::trace!("Status update: {:?}", status);
}
LanguageModelCompletionEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
}
LanguageModelCompletionEvent::RedactedThinking { .. } => {
log::trace!("Redacted thinking event");
}
LanguageModelCompletionEvent::ToolUseJsonParseError {
id,
tool_name,
raw_input,
json_parse_error,
} => {
log::error!(
"Tool use JSON parse error for tool '{}' (id: {}): {} - input: {}",
tool_name,
id,
json_parse_error,
raw_input
);
}
}
}
Err(e) => {
log::error!("Error in model response stream: {:?}", e);
// TODO: Consider sending an error message to the UI
break;
}
}
}
log::info!("Response stream completed");
anyhow::Ok(())
})
.detach();
log::info!("Successfully sent prompt to thread and started response handler");
Ok(())
})
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
log::info!("Cancelling session: {}", session_id);
self.0.update(cx, |agent, _cx| {
agent.sessions.remove(session_id);
});
}
}
/// Convert ACP content blocks to a message string
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
log::debug!("Converting {} content blocks to message", blocks.len());
let mut message = String::new();
for block in blocks {
match block {
acp::ContentBlock::Text(text) => {
log::trace!("Processing text block: {} chars", text.text.len());
message.push_str(&text.text);
}
acp::ContentBlock::ResourceLink(link) => {
log::trace!("Processing resource link: {}", link.uri);
message.push_str(&format!(" @{} ", link.uri));
}
acp::ContentBlock::Image(_) => {
log::trace!("Processing image block");
message.push_str(" [image] ");
}
acp::ContentBlock::Audio(_) => {
log::trace!("Processing audio block");
message.push_str(" [audio] ");
}
acp::ContentBlock::Resource(resource) => {
log::trace!("Processing resource block: {:?}", resource.resource);
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
}
}
}
message
}

View file

@ -0,0 +1,13 @@
mod agent;
mod native_agent_server;
mod prompts;
mod templates;
mod thread;
mod tools;
#[cfg(test)]
mod tests;
pub use agent::*;
pub use native_agent_server::NativeAgentServer;
pub use thread::*;

View file

@ -0,0 +1,58 @@
use std::path::Path;
use std::rc::Rc;
use agent_servers::AgentServer;
use anyhow::Result;
use gpui::{App, AppContext, Entity, Task};
use project::Project;
use crate::{templates::Templates, NativeAgent, NativeAgentConnection};
#[derive(Clone)]
pub struct NativeAgentServer;
impl AgentServer for NativeAgentServer {
fn name(&self) -> &'static str {
"Native Agent"
}
fn empty_state_headline(&self) -> &'static str {
"Native Agent"
}
fn empty_state_message(&self) -> &'static str {
"How can I help you today?"
}
fn logo(&self) -> ui::IconName {
// Using the ZedAssistant icon as it's the native built-in agent
ui::IconName::ZedAssistant
}
fn connect(
&self,
_root_dir: &Path,
_project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
log::info!(
"NativeAgentServer::connect called for path: {:?}",
_root_dir
);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
// Create templates (you might want to load these from files or resources)
let templates = Templates::new();
// Create the native agent
log::debug!("Creating native agent entity");
let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);
log::info!("NativeAgentServer connection established successfully");
Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
})
}
}

View file

@ -0,0 +1,30 @@
use crate::{
templates::{BaseTemplate, Template, Templates, WorktreeData},
thread::Prompt,
};
use anyhow::Result;
use gpui::{App, Entity};
use project::Project;
#[allow(dead_code)]
struct BasePrompt {
project: Entity<Project>,
}
impl Prompt for BasePrompt {
fn render(&self, templates: &Templates, cx: &App) -> Result<String> {
BaseTemplate {
os: std::env::consts::OS.to_string(),
shell: util::get_system_shell(),
worktrees: self
.project
.read(cx)
.worktrees(cx)
.map(|worktree| WorktreeData {
root_name: worktree.read(cx).root_name().to_string(),
})
.collect(),
}
.render(templates)
}
}

View file

@ -0,0 +1,57 @@
use std::sync::Arc;
use anyhow::Result;
use handlebars::Handlebars;
use rust_embed::RustEmbed;
use serde::Serialize;
#[derive(RustEmbed)]
#[folder = "src/templates"]
#[include = "*.hbs"]
struct Assets;
pub struct Templates(Handlebars<'static>);
impl Templates {
pub fn new() -> Arc<Self> {
let mut handlebars = Handlebars::new();
handlebars.register_embed_templates::<Assets>().unwrap();
Arc::new(Self(handlebars))
}
}
pub trait Template: Sized {
const TEMPLATE_NAME: &'static str;
fn render(&self, templates: &Templates) -> Result<String>
where
Self: Serialize + Sized,
{
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
}
}
#[derive(Serialize)]
pub struct BaseTemplate {
pub os: String,
pub shell: String,
pub worktrees: Vec<WorktreeData>,
}
impl Template for BaseTemplate {
const TEMPLATE_NAME: &'static str = "base.hbs";
}
#[derive(Serialize)]
pub struct WorktreeData {
pub root_name: String,
}
#[derive(Serialize)]
pub struct GlobTemplate {
pub project_roots: String,
}
impl Template for GlobTemplate {
const TEMPLATE_NAME: &'static str = "glob.hbs";
}

View file

@ -0,0 +1,56 @@
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
## Communication
1. Be conversational but professional.
2. Refer to the USER in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
## Tool Use
1. Make sure to adhere to the tools schema.
2. Provide every required argument.
3. DO NOT use tools to access items that are already available in the context section.
4. Use only the tools that are currently available.
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
## Searching and Reading
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
If appropriate, use tool calls to explore the current project, which contains the following root directories:
{{#each worktrees}}
- `{{root_name}}`
{{/each}}
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
- When looking for symbols in the project, prefer the `grep` tool.
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- Bias towards not asking the user for help if you can find the answer yourself.
## Fixing Diagnostics
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
## Debugging
When debugging, only make code changes if you are certain that you can solve the problem.
Otherwise, follow debugging best practices:
1. Address the root cause instead of the symptoms.
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
## Calling External APIs
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
## System Information
Operating System: {{os}}
Default Shell: {{shell}}

View file

@ -0,0 +1,8 @@
Find paths on disk with glob patterns.
Assume that all glob patterns are matched in a project directory with the following entries.
{{project_roots}}
When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be
sure to anchor them with one of the directories listed above.

View file

@ -0,0 +1,378 @@
use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection as _;
use agent_client_protocol as acp;
use client::{Client, UserStore};
use gpui::{AppContext, Entity, TestAppContext};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRegistry, MessageContent, StopReason,
};
use project::Project;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use smol::stream::StreamExt;
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
mod test_tools;
use test_tools::*;
#[gpui::test]
async fn test_echo(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx).await;
let events = thread
.update(cx, |thread, cx| {
thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
})
.collect()
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.messages().last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
);
});
assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
}
#[gpui::test]
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx).await;
// Test a tool call that's likely to complete *before* streaming stops.
let events = thread
.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(
model.clone(),
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
cx,
)
})
.collect()
.await;
assert_eq!(
stop_events(events),
vec![StopReason::ToolUse, StopReason::EndTurn]
);
// Test a tool calls that's likely to complete *after* streaming stops.
let events = thread
.update(cx, |thread, cx| {
thread.remove_tool(&AgentTool::name(&EchoTool));
thread.add_tool(DelayTool);
thread.send(
model.clone(),
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
cx,
)
})
.collect()
.await;
assert_eq!(
stop_events(events),
vec![StopReason::ToolUse, StopReason::EndTurn]
);
thread.update(cx, |thread, _cx| {
assert!(thread
.messages()
.last()
.unwrap()
.content
.iter()
.any(|content| {
if let MessageContent::Text(text) = content {
text.contains("Ding")
} else {
false
}
}));
});
}
#[gpui::test]
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx).await;
// Test a tool call that's likely to complete *before* streaming stops.
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(WordListTool);
thread.send(model.clone(), "Test the word_list tool.", cx)
});
let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await {
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message
let last_content = thread.messages().last().unwrap().content.last().unwrap();
if let MessageContent::ToolUse(last_tool_use) = last_content {
assert_eq!(last_tool_use.name.as_ref(), "word_list");
if tool_use_event.is_input_complete {
last_tool_use
.input
.get("a")
.expect("'a' has streamed because input is now complete");
last_tool_use
.input
.get("g")
.expect("'g' has streamed because input is now complete");
} else {
if !last_tool_use.is_input_complete
&& last_tool_use.input.get("g").is_none()
{
saw_partial_tool_use = true;
}
}
} else {
panic!("last content should be a tool use");
}
});
}
}
assert!(
saw_partial_tool_use,
"should see at least one partially streamed tool use in the history"
);
}
#[gpui::test]
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx).await;
// Test concurrent tool calls with different delay times
let events = thread
.update(cx, |thread, cx| {
thread.add_tool(DelayTool);
thread.send(
model.clone(),
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
cx,
)
})
.collect()
.await;
let stop_reasons = stop_events(events);
if stop_reasons.len() == 2 {
assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
} else if stop_reasons.len() == 3 {
assert_eq!(
stop_reasons,
vec![
StopReason::ToolUse,
StopReason::ToolUse,
StopReason::EndTurn
]
);
} else {
panic!("Expected either 1 or 2 tool uses followed by end turn");
}
thread.update(cx, |thread, _cx| {
let last_message = thread.messages().last().unwrap();
let text = last_message
.content
.iter()
.filter_map(|content| {
if let MessageContent::Text(text) = content {
Some(text.as_str())
} else {
None
}
})
.collect::<String>();
assert!(text.contains("Ding"));
});
}
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.executor().allow_parking();
cx.update(settings::init);
let templates = Templates::new();
// Initialize language model system with test provider
cx.update(|cx| {
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
// Initialize project settings
Project::init_settings(cx);
// Use test registry with fake provider
LanguageModelRegistry::test(cx);
});
// Create agent and connection
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
let connection = NativeAgentConnection(agent.clone());
// Test model_selector returns Some
let selector_opt = connection.model_selector();
assert!(
selector_opt.is_some(),
"agent2 should always support ModelSelector"
);
let selector = selector_opt.unwrap();
// Test list_models
let listed_models = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.list_models(&mut async_cx)
})
.await
.expect("list_models should succeed");
assert!(!listed_models.is_empty(), "should have at least one model");
assert_eq!(listed_models[0].id().0, "fake");
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
// Create a thread using new_thread
let cwd = Path::new("/test");
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| {
let mut async_cx = cx.to_async();
connection_rc.new_thread(project, cwd, &mut async_cx)
})
.await
.expect("new_thread should succeed");
// Get the session_id from the AcpThread
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
// Test selected_model returns the default
let selected = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&session_id, &mut async_cx)
})
.await
.expect("selected_model should succeed");
assert_eq!(selected.id().0, "fake", "should return default model");
// The thread was created via prompt with the default model
// We can verify it through selected_model
// Test prompt uses the selected model
let prompt_request = acp::PromptRequest {
session_id: session_id.clone(),
prompt: vec![acp::ContentBlock::Text(acp::TextContent {
text: "Test prompt".into(),
annotations: None,
})],
};
cx.update(|cx| connection.prompt(prompt_request, cx))
.await
.expect("prompt should succeed");
// The prompt was sent successfully
// Test cancel
cx.update(|cx| connection.cancel(&session_id, cx));
// After cancel, selected_model should fail
let result = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&session_id, &mut async_cx)
})
.await;
assert!(result.is_err(), "selected_model should fail after cancel");
// Test error case: invalid session
let invalid_session = acp::SessionId("invalid".into());
let result = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&invalid_session, &mut async_cx)
})
.await;
assert!(result.is_err(), "should fail for invalid session");
if let Err(e) = result {
assert!(
e.to_string().contains("Session not found"),
"should have correct error message"
);
}
}
/// Filters out the stop events for asserting against in tests
fn stop_events(
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> Vec<StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {
LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
_ => None,
})
.collect()
}
struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
}
async fn setup(cx: &mut TestAppContext) -> ThreadTest {
cx.executor().allow_parking();
cx.update(settings::init);
let templates = Templates::new();
let model = cx
.update(|cx| {
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
let models = LanguageModelRegistry::read_global(cx);
let model = models
.available_models(cx)
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
.unwrap();
let provider = models.provider(&model.provider_id()).unwrap();
let authenticated = provider.authenticate(cx);
cx.spawn(async move |_cx| {
authenticated.await.unwrap();
model
})
})
.await;
let thread = cx.new(|_| Thread::new(templates, model.clone()));
ThreadTest { model, thread }
}
#[cfg(test)]
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}

View file

@ -0,0 +1,85 @@
use super::*;
use anyhow::Result;
use gpui::{App, SharedString, Task};
/// A tool that echoes its input
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct EchoToolInput {
/// The text to echo.
text: String,
}
pub struct EchoTool;
impl AgentTool for EchoTool {
type Input = EchoToolInput;
fn name(&self) -> SharedString {
"echo".into()
}
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
Task::ready(Ok(input.text))
}
}
/// A tool that waits for a specified delay
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct DelayToolInput {
/// The delay in milliseconds.
ms: u64,
}
pub struct DelayTool;
impl AgentTool for DelayTool {
type Input = DelayToolInput;
fn name(&self) -> SharedString {
"delay".into()
}
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
where
Self: Sized,
{
cx.foreground_executor().spawn(async move {
smol::Timer::after(Duration::from_millis(input.ms)).await;
Ok("Ding".to_string())
})
}
}
/// A tool that takes an object with map from letters to random words starting with that letter.
/// All fiealds are required! Pass a word for every letter!
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct WordListInput {
/// Provide a random word that starts with A.
a: Option<String>,
/// Provide a random word that starts with B.
b: Option<String>,
/// Provide a random word that starts with C.
c: Option<String>,
/// Provide a random word that starts with D.
d: Option<String>,
/// Provide a random word that starts with E.
e: Option<String>,
/// Provide a random word that starts with F.
f: Option<String>,
/// Provide a random word that starts with G.
g: Option<String>,
}
pub struct WordListTool;
impl AgentTool for WordListTool {
type Input = WordListInput;
fn name(&self) -> SharedString {
"word_list".into()
}
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
Task::ready(Ok("ok".to_string()))
}
}

493
crates/agent2/src/thread.rs Normal file
View file

@ -0,0 +1,493 @@
use crate::templates::Templates;
use anyhow::{anyhow, Result};
use cloud_llm_client::{CompletionIntent, CompletionMode};
use futures::{channel::mpsc, future};
use gpui::{App, Context, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, MessageContent, Role, StopReason,
};
use log;
use schemars::{JsonSchema, Schema};
use serde::Deserialize;
use smol::stream::StreamExt;
use std::{collections::BTreeMap, sync::Arc};
use util::ResultExt;
#[derive(Debug)]
pub struct AgentMessage {
pub role: Role,
pub content: Vec<MessageContent>,
}
pub type AgentResponseEvent = LanguageModelCompletionEvent;
pub trait Prompt {
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
}
pub struct Thread {
messages: Vec<AgentMessage>,
completion_mode: CompletionMode,
/// Holds the task that handles agent interaction until the end of the turn.
/// Survives across multiple requests as the model performs tool calls and
/// we run tools, report their results.
running_turn: Option<Task<()>>,
system_prompts: Vec<Arc<dyn Prompt>>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
// action_log: Entity<ActionLog>,
}
impl Thread {
pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
Self {
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
system_prompts: Vec::new(),
running_turn: None,
tools: BTreeMap::default(),
templates,
selected_model: default_model,
}
}
pub fn set_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
pub fn messages(&self) -> &[AgentMessage] {
&self.messages
}
pub fn add_tool(&mut self, tool: impl AgentTool) {
self.tools.insert(tool.name(), tool.erase());
}
pub fn remove_tool(&mut self, name: &str) -> bool {
self.tools.remove(name).is_some()
}
/// Sending a message results in the model streaming a response, which could include tool calls.
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
pub fn send(
&mut self,
model: Arc<dyn LanguageModel>,
content: impl Into<MessageContent>,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
let content = content.into();
log::info!("Thread::send called with model: {:?}", model.name());
log::debug!("Thread::send content: {:?}", content);
cx.notify();
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let system_message = self.build_system_message(cx);
log::debug!(
"System messages count: {}",
if system_message.is_some() { 1 } else { 0 }
);
self.messages.extend(system_message);
self.messages.push(AgentMessage {
role: Role::User,
content: vec![content],
});
log::info!("Total messages in thread: {}", self.messages.len());
self.running_turn = Some(cx.spawn(async move |thread, cx| {
log::info!("Starting agent turn execution");
let turn_result = async {
// Perform one request, then keep looping if the model makes tool calls.
let mut completion_intent = CompletionIntent::UserPrompt;
loop {
log::debug!(
"Building completion request with intent: {:?}",
completion_intent
);
let request = thread.update(cx, |thread, cx| {
thread.build_completion_request(completion_intent, cx)
})?;
// println!(
// "request: {}",
// serde_json::to_string_pretty(&request).unwrap()
// );
// Stream events, appending to messages and collecting up tool uses.
log::info!("Calling model.stream_completion");
let mut events = model.stream_completion(request, cx).await?;
log::debug!("Stream completion started successfully");
let mut tool_uses = Vec::new();
while let Some(event) = events.next().await {
match event {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
thread
.update(cx, |thread, cx| {
tool_uses.extend(thread.handle_streamed_completion_event(
event,
events_tx.clone(),
cx,
));
})
.ok();
}
Err(error) => {
log::error!("Error in completion stream: {:?}", error);
events_tx.unbounded_send(Err(error)).ok();
break;
}
}
}
// If there are no tool uses, the turn is done.
if tool_uses.is_empty() {
log::info!("No tool uses found, completing turn");
break;
}
log::info!("Found {} tool uses to execute", tool_uses.len());
// If there are tool uses, wait for their results to be
// computed, then send them together in a single message on
// the next loop iteration.
let tool_results = future::join_all(tool_uses).await;
log::debug!("Tool execution completed, {} results", tool_results.len());
thread
.update(cx, |thread, _cx| {
thread.messages.push(AgentMessage {
role: Role::User,
content: tool_results
.into_iter()
.map(MessageContent::ToolResult)
.collect(),
});
})
.ok();
completion_intent = CompletionIntent::ToolResults;
}
Ok(())
}
.await;
if let Err(error) = turn_result {
log::error!("Turn execution failed: {:?}", error);
events_tx.unbounded_send(Err(error)).ok();
} else {
log::info!("Turn execution completed successfully");
}
}));
events_rx
}
pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
log::debug!("Building system message");
let mut system_message = AgentMessage {
role: Role::System,
content: Vec::new(),
};
for prompt in &self.system_prompts {
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
system_message
.content
.push(MessageContent::Text(rendered_prompt));
}
}
let result = (!system_message.content.is_empty()).then_some(system_message);
log::debug!("System message built: {}", result.is_some());
result
}
/// A helper method that's called on every streamed completion event.
/// Returns an optional tool result task, which the main agentic loop in
/// send will send back to the model when it resolves.
fn handle_streamed_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
log::trace!("Handling streamed completion event: {:?}", event);
use LanguageModelCompletionEvent::*;
events_tx.unbounded_send(Ok(event.clone())).ok();
match event {
Text(new_text) => self.handle_text_event(new_text, cx),
Thinking {
text: _text,
signature: _signature,
} => {
todo!()
}
ToolUse(tool_use) => {
return self.handle_tool_use_event(tool_use, cx);
}
StartMessage { .. } => {
self.messages.push(AgentMessage {
role: Role::Assistant,
content: Vec::new(),
});
}
UsageUpdate(_) => {}
Stop(stop_reason) => self.handle_stop_event(stop_reason),
StatusUpdate(_completion_request_status) => {}
RedactedThinking { data: _data } => todo!(),
ToolUseJsonParseError {
id: _id,
tool_name: _tool_name,
raw_input: _raw_input,
json_parse_error: _json_parse_error,
} => todo!(),
}
None
}
fn handle_stop_event(&mut self, stop_reason: StopReason) {
match stop_reason {
StopReason::EndTurn | StopReason::ToolUse => {}
StopReason::MaxTokens => todo!(),
StopReason::Refusal => todo!(),
}
}
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
let last_message = self.last_assistant_message();
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
text.push_str(&new_text);
} else {
last_message.content.push(MessageContent::Text(new_text));
}
cx.notify();
}
fn handle_tool_use_event(
&mut self,
tool_use: LanguageModelToolUse,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
cx.notify();
let last_message = self.last_assistant_message();
// Ensure the last message ends in the current tool use
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
if let MessageContent::ToolUse(last_tool_use) = content {
if last_tool_use.id == tool_use.id {
*last_tool_use = tool_use.clone();
false
} else {
true
}
} else {
true
}
});
if push_new_tool_use {
last_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
}
if !tool_use.is_input_complete {
return None;
}
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
let pending_tool_result = tool.clone().run(tool_use.input, cx);
Some(cx.foreground_executor().spawn(async move {
match pending_tool_result.await {
Ok(tool_output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
output: None,
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
output: None,
},
}
}))
} else {
Some(Task::ready(LanguageModelToolResult {
content: LanguageModelToolResultContent::Text(Arc::from(format!(
"No tool named {} exists",
tool_use.name
))),
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
output: None,
}))
}
}
/// Guarantees the last message is from the assistant and returns a mutable reference.
fn last_assistant_message(&mut self) -> &mut AgentMessage {
if self
.messages
.last()
.map_or(true, |m| m.role != Role::Assistant)
{
self.messages.push(AgentMessage {
role: Role::Assistant,
content: Vec::new(),
});
}
self.messages.last_mut().unwrap()
}
fn build_completion_request(
&self,
completion_intent: CompletionIntent,
cx: &mut App,
) -> LanguageModelRequest {
log::debug!("Building completion request");
log::debug!("Completion intent: {:?}", completion_intent);
log::debug!("Completion mode: {:?}", self.completion_mode);
let messages = self.build_request_messages();
log::info!("Request will include {} messages", messages.len());
let tools: Vec<LanguageModelRequestTool> = self
.tools
.values()
.filter_map(|tool| {
let tool_name = tool.name().to_string();
log::trace!("Including tool: {}", tool_name);
Some(LanguageModelRequestTool {
name: tool_name,
description: tool.description(cx).to_string(),
input_schema: tool
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
.log_err()?,
})
})
.collect();
log::info!("Request includes {} tools", tools.len());
let request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: Some(completion_intent),
mode: Some(self.completion_mode),
messages,
tools,
tool_choice: None,
stop: Vec::new(),
temperature: None,
thinking_allowed: false,
};
log::debug!("Completion request built successfully");
request
}
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
log::trace!(
"Building request messages from {} thread messages",
self.messages.len()
);
let messages = self
.messages
.iter()
.map(|message| {
log::trace!(
" - {} message with {} content items",
match message.role {
Role::System => "System",
Role::User => "User",
Role::Assistant => "Assistant",
},
message.content.len()
);
LanguageModelRequestMessage {
role: message.role,
content: message.content.clone(),
cache: false,
}
})
.collect();
messages
}
}
pub trait AgentTool
where
Self: 'static + Sized,
{
type Input: for<'de> Deserialize<'de> + JsonSchema;
fn name(&self) -> SharedString;
fn description(&self, _cx: &mut App) -> SharedString {
let schema = schemars::schema_for!(Self::Input);
SharedString::new(
schema
.get("description")
.and_then(|description| description.as_str())
.unwrap_or_default(),
)
}
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
schemars::schema_for!(Self::Input)
}
/// Runs the tool with the provided input.
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self)))
}
}
pub struct Erased<T>(T);
pub trait AnyAgentTool {
fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
}
impl<T> AnyAgentTool for Erased<Arc<T>>
where
T: AgentTool,
{
fn name(&self) -> SharedString {
self.0.name()
}
fn description(&self, cx: &mut App) -> SharedString {
self.0.description(cx)
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
Ok(serde_json::to_value(self.0.input_schema(format))?)
}
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
match parsed_input {
Ok(input) => self.0.clone().run(input, cx),
Err(error) => Task::ready(Err(anyhow!(error))),
}
}
}

View file

@ -0,0 +1 @@
mod glob;

View file

@ -0,0 +1,76 @@
use anyhow::{anyhow, Result};
use gpui::{App, AppContext, Entity, SharedString, Task};
use project::Project;
use schemars::JsonSchema;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
use worktree::Snapshot as WorktreeSnapshot;
use crate::{
templates::{GlobTemplate, Template, Templates},
thread::AgentTool,
};
// Description is dynamic, see `fn description` below
#[derive(Deserialize, JsonSchema)]
struct GlobInput {
/// A POSIX glob pattern
glob: SharedString,
}
struct GlobTool {
project: Entity<Project>,
templates: Arc<Templates>,
}
impl AgentTool for GlobTool {
type Input = GlobInput;
fn name(&self) -> SharedString {
"glob".into()
}
fn description(&self, cx: &mut App) -> SharedString {
let project_roots = self
.project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).root_name().into())
.collect::<Vec<String>>()
.join("\n");
GlobTemplate { project_roots }
.render(&self.templates)
.expect("template failed to render")
.into()
}
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
let path_matcher = match PathMatcher::new([&input.glob]) {
Ok(matcher) => matcher,
Err(error) => return Task::ready(Err(anyhow!(error))),
};
let snapshots: Vec<WorktreeSnapshot> = self
.project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
let paths = snapshots.iter().flat_map(|snapshot| {
let root_name = PathBuf::from(snapshot.root_name());
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
});
let output = paths
.map(|path| format!("{}\n", path.display()))
.collect::<String>();
Ok(output)
})
}
}

0
crates/agent_servers/acp Normal file
View file

View file

@ -0,0 +1,245 @@
use agent_client_protocol::{self as acp, Agent as _};
use collections::HashMap;
use futures::channel::oneshot;
use project::Project;
use std::cell::RefCell;
use std::path::Path;
use std::rc::Rc;
use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::AgentServerCommand;
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
pub struct AcpConnection {
server_name: &'static str,
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
_io_task: Task<Result<()>>,
}
pub struct AcpSession {
thread: WeakEntity<AcpThread>,
}
impl AcpConnection {
pub async fn stdio(
server_name: &'static str,
command: AgentServerCommand,
root_dir: &Path,
cx: &mut AsyncApp,
) -> Result<Self> {
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter().map(|arg| arg.as_str()))
.envs(command.env.iter().flatten())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()?;
let stdout = child.stdout.take().expect("Failed to take stdout");
let stdin = child.stdin.take().expect("Failed to take stdin");
let sessions = Rc::new(RefCell::new(HashMap::default()));
let client = ClientDelegate {
sessions: sessions.clone(),
cx: cx.clone(),
};
let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
let foreground_executor = cx.foreground_executor().clone();
move |fut| {
foreground_executor.spawn(fut).detach();
}
});
let io_task = cx.background_spawn(io_task);
let response = connection
.initialize(acp::InitializeRequest {
protocol_version: acp::VERSION,
client_capabilities: acp::ClientCapabilities {
fs: acp::FileSystemCapability {
read_text_file: true,
write_text_file: true,
},
},
})
.await?;
// todo! check version
Ok(Self {
auth_methods: response.auth_methods,
connection: connection.into(),
server_name,
sessions,
_io_task: io_task,
})
}
}
impl AgentConnection for AcpConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let response = conn
.new_session(acp::NewSessionRequest {
// todo! Zed MCP server?
mcp_servers: vec![],
cwd,
})
.await?;
let Some(session_id) = response.session_id else {
anyhow::bail!(AuthRequired);
};
let thread = cx.new(|cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
session_id.clone(),
cx,
)
})?;
let session = AcpSession {
thread: thread.downgrade(),
};
sessions.borrow_mut().insert(session_id, session);
Ok(thread)
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&self.auth_methods
}
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let conn = self.connection.clone();
cx.foreground_executor().spawn(async move {
let result = conn
.authenticate(acp::AuthenticateRequest {
method_id: method_id.clone(),
})
.await?;
Ok(result)
})
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
let conn = self.connection.clone();
cx.foreground_executor()
.spawn(async move { Ok(conn.prompt(params).await?) })
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
let conn = self.connection.clone();
let params = acp::CancelledNotification {
session_id: session_id.clone(),
};
cx.foreground_executor()
.spawn(async move { conn.cancelled(params).await })
.detach();
}
}
struct ClientDelegate {
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
cx: AsyncApp,
}
impl acp::Client for ClientDelegate {
async fn request_permission(
&self,
arguments: acp::RequestPermissionRequest,
) -> Result<acp::RequestPermissionResponse, acp::Error> {
let cx = &mut self.cx.clone();
let result = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
})?
.await;
let outcome = match result {
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
};
Ok(acp::RequestPermissionResponse { outcome })
}
async fn write_text_file(
&self,
arguments: acp::WriteTextFileRequest,
) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
self.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.write_text_file(arguments.path, arguments.content, cx)
})?
.await?;
Ok(())
}
async fn read_text_file(
&self,
arguments: acp::ReadTextFileRequest,
) -> Result<acp::ReadTextFileResponse, acp::Error> {
let cx = &mut self.cx.clone();
let content = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
})?
.await?;
Ok(acp::ReadTextFileResponse { content })
}
async fn session_notification(
&self,
notification: acp::SessionNotification,
) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
let sessions = self.sessions.borrow();
let session = sessions
.get(&notification.session_id)
.context("Failed to get session")?;
session.thread.update(cx, |thread, cx| {
thread.handle_session_update(notification.update, cx)
})??;
Ok(())
}
}

View file

@ -1,14 +1,12 @@
mod acp_connection;
mod claude; mod claude;
mod codex;
mod gemini; mod gemini;
mod mcp_server;
mod settings; mod settings;
#[cfg(test)] #[cfg(test)]
mod e2e_tests; mod e2e_tests;
pub use claude::*; pub use claude::*;
pub use codex::*;
pub use gemini::*; pub use gemini::*;
pub use settings::*; pub use settings::*;
@ -38,7 +36,6 @@ pub trait AgentServer: Send {
fn connect( fn connect(
&self, &self,
// these will go away when old_acp is fully removed
root_dir: &Path, root_dir: &Path,
project: &Entity<Project>, project: &Entity<Project>,
cx: &mut App, cx: &mut App,

View file

@ -70,10 +70,6 @@ struct ClaudeAgentConnection {
} }
impl AgentConnection for ClaudeAgentConnection { impl AgentConnection for ClaudeAgentConnection {
fn name(&self) -> &'static str {
ClaudeCode.name()
}
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection {
} }
}); });
let thread = let thread = cx.new(|cx| {
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?; AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
})?;
thread_tx.send(thread.downgrade())?; thread_tx.send(thread.downgrade())?;
@ -186,11 +183,15 @@ impl AgentConnection for ClaudeAgentConnection {
}) })
} }
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> { fn auth_methods(&self) -> &[acp::AuthMethod] {
&[]
}
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Err(anyhow!("Authentication not supported"))) Task::ready(Err(anyhow!("Authentication not supported")))
} }
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> { fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
let sessions = self.sessions.borrow(); let sessions = self.sessions.borrow();
let Some(session) = sessions.get(&params.session_id) else { let Some(session) = sessions.get(&params.session_id) else {
return Task::ready(Err(anyhow!( return Task::ready(Err(anyhow!(

View file

@ -1,319 +0,0 @@
use agent_client_protocol as acp;
use anyhow::anyhow;
use collections::HashMap;
use context_server::listener::McpServerTool;
use context_server::types::requests;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use project::Project;
use settings::SettingsStore;
use smol::stream::StreamExt as _;
use std::cell::RefCell;
use std::rc::Rc;
use std::{path::Path, sync::Arc};
use util::ResultExt;
use anyhow::{Context, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::mcp_server::ZedMcpServer;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
use acp_thread::{AcpThread, AgentConnection};
#[derive(Clone)]
pub struct Codex;
impl AgentServer for Codex {
fn name(&self) -> &'static str {
"Codex"
}
fn empty_state_headline(&self) -> &'static str {
"Welcome to Codex"
}
fn empty_state_message(&self) -> &'static str {
"What can I help with?"
}
fn logo(&self) -> ui::IconName {
ui::IconName::AiOpenAi
}
fn connect(
&self,
_root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let project = project.clone();
let working_directory = project.read(cx).active_project_directory(cx);
cx.spawn(async move |cx| {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).codex.clone()
})?;
let Some(command) =
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
else {
anyhow::bail!("Failed to find codex binary");
};
let client: Arc<ContextServer> = ContextServer::stdio(
ContextServerId("codex-mcp-server".into()),
ContextServerCommand {
path: command.path,
args: command.args,
env: command.env,
},
working_directory,
)
.into();
ContextServer::start(client.clone(), cx).await?;
let (notification_tx, mut notification_rx) = mpsc::unbounded();
client
.client()
.context("Failed to subscribe")?
.on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
move |notification, _cx| {
let notification_tx = notification_tx.clone();
log::trace!(
"ACP Notification: {}",
serde_json::to_string_pretty(&notification).unwrap()
);
if let Some(notification) =
serde_json::from_value::<acp::SessionNotification>(notification)
.log_err()
{
notification_tx.unbounded_send(notification).ok();
}
}
});
let sessions = Rc::new(RefCell::new(HashMap::default()));
let notification_handler_task = cx.spawn({
let sessions = sessions.clone();
async move |cx| {
while let Some(notification) = notification_rx.next().await {
CodexConnection::handle_session_notification(
notification,
sessions.clone(),
cx,
)
}
}
});
let connection = CodexConnection {
client,
sessions,
_notification_handler_task: notification_handler_task,
};
Ok(Rc::new(connection) as _)
})
}
}
struct CodexConnection {
client: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
_notification_handler_task: Task<()>,
}
struct CodexSession {
thread: WeakEntity<AcpThread>,
cancel_tx: Option<oneshot::Sender<()>>,
_mcp_server: ZedMcpServer,
}
impl AgentConnection for CodexConnection {
fn name(&self) -> &'static str {
"Codex"
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let client = self.client.client();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let client = client.context("MCP server is not initialized yet")?;
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
let response = client
.request::<requests::CallTool>(context_server::types::CallToolParams {
name: acp::NEW_SESSION_TOOL_NAME.into(),
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
mcp_servers: [(
mcp_server::SERVER_NAME.to_string(),
mcp_server.server_config()?,
)]
.into(),
client_tools: acp::ClientTools {
request_permission: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
}),
read_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
}),
write_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
}),
},
cwd,
})?),
meta: None,
})
.await?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
let result = serde_json::from_value::<acp::NewSessionOutput>(
response.structured_content.context("Empty response")?,
)?;
let thread =
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
thread_tx.send(thread.downgrade())?;
let session = CodexSession {
thread: thread.downgrade(),
cancel_tx: None,
_mcp_server: mcp_server,
};
sessions.borrow_mut().insert(result.session_id, session);
Ok(thread)
})
}
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Err(anyhow!("Authentication not supported")))
}
fn prompt(
&self,
params: agent_client_protocol::PromptArguments,
cx: &mut App,
) -> Task<Result<()>> {
let client = self.client.client();
let sessions = self.sessions.clone();
cx.foreground_executor().spawn(async move {
let client = client.context("MCP server is not initialized yet")?;
let (new_cancel_tx, cancel_rx) = oneshot::channel();
{
let mut sessions = sessions.borrow_mut();
let session = sessions
.get_mut(&params.session_id)
.context("Session not found")?;
session.cancel_tx.replace(new_cancel_tx);
}
let result = client
.request_with::<requests::CallTool>(
context_server::types::CallToolParams {
name: acp::PROMPT_TOOL_NAME.into(),
arguments: Some(serde_json::to_value(params)?),
meta: None,
},
Some(cancel_rx),
None,
)
.await;
if let Err(err) = &result
&& err.is::<context_server::client::RequestCanceled>()
{
return Ok(());
}
let response = result?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
Ok(())
})
}
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
let mut sessions = self.sessions.borrow_mut();
if let Some(cancel_tx) = sessions
.get_mut(session_id)
.and_then(|session| session.cancel_tx.take())
{
cancel_tx.send(()).ok();
}
}
}
impl CodexConnection {
pub fn handle_session_notification(
notification: acp::SessionNotification,
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
cx: &mut AsyncApp,
) {
let threads = threads.borrow();
let Some(thread) = threads
.get(&notification.session_id)
.and_then(|session| session.thread.upgrade())
else {
log::error!(
"Thread not found for session ID: {}",
notification.session_id
);
return;
};
thread
.update(cx, |thread, cx| {
thread.handle_session_update(notification.update, cx)
})
.log_err();
}
}
impl Drop for CodexConnection {
fn drop(&mut self) {
self.client.stop().log_err();
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::AgentServerCommand;
use std::path::Path;
crate::common_e2e_tests!(Codex, allow_option_id = "approve");
pub fn local_command() -> AgentServerCommand {
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../../../codex/codex-rs/target/debug/codex");
AgentServerCommand {
path: cli_path,
args: vec![],
env: None,
}
}
}

View file

@ -375,9 +375,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
gemini: Some(AgentServerSettings { gemini: Some(AgentServerSettings {
command: crate::gemini::tests::local_command(), command: crate::gemini::tests::local_command(),
}), }),
codex: Some(AgentServerSettings {
command: crate::codex::tests::local_command(),
}),
}, },
cx, cx,
); );

View file

@ -1,14 +1,10 @@
use anyhow::anyhow;
use std::cell::RefCell;
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
use util::ResultExt as _;
use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; use crate::{AgentServer, AgentServerCommand, acp_connection::AcpConnection};
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; use acp_thread::AgentConnection;
use agentic_coding_protocol as acp_old; use anyhow::Result;
use anyhow::{Context as _, Result}; use gpui::{Entity, Task};
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use ui::App; use ui::App;
@ -43,146 +39,27 @@ impl AgentServer for Gemini {
project: &Entity<Project>, project: &Entity<Project>,
cx: &mut App, cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> { ) -> Task<Result<Rc<dyn AgentConnection>>> {
let root_dir = root_dir.to_path_buf();
let project = project.clone(); let project = project.clone();
let this = self.clone(); let server_name = self.name();
let name = self.name(); let root_dir = root_dir.to_path_buf();
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let command = this.command(&project, cx).await?; let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).gemini.clone()
})?;
let mut child = util::command::new_smol_command(&command.path) let Some(command) =
.args(command.args.iter()) AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
.current_dir(root_dir) else {
.stdin(std::process::Stdio::piped()) anyhow::bail!("Failed to find gemini binary");
.stdout(std::process::Stdio::piped()) };
.stderr(std::process::Stdio::inherit()) // todo! check supported version
.kill_on_drop(true)
.spawn()?;
let stdin = child.stdin.take().unwrap(); let conn = AcpConnection::stdio(server_name, command, &root_dir, cx).await?;
let stdout = child.stdout.take().unwrap(); Ok(Rc::new(conn) as _)
let foreground_executor = cx.foreground_executor().clone();
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
stdin,
stdout,
move |fut| foreground_executor.spawn(fut).detach(),
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
});
let child_status = cx.background_spawn(async move {
let result = match child.status().await {
Err(e) => Err(anyhow!(e)),
Ok(result) if result.success() => Ok(()),
Ok(result) => {
if let Some(AgentServerVersion::Unsupported {
error_message,
upgrade_message,
upgrade_command,
}) = this.version(&command).await.log_err()
{
Err(anyhow!(LoadError::Unsupported {
error_message,
upgrade_message,
upgrade_command
}))
} else {
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
}
}
};
drop(io_task);
result
});
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
name,
connection,
child_status,
current_thread: thread_rc,
});
Ok(connection)
}) })
} }
} }
impl Gemini {
async fn command(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<AgentServerCommand> {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).gemini.clone()
})?;
if let Some(command) =
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
{
return Ok(command);
};
let (fs, node_runtime) = project.update(cx, |project, _| {
(project.fs().clone(), project.node_runtime().cloned())
})?;
let node_runtime = node_runtime.context("gemini not found on path")?;
let directory = ::paths::agent_servers_dir().join("gemini");
fs.create_dir(&directory).await?;
node_runtime
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
.await?;
let path = directory.join("node_modules/.bin/gemini");
Ok(AgentServerCommand {
path,
args: vec![ACP_ARG.into()],
env: None,
})
}
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
let version_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--version")
.kill_on_drop(true)
.output();
let help_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--help")
.kill_on_drop(true)
.output();
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
let current_version = String::from_utf8(version_output?.stdout)?;
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
if supported {
Ok(AgentServerVersion::Supported)
} else {
Ok(AgentServerVersion::Unsupported {
error_message: format!(
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
current_version
).into(),
upgrade_message: "Upgrade Gemini to Latest".into(),
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
})
}
}
}
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use super::*; use super::*;
@ -199,7 +76,7 @@ pub(crate) mod tests {
AgentServerCommand { AgentServerCommand {
path: "node".into(), path: "node".into(),
args: vec![cli_path, ACP_ARG.into()], args: vec![cli_path],
env: None, env: None,
} }
} }

View file

@ -1,207 +0,0 @@
use acp_thread::AcpThread;
use agent_client_protocol as acp;
use anyhow::Result;
use context_server::listener::{McpServerTool, ToolResponse};
use context_server::types::{
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
ToolsCapabilities, requests,
};
use futures::channel::oneshot;
use gpui::{App, AsyncApp, Task, WeakEntity};
use indoc::indoc;
pub struct ZedMcpServer {
server: context_server::listener::McpServer,
}
pub const SERVER_NAME: &str = "zed";
impl ZedMcpServer {
pub async fn new(
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
cx: &AsyncApp,
) -> Result<Self> {
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
mcp_server.add_tool(RequestPermissionTool {
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(ReadTextFileTool {
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(WriteTextFileTool {
thread_rx: thread_rx.clone(),
});
Ok(Self { server: mcp_server })
}
pub fn server_config(&self) -> Result<acp::McpServerConfig> {
#[cfg(not(test))]
let zed_path = anyhow::Context::context(
std::env::current_exe(),
"finding current executable path for use in mcp_server",
)?;
#[cfg(test)]
let zed_path = crate::e2e_tests::get_zed_path();
Ok(acp::McpServerConfig {
command: zed_path,
args: vec![
"--nc".into(),
self.server.socket_path().display().to_string(),
],
env: None,
})
}
fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
cx.foreground_executor().spawn(async move {
Ok(InitializeResponse {
protocol_version: ProtocolVersion("2025-06-18".into()),
capabilities: ServerCapabilities {
experimental: None,
logging: None,
completions: None,
prompts: None,
resources: None,
tools: Some(ToolsCapabilities {
list_changed: Some(false),
}),
},
server_info: Implementation {
name: SERVER_NAME.into(),
version: "0.1.0".into(),
},
meta: None,
})
})
}
}
// Tools
#[derive(Clone)]
pub struct RequestPermissionTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for RequestPermissionTool {
type Input = acp::RequestPermissionArguments;
type Output = acp::RequestPermissionOutput;
const NAME: &'static str = "Confirmation";
fn description(&self) -> &'static str {
indoc! {"
Request permission for tool calls.
This tool is meant to be called programmatically by the agent loop, not the LLM.
"}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let result = thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(input.tool_call, input.options, cx)
})?
.await;
let outcome = match result {
Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id },
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
};
Ok(ToolResponse {
content: vec![],
structured_content: acp::RequestPermissionOutput { outcome },
})
}
}
#[derive(Clone)]
pub struct ReadTextFileTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for ReadTextFileTool {
type Input = acp::ReadTextFileArguments;
type Output = acp::ReadTextFileOutput;
const NAME: &'static str = "Read";
fn description(&self) -> &'static str {
"Reads the content of the given file in the project including unsaved changes."
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.path, input.line, input.limit, false, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: acp::ReadTextFileOutput { content },
})
}
}
#[derive(Clone)]
pub struct WriteTextFileTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for WriteTextFileTool {
type Input = acp::WriteTextFileArguments;
type Output = ();
const NAME: &'static str = "Write";
fn description(&self) -> &'static str {
"Write to a file replacing its contents"
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.path, input.content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: (),
})
}
}

View file

@ -13,7 +13,6 @@ pub fn init(cx: &mut App) {
pub struct AllAgentServersSettings { pub struct AllAgentServersSettings {
pub gemini: Option<AgentServerSettings>, pub gemini: Option<AgentServerSettings>,
pub claude: Option<AgentServerSettings>, pub claude: Option<AgentServerSettings>,
pub codex: Option<AgentServerSettings>,
} }
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] #[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings {
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> { fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
let mut settings = AllAgentServersSettings::default(); let mut settings = AllAgentServersSettings::default();
for AllAgentServersSettings { for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
gemini,
claude,
codex,
} in sources.defaults_and_customizations()
{
if gemini.is_some() { if gemini.is_some() {
settings.gemini = gemini.clone(); settings.gemini = gemini.clone();
} }
if claude.is_some() { if claude.is_some() {
settings.claude = claude.clone(); settings.claude = claude.clone();
} }
if codex.is_some() {
settings.codex = codex.clone();
}
} }
Ok(settings) Ok(settings)

View file

@ -19,6 +19,7 @@ test-support = ["gpui/test-support", "language/test-support"]
acp_thread.workspace = true acp_thread.workspace = true
agent-client-protocol.workspace = true agent-client-protocol.workspace = true
agent.workspace = true agent.workspace = true
agent2.workspace = true
agent_servers.workspace = true agent_servers.workspace = true
agent_settings.workspace = true agent_settings.workspace = true
ai_onboarding.workspace = true ai_onboarding.workspace = true

View file

@ -232,7 +232,8 @@ impl AcpThreadView {
{ {
Err(e) => { Err(e) => {
let mut cx = cx.clone(); let mut cx = cx.clone();
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() { // todo! remove duplication
if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection }; this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify(); cx.notify();
@ -675,13 +676,18 @@ impl AcpThreadView {
Some(entry.diffs().map(|diff| diff.multibuffer.clone())) Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
} }
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn authenticate(
&mut self,
method: acp::AuthMethodId,
window: &mut Window,
cx: &mut Context<Self>,
) {
let ThreadState::Unauthenticated { ref connection } = self.thread_state else { let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
return; return;
}; };
self.last_error.take(); self.last_error.take();
let authenticate = connection.authenticate(cx); let authenticate = connection.authenticate(method, cx);
self.auth_task = Some(cx.spawn_in(window, { self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone(); let project = self.project.clone();
let agent = self.agent.clone(); let agent = self.agent.clone();
@ -2380,22 +2386,26 @@ impl Render for AcpThreadView {
.on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::next_history_message))
.on_action(cx.listener(Self::open_agent_diff)) .on_action(cx.listener(Self::open_agent_diff))
.child(match &self.thread_state { .child(match &self.thread_state {
ThreadState::Unauthenticated { .. } => { ThreadState::Unauthenticated { connection } => v_flex()
v_flex() .p_2()
.p_2() .flex_1()
.flex_1() .items_center()
.items_center() .justify_center()
.justify_center() .child(self.render_pending_auth_state())
.child(self.render_pending_auth_state()) .child(h_flex().mt_1p5().justify_center().children(
.child( connection.auth_methods().into_iter().map(|method| {
h_flex().mt_1p5().justify_center().child( Button::new(
Button::new("sign-in", format!("Sign in to {}", self.agent.name())) SharedString::from(method.id.0.clone()),
.on_click(cx.listener(|this, _, window, cx| { method.label.clone(),
this.authenticate(window, cx) )
})), .on_click({
), let method_id = method.id.clone();
) cx.listener(move |this, _, window, cx| {
} this.authenticate(method_id.clone(), window, cx)
})
})
}),
)),
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
ThreadState::LoadError(e) => v_flex() ThreadState::LoadError(e) => v_flex()
.p_2() .p_2()
@ -2834,10 +2844,6 @@ mod tests {
} }
impl AgentConnection for StubAgentConnection { impl AgentConnection for StubAgentConnection {
fn name(&self) -> &'static str {
"StubAgentConnection"
}
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
@ -2853,17 +2859,27 @@ mod tests {
.into(), .into(),
); );
let thread = cx let thread = cx
.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx)) .new(|cx| {
AcpThread::new("New Thread", self.clone(), project, session_id.clone(), cx)
})
.unwrap(); .unwrap();
self.sessions.lock().insert(session_id, thread.downgrade()); self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread)) Task::ready(Ok(thread))
} }
fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> { fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
todo!()
}
fn authenticate(
&self,
_method: acp::AuthMethodId,
_cx: &mut App,
) -> Task<gpui::Result<()>> {
unimplemented!() unimplemented!()
} }
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> { fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
let sessions = self.sessions.lock(); let sessions = self.sessions.lock();
let thread = sessions.get(&params.session_id).unwrap(); let thread = sessions.get(&params.session_id).unwrap();
let mut tasks = vec![]; let mut tasks = vec![];
@ -2910,10 +2926,6 @@ mod tests {
struct SaboteurAgentConnection; struct SaboteurAgentConnection;
impl AgentConnection for SaboteurAgentConnection { impl AgentConnection for SaboteurAgentConnection {
fn name(&self) -> &'static str {
"SaboteurAgentConnection"
}
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
@ -2921,15 +2933,23 @@ mod tests {
cx: &mut gpui::AsyncApp, cx: &mut gpui::AsyncApp,
) -> Task<gpui::Result<Entity<AcpThread>>> { ) -> Task<gpui::Result<Entity<AcpThread>>> {
Task::ready(Ok(cx Task::ready(Ok(cx
.new(|cx| AcpThread::new(self, project, SessionId("test".into()), cx)) .new(|cx| AcpThread::new("New Thread", self, project, SessionId("test".into()), cx))
.unwrap())) .unwrap()))
} }
fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> { fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
todo!()
}
fn authenticate(
&self,
_method: acp::AuthMethodId,
_cx: &mut App,
) -> Task<gpui::Result<()>> {
unimplemented!() unimplemented!()
} }
fn prompt(&self, _params: acp::PromptArguments, _cx: &mut App) -> Task<gpui::Result<()>> { fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task<gpui::Result<()>> {
Task::ready(Err(anyhow::anyhow!("Error prompting"))) Task::ready(Err(anyhow::anyhow!("Error prompting")))
} }

View file

@ -1954,54 +1954,54 @@ impl AgentPanel {
this this
} }
}) })
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| { // Temporarily removed feature flag check for testing
this.separator() // .when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
.header("External Agents") // this
.item( .separator()
ContextMenuEntry::new("New Gemini Thread") .header("External Agents")
.icon(IconName::AiGemini) .item(
.icon_color(Color::Muted) ContextMenuEntry::new("New Gemini Thread")
.handler(move |window, cx| { .icon(IconName::AiGemini)
window.dispatch_action( .icon_color(Color::Muted)
NewExternalAgentThread { .handler(move |window, cx| {
agent: Some(crate::ExternalAgent::Gemini), window.dispatch_action(
} NewExternalAgentThread {
.boxed_clone(), agent: Some(crate::ExternalAgent::Gemini),
cx, }
); .boxed_clone(),
}), cx,
) );
.item( }),
ContextMenuEntry::new("New Claude Code Thread") )
.icon(IconName::AiClaude) .item(
.icon_color(Color::Muted) ContextMenuEntry::new("New Claude Code Thread")
.handler(move |window, cx| { .icon(IconName::AiClaude)
window.dispatch_action( .icon_color(Color::Muted)
NewExternalAgentThread { .handler(move |window, cx| {
agent: Some( window.dispatch_action(
crate::ExternalAgent::ClaudeCode, NewExternalAgentThread {
), agent: Some(crate::ExternalAgent::ClaudeCode),
} }
.boxed_clone(), .boxed_clone(),
cx, cx,
); );
}), }),
) )
.item( .item(
ContextMenuEntry::new("New Codex Thread") ContextMenuEntry::new("New Native Agent Thread")
.icon(IconName::AiOpenAi) .icon(IconName::ZedAssistant)
.icon_color(Color::Muted) .icon_color(Color::Muted)
.handler(move |window, cx| { .handler(move |window, cx| {
window.dispatch_action( window.dispatch_action(
NewExternalAgentThread { NewExternalAgentThread {
agent: Some(crate::ExternalAgent::Codex), agent: Some(crate::ExternalAgent::NativeAgent),
} }
.boxed_clone(), .boxed_clone(),
cx, cx,
); );
}), }),
) );
}); // });
menu menu
})) }))
} }
@ -2608,82 +2608,87 @@ impl AgentPanel {
), ),
), ),
) )
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| { // Temporarily removed feature flag check for testing
this.child( // .when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
h_flex() // this
.w_full() .child(
.gap_2() h_flex()
.child( .w_full()
NewThreadButton::new( .gap_2()
"new-gemini-thread-btn", .child(
"New Gemini Thread", NewThreadButton::new(
IconName::AiGemini, "new-gemini-thread-btn",
) "New Gemini Thread",
// .keybinding(KeyBinding::for_action_in( IconName::AiGemini,
// &OpenHistory,
// &self.focus_handle(cx),
// window,
// cx,
// ))
.on_click(
|window, cx| {
window.dispatch_action(
Box::new(NewExternalAgentThread {
agent: Some(
crate::ExternalAgent::Gemini,
),
}),
cx,
)
},
),
) )
.child( // .keybinding(KeyBinding::for_action_in(
NewThreadButton::new( // &OpenHistory,
"new-claude-thread-btn", // &self.focus_handle(cx),
"New Claude Code Thread", // window,
IconName::AiClaude, // cx,
) // ))
// .keybinding(KeyBinding::for_action_in( .on_click(
// &OpenHistory, |window, cx| {
// &self.focus_handle(cx), window.dispatch_action(
// window, Box::new(NewExternalAgentThread {
// cx, agent: Some(crate::ExternalAgent::Gemini),
// )) }),
.on_click( cx,
|window, cx| { )
window.dispatch_action( },
Box::new(NewExternalAgentThread {
agent: Some(
crate::ExternalAgent::ClaudeCode,
),
}),
cx,
)
},
),
)
.child(
NewThreadButton::new(
"new-codex-thread-btn",
"New Codex Thread",
IconName::AiOpenAi,
)
.on_click(
|window, cx| {
window.dispatch_action(
Box::new(NewExternalAgentThread {
agent: Some(
crate::ExternalAgent::Codex,
),
}),
cx,
)
},
),
), ),
) )
}), .child(
NewThreadButton::new(
"new-claude-thread-btn",
"New Claude Code Thread",
IconName::AiClaude,
)
// .keybinding(KeyBinding::for_action_in(
// &OpenHistory,
// &self.focus_handle(cx),
// window,
// cx,
// ))
.on_click(
|window, cx| {
window.dispatch_action(
Box::new(NewExternalAgentThread {
agent: Some(
crate::ExternalAgent::ClaudeCode,
),
}),
cx,
)
},
),
)
.child(
NewThreadButton::new(
"new-native-agent-thread-btn",
"New Native Agent Thread",
IconName::ZedAssistant,
)
// .keybinding(KeyBinding::for_action_in(
// &OpenHistory,
// &self.focus_handle(cx),
// window,
// cx,
// ))
.on_click(
|window, cx| {
window.dispatch_action(
Box::new(NewExternalAgentThread {
agent: Some(
crate::ExternalAgent::NativeAgent,
),
}),
cx,
)
},
),
),
), // })
) )
.when_some(configuration_error.as_ref(), |this, err| { .when_some(configuration_error.as_ref(), |this, err| {
this.child(self.render_configuration_error(err, &focus_handle, window, cx)) this.child(self.render_configuration_error(err, &focus_handle, window, cx))

View file

@ -150,7 +150,7 @@ enum ExternalAgent {
#[default] #[default]
Gemini, Gemini,
ClaudeCode, ClaudeCode,
Codex, NativeAgent,
} }
impl ExternalAgent { impl ExternalAgent {
@ -158,7 +158,7 @@ impl ExternalAgent {
match self { match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
ExternalAgent::Codex => Rc::new(agent_servers::Codex), ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer),
} }
} }
} }

View file

@ -441,14 +441,12 @@ impl Client {
Ok(()) Ok(())
} }
#[allow(unused)] pub fn on_notification(
pub fn on_notification<F>(&self, method: &'static str, f: F) &self,
where method: &'static str,
F: 'static + Send + FnMut(Value, AsyncApp), f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
{ ) {
self.notification_handlers self.notification_handlers.lock().insert(method, f);
.lock()
.insert(method, Box::new(f));
} }
} }

View file

@ -95,8 +95,28 @@ impl ContextServer {
self.client.read().clone() self.client.read().clone()
} }
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> { pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
let client = match &self.configuration { self.initialize(self.new_client(cx)?).await
}
/// Starts the context server, making sure handlers are registered before initialization happens
pub async fn start_with_handlers(
&self,
notification_handlers: Vec<(
&'static str,
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
)>,
cx: &AsyncApp,
) -> Result<()> {
let client = self.new_client(cx)?;
for (method, handler) in notification_handlers {
client.on_notification(method, handler);
}
self.initialize(client).await
}
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
Ok(match &self.configuration {
ContextServerTransport::Stdio(command, working_directory) => Client::stdio( ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
client::ContextServerId(self.id.0.clone()), client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary { client::ModelContextServerBinary {
@ -113,8 +133,7 @@ impl ContextServer {
transport.clone(), transport.clone(),
cx.clone(), cx.clone(),
)?, )?,
}; })
self.initialize(client).await
} }
async fn initialize(&self, client: Client) -> Result<()> { async fn initialize(&self, client: Client) -> Result<()> {

View file

@ -83,14 +83,18 @@ impl McpServer {
} }
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) { pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
let output_schema = schemars::schema_for!(T::Output); let mut settings = schemars::generate::SchemaSettings::draft07();
let unit_schema = schemars::schema_for!(()); settings.inline_subschemas = true;
let mut generator = settings.into_generator();
let output_schema = generator.root_schema_for::<T::Output>();
let unit_schema = generator.root_schema_for::<T::Output>();
let registered_tool = RegisteredTool { let registered_tool = RegisteredTool {
tool: Tool { tool: Tool {
name: T::NAME.into(), name: T::NAME.into(),
description: Some(tool.description().into()), description: Some(tool.description().into()),
input_schema: schemars::schema_for!(T::Input).into(), input_schema: generator.root_schema_for::<T::Input>().into(),
output_schema: if output_schema == unit_schema { output_schema: if output_schema == unit_schema {
None None
} else { } else {

View file

@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
self.inner.notify(T::METHOD, params) self.inner.notify(T::METHOD, params)
} }
pub fn on_notification<F>(&self, method: &'static str, f: F) pub fn on_notification(
where &self,
F: 'static + Send + FnMut(Value, AsyncApp), method: &'static str,
{ f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
) {
self.inner.on_notification(method, f); self.inner.on_notification(method, f);
} }
} }