From c3d9cdecab264849c8e4da950e20a997786b1666 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Sun, 4 May 2025 10:37:42 -0700 Subject: [PATCH] Change cloud language model provider JSON protocol to surface errors and usage information (#29830) Release Notes: - N/A --------- Co-authored-by: Nathan Sobo Co-authored-by: Marshall Bowers --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/agent/src/thread.rs | 70 ++++---- .../assistant_context_editor/src/context.rs | 6 +- crates/eval/src/instance.rs | 4 +- crates/language_model/src/language_model.rs | 58 ++----- crates/language_model/src/rate_limiter.rs | 30 ---- crates/language_models/src/provider/cloud.rs | 151 ++++++++---------- 8 files changed, 128 insertions(+), 197 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d3cbc2afb0..80aa1bec3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18825,9 +18825,9 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.7.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "226e0b479b3aed072d83db276866d54bce631e3a8600fcdf4f309d73389af9c7" +checksum = "2adf9bc80def4ec93c190f06eb78111865edc2576019a9753eaef6fd7bc3b72c" dependencies = [ "anyhow", "serde", diff --git a/Cargo.toml b/Cargo.toml index df699ef5a9..b3705be27d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -611,7 +611,7 @@ wasmtime-wasi = "29" which = "6.0.0" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "0.7.2" +zed_llm_client = "0.7.4" zstd = "0.11" [workspace.dependencies.async-stripe] diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 02fee57a47..e819dd9000 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -37,7 +37,7 @@ use settings::Settings; use thiserror::Error; use util::{ResultExt as _, TryFutureExt as _, post_inc}; use uuid::Uuid; -use zed_llm_client::CompletionMode; +use zed_llm_client::{CompletionMode, CompletionRequestStatus}; use crate::ThreadStore; use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext}; @@ -1356,20 +1356,17 @@ impl Thread { self.last_received_chunk_at = Some(Instant::now()); let task = cx.spawn(async move |thread, cx| { - let stream_completion_future = model.stream_completion_with_usage(request, &cx); + let stream_completion_future = model.stream_completion(request, &cx); let initial_token_usage = thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { - let (mut events, usage) = stream_completion_future.await?; + let mut events = stream_completion_future.await?; let mut stop_reason = StopReason::EndTurn; let mut current_token_usage = TokenUsage::default(); thread .update(cx, |_thread, cx| { - if let Some(usage) = usage { - cx.emit(ThreadEvent::UsageUpdated(usage)); - } cx.emit(ThreadEvent::NewRequest); }) .ok(); @@ -1515,27 +1512,34 @@ impl Thread { }); } } - LanguageModelCompletionEvent::QueueUpdate(status) => { + LanguageModelCompletionEvent::StatusUpdate(status_update) => { if let Some(completion) = thread .pending_completions .iter_mut() .find(|completion| completion.id == pending_completion_id) { - let queue_state = match status { - language_model::CompletionRequestStatus::Queued { + match status_update { + CompletionRequestStatus::Queued { position, - } => Some(QueueState::Queued { position }), - language_model::CompletionRequestStatus::Started => { - Some(QueueState::Started) + } => { + completion.queue_state = QueueState::Queued { position }; } - language_model::CompletionRequestStatus::ToolUseLimitReached => { + CompletionRequestStatus::Started => { + completion.queue_state = QueueState::Started; + } + CompletionRequestStatus::Failed { + code, message + } => { + return Err(anyhow!("completion request failed. code: {code}, message: {message}")); + } + CompletionRequestStatus::UsageUpdated { + amount, limit + } => { + cx.emit(ThreadEvent::UsageUpdated(RequestUsage { limit, amount: amount as i32 })); + } + CompletionRequestStatus::ToolUseLimitReached => { thread.tool_use_limit_reached = true; - None } - }; - - if let Some(queue_state) = queue_state { - completion.queue_state = queue_state; } } } @@ -1690,19 +1694,27 @@ impl Thread { self.pending_summary = cx.spawn(async move |this, cx| { async move { - let stream = model.model.stream_completion_text_with_usage(request, &cx); - let (mut messages, usage) = stream.await?; - - if let Some(usage) = usage { - this.update(cx, |_thread, cx| { - cx.emit(ThreadEvent::UsageUpdated(usage)); - }) - .ok(); - } + let mut messages = model.model.stream_completion(request, &cx).await?; let mut new_summary = String::new(); - while let Some(message) = messages.stream.next().await { - let text = message?; + while let Some(event) = messages.next().await { + let event = event?; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { amount, limit }, + ) => { + this.update(cx, |_, cx| { + cx.emit(ThreadEvent::UsageUpdated(RequestUsage { + limit, + amount: amount as i32, + })); + })?; + continue; + } + _ => continue, + }; + let mut lines = text.lines(); new_summary.extend(lines.next()); diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index 6beeaf3461..fe607fb291 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -2371,7 +2371,7 @@ impl AssistantContext { }); match event { - LanguageModelCompletionEvent::QueueUpdate { .. } => {} + LanguageModelCompletionEvent::StatusUpdate { .. } => {} LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; @@ -2429,8 +2429,8 @@ impl AssistantContext { cx, ); } - LanguageModelCompletionEvent::ToolUse(_) => {} - LanguageModelCompletionEvent::UsageUpdate(_) => {} + LanguageModelCompletionEvent::ToolUse(_) | + LanguageModelCompletionEvent::UsageUpdate(_) => {} } }); diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index d3c5fdb29c..f9c9b72e30 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -1018,7 +1018,7 @@ pub fn response_events_to_markdown( Ok( LanguageModelCompletionEvent::UsageUpdate(_) | LanguageModelCompletionEvent::StartMessage { .. } - | LanguageModelCompletionEvent::QueueUpdate { .. }, + | LanguageModelCompletionEvent::StatusUpdate { .. }, ) => {} Err(error) => { flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); @@ -1093,7 +1093,7 @@ impl ThreadDialog { // Skip these Ok(LanguageModelCompletionEvent::UsageUpdate(_)) - | Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) + | Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) | Ok(LanguageModelCompletionEvent::StartMessage { .. }) | Ok(LanguageModelCompletionEvent::Stop(_)) => {} diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 7ff3fc86ec..1d73b7334b 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -26,7 +26,8 @@ use std::sync::Arc; use thiserror::Error; use util::serde::is_default; use zed_llm_client::{ - MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, + CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, + MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, }; pub use crate::model::*; @@ -64,18 +65,10 @@ pub struct LanguageModelCacheConfiguration { pub min_total_token: usize, } -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] -#[serde(tag = "status", rename_all = "snake_case")] -pub enum CompletionRequestStatus { - Queued { position: usize }, - Started, - ToolUseLimitReached, -} - /// A completion event from a language model. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub enum LanguageModelCompletionEvent { - QueueUpdate(CompletionRequestStatus), + StatusUpdate(CompletionRequestStatus), Stop(StopReason), Text(String), Thinking { @@ -299,41 +292,15 @@ pub trait LanguageModel: Send + Sync { >, >; - fn stream_completion_with_usage( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result<( - BoxStream<'static, Result>, - Option, - )>, - > { - self.stream_completion(request, cx) - .map(|result| result.map(|stream| (stream, None))) - .boxed() - } - fn stream_completion_text( &self, request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture<'static, Result> { - self.stream_completion_text_with_usage(request, cx) - .map(|result| result.map(|(stream, _usage)| stream)) - .boxed() - } - - fn stream_completion_text_with_usage( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture<'static, Result<(LanguageModelTextStream, Option)>> { - let future = self.stream_completion_with_usage(request, cx); + let future = self.stream_completion(request, cx); async move { - let (events, usage) = future.await?; + let events = future.await?; let mut events = events.fuse(); let mut message_id = None; let mut first_item_text = None; @@ -358,7 +325,7 @@ pub trait LanguageModel: Send + Sync { let last_token_usage = last_token_usage.clone(); async move { match result { - Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None, + Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Thinking { .. }) => None, @@ -375,14 +342,11 @@ pub trait LanguageModel: Send + Sync { })) .boxed(); - Ok(( - LanguageModelTextStream { - message_id, - stream, - last_token_usage, - }, - usage, - )) + Ok(LanguageModelTextStream { + message_id, + stream, + last_token_usage, + }) } .boxed() } diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model/src/rate_limiter.rs index 7383dd56c9..a48d34488b 100644 --- a/crates/language_model/src/rate_limiter.rs +++ b/crates/language_model/src/rate_limiter.rs @@ -8,8 +8,6 @@ use std::{ task::{Context, Poll}, }; -use crate::RequestUsage; - #[derive(Clone)] pub struct RateLimiter { semaphore: Arc, @@ -69,32 +67,4 @@ impl RateLimiter { }) } } - - pub fn stream_with_usage<'a, Fut, T>( - &self, - future: Fut, - ) -> impl 'a - + Future< - Output = Result<( - impl Stream + use, - Option, - )>, - > - where - Fut: 'a + Future)>>, - T: Stream, - { - let guard = self.semaphore.acquire_arc(); - async move { - let guard = guard.await; - let (inner, usage) = future.await?; - Ok(( - RateLimitGuard { - inner, - _guard: guard, - }, - usage, - )) - } - } } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 68c494234e..75ea219fa5 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -9,12 +9,11 @@ use futures::{ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ - AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel, - LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId, - LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, - LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, - ZED_CLOUD_PROVIDER_ID, + AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, + LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, + ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, @@ -36,9 +35,10 @@ use strum::IntoEnumIterator; use thiserror::Error; use ui::{TintColor, prelude::*}; use zed_llm_client::{ - CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse, - EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, - MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, + CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, + CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, + MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, + SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME, }; @@ -517,7 +517,7 @@ struct PerformLlmCompletionResponse { response: Response, usage: Option, tool_use_limit_reached: bool, - includes_queue_events: bool, + includes_status_messages: bool, } impl CloudLanguageModel { @@ -545,25 +545,31 @@ impl CloudLanguageModel { let request = request_builder .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {token}")) - .header("x-zed-client-supports-queueing", "true") + .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") .body(serde_json::to_string(&body)?.into())?; let mut response = http_client.send(request).await?; let status = response.status(); if status.is_success() { - let includes_queue_events = response + let includes_status_messages = response .headers() - .get("x-zed-server-supports-queueing") + .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME) .is_some(); + let tool_use_limit_reached = response .headers() .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME) .is_some(); - let usage = RequestUsage::from_headers(response.headers()).ok(); + + let usage = if includes_status_messages { + None + } else { + RequestUsage::from_headers(response.headers()).ok() + }; return Ok(PerformLlmCompletionResponse { response, usage, - includes_queue_events, + includes_status_messages, tool_use_limit_reached, }); } else if response @@ -767,28 +773,12 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, request: LanguageModelRequest, - cx: &AsyncApp, + _cx: &AsyncApp, ) -> BoxFuture< 'static, Result< BoxStream<'static, Result>, >, - > { - self.stream_completion_with_usage(request, cx) - .map(|result| result.map(|(stream, _)| stream)) - .boxed() - } - - fn stream_completion_with_usage( - &self, - request: LanguageModelRequest, - _cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result<( - BoxStream<'static, Result>, - Option, - )>, > { let thread_id = request.thread_id.clone(); let prompt_id = request.prompt_id.clone(); @@ -804,11 +794,11 @@ impl LanguageModel for CloudLanguageModel { ); let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream_with_usage(async move { + let future = self.request_limiter.stream(async move { let PerformLlmCompletionResponse { response, usage, - includes_queue_events, + includes_status_messages, tool_use_limit_reached, } = Self::perform_llm_completion( client.clone(), @@ -840,32 +830,26 @@ impl LanguageModel for CloudLanguageModel { })?; let mut mapper = AnthropicEventMapper::new(); - Ok(( - map_cloud_completion_events( - Box::pin( - response_lines(response, includes_queue_events) - .chain(tool_use_limit_reached_event(tool_use_limit_reached)), - ), - move |event| mapper.map_event(event), + Ok(map_cloud_completion_events( + Box::pin( + response_lines(response, includes_status_messages) + .chain(usage_updated_event(usage)) + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), ), - usage, + move |event| mapper.map_event(event), )) }); - async move { - let (stream, usage) = future.await?; - Ok((stream.boxed(), usage)) - } - .boxed() + async move { Ok(future.await?.boxed()) }.boxed() } CloudModel::OpenAi(model) => { let client = self.client.clone(); let request = into_open_ai(request, model, model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream_with_usage(async move { + let future = self.request_limiter.stream(async move { let PerformLlmCompletionResponse { response, usage, - includes_queue_events, + includes_status_messages, tool_use_limit_reached, } = Self::perform_llm_completion( client.clone(), @@ -882,32 +866,26 @@ impl LanguageModel for CloudLanguageModel { .await?; let mut mapper = OpenAiEventMapper::new(); - Ok(( - map_cloud_completion_events( - Box::pin( - response_lines(response, includes_queue_events) - .chain(tool_use_limit_reached_event(tool_use_limit_reached)), - ), - move |event| mapper.map_event(event), + Ok(map_cloud_completion_events( + Box::pin( + response_lines(response, includes_status_messages) + .chain(usage_updated_event(usage)) + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), ), - usage, + move |event| mapper.map_event(event), )) }); - async move { - let (stream, usage) = future.await?; - Ok((stream.boxed(), usage)) - } - .boxed() + async move { Ok(future.await?.boxed()) }.boxed() } CloudModel::Google(model) => { let client = self.client.clone(); let request = into_google(request, model.id().into()); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream_with_usage(async move { + let future = self.request_limiter.stream(async move { let PerformLlmCompletionResponse { response, usage, - includes_queue_events, + includes_status_messages, tool_use_limit_reached, } = Self::perform_llm_completion( client.clone(), @@ -924,22 +902,16 @@ impl LanguageModel for CloudLanguageModel { .await?; let mut mapper = GoogleEventMapper::new(); - Ok(( - map_cloud_completion_events( - Box::pin( - response_lines(response, includes_queue_events) - .chain(tool_use_limit_reached_event(tool_use_limit_reached)), - ), - move |event| mapper.map_event(event), + Ok(map_cloud_completion_events( + Box::pin( + response_lines(response, includes_status_messages) + .chain(usage_updated_event(usage)) + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), ), - usage, + move |event| mapper.map_event(event), )) }); - async move { - let (stream, usage) = future.await?; - Ok((stream.boxed(), usage)) - } - .boxed() + async move { Ok(future.await?.boxed()) }.boxed() } } } @@ -948,7 +920,7 @@ impl LanguageModel for CloudLanguageModel { #[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum CloudCompletionEvent { - System(CompletionRequestStatus), + Status(CompletionRequestStatus), Event(T), } @@ -968,8 +940,8 @@ where Err(error) => { vec![Err(LanguageModelCompletionError::Other(error))] } - Ok(CloudCompletionEvent::System(event)) => { - vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))] + Ok(CloudCompletionEvent::Status(event)) => { + vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] } Ok(CloudCompletionEvent::Event(event)) => map_callback(event), }) @@ -977,11 +949,24 @@ where .boxed() } +fn usage_updated_event( + usage: Option, +) -> impl Stream>> { + futures::stream::iter(usage.map(|usage| { + Ok(CloudCompletionEvent::Status( + CompletionRequestStatus::UsageUpdated { + amount: usage.amount as usize, + limit: usage.limit, + }, + )) + })) +} + fn tool_use_limit_reached_event( tool_use_limit_reached: bool, ) -> impl Stream>> { futures::stream::iter(tool_use_limit_reached.then(|| { - Ok(CloudCompletionEvent::System( + Ok(CloudCompletionEvent::Status( CompletionRequestStatus::ToolUseLimitReached, )) })) @@ -989,7 +974,7 @@ fn tool_use_limit_reached_event( fn response_lines( response: Response, - includes_queue_events: bool, + includes_status_messages: bool, ) -> impl Stream>> { futures::stream::try_unfold( (String::new(), BufReader::new(response.into_body())), @@ -997,7 +982,7 @@ fn response_lines( match body.read_line(&mut line).await { Ok(0) => Ok(None), Ok(_) => { - let event = if includes_queue_events { + let event = if includes_status_messages { serde_json::from_str::>(&line)? } else { CloudCompletionEvent::Event(serde_json::from_str::(&line)?)