assistant2: Rework how tool results are stored and referred to (#25817)

This PR reworks how we store tool results and refer to them later.

We now maintain a mapping of the tool uses to their corresponding
results, with separate mappings for the messages and the tool uses they
correspond to.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-02-28 11:33:08 -05:00 committed by GitHub
parent 508b581215
commit b445e4ce24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 56 additions and 64 deletions

View file

@ -255,12 +255,7 @@ impl ActiveThread {
let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(
tool_use.assistant_message_id,
tool_use.id.clone(),
task,
cx,
);
thread.insert_tool_output(tool_use.id.clone(), task, cx);
});
}
}

View file

@ -88,8 +88,9 @@ pub struct Thread {
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
tools: Arc<ToolWorkingSet>,
tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
}
@ -107,8 +108,9 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
tool_uses_by_message: HashMap::default(),
tool_results_by_message: HashMap::default(),
tool_uses_by_assistant_message: HashMap::default(),
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
}
}
@ -141,8 +143,9 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
tool_uses_by_message: HashMap::default(),
tool_results_by_message: HashMap::default(),
tool_uses_by_assistant_message: HashMap::default(),
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
}
}
@ -209,26 +212,14 @@ impl Thread {
}
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
let Some(tool_uses_for_message) = &self.tool_uses_by_message.get(&id) else {
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
return Vec::new();
};
// The tool use was requested by an Assistant message, so we need to
// look for the tool results on the next user message.
let next_user_message = MessageId(id.0 + 1);
let empty = Vec::new();
let tool_results_for_message = self
.tool_results_by_message
.get(&next_user_message)
.unwrap_or_else(|| &empty);
let mut tool_uses = Vec::new();
for tool_use in tool_uses_for_message.iter() {
let tool_result = tool_results_for_message
.iter()
.find(|tool_result| tool_result.tool_use_id == tool_use.id);
let tool_result = self.tool_results.get(&tool_use.id);
let status = (|| {
if let Some(tool_result) = tool_result {
@ -264,7 +255,7 @@ impl Thread {
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_results_by_message
self.tool_uses_by_user_message
.get(&message_id)
.map_or(false, |results| !results.is_empty())
}
@ -369,13 +360,15 @@ impl Thread {
content: Vec::new(),
cache: false,
};
if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message.id) {
match request_kind {
RequestKind::Chat => {
for tool_result in tool_results {
request_message
.content
.push(MessageContent::ToolResult(tool_result.clone()));
for tool_use_id in tool_uses {
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
request_message
.content
.push(MessageContent::ToolResult(tool_result.clone()));
}
}
}
RequestKind::Summarize => {
@ -390,7 +383,7 @@ impl Thread {
.push(MessageContent::Text(message.text.clone()));
}
if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message.id) {
match request_kind {
RequestKind::Chat => {
for tool_use in tool_uses {
@ -477,11 +470,22 @@ impl Thread {
.rfind(|message| message.role == Role::Assistant)
{
thread
.tool_uses_by_message
.tool_uses_by_assistant_message
.entry(last_assistant_message.id)
.or_default()
.push(tool_use.clone());
// The tool use is being requested by the
// Assistant, so we want to attach the tool
// results to the next user message.
let next_user_message_id =
MessageId(last_assistant_message.id.0 + 1);
thread
.tool_uses_by_user_message
.entry(next_user_message_id)
.or_default()
.push(tool_use.id.clone());
thread.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
PendingToolUse {
@ -611,7 +615,6 @@ impl Thread {
pub fn insert_tool_output(
&mut self,
assistant_message_id: MessageId,
tool_use_id: LanguageModelToolUseId,
output: Task<Result<String>>,
cx: &mut Context<Self>,
@ -621,44 +624,38 @@ impl Thread {
async move {
let output = output.await;
thread
.update(&mut cx, |thread, cx| {
// The tool use was requested by an Assistant message,
// so we want to attach the tool results to the next
// user message.
let next_user_message = MessageId(assistant_message_id.0 + 1);
let tool_results = thread
.tool_results_by_message
.entry(next_user_message)
.or_default();
match output {
Ok(output) => {
tool_results.push(LanguageModelToolResult {
.update(&mut cx, |thread, cx| match output {
Ok(output) => {
thread.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
content: output.into(),
is_error: false,
});
thread.pending_tool_uses_by_id.remove(&tool_use_id);
},
);
thread.pending_tool_uses_by_id.remove(&tool_use_id);
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
}
Err(err) => {
tool_results.push(LanguageModelToolResult {
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
}
Err(err) => {
thread.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
content: err.to_string().into(),
is_error: true,
});
},
);
if let Some(tool_use) =
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
{
tool_use.status =
PendingToolUseStatus::Error(err.to_string().into());
}
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
if let Some(tool_use) =
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
{
tool_use.status =
PendingToolUseStatus::Error(err.to_string().into());
}
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
}
})
.ok();