diff --git a/Cargo.lock b/Cargo.lock index c1a30686df..3634dbe4e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -94,6 +94,7 @@ dependencies = [ "parking_lot", "paths", "picker", + "postage", "project", "prompt_store", "proto", diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 4152843d22..83562a321b 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -60,6 +60,7 @@ ordered-float.workspace = true parking_lot.workspace = true paths.workspace = true picker.workspace = true +postage.workspace = true project.workspace = true rules_library.workspace = true prompt_store.workspace = true diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index 7b19f49313..3912adda37 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -542,12 +542,21 @@ impl ThreadContextHandle { } fn load(self, cx: &App) -> Task>)>> { - let context = AgentContext::Thread(ThreadContext { - title: self.title(cx), - text: self.thread.read(cx).latest_detailed_summary_or_text(), - handle: self, - }); - Task::ready(Some((context, vec![]))) + cx.spawn(async move |cx| { + let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?; + let title = self + .thread + .read_with(cx, |thread, _cx| { + thread.summary().unwrap_or_else(|| "New thread".into()) + }) + .ok()?; + let context = AgentContext::Thread(ThreadContext { + title, + text, + handle: self, + }); + Some((context, vec![])) + }) } } diff --git a/crates/agent/src/context_store.rs b/crates/agent/src/context_store.rs index 6ad9888d43..c5c0826323 100644 --- a/crates/agent/src/context_store.rs +++ b/crates/agent/src/context_store.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use anyhow::{Result, anyhow}; use collections::{HashSet, IndexSet}; -use futures::future::join_all; use futures::{self, FutureExt}; use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity}; use language::Buffer; @@ -13,7 +12,6 @@ use project::{Project, ProjectItem, ProjectPath, Symbol}; use prompt_store::UserPromptId; use ref_cast::RefCast as _; use text::{Anchor, OffsetRangeExt}; -use util::ResultExt as _; use crate::ThreadStore; use crate::context::{ @@ -27,7 +25,6 @@ use crate::thread::{Thread, ThreadId}; pub struct ContextStore { project: WeakEntity, thread_store: Option>, - thread_summary_tasks: Vec>, next_context_id: ContextId, context_set: IndexSet, context_thread_ids: HashSet, @@ -41,7 +38,6 @@ impl ContextStore { Self { project, thread_store, - thread_summary_tasks: Vec::new(), next_context_id: ContextId::zero(), context_set: IndexSet::default(), context_thread_ids: HashSet::default(), @@ -201,41 +197,6 @@ impl ContextStore { } } - fn start_summarizing_thread_if_needed( - &mut self, - thread: &Entity, - cx: &mut Context, - ) { - if let Some(summary_task) = - thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx)) - { - let thread = thread.clone(); - let thread_store = self.thread_store.clone(); - - self.thread_summary_tasks.push(cx.spawn(async move |_, cx| { - summary_task.await; - - if let Some(thread_store) = thread_store { - // Save thread so its summary can be reused later - let save_task = thread_store - .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)); - - if let Some(save_task) = save_task.ok() { - save_task.await.log_err(); - } - } - })); - } - } - - pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> { - let tasks = std::mem::take(&mut self.thread_summary_tasks); - - cx.spawn(async move |_cx| { - join_all(tasks).await; - }) - } - pub fn add_rules( &mut self, prompt_id: UserPromptId, @@ -331,9 +292,15 @@ impl ContextStore { fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context) -> bool { match &context { AgentContextHandle::Thread(thread_context) => { - self.context_thread_ids - .insert(thread_context.thread.read(cx).id().clone()); - self.start_summarizing_thread_if_needed(&thread_context.thread, cx); + if let Some(thread_store) = self.thread_store.clone() { + thread_context.thread.update(cx, |thread, cx| { + thread.start_generating_detailed_summary_if_needed(thread_store, cx); + }); + self.context_thread_ids + .insert(thread_context.thread.read(cx).id().clone()); + } else { + return false; + } } _ => {} } diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index cc4f439f68..17fbcfb5a0 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -60,11 +60,10 @@ pub struct MessageEditor { context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, last_loaded_context: Option, - context_load_task: Option>>, + load_context_task: Option>>, profile_selector: Entity, edits_expanded: bool, editor_is_expanded: bool, - waiting_for_summaries_to_send: bool, last_estimated_token_count: Option, update_token_count_task: Option>>, _subscriptions: Vec, @@ -149,7 +148,8 @@ impl MessageEditor { _ => {} }), cx.observe(&context_store, |this, _, cx| { - let _ = this.start_context_load(cx); + // When context changes, reload it for token counting. + let _ = this.reload_context(cx); }), ]; @@ -163,7 +163,7 @@ impl MessageEditor { prompt_store, context_strip, context_picker_menu_handle, - context_load_task: None, + load_context_task: None, last_loaded_context: None, model_selector: cx.new(|cx| { AssistantModelSelector::new( @@ -177,7 +177,6 @@ impl MessageEditor { }), edits_expanded: false, editor_is_expanded: false, - waiting_for_summaries_to_send: false, profile_selector: cx .new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)), last_estimated_token_count: None, @@ -289,7 +288,7 @@ impl MessageEditor { let thread = self.thread.clone(); let git_store = self.project.read(cx).git_store().clone(); let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx)); - let context_task = self.load_context(cx); + let context_task = self.reload_context(cx); let window_handle = window.window_handle(); cx.spawn(async move |_this, cx| { @@ -312,30 +311,6 @@ impl MessageEditor { .detach(); } - fn wait_for_summaries(&mut self, cx: &mut Context) -> Task<()> { - let context_store = self.context_store.clone(); - cx.spawn(async move |this, cx| { - if let Some(wait_for_summaries) = context_store - .update(cx, |context_store, cx| context_store.wait_for_summaries(cx)) - .ok() - { - this.update(cx, |this, cx| { - this.waiting_for_summaries_to_send = true; - cx.notify(); - }) - .ok(); - - wait_for_summaries.await; - - this.update(cx, |this, cx| { - this.waiting_for_summaries_to_send = false; - cx.notify(); - }) - .ok(); - } - }) - } - fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context) { let cancelled = self.thread.update(cx, |thread, cx| { thread.cancel_last_completion(Some(window.window_handle()), cx) @@ -664,31 +639,31 @@ impl MessageEditor { }) .when(!is_editor_empty, |parent| { parent.child( - IconButton::new("send-message", IconName::Send) - .icon_color(Color::Accent) - .style(ButtonStyle::Filled) - .disabled( - !is_model_selected - || self - .waiting_for_summaries_to_send, - ) - .on_click({ - let focus_handle = focus_handle.clone(); - move |_event, window, cx| { - focus_handle.dispatch_action( - &Chat, window, cx, - ); - } - }) - .tooltip(move |window, cx| { - Tooltip::for_action( - "Stop and Send New Message", - &Chat, - window, - cx, + IconButton::new( + "send-message", + IconName::Send, ) - }), - ) + .icon_color(Color::Accent) + .style(ButtonStyle::Filled) + .disabled(!is_model_selected) + .on_click({ + let focus_handle = + focus_handle.clone(); + move |_event, window, cx| { + focus_handle.dispatch_action( + &Chat, window, cx, + ); + } + }) + .tooltip(move |window, cx| { + Tooltip::for_action( + "Stop and Send New Message", + &Chat, + window, + cx, + ) + }), + ) }) } else { parent.child( @@ -696,10 +671,7 @@ impl MessageEditor { .icon_color(Color::Accent) .style(ButtonStyle::Filled) .disabled( - is_editor_empty - || !is_model_selected - || self - .waiting_for_summaries_to_send, + is_editor_empty || !is_model_selected, ) .on_click({ let focus_handle = focus_handle.clone(); @@ -1041,16 +1013,8 @@ impl MessageEditor { self.update_token_count_task.is_some() } - fn handle_message_changed(&mut self, cx: &mut Context) { - self.message_or_context_changed(true, cx); - } - - fn start_context_load(&mut self, cx: &mut Context) -> Shared> { - let summaries_task = self.wait_for_summaries(cx); + fn reload_context(&mut self, cx: &mut Context) -> Task> { let load_task = cx.spawn(async move |this, cx| { - // Waits for detailed summaries before `load_context`, as it directly reads these from - // the thread. TODO: Would be cleaner to have context loading await on summarization. - summaries_task.await; let Ok(load_task) = this.update(cx, |this, cx| { let new_context = this.context_store.read_with(cx, |context_store, cx| { context_store.new_context_for_thread(this.thread.read(cx)) @@ -1062,27 +1026,26 @@ impl MessageEditor { let result = load_task.await; this.update(cx, |this, cx| { this.last_loaded_context = Some(result); - this.context_load_task = None; + this.load_context_task = None; this.message_or_context_changed(false, cx); }) .ok(); }); // Replace existing load task, if any, causing it to be cancelled. let load_task = load_task.shared(); - self.context_load_task = Some(load_task.clone()); - load_task - } - - fn load_context(&mut self, cx: &mut Context) -> Task> { - let context_load_task = self.start_context_load(cx); + self.load_context_task = Some(load_task.clone()); cx.spawn(async move |this, cx| { - context_load_task.await; + load_task.await; this.read_with(cx, |this, _cx| this.last_loaded_context.clone()) .ok() .flatten() }) } + fn handle_message_changed(&mut self, cx: &mut Context) { + self.message_or_context_changed(true, cx); + } + fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context) { cx.emit(MessageEditorEvent::Changed); self.update_token_count_task.take(); @@ -1183,41 +1146,6 @@ impl Render for MessageEditor { v_flex() .size_full() - .when(self.waiting_for_summaries_to_send, |parent| { - parent.child( - h_flex().py_3().w_full().justify_center().child( - h_flex() - .flex_none() - .px_2() - .py_2() - .bg(cx.theme().colors().editor_background) - .border_1() - .border_color(cx.theme().colors().border_variant) - .rounded_lg() - .shadow_md() - .gap_1() - .child( - Icon::new(IconName::ArrowCircle) - .size(IconSize::XSmall) - .color(Color::Muted) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(gpui::Transformation::rotate( - gpui::percentage(delta), - )) - }, - ), - ) - .child( - Label::new("Summarizing context…") - .size(LabelSize::XSmall) - .color(Color::Muted), - ), - ), - ) - }) .when(changed_buffers.len() > 0, |parent| { parent.child(self.render_changed_buffers(&changed_buffers, window, cx)) }) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 4e62dbea29..88a56bf2a5 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -14,7 +14,8 @@ use futures::future::Shared; use futures::{FutureExt, StreamExt as _}; use git::repository::DiffType; use gpui::{ - AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, + AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, + WeakEntity, }; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -24,6 +25,7 @@ use language_model::{ ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason, TokenUsage, }; +use postage::stream::Stream as _; use project::Project; use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState}; use prompt_store::{ModelContext, PromptBuilder}; @@ -36,6 +38,7 @@ use util::{ResultExt as _, TryFutureExt as _, post_inc}; use uuid::Uuid; use zed_llm_client::CompletionMode; +use crate::ThreadStore; use crate::context::{AgentContext, ContextLoadResult, LoadedContext}; use crate::thread_store::{ SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, @@ -243,6 +246,16 @@ pub enum DetailedSummaryState { }, } +impl DetailedSummaryState { + fn text(&self) -> Option { + if let Self::Generated { text, .. } = self { + Some(text.clone()) + } else { + None + } + } +} + #[derive(Default)] pub struct TotalTokenUsage { pub total: usize, @@ -290,7 +303,9 @@ pub struct Thread { updated_at: DateTime, summary: Option, pending_summary: Task>, - detailed_summary_state: DetailedSummaryState, + detailed_summary_task: Task>, + detailed_summary_tx: postage::watch::Sender, + detailed_summary_rx: postage::watch::Receiver, completion_mode: Option, messages: Vec, next_message_id: MessageId, @@ -335,12 +350,15 @@ impl Thread { system_prompt: SharedProjectContext, cx: &mut Context, ) -> Self { + let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); Self { id: ThreadId::new(), updated_at: Utc::now(), summary: None, pending_summary: Task::ready(None), - detailed_summary_state: DetailedSummaryState::NotGenerated, + detailed_summary_task: Task::ready(None), + detailed_summary_tx, + detailed_summary_rx, completion_mode: None, messages: Vec::new(), next_message_id: MessageId(0), @@ -390,13 +408,17 @@ impl Thread { .unwrap_or(0), ); let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages); + let (detailed_summary_tx, detailed_summary_rx) = + postage::watch::channel_with(serialized.detailed_summary_state); Self { id, updated_at: serialized.updated_at, summary: Some(serialized.summary), pending_summary: Task::ready(None), - detailed_summary_state: serialized.detailed_summary_state, + detailed_summary_task: Task::ready(None), + detailed_summary_tx, + detailed_summary_rx, completion_mode: None, messages: serialized .messages @@ -509,19 +531,6 @@ impl Thread { } } - pub fn latest_detailed_summary_or_text(&self) -> SharedString { - self.latest_detailed_summary() - .unwrap_or_else(|| self.text().into()) - } - - fn latest_detailed_summary(&self) -> Option { - if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state { - Some(text.clone()) - } else { - None - } - } - pub fn completion_mode(&self) -> Option { self.completion_mode } @@ -941,7 +950,7 @@ impl Thread { initial_project_snapshot, cumulative_token_usage: this.cumulative_token_usage, request_token_usage: this.request_token_usage.clone(), - detailed_summary_state: this.detailed_summary_state.clone(), + detailed_summary_state: this.detailed_summary_rx.borrow().clone(), exceeded_window_error: this.exceeded_window_error.clone(), }) }) @@ -1540,25 +1549,34 @@ impl Thread { }); } - pub fn generate_detailed_summary(&mut self, cx: &mut Context) -> Option> { - let last_message_id = self.messages.last().map(|message| message.id)?; + pub fn start_generating_detailed_summary_if_needed( + &mut self, + thread_store: WeakEntity, + cx: &mut Context, + ) { + let Some(last_message_id) = self.messages.last().map(|message| message.id) else { + return; + }; - match &self.detailed_summary_state { + match &*self.detailed_summary_rx.borrow() { DetailedSummaryState::Generating { message_id, .. } | DetailedSummaryState::Generated { message_id, .. } if *message_id == last_message_id => { // Already up-to-date - return None; + return; } _ => {} } - let ConfiguredModel { model, provider } = - LanguageModelRegistry::read_global(cx).thread_summary_model()?; + let Some(ConfiguredModel { model, provider }) = + LanguageModelRegistry::read_global(cx).thread_summary_model() + else { + return; + }; if !provider.is_authenticated(cx) { - return None; + return; } let added_user_message = "Generate a detailed summary of this conversation. Include:\n\ @@ -1570,16 +1588,24 @@ impl Thread { let request = self.to_summarize_request(added_user_message.into()); - let task = cx.spawn(async move |thread, cx| { + *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating { + message_id: last_message_id, + }; + + // Replace the detailed summarization task if there is one, cancelling it. It would probably + // be better to allow the old task to complete, but this would require logic for choosing + // which result to prefer (the old task could complete after the new one, resulting in a + // stale summary). + self.detailed_summary_task = cx.spawn(async move |thread, cx| { let stream = model.stream_completion_text(request, &cx); let Some(mut messages) = stream.await.log_err() else { thread - .update(cx, |this, _cx| { - this.detailed_summary_state = DetailedSummaryState::NotGenerated; + .update(cx, |thread, _cx| { + *thread.detailed_summary_tx.borrow_mut() = + DetailedSummaryState::NotGenerated; }) - .log_err(); - - return; + .ok()?; + return None; }; let mut new_detailed_summary = String::new(); @@ -1591,25 +1617,56 @@ impl Thread { } thread - .update(cx, |this, _cx| { - this.detailed_summary_state = DetailedSummaryState::Generated { + .update(cx, |thread, _cx| { + *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated { text: new_detailed_summary.into(), message_id: last_message_id, }; }) - .log_err(); + .ok()?; + + // Save thread so its summary can be reused later + if let Some(thread) = thread.upgrade() { + if let Ok(Ok(save_task)) = cx.update(|cx| { + thread_store + .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) + }) { + save_task.await.log_err(); + } + } + + Some(()) }); + } - self.detailed_summary_state = DetailedSummaryState::Generating { - message_id: last_message_id, - }; + pub async fn wait_for_detailed_summary_or_text( + this: &Entity, + cx: &mut AsyncApp, + ) -> Option { + let mut detailed_summary_rx = this + .read_with(cx, |this, _cx| this.detailed_summary_rx.clone()) + .ok()?; + loop { + match detailed_summary_rx.recv().await? { + DetailedSummaryState::Generating { .. } => {} + DetailedSummaryState::NotGenerated => { + return this.read_with(cx, |this, _cx| this.text().into()).ok(); + } + DetailedSummaryState::Generated { text, .. } => return Some(text), + } + } + } - Some(task) + pub fn latest_detailed_summary_or_text(&self) -> SharedString { + self.detailed_summary_rx + .borrow() + .text() + .unwrap_or_else(|| self.text().into()) } pub fn is_generating_detailed_summary(&self) -> bool { matches!( - self.detailed_summary_state, + &*self.detailed_summary_rx.borrow(), DetailedSummaryState::Generating { .. } ) }