From bafc086d27a3a5f129f48cedd0eedbe2a0574313 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Sat, 19 Apr 2025 22:12:03 +0200 Subject: [PATCH] agent: Preserve thinking blocks between requests (#29055) Looks like the required backend component of this was deployed. https://github.com/zed-industries/monorepo/actions/runs/14541199197 Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra Co-authored-by: Agus Zubiaga Co-authored-by: Richard Feldman Co-authored-by: Nathan Sobo --- crates/agent/src/active_thread.rs | 27 +++-- crates/agent/src/thread.rs | 108 +++++++++++++----- crates/agent/src/thread_store.rs | 13 ++- crates/anthropic/src/anthropic.rs | 9 ++ .../assistant_context_editor/src/context.rs | 2 +- crates/eval/src/example.rs | 16 ++- crates/language_model/src/language_model.rs | 11 +- crates/language_model/src/request.rs | 14 ++- .../language_models/src/provider/anthropic.rs | 49 +++++++- .../language_models/src/provider/bedrock.rs | 28 +++-- .../src/provider/copilot_chat.rs | 5 +- crates/language_models/src/provider/google.rs | 4 +- .../language_models/src/provider/open_ai.rs | 18 +-- 13 files changed, 236 insertions(+), 68 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 885d481eb2..4c3fbe878e 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -133,18 +133,23 @@ impl RenderedMessage { } fn push_segment(&mut self, segment: &MessageSegment, cx: &mut App) { - let rendered_segment = match segment { - MessageSegment::Thinking(text) => RenderedMessageSegment::Thinking { - content: parse_markdown(text.into(), self.language_registry.clone(), cx), - scroll_handle: ScrollHandle::default(), - }, - MessageSegment::Text(text) => RenderedMessageSegment::Text(parse_markdown( - text.into(), - self.language_registry.clone(), - cx, - )), + match segment { + MessageSegment::Thinking { text, .. } => { + self.segments.push(RenderedMessageSegment::Thinking { + content: parse_markdown(text.into(), self.language_registry.clone(), cx), + scroll_handle: ScrollHandle::default(), + }) + } + MessageSegment::Text(text) => { + self.segments + .push(RenderedMessageSegment::Text(parse_markdown( + text.into(), + self.language_registry.clone(), + cx, + ))) + } + MessageSegment::RedactedThinking(_) => {} }; - self.segments.push(rendered_segment); } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index c50c001cd2..1929daf2ed 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -113,12 +113,21 @@ impl Message { self.segments.iter().all(|segment| segment.should_display()) } - pub fn push_thinking(&mut self, text: &str) { - if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() { + pub fn push_thinking(&mut self, text: &str, signature: Option) { + if let Some(MessageSegment::Thinking { + text: segment, + signature: current_signature, + }) = self.segments.last_mut() + { + if let Some(signature) = signature { + *current_signature = Some(signature); + } segment.push_str(text); } else { - self.segments - .push(MessageSegment::Thinking(text.to_string())); + self.segments.push(MessageSegment::Thinking { + text: text.to_string(), + signature, + }); } } @@ -140,11 +149,12 @@ impl Message { for segment in &self.segments { match segment { MessageSegment::Text(text) => result.push_str(text), - MessageSegment::Thinking(text) => { - result.push_str(""); + MessageSegment::Thinking { text, .. } => { + result.push_str("\n"); result.push_str(text); - result.push_str(""); + result.push_str("\n"); } + MessageSegment::RedactedThinking(_) => {} } } @@ -155,24 +165,22 @@ impl Message { #[derive(Debug, Clone, PartialEq, Eq)] pub enum MessageSegment { Text(String), - Thinking(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking(Vec), } impl MessageSegment { - pub fn text_mut(&mut self) -> &mut String { - match self { - Self::Text(text) => text, - Self::Thinking(text) => text, - } - } - pub fn should_display(&self) -> bool { // We add USING_TOOL_MARKER when making a request that includes tool uses // without non-whitespace text around them, and this can cause the model // to mimic the pattern, so we consider those segments not displayable. match self { Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER, - Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER, + Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER, + Self::RedactedThinking(_) => false, } } } @@ -408,8 +416,11 @@ impl Thread { .into_iter() .map(|segment| match segment { SerializedMessageSegment::Text { text } => MessageSegment::Text(text), - SerializedMessageSegment::Thinking { text } => { - MessageSegment::Thinking(text) + SerializedMessageSegment::Thinking { text, signature } => { + MessageSegment::Thinking { text, signature } + } + SerializedMessageSegment::RedactedThinking { data } => { + MessageSegment::RedactedThinking(data) } }) .collect(), @@ -862,9 +873,10 @@ impl Thread { for segment in &message.segments { match segment { MessageSegment::Text(content) => text.push_str(content), - MessageSegment::Thinking(content) => { + MessageSegment::Thinking { text: content, .. } => { text.push_str(&format!("{}", content)) } + MessageSegment::RedactedThinking(_) => {} } } text.push('\n'); @@ -894,8 +906,16 @@ impl Thread { MessageSegment::Text(text) => { SerializedMessageSegment::Text { text: text.clone() } } - MessageSegment::Thinking(text) => { - SerializedMessageSegment::Thinking { text: text.clone() } + MessageSegment::Thinking { text, signature } => { + SerializedMessageSegment::Thinking { + text: text.clone(), + signature: signature.clone(), + } + } + MessageSegment::RedactedThinking(data) => { + SerializedMessageSegment::RedactedThinking { + data: data.clone(), + } } }) .collect(), @@ -1038,10 +1058,35 @@ impl Thread { } } - if !message.segments.is_empty() { + if !message.context.is_empty() { request_message .content - .push(MessageContent::Text(message.to_string())); + .push(MessageContent::Text(message.context.to_string())); + } + + for segment in &message.segments { + match segment { + MessageSegment::Text(text) => { + if !text.is_empty() { + request_message + .content + .push(MessageContent::Text(text.into())); + } + } + MessageSegment::Thinking { text, signature } => { + if !text.is_empty() { + request_message.content.push(MessageContent::Thinking { + text: text.into(), + signature: signature.clone(), + }); + } + } + MessageSegment::RedactedThinking(data) => { + request_message + .content + .push(MessageContent::RedactedThinking(data.clone())); + } + }; } match request_kind { @@ -1187,10 +1232,13 @@ impl Thread { }; } } - LanguageModelCompletionEvent::Thinking(chunk) => { + LanguageModelCompletionEvent::Thinking { + text: chunk, + signature, + } => { if let Some(last_message) = thread.messages.last_mut() { if last_message.role == Role::Assistant { - last_message.push_thinking(&chunk); + last_message.push_thinking(&chunk, signature); cx.emit(ThreadEvent::StreamedAssistantThinking( last_message.id, chunk, @@ -1203,7 +1251,10 @@ impl Thread { // will result in duplicating the text of the chunk in the rendered Markdown. thread.insert_message( Role::Assistant, - vec![MessageSegment::Thinking(chunk.to_string())], + vec![MessageSegment::Thinking { + text: chunk.to_string(), + signature, + }], cx, ); }; @@ -1893,9 +1944,10 @@ impl Thread { for segment in &message.segments { match segment { MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?, - MessageSegment::Thinking(text) => { - writeln!(markdown, "{}\n", text)? + MessageSegment::Thinking { text, .. } => { + writeln!(markdown, "\n{}\n\n", text)? } + MessageSegment::RedactedThinking(_) => {} } } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index e2f9e3d2de..646a26d26d 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -660,9 +660,18 @@ pub struct SerializedMessage { #[serde(tag = "type")] pub enum SerializedMessageSegment { #[serde(rename = "text")] - Text { text: String }, + Text { + text: String, + }, #[serde(rename = "thinking")] - Thinking { text: String }, + Thinking { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + signature: Option, + }, + RedactedThinking { + data: Vec, + }, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 266d3c7642..684feaca3b 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -507,6 +507,15 @@ pub enum RequestContent { #[serde(skip_serializing_if = "Option::is_none")] cache_control: Option, }, + #[serde(rename = "thinking")] + Thinking { + thinking: String, + signature: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, + #[serde(rename = "redacted_thinking")] + RedactedThinking { data: String }, #[serde(rename = "image")] Image { source: ImageSource, diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index 097cdb6ad5..398a6659e1 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -2373,7 +2373,7 @@ impl AssistantContext { LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; } - LanguageModelCompletionEvent::Thinking(chunk) => { + LanguageModelCompletionEvent::Thinking { text: chunk, .. } => { if thought_process_stack.is_empty() { let start = buffer.anchor_before(message_old_end_offset); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 6c12311b6a..982daeaed7 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -916,6 +916,20 @@ impl RequestMarkdown { MessageContent::Image(_) => { messages.push_str("[IMAGE DATA]\n\n"); } + MessageContent::Thinking { text, signature } => { + messages.push_str("**Thinking**:\n\n"); + if let Some(sig) = signature { + messages.push_str(&format!("Signature: {}\n\n", sig)); + } + messages.push_str(text); + messages.push_str("\n"); + } + MessageContent::RedactedThinking(items) => { + messages.push_str(&format!( + "**Redacted Thinking**: {} item(s)\n\n", + items.len() + )); + } MessageContent::ToolUse(tool_use) => { messages.push_str(&format!( "**Tool Use**: {} (ID: {})\n", @@ -970,7 +984,7 @@ fn response_events_to_markdown( Ok(LanguageModelCompletionEvent::Text(text)) => { text_buffer.push_str(text); } - Ok(LanguageModelCompletionEvent::Thinking(text)) => { + Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => { thinking_buffer.push_str(text); } Ok(LanguageModelCompletionEvent::Stop(reason)) => { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index aa08370edf..206958e82f 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -65,9 +65,14 @@ pub struct LanguageModelCacheConfiguration { pub enum LanguageModelCompletionEvent { Stop(StopReason), Text(String), - Thinking(String), + Thinking { + text: String, + signature: Option, + }, ToolUse(LanguageModelToolUse), - StartMessage { message_id: String }, + StartMessage { + message_id: String, + }, UsageUpdate(TokenUsage), } @@ -302,7 +307,7 @@ pub trait LanguageModel: Send + Sync { match result { Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), - Ok(LanguageModelCompletionEvent::Thinking(_)) => None, + Ok(LanguageModelCompletionEvent::Thinking { .. }) => None, Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 754b1671bb..0f1e97af5a 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -175,6 +175,11 @@ pub struct LanguageModelToolResult { #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub enum MessageContent { Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking(Vec), Image(LanguageModelImage), ToolUse(LanguageModelToolUse), ToolResult(LanguageModelToolResult), @@ -204,6 +209,8 @@ impl LanguageModelRequestMessage { let mut buffer = String::new(); for string in self.content.iter().filter_map(|content| match content { MessageContent::Text(text) => Some(text.as_str()), + MessageContent::Thinking { text, .. } => Some(text.as_str()), + MessageContent::RedactedThinking(_) => None, MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()), MessageContent::ToolUse(_) | MessageContent::Image(_) => None, }) { @@ -220,10 +227,15 @@ impl LanguageModelRequestMessage { .first() .map(|content| match content { MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), + MessageContent::Thinking { text, .. } => { + text.chars().all(|c| c.is_whitespace()) + } MessageContent::ToolResult(tool_result) => { tool_result.content.chars().all(|c| c.is_whitespace()) } - MessageContent::ToolUse(_) | MessageContent::Image(_) => true, + MessageContent::RedactedThinking(_) + | MessageContent::ToolUse(_) + | MessageContent::Image(_) => true, }) .unwrap_or(false) } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index da796a3f2b..6a29976504 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -336,6 +336,12 @@ pub fn count_anthropic_tokens( MessageContent::Text(text) => { string_contents.push_str(&text); } + MessageContent::Thinking { .. } => { + // Thinking blocks are not included in the input token count. + } + MessageContent::RedactedThinking(_) => { + // Thinking blocks are not included in the input token count. + } MessageContent::Image(image) => { tokens_from_images += image.estimate_tokens(); } @@ -515,6 +521,29 @@ pub fn into_anthropic( None } } + MessageContent::Thinking { + text: thinking, + signature, + } => { + if !thinking.is_empty() { + Some(anthropic::RequestContent::Thinking { + thinking, + signature: signature.unwrap_or_default(), + cache_control, + }) + } else { + None + } + } + MessageContent::RedactedThinking(data) => { + if !data.is_empty() { + Some(anthropic::RequestContent::RedactedThinking { + data: String::from_utf8(data).ok()?, + }) + } else { + None + } + } MessageContent::Image(image) => Some(anthropic::RequestContent::Image { source: anthropic::ImageSource { source_type: "base64".to_string(), @@ -637,7 +666,10 @@ pub fn map_to_language_model_completion_events( } ResponseContent::Thinking { thinking } => { return Some(( - vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))], + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })], state, )); } @@ -665,11 +697,22 @@ pub fn map_to_language_model_completion_events( } ContentDelta::ThinkingDelta { thinking } => { return Some(( - vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))], + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })], + state, + )); + } + ContentDelta::SignatureDelta { signature } => { + return Some(( + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature), + })], state, )); } - ContentDelta::SignatureDelta { .. } => {} ContentDelta::InputJsonDelta { partial_json } => { if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { tool_use.input_json.push_str(&partial_json); diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 48b7e62757..c4ef48404f 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -742,9 +742,10 @@ pub fn get_bedrock_tokens( for content in message.content { match content { - MessageContent::Text(text) => { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { string_contents.push_str(&text); } + MessageContent::RedactedThinking(_) => {} MessageContent::Image(image) => { tokens_from_images += image.estimate_tokens(); } @@ -830,25 +831,36 @@ pub fn map_to_language_model_completion_events( redacted, ) => { let thinking_event = - LanguageModelCompletionEvent::Thinking( - String::from_utf8( + LanguageModelCompletionEvent::Thinking { + text: String::from_utf8( redacted.into_inner(), ) .unwrap_or("REDACTED".to_string()), - ); + signature: None, + }; return Some(( Some(Ok(thinking_event)), state, )); } - ReasoningContentBlockDelta::Signature(_sig) => { + ReasoningContentBlockDelta::Signature( + signature, + ) => { + return Some(( + Some(Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature) + })), + state, + )); } ReasoningContentBlockDelta::Text(thoughts) => { let thinking_event = - LanguageModelCompletionEvent::Thinking( - thoughts.to_string(), - ); + LanguageModelCompletionEvent::Thinking { + text: thoughts.to_string(), + signature: None + }; return Some(( Some(Ok(thinking_event)), diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 39c005babc..3d4924b890 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -424,8 +424,11 @@ impl CopilotChatLanguageModel { let text_content = { let mut buffer = String::new(); for string in message.content.iter().filter_map(|content| match content { - MessageContent::Text(text) => Some(text.as_str()), + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + Some(text.as_str()) + } MessageContent::ToolUse(_) + | MessageContent::RedactedThinking(_) | MessageContent::ToolResult(_) | MessageContent::Image(_) => None, }) { diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index bec5a8768d..c754e63bbd 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -368,13 +368,15 @@ pub fn into_google( content .into_iter() .filter_map(|content| match content { - language_model::MessageContent::Text(text) => { + language_model::MessageContent::Text(text) + | language_model::MessageContent::Thinking { text, .. } => { if !text.is_empty() { Some(Part::TextPart(google_ai::TextPart { text })) } else { None } } + language_model::MessageContent::RedactedThinking(_) => None, language_model::MessageContent::Image(_) => None, language_model::MessageContent::ToolUse(tool_use) => { Some(Part::FunctionCallPart(google_ai::FunctionCallPart { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 6b14744a30..020c642520 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -342,14 +342,16 @@ pub fn into_open_ai( for message in request.messages { for content in message.content { match content { - MessageContent::Text(text) => messages.push(match message.role { - Role::User => open_ai::RequestMessage::User { content: text }, - Role::Assistant => open_ai::RequestMessage::Assistant { - content: Some(text), - tool_calls: Vec::new(), - }, - Role::System => open_ai::RequestMessage::System { content: text }, - }), + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages + .push(match message.role { + Role::User => open_ai::RequestMessage::User { content: text }, + Role::Assistant => open_ai::RequestMessage::Assistant { + content: Some(text), + tool_calls: Vec::new(), + }, + Role::System => open_ai::RequestMessage::System { content: text }, + }), + MessageContent::RedactedThinking(_) => {} MessageContent::Image(_) => {} MessageContent::ToolUse(tool_use) => { let tool_call = open_ai::ToolCall {