diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 542f6c2df4..f15c4dfe22 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -520,6 +520,13 @@ "alt-enter": "editor::Newline" } }, + { + "context": "PromptEditor", + "bindings": { + "ctrl-[": "assistant::CyclePreviousInlineAssist", + "ctrl-]": "assistant::CycleNextInlineAssist" + } + }, { "context": "ProjectSearchBar && !in_replace", "bindings": { diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 77fac3254b..a58112b3c0 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -527,6 +527,13 @@ "ctrl-enter": "assistant::InlineAssist" } }, + { + "context": "PromptEditor", + "bindings": { + "ctrl-[": "assistant::CyclePreviousInlineAssist", + "ctrl-]": "assistant::CycleNextInlineAssist" + } + }, { "context": "ProjectSearchBar && !in_replace", "bindings": { diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index d7466878c9..8b9c66ee55 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -69,6 +69,8 @@ actions!( ConfirmCommand, NewContext, ToggleModelSelector, + CycleNextInlineAssist, + CyclePreviousInlineAssist ] ); @@ -359,8 +361,19 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) { let settings = AssistantSettings::get_global(cx); let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone()); let model_id = LanguageModelId::from(settings.default_model.model.clone()); + let inline_alternatives = settings + .inline_alternatives + .iter() + .map(|alternative| { + ( + LanguageModelProviderId::from(alternative.provider.clone()), + LanguageModelId::from(alternative.model.clone()), + ) + }) + .collect::>(); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.select_active_model(&provider_name, &model_id, cx); + registry.select_inline_alternative_models(inline_alternatives, cx); }); } diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index e2c6a8eb24..5aa379bae3 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -59,6 +59,7 @@ pub struct AssistantSettings { pub default_width: Pixels, pub default_height: Pixels, pub default_model: LanguageModelSelection, + pub inline_alternatives: Vec, pub using_outdated_settings_version: bool, } @@ -236,6 +237,7 @@ impl AssistantSettingsContent { }) } }), + inline_alternatives: None, }, VersionedAssistantSettingsContent::V2(settings) => settings.clone(), }, @@ -254,6 +256,7 @@ impl AssistantSettingsContent { .id() .to_string(), }), + inline_alternatives: None, }, } } @@ -369,6 +372,7 @@ impl Default for VersionedAssistantSettingsContent { default_width: None, default_height: None, default_model: None, + inline_alternatives: None, }) } } @@ -397,6 +401,8 @@ pub struct AssistantSettingsContentV2 { default_height: Option, /// The default model to use when creating new contexts. default_model: Option, + /// Additional models with which to generate alternatives when performing inline assists. + inline_alternatives: Option>, } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] @@ -517,10 +523,8 @@ impl Settings for AssistantSettings { &mut settings.default_height, value.default_height.map(Into::into), ); - merge( - &mut settings.default_model, - value.default_model.map(Into::into), - ); + merge(&mut settings.default_model, value.default_model); + merge(&mut settings.inline_alternatives, value.inline_alternatives); // merge(&mut settings.infer_context, value.infer_context); TODO re-enable this once we ship context inference } @@ -574,6 +578,7 @@ mod tests { provider: "test-provider".into(), model: "gpt-99".into(), }), + inline_alternatives: None, enabled: None, button: None, dock: None, diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index c9360213ae..428b33f3bb 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1,7 +1,7 @@ use crate::{ assistant_settings::AssistantSettings, humanize_token_count, prompts::PromptBuilder, - AssistantPanel, AssistantPanelEvent, CharOperation, LineDiff, LineOperation, ModelSelector, - StreamingDiff, + AssistantPanel, AssistantPanelEvent, CharOperation, CycleNextInlineAssist, + CyclePreviousInlineAssist, LineDiff, LineOperation, ModelSelector, StreamingDiff, }; use anyhow::{anyhow, Context as _, Result}; use client::{telemetry::Telemetry, ErrorExt}; @@ -25,13 +25,13 @@ use futures::{ SinkExt, Stream, StreamExt, }; use gpui::{ - anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView, - FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, - UpdateGlobal, View, ViewContext, WeakView, WindowContext, + anchored, deferred, point, AnyElement, AppContext, ClickEvent, EventEmitter, FocusHandle, + FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, + TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext, }; use language::{Buffer, IndentKind, Point, Selection, TransactionId}; use language_model::{ - LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, + LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, }; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; @@ -41,7 +41,7 @@ use smol::future::FutureExt; use std::{ cmp, future::{self, Future}, - mem, + iter, mem, ops::{Range, RangeInclusive}, pin::Pin, sync::Arc, @@ -85,7 +85,7 @@ pub struct InlineAssistant { async_watch::Receiver, ), >, - confirmed_assists: HashMap>, + confirmed_assists: HashMap>, prompt_history: VecDeque, prompt_builder: Arc, telemetry: Option>, @@ -157,7 +157,7 @@ impl InlineAssistant { if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) { for assist_id in editor_assists.assist_ids.clone() { let assist = &self.assists[&assist_id]; - if let CodegenStatus::Done = &assist.codegen.read(cx).status { + if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) { self.finish_assist(assist_id, false, cx) } } @@ -553,7 +553,7 @@ impl InlineAssistant { let assist_range = assist.range.to_offset(&buffer); if assist_range.contains(&selection.start) && assist_range.contains(&selection.end) { - if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) { + if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) { self.dismiss_assist(*assist_id, cx); } else { self.finish_assist(*assist_id, false, cx); @@ -671,7 +671,7 @@ impl InlineAssistant { for assist_id in editor_assists.assist_ids.clone() { let assist = &self.assists[&assist_id]; if matches!( - assist.codegen.read(cx).status, + assist.codegen.read(cx).status(cx), CodegenStatus::Error(_) | CodegenStatus::Done ) { let assist_range = assist.range.to_offset(&snapshot); @@ -774,7 +774,9 @@ impl InlineAssistant { if undo { assist.codegen.update(cx, |codegen, cx| codegen.undo(cx)); } else { - self.confirmed_assists.insert(assist_id, assist.codegen); + let confirmed_alternative = assist.codegen.read(cx).active_alternative().clone(); + self.confirmed_assists + .insert(assist_id, confirmed_alternative); } } @@ -978,12 +980,7 @@ impl InlineAssistant { assist .codegen .update(cx, |codegen, cx| { - codegen.start( - assist.range.clone(), - user_prompt, - assistant_panel_context, - cx, - ) + codegen.start(user_prompt, assistant_panel_context, cx) }) .log_err(); @@ -1008,7 +1005,7 @@ impl InlineAssistant { pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus { if let Some(assist) = self.assists.get(&assist_id) { - match &assist.codegen.read(cx).status { + match assist.codegen.read(cx).status(cx) { CodegenStatus::Idle => InlineAssistStatus::Idle, CodegenStatus::Pending => InlineAssistStatus::Pending, CodegenStatus::Done => InlineAssistStatus::Done, @@ -1037,16 +1034,16 @@ impl InlineAssistant { for assist_id in assist_ids { if let Some(assist) = self.assists.get(assist_id) { let codegen = assist.codegen.read(cx); - let buffer = codegen.buffer.read(cx).read(cx); - foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned()); + let buffer = codegen.buffer(cx).read(cx).read(cx); + foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned()); let pending_range = - codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end; + codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end; if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) { gutter_pending_ranges.push(pending_range); } - if let Some(edit_position) = codegen.edit_position { + if let Some(edit_position) = codegen.edit_position(cx) { let edited_range = assist.range.start..edit_position; if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) { gutter_transformed_ranges.push(edited_range); @@ -1054,7 +1051,8 @@ impl InlineAssistant { } if assist.decorations.is_some() { - inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned()); + inserted_row_ranges + .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned()); } } } @@ -1125,9 +1123,9 @@ impl InlineAssistant { }; let codegen = assist.codegen.read(cx); - let old_snapshot = codegen.snapshot.clone(); - let old_buffer = codegen.old_buffer.clone(); - let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone(); + let old_snapshot = codegen.snapshot(cx); + let old_buffer = codegen.old_buffer(cx); + let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone(); editor.update(cx, |editor, cx| { let old_blocks = mem::take(&mut decorations.removed_line_block_ids); @@ -1406,8 +1404,15 @@ impl EventEmitter for PromptEditor {} impl Render for PromptEditor { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { let gutter_dimensions = *self.gutter_dimensions.lock(); - let status = &self.codegen.read(cx).status; - let buttons = match status { + let codegen = self.codegen.read(cx); + + let mut buttons = Vec::new(); + if codegen.alternative_count(cx) > 1 { + buttons.push(self.render_cycle_controls(cx)); + } + + let status = codegen.status(cx); + buttons.extend(match status { CodegenStatus::Idle => { vec![ IconButton::new("cancel", IconName::Close) @@ -1416,14 +1421,16 @@ impl Render for PromptEditor { .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx)) .on_click( cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)), - ), + ) + .into_any_element(), IconButton::new("start", IconName::SparkleAlt) .icon_color(Color::Muted) .shape(IconButtonShape::Square) .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx)) .on_click( cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)), - ), + ) + .into_any_element(), ] } CodegenStatus::Pending => { @@ -1434,7 +1441,8 @@ impl Render for PromptEditor { .tooltip(|cx| Tooltip::text("Cancel Assist", cx)) .on_click( cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)), - ), + ) + .into_any_element(), IconButton::new("stop", IconName::Stop) .icon_color(Color::Error) .shape(IconButtonShape::Square) @@ -1446,9 +1454,8 @@ impl Render for PromptEditor { cx, ) }) - .on_click( - cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)), - ), + .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested))) + .into_any_element(), ] } CodegenStatus::Error(_) | CodegenStatus::Done => { @@ -1459,7 +1466,8 @@ impl Render for PromptEditor { .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx)) .on_click( cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)), - ), + ) + .into_any_element(), if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) { IconButton::new("restart", IconName::RotateCw) .icon_color(Color::Info) @@ -1475,6 +1483,7 @@ impl Render for PromptEditor { .on_click(cx.listener(|_, _, cx| { cx.emit(PromptEditorEvent::StartRequested); })) + .into_any_element() } else { IconButton::new("confirm", IconName::Check) .icon_color(Color::Info) @@ -1483,12 +1492,14 @@ impl Render for PromptEditor { .on_click(cx.listener(|_, _, cx| { cx.emit(PromptEditorEvent::ConfirmRequested); })) + .into_any_element() }, ] } - }; + }); h_flex() + .key_context("PromptEditor") .bg(cx.theme().colors().editor_background) .border_y_1() .border_color(cx.theme().status().info_border) @@ -1498,6 +1509,8 @@ impl Render for PromptEditor { .on_action(cx.listener(Self::cancel)) .on_action(cx.listener(Self::move_up)) .on_action(cx.listener(Self::move_down)) + .capture_action(cx.listener(Self::cycle_prev)) + .capture_action(cx.listener(Self::cycle_next)) .child( h_flex() .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)) @@ -1532,7 +1545,7 @@ impl Render for PromptEditor { ), ) .map(|el| { - let CodegenStatus::Error(error) = &self.codegen.read(cx).status else { + let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else { return el; }; @@ -1776,7 +1789,7 @@ impl PromptEditor { } fn handle_codegen_changed(&mut self, _: Model, cx: &mut ViewContext) { - match &self.codegen.read(cx).status { + match self.codegen.read(cx).status(cx) { CodegenStatus::Idle => { self.editor .update(cx, |editor, _| editor.set_read_only(false)); @@ -1807,7 +1820,7 @@ impl PromptEditor { } fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext) { - match &self.codegen.read(cx).status { + match self.codegen.read(cx).status(cx) { CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => { cx.emit(PromptEditorEvent::CancelRequested); } @@ -1818,7 +1831,7 @@ impl PromptEditor { } fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - match &self.codegen.read(cx).status { + match self.codegen.read(cx).status(cx) { CodegenStatus::Idle => { cx.emit(PromptEditorEvent::StartRequested); } @@ -1878,6 +1891,79 @@ impl PromptEditor { } } + fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext) { + self.codegen + .update(cx, |codegen, cx| codegen.cycle_prev(cx)); + } + + fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext) { + self.codegen + .update(cx, |codegen, cx| codegen.cycle_next(cx)); + } + + fn render_cycle_controls(&self, cx: &ViewContext) -> AnyElement { + let codegen = self.codegen.read(cx); + let disabled = matches!(codegen.status(cx), CodegenStatus::Idle); + + h_flex() + .child( + IconButton::new("previous", IconName::ChevronLeft) + .icon_color(Color::Muted) + .disabled(disabled) + .shape(IconButtonShape::Square) + .tooltip({ + let focus_handle = self.editor.focus_handle(cx); + move |cx| { + Tooltip::for_action_in( + "Previous Alternative", + &CyclePreviousInlineAssist, + &focus_handle, + cx, + ) + } + }) + .on_click(cx.listener(|this, _, cx| { + this.codegen + .update(cx, |codegen, cx| codegen.cycle_prev(cx)) + })), + ) + .child( + Label::new(format!( + "{}/{}", + codegen.active_alternative + 1, + codegen.alternative_count(cx) + )) + .size(LabelSize::Small) + .color(if disabled { + Color::Disabled + } else { + Color::Muted + }), + ) + .child( + IconButton::new("next", IconName::ChevronRight) + .icon_color(Color::Muted) + .disabled(disabled) + .shape(IconButtonShape::Square) + .tooltip({ + let focus_handle = self.editor.focus_handle(cx); + move |cx| { + Tooltip::for_action_in( + "Next Alternative", + &CycleNextInlineAssist, + &focus_handle, + cx, + ) + } + }) + .on_click(cx.listener(|this, _, cx| { + this.codegen + .update(cx, |codegen, cx| codegen.cycle_next(cx)) + })), + ) + .into_any_element() + } + fn render_token_count(&self, cx: &mut ViewContext) -> Option { let model = LanguageModelRegistry::read_global(cx).active_model()?; let token_counts = self.token_counts?; @@ -2124,7 +2210,7 @@ impl InlineAssist { return; }; - if let CodegenStatus::Error(error) = &codegen.read(cx).status { + if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) { if assist.decorations.is_none() { if let Some(workspace) = assist .workspace @@ -2185,12 +2271,9 @@ impl InlineAssist { return future::ready(Err(anyhow!("no user prompt"))).boxed(); }; let assistant_panel_context = self.assistant_panel_context(cx); - self.codegen.read(cx).count_tokens( - self.range.clone(), - user_prompt, - assistant_panel_context, - cx, - ) + self.codegen + .read(cx) + .count_tokens(user_prompt, assistant_panel_context, cx) } } @@ -2201,19 +2284,216 @@ struct InlineAssistDecorations { end_block_id: CustomBlockId, } -#[derive(Debug)] +#[derive(Copy, Clone, Debug)] pub enum CodegenEvent { Finished, Undone, } pub struct Codegen { + alternatives: Vec>, + active_alternative: usize, + subscriptions: Vec, + buffer: Model, + range: Range, + initial_transaction_id: Option, + telemetry: Option>, + builder: Arc, +} + +impl Codegen { + pub fn new( + buffer: Model, + range: Range, + initial_transaction_id: Option, + telemetry: Option>, + builder: Arc, + cx: &mut ModelContext, + ) -> Self { + let codegen = cx.new_model(|cx| { + CodegenAlternative::new( + buffer.clone(), + range.clone(), + false, + telemetry.clone(), + builder.clone(), + cx, + ) + }); + let mut this = Self { + alternatives: vec![codegen], + active_alternative: 0, + subscriptions: Vec::new(), + buffer, + range, + initial_transaction_id, + telemetry, + builder, + }; + this.activate(0, cx); + this + } + + fn subscribe_to_alternative(&mut self, cx: &mut ModelContext) { + let codegen = self.active_alternative().clone(); + self.subscriptions.clear(); + self.subscriptions + .push(cx.observe(&codegen, |_, _, cx| cx.notify())); + self.subscriptions + .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event))); + } + + fn active_alternative(&self) -> &Model { + &self.alternatives[self.active_alternative] + } + + fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus { + &self.active_alternative().read(cx).status + } + + fn alternative_count(&self, cx: &AppContext) -> usize { + LanguageModelRegistry::read_global(cx) + .inline_alternative_models() + .len() + + 1 + } + + pub fn cycle_prev(&mut self, cx: &mut ModelContext) { + let next_active_ix = if self.active_alternative == 0 { + self.alternatives.len() - 1 + } else { + self.active_alternative - 1 + }; + self.activate(next_active_ix, cx); + } + + pub fn cycle_next(&mut self, cx: &mut ModelContext) { + let next_active_ix = (self.active_alternative + 1) % self.alternatives.len(); + self.activate(next_active_ix, cx); + } + + fn activate(&mut self, index: usize, cx: &mut ModelContext) { + self.active_alternative() + .update(cx, |codegen, cx| codegen.set_active(false, cx)); + self.active_alternative = index; + self.active_alternative() + .update(cx, |codegen, cx| codegen.set_active(true, cx)); + self.subscribe_to_alternative(cx); + cx.notify(); + } + + pub fn start( + &mut self, + user_prompt: String, + assistant_panel_context: Option, + cx: &mut ModelContext, + ) -> Result<()> { + let alternative_models = LanguageModelRegistry::read_global(cx) + .inline_alternative_models() + .to_vec(); + + self.active_alternative() + .update(cx, |alternative, cx| alternative.undo(cx)); + self.activate(0, cx); + self.alternatives.truncate(1); + + for _ in 0..alternative_models.len() { + self.alternatives.push(cx.new_model(|cx| { + CodegenAlternative::new( + self.buffer.clone(), + self.range.clone(), + false, + self.telemetry.clone(), + self.builder.clone(), + cx, + ) + })); + } + + let primary_model = LanguageModelRegistry::read_global(cx) + .active_model() + .context("no active model")?; + + for (model, alternative) in iter::once(primary_model) + .chain(alternative_models) + .zip(&self.alternatives) + { + alternative.update(cx, |alternative, cx| { + alternative.start( + user_prompt.clone(), + assistant_panel_context.clone(), + model.clone(), + cx, + ) + })?; + } + + Ok(()) + } + + pub fn stop(&mut self, cx: &mut ModelContext) { + for codegen in &self.alternatives { + codegen.update(cx, |codegen, cx| codegen.stop(cx)); + } + } + + pub fn undo(&mut self, cx: &mut ModelContext) { + self.active_alternative() + .update(cx, |codegen, cx| codegen.undo(cx)); + + self.buffer.update(cx, |buffer, cx| { + if let Some(transaction_id) = self.initial_transaction_id.take() { + buffer.undo_transaction(transaction_id, cx); + buffer.refresh_preview(cx); + } + }); + } + + pub fn count_tokens( + &self, + user_prompt: String, + assistant_panel_context: Option, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + self.active_alternative() + .read(cx) + .count_tokens(user_prompt, assistant_panel_context, cx) + } + + pub fn buffer(&self, cx: &AppContext) -> Model { + self.active_alternative().read(cx).buffer.clone() + } + + pub fn old_buffer(&self, cx: &AppContext) -> Model { + self.active_alternative().read(cx).old_buffer.clone() + } + + pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot { + self.active_alternative().read(cx).snapshot.clone() + } + + pub fn edit_position(&self, cx: &AppContext) -> Option { + self.active_alternative().read(cx).edit_position + } + + fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff { + &self.active_alternative().read(cx).diff + } + + pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range] { + self.active_alternative().read(cx).last_equal_ranges() + } +} + +impl EventEmitter for Codegen {} + +pub struct CodegenAlternative { buffer: Model, old_buffer: Model, snapshot: MultiBufferSnapshot, edit_position: Option, + range: Range, last_equal_ranges: Vec>, - initial_transaction_id: Option, transformation_transaction_id: Option, status: CodegenStatus, generation: Task<()>, @@ -2221,6 +2501,9 @@ pub struct Codegen { telemetry: Option>, _subscription: gpui::Subscription, builder: Arc, + active: bool, + edits: Vec<(Range, String)>, + line_operations: Vec, } enum CodegenStatus { @@ -2242,13 +2525,13 @@ impl Diff { } } -impl EventEmitter for Codegen {} +impl EventEmitter for CodegenAlternative {} -impl Codegen { +impl CodegenAlternative { pub fn new( buffer: Model, range: Range, - initial_transaction_id: Option, + active: bool, telemetry: Option>, builder: Arc, cx: &mut ModelContext, @@ -2287,8 +2570,33 @@ impl Codegen { diff: Diff::default(), telemetry, _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), - initial_transaction_id, builder, + active, + edits: Vec::new(), + line_operations: Vec::new(), + range, + } + } + + fn set_active(&mut self, active: bool, cx: &mut ModelContext) { + if active != self.active { + self.active = active; + + if self.active { + let edits = self.edits.clone(); + self.apply_edits(edits, cx); + if matches!(self.status, CodegenStatus::Pending) { + let line_operations = self.line_operations.clone(); + self.reapply_line_based_diff(line_operations, cx); + } else { + self.reapply_batch_diff(cx).detach(); + } + } else if let Some(transaction_id) = self.transformation_transaction_id.take() { + self.buffer.update(cx, |buffer, cx| { + buffer.undo_transaction(transaction_id, cx); + buffer.forget_transaction(transaction_id, cx); + }); + } } } @@ -2313,14 +2621,12 @@ impl Codegen { pub fn count_tokens( &self, - edit_range: Range, user_prompt: String, assistant_panel_context: Option, cx: &AppContext, ) -> BoxFuture<'static, Result> { if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { - let request = - self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx); + let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx); match request { Ok(request) => { let total_count = model.count_tokens(request.clone(), cx); @@ -2345,39 +2651,31 @@ impl Codegen { pub fn start( &mut self, - edit_range: Range, user_prompt: String, assistant_panel_context: Option, + model: Arc, cx: &mut ModelContext, ) -> Result<()> { - let model = LanguageModelRegistry::read_global(cx) - .active_model() - .context("no active model")?; - if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() { self.buffer.update(cx, |buffer, cx| { buffer.undo_transaction(transformation_transaction_id, cx); }); } - self.edit_position = Some(edit_range.start.bias_right(&self.snapshot)); + self.edit_position = Some(self.range.start.bias_right(&self.snapshot)); let telemetry_id = model.telemetry_id(); - let chunks: LocalBoxFuture>>> = if user_prompt - .trim() - .to_lowercase() - == "delete" - { - async { Ok(stream::empty().boxed()) }.boxed_local() - } else { - let request = - self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?; + let chunks: LocalBoxFuture>>> = + if user_prompt.trim().to_lowercase() == "delete" { + async { Ok(stream::empty().boxed()) }.boxed_local() + } else { + let request = self.build_request(user_prompt, assistant_panel_context, cx)?; - let chunks = - cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await }); - async move { Ok(chunks.await?.boxed()) }.boxed_local() - }; - self.handle_stream(telemetry_id, edit_range, chunks, cx); + let chunks = cx + .spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await }); + async move { Ok(chunks.await?.boxed()) }.boxed_local() + }; + self.handle_stream(telemetry_id, chunks, cx); Ok(()) } @@ -2385,11 +2683,10 @@ impl Codegen { &self, user_prompt: String, assistant_panel_context: Option, - edit_range: Range, cx: &AppContext, ) -> Result { let buffer = self.buffer.read(cx).snapshot(cx); - let language = buffer.language_at(edit_range.start); + let language = buffer.language_at(self.range.start); let language_name = if let Some(language) = language.as_ref() { if Arc::ptr_eq(language, &language::PLAIN_TEXT) { None @@ -2401,8 +2698,8 @@ impl Codegen { }; let language_name = language_name.as_ref(); - let start = buffer.point_to_buffer_offset(edit_range.start); - let end = buffer.point_to_buffer_offset(edit_range.end); + let start = buffer.point_to_buffer_offset(self.range.start); + let end = buffer.point_to_buffer_offset(self.range.end); let (buffer, range) = if let Some((start, end)) = start.zip(end) { let (start_buffer, start_buffer_offset) = start; let (end_buffer, end_buffer_offset) = end; @@ -2442,16 +2739,15 @@ impl Codegen { pub fn handle_stream( &mut self, model_telemetry_id: String, - edit_range: Range, stream: impl 'static + Future>>>, cx: &mut ModelContext, ) { let snapshot = self.snapshot.clone(); let selected_text = snapshot - .text_for_range(edit_range.start..edit_range.end) + .text_for_range(self.range.start..self.range.end) .collect::(); - let selection_start = edit_range.start.to_point(&snapshot); + let selection_start = self.range.start.to_point(&snapshot); // Start with the indentation of the first line in the selection let mut suggested_line_indent = snapshot @@ -2462,7 +2758,7 @@ impl Codegen { // If the first line in the selection does not have indentation, check the following lines if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space { - for row in selection_start.row..=edit_range.end.to_point(&snapshot).row { + for row in selection_start.row..=self.range.end.to_point(&snapshot).row { let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row)); // Prefer tabs if a line in the selection uses tabs as indentation if line_indent.kind == IndentKind::Tab { @@ -2475,7 +2771,7 @@ impl Codegen { let telemetry = self.telemetry.clone(); self.diff = Diff::default(); self.status = CodegenStatus::Pending; - let mut edit_start = edit_range.start.to_offset(&snapshot); + let mut edit_start = self.range.start.to_offset(&snapshot); self.generation = cx.spawn(|codegen, mut cx| { async move { let chunks = stream.await; @@ -2597,68 +2893,42 @@ impl Codegen { Ok(()) }); - while let Some((char_ops, line_diff)) = diff_rx.next().await { + while let Some((char_ops, line_ops)) = diff_rx.next().await { codegen.update(&mut cx, |codegen, cx| { codegen.last_equal_ranges.clear(); - let transaction = codegen.buffer.update(cx, |buffer, cx| { - // Avoid grouping assistant edits with user edits. - buffer.finalize_last_transaction(cx); + let edits = char_ops + .into_iter() + .filter_map(|operation| match operation { + CharOperation::Insert { text } => { + let edit_start = snapshot.anchor_after(edit_start); + Some((edit_start..edit_start, text)) + } + CharOperation::Delete { bytes } => { + let edit_end = edit_start + bytes; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start = edit_end; + Some((edit_range, String::new())) + } + CharOperation::Keep { bytes } => { + let edit_end = edit_start + bytes; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start = edit_end; + codegen.last_equal_ranges.push(edit_range); + None + } + }) + .collect::>(); - buffer.start_transaction(cx); - buffer.edit( - char_ops - .into_iter() - .filter_map(|operation| match operation { - CharOperation::Insert { text } => { - let edit_start = snapshot.anchor_after(edit_start); - Some((edit_start..edit_start, text)) - } - CharOperation::Delete { bytes } => { - let edit_end = edit_start + bytes; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start = edit_end; - Some((edit_range, String::new())) - } - CharOperation::Keep { bytes } => { - let edit_end = edit_start + bytes; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start = edit_end; - codegen.last_equal_ranges.push(edit_range); - None - } - }), - None, - cx, - ); - codegen.edit_position = Some(snapshot.anchor_after(edit_start)); - - buffer.end_transaction(cx) - }); - - if let Some(transaction) = transaction { - if let Some(first_transaction) = - codegen.transformation_transaction_id - { - // Group all assistant edits into the first transaction. - codegen.buffer.update(cx, |buffer, cx| { - buffer.merge_transactions( - transaction, - first_transaction, - cx, - ) - }); - } else { - codegen.transformation_transaction_id = Some(transaction); - codegen.buffer.update(cx, |buffer, cx| { - buffer.finalize_last_transaction(cx) - }); - } + if codegen.active { + codegen.apply_edits(edits.iter().cloned(), cx); + codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx); } - - codegen.reapply_line_based_diff(edit_range.clone(), line_diff, cx); + codegen.edits.extend(edits); + codegen.line_operations = line_ops; + codegen.edit_position = Some(snapshot.anchor_after(edit_start)); cx.notify(); })?; @@ -2667,9 +2937,8 @@ impl Codegen { // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer. // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff. // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`. - let batch_diff_task = codegen.update(&mut cx, |codegen, cx| { - codegen.reapply_batch_diff(edit_range.clone(), cx) - })?; + let batch_diff_task = + codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?; let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task); line_based_stream_diff?; @@ -2713,24 +2982,45 @@ impl Codegen { buffer.undo_transaction(transaction_id, cx); buffer.refresh_preview(cx); } - - if let Some(transaction_id) = self.initial_transaction_id.take() { - buffer.undo_transaction(transaction_id, cx); - buffer.refresh_preview(cx); - } }); } + fn apply_edits( + &mut self, + edits: impl IntoIterator, String)>, + cx: &mut ModelContext, + ) { + let transaction = self.buffer.update(cx, |buffer, cx| { + // Avoid grouping assistant edits with user edits. + buffer.finalize_last_transaction(cx); + buffer.start_transaction(cx); + buffer.edit(edits, None, cx); + buffer.end_transaction(cx) + }); + + if let Some(transaction) = transaction { + if let Some(first_transaction) = self.transformation_transaction_id { + // Group all assistant edits into the first transaction. + self.buffer.update(cx, |buffer, cx| { + buffer.merge_transactions(transaction, first_transaction, cx) + }); + } else { + self.transformation_transaction_id = Some(transaction); + self.buffer + .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx)); + } + } + } + fn reapply_line_based_diff( &mut self, - edit_range: Range, - line_operations: Vec, + line_operations: impl IntoIterator, cx: &mut ModelContext, ) { let old_snapshot = self.snapshot.clone(); - let old_range = edit_range.to_point(&old_snapshot); + let old_range = self.range.to_point(&old_snapshot); let new_snapshot = self.buffer.read(cx).snapshot(cx); - let new_range = edit_range.to_point(&new_snapshot); + let new_range = self.range.to_point(&new_snapshot); let mut old_row = old_range.start.row; let mut new_row = new_range.start.row; @@ -2781,15 +3071,11 @@ impl Codegen { } } - fn reapply_batch_diff( - &mut self, - edit_range: Range, - cx: &mut ModelContext, - ) -> Task<()> { + fn reapply_batch_diff(&mut self, cx: &mut ModelContext) -> Task<()> { let old_snapshot = self.snapshot.clone(); - let old_range = edit_range.to_point(&old_snapshot); + let old_range = self.range.to_point(&old_snapshot); let new_snapshot = self.buffer.read(cx).snapshot(cx); - let new_range = edit_range.to_point(&new_snapshot); + let new_range = self.range.to_point(&new_snapshot); cx.spawn(|codegen, mut cx| async move { let (deleted_row_ranges, inserted_row_ranges) = cx @@ -3073,10 +3359,10 @@ mod tests { }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let codegen = cx.new_model(|cx| { - Codegen::new( + CodegenAlternative::new( buffer.clone(), range.clone(), - None, + true, None, prompt_builder, cx, @@ -3087,7 +3373,6 @@ mod tests { codegen.update(cx, |codegen, cx| { codegen.handle_stream( String::new(), - range, future::ready(Ok(chunks_rx.map(Ok).boxed())), cx, ) @@ -3145,10 +3430,10 @@ mod tests { }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let codegen = cx.new_model(|cx| { - Codegen::new( + CodegenAlternative::new( buffer.clone(), range.clone(), - None, + true, None, prompt_builder, cx, @@ -3159,7 +3444,6 @@ mod tests { codegen.update(cx, |codegen, cx| { codegen.handle_stream( String::new(), - range.clone(), future::ready(Ok(chunks_rx.map(Ok).boxed())), cx, ) @@ -3220,10 +3504,10 @@ mod tests { }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let codegen = cx.new_model(|cx| { - Codegen::new( + CodegenAlternative::new( buffer.clone(), range.clone(), - None, + true, None, prompt_builder, cx, @@ -3234,7 +3518,6 @@ mod tests { codegen.update(cx, |codegen, cx| { codegen.handle_stream( String::new(), - range.clone(), future::ready(Ok(chunks_rx.map(Ok).boxed())), cx, ) @@ -3294,10 +3577,10 @@ mod tests { }); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let codegen = cx.new_model(|cx| { - Codegen::new( + CodegenAlternative::new( buffer.clone(), range.clone(), - None, + true, None, prompt_builder, cx, @@ -3308,7 +3591,6 @@ mod tests { codegen.update(cx, |codegen, cx| { codegen.handle_stream( String::new(), - range.clone(), future::ready(Ok(chunks_rx.map(Ok).boxed())), cx, ) @@ -3338,6 +3620,78 @@ mod tests { ); } + #[gpui::test] + async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) { + cx.update(LanguageModelRegistry::test); + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + let x = 0; + } + "}; + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + CodegenAlternative::new( + buffer.clone(), + range.clone(), + false, + None, + prompt_builder, + cx, + ) + }); + + let (chunks_tx, chunks_rx) = mpsc::unbounded(); + codegen.update(cx, |codegen, cx| { + codegen.handle_stream( + String::new(), + future::ready(Ok(chunks_rx.map(Ok).boxed())), + cx, + ) + }); + + chunks_tx + .unbounded_send("let mut x = 0;\nx += 1;".to_string()) + .unwrap(); + drop(chunks_tx); + cx.run_until_parked(); + + // The codegen is inactive, so the buffer doesn't get modified. + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + text + ); + + // Activating the codegen applies the changes. + codegen.update(cx, |codegen, cx| codegen.set_active(true, cx)); + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + x += 1; + } + "} + ); + + // Deactivating the codegen undoes the changes. + codegen.update(cx, |codegen, cx| codegen.set_active(false, cx)); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + text + ); + } + #[gpui::test] async fn test_strip_invalid_spans_from_codeblock() { assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await; diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index b3c8ef5f57..e1ba1c5886 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -76,6 +76,7 @@ impl Global for GlobalLanguageModelRegistry {} pub struct LanguageModelRegistry { active_model: Option, providers: BTreeMap>, + inline_alternatives: Vec>, } pub struct ActiveModel { @@ -229,6 +230,37 @@ impl LanguageModelRegistry { pub fn active_model(&self) -> Option> { self.active_model.as_ref()?.model.clone() } + + /// Selects and sets the inline alternatives for language models based on + /// provider name and id. + pub fn select_inline_alternative_models( + &mut self, + alternatives: impl IntoIterator, + cx: &mut ModelContext, + ) { + let mut selected_alternatives = Vec::new(); + + for (provider_id, model_id) in alternatives { + if let Some(provider) = self.providers.get(&provider_id) { + if let Some(model) = provider + .provided_models(cx) + .iter() + .find(|m| m.id() == model_id) + { + selected_alternatives.push(model.clone()); + } + } + } + + self.inline_alternatives = selected_alternatives; + } + + /// The models to use for inline assists. Returns the union of the active + /// model and all inline alternatives. When there are multiple models, the + /// user will be able to cycle through results. + pub fn inline_alternative_models(&self) -> &[Arc] { + &self.inline_alternatives + } } #[cfg(test)] diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index 29bd9a8068..c163dbc07a 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -1106,6 +1106,26 @@ impl MultiBuffer { } } + pub fn forget_transaction( + &mut self, + transaction_id: TransactionId, + cx: &mut ModelContext, + ) { + if let Some(buffer) = self.as_singleton() { + buffer.update(cx, |buffer, _| { + buffer.forget_transaction(transaction_id); + }); + } else if let Some(transaction) = self.history.forget(transaction_id) { + for (buffer_id, buffer_transaction_id) in transaction.buffer_transactions { + if let Some(state) = self.buffers.borrow_mut().get_mut(&buffer_id) { + state.buffer.update(cx, |buffer, _| { + buffer.forget_transaction(buffer_transaction_id); + }); + } + } + } + } + pub fn stream_excerpts_with_context_lines( &mut self, buffer: Model, diff --git a/docs/src/assistant/configuration.md b/docs/src/assistant/configuration.md index bcdf461e2c..17b52a27d8 100644 --- a/docs/src/assistant/configuration.md +++ b/docs/src/assistant/configuration.md @@ -20,6 +20,7 @@ To further customize providers, you can use `settings.json` to do that as follow - [Configuring endpoints](#custom-endpoint) - [Configuring timeouts](#provider-timeout) - [Configuring default model](#default-model) +- [Configuring alternative models for inline assists](#alternative-assists) ### Zed AI {#zed-ai} @@ -264,6 +265,31 @@ You can also manually edit the `default_model` object in your settings: } ``` +#### Configuring alternative models for inline assists {#alternative-assists} + +You can configure additional models that will be used to perform inline assists in parallel. When you do this, +the inline assist UI will surface controls to cycle between the alternatives generated by each model. The models +you specify here are always used in _addition_ to your default model. For example, the following configuration +will generate two outputs for every assist. One with Claude 3.5 Sonnet, and one with GPT-4o. + +```json +{ + "assistant": { + "default_model": { + "provider": "zed.dev", + "model": "claude-3-5-sonnet" + }, + "inline_alternatives": [ + { + "provider": "zed.dev", + "model": "gpt-4o" + } + ], + "version": "2" + } +} +``` + #### Common Panel Settings | key | type | default | description |