agent: Preserve thinking blocks between requests (#29055)

Looks like the required backend component of this was deployed.

https://github.com/zed-industries/monorepo/actions/runs/14541199197

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2025-04-19 22:12:03 +02:00 committed by GitHub
parent f737c4d01e
commit bafc086d27
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 236 additions and 68 deletions

View file

@ -133,18 +133,23 @@ impl RenderedMessage {
} }
fn push_segment(&mut self, segment: &MessageSegment, cx: &mut App) { fn push_segment(&mut self, segment: &MessageSegment, cx: &mut App) {
let rendered_segment = match segment { match segment {
MessageSegment::Thinking(text) => RenderedMessageSegment::Thinking { MessageSegment::Thinking { text, .. } => {
content: parse_markdown(text.into(), self.language_registry.clone(), cx), self.segments.push(RenderedMessageSegment::Thinking {
scroll_handle: ScrollHandle::default(), content: parse_markdown(text.into(), self.language_registry.clone(), cx),
}, scroll_handle: ScrollHandle::default(),
MessageSegment::Text(text) => RenderedMessageSegment::Text(parse_markdown( })
text.into(), }
self.language_registry.clone(), MessageSegment::Text(text) => {
cx, self.segments
)), .push(RenderedMessageSegment::Text(parse_markdown(
text.into(),
self.language_registry.clone(),
cx,
)))
}
MessageSegment::RedactedThinking(_) => {}
}; };
self.segments.push(rendered_segment);
} }
} }

View file

@ -113,12 +113,21 @@ impl Message {
self.segments.iter().all(|segment| segment.should_display()) self.segments.iter().all(|segment| segment.should_display())
} }
pub fn push_thinking(&mut self, text: &str) { pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() { if let Some(MessageSegment::Thinking {
text: segment,
signature: current_signature,
}) = self.segments.last_mut()
{
if let Some(signature) = signature {
*current_signature = Some(signature);
}
segment.push_str(text); segment.push_str(text);
} else { } else {
self.segments self.segments.push(MessageSegment::Thinking {
.push(MessageSegment::Thinking(text.to_string())); text: text.to_string(),
signature,
});
} }
} }
@ -140,11 +149,12 @@ impl Message {
for segment in &self.segments { for segment in &self.segments {
match segment { match segment {
MessageSegment::Text(text) => result.push_str(text), MessageSegment::Text(text) => result.push_str(text),
MessageSegment::Thinking(text) => { MessageSegment::Thinking { text, .. } => {
result.push_str("<think>"); result.push_str("<think>\n");
result.push_str(text); result.push_str(text);
result.push_str("</think>"); result.push_str("\n</think>");
} }
MessageSegment::RedactedThinking(_) => {}
} }
} }
@ -155,24 +165,22 @@ impl Message {
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageSegment { pub enum MessageSegment {
Text(String), Text(String),
Thinking(String), Thinking {
text: String,
signature: Option<String>,
},
RedactedThinking(Vec<u8>),
} }
impl MessageSegment { impl MessageSegment {
pub fn text_mut(&mut self) -> &mut String {
match self {
Self::Text(text) => text,
Self::Thinking(text) => text,
}
}
pub fn should_display(&self) -> bool { pub fn should_display(&self) -> bool {
// We add USING_TOOL_MARKER when making a request that includes tool uses // We add USING_TOOL_MARKER when making a request that includes tool uses
// without non-whitespace text around them, and this can cause the model // without non-whitespace text around them, and this can cause the model
// to mimic the pattern, so we consider those segments not displayable. // to mimic the pattern, so we consider those segments not displayable.
match self { match self {
Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER, Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER, Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
Self::RedactedThinking(_) => false,
} }
} }
} }
@ -408,8 +416,11 @@ impl Thread {
.into_iter() .into_iter()
.map(|segment| match segment { .map(|segment| match segment {
SerializedMessageSegment::Text { text } => MessageSegment::Text(text), SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
SerializedMessageSegment::Thinking { text } => { SerializedMessageSegment::Thinking { text, signature } => {
MessageSegment::Thinking(text) MessageSegment::Thinking { text, signature }
}
SerializedMessageSegment::RedactedThinking { data } => {
MessageSegment::RedactedThinking(data)
} }
}) })
.collect(), .collect(),
@ -862,9 +873,10 @@ impl Thread {
for segment in &message.segments { for segment in &message.segments {
match segment { match segment {
MessageSegment::Text(content) => text.push_str(content), MessageSegment::Text(content) => text.push_str(content),
MessageSegment::Thinking(content) => { MessageSegment::Thinking { text: content, .. } => {
text.push_str(&format!("<think>{}</think>", content)) text.push_str(&format!("<think>{}</think>", content))
} }
MessageSegment::RedactedThinking(_) => {}
} }
} }
text.push('\n'); text.push('\n');
@ -894,8 +906,16 @@ impl Thread {
MessageSegment::Text(text) => { MessageSegment::Text(text) => {
SerializedMessageSegment::Text { text: text.clone() } SerializedMessageSegment::Text { text: text.clone() }
} }
MessageSegment::Thinking(text) => { MessageSegment::Thinking { text, signature } => {
SerializedMessageSegment::Thinking { text: text.clone() } SerializedMessageSegment::Thinking {
text: text.clone(),
signature: signature.clone(),
}
}
MessageSegment::RedactedThinking(data) => {
SerializedMessageSegment::RedactedThinking {
data: data.clone(),
}
} }
}) })
.collect(), .collect(),
@ -1038,10 +1058,35 @@ impl Thread {
} }
} }
if !message.segments.is_empty() { if !message.context.is_empty() {
request_message request_message
.content .content
.push(MessageContent::Text(message.to_string())); .push(MessageContent::Text(message.context.to_string()));
}
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => {
if !text.is_empty() {
request_message
.content
.push(MessageContent::Text(text.into()));
}
}
MessageSegment::Thinking { text, signature } => {
if !text.is_empty() {
request_message.content.push(MessageContent::Thinking {
text: text.into(),
signature: signature.clone(),
});
}
}
MessageSegment::RedactedThinking(data) => {
request_message
.content
.push(MessageContent::RedactedThinking(data.clone()));
}
};
} }
match request_kind { match request_kind {
@ -1187,10 +1232,13 @@ impl Thread {
}; };
} }
} }
LanguageModelCompletionEvent::Thinking(chunk) => { LanguageModelCompletionEvent::Thinking {
text: chunk,
signature,
} => {
if let Some(last_message) = thread.messages.last_mut() { if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant { if last_message.role == Role::Assistant {
last_message.push_thinking(&chunk); last_message.push_thinking(&chunk, signature);
cx.emit(ThreadEvent::StreamedAssistantThinking( cx.emit(ThreadEvent::StreamedAssistantThinking(
last_message.id, last_message.id,
chunk, chunk,
@ -1203,7 +1251,10 @@ impl Thread {
// will result in duplicating the text of the chunk in the rendered Markdown. // will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message( thread.insert_message(
Role::Assistant, Role::Assistant,
vec![MessageSegment::Thinking(chunk.to_string())], vec![MessageSegment::Thinking {
text: chunk.to_string(),
signature,
}],
cx, cx,
); );
}; };
@ -1893,9 +1944,10 @@ impl Thread {
for segment in &message.segments { for segment in &message.segments {
match segment { match segment {
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?, MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
MessageSegment::Thinking(text) => { MessageSegment::Thinking { text, .. } => {
writeln!(markdown, "<think>{}</think>\n", text)? writeln!(markdown, "<think>\n{}\n</think>\n", text)?
} }
MessageSegment::RedactedThinking(_) => {}
} }
} }

View file

@ -660,9 +660,18 @@ pub struct SerializedMessage {
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum SerializedMessageSegment { pub enum SerializedMessageSegment {
#[serde(rename = "text")] #[serde(rename = "text")]
Text { text: String }, Text {
text: String,
},
#[serde(rename = "thinking")] #[serde(rename = "thinking")]
Thinking { text: String }, Thinking {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
signature: Option<String>,
},
RedactedThinking {
data: Vec<u8>,
},
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]

View file

@ -507,6 +507,15 @@ pub enum RequestContent {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>, cache_control: Option<CacheControl>,
}, },
#[serde(rename = "thinking")]
Thinking {
thinking: String,
signature: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "redacted_thinking")]
RedactedThinking { data: String },
#[serde(rename = "image")] #[serde(rename = "image")]
Image { Image {
source: ImageSource, source: ImageSource,

View file

@ -2373,7 +2373,7 @@ impl AssistantContext {
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason; stop_reason = reason;
} }
LanguageModelCompletionEvent::Thinking(chunk) => { LanguageModelCompletionEvent::Thinking { text: chunk, .. } => {
if thought_process_stack.is_empty() { if thought_process_stack.is_empty() {
let start = let start =
buffer.anchor_before(message_old_end_offset); buffer.anchor_before(message_old_end_offset);

View file

@ -916,6 +916,20 @@ impl RequestMarkdown {
MessageContent::Image(_) => { MessageContent::Image(_) => {
messages.push_str("[IMAGE DATA]\n\n"); messages.push_str("[IMAGE DATA]\n\n");
} }
MessageContent::Thinking { text, signature } => {
messages.push_str("**Thinking**:\n\n");
if let Some(sig) = signature {
messages.push_str(&format!("Signature: {}\n\n", sig));
}
messages.push_str(text);
messages.push_str("\n");
}
MessageContent::RedactedThinking(items) => {
messages.push_str(&format!(
"**Redacted Thinking**: {} item(s)\n\n",
items.len()
));
}
MessageContent::ToolUse(tool_use) => { MessageContent::ToolUse(tool_use) => {
messages.push_str(&format!( messages.push_str(&format!(
"**Tool Use**: {} (ID: {})\n", "**Tool Use**: {} (ID: {})\n",
@ -970,7 +984,7 @@ fn response_events_to_markdown(
Ok(LanguageModelCompletionEvent::Text(text)) => { Ok(LanguageModelCompletionEvent::Text(text)) => {
text_buffer.push_str(text); text_buffer.push_str(text);
} }
Ok(LanguageModelCompletionEvent::Thinking(text)) => { Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => {
thinking_buffer.push_str(text); thinking_buffer.push_str(text);
} }
Ok(LanguageModelCompletionEvent::Stop(reason)) => { Ok(LanguageModelCompletionEvent::Stop(reason)) => {

View file

@ -65,9 +65,14 @@ pub struct LanguageModelCacheConfiguration {
pub enum LanguageModelCompletionEvent { pub enum LanguageModelCompletionEvent {
Stop(StopReason), Stop(StopReason),
Text(String), Text(String),
Thinking(String), Thinking {
text: String,
signature: Option<String>,
},
ToolUse(LanguageModelToolUse), ToolUse(LanguageModelToolUse),
StartMessage { message_id: String }, StartMessage {
message_id: String,
},
UsageUpdate(TokenUsage), UsageUpdate(TokenUsage),
} }
@ -302,7 +307,7 @@ pub trait LanguageModel: Send + Sync {
match result { match result {
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking(_)) => None, Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {

View file

@ -175,6 +175,11 @@ pub struct LanguageModelToolResult {
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent { pub enum MessageContent {
Text(String), Text(String),
Thinking {
text: String,
signature: Option<String>,
},
RedactedThinking(Vec<u8>),
Image(LanguageModelImage), Image(LanguageModelImage),
ToolUse(LanguageModelToolUse), ToolUse(LanguageModelToolUse),
ToolResult(LanguageModelToolResult), ToolResult(LanguageModelToolResult),
@ -204,6 +209,8 @@ impl LanguageModelRequestMessage {
let mut buffer = String::new(); let mut buffer = String::new();
for string in self.content.iter().filter_map(|content| match content { for string in self.content.iter().filter_map(|content| match content {
MessageContent::Text(text) => Some(text.as_str()), MessageContent::Text(text) => Some(text.as_str()),
MessageContent::Thinking { text, .. } => Some(text.as_str()),
MessageContent::RedactedThinking(_) => None,
MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()), MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()),
MessageContent::ToolUse(_) | MessageContent::Image(_) => None, MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
}) { }) {
@ -220,10 +227,15 @@ impl LanguageModelRequestMessage {
.first() .first()
.map(|content| match content { .map(|content| match content {
MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
MessageContent::Thinking { text, .. } => {
text.chars().all(|c| c.is_whitespace())
}
MessageContent::ToolResult(tool_result) => { MessageContent::ToolResult(tool_result) => {
tool_result.content.chars().all(|c| c.is_whitespace()) tool_result.content.chars().all(|c| c.is_whitespace())
} }
MessageContent::ToolUse(_) | MessageContent::Image(_) => true, MessageContent::RedactedThinking(_)
| MessageContent::ToolUse(_)
| MessageContent::Image(_) => true,
}) })
.unwrap_or(false) .unwrap_or(false)
} }

View file

@ -336,6 +336,12 @@ pub fn count_anthropic_tokens(
MessageContent::Text(text) => { MessageContent::Text(text) => {
string_contents.push_str(&text); string_contents.push_str(&text);
} }
MessageContent::Thinking { .. } => {
// Thinking blocks are not included in the input token count.
}
MessageContent::RedactedThinking(_) => {
// Thinking blocks are not included in the input token count.
}
MessageContent::Image(image) => { MessageContent::Image(image) => {
tokens_from_images += image.estimate_tokens(); tokens_from_images += image.estimate_tokens();
} }
@ -515,6 +521,29 @@ pub fn into_anthropic(
None None
} }
} }
MessageContent::Thinking {
text: thinking,
signature,
} => {
if !thinking.is_empty() {
Some(anthropic::RequestContent::Thinking {
thinking,
signature: signature.unwrap_or_default(),
cache_control,
})
} else {
None
}
}
MessageContent::RedactedThinking(data) => {
if !data.is_empty() {
Some(anthropic::RequestContent::RedactedThinking {
data: String::from_utf8(data).ok()?,
})
} else {
None
}
}
MessageContent::Image(image) => Some(anthropic::RequestContent::Image { MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
source: anthropic::ImageSource { source: anthropic::ImageSource {
source_type: "base64".to_string(), source_type: "base64".to_string(),
@ -637,7 +666,10 @@ pub fn map_to_language_model_completion_events(
} }
ResponseContent::Thinking { thinking } => { ResponseContent::Thinking { thinking } => {
return Some(( return Some((
vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))], vec![Ok(LanguageModelCompletionEvent::Thinking {
text: thinking,
signature: None,
})],
state, state,
)); ));
} }
@ -665,11 +697,22 @@ pub fn map_to_language_model_completion_events(
} }
ContentDelta::ThinkingDelta { thinking } => { ContentDelta::ThinkingDelta { thinking } => {
return Some(( return Some((
vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))], vec![Ok(LanguageModelCompletionEvent::Thinking {
text: thinking,
signature: None,
})],
state,
));
}
ContentDelta::SignatureDelta { signature } => {
return Some((
vec![Ok(LanguageModelCompletionEvent::Thinking {
text: "".to_string(),
signature: Some(signature),
})],
state, state,
)); ));
} }
ContentDelta::SignatureDelta { .. } => {}
ContentDelta::InputJsonDelta { partial_json } => { ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json); tool_use.input_json.push_str(&partial_json);

View file

@ -742,9 +742,10 @@ pub fn get_bedrock_tokens(
for content in message.content { for content in message.content {
match content { match content {
MessageContent::Text(text) => { MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
string_contents.push_str(&text); string_contents.push_str(&text);
} }
MessageContent::RedactedThinking(_) => {}
MessageContent::Image(image) => { MessageContent::Image(image) => {
tokens_from_images += image.estimate_tokens(); tokens_from_images += image.estimate_tokens();
} }
@ -830,25 +831,36 @@ pub fn map_to_language_model_completion_events(
redacted, redacted,
) => { ) => {
let thinking_event = let thinking_event =
LanguageModelCompletionEvent::Thinking( LanguageModelCompletionEvent::Thinking {
String::from_utf8( text: String::from_utf8(
redacted.into_inner(), redacted.into_inner(),
) )
.unwrap_or("REDACTED".to_string()), .unwrap_or("REDACTED".to_string()),
); signature: None,
};
return Some(( return Some((
Some(Ok(thinking_event)), Some(Ok(thinking_event)),
state, state,
)); ));
} }
ReasoningContentBlockDelta::Signature(_sig) => { ReasoningContentBlockDelta::Signature(
signature,
) => {
return Some((
Some(Ok(LanguageModelCompletionEvent::Thinking {
text: "".to_string(),
signature: Some(signature)
})),
state,
));
} }
ReasoningContentBlockDelta::Text(thoughts) => { ReasoningContentBlockDelta::Text(thoughts) => {
let thinking_event = let thinking_event =
LanguageModelCompletionEvent::Thinking( LanguageModelCompletionEvent::Thinking {
thoughts.to_string(), text: thoughts.to_string(),
); signature: None
};
return Some(( return Some((
Some(Ok(thinking_event)), Some(Ok(thinking_event)),

View file

@ -424,8 +424,11 @@ impl CopilotChatLanguageModel {
let text_content = { let text_content = {
let mut buffer = String::new(); let mut buffer = String::new();
for string in message.content.iter().filter_map(|content| match content { for string in message.content.iter().filter_map(|content| match content {
MessageContent::Text(text) => Some(text.as_str()), MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
Some(text.as_str())
}
MessageContent::ToolUse(_) MessageContent::ToolUse(_)
| MessageContent::RedactedThinking(_)
| MessageContent::ToolResult(_) | MessageContent::ToolResult(_)
| MessageContent::Image(_) => None, | MessageContent::Image(_) => None,
}) { }) {

View file

@ -368,13 +368,15 @@ pub fn into_google(
content content
.into_iter() .into_iter()
.filter_map(|content| match content { .filter_map(|content| match content {
language_model::MessageContent::Text(text) => { language_model::MessageContent::Text(text)
| language_model::MessageContent::Thinking { text, .. } => {
if !text.is_empty() { if !text.is_empty() {
Some(Part::TextPart(google_ai::TextPart { text })) Some(Part::TextPart(google_ai::TextPart { text }))
} else { } else {
None None
} }
} }
language_model::MessageContent::RedactedThinking(_) => None,
language_model::MessageContent::Image(_) => None, language_model::MessageContent::Image(_) => None,
language_model::MessageContent::ToolUse(tool_use) => { language_model::MessageContent::ToolUse(tool_use) => {
Some(Part::FunctionCallPart(google_ai::FunctionCallPart { Some(Part::FunctionCallPart(google_ai::FunctionCallPart {

View file

@ -342,14 +342,16 @@ pub fn into_open_ai(
for message in request.messages { for message in request.messages {
for content in message.content { for content in message.content {
match content { match content {
MessageContent::Text(text) => messages.push(match message.role { MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
Role::User => open_ai::RequestMessage::User { content: text }, .push(match message.role {
Role::Assistant => open_ai::RequestMessage::Assistant { Role::User => open_ai::RequestMessage::User { content: text },
content: Some(text), Role::Assistant => open_ai::RequestMessage::Assistant {
tool_calls: Vec::new(), content: Some(text),
}, tool_calls: Vec::new(),
Role::System => open_ai::RequestMessage::System { content: text }, },
}), Role::System => open_ai::RequestMessage::System { content: text },
}),
MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {} MessageContent::Image(_) => {}
MessageContent::ToolUse(tool_use) => { MessageContent::ToolUse(tool_use) => {
let tool_call = open_ai::ToolCall { let tool_call = open_ai::ToolCall {