Saving history with thread titles

This commit is contained in:
Conrad Irwin 2025-08-18 14:16:03 -06:00
parent cc196427f0
commit 5d88de13da
10 changed files with 241 additions and 183 deletions

View file

@ -691,6 +691,7 @@ pub struct AcpThread {
pub enum AcpThreadEvent { pub enum AcpThreadEvent {
NewEntry, NewEntry,
TitleUpdated,
EntryUpdated(usize), EntryUpdated(usize),
EntriesRemoved(Range<usize>), EntriesRemoved(Range<usize>),
ToolAuthorizationRequired, ToolAuthorizationRequired,
@ -934,6 +935,12 @@ impl AcpThread {
cx.emit(AcpThreadEvent::NewEntry); cx.emit(AcpThreadEvent::NewEntry);
} }
pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
self.title = title;
cx.emit(AcpThreadEvent::TitleUpdated);
Ok(())
}
pub fn update_tool_call( pub fn update_tool_call(
&mut self, &mut self,
update: impl Into<ToolCallUpdate>, update: impl Into<ToolCallUpdate>,

View file

@ -255,6 +255,9 @@ impl NativeAgent {
this.sessions.remove(acp_thread.session_id()); this.sessions.remove(acp_thread.session_id());
}), }),
cx.observe(&thread, |this, thread, cx| { cx.observe(&thread, |this, thread, cx| {
thread.update(cx, |thread, cx| {
thread.generate_title_if_needed(cx);
});
this.save_thread(thread.clone(), cx) this.save_thread(thread.clone(), cx)
}), }),
], ],
@ -262,13 +265,14 @@ impl NativeAgent {
); );
} }
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) { fn save_thread(&mut self, thread_handle: Entity<Thread>, cx: &mut Context<Self>) {
let id = thread.read(cx).id().clone(); let thread = thread_handle.read(cx);
let id = thread.id().clone();
let Some(session) = self.sessions.get_mut(&id) else { let Some(session) = self.sessions.get_mut(&id) else {
return; return;
}; };
let thread = thread.downgrade(); let thread = thread_handle.downgrade();
let thread_database = self.thread_database.clone(); let thread_database = self.thread_database.clone();
session.save_task = cx.spawn(async move |this, cx| { session.save_task = cx.spawn(async move |this, cx| {
cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
@ -507,7 +511,7 @@ impl NativeAgent {
fn handle_models_updated_event( fn handle_models_updated_event(
&mut self, &mut self,
_registry: Entity<LanguageModelRegistry>, registry: Entity<LanguageModelRegistry>,
_event: &language_model::Event, _event: &language_model::Event,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
@ -518,6 +522,11 @@ impl NativeAgent {
if let Some(model) = self.models.model_from_id(&model_id) { if let Some(model) = self.models.model_from_id(&model_id) {
thread.set_model(model.clone(), cx); thread.set_model(model.clone(), cx);
} }
let summarization_model = registry
.read(cx)
.thread_summary_model()
.map(|model| model.model.clone());
thread.set_summarization_model(summarization_model, cx);
}); });
} }
} }
@ -641,6 +650,10 @@ impl NativeAgentConnection {
thread.update_tool_call(update, cx) thread.update_tool_call(update, cx)
})??; })??;
} }
ThreadEvent::TitleUpdate(title) => {
acp_thread
.update(cx, |thread, cx| thread.update_title(title, cx))??;
}
ThreadEvent::Stop(stop_reason) => { ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason); log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason }); return Ok(acp::PromptResponse { stop_reason });
@ -821,6 +834,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
) )
})?; })?;
let summarization_model = registry.thread_summary_model().map(|c| c.model);
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let mut thread = Thread::new( let mut thread = Thread::new(
session_id.clone(), session_id.clone(),
@ -830,6 +845,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
action_log.clone(), action_log.clone(),
agent.templates.clone(), agent.templates.clone(),
default_model, default_model,
summarization_model,
cx, cx,
); );
Self::register_tools(&mut thread, project, action_log, cx); Self::register_tools(&mut thread, project, action_log, cx);
@ -894,7 +910,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Create Thread // Create Thread
let thread = agent.update(cx, |agent, cx| { let thread = agent.update(cx, |agent, cx| {
let configured_model = LanguageModelRegistry::global(cx) let language_model_registry = LanguageModelRegistry::global(cx);
let configured_model = language_model_registry
.update(cx, |registry, cx| { .update(cx, |registry, cx| {
db_thread db_thread
.model .model
@ -915,6 +932,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.model_from_id(&LanguageModels::model_id(&configured_model.model)) .model_from_id(&LanguageModels::model_id(&configured_model.model))
.context("no model by id")?; .context("no model by id")?;
let summarization_model = language_model_registry
.read(cx)
.thread_summary_model()
.map(|c| c.model);
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let mut thread = Thread::from_db( let mut thread = Thread::from_db(
session_id, session_id,
@ -925,6 +947,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
action_log.clone(), action_log.clone(),
agent.templates.clone(), agent.templates.clone(),
model, model,
summarization_model,
cx, cx,
); );
Self::register_tools(&mut thread, project, action_log, cx); Self::register_tools(&mut thread, project, action_log, cx);
@ -1047,12 +1070,13 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{HistoryEntry, HistoryStore}; use crate::HistoryStore;
use super::*; use super::*;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
use fs::FakeFs; use fs::FakeFs;
use gpui::TestAppContext; use gpui::TestAppContext;
use language_model::fake_provider::FakeLanguageModel;
use serde_json::json; use serde_json::json;
use settings::SettingsStore; use settings::SettingsStore;
use util::path; use util::path;
@ -1245,13 +1269,6 @@ mod tests {
) )
.await .await
.unwrap(); .unwrap();
let model = cx.update(|cx| {
LanguageModelRegistry::global(cx)
.read(cx)
.default_model()
.unwrap()
.model
});
let connection = NativeAgentConnection(agent.clone()); let connection = NativeAgentConnection(agent.clone());
let history_store = cx.new(|cx| { let history_store = cx.new(|cx| {
let mut store = HistoryStore::new(cx); let mut store = HistoryStore::new(cx);
@ -1268,6 +1285,16 @@ mod tests {
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
let selector = connection.model_selector().unwrap(); let selector = connection.model_selector().unwrap();
let summarization_model: Arc<dyn LanguageModel> =
Arc::new(FakeLanguageModel::default()) as _;
agent.update(cx, |agent, cx| {
let thread = agent.sessions.get(&session_id).unwrap().thread.clone();
thread.update(cx, |thread, cx| {
thread.set_summarization_model(Some(summarization_model.clone()), cx);
})
});
let model = cx let model = cx
.update(|cx| selector.selected_model(&session_id, cx)) .update(|cx| selector.selected_model(&session_id, cx))
.await .await
@ -1283,11 +1310,16 @@ mod tests {
model.send_last_completion_stream_text_chunk("Hey"); model.send_last_completion_stream_text_chunk("Hey");
model.end_last_completion_stream(); model.end_last_completion_stream();
send.await.unwrap(); send.await.unwrap();
summarization_model
.as_fake()
.send_last_completion_stream_text_chunk("Saying Hello");
summarization_model.as_fake().end_last_completion_stream();
cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE); cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE);
let history = history_store.update(cx, |store, cx| store.entries(cx)); let history = history_store.update(cx, |store, cx| store.entries(cx));
assert_eq!(history.len(), 1); assert_eq!(history.len(), 1);
assert_eq!(history[0].title(), "Hi"); assert_eq!(history[0].title(), "Saying Hello");
} }
fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {

View file

@ -386,8 +386,6 @@ impl ThreadsDatabase {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::NativeAgent;
use crate::Templates;
use super::*; use super::*;
use agent::MessageSegment; use agent::MessageSegment;

View file

@ -1,12 +1,11 @@
use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName}; use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
use assistant_context::SavedContextMetadata; use assistant_context::SavedContextMetadata;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::HashMap; use collections::HashMap;
use gpui::{SharedString, Task, prelude::*}; use gpui::{SharedString, Task, prelude::*};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use smol::stream::StreamExt;
use std::{path::Path, sync::Arc, time::Duration}; use std::{path::Path, sync::Arc, time::Duration};
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;

View file

@ -1506,6 +1506,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
action_log, action_log,
templates, templates,
model.clone(), model.clone(),
None,
cx, cx,
) )
}); });

View file

@ -5,7 +5,7 @@ use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog; use action_log::ActionLog;
use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot}; use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format; use assistant_tool::adapt_schema_to_format;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
@ -24,7 +24,7 @@ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, TokenUsage,
}; };
use project::{ use project::{
Project, Project,
@ -75,6 +75,18 @@ impl Message {
} }
} }
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
match self {
Message::User(message) => vec![message.to_request()],
Message::Agent(message) => message.to_request(),
Message::Resume => vec![LanguageModelRequestMessage {
role: Role::User,
content: vec!["Continue where you left off".into()],
cache: false,
}],
}
}
pub fn to_markdown(&self) -> String { pub fn to_markdown(&self) -> String {
match self { match self {
Message::User(message) => message.to_markdown(), Message::User(message) => message.to_markdown(),
@ -82,6 +94,13 @@ impl Message {
Message::Resume => "[resumed after tool use limit was reached]".into(), Message::Resume => "[resumed after tool use limit was reached]".into(),
} }
} }
pub fn role(&self) -> Role {
match self {
Message::User(_) | Message::Resume => Role::User,
Message::Agent(_) => Role::Assistant,
}
}
} }
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -426,6 +445,7 @@ pub enum ThreadEvent {
ToolCall(acp::ToolCall), ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization), ToolCallAuthorization(ToolCallAuthorization),
TitleUpdate(SharedString),
Stop(acp::StopReason), Stop(acp::StopReason),
} }
@ -475,6 +495,7 @@ pub struct Thread {
project_context: Rc<RefCell<ProjectContext>>, project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>, templates: Arc<Templates>,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
summarization_model: Option<Arc<dyn LanguageModel>>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
} }
@ -488,6 +509,7 @@ impl Thread {
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
templates: Arc<Templates>, templates: Arc<Templates>,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
summarization_model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone(); let profile_id = AgentSettings::get_global(cx).default_profile.clone();
@ -516,11 +538,37 @@ impl Thread {
project_context, project_context,
templates, templates,
model, model,
summarization_model,
project, project,
action_log, action_log,
} }
} }
#[cfg(any(test, feature = "test-support"))]
pub fn test(
model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut Context<Self>,
) -> Self {
use crate::generate_session_id;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
Self::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
action_log,
Templates::new(),
model,
None,
cx,
)
}
pub fn id(&self) -> &acp::SessionId { pub fn id(&self) -> &acp::SessionId {
&self.id &self.id
} }
@ -534,6 +582,7 @@ impl Thread {
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
templates: Arc<Templates>, templates: Arc<Templates>,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
summarization_model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let profile_id = db_thread let profile_id = db_thread
@ -558,6 +607,7 @@ impl Thread {
project_context, project_context,
templates, templates,
model, model,
summarization_model,
project, project,
action_log, action_log,
updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list) updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
@ -807,6 +857,15 @@ impl Thread {
cx.notify() cx.notify()
} }
pub fn set_summarization_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
) {
self.summarization_model = model;
cx.notify()
}
pub fn completion_mode(&self) -> CompletionMode { pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode self.completion_mode
} }
@ -1018,6 +1077,86 @@ impl Thread {
events_rx events_rx
} }
pub fn generate_title_if_needed(&mut self, cx: &mut Context<Self>) {
if !matches!(self.title, ThreadTitle::None) {
return;
}
// todo!() copy logic from agent1 re: tool calls, etc.?
if self.messages.len() < 2 {
return;
}
self.generate_title(cx);
}
fn generate_title(&mut self, cx: &mut Context<Self>) {
let Some(model) = self.summarization_model.clone() else {
println!("No thread summary model");
return;
};
let mut request = LanguageModelRequest {
intent: Some(CompletionIntent::ThreadSummarization),
temperature: AgentSettings::temperature_for_model(&model, cx),
..Default::default()
};
for message in &self.messages {
request.messages.extend(message.to_request());
}
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(SUMMARIZE_THREAD_PROMPT.into())],
cache: false,
});
let task = cx.spawn(async move |this, cx| {
let result = async {
let mut messages = model.stream_completion(request, &cx).await?;
let mut new_summary = String::new();
while let Some(event) = messages.next().await {
let Ok(event) = event else {
continue;
};
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { .. },
) => {
// this.update(cx, |thread, cx| {
// thread.update_model_request_usage(amount as u32, limit, cx);
// })?;
// todo!()? not sure if this is the right place to do this.
continue;
}
_ => continue,
};
let mut lines = text.lines();
new_summary.extend(lines.next());
// Stop if the LLM generated multiple lines.
if lines.next().is_some() {
break;
}
}
anyhow::Ok(new_summary.into())
}
.await;
this.update(cx, |this, cx| {
this.title = ThreadTitle::Done(result);
cx.notify();
})
.log_err();
});
self.title = ThreadTitle::Pending(task);
}
pub fn build_system_message(&self) -> LanguageModelRequestMessage { pub fn build_system_message(&self) -> LanguageModelRequestMessage {
log::debug!("Building system message"); log::debug!("Building system message");
let prompt = SystemPromptTemplate { let prompt = SystemPromptTemplate {
@ -1373,15 +1512,7 @@ impl Thread {
); );
let mut messages = vec![self.build_system_message()]; let mut messages = vec![self.build_system_message()];
for message in &self.messages { for message in &self.messages {
match message { messages.extend(message.to_request());
Message::User(message) => messages.push(message.to_request()),
Message::Agent(message) => messages.extend(message.to_request()),
Message::Resume => messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec!["Continue where you left off".into()],
cache: false,
}),
}
} }
if let Some(message) = self.pending_message.as_ref() { if let Some(message) = self.pending_message.as_ref() {
@ -1924,7 +2055,7 @@ impl From<UserMessageContent> for acp::ContentBlock {
annotations: None, annotations: None,
uri: None, uri: None,
}), }),
UserMessageContent::Mention { uri, content } => { UserMessageContent::Mention { .. } => {
todo!() todo!()
} }
} }

View file

@ -521,7 +521,6 @@ fn resolve_path(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::{ContextServerRegistry, Templates, generate_session_id};
use action_log::ActionLog; use action_log::ActionLog;
use client::TelemetrySettings; use client::TelemetrySettings;
use fs::Fs; use fs::Fs;
@ -529,7 +528,6 @@ mod tests {
use language_model::fake_provider::FakeLanguageModel; use language_model::fake_provider::FakeLanguageModel;
use serde_json::json; use serde_json::json;
use settings::SettingsStore; use settings::SettingsStore;
use std::rc::Rc;
use util::path; use util::path;
#[gpui::test] #[gpui::test]
@ -541,21 +539,8 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
action_log,
Templates::new(),
model,
cx,
)
});
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = EditFileToolInput { let input = EditFileToolInput {
@ -743,21 +728,8 @@ mod tests {
}); });
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log.clone(), cx));
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
// First, test with format_on_save enabled // First, test with format_on_save enabled
cx.update(|cx| { cx.update(|cx| {
@ -885,22 +857,9 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log, cx));
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
// First, test with remove_trailing_whitespace_on_save enabled // First, test with remove_trailing_whitespace_on_save enabled
cx.update(|cx| { cx.update(|cx| {
@ -1015,22 +974,10 @@ mod tests {
let fs = project::FakeFs::new(cx.executor()); let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
fs.insert_tree("/root", json!({})).await; fs.insert_tree("/root", json!({})).await;
@ -1154,22 +1101,10 @@ mod tests {
fs.insert_tree("/project", json!({})).await; fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test global config paths - these should require confirmation if they exist and are outside the project // Test global config paths - these should require confirmation if they exist and are outside the project
@ -1266,21 +1201,9 @@ mod tests {
.await; .await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
Thread::new(
generate_session_id(),
project.clone(),
Rc::default(),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test files in different worktrees // Test files in different worktrees
@ -1349,21 +1272,9 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
Thread::new(
generate_session_id(),
project.clone(),
Rc::default(),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test edge cases // Test edge cases
@ -1435,21 +1346,9 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
Thread::new(
generate_session_id(),
project.clone(),
Rc::default(),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test different EditFileMode values // Test different EditFileMode values
@ -1518,21 +1417,9 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
Thread::new(
generate_session_id(),
project.clone(),
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
assert_eq!( assert_eq!(

View file

@ -1,15 +1,12 @@
use crate::{AgentPanel, RemoveSelectedThread}; use crate::RemoveSelectedThread;
use agent_servers::AgentServer; use agent_servers::AgentServer;
use agent2::{ use agent2::{HistoryEntry, HistoryStore, NativeAgentServer};
NativeAgentServer,
history_store::{HistoryEntry, HistoryStore},
};
use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; use chrono::{Datelike as _, Local, NaiveDate, TimeDelta};
use editor::{Editor, EditorEvent}; use editor::{Editor, EditorEvent};
use fuzzy::{StringMatch, StringMatchCandidate}; use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{ use gpui::{
App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task,
UniformListScrollHandle, WeakEntity, Window, uniform_list, UniformListScrollHandle, Window, uniform_list,
}; };
use project::Project; use project::Project;
use std::{fmt::Display, ops::Range, sync::Arc}; use std::{fmt::Display, ops::Range, sync::Arc};
@ -72,7 +69,7 @@ impl AcpThreadHistory {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let history_store = cx.new(|cx| agent2::history_store::HistoryStore::new(cx)); let history_store = cx.new(|cx| agent2::HistoryStore::new(cx));
let agent = NativeAgentServer::new(project.read(cx).fs().clone()); let agent = NativeAgentServer::new(project.read(cx).fs().clone());

View file

@ -687,6 +687,7 @@ impl AcpThreadView {
AcpThreadEvent::ServerExited(status) => { AcpThreadEvent::ServerExited(status) => {
self.thread_state = ThreadState::ServerExited { status: *status }; self.thread_state = ThreadState::ServerExited { status: *status };
} }
AcpThreadEvent::TitleUpdated => {}
} }
cx.notify(); cx.notify();
} }

View file

@ -199,24 +199,21 @@ impl AgentDiffPane {
let action_log = thread.action_log(cx).clone(); let action_log = thread.action_log(cx).clone();
let mut this = Self { let mut this = Self {
_subscriptions: [ _subscriptions: vec![
Some( cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
cx.observe_in(&action_log, window, |this, _action_log, window, cx| { this.update_excerpts(window, cx)
this.update_excerpts(window, cx) }),
}),
),
match &thread { match &thread {
AgentDiffThread::Native(thread) => { AgentDiffThread::Native(thread) => cx
Some(cx.subscribe(&thread, |this, _thread, event, cx| { .subscribe(&thread, |this, _thread, event, cx| {
this.handle_thread_event(event, cx) this.handle_native_thread_event(event, cx)
})) }),
} AgentDiffThread::AcpThread(thread) => cx
AgentDiffThread::AcpThread(_) => None, .subscribe(&thread, |this, _thread, event, cx| {
this.handle_acp_thread_event(event, cx)
}),
}, },
] ],
.into_iter()
.flatten()
.collect(),
title: SharedString::default(), title: SharedString::default(),
multibuffer, multibuffer,
editor, editor,
@ -324,13 +321,20 @@ impl AgentDiffPane {
} }
} }
fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) { fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
match event { match event {
ThreadEvent::SummaryGenerated => self.update_title(cx), ThreadEvent::SummaryGenerated => self.update_title(cx),
_ => {} _ => {}
} }
} }
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
match event {
AcpThreadEvent::TitleUpdated => self.update_title(cx),
_ => {}
}
}
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) { pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) { if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
self.editor.update(cx, |editor, cx| { self.editor.update(cx, |editor, cx| {
@ -1521,7 +1525,8 @@ impl AgentDiff {
self.update_reviewing_editors(workspace, window, cx); self.update_reviewing_editors(workspace, window, cx);
} }
} }
AcpThreadEvent::EntriesRemoved(_) AcpThreadEvent::TitleUpdated
| AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::Stopped | AcpThreadEvent::Stopped
| AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Error | AcpThreadEvent::Error