assistant: Add display_name for OpenAI and Gemini (#17508)

This commit is contained in:
Peter Tripp 2024-09-10 13:41:06 -04:00 committed by GitHub
parent 85f4c96fef
commit fb9d01b0d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 34 additions and 16 deletions

View file

@ -254,11 +254,13 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
}),
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
}),
};

View file

@ -37,6 +37,7 @@ pub struct GoogleSettings {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
name: String,
display_name: Option<String>,
max_tokens: usize,
}
@ -170,6 +171,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
model.name.clone(),
google_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
},
);

View file

@ -40,6 +40,7 @@ pub struct OpenAiSettings {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub max_output_tokens: Option<u32>,
}
@ -171,6 +172,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
model.name.clone(),
open_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
},
@ -368,11 +370,7 @@ pub fn count_open_ai_tokens(
})
.collect::<Vec<_>>();
if let open_ai::Model::Custom { .. } = model {
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
} else {
tiktoken_rs::num_tokens_from_messages(model.id(), &messages)
}
tiktoken_rs::num_tokens_from_messages(model.id(), &messages)
})
.boxed()
}

View file

@ -175,12 +175,14 @@ impl OpenAiSettingsContent {
.filter_map(|model| match model {
open_ai::Model::Custom {
name,
display_name,
max_tokens,
max_output_tokens,
} => Some(provider::open_ai::AvailableModel {
name,
max_tokens,
max_output_tokens,
display_name,
}),
_ => None,
})