diff --git a/Cargo.lock b/Cargo.lock index 0427f3cfb5..48be65f7b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -243,6 +243,7 @@ version = "0.1.0" dependencies = [ "anyhow", "chrono", + "collections", "futures 0.3.30", "http_client", "isahc", diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 9e48ad0e57..ddab9dfd7c 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -18,6 +18,7 @@ path = "src/anthropic.rs" [dependencies] anyhow.workspace = true chrono.workspace = true +collections.workspace = true futures.workspace = true http_client.workspace = true isahc.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 3e2f065e95..03aec20568 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,17 +1,19 @@ mod supported_countries; +use std::time::Duration; +use std::{pin::Pin, str::FromStr}; + use anyhow::{anyhow, Context, Result}; use chrono::{DateTime, Utc}; +use collections::HashMap; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use isahc::http::{HeaderMap, HeaderValue}; use serde::{Deserialize, Serialize}; -use std::time::Duration; -use std::{pin::Pin, str::FromStr}; use strum::{EnumIter, EnumString}; use thiserror::Error; -use util::ResultExt as _; +use util::{maybe, ResultExt as _}; pub use supported_countries::*; @@ -332,19 +334,22 @@ pub async fn stream_completion_with_rate_limit_info( pub fn extract_content_from_events( events: Pin>>>, -) -> impl Stream> { - struct State { - events: Pin>>>, - current_tool_use_index: Option, +) -> impl Stream> { + struct RawToolUse { + id: String, + name: String, + input_json: String, } - const INDENT: &str = " "; - const NEWLINE: char = '\n'; + struct State { + events: Pin>>>, + tool_uses_by_index: HashMap, + } futures::stream::unfold( State { events, - current_tool_use_index: None, + tool_uses_by_index: HashMap::default(), }, |mut state| async move { while let Some(event) = state.events.next().await { @@ -355,62 +360,56 @@ pub fn extract_content_from_events( content_block, } => match content_block { ResponseContent::Text { text } => { - return Some((Ok(text), state)); + return Some((Some(Ok(ResponseContent::Text { text })), state)); } ResponseContent::ToolUse { id, name, .. } => { - state.current_tool_use_index = Some(index); + state.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), + }, + ); - let mut text = String::new(); - text.push(NEWLINE); - - text.push_str(""); - text.push(NEWLINE); - - text.push_str(INDENT); - text.push_str(""); - text.push_str(&id); - text.push_str(""); - text.push(NEWLINE); - - text.push_str(INDENT); - text.push_str(""); - text.push_str(&name); - text.push_str(""); - text.push(NEWLINE); - - text.push_str(INDENT); - text.push_str(""); - - return Some((Ok(text), state)); + return Some((None, state)); } }, Event::ContentBlockDelta { index, delta } => match delta { ContentDelta::TextDelta { text } => { - return Some((Ok(text), state)); + return Some((Some(Ok(ResponseContent::Text { text })), state)); } ContentDelta::InputJsonDelta { partial_json } => { - if Some(index) == state.current_tool_use_index { - return Some((Ok(partial_json), state)); + if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + return Some((None, state)); } } }, Event::ContentBlockStop { index } => { - if Some(index) == state.current_tool_use_index.take() { - let mut text = String::new(); - text.push_str(""); - text.push(NEWLINE); - text.push_str(""); - - return Some((Ok(text), state)); + if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { + return Some(( + Some(maybe!({ + Ok(ResponseContent::ToolUse { + id: tool_use.id, + name: tool_use.name, + input: serde_json::Value::from_str( + &tool_use.input_json, + ) + .map_err(|err| anyhow!(err))?, + }) + })), + state, + )); } } Event::Error { error } => { - return Some((Err(AnthropicError::ApiError(error)), state)); + return Some((Some(Err(AnthropicError::ApiError(error))), state)); } _ => {} }, Err(err) => { - return Some((Err(err), state)); + return Some((Some(Err(err)), state)); } } } @@ -418,6 +417,7 @@ pub fn extract_content_from_events( None }, ) + .filter_map(|event| async move { event }) } pub async fn extract_tool_args_from_events( diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 9a0c985d46..08668c2797 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -25,8 +25,9 @@ use gpui::{ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset}; use language_model::{ - LanguageModel, LanguageModelCacheConfiguration, LanguageModelImage, LanguageModelRegistry, - LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, + LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent, + LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, + MessageContent, Role, }; use open_ai::Model as OpenAiModel; use paths::{context_images_dir, contexts_dir}; @@ -1950,13 +1951,13 @@ impl Context { let mut response_latency = None; let stream_completion = async { let request_start = Instant::now(); - let mut chunks = stream.await?; + let mut events = stream.await?; - while let Some(chunk) = chunks.next().await { + while let Some(event) = events.next().await { if response_latency.is_none() { response_latency = Some(request_start.elapsed()); } - let chunk = chunk?; + let event = event?; this.update(&mut cx, |this, cx| { let message_ix = this @@ -1970,11 +1971,36 @@ impl Context { .map_or(buffer.len(), |message| { message.start.to_offset(buffer).saturating_sub(1) }); - buffer.edit( - [(message_old_end_offset..message_old_end_offset, chunk)], - None, - cx, - ); + + match event { + LanguageModelCompletionEvent::Text(chunk) => { + buffer.edit( + [( + message_old_end_offset..message_old_end_offset, + chunk, + )], + None, + cx, + ); + } + LanguageModelCompletionEvent::ToolUse(tool_use) => { + let mut text = String::new(); + text.push('\n'); + text.push_str( + &serde_json::to_string_pretty(&tool_use) + .expect("failed to serialize tool use to JSON"), + ); + + buffer.edit( + [( + message_old_end_offset..message_old_end_offset, + text, + )], + None, + cx, + ); + } + } }); cx.emit(ContextEvent::StreamedCompletion); @@ -2406,7 +2432,7 @@ impl Context { self.pending_summary = cx.spawn(|this, mut cx| { async move { - let stream = model.stream_completion(request, &cx); + let stream = model.stream_completion_text(request, &cx); let mut messages = stream.await?; let mut replaced = !replace_old; diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 63c8e45bfd..26392b7654 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -2344,7 +2344,7 @@ impl Codegen { self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?; let chunks = - cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await }); + cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await }); async move { Ok(chunks.await?.boxed()) }.boxed_local() }; self.handle_stream(telemetry_id, edit_range, chunks, cx); diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 1750195982..699786f4cc 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -1010,7 +1010,7 @@ impl Codegen { self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); self.generation = cx.spawn(|this, mut cx| async move { let model_telemetry_id = model.telemetry_id(); - let response = model.stream_completion(prompt, &cx).await; + let response = model.stream_completion_text(prompt, &cx).await; let generate = async { let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index c793168606..cd85ca7f53 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -8,7 +8,8 @@ pub mod settings; use anyhow::Result; use client::{Client, UserStore}; -use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _}; +use futures::FutureExt; +use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _}; use gpui::{ AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext, }; @@ -51,6 +52,20 @@ pub struct LanguageModelCacheConfiguration { pub min_total_token: usize, } +/// A completion event from a language model. +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub enum LanguageModelCompletionEvent { + Text(String), + ToolUse(LanguageModelToolUse), +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct LanguageModelToolUse { + pub id: String, + pub name: String, + pub input: serde_json::Value, +} + pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; @@ -82,7 +97,29 @@ pub trait LanguageModel: Send + Sync { &self, request: LanguageModelRequest, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>>; + ) -> BoxFuture<'static, Result>>>; + + fn stream_completion_text( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let events = self.stream_completion(request, cx); + + async move { + Ok(events + .await? + .filter_map(|result| async move { + match result { + Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), + Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, + Err(err) => Some(Err(err)), + } + }) + .boxed()) + } + .boxed() + } fn use_any_tool( &self, diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index e4bb94a738..8258768a6a 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -3,6 +3,7 @@ use crate::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; +use crate::{LanguageModelCompletionEvent, LanguageModelToolUse}; use anthropic::AnthropicError; use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; @@ -364,7 +365,7 @@ impl LanguageModel for AnthropicModel { &self, request: LanguageModelRequest, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture<'static, Result>>> { let request = request.into_anthropic(self.model.id().into(), self.model.max_output_tokens()); let request = self.stream_completion(request, cx); @@ -375,7 +376,22 @@ impl LanguageModel for AnthropicModel { async move { Ok(future .await? - .map(|result| result.map_err(|err| anyhow!(err))) + .map(|result| { + result + .map(|content| match content { + anthropic::ResponseContent::Text { text } => { + LanguageModelCompletionEvent::Text(text) + } + anthropic::ResponseContent::ToolUse { id, name, input } => { + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + id, + name, + input, + }) + } + }) + .map_err(|err| anyhow!(err)) + }) .boxed()) } .boxed() diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index d4166fdfb5..a9b0008bbd 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -33,7 +33,10 @@ use std::{ use strum::IntoEnumIterator; use ui::{prelude::*, TintColor}; -use crate::{LanguageModelAvailability, LanguageModelProvider}; +use crate::{ + LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, + LanguageModelToolUse, +}; use super::anthropic::count_anthropic_tokens; @@ -496,7 +499,7 @@ impl LanguageModel for CloudLanguageModel { &self, request: LanguageModelRequest, _cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture<'static, Result>>> { match &self.model { CloudModel::Anthropic(model) => { let request = request.into_anthropic(model.id().into(), model.max_output_tokens()); @@ -522,7 +525,20 @@ impl LanguageModel for CloudLanguageModel { async move { Ok(future .await? - .map(|result| result.map_err(|err| anyhow!(err))) + .map(|result| { + result + .map(|content| match content { + anthropic::ResponseContent::Text { text } => { + LanguageModelCompletionEvent::Text(text) + } + anthropic::ResponseContent::ToolUse { id, name, input } => { + LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { id, name, input }, + ) + } + }) + .map_err(|err| anyhow!(err)) + }) .boxed()) } .boxed() @@ -546,7 +562,13 @@ impl LanguageModel for CloudLanguageModel { .await?; Ok(open_ai::extract_text_from_events(response_lines(response))) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + Ok(future + .await? + .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .boxed()) + } + .boxed() } CloudModel::Google(model) => { let client = self.client.clone(); @@ -569,7 +591,13 @@ impl LanguageModel for CloudLanguageModel { response, ))) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + Ok(future + .await? + .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .boxed()) + } + .boxed() } CloudModel::Zed(model) => { let client = self.client.clone(); @@ -591,7 +619,13 @@ impl LanguageModel for CloudLanguageModel { .await?; Ok(open_ai::extract_text_from_events(response_lines(response))) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + Ok(future + .await? + .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .boxed()) + } + .boxed() } } } diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs index 49241e16eb..e21060e54d 100644 --- a/crates/language_model/src/provider/copilot_chat.rs +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -24,11 +24,11 @@ use ui::{ }; use crate::settings::AllLanguageModelSettings; -use crate::LanguageModelProviderState; use crate::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role, }; +use crate::{LanguageModelCompletionEvent, LanguageModelProviderState}; use super::open_ai::count_open_ai_tokens; @@ -192,7 +192,7 @@ impl LanguageModel for CopilotChatLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture<'static, Result>>> { if let Some(message) = request.messages.last() { if message.contents_empty() { const EMPTY_PROMPT_MSG: &str = @@ -243,7 +243,13 @@ impl LanguageModel for CopilotChatLanguageModel { }).await }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + Ok(future + .await? + .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .boxed()) + } + .boxed() } fn use_any_tool( diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs index f62539aef2..2044ae520d 100644 --- a/crates/language_model/src/provider/fake.rs +++ b/crates/language_model/src/provider/fake.rs @@ -1,7 +1,7 @@ use crate::{ - LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, + LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, }; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Task}; @@ -170,10 +170,15 @@ impl LanguageModel for FakeLanguageModel { &self, request: LanguageModelRequest, _: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture<'static, Result>>> { let (tx, rx) = mpsc::unbounded(); self.current_completion_txs.lock().push((request, tx)); - async move { Ok(rx.map(Ok).boxed()) }.boxed() + async move { + Ok(rx + .map(|text| Ok(LanguageModelCompletionEvent::Text(text))) + .boxed()) + } + .boxed() } fn use_any_tool( diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index 8a9dc760a4..b59d97e036 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -17,6 +17,7 @@ use theme::ThemeSettings; use ui::{prelude::*, Icon, IconName, Tooltip}; use util::ResultExt; +use crate::LanguageModelCompletionEvent; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, @@ -281,7 +282,10 @@ impl LanguageModel for GoogleLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result>>, + > { let request = request.into_google(self.model.id().to_string()); let http_client = self.http_client.clone(); @@ -299,7 +303,13 @@ impl LanguageModel for GoogleLanguageModel { let events = response.await?; Ok(google_ai::extract_text_from_events(events).boxed()) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + Ok(future + .await? + .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .boxed()) + } + .boxed() } fn use_any_tool( diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index bdbbe9657a..cfcca1fb7a 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -13,6 +13,7 @@ use std::{collections::BTreeMap, sync::Arc, time::Duration}; use ui::{prelude::*, ButtonLike, Indicator}; use util::ResultExt; +use crate::LanguageModelCompletionEvent; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, @@ -302,7 +303,7 @@ impl LanguageModel for OllamaLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture<'static, Result>>> { let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); @@ -335,7 +336,13 @@ impl LanguageModel for OllamaLanguageModel { Ok(stream) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + Ok(future + .await? + .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .boxed()) + } + .boxed() } fn use_any_tool( diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 9933c96b94..56d83fb653 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -19,6 +19,7 @@ use theme::ThemeSettings; use ui::{prelude::*, Icon, IconName, Tooltip}; use util::ResultExt; +use crate::LanguageModelCompletionEvent; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, @@ -293,10 +294,18 @@ impl LanguageModel for OpenAiLanguageModel { &self, request: LanguageModelRequest, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture< + 'static, + Result>>, + > { let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens()); let completions = self.stream_completion(request, cx); - async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed() + async move { + Ok(open_ai::extract_text_from_events(completions.await?) + .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .boxed()) + } + .boxed() } fn use_any_tool(