diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 2b34546ba2..f6f9798351 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -105,18 +105,6 @@ impl Response { } } -struct StreamingResponse { - peer: Arc, - receipt: Receipt, -} - -impl StreamingResponse { - fn send(&self, payload: R::Response) -> Result<()> { - self.peer.respond(self.receipt, payload)?; - Ok(()) - } -} - #[derive(Clone, Debug)] pub enum Principal { User(User), @@ -630,31 +618,6 @@ impl Server { )) .add_message_handler(broadcast_project_message_from_host::) .add_message_handler(update_context) - .add_request_handler({ - let app_state = app_state.clone(); - move |request, response, session| { - let app_state = app_state.clone(); - async move { - complete_with_language_model(request, response, session, &app_state.config) - .await - } - } - }) - .add_streaming_request_handler({ - let app_state = app_state.clone(); - move |request, response, session| { - let app_state = app_state.clone(); - async move { - stream_complete_with_language_model( - request, - response, - session, - &app_state.config, - ) - .await - } - } - }) .add_request_handler({ let app_state = app_state.clone(); move |request, response, session| { @@ -948,40 +911,6 @@ impl Server { }) } - fn add_streaming_request_handler(&mut self, handler: F) -> &mut Self - where - F: 'static + Send + Sync + Fn(M, StreamingResponse, Session) -> Fut, - Fut: Send + Future>, - M: RequestMessage, - { - let handler = Arc::new(handler); - self.add_handler(move |envelope, session| { - let receipt = envelope.receipt(); - let handler = handler.clone(); - async move { - let peer = session.peer.clone(); - let response = StreamingResponse { - peer: peer.clone(), - receipt, - }; - match (handler)(envelope.payload, response, session).await { - Ok(()) => { - peer.end_stream(receipt)?; - Ok(()) - } - Err(error) => { - let proto_err = match &error { - Error::Internal(err) => err.to_proto(), - _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(), - }; - peer.respond_with_error(receipt, proto_err)?; - Err(error) - } - } - } - }) - } - #[allow(clippy::too_many_arguments)] pub fn handle_connection( self: &Arc, @@ -4561,202 +4490,6 @@ async fn acknowledge_buffer_version( Ok(()) } -struct ZedProCompleteWithLanguageModelRateLimit; - -impl RateLimit for ZedProCompleteWithLanguageModelRateLimit { - fn capacity(&self) -> usize { - std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(120) // Picked arbitrarily - } - - fn refill_duration(&self) -> chrono::Duration { - chrono::Duration::hours(1) - } - - fn db_name(&self) -> &'static str { - "zed-pro:complete-with-language-model" - } -} - -struct FreeCompleteWithLanguageModelRateLimit; - -impl RateLimit for FreeCompleteWithLanguageModelRateLimit { - fn capacity(&self) -> usize { - std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(120 / 10) // Picked arbitrarily - } - - fn refill_duration(&self) -> chrono::Duration { - chrono::Duration::hours(1) - } - - fn db_name(&self) -> &'static str { - "free:complete-with-language-model" - } -} - -async fn complete_with_language_model( - request: proto::CompleteWithLanguageModel, - response: Response, - session: Session, - config: &Config, -) -> Result<()> { - let Some(session) = session.for_user() else { - return Err(anyhow!("user not found"))?; - }; - authorize_access_to_language_models(&session).await?; - - let rate_limit: Box = match session.current_plan().await? { - proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit), - proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit), - }; - - session - .app_state - .rate_limiter - .check(&*rate_limit, session.user_id()) - .await?; - - let result = match proto::LanguageModelProvider::from_i32(request.provider) { - Some(proto::LanguageModelProvider::Anthropic) => { - let api_key = config - .anthropic_api_key - .as_ref() - .context("no Anthropic AI API key configured on the server")?; - anthropic::complete( - session.http_client.as_ref(), - anthropic::ANTHROPIC_API_URL, - api_key, - serde_json::from_str(&request.request)?, - ) - .await? - } - _ => return Err(anyhow!("unsupported provider"))?, - }; - - response.send(proto::CompleteWithLanguageModelResponse { - completion: serde_json::to_string(&result)?, - })?; - - Ok(()) -} - -async fn stream_complete_with_language_model( - request: proto::StreamCompleteWithLanguageModel, - response: StreamingResponse, - session: Session, - config: &Config, -) -> Result<()> { - let Some(session) = session.for_user() else { - return Err(anyhow!("user not found"))?; - }; - authorize_access_to_language_models(&session).await?; - - let rate_limit: Box = match session.current_plan().await? { - proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit), - proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit), - }; - - session - .app_state - .rate_limiter - .check(&*rate_limit, session.user_id()) - .await?; - - match proto::LanguageModelProvider::from_i32(request.provider) { - Some(proto::LanguageModelProvider::Anthropic) => { - let api_key = config - .anthropic_api_key - .as_ref() - .context("no Anthropic AI API key configured on the server")?; - let mut chunks = anthropic::stream_completion( - session.http_client.as_ref(), - anthropic::ANTHROPIC_API_URL, - api_key, - serde_json::from_str(&request.request)?, - None, - ) - .await?; - while let Some(event) = chunks.next().await { - let chunk = event?; - response.send(proto::StreamCompleteWithLanguageModelResponse { - event: serde_json::to_string(&chunk)?, - })?; - } - } - Some(proto::LanguageModelProvider::OpenAi) => { - let api_key = config - .openai_api_key - .as_ref() - .context("no OpenAI API key configured on the server")?; - let mut events = open_ai::stream_completion( - session.http_client.as_ref(), - open_ai::OPEN_AI_API_URL, - api_key, - serde_json::from_str(&request.request)?, - None, - ) - .await?; - while let Some(event) = events.next().await { - let event = event?; - response.send(proto::StreamCompleteWithLanguageModelResponse { - event: serde_json::to_string(&event)?, - })?; - } - } - Some(proto::LanguageModelProvider::Google) => { - let api_key = config - .google_ai_api_key - .as_ref() - .context("no Google AI API key configured on the server")?; - let mut events = google_ai::stream_generate_content( - session.http_client.as_ref(), - google_ai::API_URL, - api_key, - serde_json::from_str(&request.request)?, - ) - .await?; - while let Some(event) = events.next().await { - let event = event?; - response.send(proto::StreamCompleteWithLanguageModelResponse { - event: serde_json::to_string(&event)?, - })?; - } - } - Some(proto::LanguageModelProvider::Zed) => { - let api_key = config - .qwen2_7b_api_key - .as_ref() - .context("no Qwen2-7B API key configured on the server")?; - let api_url = config - .qwen2_7b_api_url - .as_ref() - .context("no Qwen2-7B URL configured on the server")?; - let mut events = open_ai::stream_completion( - session.http_client.as_ref(), - &api_url, - api_key, - serde_json::from_str(&request.request)?, - None, - ) - .await?; - while let Some(event) = events.next().await { - let event = event?; - response.send(proto::StreamCompleteWithLanguageModelResponse { - event: serde_json::to_string(&event)?, - })?; - } - } - None => return Err(anyhow!("unknown provider"))?, - } - - Ok(()) -} - async fn count_language_model_tokens( request: proto::CountLanguageModelTokens, response: Response, diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 22407553a4..995ebcd341 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -197,10 +197,6 @@ message Envelope { JoinHostedProject join_hosted_project = 164; - CompleteWithLanguageModel complete_with_language_model = 226; - CompleteWithLanguageModelResponse complete_with_language_model_response = 227; - StreamCompleteWithLanguageModel stream_complete_with_language_model = 228; - StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229; CountLanguageModelTokens count_language_model_tokens = 230; CountLanguageModelTokensResponse count_language_model_tokens_response = 231; GetCachedEmbeddings get_cached_embeddings = 189; @@ -279,7 +275,7 @@ message Envelope { reserved 158 to 161; reserved 166 to 169; - reserved 224 to 225; + reserved 224 to 229; } // Messages @@ -2084,24 +2080,6 @@ enum LanguageModelRole { reserved 3; } -message CompleteWithLanguageModel { - LanguageModelProvider provider = 1; - string request = 2; -} - -message CompleteWithLanguageModelResponse { - string completion = 1; -} - -message StreamCompleteWithLanguageModel { - LanguageModelProvider provider = 1; - string request = 2; -} - -message StreamCompleteWithLanguageModelResponse { - string event = 1; -} - message CountLanguageModelTokens { LanguageModelProvider provider = 1; string request = 2; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index a2282bc6a7..139ee8fdf9 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -298,10 +298,6 @@ messages!( (PrepareRename, Background), (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), - (CompleteWithLanguageModel, Background), - (CompleteWithLanguageModelResponse, Background), - (StreamCompleteWithLanguageModel, Background), - (StreamCompleteWithLanguageModelResponse, Background), (CountLanguageModelTokens, Background), (CountLanguageModelTokensResponse, Background), (RefreshInlayHints, Foreground), @@ -476,11 +472,6 @@ request_messages!( (PerformRename, PerformRenameResponse), (Ping, Ack), (PrepareRename, PrepareRenameResponse), - (CompleteWithLanguageModel, CompleteWithLanguageModelResponse), - ( - StreamCompleteWithLanguageModel, - StreamCompleteWithLanguageModelResponse - ), (CountLanguageModelTokens, CountLanguageModelTokensResponse), (RefreshInlayHints, Ack), (RejoinChannelBuffers, RejoinChannelBuffersResponse),