Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-18 17:30:12 +02:00
parent 205b1371aa
commit d83210d978
8 changed files with 319 additions and 104 deletions

1
Cargo.lock generated
View file

@ -211,6 +211,7 @@ dependencies = [
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"git",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",

View file

@ -27,7 +27,10 @@ pub trait AgentConnection {
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>>;
fn list_threads(&self, _cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
fn list_threads(
&self,
_cx: &mut App,
) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
return None;
}

View file

@ -27,6 +27,7 @@ collections.workspace = true
context_server.workspace = true
fs.workspace = true
futures.workspace = true
git.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
@ -72,6 +73,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
git = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }

View file

@ -5,7 +5,7 @@ use crate::{
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
UserMessageContent, WebSearchTool, templates::Templates,
};
use crate::{DbThread, ThreadId, ThreadsDatabase};
use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id};
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
@ -44,6 +44,8 @@ const RULES_FILE_NAMES: [&'static str; 9] = [
"GEMINI.md",
];
const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500);
pub struct RulesLoadingError {
pub message: SharedString,
}
@ -54,7 +56,8 @@ struct Session {
thread: Entity<Thread>,
/// The ACP thread that handles protocol communication
acp_thread: WeakEntity<acp_thread::AcpThread>,
_subscription: Subscription,
save_task: Task<Result<()>>,
_subscriptions: Vec<Subscription>,
}
pub struct LanguageModels {
@ -169,8 +172,9 @@ pub struct NativeAgent {
models: LanguageModels,
project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
thread_database: Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
history_listeners: Vec<UnboundedSender<Vec<AcpThreadMetadata>>>,
thread_database: Arc<ThreadsDatabase>,
history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
load_history: Task<Result<()>>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
}
@ -189,6 +193,11 @@ impl NativeAgent {
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
.await;
let thread_database = cx
.update(|cx| ThreadsDatabase::connect(cx))?
.await
.map_err(|e| anyhow!(e))?;
cx.new(|cx| {
let mut subscriptions = vec![
cx.subscribe(&project, Self::handle_project_event),
@ -203,7 +212,7 @@ impl NativeAgent {
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
watch::channel(());
Self {
let this = Self {
sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)),
project_context_needs_refresh: project_context_needs_refresh_tx,
@ -213,18 +222,85 @@ impl NativeAgent {
context_server_registry: cx.new(|cx| {
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
}),
thread_database: ThreadsDatabase::connect(cx),
thread_database,
templates,
models: LanguageModels::new(cx),
project,
prompt_store,
fs,
history_listeners: Vec::new(),
history: watch::channel(None).0,
load_history: Task::ready(Ok(())),
_subscriptions: subscriptions,
}
};
this.reload_history(cx);
this
})
}
pub fn insert_session(
&mut self,
thread: Entity<Thread>,
acp_thread: Entity<AcpThread>,
cx: &mut Context<Self>,
) {
let id = thread.read(cx).id().clone();
self.sessions.insert(
id,
Session {
thread: thread.clone(),
acp_thread: acp_thread.downgrade(),
save_task: Task::ready(()),
_subscriptions: vec![
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
cx.observe(&thread, |this, thread, cx| {
this.save_thread(thread.clone(), cx)
}),
],
},
);
}
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
let id = thread.read(cx).id().clone();
let Some(session) = self.sessions.get_mut(&id) else {
return;
};
let thread = thread.downgrade();
let thread_database = self.thread_database.clone();
session.save_task = cx.spawn(async move |this, cx| {
cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await;
thread_database.save_thread(id, db_thread).await?;
this.update(cx, |this, cx| this.reload_history(cx))?;
Ok(())
});
}
fn reload_history(&mut self, cx: &mut Context<Self>) {
let thread_database = self.thread_database.clone();
self.load_history = cx.spawn(async move |this, cx| {
let results = cx
.background_spawn(async move {
let results = thread_database.list_threads().await?;
Ok(results
.into_iter()
.map(|thread| AcpThreadMetadata {
agent: NATIVE_AGENT_SERVER_NAME.clone(),
id: thread.id.into(),
title: thread.title,
updated_at: thread.updated_at,
})
.collect())
})
.await?;
this.update(cx, |this, cx| this.history.send(Some(results)))?;
anyhow::Ok(())
});
}
pub fn models(&self) -> &LanguageModels {
&self.models
}
@ -699,7 +775,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::debug!("Starting thread creation in async context");
// Generate session ID
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
let session_id = generate_session_id();
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
@ -743,6 +819,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let thread = cx.new(|cx| {
let mut thread = Thread::new(
session_id.clone(),
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
@ -761,16 +838,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
session_id,
Session {
thread,
acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
},
);
agent.insert_session(thread, acp_thread.clone(), cx)
})?;
Ok(acp_thread)
@ -785,35 +853,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
Task::ready(Ok(()))
}
fn list_threads(&self, cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
let (mut tx, rx) = futures::channel::mpsc::unbounded();
let database = self.0.update(cx, |this, _| {
this.history_listeners.push(tx.clone());
this.thread_database.clone()
});
cx.background_executor()
.spawn(async move {
dbg!("listing!");
let database = database.await.map_err(|e| anyhow!(e))?;
let results = database.list_threads().await?;
dbg!(&results);
tx.send(
results
.into_iter()
.map(|thread| AcpThreadMetadata {
agent: NATIVE_AGENT_SERVER_NAME.clone(),
id: thread.id.into(),
title: thread.title,
updated_at: thread.updated_at,
})
.collect(),
)
.await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
Some(rx)
fn list_threads(
&self,
cx: &mut App,
) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
Some(self.0.read(cx).history.receiver())
}
fn load_thread(
@ -890,16 +934,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
session_id,
Session {
thread: thread.clone(),
acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
},
);
agent.insert_session(session_id, thread, acp_thread, cx)
})?;
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;

View file

@ -15,3 +15,9 @@ pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*;
pub use tools::*;
use agent_client_protocol as acp;
pub fn generate_session_id() -> acp::SessionId {
acp::SessionId(uuid::Uuid::new_v4().to_string().into())
}

View file

@ -709,9 +709,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
);
}
async fn expect_tool_call(
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> acp::ToolCall {
async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
let event = events
.next()
.await
@ -1501,6 +1499,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project,
project_context.clone(),
context_server_registry,

View file

@ -1,25 +1,35 @@
use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
use crate::{
ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates,
};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use chrono::{DateTime, Utc};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
use collections::IndexMap;
use fs::Fs;
use futures::{
FutureExt,
channel::{mpsc, oneshot},
future::Shared,
stream::FuturesUnordered,
};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage,
};
use project::{
Project,
git_store::{GitStore, RepositoryState},
};
use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
@ -32,41 +42,6 @@ use uuid::Uuid;
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
)]
pub struct ThreadId(pub(crate) Arc<str>);
impl ThreadId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}
impl std::fmt::Display for ThreadId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for ThreadId {
fn from(value: &str) -> Self {
Self(value.into())
}
}
impl From<acp::SessionId> for ThreadId {
fn from(value: acp::SessionId) -> Self {
Self(value.0)
}
}
impl From<ThreadId> for acp::SessionId {
fn from(value: ThreadId) -> Self {
Self(value.0)
}
}
/// The ID of the user prompt that initiated a request.
///
/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
@ -461,9 +436,28 @@ pub struct ToolCallAuthorization {
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
enum ThreadTitle {
None,
Pending(Task<()>),
Done(Result<SharedString>),
}
impl ThreadTitle {
pub fn unwrap_or_default(&self) -> SharedString {
if let ThreadTitle::Done(Ok(title)) = self {
title.clone()
} else {
"New Thread".into()
}
}
}
pub struct Thread {
id: ThreadId,
id: acp::SessionId,
prompt_id: PromptId,
updated_at: DateTime<Utc>,
title: ThreadTitle,
summary: DetailedSummaryState,
messages: Vec<Message>,
completion_mode: CompletionMode,
/// Holds the task that handles agent interaction until the end of the turn.
@ -473,6 +467,9 @@ pub struct Thread {
pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool,
request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
context_server_registry: Entity<ContextServerRegistry>,
profile_id: AgentProfileId,
project_context: Rc<RefCell<ProjectContext>>,
@ -484,6 +481,7 @@ pub struct Thread {
impl Thread {
pub fn new(
id: acp::SessionId,
project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
context_server_registry: Entity<ContextServerRegistry>,
@ -494,14 +492,25 @@ impl Thread {
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
Self {
id: ThreadId::new(),
id,
prompt_id: PromptId::new(),
updated_at: Utc::now(),
title: ThreadTitle::None,
summary: DetailedSummaryState::default(),
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
running_turn: None,
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project.clone(), cx);
cx.foreground_executor()
.spawn(async move { Some(project_snapshot.await) })
.shared()
},
context_server_registry,
profile_id,
project_context,
@ -512,8 +521,12 @@ impl Thread {
}
}
pub fn id(&self) -> &acp::SessionId {
&self.id
}
pub fn from_db(
id: ThreadId,
id: acp::SessionId,
db_thread: DbThread,
project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
@ -529,12 +542,17 @@ impl Thread {
Self {
id,
prompt_id: PromptId::new(),
title: ThreadTitle::Done(Ok(db_thread.title.clone())),
summary: db_thread.summary,
messages: db_thread.messages,
completion_mode: CompletionMode::Normal,
running_turn: None,
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
request_token_usage: db_thread.request_token_usage.clone(),
cumulative_token_usage: db_thread.cumulative_token_usage.clone(),
initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
context_server_registry,
profile_id,
project_context,
@ -542,9 +560,35 @@ impl Thread {
model,
project,
action_log,
updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
}
}
pub fn to_db(&self, cx: &App) -> Task<DbThread> {
let initial_project_snapshot = self.initial_project_snapshot.clone();
let mut thread = DbThread {
title: self.title.unwrap_or_default(),
messages: self.messages.clone(),
updated_at: self.updated_at.clone(),
summary: self.summary.clone(),
initial_project_snapshot: None,
cumulative_token_usage: self.cumulative_token_usage.clone(),
request_token_usage: self.request_token_usage.clone(),
model: Some(DbLanguageModel {
provider: self.model.provider_id().to_string(),
model: self.model.name().0.to_string(),
}),
completion_mode: Some(self.completion_mode.into()),
profile: Some(self.profile_id.clone()),
};
cx.background_spawn(async move {
let initial_project_snapshot = initial_project_snapshot.await;
thread.initial_project_snapshot = initial_project_snapshot;
thread
})
}
pub fn replay(
&mut self,
cx: &mut Context<Self>,
@ -630,6 +674,122 @@ impl Thread {
);
}
/// Create a snapshot of the current project state including git information and unsaved buffers.
fn project_snapshot(
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Arc<agent::thread::ProjectSnapshot>> {
let git_store = project.read(cx).git_store().clone();
let worktree_snapshots: Vec<_> = project
.read(cx)
.visible_worktrees(cx)
.map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
.collect();
cx.spawn(async move |_, cx| {
let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
let mut unsaved_buffers = Vec::new();
cx.update(|app_cx| {
let buffer_store = project.read(app_cx).buffer_store();
for buffer_handle in buffer_store.read(app_cx).buffers() {
let buffer = buffer_handle.read(app_cx);
if buffer.is_dirty() {
if let Some(file) = buffer.file() {
let path = file.path().to_string_lossy().to_string();
unsaved_buffers.push(path);
}
}
}
})
.ok();
Arc::new(ProjectSnapshot {
worktree_snapshots,
unsaved_buffer_paths: unsaved_buffers,
timestamp: Utc::now(),
})
})
}
fn worktree_snapshot(
worktree: Entity<project::Worktree>,
git_store: Entity<GitStore>,
cx: &App,
) -> Task<agent::thread::WorktreeSnapshot> {
cx.spawn(async move |cx| {
// Get worktree path and snapshot
let worktree_info = cx.update(|app_cx| {
let worktree = worktree.read(app_cx);
let path = worktree.abs_path().to_string_lossy().to_string();
let snapshot = worktree.snapshot();
(path, snapshot)
});
let Ok((worktree_path, _snapshot)) = worktree_info else {
return WorktreeSnapshot {
worktree_path: String::new(),
git_state: None,
};
};
let git_state = git_store
.update(cx, |git_store, cx| {
git_store
.repositories()
.values()
.find(|repo| {
repo.read(cx)
.abs_path_to_repo_path(&worktree.read(cx).abs_path())
.is_some()
})
.cloned()
})
.ok()
.flatten()
.map(|repo| {
repo.update(cx, |repo, _| {
let current_branch =
repo.branch.as_ref().map(|branch| branch.name().to_owned());
repo.send_job(None, |state, _| async move {
let RepositoryState::Local { backend, .. } = state else {
return GitState {
remote_url: None,
head_sha: None,
current_branch,
diff: None,
};
};
let remote_url = backend.remote_url("origin");
let head_sha = backend.head_sha().await;
let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
GitState {
remote_url,
head_sha,
current_branch,
diff,
}
})
})
});
let git_state = match git_state {
Some(git_state) => match git_state.ok() {
Some(git_state) => git_state.await.ok(),
None => None,
},
None => None,
};
WorktreeSnapshot {
worktree_path,
git_state,
}
})
}
pub fn project(&self) -> &Entity<Project> {
&self.project
}

View file

@ -522,7 +522,7 @@ fn resolve_path(
#[cfg(test)]
mod tests {
use super::*;
use crate::{ContextServerRegistry, Templates};
use crate::{ContextServerRegistry, Templates, generate_session_id};
use action_log::ActionLog;
use client::TelemetrySettings;
use fs::Fs;
@ -547,6 +547,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
@ -748,6 +749,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
@ -890,6 +892,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
@ -1019,6 +1022,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
@ -1157,6 +1161,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project,
Rc::default(),
context_server_registry,
@ -1267,6 +1272,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project.clone(),
Rc::default(),
context_server_registry.clone(),
@ -1349,6 +1355,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project.clone(),
Rc::default(),
context_server_registry.clone(),
@ -1434,6 +1441,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project.clone(),
Rc::default(),
context_server_registry.clone(),
@ -1516,6 +1524,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
generate_session_id(),
project.clone(),
Rc::default(),
context_server_registry,