Restore tool cards on thread deserialization (#30053)

Release Notes:

- N/A

---------

Co-authored-by: Julia Ryan <juliaryan3.14@gmail.com>
This commit is contained in:
Mikayla Maki 2025-05-06 18:16:34 -07:00 committed by GitHub
parent ab3e5cdc6c
commit 0cdd8bdded
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 307 additions and 135 deletions

View file

@ -510,6 +510,7 @@ impl AssistantPanel {
thread_store.clone(),
context_store.clone(),
[RecentEntry::Thread(thread_id, thread.clone())],
window,
cx,
)
});
@ -764,9 +765,9 @@ impl AssistantPanel {
});
if let Some(other_thread_id) = action.from_thread_id.clone() {
let other_thread_task = self
.thread_store
.update(cx, |this, cx| this.open_thread(&other_thread_id, cx));
let other_thread_task = self.thread_store.update(cx, |this, cx| {
this.open_thread(&other_thread_id, window, cx)
});
cx.spawn({
let context_store = context_store.clone();
@ -967,7 +968,7 @@ impl AssistantPanel {
) -> Task<Result<()>> {
let open_thread_task = self
.thread_store
.update(cx, |this, cx| this.open_thread(thread_id, cx));
.update(cx, |this, cx| this.open_thread(thread_id, window, cx));
cx.spawn_in(window, async move |this, cx| {
let thread = open_thread_task.await?;
this.update_in(cx, |this, window, cx| {

View file

@ -425,9 +425,9 @@ impl ContextPicker {
render_thread_context_entry(&view_thread, context_store.clone(), cx)
.into_any()
},
move |_window, cx| {
move |window, cx| {
context_picker.update(cx, |this, cx| {
this.add_recent_thread(thread.clone(), cx)
this.add_recent_thread(thread.clone(), window, cx)
.detach_and_log_err(cx);
})
},
@ -459,6 +459,7 @@ impl ContextPicker {
fn add_recent_thread(
&self,
entry: ThreadContextEntry,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let Some(context_store) = self.context_store.upgrade() else {
@ -476,7 +477,7 @@ impl ContextPicker {
};
let open_thread_task =
thread_store.update(cx, |this, cx| this.open_thread(&id, cx));
thread_store.update(cx, |this, cx| this.open_thread(&id, window, cx));
cx.spawn(async move |this, cx| {
let thread = open_thread_task.await?;
context_store.update(cx, |context_store, cx| {

View file

@ -438,15 +438,15 @@ impl ContextPickerCompletionProvider {
new_text_len,
editor.clone(),
context_store.clone(),
move |cx| match &thread_entry {
move |window, cx| match &thread_entry {
ThreadContextEntry::Thread { id, .. } => {
let thread_id = id.clone();
let context_store = context_store.clone();
let thread_store = thread_store.clone();
cx.spawn::<_, Option<_>>(async move |cx| {
window.spawn::<_, Option<_>>(cx, async move |cx| {
let thread: Entity<Thread> = thread_store
.update(cx, |thread_store, cx| {
thread_store.open_thread(&thread_id, cx)
.update_in(cx, |thread_store, window, cx| {
thread_store.open_thread(&thread_id, window, cx)
})
.ok()?
.await
@ -507,7 +507,7 @@ impl ContextPickerCompletionProvider {
new_text_len,
editor.clone(),
context_store.clone(),
move |cx| {
move |_, cx| {
let user_prompt_id = rules.prompt_id;
let context = context_store.update(cx, |context_store, cx| {
context_store.add_rules(user_prompt_id, false, cx)
@ -544,7 +544,7 @@ impl ContextPickerCompletionProvider {
new_text_len,
editor.clone(),
context_store.clone(),
move |cx| {
move |_, cx| {
let context_store = context_store.clone();
let http_client = http_client.clone();
let url_to_fetch = url_to_fetch.clone();
@ -629,7 +629,7 @@ impl ContextPickerCompletionProvider {
new_text_len,
editor,
context_store.clone(),
move |cx| {
move |_, cx| {
if is_directory {
Task::ready(
context_store
@ -700,7 +700,7 @@ impl ContextPickerCompletionProvider {
new_text_len,
editor.clone(),
context_store.clone(),
move |cx| {
move |_, cx| {
let symbol = symbol.clone();
let context_store = context_store.clone();
let workspace = workspace.clone();
@ -954,10 +954,13 @@ fn confirm_completion_callback(
content_len: usize,
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
add_context_fn: impl Fn(&mut App) -> Task<Option<AgentContextHandle>> + Send + Sync + 'static,
add_context_fn: impl Fn(&mut Window, &mut App) -> Task<Option<AgentContextHandle>>
+ Send
+ Sync
+ 'static,
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
Arc::new(move |_, window, cx| {
let context = add_context_fn(cx);
let context = add_context_fn(window, cx);
let crease_text = crease_text.clone();
let crease_icon_path = crease_icon_path.clone();

View file

@ -154,7 +154,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(entry) = self.matches.get(self.selected_index) else {
return;
};
@ -165,7 +165,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
return;
};
let open_thread_task =
thread_store.update(cx, |this, cx| this.open_thread(&id, cx));
thread_store.update(cx, |this, cx| this.open_thread(&id, window, cx));
cx.spawn(async move |this, cx| {
let thread = open_thread_task.await?;

View file

@ -115,7 +115,7 @@ impl Tool for ContextServerTool {
}
}
}
Ok(result)
Ok(result.into())
})
.into()
} else {

View file

@ -8,7 +8,7 @@ use gpui::{Entity, Task, prelude::*};
use serde::{Deserialize, Serialize};
use smol::future::FutureExt;
use std::time::Duration;
use ui::{App, SharedString};
use ui::{App, SharedString, Window};
use util::ResultExt as _;
use crate::{
@ -82,6 +82,7 @@ impl HistoryStore {
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
initial_recent_entries: impl IntoIterator<Item = RecentEntry>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let subscriptions = vec![
@ -89,56 +90,62 @@ impl HistoryStore {
cx.observe(&context_store, |_, _, cx| cx.notify()),
];
cx.spawn({
let thread_store = thread_store.downgrade();
let context_store = context_store.downgrade();
async move |this, cx| {
let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
let contents = cx
.background_spawn(async move { std::fs::read_to_string(path) })
.await
.ok()?;
let entries = serde_json::from_str::<Vec<SerializedRecentEntry>>(&contents)
.context("deserializing persisted agent panel navigation history")
.log_err()?
.into_iter()
.take(MAX_RECENTLY_OPENED_ENTRIES)
.map(|serialized| match serialized {
SerializedRecentEntry::Thread(id) => thread_store
.update(cx, |thread_store, cx| {
let thread_id = ThreadId::from(id.as_str());
thread_store
.open_thread(&thread_id, cx)
.map_ok(|thread| RecentEntry::Thread(thread_id, thread))
.boxed()
})
.unwrap_or_else(|_| async { Err(anyhow!("no thread store")) }.boxed()),
SerializedRecentEntry::Context(id) => context_store
.update(cx, |context_store, cx| {
context_store
.open_local_context(Path::new(&id).into(), cx)
.map_ok(RecentEntry::Context)
.boxed()
})
.unwrap_or_else(|_| async { Err(anyhow!("no context store")) }.boxed()),
});
let entries = join_all(entries)
.await
.into_iter()
.filter_map(|result| result.log_err())
.collect::<VecDeque<_>>();
window
.spawn(cx, {
let thread_store = thread_store.downgrade();
let context_store = context_store.downgrade();
let this = cx.weak_entity();
async move |cx| {
let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
let contents = cx
.background_spawn(async move { std::fs::read_to_string(path) })
.await
.ok()?;
let entries = serde_json::from_str::<Vec<SerializedRecentEntry>>(&contents)
.context("deserializing persisted agent panel navigation history")
.log_err()?
.into_iter()
.take(MAX_RECENTLY_OPENED_ENTRIES)
.map(|serialized| match serialized {
SerializedRecentEntry::Thread(id) => thread_store
.update_in(cx, |thread_store, window, cx| {
let thread_id = ThreadId::from(id.as_str());
thread_store
.open_thread(&thread_id, window, cx)
.map_ok(|thread| RecentEntry::Thread(thread_id, thread))
.boxed()
})
.unwrap_or_else(|_| {
async { Err(anyhow!("no thread store")) }.boxed()
}),
SerializedRecentEntry::Context(id) => context_store
.update(cx, |context_store, cx| {
context_store
.open_local_context(Path::new(&id).into(), cx)
.map_ok(RecentEntry::Context)
.boxed()
})
.unwrap_or_else(|_| {
async { Err(anyhow!("no context store")) }.boxed()
}),
});
let entries = join_all(entries)
.await
.into_iter()
.filter_map(|result| result.log_err())
.collect::<VecDeque<_>>();
this.update(cx, |this, _| {
this.recently_opened_entries.extend(entries);
this.recently_opened_entries
.truncate(MAX_RECENTLY_OPENED_ENTRIES);
})
.ok();
this.update(cx, |this, _| {
this.recently_opened_entries.extend(entries);
this.recently_opened_entries
.truncate(MAX_RECENTLY_OPENED_ENTRIES);
})
.ok();
Some(())
}
})
.detach();
Some(())
}
})
.detach();
Self {
thread_store,

View file

@ -35,6 +35,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use thiserror::Error;
use ui::Window;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::CompletionRequestStatus;
@ -430,6 +431,7 @@ impl Thread {
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
@ -439,7 +441,13 @@ impl Thread {
.map(|message| message.id.0 + 1)
.unwrap_or(0),
);
let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
let tool_use = ToolUseState::from_serialized_messages(
tools.clone(),
&serialized.messages,
project.clone(),
window,
cx,
);
let (detailed_summary_tx, detailed_summary_rx) =
postage::watch::channel_with(serialized.detailed_summary_state);
@ -1064,6 +1072,7 @@ impl Thread {
tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
output: tool_result.output.clone(),
})
.collect(),
context: message.loaded_context.text.clone(),

View file

@ -28,6 +28,7 @@ use prompt_store::{
};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use ui::Window;
use util::ResultExt as _;
use crate::context_server_tool::ContextServerTool;
@ -388,18 +389,20 @@ impl ThreadStore {
pub fn open_thread(
&self,
id: &ThreadId,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<Entity<Thread>>> {
let id = id.clone();
let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| {
let this = cx.weak_entity();
window.spawn(cx, async move |cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
let thread = database
.try_find_thread(id.clone())
.await?
.ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
let thread = this.update(cx, |this, cx| {
let thread = this.update_in(cx, |this, window, cx| {
cx.new(|cx| {
Thread::deserialize(
id.clone(),
@ -408,6 +411,7 @@ impl ThreadStore {
this.tools.clone(),
this.prompt_builder.clone(),
this.project_context.clone(),
window,
cx,
)
})
@ -772,6 +776,7 @@ pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool,
pub content: Arc<str>,
pub output: Option<serde_json::Value>,
}
#[derive(Serialize, Deserialize)]

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
@ -10,7 +10,8 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
};
use ui::IconName;
use project::Project;
use ui::{IconName, Window};
use util::truncate_lines_to_byte_limit;
use crate::thread::{MessageId, PromptId, ThreadId};
@ -54,6 +55,9 @@ impl ToolUseState {
pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Self {
let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default();
@ -93,12 +97,23 @@ impl ToolUseState {
this.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id,
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
output: tool_result.output.clone(),
},
);
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
if let Some(output) = tool_result.output.clone() {
if let Some(card) =
tool.deserialize_card(output, project.clone(), window, cx)
{
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
}
}
}
@ -124,6 +139,7 @@ impl ToolUseState {
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.name.clone(),
content,
output: None,
is_error: true,
},
);
@ -359,7 +375,7 @@ impl ToolUseState {
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
output: Result<String>,
output: Result<ToolResultOutput>,
configured_model: Option<&ConfiguredModel>,
) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
@ -379,7 +395,8 @@ impl ToolUseState {
);
match output {
Ok(tool_result) => {
Ok(output) => {
let tool_result = output.content;
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
// Protect from clearly large output
@ -406,6 +423,7 @@ impl ToolUseState {
tool_name,
content: tool_result.into(),
is_error: false,
output: output.output,
},
);
self.pending_tool_uses_by_id.remove(&tool_use_id)
@ -418,6 +436,7 @@ impl ToolUseState {
tool_name,
content: err.to_string().into(),
is_error: true,
output: None,
},
);
@ -490,6 +509,7 @@ impl ToolUseState {
} else {
tool_result.content.clone()
},
output: None,
}));
}
}

View file

@ -7,6 +7,7 @@ mod tool_working_set;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;
use anyhow::Result;
@ -61,11 +62,34 @@ impl ToolUseStatus {
}
}
#[derive(Debug)]
pub struct ToolResultOutput {
pub content: String,
pub output: Option<serde_json::Value>,
}
impl From<String> for ToolResultOutput {
fn from(value: String) -> Self {
ToolResultOutput {
content: value,
output: None,
}
}
}
impl Deref for ToolResultOutput {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.content
}
}
/// The result of running a tool, containing both the asynchronous output
/// and an optional card view that can be rendered immediately.
pub struct ToolResult {
/// The asynchronous task that will eventually resolve to the tool's output
pub output: Task<Result<String>>,
pub output: Task<Result<ToolResultOutput>>,
/// An optional view to present the output of the tool.
pub card: Option<AnyToolCard>,
}
@ -128,9 +152,9 @@ impl AnyToolCard {
}
}
impl From<Task<Result<String>>> for ToolResult {
impl From<Task<Result<ToolResultOutput>>> for ToolResult {
/// Convert from a task to a ToolResult with no card
fn from(output: Task<Result<String>>) -> Self {
fn from(output: Task<Result<ToolResultOutput>>) -> Self {
Self { output, card: None }
}
}
@ -187,6 +211,16 @@ pub trait Tool: 'static + Send + Sync {
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult;
fn deserialize_card(
self: Arc<Self>,
_output: serde_json::Value,
_project: Entity<Project>,
_window: &mut Window,
_cx: &mut App,
) -> Option<AnyToolCard> {
None
}
}
impl Debug for dyn Tool {

View file

@ -107,10 +107,9 @@ impl Tool for CopyPathTool {
cx.background_spawn(async move {
match copy_task.await {
Ok(_) => Ok(format!(
"Copied {} to {}",
input.source_path, input.destination_path
)),
Ok(_) => Ok(
format!("Copied {} to {}", input.source_path, input.destination_path).into(),
),
Err(err) => Err(anyhow!(
"Failed to copy {} to {}: {}",
input.source_path,

View file

@ -88,7 +88,7 @@ impl Tool for CreateDirectoryTool {
.await
.map_err(|err| anyhow!("Unable to create directory {destination_path}: {err}"))?;
Ok(format!("Created directory {destination_path}"))
Ok(format!("Created directory {destination_path}").into())
})
.into()
}

View file

@ -131,7 +131,7 @@ impl Tool for CreateFileTool {
.await
.map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?;
Ok(format!("Created file {destination_path}"))
Ok(format!("Created file {destination_path}").into())
})
.into()
}

View file

@ -127,7 +127,7 @@ impl Tool for DeletePathTool {
match delete {
Some(deletion_task) => match deletion_task.await {
Ok(()) => Ok(format!("Deleted {path_str}")),
Ok(()) => Ok(format!("Deleted {path_str}").into()),
Err(err) => Err(anyhow!("Failed to delete {path_str}: {err}")),
},
None => Err(anyhow!(

View file

@ -122,9 +122,9 @@ impl Tool for DiagnosticsTool {
}
if output.is_empty() {
Ok("File doesn't have errors or warnings!".to_string())
Ok("File doesn't have errors or warnings!".to_string().into())
} else {
Ok(output)
Ok(output.into())
}
})
.into()
@ -158,10 +158,12 @@ impl Tool for DiagnosticsTool {
});
if has_diagnostics {
Task::ready(Ok(output)).into()
Task::ready(Ok(output.into())).into()
} else {
Task::ready(Ok("No errors or warnings found in the project.".to_string()))
.into()
Task::ready(Ok("No errors or warnings found in the project."
.to_string()
.into()))
.into()
}
}
}

View file

@ -895,6 +895,7 @@ fn tool_result(
tool_name: name.into(),
is_error: false,
content: result.into(),
output: None,
})
}

View file

@ -1,9 +1,12 @@
use crate::{
replace::{replace_exact, replace_with_flexible_indent},
schema::json_schema_for,
streaming_edit_file_tool::StreamingEditFileToolOutput,
};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
use assistant_tool::{
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus,
};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorElement, EditorMode, EditorStyle, MultiBuffer, PathKey};
use gpui::{
@ -153,7 +156,7 @@ impl Tool for EditFileTool {
});
let card_clone = card.clone();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
let task: Task<Result<ToolResultOutput, _>> = cx.spawn(async move |cx: &mut AsyncApp| {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
@ -281,16 +284,29 @@ impl Tool for EditFileTool {
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
card.set_diff(
project_path.path.clone(),
old_text.clone(),
new_text.clone(),
cx,
);
})
.log_err();
}
Ok(format!(
"Edited {}:\n\n```diff\n{}\n```",
input.path.display(),
diff_str
))
Ok(ToolResultOutput {
content: format!(
"Edited {}:\n\n```diff\n{}\n```",
input.path.display(),
diff_str
),
output: serde_json::to_value(StreamingEditFileToolOutput {
original_path: input.path,
new_text,
old_text,
})
.ok(),
})
});
ToolResult {
@ -298,6 +314,32 @@ impl Tool for EditFileTool {
card: card.map(AnyToolCard::from),
}
}
fn deserialize_card(
self: Arc<Self>,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
let output = match serde_json::from_value::<StreamingEditFileToolOutput>(output) {
Ok(output) => output,
Err(_) => return None,
};
let card = cx.new(|cx| {
let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx);
card.set_diff(
output.original_path.into(),
output.old_text,
output.new_text,
cx,
);
card
});
Some(card.into())
}
}
pub struct EditFileToolCard {

View file

@ -166,7 +166,7 @@ impl Tool for FetchTool {
bail!("no textual content found");
}
Ok(text)
Ok(text.into())
})
.into()
}

View file

@ -98,7 +98,7 @@ impl Tool for FindPathTool {
sender.send(paginated_matches.to_vec()).log_err();
if matches.is_empty() {
Ok("No matches found".to_string())
Ok("No matches found".to_string().into())
} else {
let mut message = format!("Found {} total matches.", matches.len());
if matches.len() > RESULTS_PER_PAGE {
@ -113,7 +113,7 @@ impl Tool for FindPathTool {
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
Ok(message.into())
}
});

View file

@ -260,16 +260,16 @@ impl Tool for GrepTool {
}
if matches_found == 0 {
Ok("No matches found".to_string())
Ok("No matches found".to_string().into())
} else if has_more_matches {
Ok(format!(
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
input.offset + 1,
input.offset + matches_found,
input.offset + RESULTS_PER_PAGE,
))
).into())
} else {
Ok(format!("Found {matches_found} matches:\n{output}"))
Ok(format!("Found {matches_found} matches:\n{output}").into())
}
}).into()
}
@ -748,9 +748,9 @@ mod tests {
match task.output.await {
Ok(result) => {
if cfg!(windows) {
result.replace("root\\", "root/")
result.content.replace("root\\", "root/")
} else {
result
result.content
}
}
Err(e) => panic!("Failed to run grep tool: {}", e),

View file

@ -102,7 +102,7 @@ impl Tool for ListDirectoryTool {
.collect::<Vec<_>>()
.join("\n");
return Task::ready(Ok(output)).into();
return Task::ready(Ok(output.into())).into();
}
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
@ -134,8 +134,8 @@ impl Tool for ListDirectoryTool {
.unwrap();
}
if output.is_empty() {
return Task::ready(Ok(format!("{} is empty.", input.path))).into();
return Task::ready(Ok(format!("{} is empty.", input.path).into())).into();
}
Task::ready(Ok(output)).into()
Task::ready(Ok(output.into())).into()
}
}

View file

@ -117,10 +117,9 @@ impl Tool for MovePathTool {
cx.background_spawn(async move {
match rename_task.await {
Ok(_) => Ok(format!(
"Moved {} to {}",
input.source_path, input.destination_path
)),
Ok(_) => {
Ok(format!("Moved {} to {}", input.source_path, input.destination_path).into())
}
Err(err) => Err(anyhow!(
"Failed to move {} to {}: {}",
input.source_path,

View file

@ -73,6 +73,6 @@ impl Tool for NowTool {
};
let text = format!("The current datetime is {now}.");
Task::ready(Ok(text)).into()
Task::ready(Ok(text.into())).into()
}
}

View file

@ -70,7 +70,7 @@ impl Tool for OpenTool {
}
.context("Failed to open URL or file path")?;
Ok(format!("Successfully opened {}", input.path_or_url))
Ok(format!("Successfully opened {}", input.path_or_url).into())
})
.into()
}

View file

@ -145,9 +145,9 @@ impl Tool for ReadFileTool {
let lines = text.split('\n').skip(start_row as usize);
if let Some(end) = input.end_line {
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
Itertools::intersperse(lines.take(count as usize), "\n").collect()
Itertools::intersperse(lines.take(count as usize), "\n").collect::<String>().into()
} else {
Itertools::intersperse(lines, "\n").collect()
Itertools::intersperse(lines, "\n").collect::<String>().into()
}
})?;
@ -180,7 +180,7 @@ impl Tool for ReadFileTool {
log.buffer_read(buffer, cx);
})?;
Ok(result)
Ok(result.into())
} else {
// File is too big, so return the outline
// and a suggestion to read again with line numbers.
@ -192,7 +192,7 @@ impl Tool for ReadFileTool {
Using the line numbers in this outline, you can call this tool again while specifying
the start_line and end_line fields to see the implementations of symbols in the outline."
})
}.into())
}
}
})
@ -258,7 +258,7 @@ mod test {
.output
})
.await;
assert_eq!(result.unwrap(), "This is a small file content");
assert_eq!(result.unwrap().content, "This is a small file content");
}
#[gpui::test]
@ -358,7 +358,7 @@ mod test {
.output
})
.await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
assert_eq!(result.unwrap().content, "Line 2\nLine 3\nLine 4");
}
#[gpui::test]
@ -389,7 +389,7 @@ mod test {
.output
})
.await;
assert_eq!(result.unwrap(), "Line 1\nLine 2");
assert_eq!(result.unwrap().content, "Line 1\nLine 2");
// end_line of 0 should result in at least 1 line
let result = cx
@ -404,7 +404,7 @@ mod test {
.output
})
.await;
assert_eq!(result.unwrap(), "Line 1");
assert_eq!(result.unwrap().content, "Line 1");
// when start_line > end_line, should still return at least 1 line
let result = cx
@ -419,7 +419,7 @@ mod test {
.output
})
.await;
assert_eq!(result.unwrap(), "Line 3");
assert_eq!(result.unwrap().content, "Line 3");
}
fn init_test(cx: &mut TestAppContext) {

View file

@ -5,7 +5,7 @@ use crate::{
schema::json_schema_for,
};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult, ToolResultOutput};
use futures::StreamExt;
use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc;
@ -67,6 +67,13 @@ pub struct StreamingEditFileToolInput {
pub create_or_overwrite: bool,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct StreamingEditFileToolOutput {
pub original_path: PathBuf,
pub new_text: String,
pub old_text: String,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
@ -248,6 +255,12 @@ impl Tool for StreamingEditFileTool {
});
let (new_text, diff) = futures::join!(new_text, diff);
let output = StreamingEditFileToolOutput {
original_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text: old_text.clone(),
};
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
@ -264,10 +277,13 @@ impl Tool for StreamingEditFileTool {
I can perform the requested edits.
"}))
} else {
Ok("No edits were made.".to_string())
Ok("No edits were made.".to_string().into())
}
} else {
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
Ok(ToolResultOutput {
content: format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff),
output: serde_json::to_value(output).ok(),
})
}
});
@ -276,6 +292,32 @@ impl Tool for StreamingEditFileTool {
card: card.map(AnyToolCard::from),
}
}
fn deserialize_card(
self: Arc<Self>,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
let output = match serde_json::from_value::<StreamingEditFileToolOutput>(output) {
Ok(output) => output,
Err(_) => return None,
};
let card = cx.new(|cx| {
let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx);
card.set_diff(
output.original_path.into(),
output.old_text,
output.new_text,
cx,
);
card
});
Some(card.into())
}
}
#[cfg(test)]

View file

@ -178,7 +178,7 @@ impl Tool for TerminalTool {
let exit_status = child.wait()?;
let (processed_content, _) =
process_content(content, &input.command, Some(exit_status));
Ok(processed_content)
Ok(processed_content.into())
});
return ToolResult {
output: task,
@ -266,7 +266,7 @@ impl Tool for TerminalTool {
card.elapsed_time = Some(card.start_instant.elapsed());
});
Ok(processed_content)
Ok(processed_content.into())
}
});
@ -661,7 +661,7 @@ mod tests {
)
});
let output = result.output.await.log_err();
let output = result.output.await.log_err().map(|output| output.content);
assert_eq!(output, Some("Command executed successfully.".into()));
}
@ -693,7 +693,11 @@ mod tests {
cx,
);
cx.spawn(async move |_| {
let output = headless_result.output.await.log_err();
let output = headless_result
.output
.await
.log_err()
.map(|output| output.content);
assert_eq!(output, expected);
})
};

View file

@ -55,7 +55,7 @@ impl Tool for ThinkingTool {
) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions.
Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) {
Ok(_input) => Ok("Finished thinking.".to_string()),
Ok(_input) => Ok("Finished thinking.".to_string().into()),
Err(err) => Err(anyhow!(err)),
})
.into()

View file

@ -72,7 +72,9 @@ impl Tool for WebSearchTool {
let search_task = search_task.clone();
async move {
let response = search_task.await.map_err(|err| anyhow!(err))?;
serde_json::to_string(&response).context("Failed to serialize search results")
serde_json::to_string(&response)
.context("Failed to serialize search results")
.map(Into::into)
}
});

View file

@ -131,6 +131,7 @@ pub struct LanguageModelToolResult {
pub tool_name: Arc<str>,
pub is_error: bool,
pub content: Arc<str>,
pub output: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]