agent: Do not create user messages for tool results in thread (#29354)
We used to insert empty user messages into the `Thread::messages` `Vec` when tools finished running and then we would attach the results when creating the request. This approach was very easy to mess up during state handling, leading to empty user messages displayed in the conversation and API failures. Instead, we will no longer insert actual user messages for tool results to the `Thread`, and will only do this on the fly when creating the model request. This simplifies a lot of code and show fix the mentioned errors. Release Notes: - agent: Improve reliability of LLM requests when including tool results --------- Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de> Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
parent
952fe34aaa
commit
f81e65ae7c
4 changed files with 144 additions and 113 deletions
|
@ -1485,39 +1485,13 @@ impl ActiveThread {
|
|||
let is_first_message = ix == 0;
|
||||
let is_last_message = ix == self.messages.len() - 1;
|
||||
|
||||
let show_feedback = (!is_generating && is_last_message && message.role != Role::User)
|
||||
|| self.messages.get(ix + 1).map_or(false, |next_id| {
|
||||
self.thread
|
||||
.read(cx)
|
||||
.message(*next_id)
|
||||
.map_or(false, |next_message| {
|
||||
next_message.role == Role::User
|
||||
&& thread.tool_uses_for_message(*next_id, cx).is_empty()
|
||||
&& thread.tool_results_for_message(*next_id).is_empty()
|
||||
})
|
||||
});
|
||||
let show_feedback = thread.is_turn_end(ix);
|
||||
|
||||
let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation);
|
||||
|
||||
let generating_label = (is_generating && is_last_message)
|
||||
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
|
||||
|
||||
// Don't render user messages that are just there for returning tool results.
|
||||
if message.role == Role::User && thread.message_has_tool_results(message_id) {
|
||||
if let Some(generating_label) = generating_label {
|
||||
return h_flex()
|
||||
.w_full()
|
||||
.h_10()
|
||||
.py_1p5()
|
||||
.pl_4()
|
||||
.pb_3()
|
||||
.child(generating_label)
|
||||
.into_any_element();
|
||||
}
|
||||
|
||||
return Empty.into_any();
|
||||
}
|
||||
|
||||
let edit_message_editor = self
|
||||
.editing_message
|
||||
.as_ref()
|
||||
|
|
|
@ -391,8 +391,7 @@ impl Thread {
|
|||
.map(|message| message.id.0 + 1)
|
||||
.unwrap_or(0),
|
||||
);
|
||||
let tool_use =
|
||||
ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
|
||||
let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
|
||||
|
||||
Self {
|
||||
id,
|
||||
|
@ -524,7 +523,12 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn message(&self, id: MessageId) -> Option<&Message> {
|
||||
self.messages.iter().find(|message| message.id == id)
|
||||
let index = self
|
||||
.messages
|
||||
.binary_search_by(|message| message.id.cmp(&id))
|
||||
.ok()?;
|
||||
|
||||
self.messages.get(index)
|
||||
}
|
||||
|
||||
pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
|
||||
|
@ -673,6 +677,32 @@ impl Thread {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn is_turn_end(&self, ix: usize) -> bool {
|
||||
if self.messages.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if !self.is_generating() && ix == self.messages.len() - 1 {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(message) = self.messages.get(ix) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if message.role != Role::Assistant {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.messages
|
||||
.get(ix + 1)
|
||||
.and_then(|message| {
|
||||
self.message(message.id)
|
||||
.map(|next_message| next_message.role == Role::User)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Returns whether all of the tool uses have finished running.
|
||||
pub fn all_tools_finished(&self) -> bool {
|
||||
// If the only pending tool uses left are the ones with errors, then
|
||||
|
@ -687,8 +717,11 @@ impl Thread {
|
|||
self.tool_use.tool_uses_for_message(id, cx)
|
||||
}
|
||||
|
||||
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||
self.tool_use.tool_results_for_message(id)
|
||||
pub fn tool_results_for_message(
|
||||
&self,
|
||||
assistant_message_id: MessageId,
|
||||
) -> Vec<&LanguageModelToolResult> {
|
||||
self.tool_use.tool_results_for_message(assistant_message_id)
|
||||
}
|
||||
|
||||
pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
|
||||
|
@ -703,10 +736,6 @@ impl Thread {
|
|||
self.tool_use.tool_result_card(id).cloned()
|
||||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_use.message_has_tool_results(message_id)
|
||||
}
|
||||
|
||||
/// Filter out contexts that have already been included in previous messages
|
||||
pub fn filter_new_context<'a>(
|
||||
&self,
|
||||
|
@ -1051,9 +1080,6 @@ impl Thread {
|
|||
cache: false,
|
||||
};
|
||||
|
||||
self.tool_use
|
||||
.attach_tool_results(message.id, &mut request_message);
|
||||
|
||||
if !message.context.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
|
@ -1104,6 +1130,10 @@ impl Thread {
|
|||
.attach_tool_uses(message.id, &mut request_message);
|
||||
|
||||
request.messages.push(request_message);
|
||||
|
||||
if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
|
||||
request.messages.push(tool_results_message);
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||
|
@ -1133,11 +1163,6 @@ impl Thread {
|
|||
cache: false,
|
||||
};
|
||||
|
||||
// Skip tool results during summarization.
|
||||
if self.tool_use.message_has_tool_results(message.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => request_message
|
||||
|
@ -1272,7 +1297,9 @@ impl Thread {
|
|||
LanguageModelCompletionEvent::Text(chunk) => {
|
||||
cx.emit(ThreadEvent::ReceivedTextChunk);
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
if last_message.role == Role::Assistant
|
||||
&& !thread.tool_use.has_tool_results(last_message.id)
|
||||
{
|
||||
last_message.push_text(&chunk);
|
||||
cx.emit(ThreadEvent::StreamedAssistantText(
|
||||
last_message.id,
|
||||
|
@ -1297,7 +1324,9 @@ impl Thread {
|
|||
signature,
|
||||
} => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
if last_message.role == Role::Assistant
|
||||
&& !thread.tool_use.has_tool_results(last_message.id)
|
||||
{
|
||||
last_message.push_thinking(&chunk, signature);
|
||||
cx.emit(ThreadEvent::StreamedAssistantThinking(
|
||||
last_message.id,
|
||||
|
@ -1725,10 +1754,10 @@ impl Thread {
|
|||
if self.all_tools_finished() {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
|
||||
self.attach_tool_results(cx);
|
||||
if !canceled {
|
||||
self.send_to_model(model, window, cx);
|
||||
}
|
||||
self.auto_capture_telemetry(cx);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1738,14 +1767,6 @@ impl Thread {
|
|||
});
|
||||
}
|
||||
|
||||
/// Insert an empty message to be populated with tool results upon send.
|
||||
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
|
||||
// Tool results are assumed to be waiting on the next message id, so they will populate
|
||||
// this empty message before sending to model. Would prefer this to be more straightforward.
|
||||
self.insert_message(Role::User, vec![], cx);
|
||||
self.auto_capture_telemetry(cx);
|
||||
}
|
||||
|
||||
/// Cancels the last pending completion, if there are any pending.
|
||||
///
|
||||
/// Returns whether a completion was canceled.
|
||||
|
@ -2050,7 +2071,7 @@ impl Thread {
|
|||
}
|
||||
|
||||
for tool_result in self.tool_results_for_message(message.id) {
|
||||
write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
|
||||
write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
|
||||
if tool_result.is_error {
|
||||
write!(markdown, " (Error)")?;
|
||||
}
|
||||
|
|
|
@ -639,12 +639,17 @@ pub struct SerializedThread {
|
|||
}
|
||||
|
||||
impl SerializedThread {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
pub const VERSION: &'static str = "0.2.0";
|
||||
|
||||
pub fn from_json(json: &[u8]) -> Result<Self> {
|
||||
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
||||
match saved_thread_json.get("version") {
|
||||
Some(serde_json::Value::String(version)) => match version.as_str() {
|
||||
SerializedThreadV0_1_0::VERSION => {
|
||||
let saved_thread =
|
||||
serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
|
||||
Ok(saved_thread.upgrade())
|
||||
}
|
||||
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
|
||||
saved_thread_json,
|
||||
)?),
|
||||
|
@ -666,6 +671,38 @@ impl SerializedThread {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SerializedThreadV0_1_0(
|
||||
// The structure did not change, so we are reusing the latest SerializedThread.
|
||||
// When making the next version, make sure this points to SerializedThreadV0_2_0
|
||||
SerializedThread,
|
||||
);
|
||||
|
||||
impl SerializedThreadV0_1_0 {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
|
||||
pub fn upgrade(self) -> SerializedThread {
|
||||
debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
|
||||
|
||||
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
|
||||
|
||||
for message in self.0.messages {
|
||||
if message.role == Role::User && !message.tool_results.is_empty() {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
|
||||
last_message.tool_results = message.tool_results;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
SerializedThread { messages, ..self.0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SerializedMessage {
|
||||
pub id: MessageId,
|
||||
|
|
|
@ -30,7 +30,6 @@ pub struct ToolUse {
|
|||
pub struct ToolUseState {
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
|
||||
|
@ -42,7 +41,6 @@ impl ToolUseState {
|
|||
Self {
|
||||
tools,
|
||||
tool_uses_by_assistant_message: HashMap::default(),
|
||||
tool_uses_by_user_message: HashMap::default(),
|
||||
tool_results: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
tool_result_cards: HashMap::default(),
|
||||
|
@ -56,7 +54,6 @@ impl ToolUseState {
|
|||
pub fn from_serialized_messages(
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
messages: &[SerializedMessage],
|
||||
mut filter_by_tool_name: impl FnMut(&str) -> bool,
|
||||
) -> Self {
|
||||
let mut this = Self::new(tools);
|
||||
let mut tool_names_by_id = HashMap::default();
|
||||
|
@ -68,7 +65,6 @@ impl ToolUseState {
|
|||
let tool_uses = message
|
||||
.tool_uses
|
||||
.iter()
|
||||
.filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
|
||||
.map(|tool_use| LanguageModelToolUse {
|
||||
id: tool_use.id.clone(),
|
||||
name: tool_use.name.clone().into(),
|
||||
|
@ -86,14 +82,6 @@ impl ToolUseState {
|
|||
|
||||
this.tool_uses_by_assistant_message
|
||||
.insert(message.id, tool_uses);
|
||||
}
|
||||
}
|
||||
Role::User => {
|
||||
if !message.tool_results.is_empty() {
|
||||
let tool_uses_by_user_message = this
|
||||
.tool_uses_by_user_message
|
||||
.entry(message.id)
|
||||
.or_default();
|
||||
|
||||
for tool_result in &message.tool_results {
|
||||
let tool_use_id = tool_result.tool_use_id.clone();
|
||||
|
@ -102,11 +90,6 @@ impl ToolUseState {
|
|||
continue;
|
||||
};
|
||||
|
||||
if !(filter_by_tool_name)(tool_use.as_ref()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tool_uses_by_user_message.push(tool_use_id.clone());
|
||||
this.tool_results.insert(
|
||||
tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
|
@ -119,7 +102,7 @@ impl ToolUseState {
|
|||
}
|
||||
}
|
||||
}
|
||||
Role::System => {}
|
||||
Role::System | Role::User => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -229,20 +212,26 @@ impl ToolUseState {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||
let empty = Vec::new();
|
||||
pub fn tool_results_for_message(
|
||||
&self,
|
||||
assistant_message_id: MessageId,
|
||||
) -> Vec<&LanguageModelToolResult> {
|
||||
let Some(tool_uses) = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
.unwrap_or(&empty)
|
||||
tool_uses
|
||||
.iter()
|
||||
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
|
||||
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
.map_or(false, |results| !results.is_empty())
|
||||
}
|
||||
|
||||
|
@ -294,14 +283,6 @@ impl ToolUseState {
|
|||
self.tool_use_metadata_by_id
|
||||
.insert(tool_use.id.clone(), metadata);
|
||||
|
||||
// The tool use is being requested by the Assistant, so we want to
|
||||
// attach the tool results to the next user message.
|
||||
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
|
||||
self.tool_uses_by_user_message
|
||||
.entry(next_user_message_id)
|
||||
.or_default()
|
||||
.push(tool_use.id.clone());
|
||||
|
||||
PendingToolUseStatus::Idle
|
||||
} else {
|
||||
PendingToolUseStatus::InputStillStreaming
|
||||
|
@ -467,31 +448,49 @@ impl ToolUseState {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn attach_tool_results(
|
||||
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.contains_key(&assistant_message_id)
|
||||
}
|
||||
|
||||
pub fn tool_results_message(
|
||||
&self,
|
||||
message_id: MessageId,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
) {
|
||||
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
|
||||
for tool_use_id in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
|
||||
request_message.content.push(MessageContent::ToolResult(
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
assistant_message_id: MessageId,
|
||||
) -> Option<LanguageModelRequestMessage> {
|
||||
let tool_uses = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)?;
|
||||
|
||||
if tool_uses.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
for tool_use in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
));
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Some(request_message)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue