agent: Encode tool input with associated type

This commit is contained in:
Bennet Bo Fenner 2025-07-07 00:40:53 +02:00
parent 8fecacfbaa
commit dbf3c31a83
26 changed files with 728 additions and 571 deletions

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings}; use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
use assistant_tool::{Tool, ToolSource, ToolWorkingSet, UniqueToolName}; use assistant_tool::{AnyTool, ToolSource, ToolWorkingSet, UniqueToolName};
use collections::IndexMap; use collections::IndexMap;
use convert_case::{Case, Casing}; use convert_case::{Case, Casing};
use fs::Fs; use fs::Fs;
@ -72,7 +72,7 @@ impl AgentProfile {
&self.id &self.id
} }
pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> { pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else { let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
return Vec::new(); return Vec::new();
}; };
@ -108,7 +108,7 @@ impl AgentProfile {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use agent_settings::ContextServerPreset; use agent_settings::ContextServerPreset;
use assistant_tool::ToolRegistry; use assistant_tool::{Tool, ToolRegistry};
use collections::IndexMap; use collections::IndexMap;
use gpui::SharedString; use gpui::SharedString;
use gpui::{AppContext, TestAppContext}; use gpui::{AppContext, TestAppContext};
@ -269,8 +269,14 @@ mod tests {
fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> { fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
cx.new(|cx| { cx.new(|cx| {
let mut tool_set = ToolWorkingSet::default(); let mut tool_set = ToolWorkingSet::default();
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")), cx); tool_set.insert(
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")), cx); Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")).into(),
cx,
);
tool_set.insert(
Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")).into(),
cx,
);
tool_set tool_set
}) })
} }
@ -290,6 +296,8 @@ mod tests {
} }
impl Tool for FakeTool { impl Tool for FakeTool {
type Input = ();
fn name(&self) -> String { fn name(&self) -> String {
self.name.clone() self.name.clone()
} }
@ -308,17 +316,17 @@ mod tests {
unimplemented!() unimplemented!()
} }
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
unimplemented!() unimplemented!()
} }
fn ui_text(&self, _input: &serde_json::Value) -> String { fn ui_text(&self, _input: &Self::Input) -> String {
unimplemented!() unimplemented!()
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
_input: serde_json::Value, _input: Self::Input,
_request: Arc<language_model::LanguageModelRequest>, _request: Arc<language_model::LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<assistant_tool::ActionLog>, _action_log: Entity<assistant_tool::ActionLog>,

View file

@ -29,6 +29,8 @@ impl ContextServerTool {
} }
impl Tool for ContextServerTool { impl Tool for ContextServerTool {
type Input = serde_json::Value;
fn name(&self) -> String { fn name(&self) -> String {
self.tool.name.clone() self.tool.name.clone()
} }
@ -47,7 +49,7 @@ impl Tool for ContextServerTool {
} }
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
true true
} }
@ -69,13 +71,13 @@ impl Tool for ContextServerTool {
}) })
} }
fn ui_text(&self, _input: &serde_json::Value) -> String { fn ui_text(&self, _input: &Self::Input) -> String {
format!("Run MCP tool `{}`", self.tool.name) format!("Run MCP tool `{}`", self.tool.name)
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,

View file

@ -10,7 +10,7 @@ use crate::{
}; };
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use assistant_tool::{ActionLog, AnyTool, AnyToolCard, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage}; use client::{ModelRequestUsage, RequestUsage};
use collections::HashMap; use collections::HashMap;
@ -2452,7 +2452,7 @@ impl Thread {
ui_text: impl Into<SharedString>, ui_text: impl Into<SharedString>,
input: serde_json::Value, input: serde_json::Value,
request: Arc<LanguageModelRequest>, request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>, tool: AnyTool,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>, window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>, cx: &mut Context<Thread>,
@ -2468,7 +2468,7 @@ impl Thread {
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
request: Arc<LanguageModelRequest>, request: Arc<LanguageModelRequest>,
input: serde_json::Value, input: serde_json::Value,
tool: Arc<dyn Tool>, tool: AnyTool,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>, window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>, cx: &mut Context<Thread>,

View file

@ -6,7 +6,7 @@ use crate::{
}; };
use agent_settings::{AgentProfileId, CompletionMode}; use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{Tool, ToolId, ToolWorkingSet}; use assistant_tool::{ToolId, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::HashMap; use collections::HashMap;
use context_server::ContextServerId; use context_server::ContextServerId;
@ -576,7 +576,8 @@ impl ThreadStore {
context_server_store.clone(), context_server_store.clone(),
server.id(), server.id(),
tool, tool,
)) as Arc<dyn Tool> ))
.into()
}), }),
cx, cx,
) )

View file

@ -4,7 +4,7 @@ use crate::{
}; };
use anyhow::Result; use anyhow::Result;
use assistant_tool::{ use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet, AnyTool, AnyToolCard, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
}; };
use collections::HashMap; use collections::HashMap;
use futures::{FutureExt as _, future::Shared}; use futures::{FutureExt as _, future::Shared};
@ -378,7 +378,7 @@ impl ToolUseState {
ui_text: impl Into<Arc<str>>, ui_text: impl Into<Arc<str>>,
input: serde_json::Value, input: serde_json::Value,
request: Arc<LanguageModelRequest>, request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>, tool: AnyTool,
) { ) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
let ui_text = ui_text.into(); let ui_text = ui_text.into();
@ -533,7 +533,7 @@ pub struct Confirmation {
pub input: serde_json::Value, pub input: serde_json::Value,
pub ui_text: Arc<str>, pub ui_text: Arc<str>,
pub request: Arc<LanguageModelRequest>, pub request: Arc<LanguageModelRequest>,
pub tool: Arc<dyn Tool>, pub tool: AnyTool,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View file

@ -1,5 +1,5 @@
use agent::{Thread, ThreadEvent}; use agent::{Thread, ThreadEvent};
use assistant_tool::{Tool, ToolSource}; use assistant_tool::{AnyTool, ToolSource};
use collections::HashMap; use collections::HashMap;
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window}; use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
use language_model::{LanguageModel, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelToolSchemaFormat};
@ -7,7 +7,7 @@ use std::sync::Arc;
use ui::prelude::*; use ui::prelude::*;
pub struct IncompatibleToolsState { pub struct IncompatibleToolsState {
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>, cache: HashMap<LanguageModelToolSchemaFormat, Vec<AnyTool>>,
thread: Entity<Thread>, thread: Entity<Thread>,
_thread_subscription: Subscription, _thread_subscription: Subscription,
} }
@ -29,11 +29,7 @@ impl IncompatibleToolsState {
} }
} }
pub fn incompatible_tools( pub fn incompatible_tools(&mut self, model: &Arc<dyn LanguageModel>, cx: &App) -> &[AnyTool] {
&mut self,
model: &Arc<dyn LanguageModel>,
cx: &App,
) -> &[Arc<dyn Tool>] {
self.cache self.cache
.entry(model.tool_input_format()) .entry(model.tool_input_format())
.or_insert_with(|| { .or_insert_with(|| {
@ -50,7 +46,7 @@ impl IncompatibleToolsState {
} }
pub struct IncompatibleToolsTooltip { pub struct IncompatibleToolsTooltip {
pub incompatible_tools: Vec<Arc<dyn Tool>>, pub incompatible_tools: Vec<AnyTool>,
} }
impl Render for IncompatibleToolsTooltip { impl Render for IncompatibleToolsTooltip {

View file

@ -4,25 +4,19 @@ mod tool_registry;
mod tool_schema; mod tool_schema;
mod tool_working_set; mod tool_working_set;
use std::fmt; use std::{fmt, fmt::Debug, fmt::Formatter, ops::Deref, sync::Arc};
use std::fmt::Debug;
use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use gpui::AnyElement; use gpui::{
use gpui::AnyWindowHandle; AnyElement, AnyWindowHandle, App, Context, Entity, IntoElement, SharedString, Task, WeakEntity,
use gpui::Context; Window,
use gpui::IntoElement; };
use gpui::Window;
use gpui::{App, Entity, SharedString, Task, WeakEntity};
use icons::IconName; use icons::IconName;
use language_model::LanguageModel; use language_model::{
use language_model::LanguageModelImage; LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
use language_model::LanguageModelRequest; };
use language_model::LanguageModelToolSchemaFormat;
use project::Project; use project::Project;
use serde::de::DeserializeOwned;
use workspace::Workspace; use workspace::Workspace;
pub use crate::action_log::*; pub use crate::action_log::*;
@ -199,7 +193,10 @@ pub enum ToolSource {
} }
/// A tool that can be used by a language model. /// A tool that can be used by a language model.
pub trait Tool: 'static + Send + Sync { pub trait Tool: Send + Sync + 'static {
/// The input type that is accepted by the tool.
type Input: DeserializeOwned;
/// Returns the name of the tool. /// Returns the name of the tool.
fn name(&self) -> String; fn name(&self) -> String;
@ -216,7 +213,7 @@ pub trait Tool: 'static + Send + Sync {
/// Returns true if the tool needs the users's confirmation /// Returns true if the tool needs the users's confirmation
/// before having permission to run. /// before having permission to run.
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool; fn needs_confirmation(&self, input: &Self::Input, cx: &App) -> bool;
/// Returns true if the tool may perform edits. /// Returns true if the tool may perform edits.
fn may_perform_edits(&self) -> bool; fn may_perform_edits(&self) -> bool;
@ -227,18 +224,18 @@ pub trait Tool: 'static + Send + Sync {
} }
/// Returns markdown to be displayed in the UI for this tool. /// Returns markdown to be displayed in the UI for this tool.
fn ui_text(&self, input: &serde_json::Value) -> String; fn ui_text(&self, input: &Self::Input) -> String;
/// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
/// (so information may be missing). /// (so information may be missing).
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
self.ui_text(input) self.ui_text(input)
} }
/// Runs the tool with the provided input. /// Runs the tool with the provided input.
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
request: Arc<LanguageModelRequest>, request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
@ -258,7 +255,199 @@ pub trait Tool: 'static + Send + Sync {
} }
} }
impl Debug for dyn Tool { #[derive(Clone)]
pub struct AnyTool {
inner: Arc<dyn ErasedTool>,
}
/// Copy of `Tool` where the Input type is erased.
trait ErasedTool: Send + Sync {
fn name(&self) -> String;
fn description(&self) -> String;
fn icon(&self) -> IconName;
fn source(&self) -> ToolSource;
fn may_perform_edits(&self) -> bool;
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn ui_text(&self, input: &serde_json::Value) -> String;
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String;
fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult;
fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard>;
}
struct ErasedToolWrapper<T: Tool> {
tool: Arc<T>,
}
impl<T: Tool> ErasedTool for ErasedToolWrapper<T> {
fn name(&self) -> String {
self.tool.name()
}
fn description(&self) -> String {
self.tool.description()
}
fn icon(&self) -> IconName {
self.tool.icon()
}
fn source(&self) -> ToolSource {
self.tool.source()
}
fn may_perform_edits(&self) -> bool {
self.tool.may_perform_edits()
}
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.needs_confirmation(&parsed_input, cx),
Err(_) => true,
}
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
self.tool.input_schema(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.ui_text(&parsed_input),
Err(_) => "Invalid input".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.still_streaming_ui_text(&parsed_input),
Err(_) => "Invalid input".to_string(),
}
}
fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match serde_json::from_value::<T::Input>(input) {
Ok(parsed_input) => self.tool.clone().run(
parsed_input,
request,
project,
action_log,
model,
window,
cx,
),
Err(err) => ToolResult::from(Task::ready(Err(err.into()))),
}
}
fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
self.tool
.clone()
.deserialize_card(output, project, window, cx)
}
}
impl<T: Tool> From<Arc<T>> for AnyTool {
fn from(tool: Arc<T>) -> Self {
Self {
inner: Arc::new(ErasedToolWrapper { tool }),
}
}
}
impl AnyTool {
pub fn name(&self) -> String {
self.inner.name()
}
pub fn description(&self) -> String {
self.inner.description()
}
pub fn icon(&self) -> IconName {
self.inner.icon()
}
pub fn source(&self) -> ToolSource {
self.inner.source()
}
pub fn may_perform_edits(&self) -> bool {
self.inner.may_perform_edits()
}
pub fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
self.inner.needs_confirmation(input, cx)
}
pub fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
self.inner.input_schema(format)
}
pub fn ui_text(&self, input: &serde_json::Value) -> String {
self.inner.ui_text(input)
}
pub fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
self.inner.still_streaming_ui_text(input)
}
pub fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
self.inner
.run(input, request, project, action_log, model, window, cx)
}
pub fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
self.inner.deserialize_card(output, project, window, cx)
}
}
impl Debug for AnyTool {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Tool").field("name", &self.name()).finish() f.debug_struct("Tool").field("name", &self.name()).finish()
} }

View file

@ -6,7 +6,7 @@ use gpui::Global;
use gpui::{App, ReadGlobal}; use gpui::{App, ReadGlobal};
use parking_lot::RwLock; use parking_lot::RwLock;
use crate::Tool; use crate::{AnyTool, Tool};
#[derive(Default, Deref, DerefMut)] #[derive(Default, Deref, DerefMut)]
struct GlobalToolRegistry(Arc<ToolRegistry>); struct GlobalToolRegistry(Arc<ToolRegistry>);
@ -15,7 +15,7 @@ impl Global for GlobalToolRegistry {}
#[derive(Default)] #[derive(Default)]
struct ToolRegistryState { struct ToolRegistryState {
tools: HashMap<Arc<str>, Arc<dyn Tool>>, tools: HashMap<Arc<str>, AnyTool>,
} }
#[derive(Default)] #[derive(Default)]
@ -48,7 +48,7 @@ impl ToolRegistry {
pub fn register_tool(&self, tool: impl Tool) { pub fn register_tool(&self, tool: impl Tool) {
let mut state = self.state.write(); let mut state = self.state.write();
let tool_name: Arc<str> = tool.name().into(); let tool_name: Arc<str> = tool.name().into();
state.tools.insert(tool_name, Arc::new(tool)); state.tools.insert(tool_name, Arc::new(tool).into());
} }
/// Unregisters the provided [`Tool`]. /// Unregisters the provided [`Tool`].
@ -63,12 +63,12 @@ impl ToolRegistry {
} }
/// Returns the list of tools in the registry. /// Returns the list of tools in the registry.
pub fn tools(&self) -> Vec<Arc<dyn Tool>> { pub fn tools(&self) -> Vec<AnyTool> {
self.state.read().tools.values().cloned().collect() self.state.read().tools.values().cloned().collect()
} }
/// Returns the [`Tool`] with the given name. /// Returns the [`Tool`] with the given name.
pub fn tool(&self, name: &str) -> Option<Arc<dyn Tool>> { pub fn tool(&self, name: &str) -> Option<AnyTool> {
self.state.read().tools.get(name).cloned() self.state.read().tools.get(name).cloned()
} }
} }

View file

@ -1,6 +1,6 @@
use std::{borrow::Borrow, sync::Arc}; use std::borrow::Borrow;
use crate::{Tool, ToolRegistry, ToolSource}; use crate::{AnyTool, ToolRegistry, ToolSource};
use collections::{HashMap, HashSet, IndexMap}; use collections::{HashMap, HashSet, IndexMap};
use gpui::{App, SharedString}; use gpui::{App, SharedString};
use util::debug_panic; use util::debug_panic;
@ -45,20 +45,20 @@ impl std::fmt::Display for UniqueToolName {
/// A working set of tools for use in one instance of the Assistant Panel. /// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)] #[derive(Default)]
pub struct ToolWorkingSet { pub struct ToolWorkingSet {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>, context_server_tools_by_id: HashMap<ToolId, AnyTool>,
context_server_tools_by_name: HashMap<UniqueToolName, Arc<dyn Tool>>, context_server_tools_by_name: HashMap<UniqueToolName, AnyTool>,
next_tool_id: ToolId, next_tool_id: ToolId,
} }
impl ToolWorkingSet { impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> { pub fn tool(&self, name: &str, cx: &App) -> Option<AnyTool> {
self.context_server_tools_by_name self.context_server_tools_by_name
.get(name) .get(name)
.cloned() .cloned()
.or_else(|| ToolRegistry::global(cx).tool(name)) .or_else(|| ToolRegistry::global(cx).tool(name))
} }
pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> { pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
let mut tools = ToolRegistry::global(cx) let mut tools = ToolRegistry::global(cx)
.tools() .tools()
.into_iter() .into_iter()
@ -68,7 +68,7 @@ impl ToolWorkingSet {
tools tools
} }
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> { pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<AnyTool>> {
let mut tools_by_source = IndexMap::default(); let mut tools_by_source = IndexMap::default();
for (_, tool) in self.tools(cx) { for (_, tool) in self.tools(cx) {
@ -87,13 +87,13 @@ impl ToolWorkingSet {
tools_by_source tools_by_source
} }
pub fn insert(&mut self, tool: Arc<dyn Tool>, cx: &App) -> ToolId { pub fn insert(&mut self, tool: AnyTool, cx: &App) -> ToolId {
let tool_id = self.register_tool(tool); let tool_id = self.register_tool(tool);
self.tools_changed(cx); self.tools_changed(cx);
tool_id tool_id
} }
pub fn extend(&mut self, tools: impl Iterator<Item = Arc<dyn Tool>>, cx: &App) -> Vec<ToolId> { pub fn extend(&mut self, tools: impl Iterator<Item = AnyTool>, cx: &App) -> Vec<ToolId> {
let ids = tools.map(|tool| self.register_tool(tool)).collect(); let ids = tools.map(|tool| self.register_tool(tool)).collect();
self.tools_changed(cx); self.tools_changed(cx);
ids ids
@ -105,7 +105,7 @@ impl ToolWorkingSet {
self.tools_changed(cx); self.tools_changed(cx);
} }
fn register_tool(&mut self, tool: Arc<dyn Tool>) -> ToolId { fn register_tool(&mut self, tool: AnyTool) -> ToolId {
let tool_id = self.next_tool_id; let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1; self.next_tool_id.0 += 1;
self.context_server_tools_by_id self.context_server_tools_by_id
@ -126,10 +126,10 @@ impl ToolWorkingSet {
} }
fn resolve_context_server_tool_name_conflicts( fn resolve_context_server_tool_name_conflicts(
context_server_tools: &[Arc<dyn Tool>], context_server_tools: &[AnyTool],
native_tools: &[Arc<dyn Tool>], native_tools: &[AnyTool],
) -> HashMap<UniqueToolName, Arc<dyn Tool>> { ) -> HashMap<UniqueToolName, AnyTool> {
fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String { fn resolve_tool_name(tool: &AnyTool) -> String {
let mut tool_name = tool.name(); let mut tool_name = tool.name();
tool_name.truncate(MAX_TOOL_NAME_LENGTH); tool_name.truncate(MAX_TOOL_NAME_LENGTH);
tool_name tool_name
@ -201,11 +201,13 @@ fn resolve_context_server_tool_name_conflicts(
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use gpui::{AnyWindowHandle, Entity, Task, TestAppContext}; use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
use language_model::{LanguageModel, LanguageModelRequest}; use language_model::{LanguageModel, LanguageModelRequest};
use project::Project; use project::Project;
use crate::{ActionLog, ToolResult}; use crate::{ActionLog, Tool, ToolResult};
use super::*; use super::*;
@ -234,11 +236,13 @@ mod tests {
Arc::new(TestTool::new( Arc::new(TestTool::new(
"tool2", "tool2",
ToolSource::ContextServer { id: "mcp-1".into() }, ToolSource::ContextServer { id: "mcp-1".into() },
)) as Arc<dyn Tool>, ))
.into(),
Arc::new(TestTool::new( Arc::new(TestTool::new(
"tool2", "tool2",
ToolSource::ContextServer { id: "mcp-2".into() }, ToolSource::ContextServer { id: "mcp-2".into() },
)) as Arc<dyn Tool>, ))
.into(),
] ]
.into_iter(), .into_iter(),
cx, cx,
@ -324,13 +328,13 @@ mod tests {
context_server_tools: Vec<TestTool>, context_server_tools: Vec<TestTool>,
expected: Vec<&'static str>, expected: Vec<&'static str>,
) { ) {
let context_server_tools: Vec<Arc<dyn Tool>> = context_server_tools let context_server_tools: Vec<AnyTool> = context_server_tools
.into_iter() .into_iter()
.map(|t| Arc::new(t) as Arc<dyn Tool>) .map(|t| Arc::new(t).into())
.collect(); .collect();
let builtin_tools: Vec<Arc<dyn Tool>> = builtin_tools let builtin_tools: Vec<AnyTool> = builtin_tools
.into_iter() .into_iter()
.map(|t| Arc::new(t) as Arc<dyn Tool>) .map(|t| Arc::new(t).into())
.collect(); .collect();
let tools = let tools =
resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools); resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
@ -363,6 +367,8 @@ mod tests {
} }
impl Tool for TestTool { impl Tool for TestTool {
type Input = ();
fn name(&self) -> String { fn name(&self) -> String {
self.name.clone() self.name.clone()
} }
@ -375,7 +381,7 @@ mod tests {
false false
} }
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
true true
} }
@ -387,13 +393,13 @@ mod tests {
"Test tool".to_string() "Test tool".to_string()
} }
fn ui_text(&self, _input: &serde_json::Value) -> String { fn ui_text(&self, _input: &Self::Input) -> String {
"Test tool".to_string() "Test tool".to_string()
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
_input: serde_json::Value, _input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,

View file

@ -40,11 +40,13 @@ pub struct CopyPathToolInput {
pub struct CopyPathTool; pub struct CopyPathTool;
impl Tool for CopyPathTool { impl Tool for CopyPathTool {
type Input = CopyPathToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"copy_path".into() "copy_path".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -64,20 +66,15 @@ impl Tool for CopyPathTool {
json_schema_for::<CopyPathToolInput>(format) json_schema_for::<CopyPathToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<CopyPathToolInput>(input.clone()) { let src = MarkdownInlineCode(&input.source_path);
Ok(input) => { let dest = MarkdownInlineCode(&input.destination_path);
let src = MarkdownInlineCode(&input.source_path); format!("Copy {src} to {dest}")
let dest = MarkdownInlineCode(&input.destination_path);
format!("Copy {src} to {dest}")
}
Err(_) => "Copy path".to_string(),
}
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -85,10 +82,6 @@ impl Tool for CopyPathTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let copy_task = project.update(cx, |project, cx| { let copy_task = project.update(cx, |project, cx| {
match project match project
.find_project_path(&input.source_path, cx) .find_project_path(&input.source_path, cx)

View file

@ -29,6 +29,8 @@ pub struct CreateDirectoryToolInput {
pub struct CreateDirectoryTool; pub struct CreateDirectoryTool;
impl Tool for CreateDirectoryTool { impl Tool for CreateDirectoryTool {
type Input = CreateDirectoryToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"create_directory".into() "create_directory".into()
} }
@ -37,7 +39,7 @@ impl Tool for CreateDirectoryTool {
include_str!("./create_directory_tool/description.md").into() include_str!("./create_directory_tool/description.md").into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -53,18 +55,13 @@ impl Tool for CreateDirectoryTool {
json_schema_for::<CreateDirectoryToolInput>(format) json_schema_for::<CreateDirectoryToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<CreateDirectoryToolInput>(input.clone()) { format!("Create directory {}", MarkdownInlineCode(&input.path))
Ok(input) => {
format!("Create directory {}", MarkdownInlineCode(&input.path))
}
Err(_) => "Create directory".to_string(),
}
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -72,10 +69,6 @@ impl Tool for CreateDirectoryTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match project.read(cx).find_project_path(&input.path, cx) { let project_path = match project.read(cx).find_project_path(&input.path, cx) {
Some(project_path) => project_path, Some(project_path) => project_path,
None => { None => {

View file

@ -29,11 +29,13 @@ pub struct DeletePathToolInput {
pub struct DeletePathTool; pub struct DeletePathTool;
impl Tool for DeletePathTool { impl Tool for DeletePathTool {
type Input = DeletePathToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"delete_path".into() "delete_path".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -53,16 +55,13 @@ impl Tool for DeletePathTool {
json_schema_for::<DeletePathToolInput>(format) json_schema_for::<DeletePathToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<DeletePathToolInput>(input.clone()) { format!("Delete “`{}`”", input.path)
Ok(input) => format!("Delete “`{}`”", input.path),
Err(_) => "Delete path".to_string(),
}
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
@ -70,10 +69,7 @@ impl Tool for DeletePathTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) { let path_str = input.path;
Ok(input) => input.path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else { let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else {
return Task::ready(Err(anyhow!( return Task::ready(Err(anyhow!(
"Couldn't delete {path_str} because that path isn't in this project." "Couldn't delete {path_str} because that path isn't in this project."

View file

@ -42,11 +42,13 @@ where
pub struct DiagnosticsTool; pub struct DiagnosticsTool;
impl Tool for DiagnosticsTool { impl Tool for DiagnosticsTool {
type Input = DiagnosticsToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"diagnostics".into() "diagnostics".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -66,15 +68,9 @@ impl Tool for DiagnosticsTool {
json_schema_for::<DiagnosticsToolInput>(format) json_schema_for::<DiagnosticsToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input.clone()) if let Some(path) = input.path.as_ref().filter(|p| !p.is_empty()) {
.ok() format!("Check diagnostics for {}", MarkdownInlineCode(path))
.and_then(|input| match input.path {
Some(path) if !path.is_empty() => Some(path),
_ => None,
})
{
format!("Check diagnostics for {}", MarkdownInlineCode(&path))
} else { } else {
"Check project diagnostics".to_string() "Check project diagnostics".to_string()
} }
@ -82,7 +78,7 @@ impl Tool for DiagnosticsTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
@ -90,10 +86,7 @@ impl Tool for DiagnosticsTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
match serde_json::from_value::<DiagnosticsToolInput>(input) match input.path {
.ok()
.and_then(|input| input.path)
{
Some(path) if !path.is_empty() => { Some(path) if !path.is_empty() => {
let Some(project_path) = project.read(cx).find_project_path(&path, cx) else { let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
return Task::ready(Err(anyhow!("Could not find path {path} in project",))) return Task::ready(Err(anyhow!("Could not find path {path} in project",)))

View file

@ -121,11 +121,13 @@ struct PartialInput {
const DEFAULT_UI_TEXT: &str = "Editing file"; const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for EditFileTool { impl Tool for EditFileTool {
type Input = EditFileToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"edit_file".into() "edit_file".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -145,24 +147,20 @@ impl Tool for EditFileTool {
json_schema_for::<EditFileToolInput>(format) json_schema_for::<EditFileToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<EditFileToolInput>(input.clone()) { input.display_description.clone()
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
}
} }
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() { let description = input.display_description.trim();
let description = input.display_description.trim(); if !description.is_empty() {
if !description.is_empty() { return description.to_string();
return description.to_string(); }
}
let path = input.path.trim(); let path = input.path.to_string_lossy();
if !path.is_empty() { let path = path.trim();
return path.to_string(); if !path.is_empty() {
} return path.to_string();
} }
DEFAULT_UI_TEXT.to_string() DEFAULT_UI_TEXT.to_string()
@ -170,7 +168,7 @@ impl Tool for EditFileTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
request: Arc<LanguageModelRequest>, request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
@ -178,11 +176,6 @@ impl Tool for EditFileTool {
window: Option<AnyWindowHandle>, window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input = match serde_json::from_value::<EditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match resolve_path(&input, project.clone(), cx) { let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path, Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(), Err(err) => return Task::ready(Err(anyhow!(err))).into(),
@ -1169,12 +1162,11 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = serde_json::to_value(EditFileToolInput { let input = EditFileToolInput {
display_description: "Some edit".into(), display_description: "Some edit".into(),
path: "root/nonexistent_file.txt".into(), path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit, mode: EditFileMode::Edit,
}) };
.unwrap();
Arc::new(EditFileTool) Arc::new(EditFileTool)
.run( .run(
input, input,
@ -1288,24 +1280,22 @@ mod tests {
#[test] #[test]
fn still_streaming_ui_text_with_path() { fn still_streaming_ui_text_with_path() {
let input = json!({ let input = EditFileToolInput {
"path": "src/main.rs", path: "src/main.rs".into(),
"display_description": "", display_description: "".into(),
"old_string": "old code", mode: EditFileMode::Edit,
"new_string": "new code" };
});
assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs"); assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
} }
#[test] #[test]
fn still_streaming_ui_text_with_description() { fn still_streaming_ui_text_with_description() {
let input = json!({ let input = EditFileToolInput {
"path": "", path: "".into(),
"display_description": "Fix error handling", display_description: "Fix error handling".into(),
"old_string": "old code", mode: EditFileMode::Edit,
"new_string": "new code" };
});
assert_eq!( assert_eq!(
EditFileTool.still_streaming_ui_text(&input), EditFileTool.still_streaming_ui_text(&input),
@ -1315,12 +1305,11 @@ mod tests {
#[test] #[test]
fn still_streaming_ui_text_with_path_and_description() { fn still_streaming_ui_text_with_path_and_description() {
let input = json!({ let input = EditFileToolInput {
"path": "src/main.rs", path: "src/main.rs".into(),
"display_description": "Fix error handling", display_description: "Fix error handling".into(),
"old_string": "old code", mode: EditFileMode::Edit,
"new_string": "new code" };
});
assert_eq!( assert_eq!(
EditFileTool.still_streaming_ui_text(&input), EditFileTool.still_streaming_ui_text(&input),
@ -1330,12 +1319,11 @@ mod tests {
#[test] #[test]
fn still_streaming_ui_text_no_path_or_description() { fn still_streaming_ui_text_no_path_or_description() {
let input = json!({ let input = EditFileToolInput {
"path": "", path: "".into(),
"display_description": "", display_description: "".into(),
"old_string": "old code", mode: EditFileMode::Edit,
"new_string": "new code" };
});
assert_eq!( assert_eq!(
EditFileTool.still_streaming_ui_text(&input), EditFileTool.still_streaming_ui_text(&input),
@ -1345,7 +1333,11 @@ mod tests {
#[test] #[test]
fn still_streaming_ui_text_with_null() { fn still_streaming_ui_text_with_null() {
let input = serde_json::Value::Null; let input = EditFileToolInput {
path: "".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
assert_eq!( assert_eq!(
EditFileTool.still_streaming_ui_text(&input), EditFileTool.still_streaming_ui_text(&input),
@ -1457,12 +1449,11 @@ mod tests {
// Have the model stream unformatted content // Have the model stream unformatted content
let edit_result = { let edit_result = {
let edit_task = cx.update(|cx| { let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput { let input = EditFileToolInput {
display_description: "Create main function".into(), display_description: "Create main function".into(),
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}) };
.unwrap();
Arc::new(EditFileTool) Arc::new(EditFileTool)
.run( .run(
input, input,
@ -1521,12 +1512,11 @@ mod tests {
// Stream unformatted edits again // Stream unformatted edits again
let edit_result = { let edit_result = {
let edit_task = cx.update(|cx| { let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput { let input = EditFileToolInput {
display_description: "Update main function".into(), display_description: "Update main function".into(),
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}) };
.unwrap();
Arc::new(EditFileTool) Arc::new(EditFileTool)
.run( .run(
input, input,
@ -1600,12 +1590,11 @@ mod tests {
// Have the model stream content that contains trailing whitespace // Have the model stream content that contains trailing whitespace
let edit_result = { let edit_result = {
let edit_task = cx.update(|cx| { let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput { let input = EditFileToolInput {
display_description: "Create main function".into(), display_description: "Create main function".into(),
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}) };
.unwrap();
Arc::new(EditFileTool) Arc::new(EditFileTool)
.run( .run(
input, input,
@ -1657,12 +1646,11 @@ mod tests {
// Stream edits again with trailing whitespace // Stream edits again with trailing whitespace
let edit_result = { let edit_result = {
let edit_task = cx.update(|cx| { let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput { let input = EditFileToolInput {
display_description: "Update main function".into(), display_description: "Update main function".into(),
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}) };
.unwrap();
Arc::new(EditFileTool) Arc::new(EditFileTool)
.run( .run(
input, input,

View file

@ -3,10 +3,10 @@ use std::sync::Arc;
use std::{borrow::Cow, cell::RefCell}; use std::{borrow::Cow, cell::RefCell};
use crate::schema::json_schema_for; use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow, bail}; use anyhow::{Context as _, Result, bail};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::AsyncReadExt as _; use futures::AsyncReadExt as _;
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task}; use gpui::{AnyWindowHandle, App, AppContext as _, Entity};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown}; use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl}; use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
@ -113,11 +113,13 @@ impl FetchTool {
} }
impl Tool for FetchTool { impl Tool for FetchTool {
type Input = FetchToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"fetch".to_string() "fetch".to_string()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -137,16 +139,13 @@ impl Tool for FetchTool {
json_schema_for::<FetchToolInput>(format) json_schema_for::<FetchToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<FetchToolInput>(input.clone()) { format!("Fetch {}", MarkdownEscaped(&input.url))
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
Err(_) => "Fetch URL".to_string(),
}
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -154,11 +153,6 @@ impl Tool for FetchTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input = match serde_json::from_value::<FetchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let text = cx.background_spawn({ let text = cx.background_spawn({
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
async move { Self::build_message(http_client, &input.url).await } async move { Self::build_message(http_client, &input.url).await }

View file

@ -51,11 +51,13 @@ const RESULTS_PER_PAGE: usize = 50;
pub struct FindPathTool; pub struct FindPathTool;
impl Tool for FindPathTool { impl Tool for FindPathTool {
type Input = FindPathToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"find_path".into() "find_path".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -75,16 +77,13 @@ impl Tool for FindPathTool {
json_schema_for::<FindPathToolInput>(format) json_schema_for::<FindPathToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<FindPathToolInput>(input.clone()) { format!("Find paths matching \"`{}`\"", input.glob)
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
Err(_) => "Search paths".to_string(),
}
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -92,10 +91,7 @@ impl Tool for FindPathTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) { let (offset, glob) = (input.offset, input.glob);
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();

View file

@ -53,11 +53,13 @@ const RESULTS_PER_PAGE: u32 = 20;
pub struct GrepTool; pub struct GrepTool;
impl Tool for GrepTool { impl Tool for GrepTool {
type Input = GrepToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"grep".into() "grep".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -77,30 +79,25 @@ impl Tool for GrepTool {
json_schema_for::<GrepToolInput>(format) json_schema_for::<GrepToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<GrepToolInput>(input.clone()) { let page = input.page();
Ok(input) => { let regex_str = MarkdownInlineCode(&input.regex);
let page = input.page(); let case_info = if input.case_sensitive {
let regex_str = MarkdownInlineCode(&input.regex); " (case-sensitive)"
let case_info = if input.case_sensitive { } else {
" (case-sensitive)" ""
} else { };
""
};
if page > 1 { if page > 1 {
format!("Get page {page} of search results for regex {regex_str}{case_info}") format!("Get page {page} of search results for regex {regex_str}{case_info}")
} else { } else {
format!("Search files for regex {regex_str}{case_info}") format!("Search files for regex {regex_str}{case_info}")
}
}
Err(_) => "Search with regex".to_string(),
} }
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -111,13 +108,6 @@ impl Tool for GrepTool {
const CONTEXT_LINES: u32 = 2; const CONTEXT_LINES: u32 = 2;
const MAX_ANCESTOR_LINES: u32 = 10; const MAX_ANCESTOR_LINES: u32 = 10;
let input = match serde_json::from_value::<GrepToolInput>(input) {
Ok(input) => input,
Err(error) => {
return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into();
}
};
let include_matcher = match PathMatcher::new( let include_matcher = match PathMatcher::new(
input input
.include_pattern .include_pattern
@ -348,13 +338,12 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Test with include pattern for Rust files inside the root of the project // Test with include pattern for Rust files inside the root of the project
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "println".to_string(), regex: "println".to_string(),
include_pattern: Some("root/**/*.rs".to_string()), include_pattern: Some("root/**/*.rs".to_string()),
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
assert!(result.contains("main.rs"), "Should find matches in main.rs"); assert!(result.contains("main.rs"), "Should find matches in main.rs");
@ -368,13 +357,12 @@ mod tests {
); );
// Test with include pattern for src directory only // Test with include pattern for src directory only
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "fn".to_string(), regex: "fn".to_string(),
include_pattern: Some("root/**/src/**".to_string()), include_pattern: Some("root/**/src/**".to_string()),
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
assert!( assert!(
@ -391,13 +379,12 @@ mod tests {
); );
// Test with empty include pattern (should default to all files) // Test with empty include pattern (should default to all files)
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "fn".to_string(), regex: "fn".to_string(),
include_pattern: None, include_pattern: None,
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
assert!(result.contains("main.rs"), "Should find matches in main.rs"); assert!(result.contains("main.rs"), "Should find matches in main.rs");
@ -428,13 +415,12 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Test case-insensitive search (default) // Test case-insensitive search (default)
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "uppercase".to_string(), regex: "uppercase".to_string(),
include_pattern: Some("**/*.txt".to_string()), include_pattern: Some("**/*.txt".to_string()),
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
assert!( assert!(
@ -443,13 +429,12 @@ mod tests {
); );
// Test case-sensitive search // Test case-sensitive search
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "uppercase".to_string(), regex: "uppercase".to_string(),
include_pattern: Some("**/*.txt".to_string()), include_pattern: Some("**/*.txt".to_string()),
offset: 0, offset: 0,
case_sensitive: true, case_sensitive: true,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
assert!( assert!(
@ -458,13 +443,12 @@ mod tests {
); );
// Test case-sensitive search // Test case-sensitive search
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "LOWERCASE".to_string(), regex: "LOWERCASE".to_string(),
include_pattern: Some("**/*.txt".to_string()), include_pattern: Some("**/*.txt".to_string()),
offset: 0, offset: 0,
case_sensitive: true, case_sensitive: true,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
@ -474,13 +458,12 @@ mod tests {
); );
// Test case-sensitive search for lowercase pattern // Test case-sensitive search for lowercase pattern
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "lowercase".to_string(), regex: "lowercase".to_string(),
include_pattern: Some("**/*.txt".to_string()), include_pattern: Some("**/*.txt".to_string()),
offset: 0, offset: 0,
case_sensitive: true, case_sensitive: true,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
assert!( assert!(
@ -576,13 +559,12 @@ mod tests {
let project = setup_syntax_test(cx).await; let project = setup_syntax_test(cx).await;
// Test: Line at the top level of the file // Test: Line at the top level of the file
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "This is at the top level".to_string(), regex: "This is at the top level".to_string(),
include_pattern: Some("**/*.rs".to_string()), include_pattern: Some("**/*.rs".to_string()),
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#" let expected = r#"
@ -606,13 +588,12 @@ mod tests {
let project = setup_syntax_test(cx).await; let project = setup_syntax_test(cx).await;
// Test: Line inside a function body // Test: Line inside a function body
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "Function in nested module".to_string(), regex: "Function in nested module".to_string(),
include_pattern: Some("**/*.rs".to_string()), include_pattern: Some("**/*.rs".to_string()),
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#" let expected = r#"
@ -638,13 +619,12 @@ mod tests {
let project = setup_syntax_test(cx).await; let project = setup_syntax_test(cx).await;
// Test: Line with a function argument // Test: Line with a function argument
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "second_arg".to_string(), regex: "second_arg".to_string(),
include_pattern: Some("**/*.rs".to_string()), include_pattern: Some("**/*.rs".to_string()),
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#" let expected = r#"
@ -674,13 +654,12 @@ mod tests {
let project = setup_syntax_test(cx).await; let project = setup_syntax_test(cx).await;
// Test: Line inside an if block // Test: Line inside an if block
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "Inside if block".to_string(), regex: "Inside if block".to_string(),
include_pattern: Some("**/*.rs".to_string()), include_pattern: Some("**/*.rs".to_string()),
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#" let expected = r#"
@ -705,13 +684,12 @@ mod tests {
let project = setup_syntax_test(cx).await; let project = setup_syntax_test(cx).await;
// Test: Line in the middle of a long function - should show message about remaining lines // Test: Line in the middle of a long function - should show message about remaining lines
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "Line 5".to_string(), regex: "Line 5".to_string(),
include_pattern: Some("**/*.rs".to_string()), include_pattern: Some("**/*.rs".to_string()),
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#" let expected = r#"
@ -746,13 +724,12 @@ mod tests {
let project = setup_syntax_test(cx).await; let project = setup_syntax_test(cx).await;
// Test: Line in the long function // Test: Line in the long function
let input = serde_json::to_value(GrepToolInput { let input = GrepToolInput {
regex: "Line 12".to_string(), regex: "Line 12".to_string(),
include_pattern: Some("**/*.rs".to_string()), include_pattern: Some("**/*.rs".to_string()),
offset: 0, offset: 0,
case_sensitive: false, case_sensitive: false,
}) };
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await; let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#" let expected = r#"
@ -774,7 +751,7 @@ mod tests {
} }
async fn run_grep_tool( async fn run_grep_tool(
input: serde_json::Value, input: GrepToolInput,
project: Entity<Project>, project: Entity<Project>,
cx: &mut TestAppContext, cx: &mut TestAppContext,
) -> String { ) -> String {
@ -876,9 +853,12 @@ mod tests {
// Searching for files outside the project worktree should return no results // Searching for files outside the project worktree should return no results
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "outside_function" regex: "outside_function".to_string(),
}); include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -902,9 +882,12 @@ mod tests {
// Searching within the project should succeed // Searching within the project should succeed
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "main" regex: "main".to_string(),
}); include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -928,9 +911,12 @@ mod tests {
// Searching files that match file_scan_exclusions should return no results // Searching files that match file_scan_exclusions should return no results
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "special_configuration" regex: "special_configuration".to_string(),
}); include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -953,9 +939,12 @@ mod tests {
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "custom_metadata" regex: "custom_metadata".to_string(),
}); include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -979,9 +968,12 @@ mod tests {
// Searching private files should return no results // Searching private files should return no results
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "SECRET_KEY" regex: "SECRET_KEY".to_string(),
}); include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -1004,9 +996,12 @@ mod tests {
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "private_key_content" regex: "private_key_content".to_string(),
}); include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -1029,9 +1024,12 @@ mod tests {
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "sensitive_data" regex: "sensitive_data".to_string(),
}); include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -1055,9 +1053,12 @@ mod tests {
// Searching a normal file should still work, even with private_files configured // Searching a normal file should still work, even with private_files configured
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "normal_file_content" regex: "normal_file_content".to_string(),
}); include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -1081,10 +1082,12 @@ mod tests {
// Path traversal attempts with .. in include_pattern should not escape project // Path traversal attempts with .. in include_pattern should not escape project
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "outside_function", regex: "outside_function".to_string(),
"include_pattern": "../outside_project/**/*.rs" include_pattern: Some("../outside_project/**/*.rs".to_string()),
}); offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -1185,10 +1188,12 @@ mod tests {
// Search for "secret" - should exclude files based on worktree-specific settings // Search for "secret" - should exclude files based on worktree-specific settings
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "secret", regex: "secret".to_string(),
"case_sensitive": false include_pattern: None,
}); offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,
@ -1250,10 +1255,12 @@ mod tests {
// Test with `include_pattern` specific to one worktree // Test with `include_pattern` specific to one worktree
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = GrepToolInput {
"regex": "secret", regex: "secret".to_string(),
"include_pattern": "worktree1/**/*.rs" include_pattern: Some("worktree1/**/*.rs".to_string()),
}); offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool) Arc::new(GrepTool)
.run( .run(
input, input,

View file

@ -41,11 +41,13 @@ pub struct ListDirectoryToolInput {
pub struct ListDirectoryTool; pub struct ListDirectoryTool;
impl Tool for ListDirectoryTool { impl Tool for ListDirectoryTool {
type Input = ListDirectoryToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"list_directory".into() "list_directory".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -65,19 +67,14 @@ impl Tool for ListDirectoryTool {
json_schema_for::<ListDirectoryToolInput>(format) json_schema_for::<ListDirectoryToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<ListDirectoryToolInput>(input.clone()) { let path = MarkdownInlineCode(&input.path);
Ok(input) => { format!("List the {path} directory's contents")
let path = MarkdownInlineCode(&input.path);
format!("List the {path} directory's contents")
}
Err(_) => "List directory".to_string(),
}
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -85,11 +82,6 @@ impl Tool for ListDirectoryTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// Sometimes models will return these even though we tell it to give a path and not a glob. // Sometimes models will return these even though we tell it to give a path and not a glob.
// When this happens, just list the root worktree directories. // When this happens, just list the root worktree directories.
if matches!(input.path.as_str(), "." | "" | "./" | "*") { if matches!(input.path.as_str(), "." | "" | "./" | "*") {
@ -285,9 +277,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool); let tool = Arc::new(ListDirectoryTool);
// Test listing root directory // Test listing root directory
let input = json!({ let input = ListDirectoryToolInput {
"path": "project" path: "project".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -320,9 +312,9 @@ mod tests {
); );
// Test listing src directory // Test listing src directory
let input = json!({ let input = ListDirectoryToolInput {
"path": "project/src" path: "project/src".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -355,9 +347,9 @@ mod tests {
); );
// Test listing directory with only files // Test listing directory with only files
let input = json!({ let input = ListDirectoryToolInput {
"path": "project/tests" path: "project/tests".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -399,9 +391,9 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let tool = Arc::new(ListDirectoryTool); let tool = Arc::new(ListDirectoryTool);
let input = json!({ let input = ListDirectoryToolInput {
"path": "project/empty_dir" path: "project/empty_dir".to_string(),
}); };
let result = cx let result = cx
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx)) .update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
@ -432,9 +424,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool); let tool = Arc::new(ListDirectoryTool);
// Test non-existent path // Test non-existent path
let input = json!({ let input = ListDirectoryToolInput {
"path": "project/nonexistent" path: "project/nonexistent".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -455,9 +447,9 @@ mod tests {
assert!(result.unwrap_err().to_string().contains("Path not found")); assert!(result.unwrap_err().to_string().contains("Path not found"));
// Test trying to list a file instead of directory // Test trying to list a file instead of directory
let input = json!({ let input = ListDirectoryToolInput {
"path": "project/file.txt" path: "project/file.txt".to_string(),
}); };
let result = cx let result = cx
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx)) .update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
@ -527,9 +519,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool); let tool = Arc::new(ListDirectoryTool);
// Listing root directory should exclude private and excluded files // Listing root directory should exclude private and excluded files
let input = json!({ let input = ListDirectoryToolInput {
"path": "project" path: "project".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -568,9 +560,9 @@ mod tests {
); );
// Trying to list an excluded directory should fail // Trying to list an excluded directory should fail
let input = json!({ let input = ListDirectoryToolInput {
"path": "project/.secretdir" path: "project/.secretdir".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -600,9 +592,9 @@ mod tests {
); );
// Listing a directory should exclude private files within it // Listing a directory should exclude private files within it
let input = json!({ let input = ListDirectoryToolInput {
"path": "project/visible_dir" path: "project/visible_dir".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -720,9 +712,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool); let tool = Arc::new(ListDirectoryTool);
// Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings // Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings
let input = json!({ let input = ListDirectoryToolInput {
"path": "worktree1/src" path: "worktree1/src".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -752,9 +744,9 @@ mod tests {
); );
// Test listing worktree1/tests - should exclude fixture.sql based on local settings // Test listing worktree1/tests - should exclude fixture.sql based on local settings
let input = json!({ let input = ListDirectoryToolInput {
"path": "worktree1/tests" path: "worktree1/tests".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -780,9 +772,9 @@ mod tests {
); );
// Test listing worktree2/lib - should exclude private.js and data.json based on local settings // Test listing worktree2/lib - should exclude private.js and data.json based on local settings
let input = json!({ let input = ListDirectoryToolInput {
"path": "worktree2/lib" path: "worktree2/lib".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -812,9 +804,9 @@ mod tests {
); );
// Test listing worktree2/docs - should exclude internal.md based on local settings // Test listing worktree2/docs - should exclude internal.md based on local settings
let input = json!({ let input = ListDirectoryToolInput {
"path": "worktree2/docs" path: "worktree2/docs".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -840,9 +832,9 @@ mod tests {
); );
// Test trying to list an excluded directory directly // Test trying to list an excluded directory directly
let input = json!({ let input = ListDirectoryToolInput {
"path": "worktree1/src/secret.rs" path: "worktree1/src/secret.rs".to_string(),
}); };
let result = cx let result = cx
.update(|cx| { .update(|cx| {

View file

@ -38,11 +38,13 @@ pub struct MovePathToolInput {
pub struct MovePathTool; pub struct MovePathTool;
impl Tool for MovePathTool { impl Tool for MovePathTool {
type Input = MovePathToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"move_path".into() "move_path".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -62,34 +64,29 @@ impl Tool for MovePathTool {
json_schema_for::<MovePathToolInput>(format) json_schema_for::<MovePathToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<MovePathToolInput>(input.clone()) { let src = MarkdownInlineCode(&input.source_path);
Ok(input) => { let dest = MarkdownInlineCode(&input.destination_path);
let src = MarkdownInlineCode(&input.source_path); let src_path = Path::new(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path); let dest_path = Path::new(&input.destination_path);
let src_path = Path::new(&input.source_path);
let dest_path = Path::new(&input.destination_path);
match dest_path match dest_path
.file_name() .file_name()
.and_then(|os_str| os_str.to_os_string().into_string().ok()) .and_then(|os_str| os_str.to_os_string().into_string().ok())
{ {
Some(filename) if src_path.parent() == dest_path.parent() => { Some(filename) if src_path.parent() == dest_path.parent() => {
let filename = MarkdownInlineCode(&filename); let filename = MarkdownInlineCode(&filename);
format!("Rename {src} to {filename}") format!("Rename {src} to {filename}")
} }
_ => { _ => {
format!("Move {src} to {dest}") format!("Move {src} to {dest}")
}
}
} }
Err(_) => "Move path".to_string(),
} }
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -97,10 +94,6 @@ impl Tool for MovePathTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input = match serde_json::from_value::<MovePathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let rename_task = project.update(cx, |project, cx| { let rename_task = project.update(cx, |project, cx| {
match project match project
.find_project_path(&input.source_path, cx) .find_project_path(&input.source_path, cx)

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use crate::schema::json_schema_for; use crate::schema::json_schema_for;
use anyhow::{Result, anyhow}; use anyhow::Result;
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use chrono::{Local, Utc}; use chrono::{Local, Utc};
use gpui::{AnyWindowHandle, App, Entity, Task}; use gpui::{AnyWindowHandle, App, Entity, Task};
@ -29,11 +29,13 @@ pub struct NowToolInput {
pub struct NowTool; pub struct NowTool;
impl Tool for NowTool { impl Tool for NowTool {
type Input = NowToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"now".into() "now".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -53,13 +55,13 @@ impl Tool for NowTool {
json_schema_for::<NowToolInput>(format) json_schema_for::<NowToolInput>(format)
} }
fn ui_text(&self, _input: &serde_json::Value) -> String { fn ui_text(&self, _input: &Self::Input) -> String {
"Get current time".to_string() "Get current time".to_string()
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -67,11 +69,6 @@ impl Tool for NowTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
_cx: &mut App, _cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input: NowToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let now = match input.timezone { let now = match input.timezone {
Timezone::Utc => Utc::now().to_rfc3339(), Timezone::Utc => Utc::now().to_rfc3339(),
Timezone::Local => Local::now().to_rfc3339(), Timezone::Local => Local::now().to_rfc3339(),

View file

@ -1,7 +1,7 @@
use crate::schema::json_schema_for; use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result};
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; use gpui::{AnyWindowHandle, App, AppContext, Entity};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
@ -19,11 +19,13 @@ pub struct OpenToolInput {
pub struct OpenTool; pub struct OpenTool;
impl Tool for OpenTool { impl Tool for OpenTool {
type Input = OpenToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"open".to_string() "open".to_string()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
true true
} }
fn may_perform_edits(&self) -> bool { fn may_perform_edits(&self) -> bool {
@ -41,16 +43,13 @@ impl Tool for OpenTool {
json_schema_for::<OpenToolInput>(format) json_schema_for::<OpenToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<OpenToolInput>(input.clone()) { format!("Open `{}`", MarkdownEscaped(&input.path_or_url))
Ok(input) => format!("Open `{}`", MarkdownEscaped(&input.path_or_url)),
Err(_) => "Open file or URL".to_string(),
}
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -58,11 +57,6 @@ impl Tool for OpenTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input: OpenToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// If path_or_url turns out to be a path in the project, make it absolute. // If path_or_url turns out to be a path in the project, make it absolute.
let abs_path = to_absolute_path(&input.path_or_url, project, cx); let abs_path = to_absolute_path(&input.path_or_url, project, cx);

View file

@ -51,11 +51,13 @@ pub struct ReadFileToolInput {
pub struct ReadFileTool; pub struct ReadFileTool;
impl Tool for ReadFileTool { impl Tool for ReadFileTool {
type Input = ReadFileToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"read_file".into() "read_file".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -75,23 +77,18 @@ impl Tool for ReadFileTool {
json_schema_for::<ReadFileToolInput>(format) json_schema_for::<ReadFileToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<ReadFileToolInput>(input.clone()) { let path = MarkdownInlineCode(&input.path);
Ok(input) => { match (input.start_line, input.end_line) {
let path = MarkdownInlineCode(&input.path); (Some(start), None) => format!("Read file {path} (from line {start})"),
match (input.start_line, input.end_line) { (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
(Some(start), None) => format!("Read file {path} (from line {start})"), _ => format!("Read file {path}"),
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
_ => format!("Read file {path}"),
}
}
Err(_) => "Read file".to_string(),
} }
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
@ -99,11 +96,6 @@ impl Tool for ReadFileTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into(); return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
}; };
@ -308,9 +300,12 @@ mod test {
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "root/nonexistent_file.txt" path: "root/nonexistent_file.txt".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -347,9 +342,11 @@ mod test {
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "root/small_file.txt" path: "root/small_file.txt".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -389,9 +386,11 @@ mod test {
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "root/large_file.rs" path: "root/large_file.rs".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -421,10 +420,11 @@ mod test {
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "root/large_file.rs", path: "root/large_file.rs".to_string(),
"offset": 1 start_line: None,
}); end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -477,11 +477,11 @@ mod test {
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "root/multiline.txt", path: "root/multiline.txt".to_string(),
"start_line": 2, start_line: Some(2),
"end_line": 4 end_line: Some(4),
}); };
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -520,11 +520,11 @@ mod test {
// start_line of 0 should be treated as 1 // start_line of 0 should be treated as 1
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "root/multiline.txt", path: "root/multiline.txt".to_string(),
"start_line": 0, start_line: Some(0),
"end_line": 2 end_line: Some(2),
}); };
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -543,11 +543,11 @@ mod test {
// end_line of 0 should result in at least 1 line // end_line of 0 should result in at least 1 line
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "root/multiline.txt", path: "root/multiline.txt".to_string(),
"start_line": 1, start_line: Some(1),
"end_line": 0 end_line: Some(0),
}); };
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -566,11 +566,11 @@ mod test {
// when start_line > end_line, should still return at least 1 line // when start_line > end_line, should still return at least 1 line
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "root/multiline.txt", path: "root/multiline.txt".to_string(),
"start_line": 3, start_line: Some(3),
"end_line": 2 end_line: Some(2),
}); };
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -694,9 +694,11 @@ mod test {
// Reading a file outside the project worktree should fail // Reading a file outside the project worktree should fail
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "/outside_project/sensitive_file.txt" path: "/outside_project/sensitive_file.txt".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -718,9 +720,11 @@ mod test {
// Reading a file within the project should succeed // Reading a file within the project should succeed
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "project_root/allowed_file.txt" path: "project_root/allowed_file.txt".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -742,9 +746,11 @@ mod test {
// Reading files that match file_scan_exclusions should fail // Reading files that match file_scan_exclusions should fail
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "project_root/.secretdir/config" path: "project_root/.secretdir/config".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -765,9 +771,11 @@ mod test {
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "project_root/.mymetadata" path: "project_root/.mymetadata".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -789,9 +797,11 @@ mod test {
// Reading private files should fail // Reading private files should fail
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "project_root/.mysecrets" path: "project_root/secrets/.mysecrets".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -812,9 +822,11 @@ mod test {
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "project_root/subdir/special.privatekey" path: "project_root/subdir/special.privatekey".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -835,9 +847,11 @@ mod test {
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "project_root/subdir/data.mysensitive" path: "project_root/subdir/data.mysensitive".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -859,9 +873,11 @@ mod test {
// Reading a normal file should still work, even with private_files configured // Reading a normal file should still work, even with private_files configured
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "project_root/subdir/normal_file.txt" path: "project_root/subdir/normal_file.txt".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -884,9 +900,11 @@ mod test {
// Path traversal attempts with .. should fail // Path traversal attempts with .. should fail
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = json!({ let input = ReadFileToolInput {
"path": "project_root/../outside_project/sensitive_file.txt" path: "project_root/../outside_project/sensitive_file.txt".to_string(),
}); start_line: None,
end_line: None,
};
Arc::new(ReadFileTool) Arc::new(ReadFileTool)
.run( .run(
input, input,
@ -981,9 +999,11 @@ mod test {
let tool = Arc::new(ReadFileTool); let tool = Arc::new(ReadFileTool);
// Test reading allowed files in worktree1 // Test reading allowed files in worktree1
let input = json!({ let input = ReadFileToolInput {
"path": "worktree1/src/main.rs" path: "worktree1/src/main.rs".to_string(),
}); start_line: None,
end_line: None,
};
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -1007,9 +1027,11 @@ mod test {
); );
// Test reading private file in worktree1 should fail // Test reading private file in worktree1 should fail
let input = json!({ let input = ReadFileToolInput {
"path": "worktree1/src/secret.rs" path: "worktree1/src/secret.rs".to_string(),
}); start_line: None,
end_line: None,
};
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -1036,9 +1058,11 @@ mod test {
); );
// Test reading excluded file in worktree1 should fail // Test reading excluded file in worktree1 should fail
let input = json!({ let input = ReadFileToolInput {
"path": "worktree1/tests/fixture.sql" path: "worktree1/tests/fixture.sql".to_string(),
}); start_line: None,
end_line: None,
};
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -1065,9 +1089,11 @@ mod test {
); );
// Test reading allowed files in worktree2 // Test reading allowed files in worktree2
let input = json!({ let input = ReadFileToolInput {
"path": "worktree2/lib/public.js" path: "worktree2/lib/public.js".to_string(),
}); start_line: None,
end_line: None,
};
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -1091,9 +1117,11 @@ mod test {
); );
// Test reading private file in worktree2 should fail // Test reading private file in worktree2 should fail
let input = json!({ let input = ReadFileToolInput {
"path": "worktree2/lib/private.js" path: "worktree2/lib/private.js".to_string(),
}); start_line: None,
end_line: None,
};
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -1120,9 +1148,11 @@ mod test {
); );
// Test reading excluded file in worktree2 should fail // Test reading excluded file in worktree2 should fail
let input = json!({ let input = ReadFileToolInput {
"path": "worktree2/docs/internal.md" path: "worktree2/docs/internal.md".to_string(),
}); start_line: None,
end_line: None,
};
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -1150,9 +1180,11 @@ mod test {
// Test that files allowed in one worktree but not in another are handled correctly // Test that files allowed in one worktree but not in another are handled correctly
// (e.g., config.toml is private in worktree1 but doesn't exist in worktree2) // (e.g., config.toml is private in worktree1 but doesn't exist in worktree2)
let input = json!({ let input = ReadFileToolInput {
"path": "worktree1/src/config.toml" path: "worktree1/src/config.toml".to_string(),
}); start_line: None,
end_line: None,
};
let result = cx let result = cx
.update(|cx| { .update(|cx| {

View file

@ -2,7 +2,7 @@ use crate::{
schema::json_schema_for, schema::json_schema_for,
ui::{COLLAPSED_LINES, ToolOutputPreview}, ui::{COLLAPSED_LINES, ToolOutputPreview},
}; };
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus}; use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt as _, future::Shared}; use futures::{FutureExt as _, future::Shared};
use gpui::{ use gpui::{
@ -72,11 +72,13 @@ impl TerminalTool {
} }
impl Tool for TerminalTool { impl Tool for TerminalTool {
type Input = TerminalToolInput;
fn name(&self) -> String { fn name(&self) -> String {
Self::NAME.to_string() Self::NAME.to_string()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
true true
} }
@ -96,30 +98,24 @@ impl Tool for TerminalTool {
json_schema_for::<TerminalToolInput>(format) json_schema_for::<TerminalToolInput>(format)
} }
fn ui_text(&self, input: &serde_json::Value) -> String { fn ui_text(&self, input: &Self::Input) -> String {
match serde_json::from_value::<TerminalToolInput>(input.clone()) { let mut lines = input.command.lines();
Ok(input) => { let first_line = lines.next().unwrap_or_default();
let mut lines = input.command.lines(); let remaining_line_count = lines.count();
let first_line = lines.next().unwrap_or_default(); match remaining_line_count {
let remaining_line_count = lines.count(); 0 => MarkdownInlineCode(&first_line).to_string(),
match remaining_line_count { 1 => MarkdownInlineCode(&format!(
0 => MarkdownInlineCode(&first_line).to_string(), "{} - {} more line",
1 => MarkdownInlineCode(&format!( first_line, remaining_line_count
"{} - {} more line", ))
first_line, remaining_line_count .to_string(),
)) n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n)).to_string(),
.to_string(),
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n))
.to_string(),
}
}
Err(_) => "Run terminal command".to_string(),
} }
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -127,11 +123,6 @@ impl Tool for TerminalTool {
window: Option<AnyWindowHandle>, window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input: TerminalToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let working_dir = match working_dir(&input, &project, cx) { let working_dir = match working_dir(&input, &project, cx) {
Ok(dir) => dir, Ok(dir) => dir,
Err(err) => return Task::ready(Err(err)).into(), Err(err) => return Task::ready(Err(err)).into(),
@ -756,7 +747,7 @@ mod tests {
let result = cx.update(|cx| { let result = cx.update(|cx| {
TerminalTool::run( TerminalTool::run(
Arc::new(TerminalTool::new(cx)), Arc::new(TerminalTool::new(cx)),
serde_json::to_value(input).unwrap(), input,
Arc::default(), Arc::default(),
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),
@ -791,7 +782,7 @@ mod tests {
let check = |input, expected, cx: &mut App| { let check = |input, expected, cx: &mut App| {
let headless_result = TerminalTool::run( let headless_result = TerminalTool::run(
Arc::new(TerminalTool::new(cx)), Arc::new(TerminalTool::new(cx)),
serde_json::to_value(input).unwrap(), input,
Arc::default(), Arc::default(),
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use crate::schema::json_schema_for; use crate::schema::json_schema_for;
use anyhow::{Result, anyhow}; use anyhow::Result;
use assistant_tool::{ActionLog, Tool, ToolResult}; use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task}; use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
@ -20,11 +20,13 @@ pub struct ThinkingToolInput {
pub struct ThinkingTool; pub struct ThinkingTool;
impl Tool for ThinkingTool { impl Tool for ThinkingTool {
type Input = ThinkingToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"thinking".to_string() "thinking".to_string()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -44,13 +46,13 @@ impl Tool for ThinkingTool {
json_schema_for::<ThinkingToolInput>(format) json_schema_for::<ThinkingToolInput>(format)
} }
fn ui_text(&self, _input: &serde_json::Value) -> String { fn ui_text(&self, _input: &Self::Input) -> String {
"Thinking".to_string() "Thinking".to_string()
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, _input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -59,10 +61,6 @@ impl Tool for ThinkingTool {
_cx: &mut App, _cx: &mut App,
) -> ToolResult { ) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions. // This tool just "thinks out loud" and doesn't perform any actions.
Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) { Task::ready(Ok("Finished thinking.".to_string().into())).into()
Ok(_input) => Ok("Finished thinking.".to_string().into()),
Err(err) => Err(anyhow!(err)),
})
.into()
} }
} }

View file

@ -28,11 +28,13 @@ pub struct WebSearchToolInput {
pub struct WebSearchTool; pub struct WebSearchTool;
impl Tool for WebSearchTool { impl Tool for WebSearchTool {
type Input = WebSearchToolInput;
fn name(&self) -> String { fn name(&self) -> String {
"web_search".into() "web_search".into()
} }
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false false
} }
@ -52,13 +54,13 @@ impl Tool for WebSearchTool {
json_schema_for::<WebSearchToolInput>(format) json_schema_for::<WebSearchToolInput>(format)
} }
fn ui_text(&self, _input: &serde_json::Value) -> String { fn ui_text(&self, _input: &Self::Input) -> String {
"Searching the Web".to_string() "Searching the Web".to_string()
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: Self::Input,
_request: Arc<LanguageModelRequest>, _request: Arc<LanguageModelRequest>,
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>, _action_log: Entity<ActionLog>,
@ -66,10 +68,6 @@ impl Tool for WebSearchTool {
_window: Option<AnyWindowHandle>, _window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult { ) -> ToolResult {
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else { let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
return Task::ready(Err(anyhow!("Web search is not available."))).into(); return Task::ready(Err(anyhow!("Web search is not available."))).into();
}; };

View file

@ -1736,7 +1736,7 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu
let exists_result = cx.update(|cx| { let exists_result = cx.update(|cx| {
ReadFileTool::run( ReadFileTool::run(
Arc::new(ReadFileTool), Arc::new(ReadFileTool),
serde_json::to_value(input).unwrap(), input,
request.clone(), request.clone(),
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),
@ -1756,7 +1756,7 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu
let does_not_exist_result = cx.update(|cx| { let does_not_exist_result = cx.update(|cx| {
ReadFileTool::run( ReadFileTool::run(
Arc::new(ReadFileTool), Arc::new(ReadFileTool),
serde_json::to_value(input).unwrap(), input,
request.clone(), request.clone(),
project.clone(), project.clone(),
action_log.clone(), action_log.clone(),