Introduce a separate backend service for LLM calls (#15831)

This PR introduces a separate backend service for making LLM calls.

It exposes an HTTP interface that can be called by Zed clients. To call
these endpoints, the client must provide a `Bearer` token. These tokens
are issued/refreshed by the collab service over RPC.

We're adding this in a backwards-compatible way. Right now the access
tokens can only be minted for Zed staff, and calling this separate LLM
service is behind the `llm-service` feature flag (which is not
automatically enabled for Zed staff).

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Max Brunsfeld 2024-08-05 17:26:21 -07:00 committed by GitHub
parent 4ed43e6e6f
commit 8e9c2b1125
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 478 additions and 102 deletions

View file

@ -1,6 +1,7 @@
mod connection_pool;
use crate::api::CloudflareIpCountryHeader;
use crate::llm::LlmTokenClaims;
use crate::{
auth,
db::{
@ -11,7 +12,7 @@ use crate::{
ServerId, UpdatedChannelMessage, User, UserId,
},
executor::Executor,
AppState, Config, Error, RateLimit, RateLimiter, Result,
AppState, Config, Error, RateLimit, Result,
};
use anyhow::{anyhow, bail, Context as _};
use async_tungstenite::tungstenite::{
@ -149,10 +150,9 @@ struct Session {
db: Arc<tokio::sync::Mutex<DbHandle>>,
peer: Arc<Peer>,
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
app_state: Arc<AppState>,
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<IsahcHttpClient>,
rate_limiter: Arc<RateLimiter>,
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
@ -615,6 +615,7 @@ impl Server {
.add_message_handler(user_message_handler(unfollow))
.add_message_handler(user_message_handler(update_followers))
.add_request_handler(user_handler(get_private_user_info))
.add_request_handler(user_handler(get_llm_api_token))
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
@ -1046,9 +1047,8 @@ impl Server {
db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
live_kit_client: this.app_state.live_kit_client.clone(),
app_state: this.app_state.clone(),
http_client,
rate_limiter: this.app_state.rate_limiter.clone(),
geoip_country_code,
_executor: executor.clone(),
supermaven_client,
@ -1559,7 +1559,7 @@ async fn create_room(
let live_kit_room = nanoid::nanoid!(30);
let live_kit_connection_info = util::maybe!(async {
let live_kit = session.live_kit_client.as_ref();
let live_kit = session.app_state.live_kit_client.as_ref();
let live_kit = live_kit?;
let user_id = session.user_id().to_string();
@ -1630,25 +1630,26 @@ async fn join_room(
.trace_err();
}
let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(token) = live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()
{
Some(proto::LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish: true,
})
let live_kit_connection_info =
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
if let Some(token) = live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()
{
Some(proto::LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish: true,
})
} else {
None
}
} else {
None
}
} else {
None
};
};
response.send(proto::JoinRoomResponse {
room: Some(joined_room.room),
@ -1877,7 +1878,7 @@ async fn set_room_participant_role(
(live_kit_room, can_publish)
};
if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
live_kit
.update_participant(
live_kit_room.clone(),
@ -4048,35 +4049,40 @@ async fn join_channel_internal(
.join_channel(channel_id, session.user_id(), session.connection_id)
.await?;
let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
let (can_publish, token) = if role == ChannelRole::Guest {
(
false,
live_kit
.guest_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
let live_kit_connection_info =
session
.app_state
.live_kit_client
.as_ref()
.and_then(|live_kit| {
let (can_publish, token) = if role == ChannelRole::Guest {
(
false,
live_kit
.guest_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()?,
)
.trace_err()?,
)
} else {
(
true,
live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
} else {
(
true,
live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()?,
)
.trace_err()?,
)
};
};
Some(LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish,
})
});
Some(LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish,
})
});
response.send(proto::JoinRoomResponse {
room: Some(joined_room.room.clone()),
@ -4610,6 +4616,7 @@ async fn complete_with_language_model(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -4655,6 +4662,7 @@ async fn stream_complete_with_language_model(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -4766,6 +4774,7 @@ async fn count_language_model_tokens(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -4885,6 +4894,7 @@ async fn compute_embeddings(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -5143,6 +5153,24 @@ async fn get_private_user_info(
Ok(())
}
async fn get_llm_api_token(
_request: proto::GetLlmToken,
response: Response<proto::GetLlmToken>,
session: UserSession,
) -> Result<()> {
if !session.is_staff() {
Err(anyhow!("permission denied"))?
}
let token = LlmTokenClaims::create(
session.user_id(),
session.current_plan().await?,
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;
Ok(())
}
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
let message = match message {
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
@ -5486,7 +5514,7 @@ async fn leave_room_for_session(session: &UserSession, connection_id: Connection
update_user_contacts(contact_user_id, &session).await?;
}
if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
live_kit
.remove_participant(live_kit_room.clone(), session.user_id().to_string())
.await