assistant: Edit files tool (#26506)
Exposes a new "edit files" tool that the model can use to apply modifications to files in the project. The main model provides instructions and the tool uses a separate "editor" model (Claude 3.5 by default) to generate search/replace blocks like Aider does: ````markdown mathweb/flask/app.py ```python <<<<<<< SEARCH from flask import Flask ======= import math from flask import Flask >>>>>>> REPLACE ``` ```` The search/replace blocks are parsed and applied as they stream in. If a block fails to parse, the tool will apply the other edits and report an error pointing to the part of the input where it occurred. This should allow the model to fix it. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
f3f97895a9
commit
47a89ad243
21 changed files with 1216 additions and 7 deletions
7
Cargo.lock
generated
7
Cargo.lock
generated
|
@ -658,6 +658,7 @@ dependencies = [
|
||||||
"collections",
|
"collections",
|
||||||
"derive_more",
|
"derive_more",
|
||||||
"gpui",
|
"gpui",
|
||||||
|
"language_model",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"project",
|
"project",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -671,11 +672,16 @@ dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assistant_tool",
|
"assistant_tool",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
"collections",
|
||||||
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
|
"language_model",
|
||||||
"project",
|
"project",
|
||||||
|
"rand 0.8.5",
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"util",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3128,6 +3134,7 @@ dependencies = [
|
||||||
"extension",
|
"extension",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
|
"language_model",
|
||||||
"log",
|
"log",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"postage",
|
"postage",
|
||||||
|
|
|
@ -600,6 +600,13 @@
|
||||||
"provider": "zed.dev",
|
"provider": "zed.dev",
|
||||||
// The model to use.
|
// The model to use.
|
||||||
"model": "claude-3-5-sonnet-latest"
|
"model": "claude-3-5-sonnet-latest"
|
||||||
|
},
|
||||||
|
// The model to use when applying edits from the assistant.
|
||||||
|
"editor_model": {
|
||||||
|
// The provider to use.
|
||||||
|
"provider": "zed.dev",
|
||||||
|
// The model to use.
|
||||||
|
"model": "claude-3-5-sonnet-latest"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// The settings for slash commands.
|
// The settings for slash commands.
|
||||||
|
|
|
@ -186,8 +186,12 @@ fn init_language_model_settings(cx: &mut App) {
|
||||||
|
|
||||||
fn update_active_language_model_from_settings(cx: &mut App) {
|
fn update_active_language_model_from_settings(cx: &mut App) {
|
||||||
let settings = AssistantSettings::get_global(cx);
|
let settings = AssistantSettings::get_global(cx);
|
||||||
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
|
let active_model_provider_name =
|
||||||
let model_id = LanguageModelId::from(settings.default_model.model.clone());
|
LanguageModelProviderId::from(settings.default_model.provider.clone());
|
||||||
|
let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
|
||||||
|
let editor_provider_name =
|
||||||
|
LanguageModelProviderId::from(settings.editor_model.provider.clone());
|
||||||
|
let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone());
|
||||||
let inline_alternatives = settings
|
let inline_alternatives = settings
|
||||||
.inline_alternatives
|
.inline_alternatives
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -199,7 +203,8 @@ fn update_active_language_model_from_settings(cx: &mut App) {
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||||
registry.select_active_model(&provider_name, &model_id, cx);
|
registry.select_active_model(&active_model_provider_name, &active_model_id, cx);
|
||||||
|
registry.select_editor_model(&editor_provider_name, &editor_model_id, cx);
|
||||||
registry.select_inline_alternative_models(inline_alternatives, cx);
|
registry.select_inline_alternative_models(inline_alternatives, cx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -297,7 +297,8 @@ impl AssistantPanel {
|
||||||
&LanguageModelRegistry::global(cx),
|
&LanguageModelRegistry::global(cx),
|
||||||
window,
|
window,
|
||||||
|this, _, event: &language_model::Event, window, cx| match event {
|
|this, _, event: &language_model::Event, window, cx| match event {
|
||||||
language_model::Event::ActiveModelChanged => {
|
language_model::Event::ActiveModelChanged
|
||||||
|
| language_model::Event::EditorModelChanged => {
|
||||||
this.completion_provider_changed(window, cx);
|
this.completion_provider_changed(window, cx);
|
||||||
}
|
}
|
||||||
language_model::Event::ProviderStateChanged => {
|
language_model::Event::ProviderStateChanged => {
|
||||||
|
|
|
@ -652,7 +652,7 @@ impl ActiveThread {
|
||||||
)
|
)
|
||||||
.child(message_content),
|
.child(message_content),
|
||||||
),
|
),
|
||||||
Role::Assistant => div()
|
Role::Assistant => v_flex()
|
||||||
.id(("message-container", ix))
|
.id(("message-container", ix))
|
||||||
.child(message_content)
|
.child(message_content)
|
||||||
.map(|parent| {
|
.map(|parent| {
|
||||||
|
|
|
@ -623,6 +623,7 @@ impl Thread {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
|
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
|
||||||
|
let request = self.to_completion_request(RequestKind::Chat, cx);
|
||||||
let pending_tool_uses = self
|
let pending_tool_uses = self
|
||||||
.tool_use
|
.tool_use
|
||||||
.pending_tool_uses()
|
.pending_tool_uses()
|
||||||
|
@ -633,7 +634,7 @@ impl Thread {
|
||||||
|
|
||||||
for tool_use in pending_tool_uses {
|
for tool_use in pending_tool_uses {
|
||||||
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
||||||
let task = tool.run(tool_use.input, self.project.clone(), cx);
|
let task = tool.run(tool_use.input, &request.messages, self.project.clone(), cx);
|
||||||
|
|
||||||
self.insert_tool_output(tool_use.id.clone(), task, cx);
|
self.insert_tool_output(tool_use.id.clone(), task, cx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,6 +62,7 @@ pub struct AssistantSettings {
|
||||||
pub default_width: Pixels,
|
pub default_width: Pixels,
|
||||||
pub default_height: Pixels,
|
pub default_height: Pixels,
|
||||||
pub default_model: LanguageModelSelection,
|
pub default_model: LanguageModelSelection,
|
||||||
|
pub editor_model: LanguageModelSelection,
|
||||||
pub inline_alternatives: Vec<LanguageModelSelection>,
|
pub inline_alternatives: Vec<LanguageModelSelection>,
|
||||||
pub using_outdated_settings_version: bool,
|
pub using_outdated_settings_version: bool,
|
||||||
pub enable_experimental_live_diffs: bool,
|
pub enable_experimental_live_diffs: bool,
|
||||||
|
@ -162,6 +163,7 @@ impl AssistantSettingsContent {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|
editor_model: None,
|
||||||
inline_alternatives: None,
|
inline_alternatives: None,
|
||||||
enable_experimental_live_diffs: None,
|
enable_experimental_live_diffs: None,
|
||||||
},
|
},
|
||||||
|
@ -182,6 +184,7 @@ impl AssistantSettingsContent {
|
||||||
.id()
|
.id()
|
||||||
.to_string(),
|
.to_string(),
|
||||||
}),
|
}),
|
||||||
|
editor_model: None,
|
||||||
inline_alternatives: None,
|
inline_alternatives: None,
|
||||||
enable_experimental_live_diffs: None,
|
enable_experimental_live_diffs: None,
|
||||||
},
|
},
|
||||||
|
@ -310,6 +313,7 @@ impl Default for VersionedAssistantSettingsContent {
|
||||||
default_width: None,
|
default_width: None,
|
||||||
default_height: None,
|
default_height: None,
|
||||||
default_model: None,
|
default_model: None,
|
||||||
|
editor_model: None,
|
||||||
inline_alternatives: None,
|
inline_alternatives: None,
|
||||||
enable_experimental_live_diffs: None,
|
enable_experimental_live_diffs: None,
|
||||||
})
|
})
|
||||||
|
@ -340,6 +344,8 @@ pub struct AssistantSettingsContentV2 {
|
||||||
default_height: Option<f32>,
|
default_height: Option<f32>,
|
||||||
/// The default model to use when creating new chats.
|
/// The default model to use when creating new chats.
|
||||||
default_model: Option<LanguageModelSelection>,
|
default_model: Option<LanguageModelSelection>,
|
||||||
|
/// The model to use when applying edits from the assistant.
|
||||||
|
editor_model: Option<LanguageModelSelection>,
|
||||||
/// Additional models with which to generate alternatives when performing inline assists.
|
/// Additional models with which to generate alternatives when performing inline assists.
|
||||||
inline_alternatives: Option<Vec<LanguageModelSelection>>,
|
inline_alternatives: Option<Vec<LanguageModelSelection>>,
|
||||||
/// Enable experimental live diffs in the assistant panel.
|
/// Enable experimental live diffs in the assistant panel.
|
||||||
|
@ -470,6 +476,7 @@ impl Settings for AssistantSettings {
|
||||||
value.default_height.map(Into::into),
|
value.default_height.map(Into::into),
|
||||||
);
|
);
|
||||||
merge(&mut settings.default_model, value.default_model);
|
merge(&mut settings.default_model, value.default_model);
|
||||||
|
merge(&mut settings.editor_model, value.editor_model);
|
||||||
merge(&mut settings.inline_alternatives, value.inline_alternatives);
|
merge(&mut settings.inline_alternatives, value.inline_alternatives);
|
||||||
merge(
|
merge(
|
||||||
&mut settings.enable_experimental_live_diffs,
|
&mut settings.enable_experimental_live_diffs,
|
||||||
|
@ -528,6 +535,10 @@ mod tests {
|
||||||
provider: "test-provider".into(),
|
provider: "test-provider".into(),
|
||||||
model: "gpt-99".into(),
|
model: "gpt-99".into(),
|
||||||
}),
|
}),
|
||||||
|
editor_model: Some(LanguageModelSelection {
|
||||||
|
provider: "test-provider".into(),
|
||||||
|
model: "gpt-99".into(),
|
||||||
|
}),
|
||||||
inline_alternatives: None,
|
inline_alternatives: None,
|
||||||
enabled: None,
|
enabled: None,
|
||||||
button: None,
|
button: None,
|
||||||
|
|
|
@ -15,6 +15,7 @@ path = "src/assistant_tool.rs"
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
derive_more.workspace = true
|
derive_more.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
|
|
|
@ -5,6 +5,7 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use gpui::{App, Entity, SharedString, Task};
|
use gpui::{App, Entity, SharedString, Task};
|
||||||
|
use language_model::LanguageModelRequestMessage;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
|
||||||
pub use crate::tool_registry::*;
|
pub use crate::tool_registry::*;
|
||||||
|
@ -44,6 +45,7 @@ pub trait Tool: 'static + Send + Sync {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
messages: &[LanguageModelRequestMessage],
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<String>>;
|
) -> Task<Result<String>>;
|
||||||
|
|
|
@ -15,8 +15,18 @@ path = "src/assistant_tools.rs"
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
assistant_tool.workspace = true
|
assistant_tool.workspace = true
|
||||||
chrono.workspace = true
|
chrono.workspace = true
|
||||||
|
collections.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
util.workspace = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rand.workspace = true
|
||||||
|
collections = { workspace = true, features = ["test-support"] }
|
||||||
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
|
project = { workspace = true, features = ["test-support"] }
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
mod edit_files_tool;
|
||||||
mod list_worktrees_tool;
|
mod list_worktrees_tool;
|
||||||
mod now_tool;
|
mod now_tool;
|
||||||
mod read_file_tool;
|
mod read_file_tool;
|
||||||
|
@ -5,6 +6,7 @@ mod read_file_tool;
|
||||||
use assistant_tool::ToolRegistry;
|
use assistant_tool::ToolRegistry;
|
||||||
use gpui::App;
|
use gpui::App;
|
||||||
|
|
||||||
|
use crate::edit_files_tool::EditFilesTool;
|
||||||
use crate::list_worktrees_tool::ListWorktreesTool;
|
use crate::list_worktrees_tool::ListWorktreesTool;
|
||||||
use crate::now_tool::NowTool;
|
use crate::now_tool::NowTool;
|
||||||
use crate::read_file_tool::ReadFileTool;
|
use crate::read_file_tool::ReadFileTool;
|
||||||
|
@ -16,4 +18,5 @@ pub fn init(cx: &mut App) {
|
||||||
registry.register_tool(NowTool);
|
registry.register_tool(NowTool);
|
||||||
registry.register_tool(ListWorktreesTool);
|
registry.register_tool(ListWorktreesTool);
|
||||||
registry.register_tool(ReadFileTool);
|
registry.register_tool(ReadFileTool);
|
||||||
|
registry.register_tool(EditFilesTool);
|
||||||
}
|
}
|
||||||
|
|
155
crates/assistant_tools/src/edit_files_tool.rs
Normal file
155
crates/assistant_tools/src/edit_files_tool.rs
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
mod edit_action;
|
||||||
|
|
||||||
|
use collections::HashSet;
|
||||||
|
use std::{path::Path, sync::Arc};
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use assistant_tool::Tool;
|
||||||
|
use edit_action::{EditAction, EditActionParser};
|
||||||
|
use futures::StreamExt;
|
||||||
|
use gpui::{App, Entity, Task};
|
||||||
|
use language_model::{
|
||||||
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||||
|
};
|
||||||
|
use project::{Project, ProjectPath, WorktreeId};
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
pub struct EditFilesToolInput {
|
||||||
|
/// The ID of the worktree in which the files reside.
|
||||||
|
pub worktree_id: usize,
|
||||||
|
/// Instruct how to modify the files.
|
||||||
|
pub edit_instructions: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EditFilesTool;
|
||||||
|
|
||||||
|
impl Tool for EditFilesTool {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"edit-files".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> String {
|
||||||
|
include_str!("./edit_files_tool/description.md").into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
|
let schema = schemars::schema_for!(EditFilesToolInput);
|
||||||
|
serde_json::to_value(&schema).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
messages: &[LanguageModelRequestMessage],
|
||||||
|
project: Entity<Project>,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<String>> {
|
||||||
|
let input = match serde_json::from_value::<EditFilesToolInput>(input) {
|
||||||
|
Ok(input) => input,
|
||||||
|
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||||
|
};
|
||||||
|
|
||||||
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
let Some(model) = model_registry.editor_model() else {
|
||||||
|
return Task::ready(Err(anyhow!("No editor model configured")));
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut messages = messages.to_vec();
|
||||||
|
if let Some(last_message) = messages.last_mut() {
|
||||||
|
// Strip out tool use from the last message because we're in the middle of executing a tool call.
|
||||||
|
last_message
|
||||||
|
.content
|
||||||
|
.retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
|
||||||
|
}
|
||||||
|
messages.push(LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![
|
||||||
|
include_str!("./edit_files_tool/edit_prompt.md").into(),
|
||||||
|
input.edit_instructions.into(),
|
||||||
|
],
|
||||||
|
cache: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
let request = LanguageModelRequest {
|
||||||
|
messages,
|
||||||
|
tools: vec![],
|
||||||
|
stop: vec![],
|
||||||
|
temperature: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
|
||||||
|
let stream = model.stream_completion_text(request, &cx);
|
||||||
|
let mut chunks = stream.await?;
|
||||||
|
|
||||||
|
let mut changed_buffers = HashSet::default();
|
||||||
|
let mut applied_edits = 0;
|
||||||
|
|
||||||
|
while let Some(chunk) = chunks.stream.next().await {
|
||||||
|
for action in parser.parse_chunk(&chunk?) {
|
||||||
|
let project_path = ProjectPath {
|
||||||
|
worktree_id: WorktreeId::from_usize(input.worktree_id),
|
||||||
|
path: Path::new(action.file_path()).into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let buffer = project
|
||||||
|
.update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let diff = buffer
|
||||||
|
.read_with(&cx, |buffer, cx| {
|
||||||
|
let new_text = match action {
|
||||||
|
EditAction::Replace { old, new, .. } => {
|
||||||
|
// TODO: Replace in background?
|
||||||
|
buffer.text().replace(&old, &new)
|
||||||
|
}
|
||||||
|
EditAction::Write { content, .. } => content,
|
||||||
|
};
|
||||||
|
|
||||||
|
buffer.diff(new_text, cx)
|
||||||
|
})?
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let _clock =
|
||||||
|
buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
|
||||||
|
|
||||||
|
changed_buffers.insert(buffer);
|
||||||
|
|
||||||
|
applied_edits += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save each buffer once at the end
|
||||||
|
for buffer in changed_buffers {
|
||||||
|
project
|
||||||
|
.update(&mut cx, |project, cx| project.save_buffer(buffer, cx))?
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let errors = parser.errors();
|
||||||
|
|
||||||
|
if errors.is_empty() {
|
||||||
|
Ok("Successfully applied all edits".into())
|
||||||
|
} else {
|
||||||
|
let error_message = errors
|
||||||
|
.iter()
|
||||||
|
.map(|e| e.to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
if applied_edits > 0 {
|
||||||
|
Err(anyhow!(
|
||||||
|
"Applied {} edit(s), but some blocks failed to parse:\n{}",
|
||||||
|
applied_edits,
|
||||||
|
error_message
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Err(anyhow!(error_message))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
Edit files in a worktree by providing its id and a description of how to modify the code to complete the request.
|
||||||
|
|
||||||
|
Make instructions unambiguous and complete. Explain all needed code changes clearly and completely, but concisely. Just show the changes needed. DO NOT show the entire updated function/file/etc!
|
807
crates/assistant_tools/src/edit_files_tool/edit_action.rs
Normal file
807
crates/assistant_tools/src/edit_files_tool/edit_action.rs
Normal file
|
@ -0,0 +1,807 @@
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
/// Represents an edit action to be performed on a file.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum EditAction {
|
||||||
|
/// Replace specific content in a file with new content
|
||||||
|
Replace {
|
||||||
|
file_path: String,
|
||||||
|
old: String,
|
||||||
|
new: String,
|
||||||
|
},
|
||||||
|
/// Write content to a file (create or overwrite)
|
||||||
|
Write { file_path: String, content: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EditAction {
|
||||||
|
pub fn file_path(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
EditAction::Replace { file_path, .. } => file_path,
|
||||||
|
EditAction::Write { file_path, .. } => file_path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parses edit actions from an LLM response.
|
||||||
|
/// See system.md for more details on the format.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EditActionParser {
|
||||||
|
state: State,
|
||||||
|
pre_fence_line: Vec<u8>,
|
||||||
|
marker_ix: usize,
|
||||||
|
line: usize,
|
||||||
|
column: usize,
|
||||||
|
old_bytes: Vec<u8>,
|
||||||
|
new_bytes: Vec<u8>,
|
||||||
|
errors: Vec<ParseError>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
enum State {
|
||||||
|
/// Anywhere outside an action
|
||||||
|
Default,
|
||||||
|
/// After opening ```, in optional language tag
|
||||||
|
OpenFence,
|
||||||
|
/// In SEARCH marker
|
||||||
|
SearchMarker,
|
||||||
|
/// In search block or divider
|
||||||
|
SearchBlock,
|
||||||
|
/// In replace block or REPLACE marker
|
||||||
|
ReplaceBlock,
|
||||||
|
/// In closing ```
|
||||||
|
CloseFence,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EditActionParser {
|
||||||
|
/// Creates a new `EditActionParser`
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
state: State::Default,
|
||||||
|
pre_fence_line: Vec::new(),
|
||||||
|
marker_ix: 0,
|
||||||
|
line: 1,
|
||||||
|
column: 0,
|
||||||
|
old_bytes: Vec::new(),
|
||||||
|
new_bytes: Vec::new(),
|
||||||
|
errors: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Processes a chunk of input text and returns any completed edit actions.
|
||||||
|
///
|
||||||
|
/// This method can be called repeatedly with fragments of input. The parser
|
||||||
|
/// maintains its state between calls, allowing you to process streaming input
|
||||||
|
/// as it becomes available. Actions are only inserted once they are fully parsed.
|
||||||
|
///
|
||||||
|
/// If a block fails to parse, it will simply be skipped and an error will be recorded.
|
||||||
|
/// All errors can be accessed through the `EditActionsParser::errors` method.
|
||||||
|
pub fn parse_chunk(&mut self, input: &str) -> Vec<EditAction> {
|
||||||
|
use State::*;
|
||||||
|
|
||||||
|
const FENCE: &[u8] = b"\n```";
|
||||||
|
const SEARCH_MARKER: &[u8] = b"<<<<<<< SEARCH\n";
|
||||||
|
const DIVIDER: &[u8] = b"=======\n";
|
||||||
|
const NL_DIVIDER: &[u8] = b"\n=======\n";
|
||||||
|
const REPLACE_MARKER: &[u8] = b">>>>>>> REPLACE";
|
||||||
|
const NL_REPLACE_MARKER: &[u8] = b"\n>>>>>>> REPLACE";
|
||||||
|
|
||||||
|
let mut actions = Vec::new();
|
||||||
|
|
||||||
|
for byte in input.bytes() {
|
||||||
|
// Update line and column tracking
|
||||||
|
if byte == b'\n' {
|
||||||
|
self.line += 1;
|
||||||
|
self.column = 0;
|
||||||
|
} else {
|
||||||
|
self.column += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.state {
|
||||||
|
Default => match match_marker(byte, FENCE, &mut self.marker_ix) {
|
||||||
|
MarkerMatch::Complete => {
|
||||||
|
self.to_state(OpenFence);
|
||||||
|
}
|
||||||
|
MarkerMatch::Partial => {}
|
||||||
|
MarkerMatch::None => {
|
||||||
|
if self.marker_ix > 0 {
|
||||||
|
self.marker_ix = 0;
|
||||||
|
self.pre_fence_line.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
if byte != b'\n' {
|
||||||
|
self.pre_fence_line.push(byte);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
OpenFence => {
|
||||||
|
// skip language tag
|
||||||
|
if byte == b'\n' {
|
||||||
|
self.to_state(SearchMarker);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SearchMarker => {
|
||||||
|
if self.expect_marker(byte, SEARCH_MARKER) {
|
||||||
|
self.to_state(SearchBlock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SearchBlock => {
|
||||||
|
if collect_until_marker(
|
||||||
|
byte,
|
||||||
|
DIVIDER,
|
||||||
|
NL_DIVIDER,
|
||||||
|
&mut self.marker_ix,
|
||||||
|
&mut self.old_bytes,
|
||||||
|
) {
|
||||||
|
self.to_state(ReplaceBlock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ReplaceBlock => {
|
||||||
|
if collect_until_marker(
|
||||||
|
byte,
|
||||||
|
REPLACE_MARKER,
|
||||||
|
NL_REPLACE_MARKER,
|
||||||
|
&mut self.marker_ix,
|
||||||
|
&mut self.new_bytes,
|
||||||
|
) {
|
||||||
|
self.to_state(CloseFence);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CloseFence => {
|
||||||
|
if self.expect_marker(byte, FENCE) {
|
||||||
|
if let Some(action) = self.action() {
|
||||||
|
actions.push(action);
|
||||||
|
}
|
||||||
|
self.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
actions
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a reference to the errors encountered during parsing.
|
||||||
|
pub fn errors(&self) -> &[ParseError] {
|
||||||
|
&self.errors
|
||||||
|
}
|
||||||
|
|
||||||
|
fn action(&mut self) -> Option<EditAction> {
|
||||||
|
if self.old_bytes.is_empty() && self.new_bytes.is_empty() {
|
||||||
|
self.push_error(ParseErrorKind::NoOp);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let file_path = String::from_utf8(std::mem::take(&mut self.pre_fence_line)).log_err()?;
|
||||||
|
let content = String::from_utf8(std::mem::take(&mut self.new_bytes)).log_err()?;
|
||||||
|
|
||||||
|
if self.old_bytes.is_empty() {
|
||||||
|
Some(EditAction::Write { file_path, content })
|
||||||
|
} else {
|
||||||
|
let old = String::from_utf8(std::mem::take(&mut self.old_bytes)).log_err()?;
|
||||||
|
|
||||||
|
Some(EditAction::Replace {
|
||||||
|
file_path,
|
||||||
|
old,
|
||||||
|
new: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expect_marker(&mut self, byte: u8, marker: &'static [u8]) -> bool {
|
||||||
|
match match_marker(byte, marker, &mut self.marker_ix) {
|
||||||
|
MarkerMatch::Complete => true,
|
||||||
|
MarkerMatch::Partial => false,
|
||||||
|
MarkerMatch::None => {
|
||||||
|
self.push_error(ParseErrorKind::ExpectedMarker {
|
||||||
|
expected: marker,
|
||||||
|
found: byte,
|
||||||
|
});
|
||||||
|
self.reset();
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_state(&mut self, state: State) {
|
||||||
|
self.state = state;
|
||||||
|
self.marker_ix = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset(&mut self) {
|
||||||
|
self.pre_fence_line.clear();
|
||||||
|
self.old_bytes.clear();
|
||||||
|
self.new_bytes.clear();
|
||||||
|
self.to_state(State::Default);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push_error(&mut self, kind: ParseErrorKind) {
|
||||||
|
self.errors.push(ParseError {
|
||||||
|
line: self.line,
|
||||||
|
column: self.column,
|
||||||
|
kind,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum MarkerMatch {
|
||||||
|
None,
|
||||||
|
Partial,
|
||||||
|
Complete,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn match_marker(byte: u8, marker: &[u8], marker_ix: &mut usize) -> MarkerMatch {
|
||||||
|
if byte == marker[*marker_ix] {
|
||||||
|
*marker_ix += 1;
|
||||||
|
|
||||||
|
if *marker_ix >= marker.len() {
|
||||||
|
MarkerMatch::Complete
|
||||||
|
} else {
|
||||||
|
MarkerMatch::Partial
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MarkerMatch::None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collect_until_marker(
|
||||||
|
byte: u8,
|
||||||
|
marker: &[u8],
|
||||||
|
nl_marker: &[u8],
|
||||||
|
marker_ix: &mut usize,
|
||||||
|
buf: &mut Vec<u8>,
|
||||||
|
) -> bool {
|
||||||
|
let marker = if buf.is_empty() {
|
||||||
|
// do not require another newline if block is empty
|
||||||
|
marker
|
||||||
|
} else {
|
||||||
|
nl_marker
|
||||||
|
};
|
||||||
|
|
||||||
|
match match_marker(byte, marker, marker_ix) {
|
||||||
|
MarkerMatch::Complete => true,
|
||||||
|
MarkerMatch::Partial => false,
|
||||||
|
MarkerMatch::None => {
|
||||||
|
if *marker_ix > 0 {
|
||||||
|
buf.extend_from_slice(&marker[..*marker_ix]);
|
||||||
|
*marker_ix = 0;
|
||||||
|
|
||||||
|
// The beginning of marker might match current byte
|
||||||
|
match match_marker(byte, marker, marker_ix) {
|
||||||
|
MarkerMatch::Complete => return true,
|
||||||
|
MarkerMatch::Partial => return false,
|
||||||
|
MarkerMatch::None => { /* no match, keep collecting */ }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.push(byte);
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub struct ParseError {
|
||||||
|
line: usize,
|
||||||
|
column: usize,
|
||||||
|
kind: ParseErrorKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum ParseErrorKind {
|
||||||
|
ExpectedMarker { expected: &'static [u8], found: u8 },
|
||||||
|
NoOp,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ParseErrorKind {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
ParseErrorKind::ExpectedMarker { expected, found } => {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Expected marker {:?}, found {:?}",
|
||||||
|
String::from_utf8_lossy(expected),
|
||||||
|
*found as char
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ParseErrorKind::NoOp => {
|
||||||
|
write!(f, "No search or replace")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ParseError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "input:{}:{}: {}", self.line, self.column, self.kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simple_edit_action() {
|
||||||
|
let input = r#"src/main.rs
|
||||||
|
```
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn original() {}
|
||||||
|
=======
|
||||||
|
fn replacement() {}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
|
||||||
|
assert_eq!(actions.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
actions[0],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "src/main.rs".to_string(),
|
||||||
|
old: "fn original() {}".to_string(),
|
||||||
|
new: "fn replacement() {}".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_with_language_tag() {
|
||||||
|
let input = r#"src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn original() {}
|
||||||
|
=======
|
||||||
|
fn replacement() {}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
|
||||||
|
assert_eq!(actions.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
actions[0],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "src/main.rs".to_string(),
|
||||||
|
old: "fn original() {}".to_string(),
|
||||||
|
new: "fn replacement() {}".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_with_surrounding_text() {
|
||||||
|
let input = r#"Here's a modification I'd like to make to the file:
|
||||||
|
|
||||||
|
src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn original() {}
|
||||||
|
=======
|
||||||
|
fn replacement() {}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
|
||||||
|
This change makes the function better.
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
|
||||||
|
assert_eq!(actions.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
actions[0],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "src/main.rs".to_string(),
|
||||||
|
old: "fn original() {}".to_string(),
|
||||||
|
new: "fn replacement() {}".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiple_edit_actions() {
|
||||||
|
let input = r#"First change:
|
||||||
|
src/main.rs
|
||||||
|
```
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn original() {}
|
||||||
|
=======
|
||||||
|
fn replacement() {}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
|
||||||
|
Second change:
|
||||||
|
src/utils.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn old_util() -> bool { false }
|
||||||
|
=======
|
||||||
|
fn new_util() -> bool { true }
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
|
||||||
|
assert_eq!(actions.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
actions[0],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "src/main.rs".to_string(),
|
||||||
|
old: "fn original() {}".to_string(),
|
||||||
|
new: "fn replacement() {}".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
actions[1],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "src/utils.rs".to_string(),
|
||||||
|
old: "fn old_util() -> bool { false }".to_string(),
|
||||||
|
new: "fn new_util() -> bool { true }".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiline() {
|
||||||
|
let input = r#"src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn original() {
|
||||||
|
println!("This is the original function");
|
||||||
|
let x = 42;
|
||||||
|
if x > 0 {
|
||||||
|
println!("Positive number");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
=======
|
||||||
|
fn replacement() {
|
||||||
|
println!("This is the replacement function");
|
||||||
|
let x = 100;
|
||||||
|
if x > 50 {
|
||||||
|
println!("Large number");
|
||||||
|
} else {
|
||||||
|
println!("Small number");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
|
||||||
|
assert_eq!(actions.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
actions[0],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "src/main.rs".to_string(),
|
||||||
|
old: "fn original() {\n println!(\"This is the original function\");\n let x = 42;\n if x > 0 {\n println!(\"Positive number\");\n }\n}".to_string(),
|
||||||
|
new: "fn replacement() {\n println!(\"This is the replacement function\");\n let x = 100;\n if x > 50 {\n println!(\"Large number\");\n } else {\n println!(\"Small number\");\n }\n}".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_write_action() {
|
||||||
|
let input = r#"Create a new main.rs file:
|
||||||
|
|
||||||
|
src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
=======
|
||||||
|
fn new_function() {
|
||||||
|
println!("This function is being added");
|
||||||
|
}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
|
||||||
|
assert_eq!(actions.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
actions[0],
|
||||||
|
EditAction::Write {
|
||||||
|
file_path: "src/main.rs".to_string(),
|
||||||
|
content: "fn new_function() {\n println!(\"This function is being added\");\n}"
|
||||||
|
.to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_replace() {
|
||||||
|
let input = r#"src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn this_will_be_deleted() {
|
||||||
|
println!("Deleting this function");
|
||||||
|
}
|
||||||
|
=======
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
|
||||||
|
assert_eq!(actions.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
actions[0],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "src/main.rs".to_string(),
|
||||||
|
old: "fn this_will_be_deleted() {\n println!(\"Deleting this function\");\n}"
|
||||||
|
.to_string(),
|
||||||
|
new: "".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_both() {
|
||||||
|
let input = r#"src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
=======
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
|
||||||
|
// Should not create an action when both sections are empty
|
||||||
|
assert_eq!(actions.len(), 0);
|
||||||
|
|
||||||
|
// Check that the NoOp error was added
|
||||||
|
assert_eq!(parser.errors().len(), 1);
|
||||||
|
match parser.errors()[0].kind {
|
||||||
|
ParseErrorKind::NoOp => {}
|
||||||
|
_ => panic!("Expected NoOp error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resumability() {
|
||||||
|
let input_part1 = r#"src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn ori"#;
|
||||||
|
|
||||||
|
let input_part2 = r#"ginal() {}
|
||||||
|
=======
|
||||||
|
fn replacement() {}"#;
|
||||||
|
|
||||||
|
let input_part3 = r#"
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions1 = parser.parse_chunk(input_part1);
|
||||||
|
assert_eq!(actions1.len(), 0);
|
||||||
|
|
||||||
|
let actions2 = parser.parse_chunk(input_part2);
|
||||||
|
// No actions should be complete yet
|
||||||
|
assert_eq!(actions2.len(), 0);
|
||||||
|
|
||||||
|
let actions3 = parser.parse_chunk(input_part3);
|
||||||
|
// The third chunk should complete the action
|
||||||
|
assert_eq!(actions3.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
actions3[0],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "src/main.rs".to_string(),
|
||||||
|
old: "fn original() {}".to_string(),
|
||||||
|
new: "fn replacement() {}".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parser_state_preservation() {
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions1 = parser.parse_chunk("src/main.rs\n```rust\n<<<<<<< SEARCH\n");
|
||||||
|
|
||||||
|
// Check parser is in the correct state
|
||||||
|
assert_eq!(parser.state, State::SearchBlock);
|
||||||
|
assert_eq!(parser.pre_fence_line, b"src/main.rs");
|
||||||
|
|
||||||
|
// Continue parsing
|
||||||
|
let actions2 = parser.parse_chunk("original code\n=======\n");
|
||||||
|
assert_eq!(parser.state, State::ReplaceBlock);
|
||||||
|
assert_eq!(parser.old_bytes, b"original code");
|
||||||
|
|
||||||
|
let actions3 = parser.parse_chunk("replacement code\n>>>>>>> REPLACE\n```\n");
|
||||||
|
|
||||||
|
// After complete parsing, state should reset
|
||||||
|
assert_eq!(parser.state, State::Default);
|
||||||
|
assert!(parser.pre_fence_line.is_empty());
|
||||||
|
assert!(parser.old_bytes.is_empty());
|
||||||
|
assert!(parser.new_bytes.is_empty());
|
||||||
|
|
||||||
|
assert_eq!(actions1.len(), 0);
|
||||||
|
assert_eq!(actions2.len(), 0);
|
||||||
|
assert_eq!(actions3.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_search_marker() {
|
||||||
|
let input = r#"src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< WRONG_MARKER
|
||||||
|
fn original() {}
|
||||||
|
=======
|
||||||
|
fn replacement() {}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
assert_eq!(actions.len(), 0);
|
||||||
|
|
||||||
|
assert_eq!(parser.errors().len(), 1);
|
||||||
|
let error = &parser.errors()[0];
|
||||||
|
|
||||||
|
assert_eq!(error.line, 3);
|
||||||
|
assert_eq!(error.column, 9);
|
||||||
|
assert_eq!(
|
||||||
|
error.kind,
|
||||||
|
ParseErrorKind::ExpectedMarker {
|
||||||
|
expected: b"<<<<<<< SEARCH\n",
|
||||||
|
found: b'W'
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_missing_closing_fence() {
|
||||||
|
let input = r#"src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn original() {}
|
||||||
|
=======
|
||||||
|
fn replacement() {}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
<!-- Missing closing fence -->
|
||||||
|
|
||||||
|
src/utils.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
fn utils_func() {}
|
||||||
|
=======
|
||||||
|
fn new_utils_func() {}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(input);
|
||||||
|
|
||||||
|
// Only the second block should be parsed
|
||||||
|
assert_eq!(actions.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
actions[0],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "src/utils.rs".to_string(),
|
||||||
|
old: "fn utils_func() {}".to_string(),
|
||||||
|
new: "fn new_utils_func() {}".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// The parser should continue after an error
|
||||||
|
assert_eq!(parser.state, State::Default);
|
||||||
|
}
|
||||||
|
|
||||||
|
const SYSTEM_PROMPT: &str = include_str!("./edit_prompt.md");
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_examples_in_system_prompt() {
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let actions = parser.parse_chunk(SYSTEM_PROMPT);
|
||||||
|
assert_examples_in_system_prompt(&actions, parser.errors());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test(iterations = 10)]
|
||||||
|
fn test_random_chunking_of_system_prompt(mut rng: StdRng) {
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
let mut remaining = SYSTEM_PROMPT;
|
||||||
|
let mut actions = Vec::with_capacity(5);
|
||||||
|
|
||||||
|
while !remaining.is_empty() {
|
||||||
|
let chunk_size = rng.gen_range(1..=std::cmp::min(remaining.len(), 100));
|
||||||
|
|
||||||
|
let (chunk, rest) = remaining.split_at(chunk_size);
|
||||||
|
|
||||||
|
actions.extend(parser.parse_chunk(chunk));
|
||||||
|
remaining = rest;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_examples_in_system_prompt(&actions, parser.errors());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_examples_in_system_prompt(actions: &[EditAction], errors: &[ParseError]) {
|
||||||
|
assert_eq!(actions.len(), 5);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
actions[0],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "mathweb/flask/app.py".to_string(),
|
||||||
|
old: "from flask import Flask".to_string(),
|
||||||
|
new: "import math\nfrom flask import Flask".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
actions[1],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "mathweb/flask/app.py".to_string(),
|
||||||
|
old: "def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n".to_string(),
|
||||||
|
new: "".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
actions[2],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "mathweb/flask/app.py".to_string(),
|
||||||
|
old: " return str(factorial(n))".to_string(),
|
||||||
|
new: " return str(math.factorial(n))".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
actions[3],
|
||||||
|
EditAction::Write {
|
||||||
|
file_path: "hello.py".to_string(),
|
||||||
|
content: "def hello():\n \"print a greeting\"\n\n print(\"hello\")"
|
||||||
|
.to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
actions[4],
|
||||||
|
EditAction::Replace {
|
||||||
|
file_path: "main.py".to_string(),
|
||||||
|
old: "def hello():\n \"print a greeting\"\n\n print(\"hello\")".to_string(),
|
||||||
|
new: "from hello import hello".to_string(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// Ensure we have no parsing errors
|
||||||
|
assert!(errors.is_empty(), "Parsing errors found: {:?}", errors);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_print_error() {
|
||||||
|
let input = r#"src/main.rs
|
||||||
|
```rust
|
||||||
|
<<<<<<< WRONG_MARKER
|
||||||
|
fn original() {}
|
||||||
|
=======
|
||||||
|
fn replacement() {}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut parser = EditActionParser::new();
|
||||||
|
parser.parse_chunk(input);
|
||||||
|
|
||||||
|
assert_eq!(parser.errors().len(), 1);
|
||||||
|
let error = &parser.errors()[0];
|
||||||
|
let expected_error = r#"input:3:9: Expected marker "<<<<<<< SEARCH\n", found 'W'"#;
|
||||||
|
|
||||||
|
assert_eq!(format!("{}", error), expected_error);
|
||||||
|
}
|
||||||
|
}
|
134
crates/assistant_tools/src/edit_files_tool/edit_prompt.md
Normal file
134
crates/assistant_tools/src/edit_files_tool/edit_prompt.md
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
Act as an expert software developer.
|
||||||
|
Always use best practices when coding.
|
||||||
|
Respect and use existing conventions, libraries, etc that are already present in the code base.
|
||||||
|
|
||||||
|
Take requests for changes to the supplied code.
|
||||||
|
Always reply to the user in the same language they are using.
|
||||||
|
|
||||||
|
Once you understand the request, you MUST:
|
||||||
|
|
||||||
|
1. Decide if you need to propose *SEARCH/REPLACE* edits to any files that haven't been added to the chat. You can create new files without asking!
|
||||||
|
|
||||||
|
2. Think step-by-step and explain the needed changes in a few short sentences.
|
||||||
|
|
||||||
|
3. Describe each change with a *SEARCH/REPLACE block* per the examples below.
|
||||||
|
|
||||||
|
All changes to files must use this *SEARCH/REPLACE block* format.
|
||||||
|
ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
|
||||||
|
|
||||||
|
# Example conversations:
|
||||||
|
|
||||||
|
## USER: Change get_factorial() to use math.factorial
|
||||||
|
|
||||||
|
## ASSISTANT: To make this change we need to modify `mathweb/flask/app.py` to:
|
||||||
|
|
||||||
|
1. Import the math package.
|
||||||
|
2. Remove the existing factorial() function.
|
||||||
|
3. Update get_factorial() to call math.factorial instead.
|
||||||
|
|
||||||
|
Here are the *SEARCH/REPLACE* blocks:
|
||||||
|
|
||||||
|
mathweb/flask/app.py
|
||||||
|
```python
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
from flask import Flask
|
||||||
|
=======
|
||||||
|
import math
|
||||||
|
from flask import Flask
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
|
||||||
|
mathweb/flask/app.py
|
||||||
|
```python
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
def factorial(n):
|
||||||
|
"compute factorial"
|
||||||
|
|
||||||
|
if n == 0:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return n * factorial(n-1)
|
||||||
|
|
||||||
|
=======
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
|
||||||
|
mathweb/flask/app.py
|
||||||
|
```python
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
return str(factorial(n))
|
||||||
|
=======
|
||||||
|
return str(math.factorial(n))
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## USER: Refactor hello() into its own file.
|
||||||
|
|
||||||
|
## ASSISTANT: To make this change we need to modify `main.py` and make a new file `hello.py`:
|
||||||
|
|
||||||
|
1. Make a new hello.py file with hello() in it.
|
||||||
|
2. Remove hello() from main.py and replace it with an import.
|
||||||
|
|
||||||
|
Here are the *SEARCH/REPLACE* blocks:
|
||||||
|
|
||||||
|
hello.py
|
||||||
|
```python
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
=======
|
||||||
|
def hello():
|
||||||
|
"print a greeting"
|
||||||
|
|
||||||
|
print("hello")
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
|
||||||
|
main.py
|
||||||
|
```python
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
def hello():
|
||||||
|
"print a greeting"
|
||||||
|
|
||||||
|
print("hello")
|
||||||
|
=======
|
||||||
|
from hello import hello
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
# *SEARCH/REPLACE block* Rules:
|
||||||
|
|
||||||
|
Every *SEARCH/REPLACE block* must use this format:
|
||||||
|
1. The *FULL* file path alone on a line, verbatim. No bold asterisks, no quotes around it, no escaping of characters, etc.
|
||||||
|
2. The opening fence and code language, eg: ```python
|
||||||
|
3. The start of search block: <<<<<<< SEARCH
|
||||||
|
4. A contiguous chunk of lines to search for in the existing source code
|
||||||
|
5. The dividing line: =======
|
||||||
|
6. The lines to replace into the source code
|
||||||
|
7. The end of the replace block: >>>>>>> REPLACE
|
||||||
|
8. The closing fence: ```
|
||||||
|
|
||||||
|
Use the *FULL* file path, as shown to you by the user.
|
||||||
|
|
||||||
|
Every *SEARCH* section must *EXACTLY MATCH* the existing file content, character for character, including all comments, docstrings, etc.
|
||||||
|
If the file contains code or other data wrapped/escaped in json/xml/quotes or other containers, you need to propose edits to the literal contents of the file, including the container markup.
|
||||||
|
|
||||||
|
*SEARCH/REPLACE* blocks will *only* replace the first match occurrence.
|
||||||
|
Including multiple unique *SEARCH/REPLACE* blocks if needed.
|
||||||
|
Include enough lines in each SEARCH section to uniquely match each set of lines that need to change.
|
||||||
|
|
||||||
|
Keep *SEARCH/REPLACE* blocks concise.
|
||||||
|
Break large *SEARCH/REPLACE* blocks into a series of smaller blocks that each change a small portion of the file.
|
||||||
|
Include just the changing lines, and a few surrounding lines if needed for uniqueness.
|
||||||
|
Do not include long runs of unchanging lines in *SEARCH/REPLACE* blocks.
|
||||||
|
|
||||||
|
Only create *SEARCH/REPLACE* blocks for files that the user has added to the chat!
|
||||||
|
|
||||||
|
To move code within a file, use 2 *SEARCH/REPLACE* blocks: 1 to delete it from its current location, 1 to insert it in the new location.
|
||||||
|
|
||||||
|
Pay attention to which filenames the user wants you to edit, especially if they are asking you to create a new file.
|
||||||
|
|
||||||
|
If you want to put code in a new file, use a *SEARCH/REPLACE block* with:
|
||||||
|
- A new file path, including dir name if needed
|
||||||
|
- An empty `SEARCH` section
|
||||||
|
- The new file's contents in the `REPLACE` section
|
||||||
|
|
||||||
|
ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
|
|
@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use assistant_tool::Tool;
|
use assistant_tool::Tool;
|
||||||
use gpui::{App, Entity, Task};
|
use gpui::{App, Entity, Task};
|
||||||
|
use language_model::LanguageModelRequestMessage;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -34,6 +35,7 @@ impl Tool for ListWorktreesTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
_input: serde_json::Value,
|
_input: serde_json::Value,
|
||||||
|
_messages: &[LanguageModelRequestMessage],
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<String>> {
|
) -> Task<Result<String>> {
|
||||||
|
|
|
@ -4,6 +4,7 @@ use anyhow::{anyhow, Result};
|
||||||
use assistant_tool::Tool;
|
use assistant_tool::Tool;
|
||||||
use chrono::{Local, Utc};
|
use chrono::{Local, Utc};
|
||||||
use gpui::{App, Entity, Task};
|
use gpui::{App, Entity, Task};
|
||||||
|
use language_model::LanguageModelRequestMessage;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -42,6 +43,7 @@ impl Tool for NowTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
_messages: &[LanguageModelRequestMessage],
|
||||||
_project: Entity<Project>,
|
_project: Entity<Project>,
|
||||||
_cx: &mut App,
|
_cx: &mut App,
|
||||||
) -> Task<Result<String>> {
|
) -> Task<Result<String>> {
|
||||||
|
|
|
@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use assistant_tool::Tool;
|
use assistant_tool::Tool;
|
||||||
use gpui::{App, Entity, Task};
|
use gpui::{App, Entity, Task};
|
||||||
|
use language_model::LanguageModelRequestMessage;
|
||||||
use project::{Project, ProjectPath, WorktreeId};
|
use project::{Project, ProjectPath, WorktreeId};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -37,6 +38,7 @@ impl Tool for ReadFileTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
_messages: &[LanguageModelRequestMessage],
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<String>> {
|
) -> Task<Result<String>> {
|
||||||
|
@ -56,7 +58,16 @@ impl Tool for ReadFileTool {
|
||||||
})?
|
})?
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
cx.update(|cx| buffer.read(cx).text())
|
buffer.read_with(&cx, |buffer, _cx| {
|
||||||
|
if buffer
|
||||||
|
.file()
|
||||||
|
.map_or(false, |file| file.disk_state().exists())
|
||||||
|
{
|
||||||
|
Ok(buffer.text())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("File does not exist"))
|
||||||
|
}
|
||||||
|
})?
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ context_server_settings.workspace = true
|
||||||
extension.workspace = true
|
extension.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
postage.workspace = true
|
postage.workspace = true
|
||||||
|
|
|
@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||||
use anyhow::{anyhow, bail, Result};
|
use anyhow::{anyhow, bail, Result};
|
||||||
use assistant_tool::{Tool, ToolSource};
|
use assistant_tool::{Tool, ToolSource};
|
||||||
use gpui::{App, Entity, Task};
|
use gpui::{App, Entity, Task};
|
||||||
|
use language_model::LanguageModelRequestMessage;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
|
||||||
use crate::manager::ContextServerManager;
|
use crate::manager::ContextServerManager;
|
||||||
|
@ -58,6 +59,7 @@ impl Tool for ContextServerTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
_messages: &[LanguageModelRequestMessage],
|
||||||
_project: Entity<Project>,
|
_project: Entity<Project>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<String>> {
|
) -> Task<Result<String>> {
|
||||||
|
|
|
@ -18,6 +18,7 @@ impl Global for GlobalLanguageModelRegistry {}
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct LanguageModelRegistry {
|
pub struct LanguageModelRegistry {
|
||||||
active_model: Option<ActiveModel>,
|
active_model: Option<ActiveModel>,
|
||||||
|
editor_model: Option<ActiveModel>,
|
||||||
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
|
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
|
||||||
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
|
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
|
||||||
}
|
}
|
||||||
|
@ -29,6 +30,7 @@ pub struct ActiveModel {
|
||||||
|
|
||||||
pub enum Event {
|
pub enum Event {
|
||||||
ActiveModelChanged,
|
ActiveModelChanged,
|
||||||
|
EditorModelChanged,
|
||||||
ProviderStateChanged,
|
ProviderStateChanged,
|
||||||
AddedProvider(LanguageModelProviderId),
|
AddedProvider(LanguageModelProviderId),
|
||||||
RemovedProvider(LanguageModelProviderId),
|
RemovedProvider(LanguageModelProviderId),
|
||||||
|
@ -128,6 +130,22 @@ impl LanguageModelRegistry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn select_editor_model(
|
||||||
|
&mut self,
|
||||||
|
provider: &LanguageModelProviderId,
|
||||||
|
model_id: &LanguageModelId,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let Some(provider) = self.provider(provider) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let models = provider.provided_models(cx);
|
||||||
|
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||||
|
self.set_editor_model(Some(model), cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_active_provider(
|
pub fn set_active_provider(
|
||||||
&mut self,
|
&mut self,
|
||||||
provider: Option<Arc<dyn LanguageModelProvider>>,
|
provider: Option<Arc<dyn LanguageModelProvider>>,
|
||||||
|
@ -162,6 +180,28 @@ impl LanguageModelRegistry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_editor_model(
|
||||||
|
&mut self,
|
||||||
|
model: Option<Arc<dyn LanguageModel>>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
if let Some(model) = model {
|
||||||
|
let provider_id = model.provider_id();
|
||||||
|
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||||
|
self.editor_model = Some(ActiveModel {
|
||||||
|
provider,
|
||||||
|
model: Some(model),
|
||||||
|
});
|
||||||
|
cx.emit(Event::EditorModelChanged);
|
||||||
|
} else {
|
||||||
|
log::warn!("Active model's provider not found in registry");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.editor_model = None;
|
||||||
|
cx.emit(Event::EditorModelChanged);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
|
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||||
Some(self.active_model.as_ref()?.provider.clone())
|
Some(self.active_model.as_ref()?.provider.clone())
|
||||||
}
|
}
|
||||||
|
@ -170,6 +210,10 @@ impl LanguageModelRegistry {
|
||||||
self.active_model.as_ref()?.model.clone()
|
self.active_model.as_ref()?.model.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
||||||
|
self.editor_model.as_ref()?.model.clone()
|
||||||
|
}
|
||||||
|
|
||||||
/// Selects and sets the inline alternatives for language models based on
|
/// Selects and sets the inline alternatives for language models based on
|
||||||
/// provider name and id.
|
/// provider name and id.
|
||||||
pub fn select_inline_alternative_models(
|
pub fn select_inline_alternative_models(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue