acp: Support calling tools provided by MCP servers (#36752)
Release Notes: - N/A
This commit is contained in:
parent
3b7c1744b4
commit
4f0fad6996
3 changed files with 561 additions and 64 deletions
|
@ -4,26 +4,35 @@ use agent_client_protocol::{self as acp};
|
|||
use agent_settings::AgentProfileId;
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||
use fs::{FakeFs, Fs};
|
||||
use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
|
||||
use futures::{
|
||||
StreamExt,
|
||||
channel::{
|
||||
mpsc::{self, UnboundedReceiver},
|
||||
oneshot,
|
||||
},
|
||||
};
|
||||
use gpui::{
|
||||
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
|
||||
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
fake_provider::FakeLanguageModel,
|
||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
|
||||
};
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::Project;
|
||||
use project::{
|
||||
Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
|
||||
};
|
||||
use prompt_store::ProjectContext;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use util::path;
|
||||
|
||||
|
@ -931,6 +940,334 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
|||
assert_eq!(tool_names, vec![InfiniteTool::name()]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_mcp_tools(cx: &mut TestAppContext) {
|
||||
let ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
context_server_store,
|
||||
fs,
|
||||
..
|
||||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
// Override profiles and wait for settings to be loaded.
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"profiles": {
|
||||
"test": {
|
||||
"name": "Test Profile",
|
||||
"enable_all_context_servers": true,
|
||||
"tools": {
|
||||
EchoTool::name(): true,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
cx.run_until_parked();
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_profile(AgentProfileId("test".into()))
|
||||
});
|
||||
|
||||
let mut mcp_tool_calls = setup_context_server(
|
||||
"test_server",
|
||||
vec![context_server::types::Tool {
|
||||
name: "echo".into(),
|
||||
description: None,
|
||||
input_schema: serde_json::to_value(
|
||||
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
|
||||
)
|
||||
.unwrap(),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
}],
|
||||
&context_server_store,
|
||||
cx,
|
||||
);
|
||||
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
// Simulate the model calling the MCP tool.
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_1".into(),
|
||||
name: "echo".into(),
|
||||
raw_input: json!({"text": "test"}).to_string(),
|
||||
input: json!({"text": "test"}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
||||
assert_eq!(tool_call_params.name, "echo");
|
||||
assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
|
||||
tool_call_response
|
||||
.send(context_server::types::CallToolResponse {
|
||||
content: vec![context_server::types::ToolResponseContent::Text {
|
||||
text: "test".into(),
|
||||
}],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
structured_content: None,
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
|
||||
fake_model.send_last_completion_stream_text_chunk("Done!");
|
||||
fake_model.end_last_completion_stream();
|
||||
events.collect::<Vec<_>>().await;
|
||||
|
||||
// Send again after adding the echo tool, ensuring the name collision is resolved.
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
tool_names_for_completion(&completion),
|
||||
vec!["echo", "test_server_echo"]
|
||||
);
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_2".into(),
|
||||
name: "test_server_echo".into(),
|
||||
raw_input: json!({"text": "mcp"}).to_string(),
|
||||
input: json!({"text": "mcp"}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_3".into(),
|
||||
name: "echo".into(),
|
||||
raw_input: json!({"text": "native"}).to_string(),
|
||||
input: json!({"text": "native"}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
||||
assert_eq!(tool_call_params.name, "echo");
|
||||
assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
|
||||
tool_call_response
|
||||
.send(context_server::types::CallToolResponse {
|
||||
content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
structured_content: None,
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
// Ensure the tool results were inserted with the correct names.
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
completion.messages.last().unwrap().content,
|
||||
vec![
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: "tool_3".into(),
|
||||
tool_name: "echo".into(),
|
||||
is_error: false,
|
||||
content: "native".into(),
|
||||
output: Some("native".into()),
|
||||
},),
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: "tool_2".into(),
|
||||
tool_name: "test_server_echo".into(),
|
||||
is_error: false,
|
||||
content: "mcp".into(),
|
||||
output: Some("mcp".into()),
|
||||
},),
|
||||
]
|
||||
);
|
||||
fake_model.end_last_completion_stream();
|
||||
events.collect::<Vec<_>>().await;
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
|
||||
let ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
context_server_store,
|
||||
fs,
|
||||
..
|
||||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
// Set up a profile with all tools enabled
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"profiles": {
|
||||
"test": {
|
||||
"name": "Test Profile",
|
||||
"enable_all_context_servers": true,
|
||||
"tools": {
|
||||
EchoTool::name(): true,
|
||||
DelayTool::name(): true,
|
||||
WordListTool::name(): true,
|
||||
ToolRequiringPermission::name(): true,
|
||||
InfiniteTool::name(): true,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
cx.run_until_parked();
|
||||
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_profile(AgentProfileId("test".into()));
|
||||
thread.add_tool(EchoTool);
|
||||
thread.add_tool(DelayTool);
|
||||
thread.add_tool(WordListTool);
|
||||
thread.add_tool(ToolRequiringPermission);
|
||||
thread.add_tool(InfiniteTool);
|
||||
});
|
||||
|
||||
// Set up multiple context servers with some overlapping tool names
|
||||
let _server1_calls = setup_context_server(
|
||||
"xxx",
|
||||
vec![
|
||||
context_server::types::Tool {
|
||||
name: "echo".into(), // Conflicts with native EchoTool
|
||||
description: None,
|
||||
input_schema: serde_json::to_value(
|
||||
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
|
||||
)
|
||||
.unwrap(),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "unique_tool_1".into(),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
],
|
||||
&context_server_store,
|
||||
cx,
|
||||
);
|
||||
|
||||
let _server2_calls = setup_context_server(
|
||||
"yyy",
|
||||
vec![
|
||||
context_server::types::Tool {
|
||||
name: "echo".into(), // Also conflicts with native EchoTool
|
||||
description: None,
|
||||
input_schema: serde_json::to_value(
|
||||
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
|
||||
)
|
||||
.unwrap(),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "unique_tool_2".into(),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
],
|
||||
&context_server_store,
|
||||
cx,
|
||||
);
|
||||
let _server3_calls = setup_context_server(
|
||||
"zzz",
|
||||
vec![
|
||||
context_server::types::Tool {
|
||||
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
],
|
||||
&context_server_store,
|
||||
cx,
|
||||
);
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Go"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
tool_names_for_completion(&completion),
|
||||
vec![
|
||||
"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
|
||||
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
|
||||
"delay",
|
||||
"echo",
|
||||
"infinite",
|
||||
"tool_requiring_permission",
|
||||
"unique_tool_1",
|
||||
"unique_tool_2",
|
||||
"word_list",
|
||||
"xxx_echo",
|
||||
"y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
"yyy_echo",
|
||||
"z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
|
@ -1806,6 +2143,7 @@ struct ThreadTest {
|
|||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
fs: Arc<FakeFs>,
|
||||
}
|
||||
|
||||
|
@ -1844,6 +2182,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
WordListTool::name(): true,
|
||||
ToolRequiringPermission::name(): true,
|
||||
InfiniteTool::name(): true,
|
||||
ThinkingTool::name(): true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1900,8 +2239,9 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
.await;
|
||||
|
||||
let project_context = cx.new(|_cx| ProjectContext::default());
|
||||
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
|
@ -1916,6 +2256,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
model,
|
||||
thread,
|
||||
project_context,
|
||||
context_server_store,
|
||||
fs,
|
||||
}
|
||||
}
|
||||
|
@ -1950,3 +2291,89 @@ fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
|
|||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
|
||||
completion
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn setup_context_server(
|
||||
name: &'static str,
|
||||
tools: Vec<context_server::types::Tool>,
|
||||
context_server_store: &Entity<ContextServerStore>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> mpsc::UnboundedReceiver<(
|
||||
context_server::types::CallToolParams,
|
||||
oneshot::Sender<context_server::types::CallToolResponse>,
|
||||
)> {
|
||||
cx.update(|cx| {
|
||||
let mut settings = ProjectSettings::get_global(cx).clone();
|
||||
settings.context_servers.insert(
|
||||
name.into(),
|
||||
project::project_settings::ContextServerSettings::Custom {
|
||||
enabled: true,
|
||||
command: ContextServerCommand {
|
||||
path: "somebinary".into(),
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
},
|
||||
},
|
||||
);
|
||||
ProjectSettings::override_global(settings, cx);
|
||||
});
|
||||
|
||||
let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
|
||||
let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
|
||||
.on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
|
||||
context_server::types::InitializeResponse {
|
||||
protocol_version: context_server::types::ProtocolVersion(
|
||||
context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
|
||||
),
|
||||
server_info: context_server::types::Implementation {
|
||||
name: name.into(),
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
capabilities: context_server::types::ServerCapabilities {
|
||||
tools: Some(context_server::types::ToolsCapabilities {
|
||||
list_changed: Some(true),
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
meta: None,
|
||||
}
|
||||
})
|
||||
.on_request::<context_server::types::requests::ListTools, _>(move |_params| {
|
||||
let tools = tools.clone();
|
||||
async move {
|
||||
context_server::types::ListToolsResponse {
|
||||
tools,
|
||||
next_cursor: None,
|
||||
meta: None,
|
||||
}
|
||||
}
|
||||
})
|
||||
.on_request::<context_server::types::requests::CallTool, _>(move |params| {
|
||||
let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
|
||||
async move {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
mcp_tool_calls_tx
|
||||
.unbounded_send((params, response_tx))
|
||||
.unwrap();
|
||||
response_rx.await.unwrap()
|
||||
}
|
||||
});
|
||||
context_server_store.update(cx, |store, cx| {
|
||||
store.start_server(
|
||||
Arc::new(ContextServer::new(
|
||||
ContextServerId(name.into()),
|
||||
Arc::new(fake_transport),
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
mcp_tool_calls_rx
|
||||
}
|
||||
|
|
|
@ -9,15 +9,15 @@ use action_log::ActionLog;
|
|||
use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{
|
||||
AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
|
||||
SUMMARIZE_THREAD_PROMPT,
|
||||
AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode,
|
||||
SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
use collections::{HashMap, IndexMap};
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
FutureExt,
|
||||
|
@ -56,6 +56,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
|
|||
use uuid::Uuid;
|
||||
|
||||
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
|
||||
pub const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
/// The ID of the user prompt that initiated a request.
|
||||
///
|
||||
|
@ -627,7 +628,20 @@ impl Thread {
|
|||
stream: &ThreadEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
|
||||
let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| {
|
||||
self.context_server_registry
|
||||
.read(cx)
|
||||
.servers()
|
||||
.find_map(|(_, tools)| {
|
||||
if let Some(tool) = tools.get(tool_use.name.as_ref()) {
|
||||
Some(tool.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
let Some(tool) = tool else {
|
||||
stream
|
||||
.0
|
||||
.unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
|
||||
|
@ -1079,6 +1093,10 @@ impl Thread {
|
|||
self.cancel(cx);
|
||||
|
||||
let model = self.model.clone().context("No language model configured")?;
|
||||
let profile = AgentSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&self.profile_id)
|
||||
.context("Profile not found")?;
|
||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
|
||||
let event_stream = ThreadEventStream(events_tx);
|
||||
let message_ix = self.messages.len().saturating_sub(1);
|
||||
|
@ -1086,6 +1104,7 @@ impl Thread {
|
|||
self.summary = None;
|
||||
self.running_turn = Some(RunningTurn {
|
||||
event_stream: event_stream.clone(),
|
||||
tools: self.enabled_tools(profile, &model, cx),
|
||||
_task: cx.spawn(async move |this, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
|
||||
|
@ -1417,7 +1436,7 @@ impl Thread {
|
|||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
cx.notify();
|
||||
|
||||
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
||||
let tool = self.tool(tool_use.name.as_ref());
|
||||
let mut title = SharedString::from(&tool_use.name);
|
||||
let mut kind = acp::ToolKind::Other;
|
||||
if let Some(tool) = tool.as_ref() {
|
||||
|
@ -1727,6 +1746,21 @@ impl Thread {
|
|||
cx: &mut App,
|
||||
) -> Result<LanguageModelRequest> {
|
||||
let model = self.model().context("No language model configured")?;
|
||||
let tools = if let Some(turn) = self.running_turn.as_ref() {
|
||||
turn.tools
|
||||
.iter()
|
||||
.filter_map(|(tool_name, tool)| {
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name.to_string(),
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
log::debug!("Building completion request");
|
||||
log::debug!("Completion intent: {:?}", completion_intent);
|
||||
|
@ -1734,23 +1768,6 @@ impl Thread {
|
|||
|
||||
let messages = self.build_request_messages(cx);
|
||||
log::info!("Request will include {} messages", messages.len());
|
||||
|
||||
let tools = if let Some(tools) = self.tools(cx).log_err() {
|
||||
tools
|
||||
.filter_map(|tool| {
|
||||
let tool_name = tool.name().to_string();
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name,
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
log::info!("Request includes {} tools", tools.len());
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
|
@ -1770,37 +1787,76 @@ impl Thread {
|
|||
Ok(request)
|
||||
}
|
||||
|
||||
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
|
||||
let model = self.model().context("No language model configured")?;
|
||||
fn enabled_tools(
|
||||
&self,
|
||||
profile: &AgentProfileSettings,
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
cx: &App,
|
||||
) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
|
||||
fn truncate(tool_name: &SharedString) -> SharedString {
|
||||
if tool_name.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let mut truncated = tool_name.to_string();
|
||||
truncated.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
truncated.into()
|
||||
} else {
|
||||
tool_name.clone()
|
||||
}
|
||||
}
|
||||
|
||||
let profile = AgentSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&self.profile_id)
|
||||
.context("profile not found")?;
|
||||
let provider_id = model.provider_id();
|
||||
|
||||
Ok(self
|
||||
let mut tools = self
|
||||
.tools
|
||||
.iter()
|
||||
.filter(move |(_, tool)| tool.supported_provider(&provider_id))
|
||||
.filter_map(|(tool_name, tool)| {
|
||||
if profile.is_tool_enabled(tool_name) {
|
||||
Some(tool)
|
||||
if tool.supported_provider(&model.provider_id())
|
||||
&& profile.is_tool_enabled(tool_name)
|
||||
{
|
||||
Some((truncate(tool_name), tool.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(self.context_server_registry.read(cx).servers().flat_map(
|
||||
|(server_id, tools)| {
|
||||
tools.iter().filter_map(|(tool_name, tool)| {
|
||||
if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
|
||||
Some(tool)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
},
|
||||
)))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
|
||||
let mut context_server_tools = Vec::new();
|
||||
let mut seen_tools = tools.keys().cloned().collect::<HashSet<_>>();
|
||||
let mut duplicate_tool_names = HashSet::default();
|
||||
for (server_id, server_tools) in self.context_server_registry.read(cx).servers() {
|
||||
for (tool_name, tool) in server_tools {
|
||||
if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) {
|
||||
let tool_name = truncate(tool_name);
|
||||
if !seen_tools.insert(tool_name.clone()) {
|
||||
duplicate_tool_names.insert(tool_name.clone());
|
||||
}
|
||||
context_server_tools.push((server_id.clone(), tool_name, tool.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When there are duplicate tool names, disambiguate by prefixing them
|
||||
// with the server ID. In the rare case there isn't enough space for the
|
||||
// disambiguated tool name, keep only the last tool with this name.
|
||||
for (server_id, tool_name, tool) in context_server_tools {
|
||||
if duplicate_tool_names.contains(&tool_name) {
|
||||
let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len());
|
||||
if available >= 2 {
|
||||
let mut disambiguated = server_id.0.to_string();
|
||||
disambiguated.truncate(available - 1);
|
||||
disambiguated.push('_');
|
||||
disambiguated.push_str(&tool_name);
|
||||
tools.insert(disambiguated.into(), tool.clone());
|
||||
} else {
|
||||
tools.insert(tool_name, tool.clone());
|
||||
}
|
||||
} else {
|
||||
tools.insert(tool_name, tool.clone());
|
||||
}
|
||||
}
|
||||
|
||||
tools
|
||||
}
|
||||
|
||||
fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
|
||||
self.running_turn.as_ref()?.tools.get(name).cloned()
|
||||
}
|
||||
|
||||
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
|
||||
|
@ -1965,6 +2021,8 @@ struct RunningTurn {
|
|||
/// The current event stream for the running turn. Used to report a final
|
||||
/// cancellation event if we cancel the turn.
|
||||
event_stream: ThreadEventStream,
|
||||
/// The tools that were enabled for this turn.
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
}
|
||||
|
||||
impl RunningTurn {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use anyhow::Context as _;
|
||||
use collections::HashMap;
|
||||
use futures::{Stream, StreamExt as _, lock::Mutex};
|
||||
use futures::{FutureExt, Stream, StreamExt as _, future::BoxFuture, lock::Mutex};
|
||||
use gpui::BackgroundExecutor;
|
||||
use std::{pin::Pin, sync::Arc};
|
||||
|
||||
|
@ -14,9 +14,12 @@ pub fn create_fake_transport(
|
|||
executor: BackgroundExecutor,
|
||||
) -> FakeTransport {
|
||||
let name = name.into();
|
||||
FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
|
||||
create_initialize_response(name.clone())
|
||||
})
|
||||
FakeTransport::new(executor).on_request::<crate::types::requests::Initialize, _>(
|
||||
move |_params| {
|
||||
let name = name.clone();
|
||||
async move { create_initialize_response(name.clone()) }
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn create_initialize_response(server_name: String) -> InitializeResponse {
|
||||
|
@ -32,8 +35,10 @@ fn create_initialize_response(server_name: String) -> InitializeResponse {
|
|||
}
|
||||
|
||||
pub struct FakeTransport {
|
||||
request_handlers:
|
||||
HashMap<&'static str, Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>>,
|
||||
request_handlers: HashMap<
|
||||
&'static str,
|
||||
Arc<dyn Send + Sync + Fn(serde_json::Value) -> BoxFuture<'static, serde_json::Value>>,
|
||||
>,
|
||||
tx: futures::channel::mpsc::UnboundedSender<String>,
|
||||
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
||||
executor: BackgroundExecutor,
|
||||
|
@ -50,18 +55,25 @@ impl FakeTransport {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn on_request<T: crate::types::Request>(
|
||||
pub fn on_request<T, Fut>(
|
||||
mut self,
|
||||
handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
handler: impl 'static + Send + Sync + Fn(T::Params) -> Fut,
|
||||
) -> Self
|
||||
where
|
||||
T: crate::types::Request,
|
||||
Fut: 'static + Send + Future<Output = T::Response>,
|
||||
{
|
||||
self.request_handlers.insert(
|
||||
T::METHOD,
|
||||
Arc::new(move |value| {
|
||||
let params = value.get("params").expect("Missing parameters").clone();
|
||||
let params = value
|
||||
.get("params")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
let params: T::Params =
|
||||
serde_json::from_value(params).expect("Invalid parameters received");
|
||||
let response = handler(params);
|
||||
serde_json::to_value(response).unwrap()
|
||||
async move { serde_json::to_value(response.await).unwrap() }.boxed()
|
||||
}),
|
||||
);
|
||||
self
|
||||
|
@ -77,7 +89,7 @@ impl Transport for FakeTransport {
|
|||
if let Some(method) = msg.get("method") {
|
||||
let method = method.as_str().expect("Invalid method received");
|
||||
if let Some(handler) = self.request_handlers.get(method) {
|
||||
let payload = handler(msg);
|
||||
let payload = handler(msg).await;
|
||||
let response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue