Use tool calling instead of XML parsing to generate edit operations (#15385)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
f6012cd86e
commit
6e1f7c6e1d
22 changed files with 1155 additions and 853 deletions
|
@ -1232,12 +1232,16 @@ impl ContextEditor {
|
|||
|
||||
fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
||||
if let Some(step) = self.active_edit_step.as_ref() {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
for assist_id in &step.assist_ids {
|
||||
assistant.start_assist(*assist_id, cx);
|
||||
}
|
||||
!step.assist_ids.is_empty()
|
||||
})
|
||||
let assist_ids = step.assist_ids.clone();
|
||||
cx.window_context().defer(|cx| {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
for assist_id in assist_ids {
|
||||
assistant.start_assist(assist_id, cx);
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
!step.assist_ids.is_empty()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
@ -1286,11 +1290,7 @@ impl ContextEditor {
|
|||
.collect::<String>()
|
||||
));
|
||||
match &step.operations {
|
||||
Some(EditStepOperations::Parsed {
|
||||
operations,
|
||||
raw_output,
|
||||
}) => {
|
||||
output.push_str(&format!("Raw Output:\n{raw_output}\n"));
|
||||
Some(EditStepOperations::Ready(operations)) => {
|
||||
output.push_str("Parsed Operations:\n");
|
||||
for op in operations {
|
||||
output.push_str(&format!(" {:?}\n", op));
|
||||
|
@ -1794,13 +1794,12 @@ impl ContextEditor {
|
|||
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
|
||||
.unwrap()
|
||||
};
|
||||
let initial_text = suggestion.prepend_newline.then(|| "\n".into());
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
assist_ids.push(assistant.suggest_assist(
|
||||
&editor,
|
||||
range,
|
||||
description,
|
||||
initial_text,
|
||||
suggestion.initial_insertion,
|
||||
Some(workspace.clone()),
|
||||
assistant_panel.upgrade().as_ref(),
|
||||
cx,
|
||||
|
@ -1862,9 +1861,11 @@ impl ContextEditor {
|
|||
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
|
||||
.unwrap()
|
||||
};
|
||||
let initial_text =
|
||||
suggestion.prepend_newline.then(|| "\n".to_string());
|
||||
inline_assist_suggestions.push((range, description, initial_text));
|
||||
inline_assist_suggestions.push((
|
||||
range,
|
||||
description,
|
||||
suggestion.initial_insertion,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1875,12 +1876,12 @@ impl ContextEditor {
|
|||
.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?;
|
||||
cx.update(|cx| {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
for (range, description, initial_text) in inline_assist_suggestions {
|
||||
for (range, description, initial_insertion) in inline_assist_suggestions {
|
||||
assist_ids.push(assistant.suggest_assist(
|
||||
&editor,
|
||||
range,
|
||||
description,
|
||||
initial_text,
|
||||
initial_insertion,
|
||||
Some(workspace.clone()),
|
||||
assistant_panel.upgrade().as_ref(),
|
||||
cx,
|
||||
|
@ -2188,7 +2189,7 @@ impl ContextEditor {
|
|||
let button_text = match self.edit_step_for_cursor(cx) {
|
||||
Some(edit_step) => match &edit_step.operations {
|
||||
Some(EditStepOperations::Pending(_)) => "Computing Changes...",
|
||||
Some(EditStepOperations::Parsed { .. }) => "Apply Changes",
|
||||
Some(EditStepOperations::Ready(_)) => "Apply Changes",
|
||||
None => "Send",
|
||||
},
|
||||
None => "Send",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
|
||||
MessageId, MessageStatus,
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
|
||||
LanguageModelCompletionProvider, MessageId, MessageStatus,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
|
@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
|
|||
use language::{
|
||||
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
||||
};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::{LanguageModelRequest, Role};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::contexts_dir;
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
cmp,
|
||||
|
@ -352,7 +352,7 @@ pub struct EditSuggestion {
|
|||
pub range: Range<language::Anchor>,
|
||||
/// If None, assume this is a suggestion to delete the range rather than transform it.
|
||||
pub description: Option<String>,
|
||||
pub prepend_newline: bool,
|
||||
pub initial_insertion: Option<InitialInsertion>,
|
||||
}
|
||||
|
||||
impl EditStep {
|
||||
|
@ -361,7 +361,7 @@ impl EditStep {
|
|||
project: &Model<Project>,
|
||||
cx: &AppContext,
|
||||
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
|
||||
let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else {
|
||||
let Some(EditStepOperations::Ready(operations)) = &self.operations else {
|
||||
return Task::ready(HashMap::default());
|
||||
};
|
||||
|
||||
|
@ -471,32 +471,28 @@ impl EditStep {
|
|||
}
|
||||
|
||||
pub enum EditStepOperations {
|
||||
Pending(Task<Result<()>>),
|
||||
Parsed {
|
||||
operations: Vec<EditOperation>,
|
||||
raw_output: String,
|
||||
},
|
||||
Pending(Task<Option<()>>),
|
||||
Ready(Vec<EditOperation>),
|
||||
}
|
||||
|
||||
impl Debug for EditStepOperations {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
|
||||
EditStepOperations::Parsed {
|
||||
operations,
|
||||
raw_output,
|
||||
} => f
|
||||
EditStepOperations::Ready(operations) => f
|
||||
.debug_struct("EditStepOperations::Parsed")
|
||||
.field("operations", operations)
|
||||
.field("raw_output", raw_output)
|
||||
.finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
/// A description of an operation to apply to one location in the codebase.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
|
||||
pub struct EditOperation {
|
||||
/// The path to the file containing the relevant operation
|
||||
pub path: String,
|
||||
#[serde(flatten)]
|
||||
pub kind: EditOperationKind,
|
||||
}
|
||||
|
||||
|
@ -523,7 +519,7 @@ impl EditOperation {
|
|||
parse_status.changed().await?;
|
||||
}
|
||||
|
||||
let prepend_newline = kind.prepend_newline();
|
||||
let initial_insertion = kind.initial_insertion();
|
||||
let suggestion_range = if let Some(symbol) = kind.symbol() {
|
||||
let outline = buffer
|
||||
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
|
||||
|
@ -601,39 +597,61 @@ impl EditOperation {
|
|||
EditSuggestion {
|
||||
range: suggestion_range,
|
||||
description: kind.description().map(ToString::to_string),
|
||||
prepend_newline,
|
||||
initial_insertion,
|
||||
},
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum EditOperationKind {
|
||||
/// Rewrite the specified symbol in its entirely based on the given description.
|
||||
Update {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Create a new file with the given path based on the given description.
|
||||
Create {
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol based on the given description before the specified symbol.
|
||||
InsertSiblingBefore {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol based on the given description after the specified symbol.
|
||||
InsertSiblingAfter {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol as a child of the specified symbol at the start.
|
||||
PrependChild {
|
||||
/// An optional full path to the symbol to be rewritten from the provided list.
|
||||
/// If not provided, the edit should be applied at the top of the file.
|
||||
symbol: Option<String>,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol as a child of the specified symbol at the end.
|
||||
AppendChild {
|
||||
/// An optional full path to the symbol to be rewritten from the provided list.
|
||||
/// If not provided, the edit should be applied at the top of the file.
|
||||
symbol: Option<String>,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Delete the specified symbol.
|
||||
Delete {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
},
|
||||
}
|
||||
|
@ -663,13 +681,13 @@ impl EditOperationKind {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn prepend_newline(&self) -> bool {
|
||||
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
|
||||
match self {
|
||||
Self::PrependChild { .. }
|
||||
| Self::AppendChild { .. }
|
||||
| Self::InsertSiblingAfter { .. }
|
||||
| Self::InsertSiblingBefore { .. } => true,
|
||||
_ => false,
|
||||
EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
|
||||
EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
|
||||
EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
|
||||
EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1137,18 +1155,15 @@ impl Context {
|
|||
.timer(Duration::from_millis(200))
|
||||
.await;
|
||||
|
||||
if let Some(token_count) = cx.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})? {
|
||||
let token_count = token_count.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
})?;
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
})
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
|
@ -1304,7 +1319,24 @@ impl Context {
|
|||
&self,
|
||||
edit_step: &EditStep,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
) -> Task<Option<()>> {
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct EditTool {
|
||||
/// A sequence of operations to apply to the codebase.
|
||||
/// When multiple operations are required for a step, be sure to include multiple operations in this list.
|
||||
operations: Vec<EditOperation>,
|
||||
}
|
||||
|
||||
impl LanguageModelTool for EditTool {
|
||||
fn name() -> String {
|
||||
"edit".into()
|
||||
}
|
||||
|
||||
fn description() -> String {
|
||||
"suggest edits to one or more locations in the codebase".into()
|
||||
}
|
||||
}
|
||||
|
||||
let mut request = self.to_completion_request(cx);
|
||||
let edit_step_range = edit_step.source_range.clone();
|
||||
let step_text = self
|
||||
|
@ -1313,160 +1345,41 @@ impl Context {
|
|||
.text_for_range(edit_step_range.clone())
|
||||
.collect::<String>();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
|
||||
cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
|
||||
|
||||
let mut prompt = prompt_store.operations_prompt();
|
||||
prompt.push_str(&step_text);
|
||||
let mut prompt = prompt_store.operations_prompt();
|
||||
prompt.push_str(&step_text);
|
||||
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
|
||||
let raw_output = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
|
||||
let tool_use = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.use_tool::<EditTool>(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let step_index = this
|
||||
.edit_steps
|
||||
.binary_search_by(|step| {
|
||||
step.source_range
|
||||
.cmp(&edit_step_range, this.buffer.read(cx))
|
||||
})
|
||||
.map_err(|_| anyhow!("edit step not found"))?;
|
||||
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
|
||||
edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
|
||||
cx.emit(ContextEvent::EditStepsChanged);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let operations = Self::parse_edit_operations(&raw_output);
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let step_index = this
|
||||
.edit_steps
|
||||
.binary_search_by(|step| {
|
||||
step.source_range
|
||||
.cmp(&edit_step_range, this.buffer.read(cx))
|
||||
})
|
||||
.map_err(|_| anyhow!("edit step not found"))?;
|
||||
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
|
||||
edit_step.operations = Some(EditStepOperations::Parsed {
|
||||
operations,
|
||||
raw_output,
|
||||
});
|
||||
cx.emit(ContextEvent::EditStepsChanged);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
|
||||
let Some(start_ix) = xml.find("<operations>") else {
|
||||
return Vec::new();
|
||||
};
|
||||
let Some(end_ix) = xml[start_ix..].find("</operations>") else {
|
||||
return Vec::new();
|
||||
};
|
||||
let end_ix = end_ix + start_ix + "</operations>".len();
|
||||
|
||||
let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
|
||||
doc.map_or(Vec::new(), |doc| {
|
||||
doc.root_element()
|
||||
.children()
|
||||
.map(|node| {
|
||||
let tag_name = node.tag_name().name();
|
||||
let path = node
|
||||
.attribute("path")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'path'")
|
||||
})?
|
||||
.to_string();
|
||||
let kind = match tag_name {
|
||||
"update" => EditOperationKind::Update {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"create" => EditOperationKind::Create {
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"prepend_child" => EditOperationKind::PrependChild {
|
||||
symbol: node.attribute("symbol").map(String::from),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"append_child" => EditOperationKind::AppendChild {
|
||||
symbol: node.attribute("symbol").map(String::from),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"delete" => EditOperationKind::Delete {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
_ => return Err(anyhow!("invalid node {node:?}")),
|
||||
};
|
||||
anyhow::Ok(EditOperation { path, kind })
|
||||
})
|
||||
.filter_map(|op| op.log_err())
|
||||
.collect()
|
||||
}
|
||||
.log_err()
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -3083,55 +2996,6 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_edit_operations() {
|
||||
let operations = indoc! {r#"
|
||||
Here are the operations to make all fields of the Canvas struct private:
|
||||
|
||||
<operations>
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
|
||||
</operations>
|
||||
"#};
|
||||
|
||||
let parsed_operations = Context::parse_edit_operations(operations);
|
||||
assert_eq!(
|
||||
parsed_operations,
|
||||
vec![
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub pixels".to_string(),
|
||||
description: "Remove pub keyword from pixels field".to_string(),
|
||||
},
|
||||
},
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub size".to_string(),
|
||||
description: "Remove pub keyword from size field".to_string(),
|
||||
},
|
||||
},
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub stride".to_string(),
|
||||
description: "Remove pub keyword from stride field".to_string(),
|
||||
},
|
||||
},
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub format".to_string(),
|
||||
description: "Remove pub keyword from format field".to_string(),
|
||||
},
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_serialization(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
|
|
|
@ -17,7 +17,7 @@ use editor::{
|
|||
use fs::Fs;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
future::LocalBoxFuture,
|
||||
future::{BoxFuture, LocalBoxFuture},
|
||||
stream::{self, BoxStream},
|
||||
SinkExt, Stream, StreamExt,
|
||||
};
|
||||
|
@ -36,7 +36,7 @@ use similar::TextDiff;
|
|||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
cmp,
|
||||
future::Future,
|
||||
future::{self, Future},
|
||||
mem,
|
||||
ops::{Range, RangeInclusive},
|
||||
pin::Pin,
|
||||
|
@ -46,7 +46,7 @@ use std::{
|
|||
};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{prelude::*, IconButtonShape, Tooltip};
|
||||
use util::RangeExt;
|
||||
use util::{RangeExt, ResultExt};
|
||||
use workspace::{notifications::NotificationId, Toast, Workspace};
|
||||
|
||||
pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
|
||||
|
@ -187,7 +187,13 @@ impl InlineAssistant {
|
|||
let [prompt_block_id, end_block_id] =
|
||||
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
|
||||
|
||||
assists.push((assist_id, prompt_editor, prompt_block_id, end_block_id));
|
||||
assists.push((
|
||||
assist_id,
|
||||
range,
|
||||
prompt_editor,
|
||||
prompt_block_id,
|
||||
end_block_id,
|
||||
));
|
||||
}
|
||||
|
||||
let editor_assists = self
|
||||
|
@ -195,7 +201,7 @@ impl InlineAssistant {
|
|||
.entry(editor.downgrade())
|
||||
.or_insert_with(|| EditorInlineAssists::new(&editor, cx));
|
||||
let mut assist_group = InlineAssistGroup::new();
|
||||
for (assist_id, prompt_editor, prompt_block_id, end_block_id) in assists {
|
||||
for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
|
||||
self.assists.insert(
|
||||
assist_id,
|
||||
InlineAssist::new(
|
||||
|
@ -206,6 +212,7 @@ impl InlineAssistant {
|
|||
&prompt_editor,
|
||||
prompt_block_id,
|
||||
end_block_id,
|
||||
range,
|
||||
prompt_editor.read(cx).codegen.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
|
@ -227,7 +234,7 @@ impl InlineAssistant {
|
|||
editor: &View<Editor>,
|
||||
mut range: Range<Anchor>,
|
||||
initial_prompt: String,
|
||||
initial_insertion: Option<String>,
|
||||
initial_insertion: Option<InitialInsertion>,
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
assistant_panel: Option<&View<AssistantPanel>>,
|
||||
cx: &mut WindowContext,
|
||||
|
@ -239,22 +246,30 @@ impl InlineAssistant {
|
|||
let assist_id = self.next_assist_id.post_inc();
|
||||
|
||||
let buffer = editor.read(cx).buffer().clone();
|
||||
let prepend_transaction_id = initial_insertion.and_then(|initial_insertion| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.start_transaction(cx);
|
||||
buffer.edit([(range.start..range.start, initial_insertion)], None, cx);
|
||||
buffer.end_transaction(cx)
|
||||
})
|
||||
});
|
||||
{
|
||||
let snapshot = buffer.read(cx).read(cx);
|
||||
|
||||
range.start = range.start.bias_left(&buffer.read(cx).read(cx));
|
||||
range.end = range.end.bias_right(&buffer.read(cx).read(cx));
|
||||
let mut point_range = range.to_point(&snapshot);
|
||||
if point_range.is_empty() {
|
||||
point_range.start.column = 0;
|
||||
point_range.end.column = 0;
|
||||
} else {
|
||||
point_range.start.column = 0;
|
||||
if point_range.end.row > point_range.start.row && point_range.end.column == 0 {
|
||||
point_range.end.row -= 1;
|
||||
}
|
||||
point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row));
|
||||
}
|
||||
|
||||
range.start = snapshot.anchor_before(point_range.start);
|
||||
range.end = snapshot.anchor_after(point_range.end);
|
||||
}
|
||||
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
editor.read(cx).buffer().clone(),
|
||||
range.clone(),
|
||||
prepend_transaction_id,
|
||||
initial_insertion,
|
||||
self.telemetry.clone(),
|
||||
cx,
|
||||
)
|
||||
|
@ -295,6 +310,7 @@ impl InlineAssistant {
|
|||
&prompt_editor,
|
||||
prompt_block_id,
|
||||
end_block_id,
|
||||
range,
|
||||
prompt_editor.read(cx).codegen.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
|
@ -445,7 +461,7 @@ impl InlineAssistant {
|
|||
let buffer = editor.buffer().read(cx).snapshot(cx);
|
||||
for assist_id in &editor_assists.assist_ids {
|
||||
let assist = &self.assists[assist_id];
|
||||
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
|
||||
let assist_range = assist.range.to_offset(&buffer);
|
||||
if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
|
||||
{
|
||||
if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
|
||||
|
@ -473,7 +489,7 @@ impl InlineAssistant {
|
|||
let buffer = editor.buffer().read(cx).snapshot(cx);
|
||||
for assist_id in &editor_assists.assist_ids {
|
||||
let assist = &self.assists[assist_id];
|
||||
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
|
||||
let assist_range = assist.range.to_offset(&buffer);
|
||||
if assist.decorations.is_some()
|
||||
&& assist_range.contains(&selection.start)
|
||||
&& assist_range.contains(&selection.end)
|
||||
|
@ -551,7 +567,7 @@ impl InlineAssistant {
|
|||
assist.codegen.read(cx).status,
|
||||
CodegenStatus::Error(_) | CodegenStatus::Done
|
||||
) {
|
||||
let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot);
|
||||
let assist_range = assist.range.to_offset(&snapshot);
|
||||
if edited_ranges
|
||||
.iter()
|
||||
.any(|range| range.overlaps(&assist_range))
|
||||
|
@ -721,7 +737,7 @@ impl InlineAssistant {
|
|||
});
|
||||
}
|
||||
|
||||
let position = assist.codegen.read(cx).range.start;
|
||||
let position = assist.range.start;
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.change_selections(None, cx, |selections| {
|
||||
selections.select_anchor_ranges([position..position])
|
||||
|
@ -740,8 +756,7 @@ impl InlineAssistant {
|
|||
.0 as f32;
|
||||
} else {
|
||||
let snapshot = editor.snapshot(cx);
|
||||
let codegen = assist.codegen.read(cx);
|
||||
let start_row = codegen
|
||||
let start_row = assist
|
||||
.range
|
||||
.start
|
||||
.to_display_point(&snapshot.display_snapshot)
|
||||
|
@ -829,11 +844,7 @@ impl InlineAssistant {
|
|||
return;
|
||||
}
|
||||
|
||||
let Some(user_prompt) = assist
|
||||
.decorations
|
||||
.as_ref()
|
||||
.map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
|
||||
else {
|
||||
let Some(user_prompt) = assist.user_prompt(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
@ -843,139 +854,19 @@ impl InlineAssistant {
|
|||
self.prompt_history.pop_front();
|
||||
}
|
||||
|
||||
let codegen = assist.codegen.clone();
|
||||
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.telemetry_id())
|
||||
.unwrap_or_default();
|
||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
|
||||
if user_prompt.trim().to_lowercase() == "delete" {
|
||||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||
} else {
|
||||
let request = self.request_for_inline_assist(assist_id, cx);
|
||||
let mut cx = cx.to_async();
|
||||
async move {
|
||||
let request = request.await?;
|
||||
let chunks = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.stream_completion(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
Ok(chunks.boxed())
|
||||
}
|
||||
.boxed_local()
|
||||
};
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(telemetry_id, chunks, cx);
|
||||
});
|
||||
}
|
||||
let assistant_panel_context = assist.assistant_panel_context(cx);
|
||||
|
||||
fn request_for_inline_assist(
|
||||
&self,
|
||||
assist_id: InlineAssistId,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<LanguageModelRequest>> {
|
||||
cx.spawn(|mut cx| async move {
|
||||
let (user_prompt, context_request, project_name, buffer, range) =
|
||||
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
|
||||
let assist = this.assists.get(&assist_id).context("invalid assist")?;
|
||||
let decorations = assist.decorations.as_ref().context("invalid assist")?;
|
||||
let editor = assist.editor.upgrade().context("invalid assist")?;
|
||||
let user_prompt = decorations.prompt_editor.read(cx).prompt(cx);
|
||||
let context_request = if assist.include_context {
|
||||
assist.workspace.as_ref().and_then(|workspace| {
|
||||
let workspace = workspace.upgrade()?.read(cx);
|
||||
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
|
||||
Some(
|
||||
assistant_panel
|
||||
.read(cx)
|
||||
.active_context(cx)?
|
||||
.read(cx)
|
||||
.to_completion_request(cx),
|
||||
)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let project_name = assist.workspace.as_ref().and_then(|workspace| {
|
||||
let workspace = workspace.upgrade()?;
|
||||
Some(
|
||||
workspace
|
||||
.read(cx)
|
||||
.project()
|
||||
.read(cx)
|
||||
.worktree_root_names(cx)
|
||||
.collect::<Vec<&str>>()
|
||||
.join("/"),
|
||||
)
|
||||
});
|
||||
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
let range = assist.codegen.read(cx).range.clone();
|
||||
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
|
||||
})??;
|
||||
|
||||
let language = buffer.language_at(range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
|
||||
None
|
||||
} else {
|
||||
Some(language.name())
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Higher Temperature increases the randomness of model outputs.
|
||||
// If Markdown or No Language is Known, increase the randomness for more creative output
|
||||
// If Code, decrease temperature to get more deterministic outputs
|
||||
let temperature = if let Some(language) = language_name.clone() {
|
||||
if language.as_ref() == "Markdown" {
|
||||
1.0
|
||||
} else {
|
||||
0.5
|
||||
}
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let prompt = cx
|
||||
.background_executor()
|
||||
.spawn(async move {
|
||||
let language_name = language_name.as_deref();
|
||||
let start = buffer.point_to_buffer_offset(range.start);
|
||||
let end = buffer.point_to_buffer_offset(range.end);
|
||||
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
|
||||
let (start_buffer, start_buffer_offset) = start;
|
||||
let (end_buffer, end_buffer_offset) = end;
|
||||
if start_buffer.remote_id() == end_buffer.remote_id() {
|
||||
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
|
||||
} else {
|
||||
return Err(anyhow!("invalid transformation range"));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("invalid transformation range"));
|
||||
};
|
||||
generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
if let Some(context_request) = context_request {
|
||||
messages = context_request.messages;
|
||||
}
|
||||
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
messages,
|
||||
stop: vec!["|END|>".to_string()],
|
||||
temperature,
|
||||
assist
|
||||
.codegen
|
||||
.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
assist.range.clone(),
|
||||
user_prompt,
|
||||
assistant_panel_context,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
|
||||
|
@ -1006,12 +897,11 @@ impl InlineAssistant {
|
|||
let codegen = assist.codegen.read(cx);
|
||||
foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
|
||||
|
||||
if codegen.edit_position != codegen.range.end {
|
||||
gutter_pending_ranges.push(codegen.edit_position..codegen.range.end);
|
||||
}
|
||||
gutter_pending_ranges
|
||||
.push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end);
|
||||
|
||||
if codegen.range.start != codegen.edit_position {
|
||||
gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position);
|
||||
if let Some(edit_position) = codegen.edit_position {
|
||||
gutter_transformed_ranges.push(assist.range.start..edit_position);
|
||||
}
|
||||
|
||||
if assist.decorations.is_some() {
|
||||
|
@ -1268,6 +1158,12 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
|
|||
})
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum InitialInsertion {
|
||||
NewlineBefore,
|
||||
NewlineAfter,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct InlineAssistId(usize);
|
||||
|
||||
|
@ -1629,24 +1525,20 @@ impl PromptEditor {
|
|||
let assist_id = self.id;
|
||||
self.pending_token_count = cx.spawn(|this, mut cx| async move {
|
||||
cx.background_executor().timer(Duration::from_secs(1)).await;
|
||||
let request = cx
|
||||
let token_count = cx
|
||||
.update_global(|inline_assistant: &mut InlineAssistant, cx| {
|
||||
inline_assistant.request_for_inline_assist(assist_id, cx)
|
||||
})?
|
||||
let assist = inline_assistant
|
||||
.assists
|
||||
.get(&assist_id)
|
||||
.context("assist not found")?;
|
||||
anyhow::Ok(assist.count_tokens(cx))
|
||||
})??
|
||||
.await?;
|
||||
|
||||
if let Some(token_count) = cx.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})? {
|
||||
let token_count = token_count.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify();
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1855,6 +1747,7 @@ impl PromptEditor {
|
|||
|
||||
struct InlineAssist {
|
||||
group_id: InlineAssistGroupId,
|
||||
range: Range<Anchor>,
|
||||
editor: WeakView<Editor>,
|
||||
decorations: Option<InlineAssistDecorations>,
|
||||
codegen: Model<Codegen>,
|
||||
|
@ -1873,6 +1766,7 @@ impl InlineAssist {
|
|||
prompt_editor: &View<PromptEditor>,
|
||||
prompt_block_id: CustomBlockId,
|
||||
end_block_id: CustomBlockId,
|
||||
range: Range<Anchor>,
|
||||
codegen: Model<Codegen>,
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
cx: &mut WindowContext,
|
||||
|
@ -1888,6 +1782,7 @@ impl InlineAssist {
|
|||
removed_line_block_ids: HashSet::default(),
|
||||
end_block_id,
|
||||
}),
|
||||
range,
|
||||
codegen: codegen.clone(),
|
||||
workspace: workspace.clone(),
|
||||
_subscriptions: vec![
|
||||
|
@ -1963,6 +1858,41 @@ impl InlineAssist {
|
|||
],
|
||||
}
|
||||
}
|
||||
|
||||
fn user_prompt(&self, cx: &AppContext) -> Option<String> {
|
||||
let decorations = self.decorations.as_ref()?;
|
||||
Some(decorations.prompt_editor.read(cx).prompt(cx))
|
||||
}
|
||||
|
||||
fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
|
||||
if self.include_context {
|
||||
let workspace = self.workspace.as_ref()?;
|
||||
let workspace = workspace.upgrade()?.read(cx);
|
||||
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
|
||||
Some(
|
||||
assistant_panel
|
||||
.read(cx)
|
||||
.active_context(cx)?
|
||||
.read(cx)
|
||||
.to_completion_request(cx),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
|
||||
let Some(user_prompt) = self.user_prompt(cx) else {
|
||||
return future::ready(Err(anyhow!("no user prompt"))).boxed();
|
||||
};
|
||||
let assistant_panel_context = self.assistant_panel_context(cx);
|
||||
self.codegen.read(cx).count_tokens(
|
||||
self.range.clone(),
|
||||
user_prompt,
|
||||
assistant_panel_context,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct InlineAssistDecorations {
|
||||
|
@ -1982,16 +1912,15 @@ pub struct Codegen {
|
|||
buffer: Model<MultiBuffer>,
|
||||
old_buffer: Model<Buffer>,
|
||||
snapshot: MultiBufferSnapshot,
|
||||
range: Range<Anchor>,
|
||||
edit_position: Anchor,
|
||||
edit_position: Option<Anchor>,
|
||||
last_equal_ranges: Vec<Range<Anchor>>,
|
||||
prepend_transaction_id: Option<TransactionId>,
|
||||
generation_transaction_id: Option<TransactionId>,
|
||||
transaction_id: Option<TransactionId>,
|
||||
status: CodegenStatus,
|
||||
generation: Task<()>,
|
||||
diff: Diff,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
initial_insertion: Option<InitialInsertion>,
|
||||
}
|
||||
|
||||
enum CodegenStatus {
|
||||
|
@ -2015,7 +1944,7 @@ impl Codegen {
|
|||
pub fn new(
|
||||
buffer: Model<MultiBuffer>,
|
||||
range: Range<Anchor>,
|
||||
prepend_transaction_id: Option<TransactionId>,
|
||||
initial_insertion: Option<InitialInsertion>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
|
@ -2044,17 +1973,16 @@ impl Codegen {
|
|||
Self {
|
||||
buffer: buffer.clone(),
|
||||
old_buffer,
|
||||
edit_position: range.start,
|
||||
range,
|
||||
edit_position: None,
|
||||
snapshot,
|
||||
last_equal_ranges: Default::default(),
|
||||
prepend_transaction_id,
|
||||
generation_transaction_id: None,
|
||||
transaction_id: None,
|
||||
status: CodegenStatus::Idle,
|
||||
generation: Task::ready(()),
|
||||
diff: Diff::default(),
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
initial_insertion,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2065,13 +1993,8 @@ impl Codegen {
|
|||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
|
||||
if self.generation_transaction_id == Some(*transaction_id) {
|
||||
self.generation_transaction_id = None;
|
||||
self.generation = Task::ready(());
|
||||
cx.emit(CodegenEvent::Undone);
|
||||
} else if self.prepend_transaction_id == Some(*transaction_id) {
|
||||
self.prepend_transaction_id = None;
|
||||
self.generation_transaction_id = None;
|
||||
if self.transaction_id == Some(*transaction_id) {
|
||||
self.transaction_id = None;
|
||||
self.generation = Task::ready(());
|
||||
cx.emit(CodegenEvent::Undone);
|
||||
}
|
||||
|
@ -2082,19 +2005,152 @@ impl Codegen {
|
|||
&self.last_equal_ranges
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
edit_range: Range<Anchor>,
|
||||
user_prompt: String,
|
||||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
}
|
||||
|
||||
pub fn start(
|
||||
&mut self,
|
||||
telemetry_id: String,
|
||||
mut edit_range: Range<Anchor>,
|
||||
user_prompt: String,
|
||||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Result<()> {
|
||||
self.undo(cx);
|
||||
|
||||
// Handle initial insertion
|
||||
self.transaction_id = if let Some(initial_insertion) = self.initial_insertion {
|
||||
self.buffer.update(cx, |buffer, cx| {
|
||||
buffer.start_transaction(cx);
|
||||
let offset = edit_range.start.to_offset(&self.snapshot);
|
||||
let edit_position;
|
||||
match initial_insertion {
|
||||
InitialInsertion::NewlineBefore => {
|
||||
buffer.edit([(offset..offset, "\n\n")], None, cx);
|
||||
self.snapshot = buffer.snapshot(cx);
|
||||
edit_position = self.snapshot.anchor_after(offset + 1);
|
||||
}
|
||||
InitialInsertion::NewlineAfter => {
|
||||
buffer.edit([(offset..offset, "\n")], None, cx);
|
||||
self.snapshot = buffer.snapshot(cx);
|
||||
edit_position = self.snapshot.anchor_after(offset);
|
||||
}
|
||||
}
|
||||
self.edit_position = Some(edit_position);
|
||||
edit_range = edit_position.bias_left(&self.snapshot)..edit_position;
|
||||
buffer.end_transaction(cx)
|
||||
})
|
||||
} else {
|
||||
self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
|
||||
None
|
||||
};
|
||||
|
||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model_telemetry_id()
|
||||
.context("no active model")?;
|
||||
|
||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
|
||||
.trim()
|
||||
.to_lowercase()
|
||||
== "delete"
|
||||
{
|
||||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||
} else {
|
||||
let request =
|
||||
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
|
||||
let chunks =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
||||
};
|
||||
self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request(
|
||||
&self,
|
||||
user_prompt: String,
|
||||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
edit_range: Range<Anchor>,
|
||||
cx: &AppContext,
|
||||
) -> LanguageModelRequest {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(edit_range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
|
||||
None
|
||||
} else {
|
||||
Some(language.name())
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Higher Temperature increases the randomness of model outputs.
|
||||
// If Markdown or No Language is Known, increase the randomness for more creative output
|
||||
// If Code, decrease temperature to get more deterministic outputs
|
||||
let temperature = if let Some(language) = language_name.clone() {
|
||||
if language.as_ref() == "Markdown" {
|
||||
1.0
|
||||
} else {
|
||||
0.5
|
||||
}
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let language_name = language_name.as_deref();
|
||||
let start = buffer.point_to_buffer_offset(edit_range.start);
|
||||
let end = buffer.point_to_buffer_offset(edit_range.end);
|
||||
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
|
||||
let (start_buffer, start_buffer_offset) = start;
|
||||
let (end_buffer, end_buffer_offset) = end;
|
||||
if start_buffer.remote_id() == end_buffer.remote_id() {
|
||||
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
|
||||
} else {
|
||||
panic!("invalid transformation range");
|
||||
}
|
||||
} else {
|
||||
panic!("invalid transformation range");
|
||||
};
|
||||
let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
|
||||
|
||||
let mut messages = Vec::new();
|
||||
if let Some(context_request) = assistant_panel_context {
|
||||
messages = context_request.messages;
|
||||
}
|
||||
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
|
||||
LanguageModelRequest {
|
||||
messages,
|
||||
stop: vec!["|END|>".to_string()],
|
||||
temperature,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_stream(
|
||||
&mut self,
|
||||
model_telemetry_id: String,
|
||||
edit_range: Range<Anchor>,
|
||||
stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let range = self.range.clone();
|
||||
let snapshot = self.snapshot.clone();
|
||||
let selected_text = snapshot
|
||||
.text_for_range(range.start..range.end)
|
||||
.text_for_range(edit_range.start..edit_range.end)
|
||||
.collect::<Rope>();
|
||||
|
||||
let selection_start = range.start.to_point(&snapshot);
|
||||
let selection_start = edit_range.start.to_point(&snapshot);
|
||||
|
||||
// Start with the indentation of the first line in the selection
|
||||
let mut suggested_line_indent = snapshot
|
||||
|
@ -2105,7 +2161,7 @@ impl Codegen {
|
|||
|
||||
// If the first line in the selection does not have indentation, check the following lines
|
||||
if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
|
||||
for row in selection_start.row..=range.end.to_point(&snapshot).row {
|
||||
for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
|
||||
let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
|
||||
// Prefer tabs if a line in the selection uses tabs as indentation
|
||||
if line_indent.kind == IndentKind::Tab {
|
||||
|
@ -2116,19 +2172,13 @@ impl Codegen {
|
|||
}
|
||||
|
||||
let telemetry = self.telemetry.clone();
|
||||
self.edit_position = range.start;
|
||||
self.diff = Diff::default();
|
||||
self.status = CodegenStatus::Pending;
|
||||
if let Some(transaction_id) = self.generation_transaction_id.take() {
|
||||
self.buffer
|
||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
||||
}
|
||||
let mut edit_start = edit_range.start.to_offset(&snapshot);
|
||||
self.generation = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let chunks = stream.await;
|
||||
let generate = async {
|
||||
let mut edit_start = range.start.to_offset(&snapshot);
|
||||
|
||||
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
||||
let diff: Task<anyhow::Result<()>> =
|
||||
cx.background_executor().spawn(async move {
|
||||
|
@ -2218,7 +2268,7 @@ impl Codegen {
|
|||
telemetry.report_assistant_event(
|
||||
None,
|
||||
telemetry_events::AssistantKind::Inline,
|
||||
telemetry_id,
|
||||
model_telemetry_id,
|
||||
response_latency,
|
||||
error_message,
|
||||
);
|
||||
|
@ -2262,13 +2312,13 @@ impl Codegen {
|
|||
None,
|
||||
cx,
|
||||
);
|
||||
this.edit_position = snapshot.anchor_after(edit_start);
|
||||
this.edit_position = Some(snapshot.anchor_after(edit_start));
|
||||
|
||||
buffer.end_transaction(cx)
|
||||
});
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
if let Some(first_transaction) = this.generation_transaction_id {
|
||||
if let Some(first_transaction) = this.transaction_id {
|
||||
// Group all assistant edits into the first transaction.
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
buffer.merge_transactions(
|
||||
|
@ -2278,14 +2328,14 @@ impl Codegen {
|
|||
)
|
||||
});
|
||||
} else {
|
||||
this.generation_transaction_id = Some(transaction);
|
||||
this.transaction_id = Some(transaction);
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
buffer.finalize_last_transaction(cx)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
this.update_diff(cx);
|
||||
this.update_diff(edit_range.clone(), cx);
|
||||
cx.notify();
|
||||
})?;
|
||||
}
|
||||
|
@ -2321,27 +2371,22 @@ impl Codegen {
|
|||
}
|
||||
|
||||
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
|
||||
if let Some(transaction_id) = self.prepend_transaction_id.take() {
|
||||
self.buffer
|
||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
||||
}
|
||||
|
||||
if let Some(transaction_id) = self.generation_transaction_id.take() {
|
||||
if let Some(transaction_id) = self.transaction_id.take() {
|
||||
self.buffer
|
||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
||||
}
|
||||
}
|
||||
|
||||
fn update_diff(&mut self, cx: &mut ModelContext<Self>) {
|
||||
fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
|
||||
if self.diff.task.is_some() {
|
||||
self.diff.should_update = true;
|
||||
} else {
|
||||
self.diff.should_update = false;
|
||||
|
||||
let old_snapshot = self.snapshot.clone();
|
||||
let old_range = self.range.to_point(&old_snapshot);
|
||||
let old_range = edit_range.to_point(&old_snapshot);
|
||||
let new_snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let new_range = self.range.to_point(&new_snapshot);
|
||||
let new_range = edit_range.to_point(&new_snapshot);
|
||||
|
||||
self.diff.task = Some(cx.spawn(|this, mut cx| async move {
|
||||
let (deleted_row_ranges, inserted_row_ranges) = cx
|
||||
|
@ -2422,7 +2467,7 @@ impl Codegen {
|
|||
this.diff.inserted_row_ranges = inserted_row_ranges;
|
||||
this.diff.task = None;
|
||||
if this.diff.should_update {
|
||||
this.update_diff(cx);
|
||||
this.update_diff(edit_range, cx);
|
||||
}
|
||||
cx.notify();
|
||||
})
|
||||
|
@ -2629,12 +2674,14 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range,
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
|
@ -2690,12 +2737,14 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range.clone(),
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
|
@ -2755,12 +2804,14 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range.clone(),
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
|
@ -2819,12 +2870,14 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range.clone(),
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
|
|
|
@ -734,29 +734,27 @@ impl PromptLibrary {
|
|||
const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
|
||||
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
||||
if let Some(token_count) = cx.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(
|
||||
LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: body.to_string(),
|
||||
}],
|
||||
stop: Vec::new(),
|
||||
temperature: 1.,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})? {
|
||||
let token_count = token_count.await?;
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(
|
||||
LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: body.to_string(),
|
||||
}],
|
||||
stop: Vec::new(),
|
||||
temperature: 1.,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
|
||||
prompt_editor.token_count = Some(token_count);
|
||||
cx.notify();
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
|
||||
prompt_editor.token_count = Some(token_count);
|
||||
cx.notify();
|
||||
})
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
|
|
|
@ -6,8 +6,7 @@ pub fn generate_content_prompt(
|
|||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<usize>,
|
||||
_project_name: Option<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
let content_type = match language_name {
|
||||
|
@ -15,14 +14,16 @@ pub fn generate_content_prompt(
|
|||
writeln!(
|
||||
prompt,
|
||||
"Here's a file of text that I'm going to ask you to make an edit to."
|
||||
)?;
|
||||
)
|
||||
.unwrap();
|
||||
"text"
|
||||
}
|
||||
Some(language_name) => {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Here's a file of {language_name} that I'm going to ask you to make an edit to."
|
||||
)?;
|
||||
)
|
||||
.unwrap();
|
||||
"code"
|
||||
}
|
||||
};
|
||||
|
@ -70,7 +71,7 @@ pub fn generate_content_prompt(
|
|||
write!(prompt, "</document>\n\n").unwrap();
|
||||
|
||||
if is_truncated {
|
||||
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n")?;
|
||||
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n").unwrap();
|
||||
}
|
||||
|
||||
if range.is_empty() {
|
||||
|
@ -107,7 +108,7 @@ pub fn generate_content_prompt(
|
|||
prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```");
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
prompt
|
||||
}
|
||||
|
||||
pub fn generate_terminal_assistant_prompt(
|
||||
|
|
|
@ -707,18 +707,15 @@ impl PromptEditor {
|
|||
inline_assistant.request_for_inline_assist(assist_id, cx)
|
||||
})??;
|
||||
|
||||
if let Some(token_count) = cx.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})? {
|
||||
let token_count = token_count.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify();
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue