assistant edit tool: Fix editing files in context (#26751)

When the user attached context in the thread, the editor model request
would fail because its tool use wouldn't be removed properly leading to
an API error.

Also, after an edit, we'd keep the old file snapshot in the context.
This would make the model think that the edits didn't apply and make it
go in a loop.

Release Notes:

- N/A
This commit is contained in:
Agus Zubiaga 2025-03-14 17:07:43 -03:00 committed by GitHub
parent ba8b9ec2c7
commit 1bf1c7223f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 239 additions and 57 deletions

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::Tool;
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@ -37,6 +37,7 @@ impl Tool for BashTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input: BashToolInput = match serde_json::from_value(input) {

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@ -45,6 +45,7 @@ impl Tool for DeletePathTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let glob = match serde_json::from_value::<DeletePathToolInput>(input) {

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::LanguageModelRequestMessage;
@ -51,6 +51,7 @@ impl Tool for DiagnosticsTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<DiagnosticsToolInput>(input) {

View file

@ -2,13 +2,13 @@ mod edit_action;
pub mod log;
use anyhow::{anyhow, Context, Result};
use assistant_tool::Tool;
use assistant_tool::{ActionLog, Tool};
use collections::HashSet;
use edit_action::{EditAction, EditActionParser};
use futures::StreamExt;
use gpui::{App, AsyncApp, Entity, Task};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
use log::{EditToolLog, EditToolRequestId};
use project::{search::SearchQuery, Project};
@ -80,6 +80,7 @@ impl Tool for EditFilesTool {
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<EditFilesToolInput>(input) {
@ -93,8 +94,14 @@ impl Tool for EditFilesTool {
log.new_request(input.edit_instructions.clone(), cx)
});
let task =
EditToolRequest::new(input, messages, project, Some((log.clone(), req_id)), cx);
let task = EditToolRequest::new(
input,
messages,
project,
action_log,
Some((log.clone(), req_id)),
cx,
);
cx.spawn(|mut cx| async move {
let result = task.await;
@ -113,7 +120,7 @@ impl Tool for EditFilesTool {
})
}
None => EditToolRequest::new(input, messages, project, None, cx),
None => EditToolRequest::new(input, messages, project, action_log, None, cx),
}
}
}
@ -123,7 +130,8 @@ struct EditToolRequest {
changed_buffers: HashSet<Entity<language::Buffer>>,
bad_searches: Vec<BadSearch>,
project: Entity<Project>,
log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
}
#[derive(Debug)]
@ -143,7 +151,8 @@ impl EditToolRequest {
input: EditFilesToolInput,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
cx: &mut App,
) -> Task<Result<String>> {
let model_registry = LanguageModelRegistry::read_global(cx);
@ -152,12 +161,23 @@ impl EditToolRequest {
};
let mut messages = messages.to_vec();
if let Some(last_message) = messages.last_mut() {
// Strip out tool use from the last message because we're in the middle of executing a tool call.
last_message
.content
.retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
// Remove the last tool use (this run) to prevent an invalid request
'outer: for message in messages.iter_mut().rev() {
for (index, content) in message.content.iter().enumerate().rev() {
match content {
MessageContent::ToolUse(_) => {
message.content.remove(index);
break 'outer;
}
MessageContent::ToolResult(_) => {
// If we find any tool results before a tool use, the request is already valid
break 'outer;
}
MessageContent::Text(_) | MessageContent::Image(_) => {}
}
}
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![
@ -182,8 +202,9 @@ impl EditToolRequest {
parser: EditActionParser::new(),
changed_buffers: HashSet::default(),
bad_searches: Vec::new(),
action_log,
project,
log,
tool_log,
};
while let Some(chunk) = chunks.stream.next().await {
@ -197,7 +218,7 @@ impl EditToolRequest {
async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
let new_actions = self.parser.parse_chunk(chunk);
if let Some((ref log, req_id)) = self.log {
if let Some((ref log, req_id)) = self.tool_log {
log.update(cx, |log, cx| {
log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
})
@ -310,7 +331,7 @@ impl EditToolRequest {
};
// Save each buffer once at the end
for buffer in self.changed_buffers {
for buffer in &self.changed_buffers {
let (path, save_task) = self.project.update(cx, |project, cx| {
let path = buffer
.read(cx)
@ -329,10 +350,17 @@ impl EditToolRequest {
}
}
self.action_log
.update(cx, |log, cx| {
log.notify_buffers_changed(self.changed_buffers, cx)
})
.log_err();
let errors = self.parser.errors();
if errors.is_empty() && self.bad_searches.is_empty() {
Ok(answer.trim_end().to_string())
let answer = answer.trim_end().to_string();
Ok(answer)
} else {
if !self.bad_searches.is_empty() {
writeln!(
@ -369,7 +397,7 @@ impl EditToolRequest {
but errors are part of the conversation so you don't need to repeat them."
)?;
Err(anyhow!(answer))
Err(anyhow!(answer.trim_end().to_string()))
}
}
}

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@ -55,6 +55,7 @@ impl Tool for ListDirectoryTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use assistant_tool::{ActionLog, Tool};
use chrono::{Local, Utc};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
@ -45,6 +45,7 @@ impl Tool for NowTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_cx: &mut App,
) -> Task<Result<String>> {
let input: NowToolInput = match serde_json::from_value(input) {

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@ -45,6 +45,7 @@ impl Tool for PathSearchTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let glob = match serde_json::from_value::<PathSearchToolInput>(input) {

View file

@ -2,7 +2,7 @@ use std::path::Path;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@ -49,6 +49,7 @@ impl Tool for ReadFileTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use assistant_tool::{ActionLog, Tool};
use futures::StreamExt;
use gpui::{App, Entity, Task};
use language::OffsetRangeExt;
@ -38,6 +38,7 @@ impl Tool for RegexSearchTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
const CONTEXT_LINES: u32 = 2;
@ -110,7 +111,7 @@ impl Tool for RegexSearchTool {
}
if output.is_empty() {
Ok("No matches found".into())
Ok("No matches found".to_string())
} else {
Ok(output)
}