collab: Rework model name checks (#16365)
This PR reworks how we do checks for model names in the LLM service. We now normalize the model names using the models defined in the database. Release Notes: - N/A
This commit is contained in:
parent
463ac7f5e4
commit
7a5acc0b0c
3 changed files with 34 additions and 32 deletions
|
@ -169,7 +169,10 @@ async fn perform_completion(
|
|||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
Json(params): Json<PerformCompletionParams>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let model = normalize_model_name(params.provider, params.model);
|
||||
let model = normalize_model_name(
|
||||
state.db.model_names_for_provider(params.provider),
|
||||
params.model,
|
||||
);
|
||||
|
||||
authorize_access_to_language_model(
|
||||
&state.config,
|
||||
|
@ -200,14 +203,18 @@ async fn perform_completion(
|
|||
let mut request: anthropic::Request =
|
||||
serde_json::from_str(¶ms.provider_request.get())?;
|
||||
|
||||
// Parse the model, throw away the version that was included, and then set a specific
|
||||
// version that we control on the server.
|
||||
// Override the model on the request with the latest version of the model that is
|
||||
// known to the server.
|
||||
//
|
||||
// Right now, we use the version that's defined in `model.id()`, but we will likely
|
||||
// want to change this code once a new version of an Anthropic model is released,
|
||||
// so that users can use the new version, without having to update Zed.
|
||||
request.model = match anthropic::Model::from_id(&request.model) {
|
||||
Ok(model) => model.id().to_string(),
|
||||
Err(_) => request.model,
|
||||
request.model = match model.as_str() {
|
||||
"claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
|
||||
"claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
|
||||
"claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
|
||||
"claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
|
||||
_ => request.model,
|
||||
};
|
||||
|
||||
let chunks = anthropic::stream_completion(
|
||||
|
@ -369,31 +376,13 @@ async fn perform_completion(
|
|||
})))
|
||||
}
|
||||
|
||||
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
||||
let prefixes: &[_] = match provider {
|
||||
LanguageModelProvider::Anthropic => &[
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-haiku",
|
||||
"claude-3-opus",
|
||||
"claude-3-sonnet",
|
||||
],
|
||||
LanguageModelProvider::OpenAi => &[
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4",
|
||||
],
|
||||
LanguageModelProvider::Google => &[],
|
||||
LanguageModelProvider::Zed => &[],
|
||||
};
|
||||
|
||||
if let Some(prefix) = prefixes
|
||||
fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
|
||||
if let Some(known_model_name) = known_models
|
||||
.iter()
|
||||
.filter(|&&prefix| name.starts_with(prefix))
|
||||
.max_by_key(|&&prefix| prefix.len())
|
||||
.filter(|known_model_name| name.starts_with(known_model_name.as_str()))
|
||||
.max_by_key(|known_model_name| known_model_name.len())
|
||||
{
|
||||
prefix.to_string()
|
||||
known_model_name.to_string()
|
||||
} else {
|
||||
name
|
||||
}
|
||||
|
|
|
@ -26,9 +26,7 @@ fn authorize_access_to_model(
|
|||
}
|
||||
|
||||
match (provider, model) {
|
||||
(LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
|
||||
Ok(())
|
||||
}
|
||||
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
|
||||
_ => Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to model {model:?} is not included in your plan"),
|
||||
|
|
|
@ -67,6 +67,21 @@ impl LlmDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the names of the known models for the given [`LanguageModelProvider`].
|
||||
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
|
||||
self.models
|
||||
.keys()
|
||||
.filter_map(|(model_provider, model_name)| {
|
||||
if model_provider == &provider {
|
||||
Some(model_name)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
|
||||
Ok(self
|
||||
.models
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue