Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Bennet Bo Fenner
dbf3c31a83 agent: Encode tool input with associated type 2025-07-07 00:40:53 +02:00
26 changed files with 728 additions and 571 deletions

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
use assistant_tool::{Tool, ToolSource, ToolWorkingSet, UniqueToolName};
use assistant_tool::{AnyTool, ToolSource, ToolWorkingSet, UniqueToolName};
use collections::IndexMap;
use convert_case::{Case, Casing};
use fs::Fs;
@ -72,7 +72,7 @@ impl AgentProfile {
&self.id
}
pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
return Vec::new();
};
@ -108,7 +108,7 @@ impl AgentProfile {
#[cfg(test)]
mod tests {
use agent_settings::ContextServerPreset;
use assistant_tool::ToolRegistry;
use assistant_tool::{Tool, ToolRegistry};
use collections::IndexMap;
use gpui::SharedString;
use gpui::{AppContext, TestAppContext};
@ -269,8 +269,14 @@ mod tests {
fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
cx.new(|cx| {
let mut tool_set = ToolWorkingSet::default();
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")), cx);
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")), cx);
tool_set.insert(
Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")).into(),
cx,
);
tool_set.insert(
Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")).into(),
cx,
);
tool_set
})
}
@ -290,6 +296,8 @@ mod tests {
}
impl Tool for FakeTool {
type Input = ();
fn name(&self) -> String {
self.name.clone()
}
@ -308,17 +316,17 @@ mod tests {
unimplemented!()
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
unimplemented!()
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
fn ui_text(&self, _input: &Self::Input) -> String {
unimplemented!()
}
fn run(
self: Arc<Self>,
_input: serde_json::Value,
_input: Self::Input,
_request: Arc<language_model::LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<assistant_tool::ActionLog>,

View file

@ -29,6 +29,8 @@ impl ContextServerTool {
}
impl Tool for ContextServerTool {
type Input = serde_json::Value;
fn name(&self) -> String {
self.tool.name.clone()
}
@ -47,7 +49,7 @@ impl Tool for ContextServerTool {
}
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
true
}
@ -69,13 +71,13 @@ impl Tool for ContextServerTool {
})
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
fn ui_text(&self, _input: &Self::Input) -> String {
format!("Run MCP tool `{}`", self.tool.name)
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,

View file

@ -10,7 +10,7 @@ use crate::{
};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use assistant_tool::{ActionLog, AnyTool, AnyToolCard, ToolWorkingSet};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use collections::HashMap;
@ -2452,7 +2452,7 @@ impl Thread {
ui_text: impl Into<SharedString>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>,
tool: AnyTool,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
@ -2468,7 +2468,7 @@ impl Thread {
tool_use_id: LanguageModelToolUseId,
request: Arc<LanguageModelRequest>,
input: serde_json::Value,
tool: Arc<dyn Tool>,
tool: AnyTool,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,

View file

@ -6,7 +6,7 @@ use crate::{
};
use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{Tool, ToolId, ToolWorkingSet};
use assistant_tool::{ToolId, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::ContextServerId;
@ -576,7 +576,8 @@ impl ThreadStore {
context_server_store.clone(),
server.id(),
tool,
)) as Arc<dyn Tool>
))
.into()
}),
cx,
)

View file

@ -4,7 +4,7 @@ use crate::{
};
use anyhow::Result;
use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
AnyTool, AnyToolCard, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
};
use collections::HashMap;
use futures::{FutureExt as _, future::Shared};
@ -378,7 +378,7 @@ impl ToolUseState {
ui_text: impl Into<Arc<str>>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>,
tool: AnyTool,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
let ui_text = ui_text.into();
@ -533,7 +533,7 @@ pub struct Confirmation {
pub input: serde_json::Value,
pub ui_text: Arc<str>,
pub request: Arc<LanguageModelRequest>,
pub tool: Arc<dyn Tool>,
pub tool: AnyTool,
}
#[derive(Debug, Clone)]

View file

@ -1,5 +1,5 @@
use agent::{Thread, ThreadEvent};
use assistant_tool::{Tool, ToolSource};
use assistant_tool::{AnyTool, ToolSource};
use collections::HashMap;
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
use language_model::{LanguageModel, LanguageModelToolSchemaFormat};
@ -7,7 +7,7 @@ use std::sync::Arc;
use ui::prelude::*;
pub struct IncompatibleToolsState {
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
cache: HashMap<LanguageModelToolSchemaFormat, Vec<AnyTool>>,
thread: Entity<Thread>,
_thread_subscription: Subscription,
}
@ -29,11 +29,7 @@ impl IncompatibleToolsState {
}
}
pub fn incompatible_tools(
&mut self,
model: &Arc<dyn LanguageModel>,
cx: &App,
) -> &[Arc<dyn Tool>] {
pub fn incompatible_tools(&mut self, model: &Arc<dyn LanguageModel>, cx: &App) -> &[AnyTool] {
self.cache
.entry(model.tool_input_format())
.or_insert_with(|| {
@ -50,7 +46,7 @@ impl IncompatibleToolsState {
}
pub struct IncompatibleToolsTooltip {
pub incompatible_tools: Vec<Arc<dyn Tool>>,
pub incompatible_tools: Vec<AnyTool>,
}
impl Render for IncompatibleToolsTooltip {

View file

@ -4,25 +4,19 @@ mod tool_registry;
mod tool_schema;
mod tool_working_set;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;
use std::{fmt, fmt::Debug, fmt::Formatter, ops::Deref, sync::Arc};
use anyhow::Result;
use gpui::AnyElement;
use gpui::AnyWindowHandle;
use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task, WeakEntity};
use gpui::{
AnyElement, AnyWindowHandle, App, Context, Entity, IntoElement, SharedString, Task, WeakEntity,
Window,
};
use icons::IconName;
use language_model::LanguageModel;
use language_model::LanguageModelImage;
use language_model::LanguageModelRequest;
use language_model::LanguageModelToolSchemaFormat;
use language_model::{
LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
};
use project::Project;
use serde::de::DeserializeOwned;
use workspace::Workspace;
pub use crate::action_log::*;
@ -199,7 +193,10 @@ pub enum ToolSource {
}
/// A tool that can be used by a language model.
pub trait Tool: 'static + Send + Sync {
pub trait Tool: Send + Sync + 'static {
/// The input type that is accepted by the tool.
type Input: DeserializeOwned;
/// Returns the name of the tool.
fn name(&self) -> String;
@ -216,7 +213,7 @@ pub trait Tool: 'static + Send + Sync {
/// Returns true if the tool needs the users's confirmation
/// before having permission to run.
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
fn needs_confirmation(&self, input: &Self::Input, cx: &App) -> bool;
/// Returns true if the tool may perform edits.
fn may_perform_edits(&self) -> bool;
@ -227,18 +224,18 @@ pub trait Tool: 'static + Send + Sync {
}
/// Returns markdown to be displayed in the UI for this tool.
fn ui_text(&self, input: &serde_json::Value) -> String;
fn ui_text(&self, input: &Self::Input) -> String;
/// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
/// (so information may be missing).
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
self.ui_text(input)
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@ -258,7 +255,199 @@ pub trait Tool: 'static + Send + Sync {
}
}
impl Debug for dyn Tool {
#[derive(Clone)]
pub struct AnyTool {
inner: Arc<dyn ErasedTool>,
}
/// Copy of `Tool` where the Input type is erased.
trait ErasedTool: Send + Sync {
fn name(&self) -> String;
fn description(&self) -> String;
fn icon(&self) -> IconName;
fn source(&self) -> ToolSource;
fn may_perform_edits(&self) -> bool;
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn ui_text(&self, input: &serde_json::Value) -> String;
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String;
fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult;
fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard>;
}
struct ErasedToolWrapper<T: Tool> {
tool: Arc<T>,
}
impl<T: Tool> ErasedTool for ErasedToolWrapper<T> {
fn name(&self) -> String {
self.tool.name()
}
fn description(&self) -> String {
self.tool.description()
}
fn icon(&self) -> IconName {
self.tool.icon()
}
fn source(&self) -> ToolSource {
self.tool.source()
}
fn may_perform_edits(&self) -> bool {
self.tool.may_perform_edits()
}
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.needs_confirmation(&parsed_input, cx),
Err(_) => true,
}
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
self.tool.input_schema(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.ui_text(&parsed_input),
Err(_) => "Invalid input".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.still_streaming_ui_text(&parsed_input),
Err(_) => "Invalid input".to_string(),
}
}
fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match serde_json::from_value::<T::Input>(input) {
Ok(parsed_input) => self.tool.clone().run(
parsed_input,
request,
project,
action_log,
model,
window,
cx,
),
Err(err) => ToolResult::from(Task::ready(Err(err.into()))),
}
}
fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
self.tool
.clone()
.deserialize_card(output, project, window, cx)
}
}
impl<T: Tool> From<Arc<T>> for AnyTool {
fn from(tool: Arc<T>) -> Self {
Self {
inner: Arc::new(ErasedToolWrapper { tool }),
}
}
}
impl AnyTool {
pub fn name(&self) -> String {
self.inner.name()
}
pub fn description(&self) -> String {
self.inner.description()
}
pub fn icon(&self) -> IconName {
self.inner.icon()
}
pub fn source(&self) -> ToolSource {
self.inner.source()
}
pub fn may_perform_edits(&self) -> bool {
self.inner.may_perform_edits()
}
pub fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
self.inner.needs_confirmation(input, cx)
}
pub fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
self.inner.input_schema(format)
}
pub fn ui_text(&self, input: &serde_json::Value) -> String {
self.inner.ui_text(input)
}
pub fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
self.inner.still_streaming_ui_text(input)
}
pub fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
self.inner
.run(input, request, project, action_log, model, window, cx)
}
pub fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
self.inner.deserialize_card(output, project, window, cx)
}
}
impl Debug for AnyTool {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Tool").field("name", &self.name()).finish()
}

View file

@ -6,7 +6,7 @@ use gpui::Global;
use gpui::{App, ReadGlobal};
use parking_lot::RwLock;
use crate::Tool;
use crate::{AnyTool, Tool};
#[derive(Default, Deref, DerefMut)]
struct GlobalToolRegistry(Arc<ToolRegistry>);
@ -15,7 +15,7 @@ impl Global for GlobalToolRegistry {}
#[derive(Default)]
struct ToolRegistryState {
tools: HashMap<Arc<str>, Arc<dyn Tool>>,
tools: HashMap<Arc<str>, AnyTool>,
}
#[derive(Default)]
@ -48,7 +48,7 @@ impl ToolRegistry {
pub fn register_tool(&self, tool: impl Tool) {
let mut state = self.state.write();
let tool_name: Arc<str> = tool.name().into();
state.tools.insert(tool_name, Arc::new(tool));
state.tools.insert(tool_name, Arc::new(tool).into());
}
/// Unregisters the provided [`Tool`].
@ -63,12 +63,12 @@ impl ToolRegistry {
}
/// Returns the list of tools in the registry.
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
pub fn tools(&self) -> Vec<AnyTool> {
self.state.read().tools.values().cloned().collect()
}
/// Returns the [`Tool`] with the given name.
pub fn tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
pub fn tool(&self, name: &str) -> Option<AnyTool> {
self.state.read().tools.get(name).cloned()
}
}

View file

@ -1,6 +1,6 @@
use std::{borrow::Borrow, sync::Arc};
use std::borrow::Borrow;
use crate::{Tool, ToolRegistry, ToolSource};
use crate::{AnyTool, ToolRegistry, ToolSource};
use collections::{HashMap, HashSet, IndexMap};
use gpui::{App, SharedString};
use util::debug_panic;
@ -45,20 +45,20 @@ impl std::fmt::Display for UniqueToolName {
/// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct ToolWorkingSet {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<UniqueToolName, Arc<dyn Tool>>,
context_server_tools_by_id: HashMap<ToolId, AnyTool>,
context_server_tools_by_name: HashMap<UniqueToolName, AnyTool>,
next_tool_id: ToolId,
}
impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
pub fn tool(&self, name: &str, cx: &App) -> Option<AnyTool> {
self.context_server_tools_by_name
.get(name)
.cloned()
.or_else(|| ToolRegistry::global(cx).tool(name))
}
pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
let mut tools = ToolRegistry::global(cx)
.tools()
.into_iter()
@ -68,7 +68,7 @@ impl ToolWorkingSet {
tools
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<AnyTool>> {
let mut tools_by_source = IndexMap::default();
for (_, tool) in self.tools(cx) {
@ -87,13 +87,13 @@ impl ToolWorkingSet {
tools_by_source
}
pub fn insert(&mut self, tool: Arc<dyn Tool>, cx: &App) -> ToolId {
pub fn insert(&mut self, tool: AnyTool, cx: &App) -> ToolId {
let tool_id = self.register_tool(tool);
self.tools_changed(cx);
tool_id
}
pub fn extend(&mut self, tools: impl Iterator<Item = Arc<dyn Tool>>, cx: &App) -> Vec<ToolId> {
pub fn extend(&mut self, tools: impl Iterator<Item = AnyTool>, cx: &App) -> Vec<ToolId> {
let ids = tools.map(|tool| self.register_tool(tool)).collect();
self.tools_changed(cx);
ids
@ -105,7 +105,7 @@ impl ToolWorkingSet {
self.tools_changed(cx);
}
fn register_tool(&mut self, tool: Arc<dyn Tool>) -> ToolId {
fn register_tool(&mut self, tool: AnyTool) -> ToolId {
let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
self.context_server_tools_by_id
@ -126,10 +126,10 @@ impl ToolWorkingSet {
}
fn resolve_context_server_tool_name_conflicts(
context_server_tools: &[Arc<dyn Tool>],
native_tools: &[Arc<dyn Tool>],
) -> HashMap<UniqueToolName, Arc<dyn Tool>> {
fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
context_server_tools: &[AnyTool],
native_tools: &[AnyTool],
) -> HashMap<UniqueToolName, AnyTool> {
fn resolve_tool_name(tool: &AnyTool) -> String {
let mut tool_name = tool.name();
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
tool_name
@ -201,11 +201,13 @@ fn resolve_context_server_tool_name_conflicts(
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
use language_model::{LanguageModel, LanguageModelRequest};
use project::Project;
use crate::{ActionLog, ToolResult};
use crate::{ActionLog, Tool, ToolResult};
use super::*;
@ -234,11 +236,13 @@ mod tests {
Arc::new(TestTool::new(
"tool2",
ToolSource::ContextServer { id: "mcp-1".into() },
)) as Arc<dyn Tool>,
))
.into(),
Arc::new(TestTool::new(
"tool2",
ToolSource::ContextServer { id: "mcp-2".into() },
)) as Arc<dyn Tool>,
))
.into(),
]
.into_iter(),
cx,
@ -324,13 +328,13 @@ mod tests {
context_server_tools: Vec<TestTool>,
expected: Vec<&'static str>,
) {
let context_server_tools: Vec<Arc<dyn Tool>> = context_server_tools
let context_server_tools: Vec<AnyTool> = context_server_tools
.into_iter()
.map(|t| Arc::new(t) as Arc<dyn Tool>)
.map(|t| Arc::new(t).into())
.collect();
let builtin_tools: Vec<Arc<dyn Tool>> = builtin_tools
let builtin_tools: Vec<AnyTool> = builtin_tools
.into_iter()
.map(|t| Arc::new(t) as Arc<dyn Tool>)
.map(|t| Arc::new(t).into())
.collect();
let tools =
resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
@ -363,6 +367,8 @@ mod tests {
}
impl Tool for TestTool {
type Input = ();
fn name(&self) -> String {
self.name.clone()
}
@ -375,7 +381,7 @@ mod tests {
false
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
true
}
@ -387,13 +393,13 @@ mod tests {
"Test tool".to_string()
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
fn ui_text(&self, _input: &Self::Input) -> String {
"Test tool".to_string()
}
fn run(
self: Arc<Self>,
_input: serde_json::Value,
_input: Self::Input,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,

View file

@ -40,11 +40,13 @@ pub struct CopyPathToolInput {
pub struct CopyPathTool;
impl Tool for CopyPathTool {
type Input = CopyPathToolInput;
fn name(&self) -> String {
"copy_path".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -64,20 +66,15 @@ impl Tool for CopyPathTool {
json_schema_for::<CopyPathToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CopyPathToolInput>(input.clone()) {
Ok(input) => {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
format!("Copy {src} to {dest}")
}
Err(_) => "Copy path".to_string(),
}
fn ui_text(&self, input: &Self::Input) -> String {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
format!("Copy {src} to {dest}")
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -85,10 +82,6 @@ impl Tool for CopyPathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let copy_task = project.update(cx, |project, cx| {
match project
.find_project_path(&input.source_path, cx)

View file

@ -29,6 +29,8 @@ pub struct CreateDirectoryToolInput {
pub struct CreateDirectoryTool;
impl Tool for CreateDirectoryTool {
type Input = CreateDirectoryToolInput;
fn name(&self) -> String {
"create_directory".into()
}
@ -37,7 +39,7 @@ impl Tool for CreateDirectoryTool {
include_str!("./create_directory_tool/description.md").into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -53,18 +55,13 @@ impl Tool for CreateDirectoryTool {
json_schema_for::<CreateDirectoryToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CreateDirectoryToolInput>(input.clone()) {
Ok(input) => {
format!("Create directory {}", MarkdownInlineCode(&input.path))
}
Err(_) => "Create directory".to_string(),
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Create directory {}", MarkdownInlineCode(&input.path))
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -72,10 +69,6 @@ impl Tool for CreateDirectoryTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match project.read(cx).find_project_path(&input.path, cx) {
Some(project_path) => project_path,
None => {

View file

@ -29,11 +29,13 @@ pub struct DeletePathToolInput {
pub struct DeletePathTool;
impl Tool for DeletePathTool {
type Input = DeletePathToolInput;
fn name(&self) -> String {
"delete_path".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -53,16 +55,13 @@ impl Tool for DeletePathTool {
json_schema_for::<DeletePathToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<DeletePathToolInput>(input.clone()) {
Ok(input) => format!("Delete “`{}`”", input.path),
Err(_) => "Delete path".to_string(),
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Delete “`{}`”", input.path)
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@ -70,10 +69,7 @@ impl Tool for DeletePathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
Ok(input) => input.path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let path_str = input.path;
let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else {
return Task::ready(Err(anyhow!(
"Couldn't delete {path_str} because that path isn't in this project."

View file

@ -42,11 +42,13 @@ where
pub struct DiagnosticsTool;
impl Tool for DiagnosticsTool {
type Input = DiagnosticsToolInput;
fn name(&self) -> String {
"diagnostics".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -66,15 +68,9 @@ impl Tool for DiagnosticsTool {
json_schema_for::<DiagnosticsToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input.clone())
.ok()
.and_then(|input| match input.path {
Some(path) if !path.is_empty() => Some(path),
_ => None,
})
{
format!("Check diagnostics for {}", MarkdownInlineCode(&path))
fn ui_text(&self, input: &Self::Input) -> String {
if let Some(path) = input.path.as_ref().filter(|p| !p.is_empty()) {
format!("Check diagnostics for {}", MarkdownInlineCode(path))
} else {
"Check project diagnostics".to_string()
}
@ -82,7 +78,7 @@ impl Tool for DiagnosticsTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@ -90,10 +86,7 @@ impl Tool for DiagnosticsTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match serde_json::from_value::<DiagnosticsToolInput>(input)
.ok()
.and_then(|input| input.path)
{
match input.path {
Some(path) if !path.is_empty() => {
let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
return Task::ready(Err(anyhow!("Could not find path {path} in project",)))

View file

@ -121,11 +121,13 @@ struct PartialInput {
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for EditFileTool {
type Input = EditFileToolInput;
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -145,24 +147,20 @@ impl Tool for EditFileTool {
json_schema_for::<EditFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
}
fn ui_text(&self, input: &Self::Input) -> String {
input.display_description.clone()
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
let path = input.path.to_string_lossy();
let path = path.trim();
if !path.is_empty() {
return path.to_string();
}
DEFAULT_UI_TEXT.to_string()
@ -170,7 +168,7 @@ impl Tool for EditFileTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@ -178,11 +176,6 @@ impl Tool for EditFileTool {
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<EditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
@ -1169,12 +1162,11 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
let input = EditFileToolInput {
display_description: "Some edit".into(),
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
})
.unwrap();
};
Arc::new(EditFileTool)
.run(
input,
@ -1288,24 +1280,22 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_path() {
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
let input = EditFileToolInput {
path: "src/main.rs".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
let input = EditFileToolInput {
path: "".into(),
display_description: "Fix error handling".into(),
mode: EditFileMode::Edit,
};
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@ -1315,12 +1305,11 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
let input = EditFileToolInput {
path: "src/main.rs".into(),
display_description: "Fix error handling".into(),
mode: EditFileMode::Edit,
};
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@ -1330,12 +1319,11 @@ mod tests {
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
let input = EditFileToolInput {
path: "".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@ -1345,7 +1333,11 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_null() {
let input = serde_json::Value::Null;
let input = EditFileToolInput {
path: "".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@ -1457,12 +1449,11 @@ mod tests {
// Have the model stream unformatted content
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
let input = EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
};
Arc::new(EditFileTool)
.run(
input,
@ -1521,12 +1512,11 @@ mod tests {
// Stream unformatted edits again
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
let input = EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
};
Arc::new(EditFileTool)
.run(
input,
@ -1600,12 +1590,11 @@ mod tests {
// Have the model stream content that contains trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
let input = EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
};
Arc::new(EditFileTool)
.run(
input,
@ -1657,12 +1646,11 @@ mod tests {
// Stream edits again with trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
let input = EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
};
Arc::new(EditFileTool)
.run(
input,

View file

@ -3,10 +3,10 @@ use std::sync::Arc;
use std::{borrow::Cow, cell::RefCell};
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow, bail};
use anyhow::{Context as _, Result, bail};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::AsyncReadExt as _;
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext as _, Entity};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
@ -113,11 +113,13 @@ impl FetchTool {
}
impl Tool for FetchTool {
type Input = FetchToolInput;
fn name(&self) -> String {
"fetch".to_string()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -137,16 +139,13 @@ impl Tool for FetchTool {
json_schema_for::<FetchToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FetchToolInput>(input.clone()) {
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
Err(_) => "Fetch URL".to_string(),
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Fetch {}", MarkdownEscaped(&input.url))
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -154,11 +153,6 @@ impl Tool for FetchTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<FetchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let text = cx.background_spawn({
let http_client = self.http_client.clone();
async move { Self::build_message(http_client, &input.url).await }

View file

@ -51,11 +51,13 @@ const RESULTS_PER_PAGE: usize = 50;
pub struct FindPathTool;
impl Tool for FindPathTool {
type Input = FindPathToolInput;
fn name(&self) -> String {
"find_path".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -75,16 +77,13 @@ impl Tool for FindPathTool {
json_schema_for::<FindPathToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FindPathToolInput>(input.clone()) {
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
Err(_) => "Search paths".to_string(),
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Find paths matching \"`{}`\"", input.glob)
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -92,10 +91,7 @@ impl Tool for FindPathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let (offset, glob) = (input.offset, input.glob);
let (sender, receiver) = oneshot::channel();

View file

@ -53,11 +53,13 @@ const RESULTS_PER_PAGE: u32 = 20;
pub struct GrepTool;
impl Tool for GrepTool {
type Input = GrepToolInput;
fn name(&self) -> String {
"grep".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -77,30 +79,25 @@ impl Tool for GrepTool {
json_schema_for::<GrepToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<GrepToolInput>(input.clone()) {
Ok(input) => {
let page = input.page();
let regex_str = MarkdownInlineCode(&input.regex);
let case_info = if input.case_sensitive {
" (case-sensitive)"
} else {
""
};
fn ui_text(&self, input: &Self::Input) -> String {
let page = input.page();
let regex_str = MarkdownInlineCode(&input.regex);
let case_info = if input.case_sensitive {
" (case-sensitive)"
} else {
""
};
if page > 1 {
format!("Get page {page} of search results for regex {regex_str}{case_info}")
} else {
format!("Search files for regex {regex_str}{case_info}")
}
}
Err(_) => "Search with regex".to_string(),
if page > 1 {
format!("Get page {page} of search results for regex {regex_str}{case_info}")
} else {
format!("Search files for regex {regex_str}{case_info}")
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -111,13 +108,6 @@ impl Tool for GrepTool {
const CONTEXT_LINES: u32 = 2;
const MAX_ANCESTOR_LINES: u32 = 10;
let input = match serde_json::from_value::<GrepToolInput>(input) {
Ok(input) => input,
Err(error) => {
return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into();
}
};
let include_matcher = match PathMatcher::new(
input
.include_pattern
@ -348,13 +338,12 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Test with include pattern for Rust files inside the root of the project
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "println".to_string(),
include_pattern: Some("root/**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(result.contains("main.rs"), "Should find matches in main.rs");
@ -368,13 +357,12 @@ mod tests {
);
// Test with include pattern for src directory only
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "fn".to_string(),
include_pattern: Some("root/**/src/**".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@ -391,13 +379,12 @@ mod tests {
);
// Test with empty include pattern (should default to all files)
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "fn".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(result.contains("main.rs"), "Should find matches in main.rs");
@ -428,13 +415,12 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Test case-insensitive search (default)
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "uppercase".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@ -443,13 +429,12 @@ mod tests {
);
// Test case-sensitive search
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "uppercase".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: true,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@ -458,13 +443,12 @@ mod tests {
);
// Test case-sensitive search
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "LOWERCASE".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: true,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
@ -474,13 +458,12 @@ mod tests {
);
// Test case-sensitive search for lowercase pattern
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "lowercase".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: true,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@ -576,13 +559,12 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line at the top level of the file
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "This is at the top level".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@ -606,13 +588,12 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line inside a function body
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "Function in nested module".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@ -638,13 +619,12 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line with a function argument
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "second_arg".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@ -674,13 +654,12 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line inside an if block
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "Inside if block".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@ -705,13 +684,12 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line in the middle of a long function - should show message about remaining lines
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "Line 5".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@ -746,13 +724,12 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line in the long function
let input = serde_json::to_value(GrepToolInput {
let input = GrepToolInput {
regex: "Line 12".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
};
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@ -774,7 +751,7 @@ mod tests {
}
async fn run_grep_tool(
input: serde_json::Value,
input: GrepToolInput,
project: Entity<Project>,
cx: &mut TestAppContext,
) -> String {
@ -876,9 +853,12 @@ mod tests {
// Searching for files outside the project worktree should return no results
let result = cx
.update(|cx| {
let input = json!({
"regex": "outside_function"
});
let input = GrepToolInput {
regex: "outside_function".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -902,9 +882,12 @@ mod tests {
// Searching within the project should succeed
let result = cx
.update(|cx| {
let input = json!({
"regex": "main"
});
let input = GrepToolInput {
regex: "main".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -928,9 +911,12 @@ mod tests {
// Searching files that match file_scan_exclusions should return no results
let result = cx
.update(|cx| {
let input = json!({
"regex": "special_configuration"
});
let input = GrepToolInput {
regex: "special_configuration".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -953,9 +939,12 @@ mod tests {
let result = cx
.update(|cx| {
let input = json!({
"regex": "custom_metadata"
});
let input = GrepToolInput {
regex: "custom_metadata".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -979,9 +968,12 @@ mod tests {
// Searching private files should return no results
let result = cx
.update(|cx| {
let input = json!({
"regex": "SECRET_KEY"
});
let input = GrepToolInput {
regex: "SECRET_KEY".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -1004,9 +996,12 @@ mod tests {
let result = cx
.update(|cx| {
let input = json!({
"regex": "private_key_content"
});
let input = GrepToolInput {
regex: "private_key_content".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -1029,9 +1024,12 @@ mod tests {
let result = cx
.update(|cx| {
let input = json!({
"regex": "sensitive_data"
});
let input = GrepToolInput {
regex: "sensitive_data".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -1055,9 +1053,12 @@ mod tests {
// Searching a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = json!({
"regex": "normal_file_content"
});
let input = GrepToolInput {
regex: "normal_file_content".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -1081,10 +1082,12 @@ mod tests {
// Path traversal attempts with .. in include_pattern should not escape project
let result = cx
.update(|cx| {
let input = json!({
"regex": "outside_function",
"include_pattern": "../outside_project/**/*.rs"
});
let input = GrepToolInput {
regex: "outside_function".to_string(),
include_pattern: Some("../outside_project/**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -1185,10 +1188,12 @@ mod tests {
// Search for "secret" - should exclude files based on worktree-specific settings
let result = cx
.update(|cx| {
let input = json!({
"regex": "secret",
"case_sensitive": false
});
let input = GrepToolInput {
regex: "secret".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,
@ -1250,10 +1255,12 @@ mod tests {
// Test with `include_pattern` specific to one worktree
let result = cx
.update(|cx| {
let input = json!({
"regex": "secret",
"include_pattern": "worktree1/**/*.rs"
});
let input = GrepToolInput {
regex: "secret".to_string(),
include_pattern: Some("worktree1/**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
Arc::new(GrepTool)
.run(
input,

View file

@ -41,11 +41,13 @@ pub struct ListDirectoryToolInput {
pub struct ListDirectoryTool;
impl Tool for ListDirectoryTool {
type Input = ListDirectoryToolInput;
fn name(&self) -> String {
"list_directory".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -65,19 +67,14 @@ impl Tool for ListDirectoryTool {
json_schema_for::<ListDirectoryToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ListDirectoryToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownInlineCode(&input.path);
format!("List the {path} directory's contents")
}
Err(_) => "List directory".to_string(),
}
fn ui_text(&self, input: &Self::Input) -> String {
let path = MarkdownInlineCode(&input.path);
format!("List the {path} directory's contents")
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -85,11 +82,6 @@ impl Tool for ListDirectoryTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// Sometimes models will return these even though we tell it to give a path and not a glob.
// When this happens, just list the root worktree directories.
if matches!(input.path.as_str(), "." | "" | "./" | "*") {
@ -285,9 +277,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Test listing root directory
let input = json!({
"path": "project"
});
let input = ListDirectoryToolInput {
path: "project".to_string(),
};
let result = cx
.update(|cx| {
@ -320,9 +312,9 @@ mod tests {
);
// Test listing src directory
let input = json!({
"path": "project/src"
});
let input = ListDirectoryToolInput {
path: "project/src".to_string(),
};
let result = cx
.update(|cx| {
@ -355,9 +347,9 @@ mod tests {
);
// Test listing directory with only files
let input = json!({
"path": "project/tests"
});
let input = ListDirectoryToolInput {
path: "project/tests".to_string(),
};
let result = cx
.update(|cx| {
@ -399,9 +391,9 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let tool = Arc::new(ListDirectoryTool);
let input = json!({
"path": "project/empty_dir"
});
let input = ListDirectoryToolInput {
path: "project/empty_dir".to_string(),
};
let result = cx
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
@ -432,9 +424,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Test non-existent path
let input = json!({
"path": "project/nonexistent"
});
let input = ListDirectoryToolInput {
path: "project/nonexistent".to_string(),
};
let result = cx
.update(|cx| {
@ -455,9 +447,9 @@ mod tests {
assert!(result.unwrap_err().to_string().contains("Path not found"));
// Test trying to list a file instead of directory
let input = json!({
"path": "project/file.txt"
});
let input = ListDirectoryToolInput {
path: "project/file.txt".to_string(),
};
let result = cx
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
@ -527,9 +519,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Listing root directory should exclude private and excluded files
let input = json!({
"path": "project"
});
let input = ListDirectoryToolInput {
path: "project".to_string(),
};
let result = cx
.update(|cx| {
@ -568,9 +560,9 @@ mod tests {
);
// Trying to list an excluded directory should fail
let input = json!({
"path": "project/.secretdir"
});
let input = ListDirectoryToolInput {
path: "project/.secretdir".to_string(),
};
let result = cx
.update(|cx| {
@ -600,9 +592,9 @@ mod tests {
);
// Listing a directory should exclude private files within it
let input = json!({
"path": "project/visible_dir"
});
let input = ListDirectoryToolInput {
path: "project/visible_dir".to_string(),
};
let result = cx
.update(|cx| {
@ -720,9 +712,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings
let input = json!({
"path": "worktree1/src"
});
let input = ListDirectoryToolInput {
path: "worktree1/src".to_string(),
};
let result = cx
.update(|cx| {
@ -752,9 +744,9 @@ mod tests {
);
// Test listing worktree1/tests - should exclude fixture.sql based on local settings
let input = json!({
"path": "worktree1/tests"
});
let input = ListDirectoryToolInput {
path: "worktree1/tests".to_string(),
};
let result = cx
.update(|cx| {
@ -780,9 +772,9 @@ mod tests {
);
// Test listing worktree2/lib - should exclude private.js and data.json based on local settings
let input = json!({
"path": "worktree2/lib"
});
let input = ListDirectoryToolInput {
path: "worktree2/lib".to_string(),
};
let result = cx
.update(|cx| {
@ -812,9 +804,9 @@ mod tests {
);
// Test listing worktree2/docs - should exclude internal.md based on local settings
let input = json!({
"path": "worktree2/docs"
});
let input = ListDirectoryToolInput {
path: "worktree2/docs".to_string(),
};
let result = cx
.update(|cx| {
@ -840,9 +832,9 @@ mod tests {
);
// Test trying to list an excluded directory directly
let input = json!({
"path": "worktree1/src/secret.rs"
});
let input = ListDirectoryToolInput {
path: "worktree1/src/secret.rs".to_string(),
};
let result = cx
.update(|cx| {

View file

@ -38,11 +38,13 @@ pub struct MovePathToolInput {
pub struct MovePathTool;
impl Tool for MovePathTool {
type Input = MovePathToolInput;
fn name(&self) -> String {
"move_path".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -62,34 +64,29 @@ impl Tool for MovePathTool {
json_schema_for::<MovePathToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<MovePathToolInput>(input.clone()) {
Ok(input) => {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
let src_path = Path::new(&input.source_path);
let dest_path = Path::new(&input.destination_path);
fn ui_text(&self, input: &Self::Input) -> String {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
let src_path = Path::new(&input.source_path);
let dest_path = Path::new(&input.destination_path);
match dest_path
.file_name()
.and_then(|os_str| os_str.to_os_string().into_string().ok())
{
Some(filename) if src_path.parent() == dest_path.parent() => {
let filename = MarkdownInlineCode(&filename);
format!("Rename {src} to {filename}")
}
_ => {
format!("Move {src} to {dest}")
}
}
match dest_path
.file_name()
.and_then(|os_str| os_str.to_os_string().into_string().ok())
{
Some(filename) if src_path.parent() == dest_path.parent() => {
let filename = MarkdownInlineCode(&filename);
format!("Rename {src} to {filename}")
}
_ => {
format!("Move {src} to {dest}")
}
Err(_) => "Move path".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -97,10 +94,6 @@ impl Tool for MovePathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<MovePathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let rename_task = project.update(cx, |project, cx| {
match project
.find_project_path(&input.source_path, cx)

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use anyhow::Result;
use assistant_tool::{ActionLog, Tool, ToolResult};
use chrono::{Local, Utc};
use gpui::{AnyWindowHandle, App, Entity, Task};
@ -29,11 +29,13 @@ pub struct NowToolInput {
pub struct NowTool;
impl Tool for NowTool {
type Input = NowToolInput;
fn name(&self) -> String {
"now".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -53,13 +55,13 @@ impl Tool for NowTool {
json_schema_for::<NowToolInput>(format)
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
fn ui_text(&self, _input: &Self::Input) -> String {
"Get current time".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -67,11 +69,6 @@ impl Tool for NowTool {
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
let input: NowToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let now = match input.timezone {
Timezone::Utc => Utc::now().to_rfc3339(),
Timezone::Local => Local::now().to_rfc3339(),

View file

@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Context as _, Result};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@ -19,11 +19,13 @@ pub struct OpenToolInput {
pub struct OpenTool;
impl Tool for OpenTool {
type Input = OpenToolInput;
fn name(&self) -> String {
"open".to_string()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
true
}
fn may_perform_edits(&self) -> bool {
@ -41,16 +43,13 @@ impl Tool for OpenTool {
json_schema_for::<OpenToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<OpenToolInput>(input.clone()) {
Ok(input) => format!("Open `{}`", MarkdownEscaped(&input.path_or_url)),
Err(_) => "Open file or URL".to_string(),
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Open `{}`", MarkdownEscaped(&input.path_or_url))
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -58,11 +57,6 @@ impl Tool for OpenTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: OpenToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// If path_or_url turns out to be a path in the project, make it absolute.
let abs_path = to_absolute_path(&input.path_or_url, project, cx);

View file

@ -51,11 +51,13 @@ pub struct ReadFileToolInput {
pub struct ReadFileTool;
impl Tool for ReadFileTool {
type Input = ReadFileToolInput;
fn name(&self) -> String {
"read_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -75,23 +77,18 @@ impl Tool for ReadFileTool {
json_schema_for::<ReadFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownInlineCode(&input.path);
match (input.start_line, input.end_line) {
(Some(start), None) => format!("Read file {path} (from line {start})"),
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
_ => format!("Read file {path}"),
}
}
Err(_) => "Read file".to_string(),
fn ui_text(&self, input: &Self::Input) -> String {
let path = MarkdownInlineCode(&input.path);
match (input.start_line, input.end_line) {
(Some(start), None) => format!("Read file {path} (from line {start})"),
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
_ => format!("Read file {path}"),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@ -99,11 +96,6 @@ impl Tool for ReadFileTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
};
@ -308,9 +300,12 @@ mod test {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = json!({
"path": "root/nonexistent_file.txt"
});
let input = ReadFileToolInput {
path: "root/nonexistent_file.txt".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -347,9 +342,11 @@ mod test {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = json!({
"path": "root/small_file.txt"
});
let input = ReadFileToolInput {
path: "root/small_file.txt".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -389,9 +386,11 @@ mod test {
let result = cx
.update(|cx| {
let input = json!({
"path": "root/large_file.rs"
});
let input = ReadFileToolInput {
path: "root/large_file.rs".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -421,10 +420,11 @@ mod test {
let result = cx
.update(|cx| {
let input = json!({
"path": "root/large_file.rs",
"offset": 1
});
let input = ReadFileToolInput {
path: "root/large_file.rs".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -477,11 +477,11 @@ mod test {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 2,
"end_line": 4
});
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(2),
end_line: Some(4),
};
Arc::new(ReadFileTool)
.run(
input,
@ -520,11 +520,11 @@ mod test {
// start_line of 0 should be treated as 1
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 0,
"end_line": 2
});
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(0),
end_line: Some(2),
};
Arc::new(ReadFileTool)
.run(
input,
@ -543,11 +543,11 @@ mod test {
// end_line of 0 should result in at least 1 line
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 1,
"end_line": 0
});
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(1),
end_line: Some(0),
};
Arc::new(ReadFileTool)
.run(
input,
@ -566,11 +566,11 @@ mod test {
// when start_line > end_line, should still return at least 1 line
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 3,
"end_line": 2
});
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(3),
end_line: Some(2),
};
Arc::new(ReadFileTool)
.run(
input,
@ -694,9 +694,11 @@ mod test {
// Reading a file outside the project worktree should fail
let result = cx
.update(|cx| {
let input = json!({
"path": "/outside_project/sensitive_file.txt"
});
let input = ReadFileToolInput {
path: "/outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -718,9 +720,11 @@ mod test {
// Reading a file within the project should succeed
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/allowed_file.txt"
});
let input = ReadFileToolInput {
path: "project_root/allowed_file.txt".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -742,9 +746,11 @@ mod test {
// Reading files that match file_scan_exclusions should fail
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/.secretdir/config"
});
let input = ReadFileToolInput {
path: "project_root/.secretdir/config".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -765,9 +771,11 @@ mod test {
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/.mymetadata"
});
let input = ReadFileToolInput {
path: "project_root/.mymetadata".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -789,9 +797,11 @@ mod test {
// Reading private files should fail
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/.mysecrets"
});
let input = ReadFileToolInput {
path: "project_root/secrets/.mysecrets".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -812,9 +822,11 @@ mod test {
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/subdir/special.privatekey"
});
let input = ReadFileToolInput {
path: "project_root/subdir/special.privatekey".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -835,9 +847,11 @@ mod test {
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/subdir/data.mysensitive"
});
let input = ReadFileToolInput {
path: "project_root/subdir/data.mysensitive".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -859,9 +873,11 @@ mod test {
// Reading a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/subdir/normal_file.txt"
});
let input = ReadFileToolInput {
path: "project_root/subdir/normal_file.txt".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -884,9 +900,11 @@ mod test {
// Path traversal attempts with .. should fail
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/../outside_project/sensitive_file.txt"
});
let input = ReadFileToolInput {
path: "project_root/../outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
Arc::new(ReadFileTool)
.run(
input,
@ -981,9 +999,11 @@ mod test {
let tool = Arc::new(ReadFileTool);
// Test reading allowed files in worktree1
let input = json!({
"path": "worktree1/src/main.rs"
});
let input = ReadFileToolInput {
path: "worktree1/src/main.rs".to_string(),
start_line: None,
end_line: None,
};
let result = cx
.update(|cx| {
@ -1007,9 +1027,11 @@ mod test {
);
// Test reading private file in worktree1 should fail
let input = json!({
"path": "worktree1/src/secret.rs"
});
let input = ReadFileToolInput {
path: "worktree1/src/secret.rs".to_string(),
start_line: None,
end_line: None,
};
let result = cx
.update(|cx| {
@ -1036,9 +1058,11 @@ mod test {
);
// Test reading excluded file in worktree1 should fail
let input = json!({
"path": "worktree1/tests/fixture.sql"
});
let input = ReadFileToolInput {
path: "worktree1/tests/fixture.sql".to_string(),
start_line: None,
end_line: None,
};
let result = cx
.update(|cx| {
@ -1065,9 +1089,11 @@ mod test {
);
// Test reading allowed files in worktree2
let input = json!({
"path": "worktree2/lib/public.js"
});
let input = ReadFileToolInput {
path: "worktree2/lib/public.js".to_string(),
start_line: None,
end_line: None,
};
let result = cx
.update(|cx| {
@ -1091,9 +1117,11 @@ mod test {
);
// Test reading private file in worktree2 should fail
let input = json!({
"path": "worktree2/lib/private.js"
});
let input = ReadFileToolInput {
path: "worktree2/lib/private.js".to_string(),
start_line: None,
end_line: None,
};
let result = cx
.update(|cx| {
@ -1120,9 +1148,11 @@ mod test {
);
// Test reading excluded file in worktree2 should fail
let input = json!({
"path": "worktree2/docs/internal.md"
});
let input = ReadFileToolInput {
path: "worktree2/docs/internal.md".to_string(),
start_line: None,
end_line: None,
};
let result = cx
.update(|cx| {
@ -1150,9 +1180,11 @@ mod test {
// Test that files allowed in one worktree but not in another are handled correctly
// (e.g., config.toml is private in worktree1 but doesn't exist in worktree2)
let input = json!({
"path": "worktree1/src/config.toml"
});
let input = ReadFileToolInput {
path: "worktree1/src/config.toml".to_string(),
start_line: None,
end_line: None,
};
let result = cx
.update(|cx| {

View file

@ -2,7 +2,7 @@ use crate::{
schema::json_schema_for,
ui::{COLLAPSED_LINES, ToolOutputPreview},
};
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Context as _, Result};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt as _, future::Shared};
use gpui::{
@ -72,11 +72,13 @@ impl TerminalTool {
}
impl Tool for TerminalTool {
type Input = TerminalToolInput;
fn name(&self) -> String {
Self::NAME.to_string()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
true
}
@ -96,30 +98,24 @@ impl Tool for TerminalTool {
json_schema_for::<TerminalToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<TerminalToolInput>(input.clone()) {
Ok(input) => {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownInlineCode(&first_line).to_string(),
1 => MarkdownInlineCode(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.to_string(),
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n))
.to_string(),
}
}
Err(_) => "Run terminal command".to_string(),
fn ui_text(&self, input: &Self::Input) -> String {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownInlineCode(&first_line).to_string(),
1 => MarkdownInlineCode(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.to_string(),
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n)).to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -127,11 +123,6 @@ impl Tool for TerminalTool {
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: TerminalToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let working_dir = match working_dir(&input, &project, cx) {
Ok(dir) => dir,
Err(err) => return Task::ready(Err(err)).into(),
@ -756,7 +747,7 @@ mod tests {
let result = cx.update(|cx| {
TerminalTool::run(
Arc::new(TerminalTool::new(cx)),
serde_json::to_value(input).unwrap(),
input,
Arc::default(),
project.clone(),
action_log.clone(),
@ -791,7 +782,7 @@ mod tests {
let check = |input, expected, cx: &mut App| {
let headless_result = TerminalTool::run(
Arc::new(TerminalTool::new(cx)),
serde_json::to_value(input).unwrap(),
input,
Arc::default(),
project.clone(),
action_log.clone(),

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use anyhow::Result;
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
@ -20,11 +20,13 @@ pub struct ThinkingToolInput {
pub struct ThinkingTool;
impl Tool for ThinkingTool {
type Input = ThinkingToolInput;
fn name(&self) -> String {
"thinking".to_string()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -44,13 +46,13 @@ impl Tool for ThinkingTool {
json_schema_for::<ThinkingToolInput>(format)
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
fn ui_text(&self, _input: &Self::Input) -> String {
"Thinking".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_input: Self::Input,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -59,10 +61,6 @@ impl Tool for ThinkingTool {
_cx: &mut App,
) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions.
Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) {
Ok(_input) => Ok("Finished thinking.".to_string().into()),
Err(err) => Err(anyhow!(err)),
})
.into()
Task::ready(Ok("Finished thinking.".to_string().into())).into()
}
}

View file

@ -28,11 +28,13 @@ pub struct WebSearchToolInput {
pub struct WebSearchTool;
impl Tool for WebSearchTool {
type Input = WebSearchToolInput;
fn name(&self) -> String {
"web_search".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
false
}
@ -52,13 +54,13 @@ impl Tool for WebSearchTool {
json_schema_for::<WebSearchToolInput>(format)
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
fn ui_text(&self, _input: &Self::Input) -> String {
"Searching the Web".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
input: Self::Input,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@ -66,10 +68,6 @@ impl Tool for WebSearchTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
return Task::ready(Err(anyhow!("Web search is not available."))).into();
};

View file

@ -1736,7 +1736,7 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu
let exists_result = cx.update(|cx| {
ReadFileTool::run(
Arc::new(ReadFileTool),
serde_json::to_value(input).unwrap(),
input,
request.clone(),
project.clone(),
action_log.clone(),
@ -1756,7 +1756,7 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu
let does_not_exist_result = cx.update(|cx| {
ReadFileTool::run(
Arc::new(ReadFileTool),
serde_json::to_value(input).unwrap(),
input,
request.clone(),
project.clone(),
action_log.clone(),