Introduce a separate backend service for LLM calls (#15831)
This PR introduces a separate backend service for making LLM calls. It exposes an HTTP interface that can be called by Zed clients. To call these endpoints, the client must provide a `Bearer` token. These tokens are issued/refreshed by the collab service over RPC. We're adding this in a backwards-compatible way. Right now the access tokens can only be minted for Zed staff, and calling this separate LLM service is behind the `llm-service` feature flag (which is not automatically enabled for Zed staff). Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev> Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
4ed43e6e6f
commit
8e9c2b1125
20 changed files with 478 additions and 102 deletions
|
@ -5,13 +5,20 @@ use crate::{
|
|||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::{Client, UserStore};
|
||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use collections::BTreeMap;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
||||
use http_client::{HttpClient, Method};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use smol::{
|
||||
io::BufReader,
|
||||
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
|
||||
};
|
||||
use std::{future, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
|
@ -46,6 +53,7 @@ pub struct AvailableModel {
|
|||
|
||||
pub struct CloudLanguageModelProvider {
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
state: gpui::Model<State>,
|
||||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
|
@ -104,6 +112,7 @@ impl CloudLanguageModelProvider {
|
|||
Self {
|
||||
client,
|
||||
state,
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
_maintain_client_status: maintain_client_status,
|
||||
}
|
||||
}
|
||||
|
@ -181,6 +190,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
Arc::new(CloudLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
llm_api_token: self.llm_api_token.clone(),
|
||||
client: self.client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
|
@ -208,13 +218,27 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
}
|
||||
}
|
||||
|
||||
struct LlmServiceFeatureFlag;
|
||||
|
||||
impl FeatureFlag for LlmServiceFeatureFlag {
|
||||
const NAME: &'static str = "llm-service";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: CloudModel,
|
||||
llm_api_token: LlmApiToken,
|
||||
client: Arc<Client>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
impl LanguageModel for CloudLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
|
@ -279,25 +303,88 @@ impl LanguageModel for CloudLanguageModel {
|
|||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_: &AsyncAppContext,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_anthropic(model.id().into());
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(anthropic::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
let client = self.client.clone();
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let http_client = self.client.http_client();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
let mut did_retry = false;
|
||||
|
||||
let response = loop {
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(
|
||||
serde_json::to_string(&PerformCompletionParams {
|
||||
provider_request: RawValue::from_string(request.clone())?,
|
||||
})?
|
||||
.into(),
|
||||
)?;
|
||||
let response = http_client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
break response;
|
||||
} else if !did_retry
|
||||
&& response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
did_retry = true;
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
} else {
|
||||
break Err(anyhow!(
|
||||
"cloud language model completion failed with status {}",
|
||||
response.status()
|
||||
))?;
|
||||
}
|
||||
};
|
||||
|
||||
let body = BufReader::new(response.into_body());
|
||||
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: anthropic::Event =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(anthropic::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
} else {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?
|
||||
.map(|event| Ok(serde_json::from_str(&event?.event)?));
|
||||
Ok(anthropic::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
|
@ -417,6 +504,30 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
impl LlmApiToken {
|
||||
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
|
||||
let lock = self.0.upgradable_read().await;
|
||||
if let Some(token) = lock.as_ref() {
|
||||
Ok(token.to_string())
|
||||
} else {
|
||||
Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
|
||||
Self::fetch(self.0.write().await, &client).await
|
||||
}
|
||||
|
||||
async fn fetch<'a>(
|
||||
mut lock: RwLockWriteGuard<'a, Option<String>>,
|
||||
client: &Arc<Client>,
|
||||
) -> Result<String> {
|
||||
let response = client.request(proto::GetLlmToken {}).await?;
|
||||
*lock = Some(response.token.clone());
|
||||
Ok(response.token.clone())
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue