acp: Support calling tools provided by MCP servers (#36752)

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2025-08-22 15:16:42 +02:00 committed by GitHub
parent 3b7c1744b4
commit 4f0fad6996
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 561 additions and 64 deletions

View file

@ -4,26 +4,35 @@ use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId; use agent_settings::AgentProfileId;
use anyhow::Result; use anyhow::Result;
use client::{Client, UserStore}; use client::{Client, UserStore};
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use fs::{FakeFs, Fs}; use fs::{FakeFs, Fs};
use futures::{StreamExt, channel::mpsc::UnboundedReceiver}; use futures::{
StreamExt,
channel::{
mpsc::{self, UnboundedReceiver},
oneshot,
},
};
use gpui::{ use gpui::{
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient, App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
}; };
use indoc::indoc; use indoc::indoc;
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
fake_provider::FakeLanguageModel, LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
}; };
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use project::Project; use project::{
Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
};
use prompt_store::ProjectContext; use prompt_store::ProjectContext;
use reqwest_client::ReqwestClient; use reqwest_client::ReqwestClient;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use settings::SettingsStore; use settings::{Settings, SettingsStore};
use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path; use util::path;
@ -931,6 +940,334 @@ async fn test_profiles(cx: &mut TestAppContext) {
assert_eq!(tool_names, vec![InfiniteTool::name()]); assert_eq!(tool_names, vec![InfiniteTool::name()]);
} }
#[gpui::test]
async fn test_mcp_tools(cx: &mut TestAppContext) {
let ThreadTest {
model,
thread,
context_server_store,
fs,
..
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
// Override profiles and wait for settings to be loaded.
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"profiles": {
"test": {
"name": "Test Profile",
"enable_all_context_servers": true,
"tools": {
EchoTool::name(): true,
}
},
}
}
})
.to_string()
.into_bytes(),
)
.await;
cx.run_until_parked();
thread.update(cx, |thread, _| {
thread.set_profile(AgentProfileId("test".into()))
});
let mut mcp_tool_calls = setup_context_server(
"test_server",
vec![context_server::types::Tool {
name: "echo".into(),
description: None,
input_schema: serde_json::to_value(
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
)
.unwrap(),
output_schema: None,
annotations: None,
}],
&context_server_store,
cx,
);
let events = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
});
cx.run_until_parked();
// Simulate the model calling the MCP tool.
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_1".into(),
name: "echo".into(),
raw_input: json!({"text": "test"}).to_string(),
input: json!({"text": "test"}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
assert_eq!(tool_call_params.name, "echo");
assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
tool_call_response
.send(context_server::types::CallToolResponse {
content: vec![context_server::types::ToolResponseContent::Text {
text: "test".into(),
}],
is_error: None,
meta: None,
structured_content: None,
})
.unwrap();
cx.run_until_parked();
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
fake_model.send_last_completion_stream_text_chunk("Done!");
fake_model.end_last_completion_stream();
events.collect::<Vec<_>>().await;
// Send again after adding the echo tool, ensuring the name collision is resolved.
let events = thread.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
});
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
tool_names_for_completion(&completion),
vec!["echo", "test_server_echo"]
);
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_2".into(),
name: "test_server_echo".into(),
raw_input: json!({"text": "mcp"}).to_string(),
input: json!({"text": "mcp"}),
is_input_complete: true,
},
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_3".into(),
name: "echo".into(),
raw_input: json!({"text": "native"}).to_string(),
input: json!({"text": "native"}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
assert_eq!(tool_call_params.name, "echo");
assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
tool_call_response
.send(context_server::types::CallToolResponse {
content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
is_error: None,
meta: None,
structured_content: None,
})
.unwrap();
cx.run_until_parked();
// Ensure the tool results were inserted with the correct names.
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages.last().unwrap().content,
vec![
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: "tool_3".into(),
tool_name: "echo".into(),
is_error: false,
content: "native".into(),
output: Some("native".into()),
},),
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: "tool_2".into(),
tool_name: "test_server_echo".into(),
is_error: false,
content: "mcp".into(),
output: Some("mcp".into()),
},),
]
);
fake_model.end_last_completion_stream();
events.collect::<Vec<_>>().await;
}
#[gpui::test]
async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
let ThreadTest {
model,
thread,
context_server_store,
fs,
..
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
// Set up a profile with all tools enabled
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"profiles": {
"test": {
"name": "Test Profile",
"enable_all_context_servers": true,
"tools": {
EchoTool::name(): true,
DelayTool::name(): true,
WordListTool::name(): true,
ToolRequiringPermission::name(): true,
InfiniteTool::name(): true,
}
},
}
}
})
.to_string()
.into_bytes(),
)
.await;
cx.run_until_parked();
thread.update(cx, |thread, _| {
thread.set_profile(AgentProfileId("test".into()));
thread.add_tool(EchoTool);
thread.add_tool(DelayTool);
thread.add_tool(WordListTool);
thread.add_tool(ToolRequiringPermission);
thread.add_tool(InfiniteTool);
});
// Set up multiple context servers with some overlapping tool names
let _server1_calls = setup_context_server(
"xxx",
vec![
context_server::types::Tool {
name: "echo".into(), // Conflicts with native EchoTool
description: None,
input_schema: serde_json::to_value(
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
)
.unwrap(),
output_schema: None,
annotations: None,
},
context_server::types::Tool {
name: "unique_tool_1".into(),
description: None,
input_schema: json!({"type": "object", "properties": {}}),
output_schema: None,
annotations: None,
},
],
&context_server_store,
cx,
);
let _server2_calls = setup_context_server(
"yyy",
vec![
context_server::types::Tool {
name: "echo".into(), // Also conflicts with native EchoTool
description: None,
input_schema: serde_json::to_value(
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
)
.unwrap(),
output_schema: None,
annotations: None,
},
context_server::types::Tool {
name: "unique_tool_2".into(),
description: None,
input_schema: json!({"type": "object", "properties": {}}),
output_schema: None,
annotations: None,
},
context_server::types::Tool {
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
description: None,
input_schema: json!({"type": "object", "properties": {}}),
output_schema: None,
annotations: None,
},
context_server::types::Tool {
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
description: None,
input_schema: json!({"type": "object", "properties": {}}),
output_schema: None,
annotations: None,
},
],
&context_server_store,
cx,
);
let _server3_calls = setup_context_server(
"zzz",
vec![
context_server::types::Tool {
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
description: None,
input_schema: json!({"type": "object", "properties": {}}),
output_schema: None,
annotations: None,
},
context_server::types::Tool {
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
description: None,
input_schema: json!({"type": "object", "properties": {}}),
output_schema: None,
annotations: None,
},
context_server::types::Tool {
name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
description: None,
input_schema: json!({"type": "object", "properties": {}}),
output_schema: None,
annotations: None,
},
],
&context_server_store,
cx,
);
thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Go"], cx)
})
.unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
tool_names_for_completion(&completion),
vec![
"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
"delay",
"echo",
"infinite",
"tool_requiring_permission",
"unique_tool_1",
"unique_tool_2",
"word_list",
"xxx_echo",
"y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
"yyy_echo",
"z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
]
);
}
#[gpui::test] #[gpui::test]
#[cfg_attr(not(feature = "e2e"), ignore)] #[cfg_attr(not(feature = "e2e"), ignore)]
async fn test_cancellation(cx: &mut TestAppContext) { async fn test_cancellation(cx: &mut TestAppContext) {
@ -1806,6 +2143,7 @@ struct ThreadTest {
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
thread: Entity<Thread>, thread: Entity<Thread>,
project_context: Entity<ProjectContext>, project_context: Entity<ProjectContext>,
context_server_store: Entity<ContextServerStore>,
fs: Arc<FakeFs>, fs: Arc<FakeFs>,
} }
@ -1844,6 +2182,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
WordListTool::name(): true, WordListTool::name(): true,
ToolRequiringPermission::name(): true, ToolRequiringPermission::name(): true,
InfiniteTool::name(): true, InfiniteTool::name(): true,
ThinkingTool::name(): true,
} }
} }
} }
@ -1900,8 +2239,9 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
.await; .await;
let project_context = cx.new(|_cx| ProjectContext::default()); let project_context = cx.new(|_cx| ProjectContext::default());
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project, project,
@ -1916,6 +2256,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
model, model,
thread, thread,
project_context, project_context,
context_server_store,
fs, fs,
} }
} }
@ -1950,3 +2291,89 @@ fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
}) })
.detach(); .detach();
} }
fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
completion
.tools
.iter()
.map(|tool| tool.name.clone())
.collect()
}
fn setup_context_server(
name: &'static str,
tools: Vec<context_server::types::Tool>,
context_server_store: &Entity<ContextServerStore>,
cx: &mut TestAppContext,
) -> mpsc::UnboundedReceiver<(
context_server::types::CallToolParams,
oneshot::Sender<context_server::types::CallToolResponse>,
)> {
cx.update(|cx| {
let mut settings = ProjectSettings::get_global(cx).clone();
settings.context_servers.insert(
name.into(),
project::project_settings::ContextServerSettings::Custom {
enabled: true,
command: ContextServerCommand {
path: "somebinary".into(),
args: Vec::new(),
env: None,
},
},
);
ProjectSettings::override_global(settings, cx);
});
let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
.on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
context_server::types::InitializeResponse {
protocol_version: context_server::types::ProtocolVersion(
context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
),
server_info: context_server::types::Implementation {
name: name.into(),
version: "1.0.0".to_string(),
},
capabilities: context_server::types::ServerCapabilities {
tools: Some(context_server::types::ToolsCapabilities {
list_changed: Some(true),
}),
..Default::default()
},
meta: None,
}
})
.on_request::<context_server::types::requests::ListTools, _>(move |_params| {
let tools = tools.clone();
async move {
context_server::types::ListToolsResponse {
tools,
next_cursor: None,
meta: None,
}
}
})
.on_request::<context_server::types::requests::CallTool, _>(move |params| {
let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
async move {
let (response_tx, response_rx) = oneshot::channel();
mcp_tool_calls_tx
.unbounded_send((params, response_tx))
.unwrap();
response_rx.await.unwrap()
}
});
context_server_store.update(cx, |store, cx| {
store.start_server(
Arc::new(ContextServer::new(
ContextServerId(name.into()),
Arc::new(fake_transport),
)),
cx,
);
});
cx.run_until_parked();
mcp_tool_calls_rx
}

View file

@ -9,15 +9,15 @@ use action_log::ActionLog;
use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot}; use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::{ use agent_settings::{
AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT, AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode,
SUMMARIZE_THREAD_PROMPT, SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
}; };
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format; use assistant_tool::adapt_schema_to_format;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage}; use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use collections::{HashMap, IndexMap}; use collections::{HashMap, HashSet, IndexMap};
use fs::Fs; use fs::Fs;
use futures::{ use futures::{
FutureExt, FutureExt,
@ -56,6 +56,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid; use uuid::Uuid;
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
pub const MAX_TOOL_NAME_LENGTH: usize = 64;
/// The ID of the user prompt that initiated a request. /// The ID of the user prompt that initiated a request.
/// ///
@ -627,7 +628,20 @@ impl Thread {
stream: &ThreadEventStream, stream: &ThreadEventStream,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let Some(tool) = self.tools.get(tool_use.name.as_ref()) else { let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| {
self.context_server_registry
.read(cx)
.servers()
.find_map(|(_, tools)| {
if let Some(tool) = tools.get(tool_use.name.as_ref()) {
Some(tool.clone())
} else {
None
}
})
});
let Some(tool) = tool else {
stream stream
.0 .0
.unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall { .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
@ -1079,6 +1093,10 @@ impl Thread {
self.cancel(cx); self.cancel(cx);
let model = self.model.clone().context("No language model configured")?; let model = self.model.clone().context("No language model configured")?;
let profile = AgentSettings::get_global(cx)
.profiles
.get(&self.profile_id)
.context("Profile not found")?;
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>(); let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
let event_stream = ThreadEventStream(events_tx); let event_stream = ThreadEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1); let message_ix = self.messages.len().saturating_sub(1);
@ -1086,6 +1104,7 @@ impl Thread {
self.summary = None; self.summary = None;
self.running_turn = Some(RunningTurn { self.running_turn = Some(RunningTurn {
event_stream: event_stream.clone(), event_stream: event_stream.clone(),
tools: self.enabled_tools(profile, &model, cx),
_task: cx.spawn(async move |this, cx| { _task: cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution"); log::info!("Starting agent turn execution");
@ -1417,7 +1436,7 @@ impl Thread {
) -> Option<Task<LanguageModelToolResult>> { ) -> Option<Task<LanguageModelToolResult>> {
cx.notify(); cx.notify();
let tool = self.tools.get(tool_use.name.as_ref()).cloned(); let tool = self.tool(tool_use.name.as_ref());
let mut title = SharedString::from(&tool_use.name); let mut title = SharedString::from(&tool_use.name);
let mut kind = acp::ToolKind::Other; let mut kind = acp::ToolKind::Other;
if let Some(tool) = tool.as_ref() { if let Some(tool) = tool.as_ref() {
@ -1727,6 +1746,21 @@ impl Thread {
cx: &mut App, cx: &mut App,
) -> Result<LanguageModelRequest> { ) -> Result<LanguageModelRequest> {
let model = self.model().context("No language model configured")?; let model = self.model().context("No language model configured")?;
let tools = if let Some(turn) = self.running_turn.as_ref() {
turn.tools
.iter()
.filter_map(|(tool_name, tool)| {
log::trace!("Including tool: {}", tool_name);
Some(LanguageModelRequestTool {
name: tool_name.to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
})
})
.collect::<Vec<_>>()
} else {
Vec::new()
};
log::debug!("Building completion request"); log::debug!("Building completion request");
log::debug!("Completion intent: {:?}", completion_intent); log::debug!("Completion intent: {:?}", completion_intent);
@ -1734,23 +1768,6 @@ impl Thread {
let messages = self.build_request_messages(cx); let messages = self.build_request_messages(cx);
log::info!("Request will include {} messages", messages.len()); log::info!("Request will include {} messages", messages.len());
let tools = if let Some(tools) = self.tools(cx).log_err() {
tools
.filter_map(|tool| {
let tool_name = tool.name().to_string();
log::trace!("Including tool: {}", tool_name);
Some(LanguageModelRequestTool {
name: tool_name,
description: tool.description().to_string(),
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
})
})
.collect()
} else {
Vec::new()
};
log::info!("Request includes {} tools", tools.len()); log::info!("Request includes {} tools", tools.len());
let request = LanguageModelRequest { let request = LanguageModelRequest {
@ -1770,37 +1787,76 @@ impl Thread {
Ok(request) Ok(request)
} }
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> { fn enabled_tools(
let model = self.model().context("No language model configured")?; &self,
profile: &AgentProfileSettings,
model: &Arc<dyn LanguageModel>,
cx: &App,
) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
fn truncate(tool_name: &SharedString) -> SharedString {
if tool_name.len() > MAX_TOOL_NAME_LENGTH {
let mut truncated = tool_name.to_string();
truncated.truncate(MAX_TOOL_NAME_LENGTH);
truncated.into()
} else {
tool_name.clone()
}
}
let profile = AgentSettings::get_global(cx) let mut tools = self
.profiles
.get(&self.profile_id)
.context("profile not found")?;
let provider_id = model.provider_id();
Ok(self
.tools .tools
.iter() .iter()
.filter(move |(_, tool)| tool.supported_provider(&provider_id))
.filter_map(|(tool_name, tool)| { .filter_map(|(tool_name, tool)| {
if profile.is_tool_enabled(tool_name) { if tool.supported_provider(&model.provider_id())
Some(tool) && profile.is_tool_enabled(tool_name)
{
Some((truncate(tool_name), tool.clone()))
} else { } else {
None None
} }
}) })
.chain(self.context_server_registry.read(cx).servers().flat_map( .collect::<BTreeMap<_, _>>();
|(server_id, tools)| {
tools.iter().filter_map(|(tool_name, tool)| { let mut context_server_tools = Vec::new();
if profile.is_context_server_tool_enabled(&server_id.0, tool_name) { let mut seen_tools = tools.keys().cloned().collect::<HashSet<_>>();
Some(tool) let mut duplicate_tool_names = HashSet::default();
} else { for (server_id, server_tools) in self.context_server_registry.read(cx).servers() {
None for (tool_name, tool) in server_tools {
} if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) {
}) let tool_name = truncate(tool_name);
}, if !seen_tools.insert(tool_name.clone()) {
))) duplicate_tool_names.insert(tool_name.clone());
}
context_server_tools.push((server_id.clone(), tool_name, tool.clone()));
}
}
}
// When there are duplicate tool names, disambiguate by prefixing them
// with the server ID. In the rare case there isn't enough space for the
// disambiguated tool name, keep only the last tool with this name.
for (server_id, tool_name, tool) in context_server_tools {
if duplicate_tool_names.contains(&tool_name) {
let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len());
if available >= 2 {
let mut disambiguated = server_id.0.to_string();
disambiguated.truncate(available - 1);
disambiguated.push('_');
disambiguated.push_str(&tool_name);
tools.insert(disambiguated.into(), tool.clone());
} else {
tools.insert(tool_name, tool.clone());
}
} else {
tools.insert(tool_name, tool.clone());
}
}
tools
}
fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
self.running_turn.as_ref()?.tools.get(name).cloned()
} }
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> { fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
@ -1965,6 +2021,8 @@ struct RunningTurn {
/// The current event stream for the running turn. Used to report a final /// The current event stream for the running turn. Used to report a final
/// cancellation event if we cancel the turn. /// cancellation event if we cancel the turn.
event_stream: ThreadEventStream, event_stream: ThreadEventStream,
/// The tools that were enabled for this turn.
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
} }
impl RunningTurn { impl RunningTurn {

View file

@ -1,6 +1,6 @@
use anyhow::Context as _; use anyhow::Context as _;
use collections::HashMap; use collections::HashMap;
use futures::{Stream, StreamExt as _, lock::Mutex}; use futures::{FutureExt, Stream, StreamExt as _, future::BoxFuture, lock::Mutex};
use gpui::BackgroundExecutor; use gpui::BackgroundExecutor;
use std::{pin::Pin, sync::Arc}; use std::{pin::Pin, sync::Arc};
@ -14,9 +14,12 @@ pub fn create_fake_transport(
executor: BackgroundExecutor, executor: BackgroundExecutor,
) -> FakeTransport { ) -> FakeTransport {
let name = name.into(); let name = name.into();
FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| { FakeTransport::new(executor).on_request::<crate::types::requests::Initialize, _>(
create_initialize_response(name.clone()) move |_params| {
}) let name = name.clone();
async move { create_initialize_response(name.clone()) }
},
)
} }
fn create_initialize_response(server_name: String) -> InitializeResponse { fn create_initialize_response(server_name: String) -> InitializeResponse {
@ -32,8 +35,10 @@ fn create_initialize_response(server_name: String) -> InitializeResponse {
} }
pub struct FakeTransport { pub struct FakeTransport {
request_handlers: request_handlers: HashMap<
HashMap<&'static str, Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>>, &'static str,
Arc<dyn Send + Sync + Fn(serde_json::Value) -> BoxFuture<'static, serde_json::Value>>,
>,
tx: futures::channel::mpsc::UnboundedSender<String>, tx: futures::channel::mpsc::UnboundedSender<String>,
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>, rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
executor: BackgroundExecutor, executor: BackgroundExecutor,
@ -50,18 +55,25 @@ impl FakeTransport {
} }
} }
pub fn on_request<T: crate::types::Request>( pub fn on_request<T, Fut>(
mut self, mut self,
handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static, handler: impl 'static + Send + Sync + Fn(T::Params) -> Fut,
) -> Self { ) -> Self
where
T: crate::types::Request,
Fut: 'static + Send + Future<Output = T::Response>,
{
self.request_handlers.insert( self.request_handlers.insert(
T::METHOD, T::METHOD,
Arc::new(move |value| { Arc::new(move |value| {
let params = value.get("params").expect("Missing parameters").clone(); let params = value
.get("params")
.cloned()
.unwrap_or(serde_json::Value::Null);
let params: T::Params = let params: T::Params =
serde_json::from_value(params).expect("Invalid parameters received"); serde_json::from_value(params).expect("Invalid parameters received");
let response = handler(params); let response = handler(params);
serde_json::to_value(response).unwrap() async move { serde_json::to_value(response.await).unwrap() }.boxed()
}), }),
); );
self self
@ -77,7 +89,7 @@ impl Transport for FakeTransport {
if let Some(method) = msg.get("method") { if let Some(method) = msg.get("method") {
let method = method.as_str().expect("Invalid method received"); let method = method.as_str().expect("Invalid method received");
if let Some(handler) = self.request_handlers.get(method) { if let Some(handler) = self.request_handlers.get(method) {
let payload = handler(msg); let payload = handler(msg).await;
let response = serde_json::json!({ let response = serde_json::json!({
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": id, "id": id,