diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 1a6b9604b5..dfbb21a196 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -8,7 +8,7 @@ use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; -use collections::HashMap; +use collections::{HashMap, HashSet}; use editor::display_map::CreaseMetadata; use feature_flags::{self, FeatureFlagAppExt}; use futures::future::Shared; @@ -932,14 +932,13 @@ impl Thread { model: Arc, ) -> Vec { if model.supports_tools() { - self.profile - .enabled_tools(cx) + resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice()) .into_iter() - .filter_map(|tool| { + .filter_map(|(name, tool)| { // Skip tools that cannot be supported let input_schema = tool.input_schema(model.tool_input_format()).ok()?; Some(LanguageModelRequestTool { - name: tool.name(), + name, description: tool.description(), input_schema, }) @@ -2847,6 +2846,85 @@ struct PendingCompletion { _task: Task<()>, } +/// Resolves tool name conflicts by ensuring all tool names are unique. +/// +/// When multiple tools have the same name, this function applies the following rules: +/// 1. Native tools always keep their original name +/// 2. Context server tools get prefixed with their server ID and an underscore +/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters) +/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out +/// +/// Note: This function assumes that built-in tools occur before MCP tools in the tools list. +fn resolve_tool_name_conflicts(tools: &[Arc]) -> Vec<(String, Arc)> { + fn resolve_tool_name(tool: &Arc) -> String { + let mut tool_name = tool.name(); + tool_name.truncate(MAX_TOOL_NAME_LENGTH); + tool_name + } + + const MAX_TOOL_NAME_LENGTH: usize = 64; + + let mut duplicated_tool_names = HashSet::default(); + let mut seen_tool_names = HashSet::default(); + for tool in tools { + let tool_name = resolve_tool_name(tool); + if seen_tool_names.contains(&tool_name) { + debug_assert!( + tool.source() != assistant_tool::ToolSource::Native, + "There are two built-in tools with the same name: {}", + tool_name + ); + duplicated_tool_names.insert(tool_name); + } else { + seen_tool_names.insert(tool_name); + } + } + + if duplicated_tool_names.is_empty() { + return tools + .into_iter() + .map(|tool| (resolve_tool_name(tool), tool.clone())) + .collect(); + } + + tools + .into_iter() + .filter_map(|tool| { + let mut tool_name = resolve_tool_name(tool); + if !duplicated_tool_names.contains(&tool_name) { + return Some((tool_name, tool.clone())); + } + match tool.source() { + assistant_tool::ToolSource::Native => { + // Built-in tools always keep their original name + Some((tool_name, tool.clone())) + } + assistant_tool::ToolSource::ContextServer { id } => { + // Context server tools are prefixed with the context server ID, and truncated if necessary + tool_name.insert(0, '_'); + if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH { + let len = MAX_TOOL_NAME_LENGTH - tool_name.len(); + let mut id = id.to_string(); + id.truncate(len); + tool_name.insert_str(0, &id); + } else { + tool_name.insert_str(0, &id); + } + + tool_name.truncate(MAX_TOOL_NAME_LENGTH); + + if seen_tool_names.contains(&tool_name) { + log::error!("Cannot resolve tool name conflict for tool {}", tool.name()); + None + } else { + Some((tool_name, tool.clone())) + } + } + } + }) + .collect() +} + #[cfg(test)] mod tests { use super::*; @@ -2862,6 +2940,7 @@ mod tests { use settings::{Settings, SettingsStore}; use std::sync::Arc; use theme::ThemeSettings; + use ui::IconName; use util::path; use workspace::Workspace; @@ -3493,6 +3572,148 @@ fn main() {{ }); } + #[gpui::test] + fn test_resolve_tool_name_conflicts() { + use assistant_tool::{Tool, ToolSource}; + + assert_resolve_tool_name_conflicts( + vec![ + TestTool::new("tool1", ToolSource::Native), + TestTool::new("tool2", ToolSource::Native), + TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), + ], + vec!["tool1", "tool2", "tool3"], + ); + + assert_resolve_tool_name_conflicts( + vec![ + TestTool::new("tool1", ToolSource::Native), + TestTool::new("tool2", ToolSource::Native), + TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), + TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }), + ], + vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"], + ); + + assert_resolve_tool_name_conflicts( + vec![ + TestTool::new("tool1", ToolSource::Native), + TestTool::new("tool2", ToolSource::Native), + TestTool::new("tool3", ToolSource::Native), + TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), + TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }), + ], + vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"], + ); + + // Test that tool with very long name is always truncated + assert_resolve_tool_name_conflicts( + vec![TestTool::new( + "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah", + ToolSource::Native, + )], + vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"], + ); + + // Test deduplication of tools with very long names, in this case the mcp server name should be truncated + assert_resolve_tool_name_conflicts( + vec![ + TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native), + TestTool::new( + "tool-with-very-very-very-long-name", + ToolSource::ContextServer { + id: "mcp-with-very-very-very-long-name".into(), + }, + ), + ], + vec![ + "tool-with-very-very-very-long-name", + "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name", + ], + ); + + fn assert_resolve_tool_name_conflicts( + tools: Vec, + expected: Vec>, + ) { + let tools: Vec> = tools + .into_iter() + .map(|t| Arc::new(t) as Arc) + .collect(); + let tools = resolve_tool_name_conflicts(&tools); + assert_eq!(tools.len(), expected.len()); + for (i, expected_name) in expected.into_iter().enumerate() { + let expected_name = expected_name.into(); + let actual_name = &tools[i].0; + assert_eq!( + actual_name, &expected_name, + "Expected '{}' got '{}' at index {}", + expected_name, actual_name, i + ); + } + } + + struct TestTool { + name: String, + source: ToolSource, + } + + impl TestTool { + fn new(name: impl Into, source: ToolSource) -> Self { + Self { + name: name.into(), + source, + } + } + } + + impl Tool for TestTool { + fn name(&self) -> String { + self.name.clone() + } + + fn icon(&self) -> IconName { + IconName::Ai + } + + fn may_perform_edits(&self) -> bool { + false + } + + fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + true + } + + fn source(&self) -> ToolSource { + self.source.clone() + } + + fn description(&self) -> String { + "Test tool".to_string() + } + + fn ui_text(&self, _input: &serde_json::Value) -> String { + "Test tool".to_string() + } + + fn run( + self: Arc, + _input: serde_json::Value, + _request: Arc, + _project: Entity, + _action_log: Entity, + _model: Arc, + _window: Option, + _cx: &mut App, + ) -> assistant_tool::ToolResult { + assistant_tool::ToolResult { + output: Task::ready(Err(anyhow::anyhow!("No content"))), + card: None, + } + } + } + } + fn test_summarize_error( model: &Arc, thread: &Entity,