Improve model selection in the assistant (#12472)

https://github.com/zed-industries/zed/assets/482957/3b017850-b7b6-457a-9b2f-324d5533442e


Release Notes:

- Improved the UX for selecting a model in the assistant panel. You can
now switch model using just the keyboard by pressing `alt-m`. Also, when
switching models via the UI, settings will now be updated automatically.
This commit is contained in:
Antonio Scandurra 2024-05-30 12:36:07 +02:00 committed by GitHub
parent 5a149b970c
commit 6ff01b17ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 517 additions and 295 deletions

View file

@ -1,7 +1,7 @@
use crate::prompts::{generate_content_prompt, PromptLibrary, PromptManager};
use crate::slash_command::{rustdoc_command, search_command, tabs_command};
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel},
assistant_settings::{AssistantDockPosition, AssistantSettings},
codegen::{self, Codegen, CodegenKind},
search::*,
slash_command::{
@ -9,10 +9,11 @@ use crate::{
SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
},
ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist,
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata,
MessageStatus, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata,
SavedMessage, Split, ToggleFocus, ToggleHistory,
LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
Split, ToggleFocus, ToggleHistory,
};
use crate::{ModelSelector, ToggleModelSelector};
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
use client::telemetry::Telemetry;
@ -64,8 +65,8 @@ use std::{
use telemetry_events::AssistantKind;
use theme::ThemeSettings;
use ui::{
popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding, Tab, TabBar,
Tooltip,
popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding,
PopoverMenuHandle, Tab, TabBar, Tooltip,
};
use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
use uuid::Uuid;
@ -119,8 +120,8 @@ pub struct AssistantPanel {
pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>,
inline_prompt_history: VecDeque<String>,
_watch_saved_conversations: Task<Result<()>>,
model: LanguageModel,
authentication_prompt: Option<AnyView>,
model_menu_handle: PopoverMenuHandle<ContextMenu>,
}
struct ActiveConversationEditor {
@ -203,7 +204,6 @@ impl AssistantPanel {
}
}),
];
let model = CompletionProvider::global(cx).default_model();
cx.observe_global::<FileIcons>(|_, cx| {
cx.notify();
@ -244,8 +244,8 @@ impl AssistantPanel {
pending_inline_assist_ids_by_editor: Default::default(),
inline_prompt_history: Default::default(),
_watch_saved_conversations,
model,
authentication_prompt: None,
model_menu_handle: PopoverMenuHandle::default(),
}
})
})
@ -277,12 +277,20 @@ impl AssistantPanel {
if self.is_authenticated(cx) {
self.authentication_prompt = None;
let model = CompletionProvider::global(cx).default_model();
self.set_model(model, cx);
if let Some(editor) = self.active_conversation_editor() {
editor.update(cx, |active_conversation, cx| {
active_conversation
.conversation
.update(cx, |conversation, cx| {
conversation.completion_provider_changed(cx)
})
})
}
if self.active_conversation_editor().is_none() {
self.new_conversation(cx);
}
cx.notify();
} else if self.authentication_prompt.is_none()
|| prev_settings_version != CompletionProvider::global(cx).settings_version()
{
@ -290,6 +298,7 @@ impl AssistantPanel {
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider.authentication_prompt(cx)
}));
cx.notify();
}
}
@ -734,7 +743,7 @@ impl AssistantPanel {
.map(|message| message.to_request_message(buffer)),
);
}
let model = self.model.clone();
let model = CompletionProvider::global(cx).model();
cx.spawn(|_, mut cx| async move {
// I Don't know if we want to return a ? here.
@ -809,7 +818,6 @@ impl AssistantPanel {
let editor = cx.new_view(|cx| {
ConversationEditor::new(
self.model.clone(),
self.languages.clone(),
self.slash_commands.clone(),
self.fs.clone(),
@ -850,53 +858,6 @@ impl AssistantPanel {
cx.notify();
}
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
let next_model = match &self.model {
LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model {
open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four,
open_ai::Model::Four => open_ai::Model::FourTurbo,
open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
}),
LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
}),
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Gpt4Omni,
ZedDotDevModel::Gpt4Omni => ZedDotDevModel::Claude3Opus,
ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
ZedDotDevModel::Claude3Haiku => {
match CompletionProvider::global(cx).default_model() {
LanguageModel::ZedDotDev(custom @ ZedDotDevModel::Custom(_)) => custom,
_ => ZedDotDevModel::Gpt3Point5Turbo,
}
}
ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
}),
};
self.set_model(next_model, cx);
}
fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext<Self>) {
self.model = model.clone();
if let Some(editor) = self.active_conversation_editor() {
editor.update(cx, |active_conversation, cx| {
active_conversation
.conversation
.update(cx, |conversation, cx| {
conversation.set_model(model, cx);
})
})
}
cx.notify();
}
fn handle_conversation_editor_event(
&mut self,
_: View<ConversationEditor>,
@ -978,6 +939,10 @@ impl AssistantPanel {
.detach_and_log_err(cx);
}
fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
self.model_menu_handle.toggle(cx);
}
fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> {
Some(&self.active_conversation_editor.as_ref()?.editor)
}
@ -1133,10 +1098,8 @@ impl AssistantPanel {
cx.spawn(|this, mut cx| async move {
let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?;
let model = this.update(&mut cx, |this, _| this.model.clone())?;
let conversation = Conversation::deserialize(
saved_conversation,
model,
path.clone(),
languages,
slash_commands,
@ -1206,7 +1169,10 @@ impl AssistantPanel {
this.child(
h_flex()
.gap_1()
.child(self.render_model(&conversation, cx))
.child(ModelSelector::new(
self.model_menu_handle.clone(),
self.fs.clone(),
))
.children(self.render_remaining_tokens(&conversation, cx)),
)
.child(
@ -1256,6 +1222,7 @@ impl AssistantPanel {
.on_action(cx.listener(AssistantPanel::select_prev_match))
.on_action(cx.listener(AssistantPanel::handle_editor_cancel))
.on_action(cx.listener(AssistantPanel::reset_credentials))
.on_action(cx.listener(AssistantPanel::toggle_model_selector))
.track_focus(&self.focus_handle)
.child(header)
.children(if self.toolbar.read(cx).hidden() {
@ -1314,23 +1281,12 @@ impl AssistantPanel {
))
}
fn render_model(
&self,
conversation: &Model<Conversation>,
cx: &mut ViewContext<Self>,
) -> impl IntoElement {
Button::new("current_model", conversation.read(cx).model.display_name())
.style(ButtonStyle::Filled)
.tooltip(move |cx| Tooltip::text("Change Model", cx))
.on_click(cx.listener(|this, _, cx| this.cycle_model(cx)))
}
fn render_remaining_tokens(
&self,
conversation: &Model<Conversation>,
cx: &mut ViewContext<Self>,
) -> Option<impl IntoElement> {
let remaining_tokens = conversation.read(cx).remaining_tokens()?;
let remaining_tokens = conversation.read(cx).remaining_tokens(cx)?;
let remaining_tokens_color = if remaining_tokens <= 0 {
Color::Error
} else if remaining_tokens <= 500 {
@ -1486,7 +1442,6 @@ pub struct Conversation {
pending_summary: Task<Option<()>>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
model: LanguageModel,
token_count: Option<usize>,
pending_token_count: Task<Option<()>>,
pending_edit_suggestion_parse: Option<Task<()>>,
@ -1502,7 +1457,6 @@ impl EventEmitter<ConversationEvent> for Conversation {}
impl Conversation {
fn new(
model: LanguageModel,
language_registry: Arc<LanguageRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>,
telemetry: Option<Arc<Telemetry>>,
@ -1530,7 +1484,6 @@ impl Conversation {
token_count: None,
pending_token_count: Task::ready(None),
pending_edit_suggestion_parse: None,
model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
@ -1583,7 +1536,6 @@ impl Conversation {
#[allow(clippy::too_many_arguments)]
async fn deserialize(
saved_conversation: SavedConversation,
model: LanguageModel,
path: PathBuf,
language_registry: Arc<LanguageRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>,
@ -1640,7 +1592,6 @@ impl Conversation {
token_count: None,
pending_edit_suggestion_parse: None,
pending_token_count: Task::ready(None),
model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
@ -1938,12 +1889,12 @@ impl Conversation {
}
}
fn remaining_tokens(&self) -> Option<isize> {
Some(self.model.max_token_count() as isize - self.token_count? as isize)
fn remaining_tokens(&self, cx: &AppContext) -> Option<isize> {
let model = CompletionProvider::global(cx).model();
Some(model.max_token_count() as isize - self.token_count? as isize)
}
fn set_model(&mut self, model: LanguageModel, cx: &mut ModelContext<Self>) {
self.model = model;
fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
self.count_remaining_tokens(cx);
}
@ -2079,10 +2030,11 @@ impl Conversation {
}
if let Some(telemetry) = this.telemetry.as_ref() {
let model = CompletionProvider::global(cx).model();
telemetry.report_assistant_event(
this.id.clone(),
AssistantKind::Panel,
this.model.telemetry_id(),
model.telemetry_id(),
response_latency,
error_message,
);
@ -2111,7 +2063,7 @@ impl Conversation {
.map(|message| message.to_request_message(self.buffer.read(cx)));
LanguageModelRequest {
model: self.model.clone(),
model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
@ -2300,7 +2252,7 @@ impl Conversation {
.into(),
}));
let request = LanguageModelRequest {
model: self.model.clone(),
model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
@ -2605,7 +2557,6 @@ pub struct ConversationEditor {
impl ConversationEditor {
fn new(
model: LanguageModel,
language_registry: Arc<LanguageRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>,
fs: Arc<dyn Fs>,
@ -2618,7 +2569,6 @@ impl ConversationEditor {
let conversation = cx.new_model(|cx| {
Conversation::new(
model,
language_registry,
slash_command_registry,
Some(telemetry),
@ -3847,15 +3797,8 @@ mod tests {
init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let conversation = cx.new_model(|cx| {
Conversation::new(
LanguageModel::default(),
registry,
Default::default(),
None,
cx,
)
});
let conversation =
cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3986,15 +3929,8 @@ mod tests {
init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let conversation = cx.new_model(|cx| {
Conversation::new(
LanguageModel::default(),
registry,
Default::default(),
None,
cx,
)
});
let conversation =
cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -4092,15 +4028,8 @@ mod tests {
cx.set_global(settings_store);
init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let conversation = cx.new_model(|cx| {
Conversation::new(
LanguageModel::default(),
registry,
Default::default(),
None,
cx,
)
});
let conversation =
cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -4209,15 +4138,8 @@ mod tests {
));
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let conversation = cx.new_model(|cx| {
Conversation::new(
LanguageModel::default(),
registry.clone(),
slash_command_registry,
None,
cx,
)
});
let conversation = cx
.new_model(|cx| Conversation::new(registry.clone(), slash_command_registry, None, cx));
let output_ranges = Rc::new(RefCell::new(HashSet::default()));
conversation.update(cx, |_, cx| {
@ -4390,15 +4312,8 @@ mod tests {
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
cx.update(init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let conversation = cx.new_model(|cx| {
Conversation::new(
LanguageModel::default(),
registry.clone(),
Default::default(),
None,
cx,
)
});
let conversation =
cx.new_model(|cx| Conversation::new(registry.clone(), Default::default(), None, cx));
let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
let message_0 =
conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
@ -4434,7 +4349,6 @@ mod tests {
let deserialized_conversation = Conversation::deserialize(
conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
LanguageModel::default(),
Default::default(),
registry.clone(),
Default::default(),