collab: Add support for more providers to the LLM service (#15832)
This PR adds support for additional providers to the LLM service: - OpenAI - Google - Custom Zed models (through Hugging Face) Release Notes: - N/A
This commit is contained in:
parent
8e9c2b1125
commit
ca9511393b
3 changed files with 331 additions and 98 deletions
|
@ -12,7 +12,7 @@ use axum::{
|
||||||
};
|
};
|
||||||
use futures::StreamExt as _;
|
use futures::StreamExt as _;
|
||||||
use http_client::IsahcHttpClient;
|
use http_client::IsahcHttpClient;
|
||||||
use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub use token::*;
|
pub use token::*;
|
||||||
|
@ -94,29 +94,118 @@ async fn perform_completion(
|
||||||
Extension(_claims): Extension<LlmTokenClaims>,
|
Extension(_claims): Extension<LlmTokenClaims>,
|
||||||
Json(params): Json<PerformCompletionParams>,
|
Json(params): Json<PerformCompletionParams>,
|
||||||
) -> Result<impl IntoResponse> {
|
) -> Result<impl IntoResponse> {
|
||||||
let api_key = state
|
match params.provider {
|
||||||
.config
|
LanguageModelProvider::Anthropic => {
|
||||||
.anthropic_api_key
|
let api_key = state
|
||||||
.as_ref()
|
.config
|
||||||
.context("no Anthropic AI API key configured on the server")?;
|
.anthropic_api_key
|
||||||
let chunks = anthropic::stream_completion(
|
.as_ref()
|
||||||
&state.http_client,
|
.context("no Anthropic AI API key configured on the server")?;
|
||||||
anthropic::ANTHROPIC_API_URL,
|
let chunks = anthropic::stream_completion(
|
||||||
api_key,
|
&state.http_client,
|
||||||
serde_json::from_str(¶ms.provider_request.get())?,
|
anthropic::ANTHROPIC_API_URL,
|
||||||
None,
|
api_key,
|
||||||
)
|
serde_json::from_str(¶ms.provider_request.get())?,
|
||||||
.await?;
|
None,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let stream = chunks.map(|event| {
|
let stream = chunks.map(|event| {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
event.map(|chunk| {
|
event.map(|chunk| {
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||||
buffer.push(b'\n');
|
buffer.push(b'\n');
|
||||||
buffer
|
buffer
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(Response::new(Body::wrap_stream(stream)))
|
Ok(Response::new(Body::wrap_stream(stream)))
|
||||||
|
}
|
||||||
|
LanguageModelProvider::OpenAi => {
|
||||||
|
let api_key = state
|
||||||
|
.config
|
||||||
|
.openai_api_key
|
||||||
|
.as_ref()
|
||||||
|
.context("no OpenAI API key configured on the server")?;
|
||||||
|
let chunks = open_ai::stream_completion(
|
||||||
|
&state.http_client,
|
||||||
|
open_ai::OPEN_AI_API_URL,
|
||||||
|
api_key,
|
||||||
|
serde_json::from_str(¶ms.provider_request.get())?,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let stream = chunks.map(|event| {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
event.map(|chunk| {
|
||||||
|
buffer.clear();
|
||||||
|
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||||
|
buffer.push(b'\n');
|
||||||
|
buffer
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Response::new(Body::wrap_stream(stream)))
|
||||||
|
}
|
||||||
|
LanguageModelProvider::Google => {
|
||||||
|
let api_key = state
|
||||||
|
.config
|
||||||
|
.google_ai_api_key
|
||||||
|
.as_ref()
|
||||||
|
.context("no Google AI API key configured on the server")?;
|
||||||
|
let chunks = google_ai::stream_generate_content(
|
||||||
|
&state.http_client,
|
||||||
|
google_ai::API_URL,
|
||||||
|
api_key,
|
||||||
|
serde_json::from_str(¶ms.provider_request.get())?,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let stream = chunks.map(|event| {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
event.map(|chunk| {
|
||||||
|
buffer.clear();
|
||||||
|
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||||
|
buffer.push(b'\n');
|
||||||
|
buffer
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Response::new(Body::wrap_stream(stream)))
|
||||||
|
}
|
||||||
|
LanguageModelProvider::Zed => {
|
||||||
|
let api_key = state
|
||||||
|
.config
|
||||||
|
.qwen2_7b_api_key
|
||||||
|
.as_ref()
|
||||||
|
.context("no Qwen2-7B API key configured on the server")?;
|
||||||
|
let api_url = state
|
||||||
|
.config
|
||||||
|
.qwen2_7b_api_url
|
||||||
|
.as_ref()
|
||||||
|
.context("no Qwen2-7B URL configured on the server")?;
|
||||||
|
let chunks = open_ai::stream_completion(
|
||||||
|
&state.http_client,
|
||||||
|
&api_url,
|
||||||
|
api_key,
|
||||||
|
serde_json::from_str(¶ms.provider_request.get())?,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let stream = chunks.map(|event| {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
event.map(|chunk| {
|
||||||
|
buffer.clear();
|
||||||
|
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||||
|
buffer.push(b'\n');
|
||||||
|
buffer
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Response::new(Body::wrap_stream(stream)))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ use collections::BTreeMap;
|
||||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||||
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
||||||
use http_client::{HttpClient, Method};
|
use http_client::{AsyncBody, HttpClient, Method, Response};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::value::RawValue;
|
use serde_json::value::RawValue;
|
||||||
|
@ -239,6 +239,47 @@ pub struct CloudLanguageModel {
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||||
|
|
||||||
|
impl CloudLanguageModel {
|
||||||
|
async fn perform_llm_completion(
|
||||||
|
client: Arc<Client>,
|
||||||
|
llm_api_token: LlmApiToken,
|
||||||
|
body: PerformCompletionParams,
|
||||||
|
) -> Result<Response<AsyncBody>> {
|
||||||
|
let http_client = &client.http_client();
|
||||||
|
|
||||||
|
let mut token = llm_api_token.acquire(&client).await?;
|
||||||
|
let mut did_retry = false;
|
||||||
|
|
||||||
|
let response = loop {
|
||||||
|
let request = http_client::Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.body(serde_json::to_string(&body)?.into())?;
|
||||||
|
let response = http_client.send(request).await?;
|
||||||
|
if response.status().is_success() {
|
||||||
|
break response;
|
||||||
|
} else if !did_retry
|
||||||
|
&& response
|
||||||
|
.headers()
|
||||||
|
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
did_retry = true;
|
||||||
|
token = llm_api_token.refresh(&client).await?;
|
||||||
|
} else {
|
||||||
|
break Err(anyhow!(
|
||||||
|
"cloud language model completion failed with status {}",
|
||||||
|
response.status()
|
||||||
|
))?;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl LanguageModel for CloudLanguageModel {
|
impl LanguageModel for CloudLanguageModel {
|
||||||
fn id(&self) -> LanguageModelId {
|
fn id(&self) -> LanguageModelId {
|
||||||
self.id.clone()
|
self.id.clone()
|
||||||
|
@ -314,46 +355,21 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
{
|
{
|
||||||
let http_client = self.client.http_client();
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let response = Self::perform_llm_completion(
|
||||||
let mut token = llm_api_token.acquire(&client).await?;
|
client.clone(),
|
||||||
let mut did_retry = false;
|
llm_api_token,
|
||||||
|
PerformCompletionParams {
|
||||||
let response = loop {
|
provider: client::LanguageModelProvider::Anthropic,
|
||||||
let request = http_client::Request::builder()
|
model: request.model.clone(),
|
||||||
.method(Method::POST)
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
|
&request,
|
||||||
.header("Content-Type", "application/json")
|
)?)?,
|
||||||
.header("Authorization", format!("Bearer {token}"))
|
},
|
||||||
.body(
|
)
|
||||||
serde_json::to_string(&PerformCompletionParams {
|
.await?;
|
||||||
provider_request: RawValue::from_string(request.clone())?,
|
|
||||||
})?
|
|
||||||
.into(),
|
|
||||||
)?;
|
|
||||||
let response = http_client.send(request).await?;
|
|
||||||
if response.status().is_success() {
|
|
||||||
break response;
|
|
||||||
} else if !did_retry
|
|
||||||
&& response
|
|
||||||
.headers()
|
|
||||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
|
||||||
.is_some()
|
|
||||||
{
|
|
||||||
did_retry = true;
|
|
||||||
token = llm_api_token.refresh(&client).await?;
|
|
||||||
} else {
|
|
||||||
break Err(anyhow!(
|
|
||||||
"cloud language model completion failed with status {}",
|
|
||||||
response.status()
|
|
||||||
))?;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let body = BufReader::new(response.into_body());
|
let body = BufReader::new(response.into_body());
|
||||||
|
|
||||||
let stream =
|
let stream =
|
||||||
futures::stream::try_unfold(body, move |mut body| async move {
|
futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
|
@ -389,54 +405,171 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_open_ai(model.id().into());
|
let request = request.into_open_ai(model.id().into());
|
||||||
let future = self.request_limiter.stream(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
if cx
|
||||||
let stream = client
|
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
.unwrap_or(false)
|
||||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
{
|
||||||
request,
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
})
|
let future = self.request_limiter.stream(async move {
|
||||||
|
let response = Self::perform_llm_completion(
|
||||||
|
client.clone(),
|
||||||
|
llm_api_token,
|
||||||
|
PerformCompletionParams {
|
||||||
|
provider: client::LanguageModelProvider::OpenAi,
|
||||||
|
model: request.model.clone(),
|
||||||
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
|
&request,
|
||||||
|
)?)?,
|
||||||
|
},
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(open_ai::extract_text_from_events(
|
let body = BufReader::new(response.into_body());
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
let stream =
|
||||||
))
|
futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
});
|
let mut buffer = String::new();
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
match body.read_line(&mut buffer).await {
|
||||||
|
Ok(0) => Ok(None),
|
||||||
|
Ok(_) => {
|
||||||
|
let event: open_ai::ResponseStreamEvent =
|
||||||
|
serde_json::from_str(&buffer)?;
|
||||||
|
Ok(Some((event, body)))
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(open_ai::extract_text_from_events(stream))
|
||||||
|
});
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
} else {
|
||||||
|
let future = self.request_limiter.stream(async move {
|
||||||
|
let request = serde_json::to_string(&request)?;
|
||||||
|
let stream = client
|
||||||
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
|
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||||
|
request,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(open_ai::extract_text_from_events(
|
||||||
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_google(model.id().into());
|
let request = request.into_google(model.id().into());
|
||||||
let future = self.request_limiter.stream(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
if cx
|
||||||
let stream = client
|
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
.unwrap_or(false)
|
||||||
provider: proto::LanguageModelProvider::Google as i32,
|
{
|
||||||
request,
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
})
|
let future = self.request_limiter.stream(async move {
|
||||||
|
let response = Self::perform_llm_completion(
|
||||||
|
client.clone(),
|
||||||
|
llm_api_token,
|
||||||
|
PerformCompletionParams {
|
||||||
|
provider: client::LanguageModelProvider::Google,
|
||||||
|
model: request.model.clone(),
|
||||||
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
|
&request,
|
||||||
|
)?)?,
|
||||||
|
},
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(google_ai::extract_text_from_events(
|
let body = BufReader::new(response.into_body());
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
let stream =
|
||||||
))
|
futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
});
|
let mut buffer = String::new();
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
match body.read_line(&mut buffer).await {
|
||||||
|
Ok(0) => Ok(None),
|
||||||
|
Ok(_) => {
|
||||||
|
let event: google_ai::GenerateContentResponse =
|
||||||
|
serde_json::from_str(&buffer)?;
|
||||||
|
Ok(Some((event, body)))
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(google_ai::extract_text_from_events(stream))
|
||||||
|
});
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
} else {
|
||||||
|
let future = self.request_limiter.stream(async move {
|
||||||
|
let request = serde_json::to_string(&request)?;
|
||||||
|
let stream = client
|
||||||
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
|
provider: proto::LanguageModelProvider::Google as i32,
|
||||||
|
request,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(google_ai::extract_text_from_events(
|
||||||
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
CloudModel::Zed(model) => {
|
CloudModel::Zed(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let mut request = request.into_open_ai(model.id().into());
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
request.max_tokens = Some(4000);
|
request.max_tokens = Some(4000);
|
||||||
let future = self.request_limiter.stream(async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
if cx
|
||||||
let stream = client
|
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
.unwrap_or(false)
|
||||||
provider: proto::LanguageModelProvider::Zed as i32,
|
{
|
||||||
request,
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
})
|
let future = self.request_limiter.stream(async move {
|
||||||
|
let response = Self::perform_llm_completion(
|
||||||
|
client.clone(),
|
||||||
|
llm_api_token,
|
||||||
|
PerformCompletionParams {
|
||||||
|
provider: client::LanguageModelProvider::Zed,
|
||||||
|
model: request.model.clone(),
|
||||||
|
provider_request: RawValue::from_string(serde_json::to_string(
|
||||||
|
&request,
|
||||||
|
)?)?,
|
||||||
|
},
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(open_ai::extract_text_from_events(
|
let body = BufReader::new(response.into_body());
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
let stream =
|
||||||
))
|
futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
});
|
let mut buffer = String::new();
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
match body.read_line(&mut buffer).await {
|
||||||
|
Ok(0) => Ok(None),
|
||||||
|
Ok(_) => {
|
||||||
|
let event: open_ai::ResponseStreamEvent =
|
||||||
|
serde_json::from_str(&buffer)?;
|
||||||
|
Ok(Some((event, body)))
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(open_ai::extract_text_from_events(stream))
|
||||||
|
});
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
} else {
|
||||||
|
let future = self.request_limiter.stream(async move {
|
||||||
|
let request = serde_json::to_string(&request)?;
|
||||||
|
let stream = client
|
||||||
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
|
provider: proto::LanguageModelProvider::Zed as i32,
|
||||||
|
request,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(open_ai::extract_text_from_events(
|
||||||
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,18 @@ use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum LanguageModelProvider {
|
||||||
|
Anthropic,
|
||||||
|
OpenAi,
|
||||||
|
Google,
|
||||||
|
Zed,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
pub struct PerformCompletionParams {
|
pub struct PerformCompletionParams {
|
||||||
|
pub provider: LanguageModelProvider,
|
||||||
|
pub model: String,
|
||||||
pub provider_request: Box<serde_json::value::RawValue>,
|
pub provider_request: Box<serde_json::value::RawValue>,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue