agent: Preserve thinking blocks between requests (#29055)
Looks like the required backend component of this was deployed. https://github.com/zed-industries/monorepo/actions/runs/14541199197 Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Agus Zubiaga <hi@aguz.me> Co-authored-by: Richard Feldman <oss@rtfeldman.com> Co-authored-by: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
f737c4d01e
commit
bafc086d27
13 changed files with 236 additions and 68 deletions
|
@ -133,18 +133,23 @@ impl RenderedMessage {
|
|||
}
|
||||
|
||||
fn push_segment(&mut self, segment: &MessageSegment, cx: &mut App) {
|
||||
let rendered_segment = match segment {
|
||||
MessageSegment::Thinking(text) => RenderedMessageSegment::Thinking {
|
||||
match segment {
|
||||
MessageSegment::Thinking { text, .. } => {
|
||||
self.segments.push(RenderedMessageSegment::Thinking {
|
||||
content: parse_markdown(text.into(), self.language_registry.clone(), cx),
|
||||
scroll_handle: ScrollHandle::default(),
|
||||
},
|
||||
MessageSegment::Text(text) => RenderedMessageSegment::Text(parse_markdown(
|
||||
})
|
||||
}
|
||||
MessageSegment::Text(text) => {
|
||||
self.segments
|
||||
.push(RenderedMessageSegment::Text(parse_markdown(
|
||||
text.into(),
|
||||
self.language_registry.clone(),
|
||||
cx,
|
||||
)),
|
||||
)))
|
||||
}
|
||||
MessageSegment::RedactedThinking(_) => {}
|
||||
};
|
||||
self.segments.push(rendered_segment);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -113,12 +113,21 @@ impl Message {
|
|||
self.segments.iter().all(|segment| segment.should_display())
|
||||
}
|
||||
|
||||
pub fn push_thinking(&mut self, text: &str) {
|
||||
if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
|
||||
pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
|
||||
if let Some(MessageSegment::Thinking {
|
||||
text: segment,
|
||||
signature: current_signature,
|
||||
}) = self.segments.last_mut()
|
||||
{
|
||||
if let Some(signature) = signature {
|
||||
*current_signature = Some(signature);
|
||||
}
|
||||
segment.push_str(text);
|
||||
} else {
|
||||
self.segments
|
||||
.push(MessageSegment::Thinking(text.to_string()));
|
||||
self.segments.push(MessageSegment::Thinking {
|
||||
text: text.to_string(),
|
||||
signature,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,11 +149,12 @@ impl Message {
|
|||
for segment in &self.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => result.push_str(text),
|
||||
MessageSegment::Thinking(text) => {
|
||||
result.push_str("<think>");
|
||||
MessageSegment::Thinking { text, .. } => {
|
||||
result.push_str("<think>\n");
|
||||
result.push_str(text);
|
||||
result.push_str("</think>");
|
||||
result.push_str("\n</think>");
|
||||
}
|
||||
MessageSegment::RedactedThinking(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -155,24 +165,22 @@ impl Message {
|
|||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum MessageSegment {
|
||||
Text(String),
|
||||
Thinking(String),
|
||||
Thinking {
|
||||
text: String,
|
||||
signature: Option<String>,
|
||||
},
|
||||
RedactedThinking(Vec<u8>),
|
||||
}
|
||||
|
||||
impl MessageSegment {
|
||||
pub fn text_mut(&mut self) -> &mut String {
|
||||
match self {
|
||||
Self::Text(text) => text,
|
||||
Self::Thinking(text) => text,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn should_display(&self) -> bool {
|
||||
// We add USING_TOOL_MARKER when making a request that includes tool uses
|
||||
// without non-whitespace text around them, and this can cause the model
|
||||
// to mimic the pattern, so we consider those segments not displayable.
|
||||
match self {
|
||||
Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
|
||||
Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
|
||||
Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
|
||||
Self::RedactedThinking(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -408,8 +416,11 @@ impl Thread {
|
|||
.into_iter()
|
||||
.map(|segment| match segment {
|
||||
SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
|
||||
SerializedMessageSegment::Thinking { text } => {
|
||||
MessageSegment::Thinking(text)
|
||||
SerializedMessageSegment::Thinking { text, signature } => {
|
||||
MessageSegment::Thinking { text, signature }
|
||||
}
|
||||
SerializedMessageSegment::RedactedThinking { data } => {
|
||||
MessageSegment::RedactedThinking(data)
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
|
@ -862,9 +873,10 @@ impl Thread {
|
|||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(content) => text.push_str(content),
|
||||
MessageSegment::Thinking(content) => {
|
||||
MessageSegment::Thinking { text: content, .. } => {
|
||||
text.push_str(&format!("<think>{}</think>", content))
|
||||
}
|
||||
MessageSegment::RedactedThinking(_) => {}
|
||||
}
|
||||
}
|
||||
text.push('\n');
|
||||
|
@ -894,8 +906,16 @@ impl Thread {
|
|||
MessageSegment::Text(text) => {
|
||||
SerializedMessageSegment::Text { text: text.clone() }
|
||||
}
|
||||
MessageSegment::Thinking(text) => {
|
||||
SerializedMessageSegment::Thinking { text: text.clone() }
|
||||
MessageSegment::Thinking { text, signature } => {
|
||||
SerializedMessageSegment::Thinking {
|
||||
text: text.clone(),
|
||||
signature: signature.clone(),
|
||||
}
|
||||
}
|
||||
MessageSegment::RedactedThinking(data) => {
|
||||
SerializedMessageSegment::RedactedThinking {
|
||||
data: data.clone(),
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
|
@ -1038,10 +1058,35 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
if !message.segments.is_empty() {
|
||||
if !message.context.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(message.to_string()));
|
||||
.push(MessageContent::Text(message.context.to_string()));
|
||||
}
|
||||
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(text.into()));
|
||||
}
|
||||
}
|
||||
MessageSegment::Thinking { text, signature } => {
|
||||
if !text.is_empty() {
|
||||
request_message.content.push(MessageContent::Thinking {
|
||||
text: text.into(),
|
||||
signature: signature.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageSegment::RedactedThinking(data) => {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::RedactedThinking(data.clone()));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
match request_kind {
|
||||
|
@ -1187,10 +1232,13 @@ impl Thread {
|
|||
};
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::Thinking(chunk) => {
|
||||
LanguageModelCompletionEvent::Thinking {
|
||||
text: chunk,
|
||||
signature,
|
||||
} => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
last_message.push_thinking(&chunk);
|
||||
last_message.push_thinking(&chunk, signature);
|
||||
cx.emit(ThreadEvent::StreamedAssistantThinking(
|
||||
last_message.id,
|
||||
chunk,
|
||||
|
@ -1203,7 +1251,10 @@ impl Thread {
|
|||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Thinking(chunk.to_string())],
|
||||
vec![MessageSegment::Thinking {
|
||||
text: chunk.to_string(),
|
||||
signature,
|
||||
}],
|
||||
cx,
|
||||
);
|
||||
};
|
||||
|
@ -1893,9 +1944,10 @@ impl Thread {
|
|||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
|
||||
MessageSegment::Thinking(text) => {
|
||||
writeln!(markdown, "<think>{}</think>\n", text)?
|
||||
MessageSegment::Thinking { text, .. } => {
|
||||
writeln!(markdown, "<think>\n{}\n</think>\n", text)?
|
||||
}
|
||||
MessageSegment::RedactedThinking(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -660,9 +660,18 @@ pub struct SerializedMessage {
|
|||
#[serde(tag = "type")]
|
||||
pub enum SerializedMessageSegment {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
#[serde(rename = "thinking")]
|
||||
Thinking { text: String },
|
||||
Thinking {
|
||||
text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
signature: Option<String>,
|
||||
},
|
||||
RedactedThinking {
|
||||
data: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
|
|
@ -507,6 +507,15 @@ pub enum RequestContent {
|
|||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
#[serde(rename = "thinking")]
|
||||
Thinking {
|
||||
thinking: String,
|
||||
signature: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
#[serde(rename = "redacted_thinking")]
|
||||
RedactedThinking { data: String },
|
||||
#[serde(rename = "image")]
|
||||
Image {
|
||||
source: ImageSource,
|
||||
|
|
|
@ -2373,7 +2373,7 @@ impl AssistantContext {
|
|||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
stop_reason = reason;
|
||||
}
|
||||
LanguageModelCompletionEvent::Thinking(chunk) => {
|
||||
LanguageModelCompletionEvent::Thinking { text: chunk, .. } => {
|
||||
if thought_process_stack.is_empty() {
|
||||
let start =
|
||||
buffer.anchor_before(message_old_end_offset);
|
||||
|
|
|
@ -916,6 +916,20 @@ impl RequestMarkdown {
|
|||
MessageContent::Image(_) => {
|
||||
messages.push_str("[IMAGE DATA]\n\n");
|
||||
}
|
||||
MessageContent::Thinking { text, signature } => {
|
||||
messages.push_str("**Thinking**:\n\n");
|
||||
if let Some(sig) = signature {
|
||||
messages.push_str(&format!("Signature: {}\n\n", sig));
|
||||
}
|
||||
messages.push_str(text);
|
||||
messages.push_str("\n");
|
||||
}
|
||||
MessageContent::RedactedThinking(items) => {
|
||||
messages.push_str(&format!(
|
||||
"**Redacted Thinking**: {} item(s)\n\n",
|
||||
items.len()
|
||||
));
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
messages.push_str(&format!(
|
||||
"**Tool Use**: {} (ID: {})\n",
|
||||
|
@ -970,7 +984,7 @@ fn response_events_to_markdown(
|
|||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
text_buffer.push_str(text);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Thinking(text)) => {
|
||||
Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => {
|
||||
thinking_buffer.push_str(text);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
|
||||
|
|
|
@ -65,9 +65,14 @@ pub struct LanguageModelCacheConfiguration {
|
|||
pub enum LanguageModelCompletionEvent {
|
||||
Stop(StopReason),
|
||||
Text(String),
|
||||
Thinking(String),
|
||||
Thinking {
|
||||
text: String,
|
||||
signature: Option<String>,
|
||||
},
|
||||
ToolUse(LanguageModelToolUse),
|
||||
StartMessage { message_id: String },
|
||||
StartMessage {
|
||||
message_id: String,
|
||||
},
|
||||
UsageUpdate(TokenUsage),
|
||||
}
|
||||
|
||||
|
@ -302,7 +307,7 @@ pub trait LanguageModel: Send + Sync {
|
|||
match result {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
||||
Ok(LanguageModelCompletionEvent::Thinking(_)) => None,
|
||||
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
|
|
|
@ -175,6 +175,11 @@ pub struct LanguageModelToolResult {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
pub enum MessageContent {
|
||||
Text(String),
|
||||
Thinking {
|
||||
text: String,
|
||||
signature: Option<String>,
|
||||
},
|
||||
RedactedThinking(Vec<u8>),
|
||||
Image(LanguageModelImage),
|
||||
ToolUse(LanguageModelToolUse),
|
||||
ToolResult(LanguageModelToolResult),
|
||||
|
@ -204,6 +209,8 @@ impl LanguageModelRequestMessage {
|
|||
let mut buffer = String::new();
|
||||
for string in self.content.iter().filter_map(|content| match content {
|
||||
MessageContent::Text(text) => Some(text.as_str()),
|
||||
MessageContent::Thinking { text, .. } => Some(text.as_str()),
|
||||
MessageContent::RedactedThinking(_) => None,
|
||||
MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()),
|
||||
MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
|
||||
}) {
|
||||
|
@ -220,10 +227,15 @@ impl LanguageModelRequestMessage {
|
|||
.first()
|
||||
.map(|content| match content {
|
||||
MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
|
||||
MessageContent::Thinking { text, .. } => {
|
||||
text.chars().all(|c| c.is_whitespace())
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
tool_result.content.chars().all(|c| c.is_whitespace())
|
||||
}
|
||||
MessageContent::ToolUse(_) | MessageContent::Image(_) => true,
|
||||
MessageContent::RedactedThinking(_)
|
||||
| MessageContent::ToolUse(_)
|
||||
| MessageContent::Image(_) => true,
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
|
|
@ -336,6 +336,12 @@ pub fn count_anthropic_tokens(
|
|||
MessageContent::Text(text) => {
|
||||
string_contents.push_str(&text);
|
||||
}
|
||||
MessageContent::Thinking { .. } => {
|
||||
// Thinking blocks are not included in the input token count.
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {
|
||||
// Thinking blocks are not included in the input token count.
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
tokens_from_images += image.estimate_tokens();
|
||||
}
|
||||
|
@ -515,6 +521,29 @@ pub fn into_anthropic(
|
|||
None
|
||||
}
|
||||
}
|
||||
MessageContent::Thinking {
|
||||
text: thinking,
|
||||
signature,
|
||||
} => {
|
||||
if !thinking.is_empty() {
|
||||
Some(anthropic::RequestContent::Thinking {
|
||||
thinking,
|
||||
signature: signature.unwrap_or_default(),
|
||||
cache_control,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::RedactedThinking(data) => {
|
||||
if !data.is_empty() {
|
||||
Some(anthropic::RequestContent::RedactedThinking {
|
||||
data: String::from_utf8(data).ok()?,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
|
||||
source: anthropic::ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
|
@ -637,7 +666,10 @@ pub fn map_to_language_model_completion_events(
|
|||
}
|
||||
ResponseContent::Thinking { thinking } => {
|
||||
return Some((
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))],
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: thinking,
|
||||
signature: None,
|
||||
})],
|
||||
state,
|
||||
));
|
||||
}
|
||||
|
@ -665,11 +697,22 @@ pub fn map_to_language_model_completion_events(
|
|||
}
|
||||
ContentDelta::ThinkingDelta { thinking } => {
|
||||
return Some((
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))],
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: thinking,
|
||||
signature: None,
|
||||
})],
|
||||
state,
|
||||
));
|
||||
}
|
||||
ContentDelta::SignatureDelta { signature } => {
|
||||
return Some((
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "".to_string(),
|
||||
signature: Some(signature),
|
||||
})],
|
||||
state,
|
||||
));
|
||||
}
|
||||
ContentDelta::SignatureDelta { .. } => {}
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
||||
tool_use.input_json.push_str(&partial_json);
|
||||
|
|
|
@ -742,9 +742,10 @@ pub fn get_bedrock_tokens(
|
|||
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
string_contents.push_str(&text);
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(image) => {
|
||||
tokens_from_images += image.estimate_tokens();
|
||||
}
|
||||
|
@ -830,25 +831,36 @@ pub fn map_to_language_model_completion_events(
|
|||
redacted,
|
||||
) => {
|
||||
let thinking_event =
|
||||
LanguageModelCompletionEvent::Thinking(
|
||||
String::from_utf8(
|
||||
LanguageModelCompletionEvent::Thinking {
|
||||
text: String::from_utf8(
|
||||
redacted.into_inner(),
|
||||
)
|
||||
.unwrap_or("REDACTED".to_string()),
|
||||
);
|
||||
signature: None,
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(thinking_event)),
|
||||
state,
|
||||
));
|
||||
}
|
||||
ReasoningContentBlockDelta::Signature(_sig) => {
|
||||
ReasoningContentBlockDelta::Signature(
|
||||
signature,
|
||||
) => {
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "".to_string(),
|
||||
signature: Some(signature)
|
||||
})),
|
||||
state,
|
||||
));
|
||||
}
|
||||
ReasoningContentBlockDelta::Text(thoughts) => {
|
||||
let thinking_event =
|
||||
LanguageModelCompletionEvent::Thinking(
|
||||
thoughts.to_string(),
|
||||
);
|
||||
LanguageModelCompletionEvent::Thinking {
|
||||
text: thoughts.to_string(),
|
||||
signature: None
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(thinking_event)),
|
||||
|
|
|
@ -424,8 +424,11 @@ impl CopilotChatLanguageModel {
|
|||
let text_content = {
|
||||
let mut buffer = String::new();
|
||||
for string in message.content.iter().filter_map(|content| match content {
|
||||
MessageContent::Text(text) => Some(text.as_str()),
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
Some(text.as_str())
|
||||
}
|
||||
MessageContent::ToolUse(_)
|
||||
| MessageContent::RedactedThinking(_)
|
||||
| MessageContent::ToolResult(_)
|
||||
| MessageContent::Image(_) => None,
|
||||
}) {
|
||||
|
|
|
@ -368,13 +368,15 @@ pub fn into_google(
|
|||
content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
language_model::MessageContent::Text(text) => {
|
||||
language_model::MessageContent::Text(text)
|
||||
| language_model::MessageContent::Thinking { text, .. } => {
|
||||
if !text.is_empty() {
|
||||
Some(Part::TextPart(google_ai::TextPart { text }))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
language_model::MessageContent::RedactedThinking(_) => None,
|
||||
language_model::MessageContent::Image(_) => None,
|
||||
language_model::MessageContent::ToolUse(tool_use) => {
|
||||
Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
|
||||
|
|
|
@ -342,7 +342,8 @@ pub fn into_open_ai(
|
|||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => messages.push(match message.role {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match message.role {
|
||||
Role::User => open_ai::RequestMessage::User { content: text },
|
||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||
content: Some(text),
|
||||
|
@ -350,6 +351,7 @@ pub fn into_open_ai(
|
|||
},
|
||||
Role::System => open_ai::RequestMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = open_ai::ToolCall {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue