Introduce a new StreamingEditFileTool (#29733)

This pull request introduces a new tool for streaming edits. The
short-term goal is for this tool to replace the existing `EditFileTool`,
but we want to get this out the door as soon as possible so that we can
start testing it.

`StreamingEditFileTool` is mutually exclusive with `EditFileTool`. It
will be enabled by default for anyone who has the `agent-stream-edits`
feature flag, as well as people that set `assistant.stream_edits` to
`true` in their settings.

### Implementation

Streaming is achieved by requesting a completion while the `edit_file`
tool gets called. We invoke the model by taking the existing
conversation with the agent and appending a prompt specifically tailored
for editing. In that prompt, we ask the model to produce a stream of
`<old_text>`/`<new_text>` tags. As the model streams text in, we
incrementally parse it and start editing as soon as we can.

### Evals

Note that, as part of this pull request, I also defined some new evals
that I used to drive the behavior of the recursive LLM call. To run
them, use this command:

```bash
cargo test --package=assistant_tools --features eval -- eval_extract_handle_command_output
```

Or comment out the `#[cfg_attr(not(feature = "eval"), ignore)]` macro.

I recommend running them one at a time, because right now we don't
really have a way of orchestrating of all these evals. I think we should
invest into that effort once the new agent panel goes live.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-05-01 17:37:43 +02:00 committed by GitHub
parent e3a2d52472
commit f891dfb358
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 49077 additions and 20 deletions

16
Cargo.lock generated
View file

@ -704,7 +704,9 @@ dependencies = [
name = "assistant_tools"
version = "0.1.0"
dependencies = [
"aho-corasick",
"anyhow",
"assistant_settings",
"assistant_tool",
"buffer_diff",
"chrono",
@ -712,25 +714,36 @@ dependencies = [
"clock",
"collections",
"component",
"derive_more",
"editor",
"feature_flags",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"html_to_markdown",
"http_client",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"language_models",
"linkme",
"open",
"pretty_assertions",
"project",
"rand 0.8.5",
"regex",
"reqwest_client",
"rust-embed",
"schemars",
"serde",
"serde_json",
"settings",
"smallvec",
"streaming_diff",
"strsim",
"task",
"tempfile",
"terminal",
@ -5003,6 +5016,7 @@ dependencies = [
"node_runtime",
"pathdiff",
"paths",
"pretty_assertions",
"project",
"prompt_store",
"regex",
@ -6367,6 +6381,7 @@ dependencies = [
"log",
"pest",
"pest_derive",
"rust-embed",
"serde",
"serde_json",
"thiserror 1.0.69",
@ -18032,6 +18047,7 @@ dependencies = [
"getrandom 0.2.15",
"getrandom 0.3.2",
"gimli",
"handlebars 4.5.0",
"hashbrown 0.14.5",
"hashbrown 0.15.2",
"heck 0.4.1",

View file

@ -657,6 +657,8 @@
},
// When enabled, the agent can run potentially destructive actions without asking for your confirmation.
"always_allow_tool_actions": false,
// When enabled, the agent will stream edits.
"stream_edits": false,
"default_profile": "write",
"profiles": {
"ask": {

View file

@ -6,7 +6,7 @@ use ::open_ai::Model as OpenAiModel;
use anthropic::Model as AnthropicModel;
use anyhow::{Result, bail};
use deepseek::Model as DeepseekModel;
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
use feature_flags::{AgentStreamEditsFeatureFlag, Assistant2FeatureFlag, FeatureFlagAppExt};
use gpui::{App, Pixels};
use indexmap::IndexMap;
use language_model::{CloudModel, LanguageModel};
@ -87,9 +87,14 @@ pub struct AssistantSettings {
pub profiles: IndexMap<AgentProfileId, AgentProfile>,
pub always_allow_tool_actions: bool,
pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
pub stream_edits: bool,
}
impl AssistantSettings {
pub fn stream_edits(&self, cx: &App) -> bool {
cx.has_flag::<AgentStreamEditsFeatureFlag>() || self.stream_edits
}
pub fn are_live_diffs_enabled(&self, cx: &App) -> bool {
if cx.has_flag::<Assistant2FeatureFlag>() {
return false;
@ -218,6 +223,7 @@ impl AssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
@ -245,6 +251,7 @@ impl AssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
None => AssistantSettingsContentV2::default(),
}
@ -495,6 +502,7 @@ impl Default for VersionedAssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
})
}
}
@ -550,6 +558,10 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: "primary_screen"
notify_when_agent_waiting: Option<NotifyWhenAgentWaiting>,
/// Whether to stream edits from the agent as they are received.
///
/// Default: false
stream_edits: Option<bool>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@ -712,6 +724,7 @@ impl Settings for AssistantSettings {
&mut settings.notify_when_agent_waiting,
value.notify_when_agent_waiting,
);
merge(&mut settings.stream_edits, value.stream_edits);
merge(&mut settings.default_profile, value.default_profile);
if let Some(profiles) = value.profiles {
@ -843,6 +856,7 @@ mod tests {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
)),
}

View file

@ -11,16 +11,24 @@ workspace = true
[lib]
path = "src/assistant_tools.rs"
[features]
eval = []
[dependencies]
aho-corasick.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_settings.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
editor.workspace = true
derive_more.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
http_client.workspace = true
indoc.workspace = true
@ -31,9 +39,14 @@ linkme.workspace = true
open.workspace = true
project.workspace = true
regex.workspace = true
rust-embed.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smallvec.workspace = true
streaming_diff.workspace = true
strsim.workspace = true
task.workspace = true
terminal.workspace = true
terminal_view.workspace = true
@ -49,10 +62,15 @@ client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
gpui_tokio.workspace = true
fs = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
language_models.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
pretty_assertions.workspace = true
reqwest_client.workspace = true
settings = { workspace = true, features = ["test-support"] }
task = { workspace = true, features = ["test-support"]}
tempfile.workspace = true

View file

@ -7,6 +7,7 @@ mod create_directory_tool;
mod create_file_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_agent;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
@ -19,7 +20,9 @@ mod read_file_tool;
mod rename_tool;
mod replace;
mod schema;
mod streaming_edit_file_tool;
mod symbol_info_tool;
mod templates;
mod terminal_tool;
mod thinking_tool;
mod ui;
@ -27,14 +30,19 @@ mod web_search_tool;
use std::sync::Arc;
use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
use gpui::App;
use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool;
use settings::{Settings, SettingsStore};
use web_search_tool::WebSearchTool;
pub(crate) use templates::*;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
@ -52,6 +60,7 @@ use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
use crate::read_file_tool::ReadFileTool;
use crate::rename_tool::RenameTool;
use crate::streaming_edit_file_tool::StreamingEditFileTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
@ -71,7 +80,6 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(EditFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
@ -88,6 +96,12 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
register_edit_file_tool(cx);
cx.observe_flag::<AgentStreamEditsFeatureFlag, _>(|_, cx| register_edit_file_tool(cx))
.detach();
cx.observe_global::<SettingsStore>(register_edit_file_tool)
.detach();
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
@ -108,6 +122,19 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
.detach();
}
fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx);
registry.unregister_tool(EditFileTool);
registry.unregister_tool(StreamingEditFileTool);
if AssistantSettings::get_global(cx).stream_edits(cx) {
registry.register_tool(StreamingEditFileTool);
} else {
registry.register_tool(EditFileTool);
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -146,6 +173,7 @@ mod tests {
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
settings::init(cx);
AssistantSettings::register(cx);
let client = Client::new(
Arc::new(FakeSystemClock::new()),

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,408 @@
use derive_more::{Add, AddAssign};
use smallvec::SmallVec;
use std::{cmp, mem, ops::Range};
const OLD_TEXT_END_TAG: &str = "</old_text>";
const NEW_TEXT_END_TAG: &str = "</new_text>";
const END_TAG_LEN: usize = OLD_TEXT_END_TAG.len();
const _: () = debug_assert!(OLD_TEXT_END_TAG.len() == NEW_TEXT_END_TAG.len());
#[derive(Debug)]
pub enum EditParserEvent {
OldText(String),
NewTextChunk { chunk: String, done: bool },
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Add, AddAssign)]
pub struct EditParserMetrics {
pub tags: usize,
pub mismatched_tags: usize,
}
#[derive(Debug)]
pub struct EditParser {
state: EditParserState,
buffer: String,
metrics: EditParserMetrics,
}
#[derive(Debug, PartialEq)]
enum EditParserState {
Pending,
WithinOldText,
AfterOldText,
WithinNewText { start: bool },
}
impl EditParser {
pub fn new() -> Self {
EditParser {
state: EditParserState::Pending,
buffer: String::new(),
metrics: EditParserMetrics::default(),
}
}
pub fn push(&mut self, chunk: &str) -> SmallVec<[EditParserEvent; 1]> {
self.buffer.push_str(chunk);
let mut edit_events = SmallVec::new();
loop {
match &mut self.state {
EditParserState::Pending => {
if let Some(start) = self.buffer.find("<old_text>") {
self.buffer.drain(..start + "<old_text>".len());
self.state = EditParserState::WithinOldText;
} else {
break;
}
}
EditParserState::WithinOldText => {
if let Some(tag_range) = self.find_end_tag() {
let mut start = 0;
if self.buffer.starts_with('\n') {
start = 1;
}
let mut old_text = self.buffer[start..tag_range.start].to_string();
if old_text.ends_with('\n') {
old_text.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != OLD_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::AfterOldText;
edit_events.push(EditParserEvent::OldText(old_text));
} else {
break;
}
}
EditParserState::AfterOldText => {
if let Some(start) = self.buffer.find("<new_text>") {
self.buffer.drain(..start + "<new_text>".len());
self.state = EditParserState::WithinNewText { start: true };
} else {
break;
}
}
EditParserState::WithinNewText { start } => {
if !self.buffer.is_empty() {
if *start && self.buffer.starts_with('\n') {
self.buffer.remove(0);
}
*start = false;
}
if let Some(tag_range) = self.find_end_tag() {
let mut chunk = self.buffer[..tag_range.start].to_string();
if chunk.ends_with('\n') {
chunk.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != NEW_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::Pending;
edit_events.push(EditParserEvent::NewTextChunk { chunk, done: true });
} else {
let mut end_prefixes = (1..END_TAG_LEN)
.flat_map(|i| [&NEW_TEXT_END_TAG[..i], &OLD_TEXT_END_TAG[..i]])
.chain(["\n"]);
if end_prefixes.all(|prefix| !self.buffer.ends_with(&prefix)) {
edit_events.push(EditParserEvent::NewTextChunk {
chunk: mem::take(&mut self.buffer),
done: false,
});
}
break;
}
}
}
}
edit_events
}
fn find_end_tag(&self) -> Option<Range<usize>> {
let old_text_end_tag_ix = self.buffer.find(OLD_TEXT_END_TAG);
let new_text_end_tag_ix = self.buffer.find(NEW_TEXT_END_TAG);
let start_ix = if let Some((old_text_ix, new_text_ix)) =
old_text_end_tag_ix.zip(new_text_end_tag_ix)
{
cmp::min(old_text_ix, new_text_ix)
} else {
old_text_end_tag_ix.or(new_text_end_tag_ix)?
};
Some(start_ix..start_ix + END_TAG_LEN)
}
pub fn finish(self) -> EditParserMetrics {
self.metrics
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use rand::prelude::*;
use std::cmp;
#[gpui::test(iterations = 1000)]
fn test_single_edit(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>original</old_text><new_text>updated</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "original".to_string(),
new_text: "updated".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_multiple_edits(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
<old_text>
first old
</old_text><new_text>first new</new_text>
<old_text>second old</old_text><new_text>
second new
</new_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "first old".to_string(),
new_text: "first new".to_string(),
},
Edit {
old_text: "second old".to_string(),
new_text: "second new".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_edits_with_extra_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
ignore this <old_text>
content</old_text>extra stuff<new_text>updated content</new_text>trailing data
more text <old_text>second item
</old_text>middle text<new_text>modified second item</new_text>end
<old_text>third case</old_text><new_text>improved third case</new_text> with trailing text
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "content".to_string(),
new_text: "updated content".to_string(),
},
Edit {
old_text: "second item".to_string(),
new_text: "modified second item".to_string(),
},
Edit {
old_text: "third case".to_string(),
new_text: "improved third case".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 6,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_nested_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>code with <tag>nested</tag> elements</old_text><new_text>new <code>content</code></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "code with <tag>nested</tag> elements".to_string(),
new_text: "new <code>content</code>".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_empty_old_and_new_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text></old_text><new_text></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "".to_string(),
new_text: "".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 100)]
fn test_multiline_content(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>line1\nline2\nline3</old_text><new_text>line1\nmodified line2\nline3</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "line1\nline2\nline3".to_string(),
new_text: "line1\nmodified line2\nline3".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_mismatched_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
// Reduced from an actual Sonnet 3.7 output
indoc! {"
<old_text>
a
b
c
</new_text>
<new_text>
a
B
c
</old_text>
<old_text>
d
e
f
</new_text>
<new_text>
D
e
F
</old_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "a\nb\nc".to_string(),
new_text: "a\nB\nc".to_string(),
},
Edit {
old_text: "d\ne\nf".to_string(),
new_text: "D\ne\nF".to_string(),
}
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 4
}
);
}
#[derive(Default, Debug, PartialEq, Eq)]
struct Edit {
old_text: String,
new_text: String,
}
fn parse_random_chunks(input: &str, parser: &mut EditParser, rng: &mut StdRng) -> Vec<Edit> {
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
chunk_indices.sort();
chunk_indices.push(input.len());
let mut pending_edit = Edit::default();
let mut edits = Vec::new();
let mut last_ix = 0;
for chunk_ix in chunk_indices {
for event in parser.push(&input[last_ix..chunk_ix]) {
match event {
EditParserEvent::OldText(old_text) => {
pending_edit.old_text = old_text;
}
EditParserEvent::NewTextChunk { chunk, done } => {
pending_edit.new_text.push_str(&chunk);
if done {
edits.push(pending_edit);
pending_edit = Edit::default();
}
}
}
}
last_ix = chunk_ix;
}
edits
}
}

View file

@ -0,0 +1,889 @@
use super::*;
use crate::{
ReadFileToolInput, grep_tool::GrepToolInput,
streaming_edit_file_tool::StreamingEditFileToolInput,
};
use Role::*;
use anyhow::{Context, anyhow};
use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::{
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
};
use project::Project;
use rand::prelude::*;
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::{
cmp::Reverse,
fmt::{self, Display},
io::Write as _,
sync::mpsc,
};
use util::path;
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_extract_handle_command_output() {
let input_file_path = "root/blame.rs";
let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Read the `{input_file_path}` file and extract a method in
the final stanza of `run_git_blame` to deal with command failures,
call it `handle_command_output` and take the std::process::Output as the only parameter.
Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_delete_run_git_blame() {
let input_file_path = "root/blame.rs";
let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
let edit_description = "Delete the `run_git_blame` function.";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Read the `{input_file_path}` file and delete `run_git_blame`. Just that
one function, not its usages.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
let input_file_path = "root/lib.rs";
let input_file_content =
include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
Use `ureq` to download the SDK for the current platform and architecture.
Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
that's inside of the archive.
Don't re-download the SDK if that executable already exists.
Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
Here are the available wasi-sdk assets:
- wasi-sdk-25.0-x86_64-macos.tar.gz
- wasi-sdk-25.0-arm64-macos.tar.gz
- wasi-sdk-25.0-x86_64-linux.tar.gz
- wasi-sdk-25.0-arm64-linux.tar.gz
- wasi-sdk-25.0-x86_64-linux.tar.gz
- wasi-sdk-25.0-arm64-linux.tar.gz
- wasi-sdk-25.0-x86_64-windows.tar.gz
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: Some(971),
end_line: Some(1050),
},
)],
),
message(
User,
[tool_result(
"tool_1",
"read_file",
lines(input_file_content, 971..1050),
)],
),
message(
Assistant,
[tool_use(
"tool_2",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: Some(1050),
end_line: Some(1100),
},
)],
),
message(
User,
[tool_result(
"tool_2",
"read_file",
lines(input_file_content, 1050..1100),
)],
),
message(
Assistant,
[tool_use(
"tool_3",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: Some(1100),
end_line: Some(1150),
},
)],
),
message(
User,
[tool_result(
"tool_3",
"read_file",
lines(input_file_content, 1100..1150),
)],
),
message(
Assistant,
[tool_use(
"tool_4",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
- The compile_parser_to_wasm method has been changed to use wasi-sdk
- ureq is used to download the SDK for current platform and architecture
"}),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_disable_cursor_blinking() {
let input_file_path = "root/editor.rs";
let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
let edit_description = "Comment out the call to `BlinkManager::enable`";
eval(
100,
0.6, // TODO: make this eval better
EvalInput {
conversation: vec![
message(User, [text("Let's research how to cursor blinking works.")]),
message(
Assistant,
[tool_use(
"tool_1",
"grep",
GrepToolInput {
regex: "blink".into(),
include_pattern: None,
offset: 0,
case_sensitive: false,
},
)],
),
message(
User,
[tool_result(
"tool_1",
"grep",
[
lines(input_file_content, 100..400),
lines(input_file_content, 800..1300),
lines(input_file_content, 1600..2000),
lines(input_file_content, 5000..5500),
lines(input_file_content, 8000..9000),
lines(input_file_content, 18455..18470),
lines(input_file_content, 20000..20500),
lines(input_file_content, 21000..21300),
]
.join("Match found:\n\n"),
)],
),
message(
User,
[text(indoc! {"
Comment out the lines that interact with the BlinkManager.
Keep the outer `update` blocks, but comments everything that's inside (including if statements).
Don't add additional comments.
"})],
),
message(
Assistant,
[tool_use(
"tool_4",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_from_pixels_constructor() {
let input_file_path = "root/canvas.rs";
let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
let edit_description = "Implement from_pixels constructor and add tests.";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Introduce a new `from_pixels` constructor in Canvas and
also add tests for it in the same file.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"grep",
GrepToolInput {
regex: "mod\\s+tests".into(),
include_pattern: Some("font-kit/src/canvas.rs".into()),
offset: 0,
case_sensitive: false,
},
)],
),
message(User, [tool_result("tool_2", "grep", "No matches found")]),
message(
Assistant,
[tool_use(
"tool_3",
"grep",
GrepToolInput {
regex: "mod\\s+tests".into(),
include_pattern: Some("font-kit/src/**/*.rs".into()),
offset: 0,
case_sensitive: false,
},
)],
),
message(User, [tool_result("tool_3", "grep", "No matches found")]),
message(
Assistant,
[tool_use(
"tool_4",
"grep",
GrepToolInput {
regex: "#\\[test\\]".into(),
include_pattern: Some("font-kit/src/**/*.rs".into()),
offset: 0,
case_sensitive: false,
},
)],
),
message(
User,
[tool_result(
"tool_4",
"grep",
indoc! {"
Found 6 matches:
## Matches in font-kit/src/loaders/core_text.rs
### mod test L926-936
```
mod test {
use super::Font;
use crate::properties::{Stretch, Weight};
#[cfg(feature = \"source\")]
use crate::source::SystemSource;
static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
#[cfg(feature = \"source\")]
#[test]
```
55 lines remaining in ancestor node. Read the file to see all.
### mod test L947-951
```
}
#[test]
fn test_core_text_to_css_font_weight() {
// Exact matches
```
### mod test L959-963
```
}
#[test]
fn test_core_text_to_css_font_stretch() {
// Exact matches
```
## Matches in font-kit/src/loaders/freetype.rs
### mod test L1238-1248
```
mod test {
use crate::loaders::freetype::Font;
static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
#[test]
fn get_pcf_postscript_name() {
let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
}
```
1 lines remaining in ancestor node. Read the file to see all.
## Matches in font-kit/src/sources/core_text.rs
### mod test L265-275
```
mod test {
use crate::properties::{Stretch, Weight};
#[test]
fn test_css_to_core_text_font_weight() {
// Exact matches
assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
```
27 lines remaining in ancestor node. Read the file to see all.
### mod test L278-282
```
}
#[test]
fn test_css_to_core_text_font_stretch() {
// Exact matches
```
"},
)],
),
message(
Assistant,
[tool_use(
"tool_5",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
- The diff contains a new `from_pixels` constructor
- The diff contains new tests for the `from_pixels` constructor
"}),
},
);
}
fn message(
role: Role,
contents: impl IntoIterator<Item = MessageContent>,
) -> LanguageModelRequestMessage {
LanguageModelRequestMessage {
role,
content: contents.into_iter().collect(),
cache: false,
}
}
fn text(text: impl Into<String>) -> MessageContent {
MessageContent::Text(text.into())
}
fn lines(input: &str, range: Range<usize>) -> String {
input
.lines()
.skip(range.start)
.take(range.len())
.collect::<Vec<_>>()
.join("\n")
}
fn tool_use(
id: impl Into<Arc<str>>,
name: impl Into<Arc<str>>,
input: impl Serialize,
) -> MessageContent {
MessageContent::ToolUse(LanguageModelToolUse {
id: LanguageModelToolUseId::from(id.into()),
name: name.into(),
raw_input: serde_json::to_string_pretty(&input).unwrap(),
input: serde_json::to_value(input).unwrap(),
is_input_complete: true,
})
}
fn tool_result(
id: impl Into<Arc<str>>,
name: impl Into<Arc<str>>,
result: impl Into<Arc<str>>,
) -> MessageContent {
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: LanguageModelToolUseId::from(id.into()),
tool_name: name.into(),
is_error: false,
content: result.into(),
})
}
#[derive(Clone)]
struct EvalInput {
conversation: Vec<LanguageModelRequestMessage>,
input_path: PathBuf,
input_content: String,
edit_description: String,
assertion: EvalAssertion,
}
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
let mut evaluated_count = 0;
report_progress(evaluated_count, iterations);
let (tx, rx) = mpsc::channel();
// Cache the last message in the conversation, and run one instance of the eval so that
// all the next ones are cached.
eval.conversation.last_mut().unwrap().cache = true;
run_eval(eval.clone(), tx.clone());
let executor = gpui::background_executor();
for _ in 1..iterations {
let eval = eval.clone();
let tx = tx.clone();
executor.spawn(async move { run_eval(eval, tx) }).detach();
}
drop(tx);
let mut failed_count = 0;
let mut failed_evals = HashMap::default();
let mut errored_evals = HashMap::default();
let mut eval_outputs = Vec::new();
let mut cumulative_parser_metrics = EditParserMetrics::default();
while let Ok(output) = rx.recv() {
match output {
Ok(output) => {
cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
eval_outputs.push(output.clone());
if output.assertion.score < 80 {
failed_count += 1;
failed_evals
.entry(output.buffer_text.clone())
.or_insert(Vec::new())
.push(output);
}
}
Err(error) => {
failed_count += 1;
*errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
}
}
evaluated_count += 1;
report_progress(evaluated_count, iterations);
}
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
println!("Actual pass ratio: {}\n", actual_pass_ratio);
if actual_pass_ratio < expected_pass_ratio {
let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
errored_evals.sort_by_key(|(_, count)| Reverse(*count));
for (error, count) in errored_evals {
println!("Eval errored {} times. Error: {}", count, error);
}
let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
for (_buffer_output, failed_evals) in failed_evals {
let eval_output = failed_evals.first().unwrap();
println!("Eval failed {} times", failed_evals.len());
println!("{}", eval_output);
}
panic!(
"Actual pass ratio: {}\nExpected pass ratio: {}",
actual_pass_ratio, expected_pass_ratio
);
}
let mismatched_tag_ratio =
cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
if mismatched_tag_ratio > 0.02 {
for eval_output in eval_outputs {
println!("{}", eval_output);
}
panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
}
}
fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
let mut cx = TestAppContext::build(dispatcher, None);
let output = cx.executor().block_test(async {
let test = EditAgentTest::new(&mut cx).await;
test.eval(eval, &mut cx).await
});
tx.send(output).unwrap();
}
#[derive(Clone)]
struct EvalOutput {
assertion: EvalAssertionResult,
buffer_text: String,
edit_output: EditAgentOutput,
diff: String,
}
impl Display for EvalOutput {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Score: {:?}", self.assertion.score)?;
if let Some(message) = self.assertion.message.as_ref() {
writeln!(f, "Message: {}", message)?;
}
writeln!(f, "Diff:\n{}", self.diff)?;
writeln!(
f,
"Parser Metrics:\n{:#?}",
self.edit_output._parser_metrics
)?;
writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
Ok(())
}
}
fn report_progress(evaluated_count: usize, iterations: usize) {
print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
std::io::stdout().flush().unwrap();
}
struct EditAgentTest {
agent: EditAgent,
project: Entity<Project>,
judge_model: Arc<dyn LanguageModel>,
}
impl EditAgentTest {
async fn new(cx: &mut TestAppContext) -> Self {
cx.executor().allow_parking();
cx.update(settings::init);
cx.update(Project::init_settings);
cx.update(language::init);
cx.update(gpui_tokio::init);
cx.update(client::init_settings);
let fs = FakeFs::new(cx.executor().clone());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let (agent_model, judge_model) = cx
.update(|cx| {
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
cx.spawn(async move |cx| {
let agent_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
let judge_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
(agent_model.unwrap(), judge_model.unwrap())
})
})
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self {
agent: EditAgent::new(agent_model, action_log, Templates::new()),
project,
judge_model,
}
}
async fn load_model(
provider: &str,
id: &str,
cx: &mut AsyncApp,
) -> Result<Arc<dyn LanguageModel>> {
let (provider, model) = cx.update(|cx| {
let models = LanguageModelRegistry::read_global(cx);
let model = models
.available_models(cx)
.find(|model| model.provider_id().0 == provider && model.id().0 == id)
.unwrap();
let provider = models.provider(&model.provider_id()).unwrap();
(provider, model)
})?;
cx.update(|cx| provider.authenticate(cx))?.await?;
Ok(model)
}
async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
let path = self
.project
.read_with(cx, |project, cx| {
project.find_project_path(eval.input_path, cx)
})
.unwrap();
let buffer = self
.project
.update(cx, |project, cx| project.open_buffer(path, cx))
.await
.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text(eval.input_content.clone(), cx)
});
let (edit_output, _events) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
let edit_output = edit_output.await?;
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
let assertion = match eval.assertion {
EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
100
} else {
0
},
message: None,
},
EvalAssertion::JudgeDiff(assertions) => self
.judge_diff(&actual_diff, assertions, &cx.to_async())
.await
.context("failed comparing diffs")?,
};
Ok(EvalOutput {
assertion,
diff: actual_diff,
buffer_text,
edit_output,
})
}
async fn judge_diff(
&self,
diff: &str,
assertions: &'static str,
cx: &AsyncApp,
) -> Result<EvalAssertionResult> {
let prompt = DiffJudgeTemplate {
diff: diff.to_string(),
assertions,
}
.render(&self.agent.templates)
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
};
let mut response = self.judge_model.stream_completion_text(request, cx).await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionResult {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
enum EvalAssertion {
AssertEqual(String),
JudgeDiff(&'static str),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct EvalAssertionResult {
score: usize,
message: Option<String>,
}
#[derive(Serialize)]
pub struct DiffJudgeTemplate {
diff: String,
assertions: &'static str,
}
impl Template for DiffJudgeTemplate {
const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
}
fn strip_empty_lines(text: &str) -> String {
text.lines()
.filter(|line| !line.trim().is_empty())
.collect::<Vec<_>>()
.join("\n")
}

View file

@ -0,0 +1,328 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View file

@ -0,0 +1,374 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View file

@ -0,0 +1,378 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
handle_command_output(output)
}
fn handle_command_output(output: std::process::Output) -> Result<String> {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View file

@ -0,0 +1,374 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View file

@ -0,0 +1,339 @@
// font-kit/src/canvas.rs
//
// Copyright © 2018 The Pathfinder Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! An in-memory bitmap surface for glyph rasterization.
use lazy_static::lazy_static;
use pathfinder_geometry::rect::RectI;
use pathfinder_geometry::vector::Vector2I;
use std::cmp;
use std::fmt;
use crate::utils;
lazy_static! {
static ref BITMAP_1BPP_TO_8BPP_LUT: [[u8; 8]; 256] = {
let mut lut = [[0; 8]; 256];
for byte in 0..0x100 {
let mut value = [0; 8];
for bit in 0..8 {
if (byte & (0x80 >> bit)) != 0 {
value[bit] = 0xff;
}
}
lut[byte] = value
}
lut
};
}
/// An in-memory bitmap surface for glyph rasterization.
pub struct Canvas {
/// The raw pixel data.
pub pixels: Vec<u8>,
/// The size of the buffer, in pixels.
pub size: Vector2I,
/// The number of *bytes* between successive rows.
pub stride: usize,
/// The image format of the canvas.
pub format: Format,
}
impl Canvas {
/// Creates a new blank canvas with the given pixel size and format.
///
/// Stride is automatically calculated from width.
///
/// The canvas is initialized with transparent black (all values 0).
#[inline]
pub fn new(size: Vector2I, format: Format) -> Canvas {
Canvas::with_stride(
size,
size.x() as usize * format.bytes_per_pixel() as usize,
format,
)
}
/// Creates a new blank canvas with the given pixel size, stride (number of bytes between
/// successive rows), and format.
///
/// The canvas is initialized with transparent black (all values 0).
pub fn with_stride(size: Vector2I, stride: usize, format: Format) -> Canvas {
Canvas {
pixels: vec![0; stride * size.y() as usize],
size,
stride,
format,
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_canvas(&mut self, src: &Canvas) {
self.blit_from(
Vector2I::default(),
&src.pixels,
src.size,
src.stride,
src.format,
)
}
/// Blits to a rectangle with origin at `dst_point` and size according to `src_size`.
/// If the target area overlaps the boundaries of the canvas, only the drawable region is blitted.
/// `dst_point` and `src_size` are specified in pixels. `src_stride` is specified in bytes.
/// `src_stride` must be equal or larger than the actual data length.
#[allow(dead_code)]
pub(crate) fn blit_from(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
src_format: Format,
) {
assert_eq!(
src_stride * src_size.y() as usize,
src_bytes.len(),
"Number of pixels in src_bytes does not match stride and size."
);
assert!(
src_stride >= src_size.x() as usize * src_format.bytes_per_pixel() as usize,
"src_stride must be >= than src_size.x()"
);
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
match (self.format, src_format) {
(Format::A8, Format::A8)
| (Format::Rgb24, Format::Rgb24)
| (Format::Rgba32, Format::Rgba32) => {
self.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::A8, Format::Rgb24) => {
self.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::A8) => {
self.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::Rgba32) => self
.blit_from_with::<BlitRgba32ToRgb24>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::Rgb24) => self
.blit_from_with::<BlitRgb24ToRgba32>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::A8) | (Format::A8, Format::Rgba32) => unimplemented!(),
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_bitmap_1bpp(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
) {
if self.format != Format::A8 {
unimplemented!()
}
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
let size = dst_rect.size();
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
let dest_row_stride = size.x() as usize * dest_bytes_per_pixel;
let src_row_stride = utils::div_round_up(size.x() as usize, 8);
for y in 0..size.y() {
let (dest_row_start, src_row_start) = (
(y + dst_rect.origin_y()) as usize * self.stride
+ dst_rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + dest_row_stride;
let src_row_end = src_row_start + src_row_stride;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
for x in 0..src_row_stride {
let pattern = &BITMAP_1BPP_TO_8BPP_LUT[src_row_pixels[x] as usize];
let dest_start = x * 8;
let dest_end = cmp::min(dest_start + 8, dest_row_stride);
let src = &pattern[0..(dest_end - dest_start)];
dest_row_pixels[dest_start..dest_end].clone_from_slice(src);
}
}
}
/// Blits to area `rect` using the data given in the buffer `src_bytes`.
/// `src_stride` must be specified in bytes.
/// The dimensions of `rect` must be in pixels.
fn blit_from_with<B: Blit>(
&mut self,
rect: RectI,
src_bytes: &[u8],
src_stride: usize,
src_format: Format,
) {
let src_bytes_per_pixel = src_format.bytes_per_pixel() as usize;
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
for y in 0..rect.height() {
let (dest_row_start, src_row_start) = (
(y + rect.origin_y()) as usize * self.stride
+ rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + rect.width() as usize * dest_bytes_per_pixel;
let src_row_end = src_row_start + rect.width() as usize * src_bytes_per_pixel;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
B::blit(dest_row_pixels, src_row_pixels)
}
}
}
impl fmt::Debug for Canvas {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Canvas")
.field("pixels", &self.pixels.len()) // Do not dump a vector content.
.field("size", &self.size)
.field("stride", &self.stride)
.field("format", &self.format)
.finish()
}
}
/// The image format for the canvas.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Premultiplied R8G8B8A8, little-endian.
Rgba32,
/// R8G8B8, little-endian.
Rgb24,
/// A8.
A8,
}
impl Format {
/// Returns the number of bits per pixel that this image format corresponds to.
#[inline]
pub fn bits_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 32,
Format::Rgb24 => 24,
Format::A8 => 8,
}
}
/// Returns the number of color channels per pixel that this image format corresponds to.
#[inline]
pub fn components_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 4,
Format::Rgb24 => 3,
Format::A8 => 1,
}
}
/// Returns the number of bits per color channel that this image format contains.
#[inline]
pub fn bits_per_component(self) -> u8 {
self.bits_per_pixel() / self.components_per_pixel()
}
/// Returns the number of bytes per pixel that this image format corresponds to.
#[inline]
pub fn bytes_per_pixel(self) -> u8 {
self.bits_per_pixel() / 8
}
}
/// The antialiasing strategy that should be used when rasterizing glyphs.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum RasterizationOptions {
/// "Black-and-white" rendering. Each pixel is either entirely on or off.
Bilevel,
/// Grayscale antialiasing. Only one channel is used.
GrayscaleAa,
/// Subpixel RGB antialiasing, for LCD screens.
SubpixelAa,
}
trait Blit {
fn blit(dest: &mut [u8], src: &[u8]);
}
struct BlitMemcpy;
impl Blit for BlitMemcpy {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
dest.clone_from_slice(src)
}
}
struct BlitRgb24ToA8;
impl Blit for BlitRgb24ToA8 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.iter_mut().zip(src.chunks(3)) {
*dest = src[1]
}
}
}
struct BlitA8ToRgb24;
impl Blit for BlitA8ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(3).zip(src.iter()) {
dest[0] = *src;
dest[1] = *src;
dest[2] = *src;
}
}
}
struct BlitRgba32ToRgb24;
impl Blit for BlitRgba32ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.chunks_mut(3).zip(src.chunks(4)) {
dest.copy_from_slice(&src[0..3])
}
}
}
struct BlitRgb24ToRgba32;
impl Blit for BlitRgb24ToRgba32 {
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(4).zip(src.chunks(3)) {
dest[0] = src[0];
dest[1] = src[1];
dest[2] = src[2];
dest[3] = 255;
}
}
}

View file

@ -282,7 +282,7 @@ pub struct EditFileToolCard {
}
impl EditFileToolCard {
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
pub fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
let editor = cx.new(|cx| {
let mut editor = Editor::new(
@ -323,7 +323,7 @@ impl EditFileToolCard {
}
}
fn set_diff(
pub fn set_diff(
&mut self,
path: Arc<Path>,
old_text: String,
@ -343,6 +343,7 @@ impl EditFileToolCard {
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
multibuffer.clear(cx);
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,

View file

@ -0,0 +1,339 @@
use crate::{
Templates,
edit_agent::{EditAgent, EditAgentOutputEvent},
edit_file_tool::EditFileToolCard,
schema::json_schema_for,
};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
use futures::StreamExt;
use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc;
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
pub struct StreamingEditFileTool;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct StreamingEditFileToolInput {
/// A one-line, user-friendly markdown description of the edit. This will be
/// shown in the UI and also passed to another model to perform the edit.
///
/// Be terse, but also descriptive in what you want to achieve with this
/// edit. Avoid generic instructions.
///
/// NEVER mention the file path in this description.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
///
/// Make sure to include this field before all the others in the input object
/// so that we can display it immediately.
pub display_description: String,
/// The full path of the file to modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - backend
/// - frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with root-1. Without that, the path
/// would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
path: String,
#[serde(default)]
display_description: String,
}
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for StreamingEditFileTool {
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("streaming_edit_file_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<StreamingEditFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!(
"Path {} not found in project",
input.path.display()
)))
.into();
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
};
let exists = worktree.update(cx, |worktree, cx| {
worktree.file_exists(&project_path.path, cx)
});
let card = window.and_then(|window| {
window
.update(cx, |_, window, cx| {
cx.new(|cx| {
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
})
})
.ok()
});
let card_clone = card.clone();
let messages = messages.to_vec();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
if !exists.await? {
return Err(anyhow!("{} not found", input.path.display()));
}
let model = cx
.update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
.context("default model not set")?
.model;
let edit_agent = EditAgent::new(model, action_log, Templates::new());
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?
.await?;
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let old_text = cx
.background_spawn({
let old_snapshot = old_snapshot.clone();
async move { old_snapshot.text() }
})
.await;
let (output, mut events) = edit_agent.edit(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
);
let mut hallucinated_old_text = false;
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited => {
if let Some(card) = card_clone.as_ref() {
let new_snapshot =
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
})
.await;
card.update(cx, |card, cx| {
card.set_diff(
project_path.path.clone(),
old_text.clone(),
new_text,
cx,
);
})
.log_err();
}
}
EditAgentOutputEvent::HallucinatedOldText(_) => hallucinated_old_text = true,
}
}
output.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
});
let diff = cx.background_spawn(async move {
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
});
let (new_text, diff) = futures::join!(new_text, diff);
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
})
.log_err();
}
let input_path = input.path.display();
if diff.is_empty() {
if hallucinated_old_text {
Err(anyhow!(formatdoc! {"
Some edits were produced but none of them could be applied.
Read the relevant sections of {input_path} again so that
I can perform the requested edits.
"}))
} else {
Ok("No edits were made.".to_string())
}
} else {
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
}
});
ToolResult {
output: task,
card: card.map(AnyToolCard::from),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn still_streaming_ui_text_with_path() {
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"src/main.rs"
);
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
#[test]
fn still_streaming_ui_text_with_null() {
let input = serde_json::Value::Null;
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
}

View file

@ -0,0 +1,8 @@
This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files.
Before using this tool:
1. Use the `read_file` tool to understand the file's contents and context
2. Verify the directory path is correct (only applicable when creating new files):
- Use the `list_directory` tool to verify the parent directory exists and is the correct location

View file

@ -0,0 +1,32 @@
use anyhow::Result;
use handlebars::Handlebars;
use rust_embed::RustEmbed;
use serde::Serialize;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "src/templates"]
#[include = "*.hbs"]
struct Assets;
pub struct Templates(Handlebars<'static>);
impl Templates {
pub fn new() -> Arc<Self> {
let mut handlebars = Handlebars::new();
handlebars.register_embed_templates::<Assets>().unwrap();
handlebars.register_escape_fn(|text| text.into());
Arc::new(Self(handlebars))
}
}
pub trait Template: Sized {
const TEMPLATE_NAME: &'static str;
fn render(&self, templates: &Templates) -> Result<String>
where
Self: Serialize + Sized,
{
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
}
}

View file

@ -0,0 +1,23 @@
You are an expert coder, and have been tasked with looking at the following diff:
<diff>
{{diff}}
</diff>
Evaluate the following assertions:
<assertions>
{{assertions}}
</assertions>
You must respond with a short analysis and a score between 0 and 100, where:
- 0 means no assertions pass
- 100 means all the assertions pass perfectly
<analysis>
- Assertion 1: one line describing why the first assertion passes or fails (even partially)
- Assertion 2: one line describing why the second assertion passes or fails (even partially)
- ...
- Assertion N: one line describing why the Nth assertion passes or fails (even partially)
</analysis>
<score>YOUR FINAL SCORE HERE</score>

View file

@ -0,0 +1,49 @@
You are an expert text editor and your task is to produce a series of edits to a file given a description of the changes you need to make.
You MUST respond with a series of edits to that one file in the following format:
```
<edits>
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
<old_text>
OLD TEXT 2 HERE
</old_text>
<new_text>
NEW TEXT 2 HERE
</new_text>
<old_text>
OLD TEXT 3 HERE
</old_text>
<new_text>
NEW TEXT 3 HERE
</new_text>
</edits>
```
Rules for editing:
- `old_text` represents lines in the input file that will be replaced with `new_text`. `old_text` MUST exactly match the existing file content, character for character, including indentation.
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
- If you want to replace many occurrences of the same text, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
- When reporting multiple edits, each edit assumes the previous one has already been applied! Therefore, you must ensure `old_text` doesn't reference text that has already been modified by a previous edit.
- Don't explain the edits, just report them.
- Only edit the file specified in `<file_to_edit>` and NEVER include edits to other files!
- If you open an <old_text> tag, you MUST close it using </old_text>
- If you open an <new_text> tag, you MUST close it using </new_text>
<file_to_edit>
{{path}}
</file_to_edit>
<edit_description>
{{edit_description}}
</edit_description>

View file

@ -48,6 +48,7 @@ markdown.workspace = true
node_runtime.workspace = true
pathdiff.workspace = true
paths.workspace = true
pretty_assertions.workspace = true
project.workspace = true
prompt_store.workspace = true
regex.workspace = true

View file

@ -1,6 +1,7 @@
{
"assistant": {
"always_allow_tool_actions": true,
"stream_edits": true,
"version": "2"
}
}

View file

@ -420,12 +420,12 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
assistant_tools::init(client.http_client(), cx);
context_server::init(cx);
prompt_store::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
assistant_tools::init(client.http_client(), cx);
SettingsStore::update_global(cx, |store, cx| {
store.set_user_settings(include_str!("../runner_settings.json"), cx)

View file

@ -160,7 +160,11 @@ impl ExampleContext {
if left == right {
Ok(())
} else {
println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
println!(
"{}{}",
self.log_prefix,
pretty_assertions::Comparison::new(&left, &right)
);
Err(anyhow::Error::from(FailedAssertion(message.clone())))
},
message,
@ -334,8 +338,8 @@ impl ExampleContext {
}
pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
self.app
.read_entity(&self.agent_thread, |thread, cx| {
self.agent_thread
.read_with(&self.app, |thread, cx| {
let action_log = thread.action_log().read(cx);
HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
|(buffer, diff)| {
@ -503,16 +507,16 @@ impl ToolUse {
}
}
#[derive(Debug)]
#[derive(Debug, Eq, PartialEq)]
pub struct FileEdits {
hunks: Vec<FileEditHunk>,
pub hunks: Vec<FileEditHunk>,
}
#[derive(Debug)]
struct FileEditHunk {
base_text: String,
text: String,
status: DiffHunkStatus,
#[derive(Debug, Eq, PartialEq)]
pub struct FileEditHunk {
pub base_text: String,
pub text: String,
pub status: DiffHunkStatus,
}
impl FileEdits {

View file

@ -59,6 +59,12 @@ impl FeatureFlag for Assistant2FeatureFlag {
const NAME: &'static str = "assistant2";
}
pub struct AgentStreamEditsFeatureFlag;
impl FeatureFlag for AgentStreamEditsFeatureFlag {
const NAME: &'static str = "agent-stream-edits";
}
pub struct NewBillingFeatureFlag;
impl FeatureFlag for NewBillingFeatureFlag {

View file

@ -78,7 +78,7 @@ pub(crate) use test::*;
pub(crate) use windows::*;
#[cfg(any(test, feature = "test-support"))]
pub use test::TestScreenCaptureSource;
pub use test::{TestDispatcher, TestScreenCaptureSource};
/// Returns a background executor for the current platform.
pub fn background_executor() -> BackgroundExecutor {

View file

@ -3,7 +3,7 @@ mod display;
mod platform;
mod window;
pub(crate) use dispatcher::*;
pub use dispatcher::*;
pub(crate) use display::*;
pub(crate) use platform::*;
pub(crate) use window::*;

View file

@ -59,9 +59,9 @@ use text::operation_queue::OperationQueue;
use text::*;
pub use text::{
Anchor, Bias, Buffer as TextBuffer, BufferId, BufferSnapshot as TextBufferSnapshot, Edit,
OffsetRangeExt, OffsetUtf16, Patch, Point, PointUtf16, Rope, Selection, SelectionGoal,
Subscription, TextDimension, TextSummary, ToOffset, ToOffsetUtf16, ToPoint, ToPointUtf16,
Transaction, TransactionId, Unclipped,
LineIndent, OffsetRangeExt, OffsetUtf16, Patch, Point, PointUtf16, Rope, Selection,
SelectionGoal, Subscription, TextDimension, TextSummary, ToOffset, ToOffsetUtf16, ToPoint,
ToPointUtf16, Transaction, TransactionId, Unclipped,
};
use theme::{ActiveTheme as _, SyntaxTheme};
#[cfg(any(test, feature = "test-support"))]

View file

@ -7,6 +7,7 @@ use std::{
ops::Range,
};
#[derive(Default)]
struct Matrix {
cells: Vec<f64>,
rows: usize,
@ -95,6 +96,7 @@ pub enum CharOperation {
Keep { bytes: usize },
}
#[derive(Default)]
pub struct StreamingDiff {
old: Vec<char>,
new: Vec<char>,

View file

@ -60,6 +60,7 @@ futures-sink = { version = "0.3" }
futures-task = { version = "0.3", default-features = false, features = ["std"] }
futures-util = { version = "0.3", features = ["channel", "io-compat", "sink"] }
getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["std"] }
handlebars = { version = "4", features = ["rust-embed"] }
hashbrown-3575ec1268b04181 = { package = "hashbrown", version = "0.15", features = ["serde"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
hmac = { version = "0.12", default-features = false, features = ["reset"] }
@ -170,6 +171,7 @@ futures-sink = { version = "0.3" }
futures-task = { version = "0.3", default-features = false, features = ["std"] }
futures-util = { version = "0.3", features = ["channel", "io-compat", "sink"] }
getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["std"] }
handlebars = { version = "4", features = ["rust-embed"] }
hashbrown-3575ec1268b04181 = { package = "hashbrown", version = "0.15", features = ["serde"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
heck = { version = "0.4", features = ["unicode"] }