Allow using a custom model when using zed.dev (#14933)

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2024-07-22 12:25:53 +02:00 committed by GitHub
parent a334c69e05
commit 0155435142
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 114 additions and 110 deletions

View file

@ -4514,7 +4514,7 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
request: proto::CompleteWithLanguageModel,
mut request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
@ -4530,18 +4530,43 @@ async fn complete_with_language_model(
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
if request.model.starts_with("gpt") {
let api_key =
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
complete_with_open_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("gemini") {
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
complete_with_google_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("claude") {
let api_key = anthropic_api_key
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
complete_with_anthropic(request, response, session, api_key).await?;
let mut provider_and_model = request.model.split('/');
let (provider, model) = match (
provider_and_model.next().unwrap(),
provider_and_model.next(),
) {
(provider, Some(model)) => (provider, model),
(model, None) => {
if model.starts_with("gpt") {
("openai", model)
} else if model.starts_with("gemini") {
("google", model)
} else if model.starts_with("claude") {
("anthropic", model)
} else {
("unknown", model)
}
}
};
let provider = provider.to_string();
request.model = model.to_string();
match provider.as_str() {
"openai" => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
complete_with_open_ai(request, response, session, api_key).await?;
}
"anthropic" => {
let api_key =
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
complete_with_anthropic(request, response, session, api_key).await?;
}
"google" => {
let api_key =
google_ai_api_key.context("no Google AI API key configured on the server")?;
complete_with_google_ai(request, response, session, api_key).await?;
}
provider => return Err(anyhow!("unknown provider {:?}", provider))?,
}
Ok(())