agent: Extract usage information from response headers (#29002)

This PR updates the Agent to extract the usage information from the
response headers, if they are present.

For now we just log the information, but we'll be using this soon to
populate some UI.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-17 16:11:07 -04:00 committed by GitHub
parent b402007de6
commit d93141bded
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 141 additions and 22 deletions

2
Cargo.lock generated
View file

@ -125,6 +125,7 @@ dependencies = [
"workspace",
"workspace-hack",
"zed_actions",
"zed_llm_client",
]
[[package]]
@ -7654,6 +7655,7 @@ dependencies = [
"thiserror 2.0.12",
"util",
"workspace-hack",
"zed_llm_client",
]
[[package]]

View file

@ -90,6 +90,7 @@ uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
buffer_diff = { workspace = true, features = ["test-support"] }

View file

@ -31,6 +31,7 @@ use settings::Settings;
use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::UsageLimit;
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::thread_store::{
@ -1070,14 +1071,22 @@ impl Thread {
) {
let pending_completion_id = post_inc(&mut self.completion_count);
let task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion(request, &cx);
let stream_completion_future = model.stream_completion_with_usage(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
let mut events = stream.await?;
let (mut events, usage) = stream_completion_future.await?;
let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default();
if let Some(usage) = usage {
let limit = match usage.limit {
UsageLimit::Limited(limit) => limit.to_string(),
UsageLimit::Unlimited => "unlimited".to_string(),
};
log::info!("model request usage: {} / {}", usage.amount, limit);
}
while let Some(event) = events.next().await {
let event = event?;

View file

@ -40,6 +40,7 @@ telemetry_events.workspace = true
thiserror.workspace = true
util.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View file

@ -8,11 +8,12 @@ mod telemetry;
#[cfg(any(test, feature = "test-support"))]
pub mod fake_provider;
use anyhow::Result;
use anyhow::{Result, anyhow};
use client::Client;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::http::{HeaderMap, HeaderValue};
use icons::IconName;
use parking_lot::Mutex;
use proto::Plan;
@ -20,9 +21,13 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::fmt;
use std::ops::{Add, Sub};
use std::str::FromStr as _;
use std::sync::Arc;
use thiserror::Error;
use util::serde::is_default;
use zed_llm_client::{
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
};
pub use crate::model::*;
pub use crate::rate_limiter::*;
@ -83,6 +88,28 @@ pub enum StopReason {
ToolUse,
}
#[derive(Debug, Clone, Copy)]
pub struct RequestUsage {
pub limit: UsageLimit,
pub amount: i32,
}
impl RequestUsage {
pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
let limit = headers
.get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
.ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?;
let limit = UsageLimit::from_str(limit.to_str()?)?;
let amount = headers
.get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
.ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?;
let amount = amount.to_str()?.parse::<i32>()?;
Ok(Self { limit, amount })
}
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
pub struct TokenUsage {
#[serde(default, skip_serializing_if = "is_default")]
@ -214,6 +241,22 @@ pub trait LanguageModel: Send + Sync {
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
fn stream_completion_with_usage(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
Option<RequestUsage>,
)>,
> {
self.stream_completion(request, cx)
.map(|result| result.map(|stream| (stream, None)))
.boxed()
}
fn stream_completion_text(
&self,
request: LanguageModelRequest,

View file

@ -8,6 +8,8 @@ use std::{
task::{Context, Poll},
};
use crate::RequestUsage;
#[derive(Clone)]
pub struct RateLimiter {
semaphore: Arc<Semaphore>,
@ -67,4 +69,32 @@ impl RateLimiter {
})
}
}
pub fn stream_with_usage<'a, Fut, T>(
&self,
future: Fut,
) -> impl 'a
+ Future<
Output = Result<(
impl Stream<Item = T::Item> + use<Fut, T>,
Option<RequestUsage>,
)>,
>
where
Fut: 'a + Future<Output = Result<(T, Option<RequestUsage>)>>,
T: Stream,
{
let guard = self.semaphore.acquire_arc();
async move {
let guard = guard.await;
let (inner, usage) = future.await?;
Ok((
RateLimitGuard {
inner,
_guard: guard,
},
usage,
))
}
}
}

View file

@ -13,7 +13,7 @@ use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
@ -518,7 +518,7 @@ impl CloudLanguageModel {
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: CompletionBody,
) -> Result<Response<AsyncBody>> {
) -> Result<(Response<AsyncBody>, Option<RequestUsage>)> {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
@ -540,7 +540,9 @@ impl CloudLanguageModel {
let mut response = http_client.send(request).await?;
let status = response.status();
if status.is_success() {
return Ok(response);
let usage = RequestUsage::from_headers(response.headers()).ok();
return Ok((response, usage));
} else if response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
@ -708,8 +710,24 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
_cx: &AsyncApp,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
self.stream_completion_with_usage(request, cx)
.map(|result| result.map(|(stream, _)| stream))
.boxed()
}
fn stream_completion_with_usage(
&self,
request: LanguageModelRequest,
_cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
Option<RequestUsage>,
)>,
> {
match &self.model {
CloudModel::Anthropic(model) => {
let request = into_anthropic(
@ -721,8 +739,8 @@ impl LanguageModel for CloudLanguageModel {
);
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
let future = self.request_limiter.stream_with_usage(async move {
let (response, usage) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@ -748,20 +766,25 @@ impl LanguageModel for CloudLanguageModel {
Err(err) => anyhow!(err),
})?;
Ok(
Ok((
crate::provider::anthropic::map_to_language_model_completion_events(
Box::pin(response_lines(response).map_err(AnthropicError::Other)),
),
)
usage,
))
});
async move { Ok(future.await?.boxed()) }.boxed()
async move {
let (stream, usage) = future.await?;
Ok((stream.boxed(), usage))
}
.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
let future = self.request_limiter.stream_with_usage(async move {
let (response, usage) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@ -771,20 +794,25 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
Ok(
Ok((
crate::provider::open_ai::map_to_language_model_completion_events(
Box::pin(response_lines(response)),
),
)
usage,
))
});
async move { Ok(future.await?.boxed()) }.boxed()
async move {
let (stream, usage) = future.await?;
Ok((stream.boxed(), usage))
}
.boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
let future = self.request_limiter.stream_with_usage(async move {
let (response, usage) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@ -794,13 +822,18 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
Ok(
Ok((
crate::provider::google::map_to_language_model_completion_events(Box::pin(
response_lines(response),
)),
)
usage,
))
});
async move { Ok(future.await?.boxed()) }.boxed()
async move {
let (stream, usage) = future.await?;
Ok((stream.boxed(), usage))
}
.boxed()
}
}
}