Use tool calling instead of XML parsing to generate edit operations (#15385)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
f6012cd86e
commit
6e1f7c6e1d
22 changed files with 1155 additions and 853 deletions
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
|
||||
MessageId, MessageStatus,
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
|
||||
LanguageModelCompletionProvider, MessageId, MessageStatus,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
|
@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
|
|||
use language::{
|
||||
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
||||
};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::{LanguageModelRequest, Role};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::contexts_dir;
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
cmp,
|
||||
|
@ -352,7 +352,7 @@ pub struct EditSuggestion {
|
|||
pub range: Range<language::Anchor>,
|
||||
/// If None, assume this is a suggestion to delete the range rather than transform it.
|
||||
pub description: Option<String>,
|
||||
pub prepend_newline: bool,
|
||||
pub initial_insertion: Option<InitialInsertion>,
|
||||
}
|
||||
|
||||
impl EditStep {
|
||||
|
@ -361,7 +361,7 @@ impl EditStep {
|
|||
project: &Model<Project>,
|
||||
cx: &AppContext,
|
||||
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
|
||||
let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else {
|
||||
let Some(EditStepOperations::Ready(operations)) = &self.operations else {
|
||||
return Task::ready(HashMap::default());
|
||||
};
|
||||
|
||||
|
@ -471,32 +471,28 @@ impl EditStep {
|
|||
}
|
||||
|
||||
pub enum EditStepOperations {
|
||||
Pending(Task<Result<()>>),
|
||||
Parsed {
|
||||
operations: Vec<EditOperation>,
|
||||
raw_output: String,
|
||||
},
|
||||
Pending(Task<Option<()>>),
|
||||
Ready(Vec<EditOperation>),
|
||||
}
|
||||
|
||||
impl Debug for EditStepOperations {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
|
||||
EditStepOperations::Parsed {
|
||||
operations,
|
||||
raw_output,
|
||||
} => f
|
||||
EditStepOperations::Ready(operations) => f
|
||||
.debug_struct("EditStepOperations::Parsed")
|
||||
.field("operations", operations)
|
||||
.field("raw_output", raw_output)
|
||||
.finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
/// A description of an operation to apply to one location in the codebase.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
|
||||
pub struct EditOperation {
|
||||
/// The path to the file containing the relevant operation
|
||||
pub path: String,
|
||||
#[serde(flatten)]
|
||||
pub kind: EditOperationKind,
|
||||
}
|
||||
|
||||
|
@ -523,7 +519,7 @@ impl EditOperation {
|
|||
parse_status.changed().await?;
|
||||
}
|
||||
|
||||
let prepend_newline = kind.prepend_newline();
|
||||
let initial_insertion = kind.initial_insertion();
|
||||
let suggestion_range = if let Some(symbol) = kind.symbol() {
|
||||
let outline = buffer
|
||||
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
|
||||
|
@ -601,39 +597,61 @@ impl EditOperation {
|
|||
EditSuggestion {
|
||||
range: suggestion_range,
|
||||
description: kind.description().map(ToString::to_string),
|
||||
prepend_newline,
|
||||
initial_insertion,
|
||||
},
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum EditOperationKind {
|
||||
/// Rewrite the specified symbol in its entirely based on the given description.
|
||||
Update {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Create a new file with the given path based on the given description.
|
||||
Create {
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol based on the given description before the specified symbol.
|
||||
InsertSiblingBefore {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol based on the given description after the specified symbol.
|
||||
InsertSiblingAfter {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol as a child of the specified symbol at the start.
|
||||
PrependChild {
|
||||
/// An optional full path to the symbol to be rewritten from the provided list.
|
||||
/// If not provided, the edit should be applied at the top of the file.
|
||||
symbol: Option<String>,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol as a child of the specified symbol at the end.
|
||||
AppendChild {
|
||||
/// An optional full path to the symbol to be rewritten from the provided list.
|
||||
/// If not provided, the edit should be applied at the top of the file.
|
||||
symbol: Option<String>,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Delete the specified symbol.
|
||||
Delete {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
},
|
||||
}
|
||||
|
@ -663,13 +681,13 @@ impl EditOperationKind {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn prepend_newline(&self) -> bool {
|
||||
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
|
||||
match self {
|
||||
Self::PrependChild { .. }
|
||||
| Self::AppendChild { .. }
|
||||
| Self::InsertSiblingAfter { .. }
|
||||
| Self::InsertSiblingBefore { .. } => true,
|
||||
_ => false,
|
||||
EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
|
||||
EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
|
||||
EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
|
||||
EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1137,18 +1155,15 @@ impl Context {
|
|||
.timer(Duration::from_millis(200))
|
||||
.await;
|
||||
|
||||
if let Some(token_count) = cx.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})? {
|
||||
let token_count = token_count.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
})?;
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
})
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
|
@ -1304,7 +1319,24 @@ impl Context {
|
|||
&self,
|
||||
edit_step: &EditStep,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
) -> Task<Option<()>> {
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct EditTool {
|
||||
/// A sequence of operations to apply to the codebase.
|
||||
/// When multiple operations are required for a step, be sure to include multiple operations in this list.
|
||||
operations: Vec<EditOperation>,
|
||||
}
|
||||
|
||||
impl LanguageModelTool for EditTool {
|
||||
fn name() -> String {
|
||||
"edit".into()
|
||||
}
|
||||
|
||||
fn description() -> String {
|
||||
"suggest edits to one or more locations in the codebase".into()
|
||||
}
|
||||
}
|
||||
|
||||
let mut request = self.to_completion_request(cx);
|
||||
let edit_step_range = edit_step.source_range.clone();
|
||||
let step_text = self
|
||||
|
@ -1313,160 +1345,41 @@ impl Context {
|
|||
.text_for_range(edit_step_range.clone())
|
||||
.collect::<String>();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
|
||||
cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
|
||||
|
||||
let mut prompt = prompt_store.operations_prompt();
|
||||
prompt.push_str(&step_text);
|
||||
let mut prompt = prompt_store.operations_prompt();
|
||||
prompt.push_str(&step_text);
|
||||
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
|
||||
let raw_output = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
|
||||
let tool_use = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.use_tool::<EditTool>(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let step_index = this
|
||||
.edit_steps
|
||||
.binary_search_by(|step| {
|
||||
step.source_range
|
||||
.cmp(&edit_step_range, this.buffer.read(cx))
|
||||
})
|
||||
.map_err(|_| anyhow!("edit step not found"))?;
|
||||
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
|
||||
edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
|
||||
cx.emit(ContextEvent::EditStepsChanged);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let operations = Self::parse_edit_operations(&raw_output);
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let step_index = this
|
||||
.edit_steps
|
||||
.binary_search_by(|step| {
|
||||
step.source_range
|
||||
.cmp(&edit_step_range, this.buffer.read(cx))
|
||||
})
|
||||
.map_err(|_| anyhow!("edit step not found"))?;
|
||||
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
|
||||
edit_step.operations = Some(EditStepOperations::Parsed {
|
||||
operations,
|
||||
raw_output,
|
||||
});
|
||||
cx.emit(ContextEvent::EditStepsChanged);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
|
||||
let Some(start_ix) = xml.find("<operations>") else {
|
||||
return Vec::new();
|
||||
};
|
||||
let Some(end_ix) = xml[start_ix..].find("</operations>") else {
|
||||
return Vec::new();
|
||||
};
|
||||
let end_ix = end_ix + start_ix + "</operations>".len();
|
||||
|
||||
let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
|
||||
doc.map_or(Vec::new(), |doc| {
|
||||
doc.root_element()
|
||||
.children()
|
||||
.map(|node| {
|
||||
let tag_name = node.tag_name().name();
|
||||
let path = node
|
||||
.attribute("path")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'path'")
|
||||
})?
|
||||
.to_string();
|
||||
let kind = match tag_name {
|
||||
"update" => EditOperationKind::Update {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"create" => EditOperationKind::Create {
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"prepend_child" => EditOperationKind::PrependChild {
|
||||
symbol: node.attribute("symbol").map(String::from),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"append_child" => EditOperationKind::AppendChild {
|
||||
symbol: node.attribute("symbol").map(String::from),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"delete" => EditOperationKind::Delete {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
_ => return Err(anyhow!("invalid node {node:?}")),
|
||||
};
|
||||
anyhow::Ok(EditOperation { path, kind })
|
||||
})
|
||||
.filter_map(|op| op.log_err())
|
||||
.collect()
|
||||
}
|
||||
.log_err()
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -3083,55 +2996,6 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_edit_operations() {
|
||||
let operations = indoc! {r#"
|
||||
Here are the operations to make all fields of the Canvas struct private:
|
||||
|
||||
<operations>
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
|
||||
</operations>
|
||||
"#};
|
||||
|
||||
let parsed_operations = Context::parse_edit_operations(operations);
|
||||
assert_eq!(
|
||||
parsed_operations,
|
||||
vec![
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub pixels".to_string(),
|
||||
description: "Remove pub keyword from pixels field".to_string(),
|
||||
},
|
||||
},
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub size".to_string(),
|
||||
description: "Remove pub keyword from size field".to_string(),
|
||||
},
|
||||
},
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub stride".to_string(),
|
||||
description: "Remove pub keyword from stride field".to_string(),
|
||||
},
|
||||
},
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub format".to_string(),
|
||||
description: "Remove pub keyword from format field".to_string(),
|
||||
},
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_serialization(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue