WIP
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
This commit is contained in:
parent
205b1371aa
commit
d83210d978
8 changed files with 319 additions and 104 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -211,6 +211,7 @@ dependencies = [
|
|||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"git",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
|
|
|
@ -27,7 +27,10 @@ pub trait AgentConnection {
|
|||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn list_threads(&self, _cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
|
||||
fn list_threads(
|
||||
&self,
|
||||
_cx: &mut App,
|
||||
) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
|
||||
return None;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ collections.workspace = true
|
|||
context_server.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
git.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
html_to_markdown.workspace = true
|
||||
|
@ -72,6 +73,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
|
|||
editor = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
git = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language = { workspace = true, "features" = ["test-support"] }
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
||||
UserMessageContent, WebSearchTool, templates::Templates,
|
||||
};
|
||||
use crate::{DbThread, ThreadId, ThreadsDatabase};
|
||||
use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id};
|
||||
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
|
@ -44,6 +44,8 @@ const RULES_FILE_NAMES: [&'static str; 9] = [
|
|||
"GEMINI.md",
|
||||
];
|
||||
|
||||
const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500);
|
||||
|
||||
pub struct RulesLoadingError {
|
||||
pub message: SharedString,
|
||||
}
|
||||
|
@ -54,7 +56,8 @@ struct Session {
|
|||
thread: Entity<Thread>,
|
||||
/// The ACP thread that handles protocol communication
|
||||
acp_thread: WeakEntity<acp_thread::AcpThread>,
|
||||
_subscription: Subscription,
|
||||
save_task: Task<Result<()>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
pub struct LanguageModels {
|
||||
|
@ -169,8 +172,9 @@ pub struct NativeAgent {
|
|||
models: LanguageModels,
|
||||
project: Entity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_database: Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
|
||||
history_listeners: Vec<UnboundedSender<Vec<AcpThreadMetadata>>>,
|
||||
thread_database: Arc<ThreadsDatabase>,
|
||||
history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
|
||||
load_history: Task<Result<()>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
@ -189,6 +193,11 @@ impl NativeAgent {
|
|||
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
|
||||
.await;
|
||||
|
||||
let thread_database = cx
|
||||
.update(|cx| ThreadsDatabase::connect(cx))?
|
||||
.await
|
||||
.map_err(|e| anyhow!(e))?;
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut subscriptions = vec![
|
||||
cx.subscribe(&project, Self::handle_project_event),
|
||||
|
@ -203,7 +212,7 @@ impl NativeAgent {
|
|||
|
||||
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
|
||||
watch::channel(());
|
||||
Self {
|
||||
let this = Self {
|
||||
sessions: HashMap::new(),
|
||||
project_context: Rc::new(RefCell::new(project_context)),
|
||||
project_context_needs_refresh: project_context_needs_refresh_tx,
|
||||
|
@ -213,18 +222,85 @@ impl NativeAgent {
|
|||
context_server_registry: cx.new(|cx| {
|
||||
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
|
||||
}),
|
||||
thread_database: ThreadsDatabase::connect(cx),
|
||||
thread_database,
|
||||
templates,
|
||||
models: LanguageModels::new(cx),
|
||||
project,
|
||||
prompt_store,
|
||||
fs,
|
||||
history_listeners: Vec::new(),
|
||||
history: watch::channel(None).0,
|
||||
load_history: Task::ready(Ok(())),
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
};
|
||||
this.reload_history(cx);
|
||||
this
|
||||
})
|
||||
}
|
||||
|
||||
pub fn insert_session(
|
||||
&mut self,
|
||||
thread: Entity<Thread>,
|
||||
acp_thread: Entity<AcpThread>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let id = thread.read(cx).id().clone();
|
||||
self.sessions.insert(
|
||||
id,
|
||||
Session {
|
||||
thread: thread.clone(),
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
save_task: Task::ready(()),
|
||||
_subscriptions: vec![
|
||||
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
}),
|
||||
cx.observe(&thread, |this, thread, cx| {
|
||||
this.save_thread(thread.clone(), cx)
|
||||
}),
|
||||
],
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
|
||||
let id = thread.read(cx).id().clone();
|
||||
let Some(session) = self.sessions.get_mut(&id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let thread = thread.downgrade();
|
||||
let thread_database = self.thread_database.clone();
|
||||
session.save_task = cx.spawn(async move |this, cx| {
|
||||
cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
|
||||
let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await;
|
||||
thread_database.save_thread(id, db_thread).await?;
|
||||
this.update(cx, |this, cx| this.reload_history(cx))?;
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
fn reload_history(&mut self, cx: &mut Context<Self>) {
|
||||
let thread_database = self.thread_database.clone();
|
||||
self.load_history = cx.spawn(async move |this, cx| {
|
||||
let results = cx
|
||||
.background_spawn(async move {
|
||||
let results = thread_database.list_threads().await?;
|
||||
Ok(results
|
||||
.into_iter()
|
||||
.map(|thread| AcpThreadMetadata {
|
||||
agent: NATIVE_AGENT_SERVER_NAME.clone(),
|
||||
id: thread.id.into(),
|
||||
title: thread.title,
|
||||
updated_at: thread.updated_at,
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
.await?;
|
||||
this.update(cx, |this, cx| this.history.send(Some(results)))?;
|
||||
anyhow::Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
pub fn models(&self) -> &LanguageModels {
|
||||
&self.models
|
||||
}
|
||||
|
@ -699,7 +775,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
log::debug!("Starting thread creation in async context");
|
||||
|
||||
// Generate session ID
|
||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
||||
let session_id = generate_session_id();
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
|
||||
// Create AcpThread
|
||||
|
@ -743,6 +819,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(
|
||||
session_id.clone(),
|
||||
project.clone(),
|
||||
agent.project_context.clone(),
|
||||
agent.context_server_registry.clone(),
|
||||
|
@ -761,16 +838,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, cx| {
|
||||
agent.sessions.insert(
|
||||
session_id,
|
||||
Session {
|
||||
thread,
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
}),
|
||||
},
|
||||
);
|
||||
agent.insert_session(thread, acp_thread.clone(), cx)
|
||||
})?;
|
||||
|
||||
Ok(acp_thread)
|
||||
|
@ -785,35 +853,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn list_threads(&self, cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
|
||||
let (mut tx, rx) = futures::channel::mpsc::unbounded();
|
||||
let database = self.0.update(cx, |this, _| {
|
||||
this.history_listeners.push(tx.clone());
|
||||
this.thread_database.clone()
|
||||
});
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
dbg!("listing!");
|
||||
let database = database.await.map_err(|e| anyhow!(e))?;
|
||||
let results = database.list_threads().await?;
|
||||
|
||||
dbg!(&results);
|
||||
tx.send(
|
||||
results
|
||||
.into_iter()
|
||||
.map(|thread| AcpThreadMetadata {
|
||||
agent: NATIVE_AGENT_SERVER_NAME.clone(),
|
||||
id: thread.id.into(),
|
||||
title: thread.title,
|
||||
updated_at: thread.updated_at,
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
.await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
Some(rx)
|
||||
fn list_threads(
|
||||
&self,
|
||||
cx: &mut App,
|
||||
) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
|
||||
Some(self.0.read(cx).history.receiver())
|
||||
}
|
||||
|
||||
fn load_thread(
|
||||
|
@ -890,16 +934,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, cx| {
|
||||
agent.sessions.insert(
|
||||
session_id,
|
||||
Session {
|
||||
thread: thread.clone(),
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
}),
|
||||
},
|
||||
);
|
||||
agent.insert_session(session_id, thread, acp_thread, cx)
|
||||
})?;
|
||||
|
||||
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
|
||||
|
|
|
@ -15,3 +15,9 @@ pub use native_agent_server::NativeAgentServer;
|
|||
pub use templates::*;
|
||||
pub use thread::*;
|
||||
pub use tools::*;
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
|
||||
pub fn generate_session_id() -> acp::SessionId {
|
||||
acp::SessionId(uuid::Uuid::new_v4().to_string().into())
|
||||
}
|
||||
|
|
|
@ -709,9 +709,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
|||
);
|
||||
}
|
||||
|
||||
async fn expect_tool_call(
|
||||
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||
) -> acp::ToolCall {
|
||||
async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
|
||||
let event = events
|
||||
.next()
|
||||
.await
|
||||
|
@ -1501,6 +1499,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
project_context.clone(),
|
||||
context_server_registry,
|
||||
|
|
|
@ -1,25 +1,35 @@
|
|||
use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
|
||||
use crate::{
|
||||
ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates,
|
||||
};
|
||||
use acp_thread::{MentionUri, UserMessageId};
|
||||
use action_log::ActionLog;
|
||||
use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
|
||||
use collections::IndexMap;
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
FutureExt,
|
||||
channel::{mpsc, oneshot},
|
||||
future::Shared,
|
||||
stream::FuturesUnordered,
|
||||
};
|
||||
use git::repository::DiffType;
|
||||
use gpui::{App, AppContext, Context, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::{
|
||||
Project,
|
||||
git_store::{GitStore, RepositoryState},
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
use schemars::{JsonSchema, Schema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -32,41 +42,6 @@ use uuid::Uuid;
|
|||
|
||||
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
|
||||
|
||||
#[derive(
|
||||
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
|
||||
)]
|
||||
pub struct ThreadId(pub(crate) Arc<str>);
|
||||
|
||||
impl ThreadId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4().to_string().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ThreadId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ThreadId {
|
||||
fn from(value: &str) -> Self {
|
||||
Self(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<acp::SessionId> for ThreadId {
|
||||
fn from(value: acp::SessionId) -> Self {
|
||||
Self(value.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ThreadId> for acp::SessionId {
|
||||
fn from(value: ThreadId) -> Self {
|
||||
Self(value.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// The ID of the user prompt that initiated a request.
|
||||
///
|
||||
/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
|
||||
|
@ -461,9 +436,28 @@ pub struct ToolCallAuthorization {
|
|||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
||||
enum ThreadTitle {
|
||||
None,
|
||||
Pending(Task<()>),
|
||||
Done(Result<SharedString>),
|
||||
}
|
||||
|
||||
impl ThreadTitle {
|
||||
pub fn unwrap_or_default(&self) -> SharedString {
|
||||
if let ThreadTitle::Done(Ok(title)) = self {
|
||||
title.clone()
|
||||
} else {
|
||||
"New Thread".into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
id: ThreadId,
|
||||
id: acp::SessionId,
|
||||
prompt_id: PromptId,
|
||||
updated_at: DateTime<Utc>,
|
||||
title: ThreadTitle,
|
||||
summary: DetailedSummaryState,
|
||||
messages: Vec<Message>,
|
||||
completion_mode: CompletionMode,
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
|
@ -473,6 +467,9 @@ pub struct Thread {
|
|||
pending_message: Option<AgentMessage>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
tool_use_limit_reached: bool,
|
||||
request_token_usage: Vec<TokenUsage>,
|
||||
cumulative_token_usage: TokenUsage,
|
||||
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
profile_id: AgentProfileId,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
|
@ -484,6 +481,7 @@ pub struct Thread {
|
|||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
id: acp::SessionId,
|
||||
project: Entity<Project>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
|
@ -494,14 +492,25 @@ impl Thread {
|
|||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
id,
|
||||
prompt_id: PromptId::new(),
|
||||
updated_at: Utc::now(),
|
||||
title: ThreadTitle::None,
|
||||
summary: DetailedSummaryState::default(),
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
running_turn: None,
|
||||
pending_message: None,
|
||||
tools: BTreeMap::default(),
|
||||
tool_use_limit_reached: false,
|
||||
request_token_usage: Vec::new(),
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
initial_project_snapshot: {
|
||||
let project_snapshot = Self::project_snapshot(project.clone(), cx);
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { Some(project_snapshot.await) })
|
||||
.shared()
|
||||
},
|
||||
context_server_registry,
|
||||
profile_id,
|
||||
project_context,
|
||||
|
@ -512,8 +521,12 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &acp::SessionId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
pub fn from_db(
|
||||
id: ThreadId,
|
||||
id: acp::SessionId,
|
||||
db_thread: DbThread,
|
||||
project: Entity<Project>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
|
@ -529,12 +542,17 @@ impl Thread {
|
|||
Self {
|
||||
id,
|
||||
prompt_id: PromptId::new(),
|
||||
title: ThreadTitle::Done(Ok(db_thread.title.clone())),
|
||||
summary: db_thread.summary,
|
||||
messages: db_thread.messages,
|
||||
completion_mode: CompletionMode::Normal,
|
||||
running_turn: None,
|
||||
pending_message: None,
|
||||
tools: BTreeMap::default(),
|
||||
tool_use_limit_reached: false,
|
||||
request_token_usage: db_thread.request_token_usage.clone(),
|
||||
cumulative_token_usage: db_thread.cumulative_token_usage.clone(),
|
||||
initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
|
||||
context_server_registry,
|
||||
profile_id,
|
||||
project_context,
|
||||
|
@ -542,9 +560,35 @@ impl Thread {
|
|||
model,
|
||||
project,
|
||||
action_log,
|
||||
updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_db(&self, cx: &App) -> Task<DbThread> {
|
||||
let initial_project_snapshot = self.initial_project_snapshot.clone();
|
||||
let mut thread = DbThread {
|
||||
title: self.title.unwrap_or_default(),
|
||||
messages: self.messages.clone(),
|
||||
updated_at: self.updated_at.clone(),
|
||||
summary: self.summary.clone(),
|
||||
initial_project_snapshot: None,
|
||||
cumulative_token_usage: self.cumulative_token_usage.clone(),
|
||||
request_token_usage: self.request_token_usage.clone(),
|
||||
model: Some(DbLanguageModel {
|
||||
provider: self.model.provider_id().to_string(),
|
||||
model: self.model.name().0.to_string(),
|
||||
}),
|
||||
completion_mode: Some(self.completion_mode.into()),
|
||||
profile: Some(self.profile_id.clone()),
|
||||
};
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let initial_project_snapshot = initial_project_snapshot.await;
|
||||
thread.initial_project_snapshot = initial_project_snapshot;
|
||||
thread
|
||||
})
|
||||
}
|
||||
|
||||
pub fn replay(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -630,6 +674,122 @@ impl Thread {
|
|||
);
|
||||
}
|
||||
|
||||
/// Create a snapshot of the current project state including git information and unsaved buffers.
|
||||
fn project_snapshot(
|
||||
project: Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Arc<agent::thread::ProjectSnapshot>> {
|
||||
let git_store = project.read(cx).git_store().clone();
|
||||
let worktree_snapshots: Vec<_> = project
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
|
||||
.collect();
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
|
||||
|
||||
let mut unsaved_buffers = Vec::new();
|
||||
cx.update(|app_cx| {
|
||||
let buffer_store = project.read(app_cx).buffer_store();
|
||||
for buffer_handle in buffer_store.read(app_cx).buffers() {
|
||||
let buffer = buffer_handle.read(app_cx);
|
||||
if buffer.is_dirty() {
|
||||
if let Some(file) = buffer.file() {
|
||||
let path = file.path().to_string_lossy().to_string();
|
||||
unsaved_buffers.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
|
||||
Arc::new(ProjectSnapshot {
|
||||
worktree_snapshots,
|
||||
unsaved_buffer_paths: unsaved_buffers,
|
||||
timestamp: Utc::now(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn worktree_snapshot(
|
||||
worktree: Entity<project::Worktree>,
|
||||
git_store: Entity<GitStore>,
|
||||
cx: &App,
|
||||
) -> Task<agent::thread::WorktreeSnapshot> {
|
||||
cx.spawn(async move |cx| {
|
||||
// Get worktree path and snapshot
|
||||
let worktree_info = cx.update(|app_cx| {
|
||||
let worktree = worktree.read(app_cx);
|
||||
let path = worktree.abs_path().to_string_lossy().to_string();
|
||||
let snapshot = worktree.snapshot();
|
||||
(path, snapshot)
|
||||
});
|
||||
|
||||
let Ok((worktree_path, _snapshot)) = worktree_info else {
|
||||
return WorktreeSnapshot {
|
||||
worktree_path: String::new(),
|
||||
git_state: None,
|
||||
};
|
||||
};
|
||||
|
||||
let git_state = git_store
|
||||
.update(cx, |git_store, cx| {
|
||||
git_store
|
||||
.repositories()
|
||||
.values()
|
||||
.find(|repo| {
|
||||
repo.read(cx)
|
||||
.abs_path_to_repo_path(&worktree.read(cx).abs_path())
|
||||
.is_some()
|
||||
})
|
||||
.cloned()
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|repo| {
|
||||
repo.update(cx, |repo, _| {
|
||||
let current_branch =
|
||||
repo.branch.as_ref().map(|branch| branch.name().to_owned());
|
||||
repo.send_job(None, |state, _| async move {
|
||||
let RepositoryState::Local { backend, .. } = state else {
|
||||
return GitState {
|
||||
remote_url: None,
|
||||
head_sha: None,
|
||||
current_branch,
|
||||
diff: None,
|
||||
};
|
||||
};
|
||||
|
||||
let remote_url = backend.remote_url("origin");
|
||||
let head_sha = backend.head_sha().await;
|
||||
let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
|
||||
|
||||
GitState {
|
||||
remote_url,
|
||||
head_sha,
|
||||
current_branch,
|
||||
diff,
|
||||
}
|
||||
})
|
||||
})
|
||||
});
|
||||
|
||||
let git_state = match git_state {
|
||||
Some(git_state) => match git_state.ok() {
|
||||
Some(git_state) => git_state.await.ok(),
|
||||
None => None,
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
WorktreeSnapshot {
|
||||
worktree_path,
|
||||
git_state,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn project(&self) -> &Entity<Project> {
|
||||
&self.project
|
||||
}
|
||||
|
|
|
@ -522,7 +522,7 @@ fn resolve_path(
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ContextServerRegistry, Templates};
|
||||
use crate::{ContextServerRegistry, Templates, generate_session_id};
|
||||
use action_log::ActionLog;
|
||||
use client::TelemetrySettings;
|
||||
use fs::Fs;
|
||||
|
@ -547,6 +547,7 @@ mod tests {
|
|||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
|
@ -748,6 +749,7 @@ mod tests {
|
|||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
|
@ -890,6 +892,7 @@ mod tests {
|
|||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
|
@ -1019,6 +1022,7 @@ mod tests {
|
|||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
|
@ -1157,6 +1161,7 @@ mod tests {
|
|||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
|
@ -1267,6 +1272,7 @@ mod tests {
|
|||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
|
@ -1349,6 +1355,7 @@ mod tests {
|
|||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
|
@ -1434,6 +1441,7 @@ mod tests {
|
|||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
|
@ -1516,6 +1524,7 @@ mod tests {
|
|||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue