Allow StreamingEditFileTool
to also create files (#29785)
Refs #29733 This pull request introduces a new field to the `StreamingEditFileTool` that lets the model create or overwrite a file in a streaming way. When one of the `assistant.stream_edits` setting / `agent-stream-edits` feature flag is enabled, we are going to disable the `CreateFileTool` so that the agent model can only use `StreamingEditFileTool` for file creation. Release Notes: - N/A --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com> Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
parent
f619d5f02a
commit
35539847a4
11 changed files with 2914 additions and 140 deletions
|
@ -77,7 +77,6 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||||
registry.register_tool(TerminalTool);
|
registry.register_tool(TerminalTool);
|
||||||
registry.register_tool(BatchTool);
|
registry.register_tool(BatchTool);
|
||||||
registry.register_tool(CreateDirectoryTool);
|
registry.register_tool(CreateDirectoryTool);
|
||||||
registry.register_tool(CreateFileTool);
|
|
||||||
registry.register_tool(CopyPathTool);
|
registry.register_tool(CopyPathTool);
|
||||||
registry.register_tool(DeletePathTool);
|
registry.register_tool(DeletePathTool);
|
||||||
registry.register_tool(SymbolInfoTool);
|
registry.register_tool(SymbolInfoTool);
|
||||||
|
@ -125,12 +124,14 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||||
fn register_edit_file_tool(cx: &mut App) {
|
fn register_edit_file_tool(cx: &mut App) {
|
||||||
let registry = ToolRegistry::global(cx);
|
let registry = ToolRegistry::global(cx);
|
||||||
|
|
||||||
|
registry.unregister_tool(CreateFileTool);
|
||||||
registry.unregister_tool(EditFileTool);
|
registry.unregister_tool(EditFileTool);
|
||||||
registry.unregister_tool(StreamingEditFileTool);
|
registry.unregister_tool(StreamingEditFileTool);
|
||||||
|
|
||||||
if AssistantSettings::get_global(cx).stream_edits(cx) {
|
if AssistantSettings::get_global(cx).stream_edits(cx) {
|
||||||
registry.register_tool(StreamingEditFileTool);
|
registry.register_tool(StreamingEditFileTool);
|
||||||
} else {
|
} else {
|
||||||
|
registry.register_tool(CreateFileTool);
|
||||||
registry.register_tool(EditFileTool);
|
registry.register_tool(EditFileTool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
|
||||||
use futures::{
|
use futures::{
|
||||||
Stream, StreamExt,
|
Stream, StreamExt,
|
||||||
channel::mpsc::{self, UnboundedReceiver},
|
channel::mpsc::{self, UnboundedReceiver},
|
||||||
|
pin_mut,
|
||||||
stream::BoxStream,
|
stream::BoxStream,
|
||||||
};
|
};
|
||||||
use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
|
use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
|
||||||
|
@ -23,19 +24,29 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
|
||||||
use streaming_diff::{CharOperation, StreamingDiff};
|
use streaming_diff::{CharOperation, StreamingDiff};
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub struct EditAgentTemplate {
|
struct CreateFilePromptTemplate {
|
||||||
path: Option<PathBuf>,
|
path: Option<PathBuf>,
|
||||||
edit_description: String,
|
edit_description: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Template for EditAgentTemplate {
|
impl Template for CreateFilePromptTemplate {
|
||||||
const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
|
const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct EditFilePromptTemplate {
|
||||||
|
path: Option<PathBuf>,
|
||||||
|
edit_description: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Template for EditFilePromptTemplate {
|
||||||
|
const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||||
pub enum EditAgentOutputEvent {
|
pub enum EditAgentOutputEvent {
|
||||||
Edited,
|
Edited,
|
||||||
HallucinatedOldText(SharedString),
|
OldTextNotFound(SharedString),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -64,6 +75,82 @@ impl EditAgent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn overwrite(
|
||||||
|
&self,
|
||||||
|
buffer: Entity<Buffer>,
|
||||||
|
edit_description: String,
|
||||||
|
previous_messages: Vec<LanguageModelRequestMessage>,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> (
|
||||||
|
Task<Result<EditAgentOutput>>,
|
||||||
|
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
|
||||||
|
) {
|
||||||
|
let this = self.clone();
|
||||||
|
let (events_tx, events_rx) = mpsc::unbounded();
|
||||||
|
let output = cx.spawn(async move |cx| {
|
||||||
|
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||||
|
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||||
|
let prompt = CreateFilePromptTemplate {
|
||||||
|
path,
|
||||||
|
edit_description,
|
||||||
|
}
|
||||||
|
.render(&this.templates)?;
|
||||||
|
let new_chunks = this.request(previous_messages, prompt, cx).await?;
|
||||||
|
|
||||||
|
let (output, mut inner_events) = this.replace_text_with_chunks(buffer, new_chunks, cx);
|
||||||
|
while let Some(event) = inner_events.next().await {
|
||||||
|
events_tx.unbounded_send(event).ok();
|
||||||
|
}
|
||||||
|
output.await
|
||||||
|
});
|
||||||
|
(output, events_rx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn replace_text_with_chunks(
|
||||||
|
&self,
|
||||||
|
buffer: Entity<Buffer>,
|
||||||
|
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> (
|
||||||
|
Task<Result<EditAgentOutput>>,
|
||||||
|
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
|
||||||
|
) {
|
||||||
|
let (output_events_tx, output_events_rx) = mpsc::unbounded();
|
||||||
|
let this = self.clone();
|
||||||
|
let task = cx.spawn(async move |cx| {
|
||||||
|
// Ensure the buffer is tracked by the action log.
|
||||||
|
this.action_log
|
||||||
|
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
|
||||||
|
this.action_log
|
||||||
|
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut raw_edits = String::new();
|
||||||
|
pin_mut!(edit_chunks);
|
||||||
|
while let Some(chunk) = edit_chunks.next().await {
|
||||||
|
let chunk = chunk?;
|
||||||
|
raw_edits.push_str(&chunk);
|
||||||
|
cx.update(|cx| {
|
||||||
|
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
|
||||||
|
this.action_log
|
||||||
|
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||||
|
})?;
|
||||||
|
output_events_tx
|
||||||
|
.unbounded_send(EditAgentOutputEvent::Edited)
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(EditAgentOutput {
|
||||||
|
_raw_edits: raw_edits,
|
||||||
|
_parser_metrics: EditParserMetrics::default(),
|
||||||
|
})
|
||||||
|
});
|
||||||
|
(task, output_events_rx)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn edit(
|
pub fn edit(
|
||||||
&self,
|
&self,
|
||||||
buffer: Entity<Buffer>,
|
buffer: Entity<Buffer>,
|
||||||
|
@ -78,10 +165,15 @@ impl EditAgent {
|
||||||
let (events_tx, events_rx) = mpsc::unbounded();
|
let (events_tx, events_rx) = mpsc::unbounded();
|
||||||
let output = cx.spawn(async move |cx| {
|
let output = cx.spawn(async move |cx| {
|
||||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||||
let edit_chunks = this
|
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||||
.request_edits(snapshot, edit_description, previous_messages, cx)
|
let prompt = EditFilePromptTemplate {
|
||||||
.await?;
|
path,
|
||||||
let (output, mut inner_events) = this.apply_edits(buffer, edit_chunks, cx);
|
edit_description,
|
||||||
|
}
|
||||||
|
.render(&this.templates)?;
|
||||||
|
let edit_chunks = this.request(previous_messages, prompt, cx).await?;
|
||||||
|
|
||||||
|
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
|
||||||
while let Some(event) = inner_events.next().await {
|
while let Some(event) = inner_events.next().await {
|
||||||
events_tx.unbounded_send(event).ok();
|
events_tx.unbounded_send(event).ok();
|
||||||
}
|
}
|
||||||
|
@ -90,7 +182,7 @@ impl EditAgent {
|
||||||
(output, events_rx)
|
(output, events_rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_edits(
|
fn apply_edit_chunks(
|
||||||
&self,
|
&self,
|
||||||
buffer: Entity<Buffer>,
|
buffer: Entity<Buffer>,
|
||||||
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
|
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
|
||||||
|
@ -138,7 +230,7 @@ impl EditAgent {
|
||||||
let Some(old_range) = old_range else {
|
let Some(old_range) = old_range else {
|
||||||
// We couldn't find the old text in the buffer. Report the error.
|
// We couldn't find the old text in the buffer. Report the error.
|
||||||
output_events
|
output_events
|
||||||
.unbounded_send(EditAgentOutputEvent::HallucinatedOldText(old_text_query))
|
.unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
|
||||||
.ok();
|
.ok();
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
@ -232,7 +324,7 @@ impl EditAgent {
|
||||||
) {
|
) {
|
||||||
let (tx, rx) = mpsc::unbounded();
|
let (tx, rx) = mpsc::unbounded();
|
||||||
let output = cx.background_spawn(async move {
|
let output = cx.background_spawn(async move {
|
||||||
futures::pin_mut!(chunks);
|
pin_mut!(chunks);
|
||||||
|
|
||||||
let mut parser = EditParser::new();
|
let mut parser = EditParser::new();
|
||||||
let mut raw_edits = String::new();
|
let mut raw_edits = String::new();
|
||||||
|
@ -336,20 +428,12 @@ impl EditAgent {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn request_edits(
|
async fn request(
|
||||||
&self,
|
&self,
|
||||||
snapshot: BufferSnapshot,
|
|
||||||
edit_description: String,
|
|
||||||
mut messages: Vec<LanguageModelRequestMessage>,
|
mut messages: Vec<LanguageModelRequestMessage>,
|
||||||
|
prompt: String,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
|
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
|
||||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
|
||||||
let prompt = EditAgentTemplate {
|
|
||||||
path,
|
|
||||||
edit_description,
|
|
||||||
}
|
|
||||||
.render(&self.templates)?;
|
|
||||||
|
|
||||||
let mut message_content = Vec::new();
|
let mut message_content = Vec::new();
|
||||||
if let Some(last_message) = messages.last_mut() {
|
if let Some(last_message) = messages.last_mut() {
|
||||||
if last_message.role == Role::Assistant {
|
if last_message.role == Role::Assistant {
|
||||||
|
@ -611,7 +695,8 @@ mod tests {
|
||||||
&mut rng,
|
&mut rng,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
|
let (apply, _events) =
|
||||||
|
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||||
apply.await.unwrap();
|
apply.await.unwrap();
|
||||||
pretty_assertions::assert_eq!(
|
pretty_assertions::assert_eq!(
|
||||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||||
|
@ -648,7 +733,8 @@ mod tests {
|
||||||
&mut rng,
|
&mut rng,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
|
let (apply, _events) =
|
||||||
|
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||||
apply.await.unwrap();
|
apply.await.unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||||
|
@ -679,7 +765,8 @@ mod tests {
|
||||||
&mut rng,
|
&mut rng,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
|
let (apply, _events) =
|
||||||
|
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||||
apply.await.unwrap();
|
apply.await.unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||||
|
@ -692,7 +779,7 @@ mod tests {
|
||||||
let agent = init_test(cx).await;
|
let agent = init_test(cx).await;
|
||||||
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
|
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
|
||||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||||
let (apply, mut events) = agent.apply_edits(
|
let (apply, mut events) = agent.apply_edit_chunks(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
|
chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
|
||||||
&mut cx.to_async(),
|
&mut cx.to_async(),
|
||||||
|
@ -744,7 +831,7 @@ mod tests {
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
drain_events(&mut events),
|
drain_events(&mut events),
|
||||||
vec![EditAgentOutputEvent::HallucinatedOldText(
|
vec![EditAgentOutputEvent::OldTextNotFound(
|
||||||
"hallucinated old".into()
|
"hallucinated old".into()
|
||||||
)]
|
)]
|
||||||
);
|
);
|
||||||
|
|
|
@ -4,10 +4,11 @@ use crate::{
|
||||||
streaming_edit_file_tool::StreamingEditFileToolInput,
|
streaming_edit_file_tool::StreamingEditFileToolInput,
|
||||||
};
|
};
|
||||||
use Role::*;
|
use Role::*;
|
||||||
use anyhow::{Context, anyhow};
|
use anyhow::anyhow;
|
||||||
use client::{Client, UserStore};
|
use client::{Client, UserStore};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use fs::FakeFs;
|
use fs::FakeFs;
|
||||||
|
use futures::{FutureExt, future::LocalBoxFuture};
|
||||||
use gpui::{AppContext, TestAppContext};
|
use gpui::{AppContext, TestAppContext};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
|
@ -71,14 +72,15 @@ fn eval_extract_handle_command_output() {
|
||||||
StreamingEditFileToolInput {
|
StreamingEditFileToolInput {
|
||||||
display_description: edit_description.into(),
|
display_description: edit_description.into(),
|
||||||
path: input_file_path.into(),
|
path: input_file_path.into(),
|
||||||
|
create_or_overwrite: false,
|
||||||
},
|
},
|
||||||
)],
|
)],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
input_path: input_file_path.into(),
|
input_path: input_file_path.into(),
|
||||||
input_content: input_file_content.into(),
|
input_content: Some(input_file_content.into()),
|
||||||
edit_description: edit_description.into(),
|
edit_description: edit_description.into(),
|
||||||
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
|
assertion: EvalAssertion::assert_eq(output_file_content),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -126,14 +128,15 @@ fn eval_delete_run_git_blame() {
|
||||||
StreamingEditFileToolInput {
|
StreamingEditFileToolInput {
|
||||||
display_description: edit_description.into(),
|
display_description: edit_description.into(),
|
||||||
path: input_file_path.into(),
|
path: input_file_path.into(),
|
||||||
|
create_or_overwrite: false,
|
||||||
},
|
},
|
||||||
)],
|
)],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
input_path: input_file_path.into(),
|
input_path: input_file_path.into(),
|
||||||
input_content: input_file_content.into(),
|
input_content: Some(input_file_content.into()),
|
||||||
edit_description: edit_description.into(),
|
edit_description: edit_description.into(),
|
||||||
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
|
assertion: EvalAssertion::assert_eq(output_file_content),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -240,14 +243,15 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||||
StreamingEditFileToolInput {
|
StreamingEditFileToolInput {
|
||||||
display_description: edit_description.into(),
|
display_description: edit_description.into(),
|
||||||
path: input_file_path.into(),
|
path: input_file_path.into(),
|
||||||
|
create_or_overwrite: false,
|
||||||
},
|
},
|
||||||
)],
|
)],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
input_path: input_file_path.into(),
|
input_path: input_file_path.into(),
|
||||||
input_content: input_file_content.into(),
|
input_content: Some(input_file_content.into()),
|
||||||
edit_description: edit_description.into(),
|
edit_description: edit_description.into(),
|
||||||
assertion: EvalAssertion::JudgeDiff(indoc! {"
|
assertion: EvalAssertion::judge_diff(indoc! {"
|
||||||
- The compile_parser_to_wasm method has been changed to use wasi-sdk
|
- The compile_parser_to_wasm method has been changed to use wasi-sdk
|
||||||
- ureq is used to download the SDK for current platform and architecture
|
- ureq is used to download the SDK for current platform and architecture
|
||||||
"}),
|
"}),
|
||||||
|
@ -315,14 +319,15 @@ fn eval_disable_cursor_blinking() {
|
||||||
StreamingEditFileToolInput {
|
StreamingEditFileToolInput {
|
||||||
display_description: edit_description.into(),
|
display_description: edit_description.into(),
|
||||||
path: input_file_path.into(),
|
path: input_file_path.into(),
|
||||||
|
create_or_overwrite: false,
|
||||||
},
|
},
|
||||||
)],
|
)],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
input_path: input_file_path.into(),
|
input_path: input_file_path.into(),
|
||||||
input_content: input_file_content.into(),
|
input_content: Some(input_file_content.into()),
|
||||||
edit_description: edit_description.into(),
|
edit_description: edit_description.into(),
|
||||||
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
|
assertion: EvalAssertion::assert_eq(output_file_content),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -504,14 +509,15 @@ fn eval_from_pixels_constructor() {
|
||||||
StreamingEditFileToolInput {
|
StreamingEditFileToolInput {
|
||||||
display_description: edit_description.into(),
|
display_description: edit_description.into(),
|
||||||
path: input_file_path.into(),
|
path: input_file_path.into(),
|
||||||
|
create_or_overwrite: false,
|
||||||
},
|
},
|
||||||
)],
|
)],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
input_path: input_file_path.into(),
|
input_path: input_file_path.into(),
|
||||||
input_content: input_file_content.into(),
|
input_content: Some(input_file_content.into()),
|
||||||
edit_description: edit_description.into(),
|
edit_description: edit_description.into(),
|
||||||
assertion: EvalAssertion::JudgeDiff(indoc! {"
|
assertion: EvalAssertion::assert_eq(indoc! {"
|
||||||
- The diff contains a new `from_pixels` constructor
|
- The diff contains a new `from_pixels` constructor
|
||||||
- The diff contains new tests for the `from_pixels` constructor
|
- The diff contains new tests for the `from_pixels` constructor
|
||||||
"}),
|
"}),
|
||||||
|
@ -519,6 +525,104 @@ fn eval_from_pixels_constructor() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg_attr(not(feature = "eval"), ignore)]
|
||||||
|
fn eval_zode() {
|
||||||
|
let input_file_path = "root/zode.py";
|
||||||
|
let edit_description = "Create the main Zode CLI script";
|
||||||
|
eval(
|
||||||
|
200,
|
||||||
|
1.,
|
||||||
|
EvalInput {
|
||||||
|
conversation: vec![
|
||||||
|
message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
|
||||||
|
message(
|
||||||
|
Assistant,
|
||||||
|
[
|
||||||
|
tool_use(
|
||||||
|
"tool_1",
|
||||||
|
"read_file",
|
||||||
|
ReadFileToolInput {
|
||||||
|
path: "root/eval/react.py".into(),
|
||||||
|
start_line: None,
|
||||||
|
end_line: None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
tool_use(
|
||||||
|
"tool_2",
|
||||||
|
"read_file",
|
||||||
|
ReadFileToolInput {
|
||||||
|
path: "root/eval/react_test.py".into(),
|
||||||
|
start_line: None,
|
||||||
|
end_line: None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
message(
|
||||||
|
User,
|
||||||
|
[
|
||||||
|
tool_result(
|
||||||
|
"tool_1",
|
||||||
|
"read_file",
|
||||||
|
include_str!("evals/fixtures/zode/react.py"),
|
||||||
|
),
|
||||||
|
tool_result(
|
||||||
|
"tool_2",
|
||||||
|
"read_file",
|
||||||
|
include_str!("evals/fixtures/zode/react_test.py"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
message(
|
||||||
|
Assistant,
|
||||||
|
[
|
||||||
|
text(
|
||||||
|
"Now that I understand what we need to build, I'll create the main Python script:",
|
||||||
|
),
|
||||||
|
tool_use(
|
||||||
|
"tool_3",
|
||||||
|
"edit_file",
|
||||||
|
StreamingEditFileToolInput {
|
||||||
|
display_description: edit_description.into(),
|
||||||
|
path: input_file_path.into(),
|
||||||
|
create_or_overwrite: true,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
input_path: input_file_path.into(),
|
||||||
|
input_content: None,
|
||||||
|
edit_description: edit_description.into(),
|
||||||
|
assertion: EvalAssertion::new(async move |sample, _, _cx| {
|
||||||
|
let invalid_starts = [' ', '`', '\n'];
|
||||||
|
let mut message = String::new();
|
||||||
|
for start in invalid_starts {
|
||||||
|
if sample.text.starts_with(start) {
|
||||||
|
message.push_str(&format!("The sample starts with a {:?}\n", start));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Remove trailing newline.
|
||||||
|
message.pop();
|
||||||
|
|
||||||
|
if message.is_empty() {
|
||||||
|
Ok(EvalAssertionOutcome {
|
||||||
|
score: 100,
|
||||||
|
message: None,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(EvalAssertionOutcome {
|
||||||
|
score: 0,
|
||||||
|
message: Some(message),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
fn message(
|
fn message(
|
||||||
role: Role,
|
role: Role,
|
||||||
contents: impl IntoIterator<Item = MessageContent>,
|
contents: impl IntoIterator<Item = MessageContent>,
|
||||||
|
@ -574,11 +678,135 @@ fn tool_result(
|
||||||
struct EvalInput {
|
struct EvalInput {
|
||||||
conversation: Vec<LanguageModelRequestMessage>,
|
conversation: Vec<LanguageModelRequestMessage>,
|
||||||
input_path: PathBuf,
|
input_path: PathBuf,
|
||||||
input_content: String,
|
input_content: Option<String>,
|
||||||
edit_description: String,
|
edit_description: String,
|
||||||
assertion: EvalAssertion,
|
assertion: EvalAssertion,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct EvalSample {
|
||||||
|
text: String,
|
||||||
|
edit_output: EditAgentOutput,
|
||||||
|
diff: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
trait AssertionFn: 'static + Send + Sync {
|
||||||
|
fn assert<'a>(
|
||||||
|
&'a self,
|
||||||
|
sample: &'a EvalSample,
|
||||||
|
judge_model: Arc<dyn LanguageModel>,
|
||||||
|
cx: &'a mut TestAppContext,
|
||||||
|
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> AssertionFn for F
|
||||||
|
where
|
||||||
|
F: 'static
|
||||||
|
+ Send
|
||||||
|
+ Sync
|
||||||
|
+ AsyncFn(
|
||||||
|
&EvalSample,
|
||||||
|
Arc<dyn LanguageModel>,
|
||||||
|
&mut TestAppContext,
|
||||||
|
) -> Result<EvalAssertionOutcome>,
|
||||||
|
{
|
||||||
|
fn assert<'a>(
|
||||||
|
&'a self,
|
||||||
|
sample: &'a EvalSample,
|
||||||
|
judge_model: Arc<dyn LanguageModel>,
|
||||||
|
cx: &'a mut TestAppContext,
|
||||||
|
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
|
||||||
|
(self)(sample, judge_model, cx).boxed_local()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct EvalAssertion(Arc<dyn AssertionFn>);
|
||||||
|
|
||||||
|
impl EvalAssertion {
|
||||||
|
fn new<F>(f: F) -> Self
|
||||||
|
where
|
||||||
|
F: 'static
|
||||||
|
+ Send
|
||||||
|
+ Sync
|
||||||
|
+ AsyncFn(
|
||||||
|
&EvalSample,
|
||||||
|
Arc<dyn LanguageModel>,
|
||||||
|
&mut TestAppContext,
|
||||||
|
) -> Result<EvalAssertionOutcome>,
|
||||||
|
{
|
||||||
|
EvalAssertion(Arc::new(f))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_eq(expected: impl Into<String>) -> Self {
|
||||||
|
let expected = expected.into();
|
||||||
|
Self::new(async move |sample, _judge, _cx| {
|
||||||
|
Ok(EvalAssertionOutcome {
|
||||||
|
score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
|
||||||
|
100
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
},
|
||||||
|
message: None,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn judge_diff(assertions: &'static str) -> Self {
|
||||||
|
Self::new(async move |sample, judge, cx| {
|
||||||
|
let prompt = DiffJudgeTemplate {
|
||||||
|
diff: sample.diff.clone(),
|
||||||
|
assertions,
|
||||||
|
}
|
||||||
|
.render(&Templates::new())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let request = LanguageModelRequest {
|
||||||
|
messages: vec![LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![prompt.into()],
|
||||||
|
cache: false,
|
||||||
|
}],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut response = judge
|
||||||
|
.stream_completion_text(request, &cx.to_async())
|
||||||
|
.await?;
|
||||||
|
let mut output = String::new();
|
||||||
|
while let Some(chunk) = response.stream.next().await {
|
||||||
|
let chunk = chunk?;
|
||||||
|
output.push_str(&chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the score from the response
|
||||||
|
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
|
||||||
|
if let Some(captures) = re.captures(&output) {
|
||||||
|
if let Some(score_match) = captures.get(1) {
|
||||||
|
let score = score_match.as_str().parse().unwrap_or(0);
|
||||||
|
return Ok(EvalAssertionOutcome {
|
||||||
|
score,
|
||||||
|
message: Some(output),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(anyhow!(
|
||||||
|
"No score found in response. Raw output: {}",
|
||||||
|
output
|
||||||
|
))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
input: &EvalSample,
|
||||||
|
judge_model: Arc<dyn LanguageModel>,
|
||||||
|
cx: &mut TestAppContext,
|
||||||
|
) -> Result<EvalAssertionOutcome> {
|
||||||
|
self.0.assert(input, judge_model, cx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||||
let mut evaluated_count = 0;
|
let mut evaluated_count = 0;
|
||||||
report_progress(evaluated_count, iterations);
|
report_progress(evaluated_count, iterations);
|
||||||
|
@ -606,12 +834,12 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||||
while let Ok(output) = rx.recv() {
|
while let Ok(output) = rx.recv() {
|
||||||
match output {
|
match output {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
|
cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
|
||||||
eval_outputs.push(output.clone());
|
eval_outputs.push(output.clone());
|
||||||
if output.assertion.score < 80 {
|
if output.assertion.score < 80 {
|
||||||
failed_count += 1;
|
failed_count += 1;
|
||||||
failed_evals
|
failed_evals
|
||||||
.entry(output.buffer_text.clone())
|
.entry(output.sample.text.clone())
|
||||||
.or_insert(Vec::new())
|
.or_insert(Vec::new())
|
||||||
.push(output);
|
.push(output);
|
||||||
}
|
}
|
||||||
|
@ -671,10 +899,8 @@ fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct EvalOutput {
|
struct EvalOutput {
|
||||||
assertion: EvalAssertionResult,
|
sample: EvalSample,
|
||||||
buffer_text: String,
|
assertion: EvalAssertionOutcome,
|
||||||
edit_output: EditAgentOutput,
|
|
||||||
diff: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for EvalOutput {
|
impl Display for EvalOutput {
|
||||||
|
@ -684,14 +910,14 @@ impl Display for EvalOutput {
|
||||||
writeln!(f, "Message: {}", message)?;
|
writeln!(f, "Message: {}", message)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
writeln!(f, "Diff:\n{}", self.diff)?;
|
writeln!(f, "Diff:\n{}", self.sample.diff)?;
|
||||||
|
|
||||||
writeln!(
|
writeln!(
|
||||||
f,
|
f,
|
||||||
"Parser Metrics:\n{:#?}",
|
"Parser Metrics:\n{:#?}",
|
||||||
self.edit_output._parser_metrics
|
self.sample.edit_output._parser_metrics
|
||||||
)?;
|
)?;
|
||||||
writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
|
writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -777,96 +1003,45 @@ impl EditAgentTest {
|
||||||
.update(cx, |project, cx| project.open_buffer(path, cx))
|
.update(cx, |project, cx| project.open_buffer(path, cx))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
buffer.update(cx, |buffer, cx| {
|
let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
|
||||||
buffer.set_text(eval.input_content.clone(), cx)
|
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
|
||||||
});
|
let (edit_output, _) = self.agent.edit(
|
||||||
let (edit_output, _events) = self.agent.edit(
|
buffer.clone(),
|
||||||
buffer.clone(),
|
eval.edit_description,
|
||||||
eval.edit_description,
|
eval.conversation,
|
||||||
eval.conversation,
|
&mut cx.to_async(),
|
||||||
&mut cx.to_async(),
|
);
|
||||||
);
|
edit_output.await?
|
||||||
let edit_output = edit_output.await?;
|
} else {
|
||||||
|
let (edit_output, _) = self.agent.overwrite(
|
||||||
|
buffer.clone(),
|
||||||
|
eval.edit_description,
|
||||||
|
eval.conversation,
|
||||||
|
&mut cx.to_async(),
|
||||||
|
);
|
||||||
|
edit_output.await?
|
||||||
|
};
|
||||||
|
|
||||||
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
|
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
|
||||||
let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
|
let sample = EvalSample {
|
||||||
let assertion = match eval.assertion {
|
|
||||||
EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
|
|
||||||
score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
|
|
||||||
100
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
},
|
|
||||||
message: None,
|
|
||||||
},
|
|
||||||
EvalAssertion::JudgeDiff(assertions) => self
|
|
||||||
.judge_diff(&actual_diff, assertions, &cx.to_async())
|
|
||||||
.await
|
|
||||||
.context("failed comparing diffs")?,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(EvalOutput {
|
|
||||||
assertion,
|
|
||||||
diff: actual_diff,
|
|
||||||
buffer_text,
|
|
||||||
edit_output,
|
edit_output,
|
||||||
})
|
diff: language::unified_diff(
|
||||||
}
|
eval.input_content.as_deref().unwrap_or_default(),
|
||||||
|
&buffer_text,
|
||||||
async fn judge_diff(
|
),
|
||||||
&self,
|
text: buffer_text,
|
||||||
diff: &str,
|
|
||||||
assertions: &'static str,
|
|
||||||
cx: &AsyncApp,
|
|
||||||
) -> Result<EvalAssertionResult> {
|
|
||||||
let prompt = DiffJudgeTemplate {
|
|
||||||
diff: diff.to_string(),
|
|
||||||
assertions,
|
|
||||||
}
|
|
||||||
.render(&self.agent.templates)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let request = LanguageModelRequest {
|
|
||||||
messages: vec![LanguageModelRequestMessage {
|
|
||||||
role: Role::User,
|
|
||||||
content: vec![prompt.into()],
|
|
||||||
cache: false,
|
|
||||||
}],
|
|
||||||
..Default::default()
|
|
||||||
};
|
};
|
||||||
let mut response = self.judge_model.stream_completion_text(request, cx).await?;
|
let assertion = eval
|
||||||
let mut output = String::new();
|
.assertion
|
||||||
while let Some(chunk) = response.stream.next().await {
|
.run(&sample, self.judge_model.clone(), cx)
|
||||||
let chunk = chunk?;
|
.await?;
|
||||||
output.push_str(&chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the score from the response
|
Ok(EvalOutput { assertion, sample })
|
||||||
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
|
|
||||||
if let Some(captures) = re.captures(&output) {
|
|
||||||
if let Some(score_match) = captures.get(1) {
|
|
||||||
let score = score_match.as_str().parse().unwrap_or(0);
|
|
||||||
return Ok(EvalAssertionResult {
|
|
||||||
score,
|
|
||||||
message: Some(output),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Err(anyhow!(
|
|
||||||
"No score found in response. Raw output: {}",
|
|
||||||
output
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
||||||
enum EvalAssertion {
|
struct EvalAssertionOutcome {
|
||||||
AssertEqual(String),
|
|
||||||
JudgeDiff(&'static str),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
|
||||||
struct EvalAssertionResult {
|
|
||||||
score: usize,
|
score: usize,
|
||||||
message: Option<String>,
|
message: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
2193
crates/assistant_tools/src/edit_agent/evals/fixtures/zode/prompt.md
Normal file
2193
crates/assistant_tools/src/edit_agent/evals/fixtures/zode/prompt.md
Normal file
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,14 @@
|
||||||
|
class InputCell:
|
||||||
|
def __init__(self, initial_value):
|
||||||
|
self.value = None
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeCell:
|
||||||
|
def __init__(self, inputs, compute_function):
|
||||||
|
self.value = None
|
||||||
|
|
||||||
|
def add_callback(self, callback):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def remove_callback(self, callback):
|
||||||
|
pass
|
|
@ -0,0 +1,271 @@
|
||||||
|
# These tests are auto-generated with test data from:
|
||||||
|
# https://github.com/exercism/problem-specifications/tree/main/exercises/react/canonical-data.json
|
||||||
|
# File last updated on 2023-07-19
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from react import (
|
||||||
|
InputCell,
|
||||||
|
ComputeCell,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReactTest(unittest.TestCase):
|
||||||
|
def test_input_cells_have_a_value(self):
|
||||||
|
input = InputCell(10)
|
||||||
|
self.assertEqual(input.value, 10)
|
||||||
|
|
||||||
|
def test_an_input_cell_s_value_can_be_set(self):
|
||||||
|
input = InputCell(4)
|
||||||
|
input.value = 20
|
||||||
|
self.assertEqual(input.value, 20)
|
||||||
|
|
||||||
|
def test_compute_cells_calculate_initial_value(self):
|
||||||
|
input = InputCell(1)
|
||||||
|
output = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + 1,
|
||||||
|
)
|
||||||
|
self.assertEqual(output.value, 2)
|
||||||
|
|
||||||
|
def test_compute_cells_take_inputs_in_the_right_order(self):
|
||||||
|
one = InputCell(1)
|
||||||
|
two = InputCell(2)
|
||||||
|
output = ComputeCell(
|
||||||
|
[
|
||||||
|
one,
|
||||||
|
two,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + inputs[1] * 10,
|
||||||
|
)
|
||||||
|
self.assertEqual(output.value, 21)
|
||||||
|
|
||||||
|
def test_compute_cells_update_value_when_dependencies_are_changed(self):
|
||||||
|
input = InputCell(1)
|
||||||
|
output = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + 1,
|
||||||
|
)
|
||||||
|
input.value = 3
|
||||||
|
self.assertEqual(output.value, 4)
|
||||||
|
|
||||||
|
def test_compute_cells_can_depend_on_other_compute_cells(self):
|
||||||
|
input = InputCell(1)
|
||||||
|
times_two = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] * 2,
|
||||||
|
)
|
||||||
|
times_thirty = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] * 30,
|
||||||
|
)
|
||||||
|
output = ComputeCell(
|
||||||
|
[
|
||||||
|
times_two,
|
||||||
|
times_thirty,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + inputs[1],
|
||||||
|
)
|
||||||
|
self.assertEqual(output.value, 32)
|
||||||
|
input.value = 3
|
||||||
|
self.assertEqual(output.value, 96)
|
||||||
|
|
||||||
|
def test_compute_cells_fire_callbacks(self):
|
||||||
|
input = InputCell(1)
|
||||||
|
output = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + 1,
|
||||||
|
)
|
||||||
|
cb1_observer = []
|
||||||
|
callback1 = self.callback_factory(cb1_observer)
|
||||||
|
output.add_callback(callback1)
|
||||||
|
input.value = 3
|
||||||
|
self.assertEqual(cb1_observer[-1], 4)
|
||||||
|
|
||||||
|
def test_callback_cells_only_fire_on_change(self):
|
||||||
|
input = InputCell(1)
|
||||||
|
output = ComputeCell([input], lambda inputs: 111 if inputs[0] < 3 else 222)
|
||||||
|
cb1_observer = []
|
||||||
|
callback1 = self.callback_factory(cb1_observer)
|
||||||
|
output.add_callback(callback1)
|
||||||
|
input.value = 2
|
||||||
|
self.assertEqual(cb1_observer, [])
|
||||||
|
input.value = 4
|
||||||
|
self.assertEqual(cb1_observer[-1], 222)
|
||||||
|
|
||||||
|
def test_callbacks_do_not_report_already_reported_values(self):
|
||||||
|
input = InputCell(1)
|
||||||
|
output = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + 1,
|
||||||
|
)
|
||||||
|
cb1_observer = []
|
||||||
|
callback1 = self.callback_factory(cb1_observer)
|
||||||
|
output.add_callback(callback1)
|
||||||
|
input.value = 2
|
||||||
|
self.assertEqual(cb1_observer[-1], 3)
|
||||||
|
input.value = 3
|
||||||
|
self.assertEqual(cb1_observer[-1], 4)
|
||||||
|
|
||||||
|
def test_callbacks_can_fire_from_multiple_cells(self):
|
||||||
|
input = InputCell(1)
|
||||||
|
plus_one = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + 1,
|
||||||
|
)
|
||||||
|
minus_one = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] - 1,
|
||||||
|
)
|
||||||
|
cb1_observer = []
|
||||||
|
cb2_observer = []
|
||||||
|
callback1 = self.callback_factory(cb1_observer)
|
||||||
|
callback2 = self.callback_factory(cb2_observer)
|
||||||
|
plus_one.add_callback(callback1)
|
||||||
|
minus_one.add_callback(callback2)
|
||||||
|
input.value = 10
|
||||||
|
self.assertEqual(cb1_observer[-1], 11)
|
||||||
|
self.assertEqual(cb2_observer[-1], 9)
|
||||||
|
|
||||||
|
def test_callbacks_can_be_added_and_removed(self):
|
||||||
|
input = InputCell(11)
|
||||||
|
output = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + 1,
|
||||||
|
)
|
||||||
|
cb1_observer = []
|
||||||
|
cb2_observer = []
|
||||||
|
cb3_observer = []
|
||||||
|
callback1 = self.callback_factory(cb1_observer)
|
||||||
|
callback2 = self.callback_factory(cb2_observer)
|
||||||
|
callback3 = self.callback_factory(cb3_observer)
|
||||||
|
output.add_callback(callback1)
|
||||||
|
output.add_callback(callback2)
|
||||||
|
input.value = 31
|
||||||
|
self.assertEqual(cb1_observer[-1], 32)
|
||||||
|
self.assertEqual(cb2_observer[-1], 32)
|
||||||
|
output.remove_callback(callback1)
|
||||||
|
output.add_callback(callback3)
|
||||||
|
input.value = 41
|
||||||
|
self.assertEqual(len(cb1_observer), 1)
|
||||||
|
self.assertEqual(cb2_observer[-1], 42)
|
||||||
|
self.assertEqual(cb3_observer[-1], 42)
|
||||||
|
|
||||||
|
def test_removing_a_callback_multiple_times_doesn_t_interfere_with_other_callbacks(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
input = InputCell(1)
|
||||||
|
output = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + 1,
|
||||||
|
)
|
||||||
|
cb1_observer = []
|
||||||
|
cb2_observer = []
|
||||||
|
callback1 = self.callback_factory(cb1_observer)
|
||||||
|
callback2 = self.callback_factory(cb2_observer)
|
||||||
|
output.add_callback(callback1)
|
||||||
|
output.add_callback(callback2)
|
||||||
|
output.remove_callback(callback1)
|
||||||
|
output.remove_callback(callback1)
|
||||||
|
output.remove_callback(callback1)
|
||||||
|
input.value = 2
|
||||||
|
self.assertEqual(cb1_observer, [])
|
||||||
|
self.assertEqual(cb2_observer[-1], 3)
|
||||||
|
|
||||||
|
def test_callbacks_should_only_be_called_once_even_if_multiple_dependencies_change(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
input = InputCell(1)
|
||||||
|
plus_one = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + 1,
|
||||||
|
)
|
||||||
|
minus_one1 = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] - 1,
|
||||||
|
)
|
||||||
|
minus_one2 = ComputeCell(
|
||||||
|
[
|
||||||
|
minus_one1,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] - 1,
|
||||||
|
)
|
||||||
|
output = ComputeCell(
|
||||||
|
[
|
||||||
|
plus_one,
|
||||||
|
minus_one2,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] * inputs[1],
|
||||||
|
)
|
||||||
|
cb1_observer = []
|
||||||
|
callback1 = self.callback_factory(cb1_observer)
|
||||||
|
output.add_callback(callback1)
|
||||||
|
input.value = 4
|
||||||
|
self.assertEqual(cb1_observer[-1], 10)
|
||||||
|
|
||||||
|
def test_callbacks_should_not_be_called_if_dependencies_change_but_output_value_doesn_t_change(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
input = InputCell(1)
|
||||||
|
plus_one = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] + 1,
|
||||||
|
)
|
||||||
|
minus_one = ComputeCell(
|
||||||
|
[
|
||||||
|
input,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] - 1,
|
||||||
|
)
|
||||||
|
always_two = ComputeCell(
|
||||||
|
[
|
||||||
|
plus_one,
|
||||||
|
minus_one,
|
||||||
|
],
|
||||||
|
lambda inputs: inputs[0] - inputs[1],
|
||||||
|
)
|
||||||
|
cb1_observer = []
|
||||||
|
callback1 = self.callback_factory(cb1_observer)
|
||||||
|
always_two.add_callback(callback1)
|
||||||
|
input.value = 2
|
||||||
|
self.assertEqual(cb1_observer, [])
|
||||||
|
input.value = 3
|
||||||
|
self.assertEqual(cb1_observer, [])
|
||||||
|
input.value = 4
|
||||||
|
self.assertEqual(cb1_observer, [])
|
||||||
|
input.value = 5
|
||||||
|
self.assertEqual(cb1_observer, [])
|
||||||
|
|
||||||
|
# Utility functions.
|
||||||
|
def callback_factory(self, observer):
|
||||||
|
def callback(observer, value):
|
||||||
|
observer.append(value)
|
||||||
|
|
||||||
|
return partial(callback, observer)
|
|
@ -38,7 +38,7 @@ pub struct StreamingEditFileToolInput {
|
||||||
/// so that we can display it immediately.
|
/// so that we can display it immediately.
|
||||||
pub display_description: String,
|
pub display_description: String,
|
||||||
|
|
||||||
/// The full path of the file to modify in the project.
|
/// The full path of the file to create or modify in the project.
|
||||||
///
|
///
|
||||||
/// WARNING: When specifying which file path need changing, you MUST
|
/// WARNING: When specifying which file path need changing, you MUST
|
||||||
/// start each path with one of the project's root directories.
|
/// start each path with one of the project's root directories.
|
||||||
|
@ -58,6 +58,10 @@ pub struct StreamingEditFileToolInput {
|
||||||
/// `frontend/db.js`
|
/// `frontend/db.js`
|
||||||
/// </example>
|
/// </example>
|
||||||
pub path: PathBuf,
|
pub path: PathBuf,
|
||||||
|
|
||||||
|
/// If true, this tool will recreate the file from scratch.
|
||||||
|
/// If false, this tool will produce granular edits to an existing file.
|
||||||
|
pub create_or_overwrite: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
@ -158,7 +162,7 @@ impl Tool for StreamingEditFileTool {
|
||||||
let card_clone = card.clone();
|
let card_clone = card.clone();
|
||||||
let messages = messages.to_vec();
|
let messages = messages.to_vec();
|
||||||
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
||||||
if !exists.await? {
|
if !input.create_or_overwrite && !exists.await? {
|
||||||
return Err(anyhow!("{} not found", input.path.display()));
|
return Err(anyhow!("{} not found", input.path.display()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,12 +186,21 @@ impl Tool for StreamingEditFileTool {
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let (output, mut events) = edit_agent.edit(
|
let (output, mut events) = if input.create_or_overwrite {
|
||||||
buffer.clone(),
|
edit_agent.overwrite(
|
||||||
input.display_description.clone(),
|
buffer.clone(),
|
||||||
messages,
|
input.display_description.clone(),
|
||||||
cx,
|
messages,
|
||||||
);
|
cx,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
edit_agent.edit(
|
||||||
|
buffer.clone(),
|
||||||
|
input.display_description.clone(),
|
||||||
|
messages,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
let mut hallucinated_old_text = false;
|
let mut hallucinated_old_text = false;
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
|
@ -213,7 +226,7 @@ impl Tool for StreamingEditFileTool {
|
||||||
.log_err();
|
.log_err();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EditAgentOutputEvent::HallucinatedOldText(_) => hallucinated_old_text = true,
|
EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output.await?;
|
output.await?;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files.
|
This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
|
||||||
|
|
||||||
Before using this tool:
|
Before using this tool:
|
||||||
|
|
||||||
|
|
12
crates/assistant_tools/src/templates/create_file_prompt.hbs
Normal file
12
crates/assistant_tools/src/templates/create_file_prompt.hbs
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
You are an expert engineer and your task is to write a new file from scratch.
|
||||||
|
|
||||||
|
<file_to_edit>
|
||||||
|
{{path}}
|
||||||
|
</file_to_edit>
|
||||||
|
|
||||||
|
<edit_description>
|
||||||
|
{{edit_description}}
|
||||||
|
</edit_description>
|
||||||
|
|
||||||
|
You MUST respond directly with the file's content, without explanations, additional text or triple backticks.
|
||||||
|
The text you output will be saved verbatim as the content of the file.
|
|
@ -2141,6 +2141,14 @@ impl Buffer {
|
||||||
self.edit([(0..self.len(), text)], None, cx)
|
self.edit([(0..self.len(), text)], None, cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Appends the given text to the end of the buffer.
|
||||||
|
pub fn append<T>(&mut self, text: T, cx: &mut Context<Self>) -> Option<clock::Lamport>
|
||||||
|
where
|
||||||
|
T: Into<Arc<str>>,
|
||||||
|
{
|
||||||
|
self.edit([(self.len()..self.len(), text)], None, cx)
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies the given edits to the buffer. Each edit is specified as a range of text to
|
/// Applies the given edits to the buffer. Each edit is specified as a range of text to
|
||||||
/// delete, and a string of text to insert at that location.
|
/// delete, and a string of text to insert at that location.
|
||||||
///
|
///
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue