Lay the groundwork to support history in agent2 (#36483)

This pull request introduces title generation and history replaying. We
still need to wire up the rest of the history but this gets us very
close. I extracted a lot of this code from `agent2-history` because that
branch was starting to get long-lived and there were lots of changes
since we started.

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2025-08-19 16:24:23 +02:00 committed by GitHub
parent c4083b9b63
commit 6c255c1973
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 929 additions and 328 deletions

View file

@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
use language::{LanguageRegistry, ToPoint};
use language_model::LanguageModelToolResultContent;
use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
@ -98,11 +98,13 @@ pub enum EditFileMode {
#[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput {
#[serde(alias = "original_path")]
input_path: PathBuf,
project_path: PathBuf,
new_text: String,
old_text: Arc<String>,
#[serde(default)]
diff: String,
#[serde(alias = "raw_output")]
edit_agent_output: EditAgentOutput,
}
@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
}
pub struct EditFileTool {
thread: Entity<Thread>,
thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
}
impl EditFileTool {
pub fn new(thread: Entity<Thread>) -> Self {
Self { thread }
pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
Self {
thread,
language_registry,
}
}
fn authorize(
@ -167,8 +173,11 @@ impl EditFileTool {
// Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize
let thread = self.thread.read(cx);
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
thread.project().read(cx).find_project_path(&input.path, cx)
}) else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
// If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary.
@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let project = self.thread.read(cx).project().clone();
let Ok(project) = self
.thread
.read_with(cx, |thread, _cx| thread.project().clone())
else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))),
@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
});
}
let Some(request) = self.thread.update(cx, |thread, cx| {
thread
.build_completion_request(CompletionIntent::ToolResults, cx)
.ok()
}) else {
return Task::ready(Err(anyhow!("Failed to build completion request")));
};
let thread = self.thread.read(cx);
let Some(model) = thread.model().cloned() else {
return Task::ready(Err(anyhow!("No language model configured")));
};
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx);
cx.spawn(async move |cx: &mut AsyncApp| {
authorize.await?;
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
(request, thread.model().cloned(), thread.action_log().clone())
})?;
let request = request?;
let model = model.context("No language model configured")?;
let edit_format = EditFormat::from_model(model.clone())?;
let edit_agent = EditAgent::new(
model,
@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
Ok(EditFileToolOutput {
input_path: input.path,
project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text,
diff: unified_diff,
@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
})
})
}
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()> {
event_stream.update_diff(cx.new(|cx| {
Diff::finalized(
output.input_path,
Some(output.old_text.to_string()),
output.new_text,
self.language_registry.clone(),
cx,
)
}));
Ok(())
}
}
/// Validate that the file path is valid, meaning:
@ -515,6 +541,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -527,6 +554,7 @@ mod tests {
action_log,
Templates::new(),
Some(model),
None,
cx,
)
});
@ -537,7 +565,11 @@ mod tests {
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
})
.await;
assert_eq!(
@ -724,6 +756,7 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
@ -750,9 +783,10 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
Arc::new(EditFileTool::new(
thread.downgrade(),
language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx)
});
@ -806,7 +840,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the unformatted content
@ -850,6 +888,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
@ -860,6 +899,7 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
@ -887,9 +927,10 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
Arc::new(EditFileTool::new(
thread.downgrade(),
language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx)
});
@ -938,10 +979,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
.run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the content with trailing whitespace
@ -976,6 +1018,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
@ -986,10 +1029,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation
@ -1111,6 +1155,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@ -1123,10 +1168,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
@ -1220,7 +1266,7 @@ mod tests {
cx,
)
.await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1233,10 +1279,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test files in different worktrees
let test_cases = vec![
@ -1302,6 +1349,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1314,10 +1362,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test edge cases
let test_cases = vec![
@ -1386,6 +1435,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1398,10 +1448,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test different EditFileMode values
let modes = vec![
@ -1467,6 +1518,7 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1479,10 +1531,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
assert_eq!(
tool.initial_title(Err(json!({