Partially fix assistant onboarding (#25313)

While investigating #24896, I noticed two issues:

1. The default configuration for the `zed.dev` provider was using the
wrong string for Claude 3.5 Sonnet. This meant the provider would always
result as not configured until the user selected it from the model
picker, because we couldn't deserialize that string to a valid
`anthropic::Model` enum variant.
2. When clicking on `Open New Chat`/`Start New Thread` in the provider
configuration, we would select `Claude 3.5 Haiku` by default instead of
Claude 3.5 Sonnet.

Release Notes:

- Fixed some issues that caused AI providers to sometimes be
misconfigured.
This commit is contained in:
Antonio Scandurra 2025-02-24 08:29:55 +01:00 committed by GitHub
parent 535ba75bc7
commit f517050548
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 94 additions and 5 deletions

View file

@ -581,7 +581,7 @@
// The provider to use. // The provider to use.
"provider": "zed.dev", "provider": "zed.dev",
// The model to use. // The model to use.
"model": "claude-3-5-sonnet" "model": "claude-3-5-sonnet-latest"
} }
}, },
// The settings for slash commands. // The settings for slash commands.

View file

@ -978,7 +978,7 @@ impl AssistantPanel {
.active_provider() .active_provider()
.map_or(true, |p| p.id() != provider.id()) .map_or(true, |p| p.id() != provider.id())
{ {
if let Some(model) = provider.provided_models(cx).first().cloned() { if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>( update_settings_file::<AssistantSettings>(
this.fs.clone(), this.fs.clone(),
cx, cx,

View file

@ -431,7 +431,7 @@ impl AssistantPanel {
active_provider.id() != provider.id() active_provider.id() != provider.id()
}) })
{ {
if let Some(model) = provider.provided_models(cx).first().cloned() { if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>( update_settings_file::<AssistantSettings>(
self.fs.clone(), self.fs.clone(),
cx, cx,

View file

@ -512,7 +512,7 @@ mod tests {
AssistantSettings::get_global(cx).default_model, AssistantSettings::get_global(cx).default_model,
LanguageModelSelection { LanguageModelSelection {
provider: "zed.dev".into(), provider: "zed.dev".into(),
model: "claude-3-5-sonnet".into(), model: "claude-3-5-sonnet-latest".into(),
} }
); );
}); });

View file

@ -299,7 +299,7 @@ pub struct CountTokensResponse {
} }
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] #[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
pub enum Model { pub enum Model {
#[serde(rename = "gemini-1.5-pro")] #[serde(rename = "gemini-1.5-pro")]
Gemini15Pro, Gemini15Pro,
@ -308,6 +308,7 @@ pub enum Model {
#[serde(rename = "gemini-2.0-pro-exp")] #[serde(rename = "gemini-2.0-pro-exp")]
Gemini20Pro, Gemini20Pro,
#[serde(rename = "gemini-2.0-flash")] #[serde(rename = "gemini-2.0-flash")]
#[default]
Gemini20Flash, Gemini20Flash,
#[serde(rename = "gemini-2.0-flash-thinking-exp")] #[serde(rename = "gemini-2.0-flash-thinking-exp")]
Gemini20FlashThinking, Gemini20FlashThinking,

View file

@ -46,6 +46,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
provider_name() provider_name()
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(Arc::new(FakeLanguageModel::default()))
}
fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
vec![Arc::new(FakeLanguageModel::default())] vec![Arc::new(FakeLanguageModel::default())]
} }

View file

@ -247,6 +247,7 @@ pub trait LanguageModelProvider: 'static {
fn icon(&self) -> IconName { fn icon(&self) -> IconName {
IconName::ZedAssistant IconName::ZedAssistant
} }
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>; fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {} fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
fn is_authenticated(&self, cx: &App) -> bool; fn is_authenticated(&self, cx: &App) -> bool;

View file

@ -183,6 +183,17 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
IconName::AiAnthropic IconName::AiAnthropic
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = anthropic::Model::default();
Some(Arc::new(AnthropicModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default(); let mut models = BTreeMap::default();

View file

@ -272,6 +272,18 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
IconName::AiZed IconName::AiZed
} }
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
let llm_api_token = self.state.read(cx).llm_api_token.clone();
let model = CloudModel::Anthropic(anthropic::Model::default());
Some(Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
llm_api_token: llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default(); let mut models = BTreeMap::default();

View file

@ -89,6 +89,14 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
IconName::Copilot IconName::Copilot
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = CopilotChatModel::default();
Some(Arc::new(CopilotChatLanguageModel {
model,
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>)
}
fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
CopilotChatModel::iter() CopilotChatModel::iter()
.map(|model| { .map(|model| {

View file

@ -163,6 +163,17 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
IconName::AiDeepSeek IconName::AiDeepSeek
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = deepseek::Model::Chat;
Some(Arc::new(DeepSeekLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default(); let mut models = BTreeMap::default();

View file

@ -166,6 +166,17 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
IconName::AiGoogle IconName::AiGoogle
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = google_ai::Model::default();
Some(Arc::new(GoogleLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
}))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default(); let mut models = BTreeMap::default();

View file

@ -152,6 +152,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
IconName::AiLmStudio IconName::AiLmStudio
} }
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.provided_models(cx).into_iter().next()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default(); let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();

View file

@ -167,6 +167,17 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
IconName::AiMistral IconName::AiMistral
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = mistral::Model::default();
Some(Arc::new(MistralLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default(); let mut models = BTreeMap::default();

View file

@ -157,6 +157,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
IconName::AiOllama IconName::AiOllama
} }
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.provided_models(cx).into_iter().next()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default(); let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();

View file

@ -169,6 +169,17 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
IconName::AiOpenAi IconName::AiOpenAi
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = open_ai::Model::default();
Some(Arc::new(OpenAiLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default(); let mut models = BTreeMap::default();