collab: Add support for more providers to the LLM service (#15832)

This PR adds support for additional providers to the LLM service:

- OpenAI
- Google
- Custom Zed models (through Hugging Face)

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-05 21:16:18 -04:00 committed by GitHub
parent 8e9c2b1125
commit ca9511393b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 331 additions and 98 deletions

View file

@ -12,7 +12,7 @@ use axum::{
};
use futures::StreamExt as _;
use http_client::IsahcHttpClient;
use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::sync::Arc;
pub use token::*;
@ -94,29 +94,118 @@ async fn perform_completion(
Extension(_claims): Extension<LlmTokenClaims>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
let api_key = state
.config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
let chunks = anthropic::stream_completion(
&state.http_client,
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&params.provider_request.get())?,
None,
)
.await?;
match params.provider {
LanguageModelProvider::Anthropic => {
let api_key = state
.config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
let chunks = anthropic::stream_completion(
&state.http_client,
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&params.provider_request.get())?,
None,
)
.await?;
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
Ok(Response::new(Body::wrap_stream(stream)))
Ok(Response::new(Body::wrap_stream(stream)))
}
LanguageModelProvider::OpenAi => {
let api_key = state
.config
.openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let chunks = open_ai::stream_completion(
&state.http_client,
open_ai::OPEN_AI_API_URL,
api_key,
serde_json::from_str(&params.provider_request.get())?,
None,
)
.await?;
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
Ok(Response::new(Body::wrap_stream(stream)))
}
LanguageModelProvider::Google => {
let api_key = state
.config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
let chunks = google_ai::stream_generate_content(
&state.http_client,
google_ai::API_URL,
api_key,
serde_json::from_str(&params.provider_request.get())?,
)
.await?;
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
Ok(Response::new(Body::wrap_stream(stream)))
}
LanguageModelProvider::Zed => {
let api_key = state
.config
.qwen2_7b_api_key
.as_ref()
.context("no Qwen2-7B API key configured on the server")?;
let api_url = state
.config
.qwen2_7b_api_url
.as_ref()
.context("no Qwen2-7B URL configured on the server")?;
let chunks = open_ai::stream_completion(
&state.http_client,
&api_url,
api_key,
serde_json::from_str(&params.provider_request.get())?,
None,
)
.await?;
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
Ok(Response::new(Body::wrap_stream(stream)))
}
}
}