diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index eb6e51be27..31ed77e961 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1485,39 +1485,13 @@ impl ActiveThread { let is_first_message = ix == 0; let is_last_message = ix == self.messages.len() - 1; - let show_feedback = (!is_generating && is_last_message && message.role != Role::User) - || self.messages.get(ix + 1).map_or(false, |next_id| { - self.thread - .read(cx) - .message(*next_id) - .map_or(false, |next_message| { - next_message.role == Role::User - && thread.tool_uses_for_message(*next_id, cx).is_empty() - && thread.tool_results_for_message(*next_id).is_empty() - }) - }); + let show_feedback = thread.is_turn_end(ix); let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation); let generating_label = (is_generating && is_last_message) .then(|| AnimatedLabel::new("Generating").size(LabelSize::Small)); - // Don't render user messages that are just there for returning tool results. - if message.role == Role::User && thread.message_has_tool_results(message_id) { - if let Some(generating_label) = generating_label { - return h_flex() - .w_full() - .h_10() - .py_1p5() - .pl_4() - .pb_3() - .child(generating_label) - .into_any_element(); - } - - return Empty.into_any(); - } - let edit_message_editor = self .editing_message .as_ref() diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 5cbe49c8bc..34ccb01d60 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -391,8 +391,7 @@ impl Thread { .map(|message| message.id.0 + 1) .unwrap_or(0), ); - let tool_use = - ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true); + let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages); Self { id, @@ -524,7 +523,12 @@ impl Thread { } pub fn message(&self, id: MessageId) -> Option<&Message> { - self.messages.iter().find(|message| message.id == id) + let index = self + .messages + .binary_search_by(|message| message.id.cmp(&id)) + .ok()?; + + self.messages.get(index) } pub fn messages(&self) -> impl ExactSizeIterator { @@ -673,6 +677,32 @@ impl Thread { }) } + pub fn is_turn_end(&self, ix: usize) -> bool { + if self.messages.is_empty() { + return false; + } + + if !self.is_generating() && ix == self.messages.len() - 1 { + return true; + } + + let Some(message) = self.messages.get(ix) else { + return false; + }; + + if message.role != Role::Assistant { + return false; + } + + self.messages + .get(ix + 1) + .and_then(|message| { + self.message(message.id) + .map(|next_message| next_message.role == Role::User) + }) + .unwrap_or(false) + } + /// Returns whether all of the tool uses have finished running. pub fn all_tools_finished(&self) -> bool { // If the only pending tool uses left are the ones with errors, then @@ -687,8 +717,11 @@ impl Thread { self.tool_use.tool_uses_for_message(id, cx) } - pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> { - self.tool_use.tool_results_for_message(id) + pub fn tool_results_for_message( + &self, + assistant_message_id: MessageId, + ) -> Vec<&LanguageModelToolResult> { + self.tool_use.tool_results_for_message(assistant_message_id) } pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> { @@ -703,10 +736,6 @@ impl Thread { self.tool_use.tool_result_card(id).cloned() } - pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { - self.tool_use.message_has_tool_results(message_id) - } - /// Filter out contexts that have already been included in previous messages pub fn filter_new_context<'a>( &self, @@ -1051,9 +1080,6 @@ impl Thread { cache: false, }; - self.tool_use - .attach_tool_results(message.id, &mut request_message); - if !message.context.is_empty() { request_message .content @@ -1104,6 +1130,10 @@ impl Thread { .attach_tool_uses(message.id, &mut request_message); request.messages.push(request_message); + + if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) { + request.messages.push(tool_results_message); + } } // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching @@ -1133,11 +1163,6 @@ impl Thread { cache: false, }; - // Skip tool results during summarization. - if self.tool_use.message_has_tool_results(message.id) { - continue; - } - for segment in &message.segments { match segment { MessageSegment::Text(text) => request_message @@ -1272,7 +1297,9 @@ impl Thread { LanguageModelCompletionEvent::Text(chunk) => { cx.emit(ThreadEvent::ReceivedTextChunk); if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant { + if last_message.role == Role::Assistant + && !thread.tool_use.has_tool_results(last_message.id) + { last_message.push_text(&chunk); cx.emit(ThreadEvent::StreamedAssistantText( last_message.id, @@ -1297,7 +1324,9 @@ impl Thread { signature, } => { if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant { + if last_message.role == Role::Assistant + && !thread.tool_use.has_tool_results(last_message.id) + { last_message.push_thinking(&chunk, signature); cx.emit(ThreadEvent::StreamedAssistantThinking( last_message.id, @@ -1725,10 +1754,10 @@ impl Thread { if self.all_tools_finished() { let model_registry = LanguageModelRegistry::read_global(cx); if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { - self.attach_tool_results(cx); if !canceled { self.send_to_model(model, window, cx); } + self.auto_capture_telemetry(cx); } } @@ -1738,14 +1767,6 @@ impl Thread { }); } - /// Insert an empty message to be populated with tool results upon send. - pub fn attach_tool_results(&mut self, cx: &mut Context) { - // Tool results are assumed to be waiting on the next message id, so they will populate - // this empty message before sending to model. Would prefer this to be more straightforward. - self.insert_message(Role::User, vec![], cx); - self.auto_capture_telemetry(cx); - } - /// Cancels the last pending completion, if there are any pending. /// /// Returns whether a completion was canceled. @@ -2050,7 +2071,7 @@ impl Thread { } for tool_result in self.tool_results_for_message(message.id) { - write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?; + write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?; if tool_result.is_error { write!(markdown, " (Error)")?; } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 1f57c08cc7..117be3675c 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -639,12 +639,17 @@ pub struct SerializedThread { } impl SerializedThread { - pub const VERSION: &'static str = "0.1.0"; + pub const VERSION: &'static str = "0.2.0"; pub fn from_json(json: &[u8]) -> Result { let saved_thread_json = serde_json::from_slice::(json)?; match saved_thread_json.get("version") { Some(serde_json::Value::String(version)) => match version.as_str() { + SerializedThreadV0_1_0::VERSION => { + let saved_thread = + serde_json::from_value::(saved_thread_json)?; + Ok(saved_thread.upgrade()) + } SerializedThread::VERSION => Ok(serde_json::from_value::( saved_thread_json, )?), @@ -666,6 +671,38 @@ impl SerializedThread { } } +#[derive(Serialize, Deserialize, Debug)] +pub struct SerializedThreadV0_1_0( + // The structure did not change, so we are reusing the latest SerializedThread. + // When making the next version, make sure this points to SerializedThreadV0_2_0 + SerializedThread, +); + +impl SerializedThreadV0_1_0 { + pub const VERSION: &'static str = "0.1.0"; + + pub fn upgrade(self) -> SerializedThread { + debug_assert_eq!(SerializedThread::VERSION, "0.2.0"); + + let mut messages: Vec = Vec::with_capacity(self.0.messages.len()); + + for message in self.0.messages { + if message.role == Role::User && !message.tool_results.is_empty() { + if let Some(last_message) = messages.last_mut() { + debug_assert!(last_message.role == Role::Assistant); + + last_message.tool_results = message.tool_results; + continue; + } + } + + messages.push(message); + } + + SerializedThread { messages, ..self.0 } + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct SerializedMessage { pub id: MessageId, diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index a6441bb5ae..9b5b2f02d9 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -30,7 +30,6 @@ pub struct ToolUse { pub struct ToolUseState { tools: Entity, tool_uses_by_assistant_message: HashMap>, - tool_uses_by_user_message: HashMap>, tool_results: HashMap, pending_tool_uses_by_id: HashMap, tool_result_cards: HashMap, @@ -42,7 +41,6 @@ impl ToolUseState { Self { tools, tool_uses_by_assistant_message: HashMap::default(), - tool_uses_by_user_message: HashMap::default(), tool_results: HashMap::default(), pending_tool_uses_by_id: HashMap::default(), tool_result_cards: HashMap::default(), @@ -56,7 +54,6 @@ impl ToolUseState { pub fn from_serialized_messages( tools: Entity, messages: &[SerializedMessage], - mut filter_by_tool_name: impl FnMut(&str) -> bool, ) -> Self { let mut this = Self::new(tools); let mut tool_names_by_id = HashMap::default(); @@ -68,7 +65,6 @@ impl ToolUseState { let tool_uses = message .tool_uses .iter() - .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref())) .map(|tool_use| LanguageModelToolUse { id: tool_use.id.clone(), name: tool_use.name.clone().into(), @@ -86,14 +82,6 @@ impl ToolUseState { this.tool_uses_by_assistant_message .insert(message.id, tool_uses); - } - } - Role::User => { - if !message.tool_results.is_empty() { - let tool_uses_by_user_message = this - .tool_uses_by_user_message - .entry(message.id) - .or_default(); for tool_result in &message.tool_results { let tool_use_id = tool_result.tool_use_id.clone(); @@ -102,11 +90,6 @@ impl ToolUseState { continue; }; - if !(filter_by_tool_name)(tool_use.as_ref()) { - continue; - } - - tool_uses_by_user_message.push(tool_use_id.clone()); this.tool_results.insert( tool_use_id.clone(), LanguageModelToolResult { @@ -119,7 +102,7 @@ impl ToolUseState { } } } - Role::System => {} + Role::System | Role::User => {} } } @@ -229,20 +212,26 @@ impl ToolUseState { } } - pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> { - let empty = Vec::new(); + pub fn tool_results_for_message( + &self, + assistant_message_id: MessageId, + ) -> Vec<&LanguageModelToolResult> { + let Some(tool_uses) = self + .tool_uses_by_assistant_message + .get(&assistant_message_id) + else { + return Vec::new(); + }; - self.tool_uses_by_user_message - .get(&message_id) - .unwrap_or(&empty) + tool_uses .iter() - .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id)) + .filter_map(|tool_use| self.tool_results.get(&tool_use.id)) .collect() } - pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { - self.tool_uses_by_user_message - .get(&message_id) + pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool { + self.tool_uses_by_assistant_message + .get(&assistant_message_id) .map_or(false, |results| !results.is_empty()) } @@ -294,14 +283,6 @@ impl ToolUseState { self.tool_use_metadata_by_id .insert(tool_use.id.clone(), metadata); - // The tool use is being requested by the Assistant, so we want to - // attach the tool results to the next user message. - let next_user_message_id = MessageId(assistant_message_id.0 + 1); - self.tool_uses_by_user_message - .entry(next_user_message_id) - .or_default() - .push(tool_use.id.clone()); - PendingToolUseStatus::Idle } else { PendingToolUseStatus::InputStillStreaming @@ -467,31 +448,49 @@ impl ToolUseState { } } - pub fn attach_tool_results( + pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool { + self.tool_uses_by_assistant_message + .contains_key(&assistant_message_id) + } + + pub fn tool_results_message( &self, - message_id: MessageId, - request_message: &mut LanguageModelRequestMessage, - ) { - if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) { - for tool_use_id in tool_uses { - if let Some(tool_result) = self.tool_results.get(tool_use_id) { - request_message.content.push(MessageContent::ToolResult( - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name: tool_result.tool_name.clone(), - is_error: tool_result.is_error, - content: if tool_result.content.is_empty() { - // Surprisingly, the API fails if we return an empty string here. - // It thinks we are sending a tool use without a tool result. - "".into() - } else { - tool_result.content.clone() - }, + assistant_message_id: MessageId, + ) -> Option { + let tool_uses = self + .tool_uses_by_assistant_message + .get(&assistant_message_id)?; + + if tool_uses.is_empty() { + return None; + } + + let mut request_message = LanguageModelRequestMessage { + role: Role::User, + content: vec![], + cache: false, + }; + + for tool_use in tool_uses { + if let Some(tool_result) = self.tool_results.get(&tool_use.id) { + request_message + .content + .push(MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: tool_use.id.clone(), + tool_name: tool_result.tool_name.clone(), + is_error: tool_result.is_error, + content: if tool_result.content.is_empty() { + // Surprisingly, the API fails if we return an empty string here. + // It thinks we are sending a tool use without a tool result. + "".into() + } else { + tool_result.content.clone() }, - )); - } + })); } } + + Some(request_message) } }