From be63d51eb7b187579aba34ec4ccb07a8660eb3e7 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 17 Apr 2025 15:07:40 -0400 Subject: [PATCH] zeta: Extract usage information from response headers (#28999) This PR updates the Zeta provider to extract the usage information from the response headers, if they are present. For now we just log the information, but we'll need to figure out where this needs to get threaded through to in order to display it in the UI. Release Notes: - N/A --- Cargo.lock | 5 +++-- Cargo.toml | 2 +- crates/zeta/src/zeta.rs | 50 ++++++++++++++++++++++++++++++++++++----- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 015f7888fd..b5fa56f45b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18361,10 +18361,11 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ee4d410dbc030c3e6e3af78fc76296f6bebe20dcb6d7d3fa24bca306fc8c1ce" +checksum = "b91b8b05f1028157205026e525869eb860fa89bec87ea60b445efc91d05df31f" dependencies = [ + "anyhow", "serde", "serde_json", "strum 0.27.1", diff --git a/Cargo.toml b/Cargo.toml index 80aae32998..106f023231 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -604,7 +604,7 @@ wasmtime-wasi = "29" which = "6.0.0" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "0.5.1" +zed_llm_client = "0.6.0" zstd = "0.11" metal = "0.29" diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 14aa2820cb..b5367816fd 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -8,6 +8,7 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::KEY_VALUE_STORE; +use http_client::http::{HeaderMap, HeaderValue}; pub use init::*; use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; @@ -54,8 +55,9 @@ use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; use worktree::Worktree; use zed_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody, - PredictEditsResponse, + PredictEditsResponse, UsageLimit, }; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; @@ -74,6 +76,32 @@ const MAX_EVENT_COUNT: usize = 16; actions!(edit_prediction, [ClearHistory]); +#[derive(Debug, Clone, Copy)] +pub struct Usage { + pub limit: UsageLimit, + pub amount: i32, +} + +impl Usage { + pub fn from_headers(headers: &HeaderMap) -> Result { + let limit = headers + .get(EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME) + .ok_or_else(|| { + anyhow!("missing {EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME:?} header") + })?; + let limit = UsageLimit::from_str(limit.to_str()?)?; + + let amount = headers + .get(EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME) + .ok_or_else(|| { + anyhow!("missing {EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME:?} header") + })?; + let amount = amount.to_str()?.parse::()?; + + Ok(Self { limit, amount }) + } +} + #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] pub struct InlineCompletionId(Uuid); @@ -359,7 +387,7 @@ impl Zeta { ) -> Task>> where F: FnOnce(PerformPredictEditsParams) -> R + 'static, - R: Future> + Send + 'static, + R: Future)>> + Send + 'static, { let snapshot = self.report_changes_for_buffer(&buffer, cx); let diagnostic_groups = snapshot.diagnostic_groups(None); @@ -467,7 +495,7 @@ impl Zeta { body, }) .await; - let response = match response { + let (response, usage) = match response { Ok(response) => response, Err(err) => { if err.is::() { @@ -503,6 +531,14 @@ impl Zeta { log::debug!("completion response: {}", &response.output_excerpt); + if let Some(usage) = usage { + let limit = match usage.limit { + UsageLimit::Limited(limit) => limit.to_string(), + UsageLimit::Unlimited => "unlimited".to_string(), + }; + log::info!("edit prediction usage: {} / {}", usage.amount, limit); + } + Self::process_completion_response( response, buffer, @@ -685,7 +721,7 @@ and then another use std::future::ready; self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { - ready(Ok(response)) + ready(Ok((response, None))) }) } @@ -714,7 +750,7 @@ and then another fn perform_predict_edits( params: PerformPredictEditsParams, - ) -> impl Future> { + ) -> impl Future)>> { async move { let PerformPredictEditsParams { client, @@ -760,9 +796,11 @@ and then another } if response.status().is_success() { + let usage = Usage::from_headers(response.headers()).ok(); + let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; - return Ok(serde_json::from_str(&body)?); + return Ok((serde_json::from_str(&body)?, usage)); } else if !did_retry && response .headers()