diff --git a/Cargo.lock b/Cargo.lock index 015f7888fd..b5fa56f45b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18361,10 +18361,11 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ee4d410dbc030c3e6e3af78fc76296f6bebe20dcb6d7d3fa24bca306fc8c1ce" +checksum = "b91b8b05f1028157205026e525869eb860fa89bec87ea60b445efc91d05df31f" dependencies = [ + "anyhow", "serde", "serde_json", "strum 0.27.1", diff --git a/Cargo.toml b/Cargo.toml index 80aae32998..106f023231 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -604,7 +604,7 @@ wasmtime-wasi = "29" which = "6.0.0" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "0.5.1" +zed_llm_client = "0.6.0" zstd = "0.11" metal = "0.29" diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 14aa2820cb..b5367816fd 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -8,6 +8,7 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::KEY_VALUE_STORE; +use http_client::http::{HeaderMap, HeaderValue}; pub use init::*; use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; @@ -54,8 +55,9 @@ use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; use worktree::Worktree; use zed_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody, - PredictEditsResponse, + PredictEditsResponse, UsageLimit, }; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; @@ -74,6 +76,32 @@ const MAX_EVENT_COUNT: usize = 16; actions!(edit_prediction, [ClearHistory]); +#[derive(Debug, Clone, Copy)] +pub struct Usage { + pub limit: UsageLimit, + pub amount: i32, +} + +impl Usage { + pub fn from_headers(headers: &HeaderMap) -> Result { + let limit = headers + .get(EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME) + .ok_or_else(|| { + anyhow!("missing {EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME:?} header") + })?; + let limit = UsageLimit::from_str(limit.to_str()?)?; + + let amount = headers + .get(EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME) + .ok_or_else(|| { + anyhow!("missing {EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME:?} header") + })?; + let amount = amount.to_str()?.parse::()?; + + Ok(Self { limit, amount }) + } +} + #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] pub struct InlineCompletionId(Uuid); @@ -359,7 +387,7 @@ impl Zeta { ) -> Task>> where F: FnOnce(PerformPredictEditsParams) -> R + 'static, - R: Future> + Send + 'static, + R: Future)>> + Send + 'static, { let snapshot = self.report_changes_for_buffer(&buffer, cx); let diagnostic_groups = snapshot.diagnostic_groups(None); @@ -467,7 +495,7 @@ impl Zeta { body, }) .await; - let response = match response { + let (response, usage) = match response { Ok(response) => response, Err(err) => { if err.is::() { @@ -503,6 +531,14 @@ impl Zeta { log::debug!("completion response: {}", &response.output_excerpt); + if let Some(usage) = usage { + let limit = match usage.limit { + UsageLimit::Limited(limit) => limit.to_string(), + UsageLimit::Unlimited => "unlimited".to_string(), + }; + log::info!("edit prediction usage: {} / {}", usage.amount, limit); + } + Self::process_completion_response( response, buffer, @@ -685,7 +721,7 @@ and then another use std::future::ready; self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { - ready(Ok(response)) + ready(Ok((response, None))) }) } @@ -714,7 +750,7 @@ and then another fn perform_predict_edits( params: PerformPredictEditsParams, - ) -> impl Future> { + ) -> impl Future)>> { async move { let PerformPredictEditsParams { client, @@ -760,9 +796,11 @@ and then another } if response.status().is_success() { + let usage = Usage::from_headers(response.headers()).ok(); + let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; - return Ok(serde_json::from_str(&body)?); + return Ok((serde_json::from_str(&body)?, usage)); } else if !did_retry && response .headers()