diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index e886e38976..2c5acfc0f2 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -24,6 +24,16 @@ pub struct AnthropicModelCacheConfiguration { pub max_cache_anchors: usize, } +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub enum AnthropicModelMode { + #[default] + Default, + Thinking { + budget_tokens: Option, + }, +} + #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] pub enum Model { @@ -32,6 +42,11 @@ pub enum Model { Claude3_5Sonnet, #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")] Claude3_7Sonnet, + #[serde( + rename = "claude-3-7-sonnet-thinking", + alias = "claude-3-7-sonnet-thinking-latest" + )] + Claude3_7SonnetThinking, #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")] Claude3_5Haiku, #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")] @@ -54,6 +69,8 @@ pub enum Model { default_temperature: Option, #[serde(default)] extra_beta_headers: Vec, + #[serde(default)] + mode: AnthropicModelMode, }, } @@ -61,6 +78,8 @@ impl Model { pub fn from_id(id: &str) -> Result { if id.starts_with("claude-3-5-sonnet") { Ok(Self::Claude3_5Sonnet) + } else if id.starts_with("claude-3-7-sonnet-thinking") { + Ok(Self::Claude3_7SonnetThinking) } else if id.starts_with("claude-3-7-sonnet") { Ok(Self::Claude3_7Sonnet) } else if id.starts_with("claude-3-5-haiku") { @@ -80,6 +99,20 @@ impl Model { match self { Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest", Model::Claude3_7Sonnet => "claude-3-7-sonnet-latest", + Model::Claude3_7SonnetThinking => "claude-3-7-sonnet-thinking-latest", + Model::Claude3_5Haiku => "claude-3-5-haiku-latest", + Model::Claude3Opus => "claude-3-opus-latest", + Model::Claude3Sonnet => "claude-3-sonnet-20240229", + Model::Claude3Haiku => "claude-3-haiku-20240307", + Self::Custom { name, .. } => name, + } + } + + /// The id of the model that should be used for making API requests + pub fn request_id(&self) -> &str { + match self { + Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest", + Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => "claude-3-7-sonnet-latest", Model::Claude3_5Haiku => "claude-3-5-haiku-latest", Model::Claude3Opus => "claude-3-opus-latest", Model::Claude3Sonnet => "claude-3-sonnet-20240229", @@ -92,6 +125,7 @@ impl Model { match self { Self::Claude3_7Sonnet => "Claude 3.7 Sonnet", Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", + Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking", Self::Claude3_5Haiku => "Claude 3.5 Haiku", Self::Claude3Opus => "Claude 3 Opus", Self::Claude3Sonnet => "Claude 3 Sonnet", @@ -107,6 +141,7 @@ impl Model { Self::Claude3_5Sonnet | Self::Claude3_5Haiku | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration { min_total_token: 2_048, should_speculate: true, @@ -125,6 +160,7 @@ impl Model { Self::Claude3_5Sonnet | Self::Claude3_5Haiku | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking | Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200_000, @@ -135,7 +171,10 @@ impl Model { pub fn max_output_tokens(&self) -> u32 { match self { Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096, - Self::Claude3_5Sonnet | Self::Claude3_7Sonnet | Self::Claude3_5Haiku => 8_192, + Self::Claude3_5Sonnet + | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking + | Self::Claude3_5Haiku => 8_192, Self::Custom { max_output_tokens, .. } => max_output_tokens.unwrap_or(4_096), @@ -146,6 +185,7 @@ impl Model { match self { Self::Claude3_5Sonnet | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking | Self::Claude3_5Haiku | Self::Claude3Opus | Self::Claude3Sonnet @@ -157,6 +197,21 @@ impl Model { } } + pub fn mode(&self) -> AnthropicModelMode { + match self { + Self::Claude3_5Sonnet + | Self::Claude3_7Sonnet + | Self::Claude3_5Haiku + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3Haiku => AnthropicModelMode::Default, + Self::Claude3_7SonnetThinking => AnthropicModelMode::Thinking { + budget_tokens: Some(4_096), + }, + Self::Custom { mode, .. } => mode.clone(), + } + } + pub const DEFAULT_BETA_HEADERS: &[&str] = &["prompt-caching-2024-07-31"]; pub fn beta_headers(&self) -> String { @@ -188,7 +243,7 @@ impl Model { { tool_override } else { - self.id() + self.request_id() } } } @@ -409,6 +464,8 @@ pub async fn extract_tool_args_from_events( Err(error) => Some(Err(error)), Ok(Event::ContentBlockDelta { index, delta }) => match delta { ContentDelta::TextDelta { .. } => None, + ContentDelta::ThinkingDelta { .. } => None, + ContentDelta::SignatureDelta { .. } => None, ContentDelta::InputJsonDelta { partial_json } => { if index == tool_use_index { Some(Ok(partial_json)) @@ -487,6 +544,10 @@ pub enum RequestContent { pub enum ResponseContent { #[serde(rename = "text")] Text { text: String }, + #[serde(rename = "thinking")] + Thinking { thinking: String }, + #[serde(rename = "redacted_thinking")] + RedactedThinking { data: String }, #[serde(rename = "tool_use")] ToolUse { id: String, @@ -518,6 +579,12 @@ pub enum ToolChoice { Tool { name: String }, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum Thinking { + Enabled { budget_tokens: Option }, +} + #[derive(Debug, Serialize, Deserialize)] pub struct Request { pub model: String, @@ -526,6 +593,8 @@ pub struct Request { #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] + pub thinking: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_choice: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub system: Option, @@ -609,6 +678,10 @@ pub enum Event { pub enum ContentDelta { #[serde(rename = "text_delta")] TextDelta { text: String }, + #[serde(rename = "thinking_delta")] + ThinkingDelta { thinking: String }, + #[serde(rename = "signature_delta")] + SignatureDelta { signature: String }, #[serde(rename = "input_json_delta")] InputJsonDelta { partial_json: String }, } diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index b3363fb73a..f640a663e6 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -1,5 +1,6 @@ use crate::thread::{ - LastRestoreCheckpoint, MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ThreadFeedback, + LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError, + ThreadEvent, ThreadFeedback, }; use crate::thread_store::ThreadStore; use crate::tool_use::{ToolUse, ToolUseStatus}; @@ -7,10 +8,10 @@ use crate::ui::ContextPill; use collections::HashMap; use editor::{Editor, MultiBuffer}; use gpui::{ - list, percentage, pulsating_between, AbsoluteLength, Animation, AnimationExt, AnyElement, App, - ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, - ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, - Transformation, UnderlineStyle, WeakEntity, + linear_color_stop, linear_gradient, list, percentage, pulsating_between, AbsoluteLength, + Animation, AnimationExt, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty, + Entity, Focusable, Length, ListAlignment, ListOffset, ListState, ScrollHandle, StyleRefinement, + Subscription, Task, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, }; use language::{Buffer, LanguageRegistry}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; @@ -35,15 +36,175 @@ pub struct ActiveThread { save_thread_task: Option>, messages: Vec, list_state: ListState, - rendered_messages_by_id: HashMap>, + rendered_messages_by_id: HashMap, rendered_scripting_tool_uses: HashMap>, rendered_tool_use_labels: HashMap>, editing_message: Option<(MessageId, EditMessageState)>, expanded_tool_uses: HashMap, + expanded_thinking_segments: HashMap<(MessageId, usize), bool>, last_error: Option, _subscriptions: Vec, } +struct RenderedMessage { + language_registry: Arc, + segments: Vec, +} + +impl RenderedMessage { + fn from_segments( + segments: &[MessageSegment], + language_registry: Arc, + window: &Window, + cx: &mut App, + ) -> Self { + let mut this = Self { + language_registry, + segments: Vec::with_capacity(segments.len()), + }; + for segment in segments { + this.push_segment(segment, window, cx); + } + this + } + + fn append_thinking(&mut self, text: &String, window: &Window, cx: &mut App) { + if let Some(RenderedMessageSegment::Thinking { + content, + scroll_handle, + }) = self.segments.last_mut() + { + content.update(cx, |markdown, cx| { + markdown.append(text, cx); + }); + scroll_handle.scroll_to_bottom(); + } else { + self.segments.push(RenderedMessageSegment::Thinking { + content: render_markdown(text.into(), self.language_registry.clone(), window, cx), + scroll_handle: ScrollHandle::default(), + }); + } + } + + fn append_text(&mut self, text: &String, window: &Window, cx: &mut App) { + if let Some(RenderedMessageSegment::Text(markdown)) = self.segments.last_mut() { + markdown.update(cx, |markdown, cx| markdown.append(text, cx)); + } else { + self.segments + .push(RenderedMessageSegment::Text(render_markdown( + SharedString::from(text), + self.language_registry.clone(), + window, + cx, + ))); + } + } + + fn push_segment(&mut self, segment: &MessageSegment, window: &Window, cx: &mut App) { + let rendered_segment = match segment { + MessageSegment::Thinking(text) => RenderedMessageSegment::Thinking { + content: render_markdown(text.into(), self.language_registry.clone(), window, cx), + scroll_handle: ScrollHandle::default(), + }, + MessageSegment::Text(text) => RenderedMessageSegment::Text(render_markdown( + text.into(), + self.language_registry.clone(), + window, + cx, + )), + }; + self.segments.push(rendered_segment); + } +} + +enum RenderedMessageSegment { + Thinking { + content: Entity, + scroll_handle: ScrollHandle, + }, + Text(Entity), +} + +fn render_markdown( + text: SharedString, + language_registry: Arc, + window: &Window, + cx: &mut App, +) -> Entity { + let theme_settings = ThemeSettings::get_global(cx); + let colors = cx.theme().colors(); + let ui_font_size = TextSize::Default.rems(cx); + let buffer_font_size = TextSize::Small.rems(cx); + let mut text_style = window.text_style(); + + text_style.refine(&TextStyleRefinement { + font_family: Some(theme_settings.ui_font.family.clone()), + font_fallbacks: theme_settings.ui_font.fallbacks.clone(), + font_features: Some(theme_settings.ui_font.features.clone()), + font_size: Some(ui_font_size.into()), + color: Some(cx.theme().colors().text), + ..Default::default() + }); + + let markdown_style = MarkdownStyle { + base_text_style: text_style, + syntax: cx.theme().syntax().clone(), + selection_background_color: cx.theme().players().local().selection, + code_block_overflow_x_scroll: true, + table_overflow_x_scroll: true, + code_block: StyleRefinement { + margin: EdgesRefinement { + top: Some(Length::Definite(rems(0.).into())), + left: Some(Length::Definite(rems(0.).into())), + right: Some(Length::Definite(rems(0.).into())), + bottom: Some(Length::Definite(rems(0.5).into())), + }, + padding: EdgesRefinement { + top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), + left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), + right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), + bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), + }, + background: Some(colors.editor_background.into()), + border_color: Some(colors.border_variant), + border_widths: EdgesRefinement { + top: Some(AbsoluteLength::Pixels(Pixels(1.))), + left: Some(AbsoluteLength::Pixels(Pixels(1.))), + right: Some(AbsoluteLength::Pixels(Pixels(1.))), + bottom: Some(AbsoluteLength::Pixels(Pixels(1.))), + }, + text: Some(TextStyleRefinement { + font_family: Some(theme_settings.buffer_font.family.clone()), + font_fallbacks: theme_settings.buffer_font.fallbacks.clone(), + font_features: Some(theme_settings.buffer_font.features.clone()), + font_size: Some(buffer_font_size.into()), + ..Default::default() + }), + ..Default::default() + }, + inline_code: TextStyleRefinement { + font_family: Some(theme_settings.buffer_font.family.clone()), + font_fallbacks: theme_settings.buffer_font.fallbacks.clone(), + font_features: Some(theme_settings.buffer_font.features.clone()), + font_size: Some(buffer_font_size.into()), + background_color: Some(colors.editor_foreground.opacity(0.1)), + ..Default::default() + }, + link: TextStyleRefinement { + background_color: Some(colors.editor_foreground.opacity(0.025)), + underline: Some(UnderlineStyle { + color: Some(colors.text_accent.opacity(0.5)), + thickness: px(1.), + ..Default::default() + }), + ..Default::default() + }, + ..Default::default() + }; + + cx.new(|cx| Markdown::new(text, markdown_style, Some(language_registry), None, cx)) +} + struct EditMessageState { editor: Entity, } @@ -75,6 +236,7 @@ impl ActiveThread { rendered_scripting_tool_uses: HashMap::default(), rendered_tool_use_labels: HashMap::default(), expanded_tool_uses: HashMap::default(), + expanded_thinking_segments: HashMap::default(), list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { let this = cx.entity().downgrade(); move |ix, window: &mut Window, cx: &mut App| { @@ -88,7 +250,7 @@ impl ActiveThread { }; for message in thread.read(cx).messages().cloned().collect::>() { - this.push_message(&message.id, message.text.clone(), window, cx); + this.push_message(&message.id, &message.segments, window, cx); for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) { this.render_tool_use_label_markdown( @@ -156,7 +318,7 @@ impl ActiveThread { fn push_message( &mut self, id: &MessageId, - text: String, + segments: &[MessageSegment], window: &mut Window, cx: &mut Context, ) { @@ -164,8 +326,9 @@ impl ActiveThread { self.messages.push(*id); self.list_state.splice(old_len..old_len, 1); - let markdown = self.render_markdown(text.into(), window, cx); - self.rendered_messages_by_id.insert(*id, markdown); + let rendered_message = + RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx); + self.rendered_messages_by_id.insert(*id, rendered_message); self.list_state.scroll_to(ListOffset { item_ix: old_len, offset_in_item: Pixels(0.0), @@ -175,7 +338,7 @@ impl ActiveThread { fn edited_message( &mut self, id: &MessageId, - text: String, + segments: &[MessageSegment], window: &mut Window, cx: &mut Context, ) { @@ -183,8 +346,9 @@ impl ActiveThread { return; }; self.list_state.splice(index..index + 1, 1); - let markdown = self.render_markdown(text.into(), window, cx); - self.rendered_messages_by_id.insert(*id, markdown); + let rendered_message = + RenderedMessage::from_segments(segments, self.language_registry.clone(), window, cx); + self.rendered_messages_by_id.insert(*id, rendered_message); } fn deleted_message(&mut self, id: &MessageId) { @@ -196,94 +360,6 @@ impl ActiveThread { self.rendered_messages_by_id.remove(id); } - fn render_markdown( - &self, - text: SharedString, - window: &Window, - cx: &mut Context, - ) -> Entity { - let theme_settings = ThemeSettings::get_global(cx); - let colors = cx.theme().colors(); - let ui_font_size = TextSize::Default.rems(cx); - let buffer_font_size = TextSize::Small.rems(cx); - let mut text_style = window.text_style(); - - text_style.refine(&TextStyleRefinement { - font_family: Some(theme_settings.ui_font.family.clone()), - font_fallbacks: theme_settings.ui_font.fallbacks.clone(), - font_features: Some(theme_settings.ui_font.features.clone()), - font_size: Some(ui_font_size.into()), - color: Some(cx.theme().colors().text), - ..Default::default() - }); - - let markdown_style = MarkdownStyle { - base_text_style: text_style, - syntax: cx.theme().syntax().clone(), - selection_background_color: cx.theme().players().local().selection, - code_block_overflow_x_scroll: true, - table_overflow_x_scroll: true, - code_block: StyleRefinement { - margin: EdgesRefinement { - top: Some(Length::Definite(rems(0.).into())), - left: Some(Length::Definite(rems(0.).into())), - right: Some(Length::Definite(rems(0.).into())), - bottom: Some(Length::Definite(rems(0.5).into())), - }, - padding: EdgesRefinement { - top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), - left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), - right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), - bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), - }, - background: Some(colors.editor_background.into()), - border_color: Some(colors.border_variant), - border_widths: EdgesRefinement { - top: Some(AbsoluteLength::Pixels(Pixels(1.))), - left: Some(AbsoluteLength::Pixels(Pixels(1.))), - right: Some(AbsoluteLength::Pixels(Pixels(1.))), - bottom: Some(AbsoluteLength::Pixels(Pixels(1.))), - }, - text: Some(TextStyleRefinement { - font_family: Some(theme_settings.buffer_font.family.clone()), - font_fallbacks: theme_settings.buffer_font.fallbacks.clone(), - font_features: Some(theme_settings.buffer_font.features.clone()), - font_size: Some(buffer_font_size.into()), - ..Default::default() - }), - ..Default::default() - }, - inline_code: TextStyleRefinement { - font_family: Some(theme_settings.buffer_font.family.clone()), - font_fallbacks: theme_settings.buffer_font.fallbacks.clone(), - font_features: Some(theme_settings.buffer_font.features.clone()), - font_size: Some(buffer_font_size.into()), - background_color: Some(colors.editor_foreground.opacity(0.1)), - ..Default::default() - }, - link: TextStyleRefinement { - background_color: Some(colors.editor_foreground.opacity(0.025)), - underline: Some(UnderlineStyle { - color: Some(colors.text_accent.opacity(0.5)), - thickness: px(1.), - ..Default::default() - }), - ..Default::default() - }, - ..Default::default() - }; - - cx.new(|cx| { - Markdown::new( - text, - markdown_style, - Some(self.language_registry.clone()), - None, - cx, - ) - }) - } - /// Renders the input of a scripting tool use to Markdown. /// /// Does nothing if the tool use does not correspond to the scripting tool. @@ -303,8 +379,12 @@ impl ActiveThread { .map(|input| input.lua_script) .unwrap_or_default(); - let lua_script = - self.render_markdown(format!("```lua\n{lua_script}\n```").into(), window, cx); + let lua_script = render_markdown( + format!("```lua\n{lua_script}\n```").into(), + self.language_registry.clone(), + window, + cx, + ); self.rendered_scripting_tool_uses .insert(tool_use_id, lua_script); @@ -319,7 +399,12 @@ impl ActiveThread { ) { self.rendered_tool_use_labels.insert( tool_use_id, - self.render_markdown(tool_label.into(), window, cx), + render_markdown( + tool_label.into(), + self.language_registry.clone(), + window, + cx, + ), ); } @@ -339,33 +424,36 @@ impl ActiveThread { } ThreadEvent::DoneStreaming => {} ThreadEvent::StreamedAssistantText(message_id, text) => { - if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) { - markdown.update(cx, |markdown, cx| { - markdown.append(text, cx); - }); + if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) { + rendered_message.append_text(text, window, cx); + } + } + ThreadEvent::StreamedAssistantThinking(message_id, text) => { + if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) { + rendered_message.append_thinking(text, window, cx); } } ThreadEvent::MessageAdded(message_id) => { - if let Some(message_text) = self + if let Some(message_segments) = self .thread .read(cx) .message(*message_id) - .map(|message| message.text.clone()) + .map(|message| message.segments.clone()) { - self.push_message(message_id, message_text, window, cx); + self.push_message(message_id, &message_segments, window, cx); } self.save_thread(cx); cx.notify(); } ThreadEvent::MessageEdited(message_id) => { - if let Some(message_text) = self + if let Some(message_segments) = self .thread .read(cx) .message(*message_id) - .map(|message| message.text.clone()) + .map(|message| message.segments.clone()) { - self.edited_message(message_id, message_text, window, cx); + self.edited_message(message_id, &message_segments, window, cx); } self.save_thread(cx); @@ -490,10 +578,16 @@ impl ActiveThread { fn start_editing_message( &mut self, message_id: MessageId, - message_text: String, + message_segments: &[MessageSegment], window: &mut Window, cx: &mut Context, ) { + // User message should always consist of a single text segment, + // therefore we can skip returning early if it's not a text segment. + let Some(MessageSegment::Text(message_text)) = message_segments.first() else { + return; + }; + let buffer = cx.new(|cx| { MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx) }); @@ -534,7 +628,12 @@ impl ActiveThread { }; let edited_text = state.editor.read(cx).text(cx); self.thread.update(cx, |thread, cx| { - thread.edit_message(message_id, Role::User, edited_text, cx); + thread.edit_message( + message_id, + Role::User, + vec![MessageSegment::Text(edited_text)], + cx, + ); for message_id in self.messages_after(message_id) { thread.delete_message(*message_id, cx); } @@ -617,7 +716,7 @@ impl ActiveThread { return Empty.into_any(); }; - let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else { + let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else { return Empty.into_any(); }; @@ -759,7 +858,10 @@ impl ActiveThread { .min_h_6() .child(edit_message_editor) } else { - div().min_h_6().text_ui(cx).child(markdown.clone()) + div() + .min_h_6() + .text_ui(cx) + .child(self.render_message_content(message_id, rendered_message, cx)) }, ) .when_some(context, |parent, context| { @@ -869,11 +971,12 @@ impl ActiveThread { Button::new("edit-message", "Edit") .label_size(LabelSize::Small) .on_click(cx.listener({ - let message_text = message.text.clone(); + let message_segments = + message.segments.clone(); move |this, _, window, cx| { this.start_editing_message( message_id, - message_text.clone(), + &message_segments, window, cx, ); @@ -995,6 +1098,190 @@ impl ActiveThread { .into_any() } + fn render_message_content( + &self, + message_id: MessageId, + rendered_message: &RenderedMessage, + cx: &Context, + ) -> impl IntoElement { + let pending_thinking_segment_index = rendered_message + .segments + .iter() + .enumerate() + .last() + .filter(|(_, segment)| matches!(segment, RenderedMessageSegment::Thinking { .. })) + .map(|(index, _)| index); + + div() + .text_ui(cx) + .gap_2() + .children( + rendered_message.segments.iter().enumerate().map( + |(index, segment)| match segment { + RenderedMessageSegment::Thinking { + content, + scroll_handle, + } => self + .render_message_thinking_segment( + message_id, + index, + content.clone(), + &scroll_handle, + Some(index) == pending_thinking_segment_index, + cx, + ) + .into_any_element(), + RenderedMessageSegment::Text(markdown) => { + div().p_2p5().child(markdown.clone()).into_any_element() + } + }, + ), + ) + } + + fn render_message_thinking_segment( + &self, + message_id: MessageId, + ix: usize, + markdown: Entity, + scroll_handle: &ScrollHandle, + pending: bool, + cx: &Context, + ) -> impl IntoElement { + let is_open = self + .expanded_thinking_segments + .get(&(message_id, ix)) + .copied() + .unwrap_or_default(); + + let lighter_border = cx.theme().colors().border.opacity(0.5); + let editor_bg = cx.theme().colors().editor_background; + + v_flex() + .rounded_lg() + .border_1() + .border_color(lighter_border) + .child( + h_flex() + .justify_between() + .py_1() + .pl_1() + .pr_2() + .bg(cx.theme().colors().editor_foreground.opacity(0.025)) + .map(|this| { + if is_open { + this.rounded_t_md() + .border_b_1() + .border_color(lighter_border) + } else { + this.rounded_md() + } + }) + .child( + h_flex() + .gap_1() + .child(Disclosure::new("thinking-disclosure", is_open).on_click( + cx.listener({ + move |this, _event, _window, _cx| { + let is_open = this + .expanded_thinking_segments + .entry((message_id, ix)) + .or_insert(false); + + *is_open = !*is_open; + } + }), + )) + .child({ + if pending { + Label::new("Thinking…") + .size(LabelSize::Small) + .buffer_font(cx) + .with_animation( + "pulsating-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 0.8)), + |label, delta| label.alpha(delta), + ) + .into_any_element() + } else { + Label::new("Thought Process") + .size(LabelSize::Small) + .buffer_font(cx) + .into_any_element() + } + }), + ) + .child({ + let (icon_name, color, animated) = if pending { + (IconName::ArrowCircle, Color::Accent, true) + } else { + (IconName::Check, Color::Success, false) + }; + + let icon = Icon::new(icon_name).color(color).size(IconSize::Small); + + if animated { + icon.with_animation( + "arrow-circle", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| { + icon.transform(Transformation::rotate(percentage(delta))) + }, + ) + .into_any_element() + } else { + icon.into_any_element() + } + }), + ) + .when(pending && !is_open, |this| { + let gradient_overlay = div() + .rounded_b_lg() + .h_20() + .absolute() + .w_full() + .bottom_0() + .left_0() + .bg(linear_gradient( + 180., + linear_color_stop(editor_bg, 1.), + linear_color_stop(editor_bg.opacity(0.2), 0.), + )); + + this.child( + div() + .relative() + .bg(editor_bg) + .rounded_b_lg() + .text_ui_sm(cx) + .child( + div() + .id(("thinking-content", ix)) + .p_2() + .h_20() + .track_scroll(scroll_handle) + .child(markdown.clone()) + .overflow_hidden(), + ) + .child(gradient_overlay), + ) + }) + .when(is_open, |this| { + this.child( + div() + .id(("thinking-content", ix)) + .h_full() + .p_2() + .rounded_b_lg() + .bg(editor_bg) + .text_ui_sm(cx) + .child(markdown.clone()), + ) + }) + } + fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context) -> impl IntoElement { let is_open = self .expanded_tool_uses @@ -1258,8 +1545,9 @@ impl ActiveThread { } }), )) - .child(div().text_ui_sm(cx).child(self.render_markdown( + .child(div().text_ui_sm(cx).child(render_markdown( tool_use.ui_text.clone(), + self.language_registry.clone(), window, cx, ))) diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index c2cf4fe550..f9c2c40c03 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -29,7 +29,8 @@ use uuid::Uuid; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; use crate::thread_store::{ - SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse, + SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, + SerializedToolUse, }; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState}; @@ -69,7 +70,47 @@ impl MessageId { pub struct Message { pub id: MessageId, pub role: Role, - pub text: String, + pub segments: Vec, +} + +impl Message { + pub fn push_thinking(&mut self, text: &str) { + if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() { + segment.push_str(text); + } else { + self.segments + .push(MessageSegment::Thinking(text.to_string())); + } + } + + pub fn push_text(&mut self, text: &str) { + if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() { + segment.push_str(text); + } else { + self.segments.push(MessageSegment::Text(text.to_string())); + } + } + + pub fn to_string(&self) -> String { + let mut result = String::new(); + for segment in &self.segments { + match segment { + MessageSegment::Text(text) => result.push_str(text), + MessageSegment::Thinking(text) => { + result.push_str(""); + result.push_str(text); + result.push_str(""); + } + } + } + result + } +} + +#[derive(Debug, Clone)] +pub enum MessageSegment { + Text(String), + Thinking(String), } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -226,7 +267,16 @@ impl Thread { .map(|message| Message { id: message.id, role: message.role, - text: message.text, + segments: message + .segments + .into_iter() + .map(|segment| match segment { + SerializedMessageSegment::Text { text } => MessageSegment::Text(text), + SerializedMessageSegment::Thinking { text } => { + MessageSegment::Thinking(text) + } + }) + .collect(), }) .collect(), next_message_id, @@ -419,7 +469,8 @@ impl Thread { checkpoint: Option, cx: &mut Context, ) -> MessageId { - let message_id = self.insert_message(Role::User, text, cx); + let message_id = + self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx); let context_ids = context.iter().map(|context| context.id).collect::>(); self.context .extend(context.into_iter().map(|context| (context.id, context))); @@ -433,15 +484,11 @@ impl Thread { pub fn insert_message( &mut self, role: Role, - text: impl Into, + segments: Vec, cx: &mut Context, ) -> MessageId { let id = self.next_message_id.post_inc(); - self.messages.push(Message { - id, - role, - text: text.into(), - }); + self.messages.push(Message { id, role, segments }); self.touch_updated_at(); cx.emit(ThreadEvent::MessageAdded(id)); id @@ -451,14 +498,14 @@ impl Thread { &mut self, id: MessageId, new_role: Role, - new_text: String, + new_segments: Vec, cx: &mut Context, ) -> bool { let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { return false; }; message.role = new_role; - message.text = new_text; + message.segments = new_segments; self.touch_updated_at(); cx.emit(ThreadEvent::MessageEdited(id)); true @@ -489,7 +536,14 @@ impl Thread { }); text.push('\n'); - text.push_str(&message.text); + for segment in &message.segments { + match segment { + MessageSegment::Text(content) => text.push_str(content), + MessageSegment::Thinking(content) => { + text.push_str(&format!("{}", content)) + } + } + } text.push('\n'); } @@ -502,6 +556,7 @@ impl Thread { cx.spawn(async move |this, cx| { let initial_project_snapshot = initial_project_snapshot.await; this.read_with(cx, |this, cx| SerializedThread { + version: SerializedThread::VERSION.to_string(), summary: this.summary_or_default(), updated_at: this.updated_at(), messages: this @@ -509,7 +564,18 @@ impl Thread { .map(|message| SerializedMessage { id: message.id, role: message.role, - text: message.text.clone(), + segments: message + .segments + .iter() + .map(|segment| match segment { + MessageSegment::Text(text) => { + SerializedMessageSegment::Text { text: text.clone() } + } + MessageSegment::Thinking(text) => { + SerializedMessageSegment::Thinking { text: text.clone() } + } + }) + .collect(), tool_uses: this .tool_uses_for_message(message.id, cx) .into_iter() @@ -733,10 +799,10 @@ impl Thread { } } - if !message.text.is_empty() { + if !message.segments.is_empty() { request_message .content - .push(MessageContent::Text(message.text.clone())); + .push(MessageContent::Text(message.to_string())); } match request_kind { @@ -826,7 +892,11 @@ impl Thread { thread.update(cx, |thread, cx| { match event { LanguageModelCompletionEvent::StartMessage { .. } => { - thread.insert_message(Role::Assistant, String::new(), cx); + thread.insert_message( + Role::Assistant, + vec![MessageSegment::Text(String::new())], + cx, + ); } LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; @@ -840,7 +910,7 @@ impl Thread { LanguageModelCompletionEvent::Text(chunk) => { if let Some(last_message) = thread.messages.last_mut() { if last_message.role == Role::Assistant { - last_message.text.push_str(&chunk); + last_message.push_text(&chunk); cx.emit(ThreadEvent::StreamedAssistantText( last_message.id, chunk, @@ -851,7 +921,33 @@ impl Thread { // // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // will result in duplicating the text of the chunk in the rendered Markdown. - thread.insert_message(Role::Assistant, chunk, cx); + thread.insert_message( + Role::Assistant, + vec![MessageSegment::Text(chunk.to_string())], + cx, + ); + }; + } + } + LanguageModelCompletionEvent::Thinking(chunk) => { + if let Some(last_message) = thread.messages.last_mut() { + if last_message.role == Role::Assistant { + last_message.push_thinking(&chunk); + cx.emit(ThreadEvent::StreamedAssistantThinking( + last_message.id, + chunk, + )); + } else { + // If we won't have an Assistant message yet, assume this chunk marks the beginning + // of a new Assistant response. + // + // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it + // will result in duplicating the text of the chunk in the rendered Markdown. + thread.insert_message( + Role::Assistant, + vec![MessageSegment::Thinking(chunk.to_string())], + cx, + ); }; } } @@ -1357,7 +1453,14 @@ impl Thread { Role::System => "System", } )?; - writeln!(markdown, "{}\n", message.text)?; + for segment in &message.segments { + match segment { + MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?, + MessageSegment::Thinking(text) => { + writeln!(markdown, "{}\n", text)? + } + } + } for tool_use in self.tool_uses_for_message(message.id, cx) { writeln!( @@ -1416,6 +1519,7 @@ pub enum ThreadEvent { ShowError(ThreadError), StreamedCompletion, StreamedAssistantText(MessageId, String), + StreamedAssistantThinking(MessageId, String), DoneStreaming, MessageAdded(MessageId), MessageEdited(MessageId), diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index cfdeb674d9..7c54eef658 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::path::PathBuf; use std::sync::Arc; @@ -12,7 +13,7 @@ use futures::FutureExt as _; use gpui::{ prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, }; -use heed::types::{SerdeBincode, SerdeJson}; +use heed::types::SerdeBincode; use heed::Database; use language_model::{LanguageModelToolUseId, Role}; use project::Project; @@ -259,6 +260,7 @@ pub struct SerializedThreadMetadata { #[derive(Serialize, Deserialize)] pub struct SerializedThread { + pub version: String, pub summary: SharedString, pub updated_at: DateTime, pub messages: Vec, @@ -266,17 +268,55 @@ pub struct SerializedThread { pub initial_project_snapshot: Option>, } +impl SerializedThread { + pub const VERSION: &'static str = "0.1.0"; + + pub fn from_json(json: &[u8]) -> Result { + let saved_thread_json = serde_json::from_slice::(json)?; + match saved_thread_json.get("version") { + Some(serde_json::Value::String(version)) => match version.as_str() { + SerializedThread::VERSION => Ok(serde_json::from_value::( + saved_thread_json, + )?), + _ => Err(anyhow!( + "unrecognized serialized thread version: {}", + version + )), + }, + None => { + let saved_thread = + serde_json::from_value::(saved_thread_json)?; + Ok(saved_thread.upgrade()) + } + version => Err(anyhow!( + "unrecognized serialized thread version: {:?}", + version + )), + } + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct SerializedMessage { pub id: MessageId, pub role: Role, - pub text: String, + #[serde(default)] + pub segments: Vec, #[serde(default)] pub tool_uses: Vec, #[serde(default)] pub tool_results: Vec, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum SerializedMessageSegment { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "thinking")] + Thinking { text: String }, +} + #[derive(Debug, Serialize, Deserialize)] pub struct SerializedToolUse { pub id: LanguageModelToolUseId, @@ -291,6 +331,50 @@ pub struct SerializedToolResult { pub content: Arc, } +#[derive(Serialize, Deserialize)] +struct LegacySerializedThread { + pub summary: SharedString, + pub updated_at: DateTime, + pub messages: Vec, + #[serde(default)] + pub initial_project_snapshot: Option>, +} + +impl LegacySerializedThread { + pub fn upgrade(self) -> SerializedThread { + SerializedThread { + version: SerializedThread::VERSION.to_string(), + summary: self.summary, + updated_at: self.updated_at, + messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(), + initial_project_snapshot: self.initial_project_snapshot, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct LegacySerializedMessage { + pub id: MessageId, + pub role: Role, + pub text: String, + #[serde(default)] + pub tool_uses: Vec, + #[serde(default)] + pub tool_results: Vec, +} + +impl LegacySerializedMessage { + fn upgrade(self) -> SerializedMessage { + SerializedMessage { + id: self.id, + role: self.role, + segments: vec![SerializedMessageSegment::Text { text: self.text }], + tool_uses: self.tool_uses, + tool_results: self.tool_results, + } + } +} + struct GlobalThreadsDatabase( Shared, Arc>>>, ); @@ -300,7 +384,25 @@ impl Global for GlobalThreadsDatabase {} pub(crate) struct ThreadsDatabase { executor: BackgroundExecutor, env: heed::Env, - threads: Database, SerdeJson>, + threads: Database, SerializedThread>, +} + +impl heed::BytesEncode<'_> for SerializedThread { + type EItem = SerializedThread; + + fn bytes_encode(item: &Self::EItem) -> Result, heed::BoxedError> { + serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into) + } +} + +impl<'a> heed::BytesDecode<'a> for SerializedThread { + type DItem = SerializedThread; + + fn bytes_decode(bytes: &'a [u8]) -> Result { + // We implement this type manually because we want to call `SerializedThread::from_json`, + // instead of the Deserialize trait implementation for `SerializedThread`. + SerializedThread::from_json(bytes).map_err(Into::into) + } } impl ThreadsDatabase { diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index ef15afd05c..1ebb2a9289 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -162,6 +162,11 @@ pub enum ContextOperation { section: SlashCommandOutputSection, version: clock::Global, }, + ThoughtProcessOutputSectionAdded { + timestamp: clock::Lamport, + section: ThoughtProcessOutputSection, + version: clock::Global, + }, BufferOperation(language::Operation), } @@ -259,6 +264,20 @@ impl ContextOperation { version: language::proto::deserialize_version(&message.version), }) } + proto::context_operation::Variant::ThoughtProcessOutputSectionAdded(message) => { + let section = message.section.context("missing section")?; + Ok(Self::ThoughtProcessOutputSectionAdded { + timestamp: language::proto::deserialize_timestamp( + message.timestamp.context("missing timestamp")?, + ), + section: ThoughtProcessOutputSection { + range: language::proto::deserialize_anchor_range( + section.range.context("invalid range")?, + )?, + }, + version: language::proto::deserialize_version(&message.version), + }) + } proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation( language::proto::deserialize_operation( op.operation.context("invalid buffer operation")?, @@ -370,6 +389,27 @@ impl ContextOperation { }, )), }, + Self::ThoughtProcessOutputSectionAdded { + timestamp, + section, + version, + } => proto::ContextOperation { + variant: Some( + proto::context_operation::Variant::ThoughtProcessOutputSectionAdded( + proto::context_operation::ThoughtProcessOutputSectionAdded { + timestamp: Some(language::proto::serialize_timestamp(*timestamp)), + section: Some({ + proto::ThoughtProcessOutputSection { + range: Some(language::proto::serialize_anchor_range( + section.range.clone(), + )), + } + }), + version: language::proto::serialize_version(version), + }, + ), + ), + }, Self::BufferOperation(operation) => proto::ContextOperation { variant: Some(proto::context_operation::Variant::BufferOperation( proto::context_operation::BufferOperation { @@ -387,7 +427,8 @@ impl ContextOperation { Self::UpdateSummary { summary, .. } => summary.timestamp, Self::SlashCommandStarted { id, .. } => id.0, Self::SlashCommandOutputSectionAdded { timestamp, .. } - | Self::SlashCommandFinished { timestamp, .. } => *timestamp, + | Self::SlashCommandFinished { timestamp, .. } + | Self::ThoughtProcessOutputSectionAdded { timestamp, .. } => *timestamp, Self::BufferOperation(_) => { panic!("reading the timestamp of a buffer operation is not supported") } @@ -402,7 +443,8 @@ impl ContextOperation { | Self::UpdateSummary { version, .. } | Self::SlashCommandStarted { version, .. } | Self::SlashCommandOutputSectionAdded { version, .. } - | Self::SlashCommandFinished { version, .. } => version, + | Self::SlashCommandFinished { version, .. } + | Self::ThoughtProcessOutputSectionAdded { version, .. } => version, Self::BufferOperation(_) => { panic!("reading the version of a buffer operation is not supported") } @@ -418,6 +460,8 @@ pub enum ContextEvent { MessagesEdited, SummaryChanged, StreamedCompletion, + StartedThoughtProcess(Range), + EndedThoughtProcess(language::Anchor), PatchesUpdated { removed: Vec>, updated: Vec>, @@ -498,6 +542,17 @@ impl MessageMetadata { } } +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct ThoughtProcessOutputSection { + pub range: Range, +} + +impl ThoughtProcessOutputSection { + pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool { + self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty() + } +} + #[derive(Clone, Debug)] pub struct Message { pub offset_range: Range, @@ -580,6 +635,7 @@ pub struct AssistantContext { edits_since_last_parse: language::Subscription, slash_commands: Arc, slash_command_output_sections: Vec>, + thought_process_output_sections: Vec>, message_anchors: Vec, contents: Vec, messages_metadata: HashMap, @@ -682,6 +738,7 @@ impl AssistantContext { parsed_slash_commands: Vec::new(), invoked_slash_commands: HashMap::default(), slash_command_output_sections: Vec::new(), + thought_process_output_sections: Vec::new(), edits_since_last_parse: edits_since_last_slash_command_parse, summary: None, pending_summary: Task::ready(None), @@ -764,6 +821,18 @@ impl AssistantContext { } }) .collect(), + thought_process_output_sections: self + .thought_process_output_sections + .iter() + .filter_map(|section| { + if section.is_valid(buffer) { + let range = section.range.to_offset(buffer); + Some(ThoughtProcessOutputSection { range }) + } else { + None + } + }) + .collect(), } } @@ -957,6 +1026,16 @@ impl AssistantContext { cx.emit(ContextEvent::SlashCommandOutputSectionAdded { section }); } } + ContextOperation::ThoughtProcessOutputSectionAdded { section, .. } => { + let buffer = self.buffer.read(cx); + if let Err(ix) = self + .thought_process_output_sections + .binary_search_by(|probe| probe.range.cmp(§ion.range, buffer)) + { + self.thought_process_output_sections + .insert(ix, section.clone()); + } + } ContextOperation::SlashCommandFinished { id, error_message, @@ -1020,6 +1099,9 @@ impl AssistantContext { ContextOperation::SlashCommandOutputSectionAdded { section, .. } => { self.has_received_operations_for_anchor_range(section.range.clone(), cx) } + ContextOperation::ThoughtProcessOutputSectionAdded { section, .. } => { + self.has_received_operations_for_anchor_range(section.range.clone(), cx) + } ContextOperation::SlashCommandFinished { .. } => true, ContextOperation::BufferOperation(_) => { panic!("buffer operations should always be applied") @@ -1128,6 +1210,12 @@ impl AssistantContext { &self.slash_command_output_sections } + pub fn thought_process_output_sections( + &self, + ) -> &[ThoughtProcessOutputSection] { + &self.thought_process_output_sections + } + pub fn contains_files(&self, cx: &App) -> bool { let buffer = self.buffer.read(cx); self.slash_command_output_sections.iter().any(|section| { @@ -2168,6 +2256,35 @@ impl AssistantContext { ); } + fn insert_thought_process_output_section( + &mut self, + section: ThoughtProcessOutputSection, + cx: &mut Context, + ) { + let buffer = self.buffer.read(cx); + let insertion_ix = match self + .thought_process_output_sections + .binary_search_by(|probe| probe.range.cmp(§ion.range, buffer)) + { + Ok(ix) | Err(ix) => ix, + }; + self.thought_process_output_sections + .insert(insertion_ix, section.clone()); + // cx.emit(ContextEvent::ThoughtProcessOutputSectionAdded { + // section: section.clone(), + // }); + let version = self.version.clone(); + let timestamp = self.next_timestamp(); + self.push_op( + ContextOperation::ThoughtProcessOutputSectionAdded { + timestamp, + section, + version, + }, + cx, + ); + } + pub fn completion_provider_changed(&mut self, cx: &mut Context) { self.count_remaining_tokens(cx); } @@ -2220,6 +2337,10 @@ impl AssistantContext { let request_start = Instant::now(); let mut events = stream.await?; let mut stop_reason = StopReason::EndTurn; + let mut thought_process_stack = Vec::new(); + + const THOUGHT_PROCESS_START_MARKER: &str = "\n"; + const THOUGHT_PROCESS_END_MARKER: &str = "\n"; while let Some(event) = events.next().await { if response_latency.is_none() { @@ -2227,6 +2348,9 @@ impl AssistantContext { } let event = event?; + let mut context_event = None; + let mut thought_process_output_section = None; + this.update(cx, |this, cx| { let message_ix = this .message_anchors @@ -2245,7 +2369,50 @@ impl AssistantContext { LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; } - LanguageModelCompletionEvent::Text(chunk) => { + LanguageModelCompletionEvent::Thinking(chunk) => { + if thought_process_stack.is_empty() { + let start = + buffer.anchor_before(message_old_end_offset); + thought_process_stack.push(start); + let chunk = + format!("{THOUGHT_PROCESS_START_MARKER}{chunk}{THOUGHT_PROCESS_END_MARKER}"); + let chunk_len = chunk.len(); + buffer.edit( + [( + message_old_end_offset..message_old_end_offset, + chunk, + )], + None, + cx, + ); + let end = buffer + .anchor_before(message_old_end_offset + chunk_len); + context_event = Some( + ContextEvent::StartedThoughtProcess(start..end), + ); + } else { + // This ensures that all the thinking chunks are inserted inside the thinking tag + let insertion_position = + message_old_end_offset - THOUGHT_PROCESS_END_MARKER.len(); + buffer.edit( + [(insertion_position..insertion_position, chunk)], + None, + cx, + ); + } + } + LanguageModelCompletionEvent::Text(mut chunk) => { + if let Some(start) = thought_process_stack.pop() { + let end = buffer.anchor_before(message_old_end_offset); + context_event = + Some(ContextEvent::EndedThoughtProcess(end)); + thought_process_output_section = + Some(ThoughtProcessOutputSection { + range: start..end, + }); + chunk.insert_str(0, "\n\n"); + } + buffer.edit( [( message_old_end_offset..message_old_end_offset, @@ -2260,6 +2427,13 @@ impl AssistantContext { } }); + if let Some(section) = thought_process_output_section.take() { + this.insert_thought_process_output_section(section, cx); + } + if let Some(context_event) = context_event.take() { + cx.emit(context_event); + } + cx.emit(ContextEvent::StreamedCompletion); Some(()) @@ -3127,6 +3301,8 @@ pub struct SavedContext { pub summary: String, pub slash_command_output_sections: Vec>, + #[serde(default)] + pub thought_process_output_sections: Vec>, } impl SavedContext { @@ -3228,6 +3404,20 @@ impl SavedContext { version.observe(timestamp); } + for section in self.thought_process_output_sections { + let timestamp = next_timestamp.tick(); + operations.push(ContextOperation::ThoughtProcessOutputSectionAdded { + timestamp, + section: ThoughtProcessOutputSection { + range: buffer.anchor_after(section.range.start) + ..buffer.anchor_before(section.range.end), + }, + version: version.clone(), + }); + + version.observe(timestamp); + } + let timestamp = next_timestamp.tick(); operations.push(ContextOperation::UpdateSummary { summary: ContextSummary { @@ -3302,6 +3492,7 @@ impl SavedContextV0_3_0 { .collect(), summary: self.summary, slash_command_output_sections: self.slash_command_output_sections, + thought_process_output_sections: Vec::new(), } } } diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index bdc5a51a92..8fc2433476 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -64,7 +64,10 @@ use workspace::{ Workspace, }; -use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker}; +use crate::{ + slash_command::SlashCommandCompletionProvider, slash_command_picker, + ThoughtProcessOutputSection, +}; use crate::{ AssistantContext, AssistantPatch, AssistantPatchStatus, CacheStatus, Content, ContextEvent, ContextId, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId, @@ -120,6 +123,11 @@ enum AssistError { Message(SharedString), } +pub enum ThoughtProcessStatus { + Pending, + Completed, +} + pub trait AssistantPanelDelegate { fn active_context_editor( &self, @@ -178,6 +186,7 @@ pub struct ContextEditor { project: Entity, lsp_adapter_delegate: Option>, editor: Entity, + pending_thought_process: Option<(CreaseId, language::Anchor)>, blocks: HashMap, image_blocks: HashSet, scroll_position: Option, @@ -253,7 +262,8 @@ impl ContextEditor { cx.observe_global_in::(window, Self::settings_changed), ]; - let sections = context.read(cx).slash_command_output_sections().to_vec(); + let slash_command_sections = context.read(cx).slash_command_output_sections().to_vec(); + let thought_process_sections = context.read(cx).thought_process_output_sections().to_vec(); let patch_ranges = context.read(cx).patch_ranges().collect::>(); let slash_commands = context.read(cx).slash_commands().clone(); let mut this = Self { @@ -265,6 +275,7 @@ impl ContextEditor { image_blocks: Default::default(), scroll_position: None, remote_id: None, + pending_thought_process: None, fs: fs.clone(), workspace, project, @@ -294,7 +305,14 @@ impl ContextEditor { }; this.update_message_headers(cx); this.update_image_blocks(cx); - this.insert_slash_command_output_sections(sections, false, window, cx); + this.insert_slash_command_output_sections(slash_command_sections, false, window, cx); + this.insert_thought_process_output_sections( + thought_process_sections + .into_iter() + .map(|section| (section, ThoughtProcessStatus::Completed)), + window, + cx, + ); this.patches_updated(&Vec::new(), &patch_ranges, window, cx); this } @@ -599,6 +617,47 @@ impl ContextEditor { context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx); }); } + ContextEvent::StartedThoughtProcess(range) => { + let creases = self.insert_thought_process_output_sections( + [( + ThoughtProcessOutputSection { + range: range.clone(), + }, + ThoughtProcessStatus::Pending, + )], + window, + cx, + ); + self.pending_thought_process = Some((creases[0], range.start)); + } + ContextEvent::EndedThoughtProcess(end) => { + if let Some((crease_id, start)) = self.pending_thought_process.take() { + self.editor.update(cx, |editor, cx| { + let multi_buffer_snapshot = editor.buffer().read(cx).snapshot(cx); + let (excerpt_id, _, _) = multi_buffer_snapshot.as_singleton().unwrap(); + let start_anchor = multi_buffer_snapshot + .anchor_in_excerpt(*excerpt_id, start) + .unwrap(); + + editor.display_map.update(cx, |display_map, cx| { + display_map.unfold_intersecting( + vec![start_anchor..start_anchor], + true, + cx, + ); + }); + editor.remove_creases(vec![crease_id], cx); + }); + self.insert_thought_process_output_sections( + [( + ThoughtProcessOutputSection { range: start..*end }, + ThoughtProcessStatus::Completed, + )], + window, + cx, + ); + } + } ContextEvent::StreamedCompletion => { self.editor.update(cx, |editor, cx| { if let Some(scroll_position) = self.scroll_position { @@ -946,6 +1005,62 @@ impl ContextEditor { self.update_active_patch(window, cx); } + fn insert_thought_process_output_sections( + &mut self, + sections: impl IntoIterator< + Item = ( + ThoughtProcessOutputSection, + ThoughtProcessStatus, + ), + >, + window: &mut Window, + cx: &mut Context, + ) -> Vec { + self.editor.update(cx, |editor, cx| { + let buffer = editor.buffer().read(cx).snapshot(cx); + let excerpt_id = *buffer.as_singleton().unwrap().0; + let mut buffer_rows_to_fold = BTreeSet::new(); + let mut creases = Vec::new(); + for (section, status) in sections { + let start = buffer + .anchor_in_excerpt(excerpt_id, section.range.start) + .unwrap(); + let end = buffer + .anchor_in_excerpt(excerpt_id, section.range.end) + .unwrap(); + let buffer_row = MultiBufferRow(start.to_point(&buffer).row); + buffer_rows_to_fold.insert(buffer_row); + creases.push( + Crease::inline( + start..end, + FoldPlaceholder { + render: render_thought_process_fold_icon_button( + cx.entity().downgrade(), + status, + ), + merge_adjacent: false, + ..Default::default() + }, + render_slash_command_output_toggle, + |_, _, _, _| Empty.into_any_element(), + ) + .with_metadata(CreaseMetadata { + icon: IconName::Ai, + label: "Thinking Process".into(), + }), + ); + } + + let creases = editor.insert_creases(creases, cx); + + for buffer_row in buffer_rows_to_fold.into_iter().rev() { + editor.fold_at(&FoldAt { buffer_row }, window, cx); + } + + creases + }) + } + fn insert_slash_command_output_sections( &mut self, sections: impl IntoIterator>, @@ -2652,6 +2767,52 @@ fn find_surrounding_code_block(snapshot: &BufferSnapshot, offset: usize) -> Opti None } +fn render_thought_process_fold_icon_button( + editor: WeakEntity, + status: ThoughtProcessStatus, +) -> Arc, &mut App) -> AnyElement> { + Arc::new(move |fold_id, fold_range, _cx| { + let editor = editor.clone(); + + let button = ButtonLike::new(fold_id).layer(ElevationIndex::ElevatedSurface); + let button = match status { + ThoughtProcessStatus::Pending => button + .child( + Icon::new(IconName::Brain) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child( + Label::new("Thinking…").color(Color::Muted).with_animation( + "pulsating-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 0.8)), + |label, delta| label.alpha(delta), + ), + ), + ThoughtProcessStatus::Completed => button + .style(ButtonStyle::Filled) + .child(Icon::new(IconName::Brain).size(IconSize::Small)) + .child(Label::new("Thought Process").single_line()), + }; + + button + .on_click(move |_, window, cx| { + editor + .update(cx, |editor, cx| { + let buffer_start = fold_range + .start + .to_point(&editor.buffer().read(cx).read(cx)); + let buffer_row = MultiBufferRow(buffer_start.row); + editor.unfold_at(&UnfoldAt { buffer_row }, window, cx); + }) + .ok(); + }) + .into_any_element() + }) +} + fn render_fold_icon_button( editor: WeakEntity, icon: IconName, diff --git a/crates/assistant_eval/src/eval.rs b/crates/assistant_eval/src/eval.rs index 2268cf78ab..8801472e4d 100644 --- a/crates/assistant_eval/src/eval.rs +++ b/crates/assistant_eval/src/eval.rs @@ -120,7 +120,7 @@ impl Eval { .count(); Ok(EvalOutput { diff, - last_message: last_message.text.clone(), + last_message: last_message.to_string(), elapsed_time, assistant_response_count, tool_use_counts: assistant.tool_use_counts.clone(), diff --git a/crates/assistant_eval/src/headless_assistant.rs b/crates/assistant_eval/src/headless_assistant.rs index 0eb63d84ac..31bbc69f31 100644 --- a/crates/assistant_eval/src/headless_assistant.rs +++ b/crates/assistant_eval/src/headless_assistant.rs @@ -89,7 +89,7 @@ impl HeadlessAssistant { ThreadEvent::DoneStreaming => { let thread = thread.read(cx); if let Some(message) = thread.messages().last() { - println!("Message: {}", message.text,); + println!("Message: {}", message.to_string()); } if thread.all_tools_finished() { self.done_tx.send_blocking(Ok(())).unwrap() diff --git a/crates/gpui/src/elements/div.rs b/crates/gpui/src/elements/div.rs index 3e8f497b1a..2f573570f0 100644 --- a/crates/gpui/src/elements/div.rs +++ b/crates/gpui/src/elements/div.rs @@ -1240,20 +1240,11 @@ impl Element for Div { let mut state = scroll_handle.0.borrow_mut(); state.child_bounds = Vec::with_capacity(request_layout.child_layout_ids.len()); state.bounds = bounds; - let requested = state.requested_scroll_top.take(); - - for (ix, child_layout_id) in request_layout.child_layout_ids.iter().enumerate() { + for child_layout_id in &request_layout.child_layout_ids { let child_bounds = window.layout_bounds(*child_layout_id); child_min = child_min.min(&child_bounds.origin); child_max = child_max.max(&child_bounds.bottom_right()); state.child_bounds.push(child_bounds); - - if let Some(requested) = requested.as_ref() { - if requested.0 == ix { - *state.offset.borrow_mut() = - bounds.origin - (child_bounds.origin - point(px(0.), requested.1)); - } - } } (child_max - child_min).into() } else { @@ -1528,8 +1519,11 @@ impl Interactivity { _cx: &mut App, ) -> Point { if let Some(scroll_offset) = self.scroll_offset.as_ref() { + let mut scroll_to_bottom = false; if let Some(scroll_handle) = &self.tracked_scroll_handle { - scroll_handle.0.borrow_mut().overflow = style.overflow; + let mut state = scroll_handle.0.borrow_mut(); + state.overflow = style.overflow; + scroll_to_bottom = mem::take(&mut state.scroll_to_bottom); } let rem_size = window.rem_size(); @@ -1555,8 +1549,14 @@ impl Interactivity { // Clamp scroll offset in case scroll max is smaller now (e.g., if children // were removed or the bounds became larger). let mut scroll_offset = scroll_offset.borrow_mut(); + scroll_offset.x = scroll_offset.x.clamp(-scroll_max.width, px(0.)); - scroll_offset.y = scroll_offset.y.clamp(-scroll_max.height, px(0.)); + if scroll_to_bottom { + scroll_offset.y = -scroll_max.height; + } else { + scroll_offset.y = scroll_offset.y.clamp(-scroll_max.height, px(0.)); + } + *scroll_offset } else { Point::default() @@ -2861,12 +2861,13 @@ impl ScrollAnchor { }); } } + #[derive(Default, Debug)] struct ScrollHandleState { offset: Rc>>, bounds: Bounds, child_bounds: Vec>, - requested_scroll_top: Option<(usize, Pixels)>, + scroll_to_bottom: bool, overflow: Point, } @@ -2955,6 +2956,12 @@ impl ScrollHandle { } } + /// Scrolls to the bottom. + pub fn scroll_to_bottom(&self) { + let mut state = self.0.borrow_mut(); + state.scroll_to_bottom = true; + } + /// Set the offset explicitly. The offset is the distance from the top left of the /// parent container to the top left of the first child. /// As you scroll further down the offset becomes more negative. @@ -2978,11 +2985,6 @@ impl ScrollHandle { } } - /// Set the logical scroll top, based on a child index and a pixel offset. - pub fn set_logical_scroll_top(&self, ix: usize, px: Pixels) { - self.0.borrow_mut().requested_scroll_top = Some((ix, px)); - } - /// Get the count of children for scrollable item. pub fn children_count(&self) -> usize { self.0.borrow().child_bounds.len() diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 8ed190e731..a016ae0b37 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -59,6 +59,7 @@ pub struct LanguageModelCacheConfiguration { pub enum LanguageModelCompletionEvent { Stop(StopReason), Text(String), + Thinking(String), ToolUse(LanguageModelToolUse), StartMessage { message_id: String }, UsageUpdate(TokenUsage), @@ -217,6 +218,7 @@ pub trait LanguageModel: Send + Sync { match result { Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), + Ok(LanguageModelCompletionEvent::Thinking(_)) => None, Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None, diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 44e2199a84..3a610cad73 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -72,7 +72,9 @@ impl CloudModel { pub fn availability(&self) -> LanguageModelAvailability { match self { Self::Anthropic(model) => match model { - anthropic::Model::Claude3_5Sonnet | anthropic::Model::Claude3_7Sonnet => { + anthropic::Model::Claude3_5Sonnet + | anthropic::Model::Claude3_7Sonnet + | anthropic::Model::Claude3_7SonnetThinking => { LanguageModelAvailability::RequiresPlan(Plan::Free) } anthropic::Model::Claude3Opus diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 3bc6460ac7..7d2660571d 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,6 +1,6 @@ use crate::ui::InstructionListItem; use crate::AllLanguageModelSettings; -use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent, Usage}; +use anthropic::{AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, Usage}; use anyhow::{anyhow, Context as _, Result}; use collections::{BTreeMap, HashMap}; use credentials_provider::CredentialsProvider; @@ -55,6 +55,37 @@ pub struct AvailableModel { pub default_temperature: Option, #[serde(default)] pub extra_beta_headers: Vec, + /// The model's mode (e.g. thinking) + pub mode: Option, +} + +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. + budget_tokens: Option, + }, +} + +impl From for AnthropicModelMode { + fn from(value: ModelMode) -> Self { + match value { + ModelMode::Default => AnthropicModelMode::Default, + ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens }, + } + } +} + +impl From for ModelMode { + fn from(value: AnthropicModelMode) -> Self { + match value { + AnthropicModelMode::Default => ModelMode::Default, + AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens }, + } + } } pub struct AnthropicLanguageModelProvider { @@ -228,6 +259,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { max_output_tokens: model.max_output_tokens, default_temperature: model.default_temperature, extra_beta_headers: model.extra_beta_headers.clone(), + mode: model.mode.clone().unwrap_or_default().into(), }, ); } @@ -399,9 +431,10 @@ impl LanguageModel for AnthropicModel { ) -> BoxFuture<'static, Result>>> { let request = into_anthropic( request, - self.model.id().into(), + self.model.request_id().into(), self.model.default_temperature(), self.model.max_output_tokens(), + self.model.mode(), ); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { @@ -434,6 +467,7 @@ impl LanguageModel for AnthropicModel { self.model.tool_model_id().into(), self.model.default_temperature(), self.model.max_output_tokens(), + self.model.mode(), ); request.tool_choice = Some(anthropic::ToolChoice::Tool { name: tool_name.clone(), @@ -464,6 +498,7 @@ pub fn into_anthropic( model: String, default_temperature: f32, max_output_tokens: u32, + mode: AnthropicModelMode, ) -> anthropic::Request { let mut new_messages: Vec = Vec::new(); let mut system_message = String::new(); @@ -552,6 +587,11 @@ pub fn into_anthropic( messages: new_messages, max_tokens: max_output_tokens, system: Some(system_message), + thinking: if let AnthropicModelMode::Thinking { budget_tokens } = mode { + Some(anthropic::Thinking::Enabled { budget_tokens }) + } else { + None + }, tools: request .tools .into_iter() @@ -607,6 +647,16 @@ pub fn map_to_language_model_completion_events( state, )); } + ResponseContent::Thinking { thinking } => { + return Some(( + vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))], + state, + )); + } + ResponseContent::RedactedThinking { .. } => { + // Redacted thinking is encrypted and not accessible to the user, see: + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production + } ResponseContent::ToolUse { id, name, .. } => { state.tool_uses_by_index.insert( index, @@ -625,6 +675,13 @@ pub fn map_to_language_model_completion_events( state, )); } + ContentDelta::ThinkingDelta { thinking } => { + return Some(( + vec![Ok(LanguageModelCompletionEvent::Thinking(thinking))], + state, + )); + } + ContentDelta::SignatureDelta { .. } => {} ContentDelta::InputJsonDelta { partial_json } => { if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { tool_use.input_json.push_str(&partial_json); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 58eef9aafc..798a25084c 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,4 +1,4 @@ -use anthropic::AnthropicError; +use anthropic::{AnthropicError, AnthropicModelMode}; use anyhow::{anyhow, Result}; use client::{ zed_urls, Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME, @@ -91,6 +91,28 @@ pub struct AvailableModel { /// Any extra beta headers to provide when using the model. #[serde(default)] pub extra_beta_headers: Vec, + /// The model's mode (e.g. thinking) + pub mode: Option, +} + +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. + budget_tokens: Option, + }, +} + +impl From for AnthropicModelMode { + fn from(value: ModelMode) -> Self { + match value { + ModelMode::Default => AnthropicModelMode::Default, + ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens }, + } + } } pub struct CloudLanguageModelProvider { @@ -299,6 +321,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { default_temperature: model.default_temperature, max_output_tokens: model.max_output_tokens, extra_beta_headers: model.extra_beta_headers.clone(), + mode: model.mode.unwrap_or_default().into(), }), AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { name: model.name.clone(), @@ -567,9 +590,10 @@ impl LanguageModel for CloudLanguageModel { CloudModel::Anthropic(model) => { let request = into_anthropic( request, - model.id().into(), + model.request_id().into(), model.default_temperature(), model.max_output_tokens(), + model.mode(), ); let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); @@ -669,6 +693,7 @@ impl LanguageModel for CloudLanguageModel { model.tool_model_id().into(), model.default_temperature(), model.max_output_tokens(), + model.mode(), ); request.tool_choice = Some(anthropic::ToolChoice::Tool { name: tool_name.clone(), diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index a274d8e262..6e25ca3c40 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -109,6 +109,7 @@ impl AnthropicSettingsContent { max_output_tokens, default_temperature, extra_beta_headers, + mode, } => Some(provider::anthropic::AvailableModel { name, display_name, @@ -124,6 +125,7 @@ impl AnthropicSettingsContent { max_output_tokens, default_temperature, extra_beta_headers, + mode: Some(mode.into()), }), _ => None, }) diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 6573a57c86..f9314a97b5 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -2503,6 +2503,10 @@ message SlashCommandOutputSection { optional string metadata = 4; } +message ThoughtProcessOutputSection { + AnchorRange range = 1; +} + message ContextOperation { oneof variant { InsertMessage insert_message = 1; @@ -2512,6 +2516,7 @@ message ContextOperation { SlashCommandStarted slash_command_started = 6; SlashCommandOutputSectionAdded slash_command_output_section_added = 7; SlashCommandCompleted slash_command_completed = 8; + ThoughtProcessOutputSectionAdded thought_process_output_section_added = 9; } reserved 4; @@ -2556,6 +2561,12 @@ message ContextOperation { repeated VectorClockEntry version = 5; } + message ThoughtProcessOutputSectionAdded { + LamportTimestamp timestamp = 1; + ThoughtProcessOutputSection section = 2; + repeated VectorClockEntry version = 3; + } + message BufferOperation { Operation operation = 1; } diff --git a/docs/src/assistant/configuration.md b/docs/src/assistant/configuration.md index 354a581c66..90d215281c 100644 --- a/docs/src/assistant/configuration.md +++ b/docs/src/assistant/configuration.md @@ -68,6 +68,21 @@ You can add custom models to the Anthropic provider by adding the following to y Custom models will be listed in the model dropdown in the assistant panel. +You can configure a model to use [extended thinking](https://docs.anthropic.com/en/docs/about-claude/models/extended-thinking-models) (if it supports it), +by changing the mode in of your models configuration to `thinking`, for example: + +```json +{ + "name": "claude-3-7-sonnet-latest", + "display_name": "claude-3-7-sonnet-thinking", + "max_tokens": 200000, + "mode": { + "type": "thinking", + "budget_tokens": 4_096 + } +} +``` + ### GitHub Copilot Chat {#github-copilot-chat} You can use GitHub Copilot chat with the Zed assistant by choosing it via the model dropdown in the assistant panel.