Allow AI interactions to be proxied through Zed's server so you don't need an API key (#7367)

Co-authored-by: Antonio <antonio@zed.dev>

Resurrected this from some assistant work I did in Spring of 2023.
- [x] Resurrect streaming responses
- [x] Use streaming responses to enable AI via Zed's servers by default
(but preserve API key option for now)
- [x] Simplify protobuf
- [x] Proxy to OpenAI on zed.dev
- [x] Proxy to Gemini on zed.dev
- [x] Improve UX for switching between openAI and google models
- We current disallow cycling when setting a custom model, but we need a
better solution to keep OpenAI models available while testing the google
ones
- [x] Show remaining tokens correctly for Google models
- [x] Remove semantic index
- [x] Delete `ai` crate
- [x] Cloud front so we can ban abuse
- [x] Rate-limiting
- [x] Fix panic when using inline assistant
- [x] Double check the upgraded `AssistantSettings` are
backwards-compatible
- [x] Add hosted LLM interaction behind a `language-models` feature
flag.

Release Notes:

- We are temporarily removing the semantic index in order to redesign it
from scratch.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Nathan Sobo 2024-03-19 12:22:26 -06:00 committed by GitHub
parent 905a24079a
commit 8ae5a3b61a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
87 changed files with 3647 additions and 8937 deletions

View file

@ -1,7 +1,7 @@
syntax = "proto3";
package zed.messages;
// Looking for a number? Search "// Current max"
// Looking for a number? Search "// current max"
message PeerId {
uint32 owner_id = 1;
@ -26,6 +26,7 @@ message Envelope {
Error error = 6;
Ping ping = 7;
Test test = 8;
EndStream end_stream = 165;
CreateRoom create_room = 9;
CreateRoomResponse create_room_response = 10;
@ -198,6 +199,11 @@ message Envelope {
GetImplementationResponse get_implementation_response = 163;
JoinHostedProject join_hosted_project = 164;
CompleteWithLanguageModel complete_with_language_model = 166;
LanguageModelResponse language_model_response = 167;
CountTokensWithLanguageModel count_tokens_with_language_model = 168;
CountTokensResponse count_tokens_response = 169; // current max
}
reserved 158 to 161;
@ -236,6 +242,8 @@ enum ErrorCode {
reserved 6;
}
message EndStream {}
message Test {
uint64 id = 1;
}
@ -1718,3 +1726,45 @@ message SetRoomParticipantRole {
uint64 user_id = 2;
ChannelRole role = 3;
}
message CompleteWithLanguageModel {
string model = 1;
repeated LanguageModelRequestMessage messages = 2;
repeated string stop = 3;
float temperature = 4;
}
message LanguageModelRequestMessage {
LanguageModelRole role = 1;
string content = 2;
}
enum LanguageModelRole {
LanguageModelUser = 0;
LanguageModelAssistant = 1;
LanguageModelSystem = 2;
}
message LanguageModelResponseMessage {
optional LanguageModelRole role = 1;
optional string content = 2;
}
message LanguageModelResponse {
repeated LanguageModelChoiceDelta choices = 1;
}
message LanguageModelChoiceDelta {
uint32 index = 1;
LanguageModelResponseMessage delta = 2;
optional string finish_reason = 3;
}
message CountTokensWithLanguageModel {
string model = 1;
repeated LanguageModelRequestMessage messages = 2;
}
message CountTokensResponse {
uint32 token_count = 1;
}

View file

@ -80,7 +80,7 @@ pub trait ErrorExt {
fn error_tag(&self, k: &str) -> Option<&str>;
/// to_proto() converts the error into a proto::Error
fn to_proto(&self) -> proto::Error;
///
/// Clones the error and turns into an [anyhow::Error].
fn cloned(&self) -> anyhow::Error;
}

View file

@ -9,19 +9,21 @@ use collections::HashMap;
use futures::{
channel::{mpsc, oneshot},
stream::BoxStream,
FutureExt, SinkExt, StreamExt, TryFutureExt,
FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
};
use parking_lot::{Mutex, RwLock};
use serde::{ser::SerializeStruct, Serialize};
use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
use std::{
fmt, future,
future::Future,
marker::PhantomData,
sync::atomic::Ordering::SeqCst,
sync::{
atomic::{self, AtomicU32},
Arc,
},
time::Duration,
time::Instant,
};
use tracing::instrument;
@ -118,6 +120,15 @@ pub struct ConnectionState {
>,
>,
>,
#[allow(clippy::type_complexity)]
#[serde(skip)]
stream_response_channels: Arc<
Mutex<
Option<
HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
>,
>,
>,
}
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
@ -171,17 +182,28 @@ impl Peer {
outgoing_tx,
next_message_id: Default::default(),
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
};
let mut writer = MessageStream::new(connection.tx);
let mut reader = MessageStream::new(connection.rx);
let this = self.clone();
let response_channels = connection_state.response_channels.clone();
let stream_response_channels = connection_state.stream_response_channels.clone();
let handle_io = async move {
tracing::trace!(%connection_id, "handle io future: start");
let _end_connection = util::defer(|| {
response_channels.lock().take();
if let Some(channels) = stream_response_channels.lock().take() {
for channel in channels.values() {
let _ = channel.unbounded_send((
Err(anyhow!("connection closed")),
oneshot::channel().0,
));
}
}
this.connections.write().remove(&connection_id);
tracing::trace!(%connection_id, "handle io future: end");
});
@ -273,12 +295,14 @@ impl Peer {
};
let response_channels = connection_state.response_channels.clone();
let stream_response_channels = connection_state.stream_response_channels.clone();
self.connections
.write()
.insert(connection_id, connection_state);
let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
let response_channels = response_channels.clone();
let stream_response_channels = stream_response_channels.clone();
async move {
let message_id = incoming.id;
tracing::trace!(?incoming, "incoming message future: start");
@ -293,8 +317,15 @@ impl Peer {
responding_to,
"incoming response: received"
);
let channel = response_channels.lock().as_mut()?.remove(&responding_to);
if let Some(tx) = channel {
let response_channel =
response_channels.lock().as_mut()?.remove(&responding_to);
let stream_response_channel = stream_response_channels
.lock()
.as_ref()?
.get(&responding_to)
.cloned();
if let Some(tx) = response_channel {
let requester_resumed = oneshot::channel();
if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
tracing::trace!(
@ -319,6 +350,31 @@ impl Peer {
responding_to,
"incoming response: requester resumed"
);
} else if let Some(tx) = stream_response_channel {
let requester_resumed = oneshot::channel();
if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
tracing::debug!(
%connection_id,
message_id,
responding_to = responding_to,
?error,
"incoming stream response: request future dropped",
);
}
tracing::debug!(
%connection_id,
message_id,
responding_to,
"incoming stream response: waiting to resume requester"
);
let _ = requester_resumed.1.await;
tracing::debug!(
%connection_id,
message_id,
responding_to,
"incoming stream response: requester resumed"
);
} else {
let message_type =
proto::build_typed_envelope(connection_id, received_at, incoming)
@ -451,6 +507,66 @@ impl Peer {
}
}
pub fn request_stream<T: RequestMessage>(
&self,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
let (tx, rx) = mpsc::unbounded();
let send = self.connection_state(receiver_id).and_then(|connection| {
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
let stream_response_channels = connection.stream_response_channels.clone();
stream_response_channels
.lock()
.as_mut()
.ok_or_else(|| anyhow!("connection was closed"))?
.insert(message_id, tx);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(
request.into_envelope(message_id, None, None),
))
.map_err(|_| anyhow!("connection was closed"))?;
Ok((message_id, stream_response_channels))
});
async move {
let (message_id, stream_response_channels) = send?;
let stream_response_channels = Arc::downgrade(&stream_response_channels);
Ok(rx.filter_map(move |(response, _barrier)| {
let stream_response_channels = stream_response_channels.clone();
future::ready(match response {
Ok(response) => {
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
Some(Err(anyhow!(
"RPC request {} failed - {}",
T::NAME,
error.message
)))
} else if let Some(proto::envelope::Payload::EndStream(_)) =
&response.payload
{
// Remove the transmitting end of the response channel to end the stream.
if let Some(channels) = stream_response_channels.upgrade() {
if let Some(channels) = channels.lock().as_mut() {
channels.remove(&message_id);
}
}
None
} else {
Some(
T::Response::from_envelope(response)
.ok_or_else(|| anyhow!("received response of the wrong type")),
)
}
}
Err(error) => Some(Err(error)),
})
}))
}
}
pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
let connection = self.connection_state(receiver_id)?;
let message_id = connection
@ -503,6 +619,24 @@ impl Peer {
Ok(())
}
pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
let connection = self.connection_state(receipt.sender_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
let message = proto::EndStream {};
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(message.into_envelope(
message_id,
Some(receipt.message_id),
None,
)))?;
Ok(())
}
pub fn respond_with_error<T: RequestMessage>(
&self,
receipt: Receipt<T>,

View file

@ -149,7 +149,10 @@ messages!(
(CallCanceled, Foreground),
(CancelCall, Foreground),
(ChannelMessageSent, Foreground),
(CompleteWithLanguageModel, Background),
(CopyProjectEntry, Foreground),
(CountTokensWithLanguageModel, Background),
(CountTokensResponse, Background),
(CreateBufferForPeer, Foreground),
(CreateChannel, Foreground),
(CreateChannelResponse, Foreground),
@ -160,6 +163,7 @@ messages!(
(DeleteChannel, Foreground),
(DeleteNotification, Foreground),
(DeleteProjectEntry, Foreground),
(EndStream, Foreground),
(Error, Foreground),
(ExpandProjectEntry, Foreground),
(ExpandProjectEntryResponse, Foreground),
@ -211,6 +215,7 @@ messages!(
(JoinProjectResponse, Foreground),
(JoinRoom, Foreground),
(JoinRoomResponse, Foreground),
(LanguageModelResponse, Background),
(LeaveChannelBuffer, Background),
(LeaveChannelChat, Foreground),
(LeaveProject, Foreground),
@ -300,6 +305,8 @@ request_messages!(
(Call, Ack),
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),
(CompleteWithLanguageModel, LanguageModelResponse),
(CountTokensWithLanguageModel, CountTokensResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),