assistant: Stream tool uses as structured data (#17322)

This PR adjusts the approach we use to encoding tool uses in the
completion response to use a structured format rather than simply
injecting it into the response stream as text.

In #17170 we would encode the tool uses as XML and insert them as text.
This would require then re-parsing the tool uses out of the buffer in
order to use them.

The approach taken in this PR is to make `stream_completion` return a
stream of `LanguageModelCompletionEvent`s. Each of these events can be
either text, or a tool use.

A new `stream_completion_text` method has been added to `LanguageModel`
for scenarios where we only care about textual content (currently,
everywhere that isn't the Assistant context editor).

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-03 15:04:51 -04:00 committed by GitHub
parent 132e8e8064
commit 452272e5df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 235 additions and 83 deletions

1
Cargo.lock generated
View file

@ -243,6 +243,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"chrono", "chrono",
"collections",
"futures 0.3.30", "futures 0.3.30",
"http_client", "http_client",
"isahc", "isahc",

View file

@ -18,6 +18,7 @@ path = "src/anthropic.rs"
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
chrono.workspace = true chrono.workspace = true
collections.workspace = true
futures.workspace = true futures.workspace = true
http_client.workspace = true http_client.workspace = true
isahc.workspace = true isahc.workspace = true

View file

@ -1,17 +1,19 @@
mod supported_countries; mod supported_countries;
use std::time::Duration;
use std::{pin::Pin, str::FromStr};
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::HashMap;
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable; use isahc::config::Configurable;
use isahc::http::{HeaderMap, HeaderValue}; use isahc::http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::{pin::Pin, str::FromStr};
use strum::{EnumIter, EnumString}; use strum::{EnumIter, EnumString};
use thiserror::Error; use thiserror::Error;
use util::ResultExt as _; use util::{maybe, ResultExt as _};
pub use supported_countries::*; pub use supported_countries::*;
@ -332,19 +334,22 @@ pub async fn stream_completion_with_rate_limit_info(
pub fn extract_content_from_events( pub fn extract_content_from_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<String, AnthropicError>> { ) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
struct State { struct RawToolUse {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>, id: String,
current_tool_use_index: Option<usize>, name: String,
input_json: String,
} }
const INDENT: &str = " "; struct State {
const NEWLINE: char = '\n'; events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>,
}
futures::stream::unfold( futures::stream::unfold(
State { State {
events, events,
current_tool_use_index: None, tool_uses_by_index: HashMap::default(),
}, },
|mut state| async move { |mut state| async move {
while let Some(event) = state.events.next().await { while let Some(event) = state.events.next().await {
@ -355,62 +360,56 @@ pub fn extract_content_from_events(
content_block, content_block,
} => match content_block { } => match content_block {
ResponseContent::Text { text } => { ResponseContent::Text { text } => {
return Some((Ok(text), state)); return Some((Some(Ok(ResponseContent::Text { text })), state));
} }
ResponseContent::ToolUse { id, name, .. } => { ResponseContent::ToolUse { id, name, .. } => {
state.current_tool_use_index = Some(index); state.tool_uses_by_index.insert(
index,
RawToolUse {
id,
name,
input_json: String::new(),
},
);
let mut text = String::new(); return Some((None, state));
text.push(NEWLINE);
text.push_str("<tool_use>");
text.push(NEWLINE);
text.push_str(INDENT);
text.push_str("<id>");
text.push_str(&id);
text.push_str("</id>");
text.push(NEWLINE);
text.push_str(INDENT);
text.push_str("<name>");
text.push_str(&name);
text.push_str("</name>");
text.push(NEWLINE);
text.push_str(INDENT);
text.push_str("<input>");
return Some((Ok(text), state));
} }
}, },
Event::ContentBlockDelta { index, delta } => match delta { Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => { ContentDelta::TextDelta { text } => {
return Some((Ok(text), state)); return Some((Some(Ok(ResponseContent::Text { text })), state));
} }
ContentDelta::InputJsonDelta { partial_json } => { ContentDelta::InputJsonDelta { partial_json } => {
if Some(index) == state.current_tool_use_index { if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
return Some((Ok(partial_json), state)); tool_use.input_json.push_str(&partial_json);
return Some((None, state));
} }
} }
}, },
Event::ContentBlockStop { index } => { Event::ContentBlockStop { index } => {
if Some(index) == state.current_tool_use_index.take() { if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
let mut text = String::new(); return Some((
text.push_str("</input>"); Some(maybe!({
text.push(NEWLINE); Ok(ResponseContent::ToolUse {
text.push_str("</tool_use>"); id: tool_use.id,
name: tool_use.name,
return Some((Ok(text), state)); input: serde_json::Value::from_str(
&tool_use.input_json,
)
.map_err(|err| anyhow!(err))?,
})
})),
state,
));
} }
} }
Event::Error { error } => { Event::Error { error } => {
return Some((Err(AnthropicError::ApiError(error)), state)); return Some((Some(Err(AnthropicError::ApiError(error))), state));
} }
_ => {} _ => {}
}, },
Err(err) => { Err(err) => {
return Some((Err(err), state)); return Some((Some(Err(err)), state));
} }
} }
} }
@ -418,6 +417,7 @@ pub fn extract_content_from_events(
None None
}, },
) )
.filter_map(|event| async move { event })
} }
pub async fn extract_tool_args_from_events( pub async fn extract_tool_args_from_events(

View file

@ -25,8 +25,9 @@ use gpui::{
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset}; use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelImage, LanguageModelRegistry, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role,
}; };
use open_ai::Model as OpenAiModel; use open_ai::Model as OpenAiModel;
use paths::{context_images_dir, contexts_dir}; use paths::{context_images_dir, contexts_dir};
@ -1950,13 +1951,13 @@ impl Context {
let mut response_latency = None; let mut response_latency = None;
let stream_completion = async { let stream_completion = async {
let request_start = Instant::now(); let request_start = Instant::now();
let mut chunks = stream.await?; let mut events = stream.await?;
while let Some(chunk) = chunks.next().await { while let Some(event) = events.next().await {
if response_latency.is_none() { if response_latency.is_none() {
response_latency = Some(request_start.elapsed()); response_latency = Some(request_start.elapsed());
} }
let chunk = chunk?; let event = event?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
let message_ix = this let message_ix = this
@ -1970,11 +1971,36 @@ impl Context {
.map_or(buffer.len(), |message| { .map_or(buffer.len(), |message| {
message.start.to_offset(buffer).saturating_sub(1) message.start.to_offset(buffer).saturating_sub(1)
}); });
match event {
LanguageModelCompletionEvent::Text(chunk) => {
buffer.edit( buffer.edit(
[(message_old_end_offset..message_old_end_offset, chunk)], [(
message_old_end_offset..message_old_end_offset,
chunk,
)],
None, None,
cx, cx,
); );
}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
let mut text = String::new();
text.push('\n');
text.push_str(
&serde_json::to_string_pretty(&tool_use)
.expect("failed to serialize tool use to JSON"),
);
buffer.edit(
[(
message_old_end_offset..message_old_end_offset,
text,
)],
None,
cx,
);
}
}
}); });
cx.emit(ContextEvent::StreamedCompletion); cx.emit(ContextEvent::StreamedCompletion);
@ -2406,7 +2432,7 @@ impl Context {
self.pending_summary = cx.spawn(|this, mut cx| { self.pending_summary = cx.spawn(|this, mut cx| {
async move { async move {
let stream = model.stream_completion(request, &cx); let stream = model.stream_completion_text(request, &cx);
let mut messages = stream.await?; let mut messages = stream.await?;
let mut replaced = !replace_old; let mut replaced = !replace_old;

View file

@ -2344,7 +2344,7 @@ impl Codegen {
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?; self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
let chunks = let chunks =
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await }); cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
async move { Ok(chunks.await?.boxed()) }.boxed_local() async move { Ok(chunks.await?.boxed()) }.boxed_local()
}; };
self.handle_stream(telemetry_id, edit_range, chunks, cx); self.handle_stream(telemetry_id, edit_range, chunks, cx);

View file

@ -1010,7 +1010,7 @@ impl Codegen {
self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(|this, mut cx| async move { self.generation = cx.spawn(|this, mut cx| async move {
let model_telemetry_id = model.telemetry_id(); let model_telemetry_id = model.telemetry_id();
let response = model.stream_completion(prompt, &cx).await; let response = model.stream_completion_text(prompt, &cx).await;
let generate = async { let generate = async {
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);

View file

@ -8,7 +8,8 @@ pub mod settings;
use anyhow::Result; use anyhow::Result;
use client::{Client, UserStore}; use client::{Client, UserStore};
use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _}; use futures::FutureExt;
use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
use gpui::{ use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext, AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
}; };
@ -51,6 +52,20 @@ pub struct LanguageModelCacheConfiguration {
pub min_total_token: usize, pub min_total_token: usize,
} }
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
Text(String),
ToolUse(LanguageModelToolUse),
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUse {
pub id: String,
pub name: String,
pub input: serde_json::Value,
}
pub trait LanguageModel: Send + Sync { pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId; fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName; fn name(&self) -> LanguageModelName;
@ -82,7 +97,29 @@ pub trait LanguageModel: Send + Sync {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>; ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
fn stream_completion_text(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let events = self.stream_completion(request, cx);
async move {
Ok(events
.await?
.filter_map(|result| async move {
match result {
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Err(err) => Some(Err(err)),
}
})
.boxed())
}
.boxed()
}
fn use_any_tool( fn use_any_tool(
&self, &self,

View file

@ -3,6 +3,7 @@ use crate::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
use anthropic::AnthropicError; use anthropic::AnthropicError;
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap; use collections::BTreeMap;
@ -364,7 +365,7 @@ impl LanguageModel for AnthropicModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = let request =
request.into_anthropic(self.model.id().into(), self.model.max_output_tokens()); request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
let request = self.stream_completion(request, cx); let request = self.stream_completion(request, cx);
@ -375,7 +376,22 @@ impl LanguageModel for AnthropicModel {
async move { async move {
Ok(future Ok(future
.await? .await?
.map(|result| result.map_err(|err| anyhow!(err))) .map(|result| {
result
.map(|content| match content {
anthropic::ResponseContent::Text { text } => {
LanguageModelCompletionEvent::Text(text)
}
anthropic::ResponseContent::ToolUse { id, name, input } => {
LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
id,
name,
input,
})
}
})
.map_err(|err| anyhow!(err))
})
.boxed()) .boxed())
} }
.boxed() .boxed()

View file

@ -33,7 +33,10 @@ use std::{
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use ui::{prelude::*, TintColor}; use ui::{prelude::*, TintColor};
use crate::{LanguageModelAvailability, LanguageModelProvider}; use crate::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
LanguageModelToolUse,
};
use super::anthropic::count_anthropic_tokens; use super::anthropic::count_anthropic_tokens;
@ -496,7 +499,7 @@ impl LanguageModel for CloudLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
_cx: &AsyncAppContext, _cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
match &self.model { match &self.model {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
let request = request.into_anthropic(model.id().into(), model.max_output_tokens()); let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
@ -522,7 +525,20 @@ impl LanguageModel for CloudLanguageModel {
async move { async move {
Ok(future Ok(future
.await? .await?
.map(|result| result.map_err(|err| anyhow!(err))) .map(|result| {
result
.map(|content| match content {
anthropic::ResponseContent::Text { text } => {
LanguageModelCompletionEvent::Text(text)
}
anthropic::ResponseContent::ToolUse { id, name, input } => {
LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse { id, name, input },
)
}
})
.map_err(|err| anyhow!(err))
})
.boxed()) .boxed())
} }
.boxed() .boxed()
@ -546,7 +562,13 @@ impl LanguageModel for CloudLanguageModel {
.await?; .await?;
Ok(open_ai::extract_text_from_events(response_lines(response))) Ok(open_ai::extract_text_from_events(response_lines(response)))
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
} }
CloudModel::Google(model) => { CloudModel::Google(model) => {
let client = self.client.clone(); let client = self.client.clone();
@ -569,7 +591,13 @@ impl LanguageModel for CloudLanguageModel {
response, response,
))) )))
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
} }
CloudModel::Zed(model) => { CloudModel::Zed(model) => {
let client = self.client.clone(); let client = self.client.clone();
@ -591,7 +619,13 @@ impl LanguageModel for CloudLanguageModel {
.await?; .await?;
Ok(open_ai::extract_text_from_events(response_lines(response))) Ok(open_ai::extract_text_from_events(response_lines(response)))
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
} }
} }
} }

View file

@ -24,11 +24,11 @@ use ui::{
}; };
use crate::settings::AllLanguageModelSettings; use crate::settings::AllLanguageModelSettings;
use crate::LanguageModelProviderState;
use crate::{ use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
}; };
use crate::{LanguageModelCompletionEvent, LanguageModelProviderState};
use super::open_ai::count_open_ai_tokens; use super::open_ai::count_open_ai_tokens;
@ -192,7 +192,7 @@ impl LanguageModel for CopilotChatLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
if let Some(message) = request.messages.last() { if let Some(message) = request.messages.last() {
if message.contents_empty() { if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str = const EMPTY_PROMPT_MSG: &str =
@ -243,7 +243,13 @@ impl LanguageModel for CopilotChatLanguageModel {
}).await }).await
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
} }
fn use_any_tool( fn use_any_tool(

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRequest, LanguageModelProviderState, LanguageModelRequest,
}; };
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task}; use gpui::{AnyView, AppContext, AsyncAppContext, Task};
@ -170,10 +170,15 @@ impl LanguageModel for FakeLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
_: &AsyncAppContext, _: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx)); self.current_completion_txs.lock().push((request, tx));
async move { Ok(rx.map(Ok).boxed()) }.boxed() async move {
Ok(rx
.map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
.boxed())
}
.boxed()
} }
fn use_any_tool( fn use_any_tool(

View file

@ -17,6 +17,7 @@ use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip}; use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt; use util::ResultExt;
use crate::LanguageModelCompletionEvent;
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
@ -281,7 +282,10 @@ impl LanguageModel for GoogleLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> { ) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
let request = request.into_google(self.model.id().to_string()); let request = request.into_google(self.model.id().to_string());
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
@ -299,7 +303,13 @@ impl LanguageModel for GoogleLanguageModel {
let events = response.await?; let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed()) Ok(google_ai::extract_text_from_events(events).boxed())
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
} }
fn use_any_tool( fn use_any_tool(

View file

@ -13,6 +13,7 @@ use std::{collections::BTreeMap, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator}; use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt; use util::ResultExt;
use crate::LanguageModelCompletionEvent;
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
@ -302,7 +303,7 @@ impl LanguageModel for OllamaLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = self.to_ollama_request(request); let request = self.to_ollama_request(request);
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
@ -335,7 +336,13 @@ impl LanguageModel for OllamaLanguageModel {
Ok(stream) Ok(stream)
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
} }
fn use_any_tool( fn use_any_tool(

View file

@ -19,6 +19,7 @@ use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip}; use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt; use util::ResultExt;
use crate::LanguageModelCompletionEvent;
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
@ -293,10 +294,18 @@ impl LanguageModel for OpenAiLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> { ) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens()); let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx); let completions = self.stream_completion(request, cx);
async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed() async move {
Ok(open_ai::extract_text_from_events(completions.await?)
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
} }
fn use_any_tool( fn use_any_tool(