From f517050548fb81aaad159b44b0d6183961fc305d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 24 Feb 2025 08:29:55 +0100 Subject: [PATCH] Partially fix assistant onboarding (#25313) While investigating #24896, I noticed two issues: 1. The default configuration for the `zed.dev` provider was using the wrong string for Claude 3.5 Sonnet. This meant the provider would always result as not configured until the user selected it from the model picker, because we couldn't deserialize that string to a valid `anthropic::Model` enum variant. 2. When clicking on `Open New Chat`/`Start New Thread` in the provider configuration, we would select `Claude 3.5 Haiku` by default instead of Claude 3.5 Sonnet. Release Notes: - Fixed some issues that caused AI providers to sometimes be misconfigured. --- assets/settings/default.json | 2 +- crates/assistant/src/assistant_panel.rs | 2 +- crates/assistant2/src/assistant_panel.rs | 2 +- crates/assistant_settings/src/assistant_settings.rs | 2 +- crates/google_ai/src/google_ai.rs | 3 ++- crates/language_model/src/fake_provider.rs | 4 ++++ crates/language_model/src/language_model.rs | 1 + crates/language_models/src/provider/anthropic.rs | 11 +++++++++++ crates/language_models/src/provider/cloud.rs | 12 ++++++++++++ crates/language_models/src/provider/copilot_chat.rs | 8 ++++++++ crates/language_models/src/provider/deepseek.rs | 11 +++++++++++ crates/language_models/src/provider/google.rs | 11 +++++++++++ crates/language_models/src/provider/lmstudio.rs | 4 ++++ crates/language_models/src/provider/mistral.rs | 11 +++++++++++ crates/language_models/src/provider/ollama.rs | 4 ++++ crates/language_models/src/provider/open_ai.rs | 11 +++++++++++ 16 files changed, 94 insertions(+), 5 deletions(-) diff --git a/assets/settings/default.json b/assets/settings/default.json index d27f9a2fd1..8183c3d60e 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -581,7 +581,7 @@ // The provider to use. "provider": "zed.dev", // The model to use. - "model": "claude-3-5-sonnet" + "model": "claude-3-5-sonnet-latest" } }, // The settings for slash commands. diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ecf2e2c421..e0791e0039 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -978,7 +978,7 @@ impl AssistantPanel { .active_provider() .map_or(true, |p| p.id() != provider.id()) { - if let Some(model) = provider.provided_models(cx).first().cloned() { + if let Some(model) = provider.default_model(cx) { update_settings_file::( this.fs.clone(), cx, diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index f83a6dc75d..fb94a18e99 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -431,7 +431,7 @@ impl AssistantPanel { active_provider.id() != provider.id() }) { - if let Some(model) = provider.provided_models(cx).first().cloned() { + if let Some(model) = provider.default_model(cx) { update_settings_file::( self.fs.clone(), cx, diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 56e801fad3..5e044282b0 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -512,7 +512,7 @@ mod tests { AssistantSettings::get_global(cx).default_model, LanguageModelSelection { provider: "zed.dev".into(), - model: "claude-3-5-sonnet".into(), + model: "claude-3-5-sonnet-latest".into(), } ); }); diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index ace7ea22c4..e885599a0f 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -299,7 +299,7 @@ pub struct CountTokensResponse { } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] +#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] pub enum Model { #[serde(rename = "gemini-1.5-pro")] Gemini15Pro, @@ -308,6 +308,7 @@ pub enum Model { #[serde(rename = "gemini-2.0-pro-exp")] Gemini20Pro, #[serde(rename = "gemini-2.0-flash")] + #[default] Gemini20Flash, #[serde(rename = "gemini-2.0-flash-thinking-exp")] Gemini20FlashThinking, diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index a955638b21..0e4c0748fc 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -46,6 +46,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider { provider_name() } + fn default_model(&self, _cx: &App) -> Option> { + Some(Arc::new(FakeLanguageModel::default())) + } + fn provided_models(&self, _: &App) -> Vec> { vec![Arc::new(FakeLanguageModel::default())] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 6219fda739..7b50702a6e 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -247,6 +247,7 @@ pub trait LanguageModelProvider: 'static { fn icon(&self) -> IconName { IconName::ZedAssistant } + fn default_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; fn load_model(&self, _model: Arc, _cx: &App) {} fn is_authenticated(&self, cx: &App) -> bool; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index e3ca4998fe..9908929457 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -183,6 +183,17 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { IconName::AiAnthropic } + fn default_model(&self, _cx: &App) -> Option> { + let model = anthropic::Model::default(); + Some(Arc::new(AnthropicModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 05544f40db..236b78527b 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -272,6 +272,18 @@ impl LanguageModelProvider for CloudLanguageModelProvider { IconName::AiZed } + fn default_model(&self, cx: &App) -> Option> { + let llm_api_token = self.state.read(cx).llm_api_token.clone(); + let model = CloudModel::Anthropic(anthropic::Model::default()); + Some(Arc::new(CloudLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + llm_api_token: llm_api_token.clone(), + client: self.client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 1c4a4273ac..7bf2cfe4f6 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -89,6 +89,14 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { IconName::Copilot } + fn default_model(&self, _cx: &App) -> Option> { + let model = CopilotChatModel::default(); + Some(Arc::new(CopilotChatLanguageModel { + model, + request_limiter: RateLimiter::new(4), + }) as Arc) + } + fn provided_models(&self, _cx: &App) -> Vec> { CopilotChatModel::iter() .map(|model| { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 91cc02149d..830e94ecb5 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -163,6 +163,17 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { IconName::AiDeepSeek } + fn default_model(&self, _cx: &App) -> Option> { + let model = deepseek::Model::Chat; + Some(Arc::new(DeepSeekLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 9e313935c2..0bf5001f79 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -166,6 +166,17 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { IconName::AiGoogle } + fn default_model(&self, _cx: &App) -> Option> { + let model = google_ai::Model::default(); + Some(Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + rate_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 76832a44e1..edd07c053a 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -152,6 +152,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { IconName::AiLmStudio } + fn default_model(&self, cx: &App) -> Option> { + self.provided_models(cx).into_iter().next() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 032ee38c42..80a5988cff 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -167,6 +167,17 @@ impl LanguageModelProvider for MistralLanguageModelProvider { IconName::AiMistral } + fn default_model(&self, _cx: &App) -> Option> { + let model = mistral::Model::default(); + Some(Arc::new(MistralLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index a982eb3aa7..33ad0bcafd 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -157,6 +157,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { IconName::AiOllama } + fn default_model(&self, cx: &App) -> Option> { + self.provided_models(cx).into_iter().next() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index ee277247b8..3e46983ebb 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -169,6 +169,17 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { IconName::AiOpenAi } + fn default_model(&self, _cx: &App) -> Option> { + let model = open_ai::Model::default(); + Some(Arc::new(OpenAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default();