collab: Rework model name checks (#16365)

This PR reworks how we do checks for model names in the LLM service.

We now normalize the model names using the models defined in the
database.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-16 13:54:28 -04:00 committed by GitHub
parent 463ac7f5e4
commit 7a5acc0b0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 32 deletions

View file

@ -26,9 +26,7 @@ fn authorize_access_to_model(
}
match (provider, model) {
(LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
Ok(())
}
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
_ => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),

View file

@ -67,6 +67,21 @@ impl LlmDatabase {
Ok(())
}
/// Returns the names of the known models for the given [`LanguageModelProvider`].
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
self.models
.keys()
.filter_map(|(model_provider, model_name)| {
if model_provider == &provider {
Some(model_name)
} else {
None
}
})
.cloned()
.collect::<Vec<_>>()
}
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
Ok(self
.models