assistant2: Rework how tool results are stored and referred to (#25817)
This PR reworks how we store tool results and refer to them later. We now maintain a mapping of the tool uses to their corresponding results, with separate mappings for the messages and the tool uses they correspond to. Release Notes: - N/A
This commit is contained in:
parent
508b581215
commit
b445e4ce24
2 changed files with 56 additions and 64 deletions
|
@ -255,12 +255,7 @@ impl ActiveThread {
|
|||
let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
|
||||
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.insert_tool_output(
|
||||
tool_use.assistant_message_id,
|
||||
tool_use.id.clone(),
|
||||
task,
|
||||
cx,
|
||||
);
|
||||
thread.insert_tool_output(tool_use.id.clone(), task, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,8 +88,9 @@ pub struct Thread {
|
|||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||
tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
|
||||
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
}
|
||||
|
||||
|
@ -107,8 +108,9 @@ impl Thread {
|
|||
completion_count: 0,
|
||||
pending_completions: Vec::new(),
|
||||
tools,
|
||||
tool_uses_by_message: HashMap::default(),
|
||||
tool_results_by_message: HashMap::default(),
|
||||
tool_uses_by_assistant_message: HashMap::default(),
|
||||
tool_uses_by_user_message: HashMap::default(),
|
||||
tool_results: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
@ -141,8 +143,9 @@ impl Thread {
|
|||
completion_count: 0,
|
||||
pending_completions: Vec::new(),
|
||||
tools,
|
||||
tool_uses_by_message: HashMap::default(),
|
||||
tool_results_by_message: HashMap::default(),
|
||||
tool_uses_by_assistant_message: HashMap::default(),
|
||||
tool_uses_by_user_message: HashMap::default(),
|
||||
tool_results: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
@ -209,26 +212,14 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
|
||||
let Some(tool_uses_for_message) = &self.tool_uses_by_message.get(&id) else {
|
||||
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
// The tool use was requested by an Assistant message, so we need to
|
||||
// look for the tool results on the next user message.
|
||||
let next_user_message = MessageId(id.0 + 1);
|
||||
|
||||
let empty = Vec::new();
|
||||
let tool_results_for_message = self
|
||||
.tool_results_by_message
|
||||
.get(&next_user_message)
|
||||
.unwrap_or_else(|| &empty);
|
||||
|
||||
let mut tool_uses = Vec::new();
|
||||
|
||||
for tool_use in tool_uses_for_message.iter() {
|
||||
let tool_result = tool_results_for_message
|
||||
.iter()
|
||||
.find(|tool_result| tool_result.tool_use_id == tool_use.id);
|
||||
let tool_result = self.tool_results.get(&tool_use.id);
|
||||
|
||||
let status = (|| {
|
||||
if let Some(tool_result) = tool_result {
|
||||
|
@ -264,7 +255,7 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_results_by_message
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
.map_or(false, |results| !results.is_empty())
|
||||
}
|
||||
|
@ -369,13 +360,15 @@ impl Thread {
|
|||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
|
||||
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message.id) {
|
||||
match request_kind {
|
||||
RequestKind::Chat => {
|
||||
for tool_result in tool_results {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(tool_result.clone()));
|
||||
for tool_use_id in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(tool_result.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
RequestKind::Summarize => {
|
||||
|
@ -390,7 +383,7 @@ impl Thread {
|
|||
.push(MessageContent::Text(message.text.clone()));
|
||||
}
|
||||
|
||||
if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
|
||||
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message.id) {
|
||||
match request_kind {
|
||||
RequestKind::Chat => {
|
||||
for tool_use in tool_uses {
|
||||
|
@ -477,11 +470,22 @@ impl Thread {
|
|||
.rfind(|message| message.role == Role::Assistant)
|
||||
{
|
||||
thread
|
||||
.tool_uses_by_message
|
||||
.tool_uses_by_assistant_message
|
||||
.entry(last_assistant_message.id)
|
||||
.or_default()
|
||||
.push(tool_use.clone());
|
||||
|
||||
// The tool use is being requested by the
|
||||
// Assistant, so we want to attach the tool
|
||||
// results to the next user message.
|
||||
let next_user_message_id =
|
||||
MessageId(last_assistant_message.id.0 + 1);
|
||||
thread
|
||||
.tool_uses_by_user_message
|
||||
.entry(next_user_message_id)
|
||||
.or_default()
|
||||
.push(tool_use.id.clone());
|
||||
|
||||
thread.pending_tool_uses_by_id.insert(
|
||||
tool_use.id.clone(),
|
||||
PendingToolUse {
|
||||
|
@ -611,7 +615,6 @@ impl Thread {
|
|||
|
||||
pub fn insert_tool_output(
|
||||
&mut self,
|
||||
assistant_message_id: MessageId,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
output: Task<Result<String>>,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -621,44 +624,38 @@ impl Thread {
|
|||
async move {
|
||||
let output = output.await;
|
||||
thread
|
||||
.update(&mut cx, |thread, cx| {
|
||||
// The tool use was requested by an Assistant message,
|
||||
// so we want to attach the tool results to the next
|
||||
// user message.
|
||||
let next_user_message = MessageId(assistant_message_id.0 + 1);
|
||||
|
||||
let tool_results = thread
|
||||
.tool_results_by_message
|
||||
.entry(next_user_message)
|
||||
.or_default();
|
||||
|
||||
match output {
|
||||
Ok(output) => {
|
||||
tool_results.push(LanguageModelToolResult {
|
||||
.update(&mut cx, |thread, cx| match output {
|
||||
Ok(output) => {
|
||||
thread.tool_results.insert(
|
||||
tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
content: output.into(),
|
||||
is_error: false,
|
||||
});
|
||||
thread.pending_tool_uses_by_id.remove(&tool_use_id);
|
||||
},
|
||||
);
|
||||
thread.pending_tool_uses_by_id.remove(&tool_use_id);
|
||||
|
||||
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
||||
}
|
||||
Err(err) => {
|
||||
tool_results.push(LanguageModelToolResult {
|
||||
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
||||
}
|
||||
Err(err) => {
|
||||
thread.tool_results.insert(
|
||||
tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
content: err.to_string().into(),
|
||||
is_error: true,
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
if let Some(tool_use) =
|
||||
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
|
||||
{
|
||||
tool_use.status =
|
||||
PendingToolUseStatus::Error(err.to_string().into());
|
||||
}
|
||||
|
||||
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
||||
if let Some(tool_use) =
|
||||
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
|
||||
{
|
||||
tool_use.status =
|
||||
PendingToolUseStatus::Error(err.to_string().into());
|
||||
}
|
||||
|
||||
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue