Introduce a separate backend service for LLM calls (#15831)

This PR introduces a separate backend service for making LLM calls.

It exposes an HTTP interface that can be called by Zed clients. To call
these endpoints, the client must provide a `Bearer` token. These tokens
are issued/refreshed by the collab service over RPC.

We're adding this in a backwards-compatible way. Right now the access
tokens can only be minted for Zed staff, and calling this separate LLM
service is behind the `llm-service` feature flag (which is not
automatically enabled for Zed staff).

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Max Brunsfeld 2024-08-05 17:26:21 -07:00 committed by GitHub
parent 4ed43e6e6f
commit 8e9c2b1125
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 478 additions and 102 deletions

View file

@ -81,14 +81,14 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"missing authorization header".to_string(),
)
})?
.strip_prefix("token ")
.ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"invalid authorization header".to_string(),
)
@ -97,7 +97,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
let state = req.extensions().get::<Arc<AppState>>().unwrap();
if token != state.config.api_token {
Err(Error::Http(
Err(Error::http(
StatusCode::UNAUTHORIZED,
"invalid authorization token".to_string(),
))?
@ -185,13 +185,13 @@ async fn create_access_token(
if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
impersonated_user_id = Some(impersonated_user.id);
} else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::UNPROCESSABLE_ENTITY,
format!("user {impersonate} does not exist"),
));
}
} else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::UNAUTHORIZED,
"you do not have permission to impersonate other users".to_string(),
));

View file

@ -120,7 +120,7 @@ async fn create_billing_subscription(
.zip(app.config.stripe_price_id.clone())
else {
log::error!("failed to retrieve Stripe client or price ID");
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
@ -201,7 +201,7 @@ async fn manage_billing_subscription(
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?

View file

@ -206,14 +206,14 @@ pub async fn post_hang(
body: Bytes,
) -> Result<()> {
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::INTERNAL_SERVER_ERROR,
"events not enabled".into(),
))?;
};
if checksum != expected {
return Err(Error::Http(
return Err(Error::http(
StatusCode::BAD_REQUEST,
"invalid checksum".into(),
))?;
@ -265,25 +265,25 @@ pub async fn post_panic(
body: Bytes,
) -> Result<()> {
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::INTERNAL_SERVER_ERROR,
"events not enabled".into(),
))?;
};
if checksum != expected {
return Err(Error::Http(
return Err(Error::http(
StatusCode::BAD_REQUEST,
"invalid checksum".into(),
))?;
}
let report: telemetry_events::PanicRequest = serde_json::from_slice(&body)
.map_err(|_| Error::Http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
.map_err(|_| Error::http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
let panic = report.panic;
if panic.os_name == "Linux" && panic.os_version == Some("1.0.0".to_string()) {
return Err(Error::Http(
return Err(Error::http(
StatusCode::BAD_REQUEST,
"invalid os version".into(),
))?;
@ -362,14 +362,14 @@ pub async fn post_events(
body: Bytes,
) -> Result<()> {
let Some(clickhouse_client) = app.clickhouse_client.clone() else {
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::INTERNAL_SERVER_ERROR,
"events not enabled".into(),
))?;
@ -385,7 +385,7 @@ pub async fn post_events(
let mut to_upload = ToUpload::default();
let Some(last_event) = request_body.events.last() else {
return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
return Err(Error::http(StatusCode::BAD_REQUEST, "no events".into()))?;
};
let country_code = country_code_header.map(|h| h.to_string());

View file

@ -185,7 +185,7 @@ async fn download_extension(
.clone()
.zip(app.config.blob_store_bucket.clone())
else {
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
@ -202,7 +202,7 @@ async fn download_extension(
.await?;
if !version_exists {
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_FOUND,
"unknown extension version".into(),
))?;

View file

@ -33,7 +33,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::UNAUTHORIZED,
"missing authorization header".to_string(),
)
@ -45,14 +45,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
let first = auth_header.next().unwrap_or("");
if first == "dev-server-token" {
let dev_server_token = auth_header.next().ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"missing dev-server-token token in authorization header".to_string(),
)
})?;
let dev_server = verify_dev_server_token(dev_server_token, &state.db)
.await
.map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
.map_err(|e| Error::http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
req.extensions_mut()
.insert(Principal::DevServer(dev_server));
@ -60,14 +60,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
}
let user_id = UserId(first.parse().map_err(|_| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"missing user id in authorization header".to_string(),
)
})?);
let access_token = auth_header.next().ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"missing access token in authorization header".to_string(),
)
@ -111,7 +111,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
}
}
Err(Error::Http(
Err(Error::http(
StatusCode::UNAUTHORIZED,
"invalid credentials".to_string(),
))

View file

@ -13,7 +13,10 @@ mod tests;
use anyhow::anyhow;
use aws_config::{BehaviorVersion, Region};
use axum::{http::StatusCode, response::IntoResponse};
use axum::{
http::{HeaderMap, StatusCode},
response::IntoResponse,
};
use db::{ChannelId, Database};
use executor::Executor;
pub use rate_limiter::*;
@ -24,7 +27,7 @@ use util::ResultExt;
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error {
Http(StatusCode, String),
Http(StatusCode, String, HeaderMap),
Database(sea_orm::error::DbErr),
Internal(anyhow::Error),
Stripe(stripe::StripeError),
@ -66,12 +69,18 @@ impl From<serde_json::Error> for Error {
}
}
impl Error {
fn http(code: StatusCode, message: String) -> Self {
Self::Http(code, message, HeaderMap::default())
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
match self {
Error::Http(code, message) => {
Error::Http(code, message, headers) => {
log::error!("HTTP error {}: {}", code, &message);
(code, message).into_response()
(code, headers, message).into_response()
}
Error::Database(error) => {
log::error!(
@ -104,7 +113,7 @@ impl IntoResponse for Error {
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Http(code, message) => (code, message).fmt(f),
Error::Http(code, message, _headers) => (code, message).fmt(f),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
@ -115,7 +124,7 @@ impl std::fmt::Debug for Error {
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Http(code, message) => write!(f, "{code}: {message}"),
Error::Http(code, message, _) => write!(f, "{code}: {message}"),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
@ -141,6 +150,7 @@ pub struct Config {
pub live_kit_server: Option<String>,
pub live_kit_key: Option<String>,
pub live_kit_secret: Option<String>,
pub llm_api_secret: Option<String>,
pub rust_log: Option<String>,
pub log_json: Option<bool>,
pub blob_store_url: Option<String>,

View file

@ -1,16 +1,122 @@
mod token;
use crate::{executor::Executor, Config, Error, Result};
use anyhow::Context as _;
use axum::{
body::Body,
http::{self, HeaderName, HeaderValue, Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::post,
Extension, Json, Router,
};
use futures::StreamExt as _;
use http_client::IsahcHttpClient;
use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::sync::Arc;
use crate::{executor::Executor, Config, Result};
pub use token::*;
pub struct LlmState {
pub config: Config,
pub executor: Executor,
pub http_client: IsahcHttpClient,
}
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let this = Self { config, executor };
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client = IsahcHttpClient::builder()
.default_header("User-Agent", user_agent)
.build()
.context("failed to construct http client")?;
let this = Self {
config,
executor,
http_client,
};
Ok(Arc::new(this))
}
}
pub fn routes() -> Router<(), Body> {
Router::new()
.route("/completion", post(perform_completion))
.layer(middleware::from_fn(validate_api_token))
}
async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
let token = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::http(
StatusCode::BAD_REQUEST,
"missing authorization header".to_string(),
)
})?
.strip_prefix("Bearer ")
.ok_or_else(|| {
Error::http(
StatusCode::BAD_REQUEST,
"invalid authorization header".to_string(),
)
})?;
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
match LlmTokenClaims::validate(&token, &state.config) {
Ok(claims) => {
req.extensions_mut().insert(claims);
Ok::<_, Error>(next.run(req).await.into_response())
}
Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
[(
HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
HeaderValue::from_static("true"),
)]
.into_iter()
.collect(),
)),
Err(_err) => Err(Error::http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
)),
}
}
async fn perform_completion(
Extension(state): Extension<Arc<LlmState>>,
Extension(_claims): Extension<LlmTokenClaims>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
let api_key = state
.config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
let chunks = anthropic::stream_completion(
&state.http_client,
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&params.provider_request.get())?,
None,
)
.await?;
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
Ok(Response::new(Body::wrap_stream(stream)))
}

View file

@ -0,0 +1,75 @@
use crate::{db::UserId, Config};
use anyhow::{anyhow, Result};
use chrono::Utc;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LlmTokenClaims {
pub iat: u64,
pub exp: u64,
pub jti: String,
pub user_id: u64,
pub plan: rpc::proto::Plan,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
impl LlmTokenClaims {
pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
let secret = config
.llm_api_secret
.as_ref()
.ok_or_else(|| anyhow!("no LLM API secret"))?;
let now = Utc::now();
let claims = Self {
iat: now.timestamp() as u64,
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(),
user_id: user_id.to_proto(),
plan,
};
Ok(jsonwebtoken::encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_ref()),
)?)
}
pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
let secret = config
.llm_api_secret
.as_ref()
.ok_or_else(|| anyhow!("no LLM API secret"))?;
match jsonwebtoken::decode::<Self>(
token,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::default(),
) {
Ok(token) => Ok(token.claims),
Err(e) => {
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
Err(ValidateLlmTokenError::Expired)
} else {
Err(ValidateLlmTokenError::JwtError(e))
}
}
}
}
}
#[derive(Error, Debug)]
pub enum ValidateLlmTokenError {
#[error("access token is expired")]
Expired,
#[error("access token validation error: {0}")]
JwtError(#[from] jsonwebtoken::errors::Error),
#[error("{0}")]
Other(#[from] anyhow::Error),
}

View file

@ -83,7 +83,9 @@ async fn main() -> Result<()> {
if mode.is_llm() {
let state = LlmState::new(config.clone(), Executor::Production).await?;
app = app.layer(Extension(state.clone()));
app = app
.merge(collab::llm::routes())
.layer(Extension(state.clone()));
}
if mode.is_collab() || mode.is_api() {

View file

@ -1,6 +1,7 @@
mod connection_pool;
use crate::api::CloudflareIpCountryHeader;
use crate::llm::LlmTokenClaims;
use crate::{
auth,
db::{
@ -11,7 +12,7 @@ use crate::{
ServerId, UpdatedChannelMessage, User, UserId,
},
executor::Executor,
AppState, Config, Error, RateLimit, RateLimiter, Result,
AppState, Config, Error, RateLimit, Result,
};
use anyhow::{anyhow, bail, Context as _};
use async_tungstenite::tungstenite::{
@ -149,10 +150,9 @@ struct Session {
db: Arc<tokio::sync::Mutex<DbHandle>>,
peer: Arc<Peer>,
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
app_state: Arc<AppState>,
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<IsahcHttpClient>,
rate_limiter: Arc<RateLimiter>,
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
@ -615,6 +615,7 @@ impl Server {
.add_message_handler(user_message_handler(unfollow))
.add_message_handler(user_message_handler(update_followers))
.add_request_handler(user_handler(get_private_user_info))
.add_request_handler(user_handler(get_llm_api_token))
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
@ -1046,9 +1047,8 @@ impl Server {
db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
live_kit_client: this.app_state.live_kit_client.clone(),
app_state: this.app_state.clone(),
http_client,
rate_limiter: this.app_state.rate_limiter.clone(),
geoip_country_code,
_executor: executor.clone(),
supermaven_client,
@ -1559,7 +1559,7 @@ async fn create_room(
let live_kit_room = nanoid::nanoid!(30);
let live_kit_connection_info = util::maybe!(async {
let live_kit = session.live_kit_client.as_ref();
let live_kit = session.app_state.live_kit_client.as_ref();
let live_kit = live_kit?;
let user_id = session.user_id().to_string();
@ -1630,25 +1630,26 @@ async fn join_room(
.trace_err();
}
let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(token) = live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()
{
Some(proto::LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish: true,
})
let live_kit_connection_info =
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
if let Some(token) = live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()
{
Some(proto::LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish: true,
})
} else {
None
}
} else {
None
}
} else {
None
};
};
response.send(proto::JoinRoomResponse {
room: Some(joined_room.room),
@ -1877,7 +1878,7 @@ async fn set_room_participant_role(
(live_kit_room, can_publish)
};
if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
live_kit
.update_participant(
live_kit_room.clone(),
@ -4048,35 +4049,40 @@ async fn join_channel_internal(
.join_channel(channel_id, session.user_id(), session.connection_id)
.await?;
let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
let (can_publish, token) = if role == ChannelRole::Guest {
(
false,
live_kit
.guest_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
let live_kit_connection_info =
session
.app_state
.live_kit_client
.as_ref()
.and_then(|live_kit| {
let (can_publish, token) = if role == ChannelRole::Guest {
(
false,
live_kit
.guest_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()?,
)
.trace_err()?,
)
} else {
(
true,
live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
} else {
(
true,
live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()?,
)
.trace_err()?,
)
};
};
Some(LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish,
})
});
Some(LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish,
})
});
response.send(proto::JoinRoomResponse {
room: Some(joined_room.room.clone()),
@ -4610,6 +4616,7 @@ async fn complete_with_language_model(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -4655,6 +4662,7 @@ async fn stream_complete_with_language_model(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -4766,6 +4774,7 @@ async fn count_language_model_tokens(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -4885,6 +4894,7 @@ async fn compute_embeddings(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -5143,6 +5153,24 @@ async fn get_private_user_info(
Ok(())
}
async fn get_llm_api_token(
_request: proto::GetLlmToken,
response: Response<proto::GetLlmToken>,
session: UserSession,
) -> Result<()> {
if !session.is_staff() {
Err(anyhow!("permission denied"))?
}
let token = LlmTokenClaims::create(
session.user_id(),
session.current_plan().await?,
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;
Ok(())
}
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
let message = match message {
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
@ -5486,7 +5514,7 @@ async fn leave_room_for_session(session: &UserSession, connection_id: Connection
update_user_contacts(contact_user_id, &session).await?;
}
if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
live_kit
.remove_participant(live_kit_room.clone(), session.user_id().to_string())
.await

View file

@ -651,6 +651,7 @@ impl TestServer {
live_kit_server: None,
live_kit_key: None,
live_kit_secret: None,
llm_api_secret: None,
rust_log: None,
log_json: None,
zed_environment: "test".into(),