diff --git a/Cargo.lock b/Cargo.lock index a5dd655ac7..0676d76561 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19847,9 +19847,9 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16d993fc42f9ec43ab76fa46c6eb579a66e116bb08cd2bc9a67f3afcaa05d39d" +checksum = "9be71e2f9b271e1eb8eb3e0d986075e770d1a0a299fb036abc3f1fc13a2fa7eb" dependencies = [ "anyhow", "serde", diff --git a/Cargo.toml b/Cargo.toml index be7087de61..bc9a203d5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -615,7 +615,7 @@ wasmtime-wasi = "29" which = "6.0.0" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "0.8.1" +zed_llm_client = "0.8.2" zstd = "0.11" [workspace.dependencies.async-stripe] diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 611ab52106..0eefe47540 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -14,7 +14,7 @@ use license_detection::LICENSE_FILES_TO_CHECK; pub use license_detection::is_license_eligible_for_data_collection; pub use rate_completion_modal::*; -use anyhow::{Context as _, Result}; +use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; use client::{Client, UserStore}; use collections::{HashMap, HashSet, VecDeque}; @@ -23,7 +23,7 @@ use gpui::{ App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion, Subscription, Task, WeakEntity, actions, }; -use http_client::{HttpClient, Method}; +use http_client::{AsyncBody, HttpClient, Method, Request, Response}; use input_excerpt::excerpt_for_cursor_position; use language::{ Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff, @@ -54,8 +54,8 @@ use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; use worktree::Worktree; use zed_llm_client::{ - EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody, - PredictEditsResponse, ZED_VERSION_HEADER_NAME, + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, + PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, }; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; @@ -823,6 +823,74 @@ and then another } } + fn accept_edit_prediction( + &mut self, + request_id: InlineCompletionId, + cx: &mut Context, + ) -> Task> { + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let app_version = AppVersion::global(cx); + cx.spawn(async move |this, cx| { + let http_client = client.http_client(); + let mut response = llm_token_retry(&llm_token, &client, |token| { + let request_builder = http_client::Request::builder().method(Method::POST); + let request_builder = + if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") { + request_builder.uri(accept_prediction_url) + } else { + request_builder.uri( + http_client + .build_zed_llm_url("/predict_edits/accept", &[])? + .as_ref(), + ) + }; + Ok(request_builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .header(ZED_VERSION_HEADER_NAME, app_version.to_string()) + .body( + serde_json::to_string(&AcceptEditPredictionBody { + request_id: request_id.0, + })? + .into(), + )?) + }) + .await?; + + if let Some(minimum_required_version) = response + .headers() + .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) + .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok()) + { + if app_version < minimum_required_version { + return Err(anyhow!(ZedUpdateRequiredError { + minimum_version: minimum_required_version + })); + } + } + + if response.status().is_success() { + if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { + this.update(cx, |this, cx| { + this.last_usage = Some(usage); + cx.notify(); + })?; + } + + Ok(()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + Err(anyhow!( + "error accepting edit prediction.\nStatus: {:?}\nBody: {}", + response.status(), + body + )) + } + }) + } + fn process_completion_response( prediction_response: PredictEditsResponse, buffer: Entity, @@ -1381,6 +1449,34 @@ impl ProviderDataCollection { } } +async fn llm_token_retry( + llm_token: &LlmApiToken, + client: &Arc, + build_request: impl Fn(String) -> Result>, +) -> Result> { + let mut did_retry = false; + let http_client = client.http_client(); + let mut token = llm_token.acquire(client).await?; + loop { + let request = build_request(token.clone())?; + let response = http_client.send(request).await?; + + if !did_retry + && !response.status().is_success() + && response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + { + did_retry = true; + token = llm_token.refresh(client).await?; + continue; + } + + return Ok(response); + } +} + pub struct ZetaInlineCompletionProvider { zeta: Entity, pending_completions: ArrayVec, @@ -1597,7 +1693,18 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider // Right now we don't support cycling. } - fn accept(&mut self, _cx: &mut Context) { + fn accept(&mut self, cx: &mut Context) { + let completion_id = self + .current_completion + .as_ref() + .map(|completion| completion.completion.id); + if let Some(completion_id) = completion_id { + self.zeta + .update(cx, |zeta, cx| { + zeta.accept_edit_prediction(completion_id, cx) + }) + .detach(); + } self.pending_completions.clear(); }