collab: Rework model name checks (#16365)

This PR reworks how we do checks for model names in the LLM service.

We now normalize the model names using the models defined in the
database.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-16 13:54:28 -04:00 committed by GitHub
parent 463ac7f5e4
commit 7a5acc0b0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 32 deletions

View file

@ -169,7 +169,10 @@ async fn perform_completion(
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>, country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PerformCompletionParams>, Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> { ) -> Result<impl IntoResponse> {
let model = normalize_model_name(params.provider, params.model); let model = normalize_model_name(
state.db.model_names_for_provider(params.provider),
params.model,
);
authorize_access_to_language_model( authorize_access_to_language_model(
&state.config, &state.config,
@ -200,14 +203,18 @@ async fn perform_completion(
let mut request: anthropic::Request = let mut request: anthropic::Request =
serde_json::from_str(&params.provider_request.get())?; serde_json::from_str(&params.provider_request.get())?;
// Parse the model, throw away the version that was included, and then set a specific // Override the model on the request with the latest version of the model that is
// version that we control on the server. // known to the server.
//
// Right now, we use the version that's defined in `model.id()`, but we will likely // Right now, we use the version that's defined in `model.id()`, but we will likely
// want to change this code once a new version of an Anthropic model is released, // want to change this code once a new version of an Anthropic model is released,
// so that users can use the new version, without having to update Zed. // so that users can use the new version, without having to update Zed.
request.model = match anthropic::Model::from_id(&request.model) { request.model = match model.as_str() {
Ok(model) => model.id().to_string(), "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
Err(_) => request.model, "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
"claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
"claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
_ => request.model,
}; };
let chunks = anthropic::stream_completion( let chunks = anthropic::stream_completion(
@ -369,31 +376,13 @@ async fn perform_completion(
}))) })))
} }
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String { fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
let prefixes: &[_] = match provider { if let Some(known_model_name) = known_models
LanguageModelProvider::Anthropic => &[
"claude-3-5-sonnet",
"claude-3-haiku",
"claude-3-opus",
"claude-3-sonnet",
],
LanguageModelProvider::OpenAi => &[
"gpt-3.5-turbo",
"gpt-4-turbo-preview",
"gpt-4o-mini",
"gpt-4o",
"gpt-4",
],
LanguageModelProvider::Google => &[],
LanguageModelProvider::Zed => &[],
};
if let Some(prefix) = prefixes
.iter() .iter()
.filter(|&&prefix| name.starts_with(prefix)) .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
.max_by_key(|&&prefix| prefix.len()) .max_by_key(|known_model_name| known_model_name.len())
{ {
prefix.to_string() known_model_name.to_string()
} else { } else {
name name
} }

View file

@ -26,9 +26,7 @@ fn authorize_access_to_model(
} }
match (provider, model) { match (provider, model) {
(LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => { (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
Ok(())
}
_ => Err(Error::http( _ => Err(Error::http(
StatusCode::FORBIDDEN, StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"), format!("access to model {model:?} is not included in your plan"),

View file

@ -67,6 +67,21 @@ impl LlmDatabase {
Ok(()) Ok(())
} }
/// Returns the names of the known models for the given [`LanguageModelProvider`].
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
self.models
.keys()
.filter_map(|(model_provider, model_name)| {
if model_provider == &provider {
Some(model_name)
} else {
None
}
})
.cloned()
.collect::<Vec<_>>()
}
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> { pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
Ok(self Ok(self
.models .models