Add the ability to customize available models for OpenAI-compatible services (#13276)

Closes #11984, closes #11075.

Release Notes:

- Added the ability to customize available models for OpenAI-compatible
services ([#11984](https://github.com/zed-industries/zed/issues/11984))
([#11075](https://github.com/zed-industries/zed/issues/11075)).


![image](https://github.com/zed-industries/zed/assets/32017007/01057e7b-1f21-49ad-a3ad-abc5282ffaf0)
This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ 2024-06-26 04:37:02 +08:00 committed by GitHub
parent 9f88460870
commit 922fcaf5a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 79 additions and 11 deletions

View file

@ -24,6 +24,20 @@ use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::time::Duration;
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
fn choose_openai_model(
model: &::open_ai::Model,
available_models: &[::open_ai::Model],
) -> ::open_ai::Model {
available_models
.iter()
.find(|&m| m == model)
.or_else(|| available_models.first())
.unwrap_or_else(|| model)
.clone()
}
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider {
@ -34,8 +48,9 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
model.clone(),
choose_openai_model(model, available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -77,10 +92,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
},
) => {
provider.update(
model.clone(),
choose_openai_model(model, available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
@ -136,10 +152,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
},
) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
model.clone(),
choose_openai_model(model, available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -201,10 +218,10 @@ impl CompletionProvider {
cx.global::<Self>()
}
pub fn available_models(&self) -> Vec<LanguageModel> {
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
match self {
CompletionProvider::OpenAi(provider) => provider
.available_models()
.available_models(cx)
.map(LanguageModel::OpenAi)
.collect(),
CompletionProvider::Anthropic(provider) => provider