acp: Support calling tools provided by MCP servers (#36752)
Release Notes: - N/A
This commit is contained in:
parent
3b7c1744b4
commit
4f0fad6996
3 changed files with 561 additions and 64 deletions
|
@ -4,26 +4,35 @@ use agent_client_protocol::{self as acp};
|
|||
use agent_settings::AgentProfileId;
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||
use fs::{FakeFs, Fs};
|
||||
use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
|
||||
use futures::{
|
||||
StreamExt,
|
||||
channel::{
|
||||
mpsc::{self, UnboundedReceiver},
|
||||
oneshot,
|
||||
},
|
||||
};
|
||||
use gpui::{
|
||||
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
|
||||
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
fake_provider::FakeLanguageModel,
|
||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
|
||||
};
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::Project;
|
||||
use project::{
|
||||
Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
|
||||
};
|
||||
use prompt_store::ProjectContext;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use util::path;
|
||||
|
||||
|
@ -931,6 +940,334 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
|||
assert_eq!(tool_names, vec![InfiniteTool::name()]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_mcp_tools(cx: &mut TestAppContext) {
|
||||
let ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
context_server_store,
|
||||
fs,
|
||||
..
|
||||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
// Override profiles and wait for settings to be loaded.
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"profiles": {
|
||||
"test": {
|
||||
"name": "Test Profile",
|
||||
"enable_all_context_servers": true,
|
||||
"tools": {
|
||||
EchoTool::name(): true,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
cx.run_until_parked();
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_profile(AgentProfileId("test".into()))
|
||||
});
|
||||
|
||||
let mut mcp_tool_calls = setup_context_server(
|
||||
"test_server",
|
||||
vec![context_server::types::Tool {
|
||||
name: "echo".into(),
|
||||
description: None,
|
||||
input_schema: serde_json::to_value(
|
||||
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
|
||||
)
|
||||
.unwrap(),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
}],
|
||||
&context_server_store,
|
||||
cx,
|
||||
);
|
||||
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
// Simulate the model calling the MCP tool.
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_1".into(),
|
||||
name: "echo".into(),
|
||||
raw_input: json!({"text": "test"}).to_string(),
|
||||
input: json!({"text": "test"}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
||||
assert_eq!(tool_call_params.name, "echo");
|
||||
assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
|
||||
tool_call_response
|
||||
.send(context_server::types::CallToolResponse {
|
||||
content: vec![context_server::types::ToolResponseContent::Text {
|
||||
text: "test".into(),
|
||||
}],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
structured_content: None,
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
|
||||
fake_model.send_last_completion_stream_text_chunk("Done!");
|
||||
fake_model.end_last_completion_stream();
|
||||
events.collect::<Vec<_>>().await;
|
||||
|
||||
// Send again after adding the echo tool, ensuring the name collision is resolved.
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
tool_names_for_completion(&completion),
|
||||
vec!["echo", "test_server_echo"]
|
||||
);
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_2".into(),
|
||||
name: "test_server_echo".into(),
|
||||
raw_input: json!({"text": "mcp"}).to_string(),
|
||||
input: json!({"text": "mcp"}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_3".into(),
|
||||
name: "echo".into(),
|
||||
raw_input: json!({"text": "native"}).to_string(),
|
||||
input: json!({"text": "native"}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
||||
assert_eq!(tool_call_params.name, "echo");
|
||||
assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
|
||||
tool_call_response
|
||||
.send(context_server::types::CallToolResponse {
|
||||
content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
structured_content: None,
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
// Ensure the tool results were inserted with the correct names.
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
completion.messages.last().unwrap().content,
|
||||
vec![
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: "tool_3".into(),
|
||||
tool_name: "echo".into(),
|
||||
is_error: false,
|
||||
content: "native".into(),
|
||||
output: Some("native".into()),
|
||||
},),
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: "tool_2".into(),
|
||||
tool_name: "test_server_echo".into(),
|
||||
is_error: false,
|
||||
content: "mcp".into(),
|
||||
output: Some("mcp".into()),
|
||||
},),
|
||||
]
|
||||
);
|
||||
fake_model.end_last_completion_stream();
|
||||
events.collect::<Vec<_>>().await;
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
|
||||
let ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
context_server_store,
|
||||
fs,
|
||||
..
|
||||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
// Set up a profile with all tools enabled
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"profiles": {
|
||||
"test": {
|
||||
"name": "Test Profile",
|
||||
"enable_all_context_servers": true,
|
||||
"tools": {
|
||||
EchoTool::name(): true,
|
||||
DelayTool::name(): true,
|
||||
WordListTool::name(): true,
|
||||
ToolRequiringPermission::name(): true,
|
||||
InfiniteTool::name(): true,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
cx.run_until_parked();
|
||||
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.set_profile(AgentProfileId("test".into()));
|
||||
thread.add_tool(EchoTool);
|
||||
thread.add_tool(DelayTool);
|
||||
thread.add_tool(WordListTool);
|
||||
thread.add_tool(ToolRequiringPermission);
|
||||
thread.add_tool(InfiniteTool);
|
||||
});
|
||||
|
||||
// Set up multiple context servers with some overlapping tool names
|
||||
let _server1_calls = setup_context_server(
|
||||
"xxx",
|
||||
vec![
|
||||
context_server::types::Tool {
|
||||
name: "echo".into(), // Conflicts with native EchoTool
|
||||
description: None,
|
||||
input_schema: serde_json::to_value(
|
||||
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
|
||||
)
|
||||
.unwrap(),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "unique_tool_1".into(),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
],
|
||||
&context_server_store,
|
||||
cx,
|
||||
);
|
||||
|
||||
let _server2_calls = setup_context_server(
|
||||
"yyy",
|
||||
vec![
|
||||
context_server::types::Tool {
|
||||
name: "echo".into(), // Also conflicts with native EchoTool
|
||||
description: None,
|
||||
input_schema: serde_json::to_value(
|
||||
EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
|
||||
)
|
||||
.unwrap(),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "unique_tool_2".into(),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
],
|
||||
&context_server_store,
|
||||
cx,
|
||||
);
|
||||
let _server3_calls = setup_context_server(
|
||||
"zzz",
|
||||
vec![
|
||||
context_server::types::Tool {
|
||||
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
context_server::types::Tool {
|
||||
name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
|
||||
description: None,
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
},
|
||||
],
|
||||
&context_server_store,
|
||||
cx,
|
||||
);
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Go"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
tool_names_for_completion(&completion),
|
||||
vec![
|
||||
"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
|
||||
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
|
||||
"delay",
|
||||
"echo",
|
||||
"infinite",
|
||||
"tool_requiring_permission",
|
||||
"unique_tool_1",
|
||||
"unique_tool_2",
|
||||
"word_list",
|
||||
"xxx_echo",
|
||||
"y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
"yyy_echo",
|
||||
"z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
|
@ -1806,6 +2143,7 @@ struct ThreadTest {
|
|||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
fs: Arc<FakeFs>,
|
||||
}
|
||||
|
||||
|
@ -1844,6 +2182,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
WordListTool::name(): true,
|
||||
ToolRequiringPermission::name(): true,
|
||||
InfiniteTool::name(): true,
|
||||
ThinkingTool::name(): true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1900,8 +2239,9 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
.await;
|
||||
|
||||
let project_context = cx.new(|_cx| ProjectContext::default());
|
||||
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
|
@ -1916,6 +2256,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
model,
|
||||
thread,
|
||||
project_context,
|
||||
context_server_store,
|
||||
fs,
|
||||
}
|
||||
}
|
||||
|
@ -1950,3 +2291,89 @@ fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
|
|||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
|
||||
completion
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn setup_context_server(
|
||||
name: &'static str,
|
||||
tools: Vec<context_server::types::Tool>,
|
||||
context_server_store: &Entity<ContextServerStore>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> mpsc::UnboundedReceiver<(
|
||||
context_server::types::CallToolParams,
|
||||
oneshot::Sender<context_server::types::CallToolResponse>,
|
||||
)> {
|
||||
cx.update(|cx| {
|
||||
let mut settings = ProjectSettings::get_global(cx).clone();
|
||||
settings.context_servers.insert(
|
||||
name.into(),
|
||||
project::project_settings::ContextServerSettings::Custom {
|
||||
enabled: true,
|
||||
command: ContextServerCommand {
|
||||
path: "somebinary".into(),
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
},
|
||||
},
|
||||
);
|
||||
ProjectSettings::override_global(settings, cx);
|
||||
});
|
||||
|
||||
let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
|
||||
let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
|
||||
.on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
|
||||
context_server::types::InitializeResponse {
|
||||
protocol_version: context_server::types::ProtocolVersion(
|
||||
context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
|
||||
),
|
||||
server_info: context_server::types::Implementation {
|
||||
name: name.into(),
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
capabilities: context_server::types::ServerCapabilities {
|
||||
tools: Some(context_server::types::ToolsCapabilities {
|
||||
list_changed: Some(true),
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
meta: None,
|
||||
}
|
||||
})
|
||||
.on_request::<context_server::types::requests::ListTools, _>(move |_params| {
|
||||
let tools = tools.clone();
|
||||
async move {
|
||||
context_server::types::ListToolsResponse {
|
||||
tools,
|
||||
next_cursor: None,
|
||||
meta: None,
|
||||
}
|
||||
}
|
||||
})
|
||||
.on_request::<context_server::types::requests::CallTool, _>(move |params| {
|
||||
let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
|
||||
async move {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
mcp_tool_calls_tx
|
||||
.unbounded_send((params, response_tx))
|
||||
.unwrap();
|
||||
response_rx.await.unwrap()
|
||||
}
|
||||
});
|
||||
context_server_store.update(cx, |store, cx| {
|
||||
store.start_server(
|
||||
Arc::new(ContextServer::new(
|
||||
ContextServerId(name.into()),
|
||||
Arc::new(fake_transport),
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
mcp_tool_calls_rx
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue