agent: Do not create user messages for tool results in thread (#29354)

We used to insert empty user messages into the `Thread::messages` `Vec`
when tools finished running and then we would attach the results when
creating the request. This approach was very easy to mess up during
state handling, leading to empty user messages displayed in the
conversation and API failures.

Instead, we will no longer insert actual user messages for tool results
to the `Thread`, and will only do this on the fly when creating the
model request. This simplifies a lot of code and show fix the mentioned
errors.

Release Notes:

- agent: Improve reliability of LLM requests when including tool results

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
Agus Zubiaga 2025-04-24 13:30:15 -03:00 committed by GitHub
parent 952fe34aaa
commit f81e65ae7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 144 additions and 113 deletions

View file

@ -1485,39 +1485,13 @@ impl ActiveThread {
let is_first_message = ix == 0; let is_first_message = ix == 0;
let is_last_message = ix == self.messages.len() - 1; let is_last_message = ix == self.messages.len() - 1;
let show_feedback = (!is_generating && is_last_message && message.role != Role::User) let show_feedback = thread.is_turn_end(ix);
|| self.messages.get(ix + 1).map_or(false, |next_id| {
self.thread
.read(cx)
.message(*next_id)
.map_or(false, |next_message| {
next_message.role == Role::User
&& thread.tool_uses_for_message(*next_id, cx).is_empty()
&& thread.tool_results_for_message(*next_id).is_empty()
})
});
let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation); let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation);
let generating_label = (is_generating && is_last_message) let generating_label = (is_generating && is_last_message)
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small)); .then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
// Don't render user messages that are just there for returning tool results.
if message.role == Role::User && thread.message_has_tool_results(message_id) {
if let Some(generating_label) = generating_label {
return h_flex()
.w_full()
.h_10()
.py_1p5()
.pl_4()
.pb_3()
.child(generating_label)
.into_any_element();
}
return Empty.into_any();
}
let edit_message_editor = self let edit_message_editor = self
.editing_message .editing_message
.as_ref() .as_ref()

View file

@ -391,8 +391,7 @@ impl Thread {
.map(|message| message.id.0 + 1) .map(|message| message.id.0 + 1)
.unwrap_or(0), .unwrap_or(0),
); );
let tool_use = let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
Self { Self {
id, id,
@ -524,7 +523,12 @@ impl Thread {
} }
pub fn message(&self, id: MessageId) -> Option<&Message> { pub fn message(&self, id: MessageId) -> Option<&Message> {
self.messages.iter().find(|message| message.id == id) let index = self
.messages
.binary_search_by(|message| message.id.cmp(&id))
.ok()?;
self.messages.get(index)
} }
pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> { pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
@ -673,6 +677,32 @@ impl Thread {
}) })
} }
pub fn is_turn_end(&self, ix: usize) -> bool {
if self.messages.is_empty() {
return false;
}
if !self.is_generating() && ix == self.messages.len() - 1 {
return true;
}
let Some(message) = self.messages.get(ix) else {
return false;
};
if message.role != Role::Assistant {
return false;
}
self.messages
.get(ix + 1)
.and_then(|message| {
self.message(message.id)
.map(|next_message| next_message.role == Role::User)
})
.unwrap_or(false)
}
/// Returns whether all of the tool uses have finished running. /// Returns whether all of the tool uses have finished running.
pub fn all_tools_finished(&self) -> bool { pub fn all_tools_finished(&self) -> bool {
// If the only pending tool uses left are the ones with errors, then // If the only pending tool uses left are the ones with errors, then
@ -687,8 +717,11 @@ impl Thread {
self.tool_use.tool_uses_for_message(id, cx) self.tool_use.tool_uses_for_message(id, cx)
} }
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> { pub fn tool_results_for_message(
self.tool_use.tool_results_for_message(id) &self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
self.tool_use.tool_results_for_message(assistant_message_id)
} }
pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> { pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
@ -703,10 +736,6 @@ impl Thread {
self.tool_use.tool_result_card(id).cloned() self.tool_use.tool_result_card(id).cloned()
} }
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
/// Filter out contexts that have already been included in previous messages /// Filter out contexts that have already been included in previous messages
pub fn filter_new_context<'a>( pub fn filter_new_context<'a>(
&self, &self,
@ -1051,9 +1080,6 @@ impl Thread {
cache: false, cache: false,
}; };
self.tool_use
.attach_tool_results(message.id, &mut request_message);
if !message.context.is_empty() { if !message.context.is_empty() {
request_message request_message
.content .content
@ -1104,6 +1130,10 @@ impl Thread {
.attach_tool_uses(message.id, &mut request_message); .attach_tool_uses(message.id, &mut request_message);
request.messages.push(request_message); request.messages.push(request_message);
if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
request.messages.push(tool_results_message);
}
} }
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
@ -1133,11 +1163,6 @@ impl Thread {
cache: false, cache: false,
}; };
// Skip tool results during summarization.
if self.tool_use.message_has_tool_results(message.id) {
continue;
}
for segment in &message.segments { for segment in &message.segments {
match segment { match segment {
MessageSegment::Text(text) => request_message MessageSegment::Text(text) => request_message
@ -1272,7 +1297,9 @@ impl Thread {
LanguageModelCompletionEvent::Text(chunk) => { LanguageModelCompletionEvent::Text(chunk) => {
cx.emit(ThreadEvent::ReceivedTextChunk); cx.emit(ThreadEvent::ReceivedTextChunk);
if let Some(last_message) = thread.messages.last_mut() { if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant { if last_message.role == Role::Assistant
&& !thread.tool_use.has_tool_results(last_message.id)
{
last_message.push_text(&chunk); last_message.push_text(&chunk);
cx.emit(ThreadEvent::StreamedAssistantText( cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id, last_message.id,
@ -1297,7 +1324,9 @@ impl Thread {
signature, signature,
} => { } => {
if let Some(last_message) = thread.messages.last_mut() { if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant { if last_message.role == Role::Assistant
&& !thread.tool_use.has_tool_results(last_message.id)
{
last_message.push_thinking(&chunk, signature); last_message.push_thinking(&chunk, signature);
cx.emit(ThreadEvent::StreamedAssistantThinking( cx.emit(ThreadEvent::StreamedAssistantThinking(
last_message.id, last_message.id,
@ -1725,10 +1754,10 @@ impl Thread {
if self.all_tools_finished() { if self.all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx); let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.attach_tool_results(cx);
if !canceled { if !canceled {
self.send_to_model(model, window, cx); self.send_to_model(model, window, cx);
} }
self.auto_capture_telemetry(cx);
} }
} }
@ -1738,14 +1767,6 @@ impl Thread {
}); });
} }
/// Insert an empty message to be populated with tool results upon send.
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
// Tool results are assumed to be waiting on the next message id, so they will populate
// this empty message before sending to model. Would prefer this to be more straightforward.
self.insert_message(Role::User, vec![], cx);
self.auto_capture_telemetry(cx);
}
/// Cancels the last pending completion, if there are any pending. /// Cancels the last pending completion, if there are any pending.
/// ///
/// Returns whether a completion was canceled. /// Returns whether a completion was canceled.
@ -2050,7 +2071,7 @@ impl Thread {
} }
for tool_result in self.tool_results_for_message(message.id) { for tool_result in self.tool_results_for_message(message.id) {
write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?; write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
if tool_result.is_error { if tool_result.is_error {
write!(markdown, " (Error)")?; write!(markdown, " (Error)")?;
} }

View file

@ -639,12 +639,17 @@ pub struct SerializedThread {
} }
impl SerializedThread { impl SerializedThread {
pub const VERSION: &'static str = "0.1.0"; pub const VERSION: &'static str = "0.2.0";
pub fn from_json(json: &[u8]) -> Result<Self> { pub fn from_json(json: &[u8]) -> Result<Self> {
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?; let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
match saved_thread_json.get("version") { match saved_thread_json.get("version") {
Some(serde_json::Value::String(version)) => match version.as_str() { Some(serde_json::Value::String(version)) => match version.as_str() {
SerializedThreadV0_1_0::VERSION => {
let saved_thread =
serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
Ok(saved_thread.upgrade())
}
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>( SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
saved_thread_json, saved_thread_json,
)?), )?),
@ -666,6 +671,38 @@ impl SerializedThread {
} }
} }
#[derive(Serialize, Deserialize, Debug)]
pub struct SerializedThreadV0_1_0(
// The structure did not change, so we are reusing the latest SerializedThread.
// When making the next version, make sure this points to SerializedThreadV0_2_0
SerializedThread,
);
impl SerializedThreadV0_1_0 {
pub const VERSION: &'static str = "0.1.0";
pub fn upgrade(self) -> SerializedThread {
debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
for message in self.0.messages {
if message.role == Role::User && !message.tool_results.is_empty() {
if let Some(last_message) = messages.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
last_message.tool_results = message.tool_results;
continue;
}
}
messages.push(message);
}
SerializedThread { messages, ..self.0 }
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SerializedMessage { pub struct SerializedMessage {
pub id: MessageId, pub id: MessageId,

View file

@ -30,7 +30,6 @@ pub struct ToolUse {
pub struct ToolUseState { pub struct ToolUseState {
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>, tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>, tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>, pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>, tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
@ -42,7 +41,6 @@ impl ToolUseState {
Self { Self {
tools, tools,
tool_uses_by_assistant_message: HashMap::default(), tool_uses_by_assistant_message: HashMap::default(),
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(), tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(), pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(), tool_result_cards: HashMap::default(),
@ -56,7 +54,6 @@ impl ToolUseState {
pub fn from_serialized_messages( pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage], messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self { ) -> Self {
let mut this = Self::new(tools); let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default(); let mut tool_names_by_id = HashMap::default();
@ -68,7 +65,6 @@ impl ToolUseState {
let tool_uses = message let tool_uses = message
.tool_uses .tool_uses
.iter() .iter()
.filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
.map(|tool_use| LanguageModelToolUse { .map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(), id: tool_use.id.clone(),
name: tool_use.name.clone().into(), name: tool_use.name.clone().into(),
@ -86,14 +82,6 @@ impl ToolUseState {
this.tool_uses_by_assistant_message this.tool_uses_by_assistant_message
.insert(message.id, tool_uses); .insert(message.id, tool_uses);
}
}
Role::User => {
if !message.tool_results.is_empty() {
let tool_uses_by_user_message = this
.tool_uses_by_user_message
.entry(message.id)
.or_default();
for tool_result in &message.tool_results { for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone(); let tool_use_id = tool_result.tool_use_id.clone();
@ -102,11 +90,6 @@ impl ToolUseState {
continue; continue;
}; };
if !(filter_by_tool_name)(tool_use.as_ref()) {
continue;
}
tool_uses_by_user_message.push(tool_use_id.clone());
this.tool_results.insert( this.tool_results.insert(
tool_use_id.clone(), tool_use_id.clone(),
LanguageModelToolResult { LanguageModelToolResult {
@ -119,7 +102,7 @@ impl ToolUseState {
} }
} }
} }
Role::System => {} Role::System | Role::User => {}
} }
} }
@ -229,20 +212,26 @@ impl ToolUseState {
} }
} }
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> { pub fn tool_results_for_message(
let empty = Vec::new(); &self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
let Some(tool_uses) = self
.tool_uses_by_assistant_message
.get(&assistant_message_id)
else {
return Vec::new();
};
self.tool_uses_by_user_message tool_uses
.get(&message_id)
.unwrap_or(&empty)
.iter() .iter()
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id)) .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
.collect() .collect()
} }
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_user_message self.tool_uses_by_assistant_message
.get(&message_id) .get(&assistant_message_id)
.map_or(false, |results| !results.is_empty()) .map_or(false, |results| !results.is_empty())
} }
@ -294,14 +283,6 @@ impl ToolUseState {
self.tool_use_metadata_by_id self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata); .insert(tool_use.id.clone(), metadata);
// The tool use is being requested by the Assistant, so we want to
// attach the tool results to the next user message.
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
self.tool_uses_by_user_message
.entry(next_user_message_id)
.or_default()
.push(tool_use.id.clone());
PendingToolUseStatus::Idle PendingToolUseStatus::Idle
} else { } else {
PendingToolUseStatus::InputStillStreaming PendingToolUseStatus::InputStillStreaming
@ -467,17 +448,35 @@ impl ToolUseState {
} }
} }
pub fn attach_tool_results( pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.contains_key(&assistant_message_id)
}
pub fn tool_results_message(
&self, &self,
message_id: MessageId, assistant_message_id: MessageId,
request_message: &mut LanguageModelRequestMessage, ) -> Option<LanguageModelRequestMessage> {
) { let tool_uses = self
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) { .tool_uses_by_assistant_message
for tool_use_id in tool_uses { .get(&assistant_message_id)?;
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
request_message.content.push(MessageContent::ToolResult( if tool_uses.is_empty() {
LanguageModelToolResult { return None;
tool_use_id: tool_use_id.clone(), }
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
for tool_use in tool_uses {
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
request_message
.content
.push(MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_use.id.clone(),
tool_name: tool_result.tool_name.clone(), tool_name: tool_result.tool_name.clone(),
is_error: tool_result.is_error, is_error: tool_result.is_error,
content: if tool_result.content.is_empty() { content: if tool_result.content.is_empty() {
@ -487,11 +486,11 @@ impl ToolUseState {
} else { } else {
tool_result.content.clone() tool_result.content.clone()
}, },
}, }));
));
}
} }
} }
Some(request_message)
} }
} }