diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 151b6a247c..9eb17ef976 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -169,7 +169,10 @@ async fn perform_completion( country_code_header: Option>, Json(params): Json, ) -> Result { - let model = normalize_model_name(params.provider, params.model); + let model = normalize_model_name( + state.db.model_names_for_provider(params.provider), + params.model, + ); authorize_access_to_language_model( &state.config, @@ -200,14 +203,18 @@ async fn perform_completion( let mut request: anthropic::Request = serde_json::from_str(¶ms.provider_request.get())?; - // Parse the model, throw away the version that was included, and then set a specific - // version that we control on the server. + // Override the model on the request with the latest version of the model that is + // known to the server. + // // Right now, we use the version that's defined in `model.id()`, but we will likely // want to change this code once a new version of an Anthropic model is released, // so that users can use the new version, without having to update Zed. - request.model = match anthropic::Model::from_id(&request.model) { - Ok(model) => model.id().to_string(), - Err(_) => request.model, + request.model = match model.as_str() { + "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(), + "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(), + "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(), + "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(), + _ => request.model, }; let chunks = anthropic::stream_completion( @@ -369,31 +376,13 @@ async fn perform_completion( }))) } -fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String { - let prefixes: &[_] = match provider { - LanguageModelProvider::Anthropic => &[ - "claude-3-5-sonnet", - "claude-3-haiku", - "claude-3-opus", - "claude-3-sonnet", - ], - LanguageModelProvider::OpenAi => &[ - "gpt-3.5-turbo", - "gpt-4-turbo-preview", - "gpt-4o-mini", - "gpt-4o", - "gpt-4", - ], - LanguageModelProvider::Google => &[], - LanguageModelProvider::Zed => &[], - }; - - if let Some(prefix) = prefixes +fn normalize_model_name(known_models: Vec, name: String) -> String { + if let Some(known_model_name) = known_models .iter() - .filter(|&&prefix| name.starts_with(prefix)) - .max_by_key(|&&prefix| prefix.len()) + .filter(|known_model_name| name.starts_with(known_model_name.as_str())) + .max_by_key(|known_model_name| known_model_name.len()) { - prefix.to_string() + known_model_name.to_string() } else { name } diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs index 865ca97f7a..0b62dd4e0a 100644 --- a/crates/collab/src/llm/authorization.rs +++ b/crates/collab/src/llm/authorization.rs @@ -26,9 +26,7 @@ fn authorize_access_to_model( } match (provider, model) { - (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => { - Ok(()) - } + (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()), _ => Err(Error::http( StatusCode::FORBIDDEN, format!("access to model {model:?} is not included in your plan"), diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index b3144eeecd..f76a722471 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -67,6 +67,21 @@ impl LlmDatabase { Ok(()) } + /// Returns the names of the known models for the given [`LanguageModelProvider`]. + pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec { + self.models + .keys() + .filter_map(|(model_provider, model_name)| { + if model_provider == &provider { + Some(model_name) + } else { + None + } + }) + .cloned() + .collect::>() + } + pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> { Ok(self .models