assistant: Add support for claude-3-7-sonnet-thinking (#27085)

Closes #25671

Release Notes:

- Added support for `claude-3-7-sonnet-thinking` in the assistant panel

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
Bennet Bo Fenner 2025-03-21 13:29:07 +01:00 committed by GitHub
parent 2ffce4f516
commit a709d4c7c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1212 additions and 177 deletions

View file

@ -1,5 +1,6 @@
use crate::thread::{
LastRestoreCheckpoint, MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ThreadFeedback,
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
@ -7,10 +8,10 @@ use crate::ui::ContextPill;
use collections::HashMap;
use editor::{Editor, MultiBuffer};
use gpui::{
list, percentage, pulsating_between, AbsoluteLength, Animation, AnimationExt, AnyElement, App,
ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment,
ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement,
Transformation, UnderlineStyle, WeakEntity,
linear_color_stop, linear_gradient, list, percentage, pulsating_between, AbsoluteLength,
Animation, AnimationExt, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, ScrollHandle, StyleRefinement,
Subscription, Task, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@ -35,15 +36,175 @@ pub struct ActiveThread {
save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>,
list_state: ListState,
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
rendered_tool_use_labels: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
last_error: Option<ThreadError>,
_subscriptions: Vec<Subscription>,
}
struct RenderedMessage {
language_registry: Arc<LanguageRegistry>,
segments: Vec<RenderedMessageSegment>,
}
impl RenderedMessage {
fn from_segments(
segments: &[MessageSegment],
language_registry: Arc<LanguageRegistry>,
window: &Window,
cx: &mut App,
) -> Self {
let mut this = Self {
language_registry,
segments: Vec::with_capacity(segments.len()),
};
for segment in segments {
this.push_segment(segment, window, cx);
}
this
}
fn append_thinking(&mut self, text: &String, window: &Window, cx: &mut App) {
if let Some(RenderedMessageSegment::Thinking {
content,
scroll_handle,
}) = self.segments.last_mut()
{
content.update(cx, |markdown, cx| {
markdown.append(text, cx);
});
scroll_handle.scroll_to_bottom();
} else {
self.segments.push(RenderedMessageSegment::Thinking {
content: render_markdown(text.into(), self.language_registry.clone(), window, cx),
scroll_handle: ScrollHandle::default(),
});
}
}
fn append_text(&mut self, text: &String, window: &Window, cx: &mut App) {
if let Some(RenderedMessageSegment::Text(markdown)) = self.segments.last_mut() {
markdown.update(cx, |markdown, cx| markdown.append(text, cx));
} else {
self.segments
.push(RenderedMessageSegment::Text(render_markdown(
SharedString::from(text),
self.language_registry.clone(),
window,
cx,
)));
}
}
fn push_segment(&mut self, segment: &MessageSegment, window: &Window, cx: &mut App) {
let rendered_segment = match segment {
MessageSegment::Thinking(text) => RenderedMessageSegment::Thinking {
content: render_markdown(text.into(), self.language_registry.clone(), window, cx),
scroll_handle: ScrollHandle::default(),
},
MessageSegment::Text(text) => RenderedMessageSegment::Text(render_markdown(
text.into(),
self.language_registry.clone(),
window,
cx,
)),
};
self.segments.push(rendered_segment);
}
}
enum RenderedMessageSegment {
Thinking {
content: Entity<Markdown>,
scroll_handle: ScrollHandle,
},
Text(Entity<Markdown>),
}
fn render_markdown(
text: SharedString,
language_registry: Arc<LanguageRegistry>,
window: &Window,
cx: &mut App,
) -> Entity<Markdown> {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = TextSize::Small.rems(cx);
let mut text_style = window.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
font_features: Some(theme_settings.ui_font.features.clone()),
font_size: Some(ui_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
let markdown_style = MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
code_block: StyleRefinement {
margin: EdgesRefinement {
top: Some(Length::Definite(rems(0.).into())),
left: Some(Length::Definite(rems(0.).into())),
right: Some(Length::Definite(rems(0.).into())),
bottom: Some(Length::Definite(rems(0.5).into())),
},
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
},
background: Some(colors.editor_background.into()),
border_color: Some(colors.border_variant),
border_widths: EdgesRefinement {
top: Some(AbsoluteLength::Pixels(Pixels(1.))),
left: Some(AbsoluteLength::Pixels(Pixels(1.))),
right: Some(AbsoluteLength::Pixels(Pixels(1.))),
bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
},
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
background_color: Some(colors.editor_foreground.opacity(0.1)),
..Default::default()
},
link: TextStyleRefinement {
background_color: Some(colors.editor_foreground.opacity(0.025)),
underline: Some(UnderlineStyle {
color: Some(colors.text_accent.opacity(0.5)),
thickness: px(1.),
..Default::default()
}),
..Default::default()
},
..Default::default()
};
cx.new(|cx| Markdown::new(text, markdown_style, Some(language_registry), None, cx))
}
struct EditMessageState {
editor: Entity<Editor>,
}
@ -75,6 +236,7 @@ impl ActiveThread {
rendered_scripting_tool_uses: HashMap::default(),
rendered_tool_use_labels: HashMap::default(),
expanded_tool_uses: HashMap::default(),
expanded_thinking_segments: HashMap::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade();
move |ix, window: &mut Window, cx: &mut App| {
@ -88,7 +250,7 @@ impl ActiveThread {
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
this.push_message(&message.id, message.text.clone(), window, cx);
this.push_message(&message.id, &message.segments, window, cx);
for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
this.render_tool_use_label_markdown(
@ -156,7 +318,7 @@ impl ActiveThread {
fn push_message(
&mut self,
id: &MessageId,
text: String,
segments: &[MessageSegment],
window: &mut Window,
cx: &mut Context<Self>,
) {
@ -164,8 +326,9 @@ impl ActiveThread {
self.messages.push(*id);
self.list_state.splice(old_len..old_len, 1);
let markdown = self.render_markdown(text.into(), window, cx);
self.rendered_messages_by_id.insert(*id, markdown);
let rendered_message =
RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx);
self.rendered_messages_by_id.insert(*id, rendered_message);
self.list_state.scroll_to(ListOffset {
item_ix: old_len,
offset_in_item: Pixels(0.0),
@ -175,7 +338,7 @@ impl ActiveThread {
fn edited_message(
&mut self,
id: &MessageId,
text: String,
segments: &[MessageSegment],
window: &mut Window,
cx: &mut Context<Self>,
) {
@ -183,8 +346,9 @@ impl ActiveThread {
return;
};
self.list_state.splice(index..index + 1, 1);
let markdown = self.render_markdown(text.into(), window, cx);
self.rendered_messages_by_id.insert(*id, markdown);
let rendered_message =
RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx);
self.rendered_messages_by_id.insert(*id, rendered_message);
}
fn deleted_message(&mut self, id: &MessageId) {
@ -196,94 +360,6 @@ impl ActiveThread {
self.rendered_messages_by_id.remove(id);
}
fn render_markdown(
&self,
text: SharedString,
window: &Window,
cx: &mut Context<Self>,
) -> Entity<Markdown> {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = TextSize::Small.rems(cx);
let mut text_style = window.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
font_features: Some(theme_settings.ui_font.features.clone()),
font_size: Some(ui_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
let markdown_style = MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
code_block: StyleRefinement {
margin: EdgesRefinement {
top: Some(Length::Definite(rems(0.).into())),
left: Some(Length::Definite(rems(0.).into())),
right: Some(Length::Definite(rems(0.).into())),
bottom: Some(Length::Definite(rems(0.5).into())),
},
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
},
background: Some(colors.editor_background.into()),
border_color: Some(colors.border_variant),
border_widths: EdgesRefinement {
top: Some(AbsoluteLength::Pixels(Pixels(1.))),
left: Some(AbsoluteLength::Pixels(Pixels(1.))),
right: Some(AbsoluteLength::Pixels(Pixels(1.))),
bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
},
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
background_color: Some(colors.editor_foreground.opacity(0.1)),
..Default::default()
},
link: TextStyleRefinement {
background_color: Some(colors.editor_foreground.opacity(0.025)),
underline: Some(UnderlineStyle {
color: Some(colors.text_accent.opacity(0.5)),
thickness: px(1.),
..Default::default()
}),
..Default::default()
},
..Default::default()
};
cx.new(|cx| {
Markdown::new(
text,
markdown_style,
Some(self.language_registry.clone()),
None,
cx,
)
})
}
/// Renders the input of a scripting tool use to Markdown.
///
/// Does nothing if the tool use does not correspond to the scripting tool.
@ -303,8 +379,12 @@ impl ActiveThread {
.map(|input| input.lua_script)
.unwrap_or_default();
let lua_script =
self.render_markdown(format!("```lua\n{lua_script}\n```").into(), window, cx);
let lua_script = render_markdown(
format!("```lua\n{lua_script}\n```").into(),
self.language_registry.clone(),
window,
cx,
);
self.rendered_scripting_tool_uses
.insert(tool_use_id, lua_script);
@ -319,7 +399,12 @@ impl ActiveThread {
) {
self.rendered_tool_use_labels.insert(
tool_use_id,
self.render_markdown(tool_label.into(), window, cx),
render_markdown(
tool_label.into(),
self.language_registry.clone(),
window,
cx,
),
);
}
@ -339,33 +424,36 @@ impl ActiveThread {
}
ThreadEvent::DoneStreaming => {}
ThreadEvent::StreamedAssistantText(message_id, text) => {
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
markdown.update(cx, |markdown, cx| {
markdown.append(text, cx);
});
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
rendered_message.append_text(text, window, cx);
}
}
ThreadEvent::StreamedAssistantThinking(message_id, text) => {
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
rendered_message.append_thinking(text, window, cx);
}
}
ThreadEvent::MessageAdded(message_id) => {
if let Some(message_text) = self
if let Some(message_segments) = self
.thread
.read(cx)
.message(*message_id)
.map(|message| message.text.clone())
.map(|message| message.segments.clone())
{
self.push_message(message_id, message_text, window, cx);
self.push_message(message_id, &message_segments, window, cx);
}
self.save_thread(cx);
cx.notify();
}
ThreadEvent::MessageEdited(message_id) => {
if let Some(message_text) = self
if let Some(message_segments) = self
.thread
.read(cx)
.message(*message_id)
.map(|message| message.text.clone())
.map(|message| message.segments.clone())
{
self.edited_message(message_id, message_text, window, cx);
self.edited_message(message_id, &message_segments, window, cx);
}
self.save_thread(cx);
@ -490,10 +578,16 @@ impl ActiveThread {
fn start_editing_message(
&mut self,
message_id: MessageId,
message_text: String,
message_segments: &[MessageSegment],
window: &mut Window,
cx: &mut Context<Self>,
) {
// User message should always consist of a single text segment,
// therefore we can skip returning early if it's not a text segment.
let Some(MessageSegment::Text(message_text)) = message_segments.first() else {
return;
};
let buffer = cx.new(|cx| {
MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
});
@ -534,7 +628,12 @@ impl ActiveThread {
};
let edited_text = state.editor.read(cx).text(cx);
self.thread.update(cx, |thread, cx| {
thread.edit_message(message_id, Role::User, edited_text, cx);
thread.edit_message(
message_id,
Role::User,
vec![MessageSegment::Text(edited_text)],
cx,
);
for message_id in self.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
@ -617,7 +716,7 @@ impl ActiveThread {
return Empty.into_any();
};
let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any();
};
@ -759,7 +858,10 @@ impl ActiveThread {
.min_h_6()
.child(edit_message_editor)
} else {
div().min_h_6().text_ui(cx).child(markdown.clone())
div()
.min_h_6()
.text_ui(cx)
.child(self.render_message_content(message_id, rendered_message, cx))
},
)
.when_some(context, |parent, context| {
@ -869,11 +971,12 @@ impl ActiveThread {
Button::new("edit-message", "Edit")
.label_size(LabelSize::Small)
.on_click(cx.listener({
let message_text = message.text.clone();
let message_segments =
message.segments.clone();
move |this, _, window, cx| {
this.start_editing_message(
message_id,
message_text.clone(),
&message_segments,
window,
cx,
);
@ -995,6 +1098,190 @@ impl ActiveThread {
.into_any()
}
fn render_message_content(
&self,
message_id: MessageId,
rendered_message: &RenderedMessage,
cx: &Context<Self>,
) -> impl IntoElement {
let pending_thinking_segment_index = rendered_message
.segments
.iter()
.enumerate()
.last()
.filter(|(_, segment)| matches!(segment, RenderedMessageSegment::Thinking { .. }))
.map(|(index, _)| index);
div()
.text_ui(cx)
.gap_2()
.children(
rendered_message.segments.iter().enumerate().map(
|(index, segment)| match segment {
RenderedMessageSegment::Thinking {
content,
scroll_handle,
} => self
.render_message_thinking_segment(
message_id,
index,
content.clone(),
&scroll_handle,
Some(index) == pending_thinking_segment_index,
cx,
)
.into_any_element(),
RenderedMessageSegment::Text(markdown) => {
div().p_2p5().child(markdown.clone()).into_any_element()
}
},
),
)
}
fn render_message_thinking_segment(
&self,
message_id: MessageId,
ix: usize,
markdown: Entity<Markdown>,
scroll_handle: &ScrollHandle,
pending: bool,
cx: &Context<Self>,
) -> impl IntoElement {
let is_open = self
.expanded_thinking_segments
.get(&(message_id, ix))
.copied()
.unwrap_or_default();
let lighter_border = cx.theme().colors().border.opacity(0.5);
let editor_bg = cx.theme().colors().editor_background;
v_flex()
.rounded_lg()
.border_1()
.border_color(lighter_border)
.child(
h_flex()
.justify_between()
.py_1()
.pl_1()
.pr_2()
.bg(cx.theme().colors().editor_foreground.opacity(0.025))
.map(|this| {
if is_open {
this.rounded_t_md()
.border_b_1()
.border_color(lighter_border)
} else {
this.rounded_md()
}
})
.child(
h_flex()
.gap_1()
.child(Disclosure::new("thinking-disclosure", is_open).on_click(
cx.listener({
move |this, _event, _window, _cx| {
let is_open = this
.expanded_thinking_segments
.entry((message_id, ix))
.or_insert(false);
*is_open = !*is_open;
}
}),
))
.child({
if pending {
Label::new("Thinking…")
.size(LabelSize::Small)
.buffer_font(cx)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 0.8)),
|label, delta| label.alpha(delta),
)
.into_any_element()
} else {
Label::new("Thought Process")
.size(LabelSize::Small)
.buffer_font(cx)
.into_any_element()
}
}),
)
.child({
let (icon_name, color, animated) = if pending {
(IconName::ArrowCircle, Color::Accent, true)
} else {
(IconName::Check, Color::Success, false)
};
let icon = Icon::new(icon_name).color(color).size(IconSize::Small);
if animated {
icon.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(delta)))
},
)
.into_any_element()
} else {
icon.into_any_element()
}
}),
)
.when(pending && !is_open, |this| {
let gradient_overlay = div()
.rounded_b_lg()
.h_20()
.absolute()
.w_full()
.bottom_0()
.left_0()
.bg(linear_gradient(
180.,
linear_color_stop(editor_bg, 1.),
linear_color_stop(editor_bg.opacity(0.2), 0.),
));
this.child(
div()
.relative()
.bg(editor_bg)
.rounded_b_lg()
.text_ui_sm(cx)
.child(
div()
.id(("thinking-content", ix))
.p_2()
.h_20()
.track_scroll(scroll_handle)
.child(markdown.clone())
.overflow_hidden(),
)
.child(gradient_overlay),
)
})
.when(is_open, |this| {
this.child(
div()
.id(("thinking-content", ix))
.h_full()
.p_2()
.rounded_b_lg()
.bg(editor_bg)
.text_ui_sm(cx)
.child(markdown.clone()),
)
})
}
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
let is_open = self
.expanded_tool_uses
@ -1258,8 +1545,9 @@ impl ActiveThread {
}
}),
))
.child(div().text_ui_sm(cx).child(self.render_markdown(
.child(div().text_ui_sm(cx).child(render_markdown(
tool_use.ui_text.clone(),
self.language_registry.clone(),
window,
cx,
)))

View file

@ -29,7 +29,8 @@ use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
use crate::thread_store::{
SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
@ -69,7 +70,47 @@ impl MessageId {
pub struct Message {
pub id: MessageId,
pub role: Role,
pub text: String,
pub segments: Vec<MessageSegment>,
}
impl Message {
pub fn push_thinking(&mut self, text: &str) {
if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
segment.push_str(text);
} else {
self.segments
.push(MessageSegment::Thinking(text.to_string()));
}
}
pub fn push_text(&mut self, text: &str) {
if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
segment.push_str(text);
} else {
self.segments.push(MessageSegment::Text(text.to_string()));
}
}
pub fn to_string(&self) -> String {
let mut result = String::new();
for segment in &self.segments {
match segment {
MessageSegment::Text(text) => result.push_str(text),
MessageSegment::Thinking(text) => {
result.push_str("<think>");
result.push_str(text);
result.push_str("</think>");
}
}
}
result
}
}
#[derive(Debug, Clone)]
pub enum MessageSegment {
Text(String),
Thinking(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -226,7 +267,16 @@ impl Thread {
.map(|message| Message {
id: message.id,
role: message.role,
text: message.text,
segments: message
.segments
.into_iter()
.map(|segment| match segment {
SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
SerializedMessageSegment::Thinking { text } => {
MessageSegment::Thinking(text)
}
})
.collect(),
})
.collect(),
next_message_id,
@ -419,7 +469,8 @@ impl Thread {
checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
let message_id = self.insert_message(Role::User, text, cx);
let message_id =
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
self.context
.extend(context.into_iter().map(|context| (context.id, context)));
@ -433,15 +484,11 @@ impl Thread {
pub fn insert_message(
&mut self,
role: Role,
text: impl Into<String>,
segments: Vec<MessageSegment>,
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
self.messages.push(Message {
id,
role,
text: text.into(),
});
self.messages.push(Message { id, role, segments });
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
id
@ -451,14 +498,14 @@ impl Thread {
&mut self,
id: MessageId,
new_role: Role,
new_text: String,
new_segments: Vec<MessageSegment>,
cx: &mut Context<Self>,
) -> bool {
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
return false;
};
message.role = new_role;
message.text = new_text;
message.segments = new_segments;
self.touch_updated_at();
cx.emit(ThreadEvent::MessageEdited(id));
true
@ -489,7 +536,14 @@ impl Thread {
});
text.push('\n');
text.push_str(&message.text);
for segment in &message.segments {
match segment {
MessageSegment::Text(content) => text.push_str(content),
MessageSegment::Thinking(content) => {
text.push_str(&format!("<think>{}</think>", content))
}
}
}
text.push('\n');
}
@ -502,6 +556,7 @@ impl Thread {
cx.spawn(async move |this, cx| {
let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(cx, |this, cx| SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: this.summary_or_default(),
updated_at: this.updated_at(),
messages: this
@ -509,7 +564,18 @@ impl Thread {
.map(|message| SerializedMessage {
id: message.id,
role: message.role,
text: message.text.clone(),
segments: message
.segments
.iter()
.map(|segment| match segment {
MessageSegment::Text(text) => {
SerializedMessageSegment::Text { text: text.clone() }
}
MessageSegment::Thinking(text) => {
SerializedMessageSegment::Thinking { text: text.clone() }
}
})
.collect(),
tool_uses: this
.tool_uses_for_message(message.id, cx)
.into_iter()
@ -733,10 +799,10 @@ impl Thread {
}
}
if !message.text.is_empty() {
if !message.segments.is_empty() {
request_message
.content
.push(MessageContent::Text(message.text.clone()));
.push(MessageContent::Text(message.to_string()));
}
match request_kind {
@ -826,7 +892,11 @@ impl Thread {
thread.update(cx, |thread, cx| {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
thread.insert_message(Role::Assistant, String::new(), cx);
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text(String::new())],
cx,
);
}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
@ -840,7 +910,7 @@ impl Thread {
LanguageModelCompletionEvent::Text(chunk) => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk);
last_message.push_text(&chunk);
cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id,
chunk,
@ -851,7 +921,33 @@ impl Thread {
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message(Role::Assistant, chunk, cx);
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text(chunk.to_string())],
cx,
);
};
}
}
LanguageModelCompletionEvent::Thinking(chunk) => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.push_thinking(&chunk);
cx.emit(ThreadEvent::StreamedAssistantThinking(
last_message.id,
chunk,
));
} else {
// If we won't have an Assistant message yet, assume this chunk marks the beginning
// of a new Assistant response.
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Thinking(chunk.to_string())],
cx,
);
};
}
}
@ -1357,7 +1453,14 @@ impl Thread {
Role::System => "System",
}
)?;
writeln!(markdown, "{}\n", message.text)?;
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
MessageSegment::Thinking(text) => {
writeln!(markdown, "<think>{}</think>\n", text)?
}
}
}
for tool_use in self.tool_uses_for_message(message.id, cx) {
writeln!(
@ -1416,6 +1519,7 @@ pub enum ThreadEvent {
ShowError(ThreadError),
StreamedCompletion,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
DoneStreaming,
MessageAdded(MessageId),
MessageEdited(MessageId),

View file

@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::path::PathBuf;
use std::sync::Arc;
@ -12,7 +13,7 @@ use futures::FutureExt as _;
use gpui::{
prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
};
use heed::types::{SerdeBincode, SerdeJson};
use heed::types::SerdeBincode;
use heed::Database;
use language_model::{LanguageModelToolUseId, Role};
use project::Project;
@ -259,6 +260,7 @@ pub struct SerializedThreadMetadata {
#[derive(Serialize, Deserialize)]
pub struct SerializedThread {
pub version: String,
pub summary: SharedString,
pub updated_at: DateTime<Utc>,
pub messages: Vec<SerializedMessage>,
@ -266,17 +268,55 @@ pub struct SerializedThread {
pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
}
impl SerializedThread {
pub const VERSION: &'static str = "0.1.0";
pub fn from_json(json: &[u8]) -> Result<Self> {
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
match saved_thread_json.get("version") {
Some(serde_json::Value::String(version)) => match version.as_str() {
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
saved_thread_json,
)?),
_ => Err(anyhow!(
"unrecognized serialized thread version: {}",
version
)),
},
None => {
let saved_thread =
serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
Ok(saved_thread.upgrade())
}
version => Err(anyhow!(
"unrecognized serialized thread version: {:?}",
version
)),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedMessage {
pub id: MessageId,
pub role: Role,
pub text: String,
#[serde(default)]
pub segments: Vec<SerializedMessageSegment>,
#[serde(default)]
pub tool_uses: Vec<SerializedToolUse>,
#[serde(default)]
pub tool_results: Vec<SerializedToolResult>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum SerializedMessageSegment {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "thinking")]
Thinking { text: String },
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedToolUse {
pub id: LanguageModelToolUseId,
@ -291,6 +331,50 @@ pub struct SerializedToolResult {
pub content: Arc<str>,
}
#[derive(Serialize, Deserialize)]
struct LegacySerializedThread {
pub summary: SharedString,
pub updated_at: DateTime<Utc>,
pub messages: Vec<LegacySerializedMessage>,
#[serde(default)]
pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
}
impl LegacySerializedThread {
pub fn upgrade(self) -> SerializedThread {
SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: self.summary,
updated_at: self.updated_at,
messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
initial_project_snapshot: self.initial_project_snapshot,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct LegacySerializedMessage {
pub id: MessageId,
pub role: Role,
pub text: String,
#[serde(default)]
pub tool_uses: Vec<SerializedToolUse>,
#[serde(default)]
pub tool_results: Vec<SerializedToolResult>,
}
impl LegacySerializedMessage {
fn upgrade(self) -> SerializedMessage {
SerializedMessage {
id: self.id,
role: self.role,
segments: vec![SerializedMessageSegment::Text { text: self.text }],
tool_uses: self.tool_uses,
tool_results: self.tool_results,
}
}
}
struct GlobalThreadsDatabase(
Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
);
@ -300,7 +384,25 @@ impl Global for GlobalThreadsDatabase {}
pub(crate) struct ThreadsDatabase {
executor: BackgroundExecutor,
env: heed::Env,
threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
}
impl heed::BytesEncode<'_> for SerializedThread {
type EItem = SerializedThread;
fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
}
}
impl<'a> heed::BytesDecode<'a> for SerializedThread {
type DItem = SerializedThread;
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
// We implement this type manually because we want to call `SerializedThread::from_json`,
// instead of the Deserialize trait implementation for `SerializedThread`.
SerializedThread::from_json(bytes).map_err(Into::into)
}
}
impl ThreadsDatabase {