assistant: Add support for claude-3-7-sonnet-thinking
(#27085)
Closes #25671 Release Notes: - Added support for `claude-3-7-sonnet-thinking` in the assistant panel --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
parent
2ffce4f516
commit
a709d4c7c6
16 changed files with 1212 additions and 177 deletions
|
@ -1,5 +1,6 @@
|
|||
use crate::thread::{
|
||||
LastRestoreCheckpoint, MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ThreadFeedback,
|
||||
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
|
||||
ThreadEvent, ThreadFeedback,
|
||||
};
|
||||
use crate::thread_store::ThreadStore;
|
||||
use crate::tool_use::{ToolUse, ToolUseStatus};
|
||||
|
@ -7,10 +8,10 @@ use crate::ui::ContextPill;
|
|||
use collections::HashMap;
|
||||
use editor::{Editor, MultiBuffer};
|
||||
use gpui::{
|
||||
list, percentage, pulsating_between, AbsoluteLength, Animation, AnimationExt, AnyElement, App,
|
||||
ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment,
|
||||
ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement,
|
||||
Transformation, UnderlineStyle, WeakEntity,
|
||||
linear_color_stop, linear_gradient, list, percentage, pulsating_between, AbsoluteLength,
|
||||
Animation, AnimationExt, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
|
||||
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, ScrollHandle, StyleRefinement,
|
||||
Subscription, Task, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity,
|
||||
};
|
||||
use language::{Buffer, LanguageRegistry};
|
||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||
|
@ -35,15 +36,175 @@ pub struct ActiveThread {
|
|||
save_thread_task: Option<Task<()>>,
|
||||
messages: Vec<MessageId>,
|
||||
list_state: ListState,
|
||||
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
|
||||
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
|
||||
rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
|
||||
rendered_tool_use_labels: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
|
||||
editing_message: Option<(MessageId, EditMessageState)>,
|
||||
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
|
||||
expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
|
||||
last_error: Option<ThreadError>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
struct RenderedMessage {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
segments: Vec<RenderedMessageSegment>,
|
||||
}
|
||||
|
||||
impl RenderedMessage {
|
||||
fn from_segments(
|
||||
segments: &[MessageSegment],
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
window: &Window,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let mut this = Self {
|
||||
language_registry,
|
||||
segments: Vec::with_capacity(segments.len()),
|
||||
};
|
||||
for segment in segments {
|
||||
this.push_segment(segment, window, cx);
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
fn append_thinking(&mut self, text: &String, window: &Window, cx: &mut App) {
|
||||
if let Some(RenderedMessageSegment::Thinking {
|
||||
content,
|
||||
scroll_handle,
|
||||
}) = self.segments.last_mut()
|
||||
{
|
||||
content.update(cx, |markdown, cx| {
|
||||
markdown.append(text, cx);
|
||||
});
|
||||
scroll_handle.scroll_to_bottom();
|
||||
} else {
|
||||
self.segments.push(RenderedMessageSegment::Thinking {
|
||||
content: render_markdown(text.into(), self.language_registry.clone(), window, cx),
|
||||
scroll_handle: ScrollHandle::default(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn append_text(&mut self, text: &String, window: &Window, cx: &mut App) {
|
||||
if let Some(RenderedMessageSegment::Text(markdown)) = self.segments.last_mut() {
|
||||
markdown.update(cx, |markdown, cx| markdown.append(text, cx));
|
||||
} else {
|
||||
self.segments
|
||||
.push(RenderedMessageSegment::Text(render_markdown(
|
||||
SharedString::from(text),
|
||||
self.language_registry.clone(),
|
||||
window,
|
||||
cx,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
fn push_segment(&mut self, segment: &MessageSegment, window: &Window, cx: &mut App) {
|
||||
let rendered_segment = match segment {
|
||||
MessageSegment::Thinking(text) => RenderedMessageSegment::Thinking {
|
||||
content: render_markdown(text.into(), self.language_registry.clone(), window, cx),
|
||||
scroll_handle: ScrollHandle::default(),
|
||||
},
|
||||
MessageSegment::Text(text) => RenderedMessageSegment::Text(render_markdown(
|
||||
text.into(),
|
||||
self.language_registry.clone(),
|
||||
window,
|
||||
cx,
|
||||
)),
|
||||
};
|
||||
self.segments.push(rendered_segment);
|
||||
}
|
||||
}
|
||||
|
||||
enum RenderedMessageSegment {
|
||||
Thinking {
|
||||
content: Entity<Markdown>,
|
||||
scroll_handle: ScrollHandle,
|
||||
},
|
||||
Text(Entity<Markdown>),
|
||||
}
|
||||
|
||||
fn render_markdown(
|
||||
text: SharedString,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
window: &Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<Markdown> {
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
let colors = cx.theme().colors();
|
||||
let ui_font_size = TextSize::Default.rems(cx);
|
||||
let buffer_font_size = TextSize::Small.rems(cx);
|
||||
let mut text_style = window.text_style();
|
||||
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(theme_settings.ui_font.family.clone()),
|
||||
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.ui_font.features.clone()),
|
||||
font_size: Some(ui_font_size.into()),
|
||||
color: Some(cx.theme().colors().text),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
let markdown_style = MarkdownStyle {
|
||||
base_text_style: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
selection_background_color: cx.theme().players().local().selection,
|
||||
code_block_overflow_x_scroll: true,
|
||||
table_overflow_x_scroll: true,
|
||||
code_block: StyleRefinement {
|
||||
margin: EdgesRefinement {
|
||||
top: Some(Length::Definite(rems(0.).into())),
|
||||
left: Some(Length::Definite(rems(0.).into())),
|
||||
right: Some(Length::Definite(rems(0.).into())),
|
||||
bottom: Some(Length::Definite(rems(0.5).into())),
|
||||
},
|
||||
padding: EdgesRefinement {
|
||||
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
},
|
||||
background: Some(colors.editor_background.into()),
|
||||
border_color: Some(colors.border_variant),
|
||||
border_widths: EdgesRefinement {
|
||||
top: Some(AbsoluteLength::Pixels(Pixels(1.))),
|
||||
left: Some(AbsoluteLength::Pixels(Pixels(1.))),
|
||||
right: Some(AbsoluteLength::Pixels(Pixels(1.))),
|
||||
bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
|
||||
},
|
||||
text: Some(TextStyleRefinement {
|
||||
font_family: Some(theme_settings.buffer_font.family.clone()),
|
||||
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.buffer_font.features.clone()),
|
||||
font_size: Some(buffer_font_size.into()),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
inline_code: TextStyleRefinement {
|
||||
font_family: Some(theme_settings.buffer_font.family.clone()),
|
||||
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.buffer_font.features.clone()),
|
||||
font_size: Some(buffer_font_size.into()),
|
||||
background_color: Some(colors.editor_foreground.opacity(0.1)),
|
||||
..Default::default()
|
||||
},
|
||||
link: TextStyleRefinement {
|
||||
background_color: Some(colors.editor_foreground.opacity(0.025)),
|
||||
underline: Some(UnderlineStyle {
|
||||
color: Some(colors.text_accent.opacity(0.5)),
|
||||
thickness: px(1.),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
cx.new(|cx| Markdown::new(text, markdown_style, Some(language_registry), None, cx))
|
||||
}
|
||||
|
||||
struct EditMessageState {
|
||||
editor: Entity<Editor>,
|
||||
}
|
||||
|
@ -75,6 +236,7 @@ impl ActiveThread {
|
|||
rendered_scripting_tool_uses: HashMap::default(),
|
||||
rendered_tool_use_labels: HashMap::default(),
|
||||
expanded_tool_uses: HashMap::default(),
|
||||
expanded_thinking_segments: HashMap::default(),
|
||||
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
|
||||
let this = cx.entity().downgrade();
|
||||
move |ix, window: &mut Window, cx: &mut App| {
|
||||
|
@ -88,7 +250,7 @@ impl ActiveThread {
|
|||
};
|
||||
|
||||
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
|
||||
this.push_message(&message.id, message.text.clone(), window, cx);
|
||||
this.push_message(&message.id, &message.segments, window, cx);
|
||||
|
||||
for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
|
||||
this.render_tool_use_label_markdown(
|
||||
|
@ -156,7 +318,7 @@ impl ActiveThread {
|
|||
fn push_message(
|
||||
&mut self,
|
||||
id: &MessageId,
|
||||
text: String,
|
||||
segments: &[MessageSegment],
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
|
@ -164,8 +326,9 @@ impl ActiveThread {
|
|||
self.messages.push(*id);
|
||||
self.list_state.splice(old_len..old_len, 1);
|
||||
|
||||
let markdown = self.render_markdown(text.into(), window, cx);
|
||||
self.rendered_messages_by_id.insert(*id, markdown);
|
||||
let rendered_message =
|
||||
RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx);
|
||||
self.rendered_messages_by_id.insert(*id, rendered_message);
|
||||
self.list_state.scroll_to(ListOffset {
|
||||
item_ix: old_len,
|
||||
offset_in_item: Pixels(0.0),
|
||||
|
@ -175,7 +338,7 @@ impl ActiveThread {
|
|||
fn edited_message(
|
||||
&mut self,
|
||||
id: &MessageId,
|
||||
text: String,
|
||||
segments: &[MessageSegment],
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
|
@ -183,8 +346,9 @@ impl ActiveThread {
|
|||
return;
|
||||
};
|
||||
self.list_state.splice(index..index + 1, 1);
|
||||
let markdown = self.render_markdown(text.into(), window, cx);
|
||||
self.rendered_messages_by_id.insert(*id, markdown);
|
||||
let rendered_message =
|
||||
RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx);
|
||||
self.rendered_messages_by_id.insert(*id, rendered_message);
|
||||
}
|
||||
|
||||
fn deleted_message(&mut self, id: &MessageId) {
|
||||
|
@ -196,94 +360,6 @@ impl ActiveThread {
|
|||
self.rendered_messages_by_id.remove(id);
|
||||
}
|
||||
|
||||
fn render_markdown(
|
||||
&self,
|
||||
text: SharedString,
|
||||
window: &Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Entity<Markdown> {
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
let colors = cx.theme().colors();
|
||||
let ui_font_size = TextSize::Default.rems(cx);
|
||||
let buffer_font_size = TextSize::Small.rems(cx);
|
||||
let mut text_style = window.text_style();
|
||||
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(theme_settings.ui_font.family.clone()),
|
||||
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.ui_font.features.clone()),
|
||||
font_size: Some(ui_font_size.into()),
|
||||
color: Some(cx.theme().colors().text),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
let markdown_style = MarkdownStyle {
|
||||
base_text_style: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
selection_background_color: cx.theme().players().local().selection,
|
||||
code_block_overflow_x_scroll: true,
|
||||
table_overflow_x_scroll: true,
|
||||
code_block: StyleRefinement {
|
||||
margin: EdgesRefinement {
|
||||
top: Some(Length::Definite(rems(0.).into())),
|
||||
left: Some(Length::Definite(rems(0.).into())),
|
||||
right: Some(Length::Definite(rems(0.).into())),
|
||||
bottom: Some(Length::Definite(rems(0.5).into())),
|
||||
},
|
||||
padding: EdgesRefinement {
|
||||
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
},
|
||||
background: Some(colors.editor_background.into()),
|
||||
border_color: Some(colors.border_variant),
|
||||
border_widths: EdgesRefinement {
|
||||
top: Some(AbsoluteLength::Pixels(Pixels(1.))),
|
||||
left: Some(AbsoluteLength::Pixels(Pixels(1.))),
|
||||
right: Some(AbsoluteLength::Pixels(Pixels(1.))),
|
||||
bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
|
||||
},
|
||||
text: Some(TextStyleRefinement {
|
||||
font_family: Some(theme_settings.buffer_font.family.clone()),
|
||||
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.buffer_font.features.clone()),
|
||||
font_size: Some(buffer_font_size.into()),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
inline_code: TextStyleRefinement {
|
||||
font_family: Some(theme_settings.buffer_font.family.clone()),
|
||||
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.buffer_font.features.clone()),
|
||||
font_size: Some(buffer_font_size.into()),
|
||||
background_color: Some(colors.editor_foreground.opacity(0.1)),
|
||||
..Default::default()
|
||||
},
|
||||
link: TextStyleRefinement {
|
||||
background_color: Some(colors.editor_foreground.opacity(0.025)),
|
||||
underline: Some(UnderlineStyle {
|
||||
color: Some(colors.text_accent.opacity(0.5)),
|
||||
thickness: px(1.),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
cx.new(|cx| {
|
||||
Markdown::new(
|
||||
text,
|
||||
markdown_style,
|
||||
Some(self.language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Renders the input of a scripting tool use to Markdown.
|
||||
///
|
||||
/// Does nothing if the tool use does not correspond to the scripting tool.
|
||||
|
@ -303,8 +379,12 @@ impl ActiveThread {
|
|||
.map(|input| input.lua_script)
|
||||
.unwrap_or_default();
|
||||
|
||||
let lua_script =
|
||||
self.render_markdown(format!("```lua\n{lua_script}\n```").into(), window, cx);
|
||||
let lua_script = render_markdown(
|
||||
format!("```lua\n{lua_script}\n```").into(),
|
||||
self.language_registry.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
self.rendered_scripting_tool_uses
|
||||
.insert(tool_use_id, lua_script);
|
||||
|
@ -319,7 +399,12 @@ impl ActiveThread {
|
|||
) {
|
||||
self.rendered_tool_use_labels.insert(
|
||||
tool_use_id,
|
||||
self.render_markdown(tool_label.into(), window, cx),
|
||||
render_markdown(
|
||||
tool_label.into(),
|
||||
self.language_registry.clone(),
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -339,33 +424,36 @@ impl ActiveThread {
|
|||
}
|
||||
ThreadEvent::DoneStreaming => {}
|
||||
ThreadEvent::StreamedAssistantText(message_id, text) => {
|
||||
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
|
||||
markdown.update(cx, |markdown, cx| {
|
||||
markdown.append(text, cx);
|
||||
});
|
||||
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
|
||||
rendered_message.append_text(text, window, cx);
|
||||
}
|
||||
}
|
||||
ThreadEvent::StreamedAssistantThinking(message_id, text) => {
|
||||
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
|
||||
rendered_message.append_thinking(text, window, cx);
|
||||
}
|
||||
}
|
||||
ThreadEvent::MessageAdded(message_id) => {
|
||||
if let Some(message_text) = self
|
||||
if let Some(message_segments) = self
|
||||
.thread
|
||||
.read(cx)
|
||||
.message(*message_id)
|
||||
.map(|message| message.text.clone())
|
||||
.map(|message| message.segments.clone())
|
||||
{
|
||||
self.push_message(message_id, message_text, window, cx);
|
||||
self.push_message(message_id, &message_segments, window, cx);
|
||||
}
|
||||
|
||||
self.save_thread(cx);
|
||||
cx.notify();
|
||||
}
|
||||
ThreadEvent::MessageEdited(message_id) => {
|
||||
if let Some(message_text) = self
|
||||
if let Some(message_segments) = self
|
||||
.thread
|
||||
.read(cx)
|
||||
.message(*message_id)
|
||||
.map(|message| message.text.clone())
|
||||
.map(|message| message.segments.clone())
|
||||
{
|
||||
self.edited_message(message_id, message_text, window, cx);
|
||||
self.edited_message(message_id, &message_segments, window, cx);
|
||||
}
|
||||
|
||||
self.save_thread(cx);
|
||||
|
@ -490,10 +578,16 @@ impl ActiveThread {
|
|||
fn start_editing_message(
|
||||
&mut self,
|
||||
message_id: MessageId,
|
||||
message_text: String,
|
||||
message_segments: &[MessageSegment],
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
// User message should always consist of a single text segment,
|
||||
// therefore we can skip returning early if it's not a text segment.
|
||||
let Some(MessageSegment::Text(message_text)) = message_segments.first() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let buffer = cx.new(|cx| {
|
||||
MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
|
||||
});
|
||||
|
@ -534,7 +628,12 @@ impl ActiveThread {
|
|||
};
|
||||
let edited_text = state.editor.read(cx).text(cx);
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.edit_message(message_id, Role::User, edited_text, cx);
|
||||
thread.edit_message(
|
||||
message_id,
|
||||
Role::User,
|
||||
vec![MessageSegment::Text(edited_text)],
|
||||
cx,
|
||||
);
|
||||
for message_id in self.messages_after(message_id) {
|
||||
thread.delete_message(*message_id, cx);
|
||||
}
|
||||
|
@ -617,7 +716,7 @@ impl ActiveThread {
|
|||
return Empty.into_any();
|
||||
};
|
||||
|
||||
let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
|
||||
let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
|
||||
return Empty.into_any();
|
||||
};
|
||||
|
||||
|
@ -759,7 +858,10 @@ impl ActiveThread {
|
|||
.min_h_6()
|
||||
.child(edit_message_editor)
|
||||
} else {
|
||||
div().min_h_6().text_ui(cx).child(markdown.clone())
|
||||
div()
|
||||
.min_h_6()
|
||||
.text_ui(cx)
|
||||
.child(self.render_message_content(message_id, rendered_message, cx))
|
||||
},
|
||||
)
|
||||
.when_some(context, |parent, context| {
|
||||
|
@ -869,11 +971,12 @@ impl ActiveThread {
|
|||
Button::new("edit-message", "Edit")
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let message_text = message.text.clone();
|
||||
let message_segments =
|
||||
message.segments.clone();
|
||||
move |this, _, window, cx| {
|
||||
this.start_editing_message(
|
||||
message_id,
|
||||
message_text.clone(),
|
||||
&message_segments,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
@ -995,6 +1098,190 @@ impl ActiveThread {
|
|||
.into_any()
|
||||
}
|
||||
|
||||
fn render_message_content(
|
||||
&self,
|
||||
message_id: MessageId,
|
||||
rendered_message: &RenderedMessage,
|
||||
cx: &Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let pending_thinking_segment_index = rendered_message
|
||||
.segments
|
||||
.iter()
|
||||
.enumerate()
|
||||
.last()
|
||||
.filter(|(_, segment)| matches!(segment, RenderedMessageSegment::Thinking { .. }))
|
||||
.map(|(index, _)| index);
|
||||
|
||||
div()
|
||||
.text_ui(cx)
|
||||
.gap_2()
|
||||
.children(
|
||||
rendered_message.segments.iter().enumerate().map(
|
||||
|(index, segment)| match segment {
|
||||
RenderedMessageSegment::Thinking {
|
||||
content,
|
||||
scroll_handle,
|
||||
} => self
|
||||
.render_message_thinking_segment(
|
||||
message_id,
|
||||
index,
|
||||
content.clone(),
|
||||
&scroll_handle,
|
||||
Some(index) == pending_thinking_segment_index,
|
||||
cx,
|
||||
)
|
||||
.into_any_element(),
|
||||
RenderedMessageSegment::Text(markdown) => {
|
||||
div().p_2p5().child(markdown.clone()).into_any_element()
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_message_thinking_segment(
|
||||
&self,
|
||||
message_id: MessageId,
|
||||
ix: usize,
|
||||
markdown: Entity<Markdown>,
|
||||
scroll_handle: &ScrollHandle,
|
||||
pending: bool,
|
||||
cx: &Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let is_open = self
|
||||
.expanded_thinking_segments
|
||||
.get(&(message_id, ix))
|
||||
.copied()
|
||||
.unwrap_or_default();
|
||||
|
||||
let lighter_border = cx.theme().colors().border.opacity(0.5);
|
||||
let editor_bg = cx.theme().colors().editor_background;
|
||||
|
||||
v_flex()
|
||||
.rounded_lg()
|
||||
.border_1()
|
||||
.border_color(lighter_border)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_between()
|
||||
.py_1()
|
||||
.pl_1()
|
||||
.pr_2()
|
||||
.bg(cx.theme().colors().editor_foreground.opacity(0.025))
|
||||
.map(|this| {
|
||||
if is_open {
|
||||
this.rounded_t_md()
|
||||
.border_b_1()
|
||||
.border_color(lighter_border)
|
||||
} else {
|
||||
this.rounded_md()
|
||||
}
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Disclosure::new("thinking-disclosure", is_open).on_click(
|
||||
cx.listener({
|
||||
move |this, _event, _window, _cx| {
|
||||
let is_open = this
|
||||
.expanded_thinking_segments
|
||||
.entry((message_id, ix))
|
||||
.or_insert(false);
|
||||
|
||||
*is_open = !*is_open;
|
||||
}
|
||||
}),
|
||||
))
|
||||
.child({
|
||||
if pending {
|
||||
Label::new("Thinking…")
|
||||
.size(LabelSize::Small)
|
||||
.buffer_font(cx)
|
||||
.with_animation(
|
||||
"pulsating-label",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 0.8)),
|
||||
|label, delta| label.alpha(delta),
|
||||
)
|
||||
.into_any_element()
|
||||
} else {
|
||||
Label::new("Thought Process")
|
||||
.size(LabelSize::Small)
|
||||
.buffer_font(cx)
|
||||
.into_any_element()
|
||||
}
|
||||
}),
|
||||
)
|
||||
.child({
|
||||
let (icon_name, color, animated) = if pending {
|
||||
(IconName::ArrowCircle, Color::Accent, true)
|
||||
} else {
|
||||
(IconName::Check, Color::Success, false)
|
||||
};
|
||||
|
||||
let icon = Icon::new(icon_name).color(color).size(IconSize::Small);
|
||||
|
||||
if animated {
|
||||
icon.with_animation(
|
||||
"arrow-circle",
|
||||
Animation::new(Duration::from_secs(2)).repeat(),
|
||||
|icon, delta| {
|
||||
icon.transform(Transformation::rotate(percentage(delta)))
|
||||
},
|
||||
)
|
||||
.into_any_element()
|
||||
} else {
|
||||
icon.into_any_element()
|
||||
}
|
||||
}),
|
||||
)
|
||||
.when(pending && !is_open, |this| {
|
||||
let gradient_overlay = div()
|
||||
.rounded_b_lg()
|
||||
.h_20()
|
||||
.absolute()
|
||||
.w_full()
|
||||
.bottom_0()
|
||||
.left_0()
|
||||
.bg(linear_gradient(
|
||||
180.,
|
||||
linear_color_stop(editor_bg, 1.),
|
||||
linear_color_stop(editor_bg.opacity(0.2), 0.),
|
||||
));
|
||||
|
||||
this.child(
|
||||
div()
|
||||
.relative()
|
||||
.bg(editor_bg)
|
||||
.rounded_b_lg()
|
||||
.text_ui_sm(cx)
|
||||
.child(
|
||||
div()
|
||||
.id(("thinking-content", ix))
|
||||
.p_2()
|
||||
.h_20()
|
||||
.track_scroll(scroll_handle)
|
||||
.child(markdown.clone())
|
||||
.overflow_hidden(),
|
||||
)
|
||||
.child(gradient_overlay),
|
||||
)
|
||||
})
|
||||
.when(is_open, |this| {
|
||||
this.child(
|
||||
div()
|
||||
.id(("thinking-content", ix))
|
||||
.h_full()
|
||||
.p_2()
|
||||
.rounded_b_lg()
|
||||
.bg(editor_bg)
|
||||
.text_ui_sm(cx)
|
||||
.child(markdown.clone()),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let is_open = self
|
||||
.expanded_tool_uses
|
||||
|
@ -1258,8 +1545,9 @@ impl ActiveThread {
|
|||
}
|
||||
}),
|
||||
))
|
||||
.child(div().text_ui_sm(cx).child(self.render_markdown(
|
||||
.child(div().text_ui_sm(cx).child(render_markdown(
|
||||
tool_use.ui_text.clone(),
|
||||
self.language_registry.clone(),
|
||||
window,
|
||||
cx,
|
||||
)))
|
||||
|
|
|
@ -29,7 +29,8 @@ use uuid::Uuid;
|
|||
|
||||
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
|
||||
use crate::thread_store::{
|
||||
SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
|
||||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||
SerializedToolUse,
|
||||
};
|
||||
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
|
||||
|
||||
|
@ -69,7 +70,47 @@ impl MessageId {
|
|||
pub struct Message {
|
||||
pub id: MessageId,
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
pub segments: Vec<MessageSegment>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn push_thinking(&mut self, text: &str) {
|
||||
if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
|
||||
segment.push_str(text);
|
||||
} else {
|
||||
self.segments
|
||||
.push(MessageSegment::Thinking(text.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_text(&mut self, text: &str) {
|
||||
if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
|
||||
segment.push_str(text);
|
||||
} else {
|
||||
self.segments.push(MessageSegment::Text(text.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> String {
|
||||
let mut result = String::new();
|
||||
for segment in &self.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => result.push_str(text),
|
||||
MessageSegment::Thinking(text) => {
|
||||
result.push_str("<think>");
|
||||
result.push_str(text);
|
||||
result.push_str("</think>");
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MessageSegment {
|
||||
Text(String),
|
||||
Thinking(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
@ -226,7 +267,16 @@ impl Thread {
|
|||
.map(|message| Message {
|
||||
id: message.id,
|
||||
role: message.role,
|
||||
text: message.text,
|
||||
segments: message
|
||||
.segments
|
||||
.into_iter()
|
||||
.map(|segment| match segment {
|
||||
SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
|
||||
SerializedMessageSegment::Thinking { text } => {
|
||||
MessageSegment::Thinking(text)
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect(),
|
||||
next_message_id,
|
||||
|
@ -419,7 +469,8 @@ impl Thread {
|
|||
checkpoint: Option<GitStoreCheckpoint>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
let message_id = self.insert_message(Role::User, text, cx);
|
||||
let message_id =
|
||||
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
|
||||
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
|
||||
self.context
|
||||
.extend(context.into_iter().map(|context| (context.id, context)));
|
||||
|
@ -433,15 +484,11 @@ impl Thread {
|
|||
pub fn insert_message(
|
||||
&mut self,
|
||||
role: Role,
|
||||
text: impl Into<String>,
|
||||
segments: Vec<MessageSegment>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
let id = self.next_message_id.post_inc();
|
||||
self.messages.push(Message {
|
||||
id,
|
||||
role,
|
||||
text: text.into(),
|
||||
});
|
||||
self.messages.push(Message { id, role, segments });
|
||||
self.touch_updated_at();
|
||||
cx.emit(ThreadEvent::MessageAdded(id));
|
||||
id
|
||||
|
@ -451,14 +498,14 @@ impl Thread {
|
|||
&mut self,
|
||||
id: MessageId,
|
||||
new_role: Role,
|
||||
new_text: String,
|
||||
new_segments: Vec<MessageSegment>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
|
||||
return false;
|
||||
};
|
||||
message.role = new_role;
|
||||
message.text = new_text;
|
||||
message.segments = new_segments;
|
||||
self.touch_updated_at();
|
||||
cx.emit(ThreadEvent::MessageEdited(id));
|
||||
true
|
||||
|
@ -489,7 +536,14 @@ impl Thread {
|
|||
});
|
||||
text.push('\n');
|
||||
|
||||
text.push_str(&message.text);
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(content) => text.push_str(content),
|
||||
MessageSegment::Thinking(content) => {
|
||||
text.push_str(&format!("<think>{}</think>", content))
|
||||
}
|
||||
}
|
||||
}
|
||||
text.push('\n');
|
||||
}
|
||||
|
||||
|
@ -502,6 +556,7 @@ impl Thread {
|
|||
cx.spawn(async move |this, cx| {
|
||||
let initial_project_snapshot = initial_project_snapshot.await;
|
||||
this.read_with(cx, |this, cx| SerializedThread {
|
||||
version: SerializedThread::VERSION.to_string(),
|
||||
summary: this.summary_or_default(),
|
||||
updated_at: this.updated_at(),
|
||||
messages: this
|
||||
|
@ -509,7 +564,18 @@ impl Thread {
|
|||
.map(|message| SerializedMessage {
|
||||
id: message.id,
|
||||
role: message.role,
|
||||
text: message.text.clone(),
|
||||
segments: message
|
||||
.segments
|
||||
.iter()
|
||||
.map(|segment| match segment {
|
||||
MessageSegment::Text(text) => {
|
||||
SerializedMessageSegment::Text { text: text.clone() }
|
||||
}
|
||||
MessageSegment::Thinking(text) => {
|
||||
SerializedMessageSegment::Thinking { text: text.clone() }
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
tool_uses: this
|
||||
.tool_uses_for_message(message.id, cx)
|
||||
.into_iter()
|
||||
|
@ -733,10 +799,10 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
if !message.text.is_empty() {
|
||||
if !message.segments.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(message.text.clone()));
|
||||
.push(MessageContent::Text(message.to_string()));
|
||||
}
|
||||
|
||||
match request_kind {
|
||||
|
@ -826,7 +892,11 @@ impl Thread {
|
|||
thread.update(cx, |thread, cx| {
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
thread.insert_message(Role::Assistant, String::new(), cx);
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Text(String::new())],
|
||||
cx,
|
||||
);
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
stop_reason = reason;
|
||||
|
@ -840,7 +910,7 @@ impl Thread {
|
|||
LanguageModelCompletionEvent::Text(chunk) => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
last_message.text.push_str(&chunk);
|
||||
last_message.push_text(&chunk);
|
||||
cx.emit(ThreadEvent::StreamedAssistantText(
|
||||
last_message.id,
|
||||
chunk,
|
||||
|
@ -851,7 +921,33 @@ impl Thread {
|
|||
//
|
||||
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
||||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||
thread.insert_message(Role::Assistant, chunk, cx);
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Text(chunk.to_string())],
|
||||
cx,
|
||||
);
|
||||
};
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::Thinking(chunk) => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
last_message.push_thinking(&chunk);
|
||||
cx.emit(ThreadEvent::StreamedAssistantThinking(
|
||||
last_message.id,
|
||||
chunk,
|
||||
));
|
||||
} else {
|
||||
// If we won't have an Assistant message yet, assume this chunk marks the beginning
|
||||
// of a new Assistant response.
|
||||
//
|
||||
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
||||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Thinking(chunk.to_string())],
|
||||
cx,
|
||||
);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -1357,7 +1453,14 @@ impl Thread {
|
|||
Role::System => "System",
|
||||
}
|
||||
)?;
|
||||
writeln!(markdown, "{}\n", message.text)?;
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
|
||||
MessageSegment::Thinking(text) => {
|
||||
writeln!(markdown, "<think>{}</think>\n", text)?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for tool_use in self.tool_uses_for_message(message.id, cx) {
|
||||
writeln!(
|
||||
|
@ -1416,6 +1519,7 @@ pub enum ThreadEvent {
|
|||
ShowError(ThreadError),
|
||||
StreamedCompletion,
|
||||
StreamedAssistantText(MessageId, String),
|
||||
StreamedAssistantThinking(MessageId, String),
|
||||
DoneStreaming,
|
||||
MessageAdded(MessageId),
|
||||
MessageEdited(MessageId),
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use std::borrow::Cow;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -12,7 +13,7 @@ use futures::FutureExt as _;
|
|||
use gpui::{
|
||||
prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
|
||||
};
|
||||
use heed::types::{SerdeBincode, SerdeJson};
|
||||
use heed::types::SerdeBincode;
|
||||
use heed::Database;
|
||||
use language_model::{LanguageModelToolUseId, Role};
|
||||
use project::Project;
|
||||
|
@ -259,6 +260,7 @@ pub struct SerializedThreadMetadata {
|
|||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SerializedThread {
|
||||
pub version: String,
|
||||
pub summary: SharedString,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub messages: Vec<SerializedMessage>,
|
||||
|
@ -266,17 +268,55 @@ pub struct SerializedThread {
|
|||
pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
|
||||
}
|
||||
|
||||
impl SerializedThread {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
|
||||
pub fn from_json(json: &[u8]) -> Result<Self> {
|
||||
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
||||
match saved_thread_json.get("version") {
|
||||
Some(serde_json::Value::String(version)) => match version.as_str() {
|
||||
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
|
||||
saved_thread_json,
|
||||
)?),
|
||||
_ => Err(anyhow!(
|
||||
"unrecognized serialized thread version: {}",
|
||||
version
|
||||
)),
|
||||
},
|
||||
None => {
|
||||
let saved_thread =
|
||||
serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
|
||||
Ok(saved_thread.upgrade())
|
||||
}
|
||||
version => Err(anyhow!(
|
||||
"unrecognized serialized thread version: {:?}",
|
||||
version
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SerializedMessage {
|
||||
pub id: MessageId,
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
#[serde(default)]
|
||||
pub segments: Vec<SerializedMessageSegment>,
|
||||
#[serde(default)]
|
||||
pub tool_uses: Vec<SerializedToolUse>,
|
||||
#[serde(default)]
|
||||
pub tool_results: Vec<SerializedToolResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum SerializedMessageSegment {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "thinking")]
|
||||
Thinking { text: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SerializedToolUse {
|
||||
pub id: LanguageModelToolUseId,
|
||||
|
@ -291,6 +331,50 @@ pub struct SerializedToolResult {
|
|||
pub content: Arc<str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct LegacySerializedThread {
|
||||
pub summary: SharedString,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub messages: Vec<LegacySerializedMessage>,
|
||||
#[serde(default)]
|
||||
pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
|
||||
}
|
||||
|
||||
impl LegacySerializedThread {
|
||||
pub fn upgrade(self) -> SerializedThread {
|
||||
SerializedThread {
|
||||
version: SerializedThread::VERSION.to_string(),
|
||||
summary: self.summary,
|
||||
updated_at: self.updated_at,
|
||||
messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
|
||||
initial_project_snapshot: self.initial_project_snapshot,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LegacySerializedMessage {
|
||||
pub id: MessageId,
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
#[serde(default)]
|
||||
pub tool_uses: Vec<SerializedToolUse>,
|
||||
#[serde(default)]
|
||||
pub tool_results: Vec<SerializedToolResult>,
|
||||
}
|
||||
|
||||
impl LegacySerializedMessage {
|
||||
fn upgrade(self) -> SerializedMessage {
|
||||
SerializedMessage {
|
||||
id: self.id,
|
||||
role: self.role,
|
||||
segments: vec![SerializedMessageSegment::Text { text: self.text }],
|
||||
tool_uses: self.tool_uses,
|
||||
tool_results: self.tool_results,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct GlobalThreadsDatabase(
|
||||
Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
|
||||
);
|
||||
|
@ -300,7 +384,25 @@ impl Global for GlobalThreadsDatabase {}
|
|||
pub(crate) struct ThreadsDatabase {
|
||||
executor: BackgroundExecutor,
|
||||
env: heed::Env,
|
||||
threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
|
||||
threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
|
||||
}
|
||||
|
||||
impl heed::BytesEncode<'_> for SerializedThread {
|
||||
type EItem = SerializedThread;
|
||||
|
||||
fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
|
||||
serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> heed::BytesDecode<'a> for SerializedThread {
|
||||
type DItem = SerializedThread;
|
||||
|
||||
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
|
||||
// We implement this type manually because we want to call `SerializedThread::from_json`,
|
||||
// instead of the Deserialize trait implementation for `SerializedThread`.
|
||||
SerializedThread::from_json(bytes).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
impl ThreadsDatabase {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue