Improve model selection in the assistant (#12472)
https://github.com/zed-industries/zed/assets/482957/3b017850-b7b6-457a-9b2f-324d5533442e Release Notes: - Improved the UX for selecting a model in the assistant panel. You can now switch model using just the keyboard by pressing `alt-m`. Also, when switching models via the UI, settings will now be updated automatically.
This commit is contained in:
parent
5a149b970c
commit
6ff01b17ca
17 changed files with 517 additions and 295 deletions
|
@ -1,7 +1,7 @@
|
|||
use crate::prompts::{generate_content_prompt, PromptLibrary, PromptManager};
|
||||
use crate::slash_command::{rustdoc_command, search_command, tabs_command};
|
||||
use crate::{
|
||||
assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel},
|
||||
assistant_settings::{AssistantDockPosition, AssistantSettings},
|
||||
codegen::{self, Codegen, CodegenKind},
|
||||
search::*,
|
||||
slash_command::{
|
||||
|
@ -9,10 +9,11 @@ use crate::{
|
|||
SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
|
||||
},
|
||||
ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist,
|
||||
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata,
|
||||
MessageStatus, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata,
|
||||
SavedMessage, Split, ToggleFocus, ToggleHistory,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
|
||||
QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
|
||||
Split, ToggleFocus, ToggleHistory,
|
||||
};
|
||||
use crate::{ModelSelector, ToggleModelSelector};
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
|
||||
use client::telemetry::Telemetry;
|
||||
|
@ -64,8 +65,8 @@ use std::{
|
|||
use telemetry_events::AssistantKind;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{
|
||||
popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding, Tab, TabBar,
|
||||
Tooltip,
|
||||
popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding,
|
||||
PopoverMenuHandle, Tab, TabBar, Tooltip,
|
||||
};
|
||||
use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
|
||||
use uuid::Uuid;
|
||||
|
@ -119,8 +120,8 @@ pub struct AssistantPanel {
|
|||
pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>,
|
||||
inline_prompt_history: VecDeque<String>,
|
||||
_watch_saved_conversations: Task<Result<()>>,
|
||||
model: LanguageModel,
|
||||
authentication_prompt: Option<AnyView>,
|
||||
model_menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
}
|
||||
|
||||
struct ActiveConversationEditor {
|
||||
|
@ -203,7 +204,6 @@ impl AssistantPanel {
|
|||
}
|
||||
}),
|
||||
];
|
||||
let model = CompletionProvider::global(cx).default_model();
|
||||
|
||||
cx.observe_global::<FileIcons>(|_, cx| {
|
||||
cx.notify();
|
||||
|
@ -244,8 +244,8 @@ impl AssistantPanel {
|
|||
pending_inline_assist_ids_by_editor: Default::default(),
|
||||
inline_prompt_history: Default::default(),
|
||||
_watch_saved_conversations,
|
||||
model,
|
||||
authentication_prompt: None,
|
||||
model_menu_handle: PopoverMenuHandle::default(),
|
||||
}
|
||||
})
|
||||
})
|
||||
|
@ -277,12 +277,20 @@ impl AssistantPanel {
|
|||
if self.is_authenticated(cx) {
|
||||
self.authentication_prompt = None;
|
||||
|
||||
let model = CompletionProvider::global(cx).default_model();
|
||||
self.set_model(model, cx);
|
||||
if let Some(editor) = self.active_conversation_editor() {
|
||||
editor.update(cx, |active_conversation, cx| {
|
||||
active_conversation
|
||||
.conversation
|
||||
.update(cx, |conversation, cx| {
|
||||
conversation.completion_provider_changed(cx)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if self.active_conversation_editor().is_none() {
|
||||
self.new_conversation(cx);
|
||||
}
|
||||
cx.notify();
|
||||
} else if self.authentication_prompt.is_none()
|
||||
|| prev_settings_version != CompletionProvider::global(cx).settings_version()
|
||||
{
|
||||
|
@ -290,6 +298,7 @@ impl AssistantPanel {
|
|||
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
provider.authentication_prompt(cx)
|
||||
}));
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -734,7 +743,7 @@ impl AssistantPanel {
|
|||
.map(|message| message.to_request_message(buffer)),
|
||||
);
|
||||
}
|
||||
let model = self.model.clone();
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
// I Don't know if we want to return a ? here.
|
||||
|
@ -809,7 +818,6 @@ impl AssistantPanel {
|
|||
|
||||
let editor = cx.new_view(|cx| {
|
||||
ConversationEditor::new(
|
||||
self.model.clone(),
|
||||
self.languages.clone(),
|
||||
self.slash_commands.clone(),
|
||||
self.fs.clone(),
|
||||
|
@ -850,53 +858,6 @@ impl AssistantPanel {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let next_model = match &self.model {
|
||||
LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model {
|
||||
open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four,
|
||||
open_ai::Model::Four => open_ai::Model::FourTurbo,
|
||||
open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
|
||||
open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
|
||||
}),
|
||||
LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
|
||||
anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
|
||||
anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
|
||||
anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
|
||||
}),
|
||||
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
|
||||
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
|
||||
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
|
||||
ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Gpt4Omni,
|
||||
ZedDotDevModel::Gpt4Omni => ZedDotDevModel::Claude3Opus,
|
||||
ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
|
||||
ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
|
||||
ZedDotDevModel::Claude3Haiku => {
|
||||
match CompletionProvider::global(cx).default_model() {
|
||||
LanguageModel::ZedDotDev(custom @ ZedDotDevModel::Custom(_)) => custom,
|
||||
_ => ZedDotDevModel::Gpt3Point5Turbo,
|
||||
}
|
||||
}
|
||||
ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
|
||||
}),
|
||||
};
|
||||
|
||||
self.set_model(next_model, cx);
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext<Self>) {
|
||||
self.model = model.clone();
|
||||
if let Some(editor) = self.active_conversation_editor() {
|
||||
editor.update(cx, |active_conversation, cx| {
|
||||
active_conversation
|
||||
.conversation
|
||||
.update(cx, |conversation, cx| {
|
||||
conversation.set_model(model, cx);
|
||||
})
|
||||
})
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_conversation_editor_event(
|
||||
&mut self,
|
||||
_: View<ConversationEditor>,
|
||||
|
@ -978,6 +939,10 @@ impl AssistantPanel {
|
|||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
|
||||
self.model_menu_handle.toggle(cx);
|
||||
}
|
||||
|
||||
fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> {
|
||||
Some(&self.active_conversation_editor.as_ref()?.editor)
|
||||
}
|
||||
|
@ -1133,10 +1098,8 @@ impl AssistantPanel {
|
|||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?;
|
||||
let model = this.update(&mut cx, |this, _| this.model.clone())?;
|
||||
let conversation = Conversation::deserialize(
|
||||
saved_conversation,
|
||||
model,
|
||||
path.clone(),
|
||||
languages,
|
||||
slash_commands,
|
||||
|
@ -1206,7 +1169,10 @@ impl AssistantPanel {
|
|||
this.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(self.render_model(&conversation, cx))
|
||||
.child(ModelSelector::new(
|
||||
self.model_menu_handle.clone(),
|
||||
self.fs.clone(),
|
||||
))
|
||||
.children(self.render_remaining_tokens(&conversation, cx)),
|
||||
)
|
||||
.child(
|
||||
|
@ -1256,6 +1222,7 @@ impl AssistantPanel {
|
|||
.on_action(cx.listener(AssistantPanel::select_prev_match))
|
||||
.on_action(cx.listener(AssistantPanel::handle_editor_cancel))
|
||||
.on_action(cx.listener(AssistantPanel::reset_credentials))
|
||||
.on_action(cx.listener(AssistantPanel::toggle_model_selector))
|
||||
.track_focus(&self.focus_handle)
|
||||
.child(header)
|
||||
.children(if self.toolbar.read(cx).hidden() {
|
||||
|
@ -1314,23 +1281,12 @@ impl AssistantPanel {
|
|||
))
|
||||
}
|
||||
|
||||
fn render_model(
|
||||
&self,
|
||||
conversation: &Model<Conversation>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> impl IntoElement {
|
||||
Button::new("current_model", conversation.read(cx).model.display_name())
|
||||
.style(ButtonStyle::Filled)
|
||||
.tooltip(move |cx| Tooltip::text("Change Model", cx))
|
||||
.on_click(cx.listener(|this, _, cx| this.cycle_model(cx)))
|
||||
}
|
||||
|
||||
fn render_remaining_tokens(
|
||||
&self,
|
||||
conversation: &Model<Conversation>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Option<impl IntoElement> {
|
||||
let remaining_tokens = conversation.read(cx).remaining_tokens()?;
|
||||
let remaining_tokens = conversation.read(cx).remaining_tokens(cx)?;
|
||||
let remaining_tokens_color = if remaining_tokens <= 0 {
|
||||
Color::Error
|
||||
} else if remaining_tokens <= 500 {
|
||||
|
@ -1486,7 +1442,6 @@ pub struct Conversation {
|
|||
pending_summary: Task<Option<()>>,
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
model: LanguageModel,
|
||||
token_count: Option<usize>,
|
||||
pending_token_count: Task<Option<()>>,
|
||||
pending_edit_suggestion_parse: Option<Task<()>>,
|
||||
|
@ -1502,7 +1457,6 @@ impl EventEmitter<ConversationEvent> for Conversation {}
|
|||
|
||||
impl Conversation {
|
||||
fn new(
|
||||
model: LanguageModel,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
slash_command_registry: Arc<SlashCommandRegistry>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
|
@ -1530,7 +1484,6 @@ impl Conversation {
|
|||
token_count: None,
|
||||
pending_token_count: Task::ready(None),
|
||||
pending_edit_suggestion_parse: None,
|
||||
model,
|
||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||
pending_save: Task::ready(Ok(())),
|
||||
path: None,
|
||||
|
@ -1583,7 +1536,6 @@ impl Conversation {
|
|||
#[allow(clippy::too_many_arguments)]
|
||||
async fn deserialize(
|
||||
saved_conversation: SavedConversation,
|
||||
model: LanguageModel,
|
||||
path: PathBuf,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
slash_command_registry: Arc<SlashCommandRegistry>,
|
||||
|
@ -1640,7 +1592,6 @@ impl Conversation {
|
|||
token_count: None,
|
||||
pending_edit_suggestion_parse: None,
|
||||
pending_token_count: Task::ready(None),
|
||||
model,
|
||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||
pending_save: Task::ready(Ok(())),
|
||||
path: Some(path),
|
||||
|
@ -1938,12 +1889,12 @@ impl Conversation {
|
|||
}
|
||||
}
|
||||
|
||||
fn remaining_tokens(&self) -> Option<isize> {
|
||||
Some(self.model.max_token_count() as isize - self.token_count? as isize)
|
||||
fn remaining_tokens(&self, cx: &AppContext) -> Option<isize> {
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
Some(model.max_token_count() as isize - self.token_count? as isize)
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: LanguageModel, cx: &mut ModelContext<Self>) {
|
||||
self.model = model;
|
||||
fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
|
||||
self.count_remaining_tokens(cx);
|
||||
}
|
||||
|
||||
|
@ -2079,10 +2030,11 @@ impl Conversation {
|
|||
}
|
||||
|
||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
telemetry.report_assistant_event(
|
||||
this.id.clone(),
|
||||
AssistantKind::Panel,
|
||||
this.model.telemetry_id(),
|
||||
model.telemetry_id(),
|
||||
response_latency,
|
||||
error_message,
|
||||
);
|
||||
|
@ -2111,7 +2063,7 @@ impl Conversation {
|
|||
.map(|message| message.to_request_message(self.buffer.read(cx)));
|
||||
|
||||
LanguageModelRequest {
|
||||
model: self.model.clone(),
|
||||
model: CompletionProvider::global(cx).model(),
|
||||
messages: messages.collect(),
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
|
@ -2300,7 +2252,7 @@ impl Conversation {
|
|||
.into(),
|
||||
}));
|
||||
let request = LanguageModelRequest {
|
||||
model: self.model.clone(),
|
||||
model: CompletionProvider::global(cx).model(),
|
||||
messages: messages.collect(),
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
|
@ -2605,7 +2557,6 @@ pub struct ConversationEditor {
|
|||
|
||||
impl ConversationEditor {
|
||||
fn new(
|
||||
model: LanguageModel,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
slash_command_registry: Arc<SlashCommandRegistry>,
|
||||
fs: Arc<dyn Fs>,
|
||||
|
@ -2618,7 +2569,6 @@ impl ConversationEditor {
|
|||
|
||||
let conversation = cx.new_model(|cx| {
|
||||
Conversation::new(
|
||||
model,
|
||||
language_registry,
|
||||
slash_command_registry,
|
||||
Some(telemetry),
|
||||
|
@ -3847,15 +3797,8 @@ mod tests {
|
|||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
||||
let conversation = cx.new_model(|cx| {
|
||||
Conversation::new(
|
||||
LanguageModel::default(),
|
||||
registry,
|
||||
Default::default(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let conversation =
|
||||
cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||
|
@ -3986,15 +3929,8 @@ mod tests {
|
|||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
||||
let conversation = cx.new_model(|cx| {
|
||||
Conversation::new(
|
||||
LanguageModel::default(),
|
||||
registry,
|
||||
Default::default(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let conversation =
|
||||
cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||
|
@ -4092,15 +4028,8 @@ mod tests {
|
|||
cx.set_global(settings_store);
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
let conversation = cx.new_model(|cx| {
|
||||
Conversation::new(
|
||||
LanguageModel::default(),
|
||||
registry,
|
||||
Default::default(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let conversation =
|
||||
cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||
|
@ -4209,15 +4138,8 @@ mod tests {
|
|||
));
|
||||
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let conversation = cx.new_model(|cx| {
|
||||
Conversation::new(
|
||||
LanguageModel::default(),
|
||||
registry.clone(),
|
||||
slash_command_registry,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let conversation = cx
|
||||
.new_model(|cx| Conversation::new(registry.clone(), slash_command_registry, None, cx));
|
||||
|
||||
let output_ranges = Rc::new(RefCell::new(HashSet::default()));
|
||||
conversation.update(cx, |_, cx| {
|
||||
|
@ -4390,15 +4312,8 @@ mod tests {
|
|||
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
|
||||
cx.update(init);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let conversation = cx.new_model(|cx| {
|
||||
Conversation::new(
|
||||
LanguageModel::default(),
|
||||
registry.clone(),
|
||||
Default::default(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let conversation =
|
||||
cx.new_model(|cx| Conversation::new(registry.clone(), Default::default(), None, cx));
|
||||
let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
|
||||
let message_0 =
|
||||
conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
|
||||
|
@ -4434,7 +4349,6 @@ mod tests {
|
|||
|
||||
let deserialized_conversation = Conversation::deserialize(
|
||||
conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
|
||||
LanguageModel::default(),
|
||||
Default::default(),
|
||||
registry.clone(),
|
||||
Default::default(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue