diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index af62fedc5c..58822247ef 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -4,8 +4,8 @@ use crate::context_store::ContextStore; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::message_editor::insert_message_creases; use crate::thread::{ - LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, - ThreadEvent, ThreadFeedback, + LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, QueueState, Thread, + ThreadError, ThreadEvent, ThreadFeedback, }; use crate::thread_store::{RulesLoadingError, ThreadStore}; use crate::tool_use::{PendingToolUseStatus, ToolUse}; @@ -1733,8 +1733,27 @@ impl ActiveThread { let show_feedback = thread.is_turn_end(ix); - let generating_label = (is_generating && is_last_message) - .then(|| AnimatedLabel::new("Generating").size(LabelSize::Small)); + let generating_label = is_last_message + .then(|| match (thread.queue_state(), is_generating) { + (Some(QueueState::Sending), _) => Some( + AnimatedLabel::new("Sending") + .size(LabelSize::Small) + .into_any_element(), + ), + (Some(QueueState::Queued { position }), _) => Some( + Label::new(format!("Queue position: {position}")) + .size(LabelSize::Small) + .color(Color::Muted) + .into_any_element(), + ), + (_, true) => Some( + AnimatedLabel::new("Generating") + .size(LabelSize::Small) + .into_any_element(), + ), + _ => None, + }) + .flatten(); let editing_message_state = self .editing_message @@ -2105,7 +2124,7 @@ impl ActiveThread { parent.child(self.render_rules_item(cx)) }) .child(styled_message) - .when(generating_label.is_some(), |this| { + .when_some(generating_label, |this, generating_label| { this.child( h_flex() .h_8() @@ -2113,7 +2132,7 @@ impl ActiveThread { .mb_4() .ml_4() .py_1p5() - .child(generating_label.unwrap()), + .child(generating_label), ) }) .when(show_feedback, move |parent| { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 5fb85c4ede..5f0561bb66 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -320,6 +320,13 @@ fn default_completion_mode(cx: &App) -> CompletionMode { } } +#[derive(Debug, Clone, Copy)] +pub enum QueueState { + Sending, + Queued { position: usize }, + Started, +} + /// A thread of conversation with the LLM. pub struct Thread { id: ThreadId, @@ -625,6 +632,12 @@ impl Thread { !self.pending_completions.is_empty() || !self.all_tools_finished() } + pub fn queue_state(&self) -> Option { + self.pending_completions + .first() + .map(|pending_completion| pending_completion.queue_state) + } + pub fn tools(&self) -> &Entity { &self.tools } @@ -1470,6 +1483,20 @@ impl Thread { }); } } + LanguageModelCompletionEvent::QueueUpdate(queue_event) => { + if let Some(completion) = thread + .pending_completions + .iter_mut() + .find(|completion| completion.id == pending_completion_id) + { + completion.queue_state = match queue_event { + language_model::QueueState::Queued { position } => { + QueueState::Queued { position } + } + language_model::QueueState::Started => QueueState::Started, + } + } + } } thread.touch_updated_at(); @@ -1590,6 +1617,7 @@ impl Thread { self.pending_completions.push(PendingCompletion { id: pending_completion_id, + queue_state: QueueState::Sending, _task: task, }); } @@ -2499,6 +2527,7 @@ impl EventEmitter for Thread {} struct PendingCompletion { id: usize, + queue_state: QueueState, _task: Task<()>, } diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index f29bcbe753..6beeaf3461 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -2371,6 +2371,7 @@ impl AssistantContext { }); match event { + LanguageModelCompletionEvent::QueueUpdate { .. } => {} LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 0c477143e6..d3c5fdb29c 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -1017,7 +1017,8 @@ pub fn response_events_to_markdown( } Ok( LanguageModelCompletionEvent::UsageUpdate(_) - | LanguageModelCompletionEvent::StartMessage { .. }, + | LanguageModelCompletionEvent::StartMessage { .. } + | LanguageModelCompletionEvent::QueueUpdate { .. }, ) => {} Err(error) => { flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); @@ -1092,6 +1093,7 @@ impl ThreadDialog { // Skip these Ok(LanguageModelCompletionEvent::UsageUpdate(_)) + | Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) | Ok(LanguageModelCompletionEvent::StartMessage { .. }) | Ok(LanguageModelCompletionEvent::Stop(_)) => {} diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 4c9e918756..1146bbc137 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -64,9 +64,17 @@ pub struct LanguageModelCacheConfiguration { pub min_total_token: usize, } +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +#[serde(tag = "status", rename_all = "snake_case")] +pub enum QueueState { + Queued { position: usize }, + Started, +} + /// A completion event from a language model. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub enum LanguageModelCompletionEvent { + QueueUpdate(QueueState), Stop(StopReason), Text(String), Thinking { @@ -349,6 +357,7 @@ pub trait LanguageModel: Send + Sync { let last_token_usage = last_token_usage.clone(); async move { match result { + Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Thinking { .. }) => None, diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 52f9806832..a22a027534 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -469,7 +469,7 @@ impl LanguageModel for AnthropicModel { Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err), Err(err) => anyhow!(err), })?; - Ok(map_to_language_model_completion_events(response)) + Ok(AnthropicEventMapper::new().map_stream(response)) }); async move { Ok(future.await?.boxed()) }.boxed() } @@ -629,215 +629,186 @@ pub fn into_anthropic( } } -pub fn map_to_language_model_completion_events( - events: Pin>>>, -) -> impl Stream> { - struct RawToolUse { - id: String, - name: String, - input_json: String, - } +pub struct AnthropicEventMapper { + tool_uses_by_index: HashMap, + usage: Usage, + stop_reason: StopReason, +} - struct State { - events: Pin>>>, - tool_uses_by_index: HashMap, - usage: Usage, - stop_reason: StopReason, - } - - futures::stream::unfold( - State { - events, +impl AnthropicEventMapper { + pub fn new() -> Self { + Self { tool_uses_by_index: HashMap::default(), usage: Usage::default(), stop_reason: StopReason::EndTurn, - }, - |mut state| async move { - while let Some(event) = state.events.next().await { - match event { - Ok(event) => match event { - Event::ContentBlockStart { - index, - content_block, - } => match content_block { - ResponseContent::Text { text } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Text(text))], - state, - )); - } - ResponseContent::Thinking { thinking } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })], - state, - )); - } - ResponseContent::RedactedThinking { .. } => { - // Redacted thinking is encrypted and not accessible to the user, see: - // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production - } - ResponseContent::ToolUse { id, name, .. } => { - state.tool_uses_by_index.insert( - index, - RawToolUse { - id, - name, - input_json: String::new(), - }, - ); - } - }, - Event::ContentBlockDelta { index, delta } => match delta { - ContentDelta::TextDelta { text } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Text(text))], - state, - )); - } - ContentDelta::ThinkingDelta { thinking } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })], - state, - )); - } - ContentDelta::SignatureDelta { signature } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: "".to_string(), - signature: Some(signature), - })], - state, - )); - } - ContentDelta::InputJsonDelta { partial_json } => { - if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { - tool_use.input_json.push_str(&partial_json); + } + } - // Try to convert invalid (incomplete) JSON into - // valid JSON that serde can accept, e.g. by closing - // unclosed delimiters. This way, we can update the - // UI with whatever has been streamed back so far. - if let Ok(input) = serde_json::Value::from_str( - &partial_json_fixer::fix_json(&tool_use.input_json), - ) { - return Some(( - vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.clone().into(), - name: tool_use.name.clone().into(), - is_input_complete: false, - raw_input: tool_use.input_json.clone(), - input, - }, - ))], - state, - )); - } - } - } - }, - Event::ContentBlockStop { index } => { - if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { - let input_json = tool_use.input_json.trim(); - let input_value = if input_json.is_empty() { - Ok(serde_json::Value::Object(serde_json::Map::default())) - } else { - serde_json::Value::from_str(input_json) - }; - let event_result = match input_value { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.into(), - name: tool_use.name.into(), - is_input_complete: true, - input, - raw_input: tool_use.input_json.clone(), - }, - )), - Err(json_parse_err) => { - Err(LanguageModelCompletionError::BadInputJson { - id: tool_use.id.into(), - tool_name: tool_use.name.into(), - raw_input: input_json.into(), - json_parse_error: json_parse_err.to_string(), - }) - } - }; + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) + } - return Some((vec![event_result], state)); - } + pub fn map_event( + &mut self, + event: Event, + ) -> Vec> { + match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ResponseContent::Thinking { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ResponseContent::RedactedThinking { .. } => { + // Redacted thinking is encrypted and not accessible to the user, see: + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production + Vec::new() + } + ResponseContent::ToolUse { id, name, .. } => { + self.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), + }, + ); + Vec::new() + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ContentDelta::ThinkingDelta { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ContentDelta::SignatureDelta { signature } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature), + })] + } + ContentDelta::InputJsonDelta { partial_json } => { + if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + + // Try to convert invalid (incomplete) JSON into + // valid JSON that serde can accept, e.g. by closing + // unclosed delimiters. This way, we can update the + // UI with whatever has been streamed back so far. + if let Ok(input) = serde_json::Value::from_str( + &partial_json_fixer::fix_json(&tool_use.input_json), + ) { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + raw_input: tool_use.input_json.clone(), + input, + }, + ))]; } - Event::MessageStart { message } => { - update_usage(&mut state.usage, &message.usage); - return Some(( - vec![ - Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( - &state.usage, - ))), - Ok(LanguageModelCompletionEvent::StartMessage { - message_id: message.id, - }), - ], - state, - )); - } - Event::MessageDelta { delta, usage } => { - update_usage(&mut state.usage, &usage); - if let Some(stop_reason) = delta.stop_reason.as_deref() { - state.stop_reason = match stop_reason { - "end_turn" => StopReason::EndTurn, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - _ => { - log::error!( - "Unexpected anthropic stop_reason: {stop_reason}" - ); - StopReason::EndTurn - } - }; - } - return Some(( - vec![Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&state.usage), - ))], - state, - )); - } - Event::MessageStop => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))], - state, - )); - } - Event::Error { error } => { - return Some(( - vec![Err(LanguageModelCompletionError::Other(anyhow!( - AnthropicError::ApiError(error) - )))], - state, - )); - } - _ => {} - }, - Err(err) => { - return Some(( - vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))], - state, - )); } + return vec![]; + } + }, + Event::ContentBlockStop { index } => { + if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { + let input_json = tool_use.input_json.trim(); + let input_value = if input_json.is_empty() { + Ok(serde_json::Value::Object(serde_json::Map::default())) + } else { + serde_json::Value::from_str(input_json) + }; + let event_result = match input_value { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + is_input_complete: true, + input, + raw_input: tool_use.input_json.clone(), + }, + )), + Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson { + id: tool_use.id.into(), + tool_name: tool_use.name.into(), + raw_input: input_json.into(), + json_parse_error: json_parse_err.to_string(), + }), + }; + + vec![event_result] + } else { + Vec::new() } } + Event::MessageStart { message } => { + update_usage(&mut self.usage, &message.usage); + vec![ + Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( + &self.usage, + ))), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), + ] + } + Event::MessageDelta { delta, usage } => { + update_usage(&mut self.usage, &usage); + if let Some(stop_reason) = delta.stop_reason.as_deref() { + self.stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + _ => { + log::error!("Unexpected anthropic stop_reason: {stop_reason}"); + StopReason::EndTurn + } + }; + } + vec![Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))] + } + Event::MessageStop => { + vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] + } + Event::Error { error } => { + vec![Err(LanguageModelCompletionError::Other(anyhow!( + AnthropicError::ApiError(error) + )))] + } + _ => Vec::new(), + } + } +} - None - }, - ) - .flat_map(futures::stream::iter) +struct RawToolUse { + id: String, + name: String, + input_json: String, } pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index d0f8ba275a..556be2c75d 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,11 +1,10 @@ -use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long}; +use anthropic::{AnthropicModelMode, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::{Client, UserStore, zed_urls}; use collections::BTreeMap; use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag}; use futures::{ - AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture, - stream::BoxStream, + AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; @@ -14,7 +13,7 @@ use language_model::{ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, - ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, + ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, @@ -26,6 +25,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use settings::{Settings, SettingsStore}; use smol::Timer; use smol::io::{AsyncReadExt, BufReader}; +use std::pin::Pin; use std::str::FromStr as _; use std::{ sync::{Arc, LazyLock}, @@ -41,9 +41,9 @@ use zed_llm_client::{ }; use crate::AllLanguageModelSettings; -use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic}; -use crate::provider::google::into_google; -use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai}; +use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; +use crate::provider::google::{GoogleEventMapper, into_google}; +use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai}; pub const PROVIDER_NAME: &str = "Zed"; @@ -518,7 +518,7 @@ impl CloudLanguageModel { client: Arc, llm_api_token: LlmApiToken, body: CompletionBody, - ) -> Result<(Response, Option)> { + ) -> Result<(Response, Option, bool)> { let http_client = &client.http_client(); let mut token = llm_api_token.acquire(&client).await?; @@ -536,13 +536,18 @@ impl CloudLanguageModel { let request = request_builder .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {token}")) + .header("x-zed-client-supports-queueing", "true") .body(serde_json::to_string(&body)?.into())?; let mut response = http_client.send(request).await?; let status = response.status(); if status.is_success() { + let includes_queue_events = response + .headers() + .get("x-zed-server-supports-queueing") + .is_some(); let usage = RequestUsage::from_headers(response.headers()).ok(); - return Ok((response, usage)); + return Ok((response, usage, includes_queue_events)); } else if response .headers() .get(EXPIRED_LLM_TOKEN_HEADER_NAME) @@ -782,7 +787,7 @@ impl LanguageModel for CloudLanguageModel { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream_with_usage(async move { - let (response, usage) = Self::perform_llm_completion( + let (response, usage, includes_queue_events) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -811,9 +816,11 @@ impl LanguageModel for CloudLanguageModel { Err(err) => anyhow!(err), })?; + let mut mapper = AnthropicEventMapper::new(); Ok(( - crate::provider::anthropic::map_to_language_model_completion_events( - Box::pin(response_lines(response).map_err(AnthropicError::Other)), + map_cloud_completion_events( + Box::pin(response_lines(response, includes_queue_events)), + move |event| mapper.map_event(event), ), usage, )) @@ -829,7 +836,7 @@ impl LanguageModel for CloudLanguageModel { let request = into_open_ai(request, model, model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream_with_usage(async move { - let (response, usage) = Self::perform_llm_completion( + let (response, usage, includes_queue_events) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -842,9 +849,12 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; + + let mut mapper = OpenAiEventMapper::new(); Ok(( - crate::provider::open_ai::map_to_language_model_completion_events( - Box::pin(response_lines(response)), + map_cloud_completion_events( + Box::pin(response_lines(response, includes_queue_events)), + move |event| mapper.map_event(event), ), usage, )) @@ -860,7 +870,7 @@ impl LanguageModel for CloudLanguageModel { let request = into_google(request, model.id().into()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream_with_usage(async move { - let (response, usage) = Self::perform_llm_completion( + let (response, usage, includes_queue_events) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -873,10 +883,12 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; + let mut mapper = GoogleEventMapper::new(); Ok(( - crate::provider::google::map_to_language_model_completion_events(Box::pin( - response_lines(response), - )), + map_cloud_completion_events( + Box::pin(response_lines(response, includes_queue_events)), + move |event| mapper.map_event(event), + ), usage, )) }); @@ -890,16 +902,54 @@ impl LanguageModel for CloudLanguageModel { } } +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CloudCompletionEvent { + Queue(QueueState), + Event(T), +} + +fn map_cloud_completion_events( + stream: Pin>> + Send>>, + mut map_callback: F, +) -> BoxStream<'static, Result> +where + T: DeserializeOwned + 'static, + F: FnMut(T) -> Vec> + + Send + + 'static, +{ + stream + .flat_map(move |event| { + futures::stream::iter(match event { + Err(error) => { + vec![Err(LanguageModelCompletionError::Other(error))] + } + Ok(CloudCompletionEvent::Queue(event)) => { + vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))] + } + Ok(CloudCompletionEvent::Event(event)) => map_callback(event), + }) + }) + .boxed() +} + fn response_lines( response: Response, -) -> impl Stream> { + includes_queue_events: bool, +) -> impl Stream>> { futures::stream::try_unfold( (String::new(), BufReader::new(response.into_body())), - move |(mut line, mut body)| async { + move |(mut line, mut body)| async move { match body.read_line(&mut line).await { Ok(0) => Ok(None), Ok(_) => { - let event: T = serde_json::from_str(&line)?; + let event = if includes_queue_events { + serde_json::from_str::>(&line)? + } else { + CloudCompletionEvent::Event(serde_json::from_str::(&line)?) + }; + line.clear(); Ok(Some((event, (line, body)))) } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 1bb0df310e..b5751bb359 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -24,7 +24,10 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{ + Arc, + atomic::{self, AtomicU64}, +}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; @@ -371,7 +374,7 @@ impl LanguageModel for GoogleLanguageModel { let response = request .await .map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?; - Ok(map_to_language_model_completion_events(response)) + Ok(GoogleEventMapper::new().map_stream(response)) }); async move { Ok(future.await?.boxed()) }.boxed() } @@ -486,108 +489,98 @@ pub fn into_google( } } -pub fn map_to_language_model_completion_events( - events: Pin>>>, -) -> impl Stream> { - use std::sync::atomic::{AtomicU64, Ordering}; +pub struct GoogleEventMapper { + usage: UsageMetadata, + stop_reason: StopReason, +} - static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); - - struct State { - events: Pin>>>, - usage: UsageMetadata, - stop_reason: StopReason, - } - - futures::stream::unfold( - State { - events, +impl GoogleEventMapper { + pub fn new() -> Self { + Self { usage: UsageMetadata::default(), stop_reason: StopReason::EndTurn, - }, - |mut state| async move { - if let Some(event) = state.events.next().await { - match event { - Ok(event) => { - let mut events: Vec<_> = Vec::new(); - let mut wants_to_use_tool = false; - if let Some(usage_metadata) = event.usage_metadata { - update_usage(&mut state.usage, &usage_metadata); - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&state.usage), - ))) - } - if let Some(candidates) = event.candidates { - for candidate in candidates { - if let Some(finish_reason) = candidate.finish_reason.as_deref() { - state.stop_reason = match finish_reason { - "STOP" => StopReason::EndTurn, - "MAX_TOKENS" => StopReason::MaxTokens, - _ => { - log::error!( - "Unexpected google finish_reason: {finish_reason}" - ); - StopReason::EndTurn - } - }; - } - candidate - .content - .parts - .into_iter() - .for_each(|part| match part { - Part::TextPart(text_part) => events.push(Ok( - LanguageModelCompletionEvent::Text(text_part.text), - )), - Part::InlineDataPart(_) => {} - Part::FunctionCallPart(function_call_part) => { - wants_to_use_tool = true; - let name: Arc = - function_call_part.function_call.name.into(); - let next_tool_id = - TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst); - let id: LanguageModelToolUseId = - format!("{}-{}", name, next_tool_id).into(); + } + } - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id, - name, - is_input_complete: true, - raw_input: function_call_part - .function_call - .args - .to_string(), - input: function_call_part.function_call.args, - }, - ))); - } - Part::FunctionResponsePart(_) => {} - }); - } - } + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) + } - // Even when Gemini wants to use a Tool, the API - // responds with `finish_reason: STOP` - if wants_to_use_tool { - state.stop_reason = StopReason::ToolUse; + pub fn map_event( + &mut self, + event: GenerateContentResponse, + ) -> Vec> { + static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); + + let mut events: Vec<_> = Vec::new(); + let mut wants_to_use_tool = false; + if let Some(usage_metadata) = event.usage_metadata { + update_usage(&mut self.usage, &usage_metadata); + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))) + } + if let Some(candidates) = event.candidates { + for candidate in candidates { + if let Some(finish_reason) = candidate.finish_reason.as_deref() { + self.stop_reason = match finish_reason { + "STOP" => StopReason::EndTurn, + "MAX_TOKENS" => StopReason::MaxTokens, + _ => { + log::error!("Unexpected google finish_reason: {finish_reason}"); + StopReason::EndTurn } - events.push(Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))); - return Some((events, state)); - } - Err(err) => { - return Some(( - vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))], - state, - )); - } + }; } - } + candidate + .content + .parts + .into_iter() + .for_each(|part| match part { + Part::TextPart(text_part) => { + events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) + } + Part::InlineDataPart(_) => {} + Part::FunctionCallPart(function_call_part) => { + wants_to_use_tool = true; + let name: Arc = function_call_part.function_call.name.into(); + let next_tool_id = + TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); + let id: LanguageModelToolUseId = + format!("{}-{}", name, next_tool_id).into(); - None - }, - ) - .flat_map(futures::stream::iter) + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id, + name, + is_input_complete: true, + raw_input: function_call_part.function_call.args.to_string(), + input: function_call_part.function_call.args, + }, + ))); + } + Part::FunctionResponsePart(_) => {} + }); + } + } + + // Even when Gemini wants to use a Tool, the API + // responds with `finish_reason: STOP` + if wants_to_use_tool { + self.stop_reason = StopReason::ToolUse; + } + events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))); + events + } } pub fn count_google_tokens( diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 54f27b1727..78d183cb28 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -330,8 +330,11 @@ impl LanguageModel for OpenAiLanguageModel { > { let request = into_open_ai(request, &self.model, self.max_output_tokens()); let completions = self.stream_completion(request, cx); - async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) } - .boxed() + async move { + let mapper = OpenAiEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) + } + .boxed() } } @@ -422,123 +425,108 @@ pub fn into_open_ai( } } -pub fn map_to_language_model_completion_events( - events: Pin>>>, -) -> impl Stream> { - #[derive(Default)] - struct RawToolCall { - id: String, - name: String, - arguments: String, - } +pub struct OpenAiEventMapper { + tool_calls_by_index: HashMap, +} - struct State { - events: Pin>>>, - tool_calls_by_index: HashMap, - } - - futures::stream::unfold( - State { - events, +impl OpenAiEventMapper { + pub fn new() -> Self { + Self { tool_calls_by_index: HashMap::default(), - }, - |mut state| async move { - if let Some(event) = state.events.next().await { - match event { - Ok(event) => { - let Some(choice) = event.choices.first() else { - return Some(( - vec![Err(LanguageModelCompletionError::Other(anyhow!( - "Response contained no choices" - )))], - state, - )); - }; + } + } - let mut events = Vec::new(); - if let Some(content) = choice.delta.content.clone() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); - } + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) + } - if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { - for tool_call in tool_calls { - let entry = state - .tool_calls_by_index - .entry(tool_call.index) - .or_default(); + pub fn map_event( + &mut self, + event: ResponseStreamEvent, + ) -> Vec> { + let Some(choice) = event.choices.first() else { + return vec![Err(LanguageModelCompletionError::Other(anyhow!( + "Response contained no choices" + )))]; + }; - if let Some(tool_id) = tool_call.id.clone() { - entry.id = tool_id; - } + let mut events = Vec::new(); + if let Some(content) = choice.delta.content.clone() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } - if let Some(function) = tool_call.function.as_ref() { - if let Some(name) = function.name.clone() { - entry.name = name; - } + if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); - if let Some(arguments) = function.arguments.clone() { - entry.arguments.push_str(&arguments); - } - } - } - } + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } - match choice.finish_reason.as_deref() { - Some("stop") => { - events.push(Ok(LanguageModelCompletionEvent::Stop( - StopReason::EndTurn, - ))); - } - Some("tool_calls") => { - events.extend(state.tool_calls_by_index.drain().map( - |(_, tool_call)| match serde_json::Value::from_str( - &tool_call.arguments, - ) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.clone().into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - input, - raw_input: tool_call.arguments.clone(), - }, - )), - Err(error) => { - Err(LanguageModelCompletionError::BadInputJson { - id: tool_call.id.into(), - tool_name: tool_call.name.as_str().into(), - raw_input: tool_call.arguments.into(), - json_parse_error: error.to_string(), - }) - } - }, - )); - - events.push(Ok(LanguageModelCompletionEvent::Stop( - StopReason::ToolUse, - ))); - } - Some(stop_reason) => { - log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); - events.push(Ok(LanguageModelCompletionEvent::Stop( - StopReason::EndTurn, - ))); - } - None => {} - } - - return Some((events, state)); + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; } - Err(err) => { - return Some((vec![Err(LanguageModelCompletionError::Other(err))], state)); + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); } } } + } - None - }, - ) - .flat_map(futures::stream::iter) + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match serde_json::Value::from_str(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + }, + )), + Err(error) => Err(LanguageModelCompletionError::BadInputJson { + id: tool_call.id.into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }), + } + })); + + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} + } + + events + } +} + +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, } pub fn count_open_ai_tokens(