Wire up history completely
Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
d83210d978
commit
4b1a48e4de
7 changed files with 93 additions and 98 deletions
|
@ -2,7 +2,6 @@ use crate::{AcpThread, AcpThreadMetadata};
|
|||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use gpui::{Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -27,6 +26,8 @@ pub trait AgentConnection {
|
|||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
// todo!(expose a history trait, and include list_threads and load_thread)
|
||||
// todo!(write a test)
|
||||
fn list_threads(
|
||||
&self,
|
||||
_cx: &mut App,
|
||||
|
|
|
@ -5,16 +5,15 @@ use crate::{
|
|||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
||||
UserMessageContent, WebSearchTool, templates::Templates,
|
||||
};
|
||||
use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id};
|
||||
use crate::{ThreadsDatabase, generate_session_id};
|
||||
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
|
||||
use futures::future::Shared;
|
||||
use futures::{SinkExt, StreamExt, future};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{StreamExt, future};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||
};
|
||||
|
@ -30,6 +29,7 @@ use std::collections::HashMap;
|
|||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use util::ResultExt;
|
||||
|
||||
const RULES_FILE_NAMES: [&'static str; 9] = [
|
||||
|
@ -174,7 +174,7 @@ pub struct NativeAgent {
|
|||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_database: Arc<ThreadsDatabase>,
|
||||
history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
|
||||
load_history: Task<Result<()>>,
|
||||
load_history: Task<()>,
|
||||
fs: Arc<dyn Fs>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
@ -212,7 +212,7 @@ impl NativeAgent {
|
|||
|
||||
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
|
||||
watch::channel(());
|
||||
let this = Self {
|
||||
let mut this = Self {
|
||||
sessions: HashMap::new(),
|
||||
project_context: Rc::new(RefCell::new(project_context)),
|
||||
project_context_needs_refresh: project_context_needs_refresh_tx,
|
||||
|
@ -229,7 +229,7 @@ impl NativeAgent {
|
|||
prompt_store,
|
||||
fs,
|
||||
history: watch::channel(None).0,
|
||||
load_history: Task::ready(Ok(())),
|
||||
load_history: Task::ready(()),
|
||||
_subscriptions: subscriptions,
|
||||
};
|
||||
this.reload_history(cx);
|
||||
|
@ -249,7 +249,7 @@ impl NativeAgent {
|
|||
Session {
|
||||
thread: thread.clone(),
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
save_task: Task::ready(()),
|
||||
save_task: Task::ready(Ok(())),
|
||||
_subscriptions: vec![
|
||||
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
|
@ -280,24 +280,30 @@ impl NativeAgent {
|
|||
}
|
||||
|
||||
fn reload_history(&mut self, cx: &mut Context<Self>) {
|
||||
dbg!("");
|
||||
let thread_database = self.thread_database.clone();
|
||||
self.load_history = cx.spawn(async move |this, cx| {
|
||||
let results = cx
|
||||
.background_spawn(async move {
|
||||
let results = thread_database.list_threads().await?;
|
||||
Ok(results
|
||||
.into_iter()
|
||||
.map(|thread| AcpThreadMetadata {
|
||||
agent: NATIVE_AGENT_SERVER_NAME.clone(),
|
||||
id: thread.id.into(),
|
||||
title: thread.title,
|
||||
updated_at: thread.updated_at,
|
||||
})
|
||||
.collect())
|
||||
dbg!(&results);
|
||||
anyhow::Ok(
|
||||
results
|
||||
.into_iter()
|
||||
.map(|thread| AcpThreadMetadata {
|
||||
agent: NATIVE_AGENT_SERVER_NAME.clone(),
|
||||
id: thread.id.into(),
|
||||
title: thread.title,
|
||||
updated_at: thread.updated_at,
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
this.update(cx, |this, cx| this.history.send(Some(results)))?;
|
||||
anyhow::Ok(())
|
||||
.await;
|
||||
if let Some(results) = results.log_err() {
|
||||
this.update(cx, |this, _| this.history.send(Some(results)))
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -509,10 +515,10 @@ impl NativeAgent {
|
|||
) {
|
||||
self.models.refresh_list(cx);
|
||||
for session in self.sessions.values_mut() {
|
||||
session.thread.update(cx, |thread, _| {
|
||||
session.thread.update(cx, |thread, cx| {
|
||||
let model_id = LanguageModels::model_id(&thread.model());
|
||||
if let Some(model) = self.models.model_from_id(&model_id) {
|
||||
thread.set_model(model.clone());
|
||||
thread.set_model(model.clone(), cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -715,8 +721,8 @@ impl AgentModelSelector for NativeAgentConnection {
|
|||
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
||||
};
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.set_model(model.clone());
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_model(model.clone(), cx);
|
||||
});
|
||||
|
||||
update_settings_file::<AgentSettings>(
|
||||
|
@ -867,12 +873,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
session_id: acp::SessionId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||
let thread_id = ThreadId::from(session_id.clone());
|
||||
let database = self.0.update(cx, |this, _| this.thread_database.clone());
|
||||
cx.spawn(async move |cx| {
|
||||
let database = database.await.map_err(|e| anyhow!(e))?;
|
||||
let db_thread = database
|
||||
.load_thread(thread_id.clone())
|
||||
.load_thread(session_id.clone())
|
||||
.await?
|
||||
.context("no such thread found")?;
|
||||
|
||||
|
@ -915,7 +919,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::from_db(
|
||||
thread_id,
|
||||
session_id,
|
||||
db_thread,
|
||||
project.clone(),
|
||||
agent.project_context.clone(),
|
||||
|
@ -934,7 +938,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, cx| {
|
||||
agent.insert_session(session_id, thread, acp_thread, cx)
|
||||
agent.insert_session(thread.clone(), acp_thread.clone(), cx)
|
||||
})?;
|
||||
|
||||
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
|
||||
|
@ -995,7 +999,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
log::info!("Cancelling on session: {}", session_id);
|
||||
self.0.update(cx, |agent, cx| {
|
||||
if let Some(agent) = agent.sessions.get(session_id) {
|
||||
agent.thread.update(cx, |thread, _cx| thread.cancel());
|
||||
agent.thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -1022,7 +1026,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
|
|||
|
||||
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
|
||||
Task::ready(
|
||||
self.0
|
||||
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent};
|
||||
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
||||
use agent::thread_store;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{AgentProfileId, CompletionMode};
|
||||
|
@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel;
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbThreadMetadata {
|
||||
pub id: ThreadId,
|
||||
pub id: acp::SessionId,
|
||||
#[serde(alias = "summary")]
|
||||
pub title: SharedString,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
|
@ -323,7 +323,7 @@ impl ThreadsDatabase {
|
|||
|
||||
for (id, summary, updated_at) in rows {
|
||||
threads.push(DbThreadMetadata {
|
||||
id: ThreadId(id),
|
||||
id: acp::SessionId(id),
|
||||
title: summary.into(),
|
||||
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
|
||||
});
|
||||
|
@ -333,7 +333,7 @@ impl ThreadsDatabase {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn load_thread(&self, id: ThreadId) -> Task<Result<Option<DbThread>>> {
|
||||
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
|
||||
let connection = self.connection.clone();
|
||||
|
||||
self.executor.spawn(async move {
|
||||
|
|
|
@ -1,17 +1,13 @@
|
|||
use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
|
||||
use agent::{ThreadId, thread_store::ThreadStore};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result};
|
||||
use assistant_context::SavedContextMetadata;
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
|
||||
use itertools::Itertools;
|
||||
use paths::contexts_dir;
|
||||
use gpui::{SharedString, Task, prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration};
|
||||
use util::ResultExt as _;
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
|
||||
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
|
||||
const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
|
||||
|
@ -64,16 +60,16 @@ enum SerializedRecentOpen {
|
|||
}
|
||||
|
||||
pub struct AgentHistory {
|
||||
entries: HashMap<acp::SessionId, AcpThreadMetadata>,
|
||||
_task: Task<Result<()>>,
|
||||
entries: watch::Receiver<Option<Vec<AcpThreadMetadata>>>,
|
||||
_task: Task<()>,
|
||||
}
|
||||
|
||||
pub struct HistoryStore {
|
||||
agents: HashMap<AgentServerName, AgentHistory>,
|
||||
agents: HashMap<AgentServerName, AgentHistory>, // todo!() text threads
|
||||
}
|
||||
|
||||
impl HistoryStore {
|
||||
pub fn new(cx: &mut Context<Self>) -> Self {
|
||||
pub fn new(_cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
agents: HashMap::default(),
|
||||
}
|
||||
|
@ -88,33 +84,18 @@ impl HistoryStore {
|
|||
let Some(mut history) = connection.list_threads(cx) else {
|
||||
return;
|
||||
};
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
while let Some(updated_history) = history.next().await {
|
||||
dbg!(&updated_history);
|
||||
this.update(cx, |this, cx| {
|
||||
for entry in updated_history {
|
||||
let agent = this
|
||||
.agents
|
||||
.get_mut(&entry.agent)
|
||||
.context("agent not found")?;
|
||||
agent.entries.insert(entry.id.clone(), entry);
|
||||
}
|
||||
cx.notify();
|
||||
anyhow::Ok(())
|
||||
})??
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
self.agents.insert(
|
||||
agent_name,
|
||||
AgentHistory {
|
||||
entries: Default::default(),
|
||||
_task: task,
|
||||
},
|
||||
);
|
||||
let history = AgentHistory {
|
||||
entries: history.clone(),
|
||||
_task: cx.spawn(async move |this, cx| {
|
||||
while history.changed().await.is_ok() {
|
||||
this.update(cx, |_, cx| cx.notify()).ok();
|
||||
}
|
||||
}),
|
||||
};
|
||||
self.agents.insert(agent_name.clone(), history);
|
||||
}
|
||||
|
||||
pub fn entries(&self, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
|
||||
pub fn entries(&mut self, _cx: &mut Context<Self>) -> Vec<HistoryEntry> {
|
||||
let mut history_entries = Vec::new();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
|
@ -124,9 +105,8 @@ impl HistoryStore {
|
|||
|
||||
history_entries.extend(
|
||||
self.agents
|
||||
.values()
|
||||
.flat_map(|agent| agent.entries.values())
|
||||
.cloned()
|
||||
.values_mut()
|
||||
.flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?")
|
||||
.map(HistoryEntry::Thread),
|
||||
);
|
||||
// todo!() include the text threads in here.
|
||||
|
@ -135,7 +115,7 @@ impl HistoryStore {
|
|||
history_entries
|
||||
}
|
||||
|
||||
pub fn recent_entries(&self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
|
||||
pub fn recent_entries(&mut self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
|
||||
self.entries(cx).into_iter().take(limit).collect()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -938,7 +938,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
|||
|
||||
// Cancel the current send and ensure that the event stream is closed, even
|
||||
// if one of the tools is still running.
|
||||
thread.update(cx, |thread, _cx| thread.cancel());
|
||||
thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||
let events = events.collect::<Vec<_>>().await;
|
||||
let last_event = events.last();
|
||||
assert!(
|
||||
|
@ -1113,7 +1113,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
|||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _cx| thread.truncate(message_id))
|
||||
.update(cx, |thread, cx| thread.truncate(message_id, cx))
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
|
|
|
@ -802,16 +802,18 @@ impl Thread {
|
|||
&self.model
|
||||
}
|
||||
|
||||
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
|
||||
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
|
||||
self.model = model;
|
||||
cx.notify()
|
||||
}
|
||||
|
||||
pub fn completion_mode(&self) -> CompletionMode {
|
||||
self.completion_mode
|
||||
}
|
||||
|
||||
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
|
||||
pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
|
||||
self.completion_mode = mode;
|
||||
cx.notify()
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
|
@ -839,21 +841,22 @@ impl Thread {
|
|||
self.profile_id = profile_id;
|
||||
}
|
||||
|
||||
pub fn cancel(&mut self) {
|
||||
pub fn cancel(&mut self, cx: &mut Context<Self>) {
|
||||
if let Some(running_turn) = self.running_turn.take() {
|
||||
running_turn.cancel();
|
||||
}
|
||||
self.flush_pending_message();
|
||||
self.flush_pending_message(cx);
|
||||
}
|
||||
|
||||
pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
|
||||
self.cancel();
|
||||
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
|
||||
self.cancel(cx);
|
||||
let Some(position) = self.messages.iter().position(
|
||||
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
|
||||
) else {
|
||||
return Err(anyhow!("Message not found"));
|
||||
};
|
||||
self.messages.truncate(position);
|
||||
cx.notify();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -900,7 +903,7 @@ impl Thread {
|
|||
}
|
||||
|
||||
fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
|
||||
self.cancel();
|
||||
self.cancel(cx);
|
||||
|
||||
let model = self.model.clone();
|
||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
|
||||
|
@ -938,8 +941,8 @@ impl Thread {
|
|||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.flush_pending_message();
|
||||
this.update(cx, |this, cx| {
|
||||
this.flush_pending_message(cx);
|
||||
this.messages.truncate(message_ix);
|
||||
})?;
|
||||
return Ok(());
|
||||
|
@ -991,7 +994,7 @@ impl Thread {
|
|||
log::info!("No tool uses found, completing turn");
|
||||
return Ok(());
|
||||
} else {
|
||||
this.update(cx, |this, _| this.flush_pending_message())?;
|
||||
this.update(cx, |this, cx| this.flush_pending_message(cx))?;
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
}
|
||||
|
@ -1005,8 +1008,8 @@ impl Thread {
|
|||
log::info!("Turn execution completed successfully");
|
||||
}
|
||||
|
||||
this.update(cx, |this, _| {
|
||||
this.flush_pending_message();
|
||||
this.update(cx, |this, cx| {
|
||||
this.flush_pending_message(cx);
|
||||
this.running_turn.take();
|
||||
})
|
||||
.ok();
|
||||
|
@ -1046,7 +1049,7 @@ impl Thread {
|
|||
|
||||
match event {
|
||||
StartMessage { .. } => {
|
||||
self.flush_pending_message();
|
||||
self.flush_pending_message(cx);
|
||||
self.pending_message = Some(AgentMessage::default());
|
||||
}
|
||||
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
|
||||
|
@ -1255,7 +1258,7 @@ impl Thread {
|
|||
self.pending_message.get_or_insert_default()
|
||||
}
|
||||
|
||||
fn flush_pending_message(&mut self) {
|
||||
fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(mut message) = self.pending_message.take() else {
|
||||
return;
|
||||
};
|
||||
|
@ -1280,6 +1283,7 @@ impl Thread {
|
|||
}
|
||||
|
||||
self.messages.push(Message::Agent(message));
|
||||
cx.notify()
|
||||
}
|
||||
|
||||
pub(crate) fn build_completion_request(
|
||||
|
|
|
@ -2487,12 +2487,15 @@ impl AcpThreadView {
|
|||
return;
|
||||
};
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
let current_mode = thread.completion_mode();
|
||||
thread.set_completion_mode(match current_mode {
|
||||
CompletionMode::Burn => CompletionMode::Normal,
|
||||
CompletionMode::Normal => CompletionMode::Burn,
|
||||
});
|
||||
thread.set_completion_mode(
|
||||
match current_mode {
|
||||
CompletionMode::Burn => CompletionMode::Normal,
|
||||
CompletionMode::Normal => CompletionMode::Burn,
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -3274,8 +3277,8 @@ impl AcpThreadView {
|
|||
.tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
|
||||
.on_click({
|
||||
cx.listener(move |this, _, _window, cx| {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.set_completion_mode(CompletionMode::Burn);
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(CompletionMode::Burn, cx);
|
||||
});
|
||||
this.resume_chat(cx);
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue