diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 841482e482..c0d42bf5a2 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -6,7 +6,7 @@ use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::message_editor::insert_message_creases; use crate::thread::{ LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, - ThreadEvent, ThreadFeedback, + ThreadEvent, ThreadFeedback, ThreadSummary, }; use crate::thread_store::{RulesLoadingError, TextThreadStore, ThreadStore}; use crate::tool_use::{PendingToolUseStatus, ToolUse}; @@ -823,12 +823,12 @@ impl ActiveThread { self.messages.is_empty() } - pub fn summary(&self, cx: &App) -> Option { + pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary { self.thread.read(cx).summary() } - pub fn summary_or_default(&self, cx: &App) -> SharedString { - self.thread.read(cx).summary_or_default() + pub fn regenerate_summary(&self, cx: &mut App) { + self.thread.update(cx, |thread, cx| thread.summarize(cx)) } pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool { @@ -1134,11 +1134,7 @@ impl ActiveThread { return; } - let title = self - .thread - .read(cx) - .summary() - .unwrap_or("Agent Panel".into()); + let title = self.thread.read(cx).summary().unwrap_or("Agent Panel"); match AssistantSettings::get_global(cx).notify_when_agent_waiting { NotifyWhenAgentWaiting::PrimaryScreen => { @@ -3442,10 +3438,7 @@ pub(crate) fn open_active_thread_as_markdown( workspace.update_in(cx, |workspace, window, cx| { let thread = thread.read(cx); let markdown = thread.to_markdown(cx)?; - let thread_summary = thread - .summary() - .map(|summary| summary.to_string()) - .unwrap_or_else(|| "Thread".to_string()); + let thread_summary = thread.summary().or_default().to_string(); let project = workspace.project().clone(); diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index 54e4f1f9aa..d83b2cf80d 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -215,11 +215,7 @@ impl AgentDiffPane { } fn update_title(&mut self, cx: &mut Context) { - let new_title = self - .thread - .read(cx) - .summary() - .unwrap_or("Agent Changes".into()); + let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes"); if new_title != self.title { self.title = new_title; cx.emit(EditorEvent::TitleChanged); @@ -469,11 +465,7 @@ impl Item for AgentDiffPane { } fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { - let summary = self - .thread - .read(cx) - .summary() - .unwrap_or("Agent Changes".into()); + let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes"); Label::new(format!("Review: {}", summary)) .color(if params.selected { Color::Default diff --git a/crates/agent/src/agent_panel.rs b/crates/agent/src/agent_panel.rs index c48863ae41..eb1635bf4b 100644 --- a/crates/agent/src/agent_panel.rs +++ b/crates/agent/src/agent_panel.rs @@ -10,8 +10,8 @@ use serde::{Deserialize, Serialize}; use anyhow::{Result, anyhow}; use assistant_context_editor::{ AgentPanelDelegate, AssistantContext, ConfigurationError, ContextEditor, ContextEvent, - SlashCommandCompletionProvider, humanize_token_count, make_lsp_adapter_delegate, - render_remaining_tokens, + ContextSummary, SlashCommandCompletionProvider, humanize_token_count, + make_lsp_adapter_delegate, render_remaining_tokens, }; use assistant_settings::{AssistantDockPosition, AssistantSettings}; use assistant_slash_command::SlashCommandWorkingSet; @@ -59,7 +59,7 @@ use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent use crate::agent_diff::AgentDiff; use crate::history_store::{HistoryStore, RecentEntry}; use crate::message_editor::{MessageEditor, MessageEditorEvent}; -use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio}; +use crate::thread::{Thread, ThreadError, ThreadId, ThreadSummary, TokenUsageRatio}; use crate::thread_history::{HistoryEntryElement, ThreadHistory}; use crate::thread_store::ThreadStore; use crate::ui::AgentOnboardingModal; @@ -196,7 +196,7 @@ impl ActiveView { } pub fn thread(thread: Entity, window: &mut Window, cx: &mut App) -> Self { - let summary = thread.read(cx).summary_or_default(); + let summary = thread.read(cx).summary().or_default(); let editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); @@ -218,7 +218,7 @@ impl ActiveView { } EditorEvent::Blurred => { if editor.read(cx).text(cx).is_empty() { - let summary = thread.read(cx).summary_or_default(); + let summary = thread.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -233,7 +233,7 @@ impl ActiveView { let editor = editor.clone(); move |thread, event, window, cx| match event { ThreadEvent::SummaryGenerated => { - let summary = thread.read(cx).summary_or_default(); + let summary = thread.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -296,7 +296,8 @@ impl ActiveView { .read(cx) .context() .read(cx) - .summary_or_default(); + .summary() + .or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -311,7 +312,7 @@ impl ActiveView { let editor = editor.clone(); move |assistant_context, event, window, cx| match event { ContextEvent::SummaryGenerated => { - let summary = assistant_context.read(cx).summary_or_default(); + let summary = assistant_context.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -1452,23 +1453,45 @@ impl AgentPanel { .. } => { let active_thread = self.thread.read(cx); - let is_empty = active_thread.is_empty(); - - let summary = active_thread.summary(cx); - - if is_empty { - Label::new(Thread::DEFAULT_SUMMARY.clone()) - .truncate() - .into_any_element() - } else if summary.is_none() { - Label::new(LOADING_SUMMARY_PLACEHOLDER) - .truncate() - .into_any_element() + let state = if active_thread.is_empty() { + &ThreadSummary::Pending } else { - div() + active_thread.summary(cx) + }; + + match state { + ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone()) + .truncate() + .into_any_element(), + ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER) + .truncate() + .into_any_element(), + ThreadSummary::Ready(_) => div() .w_full() .child(change_title_editor.clone()) - .into_any_element() + .into_any_element(), + ThreadSummary::Error => h_flex() + .w_full() + .child(change_title_editor.clone()) + .child( + ui::IconButton::new("retry-summary-generation", IconName::RotateCcw) + .on_click({ + let active_thread = self.thread.clone(); + move |_, _window, cx| { + active_thread.update(cx, |thread, cx| { + thread.regenerate_summary(cx); + }); + } + }) + .tooltip(move |_window, cx| { + cx.new(|_| { + Tooltip::new("Failed to generate title") + .meta("Click to try again") + }) + .into() + }), + ) + .into_any_element(), } } ActiveView::PromptEditor { @@ -1476,14 +1499,13 @@ impl AgentPanel { context_editor, .. } => { - let context_editor = context_editor.read(cx); - let summary = context_editor.context().read(cx).summary(); + let summary = context_editor.read(cx).context().read(cx).summary(); match summary { - None => Label::new(AssistantContext::DEFAULT_SUMMARY.clone()) + ContextSummary::Pending => Label::new(ContextSummary::DEFAULT) .truncate() .into_any_element(), - Some(summary) => { + ContextSummary::Content(summary) => { if summary.done { div() .w_full() @@ -1495,6 +1517,28 @@ impl AgentPanel { .into_any_element() } } + ContextSummary::Error => h_flex() + .w_full() + .child(title_editor.clone()) + .child( + ui::IconButton::new("retry-summary-generation", IconName::RotateCcw) + .on_click({ + let context_editor = context_editor.clone(); + move |_, _window, cx| { + context_editor.update(cx, |context_editor, cx| { + context_editor.regenerate_summary(cx); + }); + } + }) + .tooltip(move |_window, cx| { + cx.new(|_| { + Tooltip::new("Failed to generate title") + .meta("Click to try again") + }) + .into() + }), + ) + .into_any_element(), } } ActiveView::History => Label::new("History").truncate().into_any_element(), diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index f458896b18..98437778aa 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -586,10 +586,7 @@ impl ThreadContextHandle { } pub fn title(&self, cx: &App) -> SharedString { - self.thread - .read(cx) - .summary() - .unwrap_or_else(|| "New thread".into()) + self.thread.read(cx).summary().or_default() } fn load(self, cx: &App) -> Task>)>> { @@ -597,9 +594,7 @@ impl ThreadContextHandle { let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?; let title = self .thread - .read_with(cx, |thread, _cx| { - thread.summary().unwrap_or_else(|| "New thread".into()) - }) + .read_with(cx, |thread, _cx| thread.summary().or_default()) .ok()?; let context = AgentContext::Thread(ThreadContext { title, @@ -642,7 +637,7 @@ impl TextThreadContextHandle { } pub fn title(&self, cx: &App) -> SharedString { - self.context.read(cx).summary_or_default() + self.context.read(cx).summary().or_default() } fn load(self, cx: &App) -> Task>)>> { diff --git a/crates/agent/src/context_strip.rs b/crates/agent/src/context_strip.rs index f9d9ff2781..8fe1a21d74 100644 --- a/crates/agent/src/context_strip.rs +++ b/crates/agent/src/context_strip.rs @@ -160,7 +160,7 @@ impl ContextStrip { } Some(SuggestedContext::Thread { - name: active_thread.summary_or_default(), + name: active_thread.summary().or_default(), thread: weak_active_thread, }) } else if let Some(active_context_editor) = panel.active_context_editor() { @@ -174,7 +174,7 @@ impl ContextStrip { } Some(SuggestedContext::TextThread { - name: context.summary_or_default(), + name: context.summary().or_default(), context: weak_context, }) } else { diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index 85fdac2bab..c8d9e9a263 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -71,8 +71,8 @@ impl Eq for RecentEntry {} impl RecentEntry { pub(crate) fn summary(&self, cx: &App) -> SharedString { match self { - RecentEntry::Thread(_, thread) => thread.read(cx).summary_or_default(), - RecentEntry::Context(context) => context.read(cx).summary_or_default(), + RecentEntry::Thread(_, thread) => thread.read(cx).summary().or_default(), + RecentEntry::Context(context) => context.read(cx).summary().or_default(), } } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index c6f2d74ff9..a65eda5b40 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -36,7 +36,7 @@ use serde::{Deserialize, Serialize}; use settings::Settings; use thiserror::Error; use ui::Window; -use util::{ResultExt as _, TryFutureExt as _, post_inc}; +use util::{ResultExt as _, post_inc}; use uuid::Uuid; use zed_llm_client::CompletionRequestStatus; @@ -324,7 +324,7 @@ pub enum QueueState { pub struct Thread { id: ThreadId, updated_at: DateTime, - summary: Option, + summary: ThreadSummary, pending_summary: Task>, detailed_summary_task: Task>, detailed_summary_tx: postage::watch::Sender, @@ -361,6 +361,33 @@ pub struct Thread { configured_model: Option, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ThreadSummary { + Pending, + Generating, + Ready(SharedString), + Error, +} + +impl ThreadSummary { + pub const DEFAULT: SharedString = SharedString::new_static("New Thread"); + + pub fn or_default(&self) -> SharedString { + self.unwrap_or(Self::DEFAULT) + } + + pub fn unwrap_or(&self, message: impl Into) -> SharedString { + self.ready().unwrap_or_else(|| message.into()) + } + + pub fn ready(&self) -> Option { + match self { + ThreadSummary::Ready(summary) => Some(summary.clone()), + ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExceededWindowError { /// Model used when last message exceeded context window @@ -383,7 +410,7 @@ impl Thread { Self { id: ThreadId::new(), updated_at: Utc::now(), - summary: None, + summary: ThreadSummary::Pending, pending_summary: Task::ready(None), detailed_summary_task: Task::ready(None), detailed_summary_tx, @@ -471,7 +498,7 @@ impl Thread { Self { id, updated_at: serialized.updated_at, - summary: Some(serialized.summary), + summary: ThreadSummary::Ready(serialized.summary), pending_summary: Task::ready(None), detailed_summary_task: Task::ready(None), detailed_summary_tx, @@ -572,10 +599,6 @@ impl Thread { self.last_prompt_id = PromptId::new(); } - pub fn summary(&self) -> Option { - self.summary.clone() - } - pub fn project_context(&self) -> SharedProjectContext { self.project_context.clone() } @@ -596,26 +619,25 @@ impl Thread { cx.notify(); } - pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread"); - - pub fn summary_or_default(&self) -> SharedString { - self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY) + pub fn summary(&self) -> &ThreadSummary { + &self.summary } pub fn set_summary(&mut self, new_summary: impl Into, cx: &mut Context) { - let Some(current_summary) = &self.summary else { - // Don't allow setting summary until generated - return; + let current_summary = match &self.summary { + ThreadSummary::Pending | ThreadSummary::Generating => return, + ThreadSummary::Ready(summary) => summary, + ThreadSummary::Error => &ThreadSummary::DEFAULT, }; let mut new_summary = new_summary.into(); if new_summary.is_empty() { - new_summary = Self::DEFAULT_SUMMARY; + new_summary = ThreadSummary::DEFAULT; } if current_summary != &new_summary { - self.summary = Some(new_summary); + self.summary = ThreadSummary::Ready(new_summary); cx.emit(ThreadEvent::SummaryChanged); } } @@ -1029,7 +1051,7 @@ impl Thread { let initial_project_snapshot = initial_project_snapshot.await; this.read_with(cx, |this, cx| SerializedThread { version: SerializedThread::VERSION.to_string(), - summary: this.summary_or_default(), + summary: this.summary().or_default(), updated_at: this.updated_at(), messages: this .messages() @@ -1625,7 +1647,7 @@ impl Thread { // If there is a response without tool use, summarize the message. Otherwise, // allow two tool uses before summarizing. - if thread.summary.is_none() + if matches!(thread.summary, ThreadSummary::Pending) && thread.messages.len() >= 2 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6) { @@ -1739,6 +1761,7 @@ impl Thread { pub fn summarize(&mut self, cx: &mut Context) { let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else { + println!("No thread summary model"); return; }; @@ -1753,13 +1776,17 @@ impl Thread { let request = self.to_summarize_request(&model.model, added_user_message.into(), cx); + self.summary = ThreadSummary::Generating; + self.pending_summary = cx.spawn(async move |this, cx| { - async move { + let result = async { let mut messages = model.model.stream_completion(request, &cx).await?; let mut new_summary = String::new(); while let Some(event) = messages.next().await { - let event = event?; + let Ok(event) = event else { + continue; + }; let text = match event { LanguageModelCompletionEvent::Text(text) => text, LanguageModelCompletionEvent::StatusUpdate( @@ -1785,18 +1812,29 @@ impl Thread { } } - this.update(cx, |this, cx| { - if !new_summary.is_empty() { - this.summary = Some(new_summary.into()); - } - - cx.emit(ThreadEvent::SummaryGenerated); - })?; - - anyhow::Ok(()) + anyhow::Ok(new_summary) } - .log_err() - .await + .await; + + this.update(cx, |this, cx| { + match result { + Ok(new_summary) => { + if new_summary.is_empty() { + this.summary = ThreadSummary::Error; + } else { + this.summary = ThreadSummary::Ready(new_summary.into()); + } + } + Err(err) => { + this.summary = ThreadSummary::Error; + log::error!("Failed to generate thread summary: {}", err); + } + } + cx.emit(ThreadEvent::SummaryGenerated); + }) + .log_err()?; + + Some(()) }); } @@ -2406,9 +2444,8 @@ impl Thread { pub fn to_markdown(&self, cx: &App) -> Result { let mut markdown = Vec::new(); - if let Some(summary) = self.summary() { - writeln!(markdown, "# {summary}\n")?; - }; + let summary = self.summary().or_default(); + writeln!(markdown, "# {summary}\n")?; for message in self.messages() { writeln!( @@ -2725,7 +2762,7 @@ mod tests { use assistant_tool::ToolRegistry; use editor::EditorSettings; use gpui::TestAppContext; - use language_model::fake_provider::FakeLanguageModel; + use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; use project::{FakeFs, Project}; use prompt_store::PromptBuilder; use serde_json::json; @@ -3226,6 +3263,196 @@ fn main() {{ assert_eq!(request.temperature, None); } + #[gpui::test] + async fn test_thread_summary(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (_, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + // Initial state should be pending + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Pending)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + + // Manually setting the summary should not be allowed in this state + thread.update(cx, |thread, cx| { + thread.set_summary("This should not work", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Pending)); + }); + + // Send a message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); + thread.send_to_model(model.clone(), None, cx); + }); + + let fake_model = model.as_fake(); + simulate_successful_response(&fake_model, cx); + + // Should start generating summary when there are >= 2 messages + thread.read_with(cx, |thread, _| { + assert_eq!(*thread.summary(), ThreadSummary::Generating); + }); + + // Should not be able to set the summary while generating + thread.update(cx, |thread, cx| { + thread.set_summary("This should not work either", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Generating)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("Brief".into()); + fake_model.stream_last_completion_response(" Introduction".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Summary should be set + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), "Brief Introduction"); + }); + + // Now we should be able to set a summary + thread.update(cx, |thread, cx| { + thread.set_summary("Brief Intro", cx); + }); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.summary().or_default(), "Brief Intro"); + }); + + // Test setting an empty summary (should default to DEFAULT) + thread.update(cx, |thread, cx| { + thread.set_summary("", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + } + + #[gpui::test] + async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (_, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + test_summarize_error(&model, &thread, cx); + + // Now we should be able to set a summary + thread.update(cx, |thread, cx| { + thread.set_summary("Brief Intro", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), "Brief Intro"); + }); + } + + #[gpui::test] + async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (_, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + test_summarize_error(&model, &thread, cx); + + // Sending another message should not trigger another summarize request + thread.update(cx, |thread, cx| { + thread.insert_user_message( + "How are you?", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + thread.send_to_model(model.clone(), None, cx); + }); + + let fake_model = model.as_fake(); + simulate_successful_response(&fake_model, cx); + + thread.read_with(cx, |thread, _| { + // State is still Error, not Generating + assert!(matches!(thread.summary(), ThreadSummary::Error)); + }); + + // But the summarize request can be invoked manually + thread.update(cx, |thread, cx| { + thread.summarize(cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Generating)); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("A successful summary".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), "A successful summary"); + }); + } + + fn test_summarize_error( + model: &Arc, + thread: &Entity, + cx: &mut TestAppContext, + ) { + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); + thread.send_to_model(model.clone(), None, cx); + }); + + let fake_model = model.as_fake(); + simulate_successful_response(&fake_model, cx); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Generating)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + + // Simulate summary request ending + cx.run_until_parked(); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // State is set to Error and default message + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Error)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + } + + fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { + cx.run_until_parked(); + fake_model.stream_last_completion_response("Assistant response".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + } + fn init_test_settings(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); @@ -3282,9 +3509,29 @@ fn main() {{ let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); - let model = FakeLanguageModel::default(); + let provider = Arc::new(FakeLanguageModelProvider); + let model = provider.test_model(); let model: Arc = Arc::new(model); + cx.update(|_, cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model( + Some(ConfiguredModel { + provider: provider.clone(), + model: model.clone(), + }), + cx, + ); + registry.set_thread_summary_model( + Some(ConfiguredModel { + provider, + model: model.clone(), + }), + cx, + ); + }) + }); + (workspace, thread_store, thread, context_store, model) } diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index 28f03de1a6..047ca89db0 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -1,7 +1,7 @@ #[cfg(test)] mod context_tests; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result, anyhow, bail}; use assistant_settings::AssistantSettings; use assistant_slash_command::{ SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection, @@ -133,7 +133,7 @@ pub enum ContextOperation { version: clock::Global, }, UpdateSummary { - summary: ContextSummary, + summary: ContextSummaryContent, version: clock::Global, }, SlashCommandStarted { @@ -203,7 +203,7 @@ impl ContextOperation { version: language::proto::deserialize_version(&update.version), }), proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary { - summary: ContextSummary { + summary: ContextSummaryContent { text: update.summary, done: update.done, timestamp: language::proto::deserialize_timestamp( @@ -467,11 +467,73 @@ pub enum ContextEvent { Operation(ContextOperation), } -#[derive(Clone, Default, Debug)] -pub struct ContextSummary { +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum ContextSummary { + Pending, + Content(ContextSummaryContent), + Error, +} + +#[derive(Default, Clone, Debug, Eq, PartialEq)] +pub struct ContextSummaryContent { pub text: String, pub done: bool, - timestamp: clock::Lamport, + pub timestamp: clock::Lamport, +} + +impl ContextSummary { + pub const DEFAULT: &str = "New Text Thread"; + + pub fn or_default(&self) -> SharedString { + self.unwrap_or(Self::DEFAULT) + } + + pub fn unwrap_or(&self, message: impl Into) -> SharedString { + self.content() + .map_or_else(|| message.into(), |content| content.text.clone().into()) + } + + pub fn content(&self) -> Option<&ContextSummaryContent> { + match self { + ContextSummary::Content(content) => Some(content), + ContextSummary::Pending | ContextSummary::Error => None, + } + } + + fn content_as_mut(&mut self) -> Option<&mut ContextSummaryContent> { + match self { + ContextSummary::Content(content) => Some(content), + ContextSummary::Pending | ContextSummary::Error => None, + } + } + + fn content_or_set_empty(&mut self) -> &mut ContextSummaryContent { + match self { + ContextSummary::Content(content) => content, + ContextSummary::Pending | ContextSummary::Error => { + let content = ContextSummaryContent::default(); + *self = ContextSummary::Content(content); + self.content_as_mut().unwrap() + } + } + } + + pub fn is_pending(&self) -> bool { + matches!(self, ContextSummary::Pending) + } + + fn timestamp(&self) -> Option { + match self { + ContextSummary::Content(content) => Some(content.timestamp), + ContextSummary::Pending | ContextSummary::Error => None, + } + } +} + +impl PartialOrd for ContextSummary { + fn partial_cmp(&self, other: &Self) -> Option { + self.timestamp().partial_cmp(&other.timestamp()) + } } #[derive(Clone, Debug, Eq, PartialEq)] @@ -607,7 +669,7 @@ pub struct AssistantContext { message_anchors: Vec, contents: Vec, messages_metadata: HashMap, - summary: Option, + summary: ContextSummary, summary_task: Task>, completion_count: usize, pending_completions: Vec, @@ -694,7 +756,7 @@ impl AssistantContext { slash_command_output_sections: Vec::new(), thought_process_output_sections: Vec::new(), edits_since_last_parse: edits_since_last_slash_command_parse, - summary: None, + summary: ContextSummary::Pending, summary_task: Task::ready(None), completion_count: Default::default(), pending_completions: Default::default(), @@ -753,7 +815,7 @@ impl AssistantContext { .collect(), summary: self .summary - .as_ref() + .content() .map(|summary| summary.text.clone()) .unwrap_or_default(), slash_command_output_sections: self @@ -939,12 +1001,10 @@ impl AssistantContext { summary: new_summary, .. } => { - if self - .summary - .as_ref() - .map_or(true, |summary| new_summary.timestamp > summary.timestamp) - { - self.summary = Some(new_summary); + if self.summary.timestamp().map_or(true, |current_timestamp| { + new_summary.timestamp > current_timestamp + }) { + self.summary = ContextSummary::Content(new_summary); summary_generated = true; } } @@ -1102,8 +1162,8 @@ impl AssistantContext { self.path.as_ref() } - pub fn summary(&self) -> Option<&ContextSummary> { - self.summary.as_ref() + pub fn summary(&self) -> &ContextSummary { + &self.summary } pub fn parsed_slash_commands(&self) -> &[ParsedSlashCommand] { @@ -2576,7 +2636,7 @@ impl AssistantContext { return; }; - if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) { + if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_pending()) { if !model.provider.is_authenticated(cx) { return; } @@ -2593,17 +2653,20 @@ impl AssistantContext { // If there is no summary, it is set with `done: false` so that "Loading Summary…" can // be displayed. - if self.summary.is_none() { - self.summary = Some(ContextSummary { - text: "".to_string(), - done: false, - timestamp: clock::Lamport::default(), - }); - replace_old = true; + match self.summary { + ContextSummary::Pending | ContextSummary::Error => { + self.summary = ContextSummary::Content(ContextSummaryContent { + text: "".to_string(), + done: false, + timestamp: clock::Lamport::default(), + }); + replace_old = true; + } + ContextSummary::Content(_) => {} } self.summary_task = cx.spawn(async move |this, cx| { - async move { + let result = async { let stream = model.model.stream_completion_text(request, &cx); let mut messages = stream.await?; @@ -2614,7 +2677,7 @@ impl AssistantContext { this.update(cx, |this, cx| { let version = this.version.clone(); let timestamp = this.next_timestamp(); - let summary = this.summary.get_or_insert(ContextSummary::default()); + let summary = this.summary.content_or_set_empty(); if !replaced && replace_old { summary.text.clear(); replaced = true; @@ -2636,10 +2699,19 @@ impl AssistantContext { } } + this.read_with(cx, |this, _cx| { + if let Some(summary) = this.summary.content() { + if summary.text.is_empty() { + bail!("Model generated an empty summary"); + } + } + Ok(()) + })??; + this.update(cx, |this, cx| { let version = this.version.clone(); let timestamp = this.next_timestamp(); - if let Some(summary) = this.summary.as_mut() { + if let Some(summary) = this.summary.content_as_mut() { summary.done = true; summary.timestamp = timestamp; let operation = ContextOperation::UpdateSummary { @@ -2654,8 +2726,18 @@ impl AssistantContext { anyhow::Ok(()) } - .log_err() - .await + .await; + + if let Err(err) = result { + this.update(cx, |this, cx| { + this.summary = ContextSummary::Error; + cx.emit(ContextEvent::SummaryChanged); + }) + .log_err(); + log::error!("Error generating context summary: {}", err); + } + + Some(()) }); } } @@ -2769,7 +2851,7 @@ impl AssistantContext { let (old_path, summary) = this.read_with(cx, |this, _| { let path = this.path.clone(); - let summary = if let Some(summary) = this.summary.as_ref() { + let summary = if let Some(summary) = this.summary.content() { if summary.done { Some(summary.text.clone()) } else { @@ -2823,21 +2905,12 @@ impl AssistantContext { pub fn set_custom_summary(&mut self, custom_summary: String, cx: &mut Context) { let timestamp = self.next_timestamp(); - let summary = self.summary.get_or_insert(ContextSummary::default()); + let summary = self.summary.content_or_set_empty(); summary.timestamp = timestamp; summary.done = true; summary.text = custom_summary; cx.emit(ContextEvent::SummaryChanged); } - - pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Text Thread"); - - pub fn summary_or_default(&self) -> SharedString { - self.summary - .as_ref() - .map(|summary| summary.text.clone().into()) - .unwrap_or(Self::DEFAULT_SUMMARY) - } } #[derive(Debug, Default)] @@ -3053,7 +3126,7 @@ impl SavedContext { let timestamp = next_timestamp.tick(); operations.push(ContextOperation::UpdateSummary { - summary: ContextSummary { + summary: ContextSummaryContent { text: self.summary, done: true, timestamp, diff --git a/crates/assistant_context_editor/src/context/context_tests.rs b/crates/assistant_context_editor/src/context/context_tests.rs index a8a5f3835d..3983a90158 100644 --- a/crates/assistant_context_editor/src/context/context_tests.rs +++ b/crates/assistant_context_editor/src/context/context_tests.rs @@ -1,5 +1,5 @@ use crate::{ - AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation, + AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation, ContextSummary, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus, }; use anyhow::Result; @@ -16,7 +16,10 @@ use futures::{ }; use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*}; use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate}; -use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role}; +use language_model::{ + ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role, + fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}, +}; use parking_lot::Mutex; use pretty_assertions::assert_eq; use project::Project; @@ -1177,6 +1180,187 @@ fn test_mark_cache_anchors(cx: &mut App) { ); } +#[gpui::test] +async fn test_summarization(cx: &mut TestAppContext) { + let (context, fake_model) = setup_context_editor_with_fake_model(cx); + + // Initial state should be pending + context.read_with(cx, |context, _| { + assert!(matches!(context.summary(), ContextSummary::Pending)); + assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT); + }); + + let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone()); + context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) + .unwrap(); + }); + + // Send a message + context.update(cx, |context, cx| { + context.assist(cx); + }); + + simulate_successful_response(&fake_model, cx); + + // Should start generating summary when there are >= 2 messages + context.read_with(cx, |context, _| { + assert!(!context.summary().content().unwrap().done); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("Brief".into()); + fake_model.stream_last_completion_response(" Introduction".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Summary should be set + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "Brief Introduction"); + }); + + // We should be able to manually set a summary + context.update(cx, |context, cx| { + context.set_custom_summary("Brief Intro".into(), cx); + }); + + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "Brief Intro"); + }); +} + +#[gpui::test] +async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { + let (context, fake_model) = setup_context_editor_with_fake_model(cx); + + test_summarize_error(&fake_model, &context, cx); + + // Now we should be able to set a summary + context.update(cx, |context, cx| { + context.set_custom_summary("Brief Intro".into(), cx); + }); + + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "Brief Intro"); + }); +} + +#[gpui::test] +async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { + let (context, fake_model) = setup_context_editor_with_fake_model(cx); + + test_summarize_error(&fake_model, &context, cx); + + // Sending another message should not trigger another summarize request + context.update(cx, |context, cx| { + context.assist(cx); + }); + + simulate_successful_response(&fake_model, cx); + + context.read_with(cx, |context, _| { + // State is still Error, not Generating + assert!(matches!(context.summary(), ContextSummary::Error)); + }); + + // But the summarize request can be invoked manually + context.update(cx, |context, cx| { + context.summarize(true, cx); + }); + + context.read_with(cx, |context, _| { + assert!(!context.summary().content().unwrap().done); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("A successful summary".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "A successful summary"); + }); +} + +fn test_summarize_error( + model: &Arc, + context: &Entity, + cx: &mut TestAppContext, +) { + let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone()); + context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) + .unwrap(); + }); + + // Send a message + context.update(cx, |context, cx| { + context.assist(cx); + }); + + simulate_successful_response(&model, cx); + + context.read_with(cx, |context, _| { + assert!(!context.summary().content().unwrap().done); + }); + + // Simulate summary request ending + cx.run_until_parked(); + model.end_last_completion_stream(); + cx.run_until_parked(); + + // State is set to Error and default message + context.read_with(cx, |context, _| { + assert_eq!(*context.summary(), ContextSummary::Error); + assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT); + }); +} + +fn setup_context_editor_with_fake_model( + cx: &mut TestAppContext, +) -> (Entity, Arc) { + let registry = Arc::new(LanguageRegistry::test(cx.executor().clone())); + + let fake_provider = Arc::new(FakeLanguageModelProvider); + let fake_model = Arc::new(fake_provider.test_model()); + + cx.update(|cx| { + init_test(cx); + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model( + Some(ConfiguredModel { + provider: fake_provider.clone(), + model: fake_model.clone(), + }), + cx, + ) + }) + }); + + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new(|cx| { + AssistantContext::local( + registry, + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + cx, + ) + }); + + (context, fake_model) +} + +fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { + cx.run_until_parked(); + fake_model.stream_last_completion_response("Assistant response".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); +} + fn messages(context: &Entity, cx: &App) -> Vec<(MessageId, Role, Range)> { context .read(cx) diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index 37cb766986..21ec018dc8 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -1860,7 +1860,12 @@ impl ContextEditor { } pub fn title(&self, cx: &App) -> SharedString { - self.context.read(cx).summary_or_default() + self.context.read(cx).summary().or_default() + } + + pub fn regenerate_summary(&mut self, cx: &mut Context) { + self.context + .update(cx, |context, cx| context.summarize(true, cx)); } fn render_notice(&self, cx: &mut Context) -> Option { diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index fe89d76109..f1f3b501a6 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -648,7 +648,10 @@ impl ContextStore { if context.replica_id() == ReplicaId::default() { Some(proto::ContextMetadata { context_id: context.id().to_proto(), - summary: context.summary().map(|summary| summary.text.clone()), + summary: context + .summary() + .content() + .map(|summary| summary.text.clone()), }) } else { None