Compare commits
30 commits
main
...
acp-champa
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f81993574e | ||
![]() |
bc1f861d3f | ||
![]() |
4f2d6a9ea9 | ||
![]() |
604a88f6e3 | ||
![]() |
a4fe8c6972 | ||
![]() |
5d621bef78 | ||
![]() |
27877325bc | ||
![]() |
9e1c7fdfea | ||
![]() |
387ee1be8d | ||
![]() |
afc8cf6098 | ||
![]() |
84d6a0fae9 | ||
![]() |
afb5c4147a | ||
![]() |
8890f590b1 | ||
![]() |
76f0f9163d | ||
![]() |
8acb58b6e5 | ||
![]() |
8563ed2252 | ||
![]() |
02d3043ec5 | ||
![]() |
30739041a4 | ||
![]() |
27708143ec | ||
![]() |
738296345e | ||
![]() |
81c111510f | ||
![]() |
f028ca4d1a | ||
![]() |
6656403ce8 | ||
![]() |
254c6be42b | ||
![]() |
745e4b5f1e | ||
![]() |
912ab505b2 | ||
![]() |
b48faddaf4 | ||
![]() |
477731d77d | ||
![]() |
ced3d09f10 | ||
![]() |
0f395df9a8 |
37 changed files with 2347 additions and 917 deletions
62
Cargo.lock
generated
62
Cargo.lock
generated
|
@ -19,6 +19,7 @@ dependencies = [
|
|||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"markdown",
|
||||
"project",
|
||||
"serde",
|
||||
|
@ -137,15 +138,57 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "agent-client-protocol"
|
||||
version = "0.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b"
|
||||
version = "0.0.14"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"acp_thread",
|
||||
"agent-client-protocol",
|
||||
"agent_servers",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"chrono",
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"reqwest_client",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"thiserror 2.0.12",
|
||||
"ui",
|
||||
"util",
|
||||
"uuid",
|
||||
"worktree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent_servers"
|
||||
version = "0.1.0"
|
||||
|
@ -209,6 +252,7 @@ dependencies = [
|
|||
"acp_thread",
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"agent2",
|
||||
"agent_servers",
|
||||
"agent_settings",
|
||||
"ai_onboarding",
|
||||
|
@ -9570,9 +9614,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.12"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
|
||||
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"scopeguard",
|
||||
|
@ -11286,9 +11330,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
|
|||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.3"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
|
||||
checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
|
@ -11296,9 +11340,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.10"
|
||||
version = "0.9.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
|
||||
checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
|
|
|
@ -4,6 +4,7 @@ members = [
|
|||
"crates/acp_thread",
|
||||
"crates/activity_indicator",
|
||||
"crates/agent",
|
||||
"crates/agent2",
|
||||
"crates/agent_servers",
|
||||
"crates/agent_settings",
|
||||
"crates/agent_ui",
|
||||
|
@ -228,6 +229,7 @@ edition = "2024"
|
|||
|
||||
acp_thread = { path = "crates/acp_thread" }
|
||||
agent = { path = "crates/agent" }
|
||||
agent2 = { path = "crates/agent2" }
|
||||
activity_indicator = { path = "crates/activity_indicator" }
|
||||
agent_ui = { path = "crates/agent_ui" }
|
||||
agent_settings = { path = "crates/agent_settings" }
|
||||
|
@ -421,7 +423,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
|||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = "0.0.11"
|
||||
agent-client-protocol = {path="../agent-client-protocol"}
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
|
|
|
@ -26,6 +26,7 @@ futures.workspace = true
|
|||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
markdown.workspace = true
|
||||
project.workspace = true
|
||||
serde.workspace = true
|
||||
|
|
|
@ -391,7 +391,7 @@ impl ToolCallContent {
|
|||
cx: &mut App,
|
||||
) -> Self {
|
||||
match content {
|
||||
acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock {
|
||||
acp::ToolCallContent::Content { content } => Self::ContentBlock {
|
||||
content: ContentBlock::new(content, &language_registry, cx),
|
||||
},
|
||||
acp::ToolCallContent::Diff { diff } => Self::Diff {
|
||||
|
@ -619,6 +619,7 @@ impl Error for LoadError {}
|
|||
|
||||
impl AcpThread {
|
||||
pub fn new(
|
||||
title: impl Into<SharedString>,
|
||||
connection: Rc<dyn AgentConnection>,
|
||||
project: Entity<Project>,
|
||||
session_id: acp::SessionId,
|
||||
|
@ -631,7 +632,7 @@ impl AcpThread {
|
|||
shared_buffers: Default::default(),
|
||||
entries: Default::default(),
|
||||
plan: Default::default(),
|
||||
title: connection.name().into(),
|
||||
title: title.into(),
|
||||
project,
|
||||
send_task: None,
|
||||
connection,
|
||||
|
@ -655,6 +656,10 @@ impl AcpThread {
|
|||
&self.entries
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> &acp::SessionId {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
pub fn status(&self) -> ThreadStatus {
|
||||
if self.send_task.is_some() {
|
||||
if self.waiting_for_tool_confirmation() {
|
||||
|
@ -697,14 +702,14 @@ impl AcpThread {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
match update {
|
||||
acp::SessionUpdate::UserMessage(content_block) => {
|
||||
self.push_user_content_block(content_block, cx);
|
||||
acp::SessionUpdate::UserMessageChunk { content } => {
|
||||
self.push_user_content_block(content, cx);
|
||||
}
|
||||
acp::SessionUpdate::AgentMessageChunk(content_block) => {
|
||||
self.push_assistant_content_block(content_block, false, cx);
|
||||
acp::SessionUpdate::AgentMessageChunk { content } => {
|
||||
self.push_assistant_content_block(content, false, cx);
|
||||
}
|
||||
acp::SessionUpdate::AgentThoughtChunk(content_block) => {
|
||||
self.push_assistant_content_block(content_block, true, cx);
|
||||
acp::SessionUpdate::AgentThoughtChunk { content } => {
|
||||
self.push_assistant_content_block(content, true, cx);
|
||||
}
|
||||
acp::SessionUpdate::ToolCall(tool_call) => {
|
||||
self.upsert_tool_call(tool_call, cx);
|
||||
|
@ -973,10 +978,6 @@ impl AcpThread {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.connection.authenticate(cx)
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn send_raw(
|
||||
&mut self,
|
||||
|
@ -1018,7 +1019,7 @@ impl AcpThread {
|
|||
let result = this
|
||||
.update(cx, |this, cx| {
|
||||
this.connection.prompt(
|
||||
acp::PromptArguments {
|
||||
acp::PromptRequest {
|
||||
prompt: message,
|
||||
session_id: this.session_id.clone(),
|
||||
},
|
||||
|
@ -1620,9 +1621,15 @@ mod tests {
|
|||
connection,
|
||||
child_status: io_task,
|
||||
current_thread: thread_rc,
|
||||
auth_methods: [acp::AuthMethod {
|
||||
id: acp::AuthMethodId("acp-old-no-id".into()),
|
||||
label: "Log in".into(),
|
||||
description: None,
|
||||
}],
|
||||
};
|
||||
|
||||
AcpThread::new(
|
||||
"test",
|
||||
Rc::new(connection),
|
||||
project,
|
||||
acp::SessionId("test".into()),
|
||||
|
|
|
@ -1,16 +1,62 @@
|
|||
use std::{path::Path, rc::Rc};
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language_model::LanguageModel;
|
||||
use project::Project;
|
||||
use ui::App;
|
||||
|
||||
use crate::AcpThread;
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn name(&self) -> &'static str;
|
||||
/// Trait for agents that support listing, selecting, and querying language models.
|
||||
///
|
||||
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
|
||||
pub trait ModelSelector: 'static {
|
||||
/// Lists all available language models for this agent.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `cx`: The GPUI app context for async operations and global access.
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to the list of models or an error (e.g., if no models are configured).
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
|
||||
|
||||
/// Selects a model for a specific session (thread).
|
||||
///
|
||||
/// This sets the default model for future interactions in the session.
|
||||
/// If the session doesn't exist or the model is invalid, it returns an error.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `session_id`: The ID of the session (thread) to apply the model to.
|
||||
/// - `model`: The model to select (should be one from [list_models]).
|
||||
/// - `cx`: The GPUI app context.
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to `Ok(())` on success or an error.
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>>;
|
||||
|
||||
/// Retrieves the currently selected model for a specific session (thread).
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `session_id`: The ID of the session (thread) to query.
|
||||
/// - `cx`: The GPUI app context.
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to the selected model (always set) or an error (e.g., session not found).
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>>;
|
||||
}
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
|
@ -18,9 +64,29 @@ pub trait AgentConnection {
|
|||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||
///
|
||||
/// If the agent does not support model selection, returns [None].
|
||||
/// This allows sharing the selector in UI components.
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
None // Default impl for agents that don't support it
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,11 +5,11 @@ use anyhow::{Context as _, Result};
|
|||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use project::Project;
|
||||
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
|
||||
use std::{cell::RefCell, path::Path, rc::Rc};
|
||||
use ui::App;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{AcpThread, AgentConnection};
|
||||
use crate::{AcpThread, AgentConnection, AuthRequired};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OldAcpClientDelegate {
|
||||
|
@ -351,28 +351,15 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Unauthenticated;
|
||||
|
||||
impl Error for Unauthenticated {}
|
||||
impl fmt::Display for Unauthenticated {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Unauthenticated")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OldAcpAgentConnection {
|
||||
pub name: &'static str,
|
||||
pub connection: acp_old::AgentConnection,
|
||||
pub child_status: Task<Result<()>>,
|
||||
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
||||
pub auth_methods: [acp::AuthMethod; 1],
|
||||
}
|
||||
|
||||
impl AgentConnection for OldAcpAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
|
@ -391,13 +378,13 @@ impl AgentConnection for OldAcpAgentConnection {
|
|||
let result = acp_old::InitializeParams::response_from_any(result)?;
|
||||
|
||||
if !result.is_authenticated {
|
||||
anyhow::bail!(Unauthenticated)
|
||||
anyhow::bail!(AuthRequired)
|
||||
}
|
||||
|
||||
cx.update(|cx| {
|
||||
let thread = cx.new(|cx| {
|
||||
let session_id = acp::SessionId("acp-old-no-id".into());
|
||||
AcpThread::new(self.clone(), project, session_id, cx)
|
||||
AcpThread::new("Gemini", self.clone(), project, session_id, cx)
|
||||
});
|
||||
current_thread.replace(thread.downgrade());
|
||||
thread
|
||||
|
@ -405,7 +392,11 @@ impl AgentConnection for OldAcpAgentConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&self.auth_methods
|
||||
}
|
||||
|
||||
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
let task = self
|
||||
.connection
|
||||
.request_any(acp_old::AuthenticateParams.into_any());
|
||||
|
@ -415,7 +406,7 @@ impl AgentConnection for OldAcpAgentConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
let chunks = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
|
|
56
crates/agent2/Cargo.toml
Normal file
56
crates/agent2/Cargo.toml
Normal file
|
@ -0,0 +1,56 @@
|
|||
[package]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "GPL-3.0-or-later"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/agent2.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_servers.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
thiserror.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
worktree.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
worktree = { workspace = true, "features" = ["test-support"] }
|
1
crates/agent2/LICENSE-GPL
Symbolic link
1
crates/agent2/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
377
crates/agent2/src/agent.rs
Normal file
377
crates/agent2/src/agent.rs
Normal file
|
@ -0,0 +1,377 @@
|
|||
use acp_thread::ModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use project::Project;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{templates::Templates, Thread};
|
||||
|
||||
/// Holds both the internal Thread and the AcpThread for a session
|
||||
#[derive(Clone)]
|
||||
struct Session {
|
||||
/// The internal thread that processes messages
|
||||
thread: Entity<Thread>,
|
||||
/// The ACP thread that handles protocol communication
|
||||
acp_thread: Entity<acp_thread::AcpThread>,
|
||||
}
|
||||
|
||||
pub struct NativeAgent {
|
||||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
/// Shared templates for all threads
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl NativeAgent {
|
||||
pub fn new(templates: Arc<Templates>) -> Self {
|
||||
log::info!("Creating new NativeAgent");
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
templates,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper struct that implements the AgentConnection trait
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||
|
||||
impl ModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
||||
log::debug!("NativeAgentConnection::list_models called");
|
||||
cx.spawn(async move |cx| {
|
||||
cx.update(|cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||
log::info!("Found {} available models", models.len());
|
||||
if models.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(models)
|
||||
}
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>> {
|
||||
log::info!(
|
||||
"Setting model for session {}: {:?}",
|
||||
session_id,
|
||||
model.name()
|
||||
);
|
||||
let agent = self.0.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
agent.update(cx, |agent, cx| {
|
||||
if let Some(session) = agent.sessions.get(&session_id) {
|
||||
session.thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model;
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Session not found"))
|
||||
}
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>> {
|
||||
let agent = self.0.clone();
|
||||
let session_id = session_id.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let session = agent
|
||||
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
||||
let selected = session
|
||||
.thread
|
||||
.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
Ok(selected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||
let agent = self.0.clone();
|
||||
log::info!("Creating new thread for project at: {:?}", cwd);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Starting thread creation in async context");
|
||||
// Create Thread
|
||||
let (session_id, thread) = agent.update(
|
||||
cx,
|
||||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||
// Fetch default model from registry settings
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
|
||||
// Log available models for debugging
|
||||
let available_count = registry.available_models(cx).count();
|
||||
log::debug!("Total available models: {}", available_count);
|
||||
|
||||
let default_model = registry
|
||||
.default_model()
|
||||
.map(|configured| {
|
||||
log::info!(
|
||||
"Using configured default model: {:?} from provider: {:?}",
|
||||
configured.model.name(),
|
||||
configured.provider.name()
|
||||
);
|
||||
configured.model
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
log::warn!("No default model configured in settings");
|
||||
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||
})?;
|
||||
|
||||
let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
|
||||
|
||||
// Generate session ID
|
||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
Ok((session_id, thread))
|
||||
},
|
||||
)??;
|
||||
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|cx| {
|
||||
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
|
||||
})
|
||||
})?;
|
||||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, _cx| {
|
||||
agent.sessions.insert(
|
||||
session_id,
|
||||
Session {
|
||||
thread,
|
||||
acp_thread: acp_thread.clone(),
|
||||
},
|
||||
);
|
||||
})?;
|
||||
|
||||
Ok(acp_thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[] // No auth for in-process
|
||||
}
|
||||
|
||||
fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
let session_id = params.session_id.clone();
|
||||
let agent = self.0.clone();
|
||||
log::info!("Received prompt request for session: {}", session_id);
|
||||
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Get session
|
||||
let session = agent
|
||||
.read_with(cx, |agent, _| {
|
||||
agent.sessions.get(&session_id).map(|s| Session {
|
||||
thread: s.thread.clone(),
|
||||
acp_thread: s.acp_thread.clone(),
|
||||
})
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
log::error!("Session not found: {}", session_id);
|
||||
anyhow::anyhow!("Session not found")
|
||||
})?;
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
// Convert prompt to message
|
||||
let message = convert_prompt_to_message(params.prompt);
|
||||
log::info!("Converted prompt to message: {} chars", message.len());
|
||||
log::debug!("Message content: {}", message);
|
||||
|
||||
// Get model using the ModelSelector capability (always available for agent2)
|
||||
// Get the selected model from the thread directly
|
||||
let model = session
|
||||
.thread
|
||||
.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
|
||||
// Send to thread
|
||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||
let response_stream = session
|
||||
.thread
|
||||
.update(cx, |thread, cx| thread.send(model, message, cx))?;
|
||||
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
let acp_thread = session.acp_thread.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
use futures::StreamExt;
|
||||
use language_model::LanguageModelCompletionEvent;
|
||||
|
||||
let mut response_stream = response_stream;
|
||||
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::Text(text) => {
|
||||
// Send text chunk as agent message
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::AgentMessageChunk {
|
||||
content: acp::ContentBlock::Text(
|
||||
acp::TextContent {
|
||||
text: text.into(),
|
||||
annotations: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||
// Convert LanguageModelToolUse to ACP ToolCall
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::ToolCall(acp::ToolCall {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
label: tool_use.name.to_string(),
|
||||
kind: acp::ToolKind::Other,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(tool_use.input),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
log::debug!("Started new assistant message");
|
||||
}
|
||||
LanguageModelCompletionEvent::UsageUpdate(usage) => {
|
||||
log::debug!("Token usage update: {:?}", usage);
|
||||
}
|
||||
LanguageModelCompletionEvent::Thinking { text, .. } => {
|
||||
// Send thinking text as agent thought chunk
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::AgentThoughtChunk {
|
||||
content: acp::ContentBlock::Text(
|
||||
acp::TextContent {
|
||||
text: text.into(),
|
||||
annotations: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
LanguageModelCompletionEvent::StatusUpdate(status) => {
|
||||
log::trace!("Status update: {:?}", status);
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
}
|
||||
LanguageModelCompletionEvent::RedactedThinking { .. } => {
|
||||
log::trace!("Redacted thinking event");
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
} => {
|
||||
log::error!(
|
||||
"Tool use JSON parse error for tool '{}' (id: {}): {} - input: {}",
|
||||
tool_name,
|
||||
id,
|
||||
json_parse_error,
|
||||
raw_input
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Error in model response stream: {:?}", e);
|
||||
// TODO: Consider sending an error message to the UI
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Response stream completed");
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
log::info!("Successfully sent prompt to thread and started response handler");
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
log::info!("Cancelling session: {}", session_id);
|
||||
self.0.update(cx, |agent, _cx| {
|
||||
agent.sessions.remove(session_id);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert ACP content blocks to a message string
|
||||
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
|
||||
log::debug!("Converting {} content blocks to message", blocks.len());
|
||||
let mut message = String::new();
|
||||
|
||||
for block in blocks {
|
||||
match block {
|
||||
acp::ContentBlock::Text(text) => {
|
||||
log::trace!("Processing text block: {} chars", text.text.len());
|
||||
message.push_str(&text.text);
|
||||
}
|
||||
acp::ContentBlock::ResourceLink(link) => {
|
||||
log::trace!("Processing resource link: {}", link.uri);
|
||||
message.push_str(&format!(" @{} ", link.uri));
|
||||
}
|
||||
acp::ContentBlock::Image(_) => {
|
||||
log::trace!("Processing image block");
|
||||
message.push_str(" [image] ");
|
||||
}
|
||||
acp::ContentBlock::Audio(_) => {
|
||||
log::trace!("Processing audio block");
|
||||
message.push_str(" [audio] ");
|
||||
}
|
||||
acp::ContentBlock::Resource(resource) => {
|
||||
log::trace!("Processing resource block: {:?}", resource.resource);
|
||||
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message
|
||||
}
|
13
crates/agent2/src/agent2.rs
Normal file
13
crates/agent2/src/agent2.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
mod agent;
|
||||
mod native_agent_server;
|
||||
mod prompts;
|
||||
mod templates;
|
||||
mod thread;
|
||||
mod tools;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use agent::*;
|
||||
pub use native_agent_server::NativeAgentServer;
|
||||
pub use thread::*;
|
58
crates/agent2/src/native_agent_server.rs
Normal file
58
crates/agent2/src/native_agent_server.rs
Normal file
|
@ -0,0 +1,58 @@
|
|||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::{templates::Templates, NativeAgent, NativeAgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentServer;
|
||||
|
||||
impl AgentServer for NativeAgentServer {
|
||||
fn name(&self) -> &'static str {
|
||||
"Native Agent"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"Native Agent"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"How can I help you today?"
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
// Using the ZedAssistant icon as it's the native built-in agent
|
||||
ui::IconName::ZedAssistant
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
_project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
|
||||
log::info!(
|
||||
"NativeAgentServer::connect called for path: {:?}",
|
||||
_root_dir
|
||||
);
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Creating templates for native agent");
|
||||
// Create templates (you might want to load these from files or resources)
|
||||
let templates = Templates::new();
|
||||
|
||||
// Create the native agent
|
||||
log::debug!("Creating native agent entity");
|
||||
let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?;
|
||||
|
||||
// Create the connection wrapper
|
||||
let connection = NativeAgentConnection(agent);
|
||||
log::info!("NativeAgentServer connection established successfully");
|
||||
|
||||
Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
|
||||
})
|
||||
}
|
||||
}
|
30
crates/agent2/src/prompts.rs
Normal file
30
crates/agent2/src/prompts.rs
Normal file
|
@ -0,0 +1,30 @@
|
|||
use crate::{
|
||||
templates::{BaseTemplate, Template, Templates, WorktreeData},
|
||||
thread::Prompt,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use gpui::{App, Entity};
|
||||
use project::Project;
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct BasePrompt {
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl Prompt for BasePrompt {
|
||||
fn render(&self, templates: &Templates, cx: &App) -> Result<String> {
|
||||
BaseTemplate {
|
||||
os: std::env::consts::OS.to_string(),
|
||||
shell: util::get_system_shell(),
|
||||
worktrees: self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| WorktreeData {
|
||||
root_name: worktree.read(cx).root_name().to_string(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
.render(templates)
|
||||
}
|
||||
}
|
57
crates/agent2/src/templates.rs
Normal file
57
crates/agent2/src/templates.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use handlebars::Handlebars;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "src/templates"]
|
||||
#[include = "*.hbs"]
|
||||
struct Assets;
|
||||
|
||||
pub struct Templates(Handlebars<'static>);
|
||||
|
||||
impl Templates {
|
||||
pub fn new() -> Arc<Self> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_embed_templates::<Assets>().unwrap();
|
||||
Arc::new(Self(handlebars))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Template: Sized {
|
||||
const TEMPLATE_NAME: &'static str;
|
||||
|
||||
fn render(&self, templates: &Templates) -> Result<String>
|
||||
where
|
||||
Self: Serialize + Sized,
|
||||
{
|
||||
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct BaseTemplate {
|
||||
pub os: String,
|
||||
pub shell: String,
|
||||
pub worktrees: Vec<WorktreeData>,
|
||||
}
|
||||
|
||||
impl Template for BaseTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "base.hbs";
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct WorktreeData {
|
||||
pub root_name: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GlobTemplate {
|
||||
pub project_roots: String,
|
||||
}
|
||||
|
||||
impl Template for GlobTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "glob.hbs";
|
||||
}
|
56
crates/agent2/src/templates/base.hbs
Normal file
56
crates/agent2/src/templates/base.hbs
Normal file
|
@ -0,0 +1,56 @@
|
|||
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
|
||||
|
||||
## Communication
|
||||
|
||||
1. Be conversational but professional.
|
||||
2. Refer to the USER in the second person and yourself in the first person.
|
||||
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
|
||||
4. NEVER lie or make things up.
|
||||
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
|
||||
|
||||
## Tool Use
|
||||
|
||||
1. Make sure to adhere to the tools schema.
|
||||
2. Provide every required argument.
|
||||
3. DO NOT use tools to access items that are already available in the context section.
|
||||
4. Use only the tools that are currently available.
|
||||
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
|
||||
|
||||
## Searching and Reading
|
||||
|
||||
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
|
||||
|
||||
If appropriate, use tool calls to explore the current project, which contains the following root directories:
|
||||
|
||||
{{#each worktrees}}
|
||||
- `{{root_name}}`
|
||||
{{/each}}
|
||||
|
||||
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
|
||||
## Fixing Diagnostics
|
||||
|
||||
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
|
||||
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
|
||||
|
||||
## Debugging
|
||||
|
||||
When debugging, only make code changes if you are certain that you can solve the problem.
|
||||
Otherwise, follow debugging best practices:
|
||||
1. Address the root cause instead of the symptoms.
|
||||
2. Add descriptive logging statements and error messages to track variable and code state.
|
||||
3. Add test functions and statements to isolate the problem.
|
||||
|
||||
## Calling External APIs
|
||||
|
||||
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
|
||||
|
||||
## System Information
|
||||
|
||||
Operating System: {{os}}
|
||||
Default Shell: {{shell}}
|
8
crates/agent2/src/templates/glob.hbs
Normal file
8
crates/agent2/src/templates/glob.hbs
Normal file
|
@ -0,0 +1,8 @@
|
|||
Find paths on disk with glob patterns.
|
||||
|
||||
Assume that all glob patterns are matched in a project directory with the following entries.
|
||||
|
||||
{{project_roots}}
|
||||
|
||||
When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be
|
||||
sure to anchor them with one of the directories listed above.
|
378
crates/agent2/src/tests/mod.rs
Normal file
378
crates/agent2/src/tests/mod.rs
Normal file
|
@ -0,0 +1,378 @@
|
|||
use super::*;
|
||||
use crate::templates::Templates;
|
||||
use acp_thread::AgentConnection as _;
|
||||
use agent_client_protocol as acp;
|
||||
use client::{Client, UserStore};
|
||||
use gpui::{AppContext, Entity, TestAppContext};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRegistry, MessageContent, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
|
||||
mod test_tools;
|
||||
use test_tools::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_echo(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx).await;
|
||||
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
|
||||
// Test a tool calls that's likely to complete *after* streaming stops.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.remove_tool(&AgentTool::name(&EchoTool));
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert!(thread
|
||||
.messages()
|
||||
.last()
|
||||
.unwrap()
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
text.contains("Ding")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(WordListTool);
|
||||
thread.send(model.clone(), "Test the word_list tool.", cx)
|
||||
});
|
||||
|
||||
let mut saw_partial_tool_use = false;
|
||||
while let Some(event) = events.next().await {
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
// Look for a tool use in the thread's last message
|
||||
let last_content = thread.messages().last().unwrap().content.last().unwrap();
|
||||
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
||||
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
||||
if tool_use_event.is_input_complete {
|
||||
last_tool_use
|
||||
.input
|
||||
.get("a")
|
||||
.expect("'a' has streamed because input is now complete");
|
||||
last_tool_use
|
||||
.input
|
||||
.get("g")
|
||||
.expect("'g' has streamed because input is now complete");
|
||||
} else {
|
||||
if !last_tool_use.is_input_complete
|
||||
&& last_tool_use.input.get("g").is_none()
|
||||
{
|
||||
saw_partial_tool_use = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("last content should be a tool use");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_partial_tool_use,
|
||||
"should see at least one partially streamed tool use in the history"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx).await;
|
||||
|
||||
// Test concurrent tool calls with different delay times
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let stop_reasons = stop_events(events);
|
||||
if stop_reasons.len() == 2 {
|
||||
assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
|
||||
} else if stop_reasons.len() == 3 {
|
||||
assert_eq!(
|
||||
stop_reasons,
|
||||
vec![
|
||||
StopReason::ToolUse,
|
||||
StopReason::ToolUse,
|
||||
StopReason::EndTurn
|
||||
]
|
||||
);
|
||||
} else {
|
||||
panic!("Expected either 1 or 2 tool uses followed by end turn");
|
||||
}
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let last_message = thread.messages().last().unwrap();
|
||||
let text = last_message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
assert!(text.contains("Ding"));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
let templates = Templates::new();
|
||||
|
||||
// Initialize language model system with test provider
|
||||
cx.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
// Initialize project settings
|
||||
Project::init_settings(cx);
|
||||
|
||||
// Use test registry with fake provider
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
|
||||
// Create agent and connection
|
||||
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Test model_selector returns Some
|
||||
let selector_opt = connection.model_selector();
|
||||
assert!(
|
||||
selector_opt.is_some(),
|
||||
"agent2 should always support ModelSelector"
|
||||
);
|
||||
let selector = selector_opt.unwrap();
|
||||
|
||||
// Test list_models
|
||||
let listed_models = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.list_models(&mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("list_models should succeed");
|
||||
assert!(!listed_models.is_empty(), "should have at least one model");
|
||||
assert_eq!(listed_models[0].id().0, "fake");
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
|
||||
// Create a thread using new_thread
|
||||
let cwd = Path::new("/test");
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
connection_rc.new_thread(project, cwd, &mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("new_thread should succeed");
|
||||
|
||||
// Get the session_id from the AcpThread
|
||||
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
||||
|
||||
// Test selected_model returns the default
|
||||
let selected = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("selected_model should succeed");
|
||||
assert_eq!(selected.id().0, "fake", "should return default model");
|
||||
|
||||
// The thread was created via prompt with the default model
|
||||
// We can verify it through selected_model
|
||||
|
||||
// Test prompt uses the selected model
|
||||
let prompt_request = acp::PromptRequest {
|
||||
session_id: session_id.clone(),
|
||||
prompt: vec![acp::ContentBlock::Text(acp::TextContent {
|
||||
text: "Test prompt".into(),
|
||||
annotations: None,
|
||||
})],
|
||||
};
|
||||
|
||||
cx.update(|cx| connection.prompt(prompt_request, cx))
|
||||
.await
|
||||
.expect("prompt should succeed");
|
||||
|
||||
// The prompt was sent successfully
|
||||
|
||||
// Test cancel
|
||||
cx.update(|cx| connection.cancel(&session_id, cx));
|
||||
|
||||
// After cancel, selected_model should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_err(), "selected_model should fail after cancel");
|
||||
|
||||
// Test error case: invalid session
|
||||
let invalid_session = acp::SessionId("invalid".into());
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&invalid_session, &mut async_cx)
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_err(), "should fail for invalid session");
|
||||
if let Err(e) = result {
|
||||
assert!(
|
||||
e.to_string().contains("Session not found"),
|
||||
"should have correct error message"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(
|
||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> Vec<StopReason> {
|
||||
result_events
|
||||
.into_iter()
|
||||
.filter_map(|event| match event.unwrap() {
|
||||
LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct ThreadTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
}
|
||||
|
||||
async fn setup(cx: &mut TestAppContext) -> ThreadTest {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
let templates = Templates::new();
|
||||
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
|
||||
.unwrap();
|
||||
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
authenticated.await.unwrap();
|
||||
model
|
||||
})
|
||||
})
|
||||
.await;
|
||||
|
||||
let thread = cx.new(|_| Thread::new(templates, model.clone()));
|
||||
|
||||
ThreadTest { model, thread }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
85
crates/agent2/src/tests/test_tools.rs
Normal file
85
crates/agent2/src/tests/test_tools.rs
Normal file
|
@ -0,0 +1,85 @@
|
|||
use super::*;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, SharedString, Task};
|
||||
|
||||
/// A tool that echoes its input
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct EchoToolInput {
|
||||
/// The text to echo.
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub struct EchoTool;
|
||||
|
||||
impl AgentTool for EchoTool {
|
||||
type Input = EchoToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"echo".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok(input.text))
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that waits for a specified delay
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct DelayToolInput {
|
||||
/// The delay in milliseconds.
|
||||
ms: u64,
|
||||
}
|
||||
|
||||
pub struct DelayTool;
|
||||
|
||||
impl AgentTool for DelayTool {
|
||||
type Input = DelayToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"delay".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
cx.foreground_executor().spawn(async move {
|
||||
smol::Timer::after(Duration::from_millis(input.ms)).await;
|
||||
Ok("Ding".to_string())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that takes an object with map from letters to random words starting with that letter.
|
||||
/// All fiealds are required! Pass a word for every letter!
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct WordListInput {
|
||||
/// Provide a random word that starts with A.
|
||||
a: Option<String>,
|
||||
/// Provide a random word that starts with B.
|
||||
b: Option<String>,
|
||||
/// Provide a random word that starts with C.
|
||||
c: Option<String>,
|
||||
/// Provide a random word that starts with D.
|
||||
d: Option<String>,
|
||||
/// Provide a random word that starts with E.
|
||||
e: Option<String>,
|
||||
/// Provide a random word that starts with F.
|
||||
f: Option<String>,
|
||||
/// Provide a random word that starts with G.
|
||||
g: Option<String>,
|
||||
}
|
||||
|
||||
pub struct WordListTool;
|
||||
|
||||
impl AgentTool for WordListTool {
|
||||
type Input = WordListInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"word_list".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok("ok".to_string()))
|
||||
}
|
||||
}
|
493
crates/agent2/src/thread.rs
Normal file
493
crates/agent2/src/thread.rs
Normal file
|
@ -0,0 +1,493 @@
|
|||
use crate::templates::Templates;
|
||||
use anyhow::{anyhow, Result};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
use futures::{channel::mpsc, future};
|
||||
use gpui::{App, Context, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
};
|
||||
use log;
|
||||
use schemars::{JsonSchema, Schema};
|
||||
use serde::Deserialize;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AgentMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
}
|
||||
|
||||
pub type AgentResponseEvent = LanguageModelCompletionEvent;
|
||||
|
||||
pub trait Prompt {
|
||||
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
completion_mode: CompletionMode,
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
system_prompts: Vec<Arc<dyn Prompt>>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
// action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
system_prompts: Vec::new(),
|
||||
running_turn: None,
|
||||
tools: BTreeMap::default(),
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
pub fn messages(&self) -> &[AgentMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
||||
self.tools.insert(tool.name(), tool.erase());
|
||||
}
|
||||
|
||||
pub fn remove_tool(&mut self, name: &str) -> bool {
|
||||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
content: impl Into<MessageContent>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
||||
let content = content.into();
|
||||
log::info!("Thread::send called with model: {:?}", model.name());
|
||||
log::debug!("Thread::send content: {:?}", content);
|
||||
|
||||
cx.notify();
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let system_message = self.build_system_message(cx);
|
||||
log::debug!(
|
||||
"System messages count: {}",
|
||||
if system_message.is_some() { 1 } else { 0 }
|
||||
);
|
||||
self.messages.extend(system_message);
|
||||
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: vec![content],
|
||||
});
|
||||
log::info!("Total messages in thread: {}", self.messages.len());
|
||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
let turn_result = async {
|
||||
// Perform one request, then keep looping if the model makes tool calls.
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
loop {
|
||||
log::debug!(
|
||||
"Building completion request with intent: {:?}",
|
||||
completion_intent
|
||||
);
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.build_completion_request(completion_intent, cx)
|
||||
})?;
|
||||
|
||||
// println!(
|
||||
// "request: {}",
|
||||
// serde_json::to_string_pretty(&request).unwrap()
|
||||
// );
|
||||
|
||||
// Stream events, appending to messages and collecting up tool uses.
|
||||
log::info!("Calling model.stream_completion");
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
let mut tool_uses = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
tool_uses.extend(thread.handle_streamed_completion_event(
|
||||
event,
|
||||
events_tx.clone(),
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("Error in completion stream: {:?}", error);
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no tool uses, the turn is done.
|
||||
if tool_uses.is_empty() {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
break;
|
||||
}
|
||||
log::info!("Found {} tool uses to execute", tool_uses.len());
|
||||
|
||||
// If there are tool uses, wait for their results to be
|
||||
// computed, then send them together in a single message on
|
||||
// the next loop iteration.
|
||||
let tool_results = future::join_all(tool_uses).await;
|
||||
log::debug!("Tool execution completed, {} results", tool_results.len());
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: tool_results
|
||||
.into_iter()
|
||||
.map(MessageContent::ToolResult)
|
||||
.collect(),
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(error) = turn_result {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
} else {
|
||||
log::info!("Turn execution completed successfully");
|
||||
}
|
||||
}));
|
||||
events_rx
|
||||
}
|
||||
|
||||
pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
|
||||
log::debug!("Building system message");
|
||||
let mut system_message = AgentMessage {
|
||||
role: Role::System,
|
||||
content: Vec::new(),
|
||||
};
|
||||
|
||||
for prompt in &self.system_prompts {
|
||||
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
|
||||
system_message
|
||||
.content
|
||||
.push(MessageContent::Text(rendered_prompt));
|
||||
}
|
||||
}
|
||||
|
||||
let result = (!system_message.content.is_empty()).then_some(system_message);
|
||||
log::debug!("System message built: {}", result.is_some());
|
||||
result
|
||||
}
|
||||
|
||||
/// A helper method that's called on every streamed completion event.
|
||||
/// Returns an optional tool result task, which the main agentic loop in
|
||||
/// send will send back to the model when it resolves.
|
||||
fn handle_streamed_completion_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
log::trace!("Handling streamed completion event: {:?}", event);
|
||||
use LanguageModelCompletionEvent::*;
|
||||
events_tx.unbounded_send(Ok(event.clone())).ok();
|
||||
|
||||
match event {
|
||||
Text(new_text) => self.handle_text_event(new_text, cx),
|
||||
Thinking {
|
||||
text: _text,
|
||||
signature: _signature,
|
||||
} => {
|
||||
todo!()
|
||||
}
|
||||
ToolUse(tool_use) => {
|
||||
return self.handle_tool_use_event(tool_use, cx);
|
||||
}
|
||||
StartMessage { .. } => {
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
UsageUpdate(_) => {}
|
||||
Stop(stop_reason) => self.handle_stop_event(stop_reason),
|
||||
StatusUpdate(_completion_request_status) => {}
|
||||
RedactedThinking { data: _data } => todo!(),
|
||||
ToolUseJsonParseError {
|
||||
id: _id,
|
||||
tool_name: _tool_name,
|
||||
raw_input: _raw_input,
|
||||
json_parse_error: _json_parse_error,
|
||||
} => todo!(),
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_stop_event(&mut self, stop_reason: StopReason) {
|
||||
match stop_reason {
|
||||
StopReason::EndTurn | StopReason::ToolUse => {}
|
||||
StopReason::MaxTokens => todo!(),
|
||||
StopReason::Refusal => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Text(new_text));
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_tool_use_event(
|
||||
&mut self,
|
||||
tool_use: LanguageModelToolUse,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
cx.notify();
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
|
||||
// Ensure the last message ends in the current tool use
|
||||
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
||||
if let MessageContent::ToolUse(last_tool_use) = content {
|
||||
if last_tool_use.id == tool_use.id {
|
||||
*last_tool_use = tool_use.clone();
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
if push_new_tool_use {
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
}
|
||||
|
||||
if !tool_use.is_input_complete {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
|
||||
let pending_tool_result = tool.clone().run(tool_use.input, cx);
|
||||
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match pending_tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||
output: None,
|
||||
},
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
Some(Task::ready(LanguageModelToolResult {
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(format!(
|
||||
"No tool named {} exists",
|
||||
tool_use.name
|
||||
))),
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
output: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Guarantees the last message is from the assistant and returns a mutable reference.
|
||||
fn last_assistant_message(&mut self) -> &mut AgentMessage {
|
||||
if self
|
||||
.messages
|
||||
.last()
|
||||
.map_or(true, |m| m.role != Role::Assistant)
|
||||
{
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
self.messages.last_mut().unwrap()
|
||||
}
|
||||
|
||||
fn build_completion_request(
|
||||
&self,
|
||||
completion_intent: CompletionIntent,
|
||||
cx: &mut App,
|
||||
) -> LanguageModelRequest {
|
||||
log::debug!("Building completion request");
|
||||
log::debug!("Completion intent: {:?}", completion_intent);
|
||||
log::debug!("Completion mode: {:?}", self.completion_mode);
|
||||
|
||||
let messages = self.build_request_messages();
|
||||
log::info!("Request will include {} messages", messages.len());
|
||||
|
||||
let tools: Vec<LanguageModelRequestTool> = self
|
||||
.tools
|
||||
.values()
|
||||
.filter_map(|tool| {
|
||||
let tool_name = tool.name().to_string();
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name,
|
||||
description: tool.description(cx).to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
log::info!("Request includes {} tools", tools.len());
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: Some(completion_intent),
|
||||
mode: Some(self.completion_mode),
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
thinking_allowed: false,
|
||||
};
|
||||
|
||||
log::debug!("Completion request built successfully");
|
||||
request
|
||||
}
|
||||
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
log::trace!(
|
||||
"Building request messages from {} thread messages",
|
||||
self.messages.len()
|
||||
);
|
||||
let messages = self
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
log::trace!(
|
||||
" - {} message with {} content items",
|
||||
match message.role {
|
||||
Role::System => "System",
|
||||
Role::User => "User",
|
||||
Role::Assistant => "Assistant",
|
||||
},
|
||||
message.content.len()
|
||||
);
|
||||
LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
content: message.content.clone(),
|
||||
cache: false,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentTool
|
||||
where
|
||||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, _cx: &mut App) -> SharedString {
|
||||
let schema = schemars::schema_for!(Self::Input);
|
||||
SharedString::new(
|
||||
schema
|
||||
.get("description")
|
||||
.and_then(|description| description.as_str())
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Erased<T>(T);
|
||||
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
|
||||
}
|
||||
|
||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||
where
|
||||
T: AgentTool,
|
||||
{
|
||||
fn name(&self) -> SharedString {
|
||||
self.0.name()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
self.0.description(cx)
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
Ok(serde_json::to_value(self.0.input_schema(format))?)
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
}
|
1
crates/agent2/src/tools.rs
Normal file
1
crates/agent2/src/tools.rs
Normal file
|
@ -0,0 +1 @@
|
|||
mod glob;
|
76
crates/agent2/src/tools/glob.rs
Normal file
76
crates/agent2/src/tools/glob.rs
Normal file
|
@ -0,0 +1,76 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
use worktree::Snapshot as WorktreeSnapshot;
|
||||
|
||||
use crate::{
|
||||
templates::{GlobTemplate, Template, Templates},
|
||||
thread::AgentTool,
|
||||
};
|
||||
|
||||
// Description is dynamic, see `fn description` below
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
struct GlobInput {
|
||||
/// A POSIX glob pattern
|
||||
glob: SharedString,
|
||||
}
|
||||
|
||||
struct GlobTool {
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl AgentTool for GlobTool {
|
||||
type Input = GlobInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"glob".into()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
let project_roots = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).root_name().into())
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
GlobTemplate { project_roots }
|
||||
.render(&self.templates)
|
||||
.expect("template failed to render")
|
||||
.into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
||||
let path_matcher = match PathMatcher::new([&input.glob]) {
|
||||
Ok(matcher) => matcher,
|
||||
Err(error) => return Task::ready(Err(anyhow!(error))),
|
||||
};
|
||||
|
||||
let snapshots: Vec<WorktreeSnapshot> = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).snapshot())
|
||||
.collect();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let paths = snapshots.iter().flat_map(|snapshot| {
|
||||
let root_name = PathBuf::from(snapshot.root_name());
|
||||
snapshot
|
||||
.entries(false, 0)
|
||||
.map(move |entry| root_name.join(&entry.path))
|
||||
.filter(|path| path_matcher.is_match(&path))
|
||||
});
|
||||
let output = paths
|
||||
.map(|path| format!("{}\n", path.display()))
|
||||
.collect::<String>();
|
||||
Ok(output)
|
||||
})
|
||||
}
|
||||
}
|
0
crates/agent_servers/acp
Normal file
0
crates/agent_servers/acp
Normal file
245
crates/agent_servers/src/acp_connection.rs
Normal file
245
crates/agent_servers/src/acp_connection.rs
Normal file
|
@ -0,0 +1,245 @@
|
|||
use agent_client_protocol::{self as acp, Agent as _};
|
||||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use project::Project;
|
||||
use std::cell::RefCell;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
|
||||
use crate::AgentServerCommand;
|
||||
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
|
||||
|
||||
pub struct AcpConnection {
|
||||
server_name: &'static str,
|
||||
connection: Rc<acp::ClientSideConnection>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
auth_methods: Vec<acp::AuthMethod>,
|
||||
_io_task: Task<Result<()>>,
|
||||
}
|
||||
|
||||
pub struct AcpSession {
|
||||
thread: WeakEntity<AcpThread>,
|
||||
}
|
||||
|
||||
impl AcpConnection {
|
||||
pub async fn stdio(
|
||||
server_name: &'static str,
|
||||
command: AgentServerCommand,
|
||||
root_dir: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter().map(|arg| arg.as_str()))
|
||||
.envs(command.env.iter().flatten())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdout = child.stdout.take().expect("Failed to take stdout");
|
||||
let stdin = child.stdin.take().expect("Failed to take stdin");
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let client = ClientDelegate {
|
||||
sessions: sessions.clone(),
|
||||
cx: cx.clone(),
|
||||
};
|
||||
let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
move |fut| {
|
||||
foreground_executor.spawn(fut).detach();
|
||||
}
|
||||
});
|
||||
|
||||
let io_task = cx.background_spawn(io_task);
|
||||
|
||||
let response = connection
|
||||
.initialize(acp::InitializeRequest {
|
||||
protocol_version: acp::VERSION,
|
||||
client_capabilities: acp::ClientCapabilities {
|
||||
fs: acp::FileSystemCapability {
|
||||
read_text_file: true,
|
||||
write_text_file: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
|
||||
// todo! check version
|
||||
|
||||
Ok(Self {
|
||||
auth_methods: response.auth_methods,
|
||||
connection: connection.into(),
|
||||
server_name,
|
||||
sessions,
|
||||
_io_task: io_task,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentConnection for AcpConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let conn = self.connection.clone();
|
||||
let sessions = self.sessions.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let response = conn
|
||||
.new_session(acp::NewSessionRequest {
|
||||
// todo! Zed MCP server?
|
||||
mcp_servers: vec![],
|
||||
cwd,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let Some(session_id) = response.session_id else {
|
||||
anyhow::bail!(AuthRequired);
|
||||
};
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new(
|
||||
self.server_name,
|
||||
self.clone(),
|
||||
project,
|
||||
session_id.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let session = AcpSession {
|
||||
thread: thread.downgrade(),
|
||||
};
|
||||
sessions.borrow_mut().insert(session_id, session);
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&self.auth_methods
|
||||
}
|
||||
|
||||
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
let conn = self.connection.clone();
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let result = conn
|
||||
.authenticate(acp::AuthenticateRequest {
|
||||
method_id: method_id.clone(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
let conn = self.connection.clone();
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { Ok(conn.prompt(params).await?) })
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
let conn = self.connection.clone();
|
||||
let params = acp::CancelledNotification {
|
||||
session_id: session_id.clone(),
|
||||
};
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { conn.cancelled(params).await })
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientDelegate {
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
cx: AsyncApp,
|
||||
}
|
||||
|
||||
impl acp::Client for ClientDelegate {
|
||||
async fn request_permission(
|
||||
&self,
|
||||
arguments: acp::RequestPermissionRequest,
|
||||
) -> Result<acp::RequestPermissionResponse, acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let result = self
|
||||
.sessions
|
||||
.borrow()
|
||||
.get(&arguments.session_id)
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let outcome = match result {
|
||||
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
|
||||
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
|
||||
};
|
||||
|
||||
Ok(acp::RequestPermissionResponse { outcome })
|
||||
}
|
||||
|
||||
async fn write_text_file(
|
||||
&self,
|
||||
arguments: acp::WriteTextFileRequest,
|
||||
) -> Result<(), acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
self.sessions
|
||||
.borrow()
|
||||
.get(&arguments.session_id)
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.write_text_file(arguments.path, arguments.content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_text_file(
|
||||
&self,
|
||||
arguments: acp::ReadTextFileRequest,
|
||||
) -> Result<acp::ReadTextFileResponse, acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let content = self
|
||||
.sessions
|
||||
.borrow()
|
||||
.get(&arguments.session_id)
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(acp::ReadTextFileResponse { content })
|
||||
}
|
||||
|
||||
async fn session_notification(
|
||||
&self,
|
||||
notification: acp::SessionNotification,
|
||||
) -> Result<(), acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let sessions = self.sessions.borrow();
|
||||
let session = sessions
|
||||
.get(¬ification.session_id)
|
||||
.context("Failed to get session")?;
|
||||
|
||||
session.thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(notification.update, cx)
|
||||
})??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,14 +1,12 @@
|
|||
mod acp_connection;
|
||||
mod claude;
|
||||
mod codex;
|
||||
mod gemini;
|
||||
mod mcp_server;
|
||||
mod settings;
|
||||
|
||||
#[cfg(test)]
|
||||
mod e2e_tests;
|
||||
|
||||
pub use claude::*;
|
||||
pub use codex::*;
|
||||
pub use gemini::*;
|
||||
pub use settings::*;
|
||||
|
||||
|
@ -38,7 +36,6 @@ pub trait AgentServer: Send {
|
|||
|
||||
fn connect(
|
||||
&self,
|
||||
// these will go away when old_acp is fully removed
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
|
|
|
@ -70,10 +70,6 @@ struct ClaudeAgentConnection {
|
|||
}
|
||||
|
||||
impl AgentConnection for ClaudeAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
ClaudeCode.name()
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
|
@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
}
|
||||
});
|
||||
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
|
||||
})?;
|
||||
|
||||
thread_tx.send(thread.downgrade())?;
|
||||
|
||||
|
@ -186,11 +183,15 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
let sessions = self.sessions.borrow();
|
||||
let Some(session) = sessions.get(¶ms.session_id) else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
|
|
|
@ -1,319 +0,0 @@
|
|||
use agent_client_protocol as acp;
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use context_server::listener::McpServerTool;
|
||||
use context_server::types::requests;
|
||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use std::{path::Path, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
|
||||
use crate::mcp_server::ZedMcpServer;
|
||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
|
||||
use acp_thread::{AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Codex;
|
||||
|
||||
impl AgentServer for Codex {
|
||||
fn name(&self) -> &'static str {
|
||||
"Codex"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"Welcome to Codex"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"What can I help with?"
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiOpenAi
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let project = project.clone();
|
||||
let working_directory = project.read(cx).active_project_directory(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
})?;
|
||||
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find codex binary");
|
||||
};
|
||||
|
||||
let client: Arc<ContextServer> = ContextServer::stdio(
|
||||
ContextServerId("codex-mcp-server".into()),
|
||||
ContextServerCommand {
|
||||
path: command.path,
|
||||
args: command.args,
|
||||
env: command.env,
|
||||
},
|
||||
working_directory,
|
||||
)
|
||||
.into();
|
||||
ContextServer::start(client.clone(), cx).await?;
|
||||
|
||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||
client
|
||||
.client()
|
||||
.context("Failed to subscribe")?
|
||||
.on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
|
||||
move |notification, _cx| {
|
||||
let notification_tx = notification_tx.clone();
|
||||
log::trace!(
|
||||
"ACP Notification: {}",
|
||||
serde_json::to_string_pretty(¬ification).unwrap()
|
||||
);
|
||||
|
||||
if let Some(notification) =
|
||||
serde_json::from_value::<acp::SessionNotification>(notification)
|
||||
.log_err()
|
||||
{
|
||||
notification_tx.unbounded_send(notification).ok();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let notification_handler_task = cx.spawn({
|
||||
let sessions = sessions.clone();
|
||||
async move |cx| {
|
||||
while let Some(notification) = notification_rx.next().await {
|
||||
CodexConnection::handle_session_notification(
|
||||
notification,
|
||||
sessions.clone(),
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let connection = CodexConnection {
|
||||
client,
|
||||
sessions,
|
||||
_notification_handler_task: notification_handler_task,
|
||||
};
|
||||
Ok(Rc::new(connection) as _)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct CodexConnection {
|
||||
client: Arc<context_server::ContextServer>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
||||
_notification_handler_task: Task<()>,
|
||||
}
|
||||
|
||||
struct CodexSession {
|
||||
thread: WeakEntity<AcpThread>,
|
||||
cancel_tx: Option<oneshot::Sender<()>>,
|
||||
_mcp_server: ZedMcpServer,
|
||||
}
|
||||
|
||||
impl AgentConnection for CodexConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
"Codex"
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let client = self.client.client();
|
||||
let sessions = self.sessions.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let client = client.context("MCP server is not initialized yet")?;
|
||||
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
||||
|
||||
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
|
||||
|
||||
let response = client
|
||||
.request::<requests::CallTool>(context_server::types::CallToolParams {
|
||||
name: acp::NEW_SESSION_TOOL_NAME.into(),
|
||||
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
|
||||
mcp_servers: [(
|
||||
mcp_server::SERVER_NAME.to_string(),
|
||||
mcp_server.server_config()?,
|
||||
)]
|
||||
.into(),
|
||||
client_tools: acp::ClientTools {
|
||||
request_permission: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
|
||||
}),
|
||||
read_text_file: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
|
||||
}),
|
||||
write_text_file: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
|
||||
}),
|
||||
},
|
||||
cwd,
|
||||
})?),
|
||||
meta: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if response.is_error.unwrap_or_default() {
|
||||
return Err(anyhow!(response.text_contents()));
|
||||
}
|
||||
|
||||
let result = serde_json::from_value::<acp::NewSessionOutput>(
|
||||
response.structured_content.context("Empty response")?,
|
||||
)?;
|
||||
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
|
||||
|
||||
thread_tx.send(thread.downgrade())?;
|
||||
|
||||
let session = CodexSession {
|
||||
thread: thread.downgrade(),
|
||||
cancel_tx: None,
|
||||
_mcp_server: mcp_server,
|
||||
};
|
||||
sessions.borrow_mut().insert(result.session_id, session);
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
params: agent_client_protocol::PromptArguments,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
let client = self.client.client();
|
||||
let sessions = self.sessions.clone();
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let client = client.context("MCP server is not initialized yet")?;
|
||||
|
||||
let (new_cancel_tx, cancel_rx) = oneshot::channel();
|
||||
{
|
||||
let mut sessions = sessions.borrow_mut();
|
||||
let session = sessions
|
||||
.get_mut(¶ms.session_id)
|
||||
.context("Session not found")?;
|
||||
session.cancel_tx.replace(new_cancel_tx);
|
||||
}
|
||||
|
||||
let result = client
|
||||
.request_with::<requests::CallTool>(
|
||||
context_server::types::CallToolParams {
|
||||
name: acp::PROMPT_TOOL_NAME.into(),
|
||||
arguments: Some(serde_json::to_value(params)?),
|
||||
meta: None,
|
||||
},
|
||||
Some(cancel_rx),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = &result
|
||||
&& err.is::<context_server::client::RequestCanceled>()
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let response = result?;
|
||||
|
||||
if response.is_error.unwrap_or_default() {
|
||||
return Err(anyhow!(response.text_contents()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
|
||||
let mut sessions = self.sessions.borrow_mut();
|
||||
|
||||
if let Some(cancel_tx) = sessions
|
||||
.get_mut(session_id)
|
||||
.and_then(|session| session.cancel_tx.take())
|
||||
{
|
||||
cancel_tx.send(()).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexConnection {
|
||||
pub fn handle_session_notification(
|
||||
notification: acp::SessionNotification,
|
||||
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let threads = threads.borrow();
|
||||
let Some(thread) = threads
|
||||
.get(¬ification.session_id)
|
||||
.and_then(|session| session.thread.upgrade())
|
||||
else {
|
||||
log::error!(
|
||||
"Thread not found for session ID: {}",
|
||||
notification.session_id
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(notification.update, cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CodexConnection {
|
||||
fn drop(&mut self) {
|
||||
self.client.stop().log_err();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use crate::AgentServerCommand;
|
||||
use std::path::Path;
|
||||
|
||||
crate::common_e2e_tests!(Codex, allow_option_id = "approve");
|
||||
|
||||
pub fn local_command() -> AgentServerCommand {
|
||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../../codex/codex-rs/target/debug/codex");
|
||||
|
||||
AgentServerCommand {
|
||||
path: cli_path,
|
||||
args: vec![],
|
||||
env: None,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -375,9 +375,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
|||
gemini: Some(AgentServerSettings {
|
||||
command: crate::gemini::tests::local_command(),
|
||||
}),
|
||||
codex: Some(AgentServerSettings {
|
||||
command: crate::codex::tests::local_command(),
|
||||
}),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
use anyhow::anyhow;
|
||||
use std::cell::RefCell;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
|
||||
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
|
||||
use agentic_coding_protocol as acp_old;
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use crate::{AgentServer, AgentServerCommand, acp_connection::AcpConnection};
|
||||
use acp_thread::AgentConnection;
|
||||
use anyhow::Result;
|
||||
use gpui::{Entity, Task};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use ui::App;
|
||||
|
@ -43,146 +39,27 @@ impl AgentServer for Gemini {
|
|||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
let project = project.clone();
|
||||
let this = self.clone();
|
||||
let name = self.name();
|
||||
|
||||
let server_name = self.name();
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let command = this.command(&project, cx).await?;
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||
})?;
|
||||
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find gemini binary");
|
||||
};
|
||||
// todo! check supported version
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
|
||||
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
|
||||
|
||||
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
|
||||
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
|
||||
stdin,
|
||||
stdout,
|
||||
move |fut| foreground_executor.spawn(fut).detach(),
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn(async move {
|
||||
io_fut.await.log_err();
|
||||
});
|
||||
|
||||
let child_status = cx.background_spawn(async move {
|
||||
let result = match child.status().await {
|
||||
Err(e) => Err(anyhow!(e)),
|
||||
Ok(result) if result.success() => Ok(()),
|
||||
Ok(result) => {
|
||||
if let Some(AgentServerVersion::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command,
|
||||
}) = this.version(&command).await.log_err()
|
||||
{
|
||||
Err(anyhow!(LoadError::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command
|
||||
}))
|
||||
} else {
|
||||
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
|
||||
}
|
||||
}
|
||||
};
|
||||
drop(io_task);
|
||||
result
|
||||
});
|
||||
|
||||
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
|
||||
name,
|
||||
connection,
|
||||
child_status,
|
||||
current_thread: thread_rc,
|
||||
});
|
||||
|
||||
Ok(connection)
|
||||
let conn = AcpConnection::stdio(server_name, command, &root_dir, cx).await?;
|
||||
Ok(Rc::new(conn) as _)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Gemini {
|
||||
async fn command(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<AgentServerCommand> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||
})?;
|
||||
|
||||
if let Some(command) =
|
||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
||||
{
|
||||
return Ok(command);
|
||||
};
|
||||
|
||||
let (fs, node_runtime) = project.update(cx, |project, _| {
|
||||
(project.fs().clone(), project.node_runtime().cloned())
|
||||
})?;
|
||||
let node_runtime = node_runtime.context("gemini not found on path")?;
|
||||
|
||||
let directory = ::paths::agent_servers_dir().join("gemini");
|
||||
fs.create_dir(&directory).await?;
|
||||
node_runtime
|
||||
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
|
||||
.await?;
|
||||
let path = directory.join("node_modules/.bin/gemini");
|
||||
|
||||
Ok(AgentServerCommand {
|
||||
path,
|
||||
args: vec![ACP_ARG.into()],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
|
||||
let version_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--version")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let help_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--help")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
|
||||
|
||||
let current_version = String::from_utf8(version_output?.stdout)?;
|
||||
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
|
||||
|
||||
if supported {
|
||||
Ok(AgentServerVersion::Supported)
|
||||
} else {
|
||||
Ok(AgentServerVersion::Unsupported {
|
||||
error_message: format!(
|
||||
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
|
||||
current_version
|
||||
).into(),
|
||||
upgrade_message: "Upgrade Gemini to Latest".into(),
|
||||
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
|
@ -199,7 +76,7 @@ pub(crate) mod tests {
|
|||
|
||||
AgentServerCommand {
|
||||
path: "node".into(),
|
||||
args: vec![cli_path, ACP_ARG.into()],
|
||||
args: vec![cli_path],
|
||||
env: None,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,207 +0,0 @@
|
|||
use acp_thread::AcpThread;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use context_server::listener::{McpServerTool, ToolResponse};
|
||||
use context_server::types::{
|
||||
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
|
||||
ToolsCapabilities, requests,
|
||||
};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
||||
use indoc::indoc;
|
||||
|
||||
pub struct ZedMcpServer {
|
||||
server: context_server::listener::McpServer,
|
||||
}
|
||||
|
||||
pub const SERVER_NAME: &str = "zed";
|
||||
|
||||
impl ZedMcpServer {
|
||||
pub async fn new(
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
||||
|
||||
mcp_server.add_tool(RequestPermissionTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(ReadTextFileTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(WriteTextFileTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
|
||||
Ok(Self { server: mcp_server })
|
||||
}
|
||||
|
||||
pub fn server_config(&self) -> Result<acp::McpServerConfig> {
|
||||
#[cfg(not(test))]
|
||||
let zed_path = anyhow::Context::context(
|
||||
std::env::current_exe(),
|
||||
"finding current executable path for use in mcp_server",
|
||||
)?;
|
||||
|
||||
#[cfg(test)]
|
||||
let zed_path = crate::e2e_tests::get_zed_path();
|
||||
|
||||
Ok(acp::McpServerConfig {
|
||||
command: zed_path,
|
||||
args: vec![
|
||||
"--nc".into(),
|
||||
self.server.socket_path().display().to_string(),
|
||||
],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
Ok(InitializeResponse {
|
||||
protocol_version: ProtocolVersion("2025-06-18".into()),
|
||||
capabilities: ServerCapabilities {
|
||||
experimental: None,
|
||||
logging: None,
|
||||
completions: None,
|
||||
prompts: None,
|
||||
resources: None,
|
||||
tools: Some(ToolsCapabilities {
|
||||
list_changed: Some(false),
|
||||
}),
|
||||
},
|
||||
server_info: Implementation {
|
||||
name: SERVER_NAME.into(),
|
||||
version: "0.1.0".into(),
|
||||
},
|
||||
meta: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Tools
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RequestPermissionTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for RequestPermissionTool {
|
||||
type Input = acp::RequestPermissionArguments;
|
||||
type Output = acp::RequestPermissionOutput;
|
||||
|
||||
const NAME: &'static str = "Confirmation";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
indoc! {"
|
||||
Request permission for tool calls.
|
||||
|
||||
This tool is meant to be called programmatically by the agent loop, not the LLM.
|
||||
"}
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let result = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(input.tool_call, input.options, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let outcome = match result {
|
||||
Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id },
|
||||
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
|
||||
};
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: acp::RequestPermissionOutput { outcome },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReadTextFileTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for ReadTextFileTool {
|
||||
type Input = acp::ReadTextFileArguments;
|
||||
type Output = acp::ReadTextFileOutput;
|
||||
|
||||
const NAME: &'static str = "Read";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Reads the content of the given file in the project including unsaved changes."
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let content = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(input.path, input.line, input.limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: acp::ReadTextFileOutput { content },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WriteTextFileTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for WriteTextFileTool {
|
||||
type Input = acp::WriteTextFileArguments;
|
||||
type Output = ();
|
||||
|
||||
const NAME: &'static str = "Write";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Write to a file replacing its contents"
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.write_text_file(input.path, input.content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: (),
|
||||
})
|
||||
}
|
||||
}
|
|
@ -13,7 +13,6 @@ pub fn init(cx: &mut App) {
|
|||
pub struct AllAgentServersSettings {
|
||||
pub gemini: Option<AgentServerSettings>,
|
||||
pub claude: Option<AgentServerSettings>,
|
||||
pub codex: Option<AgentServerSettings>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
|
@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings {
|
|||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
let mut settings = AllAgentServersSettings::default();
|
||||
|
||||
for AllAgentServersSettings {
|
||||
gemini,
|
||||
claude,
|
||||
codex,
|
||||
} in sources.defaults_and_customizations()
|
||||
{
|
||||
for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
|
||||
if gemini.is_some() {
|
||||
settings.gemini = gemini.clone();
|
||||
}
|
||||
if claude.is_some() {
|
||||
settings.claude = claude.clone();
|
||||
}
|
||||
if codex.is_some() {
|
||||
settings.codex = codex.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
|
|
|
@ -19,6 +19,7 @@ test-support = ["gpui/test-support", "language/test-support"]
|
|||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent.workspace = true
|
||||
agent2.workspace = true
|
||||
agent_servers.workspace = true
|
||||
agent_settings.workspace = true
|
||||
ai_onboarding.workspace = true
|
||||
|
|
|
@ -232,7 +232,8 @@ impl AcpThreadView {
|
|||
{
|
||||
Err(e) => {
|
||||
let mut cx = cx.clone();
|
||||
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
|
||||
// todo! remove duplication
|
||||
if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
cx.notify();
|
||||
|
@ -675,13 +676,18 @@ impl AcpThreadView {
|
|||
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
fn authenticate(
|
||||
&mut self,
|
||||
method: acp::AuthMethodId,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.last_error.take();
|
||||
let authenticate = connection.authenticate(cx);
|
||||
let authenticate = connection.authenticate(method, cx);
|
||||
self.auth_task = Some(cx.spawn_in(window, {
|
||||
let project = self.project.clone();
|
||||
let agent = self.agent.clone();
|
||||
|
@ -2380,22 +2386,26 @@ impl Render for AcpThreadView {
|
|||
.on_action(cx.listener(Self::next_history_message))
|
||||
.on_action(cx.listener(Self::open_agent_diff))
|
||||
.child(match &self.thread_state {
|
||||
ThreadState::Unauthenticated { .. } => {
|
||||
v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(
|
||||
h_flex().mt_1p5().justify_center().child(
|
||||
Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.authenticate(window, cx)
|
||||
})),
|
||||
),
|
||||
)
|
||||
}
|
||||
ThreadState::Unauthenticated { connection } => v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(h_flex().mt_1p5().justify_center().children(
|
||||
connection.auth_methods().into_iter().map(|method| {
|
||||
Button::new(
|
||||
SharedString::from(method.id.0.clone()),
|
||||
method.label.clone(),
|
||||
)
|
||||
.on_click({
|
||||
let method_id = method.id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.authenticate(method_id.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
}),
|
||||
)),
|
||||
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
||||
ThreadState::LoadError(e) => v_flex()
|
||||
.p_2()
|
||||
|
@ -2834,10 +2844,6 @@ mod tests {
|
|||
}
|
||||
|
||||
impl AgentConnection for StubAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
"StubAgentConnection"
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
|
@ -2853,17 +2859,27 @@ mod tests {
|
|||
.into(),
|
||||
);
|
||||
let thread = cx
|
||||
.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))
|
||||
.new(|cx| {
|
||||
AcpThread::new("New Thread", self.clone(), project, session_id.clone(), cx)
|
||||
})
|
||||
.unwrap();
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
Task::ready(Ok(thread))
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method: acp::AuthMethodId,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
let sessions = self.sessions.lock();
|
||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||
let mut tasks = vec![];
|
||||
|
@ -2910,10 +2926,6 @@ mod tests {
|
|||
struct SaboteurAgentConnection;
|
||||
|
||||
impl AgentConnection for SaboteurAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
"SaboteurAgentConnection"
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
|
@ -2921,15 +2933,23 @@ mod tests {
|
|||
cx: &mut gpui::AsyncApp,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
Task::ready(Ok(cx
|
||||
.new(|cx| AcpThread::new(self, project, SessionId("test".into()), cx))
|
||||
.new(|cx| AcpThread::new("New Thread", self, project, SessionId("test".into()), cx))
|
||||
.unwrap()))
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method: acp::AuthMethodId,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(&self, _params: acp::PromptArguments, _cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
Task::ready(Err(anyhow::anyhow!("Error prompting")))
|
||||
}
|
||||
|
||||
|
|
|
@ -1954,54 +1954,54 @@ impl AgentPanel {
|
|||
this
|
||||
}
|
||||
})
|
||||
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
this.separator()
|
||||
.header("External Agents")
|
||||
.item(
|
||||
ContextMenuEntry::new("New Gemini Thread")
|
||||
.icon(IconName::AiGemini)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Gemini),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Claude Code Thread")
|
||||
.icon(IconName::AiClaude)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::ClaudeCode,
|
||||
),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Codex Thread")
|
||||
.icon(IconName::AiOpenAi)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Codex),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
});
|
||||
// Temporarily removed feature flag check for testing
|
||||
// .when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
// this
|
||||
.separator()
|
||||
.header("External Agents")
|
||||
.item(
|
||||
ContextMenuEntry::new("New Gemini Thread")
|
||||
.icon(IconName::AiGemini)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Gemini),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Claude Code Thread")
|
||||
.icon(IconName::AiClaude)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::ClaudeCode),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Native Agent Thread")
|
||||
.icon(IconName::ZedAssistant)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::NativeAgent),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
);
|
||||
// });
|
||||
menu
|
||||
}))
|
||||
}
|
||||
|
@ -2608,82 +2608,87 @@ impl AgentPanel {
|
|||
),
|
||||
),
|
||||
)
|
||||
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-gemini-thread-btn",
|
||||
"New Gemini Thread",
|
||||
IconName::AiGemini,
|
||||
)
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::Gemini,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
// Temporarily removed feature flag check for testing
|
||||
// .when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
// this
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-gemini-thread-btn",
|
||||
"New Gemini Thread",
|
||||
IconName::AiGemini,
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-claude-thread-btn",
|
||||
"New Claude Code Thread",
|
||||
IconName::AiClaude,
|
||||
)
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::ClaudeCode,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-codex-thread-btn",
|
||||
"New Codex Thread",
|
||||
IconName::AiOpenAi,
|
||||
)
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::Codex,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Gemini),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-claude-thread-btn",
|
||||
"New Claude Code Thread",
|
||||
IconName::AiClaude,
|
||||
)
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::ClaudeCode,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-native-agent-thread-btn",
|
||||
"New Native Agent Thread",
|
||||
IconName::ZedAssistant,
|
||||
)
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::NativeAgent,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
), // })
|
||||
)
|
||||
.when_some(configuration_error.as_ref(), |this, err| {
|
||||
this.child(self.render_configuration_error(err, &focus_handle, window, cx))
|
||||
|
|
|
@ -150,7 +150,7 @@ enum ExternalAgent {
|
|||
#[default]
|
||||
Gemini,
|
||||
ClaudeCode,
|
||||
Codex,
|
||||
NativeAgent,
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
|
@ -158,7 +158,7 @@ impl ExternalAgent {
|
|||
match self {
|
||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||
ExternalAgent::Codex => Rc::new(agent_servers::Codex),
|
||||
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -441,14 +441,12 @@ impl Client {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
{
|
||||
self.notification_handlers
|
||||
.lock()
|
||||
.insert(method, Box::new(f));
|
||||
pub fn on_notification(
|
||||
&self,
|
||||
method: &'static str,
|
||||
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||
) {
|
||||
self.notification_handlers.lock().insert(method, f);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -95,8 +95,28 @@ impl ContextServer {
|
|||
self.client.read().clone()
|
||||
}
|
||||
|
||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
||||
let client = match &self.configuration {
|
||||
pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
|
||||
self.initialize(self.new_client(cx)?).await
|
||||
}
|
||||
|
||||
/// Starts the context server, making sure handlers are registered before initialization happens
|
||||
pub async fn start_with_handlers(
|
||||
&self,
|
||||
notification_handlers: Vec<(
|
||||
&'static str,
|
||||
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
|
||||
)>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<()> {
|
||||
let client = self.new_client(cx)?;
|
||||
for (method, handler) in notification_handlers {
|
||||
client.on_notification(method, handler);
|
||||
}
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
|
||||
Ok(match &self.configuration {
|
||||
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
|
||||
client::ContextServerId(self.id.0.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
|
@ -113,8 +133,7 @@ impl ContextServer {
|
|||
transport.clone(),
|
||||
cx.clone(),
|
||||
)?,
|
||||
};
|
||||
self.initialize(client).await
|
||||
})
|
||||
}
|
||||
|
||||
async fn initialize(&self, client: Client) -> Result<()> {
|
||||
|
|
|
@ -83,14 +83,18 @@ impl McpServer {
|
|||
}
|
||||
|
||||
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
|
||||
let output_schema = schemars::schema_for!(T::Output);
|
||||
let unit_schema = schemars::schema_for!(());
|
||||
let mut settings = schemars::generate::SchemaSettings::draft07();
|
||||
settings.inline_subschemas = true;
|
||||
let mut generator = settings.into_generator();
|
||||
|
||||
let output_schema = generator.root_schema_for::<T::Output>();
|
||||
let unit_schema = generator.root_schema_for::<T::Output>();
|
||||
|
||||
let registered_tool = RegisteredTool {
|
||||
tool: Tool {
|
||||
name: T::NAME.into(),
|
||||
description: Some(tool.description().into()),
|
||||
input_schema: schemars::schema_for!(T::Input).into(),
|
||||
input_schema: generator.root_schema_for::<T::Input>().into(),
|
||||
output_schema: if output_schema == unit_schema {
|
||||
None
|
||||
} else {
|
||||
|
|
|
@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
|
|||
self.inner.notify(T::METHOD, params)
|
||||
}
|
||||
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
{
|
||||
pub fn on_notification(
|
||||
&self,
|
||||
method: &'static str,
|
||||
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||
) {
|
||||
self.inner.on_notification(method, f);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue