acp: Support calling tools provided by MCP servers (#36752)
Release Notes: - N/A
This commit is contained in:
parent
3b7c1744b4
commit
4f0fad6996
3 changed files with 561 additions and 64 deletions
|
@ -4,26 +4,35 @@ use agent_client_protocol::{self as acp};
|
||||||
use agent_settings::AgentProfileId;
|
use agent_settings::AgentProfileId;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use client::{Client, UserStore};
|
use client::{Client, UserStore};
|
||||||
|
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||||
use fs::{FakeFs, Fs};
|
use fs::{FakeFs, Fs};
|
||||||
use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
|
use futures::{
|
||||||
|
StreamExt,
|
||||||
|
channel::{
|
||||||
|
mpsc::{self, UnboundedReceiver},
|
||||||
|
oneshot,
|
||||||
|
},
|
||||||
|
};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
|
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
|
||||||
};
|
};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
|
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
|
||||||
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
|
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
|
||||||
fake_provider::FakeLanguageModel,
|
LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
|
||||||
};
|
};
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
use project::Project;
|
use project::{
|
||||||
|
Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
|
||||||
|
};
|
||||||
use prompt_store::ProjectContext;
|
use prompt_store::ProjectContext;
|
||||||
use reqwest_client::ReqwestClient;
|
use reqwest_client::ReqwestClient;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::{Settings, SettingsStore};
|
||||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||||
use util::path;
|
use util::path;
|
||||||
|
|
||||||
|
@ -931,6 +940,334 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
||||||
assert_eq!(tool_names, vec![InfiniteTool::name()]);
|
assert_eq!(tool_names, vec![InfiniteTool::name()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_mcp_tools(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest {
|
||||||
|
model,
|
||||||
|
thread,
|
||||||
|
context_server_store,
|
||||||
|
fs,
|
||||||
|
..
|
||||||
|
} = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
// Override profiles and wait for settings to be loaded.
|
||||||
|
fs.insert_file(
|
||||||
|
paths::settings_file(),
|
||||||
|
json!({
|
||||||
|
"agent": {
|
||||||
|
"profiles": {
|
||||||
|
"test": {
|
||||||
|
"name": "Test Profile",
|
||||||
|
"enable_all_context_servers": true,
|
||||||
|
"tools": {
|
||||||
|
EchoTool::name(): true,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string()
|
||||||
|
.into_bytes(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
cx.run_until_parked();
|
||||||
|
thread.update(cx, |thread, _| {
|
||||||
|
thread.set_profile(AgentProfileId("test".into()))
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut mcp_tool_calls = setup_context_server(
|
||||||
|
"test_server",
|
||||||
|
vec![context_server::types::Tool {
|
||||||
|
name: "echo".into(),
|
||||||
|
description: None,
|
||||||
|
input_schema: serde_json::to_value(
|
||||||
|
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
}],
|
||||||
|
&context_server_store,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
let events = thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
// Simulate the model calling the MCP tool.
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: "tool_1".into(),
|
||||||
|
name: "echo".into(),
|
||||||
|
raw_input: json!({"text": "test"}).to_string(),
|
||||||
|
input: json!({"text": "test"}),
|
||||||
|
is_input_complete: true,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
||||||
|
assert_eq!(tool_call_params.name, "echo");
|
||||||
|
assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
|
||||||
|
tool_call_response
|
||||||
|
.send(context_server::types::CallToolResponse {
|
||||||
|
content: vec![context_server::types::ToolResponseContent::Text {
|
||||||
|
text: "test".into(),
|
||||||
|
}],
|
||||||
|
is_error: None,
|
||||||
|
meta: None,
|
||||||
|
structured_content: None,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Done!");
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
events.collect::<Vec<_>>().await;
|
||||||
|
|
||||||
|
// Send again after adding the echo tool, ensuring the name collision is resolved.
|
||||||
|
let events = thread.update(cx, |thread, cx| {
|
||||||
|
thread.add_tool(EchoTool);
|
||||||
|
thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
tool_names_for_completion(&completion),
|
||||||
|
vec!["echo", "test_server_echo"]
|
||||||
|
);
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: "tool_2".into(),
|
||||||
|
name: "test_server_echo".into(),
|
||||||
|
raw_input: json!({"text": "mcp"}).to_string(),
|
||||||
|
input: json!({"text": "mcp"}),
|
||||||
|
is_input_complete: true,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: "tool_3".into(),
|
||||||
|
name: "echo".into(),
|
||||||
|
raw_input: json!({"text": "native"}).to_string(),
|
||||||
|
input: json!({"text": "native"}),
|
||||||
|
is_input_complete: true,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
||||||
|
assert_eq!(tool_call_params.name, "echo");
|
||||||
|
assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
|
||||||
|
tool_call_response
|
||||||
|
.send(context_server::types::CallToolResponse {
|
||||||
|
content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
|
||||||
|
is_error: None,
|
||||||
|
meta: None,
|
||||||
|
structured_content: None,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
// Ensure the tool results were inserted with the correct names.
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
completion.messages.last().unwrap().content,
|
||||||
|
vec![
|
||||||
|
MessageContent::ToolResult(LanguageModelToolResult {
|
||||||
|
tool_use_id: "tool_3".into(),
|
||||||
|
tool_name: "echo".into(),
|
||||||
|
is_error: false,
|
||||||
|
content: "native".into(),
|
||||||
|
output: Some("native".into()),
|
||||||
|
},),
|
||||||
|
MessageContent::ToolResult(LanguageModelToolResult {
|
||||||
|
tool_use_id: "tool_2".into(),
|
||||||
|
tool_name: "test_server_echo".into(),
|
||||||
|
is_error: false,
|
||||||
|
content: "mcp".into(),
|
||||||
|
output: Some("mcp".into()),
|
||||||
|
},),
|
||||||
|
]
|
||||||
|
);
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
events.collect::<Vec<_>>().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest {
|
||||||
|
model,
|
||||||
|
thread,
|
||||||
|
context_server_store,
|
||||||
|
fs,
|
||||||
|
..
|
||||||
|
} = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
// Set up a profile with all tools enabled
|
||||||
|
fs.insert_file(
|
||||||
|
paths::settings_file(),
|
||||||
|
json!({
|
||||||
|
"agent": {
|
||||||
|
"profiles": {
|
||||||
|
"test": {
|
||||||
|
"name": "Test Profile",
|
||||||
|
"enable_all_context_servers": true,
|
||||||
|
"tools": {
|
||||||
|
EchoTool::name(): true,
|
||||||
|
DelayTool::name(): true,
|
||||||
|
WordListTool::name(): true,
|
||||||
|
ToolRequiringPermission::name(): true,
|
||||||
|
InfiniteTool::name(): true,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string()
|
||||||
|
.into_bytes(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
thread.update(cx, |thread, _| {
|
||||||
|
thread.set_profile(AgentProfileId("test".into()));
|
||||||
|
thread.add_tool(EchoTool);
|
||||||
|
thread.add_tool(DelayTool);
|
||||||
|
thread.add_tool(WordListTool);
|
||||||
|
thread.add_tool(ToolRequiringPermission);
|
||||||
|
thread.add_tool(InfiniteTool);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Set up multiple context servers with some overlapping tool names
|
||||||
|
let _server1_calls = setup_context_server(
|
||||||
|
"xxx",
|
||||||
|
vec![
|
||||||
|
context_server::types::Tool {
|
||||||
|
name: "echo".into(), // Conflicts with native EchoTool
|
||||||
|
description: None,
|
||||||
|
input_schema: serde_json::to_value(
|
||||||
|
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
context_server::types::Tool {
|
||||||
|
name: "unique_tool_1".into(),
|
||||||
|
description: None,
|
||||||
|
input_schema: json!({"type": "object", "properties": {}}),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
&context_server_store,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
let _server2_calls = setup_context_server(
|
||||||
|
"yyy",
|
||||||
|
vec![
|
||||||
|
context_server::types::Tool {
|
||||||
|
name: "echo".into(), // Also conflicts with native EchoTool
|
||||||
|
description: None,
|
||||||
|
input_schema: serde_json::to_value(
|
||||||
|
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
context_server::types::Tool {
|
||||||
|
name: "unique_tool_2".into(),
|
||||||
|
description: None,
|
||||||
|
input_schema: json!({"type": "object", "properties": {}}),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
context_server::types::Tool {
|
||||||
|
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
|
||||||
|
description: None,
|
||||||
|
input_schema: json!({"type": "object", "properties": {}}),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
context_server::types::Tool {
|
||||||
|
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
|
||||||
|
description: None,
|
||||||
|
input_schema: json!({"type": "object", "properties": {}}),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
&context_server_store,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
let _server3_calls = setup_context_server(
|
||||||
|
"zzz",
|
||||||
|
vec![
|
||||||
|
context_server::types::Tool {
|
||||||
|
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
|
||||||
|
description: None,
|
||||||
|
input_schema: json!({"type": "object", "properties": {}}),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
context_server::types::Tool {
|
||||||
|
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
|
||||||
|
description: None,
|
||||||
|
input_schema: json!({"type": "object", "properties": {}}),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
context_server::types::Tool {
|
||||||
|
name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
|
||||||
|
description: None,
|
||||||
|
input_schema: json!({"type": "object", "properties": {}}),
|
||||||
|
output_schema: None,
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
&context_server_store,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Go"], cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
tool_names_for_completion(&completion),
|
||||||
|
vec![
|
||||||
|
"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
|
||||||
|
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
|
||||||
|
"delay",
|
||||||
|
"echo",
|
||||||
|
"infinite",
|
||||||
|
"tool_requiring_permission",
|
||||||
|
"unique_tool_1",
|
||||||
|
"unique_tool_2",
|
||||||
|
"word_list",
|
||||||
|
"xxx_echo",
|
||||||
|
"y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||||
|
"yyy_echo",
|
||||||
|
"z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
|
@ -1806,6 +2143,7 @@ struct ThreadTest {
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
thread: Entity<Thread>,
|
thread: Entity<Thread>,
|
||||||
project_context: Entity<ProjectContext>,
|
project_context: Entity<ProjectContext>,
|
||||||
|
context_server_store: Entity<ContextServerStore>,
|
||||||
fs: Arc<FakeFs>,
|
fs: Arc<FakeFs>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1844,6 +2182,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||||
WordListTool::name(): true,
|
WordListTool::name(): true,
|
||||||
ToolRequiringPermission::name(): true,
|
ToolRequiringPermission::name(): true,
|
||||||
InfiniteTool::name(): true,
|
InfiniteTool::name(): true,
|
||||||
|
ThinkingTool::name(): true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1900,8 +2239,9 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let project_context = cx.new(|_cx| ProjectContext::default());
|
let project_context = cx.new(|_cx| ProjectContext::default());
|
||||||
|
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
||||||
let context_server_registry =
|
let context_server_registry =
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project,
|
project,
|
||||||
|
@ -1916,6 +2256,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||||
model,
|
model,
|
||||||
thread,
|
thread,
|
||||||
project_context,
|
project_context,
|
||||||
|
context_server_store,
|
||||||
fs,
|
fs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1950,3 +2291,89 @@ fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
|
||||||
|
completion
|
||||||
|
.tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| tool.name.clone())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn setup_context_server(
|
||||||
|
name: &'static str,
|
||||||
|
tools: Vec<context_server::types::Tool>,
|
||||||
|
context_server_store: &Entity<ContextServerStore>,
|
||||||
|
cx: &mut TestAppContext,
|
||||||
|
) -> mpsc::UnboundedReceiver<(
|
||||||
|
context_server::types::CallToolParams,
|
||||||
|
oneshot::Sender<context_server::types::CallToolResponse>,
|
||||||
|
)> {
|
||||||
|
cx.update(|cx| {
|
||||||
|
let mut settings = ProjectSettings::get_global(cx).clone();
|
||||||
|
settings.context_servers.insert(
|
||||||
|
name.into(),
|
||||||
|
project::project_settings::ContextServerSettings::Custom {
|
||||||
|
enabled: true,
|
||||||
|
command: ContextServerCommand {
|
||||||
|
path: "somebinary".into(),
|
||||||
|
args: Vec::new(),
|
||||||
|
env: None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
ProjectSettings::override_global(settings, cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
|
||||||
|
let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
|
||||||
|
.on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
|
||||||
|
context_server::types::InitializeResponse {
|
||||||
|
protocol_version: context_server::types::ProtocolVersion(
|
||||||
|
context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
|
||||||
|
),
|
||||||
|
server_info: context_server::types::Implementation {
|
||||||
|
name: name.into(),
|
||||||
|
version: "1.0.0".to_string(),
|
||||||
|
},
|
||||||
|
capabilities: context_server::types::ServerCapabilities {
|
||||||
|
tools: Some(context_server::types::ToolsCapabilities {
|
||||||
|
list_changed: Some(true),
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
meta: None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.on_request::<context_server::types::requests::ListTools, _>(move |_params| {
|
||||||
|
let tools = tools.clone();
|
||||||
|
async move {
|
||||||
|
context_server::types::ListToolsResponse {
|
||||||
|
tools,
|
||||||
|
next_cursor: None,
|
||||||
|
meta: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.on_request::<context_server::types::requests::CallTool, _>(move |params| {
|
||||||
|
let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
|
||||||
|
async move {
|
||||||
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
|
mcp_tool_calls_tx
|
||||||
|
.unbounded_send((params, response_tx))
|
||||||
|
.unwrap();
|
||||||
|
response_rx.await.unwrap()
|
||||||
|
}
|
||||||
|
});
|
||||||
|
context_server_store.update(cx, |store, cx| {
|
||||||
|
store.start_server(
|
||||||
|
Arc::new(ContextServer::new(
|
||||||
|
ContextServerId(name.into()),
|
||||||
|
Arc::new(fake_transport),
|
||||||
|
)),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
mcp_tool_calls_rx
|
||||||
|
}
|
||||||
|
|
|
@ -9,15 +9,15 @@ use action_log::ActionLog;
|
||||||
use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
|
use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::{
|
use agent_settings::{
|
||||||
AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
|
AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode,
|
||||||
SUMMARIZE_THREAD_PROMPT,
|
SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
|
||||||
};
|
};
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tool::adapt_schema_to_format;
|
use assistant_tool::adapt_schema_to_format;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use client::{ModelRequestUsage, RequestUsage};
|
use client::{ModelRequestUsage, RequestUsage};
|
||||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||||
use collections::{HashMap, IndexMap};
|
use collections::{HashMap, HashSet, IndexMap};
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use futures::{
|
use futures::{
|
||||||
FutureExt,
|
FutureExt,
|
||||||
|
@ -56,6 +56,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
|
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
|
||||||
|
pub const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||||
|
|
||||||
/// The ID of the user prompt that initiated a request.
|
/// The ID of the user prompt that initiated a request.
|
||||||
///
|
///
|
||||||
|
@ -627,7 +628,20 @@ impl Thread {
|
||||||
stream: &ThreadEventStream,
|
stream: &ThreadEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
|
let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| {
|
||||||
|
self.context_server_registry
|
||||||
|
.read(cx)
|
||||||
|
.servers()
|
||||||
|
.find_map(|(_, tools)| {
|
||||||
|
if let Some(tool) = tools.get(tool_use.name.as_ref()) {
|
||||||
|
Some(tool.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let Some(tool) = tool else {
|
||||||
stream
|
stream
|
||||||
.0
|
.0
|
||||||
.unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
|
.unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
|
||||||
|
@ -1079,6 +1093,10 @@ impl Thread {
|
||||||
self.cancel(cx);
|
self.cancel(cx);
|
||||||
|
|
||||||
let model = self.model.clone().context("No language model configured")?;
|
let model = self.model.clone().context("No language model configured")?;
|
||||||
|
let profile = AgentSettings::get_global(cx)
|
||||||
|
.profiles
|
||||||
|
.get(&self.profile_id)
|
||||||
|
.context("Profile not found")?;
|
||||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
|
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
|
||||||
let event_stream = ThreadEventStream(events_tx);
|
let event_stream = ThreadEventStream(events_tx);
|
||||||
let message_ix = self.messages.len().saturating_sub(1);
|
let message_ix = self.messages.len().saturating_sub(1);
|
||||||
|
@ -1086,6 +1104,7 @@ impl Thread {
|
||||||
self.summary = None;
|
self.summary = None;
|
||||||
self.running_turn = Some(RunningTurn {
|
self.running_turn = Some(RunningTurn {
|
||||||
event_stream: event_stream.clone(),
|
event_stream: event_stream.clone(),
|
||||||
|
tools: self.enabled_tools(profile, &model, cx),
|
||||||
_task: cx.spawn(async move |this, cx| {
|
_task: cx.spawn(async move |this, cx| {
|
||||||
log::info!("Starting agent turn execution");
|
log::info!("Starting agent turn execution");
|
||||||
|
|
||||||
|
@ -1417,7 +1436,7 @@ impl Thread {
|
||||||
) -> Option<Task<LanguageModelToolResult>> {
|
) -> Option<Task<LanguageModelToolResult>> {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
|
||||||
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
let tool = self.tool(tool_use.name.as_ref());
|
||||||
let mut title = SharedString::from(&tool_use.name);
|
let mut title = SharedString::from(&tool_use.name);
|
||||||
let mut kind = acp::ToolKind::Other;
|
let mut kind = acp::ToolKind::Other;
|
||||||
if let Some(tool) = tool.as_ref() {
|
if let Some(tool) = tool.as_ref() {
|
||||||
|
@ -1727,6 +1746,21 @@ impl Thread {
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Result<LanguageModelRequest> {
|
) -> Result<LanguageModelRequest> {
|
||||||
let model = self.model().context("No language model configured")?;
|
let model = self.model().context("No language model configured")?;
|
||||||
|
let tools = if let Some(turn) = self.running_turn.as_ref() {
|
||||||
|
turn.tools
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(tool_name, tool)| {
|
||||||
|
log::trace!("Including tool: {}", tool_name);
|
||||||
|
Some(LanguageModelRequestTool {
|
||||||
|
name: tool_name.to_string(),
|
||||||
|
description: tool.description().to_string(),
|
||||||
|
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
log::debug!("Building completion request");
|
log::debug!("Building completion request");
|
||||||
log::debug!("Completion intent: {:?}", completion_intent);
|
log::debug!("Completion intent: {:?}", completion_intent);
|
||||||
|
@ -1734,23 +1768,6 @@ impl Thread {
|
||||||
|
|
||||||
let messages = self.build_request_messages(cx);
|
let messages = self.build_request_messages(cx);
|
||||||
log::info!("Request will include {} messages", messages.len());
|
log::info!("Request will include {} messages", messages.len());
|
||||||
|
|
||||||
let tools = if let Some(tools) = self.tools(cx).log_err() {
|
|
||||||
tools
|
|
||||||
.filter_map(|tool| {
|
|
||||||
let tool_name = tool.name().to_string();
|
|
||||||
log::trace!("Including tool: {}", tool_name);
|
|
||||||
Some(LanguageModelRequestTool {
|
|
||||||
name: tool_name,
|
|
||||||
description: tool.description().to_string(),
|
|
||||||
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
Vec::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
log::info!("Request includes {} tools", tools.len());
|
log::info!("Request includes {} tools", tools.len());
|
||||||
|
|
||||||
let request = LanguageModelRequest {
|
let request = LanguageModelRequest {
|
||||||
|
@ -1770,37 +1787,76 @@ impl Thread {
|
||||||
Ok(request)
|
Ok(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
|
fn enabled_tools(
|
||||||
let model = self.model().context("No language model configured")?;
|
&self,
|
||||||
|
profile: &AgentProfileSettings,
|
||||||
|
model: &Arc<dyn LanguageModel>,
|
||||||
|
cx: &App,
|
||||||
|
) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
|
||||||
|
fn truncate(tool_name: &SharedString) -> SharedString {
|
||||||
|
if tool_name.len() > MAX_TOOL_NAME_LENGTH {
|
||||||
|
let mut truncated = tool_name.to_string();
|
||||||
|
truncated.truncate(MAX_TOOL_NAME_LENGTH);
|
||||||
|
truncated.into()
|
||||||
|
} else {
|
||||||
|
tool_name.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let profile = AgentSettings::get_global(cx)
|
let mut tools = self
|
||||||
.profiles
|
|
||||||
.get(&self.profile_id)
|
|
||||||
.context("profile not found")?;
|
|
||||||
let provider_id = model.provider_id();
|
|
||||||
|
|
||||||
Ok(self
|
|
||||||
.tools
|
.tools
|
||||||
.iter()
|
.iter()
|
||||||
.filter(move |(_, tool)| tool.supported_provider(&provider_id))
|
|
||||||
.filter_map(|(tool_name, tool)| {
|
.filter_map(|(tool_name, tool)| {
|
||||||
if profile.is_tool_enabled(tool_name) {
|
if tool.supported_provider(&model.provider_id())
|
||||||
Some(tool)
|
&& profile.is_tool_enabled(tool_name)
|
||||||
|
{
|
||||||
|
Some((truncate(tool_name), tool.clone()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.chain(self.context_server_registry.read(cx).servers().flat_map(
|
.collect::<BTreeMap<_, _>>();
|
||||||
|(server_id, tools)| {
|
|
||||||
tools.iter().filter_map(|(tool_name, tool)| {
|
let mut context_server_tools = Vec::new();
|
||||||
if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
|
let mut seen_tools = tools.keys().cloned().collect::<HashSet<_>>();
|
||||||
Some(tool)
|
let mut duplicate_tool_names = HashSet::default();
|
||||||
} else {
|
for (server_id, server_tools) in self.context_server_registry.read(cx).servers() {
|
||||||
None
|
for (tool_name, tool) in server_tools {
|
||||||
}
|
if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) {
|
||||||
})
|
let tool_name = truncate(tool_name);
|
||||||
},
|
if !seen_tools.insert(tool_name.clone()) {
|
||||||
)))
|
duplicate_tool_names.insert(tool_name.clone());
|
||||||
|
}
|
||||||
|
context_server_tools.push((server_id.clone(), tool_name, tool.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// When there are duplicate tool names, disambiguate by prefixing them
|
||||||
|
// with the server ID. In the rare case there isn't enough space for the
|
||||||
|
// disambiguated tool name, keep only the last tool with this name.
|
||||||
|
for (server_id, tool_name, tool) in context_server_tools {
|
||||||
|
if duplicate_tool_names.contains(&tool_name) {
|
||||||
|
let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len());
|
||||||
|
if available >= 2 {
|
||||||
|
let mut disambiguated = server_id.0.to_string();
|
||||||
|
disambiguated.truncate(available - 1);
|
||||||
|
disambiguated.push('_');
|
||||||
|
disambiguated.push_str(&tool_name);
|
||||||
|
tools.insert(disambiguated.into(), tool.clone());
|
||||||
|
} else {
|
||||||
|
tools.insert(tool_name, tool.clone());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tools.insert(tool_name, tool.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tools
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
|
||||||
|
self.running_turn.as_ref()?.tools.get(name).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
|
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
|
||||||
|
@ -1965,6 +2021,8 @@ struct RunningTurn {
|
||||||
/// The current event stream for the running turn. Used to report a final
|
/// The current event stream for the running turn. Used to report a final
|
||||||
/// cancellation event if we cancel the turn.
|
/// cancellation event if we cancel the turn.
|
||||||
event_stream: ThreadEventStream,
|
event_stream: ThreadEventStream,
|
||||||
|
/// The tools that were enabled for this turn.
|
||||||
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RunningTurn {
|
impl RunningTurn {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use anyhow::Context as _;
|
use anyhow::Context as _;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use futures::{Stream, StreamExt as _, lock::Mutex};
|
use futures::{FutureExt, Stream, StreamExt as _, future::BoxFuture, lock::Mutex};
|
||||||
use gpui::BackgroundExecutor;
|
use gpui::BackgroundExecutor;
|
||||||
use std::{pin::Pin, sync::Arc};
|
use std::{pin::Pin, sync::Arc};
|
||||||
|
|
||||||
|
@ -14,9 +14,12 @@ pub fn create_fake_transport(
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
) -> FakeTransport {
|
) -> FakeTransport {
|
||||||
let name = name.into();
|
let name = name.into();
|
||||||
FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
|
FakeTransport::new(executor).on_request::<crate::types::requests::Initialize, _>(
|
||||||
create_initialize_response(name.clone())
|
move |_params| {
|
||||||
})
|
let name = name.clone();
|
||||||
|
async move { create_initialize_response(name.clone()) }
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_initialize_response(server_name: String) -> InitializeResponse {
|
fn create_initialize_response(server_name: String) -> InitializeResponse {
|
||||||
|
@ -32,8 +35,10 @@ fn create_initialize_response(server_name: String) -> InitializeResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct FakeTransport {
|
pub struct FakeTransport {
|
||||||
request_handlers:
|
request_handlers: HashMap<
|
||||||
HashMap<&'static str, Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>>,
|
&'static str,
|
||||||
|
Arc<dyn Send + Sync + Fn(serde_json::Value) -> BoxFuture<'static, serde_json::Value>>,
|
||||||
|
>,
|
||||||
tx: futures::channel::mpsc::UnboundedSender<String>,
|
tx: futures::channel::mpsc::UnboundedSender<String>,
|
||||||
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
|
@ -50,18 +55,25 @@ impl FakeTransport {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn on_request<T: crate::types::Request>(
|
pub fn on_request<T, Fut>(
|
||||||
mut self,
|
mut self,
|
||||||
handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static,
|
handler: impl 'static + Send + Sync + Fn(T::Params) -> Fut,
|
||||||
) -> Self {
|
) -> Self
|
||||||
|
where
|
||||||
|
T: crate::types::Request,
|
||||||
|
Fut: 'static + Send + Future<Output = T::Response>,
|
||||||
|
{
|
||||||
self.request_handlers.insert(
|
self.request_handlers.insert(
|
||||||
T::METHOD,
|
T::METHOD,
|
||||||
Arc::new(move |value| {
|
Arc::new(move |value| {
|
||||||
let params = value.get("params").expect("Missing parameters").clone();
|
let params = value
|
||||||
|
.get("params")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(serde_json::Value::Null);
|
||||||
let params: T::Params =
|
let params: T::Params =
|
||||||
serde_json::from_value(params).expect("Invalid parameters received");
|
serde_json::from_value(params).expect("Invalid parameters received");
|
||||||
let response = handler(params);
|
let response = handler(params);
|
||||||
serde_json::to_value(response).unwrap()
|
async move { serde_json::to_value(response.await).unwrap() }.boxed()
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
self
|
self
|
||||||
|
@ -77,7 +89,7 @@ impl Transport for FakeTransport {
|
||||||
if let Some(method) = msg.get("method") {
|
if let Some(method) = msg.get("method") {
|
||||||
let method = method.as_str().expect("Invalid method received");
|
let method = method.as_str().expect("Invalid method received");
|
||||||
if let Some(handler) = self.request_handlers.get(method) {
|
if let Some(handler) = self.request_handlers.get(method) {
|
||||||
let payload = handler(msg);
|
let payload = handler(msg).await;
|
||||||
let response = serde_json::json!({
|
let response = serde_json::json!({
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": id,
|
"id": id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue