Improve model selection in the assistant (#12472)

https://github.com/zed-industries/zed/assets/482957/3b017850-b7b6-457a-9b2f-324d5533442e


Release Notes:

- Improved the UX for selecting a model in the assistant panel. You can
now switch model using just the keyboard by pressing `alt-m`. Also, when
switching models via the UI, settings will now be updated automatically.
This commit is contained in:
Antonio Scandurra 2024-05-30 12:36:07 +02:00 committed by GitHub
parent 5a149b970c
commit 6ff01b17ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 517 additions and 295 deletions

View file

@ -12,8 +12,11 @@ use serde::{
Deserialize, Deserializer, Serialize, Serializer,
};
use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator};
#[derive(Clone, Debug, Default, PartialEq)]
use crate::LanguageModel;
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum ZedDotDevModel {
Gpt3Point5Turbo,
Gpt4,
@ -53,13 +56,10 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
where
E: de::Error,
{
match value {
"gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
"gpt-4" => Ok(ZedDotDevModel::Gpt4),
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
"gpt-4o" => Ok(ZedDotDevModel::Gpt4Omni),
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
}
let model = ZedDotDevModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
Ok(model)
}
}
@ -73,24 +73,23 @@ impl JsonSchema for ZedDotDevModel {
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = vec![
"gpt-3.5-turbo".to_owned(),
"gpt-4".to_owned(),
"gpt-4-turbo-preview".to_owned(),
"gpt-4o".to_owned(),
];
let variants = ZedDotDevModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(serde_json::json!("gpt-4-turbo-preview")),
examples: vec![
serde_json::json!("gpt-3.5-turbo"),
serde_json::json!("gpt-4"),
serde_json::json!("gpt-4-turbo-preview"),
serde_json::json!("custom-model-name"),
],
default: Some(ZedDotDevModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
@ -145,51 +144,55 @@ pub enum AssistantDockPosition {
Bottom,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
#[derive(Debug, PartialEq)]
pub enum AssistantProvider {
#[serde(rename = "zed.dev")]
ZedDotDev {
#[serde(default)]
default_model: ZedDotDevModel,
model: ZedDotDevModel,
},
#[serde(rename = "openai")]
OpenAi {
#[serde(default)]
default_model: OpenAiModel,
#[serde(default = "open_ai_url")]
model: OpenAiModel,
api_url: String,
#[serde(default)]
low_speed_timeout_in_seconds: Option<u64>,
},
#[serde(rename = "anthropic")]
Anthropic {
#[serde(default)]
default_model: AnthropicModel,
#[serde(default = "anthropic_api_url")]
model: AnthropicModel,
api_url: String,
#[serde(default)]
low_speed_timeout_in_seconds: Option<u64>,
},
}
impl Default for AssistantProvider {
fn default() -> Self {
Self::ZedDotDev {
default_model: ZedDotDevModel::default(),
Self::OpenAi {
model: OpenAiModel::default(),
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
}
}
}
fn open_ai_url() -> String {
open_ai::OPEN_AI_API_URL.to_string()
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProviderContent {
#[serde(rename = "zed.dev")]
ZedDotDev {
default_model: Option<ZedDotDevModel>,
},
#[serde(rename = "openai")]
OpenAi {
default_model: Option<OpenAiModel>,
api_url: Option<String>,
low_speed_timeout_in_seconds: Option<u64>,
},
#[serde(rename = "anthropic")]
Anthropic {
default_model: Option<AnthropicModel>,
api_url: Option<String>,
low_speed_timeout_in_seconds: Option<u64>,
},
}
fn anthropic_api_url() -> String {
anthropic::ANTHROPIC_API_URL.to_string()
}
#[derive(Default, Debug, Deserialize, Serialize)]
#[derive(Debug, Default)]
pub struct AssistantSettings {
pub enabled: bool,
pub button: bool,
@ -240,16 +243,16 @@ impl AssistantSettingsContent {
default_width: settings.default_width,
default_height: settings.default_height,
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
Some(AssistantProvider::OpenAi {
default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
api_url: open_ai_api_url.clone(),
Some(AssistantProviderContent::OpenAi {
default_model: settings.default_open_ai_model.clone(),
api_url: Some(open_ai_api_url.clone()),
low_speed_timeout_in_seconds: None,
})
} else {
settings.default_open_ai_model.clone().map(|open_ai_model| {
AssistantProvider::OpenAi {
default_model: open_ai_model,
api_url: open_ai_url(),
AssistantProviderContent::OpenAi {
default_model: Some(open_ai_model),
api_url: None,
low_speed_timeout_in_seconds: None,
}
})
@ -270,6 +273,64 @@ impl AssistantSettingsContent {
}
}
}
pub fn set_model(&mut self, new_model: LanguageModel) {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
Some(AssistantProviderContent::ZedDotDev {
default_model: model,
}) => {
if let LanguageModel::ZedDotDev(new_model) = new_model {
*model = Some(new_model);
}
}
Some(AssistantProviderContent::OpenAi {
default_model: model,
..
}) => {
if let LanguageModel::OpenAi(new_model) = new_model {
*model = Some(new_model);
}
}
Some(AssistantProviderContent::Anthropic {
default_model: model,
..
}) => {
if let LanguageModel::Anthropic(new_model) = new_model {
*model = Some(new_model);
}
}
provider => match new_model {
LanguageModel::ZedDotDev(model) => {
*provider = Some(AssistantProviderContent::ZedDotDev {
default_model: Some(model),
})
}
LanguageModel::OpenAi(model) => {
*provider = Some(AssistantProviderContent::OpenAi {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
LanguageModel::Anthropic(model) => {
*provider = Some(AssistantProviderContent::Anthropic {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
},
},
},
AssistantSettingsContent::Legacy(settings) => {
if let LanguageModel::OpenAi(model) = new_model {
settings.default_open_ai_model = Some(model);
}
}
}
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@ -318,7 +379,7 @@ pub struct AssistantSettingsContentV1 {
///
/// This can either be the internal `zed.dev` service or an external `openai` service,
/// each with their respective default models and configurations.
provider: Option<AssistantProvider>,
provider: Option<AssistantProviderContent>,
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@ -376,31 +437,82 @@ impl Settings for AssistantSettings {
if let Some(provider) = value.provider.clone() {
match (&mut settings.provider, provider) {
(
AssistantProvider::ZedDotDev { default_model },
AssistantProvider::ZedDotDev {
default_model: default_model_override,
AssistantProvider::ZedDotDev { model },
AssistantProviderContent::ZedDotDev {
default_model: model_override,
},
) => {
*default_model = default_model_override;
merge(model, model_override);
}
(
AssistantProvider::OpenAi {
default_model,
model,
api_url,
low_speed_timeout_in_seconds,
},
AssistantProvider::OpenAi {
default_model: default_model_override,
AssistantProviderContent::OpenAi {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
*default_model = default_model_override;
*api_url = api_url_override;
*low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
merge(model, model_override);
merge(api_url, api_url_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(merged, provider_override) => {
*merged = provider_override;
(
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Anthropic {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(provider, provider_override) => {
*provider = match provider_override {
AssistantProviderContent::ZedDotDev {
default_model: model,
} => AssistantProvider::ZedDotDev {
model: model.unwrap_or_default(),
},
AssistantProviderContent::OpenAi {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::OpenAi {
model: model.unwrap_or_default(),
api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Anthropic {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::Anthropic {
model: model.unwrap_or_default(),
api_url: api_url
.unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
low_speed_timeout_in_seconds,
},
};
}
}
}
@ -410,7 +522,7 @@ impl Settings for AssistantSettings {
}
}
fn merge<T: Copy>(target: &mut T, value: Option<T>) {
fn merge<T>(target: &mut T, value: Option<T>) {
if let Some(value) = value {
*target = value;
}
@ -433,8 +545,8 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::FourOmni,
api_url: open_ai_url(),
model: OpenAiModel::FourOmni,
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
}
);
@ -455,7 +567,7 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::FourOmni,
model: OpenAiModel::FourOmni,
api_url: "test-url".into(),
low_speed_timeout_in_seconds: None,
}
@ -475,8 +587,8 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::Four,
api_url: open_ai_url(),
model: OpenAiModel::Four,
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
}
);
@ -501,7 +613,7 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
default_model: ZedDotDevModel::Custom("custom".into())
model: ZedDotDevModel::Custom("custom".into())
}
);
}