Add support for queuing status updates in cloud language model provider (#29818)

This sets us up to display queue position information to the user, once
our language model backend is updated to support request queuing.

The JSON returned by the LLM backend will need to look like this:

```json
{"queue": {"status": "queued", "position": 1}}
{"queue": {"status": "started"}}
{"event": {"THE_UPSTREAM_MODEL_PROVIDER_EVENT": "..."}} 
```

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Max Brunsfeld 2025-05-02 13:36:39 -07:00 committed by GitHub
parent 4d1df7bcd7
commit 04772bf17d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 492 additions and 430 deletions

View file

@ -4,8 +4,8 @@ use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::insert_message_creases; use crate::message_editor::insert_message_creases;
use crate::thread::{ use crate::thread::{
LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, QueueState, Thread,
ThreadEvent, ThreadFeedback, ThreadError, ThreadEvent, ThreadFeedback,
}; };
use crate::thread_store::{RulesLoadingError, ThreadStore}; use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse}; use crate::tool_use::{PendingToolUseStatus, ToolUse};
@ -1733,8 +1733,27 @@ impl ActiveThread {
let show_feedback = thread.is_turn_end(ix); let show_feedback = thread.is_turn_end(ix);
let generating_label = (is_generating && is_last_message) let generating_label = is_last_message
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small)); .then(|| match (thread.queue_state(), is_generating) {
(Some(QueueState::Sending), _) => Some(
AnimatedLabel::new("Sending")
.size(LabelSize::Small)
.into_any_element(),
),
(Some(QueueState::Queued { position }), _) => Some(
Label::new(format!("Queue position: {position}"))
.size(LabelSize::Small)
.color(Color::Muted)
.into_any_element(),
),
(_, true) => Some(
AnimatedLabel::new("Generating")
.size(LabelSize::Small)
.into_any_element(),
),
_ => None,
})
.flatten();
let editing_message_state = self let editing_message_state = self
.editing_message .editing_message
@ -2105,7 +2124,7 @@ impl ActiveThread {
parent.child(self.render_rules_item(cx)) parent.child(self.render_rules_item(cx))
}) })
.child(styled_message) .child(styled_message)
.when(generating_label.is_some(), |this| { .when_some(generating_label, |this, generating_label| {
this.child( this.child(
h_flex() h_flex()
.h_8() .h_8()
@ -2113,7 +2132,7 @@ impl ActiveThread {
.mb_4() .mb_4()
.ml_4() .ml_4()
.py_1p5() .py_1p5()
.child(generating_label.unwrap()), .child(generating_label),
) )
}) })
.when(show_feedback, move |parent| { .when(show_feedback, move |parent| {

View file

@ -320,6 +320,13 @@ fn default_completion_mode(cx: &App) -> CompletionMode {
} }
} }
#[derive(Debug, Clone, Copy)]
pub enum QueueState {
Sending,
Queued { position: usize },
Started,
}
/// A thread of conversation with the LLM. /// A thread of conversation with the LLM.
pub struct Thread { pub struct Thread {
id: ThreadId, id: ThreadId,
@ -625,6 +632,12 @@ impl Thread {
!self.pending_completions.is_empty() || !self.all_tools_finished() !self.pending_completions.is_empty() || !self.all_tools_finished()
} }
pub fn queue_state(&self) -> Option<QueueState> {
self.pending_completions
.first()
.map(|pending_completion| pending_completion.queue_state)
}
pub fn tools(&self) -> &Entity<ToolWorkingSet> { pub fn tools(&self) -> &Entity<ToolWorkingSet> {
&self.tools &self.tools
} }
@ -1470,6 +1483,20 @@ impl Thread {
}); });
} }
} }
LanguageModelCompletionEvent::QueueUpdate(queue_event) => {
if let Some(completion) = thread
.pending_completions
.iter_mut()
.find(|completion| completion.id == pending_completion_id)
{
completion.queue_state = match queue_event {
language_model::QueueState::Queued { position } => {
QueueState::Queued { position }
}
language_model::QueueState::Started => QueueState::Started,
}
}
}
} }
thread.touch_updated_at(); thread.touch_updated_at();
@ -1590,6 +1617,7 @@ impl Thread {
self.pending_completions.push(PendingCompletion { self.pending_completions.push(PendingCompletion {
id: pending_completion_id, id: pending_completion_id,
queue_state: QueueState::Sending,
_task: task, _task: task,
}); });
} }
@ -2499,6 +2527,7 @@ impl EventEmitter<ThreadEvent> for Thread {}
struct PendingCompletion { struct PendingCompletion {
id: usize, id: usize,
queue_state: QueueState,
_task: Task<()>, _task: Task<()>,
} }

View file

@ -2371,6 +2371,7 @@ impl AssistantContext {
}); });
match event { match event {
LanguageModelCompletionEvent::QueueUpdate { .. } => {}
LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason; stop_reason = reason;

View file

@ -1017,7 +1017,8 @@ pub fn response_events_to_markdown(
} }
Ok( Ok(
LanguageModelCompletionEvent::UsageUpdate(_) LanguageModelCompletionEvent::UsageUpdate(_)
| LanguageModelCompletionEvent::StartMessage { .. }, | LanguageModelCompletionEvent::StartMessage { .. }
| LanguageModelCompletionEvent::QueueUpdate { .. },
) => {} ) => {}
Err(error) => { Err(error) => {
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
@ -1092,6 +1093,7 @@ impl ThreadDialog {
// Skip these // Skip these
Ok(LanguageModelCompletionEvent::UsageUpdate(_)) Ok(LanguageModelCompletionEvent::UsageUpdate(_))
| Ok(LanguageModelCompletionEvent::QueueUpdate { .. })
| Ok(LanguageModelCompletionEvent::StartMessage { .. }) | Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {} | Ok(LanguageModelCompletionEvent::Stop(_)) => {}

View file

@ -64,9 +64,17 @@ pub struct LanguageModelCacheConfiguration {
pub min_total_token: usize, pub min_total_token: usize,
} }
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum QueueState {
Queued { position: usize },
Started,
}
/// A completion event from a language model. /// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent { pub enum LanguageModelCompletionEvent {
QueueUpdate(QueueState),
Stop(StopReason), Stop(StopReason),
Text(String), Text(String),
Thinking { Thinking {
@ -349,6 +357,7 @@ pub trait LanguageModel: Send + Sync {
let last_token_usage = last_token_usage.clone(); let last_token_usage = last_token_usage.clone();
async move { async move {
match result { match result {
Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None,
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None, Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,

View file

@ -469,7 +469,7 @@ impl LanguageModel for AnthropicModel {
Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err), Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
Err(err) => anyhow!(err), Err(err) => anyhow!(err),
})?; })?;
Ok(map_to_language_model_completion_events(response)) Ok(AnthropicEventMapper::new().map_stream(response))
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move { Ok(future.await?.boxed()) }.boxed()
} }
@ -629,58 +629,59 @@ pub fn into_anthropic(
} }
} }
pub fn map_to_language_model_completion_events( pub struct AnthropicEventMapper {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
name: String,
input_json: String,
}
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>, tool_uses_by_index: HashMap<usize, RawToolUse>,
usage: Usage, usage: Usage,
stop_reason: StopReason, stop_reason: StopReason,
} }
futures::stream::unfold( impl AnthropicEventMapper {
State { pub fn new() -> Self {
events, Self {
tool_uses_by_index: HashMap::default(), tool_uses_by_index: HashMap::default(),
usage: Usage::default(), usage: Usage::default(),
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
}, }
|mut state| async move { }
while let Some(event) = state.events.next().await {
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
})
})
}
pub fn map_event(
&mut self,
event: Event,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
match event { match event {
Ok(event) => match event {
Event::ContentBlockStart { Event::ContentBlockStart {
index, index,
content_block, content_block,
} => match content_block { } => match content_block {
ResponseContent::Text { text } => { ResponseContent::Text { text } => {
return Some(( vec![Ok(LanguageModelCompletionEvent::Text(text))]
vec![Ok(LanguageModelCompletionEvent::Text(text))],
state,
));
} }
ResponseContent::Thinking { thinking } => { ResponseContent::Thinking { thinking } => {
return Some((
vec![Ok(LanguageModelCompletionEvent::Thinking { vec![Ok(LanguageModelCompletionEvent::Thinking {
text: thinking, text: thinking,
signature: None, signature: None,
})], })]
state,
));
} }
ResponseContent::RedactedThinking { .. } => { ResponseContent::RedactedThinking { .. } => {
// Redacted thinking is encrypted and not accessible to the user, see: // Redacted thinking is encrypted and not accessible to the user, see:
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production
Vec::new()
} }
ResponseContent::ToolUse { id, name, .. } => { ResponseContent::ToolUse { id, name, .. } => {
state.tool_uses_by_index.insert( self.tool_uses_by_index.insert(
index, index,
RawToolUse { RawToolUse {
id, id,
@ -688,35 +689,27 @@ pub fn map_to_language_model_completion_events(
input_json: String::new(), input_json: String::new(),
}, },
); );
Vec::new()
} }
}, },
Event::ContentBlockDelta { index, delta } => match delta { Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => { ContentDelta::TextDelta { text } => {
return Some(( vec![Ok(LanguageModelCompletionEvent::Text(text))]
vec![Ok(LanguageModelCompletionEvent::Text(text))],
state,
));
} }
ContentDelta::ThinkingDelta { thinking } => { ContentDelta::ThinkingDelta { thinking } => {
return Some((
vec![Ok(LanguageModelCompletionEvent::Thinking { vec![Ok(LanguageModelCompletionEvent::Thinking {
text: thinking, text: thinking,
signature: None, signature: None,
})], })]
state,
));
} }
ContentDelta::SignatureDelta { signature } => { ContentDelta::SignatureDelta { signature } => {
return Some((
vec![Ok(LanguageModelCompletionEvent::Thinking { vec![Ok(LanguageModelCompletionEvent::Thinking {
text: "".to_string(), text: "".to_string(),
signature: Some(signature), signature: Some(signature),
})], })]
state,
));
} }
ContentDelta::InputJsonDelta { partial_json } => { ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json); tool_use.input_json.push_str(&partial_json);
// Try to convert invalid (incomplete) JSON into // Try to convert invalid (incomplete) JSON into
@ -726,8 +719,7 @@ pub fn map_to_language_model_completion_events(
if let Ok(input) = serde_json::Value::from_str( if let Ok(input) = serde_json::Value::from_str(
&partial_json_fixer::fix_json(&tool_use.input_json), &partial_json_fixer::fix_json(&tool_use.input_json),
) { ) {
return Some(( return vec![Ok(LanguageModelCompletionEvent::ToolUse(
vec![Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse { LanguageModelToolUse {
id: tool_use.id.clone().into(), id: tool_use.id.clone().into(),
name: tool_use.name.clone().into(), name: tool_use.name.clone().into(),
@ -735,15 +727,14 @@ pub fn map_to_language_model_completion_events(
raw_input: tool_use.input_json.clone(), raw_input: tool_use.input_json.clone(),
input, input,
}, },
))], ))];
state,
));
} }
} }
return vec![];
} }
}, },
Event::ContentBlockStop { index } => { Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
let input_json = tool_use.input_json.trim(); let input_json = tool_use.input_json.trim();
let input_value = if input_json.is_empty() { let input_value = if input_json.is_empty() {
Ok(serde_json::Value::Object(serde_json::Map::default())) Ok(serde_json::Value::Object(serde_json::Map::default()))
@ -760,84 +751,64 @@ pub fn map_to_language_model_completion_events(
raw_input: tool_use.input_json.clone(), raw_input: tool_use.input_json.clone(),
}, },
)), )),
Err(json_parse_err) => { Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_use.id.into(), id: tool_use.id.into(),
tool_name: tool_use.name.into(), tool_name: tool_use.name.into(),
raw_input: input_json.into(), raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(), json_parse_error: json_parse_err.to_string(),
}) }),
}
}; };
return Some((vec![event_result], state)); vec![event_result]
} else {
Vec::new()
} }
} }
Event::MessageStart { message } => { Event::MessageStart { message } => {
update_usage(&mut state.usage, &message.usage); update_usage(&mut self.usage, &message.usage);
return Some((
vec![ vec![
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
&state.usage, &self.usage,
))), ))),
Ok(LanguageModelCompletionEvent::StartMessage { Ok(LanguageModelCompletionEvent::StartMessage {
message_id: message.id, message_id: message.id,
}), }),
], ]
state,
));
} }
Event::MessageDelta { delta, usage } => { Event::MessageDelta { delta, usage } => {
update_usage(&mut state.usage, &usage); update_usage(&mut self.usage, &usage);
if let Some(stop_reason) = delta.stop_reason.as_deref() { if let Some(stop_reason) = delta.stop_reason.as_deref() {
state.stop_reason = match stop_reason { self.stop_reason = match stop_reason {
"end_turn" => StopReason::EndTurn, "end_turn" => StopReason::EndTurn,
"max_tokens" => StopReason::MaxTokens, "max_tokens" => StopReason::MaxTokens,
"tool_use" => StopReason::ToolUse, "tool_use" => StopReason::ToolUse,
_ => { _ => {
log::error!( log::error!("Unexpected anthropic stop_reason: {stop_reason}");
"Unexpected anthropic stop_reason: {stop_reason}"
);
StopReason::EndTurn StopReason::EndTurn
} }
}; };
} }
return Some((
vec![Ok(LanguageModelCompletionEvent::UsageUpdate( vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&state.usage), convert_usage(&self.usage),
))], ))]
state,
));
} }
Event::MessageStop => { Event::MessageStop => {
return Some(( vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))],
state,
));
} }
Event::Error { error } => { Event::Error { error } => {
return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!( vec![Err(LanguageModelCompletionError::Other(anyhow!(
AnthropicError::ApiError(error) AnthropicError::ApiError(error)
)))], )))]
state,
));
} }
_ => {} _ => Vec::new(),
},
Err(err) => {
return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
} }
} }
} }
None struct RawToolUse {
}, id: String,
) name: String,
.flat_map(futures::stream::iter) input_json: String,
} }
pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error { pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {

View file

@ -1,11 +1,10 @@
use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long}; use anthropic::{AnthropicModelMode, parse_prompt_too_long};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use client::{Client, UserStore, zed_urls}; use client::{Client, UserStore, zed_urls};
use collections::BTreeMap; use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag}; use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
use futures::{ use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture, AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
stream::BoxStream,
}; };
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
@ -14,7 +13,7 @@ use language_model::{
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
}; };
use language_model::{ use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@ -26,6 +25,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use smol::Timer; use smol::Timer;
use smol::io::{AsyncReadExt, BufReader}; use smol::io::{AsyncReadExt, BufReader};
use std::pin::Pin;
use std::str::FromStr as _; use std::str::FromStr as _;
use std::{ use std::{
sync::{Arc, LazyLock}, sync::{Arc, LazyLock},
@ -41,9 +41,9 @@ use zed_llm_client::{
}; };
use crate::AllLanguageModelSettings; use crate::AllLanguageModelSettings;
use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic}; use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
use crate::provider::google::into_google; use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai}; use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
pub const PROVIDER_NAME: &str = "Zed"; pub const PROVIDER_NAME: &str = "Zed";
@ -518,7 +518,7 @@ impl CloudLanguageModel {
client: Arc<Client>, client: Arc<Client>,
llm_api_token: LlmApiToken, llm_api_token: LlmApiToken,
body: CompletionBody, body: CompletionBody,
) -> Result<(Response<AsyncBody>, Option<RequestUsage>)> { ) -> Result<(Response<AsyncBody>, Option<RequestUsage>, bool)> {
let http_client = &client.http_client(); let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?; let mut token = llm_api_token.acquire(&client).await?;
@ -536,13 +536,18 @@ impl CloudLanguageModel {
let request = request_builder let request = request_builder
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}")) .header("Authorization", format!("Bearer {token}"))
.header("x-zed-client-supports-queueing", "true")
.body(serde_json::to_string(&body)?.into())?; .body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?; let mut response = http_client.send(request).await?;
let status = response.status(); let status = response.status();
if status.is_success() { if status.is_success() {
let includes_queue_events = response
.headers()
.get("x-zed-server-supports-queueing")
.is_some();
let usage = RequestUsage::from_headers(response.headers()).ok(); let usage = RequestUsage::from_headers(response.headers()).ok();
return Ok((response, usage)); return Ok((response, usage, includes_queue_events));
} else if response } else if response
.headers() .headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME) .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
@ -782,7 +787,7 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone(); let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move { let future = self.request_limiter.stream_with_usage(async move {
let (response, usage) = Self::perform_llm_completion( let (response, usage, includes_queue_events) = Self::perform_llm_completion(
client.clone(), client.clone(),
llm_api_token, llm_api_token,
CompletionBody { CompletionBody {
@ -811,9 +816,11 @@ impl LanguageModel for CloudLanguageModel {
Err(err) => anyhow!(err), Err(err) => anyhow!(err),
})?; })?;
let mut mapper = AnthropicEventMapper::new();
Ok(( Ok((
crate::provider::anthropic::map_to_language_model_completion_events( map_cloud_completion_events(
Box::pin(response_lines(response).map_err(AnthropicError::Other)), Box::pin(response_lines(response, includes_queue_events)),
move |event| mapper.map_event(event),
), ),
usage, usage,
)) ))
@ -829,7 +836,7 @@ impl LanguageModel for CloudLanguageModel {
let request = into_open_ai(request, model, model.max_output_tokens()); let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move { let future = self.request_limiter.stream_with_usage(async move {
let (response, usage) = Self::perform_llm_completion( let (response, usage, includes_queue_events) = Self::perform_llm_completion(
client.clone(), client.clone(),
llm_api_token, llm_api_token,
CompletionBody { CompletionBody {
@ -842,9 +849,12 @@ impl LanguageModel for CloudLanguageModel {
}, },
) )
.await?; .await?;
let mut mapper = OpenAiEventMapper::new();
Ok(( Ok((
crate::provider::open_ai::map_to_language_model_completion_events( map_cloud_completion_events(
Box::pin(response_lines(response)), Box::pin(response_lines(response, includes_queue_events)),
move |event| mapper.map_event(event),
), ),
usage, usage,
)) ))
@ -860,7 +870,7 @@ impl LanguageModel for CloudLanguageModel {
let request = into_google(request, model.id().into()); let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move { let future = self.request_limiter.stream_with_usage(async move {
let (response, usage) = Self::perform_llm_completion( let (response, usage, includes_queue_events) = Self::perform_llm_completion(
client.clone(), client.clone(),
llm_api_token, llm_api_token,
CompletionBody { CompletionBody {
@ -873,10 +883,12 @@ impl LanguageModel for CloudLanguageModel {
}, },
) )
.await?; .await?;
let mut mapper = GoogleEventMapper::new();
Ok(( Ok((
crate::provider::google::map_to_language_model_completion_events(Box::pin( map_cloud_completion_events(
response_lines(response), Box::pin(response_lines(response, includes_queue_events)),
)), move |event| mapper.map_event(event),
),
usage, usage,
)) ))
}); });
@ -890,16 +902,54 @@ impl LanguageModel for CloudLanguageModel {
} }
} }
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CloudCompletionEvent<T> {
Queue(QueueState),
Event(T),
}
fn map_cloud_completion_events<T, F>(
stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
mut map_callback: F,
) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
where
T: DeserializeOwned + 'static,
F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ Send
+ 'static,
{
stream
.flat_map(move |event| {
futures::stream::iter(match event {
Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))]
}
Ok(CloudCompletionEvent::Queue(event)) => {
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
}
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
})
})
.boxed()
}
fn response_lines<T: DeserializeOwned>( fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>, response: Response<AsyncBody>,
) -> impl Stream<Item = Result<T>> { includes_queue_events: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::try_unfold( futures::stream::try_unfold(
(String::new(), BufReader::new(response.into_body())), (String::new(), BufReader::new(response.into_body())),
move |(mut line, mut body)| async { move |(mut line, mut body)| async move {
match body.read_line(&mut line).await { match body.read_line(&mut line).await {
Ok(0) => Ok(None), Ok(0) => Ok(None),
Ok(_) => { Ok(_) => {
let event: T = serde_json::from_str(&line)?; let event = if includes_queue_events {
serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
} else {
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
};
line.clear(); line.clear();
Ok(Some((event, (line, body)))) Ok(Some((event, (line, body))))
} }

View file

@ -24,7 +24,10 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::{
Arc,
atomic::{self, AtomicU64},
};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*}; use ui::{Icon, IconName, List, Tooltip, prelude::*};
@ -371,7 +374,7 @@ impl LanguageModel for GoogleLanguageModel {
let response = request let response = request
.await .await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?; .map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(map_to_language_model_completion_events(response)) Ok(GoogleEventMapper::new().map_stream(response))
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move { Ok(future.await?.boxed()) }.boxed()
} }
@ -486,47 +489,54 @@ pub fn into_google(
} }
} }
pub fn map_to_language_model_completion_events( pub struct GoogleEventMapper {
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
use std::sync::atomic::{AtomicU64, Ordering};
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
usage: UsageMetadata, usage: UsageMetadata,
stop_reason: StopReason, stop_reason: StopReason,
} }
futures::stream::unfold( impl GoogleEventMapper {
State { pub fn new() -> Self {
events, Self {
usage: UsageMetadata::default(), usage: UsageMetadata::default(),
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
}, }
|mut state| async move { }
if let Some(event) = state.events.next().await {
match event { pub fn map_stream(
Ok(event) => { mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
})
})
}
pub fn map_event(
&mut self,
event: GenerateContentResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
let mut events: Vec<_> = Vec::new(); let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false; let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata { if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut state.usage, &usage_metadata); update_usage(&mut self.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&state.usage), convert_usage(&self.usage),
))) )))
} }
if let Some(candidates) = event.candidates { if let Some(candidates) = event.candidates {
for candidate in candidates { for candidate in candidates {
if let Some(finish_reason) = candidate.finish_reason.as_deref() { if let Some(finish_reason) = candidate.finish_reason.as_deref() {
state.stop_reason = match finish_reason { self.stop_reason = match finish_reason {
"STOP" => StopReason::EndTurn, "STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens, "MAX_TOKENS" => StopReason::MaxTokens,
_ => { _ => {
log::error!( log::error!("Unexpected google finish_reason: {finish_reason}");
"Unexpected google finish_reason: {finish_reason}"
);
StopReason::EndTurn StopReason::EndTurn
} }
}; };
@ -536,16 +546,15 @@ pub fn map_to_language_model_completion_events(
.parts .parts
.into_iter() .into_iter()
.for_each(|part| match part { .for_each(|part| match part {
Part::TextPart(text_part) => events.push(Ok( Part::TextPart(text_part) => {
LanguageModelCompletionEvent::Text(text_part.text), events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
)), }
Part::InlineDataPart(_) => {} Part::InlineDataPart(_) => {}
Part::FunctionCallPart(function_call_part) => { Part::FunctionCallPart(function_call_part) => {
wants_to_use_tool = true; wants_to_use_tool = true;
let name: Arc<str> = let name: Arc<str> = function_call_part.function_call.name.into();
function_call_part.function_call.name.into();
let next_tool_id = let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst); TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
let id: LanguageModelToolUseId = let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into(); format!("{}-{}", name, next_tool_id).into();
@ -554,10 +563,7 @@ pub fn map_to_language_model_completion_events(
id, id,
name, name,
is_input_complete: true, is_input_complete: true,
raw_input: function_call_part raw_input: function_call_part.function_call.args.to_string(),
.function_call
.args
.to_string(),
input: function_call_part.function_call.args, input: function_call_part.function_call.args,
}, },
))); )));
@ -570,24 +576,11 @@ pub fn map_to_language_model_completion_events(
// Even when Gemini wants to use a Tool, the API // Even when Gemini wants to use a Tool, the API
// responds with `finish_reason: STOP` // responds with `finish_reason: STOP`
if wants_to_use_tool { if wants_to_use_tool {
state.stop_reason = StopReason::ToolUse; self.stop_reason = StopReason::ToolUse;
} }
events.push(Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))); events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
return Some((events, state)); events
} }
Err(err) => {
return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
}
}
}
None
},
)
.flat_map(futures::stream::iter)
} }
pub fn count_google_tokens( pub fn count_google_tokens(

View file

@ -330,7 +330,10 @@ impl LanguageModel for OpenAiLanguageModel {
> { > {
let request = into_open_ai(request, &self.model, self.max_output_tokens()); let request = into_open_ai(request, &self.model, self.max_output_tokens());
let completions = self.stream_completion(request, cx); let completions = self.stream_completion(request, cx);
async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) } async move {
let mapper = OpenAiEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
}
.boxed() .boxed()
} }
} }
@ -422,37 +425,38 @@ pub fn into_open_ai(
} }
} }
pub fn map_to_language_model_completion_events( pub struct OpenAiEventMapper {
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)]
struct RawToolCall {
id: String,
name: String,
arguments: String,
}
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
tool_calls_by_index: HashMap<usize, RawToolCall>, tool_calls_by_index: HashMap<usize, RawToolCall>,
} }
futures::stream::unfold( impl OpenAiEventMapper {
State { pub fn new() -> Self {
events, Self {
tool_calls_by_index: HashMap::default(), tool_calls_by_index: HashMap::default(),
}, }
|mut state| async move { }
if let Some(event) = state.events.next().await {
match event { pub fn map_stream(
Ok(event) => { mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
})
})
}
pub fn map_event(
&mut self,
event: ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else { let Some(choice) = event.choices.first() else {
return Some(( return vec![Err(LanguageModelCompletionError::Other(anyhow!(
vec![Err(LanguageModelCompletionError::Other(anyhow!(
"Response contained no choices" "Response contained no choices"
)))], )))];
state,
));
}; };
let mut events = Vec::new(); let mut events = Vec::new();
@ -462,10 +466,7 @@ pub fn map_to_language_model_completion_events(
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
for tool_call in tool_calls { for tool_call in tool_calls {
let entry = state let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
.tool_calls_by_index
.entry(tool_call.index)
.or_default();
if let Some(tool_id) = tool_call.id.clone() { if let Some(tool_id) = tool_call.id.clone() {
entry.id = tool_id; entry.id = tool_id;
@ -485,15 +486,11 @@ pub fn map_to_language_model_completion_events(
match choice.finish_reason.as_deref() { match choice.finish_reason.as_deref() {
Some("stop") => { Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop( events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
StopReason::EndTurn,
)));
} }
Some("tool_calls") => { Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map( events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
|(_, tool_call)| match serde_json::Value::from_str( match serde_json::Value::from_str(&tool_call.arguments) {
&tool_call.arguments,
) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse { LanguageModelToolUse {
id: tool_call.id.clone().into(), id: tool_call.id.clone().into(),
@ -503,42 +500,33 @@ pub fn map_to_language_model_completion_events(
raw_input: tool_call.arguments.clone(), raw_input: tool_call.arguments.clone(),
}, },
)), )),
Err(error) => { Err(error) => Err(LanguageModelCompletionError::BadInputJson {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(), id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(), tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(), raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(), json_parse_error: error.to_string(),
}) }),
} }
}, }));
));
events.push(Ok(LanguageModelCompletionEvent::Stop( events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
StopReason::ToolUse,
)));
} }
Some(stop_reason) => { Some(stop_reason) => {
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
events.push(Ok(LanguageModelCompletionEvent::Stop( events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
StopReason::EndTurn,
)));
} }
None => {} None => {}
} }
return Some((events, state)); events
}
Err(err) => {
return Some((vec![Err(LanguageModelCompletionError::Other(err))], state));
}
} }
} }
None #[derive(Default)]
}, struct RawToolCall {
) id: String,
.flat_map(futures::stream::iter) name: String,
arguments: String,
} }
pub fn count_open_ai_tokens( pub fn count_open_ai_tokens(