collab: Remove LLM completions over RPC (#16114)
This PR removes the LLM completion messages from the RPC protocol, as these now go through the LLM service as of #16113. Release Notes: - N/A
This commit is contained in:
parent
f992cfdc7f
commit
f952126319
3 changed files with 1 additions and 299 deletions
|
@ -105,18 +105,6 @@ impl<R: RequestMessage> Response<R> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct StreamingResponse<R: RequestMessage> {
|
|
||||||
peer: Arc<Peer>,
|
|
||||||
receipt: Receipt<R>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R: RequestMessage> StreamingResponse<R> {
|
|
||||||
fn send(&self, payload: R::Response) -> Result<()> {
|
|
||||||
self.peer.respond(self.receipt, payload)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub enum Principal {
|
pub enum Principal {
|
||||||
User(User),
|
User(User),
|
||||||
|
@ -630,31 +618,6 @@ impl Server {
|
||||||
))
|
))
|
||||||
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
||||||
.add_message_handler(update_context)
|
.add_message_handler(update_context)
|
||||||
.add_request_handler({
|
|
||||||
let app_state = app_state.clone();
|
|
||||||
move |request, response, session| {
|
|
||||||
let app_state = app_state.clone();
|
|
||||||
async move {
|
|
||||||
complete_with_language_model(request, response, session, &app_state.config)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.add_streaming_request_handler({
|
|
||||||
let app_state = app_state.clone();
|
|
||||||
move |request, response, session| {
|
|
||||||
let app_state = app_state.clone();
|
|
||||||
async move {
|
|
||||||
stream_complete_with_language_model(
|
|
||||||
request,
|
|
||||||
response,
|
|
||||||
session,
|
|
||||||
&app_state.config,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.add_request_handler({
|
.add_request_handler({
|
||||||
let app_state = app_state.clone();
|
let app_state = app_state.clone();
|
||||||
move |request, response, session| {
|
move |request, response, session| {
|
||||||
|
@ -948,40 +911,6 @@ impl Server {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
|
||||||
where
|
|
||||||
F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
|
|
||||||
Fut: Send + Future<Output = Result<()>>,
|
|
||||||
M: RequestMessage,
|
|
||||||
{
|
|
||||||
let handler = Arc::new(handler);
|
|
||||||
self.add_handler(move |envelope, session| {
|
|
||||||
let receipt = envelope.receipt();
|
|
||||||
let handler = handler.clone();
|
|
||||||
async move {
|
|
||||||
let peer = session.peer.clone();
|
|
||||||
let response = StreamingResponse {
|
|
||||||
peer: peer.clone(),
|
|
||||||
receipt,
|
|
||||||
};
|
|
||||||
match (handler)(envelope.payload, response, session).await {
|
|
||||||
Ok(()) => {
|
|
||||||
peer.end_stream(receipt)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Err(error) => {
|
|
||||||
let proto_err = match &error {
|
|
||||||
Error::Internal(err) => err.to_proto(),
|
|
||||||
_ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
|
|
||||||
};
|
|
||||||
peer.respond_with_error(receipt, proto_err)?;
|
|
||||||
Err(error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn handle_connection(
|
pub fn handle_connection(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
|
@ -4561,202 +4490,6 @@ async fn acknowledge_buffer_version(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ZedProCompleteWithLanguageModelRateLimit;
|
|
||||||
|
|
||||||
impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
|
|
||||||
fn capacity(&self) -> usize {
|
|
||||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
|
||||||
.ok()
|
|
||||||
.and_then(|v| v.parse().ok())
|
|
||||||
.unwrap_or(120) // Picked arbitrarily
|
|
||||||
}
|
|
||||||
|
|
||||||
fn refill_duration(&self) -> chrono::Duration {
|
|
||||||
chrono::Duration::hours(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn db_name(&self) -> &'static str {
|
|
||||||
"zed-pro:complete-with-language-model"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct FreeCompleteWithLanguageModelRateLimit;
|
|
||||||
|
|
||||||
impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
|
|
||||||
fn capacity(&self) -> usize {
|
|
||||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
|
|
||||||
.ok()
|
|
||||||
.and_then(|v| v.parse().ok())
|
|
||||||
.unwrap_or(120 / 10) // Picked arbitrarily
|
|
||||||
}
|
|
||||||
|
|
||||||
fn refill_duration(&self) -> chrono::Duration {
|
|
||||||
chrono::Duration::hours(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn db_name(&self) -> &'static str {
|
|
||||||
"free:complete-with-language-model"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn complete_with_language_model(
|
|
||||||
request: proto::CompleteWithLanguageModel,
|
|
||||||
response: Response<proto::CompleteWithLanguageModel>,
|
|
||||||
session: Session,
|
|
||||||
config: &Config,
|
|
||||||
) -> Result<()> {
|
|
||||||
let Some(session) = session.for_user() else {
|
|
||||||
return Err(anyhow!("user not found"))?;
|
|
||||||
};
|
|
||||||
authorize_access_to_language_models(&session).await?;
|
|
||||||
|
|
||||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
|
||||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
|
||||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
|
||||||
};
|
|
||||||
|
|
||||||
session
|
|
||||||
.app_state
|
|
||||||
.rate_limiter
|
|
||||||
.check(&*rate_limit, session.user_id())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
|
||||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
|
||||||
let api_key = config
|
|
||||||
.anthropic_api_key
|
|
||||||
.as_ref()
|
|
||||||
.context("no Anthropic AI API key configured on the server")?;
|
|
||||||
anthropic::complete(
|
|
||||||
session.http_client.as_ref(),
|
|
||||||
anthropic::ANTHROPIC_API_URL,
|
|
||||||
api_key,
|
|
||||||
serde_json::from_str(&request.request)?,
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
}
|
|
||||||
_ => return Err(anyhow!("unsupported provider"))?,
|
|
||||||
};
|
|
||||||
|
|
||||||
response.send(proto::CompleteWithLanguageModelResponse {
|
|
||||||
completion: serde_json::to_string(&result)?,
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn stream_complete_with_language_model(
|
|
||||||
request: proto::StreamCompleteWithLanguageModel,
|
|
||||||
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
|
|
||||||
session: Session,
|
|
||||||
config: &Config,
|
|
||||||
) -> Result<()> {
|
|
||||||
let Some(session) = session.for_user() else {
|
|
||||||
return Err(anyhow!("user not found"))?;
|
|
||||||
};
|
|
||||||
authorize_access_to_language_models(&session).await?;
|
|
||||||
|
|
||||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
|
||||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
|
||||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
|
||||||
};
|
|
||||||
|
|
||||||
session
|
|
||||||
.app_state
|
|
||||||
.rate_limiter
|
|
||||||
.check(&*rate_limit, session.user_id())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
match proto::LanguageModelProvider::from_i32(request.provider) {
|
|
||||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
|
||||||
let api_key = config
|
|
||||||
.anthropic_api_key
|
|
||||||
.as_ref()
|
|
||||||
.context("no Anthropic AI API key configured on the server")?;
|
|
||||||
let mut chunks = anthropic::stream_completion(
|
|
||||||
session.http_client.as_ref(),
|
|
||||||
anthropic::ANTHROPIC_API_URL,
|
|
||||||
api_key,
|
|
||||||
serde_json::from_str(&request.request)?,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
while let Some(event) = chunks.next().await {
|
|
||||||
let chunk = event?;
|
|
||||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
|
||||||
event: serde_json::to_string(&chunk)?,
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(proto::LanguageModelProvider::OpenAi) => {
|
|
||||||
let api_key = config
|
|
||||||
.openai_api_key
|
|
||||||
.as_ref()
|
|
||||||
.context("no OpenAI API key configured on the server")?;
|
|
||||||
let mut events = open_ai::stream_completion(
|
|
||||||
session.http_client.as_ref(),
|
|
||||||
open_ai::OPEN_AI_API_URL,
|
|
||||||
api_key,
|
|
||||||
serde_json::from_str(&request.request)?,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
while let Some(event) = events.next().await {
|
|
||||||
let event = event?;
|
|
||||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
|
||||||
event: serde_json::to_string(&event)?,
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(proto::LanguageModelProvider::Google) => {
|
|
||||||
let api_key = config
|
|
||||||
.google_ai_api_key
|
|
||||||
.as_ref()
|
|
||||||
.context("no Google AI API key configured on the server")?;
|
|
||||||
let mut events = google_ai::stream_generate_content(
|
|
||||||
session.http_client.as_ref(),
|
|
||||||
google_ai::API_URL,
|
|
||||||
api_key,
|
|
||||||
serde_json::from_str(&request.request)?,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
while let Some(event) = events.next().await {
|
|
||||||
let event = event?;
|
|
||||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
|
||||||
event: serde_json::to_string(&event)?,
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(proto::LanguageModelProvider::Zed) => {
|
|
||||||
let api_key = config
|
|
||||||
.qwen2_7b_api_key
|
|
||||||
.as_ref()
|
|
||||||
.context("no Qwen2-7B API key configured on the server")?;
|
|
||||||
let api_url = config
|
|
||||||
.qwen2_7b_api_url
|
|
||||||
.as_ref()
|
|
||||||
.context("no Qwen2-7B URL configured on the server")?;
|
|
||||||
let mut events = open_ai::stream_completion(
|
|
||||||
session.http_client.as_ref(),
|
|
||||||
&api_url,
|
|
||||||
api_key,
|
|
||||||
serde_json::from_str(&request.request)?,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
while let Some(event) = events.next().await {
|
|
||||||
let event = event?;
|
|
||||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
|
||||||
event: serde_json::to_string(&event)?,
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => return Err(anyhow!("unknown provider"))?,
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn count_language_model_tokens(
|
async fn count_language_model_tokens(
|
||||||
request: proto::CountLanguageModelTokens,
|
request: proto::CountLanguageModelTokens,
|
||||||
response: Response<proto::CountLanguageModelTokens>,
|
response: Response<proto::CountLanguageModelTokens>,
|
||||||
|
|
|
@ -197,10 +197,6 @@ message Envelope {
|
||||||
|
|
||||||
JoinHostedProject join_hosted_project = 164;
|
JoinHostedProject join_hosted_project = 164;
|
||||||
|
|
||||||
CompleteWithLanguageModel complete_with_language_model = 226;
|
|
||||||
CompleteWithLanguageModelResponse complete_with_language_model_response = 227;
|
|
||||||
StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
|
|
||||||
StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
|
|
||||||
CountLanguageModelTokens count_language_model_tokens = 230;
|
CountLanguageModelTokens count_language_model_tokens = 230;
|
||||||
CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
|
CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
|
||||||
GetCachedEmbeddings get_cached_embeddings = 189;
|
GetCachedEmbeddings get_cached_embeddings = 189;
|
||||||
|
@ -279,7 +275,7 @@ message Envelope {
|
||||||
|
|
||||||
reserved 158 to 161;
|
reserved 158 to 161;
|
||||||
reserved 166 to 169;
|
reserved 166 to 169;
|
||||||
reserved 224 to 225;
|
reserved 224 to 229;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Messages
|
// Messages
|
||||||
|
@ -2084,24 +2080,6 @@ enum LanguageModelRole {
|
||||||
reserved 3;
|
reserved 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message CompleteWithLanguageModel {
|
|
||||||
LanguageModelProvider provider = 1;
|
|
||||||
string request = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message CompleteWithLanguageModelResponse {
|
|
||||||
string completion = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message StreamCompleteWithLanguageModel {
|
|
||||||
LanguageModelProvider provider = 1;
|
|
||||||
string request = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message StreamCompleteWithLanguageModelResponse {
|
|
||||||
string event = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message CountLanguageModelTokens {
|
message CountLanguageModelTokens {
|
||||||
LanguageModelProvider provider = 1;
|
LanguageModelProvider provider = 1;
|
||||||
string request = 2;
|
string request = 2;
|
||||||
|
|
|
@ -298,10 +298,6 @@ messages!(
|
||||||
(PrepareRename, Background),
|
(PrepareRename, Background),
|
||||||
(PrepareRenameResponse, Background),
|
(PrepareRenameResponse, Background),
|
||||||
(ProjectEntryResponse, Foreground),
|
(ProjectEntryResponse, Foreground),
|
||||||
(CompleteWithLanguageModel, Background),
|
|
||||||
(CompleteWithLanguageModelResponse, Background),
|
|
||||||
(StreamCompleteWithLanguageModel, Background),
|
|
||||||
(StreamCompleteWithLanguageModelResponse, Background),
|
|
||||||
(CountLanguageModelTokens, Background),
|
(CountLanguageModelTokens, Background),
|
||||||
(CountLanguageModelTokensResponse, Background),
|
(CountLanguageModelTokensResponse, Background),
|
||||||
(RefreshInlayHints, Foreground),
|
(RefreshInlayHints, Foreground),
|
||||||
|
@ -476,11 +472,6 @@ request_messages!(
|
||||||
(PerformRename, PerformRenameResponse),
|
(PerformRename, PerformRenameResponse),
|
||||||
(Ping, Ack),
|
(Ping, Ack),
|
||||||
(PrepareRename, PrepareRenameResponse),
|
(PrepareRename, PrepareRenameResponse),
|
||||||
(CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
|
|
||||||
(
|
|
||||||
StreamCompleteWithLanguageModel,
|
|
||||||
StreamCompleteWithLanguageModelResponse
|
|
||||||
),
|
|
||||||
(CountLanguageModelTokens, CountLanguageModelTokensResponse),
|
(CountLanguageModelTokens, CountLanguageModelTokensResponse),
|
||||||
(RefreshInlayHints, Ack),
|
(RefreshInlayHints, Ack),
|
||||||
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
|
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue