Restore tool cards on thread deserialization (#30053)

Release Notes:

- N/A

---------

Co-authored-by: Julia Ryan <juliaryan3.14@gmail.com>
This commit is contained in:
Mikayla Maki 2025-05-06 18:16:34 -07:00 committed by GitHub
parent ab3e5cdc6c
commit 0cdd8bdded
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 307 additions and 135 deletions

View file

@ -510,6 +510,7 @@ impl AssistantPanel {
thread_store.clone(), thread_store.clone(),
context_store.clone(), context_store.clone(),
[RecentEntry::Thread(thread_id, thread.clone())], [RecentEntry::Thread(thread_id, thread.clone())],
window,
cx, cx,
) )
}); });
@ -764,9 +765,9 @@ impl AssistantPanel {
}); });
if let Some(other_thread_id) = action.from_thread_id.clone() { if let Some(other_thread_id) = action.from_thread_id.clone() {
let other_thread_task = self let other_thread_task = self.thread_store.update(cx, |this, cx| {
.thread_store this.open_thread(&other_thread_id, window, cx)
.update(cx, |this, cx| this.open_thread(&other_thread_id, cx)); });
cx.spawn({ cx.spawn({
let context_store = context_store.clone(); let context_store = context_store.clone();
@ -967,7 +968,7 @@ impl AssistantPanel {
) -> Task<Result<()>> { ) -> Task<Result<()>> {
let open_thread_task = self let open_thread_task = self
.thread_store .thread_store
.update(cx, |this, cx| this.open_thread(thread_id, cx)); .update(cx, |this, cx| this.open_thread(thread_id, window, cx));
cx.spawn_in(window, async move |this, cx| { cx.spawn_in(window, async move |this, cx| {
let thread = open_thread_task.await?; let thread = open_thread_task.await?;
this.update_in(cx, |this, window, cx| { this.update_in(cx, |this, window, cx| {

View file

@ -425,9 +425,9 @@ impl ContextPicker {
render_thread_context_entry(&view_thread, context_store.clone(), cx) render_thread_context_entry(&view_thread, context_store.clone(), cx)
.into_any() .into_any()
}, },
move |_window, cx| { move |window, cx| {
context_picker.update(cx, |this, cx| { context_picker.update(cx, |this, cx| {
this.add_recent_thread(thread.clone(), cx) this.add_recent_thread(thread.clone(), window, cx)
.detach_and_log_err(cx); .detach_and_log_err(cx);
}) })
}, },
@ -459,6 +459,7 @@ impl ContextPicker {
fn add_recent_thread( fn add_recent_thread(
&self, &self,
entry: ThreadContextEntry, entry: ThreadContextEntry,
window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
let Some(context_store) = self.context_store.upgrade() else { let Some(context_store) = self.context_store.upgrade() else {
@ -476,7 +477,7 @@ impl ContextPicker {
}; };
let open_thread_task = let open_thread_task =
thread_store.update(cx, |this, cx| this.open_thread(&id, cx)); thread_store.update(cx, |this, cx| this.open_thread(&id, window, cx));
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let thread = open_thread_task.await?; let thread = open_thread_task.await?;
context_store.update(cx, |context_store, cx| { context_store.update(cx, |context_store, cx| {

View file

@ -438,15 +438,15 @@ impl ContextPickerCompletionProvider {
new_text_len, new_text_len,
editor.clone(), editor.clone(),
context_store.clone(), context_store.clone(),
move |cx| match &thread_entry { move |window, cx| match &thread_entry {
ThreadContextEntry::Thread { id, .. } => { ThreadContextEntry::Thread { id, .. } => {
let thread_id = id.clone(); let thread_id = id.clone();
let context_store = context_store.clone(); let context_store = context_store.clone();
let thread_store = thread_store.clone(); let thread_store = thread_store.clone();
cx.spawn::<_, Option<_>>(async move |cx| { window.spawn::<_, Option<_>>(cx, async move |cx| {
let thread: Entity<Thread> = thread_store let thread: Entity<Thread> = thread_store
.update(cx, |thread_store, cx| { .update_in(cx, |thread_store, window, cx| {
thread_store.open_thread(&thread_id, cx) thread_store.open_thread(&thread_id, window, cx)
}) })
.ok()? .ok()?
.await .await
@ -507,7 +507,7 @@ impl ContextPickerCompletionProvider {
new_text_len, new_text_len,
editor.clone(), editor.clone(),
context_store.clone(), context_store.clone(),
move |cx| { move |_, cx| {
let user_prompt_id = rules.prompt_id; let user_prompt_id = rules.prompt_id;
let context = context_store.update(cx, |context_store, cx| { let context = context_store.update(cx, |context_store, cx| {
context_store.add_rules(user_prompt_id, false, cx) context_store.add_rules(user_prompt_id, false, cx)
@ -544,7 +544,7 @@ impl ContextPickerCompletionProvider {
new_text_len, new_text_len,
editor.clone(), editor.clone(),
context_store.clone(), context_store.clone(),
move |cx| { move |_, cx| {
let context_store = context_store.clone(); let context_store = context_store.clone();
let http_client = http_client.clone(); let http_client = http_client.clone();
let url_to_fetch = url_to_fetch.clone(); let url_to_fetch = url_to_fetch.clone();
@ -629,7 +629,7 @@ impl ContextPickerCompletionProvider {
new_text_len, new_text_len,
editor, editor,
context_store.clone(), context_store.clone(),
move |cx| { move |_, cx| {
if is_directory { if is_directory {
Task::ready( Task::ready(
context_store context_store
@ -700,7 +700,7 @@ impl ContextPickerCompletionProvider {
new_text_len, new_text_len,
editor.clone(), editor.clone(),
context_store.clone(), context_store.clone(),
move |cx| { move |_, cx| {
let symbol = symbol.clone(); let symbol = symbol.clone();
let context_store = context_store.clone(); let context_store = context_store.clone();
let workspace = workspace.clone(); let workspace = workspace.clone();
@ -954,10 +954,13 @@ fn confirm_completion_callback(
content_len: usize, content_len: usize,
editor: Entity<Editor>, editor: Entity<Editor>,
context_store: Entity<ContextStore>, context_store: Entity<ContextStore>,
add_context_fn: impl Fn(&mut App) -> Task<Option<AgentContextHandle>> + Send + Sync + 'static, add_context_fn: impl Fn(&mut Window, &mut App) -> Task<Option<AgentContextHandle>>
+ Send
+ Sync
+ 'static,
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> { ) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
Arc::new(move |_, window, cx| { Arc::new(move |_, window, cx| {
let context = add_context_fn(cx); let context = add_context_fn(window, cx);
let crease_text = crease_text.clone(); let crease_text = crease_text.clone();
let crease_icon_path = crease_icon_path.clone(); let crease_icon_path = crease_icon_path.clone();

View file

@ -154,7 +154,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
}) })
} }
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) { fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(entry) = self.matches.get(self.selected_index) else { let Some(entry) = self.matches.get(self.selected_index) else {
return; return;
}; };
@ -165,7 +165,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
return; return;
}; };
let open_thread_task = let open_thread_task =
thread_store.update(cx, |this, cx| this.open_thread(&id, cx)); thread_store.update(cx, |this, cx| this.open_thread(&id, window, cx));
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let thread = open_thread_task.await?; let thread = open_thread_task.await?;

View file

@ -115,7 +115,7 @@ impl Tool for ContextServerTool {
} }
} }
} }
Ok(result) Ok(result.into())
}) })
.into() .into()
} else { } else {

View file

@ -8,7 +8,7 @@ use gpui::{Entity, Task, prelude::*};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use smol::future::FutureExt; use smol::future::FutureExt;
use std::time::Duration; use std::time::Duration;
use ui::{App, SharedString}; use ui::{App, SharedString, Window};
use util::ResultExt as _; use util::ResultExt as _;
use crate::{ use crate::{
@ -82,6 +82,7 @@ impl HistoryStore {
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>, context_store: Entity<assistant_context_editor::ContextStore>,
initial_recent_entries: impl IntoIterator<Item = RecentEntry>, initial_recent_entries: impl IntoIterator<Item = RecentEntry>,
window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let subscriptions = vec![ let subscriptions = vec![
@ -89,56 +90,62 @@ impl HistoryStore {
cx.observe(&context_store, |_, _, cx| cx.notify()), cx.observe(&context_store, |_, _, cx| cx.notify()),
]; ];
cx.spawn({ window
let thread_store = thread_store.downgrade(); .spawn(cx, {
let context_store = context_store.downgrade(); let thread_store = thread_store.downgrade();
async move |this, cx| { let context_store = context_store.downgrade();
let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH); let this = cx.weak_entity();
let contents = cx async move |cx| {
.background_spawn(async move { std::fs::read_to_string(path) }) let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
.await let contents = cx
.ok()?; .background_spawn(async move { std::fs::read_to_string(path) })
let entries = serde_json::from_str::<Vec<SerializedRecentEntry>>(&contents) .await
.context("deserializing persisted agent panel navigation history") .ok()?;
.log_err()? let entries = serde_json::from_str::<Vec<SerializedRecentEntry>>(&contents)
.into_iter() .context("deserializing persisted agent panel navigation history")
.take(MAX_RECENTLY_OPENED_ENTRIES) .log_err()?
.map(|serialized| match serialized { .into_iter()
SerializedRecentEntry::Thread(id) => thread_store .take(MAX_RECENTLY_OPENED_ENTRIES)
.update(cx, |thread_store, cx| { .map(|serialized| match serialized {
let thread_id = ThreadId::from(id.as_str()); SerializedRecentEntry::Thread(id) => thread_store
thread_store .update_in(cx, |thread_store, window, cx| {
.open_thread(&thread_id, cx) let thread_id = ThreadId::from(id.as_str());
.map_ok(|thread| RecentEntry::Thread(thread_id, thread)) thread_store
.boxed() .open_thread(&thread_id, window, cx)
}) .map_ok(|thread| RecentEntry::Thread(thread_id, thread))
.unwrap_or_else(|_| async { Err(anyhow!("no thread store")) }.boxed()), .boxed()
SerializedRecentEntry::Context(id) => context_store })
.update(cx, |context_store, cx| { .unwrap_or_else(|_| {
context_store async { Err(anyhow!("no thread store")) }.boxed()
.open_local_context(Path::new(&id).into(), cx) }),
.map_ok(RecentEntry::Context) SerializedRecentEntry::Context(id) => context_store
.boxed() .update(cx, |context_store, cx| {
}) context_store
.unwrap_or_else(|_| async { Err(anyhow!("no context store")) }.boxed()), .open_local_context(Path::new(&id).into(), cx)
}); .map_ok(RecentEntry::Context)
let entries = join_all(entries) .boxed()
.await })
.into_iter() .unwrap_or_else(|_| {
.filter_map(|result| result.log_err()) async { Err(anyhow!("no context store")) }.boxed()
.collect::<VecDeque<_>>(); }),
});
let entries = join_all(entries)
.await
.into_iter()
.filter_map(|result| result.log_err())
.collect::<VecDeque<_>>();
this.update(cx, |this, _| { this.update(cx, |this, _| {
this.recently_opened_entries.extend(entries); this.recently_opened_entries.extend(entries);
this.recently_opened_entries this.recently_opened_entries
.truncate(MAX_RECENTLY_OPENED_ENTRIES); .truncate(MAX_RECENTLY_OPENED_ENTRIES);
}) })
.ok(); .ok();
Some(()) Some(())
} }
}) })
.detach(); .detach();
Self { Self {
thread_store, thread_store,

View file

@ -35,6 +35,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::Settings; use settings::Settings;
use thiserror::Error; use thiserror::Error;
use ui::Window;
use util::{ResultExt as _, TryFutureExt as _, post_inc}; use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid; use uuid::Uuid;
use zed_llm_client::CompletionRequestStatus; use zed_llm_client::CompletionRequestStatus;
@ -430,6 +431,7 @@ impl Thread {
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext, project_context: SharedProjectContext,
window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let next_message_id = MessageId( let next_message_id = MessageId(
@ -439,7 +441,13 @@ impl Thread {
.map(|message| message.id.0 + 1) .map(|message| message.id.0 + 1)
.unwrap_or(0), .unwrap_or(0),
); );
let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages); let tool_use = ToolUseState::from_serialized_messages(
tools.clone(),
&serialized.messages,
project.clone(),
window,
cx,
);
let (detailed_summary_tx, detailed_summary_rx) = let (detailed_summary_tx, detailed_summary_rx) =
postage::watch::channel_with(serialized.detailed_summary_state); postage::watch::channel_with(serialized.detailed_summary_state);
@ -1064,6 +1072,7 @@ impl Thread {
tool_use_id: tool_result.tool_use_id.clone(), tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error, is_error: tool_result.is_error,
content: tool_result.content.clone(), content: tool_result.content.clone(),
output: tool_result.output.clone(),
}) })
.collect(), .collect(),
context: message.loaded_context.text.clone(), context: message.loaded_context.text.clone(),

View file

@ -28,6 +28,7 @@ use prompt_store::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore}; use settings::{Settings as _, SettingsStore};
use ui::Window;
use util::ResultExt as _; use util::ResultExt as _;
use crate::context_server_tool::ContextServerTool; use crate::context_server_tool::ContextServerTool;
@ -388,18 +389,20 @@ impl ThreadStore {
pub fn open_thread( pub fn open_thread(
&self, &self,
id: &ThreadId, id: &ThreadId,
window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Task<Result<Entity<Thread>>> { ) -> Task<Result<Entity<Thread>>> {
let id = id.clone(); let id = id.clone();
let database_future = ThreadsDatabase::global_future(cx); let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| { let this = cx.weak_entity();
window.spawn(cx, async move |cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?; let database = database_future.await.map_err(|err| anyhow!(err))?;
let thread = database let thread = database
.try_find_thread(id.clone()) .try_find_thread(id.clone())
.await? .await?
.ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?; .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
let thread = this.update(cx, |this, cx| { let thread = this.update_in(cx, |this, window, cx| {
cx.new(|cx| { cx.new(|cx| {
Thread::deserialize( Thread::deserialize(
id.clone(), id.clone(),
@ -408,6 +411,7 @@ impl ThreadStore {
this.tools.clone(), this.tools.clone(),
this.prompt_builder.clone(), this.prompt_builder.clone(),
this.project_context.clone(), this.project_context.clone(),
window,
cx, cx,
) )
}) })
@ -772,6 +776,7 @@ pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId, pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool, pub is_error: bool,
pub content: Arc<str>, pub content: Arc<str>,
pub output: Option<serde_json::Value>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet}; use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet};
use collections::HashMap; use collections::HashMap;
use futures::FutureExt as _; use futures::FutureExt as _;
use futures::future::Shared; use futures::future::Shared;
@ -10,7 +10,8 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult, ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
}; };
use ui::IconName; use project::Project;
use ui::{IconName, Window};
use util::truncate_lines_to_byte_limit; use util::truncate_lines_to_byte_limit;
use crate::thread::{MessageId, PromptId, ThreadId}; use crate::thread::{MessageId, PromptId, ThreadId};
@ -54,6 +55,9 @@ impl ToolUseState {
pub fn from_serialized_messages( pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage], messages: &[SerializedMessage],
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Self { ) -> Self {
let mut this = Self::new(tools); let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default(); let mut tool_names_by_id = HashMap::default();
@ -93,12 +97,23 @@ impl ToolUseState {
this.tool_results.insert( this.tool_results.insert(
tool_use_id.clone(), tool_use_id.clone(),
LanguageModelToolResult { LanguageModelToolResult {
tool_use_id, tool_use_id: tool_use_id.clone(),
tool_name: tool_use.clone(), tool_name: tool_use.clone(),
is_error: tool_result.is_error, is_error: tool_result.is_error,
content: tool_result.content.clone(), content: tool_result.content.clone(),
output: tool_result.output.clone(),
}, },
); );
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
if let Some(output) = tool_result.output.clone() {
if let Some(card) =
tool.deserialize_card(output, project.clone(), window, cx)
{
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
} }
} }
} }
@ -124,6 +139,7 @@ impl ToolUseState {
tool_use_id: tool_use_id.clone(), tool_use_id: tool_use_id.clone(),
tool_name: tool_use.name.clone(), tool_name: tool_use.name.clone(),
content, content,
output: None,
is_error: true, is_error: true,
}, },
); );
@ -359,7 +375,7 @@ impl ToolUseState {
&mut self, &mut self,
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>, tool_name: Arc<str>,
output: Result<String>, output: Result<ToolResultOutput>,
configured_model: Option<&ConfiguredModel>, configured_model: Option<&ConfiguredModel>,
) -> Option<PendingToolUse> { ) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id); let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
@ -379,7 +395,8 @@ impl ToolUseState {
); );
match output { match output {
Ok(tool_result) => { Ok(output) => {
let tool_result = output.content;
const BYTES_PER_TOKEN_ESTIMATE: usize = 3; const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
// Protect from clearly large output // Protect from clearly large output
@ -406,6 +423,7 @@ impl ToolUseState {
tool_name, tool_name,
content: tool_result.into(), content: tool_result.into(),
is_error: false, is_error: false,
output: output.output,
}, },
); );
self.pending_tool_uses_by_id.remove(&tool_use_id) self.pending_tool_uses_by_id.remove(&tool_use_id)
@ -418,6 +436,7 @@ impl ToolUseState {
tool_name, tool_name,
content: err.to_string().into(), content: err.to_string().into(),
is_error: true, is_error: true,
output: None,
}, },
); );
@ -490,6 +509,7 @@ impl ToolUseState {
} else { } else {
tool_result.content.clone() tool_result.content.clone()
}, },
output: None,
})); }));
} }
} }

View file

@ -7,6 +7,7 @@ mod tool_working_set;
use std::fmt; use std::fmt;
use std::fmt::Debug; use std::fmt::Debug;
use std::fmt::Formatter; use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
@ -61,11 +62,34 @@ impl ToolUseStatus {
} }
} }
#[derive(Debug)]
pub struct ToolResultOutput {
pub content: String,
pub output: Option<serde_json::Value>,
}
impl From<String> for ToolResultOutput {
fn from(value: String) -> Self {
ToolResultOutput {
content: value,
output: None,
}
}
}
impl Deref for ToolResultOutput {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.content
}
}
/// The result of running a tool, containing both the asynchronous output /// The result of running a tool, containing both the asynchronous output
/// and an optional card view that can be rendered immediately. /// and an optional card view that can be rendered immediately.
pub struct ToolResult { pub struct ToolResult {
/// The asynchronous task that will eventually resolve to the tool's output /// The asynchronous task that will eventually resolve to the tool's output
pub output: Task<Result<String>>, pub output: Task<Result<ToolResultOutput>>,
/// An optional view to present the output of the tool. /// An optional view to present the output of the tool.
pub card: Option<AnyToolCard>, pub card: Option<AnyToolCard>,
} }
@ -128,9 +152,9 @@ impl AnyToolCard {
} }
} }
impl From<Task<Result<String>>> for ToolResult { impl From<Task<Result<ToolResultOutput>>> for ToolResult {
/// Convert from a task to a ToolResult with no card /// Convert from a task to a ToolResult with no card
fn from(output: Task<Result<String>>) -> Self { fn from(output: Task<Result<ToolResultOutput>>) -> Self {
Self { output, card: None } Self { output, card: None }
} }
} }
@ -187,6 +211,16 @@ pub trait Tool: 'static + Send + Sync {
window: Option<AnyWindowHandle>, window: Option<AnyWindowHandle>,
cx: &mut App, cx: &mut App,
) -> ToolResult; ) -> ToolResult;
fn deserialize_card(
self: Arc<Self>,
_output: serde_json::Value,
_project: Entity<Project>,
_window: &mut Window,
_cx: &mut App,
) -> Option<AnyToolCard> {
None
}
} }
impl Debug for dyn Tool { impl Debug for dyn Tool {

View file

@ -107,10 +107,9 @@ impl Tool for CopyPathTool {
cx.background_spawn(async move { cx.background_spawn(async move {
match copy_task.await { match copy_task.await {
Ok(_) => Ok(format!( Ok(_) => Ok(
"Copied {} to {}", format!("Copied {} to {}", input.source_path, input.destination_path).into(),
input.source_path, input.destination_path ),
)),
Err(err) => Err(anyhow!( Err(err) => Err(anyhow!(
"Failed to copy {} to {}: {}", "Failed to copy {} to {}: {}",
input.source_path, input.source_path,

View file

@ -88,7 +88,7 @@ impl Tool for CreateDirectoryTool {
.await .await
.map_err(|err| anyhow!("Unable to create directory {destination_path}: {err}"))?; .map_err(|err| anyhow!("Unable to create directory {destination_path}: {err}"))?;
Ok(format!("Created directory {destination_path}")) Ok(format!("Created directory {destination_path}").into())
}) })
.into() .into()
} }

View file

@ -131,7 +131,7 @@ impl Tool for CreateFileTool {
.await .await
.map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?; .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?;
Ok(format!("Created file {destination_path}")) Ok(format!("Created file {destination_path}").into())
}) })
.into() .into()
} }

View file

@ -127,7 +127,7 @@ impl Tool for DeletePathTool {
match delete { match delete {
Some(deletion_task) => match deletion_task.await { Some(deletion_task) => match deletion_task.await {
Ok(()) => Ok(format!("Deleted {path_str}")), Ok(()) => Ok(format!("Deleted {path_str}").into()),
Err(err) => Err(anyhow!("Failed to delete {path_str}: {err}")), Err(err) => Err(anyhow!("Failed to delete {path_str}: {err}")),
}, },
None => Err(anyhow!( None => Err(anyhow!(

View file

@ -122,9 +122,9 @@ impl Tool for DiagnosticsTool {
} }
if output.is_empty() { if output.is_empty() {
Ok("File doesn't have errors or warnings!".to_string()) Ok("File doesn't have errors or warnings!".to_string().into())
} else { } else {
Ok(output) Ok(output.into())
} }
}) })
.into() .into()
@ -158,10 +158,12 @@ impl Tool for DiagnosticsTool {
}); });
if has_diagnostics { if has_diagnostics {
Task::ready(Ok(output)).into() Task::ready(Ok(output.into())).into()
} else { } else {
Task::ready(Ok("No errors or warnings found in the project.".to_string())) Task::ready(Ok("No errors or warnings found in the project."
.into() .to_string()
.into()))
.into()
} }
} }
} }

View file

@ -895,6 +895,7 @@ fn tool_result(
tool_name: name.into(), tool_name: name.into(),
is_error: false, is_error: false,
content: result.into(), content: result.into(),
output: None,
}) })
} }

View file

@ -1,9 +1,12 @@
use crate::{ use crate::{
replace::{replace_exact, replace_with_flexible_indent}, replace::{replace_exact, replace_with_flexible_indent},
schema::json_schema_for, schema::json_schema_for,
streaming_edit_file_tool::StreamingEditFileToolOutput,
}; };
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus}; use assistant_tool::{
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus,
};
use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorElement, EditorMode, EditorStyle, MultiBuffer, PathKey}; use editor::{Editor, EditorElement, EditorMode, EditorStyle, MultiBuffer, PathKey};
use gpui::{ use gpui::{
@ -153,7 +156,7 @@ impl Tool for EditFileTool {
}); });
let card_clone = card.clone(); let card_clone = card.clone();
let task = cx.spawn(async move |cx: &mut AsyncApp| { let task: Task<Result<ToolResultOutput, _>> = cx.spawn(async move |cx: &mut AsyncApp| {
let project_path = project.read_with(cx, |project, cx| { let project_path = project.read_with(cx, |project, cx| {
project project
.find_project_path(&input.path, cx) .find_project_path(&input.path, cx)
@ -281,16 +284,29 @@ impl Tool for EditFileTool {
if let Some(card) = card_clone { if let Some(card) = card_clone {
card.update(cx, |card, cx| { card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx); card.set_diff(
project_path.path.clone(),
old_text.clone(),
new_text.clone(),
cx,
);
}) })
.log_err(); .log_err();
} }
Ok(format!( Ok(ToolResultOutput {
"Edited {}:\n\n```diff\n{}\n```", content: format!(
input.path.display(), "Edited {}:\n\n```diff\n{}\n```",
diff_str input.path.display(),
)) diff_str
),
output: serde_json::to_value(StreamingEditFileToolOutput {
original_path: input.path,
new_text,
old_text,
})
.ok(),
})
}); });
ToolResult { ToolResult {
@ -298,6 +314,32 @@ impl Tool for EditFileTool {
card: card.map(AnyToolCard::from), card: card.map(AnyToolCard::from),
} }
} }
fn deserialize_card(
self: Arc<Self>,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
let output = match serde_json::from_value::<StreamingEditFileToolOutput>(output) {
Ok(output) => output,
Err(_) => return None,
};
let card = cx.new(|cx| {
let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx);
card.set_diff(
output.original_path.into(),
output.old_text,
output.new_text,
cx,
);
card
});
Some(card.into())
}
} }
pub struct EditFileToolCard { pub struct EditFileToolCard {

View file

@ -166,7 +166,7 @@ impl Tool for FetchTool {
bail!("no textual content found"); bail!("no textual content found");
} }
Ok(text) Ok(text.into())
}) })
.into() .into()
} }

View file

@ -98,7 +98,7 @@ impl Tool for FindPathTool {
sender.send(paginated_matches.to_vec()).log_err(); sender.send(paginated_matches.to_vec()).log_err();
if matches.is_empty() { if matches.is_empty() {
Ok("No matches found".to_string()) Ok("No matches found".to_string().into())
} else { } else {
let mut message = format!("Found {} total matches.", matches.len()); let mut message = format!("Found {} total matches.", matches.len());
if matches.len() > RESULTS_PER_PAGE { if matches.len() > RESULTS_PER_PAGE {
@ -113,7 +113,7 @@ impl Tool for FindPathTool {
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) { for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap(); write!(&mut message, "\n{}", mat.display()).unwrap();
} }
Ok(message) Ok(message.into())
} }
}); });

View file

@ -260,16 +260,16 @@ impl Tool for GrepTool {
} }
if matches_found == 0 { if matches_found == 0 {
Ok("No matches found".to_string()) Ok("No matches found".to_string().into())
} else if has_more_matches { } else if has_more_matches {
Ok(format!( Ok(format!(
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}", "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
input.offset + 1, input.offset + 1,
input.offset + matches_found, input.offset + matches_found,
input.offset + RESULTS_PER_PAGE, input.offset + RESULTS_PER_PAGE,
)) ).into())
} else { } else {
Ok(format!("Found {matches_found} matches:\n{output}")) Ok(format!("Found {matches_found} matches:\n{output}").into())
} }
}).into() }).into()
} }
@ -748,9 +748,9 @@ mod tests {
match task.output.await { match task.output.await {
Ok(result) => { Ok(result) => {
if cfg!(windows) { if cfg!(windows) {
result.replace("root\\", "root/") result.content.replace("root\\", "root/")
} else { } else {
result result.content
} }
} }
Err(e) => panic!("Failed to run grep tool: {}", e), Err(e) => panic!("Failed to run grep tool: {}", e),

View file

@ -102,7 +102,7 @@ impl Tool for ListDirectoryTool {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"); .join("\n");
return Task::ready(Ok(output)).into(); return Task::ready(Ok(output.into())).into();
} }
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
@ -134,8 +134,8 @@ impl Tool for ListDirectoryTool {
.unwrap(); .unwrap();
} }
if output.is_empty() { if output.is_empty() {
return Task::ready(Ok(format!("{} is empty.", input.path))).into(); return Task::ready(Ok(format!("{} is empty.", input.path).into())).into();
} }
Task::ready(Ok(output)).into() Task::ready(Ok(output.into())).into()
} }
} }

View file

@ -117,10 +117,9 @@ impl Tool for MovePathTool {
cx.background_spawn(async move { cx.background_spawn(async move {
match rename_task.await { match rename_task.await {
Ok(_) => Ok(format!( Ok(_) => {
"Moved {} to {}", Ok(format!("Moved {} to {}", input.source_path, input.destination_path).into())
input.source_path, input.destination_path }
)),
Err(err) => Err(anyhow!( Err(err) => Err(anyhow!(
"Failed to move {} to {}: {}", "Failed to move {} to {}: {}",
input.source_path, input.source_path,

View file

@ -73,6 +73,6 @@ impl Tool for NowTool {
}; };
let text = format!("The current datetime is {now}."); let text = format!("The current datetime is {now}.");
Task::ready(Ok(text)).into() Task::ready(Ok(text.into())).into()
} }
} }

View file

@ -70,7 +70,7 @@ impl Tool for OpenTool {
} }
.context("Failed to open URL or file path")?; .context("Failed to open URL or file path")?;
Ok(format!("Successfully opened {}", input.path_or_url)) Ok(format!("Successfully opened {}", input.path_or_url).into())
}) })
.into() .into()
} }

View file

@ -145,9 +145,9 @@ impl Tool for ReadFileTool {
let lines = text.split('\n').skip(start_row as usize); let lines = text.split('\n').skip(start_row as usize);
if let Some(end) = input.end_line { if let Some(end) = input.end_line {
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
Itertools::intersperse(lines.take(count as usize), "\n").collect() Itertools::intersperse(lines.take(count as usize), "\n").collect::<String>().into()
} else { } else {
Itertools::intersperse(lines, "\n").collect() Itertools::intersperse(lines, "\n").collect::<String>().into()
} }
})?; })?;
@ -180,7 +180,7 @@ impl Tool for ReadFileTool {
log.buffer_read(buffer, cx); log.buffer_read(buffer, cx);
})?; })?;
Ok(result) Ok(result.into())
} else { } else {
// File is too big, so return the outline // File is too big, so return the outline
// and a suggestion to read again with line numbers. // and a suggestion to read again with line numbers.
@ -192,7 +192,7 @@ impl Tool for ReadFileTool {
Using the line numbers in this outline, you can call this tool again while specifying Using the line numbers in this outline, you can call this tool again while specifying
the start_line and end_line fields to see the implementations of symbols in the outline." the start_line and end_line fields to see the implementations of symbols in the outline."
}) }.into())
} }
} }
}) })
@ -258,7 +258,7 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap(), "This is a small file content"); assert_eq!(result.unwrap().content, "This is a small file content");
} }
#[gpui::test] #[gpui::test]
@ -358,7 +358,7 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4"); assert_eq!(result.unwrap().content, "Line 2\nLine 3\nLine 4");
} }
#[gpui::test] #[gpui::test]
@ -389,7 +389,7 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1\nLine 2"); assert_eq!(result.unwrap().content, "Line 1\nLine 2");
// end_line of 0 should result in at least 1 line // end_line of 0 should result in at least 1 line
let result = cx let result = cx
@ -404,7 +404,7 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1"); assert_eq!(result.unwrap().content, "Line 1");
// when start_line > end_line, should still return at least 1 line // when start_line > end_line, should still return at least 1 line
let result = cx let result = cx
@ -419,7 +419,7 @@ mod test {
.output .output
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 3"); assert_eq!(result.unwrap().content, "Line 3");
} }
fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {

View file

@ -5,7 +5,7 @@ use crate::{
schema::json_schema_for, schema::json_schema_for,
}; };
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult, ToolResultOutput};
use futures::StreamExt; use futures::StreamExt;
use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task}; use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc; use indoc::formatdoc;
@ -67,6 +67,13 @@ pub struct StreamingEditFileToolInput {
pub create_or_overwrite: bool, pub create_or_overwrite: bool,
} }
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct StreamingEditFileToolOutput {
pub original_path: PathBuf,
pub new_text: String,
pub old_text: String,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput { struct PartialInput {
#[serde(default)] #[serde(default)]
@ -248,6 +255,12 @@ impl Tool for StreamingEditFileTool {
}); });
let (new_text, diff) = futures::join!(new_text, diff); let (new_text, diff) = futures::join!(new_text, diff);
let output = StreamingEditFileToolOutput {
original_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text: old_text.clone(),
};
if let Some(card) = card_clone { if let Some(card) = card_clone {
card.update(cx, |card, cx| { card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx); card.set_diff(project_path.path.clone(), old_text, new_text, cx);
@ -264,10 +277,13 @@ impl Tool for StreamingEditFileTool {
I can perform the requested edits. I can perform the requested edits.
"})) "}))
} else { } else {
Ok("No edits were made.".to_string()) Ok("No edits were made.".to_string().into())
} }
} else { } else {
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff)) Ok(ToolResultOutput {
content: format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff),
output: serde_json::to_value(output).ok(),
})
} }
}); });
@ -276,6 +292,32 @@ impl Tool for StreamingEditFileTool {
card: card.map(AnyToolCard::from), card: card.map(AnyToolCard::from),
} }
} }
fn deserialize_card(
self: Arc<Self>,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
let output = match serde_json::from_value::<StreamingEditFileToolOutput>(output) {
Ok(output) => output,
Err(_) => return None,
};
let card = cx.new(|cx| {
let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx);
card.set_diff(
output.original_path.into(),
output.old_text,
output.new_text,
cx,
);
card
});
Some(card.into())
}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -178,7 +178,7 @@ impl Tool for TerminalTool {
let exit_status = child.wait()?; let exit_status = child.wait()?;
let (processed_content, _) = let (processed_content, _) =
process_content(content, &input.command, Some(exit_status)); process_content(content, &input.command, Some(exit_status));
Ok(processed_content) Ok(processed_content.into())
}); });
return ToolResult { return ToolResult {
output: task, output: task,
@ -266,7 +266,7 @@ impl Tool for TerminalTool {
card.elapsed_time = Some(card.start_instant.elapsed()); card.elapsed_time = Some(card.start_instant.elapsed());
}); });
Ok(processed_content) Ok(processed_content.into())
} }
}); });
@ -661,7 +661,7 @@ mod tests {
) )
}); });
let output = result.output.await.log_err(); let output = result.output.await.log_err().map(|output| output.content);
assert_eq!(output, Some("Command executed successfully.".into())); assert_eq!(output, Some("Command executed successfully.".into()));
} }
@ -693,7 +693,11 @@ mod tests {
cx, cx,
); );
cx.spawn(async move |_| { cx.spawn(async move |_| {
let output = headless_result.output.await.log_err(); let output = headless_result
.output
.await
.log_err()
.map(|output| output.content);
assert_eq!(output, expected); assert_eq!(output, expected);
}) })
}; };

View file

@ -55,7 +55,7 @@ impl Tool for ThinkingTool {
) -> ToolResult { ) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions. // This tool just "thinks out loud" and doesn't perform any actions.
Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) { Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) {
Ok(_input) => Ok("Finished thinking.".to_string()), Ok(_input) => Ok("Finished thinking.".to_string().into()),
Err(err) => Err(anyhow!(err)), Err(err) => Err(anyhow!(err)),
}) })
.into() .into()

View file

@ -72,7 +72,9 @@ impl Tool for WebSearchTool {
let search_task = search_task.clone(); let search_task = search_task.clone();
async move { async move {
let response = search_task.await.map_err(|err| anyhow!(err))?; let response = search_task.await.map_err(|err| anyhow!(err))?;
serde_json::to_string(&response).context("Failed to serialize search results") serde_json::to_string(&response)
.context("Failed to serialize search results")
.map(Into::into)
} }
}); });

View file

@ -131,6 +131,7 @@ pub struct LanguageModelToolResult {
pub tool_name: Arc<str>, pub tool_name: Arc<str>,
pub is_error: bool, pub is_error: bool,
pub content: Arc<str>, pub content: Arc<str>,
pub output: Option<serde_json::Value>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]