Re-add history entries for native agent threads (#36500)

Closes #ISSUE

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Conrad Irwin 2025-08-19 12:08:11 -06:00 committed by GitHub
parent 6b6eb11643
commit 6ba52a3a42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 2007 additions and 119 deletions

View file

@ -1,10 +1,9 @@
use crate::{
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
UserMessageContent, WebSearchTool, templates::Templates,
ContextServerRegistry, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent,
templates::Templates,
};
use acp_thread::AgentModelSelector;
use crate::{HistoryStore, ThreadsDatabase};
use acp_thread::{AcpThread, AgentModelSelector};
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
@ -51,7 +50,8 @@ struct Session {
thread: Entity<Thread>,
/// The ACP thread that handles protocol communication
acp_thread: WeakEntity<acp_thread::AcpThread>,
_subscription: Subscription,
pending_save: Task<()>,
_subscriptions: Vec<Subscription>,
}
pub struct LanguageModels {
@ -155,6 +155,7 @@ impl LanguageModels {
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
history: Entity<HistoryStore>,
/// Shared project context for all threads
project_context: Entity<ProjectContext>,
project_context_needs_refresh: watch::Sender<()>,
@ -173,6 +174,7 @@ pub struct NativeAgent {
impl NativeAgent {
pub async fn new(
project: Entity<Project>,
history: Entity<HistoryStore>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
@ -200,6 +202,7 @@ impl NativeAgent {
watch::channel(());
Self {
sessions: HashMap::new(),
history,
project_context: cx.new(|_| project_context),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
@ -218,6 +221,55 @@ impl NativeAgent {
})
}
fn register_session(
&mut self,
thread_handle: Entity<Thread>,
cx: &mut Context<Self>,
) -> Entity<AcpThread> {
let connection = Rc::new(NativeAgentConnection(cx.entity()));
let registry = LanguageModelRegistry::read_global(cx);
let summarization_model = registry.thread_summary_model().map(|c| c.model);
thread_handle.update(cx, |thread, cx| {
thread.set_summarization_model(summarization_model, cx);
thread.add_default_tools(cx)
});
let thread = thread_handle.read(cx);
let session_id = thread.id().clone();
let title = thread.title();
let project = thread.project.clone();
let action_log = thread.action_log.clone();
let acp_thread = cx.new(|_cx| {
acp_thread::AcpThread::new(
title,
connection,
project.clone(),
action_log.clone(),
session_id.clone(),
)
});
let subscriptions = vec![
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
cx.observe(&thread_handle, move |this, thread, cx| {
this.save_thread(thread.clone(), cx)
}),
];
self.sessions.insert(
session_id,
Session {
thread: thread_handle,
acp_thread: acp_thread.downgrade(),
_subscriptions: subscriptions,
pending_save: Task::ready(()),
},
);
acp_thread
}
pub fn models(&self) -> &LanguageModels {
&self.models
}
@ -444,6 +496,63 @@ impl NativeAgent {
});
}
}
pub fn open_thread(
&mut self,
id: acp::SessionId,
cx: &mut Context<Self>,
) -> Task<Result<Entity<AcpThread>>> {
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
let db_thread = database
.load_thread(id.clone())
.await?
.with_context(|| format!("no thread found with ID: {id:?}"))?;
let thread = this.update(cx, |this, cx| {
let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
cx.new(|cx| {
Thread::from_db(
id.clone(),
db_thread,
this.project.clone(),
this.project_context.clone(),
this.context_server_registry.clone(),
action_log.clone(),
this.templates.clone(),
cx,
)
})
})?;
let acp_thread =
this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
cx.update(|cx| {
NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
})?
.await?;
Ok(acp_thread)
})
}
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
let database_future = ThreadsDatabase::connect(cx);
let (id, db_thread) =
thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
let Some(session) = self.sessions.get_mut(&id) else {
return;
};
let history = self.history.clone();
session.pending_save = cx.spawn(async move |_, cx| {
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
return;
};
let db_thread = db_thread.await;
database.save_thread(id, db_thread).await.log_err();
history.update(cx, |history, cx| history.reload(cx)).ok();
});
}
}
/// Wrapper struct that implements the AgentConnection trait
@ -476,13 +585,21 @@ impl NativeAgentConnection {
};
log::debug!("Found session for: {}", session_id);
let mut response_stream = match f(thread, cx) {
let response_stream = match f(thread, cx) {
Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)),
};
Self::handle_thread_events(response_stream, acp_thread, cx)
}
fn handle_thread_events(
mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
acp_thread: WeakEntity<AcpThread>,
cx: &App,
) -> Task<Result<acp::PromptResponse>> {
cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
while let Some(result) = events.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
@ -686,8 +803,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx);
let language_registry = project.read(cx).languages().clone();
// Log available models for debugging
let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count);
@ -697,72 +812,23 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.models
.model_from_id(&LanguageModels::model_id(&default_model.model))
});
let summarization_model = registry.thread_summary_model().map(|c| c.model);
let thread = cx.new(|cx| {
let mut thread = Thread::new(
Thread::new(
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
action_log.clone(),
agent.templates.clone(),
default_model,
summarization_model,
cx,
);
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));
thread.add_tool(ListDirectoryTool::new(project.clone()));
thread.add_tool(MovePathTool::new(project.clone()));
thread.add_tool(NowTool);
thread.add_tool(OpenTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
thread.add_tool(TerminalTool::new(project.clone(), cx));
thread.add_tool(ThinkingTool);
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
thread
)
});
Ok(thread)
},
)??;
let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|_cx| {
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
action_log.clone(),
session_id.clone(),
)
})
})?;
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
session_id,
Session {
thread,
acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
},
);
})?;
Ok(acp_thread)
agent.update(cx, |agent, cx| agent.register_session(thread, cx))
})
}
@ -887,8 +953,11 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
let agent = NativeAgent::new(
project.clone(),
history_store,
Templates::new(),
None,
fs.clone(),
@ -942,9 +1011,12 @@ mod tests {
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
let connection = NativeAgentConnection(
NativeAgent::new(
project.clone(),
history_store,
Templates::new(),
None,
fs.clone(),
@ -995,9 +1067,13 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [], cx).await;
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
// Create the agent and connection
let agent = NativeAgent::new(
project.clone(),
history_store,
Templates::new(),
None,
fs.clone(),