Allow AI interactions to be proxied through Zed's server so you don't need an API key (#7367)
Co-authored-by: Antonio <antonio@zed.dev> Resurrected this from some assistant work I did in Spring of 2023. - [x] Resurrect streaming responses - [x] Use streaming responses to enable AI via Zed's servers by default (but preserve API key option for now) - [x] Simplify protobuf - [x] Proxy to OpenAI on zed.dev - [x] Proxy to Gemini on zed.dev - [x] Improve UX for switching between openAI and google models - We current disallow cycling when setting a custom model, but we need a better solution to keep OpenAI models available while testing the google ones - [x] Show remaining tokens correctly for Google models - [x] Remove semantic index - [x] Delete `ai` crate - [x] Cloud front so we can ban abuse - [x] Rate-limiting - [x] Fix panic when using inline assistant - [x] Double check the upgraded `AssistantSettings` are backwards-compatible - [x] Add hosted LLM interaction behind a `language-models` feature flag. Release Notes: - We are temporarily removing the semantic index in order to redesign it from scratch. --------- Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Thorsten <thorsten@zed.dev> Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
905a24079a
commit
8ae5a3b61a
87 changed files with 3647 additions and 8937 deletions
|
@ -1,7 +1,7 @@
|
|||
syntax = "proto3";
|
||||
package zed.messages;
|
||||
|
||||
// Looking for a number? Search "// Current max"
|
||||
// Looking for a number? Search "// current max"
|
||||
|
||||
message PeerId {
|
||||
uint32 owner_id = 1;
|
||||
|
@ -26,6 +26,7 @@ message Envelope {
|
|||
Error error = 6;
|
||||
Ping ping = 7;
|
||||
Test test = 8;
|
||||
EndStream end_stream = 165;
|
||||
|
||||
CreateRoom create_room = 9;
|
||||
CreateRoomResponse create_room_response = 10;
|
||||
|
@ -198,6 +199,11 @@ message Envelope {
|
|||
GetImplementationResponse get_implementation_response = 163;
|
||||
|
||||
JoinHostedProject join_hosted_project = 164;
|
||||
|
||||
CompleteWithLanguageModel complete_with_language_model = 166;
|
||||
LanguageModelResponse language_model_response = 167;
|
||||
CountTokensWithLanguageModel count_tokens_with_language_model = 168;
|
||||
CountTokensResponse count_tokens_response = 169; // current max
|
||||
}
|
||||
|
||||
reserved 158 to 161;
|
||||
|
@ -236,6 +242,8 @@ enum ErrorCode {
|
|||
reserved 6;
|
||||
}
|
||||
|
||||
message EndStream {}
|
||||
|
||||
message Test {
|
||||
uint64 id = 1;
|
||||
}
|
||||
|
@ -1718,3 +1726,45 @@ message SetRoomParticipantRole {
|
|||
uint64 user_id = 2;
|
||||
ChannelRole role = 3;
|
||||
}
|
||||
|
||||
message CompleteWithLanguageModel {
|
||||
string model = 1;
|
||||
repeated LanguageModelRequestMessage messages = 2;
|
||||
repeated string stop = 3;
|
||||
float temperature = 4;
|
||||
}
|
||||
|
||||
message LanguageModelRequestMessage {
|
||||
LanguageModelRole role = 1;
|
||||
string content = 2;
|
||||
}
|
||||
|
||||
enum LanguageModelRole {
|
||||
LanguageModelUser = 0;
|
||||
LanguageModelAssistant = 1;
|
||||
LanguageModelSystem = 2;
|
||||
}
|
||||
|
||||
message LanguageModelResponseMessage {
|
||||
optional LanguageModelRole role = 1;
|
||||
optional string content = 2;
|
||||
}
|
||||
|
||||
message LanguageModelResponse {
|
||||
repeated LanguageModelChoiceDelta choices = 1;
|
||||
}
|
||||
|
||||
message LanguageModelChoiceDelta {
|
||||
uint32 index = 1;
|
||||
LanguageModelResponseMessage delta = 2;
|
||||
optional string finish_reason = 3;
|
||||
}
|
||||
|
||||
message CountTokensWithLanguageModel {
|
||||
string model = 1;
|
||||
repeated LanguageModelRequestMessage messages = 2;
|
||||
}
|
||||
|
||||
message CountTokensResponse {
|
||||
uint32 token_count = 1;
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ pub trait ErrorExt {
|
|||
fn error_tag(&self, k: &str) -> Option<&str>;
|
||||
/// to_proto() converts the error into a proto::Error
|
||||
fn to_proto(&self) -> proto::Error;
|
||||
///
|
||||
/// Clones the error and turns into an [anyhow::Error].
|
||||
fn cloned(&self) -> anyhow::Error;
|
||||
}
|
||||
|
||||
|
|
|
@ -9,19 +9,21 @@ use collections::HashMap;
|
|||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
stream::BoxStream,
|
||||
FutureExt, SinkExt, StreamExt, TryFutureExt,
|
||||
FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
|
||||
};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use serde::{ser::SerializeStruct, Serialize};
|
||||
use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
|
||||
use std::{
|
||||
fmt, future,
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
sync::atomic::Ordering::SeqCst,
|
||||
sync::{
|
||||
atomic::{self, AtomicU32},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
time::Instant,
|
||||
};
|
||||
use tracing::instrument;
|
||||
|
||||
|
@ -118,6 +120,15 @@ pub struct ConnectionState {
|
|||
>,
|
||||
>,
|
||||
>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[serde(skip)]
|
||||
stream_response_channels: Arc<
|
||||
Mutex<
|
||||
Option<
|
||||
HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
}
|
||||
|
||||
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
|
||||
|
@ -171,17 +182,28 @@ impl Peer {
|
|||
outgoing_tx,
|
||||
next_message_id: Default::default(),
|
||||
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
|
||||
stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
|
||||
};
|
||||
let mut writer = MessageStream::new(connection.tx);
|
||||
let mut reader = MessageStream::new(connection.rx);
|
||||
|
||||
let this = self.clone();
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
let stream_response_channels = connection_state.stream_response_channels.clone();
|
||||
|
||||
let handle_io = async move {
|
||||
tracing::trace!(%connection_id, "handle io future: start");
|
||||
|
||||
let _end_connection = util::defer(|| {
|
||||
response_channels.lock().take();
|
||||
if let Some(channels) = stream_response_channels.lock().take() {
|
||||
for channel in channels.values() {
|
||||
let _ = channel.unbounded_send((
|
||||
Err(anyhow!("connection closed")),
|
||||
oneshot::channel().0,
|
||||
));
|
||||
}
|
||||
}
|
||||
this.connections.write().remove(&connection_id);
|
||||
tracing::trace!(%connection_id, "handle io future: end");
|
||||
});
|
||||
|
@ -273,12 +295,14 @@ impl Peer {
|
|||
};
|
||||
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
let stream_response_channels = connection_state.stream_response_channels.clone();
|
||||
self.connections
|
||||
.write()
|
||||
.insert(connection_id, connection_state);
|
||||
|
||||
let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
|
||||
let response_channels = response_channels.clone();
|
||||
let stream_response_channels = stream_response_channels.clone();
|
||||
async move {
|
||||
let message_id = incoming.id;
|
||||
tracing::trace!(?incoming, "incoming message future: start");
|
||||
|
@ -293,8 +317,15 @@ impl Peer {
|
|||
responding_to,
|
||||
"incoming response: received"
|
||||
);
|
||||
let channel = response_channels.lock().as_mut()?.remove(&responding_to);
|
||||
if let Some(tx) = channel {
|
||||
let response_channel =
|
||||
response_channels.lock().as_mut()?.remove(&responding_to);
|
||||
let stream_response_channel = stream_response_channels
|
||||
.lock()
|
||||
.as_ref()?
|
||||
.get(&responding_to)
|
||||
.cloned();
|
||||
|
||||
if let Some(tx) = response_channel {
|
||||
let requester_resumed = oneshot::channel();
|
||||
if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
|
||||
tracing::trace!(
|
||||
|
@ -319,6 +350,31 @@ impl Peer {
|
|||
responding_to,
|
||||
"incoming response: requester resumed"
|
||||
);
|
||||
} else if let Some(tx) = stream_response_channel {
|
||||
let requester_resumed = oneshot::channel();
|
||||
if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
|
||||
tracing::debug!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to = responding_to,
|
||||
?error,
|
||||
"incoming stream response: request future dropped",
|
||||
);
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming stream response: waiting to resume requester"
|
||||
);
|
||||
let _ = requester_resumed.1.await;
|
||||
tracing::debug!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming stream response: requester resumed"
|
||||
);
|
||||
} else {
|
||||
let message_type =
|
||||
proto::build_typed_envelope(connection_id, received_at, incoming)
|
||||
|
@ -451,6 +507,66 @@ impl Peer {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn request_stream<T: RequestMessage>(
|
||||
&self,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let send = self.connection_state(receiver_id).and_then(|connection| {
|
||||
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
|
||||
let stream_response_channels = connection.stream_response_channels.clone();
|
||||
stream_response_channels
|
||||
.lock()
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("connection was closed"))?
|
||||
.insert(message_id, tx);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(
|
||||
request.into_envelope(message_id, None, None),
|
||||
))
|
||||
.map_err(|_| anyhow!("connection was closed"))?;
|
||||
Ok((message_id, stream_response_channels))
|
||||
});
|
||||
|
||||
async move {
|
||||
let (message_id, stream_response_channels) = send?;
|
||||
let stream_response_channels = Arc::downgrade(&stream_response_channels);
|
||||
|
||||
Ok(rx.filter_map(move |(response, _barrier)| {
|
||||
let stream_response_channels = stream_response_channels.clone();
|
||||
future::ready(match response {
|
||||
Ok(response) => {
|
||||
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
|
||||
Some(Err(anyhow!(
|
||||
"RPC request {} failed - {}",
|
||||
T::NAME,
|
||||
error.message
|
||||
)))
|
||||
} else if let Some(proto::envelope::Payload::EndStream(_)) =
|
||||
&response.payload
|
||||
{
|
||||
// Remove the transmitting end of the response channel to end the stream.
|
||||
if let Some(channels) = stream_response_channels.upgrade() {
|
||||
if let Some(channels) = channels.lock().as_mut() {
|
||||
channels.remove(&message_id);
|
||||
}
|
||||
}
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
T::Response::from_envelope(response)
|
||||
.ok_or_else(|| anyhow!("received response of the wrong type")),
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(error)),
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
|
||||
let connection = self.connection_state(receiver_id)?;
|
||||
let message_id = connection
|
||||
|
@ -503,6 +619,24 @@ impl Peer {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
|
||||
let connection = self.connection_state(receipt.sender_id)?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
|
||||
let message = proto::EndStream {};
|
||||
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(message.into_envelope(
|
||||
message_id,
|
||||
Some(receipt.message_id),
|
||||
None,
|
||||
)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn respond_with_error<T: RequestMessage>(
|
||||
&self,
|
||||
receipt: Receipt<T>,
|
||||
|
|
|
@ -149,7 +149,10 @@ messages!(
|
|||
(CallCanceled, Foreground),
|
||||
(CancelCall, Foreground),
|
||||
(ChannelMessageSent, Foreground),
|
||||
(CompleteWithLanguageModel, Background),
|
||||
(CopyProjectEntry, Foreground),
|
||||
(CountTokensWithLanguageModel, Background),
|
||||
(CountTokensResponse, Background),
|
||||
(CreateBufferForPeer, Foreground),
|
||||
(CreateChannel, Foreground),
|
||||
(CreateChannelResponse, Foreground),
|
||||
|
@ -160,6 +163,7 @@ messages!(
|
|||
(DeleteChannel, Foreground),
|
||||
(DeleteNotification, Foreground),
|
||||
(DeleteProjectEntry, Foreground),
|
||||
(EndStream, Foreground),
|
||||
(Error, Foreground),
|
||||
(ExpandProjectEntry, Foreground),
|
||||
(ExpandProjectEntryResponse, Foreground),
|
||||
|
@ -211,6 +215,7 @@ messages!(
|
|||
(JoinProjectResponse, Foreground),
|
||||
(JoinRoom, Foreground),
|
||||
(JoinRoomResponse, Foreground),
|
||||
(LanguageModelResponse, Background),
|
||||
(LeaveChannelBuffer, Background),
|
||||
(LeaveChannelChat, Foreground),
|
||||
(LeaveProject, Foreground),
|
||||
|
@ -300,6 +305,8 @@ request_messages!(
|
|||
(Call, Ack),
|
||||
(CancelCall, Ack),
|
||||
(CopyProjectEntry, ProjectEntryResponse),
|
||||
(CompleteWithLanguageModel, LanguageModelResponse),
|
||||
(CountTokensWithLanguageModel, CountTokensResponse),
|
||||
(CreateChannel, CreateChannelResponse),
|
||||
(CreateProjectEntry, ProjectEntryResponse),
|
||||
(CreateRoom, CreateRoomResponse),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue