Introduce a separate backend service for LLM calls (#15831)

This PR introduces a separate backend service for making LLM calls.

It exposes an HTTP interface that can be called by Zed clients. To call
these endpoints, the client must provide a `Bearer` token. These tokens
are issued/refreshed by the collab service over RPC.

We're adding this in a backwards-compatible way. Right now the access
tokens can only be minted for Zed staff, and calling this separate LLM
service is behind the `llm-service` feature flag (which is not
automatically enabled for Zed staff).

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Max Brunsfeld 2024-08-05 17:26:21 -07:00 committed by GitHub
parent 4ed43e6e6f
commit 8e9c2b1125
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 478 additions and 102 deletions

View file

@ -5,13 +5,20 @@ use crate::{
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anyhow::{anyhow, Context as _, Result};
use client::{Client, UserStore};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
use http_client::{HttpClient, Method};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
use smol::{
io::BufReader,
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
};
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@ -46,6 +53,7 @@ pub struct AvailableModel {
pub struct CloudLanguageModelProvider {
client: Arc<Client>,
llm_api_token: LlmApiToken,
state: gpui::Model<State>,
_maintain_client_status: Task<()>,
}
@ -104,6 +112,7 @@ impl CloudLanguageModelProvider {
Self {
client,
state,
llm_api_token: LlmApiToken::default(),
_maintain_client_status: maintain_client_status,
}
}
@ -181,6 +190,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
llm_api_token: self.llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
@ -208,13 +218,27 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
}
struct LlmServiceFeatureFlag;
impl FeatureFlag for LlmServiceFeatureFlag {
const NAME: &'static str = "llm-service";
fn enabled_for_staff() -> bool {
false
}
}
pub struct CloudLanguageModel {
id: LanguageModelId,
model: CloudModel,
llm_api_token: LlmApiToken,
client: Arc<Client>,
request_limiter: RateLimiter,
}
#[derive(Clone, Default)]
struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LanguageModel for CloudLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@ -279,25 +303,88 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
_: &AsyncAppContext,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let request = request.into_anthropic(model.id().into());
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
Ok(anthropic::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
))
});
async move { Ok(future.await?.boxed()) }.boxed()
let client = self.client.clone();
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let http_client = self.client.http_client();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let mut token = llm_api_token.acquire(&client).await?;
let mut did_retry = false;
let response = loop {
let request = http_client::Request::builder()
.method(Method::POST)
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(
serde_json::to_string(&PerformCompletionParams {
provider_request: RawValue::from_string(request.clone())?,
})?
.into(),
)?;
let response = http_client.send(request).await?;
if response.status().is_success() {
break response;
} else if !did_retry
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
did_retry = true;
token = llm_api_token.refresh(&client).await?;
} else {
break Err(anyhow!(
"cloud language model completion failed with status {}",
response.status()
))?;
}
};
let body = BufReader::new(response.into_body());
let stream =
futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
match body.read_line(&mut buffer).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: anthropic::Event =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(anthropic::extract_text_from_events(stream))
});
async move { Ok(future.await?.boxed()) }.boxed()
} else {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?
.map(|event| Ok(serde_json::from_str(&event?.event)?));
Ok(anthropic::extract_text_from_events(stream))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
@ -417,6 +504,30 @@ impl LanguageModel for CloudLanguageModel {
}
}
impl LlmApiToken {
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
} else {
Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
}
}
async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
Self::fetch(self.0.write().await, &client).await
}
async fn fetch<'a>(
mut lock: RwLockWriteGuard<'a, Option<String>>,
client: &Arc<Client>,
) -> Result<String> {
let response = client.request(proto::GetLlmToken {}).await?;
*lock = Some(response.token.clone());
Ok(response.token.clone())
}
}
struct ConfigurationView {
state: gpui::Model<State>,
}