collab: Rework model name checks (#16365)
This PR reworks how we do checks for model names in the LLM service. We now normalize the model names using the models defined in the database. Release Notes: - N/A
This commit is contained in:
parent
463ac7f5e4
commit
7a5acc0b0c
3 changed files with 34 additions and 32 deletions
|
@ -169,7 +169,10 @@ async fn perform_completion(
|
||||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||||
Json(params): Json<PerformCompletionParams>,
|
Json(params): Json<PerformCompletionParams>,
|
||||||
) -> Result<impl IntoResponse> {
|
) -> Result<impl IntoResponse> {
|
||||||
let model = normalize_model_name(params.provider, params.model);
|
let model = normalize_model_name(
|
||||||
|
state.db.model_names_for_provider(params.provider),
|
||||||
|
params.model,
|
||||||
|
);
|
||||||
|
|
||||||
authorize_access_to_language_model(
|
authorize_access_to_language_model(
|
||||||
&state.config,
|
&state.config,
|
||||||
|
@ -200,14 +203,18 @@ async fn perform_completion(
|
||||||
let mut request: anthropic::Request =
|
let mut request: anthropic::Request =
|
||||||
serde_json::from_str(¶ms.provider_request.get())?;
|
serde_json::from_str(¶ms.provider_request.get())?;
|
||||||
|
|
||||||
// Parse the model, throw away the version that was included, and then set a specific
|
// Override the model on the request with the latest version of the model that is
|
||||||
// version that we control on the server.
|
// known to the server.
|
||||||
|
//
|
||||||
// Right now, we use the version that's defined in `model.id()`, but we will likely
|
// Right now, we use the version that's defined in `model.id()`, but we will likely
|
||||||
// want to change this code once a new version of an Anthropic model is released,
|
// want to change this code once a new version of an Anthropic model is released,
|
||||||
// so that users can use the new version, without having to update Zed.
|
// so that users can use the new version, without having to update Zed.
|
||||||
request.model = match anthropic::Model::from_id(&request.model) {
|
request.model = match model.as_str() {
|
||||||
Ok(model) => model.id().to_string(),
|
"claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
|
||||||
Err(_) => request.model,
|
"claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
|
||||||
|
"claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
|
||||||
|
"claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
|
||||||
|
_ => request.model,
|
||||||
};
|
};
|
||||||
|
|
||||||
let chunks = anthropic::stream_completion(
|
let chunks = anthropic::stream_completion(
|
||||||
|
@ -369,31 +376,13 @@ async fn perform_completion(
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
|
||||||
let prefixes: &[_] = match provider {
|
if let Some(known_model_name) = known_models
|
||||||
LanguageModelProvider::Anthropic => &[
|
|
||||||
"claude-3-5-sonnet",
|
|
||||||
"claude-3-haiku",
|
|
||||||
"claude-3-opus",
|
|
||||||
"claude-3-sonnet",
|
|
||||||
],
|
|
||||||
LanguageModelProvider::OpenAi => &[
|
|
||||||
"gpt-3.5-turbo",
|
|
||||||
"gpt-4-turbo-preview",
|
|
||||||
"gpt-4o-mini",
|
|
||||||
"gpt-4o",
|
|
||||||
"gpt-4",
|
|
||||||
],
|
|
||||||
LanguageModelProvider::Google => &[],
|
|
||||||
LanguageModelProvider::Zed => &[],
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(prefix) = prefixes
|
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|&&prefix| name.starts_with(prefix))
|
.filter(|known_model_name| name.starts_with(known_model_name.as_str()))
|
||||||
.max_by_key(|&&prefix| prefix.len())
|
.max_by_key(|known_model_name| known_model_name.len())
|
||||||
{
|
{
|
||||||
prefix.to_string()
|
known_model_name.to_string()
|
||||||
} else {
|
} else {
|
||||||
name
|
name
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,9 +26,7 @@ fn authorize_access_to_model(
|
||||||
}
|
}
|
||||||
|
|
||||||
match (provider, model) {
|
match (provider, model) {
|
||||||
(LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
|
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
_ => Err(Error::http(
|
_ => Err(Error::http(
|
||||||
StatusCode::FORBIDDEN,
|
StatusCode::FORBIDDEN,
|
||||||
format!("access to model {model:?} is not included in your plan"),
|
format!("access to model {model:?} is not included in your plan"),
|
||||||
|
|
|
@ -67,6 +67,21 @@ impl LlmDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the names of the known models for the given [`LanguageModelProvider`].
|
||||||
|
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
|
||||||
|
self.models
|
||||||
|
.keys()
|
||||||
|
.filter_map(|(model_provider, model_name)| {
|
||||||
|
if model_provider == &provider {
|
||||||
|
Some(model_name)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
|
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
|
||||||
Ok(self
|
Ok(self
|
||||||
.models
|
.models
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue