From d93141bdedd4967637dd49bad224a0c8162560f1 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 17 Apr 2025 16:11:07 -0400 Subject: [PATCH] agent: Extract usage information from response headers (#29002) This PR updates the Agent to extract the usage information from the response headers, if they are present. For now we just log the information, but we'll be using this soon to populate some UI. Release Notes: - N/A --- Cargo.lock | 2 + crates/agent/Cargo.toml | 1 + crates/agent/src/thread.rs | 13 +++- crates/language_model/Cargo.toml | 1 + crates/language_model/src/language_model.rs | 45 ++++++++++++- crates/language_model/src/rate_limiter.rs | 30 +++++++++ crates/language_models/src/provider/cloud.rs | 71 ++++++++++++++------ 7 files changed, 141 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b5fa56f45b..6866438deb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -125,6 +125,7 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", + "zed_llm_client", ] [[package]] @@ -7654,6 +7655,7 @@ dependencies = [ "thiserror 2.0.12", "util", "workspace-hack", + "zed_llm_client", ] [[package]] diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index ae184a1f38..cd8b9af0ee 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -90,6 +90,7 @@ uuid.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true +zed_llm_client.workspace = true [dev-dependencies] buffer_diff = { workspace = true, features = ["test-support"] } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 62c43b877e..94882f8cbd 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -31,6 +31,7 @@ use settings::Settings; use thiserror::Error; use util::{ResultExt as _, TryFutureExt as _, post_inc}; use uuid::Uuid; +use zed_llm_client::UsageLimit; use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::thread_store::{ @@ -1070,14 +1071,22 @@ impl Thread { ) { let pending_completion_id = post_inc(&mut self.completion_count); let task = cx.spawn(async move |thread, cx| { - let stream = model.stream_completion(request, &cx); + let stream_completion_future = model.stream_completion_with_usage(request, &cx); let initial_token_usage = thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { - let mut events = stream.await?; + let (mut events, usage) = stream_completion_future.await?; let mut stop_reason = StopReason::EndTurn; let mut current_token_usage = TokenUsage::default(); + if let Some(usage) = usage { + let limit = match usage.limit { + UsageLimit::Limited(limit) => limit.to_string(), + UsageLimit::Unlimited => "unlimited".to_string(), + }; + log::info!("model request usage: {} / {}", usage.amount, limit); + } + while let Some(event) = events.next().await { let event = event?; diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 4580d9f701..c468ff8297 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -40,6 +40,7 @@ telemetry_events.workspace = true thiserror.workspace = true util.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 35bf5d6094..88115c43fb 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -8,11 +8,12 @@ mod telemetry; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anyhow::Result; +use anyhow::{Result, anyhow}; use client::Client; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; +use http_client::http::{HeaderMap, HeaderValue}; use icons::IconName; use parking_lot::Mutex; use proto::Plan; @@ -20,9 +21,13 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::fmt; use std::ops::{Add, Sub}; +use std::str::FromStr as _; use std::sync::Arc; use thiserror::Error; use util::serde::is_default; +use zed_llm_client::{ + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; pub use crate::model::*; pub use crate::rate_limiter::*; @@ -83,6 +88,28 @@ pub enum StopReason { ToolUse, } +#[derive(Debug, Clone, Copy)] +pub struct RequestUsage { + pub limit: UsageLimit, + pub amount: i32, +} + +impl RequestUsage { + pub fn from_headers(headers: &HeaderMap) -> Result { + let limit = headers + .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME) + .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?; + let limit = UsageLimit::from_str(limit.to_str()?)?; + + let amount = headers + .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME) + .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?; + let amount = amount.to_str()?.parse::()?; + + Ok(Self { limit, amount }) + } +} + #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] pub struct TokenUsage { #[serde(default, skip_serializing_if = "is_default")] @@ -214,6 +241,22 @@ pub trait LanguageModel: Send + Sync { cx: &AsyncApp, ) -> BoxFuture<'static, Result>>>; + fn stream_completion_with_usage( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result<( + BoxStream<'static, Result>, + Option, + )>, + > { + self.stream_completion(request, cx) + .map(|result| result.map(|stream| (stream, None))) + .boxed() + } + fn stream_completion_text( &self, request: LanguageModelRequest, diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model/src/rate_limiter.rs index a48d34488b..7383dd56c9 100644 --- a/crates/language_model/src/rate_limiter.rs +++ b/crates/language_model/src/rate_limiter.rs @@ -8,6 +8,8 @@ use std::{ task::{Context, Poll}, }; +use crate::RequestUsage; + #[derive(Clone)] pub struct RateLimiter { semaphore: Arc, @@ -67,4 +69,32 @@ impl RateLimiter { }) } } + + pub fn stream_with_usage<'a, Fut, T>( + &self, + future: Fut, + ) -> impl 'a + + Future< + Output = Result<( + impl Stream + use, + Option, + )>, + > + where + Fut: 'a + Future)>>, + T: Stream, + { + let guard = self.semaphore.acquire_arc(); + async move { + let guard = guard.await; + let (inner, usage) = future.await?; + Ok(( + RateLimitGuard { + inner, + _guard: guard, + }, + usage, + )) + } + } } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 8286d0a1f2..25a048537b 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -13,7 +13,7 @@ use language_model::{ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, - LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, + LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ @@ -518,7 +518,7 @@ impl CloudLanguageModel { client: Arc, llm_api_token: LlmApiToken, body: CompletionBody, - ) -> Result> { + ) -> Result<(Response, Option)> { let http_client = &client.http_client(); let mut token = llm_api_token.acquire(&client).await?; @@ -540,7 +540,9 @@ impl CloudLanguageModel { let mut response = http_client.send(request).await?; let status = response.status(); if status.is_success() { - return Ok(response); + let usage = RequestUsage::from_headers(response.headers()).ok(); + + return Ok((response, usage)); } else if response .headers() .get(EXPIRED_LLM_TOKEN_HEADER_NAME) @@ -708,8 +710,24 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, request: LanguageModelRequest, - _cx: &AsyncApp, + cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { + self.stream_completion_with_usage(request, cx) + .map(|result| result.map(|(stream, _)| stream)) + .boxed() + } + + fn stream_completion_with_usage( + &self, + request: LanguageModelRequest, + _cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result<( + BoxStream<'static, Result>, + Option, + )>, + > { match &self.model { CloudModel::Anthropic(model) => { let request = into_anthropic( @@ -721,8 +739,8 @@ impl LanguageModel for CloudLanguageModel { ); let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( + let future = self.request_limiter.stream_with_usage(async move { + let (response, usage) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -748,20 +766,25 @@ impl LanguageModel for CloudLanguageModel { Err(err) => anyhow!(err), })?; - Ok( + Ok(( crate::provider::anthropic::map_to_language_model_completion_events( Box::pin(response_lines(response).map_err(AnthropicError::Other)), ), - ) + usage, + )) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + let (stream, usage) = future.await?; + Ok((stream.boxed(), usage)) + } + .boxed() } CloudModel::OpenAi(model) => { let client = self.client.clone(); let request = into_open_ai(request, model, model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( + let future = self.request_limiter.stream_with_usage(async move { + let (response, usage) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -771,20 +794,25 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok( + Ok(( crate::provider::open_ai::map_to_language_model_completion_events( Box::pin(response_lines(response)), ), - ) + usage, + )) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + let (stream, usage) = future.await?; + Ok((stream.boxed(), usage)) + } + .boxed() } CloudModel::Google(model) => { let client = self.client.clone(); let request = into_google(request, model.id().into()); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( + let future = self.request_limiter.stream_with_usage(async move { + let (response, usage) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -794,13 +822,18 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok( + Ok(( crate::provider::google::map_to_language_model_completion_events(Box::pin( response_lines(response), )), - ) + usage, + )) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + let (stream, usage) = future.await?; + Ok((stream.boxed(), usage)) + } + .boxed() } } }