assistant2: Restructure storage of tool uses and results (#21194)
This PR restructures the storage of the tool uses and results in `assistant2` so that they don't live on the individual messages. It also introduces a `LanguageModelToolUseId` newtype for better type safety. Release Notes: - N/A
This commit is contained in:
parent
7e418cc8af
commit
968ffaa3fd
9 changed files with 136 additions and 77 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -465,6 +465,7 @@ dependencies = [
|
|||
"language_model",
|
||||
"language_model_selector",
|
||||
"proto",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
|
|
|
@ -1925,7 +1925,7 @@ impl ContextEditor {
|
|||
Content::ToolUse {
|
||||
range: tool_use.source_range.clone(),
|
||||
tool_use: LanguageModelToolUse {
|
||||
id: tool_use.id.to_string(),
|
||||
id: tool_use.id.clone(),
|
||||
name: tool_use.name.clone(),
|
||||
input: tool_use.input.clone(),
|
||||
},
|
||||
|
|
|
@ -27,8 +27,8 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
|
|||
use language_model::{
|
||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
|
||||
StopReason,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse,
|
||||
LanguageModelToolUseId, MessageContent, Role, StopReason,
|
||||
};
|
||||
use language_models::{
|
||||
provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
|
||||
|
@ -385,7 +385,7 @@ pub enum ContextEvent {
|
|||
},
|
||||
UsePendingTools,
|
||||
ToolFinished {
|
||||
tool_use_id: Arc<str>,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
output_range: Range<language::Anchor>,
|
||||
},
|
||||
Operation(ContextOperation),
|
||||
|
@ -479,7 +479,7 @@ pub enum Content {
|
|||
},
|
||||
ToolResult {
|
||||
range: Range<language::Anchor>,
|
||||
tool_use_id: Arc<str>,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -546,7 +546,7 @@ pub struct Context {
|
|||
pub(crate) slash_commands: Arc<SlashCommandWorkingSet>,
|
||||
pub(crate) tools: Arc<ToolWorkingSet>,
|
||||
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
|
||||
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
message_anchors: Vec<MessageAnchor>,
|
||||
contents: Vec<Content>,
|
||||
messages_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
|
@ -1126,7 +1126,7 @@ impl Context {
|
|||
self.pending_tool_uses_by_id.values().collect()
|
||||
}
|
||||
|
||||
pub fn get_tool_use_by_id(&self, id: &Arc<str>) -> Option<&PendingToolUse> {
|
||||
pub fn get_tool_use_by_id(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
|
||||
self.pending_tool_uses_by_id.get(id)
|
||||
}
|
||||
|
||||
|
@ -2153,7 +2153,7 @@ impl Context {
|
|||
|
||||
pub fn insert_tool_output(
|
||||
&mut self,
|
||||
tool_use_id: Arc<str>,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
output: Task<Result<String>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
|
@ -2340,11 +2340,10 @@ impl Context {
|
|||
let source_range = buffer.anchor_after(start_ix)
|
||||
..buffer.anchor_after(end_ix);
|
||||
|
||||
let tool_use_id: Arc<str> = tool_use.id.into();
|
||||
this.pending_tool_uses_by_id.insert(
|
||||
tool_use_id.clone(),
|
||||
tool_use.id.clone(),
|
||||
PendingToolUse {
|
||||
id: tool_use_id,
|
||||
id: tool_use.id,
|
||||
name: tool_use.name,
|
||||
input: tool_use.input,
|
||||
status: PendingToolUseStatus::Idle,
|
||||
|
@ -3203,7 +3202,7 @@ pub enum PendingSlashCommandStatus {
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PendingToolUse {
|
||||
pub id: Arc<str>,
|
||||
pub id: LanguageModelToolUseId,
|
||||
pub name: String,
|
||||
pub input: serde_json::Value,
|
||||
pub status: PendingToolUseStatus,
|
||||
|
|
|
@ -25,6 +25,7 @@ language_model.workspace = true
|
|||
language_model_selector.workspace = true
|
||||
proto.workspace = true
|
||||
settings.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
theme.workspace = true
|
||||
|
|
|
@ -102,7 +102,12 @@ impl AssistantPanel {
|
|||
let task = tool.run(tool_use.input, self.workspace.clone(), cx);
|
||||
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.insert_tool_output(tool_use.id.clone(), task, cx);
|
||||
thread.insert_tool_output(
|
||||
tool_use.assistant_message_id,
|
||||
tool_use.id.clone(),
|
||||
task,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,8 +8,10 @@ use futures::{FutureExt as _, StreamExt as _};
|
|||
use gpui::{AppContext, EventEmitter, ModelContext, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
|
||||
StopReason,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::post_inc;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
@ -17,34 +19,46 @@ pub enum RequestKind {
|
|||
Chat,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct MessageId(usize);
|
||||
|
||||
impl MessageId {
|
||||
fn post_inc(&mut self) -> Self {
|
||||
Self(post_inc(&mut self.0))
|
||||
}
|
||||
}
|
||||
|
||||
/// A message in a [`Thread`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Message {
|
||||
pub id: MessageId,
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
pub tool_uses: Vec<LanguageModelToolUse>,
|
||||
pub tool_results: Vec<LanguageModelToolResult>,
|
||||
}
|
||||
|
||||
/// A thread of conversation with the LLM.
|
||||
pub struct Thread {
|
||||
messages: Vec<Message>,
|
||||
next_message_id: MessageId,
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
|
||||
completed_tool_uses_by_id: HashMap<Arc<str>, String>,
|
||||
tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||
tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
|
||||
Self {
|
||||
tools,
|
||||
messages: Vec::new(),
|
||||
next_message_id: MessageId(0),
|
||||
completion_count: 0,
|
||||
pending_completions: Vec::new(),
|
||||
tools,
|
||||
tool_uses_by_message: HashMap::default(),
|
||||
tool_results_by_message: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
completed_tool_uses_by_id: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,22 +75,11 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn insert_user_message(&mut self, text: impl Into<String>) {
|
||||
let mut message = Message {
|
||||
self.messages.push(Message {
|
||||
id: self.next_message_id.post_inc(),
|
||||
role: Role::User,
|
||||
text: text.into(),
|
||||
tool_uses: Vec::new(),
|
||||
tool_results: Vec::new(),
|
||||
};
|
||||
|
||||
for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
|
||||
message.tool_results.push(LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
content: tool_output,
|
||||
is_error: false,
|
||||
});
|
||||
}
|
||||
|
||||
self.messages.push(message);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn to_completion_request(
|
||||
|
@ -98,10 +101,12 @@ impl Thread {
|
|||
cache: false,
|
||||
};
|
||||
|
||||
for tool_result in &message.tool_results {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(tool_result.clone()));
|
||||
if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
|
||||
for tool_result in tool_results {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(tool_result.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if !message.text.is_empty() {
|
||||
|
@ -110,10 +115,12 @@ impl Thread {
|
|||
.push(MessageContent::Text(message.text.clone()));
|
||||
}
|
||||
|
||||
for tool_use in &message.tool_uses {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
|
||||
for tool_use in tool_uses {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
request.messages.push(request_message);
|
||||
|
@ -143,10 +150,9 @@ impl Thread {
|
|||
match event {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
thread.messages.push(Message {
|
||||
id: thread.next_message_id.post_inc(),
|
||||
role: Role::Assistant,
|
||||
text: String::new(),
|
||||
tool_uses: Vec::new(),
|
||||
tool_results: Vec::new(),
|
||||
});
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
|
@ -160,22 +166,28 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
last_message.tool_uses.push(tool_use.clone());
|
||||
}
|
||||
}
|
||||
if let Some(last_assistant_message) = thread
|
||||
.messages
|
||||
.iter()
|
||||
.rfind(|message| message.role == Role::Assistant)
|
||||
{
|
||||
thread
|
||||
.tool_uses_by_message
|
||||
.entry(last_assistant_message.id)
|
||||
.or_default()
|
||||
.push(tool_use.clone());
|
||||
|
||||
let tool_use_id: Arc<str> = tool_use.id.into();
|
||||
thread.pending_tool_uses_by_id.insert(
|
||||
tool_use_id.clone(),
|
||||
PendingToolUse {
|
||||
id: tool_use_id,
|
||||
name: tool_use.name,
|
||||
input: tool_use.input,
|
||||
status: PendingToolUseStatus::Idle,
|
||||
},
|
||||
);
|
||||
thread.pending_tool_uses_by_id.insert(
|
||||
tool_use.id.clone(),
|
||||
PendingToolUse {
|
||||
assistant_message_id: last_assistant_message.id,
|
||||
id: tool_use.id,
|
||||
name: tool_use.name,
|
||||
input: tool_use.input,
|
||||
status: PendingToolUseStatus::Idle,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -235,7 +247,8 @@ impl Thread {
|
|||
|
||||
pub fn insert_tool_output(
|
||||
&mut self,
|
||||
tool_use_id: Arc<str>,
|
||||
assistant_message_id: MessageId,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
output: Task<Result<String>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
|
@ -244,19 +257,39 @@ impl Thread {
|
|||
async move {
|
||||
let output = output.await;
|
||||
thread
|
||||
.update(&mut cx, |thread, cx| match output {
|
||||
Ok(output) => {
|
||||
thread
|
||||
.completed_tool_uses_by_id
|
||||
.insert(tool_use_id.clone(), output);
|
||||
.update(&mut cx, |thread, cx| {
|
||||
// The tool use was requested by an Assistant message,
|
||||
// so we want to attach the tool results to the next
|
||||
// user message.
|
||||
let next_user_message = MessageId(assistant_message_id.0 + 1);
|
||||
|
||||
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
||||
}
|
||||
Err(err) => {
|
||||
if let Some(tool_use) =
|
||||
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
|
||||
{
|
||||
tool_use.status = PendingToolUseStatus::Error(err.to_string());
|
||||
let tool_results = thread
|
||||
.tool_results_by_message
|
||||
.entry(next_user_message)
|
||||
.or_default();
|
||||
|
||||
match output {
|
||||
Ok(output) => {
|
||||
tool_results.push(LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
content: output,
|
||||
is_error: false,
|
||||
});
|
||||
|
||||
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
||||
}
|
||||
Err(err) => {
|
||||
tool_results.push(LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
content: err.to_string(),
|
||||
is_error: true,
|
||||
});
|
||||
|
||||
if let Some(tool_use) =
|
||||
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
|
||||
{
|
||||
tool_use.status = PendingToolUseStatus::Error(err.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -278,7 +311,7 @@ pub enum ThreadEvent {
|
|||
UsePendingTools,
|
||||
ToolFinished {
|
||||
#[allow(unused)]
|
||||
tool_use_id: Arc<str>,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -291,7 +324,9 @@ struct PendingCompletion {
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PendingToolUse {
|
||||
pub id: Arc<str>,
|
||||
pub id: LanguageModelToolUseId,
|
||||
/// The ID of the Assistant message in which the tool use was requested.
|
||||
pub assistant_message_id: MessageId,
|
||||
pub name: String,
|
||||
pub input: serde_json::Value,
|
||||
pub status: PendingToolUseStatus,
|
||||
|
|
|
@ -63,9 +63,27 @@ pub enum StopReason {
|
|||
ToolUse,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelToolUseId(Arc<str>);
|
||||
|
||||
impl fmt::Display for LanguageModelToolUseId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for LanguageModelToolUseId
|
||||
where
|
||||
T: Into<Arc<str>>,
|
||||
{
|
||||
fn from(value: T) -> Self {
|
||||
Self(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelToolUse {
|
||||
pub id: String,
|
||||
pub id: LanguageModelToolUseId,
|
||||
pub name: String,
|
||||
pub input: serde_json::Value,
|
||||
}
|
||||
|
|
|
@ -347,7 +347,7 @@ impl LanguageModelRequest {
|
|||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
Some(anthropic::RequestContent::ToolUse {
|
||||
id: tool_use.id,
|
||||
id: tool_use.id.to_string(),
|
||||
name: tool_use.name,
|
||||
input: tool_use.input,
|
||||
cache_control,
|
||||
|
|
|
@ -498,7 +498,7 @@ pub fn map_to_language_model_completion_events(
|
|||
Some(maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id,
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name,
|
||||
input: if tool_use.input_json.is_empty() {
|
||||
serde_json::Value::Null
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue