agent: Extract usage information from response headers (#29002)
This PR updates the Agent to extract the usage information from the response headers, if they are present. For now we just log the information, but we'll be using this soon to populate some UI. Release Notes: - N/A
This commit is contained in:
parent
b402007de6
commit
d93141bded
7 changed files with 141 additions and 22 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -125,6 +125,7 @@ dependencies = [
|
||||||
"workspace",
|
"workspace",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
"zed_actions",
|
"zed_actions",
|
||||||
|
"zed_llm_client",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -7654,6 +7655,7 @@ dependencies = [
|
||||||
"thiserror 2.0.12",
|
"thiserror 2.0.12",
|
||||||
"util",
|
"util",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
|
"zed_llm_client",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -90,6 +90,7 @@ uuid.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
workspace.workspace = true
|
workspace.workspace = true
|
||||||
zed_actions.workspace = true
|
zed_actions.workspace = true
|
||||||
|
zed_llm_client.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
buffer_diff = { workspace = true, features = ["test-support"] }
|
buffer_diff = { workspace = true, features = ["test-support"] }
|
||||||
|
|
|
@ -31,6 +31,7 @@ use settings::Settings;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
use zed_llm_client::UsageLimit;
|
||||||
|
|
||||||
use crate::context::{AssistantContext, ContextId, format_context_as_string};
|
use crate::context::{AssistantContext, ContextId, format_context_as_string};
|
||||||
use crate::thread_store::{
|
use crate::thread_store::{
|
||||||
|
@ -1070,14 +1071,22 @@ impl Thread {
|
||||||
) {
|
) {
|
||||||
let pending_completion_id = post_inc(&mut self.completion_count);
|
let pending_completion_id = post_inc(&mut self.completion_count);
|
||||||
let task = cx.spawn(async move |thread, cx| {
|
let task = cx.spawn(async move |thread, cx| {
|
||||||
let stream = model.stream_completion(request, &cx);
|
let stream_completion_future = model.stream_completion_with_usage(request, &cx);
|
||||||
let initial_token_usage =
|
let initial_token_usage =
|
||||||
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
|
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
|
||||||
let stream_completion = async {
|
let stream_completion = async {
|
||||||
let mut events = stream.await?;
|
let (mut events, usage) = stream_completion_future.await?;
|
||||||
let mut stop_reason = StopReason::EndTurn;
|
let mut stop_reason = StopReason::EndTurn;
|
||||||
let mut current_token_usage = TokenUsage::default();
|
let mut current_token_usage = TokenUsage::default();
|
||||||
|
|
||||||
|
if let Some(usage) = usage {
|
||||||
|
let limit = match usage.limit {
|
||||||
|
UsageLimit::Limited(limit) => limit.to_string(),
|
||||||
|
UsageLimit::Unlimited => "unlimited".to_string(),
|
||||||
|
};
|
||||||
|
log::info!("model request usage: {} / {}", usage.amount, limit);
|
||||||
|
}
|
||||||
|
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
let event = event?;
|
let event = event?;
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,7 @@ telemetry_events.workspace = true
|
||||||
thiserror.workspace = true
|
thiserror.workspace = true
|
||||||
util.workspace = true
|
util.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
zed_llm_client.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
gpui = { workspace = true, features = ["test-support"] }
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
|
|
|
@ -8,11 +8,12 @@ mod telemetry;
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
pub mod fake_provider;
|
pub mod fake_provider;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::{Result, anyhow};
|
||||||
use client::Client;
|
use client::Client;
|
||||||
use futures::FutureExt;
|
use futures::FutureExt;
|
||||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||||
|
use http_client::http::{HeaderMap, HeaderValue};
|
||||||
use icons::IconName;
|
use icons::IconName;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use proto::Plan;
|
use proto::Plan;
|
||||||
|
@ -20,9 +21,13 @@ use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ops::{Add, Sub};
|
use std::ops::{Add, Sub};
|
||||||
|
use std::str::FromStr as _;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use util::serde::is_default;
|
use util::serde::is_default;
|
||||||
|
use zed_llm_client::{
|
||||||
|
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
|
||||||
|
};
|
||||||
|
|
||||||
pub use crate::model::*;
|
pub use crate::model::*;
|
||||||
pub use crate::rate_limiter::*;
|
pub use crate::rate_limiter::*;
|
||||||
|
@ -83,6 +88,28 @@ pub enum StopReason {
|
||||||
ToolUse,
|
ToolUse,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct RequestUsage {
|
||||||
|
pub limit: UsageLimit,
|
||||||
|
pub amount: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestUsage {
|
||||||
|
pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
|
||||||
|
let limit = headers
|
||||||
|
.get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
|
||||||
|
.ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?;
|
||||||
|
let limit = UsageLimit::from_str(limit.to_str()?)?;
|
||||||
|
|
||||||
|
let amount = headers
|
||||||
|
.get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
|
||||||
|
.ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?;
|
||||||
|
let amount = amount.to_str()?.parse::<i32>()?;
|
||||||
|
|
||||||
|
Ok(Self { limit, amount })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
|
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
|
||||||
pub struct TokenUsage {
|
pub struct TokenUsage {
|
||||||
#[serde(default, skip_serializing_if = "is_default")]
|
#[serde(default, skip_serializing_if = "is_default")]
|
||||||
|
@ -214,6 +241,22 @@ pub trait LanguageModel: Send + Sync {
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
|
||||||
|
|
||||||
|
fn stream_completion_with_usage(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AsyncApp,
|
||||||
|
) -> BoxFuture<
|
||||||
|
'static,
|
||||||
|
Result<(
|
||||||
|
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
|
||||||
|
Option<RequestUsage>,
|
||||||
|
)>,
|
||||||
|
> {
|
||||||
|
self.stream_completion(request, cx)
|
||||||
|
.map(|result| result.map(|stream| (stream, None)))
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
fn stream_completion_text(
|
fn stream_completion_text(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
|
|
|
@ -8,6 +8,8 @@ use std::{
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::RequestUsage;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct RateLimiter {
|
pub struct RateLimiter {
|
||||||
semaphore: Arc<Semaphore>,
|
semaphore: Arc<Semaphore>,
|
||||||
|
@ -67,4 +69,32 @@ impl RateLimiter {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn stream_with_usage<'a, Fut, T>(
|
||||||
|
&self,
|
||||||
|
future: Fut,
|
||||||
|
) -> impl 'a
|
||||||
|
+ Future<
|
||||||
|
Output = Result<(
|
||||||
|
impl Stream<Item = T::Item> + use<Fut, T>,
|
||||||
|
Option<RequestUsage>,
|
||||||
|
)>,
|
||||||
|
>
|
||||||
|
where
|
||||||
|
Fut: 'a + Future<Output = Result<(T, Option<RequestUsage>)>>,
|
||||||
|
T: Stream,
|
||||||
|
{
|
||||||
|
let guard = self.semaphore.acquire_arc();
|
||||||
|
async move {
|
||||||
|
let guard = guard.await;
|
||||||
|
let (inner, usage) = future.await?;
|
||||||
|
Ok((
|
||||||
|
RateLimitGuard {
|
||||||
|
inner,
|
||||||
|
_guard: guard,
|
||||||
|
},
|
||||||
|
usage,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ use language_model::{
|
||||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||||
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
|
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
|
||||||
ZED_CLOUD_PROVIDER_ID,
|
ZED_CLOUD_PROVIDER_ID,
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
|
@ -518,7 +518,7 @@ impl CloudLanguageModel {
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
llm_api_token: LlmApiToken,
|
llm_api_token: LlmApiToken,
|
||||||
body: CompletionBody,
|
body: CompletionBody,
|
||||||
) -> Result<Response<AsyncBody>> {
|
) -> Result<(Response<AsyncBody>, Option<RequestUsage>)> {
|
||||||
let http_client = &client.http_client();
|
let http_client = &client.http_client();
|
||||||
|
|
||||||
let mut token = llm_api_token.acquire(&client).await?;
|
let mut token = llm_api_token.acquire(&client).await?;
|
||||||
|
@ -540,7 +540,9 @@ impl CloudLanguageModel {
|
||||||
let mut response = http_client.send(request).await?;
|
let mut response = http_client.send(request).await?;
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
if status.is_success() {
|
if status.is_success() {
|
||||||
return Ok(response);
|
let usage = RequestUsage::from_headers(response.headers()).ok();
|
||||||
|
|
||||||
|
return Ok((response, usage));
|
||||||
} else if response
|
} else if response
|
||||||
.headers()
|
.headers()
|
||||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||||
|
@ -708,8 +710,24 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
_cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||||
|
self.stream_completion_with_usage(request, cx)
|
||||||
|
.map(|result| result.map(|(stream, _)| stream))
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_completion_with_usage(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
_cx: &AsyncApp,
|
||||||
|
) -> BoxFuture<
|
||||||
|
'static,
|
||||||
|
Result<(
|
||||||
|
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
|
||||||
|
Option<RequestUsage>,
|
||||||
|
)>,
|
||||||
|
> {
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
let request = into_anthropic(
|
let request = into_anthropic(
|
||||||
|
@ -721,8 +739,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
);
|
);
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream_with_usage(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let (response, usage) = Self::perform_llm_completion(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
llm_api_token,
|
llm_api_token,
|
||||||
CompletionBody {
|
CompletionBody {
|
||||||
|
@ -748,20 +766,25 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
Err(err) => anyhow!(err),
|
Err(err) => anyhow!(err),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(
|
Ok((
|
||||||
crate::provider::anthropic::map_to_language_model_completion_events(
|
crate::provider::anthropic::map_to_language_model_completion_events(
|
||||||
Box::pin(response_lines(response).map_err(AnthropicError::Other)),
|
Box::pin(response_lines(response).map_err(AnthropicError::Other)),
|
||||||
),
|
),
|
||||||
)
|
usage,
|
||||||
|
))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move {
|
||||||
|
let (stream, usage) = future.await?;
|
||||||
|
Ok((stream.boxed(), usage))
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = into_open_ai(request, model, model.max_output_tokens());
|
let request = into_open_ai(request, model, model.max_output_tokens());
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream_with_usage(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let (response, usage) = Self::perform_llm_completion(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
llm_api_token,
|
llm_api_token,
|
||||||
CompletionBody {
|
CompletionBody {
|
||||||
|
@ -771,20 +794,25 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(
|
Ok((
|
||||||
crate::provider::open_ai::map_to_language_model_completion_events(
|
crate::provider::open_ai::map_to_language_model_completion_events(
|
||||||
Box::pin(response_lines(response)),
|
Box::pin(response_lines(response)),
|
||||||
),
|
),
|
||||||
)
|
usage,
|
||||||
|
))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move {
|
||||||
|
let (stream, usage) = future.await?;
|
||||||
|
Ok((stream.boxed(), usage))
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = into_google(request, model.id().into());
|
let request = into_google(request, model.id().into());
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream_with_usage(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let (response, usage) = Self::perform_llm_completion(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
llm_api_token,
|
llm_api_token,
|
||||||
CompletionBody {
|
CompletionBody {
|
||||||
|
@ -794,13 +822,18 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(
|
Ok((
|
||||||
crate::provider::google::map_to_language_model_completion_events(Box::pin(
|
crate::provider::google::map_to_language_model_completion_events(Box::pin(
|
||||||
response_lines(response),
|
response_lines(response),
|
||||||
)),
|
)),
|
||||||
)
|
usage,
|
||||||
|
))
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move {
|
||||||
|
let (stream, usage) = future.await?;
|
||||||
|
Ok((stream.boxed(), usage))
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue