diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index a45787f039..c748f22275 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1020,10 +1020,19 @@ impl AcpThread { cx.emit(AcpThreadEvent::NewEntry); } - pub fn update_title(&mut self, title: SharedString, cx: &mut Context) -> Result<()> { - self.title = title; - cx.emit(AcpThreadEvent::TitleUpdated); - Ok(()) + pub fn can_set_title(&mut self, cx: &mut Context) -> bool { + self.connection.set_title(&self.session_id, cx).is_some() + } + + pub fn set_title(&mut self, title: SharedString, cx: &mut Context) -> Task> { + if title != self.title { + self.title = title.clone(); + cx.emit(AcpThreadEvent::TitleUpdated); + if let Some(set_title) = self.connection.set_title(&self.session_id, cx) { + return set_title.run(title, cx); + } + } + Task::ready(Ok(())) } pub fn update_token_usage(&mut self, usage: Option, cx: &mut Context) { @@ -1326,11 +1335,7 @@ impl AcpThread { }; let git_store = self.project.read(cx).git_store().clone(); - let message_id = if self - .connection - .session_editor(&self.session_id, cx) - .is_some() - { + let message_id = if self.connection.truncate(&self.session_id, cx).is_some() { Some(UserMessageId::new()) } else { None @@ -1476,7 +1481,7 @@ impl AcpThread { /// Rewinds this thread to before the entry at `index`, removing it and all /// subsequent entries while reverting any changes made from that point. pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context) -> Task> { - let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else { + let Some(truncate) = self.connection.truncate(&self.session_id, cx) else { return Task::ready(Err(anyhow!("not supported"))); }; let Some(message) = self.user_message(&id) else { @@ -1496,8 +1501,7 @@ impl AcpThread { .await?; } - cx.update(|cx| session_editor.truncate(id.clone(), cx))? - .await?; + cx.update(|cx| truncate.run(id.clone(), cx))?.await?; this.update(cx, |this, cx| { if let Some((ix, _)) = this.user_message_mut(&id) { let range = ix..this.entries.len(); @@ -2652,11 +2656,11 @@ mod tests { .detach(); } - fn session_editor( + fn truncate( &self, session_id: &acp::SessionId, _cx: &mut App, - ) -> Option> { + ) -> Option> { Some(Rc::new(FakeAgentSessionEditor { _session_id: session_id.clone(), })) @@ -2671,8 +2675,8 @@ mod tests { _session_id: acp::SessionId, } - impl AgentSessionEditor for FakeAgentSessionEditor { - fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task> { + impl AgentSessionTruncate for FakeAgentSessionEditor { + fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task> { Task::ready(Ok(())) } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 2bbd364873..91e46dbac1 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -50,11 +50,19 @@ pub trait AgentConnection { fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); - fn session_editor( + fn truncate( &self, _session_id: &acp::SessionId, _cx: &mut App, - ) -> Option> { + ) -> Option> { + None + } + + fn set_title( + &self, + _session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { None } @@ -79,14 +87,18 @@ impl dyn AgentConnection { } } -pub trait AgentSessionEditor { - fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task>; +pub trait AgentSessionTruncate { + fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task>; } pub trait AgentSessionResume { fn run(&self, cx: &mut App) -> Task>; } +pub trait AgentSessionSetTitle { + fn run(&self, title: SharedString, cx: &mut App) -> Task>; +} + pub trait AgentTelemetry { /// The name of the agent used for telemetry. fn agent_name(&self) -> String; @@ -424,11 +436,11 @@ mod test_support { } } - fn session_editor( + fn truncate( &self, _session_id: &agent_client_protocol::SessionId, _cx: &mut App, - ) -> Option> { + ) -> Option> { Some(Rc::new(StubAgentSessionEditor)) } @@ -439,8 +451,8 @@ mod test_support { struct StubAgentSessionEditor; - impl AgentSessionEditor for StubAgentSessionEditor { - fn truncate(&self, _: UserMessageId, _: &mut App) -> Task> { + impl AgentSessionTruncate for StubAgentSessionEditor { + fn run(&self, _: UserMessageId, _: &mut App) -> Task> { Task::ready(Ok(())) } } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index d5bc0fea63..bbc30b74bc 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -2,7 +2,7 @@ use crate::{ ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization, UserMessageContent, templates::Templates, }; -use crate::{HistoryStore, TokenUsageUpdated}; +use crate::{HistoryStore, TitleUpdated, TokenUsageUpdated}; use acp_thread::{AcpThread, AgentModelSelector}; use action_log::ActionLog; use agent_client_protocol as acp; @@ -253,6 +253,7 @@ impl NativeAgent { cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); }), + cx.subscribe(&thread_handle, Self::handle_thread_title_updated), cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated), cx.observe(&thread_handle, move |this, thread, cx| { this.save_thread(thread, cx) @@ -441,6 +442,26 @@ impl NativeAgent { }) } + fn handle_thread_title_updated( + &mut self, + thread: Entity, + _: &TitleUpdated, + cx: &mut Context, + ) { + let session_id = thread.read(cx).id(); + let Some(session) = self.sessions.get(session_id) else { + return; + }; + let thread = thread.downgrade(); + let acp_thread = session.acp_thread.clone(); + cx.spawn(async move |_, cx| { + let title = thread.read_with(cx, |thread, _| thread.title())?; + let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?; + task.await + }) + .detach_and_log_err(cx); + } + fn handle_thread_token_usage_updated( &mut self, thread: Entity, @@ -717,10 +738,6 @@ impl NativeAgentConnection { thread.update_tool_call(update, cx) })??; } - ThreadEvent::TitleUpdate(title) => { - acp_thread - .update(cx, |thread, cx| thread.update_title(title, cx))??; - } ThreadEvent::Retry(status) => { acp_thread.update(cx, |thread, cx| { thread.update_retry_status(status, cx) @@ -856,8 +873,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .models .model_from_id(&LanguageModels::model_id(&default_model.model)) }); - - let thread = cx.new(|cx| { + Ok(cx.new(|cx| { Thread::new( project.clone(), agent.project_context.clone(), @@ -867,9 +883,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { default_model, cx, ) - }); - - Ok(thread) + })) }, )??; agent.update(cx, |agent, cx| agent.register_session(thread, cx)) @@ -941,11 +955,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { }); } - fn session_editor( + fn truncate( &self, session_id: &agent_client_protocol::SessionId, cx: &mut App, - ) -> Option> { + ) -> Option> { self.0.update(cx, |agent, _cx| { agent.sessions.get(session_id).map(|session| { Rc::new(NativeAgentSessionEditor { @@ -956,6 +970,17 @@ impl acp_thread::AgentConnection for NativeAgentConnection { }) } + fn set_title( + &self, + session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + Some(Rc::new(NativeAgentSessionSetTitle { + connection: self.clone(), + session_id: session_id.clone(), + }) as _) + } + fn telemetry(&self) -> Option> { Some(Rc::new(self.clone()) as Rc) } @@ -991,8 +1016,8 @@ struct NativeAgentSessionEditor { acp_thread: WeakEntity, } -impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { - fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { +impl acp_thread::AgentSessionTruncate for NativeAgentSessionEditor { + fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { match self.thread.update(cx, |thread, cx| { thread.truncate(message_id.clone(), cx)?; Ok(thread.latest_token_usage()) @@ -1024,6 +1049,22 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume { } } +struct NativeAgentSessionSetTitle { + connection: NativeAgentConnection, + session_id: acp::SessionId, +} + +impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle { + fn run(&self, title: SharedString, cx: &mut App) -> Task> { + let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else { + return Task::ready(Err(anyhow!("session not found"))); + }; + let thread = session.thread.clone(); + thread.update(cx, |thread, cx| thread.set_title(title, cx)); + Task::ready(Ok(())) + } +} + #[cfg(test)] mod tests { use crate::HistoryEntryId; @@ -1323,6 +1364,8 @@ mod tests { ) }); + cx.run_until_parked(); + // Drop the ACP thread, which should cause the session to be dropped as well. cx.update(|_| { drop(thread); diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index edba227da7..e7e28f495e 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1383,6 +1383,7 @@ async fn test_title_generation(cx: &mut TestAppContext) { summary_model.send_last_completion_stream_text_chunk("oodnight Moon"); summary_model.end_last_completion_stream(); send.collect::>().await; + cx.run_until_parked(); thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world")); // Send another message, ensuring no title is generated this time. diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 6f560cd390..f6ef11c20b 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -487,7 +487,6 @@ pub enum ThreadEvent { ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), - TitleUpdate(SharedString), Retry(acp_thread::RetryStatus), Stop(acp::StopReason), } @@ -514,6 +513,7 @@ pub struct Thread { prompt_id: PromptId, updated_at: DateTime, title: Option, + pending_title_generation: Option>, summary: Option, messages: Vec, completion_mode: CompletionMode, @@ -555,6 +555,7 @@ impl Thread { prompt_id: PromptId::new(), updated_at: Utc::now(), title: None, + pending_title_generation: None, summary: None, messages: Vec::new(), completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, @@ -705,6 +706,7 @@ impl Thread { } else { Some(db_thread.title.clone()) }, + pending_title_generation: None, summary: db_thread.detailed_summary, messages: db_thread.messages, completion_mode: db_thread.completion_mode.unwrap_or_default(), @@ -1086,7 +1088,7 @@ impl Thread { event_stream: event_stream.clone(), _task: cx.spawn(async move |this, cx| { log::info!("Starting agent turn execution"); - let mut update_title = None; + let turn_result: Result<()> = async { let mut intent = CompletionIntent::UserPrompt; loop { @@ -1095,8 +1097,8 @@ impl Thread { let mut end_turn = true; this.update(cx, |this, cx| { // Generate title if needed. - if this.title.is_none() && update_title.is_none() { - update_title = Some(this.update_title(&event_stream, cx)); + if this.title.is_none() && this.pending_title_generation.is_none() { + this.generate_title(cx); } // End the turn if the model didn't use tools. @@ -1120,10 +1122,6 @@ impl Thread { .await; _ = this.update(cx, |this, cx| this.flush_pending_message(cx)); - if let Some(update_title) = update_title { - update_title.await.context("update title failed").log_err(); - } - match turn_result { Ok(()) => { log::info!("Turn execution completed"); @@ -1607,19 +1605,15 @@ impl Thread { }) } - fn update_title( - &mut self, - event_stream: &ThreadEventStream, - cx: &mut Context, - ) -> Task> { + fn generate_title(&mut self, cx: &mut Context) { + let Some(model) = self.summarization_model.clone() else { + return; + }; + log::info!( "Generating title with model: {:?}", self.summarization_model.as_ref().map(|model| model.name()) ); - let Some(model) = self.summarization_model.clone() else { - return Task::ready(Ok(())); - }; - let event_stream = event_stream.clone(); let mut request = LanguageModelRequest { intent: Some(CompletionIntent::ThreadSummarization), temperature: AgentSettings::temperature_for_model(&model, cx), @@ -1635,42 +1629,51 @@ impl Thread { content: vec![SUMMARIZE_THREAD_PROMPT.into()], cache: false, }); - cx.spawn(async move |this, cx| { + self.pending_title_generation = Some(cx.spawn(async move |this, cx| { let mut title = String::new(); - let mut messages = model.stream_completion(request, cx).await?; - while let Some(event) = messages.next().await { - let event = event?; - let text = match event { - LanguageModelCompletionEvent::Text(text) => text, - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::UsageUpdated { amount, limit }, - ) => { - this.update(cx, |thread, cx| { - thread.update_model_request_usage(amount, limit, cx); - })?; - continue; + + let generate = async { + let mut messages = model.stream_completion(request, cx).await?; + while let Some(event) = messages.next().await { + let event = event?; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { amount, limit }, + ) => { + this.update(cx, |thread, cx| { + thread.update_model_request_usage(amount, limit, cx); + })?; + continue; + } + _ => continue, + }; + + let mut lines = text.lines(); + title.extend(lines.next()); + + // Stop if the LLM generated multiple lines. + if lines.next().is_some() { + break; } - _ => continue, - }; - - let mut lines = text.lines(); - title.extend(lines.next()); - - // Stop if the LLM generated multiple lines. - if lines.next().is_some() { - break; } + anyhow::Ok(()) + }; + + if generate.await.context("failed to generate title").is_ok() { + _ = this.update(cx, |this, cx| this.set_title(title.into(), cx)); } + _ = this.update(cx, |this, _| this.pending_title_generation = None); + })); + } - log::info!("Setting title: {}", title); - - this.update(cx, |this, cx| { - let title = SharedString::from(title); - event_stream.send_title_update(title.clone()); - this.title = Some(title); - cx.notify(); - }) - }) + pub fn set_title(&mut self, title: SharedString, cx: &mut Context) { + self.pending_title_generation = None; + if Some(&title) != self.title.as_ref() { + self.title = Some(title); + cx.emit(TitleUpdated); + cx.notify(); + } } fn last_user_message(&self) -> Option<&UserMessage> { @@ -1975,6 +1978,10 @@ pub struct TokenUsageUpdated(pub Option); impl EventEmitter for Thread {} +pub struct TitleUpdated; + +impl EventEmitter for Thread {} + pub trait AgentTool where Self: 'static + Sized, @@ -2132,12 +2139,6 @@ where struct ThreadEventStream(mpsc::UnboundedSender>); impl ThreadEventStream { - fn send_title_update(&self, text: SharedString) { - self.0 - .unbounded_send(Ok(ThreadEvent::TitleUpdate(text))) - .ok(); - } - fn send_user_message(&self, message: &UserMessage) { self.0 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone()))) diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 05d31051b2..936f987864 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -15,7 +15,7 @@ use buffer_diff::BufferDiff; use client::zed_urls; use collections::{HashMap, HashSet}; use editor::scroll::Autoscroll; -use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects}; +use editor::{Editor, EditorEvent, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; use fs::Fs; use gpui::{ @@ -281,7 +281,8 @@ enum ThreadState { }, Ready { thread: Entity, - _subscription: [Subscription; 2], + title_editor: Option>, + _subscriptions: Vec, }, LoadError(LoadError), Unauthenticated { @@ -445,12 +446,7 @@ impl AcpThreadView { this.update_in(cx, |this, window, cx| { match result { Ok(thread) => { - let thread_subscription = - cx.subscribe_in(&thread, window, Self::handle_thread_event); - let action_log = thread.read(cx).action_log().clone(); - let action_log_subscription = - cx.observe(&action_log, |_, _, cx| cx.notify()); let count = thread.read(cx).entries().len(); this.list_state.splice(0..0, count); @@ -489,9 +485,31 @@ impl AcpThreadView { }) }); + let mut subscriptions = vec![ + cx.subscribe_in(&thread, window, Self::handle_thread_event), + cx.observe(&action_log, |_, _, cx| cx.notify()), + ]; + + let title_editor = + if thread.update(cx, |thread, cx| thread.can_set_title(cx)) { + let editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_text(thread.read(cx).title(), window, cx); + editor + }); + subscriptions.push(cx.subscribe_in( + &editor, + window, + Self::handle_title_editor_event, + )); + Some(editor) + } else { + None + }; this.thread_state = ThreadState::Ready { thread, - _subscription: [thread_subscription, action_log_subscription], + title_editor, + _subscriptions: subscriptions, }; this.profile_selector = this.as_native_thread(cx).map(|thread| { @@ -618,6 +636,14 @@ impl AcpThreadView { } } + pub fn title_editor(&self) -> Option> { + if let ThreadState::Ready { title_editor, .. } = &self.thread_state { + title_editor.clone() + } else { + None + } + } + pub fn cancel_generation(&mut self, cx: &mut Context) { self.thread_error.take(); self.thread_retry_status.take(); @@ -662,6 +688,35 @@ impl AcpThreadView { cx.notify(); } + pub fn handle_title_editor_event( + &mut self, + title_editor: &Entity, + event: &EditorEvent, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread() else { return }; + + match event { + EditorEvent::BufferEdited => { + let new_title = title_editor.read(cx).text(cx); + thread.update(cx, |thread, cx| { + thread + .set_title(new_title.into(), cx) + .detach_and_log_err(cx); + }) + } + EditorEvent::Blurred => { + if title_editor.read(cx).text(cx).is_empty() { + title_editor.update(cx, |editor, cx| { + editor.set_text("New Thread", window, cx); + }); + } + } + _ => {} + } + } + pub fn handle_message_editor_event( &mut self, _: &Entity, @@ -1009,7 +1064,17 @@ impl AcpThreadView { self.thread_retry_status.take(); self.thread_state = ThreadState::LoadError(error.clone()); } - AcpThreadEvent::TitleUpdated | AcpThreadEvent::TokenUsageUpdated => {} + AcpThreadEvent::TitleUpdated => { + let title = thread.read(cx).title(); + if let Some(title_editor) = self.title_editor() { + title_editor.update(cx, |editor, cx| { + if editor.text(cx) != title { + editor.set_text(title, window, cx); + } + }); + } + } + AcpThreadEvent::TokenUsageUpdated => {} } cx.notify(); } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 65a9da573a..d2ff6aa4f3 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -905,7 +905,7 @@ impl AgentPanel { fn active_thread_view(&self) -> Option<&Entity> { match &self.active_view { - ActiveView::ExternalAgentThread { thread_view } => Some(thread_view), + ActiveView::ExternalAgentThread { thread_view, .. } => Some(thread_view), ActiveView::Thread { .. } | ActiveView::TextThread { .. } | ActiveView::History @@ -2075,9 +2075,32 @@ impl AgentPanel { } } ActiveView::ExternalAgentThread { thread_view } => { - Label::new(thread_view.read(cx).title(cx)) - .truncate() - .into_any_element() + if let Some(title_editor) = thread_view.read(cx).title_editor() { + div() + .w_full() + .on_action({ + let thread_view = thread_view.downgrade(); + move |_: &menu::Confirm, window, cx| { + if let Some(thread_view) = thread_view.upgrade() { + thread_view.focus_handle(cx).focus(window); + } + } + }) + .on_action({ + let thread_view = thread_view.downgrade(); + move |_: &editor::actions::Cancel, window, cx| { + if let Some(thread_view) = thread_view.upgrade() { + thread_view.focus_handle(cx).focus(window); + } + } + }) + .child(title_editor) + .into_any_element() + } else { + Label::new(thread_view.read(cx).title(cx)) + .truncate() + .into_any_element() + } } ActiveView::TextThread { title_editor,