Change cloud language model provider JSON protocol to surface errors and usage information (#29830)

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Max Brunsfeld 2025-05-04 10:37:42 -07:00 committed by GitHub
parent 3984531a45
commit c3d9cdecab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 128 additions and 197 deletions

View file

@ -37,7 +37,7 @@ use settings::Settings;
use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::CompletionMode;
use zed_llm_client::{CompletionMode, CompletionRequestStatus};
use crate::ThreadStore;
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
@ -1356,20 +1356,17 @@ impl Thread {
self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| {
let stream_completion_future = model.stream_completion_with_usage(request, &cx);
let stream_completion_future = model.stream_completion(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
let (mut events, usage) = stream_completion_future.await?;
let mut events = stream_completion_future.await?;
let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default();
thread
.update(cx, |_thread, cx| {
if let Some(usage) = usage {
cx.emit(ThreadEvent::UsageUpdated(usage));
}
cx.emit(ThreadEvent::NewRequest);
})
.ok();
@ -1515,27 +1512,34 @@ impl Thread {
});
}
}
LanguageModelCompletionEvent::QueueUpdate(status) => {
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let Some(completion) = thread
.pending_completions
.iter_mut()
.find(|completion| completion.id == pending_completion_id)
{
let queue_state = match status {
language_model::CompletionRequestStatus::Queued {
match status_update {
CompletionRequestStatus::Queued {
position,
} => Some(QueueState::Queued { position }),
language_model::CompletionRequestStatus::Started => {
Some(QueueState::Started)
} => {
completion.queue_state = QueueState::Queued { position };
}
language_model::CompletionRequestStatus::ToolUseLimitReached => {
CompletionRequestStatus::Started => {
completion.queue_state = QueueState::Started;
}
CompletionRequestStatus::Failed {
code, message
} => {
return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
}
CompletionRequestStatus::UsageUpdated {
amount, limit
} => {
cx.emit(ThreadEvent::UsageUpdated(RequestUsage { limit, amount: amount as i32 }));
}
CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true;
None
}
};
if let Some(queue_state) = queue_state {
completion.queue_state = queue_state;
}
}
}
@ -1690,19 +1694,27 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
let stream = model.model.stream_completion_text_with_usage(request, &cx);
let (mut messages, usage) = stream.await?;
if let Some(usage) = usage {
this.update(cx, |_thread, cx| {
cx.emit(ThreadEvent::UsageUpdated(usage));
})
.ok();
}
let mut messages = model.model.stream_completion(request, &cx).await?;
let mut new_summary = String::new();
while let Some(message) = messages.stream.next().await {
let text = message?;
while let Some(event) = messages.next().await {
let event = event?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
this.update(cx, |_, cx| {
cx.emit(ThreadEvent::UsageUpdated(RequestUsage {
limit,
amount: amount as i32,
}));
})?;
continue;
}
_ => continue,
};
let mut lines = text.lines();
new_summary.extend(lines.next());