assistant edit tool: Fix editing files in context (#26751)

When the user attached context in the thread, the editor model request
would fail because its tool use wouldn't be removed properly leading to
an API error.

Also, after an edit, we'd keep the old file snapshot in the context.
This would make the model think that the edits didn't apply and make it
go in a loop.

Release Notes:

- N/A
This commit is contained in:
Agus Zubiaga 2025-03-14 17:07:43 -03:00 committed by GitHub
parent ba8b9ec2c7
commit 1bf1c7223f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 239 additions and 57 deletions

1
Cargo.lock generated
View file

@ -660,6 +660,7 @@ dependencies = [
"collections", "collections",
"derive_more", "derive_more",
"gpui", "gpui",
"language",
"language_model", "language_model",
"parking_lot", "parking_lot",
"project", "project",

View file

@ -22,10 +22,13 @@ use ui::Color;
use ui::{prelude::*, Disclosure, KeyBinding}; use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _; use util::ResultExt as _;
use crate::context_store::{refresh_context_store_text, ContextStore};
pub struct ActiveThread { pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
thread: Entity<Thread>, thread: Entity<Thread>,
context_store: Entity<ContextStore>,
save_thread_task: Option<Task<()>>, save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>, messages: Vec<MessageId>,
list_state: ListState, list_state: ListState,
@ -46,6 +49,7 @@ impl ActiveThread {
thread: Entity<Thread>, thread: Entity<Thread>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
context_store: Entity<ContextStore>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -58,6 +62,7 @@ impl ActiveThread {
language_registry, language_registry,
thread_store, thread_store,
thread: thread.clone(), thread: thread.clone(),
context_store,
save_thread_task: None, save_thread_task: None,
messages: Vec::new(), messages: Vec::new(),
rendered_messages_by_id: HashMap::default(), rendered_messages_by_id: HashMap::default(),
@ -350,11 +355,51 @@ impl ActiveThread {
} }
if self.thread.read(cx).all_tools_finished() { if self.thread.read(cx).all_tools_finished() {
let pending_refresh_buffers = self.thread.update(cx, |thread, cx| {
thread.action_log().update(cx, |action_log, _cx| {
action_log.take_pending_refresh_buffers()
})
});
let context_update_task = if !pending_refresh_buffers.is_empty() {
let refresh_task = refresh_context_store_text(
self.context_store.clone(),
&pending_refresh_buffers,
cx,
);
cx.spawn(|this, mut cx| async move {
let updated_context_ids = refresh_task.await;
this.update(&mut cx, |this, cx| {
this.context_store.read_with(cx, |context_store, cx| {
context_store
.context()
.iter()
.filter(|context| {
updated_context_ids.contains(&context.id())
})
.flat_map(|context| context.snapshot(cx))
.collect()
})
})
})
} else {
Task::ready(anyhow::Ok(Vec::new()))
};
let model_registry = LanguageModelRegistry::read_global(cx); let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() { if let Some(model) = model_registry.active_model() {
self.thread.update(cx, |thread, cx| { cx.spawn(|this, mut cx| async move {
thread.send_tool_results_to_model(model, cx); let updated_context = context_update_task.await?;
});
this.update(&mut cx, |this, cx| {
this.thread.update(cx, |thread, cx| {
thread.send_tool_results_to_model(model, updated_context, cx);
});
})
})
.detach();
} }
} }
} }

View file

@ -155,10 +155,14 @@ impl AssistantPanel {
let workspace = workspace.weak_handle(); let workspace = workspace.weak_handle();
let weak_self = cx.entity().downgrade(); let weak_self = cx.entity().downgrade();
let message_editor_context_store =
cx.new(|_cx| crate::context_store::ContextStore::new(workspace.clone()));
let message_editor = cx.new(|cx| { let message_editor = cx.new(|cx| {
MessageEditor::new( MessageEditor::new(
fs.clone(), fs.clone(),
workspace.clone(), workspace.clone(),
message_editor_context_store.clone(),
thread_store.downgrade(), thread_store.downgrade(),
thread.clone(), thread.clone(),
window, window,
@ -174,6 +178,7 @@ impl AssistantPanel {
thread.clone(), thread.clone(),
thread_store.clone(), thread_store.clone(),
language_registry.clone(), language_registry.clone(),
message_editor_context_store.clone(),
window, window,
cx, cx,
) )
@ -242,11 +247,16 @@ impl AssistantPanel {
.update(cx, |this, cx| this.create_thread(cx)); .update(cx, |this, cx| this.create_thread(cx));
self.active_view = ActiveView::Thread; self.active_view = ActiveView::Thread;
let message_editor_context_store =
cx.new(|_cx| crate::context_store::ContextStore::new(self.workspace.clone()));
self.thread = cx.new(|cx| { self.thread = cx.new(|cx| {
ActiveThread::new( ActiveThread::new(
thread.clone(), thread.clone(),
self.thread_store.clone(), self.thread_store.clone(),
self.language_registry.clone(), self.language_registry.clone(),
message_editor_context_store.clone(),
window, window,
cx, cx,
) )
@ -255,6 +265,7 @@ impl AssistantPanel {
MessageEditor::new( MessageEditor::new(
self.fs.clone(), self.fs.clone(),
self.workspace.clone(), self.workspace.clone(),
message_editor_context_store,
self.thread_store.downgrade(), self.thread_store.downgrade(),
thread, thread,
window, window,
@ -375,11 +386,14 @@ impl AssistantPanel {
let thread = open_thread_task.await?; let thread = open_thread_task.await?;
this.update_in(&mut cx, |this, window, cx| { this.update_in(&mut cx, |this, window, cx| {
this.active_view = ActiveView::Thread; this.active_view = ActiveView::Thread;
let message_editor_context_store =
cx.new(|_cx| crate::context_store::ContextStore::new(this.workspace.clone()));
this.thread = cx.new(|cx| { this.thread = cx.new(|cx| {
ActiveThread::new( ActiveThread::new(
thread.clone(), thread.clone(),
this.thread_store.clone(), this.thread_store.clone(),
this.language_registry.clone(), this.language_registry.clone(),
message_editor_context_store.clone(),
window, window,
cx, cx,
) )
@ -388,6 +402,7 @@ impl AssistantPanel {
MessageEditor::new( MessageEditor::new(
this.fs.clone(), this.fs.clone(),
this.workspace.clone(), this.workspace.clone(),
message_editor_context_store,
this.thread_store.downgrade(), this.thread_store.downgrade(),
thread, thread,
window, window,

View file

@ -9,6 +9,7 @@ use language::Buffer;
use project::{ProjectPath, Worktree}; use project::{ProjectPath, Worktree};
use rope::Rope; use rope::Rope;
use text::BufferId; use text::BufferId;
use util::maybe;
use workspace::Workspace; use workspace::Workspace;
use crate::context::{ use crate::context::{
@ -531,35 +532,59 @@ fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<Arc<Path>> {
pub fn refresh_context_store_text( pub fn refresh_context_store_text(
context_store: Entity<ContextStore>, context_store: Entity<ContextStore>,
changed_buffers: &HashSet<Entity<Buffer>>,
cx: &App, cx: &App,
) -> impl Future<Output = ()> { ) -> impl Future<Output = Vec<ContextId>> {
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for context in &context_store.read(cx).context { for context in &context_store.read(cx).context {
match context { let id = context.id();
AssistantContext::File(file_context) => {
let context_store = context_store.clone(); let task = maybe!({
if let Some(task) = refresh_file_text(context_store, file_context, cx) { match context {
tasks.push(task); AssistantContext::File(file_context) => {
if changed_buffers.is_empty()
|| changed_buffers.contains(&file_context.context_buffer.buffer)
{
let context_store = context_store.clone();
return refresh_file_text(context_store, file_context, cx);
}
} }
} AssistantContext::Directory(directory_context) => {
AssistantContext::Directory(directory_context) => { let should_refresh = changed_buffers.is_empty()
let context_store = context_store.clone(); || changed_buffers.iter().any(|buffer| {
if let Some(task) = refresh_directory_text(context_store, directory_context, cx) { let buffer = buffer.read(cx);
tasks.push(task);
buffer_path_log_err(&buffer)
.map_or(false, |path| path.starts_with(&directory_context.path))
});
if should_refresh {
let context_store = context_store.clone();
return refresh_directory_text(context_store, directory_context, cx);
}
} }
AssistantContext::Thread(thread_context) => {
if changed_buffers.is_empty() {
let context_store = context_store.clone();
return Some(refresh_thread_text(context_store, thread_context, cx));
}
}
// Intentionally omit refreshing fetched URLs as it doesn't seem all that useful,
// and doing the caching properly could be tricky (unless it's already handled by
// the HttpClient?).
AssistantContext::FetchedUrl(_) => {}
} }
AssistantContext::Thread(thread_context) => {
let context_store = context_store.clone(); None
tasks.push(refresh_thread_text(context_store, thread_context, cx)); });
}
// Intentionally omit refreshing fetched URLs as it doesn't seem all that useful, if let Some(task) = task {
// and doing the caching properly could be tricky (unless it's already handled by tasks.push(task.map(move |_| id));
// the HttpClient?).
AssistantContext::FetchedUrl(_) => {}
} }
} }
future::join_all(tasks).map(|_| ()) future::join_all(tasks)
} }
fn refresh_file_text( fn refresh_file_text(

View file

@ -1,5 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use collections::HashSet;
use editor::actions::MoveUp; use editor::actions::MoveUp;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle}; use editor::{Editor, EditorElement, EditorEvent, EditorStyle};
use file_icons::FileIcons; use file_icons::FileIcons;
@ -51,13 +52,13 @@ impl MessageEditor {
pub fn new( pub fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
thread_store: WeakEntity<ThreadStore>, thread_store: WeakEntity<ThreadStore>,
thread: Entity<Thread>, thread: Entity<Thread>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let tools = thread.read(cx).tools().clone(); let tools = thread.read(cx).tools().clone();
let context_store = cx.new(|_cx| ContextStore::new(workspace.clone()));
let context_picker_menu_handle = PopoverMenuHandle::default(); let context_picker_menu_handle = PopoverMenuHandle::default();
let inline_context_picker_menu_handle = PopoverMenuHandle::default(); let inline_context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default(); let model_selector_menu_handle = PopoverMenuHandle::default();
@ -200,7 +201,8 @@ impl MessageEditor {
text text
}); });
let refresh_task = refresh_context_store_text(self.context_store.clone(), cx); let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
let thread = self.thread.clone(); let thread = self.thread.clone();
let context_store = self.context_store.clone(); let context_store = self.context_store.clone();

View file

@ -2,7 +2,7 @@ use std::io::Write;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use assistant_tool::ToolWorkingSet; use assistant_tool::{ActionLog, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use futures::future::Shared; use futures::future::Shared;
@ -104,6 +104,7 @@ pub struct Thread {
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState, tool_use: ToolUseState,
action_log: Entity<ActionLog>,
scripting_session: Entity<ScriptingSession>, scripting_session: Entity<ScriptingSession>,
scripting_tool_use: ToolUseState, scripting_tool_use: ToolUseState,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>, initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
@ -134,6 +135,7 @@ impl Thread {
tool_use: ToolUseState::new(), tool_use: ToolUseState::new(),
scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)), scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
scripting_tool_use: ToolUseState::new(), scripting_tool_use: ToolUseState::new(),
action_log: cx.new(|_| ActionLog::new()),
initial_project_snapshot: { initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project, cx); let project_snapshot = Self::project_snapshot(project, cx);
cx.foreground_executor() cx.foreground_executor()
@ -191,6 +193,7 @@ impl Thread {
prompt_builder, prompt_builder,
tools, tools,
tool_use, tool_use,
action_log: cx.new(|_| ActionLog::new()),
scripting_session, scripting_session,
scripting_tool_use, scripting_tool_use,
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
@ -750,7 +753,13 @@ impl Thread {
for tool_use in pending_tool_uses { for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) { if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, &request.messages, self.project.clone(), cx); let task = tool.run(
tool_use.input,
&request.messages,
self.project.clone(),
self.action_log.clone(),
cx,
);
self.insert_tool_output(tool_use.id.clone(), task, cx); self.insert_tool_output(tool_use.id.clone(), task, cx);
} }
@ -857,8 +866,15 @@ impl Thread {
pub fn send_tool_results_to_model( pub fn send_tool_results_to_model(
&mut self, &mut self,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
updated_context: Vec<ContextSnapshot>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
self.context.extend(
updated_context
.into_iter()
.map(|context| (context.id, context)),
);
// Insert a user message to contain the tool results. // Insert a user message to contain the tool results.
self.insert_user_message( self.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back // TODO: Sending up a user message without any content results in the model sending back
@ -1057,6 +1073,10 @@ impl Thread {
Ok(String::from_utf8_lossy(&markdown).to_string()) Ok(String::from_utf8_lossy(&markdown).to_string())
} }
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn cumulative_token_usage(&self) -> TokenUsage { pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone() self.cumulative_token_usage.clone()
} }

View file

@ -226,12 +226,12 @@ impl ToolUseState {
output: Result<String>, output: Result<String>,
) -> Option<PendingToolUse> { ) -> Option<PendingToolUse> {
match output { match output {
Ok(output) => { Ok(tool_result) => {
self.tool_results.insert( self.tool_results.insert(
tool_use_id.clone(), tool_use_id.clone(),
LanguageModelToolResult { LanguageModelToolResult {
tool_use_id: tool_use_id.clone(), tool_use_id: tool_use_id.clone(),
content: output.into(), content: tool_result.into(),
is_error: false, is_error: false,
}, },
); );

View file

@ -15,8 +15,9 @@ path = "src/assistant_tool.rs"
anyhow.workspace = true anyhow.workspace = true
collections.workspace = true collections.workspace = true
derive_more.workspace = true derive_more.workspace = true
language_model.workspace = true
gpui.workspace = true gpui.workspace = true
language.workspace = true
language_model.workspace = true
parking_lot.workspace = true parking_lot.workspace = true
project.workspace = true project.workspace = true
serde.workspace = true serde.workspace = true

View file

@ -4,7 +4,10 @@ mod tool_working_set;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use collections::HashSet;
use gpui::Context;
use gpui::{App, Entity, SharedString, Task}; use gpui::{App, Entity, SharedString, Task};
use language::Buffer;
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
use project::Project; use project::Project;
@ -47,6 +50,39 @@ pub trait Tool: 'static + Send + Sync {
input: serde_json::Value, input: serde_json::Value,
messages: &[LanguageModelRequestMessage], messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>>; ) -> Task<Result<String>>;
} }
/// Tracks actions performed by tools in a thread
#[derive(Debug)]
pub struct ActionLog {
changed_buffers: HashSet<Entity<Buffer>>,
pending_refresh: HashSet<Entity<Buffer>>,
}
impl ActionLog {
/// Creates a new, empty action log.
pub fn new() -> Self {
Self {
changed_buffers: HashSet::default(),
pending_refresh: HashSet::default(),
}
}
/// Registers buffers that have changed and need refreshing.
pub fn notify_buffers_changed(
&mut self,
buffers: HashSet<Entity<Buffer>>,
_cx: &mut Context<Self>,
) {
self.changed_buffers.extend(buffers.clone());
self.pending_refresh.extend(buffers);
}
/// Takes and returns the set of buffers pending refresh, clearing internal state.
pub fn take_pending_refresh_buffers(&mut self) -> HashSet<Entity<Buffer>> {
std::mem::take(&mut self.pending_refresh)
}
}

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use assistant_tool::Tool; use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
use project::Project; use project::Project;
@ -37,6 +37,7 @@ impl Tool for BashTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let input: BashToolInput = match serde_json::from_value(input) { let input: BashToolInput = match serde_json::from_value(input) {

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_tool::Tool; use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
use project::Project; use project::Project;
@ -45,6 +45,7 @@ impl Tool for DeletePathTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let glob = match serde_json::from_value::<DeletePathToolInput>(input) { let glob = match serde_json::from_value::<DeletePathToolInput>(input) {

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_tool::Tool; use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt}; use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
@ -51,6 +51,7 @@ impl Tool for DiagnosticsTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let input = match serde_json::from_value::<DiagnosticsToolInput>(input) { let input = match serde_json::from_value::<DiagnosticsToolInput>(input) {

View file

@ -2,13 +2,13 @@ mod edit_action;
pub mod log; pub mod log;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use assistant_tool::Tool; use assistant_tool::{ActionLog, Tool};
use collections::HashSet; use collections::HashSet;
use edit_action::{EditAction, EditActionParser}; use edit_action::{EditAction, EditActionParser};
use futures::StreamExt; use futures::StreamExt;
use gpui::{App, AsyncApp, Entity, Task}; use gpui::{App, AsyncApp, Entity, Task};
use language_model::{ use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
}; };
use log::{EditToolLog, EditToolRequestId}; use log::{EditToolLog, EditToolRequestId};
use project::{search::SearchQuery, Project}; use project::{search::SearchQuery, Project};
@ -80,6 +80,7 @@ impl Tool for EditFilesTool {
input: serde_json::Value, input: serde_json::Value,
messages: &[LanguageModelRequestMessage], messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let input = match serde_json::from_value::<EditFilesToolInput>(input) { let input = match serde_json::from_value::<EditFilesToolInput>(input) {
@ -93,8 +94,14 @@ impl Tool for EditFilesTool {
log.new_request(input.edit_instructions.clone(), cx) log.new_request(input.edit_instructions.clone(), cx)
}); });
let task = let task = EditToolRequest::new(
EditToolRequest::new(input, messages, project, Some((log.clone(), req_id)), cx); input,
messages,
project,
action_log,
Some((log.clone(), req_id)),
cx,
);
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let result = task.await; let result = task.await;
@ -113,7 +120,7 @@ impl Tool for EditFilesTool {
}) })
} }
None => EditToolRequest::new(input, messages, project, None, cx), None => EditToolRequest::new(input, messages, project, action_log, None, cx),
} }
} }
} }
@ -123,7 +130,8 @@ struct EditToolRequest {
changed_buffers: HashSet<Entity<language::Buffer>>, changed_buffers: HashSet<Entity<language::Buffer>>,
bad_searches: Vec<BadSearch>, bad_searches: Vec<BadSearch>,
project: Entity<Project>, project: Entity<Project>,
log: Option<(Entity<EditToolLog>, EditToolRequestId)>, action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
} }
#[derive(Debug)] #[derive(Debug)]
@ -143,7 +151,8 @@ impl EditToolRequest {
input: EditFilesToolInput, input: EditFilesToolInput,
messages: &[LanguageModelRequestMessage], messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
log: Option<(Entity<EditToolLog>, EditToolRequestId)>, action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let model_registry = LanguageModelRegistry::read_global(cx); let model_registry = LanguageModelRegistry::read_global(cx);
@ -152,12 +161,23 @@ impl EditToolRequest {
}; };
let mut messages = messages.to_vec(); let mut messages = messages.to_vec();
if let Some(last_message) = messages.last_mut() { // Remove the last tool use (this run) to prevent an invalid request
// Strip out tool use from the last message because we're in the middle of executing a tool call. 'outer: for message in messages.iter_mut().rev() {
last_message for (index, content) in message.content.iter().enumerate().rev() {
.content match content {
.retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_))) MessageContent::ToolUse(_) => {
message.content.remove(index);
break 'outer;
}
MessageContent::ToolResult(_) => {
// If we find any tool results before a tool use, the request is already valid
break 'outer;
}
MessageContent::Text(_) | MessageContent::Image(_) => {}
}
}
} }
messages.push(LanguageModelRequestMessage { messages.push(LanguageModelRequestMessage {
role: Role::User, role: Role::User,
content: vec![ content: vec![
@ -182,8 +202,9 @@ impl EditToolRequest {
parser: EditActionParser::new(), parser: EditActionParser::new(),
changed_buffers: HashSet::default(), changed_buffers: HashSet::default(),
bad_searches: Vec::new(), bad_searches: Vec::new(),
action_log,
project, project,
log, tool_log,
}; };
while let Some(chunk) = chunks.stream.next().await { while let Some(chunk) = chunks.stream.next().await {
@ -197,7 +218,7 @@ impl EditToolRequest {
async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> { async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
let new_actions = self.parser.parse_chunk(chunk); let new_actions = self.parser.parse_chunk(chunk);
if let Some((ref log, req_id)) = self.log { if let Some((ref log, req_id)) = self.tool_log {
log.update(cx, |log, cx| { log.update(cx, |log, cx| {
log.push_editor_response_chunk(req_id, chunk, &new_actions, cx) log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
}) })
@ -310,7 +331,7 @@ impl EditToolRequest {
}; };
// Save each buffer once at the end // Save each buffer once at the end
for buffer in self.changed_buffers { for buffer in &self.changed_buffers {
let (path, save_task) = self.project.update(cx, |project, cx| { let (path, save_task) = self.project.update(cx, |project, cx| {
let path = buffer let path = buffer
.read(cx) .read(cx)
@ -329,10 +350,17 @@ impl EditToolRequest {
} }
} }
self.action_log
.update(cx, |log, cx| {
log.notify_buffers_changed(self.changed_buffers, cx)
})
.log_err();
let errors = self.parser.errors(); let errors = self.parser.errors();
if errors.is_empty() && self.bad_searches.is_empty() { if errors.is_empty() && self.bad_searches.is_empty() {
Ok(answer.trim_end().to_string()) let answer = answer.trim_end().to_string();
Ok(answer)
} else { } else {
if !self.bad_searches.is_empty() { if !self.bad_searches.is_empty() {
writeln!( writeln!(
@ -369,7 +397,7 @@ impl EditToolRequest {
but errors are part of the conversation so you don't need to repeat them." but errors are part of the conversation so you don't need to repeat them."
)?; )?;
Err(anyhow!(answer)) Err(anyhow!(answer.trim_end().to_string()))
} }
} }
} }

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_tool::Tool; use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
use project::Project; use project::Project;
@ -55,6 +55,7 @@ impl Tool for ListDirectoryTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) { let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_tool::Tool; use assistant_tool::{ActionLog, Tool};
use chrono::{Local, Utc}; use chrono::{Local, Utc};
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
@ -45,6 +45,7 @@ impl Tool for NowTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>,
_cx: &mut App, _cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let input: NowToolInput = match serde_json::from_value(input) { let input: NowToolInput = match serde_json::from_value(input) {

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_tool::Tool; use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
use project::Project; use project::Project;
@ -45,6 +45,7 @@ impl Tool for PathSearchTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let glob = match serde_json::from_value::<PathSearchToolInput>(input) { let glob = match serde_json::from_value::<PathSearchToolInput>(input) {

View file

@ -2,7 +2,7 @@ use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_tool::Tool; use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
use project::Project; use project::Project;
@ -49,6 +49,7 @@ impl Tool for ReadFileTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let input = match serde_json::from_value::<ReadFileToolInput>(input) { let input = match serde_json::from_value::<ReadFileToolInput>(input) {

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_tool::Tool; use assistant_tool::{ActionLog, Tool};
use futures::StreamExt; use futures::StreamExt;
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language::OffsetRangeExt; use language::OffsetRangeExt;
@ -38,6 +38,7 @@ impl Tool for RegexSearchTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
const CONTEXT_LINES: u32 = 2; const CONTEXT_LINES: u32 = 2;
@ -110,7 +111,7 @@ impl Tool for RegexSearchTool {
} }
if output.is_empty() { if output.is_empty() {
Ok("No matches found".into()) Ok("No matches found".to_string())
} else { } else {
Ok(output) Ok(output)
} }

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::{anyhow, bail, Result}; use anyhow::{anyhow, bail, Result};
use assistant_tool::{Tool, ToolSource}; use assistant_tool::{ActionLog, Tool, ToolSource};
use gpui::{App, Entity, Task}; use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage; use language_model::LanguageModelRequestMessage;
use project::Project; use project::Project;
@ -61,6 +61,7 @@ impl Tool for ContextServerTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
_project: Entity<Project>, _project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) { if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {