Change cloud language model provider JSON protocol to surface errors and usage information (#29830)

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Max Brunsfeld 2025-05-04 10:37:42 -07:00 committed by GitHub
parent 3984531a45
commit c3d9cdecab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 128 additions and 197 deletions

4
Cargo.lock generated
View file

@ -18825,9 +18825,9 @@ dependencies = [
[[package]] [[package]]
name = "zed_llm_client" name = "zed_llm_client"
version = "0.7.2" version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "226e0b479b3aed072d83db276866d54bce631e3a8600fcdf4f309d73389af9c7" checksum = "2adf9bc80def4ec93c190f06eb78111865edc2576019a9753eaef6fd7bc3b72c"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"serde", "serde",

View file

@ -611,7 +611,7 @@ wasmtime-wasi = "29"
which = "6.0.0" which = "6.0.0"
wit-component = "0.221" wit-component = "0.221"
workspace-hack = "0.1.0" workspace-hack = "0.1.0"
zed_llm_client = "0.7.2" zed_llm_client = "0.7.4"
zstd = "0.11" zstd = "0.11"
[workspace.dependencies.async-stripe] [workspace.dependencies.async-stripe]

View file

@ -37,7 +37,7 @@ use settings::Settings;
use thiserror::Error; use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc}; use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid; use uuid::Uuid;
use zed_llm_client::CompletionMode; use zed_llm_client::{CompletionMode, CompletionRequestStatus};
use crate::ThreadStore; use crate::ThreadStore;
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext}; use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
@ -1356,20 +1356,17 @@ impl Thread {
self.last_received_chunk_at = Some(Instant::now()); self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| { let task = cx.spawn(async move |thread, cx| {
let stream_completion_future = model.stream_completion_with_usage(request, &cx); let stream_completion_future = model.stream_completion(request, &cx);
let initial_token_usage = let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async { let stream_completion = async {
let (mut events, usage) = stream_completion_future.await?; let mut events = stream_completion_future.await?;
let mut stop_reason = StopReason::EndTurn; let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default(); let mut current_token_usage = TokenUsage::default();
thread thread
.update(cx, |_thread, cx| { .update(cx, |_thread, cx| {
if let Some(usage) = usage {
cx.emit(ThreadEvent::UsageUpdated(usage));
}
cx.emit(ThreadEvent::NewRequest); cx.emit(ThreadEvent::NewRequest);
}) })
.ok(); .ok();
@ -1515,27 +1512,34 @@ impl Thread {
}); });
} }
} }
LanguageModelCompletionEvent::QueueUpdate(status) => { LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let Some(completion) = thread if let Some(completion) = thread
.pending_completions .pending_completions
.iter_mut() .iter_mut()
.find(|completion| completion.id == pending_completion_id) .find(|completion| completion.id == pending_completion_id)
{ {
let queue_state = match status { match status_update {
language_model::CompletionRequestStatus::Queued { CompletionRequestStatus::Queued {
position, position,
} => Some(QueueState::Queued { position }), } => {
language_model::CompletionRequestStatus::Started => { completion.queue_state = QueueState::Queued { position };
Some(QueueState::Started)
} }
language_model::CompletionRequestStatus::ToolUseLimitReached => { CompletionRequestStatus::Started => {
completion.queue_state = QueueState::Started;
}
CompletionRequestStatus::Failed {
code, message
} => {
return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
}
CompletionRequestStatus::UsageUpdated {
amount, limit
} => {
cx.emit(ThreadEvent::UsageUpdated(RequestUsage { limit, amount: amount as i32 }));
}
CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true; thread.tool_use_limit_reached = true;
None
} }
};
if let Some(queue_state) = queue_state {
completion.queue_state = queue_state;
} }
} }
} }
@ -1690,19 +1694,27 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| { self.pending_summary = cx.spawn(async move |this, cx| {
async move { async move {
let stream = model.model.stream_completion_text_with_usage(request, &cx); let mut messages = model.model.stream_completion(request, &cx).await?;
let (mut messages, usage) = stream.await?;
if let Some(usage) = usage {
this.update(cx, |_thread, cx| {
cx.emit(ThreadEvent::UsageUpdated(usage));
})
.ok();
}
let mut new_summary = String::new(); let mut new_summary = String::new();
while let Some(message) = messages.stream.next().await { while let Some(event) = messages.next().await {
let text = message?; let event = event?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
this.update(cx, |_, cx| {
cx.emit(ThreadEvent::UsageUpdated(RequestUsage {
limit,
amount: amount as i32,
}));
})?;
continue;
}
_ => continue,
};
let mut lines = text.lines(); let mut lines = text.lines();
new_summary.extend(lines.next()); new_summary.extend(lines.next());

View file

@ -2371,7 +2371,7 @@ impl AssistantContext {
}); });
match event { match event {
LanguageModelCompletionEvent::QueueUpdate { .. } => {} LanguageModelCompletionEvent::StatusUpdate { .. } => {}
LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason; stop_reason = reason;
@ -2429,8 +2429,8 @@ impl AssistantContext {
cx, cx,
); );
} }
LanguageModelCompletionEvent::ToolUse(_) => {} LanguageModelCompletionEvent::ToolUse(_) |
LanguageModelCompletionEvent::UsageUpdate(_) => {} LanguageModelCompletionEvent::UsageUpdate(_) => {}
} }
}); });

View file

@ -1018,7 +1018,7 @@ pub fn response_events_to_markdown(
Ok( Ok(
LanguageModelCompletionEvent::UsageUpdate(_) LanguageModelCompletionEvent::UsageUpdate(_)
| LanguageModelCompletionEvent::StartMessage { .. } | LanguageModelCompletionEvent::StartMessage { .. }
| LanguageModelCompletionEvent::QueueUpdate { .. }, | LanguageModelCompletionEvent::StatusUpdate { .. },
) => {} ) => {}
Err(error) => { Err(error) => {
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
@ -1093,7 +1093,7 @@ impl ThreadDialog {
// Skip these // Skip these
Ok(LanguageModelCompletionEvent::UsageUpdate(_)) Ok(LanguageModelCompletionEvent::UsageUpdate(_))
| Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) | Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
| Ok(LanguageModelCompletionEvent::StartMessage { .. }) | Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {} | Ok(LanguageModelCompletionEvent::Stop(_)) => {}

View file

@ -26,7 +26,8 @@ use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use util::serde::is_default; use util::serde::is_default;
use zed_llm_client::{ use zed_llm_client::{
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME,
MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
}; };
pub use crate::model::*; pub use crate::model::*;
@ -64,18 +65,10 @@ pub struct LanguageModelCacheConfiguration {
pub min_total_token: usize, pub min_total_token: usize,
} }
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum CompletionRequestStatus {
Queued { position: usize },
Started,
ToolUseLimitReached,
}
/// A completion event from a language model. /// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent { pub enum LanguageModelCompletionEvent {
QueueUpdate(CompletionRequestStatus), StatusUpdate(CompletionRequestStatus),
Stop(StopReason), Stop(StopReason),
Text(String), Text(String),
Thinking { Thinking {
@ -299,41 +292,15 @@ pub trait LanguageModel: Send + Sync {
>, >,
>; >;
fn stream_completion_with_usage(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> {
self.stream_completion(request, cx)
.map(|result| result.map(|stream| (stream, None)))
.boxed()
}
fn stream_completion_text( fn stream_completion_text(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<LanguageModelTextStream>> { ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
self.stream_completion_text_with_usage(request, cx) let future = self.stream_completion(request, cx);
.map(|result| result.map(|(stream, _usage)| stream))
.boxed()
}
fn stream_completion_text_with_usage(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<(LanguageModelTextStream, Option<RequestUsage>)>> {
let future = self.stream_completion_with_usage(request, cx);
async move { async move {
let (events, usage) = future.await?; let events = future.await?;
let mut events = events.fuse(); let mut events = events.fuse();
let mut message_id = None; let mut message_id = None;
let mut first_item_text = None; let mut first_item_text = None;
@ -358,7 +325,7 @@ pub trait LanguageModel: Send + Sync {
let last_token_usage = last_token_usage.clone(); let last_token_usage = last_token_usage.clone();
async move { async move {
match result { match result {
Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None, Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None, Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
@ -375,14 +342,11 @@ pub trait LanguageModel: Send + Sync {
})) }))
.boxed(); .boxed();
Ok(( Ok(LanguageModelTextStream {
LanguageModelTextStream { message_id,
message_id, stream,
stream, last_token_usage,
last_token_usage, })
},
usage,
))
} }
.boxed() .boxed()
} }

View file

@ -8,8 +8,6 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use crate::RequestUsage;
#[derive(Clone)] #[derive(Clone)]
pub struct RateLimiter { pub struct RateLimiter {
semaphore: Arc<Semaphore>, semaphore: Arc<Semaphore>,
@ -69,32 +67,4 @@ impl RateLimiter {
}) })
} }
} }
pub fn stream_with_usage<'a, Fut, T>(
&self,
future: Fut,
) -> impl 'a
+ Future<
Output = Result<(
impl Stream<Item = T::Item> + use<Fut, T>,
Option<RequestUsage>,
)>,
>
where
Fut: 'a + Future<Output = Result<(T, Option<RequestUsage>)>>,
T: Stream,
{
let guard = self.semaphore.acquire_arc();
async move {
let guard = guard.await;
let (inner, usage) = future.await?;
Ok((
RateLimitGuard {
inner,
_guard: guard,
},
usage,
))
}
}
} }

View file

@ -9,12 +9,11 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{ use language_model::{
AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel, AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
ZED_CLOUD_PROVIDER_ID,
}; };
use language_model::{ use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@ -36,9 +35,10 @@ use strum::IntoEnumIterator;
use thiserror::Error; use thiserror::Error;
use ui::{TintColor, prelude::*}; use ui::{TintColor, prelude::*};
use zed_llm_client::{ use zed_llm_client::{
CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse, CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME,
}; };
@ -517,7 +517,7 @@ struct PerformLlmCompletionResponse {
response: Response<AsyncBody>, response: Response<AsyncBody>,
usage: Option<RequestUsage>, usage: Option<RequestUsage>,
tool_use_limit_reached: bool, tool_use_limit_reached: bool,
includes_queue_events: bool, includes_status_messages: bool,
} }
impl CloudLanguageModel { impl CloudLanguageModel {
@ -545,25 +545,31 @@ impl CloudLanguageModel {
let request = request_builder let request = request_builder
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}")) .header("Authorization", format!("Bearer {token}"))
.header("x-zed-client-supports-queueing", "true") .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
.body(serde_json::to_string(&body)?.into())?; .body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?; let mut response = http_client.send(request).await?;
let status = response.status(); let status = response.status();
if status.is_success() { if status.is_success() {
let includes_queue_events = response let includes_status_messages = response
.headers() .headers()
.get("x-zed-server-supports-queueing") .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
.is_some(); .is_some();
let tool_use_limit_reached = response let tool_use_limit_reached = response
.headers() .headers()
.get(TOOL_USE_LIMIT_REACHED_HEADER_NAME) .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
.is_some(); .is_some();
let usage = RequestUsage::from_headers(response.headers()).ok();
let usage = if includes_status_messages {
None
} else {
RequestUsage::from_headers(response.headers()).ok()
};
return Ok(PerformLlmCompletionResponse { return Ok(PerformLlmCompletionResponse {
response, response,
usage, usage,
includes_queue_events, includes_status_messages,
tool_use_limit_reached, tool_use_limit_reached,
}); });
} else if response } else if response
@ -767,28 +773,12 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion( fn stream_completion(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncApp, _cx: &AsyncApp,
) -> BoxFuture< ) -> BoxFuture<
'static, 'static,
Result< Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>, BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>, >,
> {
self.stream_completion_with_usage(request, cx)
.map(|result| result.map(|(stream, _)| stream))
.boxed()
}
fn stream_completion_with_usage(
&self,
request: LanguageModelRequest,
_cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> { > {
let thread_id = request.thread_id.clone(); let thread_id = request.thread_id.clone();
let prompt_id = request.prompt_id.clone(); let prompt_id = request.prompt_id.clone();
@ -804,11 +794,11 @@ impl LanguageModel for CloudLanguageModel {
); );
let client = self.client.clone(); let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move { let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse { let PerformLlmCompletionResponse {
response, response,
usage, usage,
includes_queue_events, includes_status_messages,
tool_use_limit_reached, tool_use_limit_reached,
} = Self::perform_llm_completion( } = Self::perform_llm_completion(
client.clone(), client.clone(),
@ -840,32 +830,26 @@ impl LanguageModel for CloudLanguageModel {
})?; })?;
let mut mapper = AnthropicEventMapper::new(); let mut mapper = AnthropicEventMapper::new();
Ok(( Ok(map_cloud_completion_events(
map_cloud_completion_events( Box::pin(
Box::pin( response_lines(response, includes_status_messages)
response_lines(response, includes_queue_events) .chain(usage_updated_event(usage))
.chain(tool_use_limit_reached_event(tool_use_limit_reached)), .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
move |event| mapper.map_event(event),
), ),
usage, move |event| mapper.map_event(event),
)) ))
}); });
async move { async move { Ok(future.await?.boxed()) }.boxed()
let (stream, usage) = future.await?;
Ok((stream.boxed(), usage))
}
.boxed()
} }
CloudModel::OpenAi(model) => { CloudModel::OpenAi(model) => {
let client = self.client.clone(); let client = self.client.clone();
let request = into_open_ai(request, model, model.max_output_tokens()); let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move { let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse { let PerformLlmCompletionResponse {
response, response,
usage, usage,
includes_queue_events, includes_status_messages,
tool_use_limit_reached, tool_use_limit_reached,
} = Self::perform_llm_completion( } = Self::perform_llm_completion(
client.clone(), client.clone(),
@ -882,32 +866,26 @@ impl LanguageModel for CloudLanguageModel {
.await?; .await?;
let mut mapper = OpenAiEventMapper::new(); let mut mapper = OpenAiEventMapper::new();
Ok(( Ok(map_cloud_completion_events(
map_cloud_completion_events( Box::pin(
Box::pin( response_lines(response, includes_status_messages)
response_lines(response, includes_queue_events) .chain(usage_updated_event(usage))
.chain(tool_use_limit_reached_event(tool_use_limit_reached)), .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
move |event| mapper.map_event(event),
), ),
usage, move |event| mapper.map_event(event),
)) ))
}); });
async move { async move { Ok(future.await?.boxed()) }.boxed()
let (stream, usage) = future.await?;
Ok((stream.boxed(), usage))
}
.boxed()
} }
CloudModel::Google(model) => { CloudModel::Google(model) => {
let client = self.client.clone(); let client = self.client.clone();
let request = into_google(request, model.id().into()); let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move { let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse { let PerformLlmCompletionResponse {
response, response,
usage, usage,
includes_queue_events, includes_status_messages,
tool_use_limit_reached, tool_use_limit_reached,
} = Self::perform_llm_completion( } = Self::perform_llm_completion(
client.clone(), client.clone(),
@ -924,22 +902,16 @@ impl LanguageModel for CloudLanguageModel {
.await?; .await?;
let mut mapper = GoogleEventMapper::new(); let mut mapper = GoogleEventMapper::new();
Ok(( Ok(map_cloud_completion_events(
map_cloud_completion_events( Box::pin(
Box::pin( response_lines(response, includes_status_messages)
response_lines(response, includes_queue_events) .chain(usage_updated_event(usage))
.chain(tool_use_limit_reached_event(tool_use_limit_reached)), .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
move |event| mapper.map_event(event),
), ),
usage, move |event| mapper.map_event(event),
)) ))
}); });
async move { async move { Ok(future.await?.boxed()) }.boxed()
let (stream, usage) = future.await?;
Ok((stream.boxed(), usage))
}
.boxed()
} }
} }
} }
@ -948,7 +920,7 @@ impl LanguageModel for CloudLanguageModel {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum CloudCompletionEvent<T> { pub enum CloudCompletionEvent<T> {
System(CompletionRequestStatus), Status(CompletionRequestStatus),
Event(T), Event(T),
} }
@ -968,8 +940,8 @@ where
Err(error) => { Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))] vec![Err(LanguageModelCompletionError::Other(error))]
} }
Ok(CloudCompletionEvent::System(event)) => { Ok(CloudCompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))] vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
} }
Ok(CloudCompletionEvent::Event(event)) => map_callback(event), Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
}) })
@ -977,11 +949,24 @@ where
.boxed() .boxed()
} }
fn usage_updated_event<T>(
usage: Option<RequestUsage>,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::iter(usage.map(|usage| {
Ok(CloudCompletionEvent::Status(
CompletionRequestStatus::UsageUpdated {
amount: usage.amount as usize,
limit: usage.limit,
},
))
}))
}
fn tool_use_limit_reached_event<T>( fn tool_use_limit_reached_event<T>(
tool_use_limit_reached: bool, tool_use_limit_reached: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { ) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::iter(tool_use_limit_reached.then(|| { futures::stream::iter(tool_use_limit_reached.then(|| {
Ok(CloudCompletionEvent::System( Ok(CloudCompletionEvent::Status(
CompletionRequestStatus::ToolUseLimitReached, CompletionRequestStatus::ToolUseLimitReached,
)) ))
})) }))
@ -989,7 +974,7 @@ fn tool_use_limit_reached_event<T>(
fn response_lines<T: DeserializeOwned>( fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>, response: Response<AsyncBody>,
includes_queue_events: bool, includes_status_messages: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { ) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::try_unfold( futures::stream::try_unfold(
(String::new(), BufReader::new(response.into_body())), (String::new(), BufReader::new(response.into_body())),
@ -997,7 +982,7 @@ fn response_lines<T: DeserializeOwned>(
match body.read_line(&mut line).await { match body.read_line(&mut line).await {
Ok(0) => Ok(None), Ok(0) => Ok(None),
Ok(_) => { Ok(_) => {
let event = if includes_queue_events { let event = if includes_status_messages {
serde_json::from_str::<CloudCompletionEvent<T>>(&line)? serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
} else { } else {
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?) CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)