diff --git a/assets/settings/default.json b/assets/settings/default.json index d27f9a2fd1..8183c3d60e 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -581,7 +581,7 @@ // The provider to use. "provider": "zed.dev", // The model to use. - "model": "claude-3-5-sonnet" + "model": "claude-3-5-sonnet-latest" } }, // The settings for slash commands. diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ecf2e2c421..e0791e0039 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -978,7 +978,7 @@ impl AssistantPanel { .active_provider() .map_or(true, |p| p.id() != provider.id()) { - if let Some(model) = provider.provided_models(cx).first().cloned() { + if let Some(model) = provider.default_model(cx) { update_settings_file::( this.fs.clone(), cx, diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index f83a6dc75d..fb94a18e99 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -431,7 +431,7 @@ impl AssistantPanel { active_provider.id() != provider.id() }) { - if let Some(model) = provider.provided_models(cx).first().cloned() { + if let Some(model) = provider.default_model(cx) { update_settings_file::( self.fs.clone(), cx, diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 56e801fad3..5e044282b0 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -512,7 +512,7 @@ mod tests { AssistantSettings::get_global(cx).default_model, LanguageModelSelection { provider: "zed.dev".into(), - model: "claude-3-5-sonnet".into(), + model: "claude-3-5-sonnet-latest".into(), } ); }); diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index ace7ea22c4..e885599a0f 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -299,7 +299,7 @@ pub struct CountTokensResponse { } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] +#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] pub enum Model { #[serde(rename = "gemini-1.5-pro")] Gemini15Pro, @@ -308,6 +308,7 @@ pub enum Model { #[serde(rename = "gemini-2.0-pro-exp")] Gemini20Pro, #[serde(rename = "gemini-2.0-flash")] + #[default] Gemini20Flash, #[serde(rename = "gemini-2.0-flash-thinking-exp")] Gemini20FlashThinking, diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index a955638b21..0e4c0748fc 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -46,6 +46,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider { provider_name() } + fn default_model(&self, _cx: &App) -> Option> { + Some(Arc::new(FakeLanguageModel::default())) + } + fn provided_models(&self, _: &App) -> Vec> { vec![Arc::new(FakeLanguageModel::default())] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 6219fda739..7b50702a6e 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -247,6 +247,7 @@ pub trait LanguageModelProvider: 'static { fn icon(&self) -> IconName { IconName::ZedAssistant } + fn default_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; fn load_model(&self, _model: Arc, _cx: &App) {} fn is_authenticated(&self, cx: &App) -> bool; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index e3ca4998fe..9908929457 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -183,6 +183,17 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { IconName::AiAnthropic } + fn default_model(&self, _cx: &App) -> Option> { + let model = anthropic::Model::default(); + Some(Arc::new(AnthropicModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 05544f40db..236b78527b 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -272,6 +272,18 @@ impl LanguageModelProvider for CloudLanguageModelProvider { IconName::AiZed } + fn default_model(&self, cx: &App) -> Option> { + let llm_api_token = self.state.read(cx).llm_api_token.clone(); + let model = CloudModel::Anthropic(anthropic::Model::default()); + Some(Arc::new(CloudLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + llm_api_token: llm_api_token.clone(), + client: self.client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 1c4a4273ac..7bf2cfe4f6 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -89,6 +89,14 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { IconName::Copilot } + fn default_model(&self, _cx: &App) -> Option> { + let model = CopilotChatModel::default(); + Some(Arc::new(CopilotChatLanguageModel { + model, + request_limiter: RateLimiter::new(4), + }) as Arc) + } + fn provided_models(&self, _cx: &App) -> Vec> { CopilotChatModel::iter() .map(|model| { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 91cc02149d..830e94ecb5 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -163,6 +163,17 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { IconName::AiDeepSeek } + fn default_model(&self, _cx: &App) -> Option> { + let model = deepseek::Model::Chat; + Some(Arc::new(DeepSeekLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 9e313935c2..0bf5001f79 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -166,6 +166,17 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { IconName::AiGoogle } + fn default_model(&self, _cx: &App) -> Option> { + let model = google_ai::Model::default(); + Some(Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + rate_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 76832a44e1..edd07c053a 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -152,6 +152,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { IconName::AiLmStudio } + fn default_model(&self, cx: &App) -> Option> { + self.provided_models(cx).into_iter().next() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 032ee38c42..80a5988cff 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -167,6 +167,17 @@ impl LanguageModelProvider for MistralLanguageModelProvider { IconName::AiMistral } + fn default_model(&self, _cx: &App) -> Option> { + let model = mistral::Model::default(); + Some(Arc::new(MistralLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index a982eb3aa7..33ad0bcafd 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -157,6 +157,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { IconName::AiOllama } + fn default_model(&self, cx: &App) -> Option> { + self.provided_models(cx).into_iter().next() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index ee277247b8..3e46983ebb 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -169,6 +169,17 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { IconName::AiOpenAi } + fn default_model(&self, _cx: &App) -> Option> { + let model = open_ai::Model::default(); + Some(Arc::new(OpenAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default();