agent: Ensure tool names are unique (#33237)
Closes #31903 Release Notes: - agent: Fix an issue where an error would occur when MCP servers specified tools with the same name --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
28f56ad7ae
commit
e68b95c61b
1 changed files with 226 additions and 5 deletions
|
@ -8,7 +8,7 @@ use anyhow::{Result, anyhow};
|
|||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use collections::HashMap;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::display_map::CreaseMetadata;
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use futures::future::Shared;
|
||||
|
@ -932,14 +932,13 @@ impl Thread {
|
|||
model: Arc<dyn LanguageModel>,
|
||||
) -> Vec<LanguageModelRequestTool> {
|
||||
if model.supports_tools() {
|
||||
self.profile
|
||||
.enabled_tools(cx)
|
||||
resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
.filter_map(|(name, tool)| {
|
||||
// Skip tools that cannot be supported
|
||||
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
name,
|
||||
description: tool.description(),
|
||||
input_schema,
|
||||
})
|
||||
|
@ -2847,6 +2846,85 @@ struct PendingCompletion {
|
|||
_task: Task<()>,
|
||||
}
|
||||
|
||||
/// Resolves tool name conflicts by ensuring all tool names are unique.
|
||||
///
|
||||
/// When multiple tools have the same name, this function applies the following rules:
|
||||
/// 1. Native tools always keep their original name
|
||||
/// 2. Context server tools get prefixed with their server ID and an underscore
|
||||
/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
|
||||
/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
|
||||
///
|
||||
/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
|
||||
fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
|
||||
fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
|
||||
let mut tool_name = tool.name();
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
tool_name
|
||||
}
|
||||
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
let mut duplicated_tool_names = HashSet::default();
|
||||
let mut seen_tool_names = HashSet::default();
|
||||
for tool in tools {
|
||||
let tool_name = resolve_tool_name(tool);
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
debug_assert!(
|
||||
tool.source() != assistant_tool::ToolSource::Native,
|
||||
"There are two built-in tools with the same name: {}",
|
||||
tool_name
|
||||
);
|
||||
duplicated_tool_names.insert(tool_name);
|
||||
} else {
|
||||
seen_tool_names.insert(tool_name);
|
||||
}
|
||||
}
|
||||
|
||||
if duplicated_tool_names.is_empty() {
|
||||
return tools
|
||||
.into_iter()
|
||||
.map(|tool| (resolve_tool_name(tool), tool.clone()))
|
||||
.collect();
|
||||
}
|
||||
|
||||
tools
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
let mut tool_name = resolve_tool_name(tool);
|
||||
if !duplicated_tool_names.contains(&tool_name) {
|
||||
return Some((tool_name, tool.clone()));
|
||||
}
|
||||
match tool.source() {
|
||||
assistant_tool::ToolSource::Native => {
|
||||
// Built-in tools always keep their original name
|
||||
Some((tool_name, tool.clone()))
|
||||
}
|
||||
assistant_tool::ToolSource::ContextServer { id } => {
|
||||
// Context server tools are prefixed with the context server ID, and truncated if necessary
|
||||
tool_name.insert(0, '_');
|
||||
if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
|
||||
let mut id = id.to_string();
|
||||
id.truncate(len);
|
||||
tool_name.insert_str(0, &id);
|
||||
} else {
|
||||
tool_name.insert_str(0, &id);
|
||||
}
|
||||
|
||||
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
|
||||
|
||||
if seen_tool_names.contains(&tool_name) {
|
||||
log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
|
||||
None
|
||||
} else {
|
||||
Some((tool_name, tool.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -2862,6 +2940,7 @@ mod tests {
|
|||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use theme::ThemeSettings;
|
||||
use ui::IconName;
|
||||
use util::path;
|
||||
use workspace::Workspace;
|
||||
|
||||
|
@ -3493,6 +3572,148 @@ fn main() {{
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_tool_name_conflicts() {
|
||||
use assistant_tool::{Tool, ToolSource};
|
||||
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
],
|
||||
vec!["tool1", "tool2", "tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool1", ToolSource::Native),
|
||||
TestTool::new("tool2", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::Native),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
|
||||
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
|
||||
],
|
||||
vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
|
||||
);
|
||||
|
||||
// Test that tool with very long name is always truncated
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![TestTool::new(
|
||||
"tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
|
||||
ToolSource::Native,
|
||||
)],
|
||||
vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
|
||||
);
|
||||
|
||||
// Test deduplication of tools with very long names, in this case the mcp server name should be truncated
|
||||
assert_resolve_tool_name_conflicts(
|
||||
vec![
|
||||
TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
|
||||
TestTool::new(
|
||||
"tool-with-very-very-very-long-name",
|
||||
ToolSource::ContextServer {
|
||||
id: "mcp-with-very-very-very-long-name".into(),
|
||||
},
|
||||
),
|
||||
],
|
||||
vec![
|
||||
"tool-with-very-very-very-long-name",
|
||||
"mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
|
||||
],
|
||||
);
|
||||
|
||||
fn assert_resolve_tool_name_conflicts(
|
||||
tools: Vec<TestTool>,
|
||||
expected: Vec<impl Into<String>>,
|
||||
) {
|
||||
let tools: Vec<Arc<dyn Tool>> = tools
|
||||
.into_iter()
|
||||
.map(|t| Arc::new(t) as Arc<dyn Tool>)
|
||||
.collect();
|
||||
let tools = resolve_tool_name_conflicts(&tools);
|
||||
assert_eq!(tools.len(), expected.len());
|
||||
for (i, expected_name) in expected.into_iter().enumerate() {
|
||||
let expected_name = expected_name.into();
|
||||
let actual_name = &tools[i].0;
|
||||
assert_eq!(
|
||||
actual_name, &expected_name,
|
||||
"Expected '{}' got '{}' at index {}",
|
||||
expected_name, actual_name, i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
struct TestTool {
|
||||
name: String,
|
||||
source: ToolSource,
|
||||
}
|
||||
|
||||
impl TestTool {
|
||||
fn new(name: impl Into<String>, source: ToolSource) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for TestTool {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Ai
|
||||
}
|
||||
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn source(&self) -> ToolSource {
|
||||
self.source.clone()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Test tool".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> assistant_tool::ToolResult {
|
||||
assistant_tool::ToolResult {
|
||||
output: Task::ready(Err(anyhow::anyhow!("No content"))),
|
||||
card: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn test_summarize_error(
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
thread: &Entity<Thread>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue