Allow AI interactions to be proxied through Zed's server so you don't need an API key (#7367)

Co-authored-by: Antonio <antonio@zed.dev>

Resurrected this from some assistant work I did in Spring of 2023.
- [x] Resurrect streaming responses
- [x] Use streaming responses to enable AI via Zed's servers by default
(but preserve API key option for now)
- [x] Simplify protobuf
- [x] Proxy to OpenAI on zed.dev
- [x] Proxy to Gemini on zed.dev
- [x] Improve UX for switching between openAI and google models
- We current disallow cycling when setting a custom model, but we need a
better solution to keep OpenAI models available while testing the google
ones
- [x] Show remaining tokens correctly for Google models
- [x] Remove semantic index
- [x] Delete `ai` crate
- [x] Cloud front so we can ban abuse
- [x] Rate-limiting
- [x] Fix panic when using inline assistant
- [x] Double check the upgraded `AssistantSettings` are
backwards-compatible
- [x] Add hosted LLM interaction behind a `language-models` feature
flag.

Release Notes:

- We are temporarily removing the semantic index in order to redesign it
from scratch.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Nathan Sobo 2024-03-19 12:22:26 -06:00 committed by GitHub
parent 905a24079a
commit 8ae5a3b61a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
87 changed files with 3647 additions and 8937 deletions

View file

@ -9,9 +9,9 @@ use crate::{
User, UserId,
},
executor::Executor,
AppState, Error, Result,
AppState, Error, RateLimit, RateLimiter, Result,
};
use anyhow::anyhow;
use anyhow::{anyhow, Context as _};
use async_tungstenite::tungstenite::{
protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
};
@ -30,6 +30,8 @@ use axum::{
};
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use futures::{
channel::oneshot,
future::{self, BoxFuture},
@ -39,15 +41,14 @@ use futures::{
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
},
Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
};
use serde::{Serialize, Serializer};
use std::{
any::TypeId,
fmt,
future::Future,
marker::PhantomData,
mem,
@ -64,7 +65,7 @@ use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore};
use tower::ServiceBuilder;
use tracing::{field, info_span, instrument, Instrument};
use util::SemanticVersion;
use util::{http::IsahcHttpClient, SemanticVersion};
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
@ -92,6 +93,18 @@ impl<R: RequestMessage> Response<R> {
}
}
struct StreamingResponse<R: RequestMessage> {
peer: Arc<Peer>,
receipt: Receipt<R>,
}
impl<R: RequestMessage> StreamingResponse<R> {
fn send(&self, payload: R::Response) -> Result<()> {
self.peer.respond(self.receipt, payload)?;
Ok(())
}
}
#[derive(Clone)]
struct Session {
user_id: UserId,
@ -100,6 +113,8 @@ struct Session {
peer: Arc<Peer>,
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
http_client: IsahcHttpClient,
rate_limiter: Arc<RateLimiter>,
_executor: Executor,
}
@ -124,8 +139,8 @@ impl Session {
}
}
impl fmt::Debug for Session {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
impl Debug for Session {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Session")
.field("user_id", &self.user_id)
.field("connection_id", &self.connection_id)
@ -148,7 +163,6 @@ pub struct Server {
peer: Arc<Peer>,
pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
app_state: Arc<AppState>,
executor: Executor,
handlers: HashMap<TypeId, MessageHandler>,
teardown: watch::Sender<bool>,
}
@ -175,12 +189,11 @@ where
}
impl Server {
pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
let mut server = Self {
id: parking_lot::Mutex::new(id),
peer: Peer::new(id.0 as u32),
app_state,
executor,
app_state: app_state.clone(),
connection_pool: Default::default(),
handlers: Default::default(),
teardown: watch::channel(false).0,
@ -280,7 +293,30 @@ impl Server {
.add_message_handler(update_followers)
.add_request_handler(get_private_user_info)
.add_message_handler(acknowledge_channel_message)
.add_message_handler(acknowledge_buffer_version);
.add_message_handler(acknowledge_buffer_version)
.add_streaming_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
complete_with_language_model(
request,
response,
session,
app_state.config.openai_api_key.clone(),
app_state.config.google_ai_api_key.clone(),
)
}
})
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
count_tokens_with_language_model(
request,
response,
session,
app_state.config.google_ai_api_key.clone(),
)
}
});
Arc::new(server)
}
@ -289,12 +325,12 @@ impl Server {
let server_id = *self.id.lock();
let app_state = self.app_state.clone();
let peer = self.peer.clone();
let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
let pool = self.connection_pool.clone();
let live_kit_client = self.app_state.live_kit_client.clone();
let span = info_span!("start server");
self.executor.spawn_detached(
self.app_state.executor.spawn_detached(
async move {
tracing::info!("waiting for cleanup timeout");
timeout.await;
@ -536,6 +572,40 @@ impl Server {
})
}
fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
Fut: Send + Future<Output = Result<()>>,
M: RequestMessage,
{
let handler = Arc::new(handler);
self.add_handler(move |envelope, session| {
let receipt = envelope.receipt();
let handler = handler.clone();
async move {
let peer = session.peer.clone();
let response = StreamingResponse {
peer: peer.clone(),
receipt,
};
match (handler)(envelope.payload, response, session).await {
Ok(()) => {
peer.end_stream(receipt)?;
Ok(())
}
Err(error) => {
let proto_err = match &error {
Error::Internal(err) => err.to_proto(),
_ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
};
peer.respond_with_error(receipt, proto_err)?;
Err(error)
}
}
}
})
}
#[allow(clippy::too_many_arguments)]
pub fn handle_connection(
self: &Arc<Self>,
@ -569,6 +639,14 @@ impl Server {
tracing::Span::current().record("connection_id", format!("{}", connection_id));
tracing::info!("connection opened");
let http_client = match IsahcHttpClient::new() {
Ok(http_client) => http_client,
Err(error) => {
tracing::error!(?error, "failed to create HTTP client");
return;
}
};
let session = Session {
user_id,
connection_id,
@ -576,7 +654,9 @@ impl Server {
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
live_kit_client: this.app_state.live_kit_client.clone(),
_executor: executor.clone()
http_client,
rate_limiter: this.app_state.rate_limiter.clone(),
_executor: executor.clone(),
};
if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
@ -3220,6 +3300,207 @@ async fn acknowledge_buffer_version(
Ok(())
}
struct CompleteWithLanguageModelRateLimit;
impl RateLimit for CompleteWithLanguageModelRateLimit {
fn capacity() -> usize {
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120) // Picked arbitrarily
}
fn refill_duration() -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name() -> &'static str {
"complete-with-language-model"
}
}
async fn complete_with_language_model(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id)
.await?;
if request.model.starts_with("gpt") {
let api_key =
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
complete_with_open_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("gemini") {
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
complete_with_google_ai(request, response, session, api_key).await?;
}
Ok(())
}
async fn complete_with_open_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
api_key: Arc<str>,
) -> Result<()> {
const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
let mut completion_stream = open_ai::stream_completion(
&session.http_client,
OPEN_AI_API_URL,
&api_key,
crate::ai::language_model_request_to_open_ai(request)?,
)
.await
.context("open_ai::stream_completion request failed")?;
while let Some(event) = completion_stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.choices
.into_iter()
.map(|choice| proto::LanguageModelChoiceDelta {
index: choice.index,
delta: Some(proto::LanguageModelResponseMessage {
role: choice.delta.role.map(|role| match role {
open_ai::Role::User => LanguageModelRole::LanguageModelUser,
open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
} as i32),
content: choice.delta.content,
}),
finish_reason: choice.finish_reason,
})
.collect(),
})?;
}
Ok(())
}
async fn complete_with_google_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
api_key: Arc<str>,
) -> Result<()> {
let mut stream = google_ai::stream_generate_content(
&session.http_client,
google_ai::API_URL,
api_key.as_ref(),
crate::ai::language_model_request_to_google_ai(request)?,
)
.await
.context("google_ai::stream_generate_content request failed")?;
while let Some(event) = stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.candidates
.unwrap_or_default()
.into_iter()
.map(|candidate| proto::LanguageModelChoiceDelta {
index: candidate.index as u32,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(match candidate.content.role {
google_ai::Role::User => LanguageModelRole::LanguageModelUser,
google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
} as i32),
content: Some(
candidate
.content
.parts
.into_iter()
.filter_map(|part| match part {
google_ai::Part::TextPart(part) => Some(part.text),
google_ai::Part::InlineDataPart(_) => None,
})
.collect(),
),
}),
finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
})
.collect(),
})?;
}
Ok(())
}
struct CountTokensWithLanguageModelRateLimit;
impl RateLimit for CountTokensWithLanguageModelRateLimit {
fn capacity() -> usize {
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
}
fn refill_duration() -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name() -> &'static str {
"count-tokens-with-language-model"
}
}
async fn count_tokens_with_language_model(
request: proto::CountTokensWithLanguageModel,
response: Response<proto::CountTokensWithLanguageModel>,
session: Session,
google_ai_api_key: Option<Arc<str>>,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
if !request.model.starts_with("gemini") {
return Err(anyhow!(
"counting tokens for model: {:?} is not supported",
request.model
))?;
}
session
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id)
.await?;
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
let tokens_response = google_ai::count_tokens(
&session.http_client,
google_ai::API_URL,
&api_key,
crate::ai::count_tokens_request_to_google_ai(request)?,
)
.await?;
response.send(proto::CountTokensResponse {
token_count: tokens_response.total_tokens as u32,
})?;
Ok(())
}
async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> {
let db = session.db().await;
let flags = db.get_user_flags(session.user_id).await?;
if flags.iter().any(|flag| flag == "language-models") {
Ok(())
} else {
Err(anyhow!("permission denied"))?
}
}
/// Start receiving chat updates for a channel
async fn join_channel_chat(
request: proto::JoinChannelChat,