Use tool calling instead of XML parsing to generate edit operations (#15385)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-29 16:42:08 +02:00 committed by GitHub
parent f6012cd86e
commit 6e1f7c6e1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1155 additions and 853 deletions

View file

@ -1,6 +1,6 @@
use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
MessageId, MessageStatus,
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
LanguageModelCompletionProvider, MessageId, MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequest, Role};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
cmp,
@ -352,7 +352,7 @@ pub struct EditSuggestion {
pub range: Range<language::Anchor>,
/// If None, assume this is a suggestion to delete the range rather than transform it.
pub description: Option<String>,
pub prepend_newline: bool,
pub initial_insertion: Option<InitialInsertion>,
}
impl EditStep {
@ -361,7 +361,7 @@ impl EditStep {
project: &Model<Project>,
cx: &AppContext,
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else {
let Some(EditStepOperations::Ready(operations)) = &self.operations else {
return Task::ready(HashMap::default());
};
@ -471,32 +471,28 @@ impl EditStep {
}
pub enum EditStepOperations {
Pending(Task<Result<()>>),
Parsed {
operations: Vec<EditOperation>,
raw_output: String,
},
Pending(Task<Option<()>>),
Ready(Vec<EditOperation>),
}
impl Debug for EditStepOperations {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
EditStepOperations::Parsed {
operations,
raw_output,
} => f
EditStepOperations::Ready(operations) => f
.debug_struct("EditStepOperations::Parsed")
.field("operations", operations)
.field("raw_output", raw_output)
.finish(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
/// A description of an operation to apply to one location in the codebase.
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
pub struct EditOperation {
/// The path to the file containing the relevant operation
pub path: String,
#[serde(flatten)]
pub kind: EditOperationKind,
}
@ -523,7 +519,7 @@ impl EditOperation {
parse_status.changed().await?;
}
let prepend_newline = kind.prepend_newline();
let initial_insertion = kind.initial_insertion();
let suggestion_range = if let Some(symbol) = kind.symbol() {
let outline = buffer
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
@ -601,39 +597,61 @@ impl EditOperation {
EditSuggestion {
range: suggestion_range,
description: kind.description().map(ToString::to_string),
prepend_newline,
initial_insertion,
},
))
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
#[serde(tag = "kind")]
pub enum EditOperationKind {
/// Rewrite the specified symbol in its entirely based on the given description.
Update {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Create a new file with the given path based on the given description.
Create {
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol based on the given description before the specified symbol.
InsertSiblingBefore {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol based on the given description after the specified symbol.
InsertSiblingAfter {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol as a child of the specified symbol at the start.
PrependChild {
/// An optional full path to the symbol to be rewritten from the provided list.
/// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol as a child of the specified symbol at the end.
AppendChild {
/// An optional full path to the symbol to be rewritten from the provided list.
/// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Delete the specified symbol.
Delete {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
},
}
@ -663,13 +681,13 @@ impl EditOperationKind {
}
}
pub fn prepend_newline(&self) -> bool {
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
match self {
Self::PrependChild { .. }
| Self::AppendChild { .. }
| Self::InsertSiblingAfter { .. }
| Self::InsertSiblingBefore { .. } => true,
_ => false,
EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
_ => None,
}
}
}
@ -1137,18 +1155,15 @@ impl Context {
.timer(Duration::from_millis(200))
.await;
if let Some(token_count) = cx.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})? {
let token_count = token_count.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify()
})?;
}
anyhow::Ok(())
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify()
})
}
.log_err()
});
@ -1304,7 +1319,24 @@ impl Context {
&self,
edit_step: &EditStep,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
) -> Task<Option<()>> {
#[derive(Debug, Deserialize, JsonSchema)]
struct EditTool {
/// A sequence of operations to apply to the codebase.
/// When multiple operations are required for a step, be sure to include multiple operations in this list.
operations: Vec<EditOperation>,
}
impl LanguageModelTool for EditTool {
fn name() -> String {
"edit".into()
}
fn description() -> String {
"suggest edits to one or more locations in the codebase".into()
}
}
let mut request = self.to_completion_request(cx);
let edit_step_range = edit_step.source_range.clone();
let step_text = self
@ -1313,160 +1345,41 @@ impl Context {
.text_for_range(edit_step_range.clone())
.collect::<String>();
cx.spawn(|this, mut cx| async move {
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
cx.spawn(|this, mut cx| {
async move {
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
let mut prompt = prompt_store.operations_prompt();
prompt.push_str(&step_text);
let mut prompt = prompt_store.operations_prompt();
prompt.push_str(&step_text);
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
let raw_output = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
let tool_use = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.use_tool::<EditTool>(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
let step_index = this
.edit_steps
.binary_search_by(|step| {
step.source_range
.cmp(&edit_step_range, this.buffer.read(cx))
})
.map_err(|_| anyhow!("edit step not found"))?;
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
cx.emit(ContextEvent::EditStepsChanged);
}
anyhow::Ok(())
})?
.await?;
let operations = Self::parse_edit_operations(&raw_output);
this.update(&mut cx, |this, cx| {
let step_index = this
.edit_steps
.binary_search_by(|step| {
step.source_range
.cmp(&edit_step_range, this.buffer.read(cx))
})
.map_err(|_| anyhow!("edit step not found"))?;
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
edit_step.operations = Some(EditStepOperations::Parsed {
operations,
raw_output,
});
cx.emit(ContextEvent::EditStepsChanged);
}
anyhow::Ok(())
})?
})
}
fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
let Some(start_ix) = xml.find("<operations>") else {
return Vec::new();
};
let Some(end_ix) = xml[start_ix..].find("</operations>") else {
return Vec::new();
};
let end_ix = end_ix + start_ix + "</operations>".len();
let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
doc.map_or(Vec::new(), |doc| {
doc.root_element()
.children()
.map(|node| {
let tag_name = node.tag_name().name();
let path = node
.attribute("path")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'path'")
})?
.to_string();
let kind = match tag_name {
"update" => EditOperationKind::Update {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"create" => EditOperationKind::Create {
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"prepend_child" => EditOperationKind::PrependChild {
symbol: node.attribute("symbol").map(String::from),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"append_child" => EditOperationKind::AppendChild {
symbol: node.attribute("symbol").map(String::from),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"delete" => EditOperationKind::Delete {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
},
_ => return Err(anyhow!("invalid node {node:?}")),
};
anyhow::Ok(EditOperation { path, kind })
})
.filter_map(|op| op.log_err())
.collect()
}
.log_err()
})
}
@ -3083,55 +2996,6 @@ mod tests {
}
}
#[test]
fn test_parse_edit_operations() {
let operations = indoc! {r#"
Here are the operations to make all fields of the Canvas struct private:
<operations>
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
</operations>
"#};
let parsed_operations = Context::parse_edit_operations(operations);
assert_eq!(
parsed_operations,
vec![
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub pixels".to_string(),
description: "Remove pub keyword from pixels field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub size".to_string(),
description: "Remove pub keyword from size field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub stride".to_string(),
description: "Remove pub keyword from stride field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub format".to_string(),
description: "Remove pub keyword from format field".to_string(),
},
},
]
);
}
#[gpui::test]
async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);