assistant: Stream tool uses as structured data (#17322)

This PR adjusts the approach we use to encoding tool uses in the
completion response to use a structured format rather than simply
injecting it into the response stream as text.

In #17170 we would encode the tool uses as XML and insert them as text.
This would require then re-parsing the tool uses out of the buffer in
order to use them.

The approach taken in this PR is to make `stream_completion` return a
stream of `LanguageModelCompletionEvent`s. Each of these events can be
either text, or a tool use.

A new `stream_completion_text` method has been added to `LanguageModel`
for scenarios where we only care about textual content (currently,
everywhere that isn't the Assistant context editor).

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-03 15:04:51 -04:00 committed by GitHub
parent 132e8e8064
commit 452272e5df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 235 additions and 83 deletions

View file

@ -8,7 +8,8 @@ pub mod settings;
use anyhow::Result;
use client::{Client, UserStore};
use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _};
use futures::FutureExt;
use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
};
@ -51,6 +52,20 @@ pub struct LanguageModelCacheConfiguration {
pub min_total_token: usize,
}
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
Text(String),
ToolUse(LanguageModelToolUse),
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUse {
pub id: String,
pub name: String,
pub input: serde_json::Value,
}
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
@ -82,7 +97,29 @@ pub trait LanguageModel: Send + Sync {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
fn stream_completion_text(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let events = self.stream_completion(request, cx);
async move {
Ok(events
.await?
.filter_map(|result| async move {
match result {
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Err(err) => Some(Err(err)),
}
})
.boxed())
}
.boxed()
}
fn use_any_tool(
&self,

View file

@ -3,6 +3,7 @@ use crate::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
use anthropic::AnthropicError;
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
@ -364,7 +365,7 @@ impl LanguageModel for AnthropicModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request =
request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
let request = self.stream_completion(request, cx);
@ -375,7 +376,22 @@ impl LanguageModel for AnthropicModel {
async move {
Ok(future
.await?
.map(|result| result.map_err(|err| anyhow!(err)))
.map(|result| {
result
.map(|content| match content {
anthropic::ResponseContent::Text { text } => {
LanguageModelCompletionEvent::Text(text)
}
anthropic::ResponseContent::ToolUse { id, name, input } => {
LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
id,
name,
input,
})
}
})
.map_err(|err| anyhow!(err))
})
.boxed())
}
.boxed()

View file

@ -33,7 +33,10 @@ use std::{
use strum::IntoEnumIterator;
use ui::{prelude::*, TintColor};
use crate::{LanguageModelAvailability, LanguageModelProvider};
use crate::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
LanguageModelToolUse,
};
use super::anthropic::count_anthropic_tokens;
@ -496,7 +499,7 @@ impl LanguageModel for CloudLanguageModel {
&self,
request: LanguageModelRequest,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
match &self.model {
CloudModel::Anthropic(model) => {
let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
@ -522,7 +525,20 @@ impl LanguageModel for CloudLanguageModel {
async move {
Ok(future
.await?
.map(|result| result.map_err(|err| anyhow!(err)))
.map(|result| {
result
.map(|content| match content {
anthropic::ResponseContent::Text { text } => {
LanguageModelCompletionEvent::Text(text)
}
anthropic::ResponseContent::ToolUse { id, name, input } => {
LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse { id, name, input },
)
}
})
.map_err(|err| anyhow!(err))
})
.boxed())
}
.boxed()
@ -546,7 +562,13 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(open_ai::extract_text_from_events(response_lines(response)))
});
async move { Ok(future.await?.boxed()) }.boxed()
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
@ -569,7 +591,13 @@ impl LanguageModel for CloudLanguageModel {
response,
)))
});
async move { Ok(future.await?.boxed()) }.boxed()
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
CloudModel::Zed(model) => {
let client = self.client.clone();
@ -591,7 +619,13 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(open_ai::extract_text_from_events(response_lines(response)))
});
async move { Ok(future.await?.boxed()) }.boxed()
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
}
}

View file

@ -24,11 +24,11 @@ use ui::{
};
use crate::settings::AllLanguageModelSettings;
use crate::LanguageModelProviderState;
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
};
use crate::{LanguageModelCompletionEvent, LanguageModelProviderState};
use super::open_ai::count_open_ai_tokens;
@ -192,7 +192,7 @@ impl LanguageModel for CopilotChatLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
if let Some(message) = request.messages.last() {
if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str =
@ -243,7 +243,13 @@ impl LanguageModel for CopilotChatLanguageModel {
}).await
});
async move { Ok(future.await?.boxed()) }.boxed()
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
fn use_any_tool(

View file

@ -1,7 +1,7 @@
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
@ -170,10 +170,15 @@ impl LanguageModel for FakeLanguageModel {
&self,
request: LanguageModelRequest,
_: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move { Ok(rx.map(Ok).boxed()) }.boxed()
async move {
Ok(rx
.map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
.boxed())
}
.boxed()
}
fn use_any_tool(

View file

@ -17,6 +17,7 @@ use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
use crate::LanguageModelCompletionEvent;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
@ -281,7 +282,10 @@ impl LanguageModel for GoogleLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
let request = request.into_google(self.model.id().to_string());
let http_client = self.http_client.clone();
@ -299,7 +303,13 @@ impl LanguageModel for GoogleLanguageModel {
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
});
async move { Ok(future.await?.boxed()) }.boxed()
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
fn use_any_tool(

View file

@ -13,6 +13,7 @@ use std::{collections::BTreeMap, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt;
use crate::LanguageModelCompletionEvent;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
@ -302,7 +303,7 @@ impl LanguageModel for OllamaLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
@ -335,7 +336,13 @@ impl LanguageModel for OllamaLanguageModel {
Ok(stream)
});
async move { Ok(future.await?.boxed()) }.boxed()
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
fn use_any_tool(

View file

@ -19,6 +19,7 @@ use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
use crate::LanguageModelCompletionEvent;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
@ -293,10 +294,18 @@ impl LanguageModel for OpenAiLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
async move {
Ok(open_ai::extract_text_from_events(completions.await?)
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
fn use_any_tool(