assistant2: Restructure storage of tool uses and results (#21194)

This PR restructures the storage of the tool uses and results in
`assistant2` so that they don't live on the individual messages.

It also introduces a `LanguageModelToolUseId` newtype for better type
safety.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-25 21:53:27 -05:00 committed by GitHub
parent 7e418cc8af
commit 968ffaa3fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 136 additions and 77 deletions

1
Cargo.lock generated
View file

@ -465,6 +465,7 @@ dependencies = [
"language_model", "language_model",
"language_model_selector", "language_model_selector",
"proto", "proto",
"serde",
"serde_json", "serde_json",
"settings", "settings",
"smol", "smol",

View file

@ -1925,7 +1925,7 @@ impl ContextEditor {
Content::ToolUse { Content::ToolUse {
range: tool_use.source_range.clone(), range: tool_use.source_range.clone(),
tool_use: LanguageModelToolUse { tool_use: LanguageModelToolUse {
id: tool_use.id.to_string(), id: tool_use.id.clone(),
name: tool_use.name.clone(), name: tool_use.name.clone(),
input: tool_use.input.clone(), input: tool_use.input.clone(),
}, },

View file

@ -27,8 +27,8 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
use language_model::{ use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse,
StopReason, LanguageModelToolUseId, MessageContent, Role, StopReason,
}; };
use language_models::{ use language_models::{
provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError}, provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
@ -385,7 +385,7 @@ pub enum ContextEvent {
}, },
UsePendingTools, UsePendingTools,
ToolFinished { ToolFinished {
tool_use_id: Arc<str>, tool_use_id: LanguageModelToolUseId,
output_range: Range<language::Anchor>, output_range: Range<language::Anchor>,
}, },
Operation(ContextOperation), Operation(ContextOperation),
@ -479,7 +479,7 @@ pub enum Content {
}, },
ToolResult { ToolResult {
range: Range<language::Anchor>, range: Range<language::Anchor>,
tool_use_id: Arc<str>, tool_use_id: LanguageModelToolUseId,
}, },
} }
@ -546,7 +546,7 @@ pub struct Context {
pub(crate) slash_commands: Arc<SlashCommandWorkingSet>, pub(crate) slash_commands: Arc<SlashCommandWorkingSet>,
pub(crate) tools: Arc<ToolWorkingSet>, pub(crate) tools: Arc<ToolWorkingSet>,
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>, slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>, pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
message_anchors: Vec<MessageAnchor>, message_anchors: Vec<MessageAnchor>,
contents: Vec<Content>, contents: Vec<Content>,
messages_metadata: HashMap<MessageId, MessageMetadata>, messages_metadata: HashMap<MessageId, MessageMetadata>,
@ -1126,7 +1126,7 @@ impl Context {
self.pending_tool_uses_by_id.values().collect() self.pending_tool_uses_by_id.values().collect()
} }
pub fn get_tool_use_by_id(&self, id: &Arc<str>) -> Option<&PendingToolUse> { pub fn get_tool_use_by_id(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
self.pending_tool_uses_by_id.get(id) self.pending_tool_uses_by_id.get(id)
} }
@ -2153,7 +2153,7 @@ impl Context {
pub fn insert_tool_output( pub fn insert_tool_output(
&mut self, &mut self,
tool_use_id: Arc<str>, tool_use_id: LanguageModelToolUseId,
output: Task<Result<String>>, output: Task<Result<String>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
@ -2340,11 +2340,10 @@ impl Context {
let source_range = buffer.anchor_after(start_ix) let source_range = buffer.anchor_after(start_ix)
..buffer.anchor_after(end_ix); ..buffer.anchor_after(end_ix);
let tool_use_id: Arc<str> = tool_use.id.into();
this.pending_tool_uses_by_id.insert( this.pending_tool_uses_by_id.insert(
tool_use_id.clone(), tool_use.id.clone(),
PendingToolUse { PendingToolUse {
id: tool_use_id, id: tool_use.id,
name: tool_use.name, name: tool_use.name,
input: tool_use.input, input: tool_use.input,
status: PendingToolUseStatus::Idle, status: PendingToolUseStatus::Idle,
@ -3203,7 +3202,7 @@ pub enum PendingSlashCommandStatus {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PendingToolUse { pub struct PendingToolUse {
pub id: Arc<str>, pub id: LanguageModelToolUseId,
pub name: String, pub name: String,
pub input: serde_json::Value, pub input: serde_json::Value,
pub status: PendingToolUseStatus, pub status: PendingToolUseStatus,

View file

@ -25,6 +25,7 @@ language_model.workspace = true
language_model_selector.workspace = true language_model_selector.workspace = true
proto.workspace = true proto.workspace = true
settings.workspace = true settings.workspace = true
serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
smol.workspace = true smol.workspace = true
theme.workspace = true theme.workspace = true

View file

@ -102,7 +102,12 @@ impl AssistantPanel {
let task = tool.run(tool_use.input, self.workspace.clone(), cx); let task = tool.run(tool_use.input, self.workspace.clone(), cx);
self.thread.update(cx, |thread, cx| { self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(tool_use.id.clone(), task, cx); thread.insert_tool_output(
tool_use.assistant_message_id,
tool_use.id.clone(),
task,
cx,
);
}); });
} }
} }

View file

@ -8,8 +8,10 @@ use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, EventEmitter, ModelContext, Task}; use gpui::{AppContext, EventEmitter, ModelContext, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
StopReason,
}; };
use serde::{Deserialize, Serialize};
use util::post_inc; use util::post_inc;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -17,34 +19,46 @@ pub enum RequestKind {
Chat, Chat,
} }
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct MessageId(usize);
impl MessageId {
fn post_inc(&mut self) -> Self {
Self(post_inc(&mut self.0))
}
}
/// A message in a [`Thread`]. /// A message in a [`Thread`].
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Message { pub struct Message {
pub id: MessageId,
pub role: Role, pub role: Role,
pub text: String, pub text: String,
pub tool_uses: Vec<LanguageModelToolUse>,
pub tool_results: Vec<LanguageModelToolResult>,
} }
/// A thread of conversation with the LLM. /// A thread of conversation with the LLM.
pub struct Thread { pub struct Thread {
messages: Vec<Message>, messages: Vec<Message>,
next_message_id: MessageId,
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>, tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
completed_tool_uses_by_id: HashMap<Arc<str>, String>, tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
} }
impl Thread { impl Thread {
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self { pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
Self { Self {
tools,
messages: Vec::new(), messages: Vec::new(),
next_message_id: MessageId(0),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
tools,
tool_uses_by_message: HashMap::default(),
tool_results_by_message: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(), pending_tool_uses_by_id: HashMap::default(),
completed_tool_uses_by_id: HashMap::default(),
} }
} }
@ -61,24 +75,13 @@ impl Thread {
} }
pub fn insert_user_message(&mut self, text: impl Into<String>) { pub fn insert_user_message(&mut self, text: impl Into<String>) {
let mut message = Message { self.messages.push(Message {
id: self.next_message_id.post_inc(),
role: Role::User, role: Role::User,
text: text.into(), text: text.into(),
tool_uses: Vec::new(),
tool_results: Vec::new(),
};
for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
message.tool_results.push(LanguageModelToolResult {
tool_use_id: tool_use_id.to_string(),
content: tool_output,
is_error: false,
}); });
} }
self.messages.push(message);
}
pub fn to_completion_request( pub fn to_completion_request(
&self, &self,
_request_kind: RequestKind, _request_kind: RequestKind,
@ -98,11 +101,13 @@ impl Thread {
cache: false, cache: false,
}; };
for tool_result in &message.tool_results { if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
for tool_result in tool_results {
request_message request_message
.content .content
.push(MessageContent::ToolResult(tool_result.clone())); .push(MessageContent::ToolResult(tool_result.clone()));
} }
}
if !message.text.is_empty() { if !message.text.is_empty() {
request_message request_message
@ -110,11 +115,13 @@ impl Thread {
.push(MessageContent::Text(message.text.clone())); .push(MessageContent::Text(message.text.clone()));
} }
for tool_use in &message.tool_uses { if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
for tool_use in tool_uses {
request_message request_message
.content .content
.push(MessageContent::ToolUse(tool_use.clone())); .push(MessageContent::ToolUse(tool_use.clone()));
} }
}
request.messages.push(request_message); request.messages.push(request_message);
} }
@ -143,10 +150,9 @@ impl Thread {
match event { match event {
LanguageModelCompletionEvent::StartMessage { .. } => { LanguageModelCompletionEvent::StartMessage { .. } => {
thread.messages.push(Message { thread.messages.push(Message {
id: thread.next_message_id.post_inc(),
role: Role::Assistant, role: Role::Assistant,
text: String::new(), text: String::new(),
tool_uses: Vec::new(),
tool_results: Vec::new(),
}); });
} }
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
@ -160,17 +166,22 @@ impl Thread {
} }
} }
LanguageModelCompletionEvent::ToolUse(tool_use) => { LanguageModelCompletionEvent::ToolUse(tool_use) => {
if let Some(last_message) = thread.messages.last_mut() { if let Some(last_assistant_message) = thread
if last_message.role == Role::Assistant { .messages
last_message.tool_uses.push(tool_use.clone()); .iter()
} .rfind(|message| message.role == Role::Assistant)
} {
thread
.tool_uses_by_message
.entry(last_assistant_message.id)
.or_default()
.push(tool_use.clone());
let tool_use_id: Arc<str> = tool_use.id.into();
thread.pending_tool_uses_by_id.insert( thread.pending_tool_uses_by_id.insert(
tool_use_id.clone(), tool_use.id.clone(),
PendingToolUse { PendingToolUse {
id: tool_use_id, assistant_message_id: last_assistant_message.id,
id: tool_use.id,
name: tool_use.name, name: tool_use.name,
input: tool_use.input, input: tool_use.input,
status: PendingToolUseStatus::Idle, status: PendingToolUseStatus::Idle,
@ -178,6 +189,7 @@ impl Thread {
); );
} }
} }
}
cx.emit(ThreadEvent::StreamedCompletion); cx.emit(ThreadEvent::StreamedCompletion);
cx.notify(); cx.notify();
@ -235,7 +247,8 @@ impl Thread {
pub fn insert_tool_output( pub fn insert_tool_output(
&mut self, &mut self,
tool_use_id: Arc<str>, assistant_message_id: MessageId,
tool_use_id: LanguageModelToolUseId,
output: Task<Result<String>>, output: Task<Result<String>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
@ -244,21 +257,41 @@ impl Thread {
async move { async move {
let output = output.await; let output = output.await;
thread thread
.update(&mut cx, |thread, cx| match output { .update(&mut cx, |thread, cx| {
// The tool use was requested by an Assistant message,
// so we want to attach the tool results to the next
// user message.
let next_user_message = MessageId(assistant_message_id.0 + 1);
let tool_results = thread
.tool_results_by_message
.entry(next_user_message)
.or_default();
match output {
Ok(output) => { Ok(output) => {
thread tool_results.push(LanguageModelToolResult {
.completed_tool_uses_by_id tool_use_id: tool_use_id.to_string(),
.insert(tool_use_id.clone(), output); content: output,
is_error: false,
});
cx.emit(ThreadEvent::ToolFinished { tool_use_id }); cx.emit(ThreadEvent::ToolFinished { tool_use_id });
} }
Err(err) => { Err(err) => {
tool_results.push(LanguageModelToolResult {
tool_use_id: tool_use_id.to_string(),
content: err.to_string(),
is_error: true,
});
if let Some(tool_use) = if let Some(tool_use) =
thread.pending_tool_uses_by_id.get_mut(&tool_use_id) thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
{ {
tool_use.status = PendingToolUseStatus::Error(err.to_string()); tool_use.status = PendingToolUseStatus::Error(err.to_string());
} }
} }
}
}) })
.ok(); .ok();
} }
@ -278,7 +311,7 @@ pub enum ThreadEvent {
UsePendingTools, UsePendingTools,
ToolFinished { ToolFinished {
#[allow(unused)] #[allow(unused)]
tool_use_id: Arc<str>, tool_use_id: LanguageModelToolUseId,
}, },
} }
@ -291,7 +324,9 @@ struct PendingCompletion {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PendingToolUse { pub struct PendingToolUse {
pub id: Arc<str>, pub id: LanguageModelToolUseId,
/// The ID of the Assistant message in which the tool use was requested.
pub assistant_message_id: MessageId,
pub name: String, pub name: String,
pub input: serde_json::Value, pub input: serde_json::Value,
pub status: PendingToolUseStatus, pub status: PendingToolUseStatus,

View file

@ -63,9 +63,27 @@ pub enum StopReason {
ToolUse, ToolUse,
} }
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUseId(Arc<str>);
impl fmt::Display for LanguageModelToolUseId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl<T> From<T> for LanguageModelToolUseId
where
T: Into<Arc<str>>,
{
fn from(value: T) -> Self {
Self(value.into())
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUse { pub struct LanguageModelToolUse {
pub id: String, pub id: LanguageModelToolUseId,
pub name: String, pub name: String,
pub input: serde_json::Value, pub input: serde_json::Value,
} }

View file

@ -347,7 +347,7 @@ impl LanguageModelRequest {
} }
MessageContent::ToolUse(tool_use) => { MessageContent::ToolUse(tool_use) => {
Some(anthropic::RequestContent::ToolUse { Some(anthropic::RequestContent::ToolUse {
id: tool_use.id, id: tool_use.id.to_string(),
name: tool_use.name, name: tool_use.name,
input: tool_use.input, input: tool_use.input,
cache_control, cache_control,

View file

@ -498,7 +498,7 @@ pub fn map_to_language_model_completion_events(
Some(maybe!({ Some(maybe!({
Ok(LanguageModelCompletionEvent::ToolUse( Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse { LanguageModelToolUse {
id: tool_use.id, id: tool_use.id.into(),
name: tool_use.name, name: tool_use.name,
input: if tool_use.input_json.is_empty() { input: if tool_use.input_json.is_empty() {
serde_json::Value::Null serde_json::Value::Null