From 0cdd8bdded9bfea62a4508372eb3ff19179801ba Mon Sep 17 00:00:00 2001 From: Mikayla Maki Date: Tue, 6 May 2025 18:16:34 -0700 Subject: [PATCH] Restore tool cards on thread deserialization (#30053) Release Notes: - N/A --------- Co-authored-by: Julia Ryan --- crates/agent/src/assistant_panel.rs | 9 +- crates/agent/src/context_picker.rs | 7 +- .../src/context_picker/completion_provider.rs | 23 ++-- .../context_picker/thread_context_picker.rs | 4 +- crates/agent/src/context_server_tool.rs | 2 +- crates/agent/src/history_store.rs | 105 ++++++++++-------- crates/agent/src/thread.rs | 11 +- crates/agent/src/thread_store.rs | 9 +- crates/agent/src/tool_use.rs | 30 ++++- crates/assistant_tool/src/assistant_tool.rs | 40 ++++++- crates/assistant_tools/src/copy_path_tool.rs | 7 +- .../src/create_directory_tool.rs | 2 +- .../assistant_tools/src/create_file_tool.rs | 2 +- .../assistant_tools/src/delete_path_tool.rs | 2 +- .../assistant_tools/src/diagnostics_tool.rs | 12 +- .../assistant_tools/src/edit_agent/evals.rs | 1 + crates/assistant_tools/src/edit_file_tool.rs | 58 ++++++++-- crates/assistant_tools/src/fetch_tool.rs | 2 +- crates/assistant_tools/src/find_path_tool.rs | 4 +- crates/assistant_tools/src/grep_tool.rs | 10 +- .../src/list_directory_tool.rs | 6 +- crates/assistant_tools/src/move_path_tool.rs | 7 +- crates/assistant_tools/src/now_tool.rs | 2 +- crates/assistant_tools/src/open_tool.rs | 2 +- crates/assistant_tools/src/read_file_tool.rs | 18 +-- .../src/streaming_edit_file_tool.rs | 48 +++++++- crates/assistant_tools/src/terminal_tool.rs | 12 +- crates/assistant_tools/src/thinking_tool.rs | 2 +- crates/assistant_tools/src/web_search_tool.rs | 4 +- crates/language_model/src/request.rs | 1 + 30 files changed, 307 insertions(+), 135 deletions(-) diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 6b46e95bae..53a1e9edd1 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -510,6 +510,7 @@ impl AssistantPanel { thread_store.clone(), context_store.clone(), [RecentEntry::Thread(thread_id, thread.clone())], + window, cx, ) }); @@ -764,9 +765,9 @@ impl AssistantPanel { }); if let Some(other_thread_id) = action.from_thread_id.clone() { - let other_thread_task = self - .thread_store - .update(cx, |this, cx| this.open_thread(&other_thread_id, cx)); + let other_thread_task = self.thread_store.update(cx, |this, cx| { + this.open_thread(&other_thread_id, window, cx) + }); cx.spawn({ let context_store = context_store.clone(); @@ -967,7 +968,7 @@ impl AssistantPanel { ) -> Task> { let open_thread_task = self .thread_store - .update(cx, |this, cx| this.open_thread(thread_id, cx)); + .update(cx, |this, cx| this.open_thread(thread_id, window, cx)); cx.spawn_in(window, async move |this, cx| { let thread = open_thread_task.await?; this.update_in(cx, |this, window, cx| { diff --git a/crates/agent/src/context_picker.rs b/crates/agent/src/context_picker.rs index ab4f4d2d9c..a9cf7889ca 100644 --- a/crates/agent/src/context_picker.rs +++ b/crates/agent/src/context_picker.rs @@ -425,9 +425,9 @@ impl ContextPicker { render_thread_context_entry(&view_thread, context_store.clone(), cx) .into_any() }, - move |_window, cx| { + move |window, cx| { context_picker.update(cx, |this, cx| { - this.add_recent_thread(thread.clone(), cx) + this.add_recent_thread(thread.clone(), window, cx) .detach_and_log_err(cx); }) }, @@ -459,6 +459,7 @@ impl ContextPicker { fn add_recent_thread( &self, entry: ThreadContextEntry, + window: &mut Window, cx: &mut Context, ) -> Task> { let Some(context_store) = self.context_store.upgrade() else { @@ -476,7 +477,7 @@ impl ContextPicker { }; let open_thread_task = - thread_store.update(cx, |this, cx| this.open_thread(&id, cx)); + thread_store.update(cx, |this, cx| this.open_thread(&id, window, cx)); cx.spawn(async move |this, cx| { let thread = open_thread_task.await?; context_store.update(cx, |context_store, cx| { diff --git a/crates/agent/src/context_picker/completion_provider.rs b/crates/agent/src/context_picker/completion_provider.rs index b886725913..ebdd984d48 100644 --- a/crates/agent/src/context_picker/completion_provider.rs +++ b/crates/agent/src/context_picker/completion_provider.rs @@ -438,15 +438,15 @@ impl ContextPickerCompletionProvider { new_text_len, editor.clone(), context_store.clone(), - move |cx| match &thread_entry { + move |window, cx| match &thread_entry { ThreadContextEntry::Thread { id, .. } => { let thread_id = id.clone(); let context_store = context_store.clone(); let thread_store = thread_store.clone(); - cx.spawn::<_, Option<_>>(async move |cx| { + window.spawn::<_, Option<_>>(cx, async move |cx| { let thread: Entity = thread_store - .update(cx, |thread_store, cx| { - thread_store.open_thread(&thread_id, cx) + .update_in(cx, |thread_store, window, cx| { + thread_store.open_thread(&thread_id, window, cx) }) .ok()? .await @@ -507,7 +507,7 @@ impl ContextPickerCompletionProvider { new_text_len, editor.clone(), context_store.clone(), - move |cx| { + move |_, cx| { let user_prompt_id = rules.prompt_id; let context = context_store.update(cx, |context_store, cx| { context_store.add_rules(user_prompt_id, false, cx) @@ -544,7 +544,7 @@ impl ContextPickerCompletionProvider { new_text_len, editor.clone(), context_store.clone(), - move |cx| { + move |_, cx| { let context_store = context_store.clone(); let http_client = http_client.clone(); let url_to_fetch = url_to_fetch.clone(); @@ -629,7 +629,7 @@ impl ContextPickerCompletionProvider { new_text_len, editor, context_store.clone(), - move |cx| { + move |_, cx| { if is_directory { Task::ready( context_store @@ -700,7 +700,7 @@ impl ContextPickerCompletionProvider { new_text_len, editor.clone(), context_store.clone(), - move |cx| { + move |_, cx| { let symbol = symbol.clone(); let context_store = context_store.clone(); let workspace = workspace.clone(); @@ -954,10 +954,13 @@ fn confirm_completion_callback( content_len: usize, editor: Entity, context_store: Entity, - add_context_fn: impl Fn(&mut App) -> Task> + Send + Sync + 'static, + add_context_fn: impl Fn(&mut Window, &mut App) -> Task> + + Send + + Sync + + 'static, ) -> Arc bool + Send + Sync> { Arc::new(move |_, window, cx| { - let context = add_context_fn(cx); + let context = add_context_fn(window, cx); let crease_text = crease_text.clone(); let crease_icon_path = crease_icon_path.clone(); diff --git a/crates/agent/src/context_picker/thread_context_picker.rs b/crates/agent/src/context_picker/thread_context_picker.rs index 96ee399a7f..c189d071be 100644 --- a/crates/agent/src/context_picker/thread_context_picker.rs +++ b/crates/agent/src/context_picker/thread_context_picker.rs @@ -154,7 +154,7 @@ impl PickerDelegate for ThreadContextPickerDelegate { }) } - fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context>) { + fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { let Some(entry) = self.matches.get(self.selected_index) else { return; }; @@ -165,7 +165,7 @@ impl PickerDelegate for ThreadContextPickerDelegate { return; }; let open_thread_task = - thread_store.update(cx, |this, cx| this.open_thread(&id, cx)); + thread_store.update(cx, |this, cx| this.open_thread(&id, window, cx)); cx.spawn(async move |this, cx| { let thread = open_thread_task.await?; diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index 69283b9b63..decba03cdd 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -115,7 +115,7 @@ impl Tool for ContextServerTool { } } } - Ok(result) + Ok(result.into()) }) .into() } else { diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index c2018e1c3b..2c03c81d41 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -8,7 +8,7 @@ use gpui::{Entity, Task, prelude::*}; use serde::{Deserialize, Serialize}; use smol::future::FutureExt; use std::time::Duration; -use ui::{App, SharedString}; +use ui::{App, SharedString, Window}; use util::ResultExt as _; use crate::{ @@ -82,6 +82,7 @@ impl HistoryStore { thread_store: Entity, context_store: Entity, initial_recent_entries: impl IntoIterator, + window: &mut Window, cx: &mut Context, ) -> Self { let subscriptions = vec![ @@ -89,56 +90,62 @@ impl HistoryStore { cx.observe(&context_store, |_, _, cx| cx.notify()), ]; - cx.spawn({ - let thread_store = thread_store.downgrade(); - let context_store = context_store.downgrade(); - async move |this, cx| { - let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH); - let contents = cx - .background_spawn(async move { std::fs::read_to_string(path) }) - .await - .ok()?; - let entries = serde_json::from_str::>(&contents) - .context("deserializing persisted agent panel navigation history") - .log_err()? - .into_iter() - .take(MAX_RECENTLY_OPENED_ENTRIES) - .map(|serialized| match serialized { - SerializedRecentEntry::Thread(id) => thread_store - .update(cx, |thread_store, cx| { - let thread_id = ThreadId::from(id.as_str()); - thread_store - .open_thread(&thread_id, cx) - .map_ok(|thread| RecentEntry::Thread(thread_id, thread)) - .boxed() - }) - .unwrap_or_else(|_| async { Err(anyhow!("no thread store")) }.boxed()), - SerializedRecentEntry::Context(id) => context_store - .update(cx, |context_store, cx| { - context_store - .open_local_context(Path::new(&id).into(), cx) - .map_ok(RecentEntry::Context) - .boxed() - }) - .unwrap_or_else(|_| async { Err(anyhow!("no context store")) }.boxed()), - }); - let entries = join_all(entries) - .await - .into_iter() - .filter_map(|result| result.log_err()) - .collect::>(); + window + .spawn(cx, { + let thread_store = thread_store.downgrade(); + let context_store = context_store.downgrade(); + let this = cx.weak_entity(); + async move |cx| { + let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH); + let contents = cx + .background_spawn(async move { std::fs::read_to_string(path) }) + .await + .ok()?; + let entries = serde_json::from_str::>(&contents) + .context("deserializing persisted agent panel navigation history") + .log_err()? + .into_iter() + .take(MAX_RECENTLY_OPENED_ENTRIES) + .map(|serialized| match serialized { + SerializedRecentEntry::Thread(id) => thread_store + .update_in(cx, |thread_store, window, cx| { + let thread_id = ThreadId::from(id.as_str()); + thread_store + .open_thread(&thread_id, window, cx) + .map_ok(|thread| RecentEntry::Thread(thread_id, thread)) + .boxed() + }) + .unwrap_or_else(|_| { + async { Err(anyhow!("no thread store")) }.boxed() + }), + SerializedRecentEntry::Context(id) => context_store + .update(cx, |context_store, cx| { + context_store + .open_local_context(Path::new(&id).into(), cx) + .map_ok(RecentEntry::Context) + .boxed() + }) + .unwrap_or_else(|_| { + async { Err(anyhow!("no context store")) }.boxed() + }), + }); + let entries = join_all(entries) + .await + .into_iter() + .filter_map(|result| result.log_err()) + .collect::>(); - this.update(cx, |this, _| { - this.recently_opened_entries.extend(entries); - this.recently_opened_entries - .truncate(MAX_RECENTLY_OPENED_ENTRIES); - }) - .ok(); + this.update(cx, |this, _| { + this.recently_opened_entries.extend(entries); + this.recently_opened_entries + .truncate(MAX_RECENTLY_OPENED_ENTRIES); + }) + .ok(); - Some(()) - } - }) - .detach(); + Some(()) + } + }) + .detach(); Self { thread_store, diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index dacde7cda4..ca6f0f99d4 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -35,6 +35,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; use thiserror::Error; +use ui::Window; use util::{ResultExt as _, TryFutureExt as _, post_inc}; use uuid::Uuid; use zed_llm_client::CompletionRequestStatus; @@ -430,6 +431,7 @@ impl Thread { tools: Entity, prompt_builder: Arc, project_context: SharedProjectContext, + window: &mut Window, cx: &mut Context, ) -> Self { let next_message_id = MessageId( @@ -439,7 +441,13 @@ impl Thread { .map(|message| message.id.0 + 1) .unwrap_or(0), ); - let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages); + let tool_use = ToolUseState::from_serialized_messages( + tools.clone(), + &serialized.messages, + project.clone(), + window, + cx, + ); let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel_with(serialized.detailed_summary_state); @@ -1064,6 +1072,7 @@ impl Thread { tool_use_id: tool_result.tool_use_id.clone(), is_error: tool_result.is_error, content: tool_result.content.clone(), + output: tool_result.output.clone(), }) .collect(), context: message.loaded_context.text.clone(), diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 3d3a1b757e..99ecd3d442 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -28,6 +28,7 @@ use prompt_store::{ }; use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; +use ui::Window; use util::ResultExt as _; use crate::context_server_tool::ContextServerTool; @@ -388,18 +389,20 @@ impl ThreadStore { pub fn open_thread( &self, id: &ThreadId, + window: &mut Window, cx: &mut Context, ) -> Task>> { let id = id.clone(); let database_future = ThreadsDatabase::global_future(cx); - cx.spawn(async move |this, cx| { + let this = cx.weak_entity(); + window.spawn(cx, async move |cx| { let database = database_future.await.map_err(|err| anyhow!(err))?; let thread = database .try_find_thread(id.clone()) .await? .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?; - let thread = this.update(cx, |this, cx| { + let thread = this.update_in(cx, |this, window, cx| { cx.new(|cx| { Thread::deserialize( id.clone(), @@ -408,6 +411,7 @@ impl ThreadStore { this.tools.clone(), this.prompt_builder.clone(), this.project_context.clone(), + window, cx, ) }) @@ -772,6 +776,7 @@ pub struct SerializedToolResult { pub tool_use_id: LanguageModelToolUseId, pub is_error: bool, pub content: Arc, + pub output: Option, } #[derive(Serialize, Deserialize)] diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 38b00786ff..f7c02f7d74 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use anyhow::Result; -use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet}; +use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet}; use collections::HashMap; use futures::FutureExt as _; use futures::future::Shared; @@ -10,7 +10,8 @@ use language_model::{ ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, }; -use ui::IconName; +use project::Project; +use ui::{IconName, Window}; use util::truncate_lines_to_byte_limit; use crate::thread::{MessageId, PromptId, ThreadId}; @@ -54,6 +55,9 @@ impl ToolUseState { pub fn from_serialized_messages( tools: Entity, messages: &[SerializedMessage], + project: Entity, + window: &mut Window, + cx: &mut App, ) -> Self { let mut this = Self::new(tools); let mut tool_names_by_id = HashMap::default(); @@ -93,12 +97,23 @@ impl ToolUseState { this.tool_results.insert( tool_use_id.clone(), LanguageModelToolResult { - tool_use_id, + tool_use_id: tool_use_id.clone(), tool_name: tool_use.clone(), is_error: tool_result.is_error, content: tool_result.content.clone(), + output: tool_result.output.clone(), }, ); + + if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) { + if let Some(output) = tool_result.output.clone() { + if let Some(card) = + tool.deserialize_card(output, project.clone(), window, cx) + { + this.tool_result_cards.insert(tool_use_id, card); + } + } + } } } } @@ -124,6 +139,7 @@ impl ToolUseState { tool_use_id: tool_use_id.clone(), tool_name: tool_use.name.clone(), content, + output: None, is_error: true, }, ); @@ -359,7 +375,7 @@ impl ToolUseState { &mut self, tool_use_id: LanguageModelToolUseId, tool_name: Arc, - output: Result, + output: Result, configured_model: Option<&ConfiguredModel>, ) -> Option { let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id); @@ -379,7 +395,8 @@ impl ToolUseState { ); match output { - Ok(tool_result) => { + Ok(output) => { + let tool_result = output.content; const BYTES_PER_TOKEN_ESTIMATE: usize = 3; // Protect from clearly large output @@ -406,6 +423,7 @@ impl ToolUseState { tool_name, content: tool_result.into(), is_error: false, + output: output.output, }, ); self.pending_tool_uses_by_id.remove(&tool_use_id) @@ -418,6 +436,7 @@ impl ToolUseState { tool_name, content: err.to_string().into(), is_error: true, + output: None, }, ); @@ -490,6 +509,7 @@ impl ToolUseState { } else { tool_result.content.clone() }, + output: None, })); } } diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 3ef54d57da..68a4f9746d 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -7,6 +7,7 @@ mod tool_working_set; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; +use std::ops::Deref; use std::sync::Arc; use anyhow::Result; @@ -61,11 +62,34 @@ impl ToolUseStatus { } } +#[derive(Debug)] +pub struct ToolResultOutput { + pub content: String, + pub output: Option, +} + +impl From for ToolResultOutput { + fn from(value: String) -> Self { + ToolResultOutput { + content: value, + output: None, + } + } +} + +impl Deref for ToolResultOutput { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.content + } +} + /// The result of running a tool, containing both the asynchronous output /// and an optional card view that can be rendered immediately. pub struct ToolResult { /// The asynchronous task that will eventually resolve to the tool's output - pub output: Task>, + pub output: Task>, /// An optional view to present the output of the tool. pub card: Option, } @@ -128,9 +152,9 @@ impl AnyToolCard { } } -impl From>> for ToolResult { +impl From>> for ToolResult { /// Convert from a task to a ToolResult with no card - fn from(output: Task>) -> Self { + fn from(output: Task>) -> Self { Self { output, card: None } } } @@ -187,6 +211,16 @@ pub trait Tool: 'static + Send + Sync { window: Option, cx: &mut App, ) -> ToolResult; + + fn deserialize_card( + self: Arc, + _output: serde_json::Value, + _project: Entity, + _window: &mut Window, + _cx: &mut App, + ) -> Option { + None + } } impl Debug for dyn Tool { diff --git a/crates/assistant_tools/src/copy_path_tool.rs b/crates/assistant_tools/src/copy_path_tool.rs index 07d4e58302..8839ef7fc2 100644 --- a/crates/assistant_tools/src/copy_path_tool.rs +++ b/crates/assistant_tools/src/copy_path_tool.rs @@ -107,10 +107,9 @@ impl Tool for CopyPathTool { cx.background_spawn(async move { match copy_task.await { - Ok(_) => Ok(format!( - "Copied {} to {}", - input.source_path, input.destination_path - )), + Ok(_) => Ok( + format!("Copied {} to {}", input.source_path, input.destination_path).into(), + ), Err(err) => Err(anyhow!( "Failed to copy {} to {}: {}", input.source_path, diff --git a/crates/assistant_tools/src/create_directory_tool.rs b/crates/assistant_tools/src/create_directory_tool.rs index a0ccd6f425..7354ef3eb7 100644 --- a/crates/assistant_tools/src/create_directory_tool.rs +++ b/crates/assistant_tools/src/create_directory_tool.rs @@ -88,7 +88,7 @@ impl Tool for CreateDirectoryTool { .await .map_err(|err| anyhow!("Unable to create directory {destination_path}: {err}"))?; - Ok(format!("Created directory {destination_path}")) + Ok(format!("Created directory {destination_path}").into()) }) .into() } diff --git a/crates/assistant_tools/src/create_file_tool.rs b/crates/assistant_tools/src/create_file_tool.rs index d52c704e7c..ae76fcb459 100644 --- a/crates/assistant_tools/src/create_file_tool.rs +++ b/crates/assistant_tools/src/create_file_tool.rs @@ -131,7 +131,7 @@ impl Tool for CreateFileTool { .await .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?; - Ok(format!("Created file {destination_path}")) + Ok(format!("Created file {destination_path}").into()) }) .into() } diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index f5452d0eb8..931a989d49 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -127,7 +127,7 @@ impl Tool for DeletePathTool { match delete { Some(deletion_task) => match deletion_task.await { - Ok(()) => Ok(format!("Deleted {path_str}")), + Ok(()) => Ok(format!("Deleted {path_str}").into()), Err(err) => Err(anyhow!("Failed to delete {path_str}: {err}")), }, None => Err(anyhow!( diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index 17f96177c1..702d5f4277 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -122,9 +122,9 @@ impl Tool for DiagnosticsTool { } if output.is_empty() { - Ok("File doesn't have errors or warnings!".to_string()) + Ok("File doesn't have errors or warnings!".to_string().into()) } else { - Ok(output) + Ok(output.into()) } }) .into() @@ -158,10 +158,12 @@ impl Tool for DiagnosticsTool { }); if has_diagnostics { - Task::ready(Ok(output)).into() + Task::ready(Ok(output.into())).into() } else { - Task::ready(Ok("No errors or warnings found in the project.".to_string())) - .into() + Task::ready(Ok("No errors or warnings found in the project." + .to_string() + .into())) + .into() } } } diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index d08ef84a46..8b5c8c8d05 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -895,6 +895,7 @@ fn tool_result( tool_name: name.into(), is_error: false, content: result.into(), + output: None, }) } diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 9297648389..033ac34d5e 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -1,9 +1,12 @@ use crate::{ replace::{replace_exact, replace_with_flexible_indent}, schema::json_schema_for, + streaming_edit_file_tool::StreamingEditFileToolOutput, }; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus}; +use assistant_tool::{ + ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus, +}; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::{Editor, EditorElement, EditorMode, EditorStyle, MultiBuffer, PathKey}; use gpui::{ @@ -153,7 +156,7 @@ impl Tool for EditFileTool { }); let card_clone = card.clone(); - let task = cx.spawn(async move |cx: &mut AsyncApp| { + let task: Task> = cx.spawn(async move |cx: &mut AsyncApp| { let project_path = project.read_with(cx, |project, cx| { project .find_project_path(&input.path, cx) @@ -281,16 +284,29 @@ impl Tool for EditFileTool { if let Some(card) = card_clone { card.update(cx, |card, cx| { - card.set_diff(project_path.path.clone(), old_text, new_text, cx); + card.set_diff( + project_path.path.clone(), + old_text.clone(), + new_text.clone(), + cx, + ); }) .log_err(); } - Ok(format!( - "Edited {}:\n\n```diff\n{}\n```", - input.path.display(), - diff_str - )) + Ok(ToolResultOutput { + content: format!( + "Edited {}:\n\n```diff\n{}\n```", + input.path.display(), + diff_str + ), + output: serde_json::to_value(StreamingEditFileToolOutput { + original_path: input.path, + new_text, + old_text, + }) + .ok(), + }) }); ToolResult { @@ -298,6 +314,32 @@ impl Tool for EditFileTool { card: card.map(AnyToolCard::from), } } + + fn deserialize_card( + self: Arc, + output: serde_json::Value, + project: Entity, + window: &mut Window, + cx: &mut App, + ) -> Option { + let output = match serde_json::from_value::(output) { + Ok(output) => output, + Err(_) => return None, + }; + + let card = cx.new(|cx| { + let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx); + card.set_diff( + output.original_path.into(), + output.old_text, + output.new_text, + cx, + ); + card + }); + + Some(card.into()) + } } pub struct EditFileToolCard { diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index 413a4c5589..92a403d868 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -166,7 +166,7 @@ impl Tool for FetchTool { bail!("no textual content found"); } - Ok(text) + Ok(text.into()) }) .into() } diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs index c2202b445a..9422ad2178 100644 --- a/crates/assistant_tools/src/find_path_tool.rs +++ b/crates/assistant_tools/src/find_path_tool.rs @@ -98,7 +98,7 @@ impl Tool for FindPathTool { sender.send(paginated_matches.to_vec()).log_err(); if matches.is_empty() { - Ok("No matches found".to_string()) + Ok("No matches found".to_string().into()) } else { let mut message = format!("Found {} total matches.", matches.len()); if matches.len() > RESULTS_PER_PAGE { @@ -113,7 +113,7 @@ impl Tool for FindPathTool { for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) { write!(&mut message, "\n{}", mat.display()).unwrap(); } - Ok(message) + Ok(message.into()) } }); diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index e296a472b2..e821a7fda4 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -260,16 +260,16 @@ impl Tool for GrepTool { } if matches_found == 0 { - Ok("No matches found".to_string()) + Ok("No matches found".to_string().into()) } else if has_more_matches { Ok(format!( "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}", input.offset + 1, input.offset + matches_found, input.offset + RESULTS_PER_PAGE, - )) + ).into()) } else { - Ok(format!("Found {matches_found} matches:\n{output}")) + Ok(format!("Found {matches_found} matches:\n{output}").into()) } }).into() } @@ -748,9 +748,9 @@ mod tests { match task.output.await { Ok(result) => { if cfg!(windows) { - result.replace("root\\", "root/") + result.content.replace("root\\", "root/") } else { - result + result.content } } Err(e) => panic!("Failed to run grep tool: {}", e), diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index 26665f311c..a988a145e4 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -102,7 +102,7 @@ impl Tool for ListDirectoryTool { .collect::>() .join("\n"); - return Task::ready(Ok(output)).into(); + return Task::ready(Ok(output.into())).into(); } let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { @@ -134,8 +134,8 @@ impl Tool for ListDirectoryTool { .unwrap(); } if output.is_empty() { - return Task::ready(Ok(format!("{} is empty.", input.path))).into(); + return Task::ready(Ok(format!("{} is empty.", input.path).into())).into(); } - Task::ready(Ok(output)).into() + Task::ready(Ok(output.into())).into() } } diff --git a/crates/assistant_tools/src/move_path_tool.rs b/crates/assistant_tools/src/move_path_tool.rs index 9c9493c4f5..fba73d0b8d 100644 --- a/crates/assistant_tools/src/move_path_tool.rs +++ b/crates/assistant_tools/src/move_path_tool.rs @@ -117,10 +117,9 @@ impl Tool for MovePathTool { cx.background_spawn(async move { match rename_task.await { - Ok(_) => Ok(format!( - "Moved {} to {}", - input.source_path, input.destination_path - )), + Ok(_) => { + Ok(format!("Moved {} to {}", input.source_path, input.destination_path).into()) + } Err(err) => Err(anyhow!( "Failed to move {} to {}: {}", input.source_path, diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index 46d009715f..bcef45b104 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -73,6 +73,6 @@ impl Tool for NowTool { }; let text = format!("The current datetime is {now}."); - Task::ready(Ok(text)).into() + Task::ready(Ok(text.into())).into() } } diff --git a/crates/assistant_tools/src/open_tool.rs b/crates/assistant_tools/src/open_tool.rs index 2df0dda905..c4e3fa822f 100644 --- a/crates/assistant_tools/src/open_tool.rs +++ b/crates/assistant_tools/src/open_tool.rs @@ -70,7 +70,7 @@ impl Tool for OpenTool { } .context("Failed to open URL or file path")?; - Ok(format!("Successfully opened {}", input.path_or_url)) + Ok(format!("Successfully opened {}", input.path_or_url).into()) }) .into() } diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 709683e99e..72718ba136 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -145,9 +145,9 @@ impl Tool for ReadFileTool { let lines = text.split('\n').skip(start_row as usize); if let Some(end) = input.end_line { let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line - Itertools::intersperse(lines.take(count as usize), "\n").collect() + Itertools::intersperse(lines.take(count as usize), "\n").collect::().into() } else { - Itertools::intersperse(lines, "\n").collect() + Itertools::intersperse(lines, "\n").collect::().into() } })?; @@ -180,7 +180,7 @@ impl Tool for ReadFileTool { log.buffer_read(buffer, cx); })?; - Ok(result) + Ok(result.into()) } else { // File is too big, so return the outline // and a suggestion to read again with line numbers. @@ -192,7 +192,7 @@ impl Tool for ReadFileTool { Using the line numbers in this outline, you can call this tool again while specifying the start_line and end_line fields to see the implementations of symbols in the outline." - }) + }.into()) } } }) @@ -258,7 +258,7 @@ mod test { .output }) .await; - assert_eq!(result.unwrap(), "This is a small file content"); + assert_eq!(result.unwrap().content, "This is a small file content"); } #[gpui::test] @@ -358,7 +358,7 @@ mod test { .output }) .await; - assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4"); + assert_eq!(result.unwrap().content, "Line 2\nLine 3\nLine 4"); } #[gpui::test] @@ -389,7 +389,7 @@ mod test { .output }) .await; - assert_eq!(result.unwrap(), "Line 1\nLine 2"); + assert_eq!(result.unwrap().content, "Line 1\nLine 2"); // end_line of 0 should result in at least 1 line let result = cx @@ -404,7 +404,7 @@ mod test { .output }) .await; - assert_eq!(result.unwrap(), "Line 1"); + assert_eq!(result.unwrap().content, "Line 1"); // when start_line > end_line, should still return at least 1 line let result = cx @@ -419,7 +419,7 @@ mod test { .output }) .await; - assert_eq!(result.unwrap(), "Line 3"); + assert_eq!(result.unwrap().content, "Line 3"); } fn init_test(cx: &mut TestAppContext) { diff --git a/crates/assistant_tools/src/streaming_edit_file_tool.rs b/crates/assistant_tools/src/streaming_edit_file_tool.rs index e9e69de54c..60d0d354aa 100644 --- a/crates/assistant_tools/src/streaming_edit_file_tool.rs +++ b/crates/assistant_tools/src/streaming_edit_file_tool.rs @@ -5,7 +5,7 @@ use crate::{ schema::json_schema_for, }; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult}; +use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult, ToolResultOutput}; use futures::StreamExt; use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task}; use indoc::formatdoc; @@ -67,6 +67,13 @@ pub struct StreamingEditFileToolInput { pub create_or_overwrite: bool, } +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct StreamingEditFileToolOutput { + pub original_path: PathBuf, + pub new_text: String, + pub old_text: String, +} + #[derive(Debug, Serialize, Deserialize, JsonSchema)] struct PartialInput { #[serde(default)] @@ -248,6 +255,12 @@ impl Tool for StreamingEditFileTool { }); let (new_text, diff) = futures::join!(new_text, diff); + let output = StreamingEditFileToolOutput { + original_path: project_path.path.to_path_buf(), + new_text: new_text.clone(), + old_text: old_text.clone(), + }; + if let Some(card) = card_clone { card.update(cx, |card, cx| { card.set_diff(project_path.path.clone(), old_text, new_text, cx); @@ -264,10 +277,13 @@ impl Tool for StreamingEditFileTool { I can perform the requested edits. "})) } else { - Ok("No edits were made.".to_string()) + Ok("No edits were made.".to_string().into()) } } else { - Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff)) + Ok(ToolResultOutput { + content: format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff), + output: serde_json::to_value(output).ok(), + }) } }); @@ -276,6 +292,32 @@ impl Tool for StreamingEditFileTool { card: card.map(AnyToolCard::from), } } + + fn deserialize_card( + self: Arc, + output: serde_json::Value, + project: Entity, + window: &mut Window, + cx: &mut App, + ) -> Option { + let output = match serde_json::from_value::(output) { + Ok(output) => output, + Err(_) => return None, + }; + + let card = cx.new(|cx| { + let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx); + card.set_diff( + output.original_path.into(), + output.old_text, + output.new_text, + cx, + ); + card + }); + + Some(card.into()) + } } #[cfg(test)] diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index a82b80298b..e5c09d942f 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -178,7 +178,7 @@ impl Tool for TerminalTool { let exit_status = child.wait()?; let (processed_content, _) = process_content(content, &input.command, Some(exit_status)); - Ok(processed_content) + Ok(processed_content.into()) }); return ToolResult { output: task, @@ -266,7 +266,7 @@ impl Tool for TerminalTool { card.elapsed_time = Some(card.start_instant.elapsed()); }); - Ok(processed_content) + Ok(processed_content.into()) } }); @@ -661,7 +661,7 @@ mod tests { ) }); - let output = result.output.await.log_err(); + let output = result.output.await.log_err().map(|output| output.content); assert_eq!(output, Some("Command executed successfully.".into())); } @@ -693,7 +693,11 @@ mod tests { cx, ); cx.spawn(async move |_| { - let output = headless_result.output.await.log_err(); + let output = headless_result + .output + .await + .log_err() + .map(|output| output.content); assert_eq!(output, expected); }) }; diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index fc6946d4d7..ae0cf31945 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -55,7 +55,7 @@ impl Tool for ThinkingTool { ) -> ToolResult { // This tool just "thinks out loud" and doesn't perform any actions. Task::ready(match serde_json::from_value::(input) { - Ok(_input) => Ok("Finished thinking.".to_string()), + Ok(_input) => Ok("Finished thinking.".to_string().into()), Err(err) => Err(anyhow!(err)), }) .into() diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index a2afde9f97..747c62f508 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -72,7 +72,9 @@ impl Tool for WebSearchTool { let search_task = search_task.clone(); async move { let response = search_task.await.map_err(|err| anyhow!(err))?; - serde_json::to_string(&response).context("Failed to serialize search results") + serde_json::to_string(&response) + .context("Failed to serialize search results") + .map(Into::into) } }); diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 28f6b8c133..23c899e9d0 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -131,6 +131,7 @@ pub struct LanguageModelToolResult { pub tool_name: Arc, pub is_error: bool, pub content: Arc, + pub output: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]