Introduce a new StreamingEditFileTool (#29733)

This pull request introduces a new tool for streaming edits. The
short-term goal is for this tool to replace the existing `EditFileTool`,
but we want to get this out the door as soon as possible so that we can
start testing it.

`StreamingEditFileTool` is mutually exclusive with `EditFileTool`. It
will be enabled by default for anyone who has the `agent-stream-edits`
feature flag, as well as people that set `assistant.stream_edits` to
`true` in their settings.

### Implementation

Streaming is achieved by requesting a completion while the `edit_file`
tool gets called. We invoke the model by taking the existing
conversation with the agent and appending a prompt specifically tailored
for editing. In that prompt, we ask the model to produce a stream of
`<old_text>`/`<new_text>` tags. As the model streams text in, we
incrementally parse it and start editing as soon as we can.

### Evals

Note that, as part of this pull request, I also defined some new evals
that I used to drive the behavior of the recursive LLM call. To run
them, use this command:

```bash
cargo test --package=assistant_tools --features eval -- eval_extract_handle_command_output
```

Or comment out the `#[cfg_attr(not(feature = "eval"), ignore)]` macro.

I recommend running them one at a time, because right now we don't
really have a way of orchestrating of all these evals. I think we should
invest into that effort once the new agent panel goes live.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-05-01 17:37:43 +02:00 committed by GitHub
parent e3a2d52472
commit f891dfb358
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 49077 additions and 20 deletions

View file

@ -11,16 +11,24 @@ workspace = true
[lib]
path = "src/assistant_tools.rs"
[features]
eval = []
[dependencies]
aho-corasick.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_settings.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
editor.workspace = true
derive_more.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
http_client.workspace = true
indoc.workspace = true
@ -31,9 +39,14 @@ linkme.workspace = true
open.workspace = true
project.workspace = true
regex.workspace = true
rust-embed.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smallvec.workspace = true
streaming_diff.workspace = true
strsim.workspace = true
task.workspace = true
terminal.workspace = true
terminal_view.workspace = true
@ -49,10 +62,15 @@ client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
gpui_tokio.workspace = true
fs = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
language_models.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
pretty_assertions.workspace = true
reqwest_client.workspace = true
settings = { workspace = true, features = ["test-support"] }
task = { workspace = true, features = ["test-support"]}
tempfile.workspace = true

View file

@ -7,6 +7,7 @@ mod create_directory_tool;
mod create_file_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_agent;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
@ -19,7 +20,9 @@ mod read_file_tool;
mod rename_tool;
mod replace;
mod schema;
mod streaming_edit_file_tool;
mod symbol_info_tool;
mod templates;
mod terminal_tool;
mod thinking_tool;
mod ui;
@ -27,14 +30,19 @@ mod web_search_tool;
use std::sync::Arc;
use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
use gpui::App;
use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool;
use settings::{Settings, SettingsStore};
use web_search_tool::WebSearchTool;
pub(crate) use templates::*;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
@ -52,6 +60,7 @@ use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
use crate::read_file_tool::ReadFileTool;
use crate::rename_tool::RenameTool;
use crate::streaming_edit_file_tool::StreamingEditFileTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
@ -71,7 +80,6 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(EditFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
@ -88,6 +96,12 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
register_edit_file_tool(cx);
cx.observe_flag::<AgentStreamEditsFeatureFlag, _>(|_, cx| register_edit_file_tool(cx))
.detach();
cx.observe_global::<SettingsStore>(register_edit_file_tool)
.detach();
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
@ -108,6 +122,19 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
.detach();
}
fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx);
registry.unregister_tool(EditFileTool);
registry.unregister_tool(StreamingEditFileTool);
if AssistantSettings::get_global(cx).stream_edits(cx) {
registry.register_tool(StreamingEditFileTool);
} else {
registry.register_tool(EditFileTool);
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -146,6 +173,7 @@ mod tests {
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
settings::init(cx);
AssistantSettings::register(cx);
let client = Client::new(
Arc::new(FakeSystemClock::new()),

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,408 @@
use derive_more::{Add, AddAssign};
use smallvec::SmallVec;
use std::{cmp, mem, ops::Range};
const OLD_TEXT_END_TAG: &str = "</old_text>";
const NEW_TEXT_END_TAG: &str = "</new_text>";
const END_TAG_LEN: usize = OLD_TEXT_END_TAG.len();
const _: () = debug_assert!(OLD_TEXT_END_TAG.len() == NEW_TEXT_END_TAG.len());
#[derive(Debug)]
pub enum EditParserEvent {
OldText(String),
NewTextChunk { chunk: String, done: bool },
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Add, AddAssign)]
pub struct EditParserMetrics {
pub tags: usize,
pub mismatched_tags: usize,
}
#[derive(Debug)]
pub struct EditParser {
state: EditParserState,
buffer: String,
metrics: EditParserMetrics,
}
#[derive(Debug, PartialEq)]
enum EditParserState {
Pending,
WithinOldText,
AfterOldText,
WithinNewText { start: bool },
}
impl EditParser {
pub fn new() -> Self {
EditParser {
state: EditParserState::Pending,
buffer: String::new(),
metrics: EditParserMetrics::default(),
}
}
pub fn push(&mut self, chunk: &str) -> SmallVec<[EditParserEvent; 1]> {
self.buffer.push_str(chunk);
let mut edit_events = SmallVec::new();
loop {
match &mut self.state {
EditParserState::Pending => {
if let Some(start) = self.buffer.find("<old_text>") {
self.buffer.drain(..start + "<old_text>".len());
self.state = EditParserState::WithinOldText;
} else {
break;
}
}
EditParserState::WithinOldText => {
if let Some(tag_range) = self.find_end_tag() {
let mut start = 0;
if self.buffer.starts_with('\n') {
start = 1;
}
let mut old_text = self.buffer[start..tag_range.start].to_string();
if old_text.ends_with('\n') {
old_text.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != OLD_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::AfterOldText;
edit_events.push(EditParserEvent::OldText(old_text));
} else {
break;
}
}
EditParserState::AfterOldText => {
if let Some(start) = self.buffer.find("<new_text>") {
self.buffer.drain(..start + "<new_text>".len());
self.state = EditParserState::WithinNewText { start: true };
} else {
break;
}
}
EditParserState::WithinNewText { start } => {
if !self.buffer.is_empty() {
if *start && self.buffer.starts_with('\n') {
self.buffer.remove(0);
}
*start = false;
}
if let Some(tag_range) = self.find_end_tag() {
let mut chunk = self.buffer[..tag_range.start].to_string();
if chunk.ends_with('\n') {
chunk.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != NEW_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::Pending;
edit_events.push(EditParserEvent::NewTextChunk { chunk, done: true });
} else {
let mut end_prefixes = (1..END_TAG_LEN)
.flat_map(|i| [&NEW_TEXT_END_TAG[..i], &OLD_TEXT_END_TAG[..i]])
.chain(["\n"]);
if end_prefixes.all(|prefix| !self.buffer.ends_with(&prefix)) {
edit_events.push(EditParserEvent::NewTextChunk {
chunk: mem::take(&mut self.buffer),
done: false,
});
}
break;
}
}
}
}
edit_events
}
fn find_end_tag(&self) -> Option<Range<usize>> {
let old_text_end_tag_ix = self.buffer.find(OLD_TEXT_END_TAG);
let new_text_end_tag_ix = self.buffer.find(NEW_TEXT_END_TAG);
let start_ix = if let Some((old_text_ix, new_text_ix)) =
old_text_end_tag_ix.zip(new_text_end_tag_ix)
{
cmp::min(old_text_ix, new_text_ix)
} else {
old_text_end_tag_ix.or(new_text_end_tag_ix)?
};
Some(start_ix..start_ix + END_TAG_LEN)
}
pub fn finish(self) -> EditParserMetrics {
self.metrics
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use rand::prelude::*;
use std::cmp;
#[gpui::test(iterations = 1000)]
fn test_single_edit(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>original</old_text><new_text>updated</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "original".to_string(),
new_text: "updated".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_multiple_edits(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
<old_text>
first old
</old_text><new_text>first new</new_text>
<old_text>second old</old_text><new_text>
second new
</new_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "first old".to_string(),
new_text: "first new".to_string(),
},
Edit {
old_text: "second old".to_string(),
new_text: "second new".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_edits_with_extra_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
ignore this <old_text>
content</old_text>extra stuff<new_text>updated content</new_text>trailing data
more text <old_text>second item
</old_text>middle text<new_text>modified second item</new_text>end
<old_text>third case</old_text><new_text>improved third case</new_text> with trailing text
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "content".to_string(),
new_text: "updated content".to_string(),
},
Edit {
old_text: "second item".to_string(),
new_text: "modified second item".to_string(),
},
Edit {
old_text: "third case".to_string(),
new_text: "improved third case".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 6,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_nested_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>code with <tag>nested</tag> elements</old_text><new_text>new <code>content</code></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "code with <tag>nested</tag> elements".to_string(),
new_text: "new <code>content</code>".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_empty_old_and_new_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text></old_text><new_text></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "".to_string(),
new_text: "".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 100)]
fn test_multiline_content(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>line1\nline2\nline3</old_text><new_text>line1\nmodified line2\nline3</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "line1\nline2\nline3".to_string(),
new_text: "line1\nmodified line2\nline3".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_mismatched_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
// Reduced from an actual Sonnet 3.7 output
indoc! {"
<old_text>
a
b
c
</new_text>
<new_text>
a
B
c
</old_text>
<old_text>
d
e
f
</new_text>
<new_text>
D
e
F
</old_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "a\nb\nc".to_string(),
new_text: "a\nB\nc".to_string(),
},
Edit {
old_text: "d\ne\nf".to_string(),
new_text: "D\ne\nF".to_string(),
}
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 4
}
);
}
#[derive(Default, Debug, PartialEq, Eq)]
struct Edit {
old_text: String,
new_text: String,
}
fn parse_random_chunks(input: &str, parser: &mut EditParser, rng: &mut StdRng) -> Vec<Edit> {
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
chunk_indices.sort();
chunk_indices.push(input.len());
let mut pending_edit = Edit::default();
let mut edits = Vec::new();
let mut last_ix = 0;
for chunk_ix in chunk_indices {
for event in parser.push(&input[last_ix..chunk_ix]) {
match event {
EditParserEvent::OldText(old_text) => {
pending_edit.old_text = old_text;
}
EditParserEvent::NewTextChunk { chunk, done } => {
pending_edit.new_text.push_str(&chunk);
if done {
edits.push(pending_edit);
pending_edit = Edit::default();
}
}
}
}
last_ix = chunk_ix;
}
edits
}
}

View file

@ -0,0 +1,889 @@
use super::*;
use crate::{
ReadFileToolInput, grep_tool::GrepToolInput,
streaming_edit_file_tool::StreamingEditFileToolInput,
};
use Role::*;
use anyhow::{Context, anyhow};
use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::{
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
};
use project::Project;
use rand::prelude::*;
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::{
cmp::Reverse,
fmt::{self, Display},
io::Write as _,
sync::mpsc,
};
use util::path;
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_extract_handle_command_output() {
let input_file_path = "root/blame.rs";
let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Read the `{input_file_path}` file and extract a method in
the final stanza of `run_git_blame` to deal with command failures,
call it `handle_command_output` and take the std::process::Output as the only parameter.
Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_delete_run_git_blame() {
let input_file_path = "root/blame.rs";
let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
let edit_description = "Delete the `run_git_blame` function.";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Read the `{input_file_path}` file and delete `run_git_blame`. Just that
one function, not its usages.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
let input_file_path = "root/lib.rs";
let input_file_content =
include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
Use `ureq` to download the SDK for the current platform and architecture.
Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
that's inside of the archive.
Don't re-download the SDK if that executable already exists.
Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
Here are the available wasi-sdk assets:
- wasi-sdk-25.0-x86_64-macos.tar.gz
- wasi-sdk-25.0-arm64-macos.tar.gz
- wasi-sdk-25.0-x86_64-linux.tar.gz
- wasi-sdk-25.0-arm64-linux.tar.gz
- wasi-sdk-25.0-x86_64-linux.tar.gz
- wasi-sdk-25.0-arm64-linux.tar.gz
- wasi-sdk-25.0-x86_64-windows.tar.gz
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: Some(971),
end_line: Some(1050),
},
)],
),
message(
User,
[tool_result(
"tool_1",
"read_file",
lines(input_file_content, 971..1050),
)],
),
message(
Assistant,
[tool_use(
"tool_2",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: Some(1050),
end_line: Some(1100),
},
)],
),
message(
User,
[tool_result(
"tool_2",
"read_file",
lines(input_file_content, 1050..1100),
)],
),
message(
Assistant,
[tool_use(
"tool_3",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: Some(1100),
end_line: Some(1150),
},
)],
),
message(
User,
[tool_result(
"tool_3",
"read_file",
lines(input_file_content, 1100..1150),
)],
),
message(
Assistant,
[tool_use(
"tool_4",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
- The compile_parser_to_wasm method has been changed to use wasi-sdk
- ureq is used to download the SDK for current platform and architecture
"}),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_disable_cursor_blinking() {
let input_file_path = "root/editor.rs";
let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
let edit_description = "Comment out the call to `BlinkManager::enable`";
eval(
100,
0.6, // TODO: make this eval better
EvalInput {
conversation: vec![
message(User, [text("Let's research how to cursor blinking works.")]),
message(
Assistant,
[tool_use(
"tool_1",
"grep",
GrepToolInput {
regex: "blink".into(),
include_pattern: None,
offset: 0,
case_sensitive: false,
},
)],
),
message(
User,
[tool_result(
"tool_1",
"grep",
[
lines(input_file_content, 100..400),
lines(input_file_content, 800..1300),
lines(input_file_content, 1600..2000),
lines(input_file_content, 5000..5500),
lines(input_file_content, 8000..9000),
lines(input_file_content, 18455..18470),
lines(input_file_content, 20000..20500),
lines(input_file_content, 21000..21300),
]
.join("Match found:\n\n"),
)],
),
message(
User,
[text(indoc! {"
Comment out the lines that interact with the BlinkManager.
Keep the outer `update` blocks, but comments everything that's inside (including if statements).
Don't add additional comments.
"})],
),
message(
Assistant,
[tool_use(
"tool_4",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_from_pixels_constructor() {
let input_file_path = "root/canvas.rs";
let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
let edit_description = "Implement from_pixels constructor and add tests.";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Introduce a new `from_pixels` constructor in Canvas and
also add tests for it in the same file.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"grep",
GrepToolInput {
regex: "mod\\s+tests".into(),
include_pattern: Some("font-kit/src/canvas.rs".into()),
offset: 0,
case_sensitive: false,
},
)],
),
message(User, [tool_result("tool_2", "grep", "No matches found")]),
message(
Assistant,
[tool_use(
"tool_3",
"grep",
GrepToolInput {
regex: "mod\\s+tests".into(),
include_pattern: Some("font-kit/src/**/*.rs".into()),
offset: 0,
case_sensitive: false,
},
)],
),
message(User, [tool_result("tool_3", "grep", "No matches found")]),
message(
Assistant,
[tool_use(
"tool_4",
"grep",
GrepToolInput {
regex: "#\\[test\\]".into(),
include_pattern: Some("font-kit/src/**/*.rs".into()),
offset: 0,
case_sensitive: false,
},
)],
),
message(
User,
[tool_result(
"tool_4",
"grep",
indoc! {"
Found 6 matches:
## Matches in font-kit/src/loaders/core_text.rs
### mod test L926-936
```
mod test {
use super::Font;
use crate::properties::{Stretch, Weight};
#[cfg(feature = \"source\")]
use crate::source::SystemSource;
static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
#[cfg(feature = \"source\")]
#[test]
```
55 lines remaining in ancestor node. Read the file to see all.
### mod test L947-951
```
}
#[test]
fn test_core_text_to_css_font_weight() {
// Exact matches
```
### mod test L959-963
```
}
#[test]
fn test_core_text_to_css_font_stretch() {
// Exact matches
```
## Matches in font-kit/src/loaders/freetype.rs
### mod test L1238-1248
```
mod test {
use crate::loaders::freetype::Font;
static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
#[test]
fn get_pcf_postscript_name() {
let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
}
```
1 lines remaining in ancestor node. Read the file to see all.
## Matches in font-kit/src/sources/core_text.rs
### mod test L265-275
```
mod test {
use crate::properties::{Stretch, Weight};
#[test]
fn test_css_to_core_text_font_weight() {
// Exact matches
assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
```
27 lines remaining in ancestor node. Read the file to see all.
### mod test L278-282
```
}
#[test]
fn test_css_to_core_text_font_stretch() {
// Exact matches
```
"},
)],
),
message(
Assistant,
[tool_use(
"tool_5",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
- The diff contains a new `from_pixels` constructor
- The diff contains new tests for the `from_pixels` constructor
"}),
},
);
}
fn message(
role: Role,
contents: impl IntoIterator<Item = MessageContent>,
) -> LanguageModelRequestMessage {
LanguageModelRequestMessage {
role,
content: contents.into_iter().collect(),
cache: false,
}
}
fn text(text: impl Into<String>) -> MessageContent {
MessageContent::Text(text.into())
}
fn lines(input: &str, range: Range<usize>) -> String {
input
.lines()
.skip(range.start)
.take(range.len())
.collect::<Vec<_>>()
.join("\n")
}
fn tool_use(
id: impl Into<Arc<str>>,
name: impl Into<Arc<str>>,
input: impl Serialize,
) -> MessageContent {
MessageContent::ToolUse(LanguageModelToolUse {
id: LanguageModelToolUseId::from(id.into()),
name: name.into(),
raw_input: serde_json::to_string_pretty(&input).unwrap(),
input: serde_json::to_value(input).unwrap(),
is_input_complete: true,
})
}
fn tool_result(
id: impl Into<Arc<str>>,
name: impl Into<Arc<str>>,
result: impl Into<Arc<str>>,
) -> MessageContent {
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: LanguageModelToolUseId::from(id.into()),
tool_name: name.into(),
is_error: false,
content: result.into(),
})
}
#[derive(Clone)]
struct EvalInput {
conversation: Vec<LanguageModelRequestMessage>,
input_path: PathBuf,
input_content: String,
edit_description: String,
assertion: EvalAssertion,
}
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
let mut evaluated_count = 0;
report_progress(evaluated_count, iterations);
let (tx, rx) = mpsc::channel();
// Cache the last message in the conversation, and run one instance of the eval so that
// all the next ones are cached.
eval.conversation.last_mut().unwrap().cache = true;
run_eval(eval.clone(), tx.clone());
let executor = gpui::background_executor();
for _ in 1..iterations {
let eval = eval.clone();
let tx = tx.clone();
executor.spawn(async move { run_eval(eval, tx) }).detach();
}
drop(tx);
let mut failed_count = 0;
let mut failed_evals = HashMap::default();
let mut errored_evals = HashMap::default();
let mut eval_outputs = Vec::new();
let mut cumulative_parser_metrics = EditParserMetrics::default();
while let Ok(output) = rx.recv() {
match output {
Ok(output) => {
cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
eval_outputs.push(output.clone());
if output.assertion.score < 80 {
failed_count += 1;
failed_evals
.entry(output.buffer_text.clone())
.or_insert(Vec::new())
.push(output);
}
}
Err(error) => {
failed_count += 1;
*errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
}
}
evaluated_count += 1;
report_progress(evaluated_count, iterations);
}
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
println!("Actual pass ratio: {}\n", actual_pass_ratio);
if actual_pass_ratio < expected_pass_ratio {
let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
errored_evals.sort_by_key(|(_, count)| Reverse(*count));
for (error, count) in errored_evals {
println!("Eval errored {} times. Error: {}", count, error);
}
let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
for (_buffer_output, failed_evals) in failed_evals {
let eval_output = failed_evals.first().unwrap();
println!("Eval failed {} times", failed_evals.len());
println!("{}", eval_output);
}
panic!(
"Actual pass ratio: {}\nExpected pass ratio: {}",
actual_pass_ratio, expected_pass_ratio
);
}
let mismatched_tag_ratio =
cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
if mismatched_tag_ratio > 0.02 {
for eval_output in eval_outputs {
println!("{}", eval_output);
}
panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
}
}
fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
let mut cx = TestAppContext::build(dispatcher, None);
let output = cx.executor().block_test(async {
let test = EditAgentTest::new(&mut cx).await;
test.eval(eval, &mut cx).await
});
tx.send(output).unwrap();
}
#[derive(Clone)]
struct EvalOutput {
assertion: EvalAssertionResult,
buffer_text: String,
edit_output: EditAgentOutput,
diff: String,
}
impl Display for EvalOutput {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Score: {:?}", self.assertion.score)?;
if let Some(message) = self.assertion.message.as_ref() {
writeln!(f, "Message: {}", message)?;
}
writeln!(f, "Diff:\n{}", self.diff)?;
writeln!(
f,
"Parser Metrics:\n{:#?}",
self.edit_output._parser_metrics
)?;
writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
Ok(())
}
}
fn report_progress(evaluated_count: usize, iterations: usize) {
print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
std::io::stdout().flush().unwrap();
}
struct EditAgentTest {
agent: EditAgent,
project: Entity<Project>,
judge_model: Arc<dyn LanguageModel>,
}
impl EditAgentTest {
async fn new(cx: &mut TestAppContext) -> Self {
cx.executor().allow_parking();
cx.update(settings::init);
cx.update(Project::init_settings);
cx.update(language::init);
cx.update(gpui_tokio::init);
cx.update(client::init_settings);
let fs = FakeFs::new(cx.executor().clone());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let (agent_model, judge_model) = cx
.update(|cx| {
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
cx.spawn(async move |cx| {
let agent_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
let judge_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
(agent_model.unwrap(), judge_model.unwrap())
})
})
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self {
agent: EditAgent::new(agent_model, action_log, Templates::new()),
project,
judge_model,
}
}
async fn load_model(
provider: &str,
id: &str,
cx: &mut AsyncApp,
) -> Result<Arc<dyn LanguageModel>> {
let (provider, model) = cx.update(|cx| {
let models = LanguageModelRegistry::read_global(cx);
let model = models
.available_models(cx)
.find(|model| model.provider_id().0 == provider && model.id().0 == id)
.unwrap();
let provider = models.provider(&model.provider_id()).unwrap();
(provider, model)
})?;
cx.update(|cx| provider.authenticate(cx))?.await?;
Ok(model)
}
async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
let path = self
.project
.read_with(cx, |project, cx| {
project.find_project_path(eval.input_path, cx)
})
.unwrap();
let buffer = self
.project
.update(cx, |project, cx| project.open_buffer(path, cx))
.await
.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text(eval.input_content.clone(), cx)
});
let (edit_output, _events) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
let edit_output = edit_output.await?;
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
let assertion = match eval.assertion {
EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
100
} else {
0
},
message: None,
},
EvalAssertion::JudgeDiff(assertions) => self
.judge_diff(&actual_diff, assertions, &cx.to_async())
.await
.context("failed comparing diffs")?,
};
Ok(EvalOutput {
assertion,
diff: actual_diff,
buffer_text,
edit_output,
})
}
async fn judge_diff(
&self,
diff: &str,
assertions: &'static str,
cx: &AsyncApp,
) -> Result<EvalAssertionResult> {
let prompt = DiffJudgeTemplate {
diff: diff.to_string(),
assertions,
}
.render(&self.agent.templates)
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
};
let mut response = self.judge_model.stream_completion_text(request, cx).await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionResult {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
enum EvalAssertion {
AssertEqual(String),
JudgeDiff(&'static str),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct EvalAssertionResult {
score: usize,
message: Option<String>,
}
#[derive(Serialize)]
pub struct DiffJudgeTemplate {
diff: String,
assertions: &'static str,
}
impl Template for DiffJudgeTemplate {
const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
}
fn strip_empty_lines(text: &str) -> String {
text.lines()
.filter(|line| !line.trim().is_empty())
.collect::<Vec<_>>()
.join("\n")
}

View file

@ -0,0 +1,328 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View file

@ -0,0 +1,374 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View file

@ -0,0 +1,378 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
handle_command_output(output)
}
fn handle_command_output(output: std::process::Output) -> Result<String> {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View file

@ -0,0 +1,374 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View file

@ -0,0 +1,339 @@
// font-kit/src/canvas.rs
//
// Copyright © 2018 The Pathfinder Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! An in-memory bitmap surface for glyph rasterization.
use lazy_static::lazy_static;
use pathfinder_geometry::rect::RectI;
use pathfinder_geometry::vector::Vector2I;
use std::cmp;
use std::fmt;
use crate::utils;
lazy_static! {
static ref BITMAP_1BPP_TO_8BPP_LUT: [[u8; 8]; 256] = {
let mut lut = [[0; 8]; 256];
for byte in 0..0x100 {
let mut value = [0; 8];
for bit in 0..8 {
if (byte & (0x80 >> bit)) != 0 {
value[bit] = 0xff;
}
}
lut[byte] = value
}
lut
};
}
/// An in-memory bitmap surface for glyph rasterization.
pub struct Canvas {
/// The raw pixel data.
pub pixels: Vec<u8>,
/// The size of the buffer, in pixels.
pub size: Vector2I,
/// The number of *bytes* between successive rows.
pub stride: usize,
/// The image format of the canvas.
pub format: Format,
}
impl Canvas {
/// Creates a new blank canvas with the given pixel size and format.
///
/// Stride is automatically calculated from width.
///
/// The canvas is initialized with transparent black (all values 0).
#[inline]
pub fn new(size: Vector2I, format: Format) -> Canvas {
Canvas::with_stride(
size,
size.x() as usize * format.bytes_per_pixel() as usize,
format,
)
}
/// Creates a new blank canvas with the given pixel size, stride (number of bytes between
/// successive rows), and format.
///
/// The canvas is initialized with transparent black (all values 0).
pub fn with_stride(size: Vector2I, stride: usize, format: Format) -> Canvas {
Canvas {
pixels: vec![0; stride * size.y() as usize],
size,
stride,
format,
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_canvas(&mut self, src: &Canvas) {
self.blit_from(
Vector2I::default(),
&src.pixels,
src.size,
src.stride,
src.format,
)
}
/// Blits to a rectangle with origin at `dst_point` and size according to `src_size`.
/// If the target area overlaps the boundaries of the canvas, only the drawable region is blitted.
/// `dst_point` and `src_size` are specified in pixels. `src_stride` is specified in bytes.
/// `src_stride` must be equal or larger than the actual data length.
#[allow(dead_code)]
pub(crate) fn blit_from(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
src_format: Format,
) {
assert_eq!(
src_stride * src_size.y() as usize,
src_bytes.len(),
"Number of pixels in src_bytes does not match stride and size."
);
assert!(
src_stride >= src_size.x() as usize * src_format.bytes_per_pixel() as usize,
"src_stride must be >= than src_size.x()"
);
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
match (self.format, src_format) {
(Format::A8, Format::A8)
| (Format::Rgb24, Format::Rgb24)
| (Format::Rgba32, Format::Rgba32) => {
self.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::A8, Format::Rgb24) => {
self.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::A8) => {
self.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::Rgba32) => self
.blit_from_with::<BlitRgba32ToRgb24>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::Rgb24) => self
.blit_from_with::<BlitRgb24ToRgba32>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::A8) | (Format::A8, Format::Rgba32) => unimplemented!(),
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_bitmap_1bpp(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
) {
if self.format != Format::A8 {
unimplemented!()
}
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
let size = dst_rect.size();
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
let dest_row_stride = size.x() as usize * dest_bytes_per_pixel;
let src_row_stride = utils::div_round_up(size.x() as usize, 8);
for y in 0..size.y() {
let (dest_row_start, src_row_start) = (
(y + dst_rect.origin_y()) as usize * self.stride
+ dst_rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + dest_row_stride;
let src_row_end = src_row_start + src_row_stride;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
for x in 0..src_row_stride {
let pattern = &BITMAP_1BPP_TO_8BPP_LUT[src_row_pixels[x] as usize];
let dest_start = x * 8;
let dest_end = cmp::min(dest_start + 8, dest_row_stride);
let src = &pattern[0..(dest_end - dest_start)];
dest_row_pixels[dest_start..dest_end].clone_from_slice(src);
}
}
}
/// Blits to area `rect` using the data given in the buffer `src_bytes`.
/// `src_stride` must be specified in bytes.
/// The dimensions of `rect` must be in pixels.
fn blit_from_with<B: Blit>(
&mut self,
rect: RectI,
src_bytes: &[u8],
src_stride: usize,
src_format: Format,
) {
let src_bytes_per_pixel = src_format.bytes_per_pixel() as usize;
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
for y in 0..rect.height() {
let (dest_row_start, src_row_start) = (
(y + rect.origin_y()) as usize * self.stride
+ rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + rect.width() as usize * dest_bytes_per_pixel;
let src_row_end = src_row_start + rect.width() as usize * src_bytes_per_pixel;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
B::blit(dest_row_pixels, src_row_pixels)
}
}
}
impl fmt::Debug for Canvas {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Canvas")
.field("pixels", &self.pixels.len()) // Do not dump a vector content.
.field("size", &self.size)
.field("stride", &self.stride)
.field("format", &self.format)
.finish()
}
}
/// The image format for the canvas.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Premultiplied R8G8B8A8, little-endian.
Rgba32,
/// R8G8B8, little-endian.
Rgb24,
/// A8.
A8,
}
impl Format {
/// Returns the number of bits per pixel that this image format corresponds to.
#[inline]
pub fn bits_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 32,
Format::Rgb24 => 24,
Format::A8 => 8,
}
}
/// Returns the number of color channels per pixel that this image format corresponds to.
#[inline]
pub fn components_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 4,
Format::Rgb24 => 3,
Format::A8 => 1,
}
}
/// Returns the number of bits per color channel that this image format contains.
#[inline]
pub fn bits_per_component(self) -> u8 {
self.bits_per_pixel() / self.components_per_pixel()
}
/// Returns the number of bytes per pixel that this image format corresponds to.
#[inline]
pub fn bytes_per_pixel(self) -> u8 {
self.bits_per_pixel() / 8
}
}
/// The antialiasing strategy that should be used when rasterizing glyphs.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum RasterizationOptions {
/// "Black-and-white" rendering. Each pixel is either entirely on or off.
Bilevel,
/// Grayscale antialiasing. Only one channel is used.
GrayscaleAa,
/// Subpixel RGB antialiasing, for LCD screens.
SubpixelAa,
}
trait Blit {
fn blit(dest: &mut [u8], src: &[u8]);
}
struct BlitMemcpy;
impl Blit for BlitMemcpy {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
dest.clone_from_slice(src)
}
}
struct BlitRgb24ToA8;
impl Blit for BlitRgb24ToA8 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.iter_mut().zip(src.chunks(3)) {
*dest = src[1]
}
}
}
struct BlitA8ToRgb24;
impl Blit for BlitA8ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(3).zip(src.iter()) {
dest[0] = *src;
dest[1] = *src;
dest[2] = *src;
}
}
}
struct BlitRgba32ToRgb24;
impl Blit for BlitRgba32ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.chunks_mut(3).zip(src.chunks(4)) {
dest.copy_from_slice(&src[0..3])
}
}
}
struct BlitRgb24ToRgba32;
impl Blit for BlitRgb24ToRgba32 {
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(4).zip(src.chunks(3)) {
dest[0] = src[0];
dest[1] = src[1];
dest[2] = src[2];
dest[3] = 255;
}
}
}

View file

@ -282,7 +282,7 @@ pub struct EditFileToolCard {
}
impl EditFileToolCard {
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
pub fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
let editor = cx.new(|cx| {
let mut editor = Editor::new(
@ -323,7 +323,7 @@ impl EditFileToolCard {
}
}
fn set_diff(
pub fn set_diff(
&mut self,
path: Arc<Path>,
old_text: String,
@ -343,6 +343,7 @@ impl EditFileToolCard {
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
multibuffer.clear(cx);
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,

View file

@ -0,0 +1,339 @@
use crate::{
Templates,
edit_agent::{EditAgent, EditAgentOutputEvent},
edit_file_tool::EditFileToolCard,
schema::json_schema_for,
};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
use futures::StreamExt;
use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc;
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
pub struct StreamingEditFileTool;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct StreamingEditFileToolInput {
/// A one-line, user-friendly markdown description of the edit. This will be
/// shown in the UI and also passed to another model to perform the edit.
///
/// Be terse, but also descriptive in what you want to achieve with this
/// edit. Avoid generic instructions.
///
/// NEVER mention the file path in this description.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
///
/// Make sure to include this field before all the others in the input object
/// so that we can display it immediately.
pub display_description: String,
/// The full path of the file to modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - backend
/// - frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with root-1. Without that, the path
/// would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
path: String,
#[serde(default)]
display_description: String,
}
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for StreamingEditFileTool {
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("streaming_edit_file_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<StreamingEditFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!(
"Path {} not found in project",
input.path.display()
)))
.into();
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
};
let exists = worktree.update(cx, |worktree, cx| {
worktree.file_exists(&project_path.path, cx)
});
let card = window.and_then(|window| {
window
.update(cx, |_, window, cx| {
cx.new(|cx| {
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
})
})
.ok()
});
let card_clone = card.clone();
let messages = messages.to_vec();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
if !exists.await? {
return Err(anyhow!("{} not found", input.path.display()));
}
let model = cx
.update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
.context("default model not set")?
.model;
let edit_agent = EditAgent::new(model, action_log, Templates::new());
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?
.await?;
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let old_text = cx
.background_spawn({
let old_snapshot = old_snapshot.clone();
async move { old_snapshot.text() }
})
.await;
let (output, mut events) = edit_agent.edit(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
);
let mut hallucinated_old_text = false;
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited => {
if let Some(card) = card_clone.as_ref() {
let new_snapshot =
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
})
.await;
card.update(cx, |card, cx| {
card.set_diff(
project_path.path.clone(),
old_text.clone(),
new_text,
cx,
);
})
.log_err();
}
}
EditAgentOutputEvent::HallucinatedOldText(_) => hallucinated_old_text = true,
}
}
output.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
});
let diff = cx.background_spawn(async move {
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
});
let (new_text, diff) = futures::join!(new_text, diff);
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
})
.log_err();
}
let input_path = input.path.display();
if diff.is_empty() {
if hallucinated_old_text {
Err(anyhow!(formatdoc! {"
Some edits were produced but none of them could be applied.
Read the relevant sections of {input_path} again so that
I can perform the requested edits.
"}))
} else {
Ok("No edits were made.".to_string())
}
} else {
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
}
});
ToolResult {
output: task,
card: card.map(AnyToolCard::from),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn still_streaming_ui_text_with_path() {
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"src/main.rs"
);
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
#[test]
fn still_streaming_ui_text_with_null() {
let input = serde_json::Value::Null;
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
}

View file

@ -0,0 +1,8 @@
This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files.
Before using this tool:
1. Use the `read_file` tool to understand the file's contents and context
2. Verify the directory path is correct (only applicable when creating new files):
- Use the `list_directory` tool to verify the parent directory exists and is the correct location

View file

@ -0,0 +1,32 @@
use anyhow::Result;
use handlebars::Handlebars;
use rust_embed::RustEmbed;
use serde::Serialize;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "src/templates"]
#[include = "*.hbs"]
struct Assets;
pub struct Templates(Handlebars<'static>);
impl Templates {
pub fn new() -> Arc<Self> {
let mut handlebars = Handlebars::new();
handlebars.register_embed_templates::<Assets>().unwrap();
handlebars.register_escape_fn(|text| text.into());
Arc::new(Self(handlebars))
}
}
pub trait Template: Sized {
const TEMPLATE_NAME: &'static str;
fn render(&self, templates: &Templates) -> Result<String>
where
Self: Serialize + Sized,
{
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
}
}

View file

@ -0,0 +1,23 @@
You are an expert coder, and have been tasked with looking at the following diff:
<diff>
{{diff}}
</diff>
Evaluate the following assertions:
<assertions>
{{assertions}}
</assertions>
You must respond with a short analysis and a score between 0 and 100, where:
- 0 means no assertions pass
- 100 means all the assertions pass perfectly
<analysis>
- Assertion 1: one line describing why the first assertion passes or fails (even partially)
- Assertion 2: one line describing why the second assertion passes or fails (even partially)
- ...
- Assertion N: one line describing why the Nth assertion passes or fails (even partially)
</analysis>
<score>YOUR FINAL SCORE HERE</score>

View file

@ -0,0 +1,49 @@
You are an expert text editor and your task is to produce a series of edits to a file given a description of the changes you need to make.
You MUST respond with a series of edits to that one file in the following format:
```
<edits>
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
<old_text>
OLD TEXT 2 HERE
</old_text>
<new_text>
NEW TEXT 2 HERE
</new_text>
<old_text>
OLD TEXT 3 HERE
</old_text>
<new_text>
NEW TEXT 3 HERE
</new_text>
</edits>
```
Rules for editing:
- `old_text` represents lines in the input file that will be replaced with `new_text`. `old_text` MUST exactly match the existing file content, character for character, including indentation.
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
- If you want to replace many occurrences of the same text, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
- When reporting multiple edits, each edit assumes the previous one has already been applied! Therefore, you must ensure `old_text` doesn't reference text that has already been modified by a previous edit.
- Don't explain the edits, just report them.
- Only edit the file specified in `<file_to_edit>` and NEVER include edits to other files!
- If you open an <old_text> tag, you MUST close it using </old_text>
- If you open an <new_text> tag, you MUST close it using </new_text>
<file_to_edit>
{{path}}
</file_to_edit>
<edit_description>
{{edit_description}}
</edit_description>