Improve model selection in the assistant (#12472)

https://github.com/zed-industries/zed/assets/482957/3b017850-b7b6-457a-9b2f-324d5533442e


Release Notes:

- Improved the UX for selecting a model in the assistant panel. You can
now switch model using just the keyboard by pressing `alt-m`. Also, when
switching models via the UI, settings will now be updated automatically.
This commit is contained in:
Antonio Scandurra 2024-05-30 12:36:07 +02:00 committed by GitHub
parent 5a149b970c
commit 6ff01b17ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 517 additions and 295 deletions

View file

@ -25,31 +25,26 @@ use std::time::Duration;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { default_model } => {
CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
default_model.clone(),
client.clone(),
settings_version,
cx,
))
}
AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
),
AssistantProvider::OpenAi {
default_model,
model,
api_url,
low_speed_timeout_in_seconds,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(),
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
)),
AssistantProvider::Anthropic {
default_model,
model,
api_url,
low_speed_timeout_in_seconds,
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
default_model.clone(),
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -65,13 +60,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
(
CompletionProvider::OpenAi(provider),
AssistantProvider::OpenAi {
default_model,
model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
provider.update(
default_model.clone(),
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
@ -80,13 +75,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
(
CompletionProvider::Anthropic(provider),
AssistantProvider::Anthropic {
default_model,
model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
provider.update(
default_model.clone(),
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
@ -94,13 +89,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
}
(
CompletionProvider::ZedDotDev(provider),
AssistantProvider::ZedDotDev { default_model },
AssistantProvider::ZedDotDev { model },
) => {
provider.update(default_model.clone(), settings_version);
provider.update(model.clone(), settings_version);
}
(_, AssistantProvider::ZedDotDev { default_model }) => {
(_, AssistantProvider::ZedDotDev { model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
default_model.clone(),
model.clone(),
client.clone(),
settings_version,
cx,
@ -109,13 +104,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
(
_,
AssistantProvider::OpenAi {
default_model,
model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(),
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -125,13 +120,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
(
_,
AssistantProvider::Anthropic {
default_model,
model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
default_model.clone(),
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -159,6 +154,25 @@ impl CompletionProvider {
cx.global::<Self>()
}
pub fn available_models(&self) -> Vec<LanguageModel> {
match self {
CompletionProvider::OpenAi(provider) => provider
.available_models()
.map(LanguageModel::OpenAi)
.collect(),
CompletionProvider::Anthropic(provider) => provider
.available_models()
.map(LanguageModel::Anthropic)
.collect(),
CompletionProvider::ZedDotDev(provider) => provider
.available_models()
.map(LanguageModel::ZedDotDev)
.collect(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn settings_version(&self) -> usize {
match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(),
@ -209,17 +223,13 @@ impl CompletionProvider {
}
}
pub fn default_model(&self) -> LanguageModel {
pub fn model(&self) -> LanguageModel {
match self {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
CompletionProvider::Anthropic(provider) => {
LanguageModel::Anthropic(provider.default_model())
}
CompletionProvider::ZedDotDev(provider) => {
LanguageModel::ZedDotDev(provider.default_model())
}
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
CompletionProvider::Fake(_) => LanguageModel::default(),
}
}