From 244432175669cf3bc4c1c49c794692e8f0947fd3 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 12 Aug 2025 14:17:48 +0200 Subject: [PATCH] Support profiles in agent2 (#36034) We still need a profile selector. Release Notes: - N/A --------- Co-authored-by: Ben Brandt --- Cargo.lock | 1 + crates/acp_thread/src/acp_thread.rs | 51 ++++ crates/agent2/Cargo.toml | 2 + crates/agent2/src/agent.rs | 34 ++- crates/agent2/src/tests/mod.rs | 142 +++++++++-- crates/agent2/src/thread.rs | 87 +++++-- crates/agent2/src/tools.rs | 2 + .../src/tools/context_server_registry.rs | 231 ++++++++++++++++++ crates/agent2/src/tools/diagnostics_tool.rs | 18 +- crates/agent2/src/tools/edit_file_tool.rs | 66 ++++- crates/agent2/src/tools/fetch_tool.rs | 8 +- crates/agent2/src/tools/find_path_tool.rs | 3 - crates/agent2/src/tools/grep_tool.rs | 25 +- crates/agent2/src/tools/now_tool.rs | 11 +- crates/agent_settings/src/agent_profile.rs | 14 ++ 15 files changed, 587 insertions(+), 108 deletions(-) create mode 100644 crates/agent2/src/tools/context_server_registry.rs diff --git a/Cargo.lock b/Cargo.lock index 79bce189e2..dc28a1cb44 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -196,6 +196,7 @@ dependencies = [ "clock", "cloud_llm_client", "collections", + "context_server", "ctor", "editor", "env_logger 0.11.8", diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d632e6e570..1c0a9479df 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -254,6 +254,15 @@ impl ToolCall { } if let Some(raw_output) = raw_output { + if self.content.is_empty() { + if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx) + { + self.content + .push(ToolCallContent::ContentBlock(ContentBlock::Markdown { + markdown, + })); + } + } self.raw_output = Some(raw_output); } } @@ -1266,6 +1275,48 @@ impl AcpThread { } } +fn markdown_for_raw_output( + raw_output: &serde_json::Value, + language_registry: &Arc, + cx: &mut App, +) -> Option> { + match raw_output { + serde_json::Value::Null => None, + serde_json::Value::Bool(value) => Some(cx.new(|cx| { + Markdown::new( + value.to_string().into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + serde_json::Value::Number(value) => Some(cx.new(|cx| { + Markdown::new( + value.to_string().into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + serde_json::Value::String(value) => Some(cx.new(|cx| { + Markdown::new( + value.clone().into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + value => Some(cx.new(|cx| { + Markdown::new( + format!("```json\n{}\n```", value).into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 7ee48aca04..1030380dc0 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -23,6 +23,7 @@ assistant_tools.workspace = true chrono.workspace = true cloud_llm_client.workspace = true collections.workspace = true +context_server.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true @@ -60,6 +61,7 @@ workspace-hack.workspace = true ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } +context_server = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 66893f49f9..18a830b978 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,8 +1,8 @@ use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ - CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, - GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, - ThinkingTool, ToolCallAuthorization, WebSearchTool, + ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, + FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, + ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool, }; use acp_thread::ModelSelector; use agent_client_protocol as acp; @@ -55,6 +55,7 @@ pub struct NativeAgent { project_context: Rc>, project_context_needs_refresh: watch::Sender<()>, _maintain_project_context: Task>, + context_server_registry: Entity, /// Shared templates for all threads templates: Arc, project: Entity, @@ -90,6 +91,9 @@ impl NativeAgent { _maintain_project_context: cx.spawn(async move |this, cx| { Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await }), + context_server_registry: cx.new(|cx| { + ContextServerRegistry::new(project.read(cx).context_server_store(), cx) + }), templates, project, prompt_store, @@ -385,7 +389,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Create AcpThread let acp_thread = cx.update(|cx| { cx.new(|cx| { - acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx) + acp_thread::AcpThread::new( + "agent2", + self.clone(), + project.clone(), + session_id.clone(), + cx, + ) }) })?; let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; @@ -413,11 +423,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection { }) .ok_or_else(|| { log::warn!("No default model configured in settings"); - anyhow!("No default model configured. Please configure a default model in settings.") + anyhow!( + "No default model. Please configure a default model in settings." + ) })?; let thread = cx.new(|cx| { - let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model); + let mut thread = Thread::new( + project.clone(), + agent.project_context.clone(), + agent.context_server_registry.clone(), + action_log.clone(), + agent.templates.clone(), + default_model, + cx, + ); thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(CopyPathTool::new(project.clone())); thread.add_tool(DiagnosticsTool::new(project.clone())); @@ -450,7 +470,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { acp_thread: acp_thread.downgrade(), _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); - }) + }), }, ); })?; diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index d6aaddf2c2..7f4b934c08 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -2,6 +2,7 @@ use super::*; use acp_thread::AgentConnection; use action_log::ActionLog; use agent_client_protocol::{self as acp}; +use agent_settings::AgentProfileId; use anyhow::Result; use client::{Client, UserStore}; use fs::{FakeFs, Fs}; @@ -165,7 +166,9 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { } else { false } - }) + }), + "{}", + thread.to_markdown() ); }); } @@ -469,6 +472,82 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_profiles(cx: &mut TestAppContext) { + let ThreadTest { + model, thread, fs, .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread.update(cx, |thread, _cx| { + thread.add_tool(DelayTool); + thread.add_tool(EchoTool); + thread.add_tool(InfiniteTool); + }); + + // Override profiles and wait for settings to be loaded. + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "profiles": { + "test-1": { + "name": "Test Profile 1", + "tools": { + EchoTool.name(): true, + DelayTool.name(): true, + } + }, + "test-2": { + "name": "Test Profile 2", + "tools": { + InfiniteTool.name(): true, + } + } + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + cx.run_until_parked(); + + // Test that test-1 profile (default) has echo and delay tools + thread.update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test-1".into())); + thread.send("test", cx); + }); + cx.run_until_parked(); + + let mut pending_completions = fake_model.pending_completions(); + assert_eq!(pending_completions.len(), 1); + let completion = pending_completions.pop().unwrap(); + let tool_names: Vec = completion + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect(); + assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]); + fake_model.end_last_completion_stream(); + + // Switch to test-2 profile, and verify that it has only the infinite tool. + thread.update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test-2".into())); + thread.send("test2", cx) + }); + cx.run_until_parked(); + let mut pending_completions = fake_model.pending_completions(); + assert_eq!(pending_completions.len(), 1); + let completion = pending_completions.pop().unwrap(); + let tool_names: Vec = completion + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect(); + assert_eq!(tool_names, vec![InfiniteTool.name()]); +} + #[gpui::test] #[ignore = "can't run on CI yet"] async fn test_cancellation(cx: &mut TestAppContext) { @@ -595,6 +674,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { language_models::init(user_store.clone(), client.clone(), cx); Project::init_settings(cx); LanguageModelRegistry::test(cx); + agent_settings::init(cx); }); cx.executor().forbid_parking(); @@ -790,6 +870,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { id: acp::ToolCallId("1".into()), fields: acp::ToolCallUpdateFields { status: Some(acp::ToolCallStatus::Completed), + raw_output: Some("Finished thinking.".into()), ..Default::default() }, } @@ -813,6 +894,7 @@ struct ThreadTest { model: Arc, thread: Entity, project_context: Rc>, + fs: Arc, } enum TestModel { @@ -835,30 +917,57 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { cx.executor().allow_parking(); let fs = FakeFs::new(cx.background_executor.clone()); + fs.create_dir(paths::settings_file().parent().unwrap()) + .await + .unwrap(); + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "default_profile": "test-profile", + "profiles": { + "test-profile": { + "name": "Test Profile", + "tools": { + EchoTool.name(): true, + DelayTool.name(): true, + WordListTool.name(): true, + ToolRequiringPermission.name(): true, + InfiniteTool.name(): true, + } + } + } + } + }) + .to_string() + .into_bytes(), + ) + .await; cx.update(|cx| { settings::init(cx); - watch_settings(fs.clone(), cx); Project::init_settings(cx); agent_settings::init(cx); + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + watch_settings(fs.clone(), cx); }); + let templates = Templates::new(); fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let model = cx .update(|cx| { - gpui_tokio::init(cx); - let http_client = ReqwestClient::user_agent("agent tests").unwrap(); - cx.set_http_client(Arc::new(http_client)); - - client::init_settings(cx); - let client = Client::production(cx); - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), cx); - if let TestModel::Fake = model { Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>) } else { @@ -881,20 +990,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { .await; let project_context = Rc::new(RefCell::new(ProjectContext::default())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, project_context.clone(), + context_server_registry, action_log, templates, model.clone(), + cx, ) }); ThreadTest { model, thread, project_context, + fs, } } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 23a0f7972d..231f83ce20 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,7 +1,7 @@ -use crate::{SystemPromptTemplate, Template, Templates}; +use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; use action_log::ActionLog; use agent_client_protocol as acp; -use agent_settings::AgentSettings; +use agent_settings::{AgentProfileId, AgentSettings}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use cloud_llm_client::{CompletionIntent, CompletionMode}; @@ -126,6 +126,8 @@ pub struct Thread { running_turn: Option>, pending_tool_uses: HashMap, tools: BTreeMap>, + context_server_registry: Entity, + profile_id: AgentProfileId, project_context: Rc>, templates: Arc, pub selected_model: Arc, @@ -137,16 +139,21 @@ impl Thread { pub fn new( project: Entity, project_context: Rc>, + context_server_registry: Entity, action_log: Entity, templates: Arc, default_model: Arc, + cx: &mut Context, ) -> Self { + let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, pending_tool_uses: HashMap::default(), tools: BTreeMap::default(), + context_server_registry, + profile_id, project_context, templates, selected_model: default_model, @@ -179,6 +186,10 @@ impl Thread { self.tools.remove(name).is_some() } + pub fn set_profile(&mut self, profile_id: AgentProfileId) { + self.profile_id = profile_id; + } + pub fn cancel(&mut self) { self.running_turn.take(); @@ -298,6 +309,7 @@ impl Thread { } else { acp::ToolCallStatus::Completed }), + raw_output: tool_result.output.clone(), ..Default::default() }, ); @@ -604,21 +616,23 @@ impl Thread { let messages = self.build_request_messages(); log::info!("Request will include {} messages", messages.len()); - let tools: Vec = self - .tools - .values() - .filter_map(|tool| { - let tool_name = tool.name().to_string(); - log::trace!("Including tool: {}", tool_name); - Some(LanguageModelRequestTool { - name: tool_name, - description: tool.description(cx).to_string(), - input_schema: tool - .input_schema(self.selected_model.tool_input_format()) - .log_err()?, + let tools = if let Some(tools) = self.tools(cx).log_err() { + tools + .filter_map(|tool| { + let tool_name = tool.name().to_string(); + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name, + description: tool.description().to_string(), + input_schema: tool + .input_schema(self.selected_model.tool_input_format()) + .log_err()?, + }) }) - }) - .collect(); + .collect() + } else { + Vec::new() + }; log::info!("Request includes {} tools", tools.len()); @@ -639,6 +653,35 @@ impl Thread { request } + fn tools<'a>(&'a self, cx: &'a App) -> Result>> { + let profile = AgentSettings::get_global(cx) + .profiles + .get(&self.profile_id) + .context("profile not found")?; + + Ok(self + .tools + .iter() + .filter_map(|(tool_name, tool)| { + if profile.is_tool_enabled(tool_name) { + Some(tool) + } else { + None + } + }) + .chain(self.context_server_registry.read(cx).servers().flat_map( + |(server_id, tools)| { + tools.iter().filter_map(|(tool_name, tool)| { + if profile.is_context_server_tool_enabled(&server_id.0, tool_name) { + Some(tool) + } else { + None + } + }) + }, + ))) + } + fn build_request_messages(&self) -> Vec { log::trace!( "Building request messages from {} thread messages", @@ -686,7 +729,7 @@ where fn name(&self) -> SharedString; - fn description(&self, _cx: &mut App) -> SharedString { + fn description(&self) -> SharedString { let schema = schemars::schema_for!(Self::Input); SharedString::new( schema @@ -722,13 +765,13 @@ where pub struct Erased(T); pub struct AgentToolOutput { - llm_output: LanguageModelToolResultContent, - raw_output: serde_json::Value, + pub llm_output: LanguageModelToolResultContent, + pub raw_output: serde_json::Value, } pub trait AnyAgentTool { fn name(&self) -> SharedString; - fn description(&self, cx: &mut App) -> SharedString; + fn description(&self) -> SharedString; fn kind(&self) -> acp::ToolKind; fn initial_title(&self, input: serde_json::Value) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; @@ -748,8 +791,8 @@ where self.0.name() } - fn description(&self, cx: &mut App) -> SharedString { - self.0.description(cx) + fn description(&self) -> SharedString { + self.0.description() } fn kind(&self) -> agent_client_protocol::ToolKind { diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs index 8896b14538..d1f2b3b1c7 100644 --- a/crates/agent2/src/tools.rs +++ b/crates/agent2/src/tools.rs @@ -1,3 +1,4 @@ +mod context_server_registry; mod copy_path_tool; mod create_directory_tool; mod delete_path_tool; @@ -15,6 +16,7 @@ mod terminal_tool; mod thinking_tool; mod web_search_tool; +pub use context_server_registry::*; pub use copy_path_tool::*; pub use create_directory_tool::*; pub use delete_path_tool::*; diff --git a/crates/agent2/src/tools/context_server_registry.rs b/crates/agent2/src/tools/context_server_registry.rs new file mode 100644 index 0000000000..db39e9278c --- /dev/null +++ b/crates/agent2/src/tools/context_server_registry.rs @@ -0,0 +1,231 @@ +use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream}; +use agent_client_protocol::ToolKind; +use anyhow::{Result, anyhow, bail}; +use collections::{BTreeMap, HashMap}; +use context_server::ContextServerId; +use gpui::{App, Context, Entity, SharedString, Task}; +use project::context_server_store::{ContextServerStatus, ContextServerStore}; +use std::sync::Arc; +use util::ResultExt; + +pub struct ContextServerRegistry { + server_store: Entity, + registered_servers: HashMap, + _subscription: gpui::Subscription, +} + +struct RegisteredContextServer { + tools: BTreeMap>, + load_tools: Task>, +} + +impl ContextServerRegistry { + pub fn new(server_store: Entity, cx: &mut Context) -> Self { + let mut this = Self { + server_store: server_store.clone(), + registered_servers: HashMap::default(), + _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event), + }; + for server in server_store.read(cx).running_servers() { + this.reload_tools_for_server(server.id(), cx); + } + this + } + + pub fn servers( + &self, + ) -> impl Iterator< + Item = ( + &ContextServerId, + &BTreeMap>, + ), + > { + self.registered_servers + .iter() + .map(|(id, server)| (id, &server.tools)) + } + + fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context) { + let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else { + return; + }; + let Some(client) = server.client() else { + return; + }; + if !client.capable(context_server::protocol::ServerCapability::Tools) { + return; + } + + let registered_server = + self.registered_servers + .entry(server_id.clone()) + .or_insert(RegisteredContextServer { + tools: BTreeMap::default(), + load_tools: Task::ready(Ok(())), + }); + registered_server.load_tools = cx.spawn(async move |this, cx| { + let response = client + .request::(()) + .await; + + this.update(cx, |this, cx| { + let Some(registered_server) = this.registered_servers.get_mut(&server_id) else { + return; + }; + + registered_server.tools.clear(); + if let Some(response) = response.log_err() { + for tool in response.tools { + let tool = Arc::new(ContextServerTool::new( + this.server_store.clone(), + server.id(), + tool, + )); + registered_server.tools.insert(tool.name(), tool); + } + cx.notify(); + } + }) + }); + } + + fn handle_context_server_store_event( + &mut self, + _: Entity, + event: &project::context_server_store::Event, + cx: &mut Context, + ) { + match event { + project::context_server_store::Event::ServerStatusChanged { server_id, status } => { + match status { + ContextServerStatus::Starting => {} + ContextServerStatus::Running => { + self.reload_tools_for_server(server_id.clone(), cx); + } + ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { + self.registered_servers.remove(&server_id); + cx.notify(); + } + } + } + } + } +} + +struct ContextServerTool { + store: Entity, + server_id: ContextServerId, + tool: context_server::types::Tool, +} + +impl ContextServerTool { + fn new( + store: Entity, + server_id: ContextServerId, + tool: context_server::types::Tool, + ) -> Self { + Self { + store, + server_id, + tool, + } + } +} + +impl AnyAgentTool for ContextServerTool { + fn name(&self) -> SharedString { + self.tool.name.clone().into() + } + + fn description(&self) -> SharedString { + self.tool.description.clone().unwrap_or_default().into() + } + + fn kind(&self) -> ToolKind { + ToolKind::Other + } + + fn initial_title(&self, _input: serde_json::Value) -> SharedString { + format!("Run MCP tool `{}`", self.tool.name).into() + } + + fn input_schema( + &self, + format: language_model::LanguageModelToolSchemaFormat, + ) -> Result { + let mut schema = self.tool.input_schema.clone(); + assistant_tool::adapt_schema_to_format(&mut schema, format)?; + Ok(match schema { + serde_json::Value::Null => { + serde_json::json!({ "type": "object", "properties": [] }) + } + serde_json::Value::Object(map) if map.is_empty() => { + serde_json::json!({ "type": "object", "properties": [] }) + } + _ => schema, + }) + } + + fn run( + self: Arc, + input: serde_json::Value, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else { + return Task::ready(Err(anyhow!("Context server not found"))); + }; + let tool_name = self.tool.name.clone(); + let server_clone = server.clone(); + let input_clone = input.clone(); + + cx.spawn(async move |_cx| { + let Some(protocol) = server_clone.client() else { + bail!("Context server not initialized"); + }; + + let arguments = if let serde_json::Value::Object(map) = input_clone { + Some(map.into_iter().collect()) + } else { + None + }; + + log::trace!( + "Running tool: {} with arguments: {:?}", + tool_name, + arguments + ); + let response = protocol + .request::( + context_server::types::CallToolParams { + name: tool_name, + arguments, + meta: None, + }, + ) + .await?; + + let mut result = String::new(); + for content in response.content { + match content { + context_server::types::ToolResponseContent::Text { text } => { + result.push_str(&text); + } + context_server::types::ToolResponseContent::Image { .. } => { + log::warn!("Ignoring image content from tool response"); + } + context_server::types::ToolResponseContent::Audio { .. } => { + log::warn!("Ignoring audio content from tool response"); + } + context_server::types::ToolResponseContent::Resource { .. } => { + log::warn!("Ignoring resource content from tool response"); + } + } + } + Ok(AgentToolOutput { + raw_output: result.clone().into(), + llm_output: result.into(), + }) + }) + } +} diff --git a/crates/agent2/src/tools/diagnostics_tool.rs b/crates/agent2/src/tools/diagnostics_tool.rs index bd0b20df5a..6ba8b7b377 100644 --- a/crates/agent2/src/tools/diagnostics_tool.rs +++ b/crates/agent2/src/tools/diagnostics_tool.rs @@ -85,7 +85,7 @@ impl AgentTool for DiagnosticsTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { match input.path { @@ -119,11 +119,6 @@ impl AgentTool for DiagnosticsTool { range.start.row + 1, entry.diagnostic.message )?; - - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![output.clone().into()]), - ..Default::default() - }); } if output.is_empty() { @@ -158,18 +153,9 @@ impl AgentTool for DiagnosticsTool { } if has_diagnostics { - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![output.clone().into()]), - ..Default::default() - }); Task::ready(Ok(output)) } else { - let text = "No errors or warnings found in the project."; - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![text.into()]), - ..Default::default() - }); - Task::ready(Ok(text.into())) + Task::ready(Ok("No errors or warnings found in the project.".into())) } } } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 88764d1953..134bc5e5e4 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -454,9 +454,8 @@ fn resolve_path( #[cfg(test)] mod tests { - use crate::Templates; - use super::*; + use crate::{ContextServerRegistry, Templates}; use action_log::ActionLog; use client::TelemetrySettings; use fs::Fs; @@ -475,9 +474,20 @@ mod tests { fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = - cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model)); + let thread = cx.new(|cx| { + Thread::new( + project, + Rc::default(), + context_server_registry, + action_log, + Templates::new(), + model, + cx, + ) + }); let result = cx .update(|cx| { let input = EditFileToolInput { @@ -661,14 +671,18 @@ mod tests { }); let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); @@ -792,15 +806,19 @@ mod tests { .unwrap(); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); @@ -914,15 +932,19 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1041,15 +1063,19 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1148,14 +1174,18 @@ mod tests { .await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry.clone(), action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1225,14 +1255,18 @@ mod tests { .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry.clone(), action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1305,14 +1339,18 @@ mod tests { .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry.clone(), action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1382,14 +1420,18 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); diff --git a/crates/agent2/src/tools/fetch_tool.rs b/crates/agent2/src/tools/fetch_tool.rs index 7f3752843c..ae26c5fe19 100644 --- a/crates/agent2/src/tools/fetch_tool.rs +++ b/crates/agent2/src/tools/fetch_tool.rs @@ -136,7 +136,7 @@ impl AgentTool for FetchTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let text = cx.background_spawn({ @@ -149,12 +149,6 @@ impl AgentTool for FetchTool { if text.trim().is_empty() { bail!("no textual content found"); } - - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![text.clone().into()]), - ..Default::default() - }); - Ok(text) }) } diff --git a/crates/agent2/src/tools/find_path_tool.rs b/crates/agent2/src/tools/find_path_tool.rs index 611d34e701..552de144a7 100644 --- a/crates/agent2/src/tools/find_path_tool.rs +++ b/crates/agent2/src/tools/find_path_tool.rs @@ -139,9 +139,6 @@ impl AgentTool for FindPathTool { }) .collect(), ), - raw_output: Some(serde_json::json!({ - "paths": &matches, - })), ..Default::default() }); diff --git a/crates/agent2/src/tools/grep_tool.rs b/crates/agent2/src/tools/grep_tool.rs index 3266cb5734..e5d92b3c1d 100644 --- a/crates/agent2/src/tools/grep_tool.rs +++ b/crates/agent2/src/tools/grep_tool.rs @@ -101,7 +101,7 @@ impl AgentTool for GrepTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { const CONTEXT_LINES: u32 = 2; @@ -282,33 +282,22 @@ impl AgentTool for GrepTool { } } - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![output.clone().into()]), - ..Default::default() - }); matches_found += 1; } } - let output = if matches_found == 0 { - "No matches found".to_string() + if matches_found == 0 { + Ok("No matches found".into()) } else if has_more_matches { - format!( + Ok(format!( "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}", input.offset + 1, input.offset + matches_found, input.offset + RESULTS_PER_PAGE, - ) + )) } else { - format!("Found {matches_found} matches:\n{output}") - }; - - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![output.clone().into()]), - ..Default::default() - }); - - Ok(output) + Ok(format!("Found {matches_found} matches:\n{output}")) + } }) } } diff --git a/crates/agent2/src/tools/now_tool.rs b/crates/agent2/src/tools/now_tool.rs index 71698b8275..a72ede26fe 100644 --- a/crates/agent2/src/tools/now_tool.rs +++ b/crates/agent2/src/tools/now_tool.rs @@ -47,20 +47,13 @@ impl AgentTool for NowTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, _cx: &mut App, ) -> Task> { let now = match input.timezone { Timezone::Utc => Utc::now().to_rfc3339(), Timezone::Local => Local::now().to_rfc3339(), }; - let content = format!("The current datetime is {now}."); - - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![content.clone().into()]), - ..Default::default() - }); - - Task::ready(Ok(content)) + Task::ready(Ok(format!("The current datetime is {now}."))) } } diff --git a/crates/agent_settings/src/agent_profile.rs b/crates/agent_settings/src/agent_profile.rs index a6b8633b34..402cf81678 100644 --- a/crates/agent_settings/src/agent_profile.rs +++ b/crates/agent_settings/src/agent_profile.rs @@ -48,6 +48,20 @@ pub struct AgentProfileSettings { pub context_servers: IndexMap, ContextServerPreset>, } +impl AgentProfileSettings { + pub fn is_tool_enabled(&self, tool_name: &str) -> bool { + self.tools.get(tool_name) == Some(&true) + } + + pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool { + self.enable_all_context_servers + || self + .context_servers + .get(server_id) + .map_or(false, |preset| preset.tools.get(tool_name) == Some(&true)) + } +} + #[derive(Debug, Clone, Default)] pub struct ContextServerPreset { pub tools: IndexMap, bool>,