assistant: Add support for claude-3-7-sonnet-thinking
(#27085)
Closes #25671 Release Notes: - Added support for `claude-3-7-sonnet-thinking` in the assistant panel --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
parent
2ffce4f516
commit
a709d4c7c6
16 changed files with 1212 additions and 177 deletions
|
@ -29,7 +29,8 @@ use uuid::Uuid;
|
|||
|
||||
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
|
||||
use crate::thread_store::{
|
||||
SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
|
||||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||
SerializedToolUse,
|
||||
};
|
||||
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
|
||||
|
||||
|
@ -69,7 +70,47 @@ impl MessageId {
|
|||
pub struct Message {
|
||||
pub id: MessageId,
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
pub segments: Vec<MessageSegment>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn push_thinking(&mut self, text: &str) {
|
||||
if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
|
||||
segment.push_str(text);
|
||||
} else {
|
||||
self.segments
|
||||
.push(MessageSegment::Thinking(text.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_text(&mut self, text: &str) {
|
||||
if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
|
||||
segment.push_str(text);
|
||||
} else {
|
||||
self.segments.push(MessageSegment::Text(text.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> String {
|
||||
let mut result = String::new();
|
||||
for segment in &self.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => result.push_str(text),
|
||||
MessageSegment::Thinking(text) => {
|
||||
result.push_str("<think>");
|
||||
result.push_str(text);
|
||||
result.push_str("</think>");
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MessageSegment {
|
||||
Text(String),
|
||||
Thinking(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
@ -226,7 +267,16 @@ impl Thread {
|
|||
.map(|message| Message {
|
||||
id: message.id,
|
||||
role: message.role,
|
||||
text: message.text,
|
||||
segments: message
|
||||
.segments
|
||||
.into_iter()
|
||||
.map(|segment| match segment {
|
||||
SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
|
||||
SerializedMessageSegment::Thinking { text } => {
|
||||
MessageSegment::Thinking(text)
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect(),
|
||||
next_message_id,
|
||||
|
@ -419,7 +469,8 @@ impl Thread {
|
|||
checkpoint: Option<GitStoreCheckpoint>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
let message_id = self.insert_message(Role::User, text, cx);
|
||||
let message_id =
|
||||
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
|
||||
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
|
||||
self.context
|
||||
.extend(context.into_iter().map(|context| (context.id, context)));
|
||||
|
@ -433,15 +484,11 @@ impl Thread {
|
|||
pub fn insert_message(
|
||||
&mut self,
|
||||
role: Role,
|
||||
text: impl Into<String>,
|
||||
segments: Vec<MessageSegment>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
let id = self.next_message_id.post_inc();
|
||||
self.messages.push(Message {
|
||||
id,
|
||||
role,
|
||||
text: text.into(),
|
||||
});
|
||||
self.messages.push(Message { id, role, segments });
|
||||
self.touch_updated_at();
|
||||
cx.emit(ThreadEvent::MessageAdded(id));
|
||||
id
|
||||
|
@ -451,14 +498,14 @@ impl Thread {
|
|||
&mut self,
|
||||
id: MessageId,
|
||||
new_role: Role,
|
||||
new_text: String,
|
||||
new_segments: Vec<MessageSegment>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
|
||||
return false;
|
||||
};
|
||||
message.role = new_role;
|
||||
message.text = new_text;
|
||||
message.segments = new_segments;
|
||||
self.touch_updated_at();
|
||||
cx.emit(ThreadEvent::MessageEdited(id));
|
||||
true
|
||||
|
@ -489,7 +536,14 @@ impl Thread {
|
|||
});
|
||||
text.push('\n');
|
||||
|
||||
text.push_str(&message.text);
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(content) => text.push_str(content),
|
||||
MessageSegment::Thinking(content) => {
|
||||
text.push_str(&format!("<think>{}</think>", content))
|
||||
}
|
||||
}
|
||||
}
|
||||
text.push('\n');
|
||||
}
|
||||
|
||||
|
@ -502,6 +556,7 @@ impl Thread {
|
|||
cx.spawn(async move |this, cx| {
|
||||
let initial_project_snapshot = initial_project_snapshot.await;
|
||||
this.read_with(cx, |this, cx| SerializedThread {
|
||||
version: SerializedThread::VERSION.to_string(),
|
||||
summary: this.summary_or_default(),
|
||||
updated_at: this.updated_at(),
|
||||
messages: this
|
||||
|
@ -509,7 +564,18 @@ impl Thread {
|
|||
.map(|message| SerializedMessage {
|
||||
id: message.id,
|
||||
role: message.role,
|
||||
text: message.text.clone(),
|
||||
segments: message
|
||||
.segments
|
||||
.iter()
|
||||
.map(|segment| match segment {
|
||||
MessageSegment::Text(text) => {
|
||||
SerializedMessageSegment::Text { text: text.clone() }
|
||||
}
|
||||
MessageSegment::Thinking(text) => {
|
||||
SerializedMessageSegment::Thinking { text: text.clone() }
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
tool_uses: this
|
||||
.tool_uses_for_message(message.id, cx)
|
||||
.into_iter()
|
||||
|
@ -733,10 +799,10 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
if !message.text.is_empty() {
|
||||
if !message.segments.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(message.text.clone()));
|
||||
.push(MessageContent::Text(message.to_string()));
|
||||
}
|
||||
|
||||
match request_kind {
|
||||
|
@ -826,7 +892,11 @@ impl Thread {
|
|||
thread.update(cx, |thread, cx| {
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
thread.insert_message(Role::Assistant, String::new(), cx);
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Text(String::new())],
|
||||
cx,
|
||||
);
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
stop_reason = reason;
|
||||
|
@ -840,7 +910,7 @@ impl Thread {
|
|||
LanguageModelCompletionEvent::Text(chunk) => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
last_message.text.push_str(&chunk);
|
||||
last_message.push_text(&chunk);
|
||||
cx.emit(ThreadEvent::StreamedAssistantText(
|
||||
last_message.id,
|
||||
chunk,
|
||||
|
@ -851,7 +921,33 @@ impl Thread {
|
|||
//
|
||||
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
||||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||
thread.insert_message(Role::Assistant, chunk, cx);
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Text(chunk.to_string())],
|
||||
cx,
|
||||
);
|
||||
};
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::Thinking(chunk) => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
last_message.push_thinking(&chunk);
|
||||
cx.emit(ThreadEvent::StreamedAssistantThinking(
|
||||
last_message.id,
|
||||
chunk,
|
||||
));
|
||||
} else {
|
||||
// If we won't have an Assistant message yet, assume this chunk marks the beginning
|
||||
// of a new Assistant response.
|
||||
//
|
||||
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
||||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Thinking(chunk.to_string())],
|
||||
cx,
|
||||
);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -1357,7 +1453,14 @@ impl Thread {
|
|||
Role::System => "System",
|
||||
}
|
||||
)?;
|
||||
writeln!(markdown, "{}\n", message.text)?;
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
|
||||
MessageSegment::Thinking(text) => {
|
||||
writeln!(markdown, "<think>{}</think>\n", text)?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for tool_use in self.tool_uses_for_message(message.id, cx) {
|
||||
writeln!(
|
||||
|
@ -1416,6 +1519,7 @@ pub enum ThreadEvent {
|
|||
ShowError(ThreadError),
|
||||
StreamedCompletion,
|
||||
StreamedAssistantText(MessageId, String),
|
||||
StreamedAssistantThinking(MessageId, String),
|
||||
DoneStreaming,
|
||||
MessageAdded(MessageId),
|
||||
MessageEdited(MessageId),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue