acp: Support calling tools provided by MCP servers (#36752)
Release Notes: - N/A
This commit is contained in:
parent
3b7c1744b4
commit
4f0fad6996
3 changed files with 561 additions and 64 deletions
|
@ -9,15 +9,15 @@ use action_log::ActionLog;
|
|||
use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{
|
||||
AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
|
||||
SUMMARIZE_THREAD_PROMPT,
|
||||
AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode,
|
||||
SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
use collections::{HashMap, IndexMap};
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
FutureExt,
|
||||
|
@ -56,6 +56,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
|
|||
use uuid::Uuid;
|
||||
|
||||
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
|
||||
pub const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
/// The ID of the user prompt that initiated a request.
|
||||
///
|
||||
|
@ -627,7 +628,20 @@ impl Thread {
|
|||
stream: &ThreadEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
|
||||
let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| {
|
||||
self.context_server_registry
|
||||
.read(cx)
|
||||
.servers()
|
||||
.find_map(|(_, tools)| {
|
||||
if let Some(tool) = tools.get(tool_use.name.as_ref()) {
|
||||
Some(tool.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
let Some(tool) = tool else {
|
||||
stream
|
||||
.0
|
||||
.unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
|
||||
|
@ -1079,6 +1093,10 @@ impl Thread {
|
|||
self.cancel(cx);
|
||||
|
||||
let model = self.model.clone().context("No language model configured")?;
|
||||
let profile = AgentSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&self.profile_id)
|
||||
.context("Profile not found")?;
|
||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
|
||||
let event_stream = ThreadEventStream(events_tx);
|
||||
let message_ix = self.messages.len().saturating_sub(1);
|
||||
|
@ -1086,6 +1104,7 @@ impl Thread {
|
|||
self.summary = None;
|
||||
self.running_turn = Some(RunningTurn {
|
||||
event_stream: event_stream.clone(),
|
||||
tools: self.enabled_tools(profile, &model, cx),
|
||||
_task: cx.spawn(async move |this, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
|
||||
|
@ -1417,7 +1436,7 @@ impl Thread {
|
|||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
cx.notify();
|
||||
|
||||
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
||||
let tool = self.tool(tool_use.name.as_ref());
|
||||
let mut title = SharedString::from(&tool_use.name);
|
||||
let mut kind = acp::ToolKind::Other;
|
||||
if let Some(tool) = tool.as_ref() {
|
||||
|
@ -1727,6 +1746,21 @@ impl Thread {
|
|||
cx: &mut App,
|
||||
) -> Result<LanguageModelRequest> {
|
||||
let model = self.model().context("No language model configured")?;
|
||||
let tools = if let Some(turn) = self.running_turn.as_ref() {
|
||||
turn.tools
|
||||
.iter()
|
||||
.filter_map(|(tool_name, tool)| {
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name.to_string(),
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
log::debug!("Building completion request");
|
||||
log::debug!("Completion intent: {:?}", completion_intent);
|
||||
|
@ -1734,23 +1768,6 @@ impl Thread {
|
|||
|
||||
let messages = self.build_request_messages(cx);
|
||||
log::info!("Request will include {} messages", messages.len());
|
||||
|
||||
let tools = if let Some(tools) = self.tools(cx).log_err() {
|
||||
tools
|
||||
.filter_map(|tool| {
|
||||
let tool_name = tool.name().to_string();
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name,
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
log::info!("Request includes {} tools", tools.len());
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
|
@ -1770,37 +1787,76 @@ impl Thread {
|
|||
Ok(request)
|
||||
}
|
||||
|
||||
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
|
||||
let model = self.model().context("No language model configured")?;
|
||||
fn enabled_tools(
|
||||
&self,
|
||||
profile: &AgentProfileSettings,
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
cx: &App,
|
||||
) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
|
||||
fn truncate(tool_name: &SharedString) -> SharedString {
|
||||
if tool_name.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let mut truncated = tool_name.to_string();
|
||||
truncated.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
truncated.into()
|
||||
} else {
|
||||
tool_name.clone()
|
||||
}
|
||||
}
|
||||
|
||||
let profile = AgentSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&self.profile_id)
|
||||
.context("profile not found")?;
|
||||
let provider_id = model.provider_id();
|
||||
|
||||
Ok(self
|
||||
let mut tools = self
|
||||
.tools
|
||||
.iter()
|
||||
.filter(move |(_, tool)| tool.supported_provider(&provider_id))
|
||||
.filter_map(|(tool_name, tool)| {
|
||||
if profile.is_tool_enabled(tool_name) {
|
||||
Some(tool)
|
||||
if tool.supported_provider(&model.provider_id())
|
||||
&& profile.is_tool_enabled(tool_name)
|
||||
{
|
||||
Some((truncate(tool_name), tool.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(self.context_server_registry.read(cx).servers().flat_map(
|
||||
|(server_id, tools)| {
|
||||
tools.iter().filter_map(|(tool_name, tool)| {
|
||||
if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
|
||||
Some(tool)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
},
|
||||
)))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
|
||||
let mut context_server_tools = Vec::new();
|
||||
let mut seen_tools = tools.keys().cloned().collect::<HashSet<_>>();
|
||||
let mut duplicate_tool_names = HashSet::default();
|
||||
for (server_id, server_tools) in self.context_server_registry.read(cx).servers() {
|
||||
for (tool_name, tool) in server_tools {
|
||||
if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) {
|
||||
let tool_name = truncate(tool_name);
|
||||
if !seen_tools.insert(tool_name.clone()) {
|
||||
duplicate_tool_names.insert(tool_name.clone());
|
||||
}
|
||||
context_server_tools.push((server_id.clone(), tool_name, tool.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When there are duplicate tool names, disambiguate by prefixing them
|
||||
// with the server ID. In the rare case there isn't enough space for the
|
||||
// disambiguated tool name, keep only the last tool with this name.
|
||||
for (server_id, tool_name, tool) in context_server_tools {
|
||||
if duplicate_tool_names.contains(&tool_name) {
|
||||
let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len());
|
||||
if available >= 2 {
|
||||
let mut disambiguated = server_id.0.to_string();
|
||||
disambiguated.truncate(available - 1);
|
||||
disambiguated.push('_');
|
||||
disambiguated.push_str(&tool_name);
|
||||
tools.insert(disambiguated.into(), tool.clone());
|
||||
} else {
|
||||
tools.insert(tool_name, tool.clone());
|
||||
}
|
||||
} else {
|
||||
tools.insert(tool_name, tool.clone());
|
||||
}
|
||||
}
|
||||
|
||||
tools
|
||||
}
|
||||
|
||||
fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
|
||||
self.running_turn.as_ref()?.tools.get(name).cloned()
|
||||
}
|
||||
|
||||
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
|
||||
|
@ -1965,6 +2021,8 @@ struct RunningTurn {
|
|||
/// The current event stream for the running turn. Used to report a final
|
||||
/// cancellation event if we cancel the turn.
|
||||
event_stream: ThreadEventStream,
|
||||
/// The tools that were enabled for this turn.
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
}
|
||||
|
||||
impl RunningTurn {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue