diff --git a/Cargo.lock b/Cargo.lock index b5fa56f45b..6866438deb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -125,6 +125,7 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", + "zed_llm_client", ] [[package]] @@ -7654,6 +7655,7 @@ dependencies = [ "thiserror 2.0.12", "util", "workspace-hack", + "zed_llm_client", ] [[package]] diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index ae184a1f38..cd8b9af0ee 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -90,6 +90,7 @@ uuid.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true +zed_llm_client.workspace = true [dev-dependencies] buffer_diff = { workspace = true, features = ["test-support"] } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 62c43b877e..94882f8cbd 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -31,6 +31,7 @@ use settings::Settings; use thiserror::Error; use util::{ResultExt as _, TryFutureExt as _, post_inc}; use uuid::Uuid; +use zed_llm_client::UsageLimit; use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::thread_store::{ @@ -1070,14 +1071,22 @@ impl Thread { ) { let pending_completion_id = post_inc(&mut self.completion_count); let task = cx.spawn(async move |thread, cx| { - let stream = model.stream_completion(request, &cx); + let stream_completion_future = model.stream_completion_with_usage(request, &cx); let initial_token_usage = thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { - let mut events = stream.await?; + let (mut events, usage) = stream_completion_future.await?; let mut stop_reason = StopReason::EndTurn; let mut current_token_usage = TokenUsage::default(); + if let Some(usage) = usage { + let limit = match usage.limit { + UsageLimit::Limited(limit) => limit.to_string(), + UsageLimit::Unlimited => "unlimited".to_string(), + }; + log::info!("model request usage: {} / {}", usage.amount, limit); + } + while let Some(event) = events.next().await { let event = event?; diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 4580d9f701..c468ff8297 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -40,6 +40,7 @@ telemetry_events.workspace = true thiserror.workspace = true util.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 35bf5d6094..88115c43fb 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -8,11 +8,12 @@ mod telemetry; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anyhow::Result; +use anyhow::{Result, anyhow}; use client::Client; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; +use http_client::http::{HeaderMap, HeaderValue}; use icons::IconName; use parking_lot::Mutex; use proto::Plan; @@ -20,9 +21,13 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::fmt; use std::ops::{Add, Sub}; +use std::str::FromStr as _; use std::sync::Arc; use thiserror::Error; use util::serde::is_default; +use zed_llm_client::{ + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; pub use crate::model::*; pub use crate::rate_limiter::*; @@ -83,6 +88,28 @@ pub enum StopReason { ToolUse, } +#[derive(Debug, Clone, Copy)] +pub struct RequestUsage { + pub limit: UsageLimit, + pub amount: i32, +} + +impl RequestUsage { + pub fn from_headers(headers: &HeaderMap) -> Result { + let limit = headers + .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME) + .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?; + let limit = UsageLimit::from_str(limit.to_str()?)?; + + let amount = headers + .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME) + .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?; + let amount = amount.to_str()?.parse::()?; + + Ok(Self { limit, amount }) + } +} + #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] pub struct TokenUsage { #[serde(default, skip_serializing_if = "is_default")] @@ -214,6 +241,22 @@ pub trait LanguageModel: Send + Sync { cx: &AsyncApp, ) -> BoxFuture<'static, Result>>>; + fn stream_completion_with_usage( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result<( + BoxStream<'static, Result>, + Option, + )>, + > { + self.stream_completion(request, cx) + .map(|result| result.map(|stream| (stream, None))) + .boxed() + } + fn stream_completion_text( &self, request: LanguageModelRequest, diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model/src/rate_limiter.rs index a48d34488b..7383dd56c9 100644 --- a/crates/language_model/src/rate_limiter.rs +++ b/crates/language_model/src/rate_limiter.rs @@ -8,6 +8,8 @@ use std::{ task::{Context, Poll}, }; +use crate::RequestUsage; + #[derive(Clone)] pub struct RateLimiter { semaphore: Arc, @@ -67,4 +69,32 @@ impl RateLimiter { }) } } + + pub fn stream_with_usage<'a, Fut, T>( + &self, + future: Fut, + ) -> impl 'a + + Future< + Output = Result<( + impl Stream + use, + Option, + )>, + > + where + Fut: 'a + Future)>>, + T: Stream, + { + let guard = self.semaphore.acquire_arc(); + async move { + let guard = guard.await; + let (inner, usage) = future.await?; + Ok(( + RateLimitGuard { + inner, + _guard: guard, + }, + usage, + )) + } + } } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 8286d0a1f2..25a048537b 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -13,7 +13,7 @@ use language_model::{ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, - LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, + LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ @@ -518,7 +518,7 @@ impl CloudLanguageModel { client: Arc, llm_api_token: LlmApiToken, body: CompletionBody, - ) -> Result> { + ) -> Result<(Response, Option)> { let http_client = &client.http_client(); let mut token = llm_api_token.acquire(&client).await?; @@ -540,7 +540,9 @@ impl CloudLanguageModel { let mut response = http_client.send(request).await?; let status = response.status(); if status.is_success() { - return Ok(response); + let usage = RequestUsage::from_headers(response.headers()).ok(); + + return Ok((response, usage)); } else if response .headers() .get(EXPIRED_LLM_TOKEN_HEADER_NAME) @@ -708,8 +710,24 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, request: LanguageModelRequest, - _cx: &AsyncApp, + cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { + self.stream_completion_with_usage(request, cx) + .map(|result| result.map(|(stream, _)| stream)) + .boxed() + } + + fn stream_completion_with_usage( + &self, + request: LanguageModelRequest, + _cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result<( + BoxStream<'static, Result>, + Option, + )>, + > { match &self.model { CloudModel::Anthropic(model) => { let request = into_anthropic( @@ -721,8 +739,8 @@ impl LanguageModel for CloudLanguageModel { ); let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( + let future = self.request_limiter.stream_with_usage(async move { + let (response, usage) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -748,20 +766,25 @@ impl LanguageModel for CloudLanguageModel { Err(err) => anyhow!(err), })?; - Ok( + Ok(( crate::provider::anthropic::map_to_language_model_completion_events( Box::pin(response_lines(response).map_err(AnthropicError::Other)), ), - ) + usage, + )) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + let (stream, usage) = future.await?; + Ok((stream.boxed(), usage)) + } + .boxed() } CloudModel::OpenAi(model) => { let client = self.client.clone(); let request = into_open_ai(request, model, model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( + let future = self.request_limiter.stream_with_usage(async move { + let (response, usage) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -771,20 +794,25 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok( + Ok(( crate::provider::open_ai::map_to_language_model_completion_events( Box::pin(response_lines(response)), ), - ) + usage, + )) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + let (stream, usage) = future.await?; + Ok((stream.boxed(), usage)) + } + .boxed() } CloudModel::Google(model) => { let client = self.client.clone(); let request = into_google(request, model.id().into()); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( + let future = self.request_limiter.stream_with_usage(async move { + let (response, usage) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -794,13 +822,18 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok( + Ok(( crate::provider::google::map_to_language_model_completion_events(Box::pin( response_lines(response), )), - ) + usage, + )) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + let (stream, usage) = future.await?; + Ok((stream.boxed(), usage)) + } + .boxed() } } }