Allow the assistant to suggest edits to files in the project (#11993)

### Todo

* [x] tuck the new system prompt away somehow
* for now, we're treating it as built-in, and not editable. once we have
a way to fold away default prompts, let's make it a default prompt.
* [x] when applying edits, re-parse the edit from the latest content of
the assistant buffer (to allow for manual editing of edits)
* [x] automatically adjust the indentation of edits suggested by the
assistant
* [x] fix edit row highlights persisting even when assistant messages
with edits are deleted
* ~adjust the fuzzy search to allow for small errors in the old text,
using some string similarity routine~

We decided to defer the fuzzy searching thing to a separate PR, since
it's a little bit involved, and the current functionality works well
enough to be worth landing. A couple of notes on the fuzzy searching:
* sometimes the assistant accidentally omits line breaks from the text
that it wants to replace
* when the old text has hallucinations, the new text often contains the
same hallucinations. so we'll probably need to use a more fine-grained
editing strategy where we perform a character-wise diff of the old and
new text as reported by the assistant, and then adjust that diff so that
it can be applied to the actual buffer text

Release Notes:

- Added the ability to request edits to project files using the
assistant panel.

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-05-17 15:38:14 -07:00 committed by GitHub
parent 4386268a94
commit 84affa96ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 912 additions and 181 deletions

View file

@ -1,21 +1,14 @@
use std::fmt::Write;
use std::iter;
use std::path::PathBuf;
use std::time::Duration;
use anyhow::Result;
use crate::{assistant_panel::Conversation, LanguageModelRequestMessage, Role};
use gpui::{ModelContext, Subscription, Task, WeakModel};
use language::{Buffer, BufferSnapshot, DiagnosticEntry, Point};
use util::ResultExt;
use language::{Buffer, BufferSnapshot, Rope};
use std::{fmt::Write, path::PathBuf, time::Duration};
use crate::ambient_context::ContextUpdated;
use crate::assistant_panel::Conversation;
use crate::{LanguageModelRequestMessage, Role};
use super::ContextUpdated;
pub struct RecentBuffersContext {
pub enabled: bool,
pub buffers: Vec<RecentBuffer>,
pub message: String,
pub snapshot: RecentBuffersSnapshot,
pub pending_message: Option<Task<()>>,
}
@ -29,27 +22,19 @@ impl Default for RecentBuffersContext {
Self {
enabled: true,
buffers: Vec::new(),
message: String::new(),
snapshot: RecentBuffersSnapshot::default(),
pending_message: None,
}
}
}
impl RecentBuffersContext {
/// Returns the [`RecentBuffersContext`] as a message to the language model.
pub fn to_message(&self) -> Option<LanguageModelRequestMessage> {
self.enabled.then(|| LanguageModelRequestMessage {
role: Role::System,
content: self.message.clone(),
})
}
pub fn update(&mut self, cx: &mut ModelContext<Conversation>) -> ContextUpdated {
let buffers = self
let source_buffers = self
.buffers
.iter()
.filter_map(|recent| {
recent
let (full_path, snapshot) = recent
.buffer
.read_with(cx, |buffer, cx| {
(
@ -57,12 +42,18 @@ impl RecentBuffersContext {
buffer.snapshot(),
)
})
.ok()
.ok()?;
Some(SourceBufferSnapshot {
full_path,
model: recent.buffer.clone(),
snapshot,
})
})
.collect::<Vec<_>>();
if !self.enabled || buffers.is_empty() {
self.message.clear();
if !self.enabled || source_buffers.is_empty() {
self.snapshot.message = Default::default();
self.snapshot.source_buffers.clear();
self.pending_message = None;
cx.notify();
ContextUpdated::Disabled
@ -71,131 +62,84 @@ impl RecentBuffersContext {
const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100);
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
let message_task = cx
.background_executor()
.spawn(async move { Self::build_message(&buffers) });
if let Some(message) = message_task.await.log_err() {
this.update(&mut cx, |conversation, cx| {
conversation.ambient_context.recent_buffers.message = message;
conversation.count_remaining_tokens(cx);
cx.notify();
})
.log_err();
}
let message = if source_buffers.is_empty() {
Rope::new()
} else {
cx.background_executor()
.spawn({
let source_buffers = source_buffers.clone();
async move { message_for_recent_buffers(source_buffers) }
})
.await
};
this.update(&mut cx, |this, cx| {
this.ambient_context.recent_buffers.snapshot.source_buffers = source_buffers;
this.ambient_context.recent_buffers.snapshot.message = message;
this.count_remaining_tokens(cx);
cx.notify();
})
.ok();
}));
ContextUpdated::Updating
}
}
fn build_message(buffers: &[(Option<PathBuf>, BufferSnapshot)]) -> Result<String> {
let mut message = String::new();
writeln!(
message,
"The following is a list of recent buffers that the user has opened."
)?;
writeln!(
message,
"For every line in the buffer, I will include a row number that line corresponds to."
)?;
writeln!(
message,
"Lines that don't have a number correspond to errors and warnings. For example:"
)?;
writeln!(message, "path/to/file.md")?;
writeln!(message, "```markdown")?;
writeln!(message, "1 The quick brown fox")?;
writeln!(message, "2 jumps over one active")?;
writeln!(message, " --- error: should be 'the'")?;
writeln!(message, " ------ error: should be 'lazy'")?;
writeln!(message, "3 dog")?;
writeln!(message, "```")?;
message.push('\n');
writeln!(message, "Here's the actual recent buffer list:")?;
for (path, buffer) in buffers {
if let Some(path) = path {
writeln!(message, "{}", path.display())?;
} else {
writeln!(message, "untitled")?;
}
if let Some(language) = buffer.language() {
writeln!(message, "```{}", language.name().to_lowercase())?;
} else {
writeln!(message, "```")?;
}
let mut diagnostics = buffer
.diagnostics_in_range::<_, Point>(
language::Anchor::MIN..language::Anchor::MAX,
false,
)
.peekable();
let mut active_diagnostics = Vec::<DiagnosticEntry<Point>>::new();
const GUTTER_PADDING: usize = 4;
let gutter_width =
((buffer.max_point().row + 1) as f32).log10() as usize + 1 + GUTTER_PADDING;
for buffer_row in 0..=buffer.max_point().row {
let display_row = buffer_row + 1;
active_diagnostics.retain(|diagnostic| {
(diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row)
});
while diagnostics.peek().map_or(false, |diagnostic| {
(diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row)
}) {
active_diagnostics.push(diagnostics.next().unwrap());
}
let row_width = (display_row as f32).log10() as usize + 1;
write!(message, "{}", display_row)?;
if row_width < gutter_width {
message.extend(iter::repeat(' ').take(gutter_width - row_width));
}
for chunk in buffer.text_for_range(
Point::new(buffer_row, 0)..Point::new(buffer_row, buffer.line_len(buffer_row)),
) {
message.push_str(chunk);
}
message.push('\n');
for diagnostic in &active_diagnostics {
message.extend(iter::repeat(' ').take(gutter_width));
let start_column = if diagnostic.range.start.row == buffer_row {
message
.extend(iter::repeat(' ').take(diagnostic.range.start.column as usize));
diagnostic.range.start.column
} else {
0
};
let end_column = if diagnostic.range.end.row == buffer_row {
diagnostic.range.end.column
} else {
buffer.line_len(buffer_row)
};
message.extend(iter::repeat('-').take((end_column - start_column) as usize));
writeln!(message, " {}", diagnostic.diagnostic.message)?;
}
}
message.push('\n');
}
writeln!(
message,
"When quoting the above code, mention which rows the code occurs at."
)?;
writeln!(
message,
"Never include rows in the quoted code itself and only report lines that didn't start with a row number."
)
?;
Ok(message)
/// Returns the [`RecentBuffersContext`] as a message to the language model.
pub fn to_message(&self) -> Option<LanguageModelRequestMessage> {
self.enabled.then(|| LanguageModelRequestMessage {
role: Role::System,
content: self.snapshot.message.to_string(),
})
}
}
#[derive(Clone, Default, Debug)]
pub struct RecentBuffersSnapshot {
pub message: Rope,
pub source_buffers: Vec<SourceBufferSnapshot>,
}
#[derive(Clone)]
pub struct SourceBufferSnapshot {
pub full_path: Option<PathBuf>,
pub model: WeakModel<Buffer>,
pub snapshot: BufferSnapshot,
}
impl std::fmt::Debug for SourceBufferSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SourceBufferSnapshot")
.field("full_path", &self.full_path)
.field("model (entity id)", &self.model.entity_id())
.field("snapshot (text)", &self.snapshot.text())
.finish()
}
}
fn message_for_recent_buffers(buffers: Vec<SourceBufferSnapshot>) -> Rope {
let mut message = String::new();
writeln!(
message,
"The following is a list of recent buffers that the user has opened."
)
.unwrap();
for buffer in buffers {
if let Some(path) = buffer.full_path {
writeln!(message, "```{}", path.display()).unwrap();
} else {
writeln!(message, "```untitled").unwrap();
}
for chunk in buffer.snapshot.chunks(0..buffer.snapshot.len(), false) {
message.push_str(chunk.text);
}
if !message.ends_with('\n') {
message.push('\n');
}
message.push_str("```\n");
}
Rope::from(message.as_str())
}