assistant: Add support for claude-3-7-sonnet-thinking (#27085)

Closes #25671

Release Notes:

- Added support for `claude-3-7-sonnet-thinking` in the assistant panel

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
Bennet Bo Fenner 2025-03-21 13:29:07 +01:00 committed by GitHub
parent 2ffce4f516
commit a709d4c7c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1212 additions and 177 deletions

View file

@ -24,6 +24,16 @@ pub struct AnthropicModelCacheConfiguration {
pub max_cache_anchors: usize, pub max_cache_anchors: usize,
} }
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum AnthropicModelMode {
#[default]
Default,
Thinking {
budget_tokens: Option<u32>,
},
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model { pub enum Model {
@ -32,6 +42,11 @@ pub enum Model {
Claude3_5Sonnet, Claude3_5Sonnet,
#[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")] #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
Claude3_7Sonnet, Claude3_7Sonnet,
#[serde(
rename = "claude-3-7-sonnet-thinking",
alias = "claude-3-7-sonnet-thinking-latest"
)]
Claude3_7SonnetThinking,
#[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")] #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
Claude3_5Haiku, Claude3_5Haiku,
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")] #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
@ -54,6 +69,8 @@ pub enum Model {
default_temperature: Option<f32>, default_temperature: Option<f32>,
#[serde(default)] #[serde(default)]
extra_beta_headers: Vec<String>, extra_beta_headers: Vec<String>,
#[serde(default)]
mode: AnthropicModelMode,
}, },
} }
@ -61,6 +78,8 @@ impl Model {
pub fn from_id(id: &str) -> Result<Self> { pub fn from_id(id: &str) -> Result<Self> {
if id.starts_with("claude-3-5-sonnet") { if id.starts_with("claude-3-5-sonnet") {
Ok(Self::Claude3_5Sonnet) Ok(Self::Claude3_5Sonnet)
} else if id.starts_with("claude-3-7-sonnet-thinking") {
Ok(Self::Claude3_7SonnetThinking)
} else if id.starts_with("claude-3-7-sonnet") { } else if id.starts_with("claude-3-7-sonnet") {
Ok(Self::Claude3_7Sonnet) Ok(Self::Claude3_7Sonnet)
} else if id.starts_with("claude-3-5-haiku") { } else if id.starts_with("claude-3-5-haiku") {
@ -80,6 +99,20 @@ impl Model {
match self { match self {
Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest", Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
Model::Claude3_7Sonnet => "claude-3-7-sonnet-latest", Model::Claude3_7Sonnet => "claude-3-7-sonnet-latest",
Model::Claude3_7SonnetThinking => "claude-3-7-sonnet-thinking-latest",
Model::Claude3_5Haiku => "claude-3-5-haiku-latest",
Model::Claude3Opus => "claude-3-opus-latest",
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
Model::Claude3Haiku => "claude-3-haiku-20240307",
Self::Custom { name, .. } => name,
}
}
/// The id of the model that should be used for making API requests
pub fn request_id(&self) -> &str {
match self {
Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => "claude-3-7-sonnet-latest",
Model::Claude3_5Haiku => "claude-3-5-haiku-latest", Model::Claude3_5Haiku => "claude-3-5-haiku-latest",
Model::Claude3Opus => "claude-3-opus-latest", Model::Claude3Opus => "claude-3-opus-latest",
Model::Claude3Sonnet => "claude-3-sonnet-20240229", Model::Claude3Sonnet => "claude-3-sonnet-20240229",
@ -92,6 +125,7 @@ impl Model {
match self { match self {
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet", Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
Self::Claude3_5Haiku => "Claude 3.5 Haiku", Self::Claude3_5Haiku => "Claude 3.5 Haiku",
Self::Claude3Opus => "Claude 3 Opus", Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet", Self::Claude3Sonnet => "Claude 3 Sonnet",
@ -107,6 +141,7 @@ impl Model {
Self::Claude3_5Sonnet Self::Claude3_5Sonnet
| Self::Claude3_5Haiku | Self::Claude3_5Haiku
| Self::Claude3_7Sonnet | Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking
| Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration { | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
min_total_token: 2_048, min_total_token: 2_048,
should_speculate: true, should_speculate: true,
@ -125,6 +160,7 @@ impl Model {
Self::Claude3_5Sonnet Self::Claude3_5Sonnet
| Self::Claude3_5Haiku | Self::Claude3_5Haiku
| Self::Claude3_7Sonnet | Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking
| Self::Claude3Opus | Self::Claude3Opus
| Self::Claude3Sonnet | Self::Claude3Sonnet
| Self::Claude3Haiku => 200_000, | Self::Claude3Haiku => 200_000,
@ -135,7 +171,10 @@ impl Model {
pub fn max_output_tokens(&self) -> u32 { pub fn max_output_tokens(&self) -> u32 {
match self { match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096, Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
Self::Claude3_5Sonnet | Self::Claude3_7Sonnet | Self::Claude3_5Haiku => 8_192, Self::Claude3_5Sonnet
| Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking
| Self::Claude3_5Haiku => 8_192,
Self::Custom { Self::Custom {
max_output_tokens, .. max_output_tokens, ..
} => max_output_tokens.unwrap_or(4_096), } => max_output_tokens.unwrap_or(4_096),
@ -146,6 +185,7 @@ impl Model {
match self { match self {
Self::Claude3_5Sonnet Self::Claude3_5Sonnet
| Self::Claude3_7Sonnet | Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking
| Self::Claude3_5Haiku | Self::Claude3_5Haiku
| Self::Claude3Opus | Self::Claude3Opus
| Self::Claude3Sonnet | Self::Claude3Sonnet
@ -157,6 +197,21 @@ impl Model {
} }
} }
pub fn mode(&self) -> AnthropicModelMode {
match self {
Self::Claude3_5Sonnet
| Self::Claude3_7Sonnet
| Self::Claude3_5Haiku
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => AnthropicModelMode::Default,
Self::Claude3_7SonnetThinking => AnthropicModelMode::Thinking {
budget_tokens: Some(4_096),
},
Self::Custom { mode, .. } => mode.clone(),
}
}
pub const DEFAULT_BETA_HEADERS: &[&str] = &["prompt-caching-2024-07-31"]; pub const DEFAULT_BETA_HEADERS: &[&str] = &["prompt-caching-2024-07-31"];
pub fn beta_headers(&self) -> String { pub fn beta_headers(&self) -> String {
@ -188,7 +243,7 @@ impl Model {
{ {
tool_override tool_override
} else { } else {
self.id() self.request_id()
} }
} }
} }
@ -409,6 +464,8 @@ pub async fn extract_tool_args_from_events(
Err(error) => Some(Err(error)), Err(error) => Some(Err(error)),
Ok(Event::ContentBlockDelta { index, delta }) => match delta { Ok(Event::ContentBlockDelta { index, delta }) => match delta {
ContentDelta::TextDelta { .. } => None, ContentDelta::TextDelta { .. } => None,
ContentDelta::ThinkingDelta { .. } => None,
ContentDelta::SignatureDelta { .. } => None,
ContentDelta::InputJsonDelta { partial_json } => { ContentDelta::InputJsonDelta { partial_json } => {
if index == tool_use_index { if index == tool_use_index {
Some(Ok(partial_json)) Some(Ok(partial_json))
@ -487,6 +544,10 @@ pub enum RequestContent {
pub enum ResponseContent { pub enum ResponseContent {
#[serde(rename = "text")] #[serde(rename = "text")]
Text { text: String }, Text { text: String },
#[serde(rename = "thinking")]
Thinking { thinking: String },
#[serde(rename = "redacted_thinking")]
RedactedThinking { data: String },
#[serde(rename = "tool_use")] #[serde(rename = "tool_use")]
ToolUse { ToolUse {
id: String, id: String,
@ -518,6 +579,12 @@ pub enum ToolChoice {
Tool { name: String }, Tool { name: String },
} }
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum Thinking {
Enabled { budget_tokens: Option<u32> },
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Request { pub struct Request {
pub model: String, pub model: String,
@ -526,6 +593,8 @@ pub struct Request {
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>, pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub thinking: Option<Thinking>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>, pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub system: Option<String>, pub system: Option<String>,
@ -609,6 +678,10 @@ pub enum Event {
pub enum ContentDelta { pub enum ContentDelta {
#[serde(rename = "text_delta")] #[serde(rename = "text_delta")]
TextDelta { text: String }, TextDelta { text: String },
#[serde(rename = "thinking_delta")]
ThinkingDelta { thinking: String },
#[serde(rename = "signature_delta")]
SignatureDelta { signature: String },
#[serde(rename = "input_json_delta")] #[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String }, InputJsonDelta { partial_json: String },
} }

View file

@ -1,5 +1,6 @@
use crate::thread::{ use crate::thread::{
LastRestoreCheckpoint, MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ThreadFeedback, LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
ThreadEvent, ThreadFeedback,
}; };
use crate::thread_store::ThreadStore; use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus}; use crate::tool_use::{ToolUse, ToolUseStatus};
@ -7,10 +8,10 @@ use crate::ui::ContextPill;
use collections::HashMap; use collections::HashMap;
use editor::{Editor, MultiBuffer}; use editor::{Editor, MultiBuffer};
use gpui::{ use gpui::{
list, percentage, pulsating_between, AbsoluteLength, Animation, AnimationExt, AnyElement, App, linear_color_stop, linear_gradient, list, percentage, pulsating_between, AbsoluteLength,
ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, Animation, AnimationExt, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Entity, Focusable, Length, ListAlignment, ListOffset, ListState, ScrollHandle, StyleRefinement,
Transformation, UnderlineStyle, WeakEntity, Subscription, Task, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity,
}; };
use language::{Buffer, LanguageRegistry}; use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@ -35,15 +36,175 @@ pub struct ActiveThread {
save_thread_task: Option<Task<()>>, save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>, messages: Vec<MessageId>,
list_state: ListState, list_state: ListState,
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>, rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>, rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
rendered_tool_use_labels: HashMap<LanguageModelToolUseId, Entity<Markdown>>, rendered_tool_use_labels: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
editing_message: Option<(MessageId, EditMessageState)>, editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>, expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
last_error: Option<ThreadError>, last_error: Option<ThreadError>,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
struct RenderedMessage {
language_registry: Arc<LanguageRegistry>,
segments: Vec<RenderedMessageSegment>,
}
impl RenderedMessage {
fn from_segments(
segments: &[MessageSegment],
language_registry: Arc<LanguageRegistry>,
window: &Window,
cx: &mut App,
) -> Self {
let mut this = Self {
language_registry,
segments: Vec::with_capacity(segments.len()),
};
for segment in segments {
this.push_segment(segment, window, cx);
}
this
}
fn append_thinking(&mut self, text: &String, window: &Window, cx: &mut App) {
if let Some(RenderedMessageSegment::Thinking {
content,
scroll_handle,
}) = self.segments.last_mut()
{
content.update(cx, |markdown, cx| {
markdown.append(text, cx);
});
scroll_handle.scroll_to_bottom();
} else {
self.segments.push(RenderedMessageSegment::Thinking {
content: render_markdown(text.into(), self.language_registry.clone(), window, cx),
scroll_handle: ScrollHandle::default(),
});
}
}
fn append_text(&mut self, text: &String, window: &Window, cx: &mut App) {
if let Some(RenderedMessageSegment::Text(markdown)) = self.segments.last_mut() {
markdown.update(cx, |markdown, cx| markdown.append(text, cx));
} else {
self.segments
.push(RenderedMessageSegment::Text(render_markdown(
SharedString::from(text),
self.language_registry.clone(),
window,
cx,
)));
}
}
fn push_segment(&mut self, segment: &MessageSegment, window: &Window, cx: &mut App) {
let rendered_segment = match segment {
MessageSegment::Thinking(text) => RenderedMessageSegment::Thinking {
content: render_markdown(text.into(), self.language_registry.clone(), window, cx),
scroll_handle: ScrollHandle::default(),
},
MessageSegment::Text(text) => RenderedMessageSegment::Text(render_markdown(
text.into(),
self.language_registry.clone(),
window,
cx,
)),
};
self.segments.push(rendered_segment);
}
}
enum RenderedMessageSegment {
Thinking {
content: Entity<Markdown>,
scroll_handle: ScrollHandle,
},
Text(Entity<Markdown>),
}
fn render_markdown(
text: SharedString,
language_registry: Arc<LanguageRegistry>,
window: &Window,
cx: &mut App,
) -> Entity<Markdown> {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = TextSize::Small.rems(cx);
let mut text_style = window.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
font_features: Some(theme_settings.ui_font.features.clone()),
font_size: Some(ui_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
let markdown_style = MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
code_block: StyleRefinement {
margin: EdgesRefinement {
top: Some(Length::Definite(rems(0.).into())),
left: Some(Length::Definite(rems(0.).into())),
right: Some(Length::Definite(rems(0.).into())),
bottom: Some(Length::Definite(rems(0.5).into())),
},
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
},
background: Some(colors.editor_background.into()),
border_color: Some(colors.border_variant),
border_widths: EdgesRefinement {
top: Some(AbsoluteLength::Pixels(Pixels(1.))),
left: Some(AbsoluteLength::Pixels(Pixels(1.))),
right: Some(AbsoluteLength::Pixels(Pixels(1.))),
bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
},
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
background_color: Some(colors.editor_foreground.opacity(0.1)),
..Default::default()
},
link: TextStyleRefinement {
background_color: Some(colors.editor_foreground.opacity(0.025)),
underline: Some(UnderlineStyle {
color: Some(colors.text_accent.opacity(0.5)),
thickness: px(1.),
..Default::default()
}),
..Default::default()
},
..Default::default()
};
cx.new(|cx| Markdown::new(text, markdown_style, Some(language_registry), None, cx))
}
struct EditMessageState { struct EditMessageState {
editor: Entity<Editor>, editor: Entity<Editor>,
} }
@ -75,6 +236,7 @@ impl ActiveThread {
rendered_scripting_tool_uses: HashMap::default(), rendered_scripting_tool_uses: HashMap::default(),
rendered_tool_use_labels: HashMap::default(), rendered_tool_use_labels: HashMap::default(),
expanded_tool_uses: HashMap::default(), expanded_tool_uses: HashMap::default(),
expanded_thinking_segments: HashMap::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade(); let this = cx.entity().downgrade();
move |ix, window: &mut Window, cx: &mut App| { move |ix, window: &mut Window, cx: &mut App| {
@ -88,7 +250,7 @@ impl ActiveThread {
}; };
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() { for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
this.push_message(&message.id, message.text.clone(), window, cx); this.push_message(&message.id, &message.segments, window, cx);
for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) { for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
this.render_tool_use_label_markdown( this.render_tool_use_label_markdown(
@ -156,7 +318,7 @@ impl ActiveThread {
fn push_message( fn push_message(
&mut self, &mut self,
id: &MessageId, id: &MessageId,
text: String, segments: &[MessageSegment],
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
@ -164,8 +326,9 @@ impl ActiveThread {
self.messages.push(*id); self.messages.push(*id);
self.list_state.splice(old_len..old_len, 1); self.list_state.splice(old_len..old_len, 1);
let markdown = self.render_markdown(text.into(), window, cx); let rendered_message =
self.rendered_messages_by_id.insert(*id, markdown); RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx);
self.rendered_messages_by_id.insert(*id, rendered_message);
self.list_state.scroll_to(ListOffset { self.list_state.scroll_to(ListOffset {
item_ix: old_len, item_ix: old_len,
offset_in_item: Pixels(0.0), offset_in_item: Pixels(0.0),
@ -175,7 +338,7 @@ impl ActiveThread {
fn edited_message( fn edited_message(
&mut self, &mut self,
id: &MessageId, id: &MessageId,
text: String, segments: &[MessageSegment],
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
@ -183,8 +346,9 @@ impl ActiveThread {
return; return;
}; };
self.list_state.splice(index..index + 1, 1); self.list_state.splice(index..index + 1, 1);
let markdown = self.render_markdown(text.into(), window, cx); let rendered_message =
self.rendered_messages_by_id.insert(*id, markdown); RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx);
self.rendered_messages_by_id.insert(*id, rendered_message);
} }
fn deleted_message(&mut self, id: &MessageId) { fn deleted_message(&mut self, id: &MessageId) {
@ -196,94 +360,6 @@ impl ActiveThread {
self.rendered_messages_by_id.remove(id); self.rendered_messages_by_id.remove(id);
} }
fn render_markdown(
&self,
text: SharedString,
window: &Window,
cx: &mut Context<Self>,
) -> Entity<Markdown> {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = TextSize::Small.rems(cx);
let mut text_style = window.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
font_features: Some(theme_settings.ui_font.features.clone()),
font_size: Some(ui_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
let markdown_style = MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
code_block: StyleRefinement {
margin: EdgesRefinement {
top: Some(Length::Definite(rems(0.).into())),
left: Some(Length::Definite(rems(0.).into())),
right: Some(Length::Definite(rems(0.).into())),
bottom: Some(Length::Definite(rems(0.5).into())),
},
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
},
background: Some(colors.editor_background.into()),
border_color: Some(colors.border_variant),
border_widths: EdgesRefinement {
top: Some(AbsoluteLength::Pixels(Pixels(1.))),
left: Some(AbsoluteLength::Pixels(Pixels(1.))),
right: Some(AbsoluteLength::Pixels(Pixels(1.))),
bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
},
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
background_color: Some(colors.editor_foreground.opacity(0.1)),
..Default::default()
},
link: TextStyleRefinement {
background_color: Some(colors.editor_foreground.opacity(0.025)),
underline: Some(UnderlineStyle {
color: Some(colors.text_accent.opacity(0.5)),
thickness: px(1.),
..Default::default()
}),
..Default::default()
},
..Default::default()
};
cx.new(|cx| {
Markdown::new(
text,
markdown_style,
Some(self.language_registry.clone()),
None,
cx,
)
})
}
/// Renders the input of a scripting tool use to Markdown. /// Renders the input of a scripting tool use to Markdown.
/// ///
/// Does nothing if the tool use does not correspond to the scripting tool. /// Does nothing if the tool use does not correspond to the scripting tool.
@ -303,8 +379,12 @@ impl ActiveThread {
.map(|input| input.lua_script) .map(|input| input.lua_script)
.unwrap_or_default(); .unwrap_or_default();
let lua_script = let lua_script = render_markdown(
self.render_markdown(format!("```lua\n{lua_script}\n```").into(), window, cx); format!("```lua\n{lua_script}\n```").into(),
self.language_registry.clone(),
window,
cx,
);
self.rendered_scripting_tool_uses self.rendered_scripting_tool_uses
.insert(tool_use_id, lua_script); .insert(tool_use_id, lua_script);
@ -319,7 +399,12 @@ impl ActiveThread {
) { ) {
self.rendered_tool_use_labels.insert( self.rendered_tool_use_labels.insert(
tool_use_id, tool_use_id,
self.render_markdown(tool_label.into(), window, cx), render_markdown(
tool_label.into(),
self.language_registry.clone(),
window,
cx,
),
); );
} }
@ -339,33 +424,36 @@ impl ActiveThread {
} }
ThreadEvent::DoneStreaming => {} ThreadEvent::DoneStreaming => {}
ThreadEvent::StreamedAssistantText(message_id, text) => { ThreadEvent::StreamedAssistantText(message_id, text) => {
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) { if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
markdown.update(cx, |markdown, cx| { rendered_message.append_text(text, window, cx);
markdown.append(text, cx); }
}); }
ThreadEvent::StreamedAssistantThinking(message_id, text) => {
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
rendered_message.append_thinking(text, window, cx);
} }
} }
ThreadEvent::MessageAdded(message_id) => { ThreadEvent::MessageAdded(message_id) => {
if let Some(message_text) = self if let Some(message_segments) = self
.thread .thread
.read(cx) .read(cx)
.message(*message_id) .message(*message_id)
.map(|message| message.text.clone()) .map(|message| message.segments.clone())
{ {
self.push_message(message_id, message_text, window, cx); self.push_message(message_id, &message_segments, window, cx);
} }
self.save_thread(cx); self.save_thread(cx);
cx.notify(); cx.notify();
} }
ThreadEvent::MessageEdited(message_id) => { ThreadEvent::MessageEdited(message_id) => {
if let Some(message_text) = self if let Some(message_segments) = self
.thread .thread
.read(cx) .read(cx)
.message(*message_id) .message(*message_id)
.map(|message| message.text.clone()) .map(|message| message.segments.clone())
{ {
self.edited_message(message_id, message_text, window, cx); self.edited_message(message_id, &message_segments, window, cx);
} }
self.save_thread(cx); self.save_thread(cx);
@ -490,10 +578,16 @@ impl ActiveThread {
fn start_editing_message( fn start_editing_message(
&mut self, &mut self,
message_id: MessageId, message_id: MessageId,
message_text: String, message_segments: &[MessageSegment],
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
// User message should always consist of a single text segment,
// therefore we can skip returning early if it's not a text segment.
let Some(MessageSegment::Text(message_text)) = message_segments.first() else {
return;
};
let buffer = cx.new(|cx| { let buffer = cx.new(|cx| {
MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx) MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
}); });
@ -534,7 +628,12 @@ impl ActiveThread {
}; };
let edited_text = state.editor.read(cx).text(cx); let edited_text = state.editor.read(cx).text(cx);
self.thread.update(cx, |thread, cx| { self.thread.update(cx, |thread, cx| {
thread.edit_message(message_id, Role::User, edited_text, cx); thread.edit_message(
message_id,
Role::User,
vec![MessageSegment::Text(edited_text)],
cx,
);
for message_id in self.messages_after(message_id) { for message_id in self.messages_after(message_id) {
thread.delete_message(*message_id, cx); thread.delete_message(*message_id, cx);
} }
@ -617,7 +716,7 @@ impl ActiveThread {
return Empty.into_any(); return Empty.into_any();
}; };
let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else { let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any(); return Empty.into_any();
}; };
@ -759,7 +858,10 @@ impl ActiveThread {
.min_h_6() .min_h_6()
.child(edit_message_editor) .child(edit_message_editor)
} else { } else {
div().min_h_6().text_ui(cx).child(markdown.clone()) div()
.min_h_6()
.text_ui(cx)
.child(self.render_message_content(message_id, rendered_message, cx))
}, },
) )
.when_some(context, |parent, context| { .when_some(context, |parent, context| {
@ -869,11 +971,12 @@ impl ActiveThread {
Button::new("edit-message", "Edit") Button::new("edit-message", "Edit")
.label_size(LabelSize::Small) .label_size(LabelSize::Small)
.on_click(cx.listener({ .on_click(cx.listener({
let message_text = message.text.clone(); let message_segments =
message.segments.clone();
move |this, _, window, cx| { move |this, _, window, cx| {
this.start_editing_message( this.start_editing_message(
message_id, message_id,
message_text.clone(), &message_segments,
window, window,
cx, cx,
); );
@ -995,6 +1098,190 @@ impl ActiveThread {
.into_any() .into_any()
} }
fn render_message_content(
&self,
message_id: MessageId,
rendered_message: &RenderedMessage,
cx: &Context<Self>,
) -> impl IntoElement {
let pending_thinking_segment_index = rendered_message
.segments
.iter()
.enumerate()
.last()
.filter(|(_, segment)| matches!(segment, RenderedMessageSegment::Thinking { .. }))
.map(|(index, _)| index);
div()
.text_ui(cx)
.gap_2()
.children(
rendered_message.segments.iter().enumerate().map(
|(index, segment)| match segment {
RenderedMessageSegment::Thinking {
content,
scroll_handle,
} => self
.render_message_thinking_segment(
message_id,
index,
content.clone(),
&scroll_handle,
Some(index) == pending_thinking_segment_index,
cx,
)
.into_any_element(),
RenderedMessageSegment::Text(markdown) => {
div().p_2p5().child(markdown.clone()).into_any_element()
}
},
),
)
}
fn render_message_thinking_segment(
&self,
message_id: MessageId,
ix: usize,
markdown: Entity<Markdown>,
scroll_handle: &ScrollHandle,
pending: bool,
cx: &Context<Self>,
) -> impl IntoElement {
let is_open = self
.expanded_thinking_segments
.get(&(message_id, ix))
.copied()
.unwrap_or_default();
let lighter_border = cx.theme().colors().border.opacity(0.5);
let editor_bg = cx.theme().colors().editor_background;
v_flex()
.rounded_lg()
.border_1()
.border_color(lighter_border)
.child(
h_flex()
.justify_between()
.py_1()
.pl_1()
.pr_2()
.bg(cx.theme().colors().editor_foreground.opacity(0.025))
.map(|this| {
if is_open {
this.rounded_t_md()
.border_b_1()
.border_color(lighter_border)
} else {
this.rounded_md()
}
})
.child(
h_flex()
.gap_1()
.child(Disclosure::new("thinking-disclosure", is_open).on_click(
cx.listener({
move |this, _event, _window, _cx| {
let is_open = this
.expanded_thinking_segments
.entry((message_id, ix))
.or_insert(false);
*is_open = !*is_open;
}
}),
))
.child({
if pending {
Label::new("Thinking…")
.size(LabelSize::Small)
.buffer_font(cx)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 0.8)),
|label, delta| label.alpha(delta),
)
.into_any_element()
} else {
Label::new("Thought Process")
.size(LabelSize::Small)
.buffer_font(cx)
.into_any_element()
}
}),
)
.child({
let (icon_name, color, animated) = if pending {
(IconName::ArrowCircle, Color::Accent, true)
} else {
(IconName::Check, Color::Success, false)
};
let icon = Icon::new(icon_name).color(color).size(IconSize::Small);
if animated {
icon.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(delta)))
},
)
.into_any_element()
} else {
icon.into_any_element()
}
}),
)
.when(pending && !is_open, |this| {
let gradient_overlay = div()
.rounded_b_lg()
.h_20()
.absolute()
.w_full()
.bottom_0()
.left_0()
.bg(linear_gradient(
180.,
linear_color_stop(editor_bg, 1.),
linear_color_stop(editor_bg.opacity(0.2), 0.),
));
this.child(
div()
.relative()
.bg(editor_bg)
.rounded_b_lg()
.text_ui_sm(cx)
.child(
div()
.id(("thinking-content", ix))
.p_2()
.h_20()
.track_scroll(scroll_handle)
.child(markdown.clone())
.overflow_hidden(),
)
.child(gradient_overlay),
)
})
.when(is_open, |this| {
this.child(
div()
.id(("thinking-content", ix))
.h_full()
.p_2()
.rounded_b_lg()
.bg(editor_bg)
.text_ui_sm(cx)
.child(markdown.clone()),
)
})
}
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement { fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
let is_open = self let is_open = self
.expanded_tool_uses .expanded_tool_uses
@ -1258,8 +1545,9 @@ impl ActiveThread {
} }
}), }),
)) ))
.child(div().text_ui_sm(cx).child(self.render_markdown( .child(div().text_ui_sm(cx).child(render_markdown(
tool_use.ui_text.clone(), tool_use.ui_text.clone(),
self.language_registry.clone(),
window, window,
cx, cx,
))) )))

View file

@ -29,7 +29,8 @@ use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
use crate::thread_store::{ use crate::thread_store::{
SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse, SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse,
}; };
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState}; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
@ -69,7 +70,47 @@ impl MessageId {
pub struct Message { pub struct Message {
pub id: MessageId, pub id: MessageId,
pub role: Role, pub role: Role,
pub text: String, pub segments: Vec<MessageSegment>,
}
impl Message {
pub fn push_thinking(&mut self, text: &str) {
if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
segment.push_str(text);
} else {
self.segments
.push(MessageSegment::Thinking(text.to_string()));
}
}
pub fn push_text(&mut self, text: &str) {
if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
segment.push_str(text);
} else {
self.segments.push(MessageSegment::Text(text.to_string()));
}
}
pub fn to_string(&self) -> String {
let mut result = String::new();
for segment in &self.segments {
match segment {
MessageSegment::Text(text) => result.push_str(text),
MessageSegment::Thinking(text) => {
result.push_str("<think>");
result.push_str(text);
result.push_str("</think>");
}
}
}
result
}
}
#[derive(Debug, Clone)]
pub enum MessageSegment {
Text(String),
Thinking(String),
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -226,7 +267,16 @@ impl Thread {
.map(|message| Message { .map(|message| Message {
id: message.id, id: message.id,
role: message.role, role: message.role,
text: message.text, segments: message
.segments
.into_iter()
.map(|segment| match segment {
SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
SerializedMessageSegment::Thinking { text } => {
MessageSegment::Thinking(text)
}
})
.collect(),
}) })
.collect(), .collect(),
next_message_id, next_message_id,
@ -419,7 +469,8 @@ impl Thread {
checkpoint: Option<GitStoreCheckpoint>, checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> MessageId { ) -> MessageId {
let message_id = self.insert_message(Role::User, text, cx); let message_id =
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>(); let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
self.context self.context
.extend(context.into_iter().map(|context| (context.id, context))); .extend(context.into_iter().map(|context| (context.id, context)));
@ -433,15 +484,11 @@ impl Thread {
pub fn insert_message( pub fn insert_message(
&mut self, &mut self,
role: Role, role: Role,
text: impl Into<String>, segments: Vec<MessageSegment>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> MessageId { ) -> MessageId {
let id = self.next_message_id.post_inc(); let id = self.next_message_id.post_inc();
self.messages.push(Message { self.messages.push(Message { id, role, segments });
id,
role,
text: text.into(),
});
self.touch_updated_at(); self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id)); cx.emit(ThreadEvent::MessageAdded(id));
id id
@ -451,14 +498,14 @@ impl Thread {
&mut self, &mut self,
id: MessageId, id: MessageId,
new_role: Role, new_role: Role,
new_text: String, new_segments: Vec<MessageSegment>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> bool { ) -> bool {
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
return false; return false;
}; };
message.role = new_role; message.role = new_role;
message.text = new_text; message.segments = new_segments;
self.touch_updated_at(); self.touch_updated_at();
cx.emit(ThreadEvent::MessageEdited(id)); cx.emit(ThreadEvent::MessageEdited(id));
true true
@ -489,7 +536,14 @@ impl Thread {
}); });
text.push('\n'); text.push('\n');
text.push_str(&message.text); for segment in &message.segments {
match segment {
MessageSegment::Text(content) => text.push_str(content),
MessageSegment::Thinking(content) => {
text.push_str(&format!("<think>{}</think>", content))
}
}
}
text.push('\n'); text.push('\n');
} }
@ -502,6 +556,7 @@ impl Thread {
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let initial_project_snapshot = initial_project_snapshot.await; let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(cx, |this, cx| SerializedThread { this.read_with(cx, |this, cx| SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: this.summary_or_default(), summary: this.summary_or_default(),
updated_at: this.updated_at(), updated_at: this.updated_at(),
messages: this messages: this
@ -509,7 +564,18 @@ impl Thread {
.map(|message| SerializedMessage { .map(|message| SerializedMessage {
id: message.id, id: message.id,
role: message.role, role: message.role,
text: message.text.clone(), segments: message
.segments
.iter()
.map(|segment| match segment {
MessageSegment::Text(text) => {
SerializedMessageSegment::Text { text: text.clone() }
}
MessageSegment::Thinking(text) => {
SerializedMessageSegment::Thinking { text: text.clone() }
}
})
.collect(),
tool_uses: this tool_uses: this
.tool_uses_for_message(message.id, cx) .tool_uses_for_message(message.id, cx)
.into_iter() .into_iter()
@ -733,10 +799,10 @@ impl Thread {
} }
} }
if !message.text.is_empty() { if !message.segments.is_empty() {
request_message request_message
.content .content
.push(MessageContent::Text(message.text.clone())); .push(MessageContent::Text(message.to_string()));
} }
match request_kind { match request_kind {
@ -826,7 +892,11 @@ impl Thread {
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
match event { match event {
LanguageModelCompletionEvent::StartMessage { .. } => { LanguageModelCompletionEvent::StartMessage { .. } => {
thread.insert_message(Role::Assistant, String::new(), cx); thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text(String::new())],
cx,
);
} }
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason; stop_reason = reason;
@ -840,7 +910,7 @@ impl Thread {
LanguageModelCompletionEvent::Text(chunk) => { LanguageModelCompletionEvent::Text(chunk) => {
if let Some(last_message) = thread.messages.last_mut() { if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant { if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk); last_message.push_text(&chunk);
cx.emit(ThreadEvent::StreamedAssistantText( cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id, last_message.id,
chunk, chunk,
@ -851,7 +921,33 @@ impl Thread {
// //
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown. // will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message(Role::Assistant, chunk, cx); thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text(chunk.to_string())],
cx,
);
};
}
}
LanguageModelCompletionEvent::Thinking(chunk) => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.push_thinking(&chunk);
cx.emit(ThreadEvent::StreamedAssistantThinking(
last_message.id,
chunk,
));
} else {
// If we won't have an Assistant message yet, assume this chunk marks the beginning
// of a new Assistant response.
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Thinking(chunk.to_string())],
cx,
);
}; };
} }
} }
@ -1357,7 +1453,14 @@ impl Thread {
Role::System => "System", Role::System => "System",
} }
)?; )?;
writeln!(markdown, "{}\n", message.text)?; for segment in &message.segments {
match segment {
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
MessageSegment::Thinking(text) => {
writeln!(markdown, "<think>{}</think>\n", text)?
}
}
}
for tool_use in self.tool_uses_for_message(message.id, cx) { for tool_use in self.tool_uses_for_message(message.id, cx) {
writeln!( writeln!(
@ -1416,6 +1519,7 @@ pub enum ThreadEvent {
ShowError(ThreadError), ShowError(ThreadError),
StreamedCompletion, StreamedCompletion,
StreamedAssistantText(MessageId, String), StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
DoneStreaming, DoneStreaming,
MessageAdded(MessageId), MessageAdded(MessageId),
MessageEdited(MessageId), MessageEdited(MessageId),

View file

@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
@ -12,7 +13,7 @@ use futures::FutureExt as _;
use gpui::{ use gpui::{
prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
}; };
use heed::types::{SerdeBincode, SerdeJson}; use heed::types::SerdeBincode;
use heed::Database; use heed::Database;
use language_model::{LanguageModelToolUseId, Role}; use language_model::{LanguageModelToolUseId, Role};
use project::Project; use project::Project;
@ -259,6 +260,7 @@ pub struct SerializedThreadMetadata {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct SerializedThread { pub struct SerializedThread {
pub version: String,
pub summary: SharedString, pub summary: SharedString,
pub updated_at: DateTime<Utc>, pub updated_at: DateTime<Utc>,
pub messages: Vec<SerializedMessage>, pub messages: Vec<SerializedMessage>,
@ -266,17 +268,55 @@ pub struct SerializedThread {
pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>, pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
} }
impl SerializedThread {
pub const VERSION: &'static str = "0.1.0";
pub fn from_json(json: &[u8]) -> Result<Self> {
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
match saved_thread_json.get("version") {
Some(serde_json::Value::String(version)) => match version.as_str() {
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
saved_thread_json,
)?),
_ => Err(anyhow!(
"unrecognized serialized thread version: {}",
version
)),
},
None => {
let saved_thread =
serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
Ok(saved_thread.upgrade())
}
version => Err(anyhow!(
"unrecognized serialized thread version: {:?}",
version
)),
}
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SerializedMessage { pub struct SerializedMessage {
pub id: MessageId, pub id: MessageId,
pub role: Role, pub role: Role,
pub text: String, #[serde(default)]
pub segments: Vec<SerializedMessageSegment>,
#[serde(default)] #[serde(default)]
pub tool_uses: Vec<SerializedToolUse>, pub tool_uses: Vec<SerializedToolUse>,
#[serde(default)] #[serde(default)]
pub tool_results: Vec<SerializedToolResult>, pub tool_results: Vec<SerializedToolResult>,
} }
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum SerializedMessageSegment {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "thinking")]
Thinking { text: String },
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SerializedToolUse { pub struct SerializedToolUse {
pub id: LanguageModelToolUseId, pub id: LanguageModelToolUseId,
@ -291,6 +331,50 @@ pub struct SerializedToolResult {
pub content: Arc<str>, pub content: Arc<str>,
} }
#[derive(Serialize, Deserialize)]
struct LegacySerializedThread {
pub summary: SharedString,
pub updated_at: DateTime<Utc>,
pub messages: Vec<LegacySerializedMessage>,
#[serde(default)]
pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
}
impl LegacySerializedThread {
pub fn upgrade(self) -> SerializedThread {
SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: self.summary,
updated_at: self.updated_at,
messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
initial_project_snapshot: self.initial_project_snapshot,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct LegacySerializedMessage {
pub id: MessageId,
pub role: Role,
pub text: String,
#[serde(default)]
pub tool_uses: Vec<SerializedToolUse>,
#[serde(default)]
pub tool_results: Vec<SerializedToolResult>,
}
impl LegacySerializedMessage {
fn upgrade(self) -> SerializedMessage {
SerializedMessage {
id: self.id,
role: self.role,
segments: vec![SerializedMessageSegment::Text { text: self.text }],
tool_uses: self.tool_uses,
tool_results: self.tool_results,
}
}
}
struct GlobalThreadsDatabase( struct GlobalThreadsDatabase(
Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>, Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
); );
@ -300,7 +384,25 @@ impl Global for GlobalThreadsDatabase {}
pub(crate) struct ThreadsDatabase { pub(crate) struct ThreadsDatabase {
executor: BackgroundExecutor, executor: BackgroundExecutor,
env: heed::Env, env: heed::Env,
threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>, threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
}
impl heed::BytesEncode<'_> for SerializedThread {
type EItem = SerializedThread;
fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
}
}
impl<'a> heed::BytesDecode<'a> for SerializedThread {
type DItem = SerializedThread;
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
// We implement this type manually because we want to call `SerializedThread::from_json`,
// instead of the Deserialize trait implementation for `SerializedThread`.
SerializedThread::from_json(bytes).map_err(Into::into)
}
} }
impl ThreadsDatabase { impl ThreadsDatabase {

View file

@ -162,6 +162,11 @@ pub enum ContextOperation {
section: SlashCommandOutputSection<language::Anchor>, section: SlashCommandOutputSection<language::Anchor>,
version: clock::Global, version: clock::Global,
}, },
ThoughtProcessOutputSectionAdded {
timestamp: clock::Lamport,
section: ThoughtProcessOutputSection<language::Anchor>,
version: clock::Global,
},
BufferOperation(language::Operation), BufferOperation(language::Operation),
} }
@ -259,6 +264,20 @@ impl ContextOperation {
version: language::proto::deserialize_version(&message.version), version: language::proto::deserialize_version(&message.version),
}) })
} }
proto::context_operation::Variant::ThoughtProcessOutputSectionAdded(message) => {
let section = message.section.context("missing section")?;
Ok(Self::ThoughtProcessOutputSectionAdded {
timestamp: language::proto::deserialize_timestamp(
message.timestamp.context("missing timestamp")?,
),
section: ThoughtProcessOutputSection {
range: language::proto::deserialize_anchor_range(
section.range.context("invalid range")?,
)?,
},
version: language::proto::deserialize_version(&message.version),
})
}
proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation( proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation(
language::proto::deserialize_operation( language::proto::deserialize_operation(
op.operation.context("invalid buffer operation")?, op.operation.context("invalid buffer operation")?,
@ -370,6 +389,27 @@ impl ContextOperation {
}, },
)), )),
}, },
Self::ThoughtProcessOutputSectionAdded {
timestamp,
section,
version,
} => proto::ContextOperation {
variant: Some(
proto::context_operation::Variant::ThoughtProcessOutputSectionAdded(
proto::context_operation::ThoughtProcessOutputSectionAdded {
timestamp: Some(language::proto::serialize_timestamp(*timestamp)),
section: Some({
proto::ThoughtProcessOutputSection {
range: Some(language::proto::serialize_anchor_range(
section.range.clone(),
)),
}
}),
version: language::proto::serialize_version(version),
},
),
),
},
Self::BufferOperation(operation) => proto::ContextOperation { Self::BufferOperation(operation) => proto::ContextOperation {
variant: Some(proto::context_operation::Variant::BufferOperation( variant: Some(proto::context_operation::Variant::BufferOperation(
proto::context_operation::BufferOperation { proto::context_operation::BufferOperation {
@ -387,7 +427,8 @@ impl ContextOperation {
Self::UpdateSummary { summary, .. } => summary.timestamp, Self::UpdateSummary { summary, .. } => summary.timestamp,
Self::SlashCommandStarted { id, .. } => id.0, Self::SlashCommandStarted { id, .. } => id.0,
Self::SlashCommandOutputSectionAdded { timestamp, .. } Self::SlashCommandOutputSectionAdded { timestamp, .. }
| Self::SlashCommandFinished { timestamp, .. } => *timestamp, | Self::SlashCommandFinished { timestamp, .. }
| Self::ThoughtProcessOutputSectionAdded { timestamp, .. } => *timestamp,
Self::BufferOperation(_) => { Self::BufferOperation(_) => {
panic!("reading the timestamp of a buffer operation is not supported") panic!("reading the timestamp of a buffer operation is not supported")
} }
@ -402,7 +443,8 @@ impl ContextOperation {
| Self::UpdateSummary { version, .. } | Self::UpdateSummary { version, .. }
| Self::SlashCommandStarted { version, .. } | Self::SlashCommandStarted { version, .. }
| Self::SlashCommandOutputSectionAdded { version, .. } | Self::SlashCommandOutputSectionAdded { version, .. }
| Self::SlashCommandFinished { version, .. } => version, | Self::SlashCommandFinished { version, .. }
| Self::ThoughtProcessOutputSectionAdded { version, .. } => version,
Self::BufferOperation(_) => { Self::BufferOperation(_) => {
panic!("reading the version of a buffer operation is not supported") panic!("reading the version of a buffer operation is not supported")
} }
@ -418,6 +460,8 @@ pub enum ContextEvent {
MessagesEdited, MessagesEdited,
SummaryChanged, SummaryChanged,
StreamedCompletion, StreamedCompletion,
StartedThoughtProcess(Range<language::Anchor>),
EndedThoughtProcess(language::Anchor),
PatchesUpdated { PatchesUpdated {
removed: Vec<Range<language::Anchor>>, removed: Vec<Range<language::Anchor>>,
updated: Vec<Range<language::Anchor>>, updated: Vec<Range<language::Anchor>>,
@ -498,6 +542,17 @@ impl MessageMetadata {
} }
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ThoughtProcessOutputSection<T> {
pub range: Range<T>,
}
impl ThoughtProcessOutputSection<language::Anchor> {
pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Message { pub struct Message {
pub offset_range: Range<usize>, pub offset_range: Range<usize>,
@ -580,6 +635,7 @@ pub struct AssistantContext {
edits_since_last_parse: language::Subscription, edits_since_last_parse: language::Subscription,
slash_commands: Arc<SlashCommandWorkingSet>, slash_commands: Arc<SlashCommandWorkingSet>,
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>, slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
thought_process_output_sections: Vec<ThoughtProcessOutputSection<language::Anchor>>,
message_anchors: Vec<MessageAnchor>, message_anchors: Vec<MessageAnchor>,
contents: Vec<Content>, contents: Vec<Content>,
messages_metadata: HashMap<MessageId, MessageMetadata>, messages_metadata: HashMap<MessageId, MessageMetadata>,
@ -682,6 +738,7 @@ impl AssistantContext {
parsed_slash_commands: Vec::new(), parsed_slash_commands: Vec::new(),
invoked_slash_commands: HashMap::default(), invoked_slash_commands: HashMap::default(),
slash_command_output_sections: Vec::new(), slash_command_output_sections: Vec::new(),
thought_process_output_sections: Vec::new(),
edits_since_last_parse: edits_since_last_slash_command_parse, edits_since_last_parse: edits_since_last_slash_command_parse,
summary: None, summary: None,
pending_summary: Task::ready(None), pending_summary: Task::ready(None),
@ -764,6 +821,18 @@ impl AssistantContext {
} }
}) })
.collect(), .collect(),
thought_process_output_sections: self
.thought_process_output_sections
.iter()
.filter_map(|section| {
if section.is_valid(buffer) {
let range = section.range.to_offset(buffer);
Some(ThoughtProcessOutputSection { range })
} else {
None
}
})
.collect(),
} }
} }
@ -957,6 +1026,16 @@ impl AssistantContext {
cx.emit(ContextEvent::SlashCommandOutputSectionAdded { section }); cx.emit(ContextEvent::SlashCommandOutputSectionAdded { section });
} }
} }
ContextOperation::ThoughtProcessOutputSectionAdded { section, .. } => {
let buffer = self.buffer.read(cx);
if let Err(ix) = self
.thought_process_output_sections
.binary_search_by(|probe| probe.range.cmp(&section.range, buffer))
{
self.thought_process_output_sections
.insert(ix, section.clone());
}
}
ContextOperation::SlashCommandFinished { ContextOperation::SlashCommandFinished {
id, id,
error_message, error_message,
@ -1020,6 +1099,9 @@ impl AssistantContext {
ContextOperation::SlashCommandOutputSectionAdded { section, .. } => { ContextOperation::SlashCommandOutputSectionAdded { section, .. } => {
self.has_received_operations_for_anchor_range(section.range.clone(), cx) self.has_received_operations_for_anchor_range(section.range.clone(), cx)
} }
ContextOperation::ThoughtProcessOutputSectionAdded { section, .. } => {
self.has_received_operations_for_anchor_range(section.range.clone(), cx)
}
ContextOperation::SlashCommandFinished { .. } => true, ContextOperation::SlashCommandFinished { .. } => true,
ContextOperation::BufferOperation(_) => { ContextOperation::BufferOperation(_) => {
panic!("buffer operations should always be applied") panic!("buffer operations should always be applied")
@ -1128,6 +1210,12 @@ impl AssistantContext {
&self.slash_command_output_sections &self.slash_command_output_sections
} }
pub fn thought_process_output_sections(
&self,
) -> &[ThoughtProcessOutputSection<language::Anchor>] {
&self.thought_process_output_sections
}
pub fn contains_files(&self, cx: &App) -> bool { pub fn contains_files(&self, cx: &App) -> bool {
let buffer = self.buffer.read(cx); let buffer = self.buffer.read(cx);
self.slash_command_output_sections.iter().any(|section| { self.slash_command_output_sections.iter().any(|section| {
@ -2168,6 +2256,35 @@ impl AssistantContext {
); );
} }
fn insert_thought_process_output_section(
&mut self,
section: ThoughtProcessOutputSection<language::Anchor>,
cx: &mut Context<Self>,
) {
let buffer = self.buffer.read(cx);
let insertion_ix = match self
.thought_process_output_sections
.binary_search_by(|probe| probe.range.cmp(&section.range, buffer))
{
Ok(ix) | Err(ix) => ix,
};
self.thought_process_output_sections
.insert(insertion_ix, section.clone());
// cx.emit(ContextEvent::ThoughtProcessOutputSectionAdded {
// section: section.clone(),
// });
let version = self.version.clone();
let timestamp = self.next_timestamp();
self.push_op(
ContextOperation::ThoughtProcessOutputSectionAdded {
timestamp,
section,
version,
},
cx,
);
}
pub fn completion_provider_changed(&mut self, cx: &mut Context<Self>) { pub fn completion_provider_changed(&mut self, cx: &mut Context<Self>) {
self.count_remaining_tokens(cx); self.count_remaining_tokens(cx);
} }
@ -2220,6 +2337,10 @@ impl AssistantContext {
let request_start = Instant::now(); let request_start = Instant::now();
let mut events = stream.await?; let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn; let mut stop_reason = StopReason::EndTurn;
let mut thought_process_stack = Vec::new();
const THOUGHT_PROCESS_START_MARKER: &str = "<think>\n";
const THOUGHT_PROCESS_END_MARKER: &str = "\n</think>";
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
if response_latency.is_none() { if response_latency.is_none() {
@ -2227,6 +2348,9 @@ impl AssistantContext {
} }
let event = event?; let event = event?;
let mut context_event = None;
let mut thought_process_output_section = None;
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
let message_ix = this let message_ix = this
.message_anchors .message_anchors
@ -2245,7 +2369,50 @@ impl AssistantContext {
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason; stop_reason = reason;
} }
LanguageModelCompletionEvent::Text(chunk) => { LanguageModelCompletionEvent::Thinking(chunk) => {
if thought_process_stack.is_empty() {
let start =
buffer.anchor_before(message_old_end_offset);
thought_process_stack.push(start);
let chunk =
format!("{THOUGHT_PROCESS_START_MARKER}{chunk}{THOUGHT_PROCESS_END_MARKER}");
let chunk_len = chunk.len();
buffer.edit(
[(
message_old_end_offset..message_old_end_offset,
chunk,
)],
None,
cx,
);
let end = buffer
.anchor_before(message_old_end_offset + chunk_len);
context_event = Some(
ContextEvent::StartedThoughtProcess(start..end),
);
} else {
// This ensures that all the thinking chunks are inserted inside the thinking tag
let insertion_position =
message_old_end_offset - THOUGHT_PROCESS_END_MARKER.len();
buffer.edit(
[(insertion_position..insertion_position, chunk)],
None,
cx,
);
}
}
LanguageModelCompletionEvent::Text(mut chunk) => {
if let Some(start) = thought_process_stack.pop() {
let end = buffer.anchor_before(message_old_end_offset);
context_event =
Some(ContextEvent::EndedThoughtProcess(end));
thought_process_output_section =
Some(ThoughtProcessOutputSection {
range: start..end,
});
chunk.insert_str(0, "\n\n");
}
buffer.edit( buffer.edit(
[( [(
message_old_end_offset..message_old_end_offset, message_old_end_offset..message_old_end_offset,
@ -2260,6 +2427,13 @@ impl AssistantContext {
} }
}); });
if let Some(section) = thought_process_output_section.take() {
this.insert_thought_process_output_section(section, cx);
}
if let Some(context_event) = context_event.take() {
cx.emit(context_event);
}
cx.emit(ContextEvent::StreamedCompletion); cx.emit(ContextEvent::StreamedCompletion);
Some(()) Some(())
@ -3127,6 +3301,8 @@ pub struct SavedContext {
pub summary: String, pub summary: String,
pub slash_command_output_sections: pub slash_command_output_sections:
Vec<assistant_slash_command::SlashCommandOutputSection<usize>>, Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
#[serde(default)]
pub thought_process_output_sections: Vec<ThoughtProcessOutputSection<usize>>,
} }
impl SavedContext { impl SavedContext {
@ -3228,6 +3404,20 @@ impl SavedContext {
version.observe(timestamp); version.observe(timestamp);
} }
for section in self.thought_process_output_sections {
let timestamp = next_timestamp.tick();
operations.push(ContextOperation::ThoughtProcessOutputSectionAdded {
timestamp,
section: ThoughtProcessOutputSection {
range: buffer.anchor_after(section.range.start)
..buffer.anchor_before(section.range.end),
},
version: version.clone(),
});
version.observe(timestamp);
}
let timestamp = next_timestamp.tick(); let timestamp = next_timestamp.tick();
operations.push(ContextOperation::UpdateSummary { operations.push(ContextOperation::UpdateSummary {
summary: ContextSummary { summary: ContextSummary {
@ -3302,6 +3492,7 @@ impl SavedContextV0_3_0 {
.collect(), .collect(),
summary: self.summary, summary: self.summary,
slash_command_output_sections: self.slash_command_output_sections, slash_command_output_sections: self.slash_command_output_sections,
thought_process_output_sections: Vec::new(),
} }
} }
} }

View file

@ -64,7 +64,10 @@ use workspace::{
Workspace, Workspace,
}; };
use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker}; use crate::{
slash_command::SlashCommandCompletionProvider, slash_command_picker,
ThoughtProcessOutputSection,
};
use crate::{ use crate::{
AssistantContext, AssistantPatch, AssistantPatchStatus, CacheStatus, Content, ContextEvent, AssistantContext, AssistantPatch, AssistantPatchStatus, CacheStatus, Content, ContextEvent,
ContextId, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId, ContextId, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId,
@ -120,6 +123,11 @@ enum AssistError {
Message(SharedString), Message(SharedString),
} }
pub enum ThoughtProcessStatus {
Pending,
Completed,
}
pub trait AssistantPanelDelegate { pub trait AssistantPanelDelegate {
fn active_context_editor( fn active_context_editor(
&self, &self,
@ -178,6 +186,7 @@ pub struct ContextEditor {
project: Entity<Project>, project: Entity<Project>,
lsp_adapter_delegate: Option<Arc<dyn LspAdapterDelegate>>, lsp_adapter_delegate: Option<Arc<dyn LspAdapterDelegate>>,
editor: Entity<Editor>, editor: Entity<Editor>,
pending_thought_process: Option<(CreaseId, language::Anchor)>,
blocks: HashMap<MessageId, (MessageHeader, CustomBlockId)>, blocks: HashMap<MessageId, (MessageHeader, CustomBlockId)>,
image_blocks: HashSet<CustomBlockId>, image_blocks: HashSet<CustomBlockId>,
scroll_position: Option<ScrollPosition>, scroll_position: Option<ScrollPosition>,
@ -253,7 +262,8 @@ impl ContextEditor {
cx.observe_global_in::<SettingsStore>(window, Self::settings_changed), cx.observe_global_in::<SettingsStore>(window, Self::settings_changed),
]; ];
let sections = context.read(cx).slash_command_output_sections().to_vec(); let slash_command_sections = context.read(cx).slash_command_output_sections().to_vec();
let thought_process_sections = context.read(cx).thought_process_output_sections().to_vec();
let patch_ranges = context.read(cx).patch_ranges().collect::<Vec<_>>(); let patch_ranges = context.read(cx).patch_ranges().collect::<Vec<_>>();
let slash_commands = context.read(cx).slash_commands().clone(); let slash_commands = context.read(cx).slash_commands().clone();
let mut this = Self { let mut this = Self {
@ -265,6 +275,7 @@ impl ContextEditor {
image_blocks: Default::default(), image_blocks: Default::default(),
scroll_position: None, scroll_position: None,
remote_id: None, remote_id: None,
pending_thought_process: None,
fs: fs.clone(), fs: fs.clone(),
workspace, workspace,
project, project,
@ -294,7 +305,14 @@ impl ContextEditor {
}; };
this.update_message_headers(cx); this.update_message_headers(cx);
this.update_image_blocks(cx); this.update_image_blocks(cx);
this.insert_slash_command_output_sections(sections, false, window, cx); this.insert_slash_command_output_sections(slash_command_sections, false, window, cx);
this.insert_thought_process_output_sections(
thought_process_sections
.into_iter()
.map(|section| (section, ThoughtProcessStatus::Completed)),
window,
cx,
);
this.patches_updated(&Vec::new(), &patch_ranges, window, cx); this.patches_updated(&Vec::new(), &patch_ranges, window, cx);
this this
} }
@ -599,6 +617,47 @@ impl ContextEditor {
context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx); context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
}); });
} }
ContextEvent::StartedThoughtProcess(range) => {
let creases = self.insert_thought_process_output_sections(
[(
ThoughtProcessOutputSection {
range: range.clone(),
},
ThoughtProcessStatus::Pending,
)],
window,
cx,
);
self.pending_thought_process = Some((creases[0], range.start));
}
ContextEvent::EndedThoughtProcess(end) => {
if let Some((crease_id, start)) = self.pending_thought_process.take() {
self.editor.update(cx, |editor, cx| {
let multi_buffer_snapshot = editor.buffer().read(cx).snapshot(cx);
let (excerpt_id, _, _) = multi_buffer_snapshot.as_singleton().unwrap();
let start_anchor = multi_buffer_snapshot
.anchor_in_excerpt(*excerpt_id, start)
.unwrap();
editor.display_map.update(cx, |display_map, cx| {
display_map.unfold_intersecting(
vec![start_anchor..start_anchor],
true,
cx,
);
});
editor.remove_creases(vec![crease_id], cx);
});
self.insert_thought_process_output_sections(
[(
ThoughtProcessOutputSection { range: start..*end },
ThoughtProcessStatus::Completed,
)],
window,
cx,
);
}
}
ContextEvent::StreamedCompletion => { ContextEvent::StreamedCompletion => {
self.editor.update(cx, |editor, cx| { self.editor.update(cx, |editor, cx| {
if let Some(scroll_position) = self.scroll_position { if let Some(scroll_position) = self.scroll_position {
@ -946,6 +1005,62 @@ impl ContextEditor {
self.update_active_patch(window, cx); self.update_active_patch(window, cx);
} }
fn insert_thought_process_output_sections(
&mut self,
sections: impl IntoIterator<
Item = (
ThoughtProcessOutputSection<language::Anchor>,
ThoughtProcessStatus,
),
>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Vec<CreaseId> {
self.editor.update(cx, |editor, cx| {
let buffer = editor.buffer().read(cx).snapshot(cx);
let excerpt_id = *buffer.as_singleton().unwrap().0;
let mut buffer_rows_to_fold = BTreeSet::new();
let mut creases = Vec::new();
for (section, status) in sections {
let start = buffer
.anchor_in_excerpt(excerpt_id, section.range.start)
.unwrap();
let end = buffer
.anchor_in_excerpt(excerpt_id, section.range.end)
.unwrap();
let buffer_row = MultiBufferRow(start.to_point(&buffer).row);
buffer_rows_to_fold.insert(buffer_row);
creases.push(
Crease::inline(
start..end,
FoldPlaceholder {
render: render_thought_process_fold_icon_button(
cx.entity().downgrade(),
status,
),
merge_adjacent: false,
..Default::default()
},
render_slash_command_output_toggle,
|_, _, _, _| Empty.into_any_element(),
)
.with_metadata(CreaseMetadata {
icon: IconName::Ai,
label: "Thinking Process".into(),
}),
);
}
let creases = editor.insert_creases(creases, cx);
for buffer_row in buffer_rows_to_fold.into_iter().rev() {
editor.fold_at(&FoldAt { buffer_row }, window, cx);
}
creases
})
}
fn insert_slash_command_output_sections( fn insert_slash_command_output_sections(
&mut self, &mut self,
sections: impl IntoIterator<Item = SlashCommandOutputSection<language::Anchor>>, sections: impl IntoIterator<Item = SlashCommandOutputSection<language::Anchor>>,
@ -2652,6 +2767,52 @@ fn find_surrounding_code_block(snapshot: &BufferSnapshot, offset: usize) -> Opti
None None
} }
fn render_thought_process_fold_icon_button(
editor: WeakEntity<Editor>,
status: ThoughtProcessStatus,
) -> Arc<dyn Send + Sync + Fn(FoldId, Range<Anchor>, &mut App) -> AnyElement> {
Arc::new(move |fold_id, fold_range, _cx| {
let editor = editor.clone();
let button = ButtonLike::new(fold_id).layer(ElevationIndex::ElevatedSurface);
let button = match status {
ThoughtProcessStatus::Pending => button
.child(
Icon::new(IconName::Brain)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(
Label::new("Thinking…").color(Color::Muted).with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 0.8)),
|label, delta| label.alpha(delta),
),
),
ThoughtProcessStatus::Completed => button
.style(ButtonStyle::Filled)
.child(Icon::new(IconName::Brain).size(IconSize::Small))
.child(Label::new("Thought Process").single_line()),
};
button
.on_click(move |_, window, cx| {
editor
.update(cx, |editor, cx| {
let buffer_start = fold_range
.start
.to_point(&editor.buffer().read(cx).read(cx));
let buffer_row = MultiBufferRow(buffer_start.row);
editor.unfold_at(&UnfoldAt { buffer_row }, window, cx);
})
.ok();
})
.into_any_element()
})
}
fn render_fold_icon_button( fn render_fold_icon_button(
editor: WeakEntity<Editor>, editor: WeakEntity<Editor>,
icon: IconName, icon: IconName,

View file

@ -120,7 +120,7 @@ impl Eval {
.count(); .count();
Ok(EvalOutput { Ok(EvalOutput {
diff, diff,
last_message: last_message.text.clone(), last_message: last_message.to_string(),
elapsed_time, elapsed_time,
assistant_response_count, assistant_response_count,
tool_use_counts: assistant.tool_use_counts.clone(), tool_use_counts: assistant.tool_use_counts.clone(),

View file

@ -89,7 +89,7 @@ impl HeadlessAssistant {
ThreadEvent::DoneStreaming => { ThreadEvent::DoneStreaming => {
let thread = thread.read(cx); let thread = thread.read(cx);
if let Some(message) = thread.messages().last() { if let Some(message) = thread.messages().last() {
println!("Message: {}", message.text,); println!("Message: {}", message.to_string());
} }
if thread.all_tools_finished() { if thread.all_tools_finished() {
self.done_tx.send_blocking(Ok(())).unwrap() self.done_tx.send_blocking(Ok(())).unwrap()

View file

@ -1240,20 +1240,11 @@ impl Element for Div {
let mut state = scroll_handle.0.borrow_mut(); let mut state = scroll_handle.0.borrow_mut();
state.child_bounds = Vec::with_capacity(request_layout.child_layout_ids.len()); state.child_bounds = Vec::with_capacity(request_layout.child_layout_ids.len());
state.bounds = bounds; state.bounds = bounds;
let requested = state.requested_scroll_top.take(); for child_layout_id in &request_layout.child_layout_ids {
for (ix, child_layout_id) in request_layout.child_layout_ids.iter().enumerate() {
let child_bounds = window.layout_bounds(*child_layout_id); let child_bounds = window.layout_bounds(*child_layout_id);
child_min = child_min.min(&child_bounds.origin); child_min = child_min.min(&child_bounds.origin);
child_max = child_max.max(&child_bounds.bottom_right()); child_max = child_max.max(&child_bounds.bottom_right());
state.child_bounds.push(child_bounds); state.child_bounds.push(child_bounds);
if let Some(requested) = requested.as_ref() {
if requested.0 == ix {
*state.offset.borrow_mut() =
bounds.origin - (child_bounds.origin - point(px(0.), requested.1));
}
}
} }
(child_max - child_min).into() (child_max - child_min).into()
} else { } else {
@ -1528,8 +1519,11 @@ impl Interactivity {
_cx: &mut App, _cx: &mut App,
) -> Point<Pixels> { ) -> Point<Pixels> {
if let Some(scroll_offset) = self.scroll_offset.as_ref() { if let Some(scroll_offset) = self.scroll_offset.as_ref() {
let mut scroll_to_bottom = false;
if let Some(scroll_handle) = &self.tracked_scroll_handle { if let Some(scroll_handle) = &self.tracked_scroll_handle {
scroll_handle.0.borrow_mut().overflow = style.overflow; let mut state = scroll_handle.0.borrow_mut();
state.overflow = style.overflow;
scroll_to_bottom = mem::take(&mut state.scroll_to_bottom);
} }
let rem_size = window.rem_size(); let rem_size = window.rem_size();
@ -1555,8 +1549,14 @@ impl Interactivity {
// Clamp scroll offset in case scroll max is smaller now (e.g., if children // Clamp scroll offset in case scroll max is smaller now (e.g., if children
// were removed or the bounds became larger). // were removed or the bounds became larger).
let mut scroll_offset = scroll_offset.borrow_mut(); let mut scroll_offset = scroll_offset.borrow_mut();
scroll_offset.x = scroll_offset.x.clamp(-scroll_max.width, px(0.)); scroll_offset.x = scroll_offset.x.clamp(-scroll_max.width, px(0.));
scroll_offset.y = scroll_offset.y.clamp(-scroll_max.height, px(0.)); if scroll_to_bottom {
scroll_offset.y = -scroll_max.height;
} else {
scroll_offset.y = scroll_offset.y.clamp(-scroll_max.height, px(0.));
}
*scroll_offset *scroll_offset
} else { } else {
Point::default() Point::default()
@ -2861,12 +2861,13 @@ impl ScrollAnchor {
}); });
} }
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
struct ScrollHandleState { struct ScrollHandleState {
offset: Rc<RefCell<Point<Pixels>>>, offset: Rc<RefCell<Point<Pixels>>>,
bounds: Bounds<Pixels>, bounds: Bounds<Pixels>,
child_bounds: Vec<Bounds<Pixels>>, child_bounds: Vec<Bounds<Pixels>>,
requested_scroll_top: Option<(usize, Pixels)>, scroll_to_bottom: bool,
overflow: Point<Overflow>, overflow: Point<Overflow>,
} }
@ -2955,6 +2956,12 @@ impl ScrollHandle {
} }
} }
/// Scrolls to the bottom.
pub fn scroll_to_bottom(&self) {
let mut state = self.0.borrow_mut();
state.scroll_to_bottom = true;
}
/// Set the offset explicitly. The offset is the distance from the top left of the /// Set the offset explicitly. The offset is the distance from the top left of the
/// parent container to the top left of the first child. /// parent container to the top left of the first child.
/// As you scroll further down the offset becomes more negative. /// As you scroll further down the offset becomes more negative.
@ -2978,11 +2985,6 @@ impl ScrollHandle {
} }
} }
/// Set the logical scroll top, based on a child index and a pixel offset.
pub fn set_logical_scroll_top(&self, ix: usize, px: Pixels) {
self.0.borrow_mut().requested_scroll_top = Some((ix, px));
}
/// Get the count of children for scrollable item. /// Get the count of children for scrollable item.
pub fn children_count(&self) -> usize { pub fn children_count(&self) -> usize {
self.0.borrow().child_bounds.len() self.0.borrow().child_bounds.len()

View file

@ -59,6 +59,7 @@ pub struct LanguageModelCacheConfiguration {
pub enum LanguageModelCompletionEvent { pub enum LanguageModelCompletionEvent {
Stop(StopReason), Stop(StopReason),
Text(String), Text(String),
Thinking(String),
ToolUse(LanguageModelToolUse), ToolUse(LanguageModelToolUse),
StartMessage { message_id: String }, StartMessage { message_id: String },
UsageUpdate(TokenUsage), UsageUpdate(TokenUsage),
@ -217,6 +218,7 @@ pub trait LanguageModel: Send + Sync {
match result { match result {
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking(_)) => None,
Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None, Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None,

View file

@ -72,7 +72,9 @@ impl CloudModel {
pub fn availability(&self) -> LanguageModelAvailability { pub fn availability(&self) -> LanguageModelAvailability {
match self { match self {
Self::Anthropic(model) => match model { Self::Anthropic(model) => match model {
anthropic::Model::Claude3_5Sonnet | anthropic::Model::Claude3_7Sonnet => { anthropic::Model::Claude3_5Sonnet
| anthropic::Model::Claude3_7Sonnet
| anthropic::Model::Claude3_7SonnetThinking => {
LanguageModelAvailability::RequiresPlan(Plan::Free) LanguageModelAvailability::RequiresPlan(Plan::Free)
} }
anthropic::Model::Claude3Opus anthropic::Model::Claude3Opus

View file

@ -1,6 +1,6 @@
use crate::ui::InstructionListItem; use crate::ui::InstructionListItem;
use crate::AllLanguageModelSettings; use crate::AllLanguageModelSettings;
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent, Usage}; use anthropic::{AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, Usage};
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::{BTreeMap, HashMap}; use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider; use credentials_provider::CredentialsProvider;
@ -55,6 +55,37 @@ pub struct AvailableModel {
pub default_temperature: Option<f32>, pub default_temperature: Option<f32>,
#[serde(default)] #[serde(default)]
pub extra_beta_headers: Vec<String>, pub extra_beta_headers: Vec<String>,
/// The model's mode (e.g. thinking)
pub mode: Option<ModelMode>,
}
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ModelMode {
#[default]
Default,
Thinking {
/// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
budget_tokens: Option<u32>,
},
}
impl From<ModelMode> for AnthropicModelMode {
fn from(value: ModelMode) -> Self {
match value {
ModelMode::Default => AnthropicModelMode::Default,
ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
}
}
}
impl From<AnthropicModelMode> for ModelMode {
fn from(value: AnthropicModelMode) -> Self {
match value {
AnthropicModelMode::Default => ModelMode::Default,
AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
}
}
} }
pub struct AnthropicLanguageModelProvider { pub struct AnthropicLanguageModelProvider {
@ -228,6 +259,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
max_output_tokens: model.max_output_tokens, max_output_tokens: model.max_output_tokens,
default_temperature: model.default_temperature, default_temperature: model.default_temperature,
extra_beta_headers: model.extra_beta_headers.clone(), extra_beta_headers: model.extra_beta_headers.clone(),
mode: model.mode.clone().unwrap_or_default().into(),
}, },
); );
} }
@ -399,9 +431,10 @@ impl LanguageModel for AnthropicModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = into_anthropic( let request = into_anthropic(
request, request,
self.model.id().into(), self.model.request_id().into(),
self.model.default_temperature(), self.model.default_temperature(),
self.model.max_output_tokens(), self.model.max_output_tokens(),
self.model.mode(),
); );
let request = self.stream_completion(request, cx); let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
@ -434,6 +467,7 @@ impl LanguageModel for AnthropicModel {
self.model.tool_model_id().into(), self.model.tool_model_id().into(),
self.model.default_temperature(), self.model.default_temperature(),
self.model.max_output_tokens(), self.model.max_output_tokens(),
self.model.mode(),
); );
request.tool_choice = Some(anthropic::ToolChoice::Tool { request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(), name: tool_name.clone(),
@ -464,6 +498,7 @@ pub fn into_anthropic(
model: String, model: String,
default_temperature: f32, default_temperature: f32,
max_output_tokens: u32, max_output_tokens: u32,
mode: AnthropicModelMode,
) -> anthropic::Request { ) -> anthropic::Request {
let mut new_messages: Vec<anthropic::Message> = Vec::new(); let mut new_messages: Vec<anthropic::Message> = Vec::new();
let mut system_message = String::new(); let mut system_message = String::new();
@ -552,6 +587,11 @@ pub fn into_anthropic(
messages: new_messages, messages: new_messages,
max_tokens: max_output_tokens, max_tokens: max_output_tokens,
system: Some(system_message), system: Some(system_message),
thinking: if let AnthropicModelMode::Thinking { budget_tokens } = mode {
Some(anthropic::Thinking::Enabled { budget_tokens })
} else {
None
},
tools: request tools: request
.tools .tools
.into_iter() .into_iter()
@ -607,6 +647,16 @@ pub fn map_to_language_model_completion_events(
state, state,
)); ));
} }
ResponseContent::Thinking { thinking } => {
return Some((
vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))],
state,
));
}
ResponseContent::RedactedThinking { .. } => {
// Redacted thinking is encrypted and not accessible to the user, see:
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production
}
ResponseContent::ToolUse { id, name, .. } => { ResponseContent::ToolUse { id, name, .. } => {
state.tool_uses_by_index.insert( state.tool_uses_by_index.insert(
index, index,
@ -625,6 +675,13 @@ pub fn map_to_language_model_completion_events(
state, state,
)); ));
} }
ContentDelta::ThinkingDelta { thinking } => {
return Some((
vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))],
state,
));
}
ContentDelta::SignatureDelta { .. } => {}
ContentDelta::InputJsonDelta { partial_json } => { ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json); tool_use.input_json.push_str(&partial_json);

View file

@ -1,4 +1,4 @@
use anthropic::AnthropicError; use anthropic::{AnthropicError, AnthropicModelMode};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use client::{ use client::{
zed_urls, Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME, zed_urls, Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
@ -91,6 +91,28 @@ pub struct AvailableModel {
/// Any extra beta headers to provide when using the model. /// Any extra beta headers to provide when using the model.
#[serde(default)] #[serde(default)]
pub extra_beta_headers: Vec<String>, pub extra_beta_headers: Vec<String>,
/// The model's mode (e.g. thinking)
pub mode: Option<ModelMode>,
}
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ModelMode {
#[default]
Default,
Thinking {
/// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
budget_tokens: Option<u32>,
},
}
impl From<ModelMode> for AnthropicModelMode {
fn from(value: ModelMode) -> Self {
match value {
ModelMode::Default => AnthropicModelMode::Default,
ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
}
}
} }
pub struct CloudLanguageModelProvider { pub struct CloudLanguageModelProvider {
@ -299,6 +321,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
default_temperature: model.default_temperature, default_temperature: model.default_temperature,
max_output_tokens: model.max_output_tokens, max_output_tokens: model.max_output_tokens,
extra_beta_headers: model.extra_beta_headers.clone(), extra_beta_headers: model.extra_beta_headers.clone(),
mode: model.mode.unwrap_or_default().into(),
}), }),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(), name: model.name.clone(),
@ -567,9 +590,10 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
let request = into_anthropic( let request = into_anthropic(
request, request,
model.id().into(), model.request_id().into(),
model.default_temperature(), model.default_temperature(),
model.max_output_tokens(), model.max_output_tokens(),
model.mode(),
); );
let client = self.client.clone(); let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
@ -669,6 +693,7 @@ impl LanguageModel for CloudLanguageModel {
model.tool_model_id().into(), model.tool_model_id().into(),
model.default_temperature(), model.default_temperature(),
model.max_output_tokens(), model.max_output_tokens(),
model.mode(),
); );
request.tool_choice = Some(anthropic::ToolChoice::Tool { request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(), name: tool_name.clone(),

View file

@ -109,6 +109,7 @@ impl AnthropicSettingsContent {
max_output_tokens, max_output_tokens,
default_temperature, default_temperature,
extra_beta_headers, extra_beta_headers,
mode,
} => Some(provider::anthropic::AvailableModel { } => Some(provider::anthropic::AvailableModel {
name, name,
display_name, display_name,
@ -124,6 +125,7 @@ impl AnthropicSettingsContent {
max_output_tokens, max_output_tokens,
default_temperature, default_temperature,
extra_beta_headers, extra_beta_headers,
mode: Some(mode.into()),
}), }),
_ => None, _ => None,
}) })

View file

@ -2503,6 +2503,10 @@ message SlashCommandOutputSection {
optional string metadata = 4; optional string metadata = 4;
} }
message ThoughtProcessOutputSection {
AnchorRange range = 1;
}
message ContextOperation { message ContextOperation {
oneof variant { oneof variant {
InsertMessage insert_message = 1; InsertMessage insert_message = 1;
@ -2512,6 +2516,7 @@ message ContextOperation {
SlashCommandStarted slash_command_started = 6; SlashCommandStarted slash_command_started = 6;
SlashCommandOutputSectionAdded slash_command_output_section_added = 7; SlashCommandOutputSectionAdded slash_command_output_section_added = 7;
SlashCommandCompleted slash_command_completed = 8; SlashCommandCompleted slash_command_completed = 8;
ThoughtProcessOutputSectionAdded thought_process_output_section_added = 9;
} }
reserved 4; reserved 4;
@ -2556,6 +2561,12 @@ message ContextOperation {
repeated VectorClockEntry version = 5; repeated VectorClockEntry version = 5;
} }
message ThoughtProcessOutputSectionAdded {
LamportTimestamp timestamp = 1;
ThoughtProcessOutputSection section = 2;
repeated VectorClockEntry version = 3;
}
message BufferOperation { message BufferOperation {
Operation operation = 1; Operation operation = 1;
} }

View file

@ -68,6 +68,21 @@ You can add custom models to the Anthropic provider by adding the following to y
Custom models will be listed in the model dropdown in the assistant panel. Custom models will be listed in the model dropdown in the assistant panel.
You can configure a model to use [extended thinking](https://docs.anthropic.com/en/docs/about-claude/models/extended-thinking-models) (if it supports it),
by changing the mode in of your models configuration to `thinking`, for example:
```json
{
"name": "claude-3-7-sonnet-latest",
"display_name": "claude-3-7-sonnet-thinking",
"max_tokens": 200000,
"mode": {
"type": "thinking",
"budget_tokens": 4_096
}
}
```
### GitHub Copilot Chat {#github-copilot-chat} ### GitHub Copilot Chat {#github-copilot-chat}
You can use GitHub Copilot chat with the Zed assistant by choosing it via the model dropdown in the assistant panel. You can use GitHub Copilot chat with the Zed assistant by choosing it via the model dropdown in the assistant panel.