diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index a0e62c29e3..033c5dd93c 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -691,6 +691,7 @@ pub struct AcpThread { pub enum AcpThreadEvent { NewEntry, + TitleUpdated, EntryUpdated(usize), EntriesRemoved(Range), ToolAuthorizationRequired, @@ -934,6 +935,12 @@ impl AcpThread { cx.emit(AcpThreadEvent::NewEntry); } + pub fn update_title(&mut self, title: SharedString, cx: &mut Context) -> Result<()> { + self.title = title; + cx.emit(AcpThreadEvent::TitleUpdated); + Ok(()) + } + pub fn update_tool_call( &mut self, update: impl Into, diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index c7f0840062..3e382d3864 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -255,6 +255,9 @@ impl NativeAgent { this.sessions.remove(acp_thread.session_id()); }), cx.observe(&thread, |this, thread, cx| { + thread.update(cx, |thread, cx| { + thread.generate_title_if_needed(cx); + }); this.save_thread(thread.clone(), cx) }), ], @@ -262,13 +265,14 @@ impl NativeAgent { ); } - fn save_thread(&mut self, thread: Entity, cx: &mut Context) { - let id = thread.read(cx).id().clone(); + fn save_thread(&mut self, thread_handle: Entity, cx: &mut Context) { + let thread = thread_handle.read(cx); + let id = thread.id().clone(); let Some(session) = self.sessions.get_mut(&id) else { return; }; - let thread = thread.downgrade(); + let thread = thread_handle.downgrade(); let thread_database = self.thread_database.clone(); session.save_task = cx.spawn(async move |this, cx| { cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; @@ -507,7 +511,7 @@ impl NativeAgent { fn handle_models_updated_event( &mut self, - _registry: Entity, + registry: Entity, _event: &language_model::Event, cx: &mut Context, ) { @@ -518,6 +522,11 @@ impl NativeAgent { if let Some(model) = self.models.model_from_id(&model_id) { thread.set_model(model.clone(), cx); } + let summarization_model = registry + .read(cx) + .thread_summary_model() + .map(|model| model.model.clone()); + thread.set_summarization_model(summarization_model, cx); }); } } @@ -641,6 +650,10 @@ impl NativeAgentConnection { thread.update_tool_call(update, cx) })??; } + ThreadEvent::TitleUpdate(title) => { + acp_thread + .update(cx, |thread, cx| thread.update_title(title, cx))??; + } ThreadEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse { stop_reason }); @@ -821,6 +834,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { ) })?; + let summarization_model = registry.thread_summary_model().map(|c| c.model); + let thread = cx.new(|cx| { let mut thread = Thread::new( session_id.clone(), @@ -830,6 +845,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { action_log.clone(), agent.templates.clone(), default_model, + summarization_model, cx, ); Self::register_tools(&mut thread, project, action_log, cx); @@ -894,7 +910,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Create Thread let thread = agent.update(cx, |agent, cx| { - let configured_model = LanguageModelRegistry::global(cx) + let language_model_registry = LanguageModelRegistry::global(cx); + let configured_model = language_model_registry .update(cx, |registry, cx| { db_thread .model @@ -915,6 +932,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .model_from_id(&LanguageModels::model_id(&configured_model.model)) .context("no model by id")?; + let summarization_model = language_model_registry + .read(cx) + .thread_summary_model() + .map(|c| c.model); + let thread = cx.new(|cx| { let mut thread = Thread::from_db( session_id, @@ -925,6 +947,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { action_log.clone(), agent.templates.clone(), model, + summarization_model, cx, ); Self::register_tools(&mut thread, project, action_log, cx); @@ -1047,12 +1070,13 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume { #[cfg(test)] mod tests { - use crate::{HistoryEntry, HistoryStore}; + use crate::HistoryStore; use super::*; use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; use fs::FakeFs; use gpui::TestAppContext; + use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; use util::path; @@ -1245,13 +1269,6 @@ mod tests { ) .await .unwrap(); - let model = cx.update(|cx| { - LanguageModelRegistry::global(cx) - .read(cx) - .default_model() - .unwrap() - .model - }); let connection = NativeAgentConnection(agent.clone()); let history_store = cx.new(|cx| { let mut store = HistoryStore::new(cx); @@ -1268,6 +1285,16 @@ mod tests { let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let selector = connection.model_selector().unwrap(); + let summarization_model: Arc = + Arc::new(FakeLanguageModel::default()) as _; + + agent.update(cx, |agent, cx| { + let thread = agent.sessions.get(&session_id).unwrap().thread.clone(); + thread.update(cx, |thread, cx| { + thread.set_summarization_model(Some(summarization_model.clone()), cx); + }) + }); + let model = cx .update(|cx| selector.selected_model(&session_id, cx)) .await @@ -1283,11 +1310,16 @@ mod tests { model.send_last_completion_stream_text_chunk("Hey"); model.end_last_completion_stream(); send.await.unwrap(); + + summarization_model + .as_fake() + .send_last_completion_stream_text_chunk("Saying Hello"); + summarization_model.as_fake().end_last_completion_stream(); cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE); let history = history_store.update(cx, |store, cx| store.entries(cx)); assert_eq!(history.len(), 1); - assert_eq!(history[0].title(), "Hi"); + assert_eq!(history[0].title(), "Saying Hello"); } fn init_test(cx: &mut TestAppContext) { diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index d5d882bcde..67dc8c5e98 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -386,8 +386,6 @@ impl ThreadsDatabase { #[cfg(test)] mod tests { - use crate::NativeAgent; - use crate::Templates; use super::*; use agent::MessageSegment; diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs index d7d0ba2874..0622dd4f58 100644 --- a/crates/agent2/src/history_store.rs +++ b/crates/agent2/src/history_store.rs @@ -1,12 +1,11 @@ use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName}; use agent_client_protocol as acp; -use anyhow::{Context as _, Result}; use assistant_context::SavedContextMetadata; use chrono::{DateTime, Utc}; use collections::HashMap; use gpui::{SharedString, Task, prelude::*}; use serde::{Deserialize, Serialize}; -use smol::stream::StreamExt; + use std::{path::Path, sync::Arc, time::Duration}; const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 2a4d306290..a7fc8d907a 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1506,6 +1506,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { action_log, templates, model.clone(), + None, cx, ) }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 7ea5ff7cc6..9048f7099b 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -5,7 +5,7 @@ use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot}; use agent_client_protocol as acp; -use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use chrono::{DateTime, Utc}; @@ -24,7 +24,7 @@ use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, TokenUsage, }; use project::{ Project, @@ -75,6 +75,18 @@ impl Message { } } + pub fn to_request(&self) -> Vec { + match self { + Message::User(message) => vec![message.to_request()], + Message::Agent(message) => message.to_request(), + Message::Resume => vec![LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: false, + }], + } + } + pub fn to_markdown(&self) -> String { match self { Message::User(message) => message.to_markdown(), @@ -82,6 +94,13 @@ impl Message { Message::Resume => "[resumed after tool use limit was reached]".into(), } } + + pub fn role(&self) -> Role { + match self { + Message::User(_) | Message::Resume => Role::User, + Message::Agent(_) => Role::Assistant, + } + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -426,6 +445,7 @@ pub enum ThreadEvent { ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + TitleUpdate(SharedString), Stop(acp::StopReason), } @@ -475,6 +495,7 @@ pub struct Thread { project_context: Rc>, templates: Arc, model: Arc, + summarization_model: Option>, project: Entity, action_log: Entity, } @@ -488,6 +509,7 @@ impl Thread { action_log: Entity, templates: Arc, model: Arc, + summarization_model: Option>, cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); @@ -516,11 +538,37 @@ impl Thread { project_context, templates, model, + summarization_model, project, action_log, } } + #[cfg(any(test, feature = "test-support"))] + pub fn test( + model: Arc, + project: Entity, + action_log: Entity, + cx: &mut Context, + ) -> Self { + use crate::generate_session_id; + + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + + Self::new( + generate_session_id(), + project, + Rc::default(), + context_server_registry, + action_log, + Templates::new(), + model, + None, + cx, + ) + } + pub fn id(&self) -> &acp::SessionId { &self.id } @@ -534,6 +582,7 @@ impl Thread { action_log: Entity, templates: Arc, model: Arc, + summarization_model: Option>, cx: &mut Context, ) -> Self { let profile_id = db_thread @@ -558,6 +607,7 @@ impl Thread { project_context, templates, model, + summarization_model, project, action_log, updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list) @@ -807,6 +857,15 @@ impl Thread { cx.notify() } + pub fn set_summarization_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + self.summarization_model = model; + cx.notify() + } + pub fn completion_mode(&self) -> CompletionMode { self.completion_mode } @@ -1018,6 +1077,86 @@ impl Thread { events_rx } + pub fn generate_title_if_needed(&mut self, cx: &mut Context) { + if !matches!(self.title, ThreadTitle::None) { + return; + } + + // todo!() copy logic from agent1 re: tool calls, etc.? + if self.messages.len() < 2 { + return; + } + + self.generate_title(cx); + } + + fn generate_title(&mut self, cx: &mut Context) { + let Some(model) = self.summarization_model.clone() else { + println!("No thread summary model"); + return; + }; + let mut request = LanguageModelRequest { + intent: Some(CompletionIntent::ThreadSummarization), + temperature: AgentSettings::temperature_for_model(&model, cx), + ..Default::default() + }; + + for message in &self.messages { + request.messages.extend(message.to_request()); + } + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(SUMMARIZE_THREAD_PROMPT.into())], + cache: false, + }); + + let task = cx.spawn(async move |this, cx| { + let result = async { + let mut messages = model.stream_completion(request, &cx).await?; + + let mut new_summary = String::new(); + while let Some(event) = messages.next().await { + let Ok(event) = event else { + continue; + }; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { .. }, + ) => { + // this.update(cx, |thread, cx| { + // thread.update_model_request_usage(amount as u32, limit, cx); + // })?; + // todo!()? not sure if this is the right place to do this. + continue; + } + _ => continue, + }; + + let mut lines = text.lines(); + new_summary.extend(lines.next()); + + // Stop if the LLM generated multiple lines. + if lines.next().is_some() { + break; + } + } + + anyhow::Ok(new_summary.into()) + } + .await; + + this.update(cx, |this, cx| { + this.title = ThreadTitle::Done(result); + cx.notify(); + }) + .log_err(); + }); + + self.title = ThreadTitle::Pending(task); + } + pub fn build_system_message(&self) -> LanguageModelRequestMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { @@ -1373,15 +1512,7 @@ impl Thread { ); let mut messages = vec![self.build_system_message()]; for message in &self.messages { - match message { - Message::User(message) => messages.push(message.to_request()), - Message::Agent(message) => messages.extend(message.to_request()), - Message::Resume => messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec!["Continue where you left off".into()], - cache: false, - }), - } + messages.extend(message.to_request()); } if let Some(message) = self.pending_message.as_ref() { @@ -1924,7 +2055,7 @@ impl From for acp::ContentBlock { annotations: None, uri: None, }), - UserMessageContent::Mention { uri, content } => { + UserMessageContent::Mention { .. } => { todo!() } } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index f48ea7e86a..f540349f82 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -521,7 +521,6 @@ fn resolve_path( #[cfg(test)] mod tests { use super::*; - use crate::{ContextServerRegistry, Templates, generate_session_id}; use action_log::ActionLog; use client::TelemetrySettings; use fs::Fs; @@ -529,7 +528,6 @@ mod tests { use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; - use std::rc::Rc; use util::path; #[gpui::test] @@ -541,21 +539,8 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log, - Templates::new(), - model, - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); let result = cx .update(|cx| { let input = EditFileToolInput { @@ -743,21 +728,8 @@ mod tests { }); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log.clone(), cx)); // First, test with format_on_save enabled cx.update(|cx| { @@ -885,22 +857,9 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log, cx)); // First, test with remove_trailing_whitespace_on_save enabled cx.update(|cx| { @@ -1015,22 +974,10 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); fs.insert_tree("/root", json!({})).await; @@ -1154,22 +1101,10 @@ mod tests { fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test global config paths - these should require confirmation if they exist and are outside the project @@ -1266,21 +1201,9 @@ mod tests { .await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project.clone(), - Rc::default(), - context_server_registry.clone(), - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test files in different worktrees @@ -1349,21 +1272,9 @@ mod tests { let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project.clone(), - Rc::default(), - context_server_registry.clone(), - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test edge cases @@ -1435,21 +1346,9 @@ mod tests { let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project.clone(), - Rc::default(), - context_server_registry.clone(), - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test different EditFileMode values @@ -1518,21 +1417,9 @@ mod tests { let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project.clone(), - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); assert_eq!( diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index 7cc3ed3b9b..344790f26e 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -1,15 +1,12 @@ -use crate::{AgentPanel, RemoveSelectedThread}; +use crate::RemoveSelectedThread; use agent_servers::AgentServer; -use agent2::{ - NativeAgentServer, - history_store::{HistoryEntry, HistoryStore}, -}; +use agent2::{HistoryEntry, HistoryStore, NativeAgentServer}; use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; use editor::{Editor, EditorEvent}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, - UniformListScrollHandle, WeakEntity, Window, uniform_list, + UniformListScrollHandle, Window, uniform_list, }; use project::Project; use std::{fmt::Display, ops::Range, sync::Arc}; @@ -72,7 +69,7 @@ impl AcpThreadHistory { window: &mut Window, cx: &mut Context, ) -> Self { - let history_store = cx.new(|cx| agent2::history_store::HistoryStore::new(cx)); + let history_store = cx.new(|cx| agent2::HistoryStore::new(cx)); let agent = NativeAgentServer::new(project.read(cx).fs().clone()); diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 40517e49a0..959c152525 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -687,6 +687,7 @@ impl AcpThreadView { AcpThreadEvent::ServerExited(status) => { self.thread_state = ThreadState::ServerExited { status: *status }; } + AcpThreadEvent::TitleUpdated => {} } cx.notify(); } diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index b9e1ea5d0a..3d43c6883d 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -199,24 +199,21 @@ impl AgentDiffPane { let action_log = thread.action_log(cx).clone(); let mut this = Self { - _subscriptions: [ - Some( - cx.observe_in(&action_log, window, |this, _action_log, window, cx| { - this.update_excerpts(window, cx) - }), - ), + _subscriptions: vec![ + cx.observe_in(&action_log, window, |this, _action_log, window, cx| { + this.update_excerpts(window, cx) + }), match &thread { - AgentDiffThread::Native(thread) => { - Some(cx.subscribe(&thread, |this, _thread, event, cx| { - this.handle_thread_event(event, cx) - })) - } - AgentDiffThread::AcpThread(_) => None, + AgentDiffThread::Native(thread) => cx + .subscribe(&thread, |this, _thread, event, cx| { + this.handle_native_thread_event(event, cx) + }), + AgentDiffThread::AcpThread(thread) => cx + .subscribe(&thread, |this, _thread, event, cx| { + this.handle_acp_thread_event(event, cx) + }), }, - ] - .into_iter() - .flatten() - .collect(), + ], title: SharedString::default(), multibuffer, editor, @@ -324,13 +321,20 @@ impl AgentDiffPane { } } - fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { + fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { match event { ThreadEvent::SummaryGenerated => self.update_title(cx), _ => {} } } + fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context) { + match event { + AcpThreadEvent::TitleUpdated => self.update_title(cx), + _ => {} + } + } + pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) { if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) { self.editor.update(cx, |editor, cx| { @@ -1521,7 +1525,8 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } } - AcpThreadEvent::EntriesRemoved(_) + AcpThreadEvent::TitleUpdated + | AcpThreadEvent::EntriesRemoved(_) | AcpThreadEvent::Stopped | AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::Error