Re-add history entries for native agent threads (#36500)
Closes #ISSUE Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
6b6eb11643
commit
6ba52a3a42
16 changed files with 2007 additions and 119 deletions
|
@ -1,10 +1,9 @@
|
|||
use crate::{
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
|
||||
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
|
||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
||||
UserMessageContent, WebSearchTool, templates::Templates,
|
||||
ContextServerRegistry, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent,
|
||||
templates::Templates,
|
||||
};
|
||||
use acp_thread::AgentModelSelector;
|
||||
use crate::{HistoryStore, ThreadsDatabase};
|
||||
use acp_thread::{AcpThread, AgentModelSelector};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
|
@ -51,7 +50,8 @@ struct Session {
|
|||
thread: Entity<Thread>,
|
||||
/// The ACP thread that handles protocol communication
|
||||
acp_thread: WeakEntity<acp_thread::AcpThread>,
|
||||
_subscription: Subscription,
|
||||
pending_save: Task<()>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
pub struct LanguageModels {
|
||||
|
@ -155,6 +155,7 @@ impl LanguageModels {
|
|||
pub struct NativeAgent {
|
||||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
history: Entity<HistoryStore>,
|
||||
/// Shared project context for all threads
|
||||
project_context: Entity<ProjectContext>,
|
||||
project_context_needs_refresh: watch::Sender<()>,
|
||||
|
@ -173,6 +174,7 @@ pub struct NativeAgent {
|
|||
impl NativeAgent {
|
||||
pub async fn new(
|
||||
project: Entity<Project>,
|
||||
history: Entity<HistoryStore>,
|
||||
templates: Arc<Templates>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
|
@ -200,6 +202,7 @@ impl NativeAgent {
|
|||
watch::channel(());
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
history,
|
||||
project_context: cx.new(|_| project_context),
|
||||
project_context_needs_refresh: project_context_needs_refresh_tx,
|
||||
_maintain_project_context: cx.spawn(async move |this, cx| {
|
||||
|
@ -218,6 +221,55 @@ impl NativeAgent {
|
|||
})
|
||||
}
|
||||
|
||||
fn register_session(
|
||||
&mut self,
|
||||
thread_handle: Entity<Thread>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Entity<AcpThread> {
|
||||
let connection = Rc::new(NativeAgentConnection(cx.entity()));
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let summarization_model = registry.thread_summary_model().map(|c| c.model);
|
||||
|
||||
thread_handle.update(cx, |thread, cx| {
|
||||
thread.set_summarization_model(summarization_model, cx);
|
||||
thread.add_default_tools(cx)
|
||||
});
|
||||
|
||||
let thread = thread_handle.read(cx);
|
||||
let session_id = thread.id().clone();
|
||||
let title = thread.title();
|
||||
let project = thread.project.clone();
|
||||
let action_log = thread.action_log.clone();
|
||||
let acp_thread = cx.new(|_cx| {
|
||||
acp_thread::AcpThread::new(
|
||||
title,
|
||||
connection,
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
session_id.clone(),
|
||||
)
|
||||
});
|
||||
let subscriptions = vec![
|
||||
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
}),
|
||||
cx.observe(&thread_handle, move |this, thread, cx| {
|
||||
this.save_thread(thread.clone(), cx)
|
||||
}),
|
||||
];
|
||||
|
||||
self.sessions.insert(
|
||||
session_id,
|
||||
Session {
|
||||
thread: thread_handle,
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
_subscriptions: subscriptions,
|
||||
pending_save: Task::ready(()),
|
||||
},
|
||||
);
|
||||
acp_thread
|
||||
}
|
||||
|
||||
pub fn models(&self) -> &LanguageModels {
|
||||
&self.models
|
||||
}
|
||||
|
@ -444,6 +496,63 @@ impl NativeAgent {
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open_thread(
|
||||
&mut self,
|
||||
id: acp::SessionId,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let database_future = ThreadsDatabase::connect(cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let database = database_future.await.map_err(|err| anyhow!(err))?;
|
||||
let db_thread = database
|
||||
.load_thread(id.clone())
|
||||
.await?
|
||||
.with_context(|| format!("no thread found with ID: {id:?}"))?;
|
||||
|
||||
let thread = this.update(cx, |this, cx| {
|
||||
let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
|
||||
cx.new(|cx| {
|
||||
Thread::from_db(
|
||||
id.clone(),
|
||||
db_thread,
|
||||
this.project.clone(),
|
||||
this.project_context.clone(),
|
||||
this.context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
this.templates.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?;
|
||||
let acp_thread =
|
||||
this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
|
||||
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
|
||||
cx.update(|cx| {
|
||||
NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
|
||||
})?
|
||||
.await?;
|
||||
Ok(acp_thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
|
||||
let database_future = ThreadsDatabase::connect(cx);
|
||||
let (id, db_thread) =
|
||||
thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
|
||||
let Some(session) = self.sessions.get_mut(&id) else {
|
||||
return;
|
||||
};
|
||||
let history = self.history.clone();
|
||||
session.pending_save = cx.spawn(async move |_, cx| {
|
||||
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
|
||||
return;
|
||||
};
|
||||
let db_thread = db_thread.await;
|
||||
database.save_thread(id, db_thread).await.log_err();
|
||||
history.update(cx, |history, cx| history.reload(cx)).ok();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper struct that implements the AgentConnection trait
|
||||
|
@ -476,13 +585,21 @@ impl NativeAgentConnection {
|
|||
};
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
let mut response_stream = match f(thread, cx) {
|
||||
let response_stream = match f(thread, cx) {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
};
|
||||
Self::handle_thread_events(response_stream, acp_thread, cx)
|
||||
}
|
||||
|
||||
fn handle_thread_events(
|
||||
mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
||||
acp_thread: WeakEntity<AcpThread>,
|
||||
cx: &App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
cx.spawn(async move |cx| {
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
while let Some(result) = events.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
|
@ -686,8 +803,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||
// Fetch default model from registry settings
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let language_registry = project.read(cx).languages().clone();
|
||||
|
||||
// Log available models for debugging
|
||||
let available_count = registry.available_models(cx).count();
|
||||
log::debug!("Total available models: {}", available_count);
|
||||
|
@ -697,72 +812,23 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
.models
|
||||
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
||||
});
|
||||
let summarization_model = registry.thread_summary_model().map(|c| c.model);
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
agent.project_context.clone(),
|
||||
agent.context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
default_model,
|
||||
summarization_model,
|
||||
cx,
|
||||
);
|
||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
|
||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(GrepTool::new(project.clone()));
|
||||
thread.add_tool(ListDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(MovePathTool::new(project.clone()));
|
||||
thread.add_tool(NowTool);
|
||||
thread.add_tool(OpenTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
|
||||
thread.add_tool(TerminalTool::new(project.clone(), cx));
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
|
||||
thread
|
||||
)
|
||||
});
|
||||
|
||||
Ok(thread)
|
||||
},
|
||||
)??;
|
||||
|
||||
let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|_cx| {
|
||||
acp_thread::AcpThread::new(
|
||||
"agent2",
|
||||
self.clone(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
session_id.clone(),
|
||||
)
|
||||
})
|
||||
})?;
|
||||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, cx| {
|
||||
agent.sessions.insert(
|
||||
session_id,
|
||||
Session {
|
||||
thread,
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
}),
|
||||
},
|
||||
);
|
||||
})?;
|
||||
|
||||
Ok(acp_thread)
|
||||
agent.update(cx, |agent, cx| agent.register_session(thread, cx))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -887,8 +953,11 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
|
||||
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
history_store,
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
|
@ -942,9 +1011,12 @@ mod tests {
|
|||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/", json!({ "a": {} })).await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
|
||||
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
|
||||
let connection = NativeAgentConnection(
|
||||
NativeAgent::new(
|
||||
project.clone(),
|
||||
history_store,
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
|
@ -995,9 +1067,13 @@ mod tests {
|
|||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
|
||||
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
|
||||
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
|
||||
|
||||
// Create the agent and connection
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
history_store,
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue