collab: Rework model name checks (#16365)
This PR reworks how we do checks for model names in the LLM service. We now normalize the model names using the models defined in the database. Release Notes: - N/A
This commit is contained in:
parent
463ac7f5e4
commit
7a5acc0b0c
3 changed files with 34 additions and 32 deletions
|
@ -26,9 +26,7 @@ fn authorize_access_to_model(
|
|||
}
|
||||
|
||||
match (provider, model) {
|
||||
(LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
|
||||
Ok(())
|
||||
}
|
||||
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
|
||||
_ => Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to model {model:?} is not included in your plan"),
|
||||
|
|
|
@ -67,6 +67,21 @@ impl LlmDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the names of the known models for the given [`LanguageModelProvider`].
|
||||
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
|
||||
self.models
|
||||
.keys()
|
||||
.filter_map(|(model_provider, model_name)| {
|
||||
if model_provider == &provider {
|
||||
Some(model_name)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
|
||||
Ok(self
|
||||
.models
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue