diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index a7f7ca5164..a5aa993e43 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::{anyhow, Context as _, Result}; use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; use collections::BTreeMap; -use feature_flags::{FeatureFlag, FeatureFlagAppExt}; +use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels}; use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task}; use http_client::{AsyncBody, HttpClient, Method, Response}; @@ -142,46 +142,56 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn provided_models(&self, cx: &AppContext) -> Vec> { let mut models = BTreeMap::default(); - for model in anthropic::Model::iter() { - if !matches!(model, anthropic::Model::Custom { .. }) { - models.insert(model.id().to_string(), CloudModel::Anthropic(model)); + let is_user = !cx.has_flag::(); + if is_user { + models.insert( + anthropic::Model::Claude3_5Sonnet.id().to_string(), + CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet), + ); + } else { + for model in anthropic::Model::iter() { + if !matches!(model, anthropic::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::Anthropic(model)); + } } - } - for model in open_ai::Model::iter() { - if !matches!(model, open_ai::Model::Custom { .. }) { - models.insert(model.id().to_string(), CloudModel::OpenAi(model)); + for model in open_ai::Model::iter() { + if !matches!(model, open_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::OpenAi(model)); + } } - } - for model in google_ai::Model::iter() { - if !matches!(model, google_ai::Model::Custom { .. }) { - models.insert(model.id().to_string(), CloudModel::Google(model)); + for model in google_ai::Model::iter() { + if !matches!(model, google_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::Google(model)); + } + } + for model in ZedModel::iter() { + models.insert(model.id().to_string(), CloudModel::Zed(model)); } - } - for model in ZedModel::iter() { - models.insert(model.id().to_string(), CloudModel::Zed(model)); - } - // Override with available models from settings - for model in &AllLanguageModelSettings::get_global(cx) - .zed_dot_dev - .available_models - { - let model = match model.provider { - AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom { - name: model.name.clone(), - max_tokens: model.max_tokens, - tool_override: model.tool_override.clone(), - }), - AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { - name: model.name.clone(), - max_tokens: model.max_tokens, - }), - AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { - name: model.name.clone(), - max_tokens: model.max_tokens, - }), - }; - models.insert(model.id().to_string(), model.clone()); + // Override with available models from settings + for model in &AllLanguageModelSettings::get_global(cx) + .zed_dot_dev + .available_models + { + let model = match model.provider { + AvailableProvider::Anthropic => { + CloudModel::Anthropic(anthropic::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + tool_override: model.tool_override.clone(), + }) + } + AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + }; + models.insert(model.id().to_string(), model.clone()); + } } models