Compare commits
27 commits
main
...
agent2-his
Author | SHA1 | Date | |
---|---|---|---|
![]() |
67e7d1426c | ||
![]() |
3b7ad6236d | ||
![]() |
8998cdee26 | ||
![]() |
3ed2b7691b | ||
![]() |
999449424e | ||
![]() |
8373884cdb | ||
![]() |
5d88de13da | ||
![]() |
cc196427f0 | ||
![]() |
fc076e84ca | ||
![]() |
4b1a48e4de | ||
![]() |
d83210d978 | ||
![]() |
205b1371aa | ||
![]() |
e6e23d04f8 | ||
![]() |
199256e43e | ||
![]() |
3a0e55d9b6 | ||
![]() |
5259c8692d | ||
![]() |
501e72e8f0 | ||
![]() |
a231fd3ee5 | ||
![]() |
fae5900749 | ||
![]() |
fa6c0a1a49 | ||
![]() |
eebe425c1d | ||
![]() |
1b793331b3 | ||
![]() |
e72f6f99c8 | ||
![]() |
296e3fcf69 | ||
![]() |
251baacdab | ||
![]() |
fd8ea2acfc | ||
![]() |
6b6b7e66e1 |
31 changed files with 2860 additions and 453 deletions
7
Cargo.lock
generated
7
Cargo.lock
generated
|
@ -11,6 +11,7 @@ dependencies = [
|
||||||
"agent-client-protocol",
|
"agent-client-protocol",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"buffer_diff",
|
"buffer_diff",
|
||||||
|
"chrono",
|
||||||
"collections",
|
"collections",
|
||||||
"editor",
|
"editor",
|
||||||
"env_logger 0.11.8",
|
"env_logger 0.11.8",
|
||||||
|
@ -191,10 +192,12 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"acp_thread",
|
"acp_thread",
|
||||||
"action_log",
|
"action_log",
|
||||||
|
"agent",
|
||||||
"agent-client-protocol",
|
"agent-client-protocol",
|
||||||
"agent_servers",
|
"agent_servers",
|
||||||
"agent_settings",
|
"agent_settings",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"assistant_context",
|
||||||
"assistant_tool",
|
"assistant_tool",
|
||||||
"assistant_tools",
|
"assistant_tools",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
@ -208,6 +211,7 @@ dependencies = [
|
||||||
"env_logger 0.11.8",
|
"env_logger 0.11.8",
|
||||||
"fs",
|
"fs",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
|
"git",
|
||||||
"gpui",
|
"gpui",
|
||||||
"gpui_tokio",
|
"gpui_tokio",
|
||||||
"handlebars 4.5.0",
|
"handlebars 4.5.0",
|
||||||
|
@ -221,6 +225,7 @@ dependencies = [
|
||||||
"log",
|
"log",
|
||||||
"lsp",
|
"lsp",
|
||||||
"open",
|
"open",
|
||||||
|
"parking_lot",
|
||||||
"paths",
|
"paths",
|
||||||
"portable-pty",
|
"portable-pty",
|
||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
|
@ -233,6 +238,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"settings",
|
"settings",
|
||||||
"smol",
|
"smol",
|
||||||
|
"sqlez",
|
||||||
"task",
|
"task",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"terminal",
|
"terminal",
|
||||||
|
@ -249,6 +255,7 @@ dependencies = [
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
"worktree",
|
"worktree",
|
||||||
"zlog",
|
"zlog",
|
||||||
|
"zstd",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -21,6 +21,7 @@ agent-client-protocol.workspace = true
|
||||||
agent.workspace = true
|
agent.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
buffer_diff.workspace = true
|
buffer_diff.workspace = true
|
||||||
|
chrono.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
editor.workspace = true
|
editor.workspace = true
|
||||||
file_icons.workspace = true
|
file_icons.workspace = true
|
||||||
|
|
|
@ -6,11 +6,13 @@ mod terminal;
|
||||||
pub use connection::*;
|
pub use connection::*;
|
||||||
pub use diff::*;
|
pub use diff::*;
|
||||||
pub use mention::*;
|
pub use mention::*;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
pub use terminal::*;
|
pub use terminal::*;
|
||||||
|
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
use editor::Bias;
|
use editor::Bias;
|
||||||
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
|
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
|
||||||
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||||
|
@ -537,9 +539,15 @@ impl ToolCallContent {
|
||||||
acp::ToolCallContent::Content { content } => {
|
acp::ToolCallContent::Content { content } => {
|
||||||
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
|
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
|
||||||
}
|
}
|
||||||
acp::ToolCallContent::Diff { diff } => {
|
acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
|
||||||
Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
|
Diff::finalized(
|
||||||
}
|
diff.path,
|
||||||
|
diff.old_text,
|
||||||
|
diff.new_text,
|
||||||
|
language_registry,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -658,6 +666,17 @@ impl PlanEntry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||||
|
pub struct AgentServerName(pub SharedString);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AcpThreadMetadata {
|
||||||
|
pub agent: AgentServerName,
|
||||||
|
pub id: acp::SessionId,
|
||||||
|
pub title: SharedString,
|
||||||
|
pub updated_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct AcpThread {
|
pub struct AcpThread {
|
||||||
title: SharedString,
|
title: SharedString,
|
||||||
entries: Vec<AgentThreadEntry>,
|
entries: Vec<AgentThreadEntry>,
|
||||||
|
@ -673,6 +692,7 @@ pub struct AcpThread {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum AcpThreadEvent {
|
pub enum AcpThreadEvent {
|
||||||
NewEntry,
|
NewEntry,
|
||||||
|
TitleUpdated,
|
||||||
EntryUpdated(usize),
|
EntryUpdated(usize),
|
||||||
EntriesRemoved(Range<usize>),
|
EntriesRemoved(Range<usize>),
|
||||||
ToolAuthorizationRequired,
|
ToolAuthorizationRequired,
|
||||||
|
@ -916,6 +936,12 @@ impl AcpThread {
|
||||||
cx.emit(AcpThreadEvent::NewEntry);
|
cx.emit(AcpThreadEvent::NewEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
|
||||||
|
self.title = title;
|
||||||
|
cx.emit(AcpThreadEvent::TitleUpdated);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn update_tool_call(
|
pub fn update_tool_call(
|
||||||
&mut self,
|
&mut self,
|
||||||
update: impl Into<ToolCallUpdate>,
|
update: impl Into<ToolCallUpdate>,
|
||||||
|
@ -1641,7 +1667,7 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
||||||
use gpui::{AsyncApp, TestAppContext, WeakEntity};
|
use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use project::{FakeFs, Fs};
|
use project::{FakeFs, Fs};
|
||||||
use rand::Rng as _;
|
use rand::Rng as _;
|
||||||
|
@ -2311,7 +2337,7 @@ mod tests {
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
_cwd: &Path,
|
_cwd: &Path,
|
||||||
cx: &mut gpui::App,
|
cx: &mut App,
|
||||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||||
let session_id = acp::SessionId(
|
let session_id = acp::SessionId(
|
||||||
rand::thread_rng()
|
rand::thread_rng()
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
use crate::AcpThread;
|
use crate::{AcpThread, AcpThreadMetadata};
|
||||||
use agent_client_protocol::{self as acp};
|
use agent_client_protocol::{self as acp};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use collections::IndexMap;
|
use collections::IndexMap;
|
||||||
|
use futures::channel::mpsc::UnboundedReceiver;
|
||||||
use gpui::{Entity, SharedString, Task};
|
use gpui::{Entity, SharedString, Task};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||||
use ui::{App, IconName};
|
use ui::{App, IconName};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct UserMessageId(Arc<str>);
|
pub struct UserMessageId(Arc<str>);
|
||||||
|
|
||||||
impl UserMessageId {
|
impl UserMessageId {
|
||||||
|
@ -62,6 +64,10 @@ pub trait AgentConnection {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn history(self: Rc<Self>) -> Option<Rc<dyn AgentHistory>> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,6 +85,18 @@ pub trait AgentSessionResume {
|
||||||
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
|
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait AgentHistory {
|
||||||
|
fn list_threads(&self, cx: &mut App) -> Task<Result<Vec<AcpThreadMetadata>>>;
|
||||||
|
fn observe_history(&self, cx: &mut App) -> UnboundedReceiver<AcpThreadMetadata>;
|
||||||
|
fn load_thread(
|
||||||
|
self: Rc<Self>,
|
||||||
|
_project: Entity<Project>,
|
||||||
|
_cwd: &Path,
|
||||||
|
_session_id: acp::SessionId,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Task<Result<Entity<AcpThread>>>;
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct AuthRequired;
|
pub struct AuthRequired;
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use agent_client_protocol as acp;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||||
use editor::{MultiBuffer, PathKey};
|
use editor::{MultiBuffer, PathKey};
|
||||||
|
@ -21,17 +20,13 @@ pub enum Diff {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Diff {
|
impl Diff {
|
||||||
pub fn from_acp(
|
pub fn finalized(
|
||||||
diff: acp::Diff,
|
path: PathBuf,
|
||||||
|
old_text: Option<String>,
|
||||||
|
new_text: String,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let acp::Diff {
|
|
||||||
path,
|
|
||||||
old_text,
|
|
||||||
new_text,
|
|
||||||
} = diff;
|
|
||||||
|
|
||||||
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||||
|
|
||||||
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
||||||
|
|
|
@ -2,6 +2,7 @@ use agent::ThreadId;
|
||||||
use anyhow::{Context as _, Result, bail};
|
use anyhow::{Context as _, Result, bail};
|
||||||
use file_icons::FileIcons;
|
use file_icons::FileIcons;
|
||||||
use prompt_store::{PromptId, UserPromptId};
|
use prompt_store::{PromptId, UserPromptId};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
fmt,
|
fmt,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
|
@ -11,7 +12,7 @@ use std::{
|
||||||
use ui::{App, IconName, SharedString};
|
use ui::{App, IconName, SharedString};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub enum MentionUri {
|
pub enum MentionUri {
|
||||||
File {
|
File {
|
||||||
abs_path: PathBuf,
|
abs_path: PathBuf,
|
||||||
|
|
|
@ -62,7 +62,7 @@ enum SerializedRecentOpen {
|
||||||
|
|
||||||
pub struct HistoryStore {
|
pub struct HistoryStore {
|
||||||
thread_store: Entity<ThreadStore>,
|
thread_store: Entity<ThreadStore>,
|
||||||
context_store: Entity<assistant_context::ContextStore>,
|
pub context_store: Entity<assistant_context::ContextStore>,
|
||||||
recently_opened_entries: VecDeque<HistoryEntryId>,
|
recently_opened_entries: VecDeque<HistoryEntryId>,
|
||||||
_subscriptions: Vec<gpui::Subscription>,
|
_subscriptions: Vec<gpui::Subscription>,
|
||||||
_save_recently_opened_entries_task: Task<()>,
|
_save_recently_opened_entries_task: Task<()>,
|
||||||
|
|
|
@ -893,7 +893,7 @@ impl ThreadsDatabase {
|
||||||
|
|
||||||
let needs_migration_from_heed = mdb_path.exists();
|
let needs_migration_from_heed = mdb_path.exists();
|
||||||
|
|
||||||
let connection = if *ZED_STATELESS {
|
let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
|
||||||
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
|
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
|
||||||
} else {
|
} else {
|
||||||
Connection::open_file(&sqlite_path.to_string_lossy())
|
Connection::open_file(&sqlite_path.to_string_lossy())
|
||||||
|
|
|
@ -17,6 +17,7 @@ action_log.workspace = true
|
||||||
agent-client-protocol.workspace = true
|
agent-client-protocol.workspace = true
|
||||||
agent_servers.workspace = true
|
agent_servers.workspace = true
|
||||||
agent_settings.workspace = true
|
agent_settings.workspace = true
|
||||||
|
agent.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
assistant_tool.workspace = true
|
assistant_tool.workspace = true
|
||||||
assistant_tools.workspace = true
|
assistant_tools.workspace = true
|
||||||
|
@ -26,6 +27,7 @@ collections.workspace = true
|
||||||
context_server.workspace = true
|
context_server.workspace = true
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
|
git.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||||
html_to_markdown.workspace = true
|
html_to_markdown.workspace = true
|
||||||
|
@ -37,6 +39,7 @@ language_model.workspace = true
|
||||||
language_models.workspace = true
|
language_models.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
open.workspace = true
|
open.workspace = true
|
||||||
|
parking_lot.workspace = true
|
||||||
paths.workspace = true
|
paths.workspace = true
|
||||||
portable-pty.workspace = true
|
portable-pty.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
|
@ -46,6 +49,7 @@ schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
|
sqlez.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
task.workspace = true
|
task.workspace = true
|
||||||
terminal.workspace = true
|
terminal.workspace = true
|
||||||
|
@ -57,8 +61,12 @@ watch.workspace = true
|
||||||
web_search.workspace = true
|
web_search.workspace = true
|
||||||
which.workspace = true
|
which.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
zstd.workspace = true
|
||||||
|
assistant_context.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
agent = { workspace = true, "features" = ["test-support"] }
|
||||||
|
acp_thread = { workspace = true, "features" = ["test-support"] }
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
client = { workspace = true, "features" = ["test-support"] }
|
client = { workspace = true, "features" = ["test-support"] }
|
||||||
clock = { workspace = true, "features" = ["test-support"] }
|
clock = { workspace = true, "features" = ["test-support"] }
|
||||||
|
@ -66,6 +74,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
|
||||||
editor = { workspace = true, "features" = ["test-support"] }
|
editor = { workspace = true, "features" = ["test-support"] }
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
fs = { workspace = true, "features" = ["test-support"] }
|
fs = { workspace = true, "features" = ["test-support"] }
|
||||||
|
git = { workspace = true, "features" = ["test-support"] }
|
||||||
gpui = { workspace = true, "features" = ["test-support"] }
|
gpui = { workspace = true, "features" = ["test-support"] }
|
||||||
gpui_tokio.workspace = true
|
gpui_tokio.workspace = true
|
||||||
language = { workspace = true, "features" = ["test-support"] }
|
language = { workspace = true, "features" = ["test-support"] }
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
|
use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
|
||||||
use crate::{
|
use crate::{
|
||||||
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
|
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
|
||||||
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
|
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
|
||||||
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
|
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
||||||
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
|
UserMessageContent, WebSearchTool, templates::Templates,
|
||||||
};
|
};
|
||||||
use acp_thread::AgentModelSelector;
|
use crate::{ThreadsDatabase, generate_session_id};
|
||||||
|
use acp_thread::{AcpThread, AcpThreadMetadata, AgentHistory, AgentModelSelector};
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::AgentSettings;
|
use agent_settings::AgentSettings;
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
|
@ -15,7 +17,7 @@ use futures::{StreamExt, future};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
|
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
|
||||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||||
use prompt_store::{
|
use prompt_store::{
|
||||||
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
||||||
|
@ -27,6 +29,7 @@ use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
const RULES_FILE_NAMES: [&'static str; 9] = [
|
const RULES_FILE_NAMES: [&'static str; 9] = [
|
||||||
|
@ -41,6 +44,8 @@ const RULES_FILE_NAMES: [&'static str; 9] = [
|
||||||
"GEMINI.md",
|
"GEMINI.md",
|
||||||
];
|
];
|
||||||
|
|
||||||
|
const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500);
|
||||||
|
|
||||||
pub struct RulesLoadingError {
|
pub struct RulesLoadingError {
|
||||||
pub message: SharedString,
|
pub message: SharedString,
|
||||||
}
|
}
|
||||||
|
@ -51,7 +56,8 @@ struct Session {
|
||||||
thread: Entity<Thread>,
|
thread: Entity<Thread>,
|
||||||
/// The ACP thread that handles protocol communication
|
/// The ACP thread that handles protocol communication
|
||||||
acp_thread: WeakEntity<acp_thread::AcpThread>,
|
acp_thread: WeakEntity<acp_thread::AcpThread>,
|
||||||
_subscription: Subscription,
|
save_task: Task<()>,
|
||||||
|
_subscriptions: Vec<Subscription>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LanguageModels {
|
pub struct LanguageModels {
|
||||||
|
@ -166,6 +172,8 @@ pub struct NativeAgent {
|
||||||
models: LanguageModels,
|
models: LanguageModels,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
prompt_store: Option<Entity<PromptStore>>,
|
prompt_store: Option<Entity<PromptStore>>,
|
||||||
|
thread_database: Arc<ThreadsDatabase>,
|
||||||
|
history_watchers: Vec<mpsc::UnboundedSender<AcpThreadMetadata>>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
_subscriptions: Vec<Subscription>,
|
_subscriptions: Vec<Subscription>,
|
||||||
}
|
}
|
||||||
|
@ -184,6 +192,11 @@ impl NativeAgent {
|
||||||
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
|
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
let thread_database = cx
|
||||||
|
.update(|cx| ThreadsDatabase::connect(cx))?
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!(e))?;
|
||||||
|
|
||||||
cx.new(|cx| {
|
cx.new(|cx| {
|
||||||
let mut subscriptions = vec![
|
let mut subscriptions = vec![
|
||||||
cx.subscribe(&project, Self::handle_project_event),
|
cx.subscribe(&project, Self::handle_project_event),
|
||||||
|
@ -208,16 +221,87 @@ impl NativeAgent {
|
||||||
context_server_registry: cx.new(|cx| {
|
context_server_registry: cx.new(|cx| {
|
||||||
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
|
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
|
||||||
}),
|
}),
|
||||||
|
thread_database,
|
||||||
templates,
|
templates,
|
||||||
models: LanguageModels::new(cx),
|
models: LanguageModels::new(cx),
|
||||||
project,
|
project,
|
||||||
prompt_store,
|
prompt_store,
|
||||||
fs,
|
fs,
|
||||||
|
history_watchers: Vec::new(),
|
||||||
_subscriptions: subscriptions,
|
_subscriptions: subscriptions,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn insert_session(
|
||||||
|
&mut self,
|
||||||
|
thread: Entity<Thread>,
|
||||||
|
acp_thread: Entity<AcpThread>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let id = thread.read(cx).id().clone();
|
||||||
|
let weak_thread = acp_thread.downgrade();
|
||||||
|
self.sessions.insert(
|
||||||
|
id,
|
||||||
|
Session {
|
||||||
|
thread: thread.clone(),
|
||||||
|
acp_thread: weak_thread.clone(),
|
||||||
|
save_task: Task::ready(()),
|
||||||
|
_subscriptions: vec![
|
||||||
|
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||||
|
this.sessions.remove(acp_thread.session_id());
|
||||||
|
}),
|
||||||
|
cx.observe(&thread, move |this, thread, cx| {
|
||||||
|
if let Some(response_stream) =
|
||||||
|
thread.update(cx, |thread, cx| thread.generate_title_if_needed(cx))
|
||||||
|
{
|
||||||
|
NativeAgentConnection::handle_thread_events(
|
||||||
|
response_stream,
|
||||||
|
weak_thread.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
}
|
||||||
|
this.save_thread(thread.clone(), cx)
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_thread(&mut self, thread_handle: Entity<Thread>, cx: &mut Context<Self>) {
|
||||||
|
let thread = thread_handle.read(cx);
|
||||||
|
let id = thread.id().clone();
|
||||||
|
let Some(session) = self.sessions.get_mut(&id) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread = thread_handle.downgrade();
|
||||||
|
let thread_database = self.thread_database.clone();
|
||||||
|
session.save_task = cx.spawn(async move |this, cx| {
|
||||||
|
cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
|
||||||
|
|
||||||
|
let Some(task) = thread.update(cx, |thread, cx| thread.to_db(cx)).ok() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let db_thread = task.await;
|
||||||
|
let metadata = thread_database
|
||||||
|
.save_thread(id.clone(), db_thread)
|
||||||
|
.await
|
||||||
|
.log_err();
|
||||||
|
if let Some(metadata) = metadata {
|
||||||
|
this.update(cx, |this, _| {
|
||||||
|
for watcher in this.history_watchers.iter_mut() {
|
||||||
|
watcher
|
||||||
|
.unbounded_send(metadata.clone().to_acp(NATIVE_AGENT_SERVER_NAME))
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub fn models(&self) -> &LanguageModels {
|
pub fn models(&self) -> &LanguageModels {
|
||||||
&self.models
|
&self.models
|
||||||
}
|
}
|
||||||
|
@ -420,7 +504,7 @@ impl NativeAgent {
|
||||||
|
|
||||||
fn handle_models_updated_event(
|
fn handle_models_updated_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
_registry: Entity<LanguageModelRegistry>,
|
registry: Entity<LanguageModelRegistry>,
|
||||||
_event: &language_model::Event,
|
_event: &language_model::Event,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
|
@ -435,9 +519,14 @@ impl NativeAgent {
|
||||||
if thread.model().is_none()
|
if thread.model().is_none()
|
||||||
&& let Some(model) = default_model.clone()
|
&& let Some(model) = default_model.clone()
|
||||||
{
|
{
|
||||||
thread.set_model(model);
|
thread.set_model(model, cx);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
let summarization_model = registry
|
||||||
|
.read(cx)
|
||||||
|
.thread_summary_model()
|
||||||
|
.map(|model| model.model.clone());
|
||||||
|
thread.set_summarization_model(summarization_model, cx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -461,10 +550,7 @@ impl NativeAgentConnection {
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
f: impl 'static
|
f: impl 'static
|
||||||
+ FnOnce(
|
+ FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
|
||||||
Entity<Thread>,
|
|
||||||
&mut App,
|
|
||||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
|
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
||||||
agent
|
agent
|
||||||
|
@ -476,10 +562,18 @@ impl NativeAgentConnection {
|
||||||
};
|
};
|
||||||
log::debug!("Found session for: {}", session_id);
|
log::debug!("Found session for: {}", session_id);
|
||||||
|
|
||||||
let mut response_stream = match f(thread, cx) {
|
let response_stream = match f(thread, cx) {
|
||||||
Ok(stream) => stream,
|
Ok(stream) => stream,
|
||||||
Err(err) => return Task::ready(Err(err)),
|
Err(err) => return Task::ready(Err(err)),
|
||||||
};
|
};
|
||||||
|
Self::handle_thread_events(response_stream, acp_thread, cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_thread_events(
|
||||||
|
mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
||||||
|
acp_thread: WeakEntity<AcpThread>,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
// Handle response stream and forward to session.acp_thread
|
// Handle response stream and forward to session.acp_thread
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
|
@ -488,7 +582,18 @@ impl NativeAgentConnection {
|
||||||
log::trace!("Received completion event: {:?}", event);
|
log::trace!("Received completion event: {:?}", event);
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::Text(text) => {
|
ThreadEvent::UserMessage(message) => {
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
for content in message.content {
|
||||||
|
thread.push_user_content_block(
|
||||||
|
Some(message.id.clone()),
|
||||||
|
content.into(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
ThreadEvent::AgentText(text) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.push_assistant_content_block(
|
thread.push_assistant_content_block(
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
|
@ -500,7 +605,7 @@ impl NativeAgentConnection {
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::Thinking(text) => {
|
ThreadEvent::AgentThinking(text) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.push_assistant_content_block(
|
thread.push_assistant_content_block(
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
|
@ -512,7 +617,7 @@ impl NativeAgentConnection {
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||||
tool_call,
|
tool_call,
|
||||||
options,
|
options,
|
||||||
response,
|
response,
|
||||||
|
@ -535,17 +640,21 @@ impl NativeAgentConnection {
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCall(tool_call) => {
|
ThreadEvent::ToolCall(tool_call) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.upsert_tool_call(tool_call, cx)
|
thread.upsert_tool_call(tool_call, cx)
|
||||||
})??;
|
})??;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
ThreadEvent::ToolCallUpdate(update) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.update_tool_call(update, cx)
|
thread.update_tool_call(update, cx)
|
||||||
})??;
|
})??;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::Stop(stop_reason) => {
|
ThreadEvent::TitleUpdate(title) => {
|
||||||
|
acp_thread
|
||||||
|
.update(cx, |thread, cx| thread.update_title(title, cx))??;
|
||||||
|
}
|
||||||
|
ThreadEvent::Stop(stop_reason) => {
|
||||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||||
return Ok(acp::PromptResponse { stop_reason });
|
return Ok(acp::PromptResponse { stop_reason });
|
||||||
}
|
}
|
||||||
|
@ -564,6 +673,31 @@ impl NativeAgentConnection {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn register_tools(
|
||||||
|
thread: &mut Thread,
|
||||||
|
project: Entity<Project>,
|
||||||
|
action_log: Entity<action_log::ActionLog>,
|
||||||
|
cx: &mut Context<Thread>,
|
||||||
|
) {
|
||||||
|
let language_registry = project.read(cx).languages().clone();
|
||||||
|
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||||
|
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||||
|
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
||||||
|
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||||
|
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
|
||||||
|
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||||
|
thread.add_tool(FindPathTool::new(project.clone()));
|
||||||
|
thread.add_tool(GrepTool::new(project.clone()));
|
||||||
|
thread.add_tool(ListDirectoryTool::new(project.clone()));
|
||||||
|
thread.add_tool(MovePathTool::new(project.clone()));
|
||||||
|
thread.add_tool(NowTool);
|
||||||
|
thread.add_tool(OpenTool::new(project.clone()));
|
||||||
|
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
||||||
|
thread.add_tool(TerminalTool::new(project.clone(), cx));
|
||||||
|
thread.add_tool(ThinkingTool);
|
||||||
|
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentModelSelector for NativeAgentConnection {
|
impl AgentModelSelector for NativeAgentConnection {
|
||||||
|
@ -598,8 +732,8 @@ impl AgentModelSelector for NativeAgentConnection {
|
||||||
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
||||||
};
|
};
|
||||||
|
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.set_model(model.clone());
|
thread.set_model(model.clone(), cx);
|
||||||
});
|
});
|
||||||
|
|
||||||
update_settings_file::<AgentSettings>(
|
update_settings_file::<AgentSettings>(
|
||||||
|
@ -660,7 +794,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
log::debug!("Starting thread creation in async context");
|
log::debug!("Starting thread creation in async context");
|
||||||
|
|
||||||
// Generate session ID
|
// Generate session ID
|
||||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
let session_id = generate_session_id();
|
||||||
log::info!("Created session with ID: {}", session_id);
|
log::info!("Created session with ID: {}", session_id);
|
||||||
|
|
||||||
// Create AcpThread
|
// Create AcpThread
|
||||||
|
@ -694,32 +828,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let summarization_model = registry.thread_summary_model().map(|c| c.model);
|
||||||
|
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let mut thread = Thread::new(
|
let mut thread = Thread::new(
|
||||||
|
session_id.clone(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
agent.project_context.clone(),
|
agent.project_context.clone(),
|
||||||
agent.context_server_registry.clone(),
|
agent.context_server_registry.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
agent.templates.clone(),
|
agent.templates.clone(),
|
||||||
default_model,
|
default_model,
|
||||||
|
summarization_model,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
Self::register_tools(&mut thread, project, action_log, cx);
|
||||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
|
||||||
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
|
||||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
|
||||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
|
||||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
|
||||||
thread.add_tool(FindPathTool::new(project.clone()));
|
|
||||||
thread.add_tool(GrepTool::new(project.clone()));
|
|
||||||
thread.add_tool(ListDirectoryTool::new(project.clone()));
|
|
||||||
thread.add_tool(MovePathTool::new(project.clone()));
|
|
||||||
thread.add_tool(NowTool);
|
|
||||||
thread.add_tool(OpenTool::new(project.clone()));
|
|
||||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
|
||||||
thread.add_tool(TerminalTool::new(project.clone(), cx));
|
|
||||||
thread.add_tool(ThinkingTool);
|
|
||||||
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
|
|
||||||
thread
|
thread
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -729,16 +852,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
|
|
||||||
// Store the session
|
// Store the session
|
||||||
agent.update(cx, |agent, cx| {
|
agent.update(cx, |agent, cx| {
|
||||||
agent.sessions.insert(
|
agent.insert_session(thread, acp_thread.clone(), cx)
|
||||||
session_id,
|
|
||||||
Session {
|
|
||||||
thread,
|
|
||||||
acp_thread: acp_thread.downgrade(),
|
|
||||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
|
||||||
this.sessions.remove(acp_thread.session_id());
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(acp_thread)
|
Ok(acp_thread)
|
||||||
|
@ -797,7 +911,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
log::info!("Cancelling on session: {}", session_id);
|
log::info!("Cancelling on session: {}", session_id);
|
||||||
self.0.update(cx, |agent, cx| {
|
self.0.update(cx, |agent, cx| {
|
||||||
if let Some(agent) = agent.sessions.get(session_id) {
|
if let Some(agent) = agent.sessions.get(session_id) {
|
||||||
agent.thread.update(cx, |thread, _cx| thread.cancel());
|
agent.thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -815,6 +929,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn history(self: Rc<Self>) -> Option<Rc<dyn AgentHistory>> {
|
||||||
|
Some(self)
|
||||||
|
}
|
||||||
|
|
||||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
@ -824,7 +942,121 @@ struct NativeAgentSessionEditor(Entity<Thread>);
|
||||||
|
|
||||||
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||||
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
||||||
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
|
Task::ready(
|
||||||
|
self.0
|
||||||
|
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl acp_thread::AgentHistory for NativeAgentConnection {
|
||||||
|
fn list_threads(&self, cx: &mut App) -> Task<Result<Vec<AcpThreadMetadata>>> {
|
||||||
|
let database = self.0.read(cx).thread_database.clone();
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
let threads = database.list_threads().await?;
|
||||||
|
anyhow::Ok(
|
||||||
|
threads
|
||||||
|
.into_iter()
|
||||||
|
.map(|thread| thread.to_acp(NATIVE_AGENT_SERVER_NAME))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn observe_history(&self, cx: &mut App) -> mpsc::UnboundedReceiver<AcpThreadMetadata> {
|
||||||
|
let (tx, rx) = mpsc::unbounded();
|
||||||
|
self.0.update(cx, |this, _| this.history_watchers.push(tx));
|
||||||
|
rx
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_thread(
|
||||||
|
self: Rc<Self>,
|
||||||
|
project: Entity<Project>,
|
||||||
|
_cwd: &Path,
|
||||||
|
session_id: acp::SessionId,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||||
|
let database = self.0.update(cx, |this, _| this.thread_database.clone());
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
let db_thread = database
|
||||||
|
.load_thread(session_id.clone())
|
||||||
|
.await?
|
||||||
|
.context("no such thread found")?;
|
||||||
|
|
||||||
|
let acp_thread = cx.update(|cx| {
|
||||||
|
cx.new(|cx| {
|
||||||
|
acp_thread::AcpThread::new(
|
||||||
|
db_thread.title.clone(),
|
||||||
|
self.clone(),
|
||||||
|
project.clone(),
|
||||||
|
session_id.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
|
||||||
|
let agent = self.0.clone();
|
||||||
|
|
||||||
|
// Create Thread
|
||||||
|
let thread = agent.update(cx, |agent, cx| {
|
||||||
|
let language_model_registry = LanguageModelRegistry::global(cx);
|
||||||
|
let configured_model = language_model_registry
|
||||||
|
.update(cx, |registry, cx| {
|
||||||
|
db_thread
|
||||||
|
.model
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|model| {
|
||||||
|
let model = SelectedModel {
|
||||||
|
provider: model.provider.clone().into(),
|
||||||
|
model: model.model.clone().into(),
|
||||||
|
};
|
||||||
|
registry.select_model(&model, cx)
|
||||||
|
})
|
||||||
|
.or_else(|| registry.default_model())
|
||||||
|
})
|
||||||
|
.context("no default model configured")?;
|
||||||
|
|
||||||
|
let model = agent
|
||||||
|
.models
|
||||||
|
.model_from_id(&LanguageModels::model_id(&configured_model.model))
|
||||||
|
.context("no model by id")?;
|
||||||
|
|
||||||
|
let summarization_model = language_model_registry
|
||||||
|
.read(cx)
|
||||||
|
.thread_summary_model()
|
||||||
|
.map(|c| c.model);
|
||||||
|
|
||||||
|
let thread = cx.new(|cx| {
|
||||||
|
let mut thread = Thread::from_db(
|
||||||
|
session_id,
|
||||||
|
db_thread,
|
||||||
|
project.clone(),
|
||||||
|
agent.project_context.clone(),
|
||||||
|
agent.context_server_registry.clone(),
|
||||||
|
action_log.clone(),
|
||||||
|
agent.templates.clone(),
|
||||||
|
model,
|
||||||
|
summarization_model,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
Self::register_tools(&mut thread, project, action_log, cx);
|
||||||
|
thread
|
||||||
|
});
|
||||||
|
|
||||||
|
anyhow::Ok(thread)
|
||||||
|
})??;
|
||||||
|
|
||||||
|
// Store the session
|
||||||
|
agent.update(cx, |agent, cx| {
|
||||||
|
agent.insert_session(thread.clone(), acp_thread.clone(), cx)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
|
||||||
|
cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(acp_thread)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -844,12 +1076,16 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use crate::HistoryStore;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
|
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
|
||||||
use fs::FakeFs;
|
use fs::FakeFs;
|
||||||
use gpui::TestAppContext;
|
use gpui::TestAppContext;
|
||||||
|
use language_model::fake_provider::FakeLanguageModel;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
|
use util::path;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
|
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
|
||||||
|
@ -1024,6 +1260,80 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_history(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
let project = Project::test(fs.clone(), [], cx).await;
|
||||||
|
|
||||||
|
let agent = NativeAgent::new(
|
||||||
|
project.clone(),
|
||||||
|
Templates::new(),
|
||||||
|
None,
|
||||||
|
fs.clone(),
|
||||||
|
&mut cx.to_async(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
||||||
|
let history = connection.clone().history().unwrap();
|
||||||
|
let history_store = cx.new(|cx| HistoryStore::get_or_init(cx));
|
||||||
|
|
||||||
|
history_store
|
||||||
|
.update(cx, |history_store, cx| {
|
||||||
|
history_store.load_history(NATIVE_AGENT_SERVER_NAME.clone(), history.as_ref(), cx)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let acp_thread = cx
|
||||||
|
.update(|cx| {
|
||||||
|
connection
|
||||||
|
.clone()
|
||||||
|
.new_thread(project.clone(), Path::new(path!("")), cx)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
||||||
|
let selector = connection.model_selector().unwrap();
|
||||||
|
|
||||||
|
let summarization_model: Arc<dyn LanguageModel> =
|
||||||
|
Arc::new(FakeLanguageModel::default()) as _;
|
||||||
|
|
||||||
|
agent.update(cx, |agent, cx| {
|
||||||
|
let thread = agent.sessions.get(&session_id).unwrap().thread.clone();
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.set_summarization_model(Some(summarization_model.clone()), cx);
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let model = cx
|
||||||
|
.update(|cx| selector.selected_model(&session_id, cx))
|
||||||
|
.await
|
||||||
|
.expect("selected_model should succeed");
|
||||||
|
let model = cx
|
||||||
|
.update(|cx| agent.read(cx).models().model_from_id(&model.id))
|
||||||
|
.unwrap();
|
||||||
|
let model = model.as_fake();
|
||||||
|
|
||||||
|
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Hi", cx));
|
||||||
|
let send = cx.foreground_executor().spawn(send);
|
||||||
|
cx.run_until_parked();
|
||||||
|
model.send_last_completion_stream_text_chunk("Hey");
|
||||||
|
model.end_last_completion_stream();
|
||||||
|
send.await.unwrap();
|
||||||
|
|
||||||
|
summarization_model
|
||||||
|
.as_fake()
|
||||||
|
.send_last_completion_stream_text_chunk("Saying Hello");
|
||||||
|
summarization_model.as_fake().end_last_completion_stream();
|
||||||
|
cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE);
|
||||||
|
|
||||||
|
let history = history_store.update(cx, |store, cx| store.entries(cx));
|
||||||
|
assert_eq!(history.len(), 1);
|
||||||
|
assert_eq!(history[0].title(), "Saying Hello");
|
||||||
|
}
|
||||||
|
|
||||||
fn init_test(cx: &mut TestAppContext) {
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
env_logger::try_init().ok();
|
env_logger::try_init().ok();
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
mod agent;
|
mod agent;
|
||||||
|
mod db;
|
||||||
|
mod history_store;
|
||||||
mod native_agent_server;
|
mod native_agent_server;
|
||||||
mod templates;
|
mod templates;
|
||||||
mod thread;
|
mod thread;
|
||||||
|
@ -8,7 +10,15 @@ mod tools;
|
||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
pub use agent::*;
|
pub use agent::*;
|
||||||
|
pub use db::*;
|
||||||
|
pub use history_store::*;
|
||||||
pub use native_agent_server::NativeAgentServer;
|
pub use native_agent_server::NativeAgentServer;
|
||||||
pub use templates::*;
|
pub use templates::*;
|
||||||
pub use thread::*;
|
pub use thread::*;
|
||||||
pub use tools::*;
|
pub use tools::*;
|
||||||
|
|
||||||
|
use agent_client_protocol as acp;
|
||||||
|
|
||||||
|
pub fn generate_session_id() -> acp::SessionId {
|
||||||
|
acp::SessionId(uuid::Uuid::new_v4().to_string().into())
|
||||||
|
}
|
||||||
|
|
488
crates/agent2/src/db.rs
Normal file
488
crates/agent2/src/db.rs
Normal file
|
@ -0,0 +1,488 @@
|
||||||
|
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
||||||
|
use acp_thread::{AcpThreadMetadata, AgentServerName};
|
||||||
|
use agent::thread_store;
|
||||||
|
use agent_client_protocol as acp;
|
||||||
|
use agent_settings::{AgentProfileId, CompletionMode};
|
||||||
|
use anyhow::{Result, anyhow};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use collections::{HashMap, IndexMap};
|
||||||
|
use futures::{FutureExt, future::Shared};
|
||||||
|
use gpui::{BackgroundExecutor, Global, Task};
|
||||||
|
use indoc::indoc;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sqlez::{
|
||||||
|
bindable::{Bind, Column},
|
||||||
|
connection::Connection,
|
||||||
|
statement::Statement,
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use ui::{App, SharedString};
|
||||||
|
|
||||||
|
pub type DbMessage = crate::Message;
|
||||||
|
pub type DbSummary = agent::thread::DetailedSummaryState;
|
||||||
|
pub type DbLanguageModel = thread_store::SerializedLanguageModel;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DbThreadMetadata {
|
||||||
|
pub id: acp::SessionId,
|
||||||
|
#[serde(alias = "summary")]
|
||||||
|
pub title: SharedString,
|
||||||
|
pub updated_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DbThreadMetadata {
|
||||||
|
pub fn to_acp(self, agent: AgentServerName) -> AcpThreadMetadata {
|
||||||
|
AcpThreadMetadata {
|
||||||
|
agent,
|
||||||
|
id: self.id,
|
||||||
|
title: self.title,
|
||||||
|
updated_at: self.updated_at,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct DbThread {
|
||||||
|
pub title: SharedString,
|
||||||
|
pub messages: Vec<DbMessage>,
|
||||||
|
pub updated_at: DateTime<Utc>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub summary: DbSummary,
|
||||||
|
#[serde(default)]
|
||||||
|
pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub cumulative_token_usage: language_model::TokenUsage,
|
||||||
|
#[serde(default)]
|
||||||
|
pub request_token_usage: Vec<language_model::TokenUsage>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub model: Option<DbLanguageModel>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub completion_mode: Option<CompletionMode>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub profile: Option<AgentProfileId>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DbThread {
|
||||||
|
pub const VERSION: &'static str = "0.3.0";
|
||||||
|
|
||||||
|
pub fn from_json(json: &[u8]) -> Result<Self> {
|
||||||
|
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
||||||
|
match saved_thread_json.get("version") {
|
||||||
|
Some(serde_json::Value::String(version)) => match version.as_str() {
|
||||||
|
Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
|
||||||
|
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
|
||||||
|
},
|
||||||
|
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
for msg in thread.messages {
|
||||||
|
let message = match msg.role {
|
||||||
|
language_model::Role::User => {
|
||||||
|
let mut content = Vec::new();
|
||||||
|
|
||||||
|
// Convert segments to content
|
||||||
|
for segment in msg.segments {
|
||||||
|
match segment {
|
||||||
|
thread_store::SerializedMessageSegment::Text { text } => {
|
||||||
|
content.push(UserMessageContent::Text(text));
|
||||||
|
}
|
||||||
|
thread_store::SerializedMessageSegment::Thinking { text, .. } => {
|
||||||
|
// User messages don't have thinking segments, but handle gracefully
|
||||||
|
content.push(UserMessageContent::Text(text));
|
||||||
|
}
|
||||||
|
thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
|
||||||
|
// User messages don't have redacted thinking, skip.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no content was added, add context as text if available
|
||||||
|
if content.is_empty() && !msg.context.is_empty() {
|
||||||
|
content.push(UserMessageContent::Text(msg.context));
|
||||||
|
}
|
||||||
|
|
||||||
|
crate::Message::User(UserMessage {
|
||||||
|
// MessageId from old format can't be meaningfully converted, so generate a new one
|
||||||
|
id: acp_thread::UserMessageId::new(),
|
||||||
|
content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
language_model::Role::Assistant => {
|
||||||
|
let mut content = Vec::new();
|
||||||
|
|
||||||
|
// Convert segments to content
|
||||||
|
for segment in msg.segments {
|
||||||
|
match segment {
|
||||||
|
thread_store::SerializedMessageSegment::Text { text } => {
|
||||||
|
content.push(AgentMessageContent::Text(text));
|
||||||
|
}
|
||||||
|
thread_store::SerializedMessageSegment::Thinking {
|
||||||
|
text,
|
||||||
|
signature,
|
||||||
|
} => {
|
||||||
|
content.push(AgentMessageContent::Thinking { text, signature });
|
||||||
|
}
|
||||||
|
thread_store::SerializedMessageSegment::RedactedThinking { data } => {
|
||||||
|
content.push(AgentMessageContent::RedactedThinking(data));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert tool uses
|
||||||
|
let mut tool_names_by_id = HashMap::default();
|
||||||
|
for tool_use in msg.tool_uses {
|
||||||
|
tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
|
||||||
|
content.push(AgentMessageContent::ToolUse(
|
||||||
|
language_model::LanguageModelToolUse {
|
||||||
|
id: tool_use.id,
|
||||||
|
name: tool_use.name.into(),
|
||||||
|
raw_input: serde_json::to_string(&tool_use.input)
|
||||||
|
.unwrap_or_default(),
|
||||||
|
input: tool_use.input,
|
||||||
|
is_input_complete: true,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert tool results
|
||||||
|
let mut tool_results = IndexMap::default();
|
||||||
|
for tool_result in msg.tool_results {
|
||||||
|
let name = tool_names_by_id
|
||||||
|
.remove(&tool_result.tool_use_id)
|
||||||
|
.unwrap_or_else(|| SharedString::from("unknown"));
|
||||||
|
tool_results.insert(
|
||||||
|
tool_result.tool_use_id.clone(),
|
||||||
|
language_model::LanguageModelToolResult {
|
||||||
|
tool_use_id: tool_result.tool_use_id,
|
||||||
|
tool_name: name.into(),
|
||||||
|
is_error: tool_result.is_error,
|
||||||
|
content: tool_result.content,
|
||||||
|
output: tool_result.output,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
crate::Message::Agent(AgentMessage {
|
||||||
|
content,
|
||||||
|
tool_results,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
language_model::Role::System => {
|
||||||
|
// Skip system messages as they're not supported in the new format
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
messages.push(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
title: thread.summary,
|
||||||
|
messages,
|
||||||
|
updated_at: thread.updated_at,
|
||||||
|
summary: thread.detailed_summary_state,
|
||||||
|
initial_project_snapshot: thread.initial_project_snapshot,
|
||||||
|
cumulative_token_usage: thread.cumulative_token_usage,
|
||||||
|
request_token_usage: thread.request_token_usage,
|
||||||
|
model: thread.model,
|
||||||
|
completion_mode: thread.completion_mode,
|
||||||
|
profile: thread.profile,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub static ZED_STATELESS: std::sync::LazyLock<bool> =
|
||||||
|
std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum DataType {
|
||||||
|
#[serde(rename = "json")]
|
||||||
|
Json,
|
||||||
|
#[serde(rename = "zstd")]
|
||||||
|
Zstd,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Bind for DataType {
|
||||||
|
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||||
|
let value = match self {
|
||||||
|
DataType::Json => "json",
|
||||||
|
DataType::Zstd => "zstd",
|
||||||
|
};
|
||||||
|
value.bind(statement, start_index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Column for DataType {
|
||||||
|
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||||
|
let (value, next_index) = String::column(statement, start_index)?;
|
||||||
|
let data_type = match value.as_str() {
|
||||||
|
"json" => DataType::Json,
|
||||||
|
"zstd" => DataType::Zstd,
|
||||||
|
_ => anyhow::bail!("Unknown data type: {}", value),
|
||||||
|
};
|
||||||
|
Ok((data_type, next_index))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct ThreadsDatabase {
|
||||||
|
executor: BackgroundExecutor,
|
||||||
|
connection: Arc<Mutex<Connection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
|
||||||
|
|
||||||
|
impl Global for GlobalThreadsDatabase {}
|
||||||
|
|
||||||
|
impl ThreadsDatabase {
|
||||||
|
fn connection(&self) -> Arc<Mutex<Connection>> {
|
||||||
|
self.connection.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
const COMPRESSION_LEVEL: i32 = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThreadsDatabase {
|
||||||
|
pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
|
||||||
|
if cx.has_global::<GlobalThreadsDatabase>() {
|
||||||
|
return cx.global::<GlobalThreadsDatabase>().0.clone();
|
||||||
|
}
|
||||||
|
let executor = cx.background_executor().clone();
|
||||||
|
let task = executor
|
||||||
|
.spawn({
|
||||||
|
let executor = executor.clone();
|
||||||
|
async move {
|
||||||
|
match ThreadsDatabase::new(executor) {
|
||||||
|
Ok(db) => Ok(Arc::new(db)),
|
||||||
|
Err(err) => Err(Arc::new(err)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.shared();
|
||||||
|
|
||||||
|
cx.set_global(GlobalThreadsDatabase(task.clone()));
|
||||||
|
task
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(executor: BackgroundExecutor) -> Result<Self> {
|
||||||
|
let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
|
||||||
|
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
|
||||||
|
} else {
|
||||||
|
let threads_dir = paths::data_dir().join("threads");
|
||||||
|
std::fs::create_dir_all(&threads_dir)?;
|
||||||
|
let sqlite_path = threads_dir.join("threads.db");
|
||||||
|
Connection::open_file(&sqlite_path.to_string_lossy())
|
||||||
|
};
|
||||||
|
|
||||||
|
connection.exec(indoc! {"
|
||||||
|
CREATE TABLE IF NOT EXISTS threads (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
summary TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL,
|
||||||
|
data_type TEXT NOT NULL,
|
||||||
|
data BLOB NOT NULL
|
||||||
|
)
|
||||||
|
"})?()
|
||||||
|
.map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
|
||||||
|
|
||||||
|
let db = Self {
|
||||||
|
executor: executor.clone(),
|
||||||
|
connection: Arc::new(Mutex::new(connection)),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_thread_sync(
|
||||||
|
connection: &Arc<Mutex<Connection>>,
|
||||||
|
id: acp::SessionId,
|
||||||
|
thread: DbThread,
|
||||||
|
) -> Result<DbThreadMetadata> {
|
||||||
|
let json_data = serde_json::to_string(&thread)?;
|
||||||
|
let title = thread.title.to_string();
|
||||||
|
let updated_at = thread.updated_at.to_rfc3339();
|
||||||
|
|
||||||
|
let connection = connection.lock();
|
||||||
|
|
||||||
|
let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
|
||||||
|
let data_type = DataType::Zstd;
|
||||||
|
let data = compressed;
|
||||||
|
|
||||||
|
let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
|
||||||
|
INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
|
||||||
|
"})?;
|
||||||
|
|
||||||
|
insert((id.0.clone(), title, updated_at, data_type, data))?;
|
||||||
|
|
||||||
|
Ok(DbThreadMetadata {
|
||||||
|
id,
|
||||||
|
title: thread.title,
|
||||||
|
updated_at: thread.updated_at,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
|
||||||
|
let connection = self.connection.clone();
|
||||||
|
|
||||||
|
self.executor.spawn(async move {
|
||||||
|
let connection = connection.lock();
|
||||||
|
let mut select =
|
||||||
|
connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
|
||||||
|
SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
|
||||||
|
"})?;
|
||||||
|
|
||||||
|
let rows = select(())?;
|
||||||
|
let mut threads = Vec::new();
|
||||||
|
|
||||||
|
for (id, summary, updated_at) in rows {
|
||||||
|
threads.push(DbThreadMetadata {
|
||||||
|
id: acp::SessionId(id),
|
||||||
|
title: summary.into(),
|
||||||
|
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(threads)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
|
||||||
|
let connection = self.connection.clone();
|
||||||
|
|
||||||
|
self.executor.spawn(async move {
|
||||||
|
let connection = connection.lock();
|
||||||
|
let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
|
||||||
|
SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
|
||||||
|
"})?;
|
||||||
|
|
||||||
|
let rows = select(id.0)?;
|
||||||
|
if let Some((data_type, data)) = rows.into_iter().next() {
|
||||||
|
let json_data = match data_type {
|
||||||
|
DataType::Zstd => {
|
||||||
|
let decompressed = zstd::decode_all(&data[..])?;
|
||||||
|
String::from_utf8(decompressed)?
|
||||||
|
}
|
||||||
|
DataType::Json => String::from_utf8(data)?,
|
||||||
|
};
|
||||||
|
dbg!(&json_data);
|
||||||
|
|
||||||
|
let thread = dbg!(DbThread::from_json(json_data.as_bytes()))?;
|
||||||
|
Ok(Some(thread))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save_thread(
|
||||||
|
&self,
|
||||||
|
id: acp::SessionId,
|
||||||
|
thread: DbThread,
|
||||||
|
) -> Task<Result<DbThreadMetadata>> {
|
||||||
|
let connection = self.connection.clone();
|
||||||
|
|
||||||
|
self.executor
|
||||||
|
.spawn(async move { Self::save_thread_sync(&connection, id, thread) })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
|
||||||
|
let connection = self.connection.clone();
|
||||||
|
|
||||||
|
self.executor.spawn(async move {
|
||||||
|
let connection = connection.lock();
|
||||||
|
|
||||||
|
let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
|
||||||
|
DELETE FROM threads WHERE id = ?
|
||||||
|
"})?;
|
||||||
|
|
||||||
|
delete(id.0)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use agent::MessageSegment;
|
||||||
|
use agent::context::LoadedContext;
|
||||||
|
use client::Client;
|
||||||
|
use fs::FakeFs;
|
||||||
|
use gpui::AppContext;
|
||||||
|
use gpui::TestAppContext;
|
||||||
|
use http_client::FakeHttpClient;
|
||||||
|
use language_model::Role;
|
||||||
|
use project::Project;
|
||||||
|
use settings::SettingsStore;
|
||||||
|
|
||||||
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
|
env_logger::try_init().ok();
|
||||||
|
cx.update(|cx| {
|
||||||
|
let settings_store = SettingsStore::test(cx);
|
||||||
|
cx.set_global(settings_store);
|
||||||
|
Project::init_settings(cx);
|
||||||
|
language::init(cx);
|
||||||
|
|
||||||
|
let http_client = FakeHttpClient::with_404_response();
|
||||||
|
let clock = Arc::new(clock::FakeSystemClock::new());
|
||||||
|
let client = Client::new(clock, http_client, cx);
|
||||||
|
agent::init(cx);
|
||||||
|
agent_settings::init(cx);
|
||||||
|
language_model::init(client.clone(), cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
let project = Project::test(fs, [], cx).await;
|
||||||
|
|
||||||
|
// Save a thread using the old agent.
|
||||||
|
{
|
||||||
|
let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
|
||||||
|
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.insert_message(
|
||||||
|
Role::User,
|
||||||
|
vec![MessageSegment::Text("Hey!".into())],
|
||||||
|
LoadedContext::default(),
|
||||||
|
vec![],
|
||||||
|
false,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
thread.insert_message(
|
||||||
|
Role::Assistant,
|
||||||
|
vec![MessageSegment::Text("How're you doing?".into())],
|
||||||
|
LoadedContext::default(),
|
||||||
|
vec![],
|
||||||
|
false,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
thread_store
|
||||||
|
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let db = cx.update(|cx| ThreadsDatabase::connect(cx)).await.unwrap();
|
||||||
|
let threads = db.list_threads().await.unwrap();
|
||||||
|
assert_eq!(threads.len(), 1);
|
||||||
|
let thread = db
|
||||||
|
.load_thread(threads[0].id.clone())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
|
||||||
|
assert_eq!(
|
||||||
|
thread.messages[1].to_markdown(),
|
||||||
|
"## Assistant\n\nHow're you doing?\n"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
174
crates/agent2/src/history_store.rs
Normal file
174
crates/agent2/src/history_store.rs
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
|
||||||
|
use agent_client_protocol as acp;
|
||||||
|
use agent_servers::AgentServer;
|
||||||
|
use assistant_context::SavedContextMetadata;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use collections::HashMap;
|
||||||
|
use gpui::{Entity, Global, SharedString, Task, prelude::*};
|
||||||
|
use project::Project;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use ui::App;
|
||||||
|
|
||||||
|
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||||
|
|
||||||
|
use crate::NativeAgentServer;
|
||||||
|
|
||||||
|
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
|
||||||
|
const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
|
||||||
|
const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
|
||||||
|
|
||||||
|
// todo!(put this in the UI)
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum HistoryEntry {
|
||||||
|
AcpThread(AcpThreadMetadata),
|
||||||
|
TextThread(SavedContextMetadata),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HistoryEntry {
|
||||||
|
pub fn updated_at(&self) -> DateTime<Utc> {
|
||||||
|
match self {
|
||||||
|
HistoryEntry::AcpThread(thread) => thread.updated_at,
|
||||||
|
HistoryEntry::TextThread(context) => context.mtime.to_utc(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> HistoryEntryId {
|
||||||
|
match self {
|
||||||
|
HistoryEntry::AcpThread(thread) => {
|
||||||
|
HistoryEntryId::Thread(thread.agent.clone(), thread.id.clone())
|
||||||
|
}
|
||||||
|
HistoryEntry::TextThread(context) => HistoryEntryId::Context(context.path.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn title(&self) -> &SharedString {
|
||||||
|
match self {
|
||||||
|
HistoryEntry::AcpThread(thread) => &thread.title,
|
||||||
|
HistoryEntry::TextThread(context) => &context.title,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generic identifier for a history entry.
|
||||||
|
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||||
|
pub enum HistoryEntryId {
|
||||||
|
Thread(AgentServerName, acp::SessionId),
|
||||||
|
Context(Arc<Path>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
enum SerializedRecentOpen {
|
||||||
|
Thread(String),
|
||||||
|
ContextName(String),
|
||||||
|
/// Old format which stores the full path
|
||||||
|
Context(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct AgentHistory {
|
||||||
|
entries: HashMap<acp::SessionId, AcpThreadMetadata>,
|
||||||
|
loaded: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct HistoryStore {
|
||||||
|
agents: HashMap<AgentServerName, AgentHistory>, // todo!() text threads
|
||||||
|
}
|
||||||
|
// note, we have to share the history store between all windows
|
||||||
|
// because we only get updates from one connection at a time.
|
||||||
|
struct GlobalHistoryStore(Entity<HistoryStore>);
|
||||||
|
impl Global for GlobalHistoryStore {}
|
||||||
|
|
||||||
|
impl HistoryStore {
|
||||||
|
pub fn get_or_init(project: &Entity<Project>, cx: &mut App) -> Entity<Self> {
|
||||||
|
if cx.has_global::<GlobalHistoryStore>() {
|
||||||
|
return cx.global::<GlobalHistoryStore>().0.clone();
|
||||||
|
}
|
||||||
|
let history_store = cx.new(|cx| HistoryStore::new(cx));
|
||||||
|
cx.set_global(GlobalHistoryStore(history_store.clone()));
|
||||||
|
let root_dir = project
|
||||||
|
.read(cx)
|
||||||
|
.visible_worktrees(cx)
|
||||||
|
.next()
|
||||||
|
.map(|worktree| worktree.read(cx).abs_path())
|
||||||
|
.unwrap_or_else(|| paths::home_dir().as_path().into());
|
||||||
|
|
||||||
|
let agent = NativeAgentServer::new(project.read(cx).fs().clone());
|
||||||
|
let connect = agent.connect(&root_dir, project, cx);
|
||||||
|
cx.spawn({
|
||||||
|
let history_store = history_store.clone();
|
||||||
|
async move |cx| {
|
||||||
|
let connection = connect.await?.history().unwrap();
|
||||||
|
history_store
|
||||||
|
.update(cx, |history_store, cx| {
|
||||||
|
history_store.load_history(agent.name(), connection.as_ref(), cx)
|
||||||
|
})?
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
history_store
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(_cx: &mut Context<Self>) -> Self {
|
||||||
|
Self {
|
||||||
|
agents: HashMap::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_history(&mut self, entry: AcpThreadMetadata, cx: &mut Context<Self>) {
|
||||||
|
let agent = self
|
||||||
|
.agents
|
||||||
|
.entry(entry.agent.clone())
|
||||||
|
.or_insert(Default::default());
|
||||||
|
|
||||||
|
agent.entries.insert(entry.id.clone(), entry);
|
||||||
|
cx.notify()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_history(
|
||||||
|
&mut self,
|
||||||
|
agent_name: AgentServerName,
|
||||||
|
connection: &dyn acp_thread::AgentHistory,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Task<anyhow::Result<()>> {
|
||||||
|
let threads = connection.list_threads(cx);
|
||||||
|
cx.spawn(async move |this, cx| {
|
||||||
|
let threads = threads.await?;
|
||||||
|
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.agents.insert(
|
||||||
|
agent_name,
|
||||||
|
AgentHistory {
|
||||||
|
loaded: true,
|
||||||
|
entries: threads.into_iter().map(|t| (t.id.clone(), t)).collect(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
cx.notify()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn entries(&mut self, _cx: &mut Context<Self>) -> Vec<HistoryEntry> {
|
||||||
|
let mut history_entries = Vec::new();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
|
||||||
|
return history_entries;
|
||||||
|
}
|
||||||
|
|
||||||
|
history_entries.extend(
|
||||||
|
self.agents
|
||||||
|
.values_mut()
|
||||||
|
.flat_map(|history| history.entries.values().cloned()) // todo!("surface the loading state?")
|
||||||
|
.map(HistoryEntry::AcpThread),
|
||||||
|
);
|
||||||
|
// todo!() include the text threads in here.
|
||||||
|
|
||||||
|
history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
|
||||||
|
history_entries
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn recent_entries(&mut self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
|
||||||
|
self.entries(cx).into_iter().take(limit).collect()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,11 +1,13 @@
|
||||||
use std::{path::Path, rc::Rc, sync::Arc};
|
use std::{path::Path, rc::Rc, sync::Arc};
|
||||||
|
|
||||||
|
use acp_thread::AgentServerName;
|
||||||
use agent_servers::AgentServer;
|
use agent_servers::AgentServer;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use gpui::{App, Entity, Task};
|
use gpui::{App, Entity, Task};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_store::PromptStore;
|
use prompt_store::PromptStore;
|
||||||
|
use ui::SharedString;
|
||||||
|
|
||||||
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
|
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
|
||||||
|
|
||||||
|
@ -20,9 +22,12 @@ impl NativeAgentServer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const NATIVE_AGENT_SERVER_NAME: AgentServerName =
|
||||||
|
AgentServerName(SharedString::new_static("Native Agent"));
|
||||||
|
|
||||||
impl AgentServer for NativeAgentServer {
|
impl AgentServer for NativeAgentServer {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> AgentServerName {
|
||||||
"Native Agent"
|
NATIVE_AGENT_SERVER_NAME.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn empty_state_headline(&self) -> &'static str {
|
fn empty_state_headline(&self) -> &'static str {
|
||||||
|
|
|
@ -343,7 +343,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let mut saw_partial_tool_use = false;
|
let mut saw_partial_tool_use = false;
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
// Look for a tool use in the thread's last message
|
// Look for a tool use in the thread's last message
|
||||||
let message = thread.last_message().unwrap();
|
let message = thread.last_message().unwrap();
|
||||||
|
@ -733,16 +733,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn expect_tool_call(
|
async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
|
||||||
) -> acp::ToolCall {
|
|
||||||
let event = events
|
let event = events
|
||||||
.next()
|
.next()
|
||||||
.await
|
.await
|
||||||
.expect("no tool call authorization event received")
|
.expect("no tool call authorization event received")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
|
ThreadEvent::ToolCall(tool_call) => return tool_call,
|
||||||
event => {
|
event => {
|
||||||
panic!("Unexpected event {event:?}");
|
panic!("Unexpected event {event:?}");
|
||||||
}
|
}
|
||||||
|
@ -750,7 +748,7 @@ async fn expect_tool_call(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn expect_tool_call_update_fields(
|
async fn expect_tool_call_update_fields(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||||
) -> acp::ToolCallUpdate {
|
) -> acp::ToolCallUpdate {
|
||||||
let event = events
|
let event = events
|
||||||
.next()
|
.next()
|
||||||
|
@ -758,7 +756,7 @@ async fn expect_tool_call_update_fields(
|
||||||
.expect("no tool call authorization event received")
|
.expect("no tool call authorization event received")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
|
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
|
||||||
return update;
|
return update;
|
||||||
}
|
}
|
||||||
event => {
|
event => {
|
||||||
|
@ -768,7 +766,7 @@ async fn expect_tool_call_update_fields(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn next_tool_call_authorization(
|
async fn next_tool_call_authorization(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||||
) -> ToolCallAuthorization {
|
) -> ToolCallAuthorization {
|
||||||
loop {
|
loop {
|
||||||
let event = events
|
let event = events
|
||||||
|
@ -776,7 +774,7 @@ async fn next_tool_call_authorization(
|
||||||
.await
|
.await
|
||||||
.expect("no tool call authorization event received")
|
.expect("no tool call authorization event received")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
||||||
let permission_kinds = tool_call_authorization
|
let permission_kinds = tool_call_authorization
|
||||||
.options
|
.options
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -943,13 +941,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
let mut echo_completed = false;
|
let mut echo_completed = false;
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
match event.unwrap() {
|
match event.unwrap() {
|
||||||
AgentResponseEvent::ToolCall(tool_call) => {
|
ThreadEvent::ToolCall(tool_call) => {
|
||||||
assert_eq!(tool_call.title, expected_tools.remove(0));
|
assert_eq!(tool_call.title, expected_tools.remove(0));
|
||||||
if tool_call.title == "Echo" {
|
if tool_call.title == "Echo" {
|
||||||
echo_id = Some(tool_call.id);
|
echo_id = Some(tool_call.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
||||||
acp::ToolCallUpdate {
|
acp::ToolCallUpdate {
|
||||||
id,
|
id,
|
||||||
fields:
|
fields:
|
||||||
|
@ -971,13 +969,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
// Cancel the current send and ensure that the event stream is closed, even
|
// Cancel the current send and ensure that the event stream is closed, even
|
||||||
// if one of the tools is still running.
|
// if one of the tools is still running.
|
||||||
thread.update(cx, |thread, _cx| thread.cancel());
|
thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||||
let events = events.collect::<Vec<_>>().await;
|
let events = events.collect::<Vec<_>>().await;
|
||||||
let last_event = events.last();
|
let last_event = events.last();
|
||||||
assert!(
|
assert!(
|
||||||
matches!(
|
matches!(
|
||||||
last_event,
|
last_event,
|
||||||
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
|
Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
|
||||||
),
|
),
|
||||||
"unexpected event {last_event:?}"
|
"unexpected event {last_event:?}"
|
||||||
);
|
);
|
||||||
|
@ -1159,7 +1157,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, _cx| thread.truncate(message_id))
|
.update(cx, |thread, cx| thread.truncate(message_id, cx))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
|
@ -1434,11 +1432,11 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Filters out the stop events for asserting against in tests
|
/// Filters out the stop events for asserting against in tests
|
||||||
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
|
||||||
result_events
|
result_events
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|event| match event.unwrap() {
|
.filter_map(|event| match event.unwrap() {
|
||||||
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
|
ThreadEvent::Stop(stop_reason) => Some(stop_reason),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
|
@ -1549,12 +1547,14 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
|
generate_session_id(),
|
||||||
project,
|
project,
|
||||||
project_context.clone(),
|
project_context.clone(),
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
action_log,
|
action_log,
|
||||||
templates,
|
templates,
|
||||||
Some(model.clone()),
|
Some(model.clone()),
|
||||||
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
_input: serde_json::Value,
|
||||||
|
_output: serde_json::Value,
|
||||||
|
_event_stream: ToolCallEventStream,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
||||||
use cloud_llm_client::CompletionIntent;
|
use cloud_llm_client::CompletionIntent;
|
||||||
use collections::HashSet;
|
use collections::HashSet;
|
||||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||||
use indoc::formatdoc;
|
use indoc::formatdoc;
|
||||||
use language::ToPoint;
|
|
||||||
use language::language_settings::{self, FormatOnSave};
|
use language::language_settings::{self, FormatOnSave};
|
||||||
|
use language::{LanguageRegistry, ToPoint};
|
||||||
use language_model::LanguageModelToolResultContent;
|
use language_model::LanguageModelToolResultContent;
|
||||||
use paths;
|
use paths;
|
||||||
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
||||||
|
@ -98,11 +98,13 @@ pub enum EditFileMode {
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct EditFileToolOutput {
|
pub struct EditFileToolOutput {
|
||||||
|
#[serde(alias = "original_path")]
|
||||||
input_path: PathBuf,
|
input_path: PathBuf,
|
||||||
project_path: PathBuf,
|
|
||||||
new_text: String,
|
new_text: String,
|
||||||
old_text: Arc<String>,
|
old_text: Arc<String>,
|
||||||
|
#[serde(default)]
|
||||||
diff: String,
|
diff: String,
|
||||||
|
#[serde(alias = "raw_output")]
|
||||||
edit_agent_output: EditAgentOutput,
|
edit_agent_output: EditAgentOutput,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct EditFileTool {
|
pub struct EditFileTool {
|
||||||
thread: Entity<Thread>,
|
thread: WeakEntity<Thread>,
|
||||||
|
language_registry: Arc<LanguageRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EditFileTool {
|
impl EditFileTool {
|
||||||
pub fn new(thread: Entity<Thread>) -> Self {
|
pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
|
||||||
Self { thread }
|
Self {
|
||||||
|
thread,
|
||||||
|
language_registry,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authorize(
|
fn authorize(
|
||||||
|
@ -167,8 +173,11 @@ impl EditFileTool {
|
||||||
|
|
||||||
// Check if path is inside the global config directory
|
// Check if path is inside the global config directory
|
||||||
// First check if it's already inside project - if not, try to canonicalize
|
// First check if it's already inside project - if not, try to canonicalize
|
||||||
let thread = self.thread.read(cx);
|
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
|
||||||
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
|
thread.project().read(cx).find_project_path(&input.path, cx)
|
||||||
|
}) else {
|
||||||
|
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||||
|
};
|
||||||
|
|
||||||
// If the path is inside the project, and it's not one of the above edge cases,
|
// If the path is inside the project, and it's not one of the above edge cases,
|
||||||
// then no confirmation is necessary. Otherwise, confirmation is necessary.
|
// then no confirmation is necessary. Otherwise, confirmation is necessary.
|
||||||
|
@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
|
||||||
event_stream: ToolCallEventStream,
|
event_stream: ToolCallEventStream,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Self::Output>> {
|
) -> Task<Result<Self::Output>> {
|
||||||
let project = self.thread.read(cx).project().clone();
|
let Ok(project) = self
|
||||||
|
.thread
|
||||||
|
.read_with(cx, |thread, _cx| thread.project().clone())
|
||||||
|
else {
|
||||||
|
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||||
|
};
|
||||||
let project_path = match resolve_path(&input, project.clone(), cx) {
|
let project_path = match resolve_path(&input, project.clone(), cx) {
|
||||||
Ok(path) => path,
|
Ok(path) => path,
|
||||||
Err(err) => return Task::ready(Err(anyhow!(err))),
|
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||||
|
@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(request) = self.thread.update(cx, |thread, cx| {
|
|
||||||
thread
|
|
||||||
.build_completion_request(CompletionIntent::ToolResults, cx)
|
|
||||||
.ok()
|
|
||||||
}) else {
|
|
||||||
return Task::ready(Err(anyhow!("Failed to build completion request")));
|
|
||||||
};
|
|
||||||
let thread = self.thread.read(cx);
|
|
||||||
let Some(model) = thread.model().cloned() else {
|
|
||||||
return Task::ready(Err(anyhow!("No language model configured")));
|
|
||||||
};
|
|
||||||
let action_log = thread.action_log().clone();
|
|
||||||
|
|
||||||
let authorize = self.authorize(&input, &event_stream, cx);
|
let authorize = self.authorize(&input, &event_stream, cx);
|
||||||
cx.spawn(async move |cx: &mut AsyncApp| {
|
cx.spawn(async move |cx: &mut AsyncApp| {
|
||||||
authorize.await?;
|
authorize.await?;
|
||||||
|
|
||||||
|
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
|
||||||
|
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
|
||||||
|
(request, thread.model().cloned(), thread.action_log().clone())
|
||||||
|
})?;
|
||||||
|
let request = request?;
|
||||||
|
let model = model.context("No language model configured")?;
|
||||||
|
|
||||||
let edit_format = EditFormat::from_model(model.clone())?;
|
let edit_format = EditFormat::from_model(model.clone())?;
|
||||||
let edit_agent = EditAgent::new(
|
let edit_agent = EditAgent::new(
|
||||||
model,
|
model,
|
||||||
|
@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
|
||||||
|
|
||||||
Ok(EditFileToolOutput {
|
Ok(EditFileToolOutput {
|
||||||
input_path: input.path,
|
input_path: input.path,
|
||||||
project_path: project_path.path.to_path_buf(),
|
|
||||||
new_text: new_text.clone(),
|
new_text: new_text.clone(),
|
||||||
old_text,
|
old_text,
|
||||||
diff: unified_diff,
|
diff: unified_diff,
|
||||||
|
@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
_input: Self::Input,
|
||||||
|
output: Self::Output,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Result<()> {
|
||||||
|
event_stream.update_diff(cx.new(|cx| {
|
||||||
|
Diff::finalized(
|
||||||
|
output.input_path,
|
||||||
|
Some(output.old_text.to_string()),
|
||||||
|
output.new_text,
|
||||||
|
self.language_registry.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate that the file path is valid, meaning:
|
/// Validate that the file path is valid, meaning:
|
||||||
|
@ -497,7 +523,6 @@ fn resolve_path(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{ContextServerRegistry, Templates};
|
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use client::TelemetrySettings;
|
use client::TelemetrySettings;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
|
@ -505,7 +530,6 @@ mod tests {
|
||||||
use language_model::fake_provider::FakeLanguageModel;
|
use language_model::fake_provider::FakeLanguageModel;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::rc::Rc;
|
|
||||||
use util::path;
|
use util::path;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
@ -515,21 +539,10 @@ mod tests {
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
fs.insert_tree("/root", json!({})).await;
|
fs.insert_tree("/root", json!({})).await;
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||||
Thread::new(
|
|
||||||
project,
|
|
||||||
Rc::default(),
|
|
||||||
context_server_registry,
|
|
||||||
action_log,
|
|
||||||
Templates::new(),
|
|
||||||
Some(model),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let result = cx
|
let result = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
let input = EditFileToolInput {
|
let input = EditFileToolInput {
|
||||||
|
@ -537,7 +550,11 @@ mod tests {
|
||||||
path: "root/nonexistent_file.txt".into(),
|
path: "root/nonexistent_file.txt".into(),
|
||||||
mode: EditFileMode::Edit,
|
mode: EditFileMode::Edit,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||||
|
input,
|
||||||
|
ToolCallEventStream::test().0,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -713,20 +730,8 @@ mod tests {
|
||||||
});
|
});
|
||||||
|
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log.clone(), cx));
|
||||||
Thread::new(
|
|
||||||
project,
|
|
||||||
Rc::default(),
|
|
||||||
context_server_registry,
|
|
||||||
action_log.clone(),
|
|
||||||
Templates::new(),
|
|
||||||
Some(model.clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
// First, test with format_on_save enabled
|
// First, test with format_on_save enabled
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
|
@ -750,9 +755,10 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool {
|
Arc::new(EditFileTool::new(
|
||||||
thread: thread.clone(),
|
thread.downgrade(),
|
||||||
})
|
language_registry.clone(),
|
||||||
|
))
|
||||||
.run(input, ToolCallEventStream::test().0, cx)
|
.run(input, ToolCallEventStream::test().0, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -806,7 +812,11 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||||
|
input,
|
||||||
|
ToolCallEventStream::test().0,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Stream the unformatted content
|
// Stream the unformatted content
|
||||||
|
@ -848,21 +858,10 @@ mod tests {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
let context_server_registry =
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log, cx));
|
||||||
Thread::new(
|
|
||||||
project,
|
|
||||||
Rc::default(),
|
|
||||||
context_server_registry,
|
|
||||||
action_log.clone(),
|
|
||||||
Templates::new(),
|
|
||||||
Some(model.clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
// First, test with remove_trailing_whitespace_on_save enabled
|
// First, test with remove_trailing_whitespace_on_save enabled
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
|
@ -887,9 +886,10 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool {
|
Arc::new(EditFileTool::new(
|
||||||
thread: thread.clone(),
|
thread.downgrade(),
|
||||||
})
|
language_registry.clone(),
|
||||||
|
))
|
||||||
.run(input, ToolCallEventStream::test().0, cx)
|
.run(input, ToolCallEventStream::test().0, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -938,10 +938,11 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool {
|
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
|
||||||
thread: thread.clone(),
|
input,
|
||||||
})
|
ToolCallEventStream::test().0,
|
||||||
.run(input, ToolCallEventStream::test().0, cx)
|
cx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Stream the content with trailing whitespace
|
// Stream the content with trailing whitespace
|
||||||
|
@ -974,22 +975,12 @@ mod tests {
|
||||||
init_test(cx);
|
init_test(cx);
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
let context_server_registry =
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||||
Thread::new(
|
|
||||||
project,
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
Rc::default(),
|
|
||||||
context_server_registry,
|
|
||||||
action_log.clone(),
|
|
||||||
Templates::new(),
|
|
||||||
Some(model.clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
|
||||||
fs.insert_tree("/root", json!({})).await;
|
fs.insert_tree("/root", json!({})).await;
|
||||||
|
|
||||||
// Test 1: Path with .zed component should require confirmation
|
// Test 1: Path with .zed component should require confirmation
|
||||||
|
@ -1111,22 +1102,12 @@ mod tests {
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
fs.insert_tree("/project", json!({})).await;
|
fs.insert_tree("/project", json!({})).await;
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
let context_server_registry =
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||||
Thread::new(
|
|
||||||
project,
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
Rc::default(),
|
|
||||||
context_server_registry,
|
|
||||||
action_log.clone(),
|
|
||||||
Templates::new(),
|
|
||||||
Some(model.clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
|
||||||
|
|
||||||
// Test global config paths - these should require confirmation if they exist and are outside the project
|
// Test global config paths - these should require confirmation if they exist and are outside the project
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
|
@ -1220,23 +1201,12 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||||
Thread::new(
|
|
||||||
project.clone(),
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
Rc::default(),
|
|
||||||
context_server_registry.clone(),
|
|
||||||
action_log.clone(),
|
|
||||||
Templates::new(),
|
|
||||||
Some(model.clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
|
||||||
|
|
||||||
// Test files in different worktrees
|
// Test files in different worktrees
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
|
@ -1302,22 +1272,12 @@ mod tests {
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||||
Thread::new(
|
|
||||||
project.clone(),
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
Rc::default(),
|
|
||||||
context_server_registry.clone(),
|
|
||||||
action_log.clone(),
|
|
||||||
Templates::new(),
|
|
||||||
Some(model.clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
|
||||||
|
|
||||||
// Test edge cases
|
// Test edge cases
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
|
@ -1386,22 +1346,12 @@ mod tests {
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||||
Thread::new(
|
|
||||||
project.clone(),
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
Rc::default(),
|
|
||||||
context_server_registry.clone(),
|
|
||||||
action_log.clone(),
|
|
||||||
Templates::new(),
|
|
||||||
Some(model.clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
|
||||||
|
|
||||||
// Test different EditFileMode values
|
// Test different EditFileMode values
|
||||||
let modes = vec![
|
let modes = vec![
|
||||||
|
@ -1467,22 +1417,12 @@ mod tests {
|
||||||
init_test(cx);
|
init_test(cx);
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let context_server_registry =
|
|
||||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||||
Thread::new(
|
|
||||||
project.clone(),
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||||
Rc::default(),
|
|
||||||
context_server_registry,
|
|
||||||
action_log.clone(),
|
|
||||||
Templates::new(),
|
|
||||||
Some(model.clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tool.initial_title(Err(json!({
|
tool.initial_title(Err(json!({
|
||||||
|
|
|
@ -319,7 +319,7 @@ mod tests {
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use util::test::TempTree;
|
use util::test::TempTree;
|
||||||
|
|
||||||
use crate::AgentResponseEvent;
|
use crate::ThreadEvent;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
@ -396,7 +396,7 @@ mod tests {
|
||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let event = stream_rx.try_next();
|
let event = stream_rx.try_next();
|
||||||
if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event {
|
if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
|
||||||
auth.response.send(auth.options[0].id.clone()).unwrap();
|
auth.response.send(auth.options[0].id.clone()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let result_text = if response.results.len() == 1 {
|
emit_update(&response, &event_stream);
|
||||||
"1 result".to_string()
|
|
||||||
} else {
|
|
||||||
format!("{} results", response.results.len())
|
|
||||||
};
|
|
||||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
|
||||||
title: Some(format!("Searched the web: {result_text}")),
|
|
||||||
content: Some(
|
|
||||||
response
|
|
||||||
.results
|
|
||||||
.iter()
|
|
||||||
.map(|result| acp::ToolCallContent::Content {
|
|
||||||
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
|
||||||
name: result.title.clone(),
|
|
||||||
uri: result.url.clone(),
|
|
||||||
title: Some(result.title.clone()),
|
|
||||||
description: Some(result.text.clone()),
|
|
||||||
mime_type: None,
|
|
||||||
annotations: None,
|
|
||||||
size: None,
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
Ok(WebSearchToolOutput(response))
|
Ok(WebSearchToolOutput(response))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
_input: Self::Input,
|
||||||
|
output: Self::Output,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Result<()> {
|
||||||
|
emit_update(&output.0, &event_stream);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
|
||||||
|
let result_text = if response.results.len() == 1 {
|
||||||
|
"1 result".to_string()
|
||||||
|
} else {
|
||||||
|
format!("{} results", response.results.len())
|
||||||
|
};
|
||||||
|
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||||
|
title: Some(format!("Searched the web: {result_text}")),
|
||||||
|
content: Some(
|
||||||
|
response
|
||||||
|
.results
|
||||||
|
.iter()
|
||||||
|
.map(|result| acp::ToolCallContent::Content {
|
||||||
|
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||||
|
name: result.title.clone(),
|
||||||
|
uri: result.url.clone(),
|
||||||
|
title: Some(result.title.clone()),
|
||||||
|
description: Some(result.text.clone()),
|
||||||
|
mime_type: None,
|
||||||
|
annotations: None,
|
||||||
|
size: None,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use std::{path::Path, rc::Rc};
|
use std::{path::Path, rc::Rc};
|
||||||
|
|
||||||
use crate::AgentServerCommand;
|
use crate::AgentServerCommand;
|
||||||
use acp_thread::AgentConnection;
|
use acp_thread::{AgentConnection, AgentServerName};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use gpui::AsyncApp;
|
use gpui::AsyncApp;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -14,12 +14,12 @@ mod v1;
|
||||||
pub struct UnsupportedVersion;
|
pub struct UnsupportedVersion;
|
||||||
|
|
||||||
pub async fn connect(
|
pub async fn connect(
|
||||||
server_name: &'static str,
|
server_name: AgentServerName,
|
||||||
command: AgentServerCommand,
|
command: AgentServerCommand,
|
||||||
root_dir: &Path,
|
root_dir: &Path,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Result<Rc<dyn AgentConnection>> {
|
) -> Result<Rc<dyn AgentConnection>> {
|
||||||
let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await;
|
let conn = v1::AcpConnection::stdio(server_name.clone(), command.clone(), &root_dir, cx).await;
|
||||||
|
|
||||||
match conn {
|
match conn {
|
||||||
Ok(conn) => Ok(Rc::new(conn) as _),
|
Ok(conn) => Ok(Rc::new(conn) as _),
|
||||||
|
|
|
@ -10,7 +10,7 @@ use ui::App;
|
||||||
use util::ResultExt as _;
|
use util::ResultExt as _;
|
||||||
|
|
||||||
use crate::AgentServerCommand;
|
use crate::AgentServerCommand;
|
||||||
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
|
use acp_thread::{AcpThread, AgentConnection, AgentServerName, AuthRequired};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct OldAcpClientDelegate {
|
struct OldAcpClientDelegate {
|
||||||
|
@ -354,7 +354,7 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct AcpConnection {
|
pub struct AcpConnection {
|
||||||
pub name: &'static str,
|
pub name: AgentServerName,
|
||||||
pub connection: acp_old::AgentConnection,
|
pub connection: acp_old::AgentConnection,
|
||||||
pub _child_status: Task<Result<()>>,
|
pub _child_status: Task<Result<()>>,
|
||||||
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
||||||
|
@ -362,7 +362,7 @@ pub struct AcpConnection {
|
||||||
|
|
||||||
impl AcpConnection {
|
impl AcpConnection {
|
||||||
pub fn stdio(
|
pub fn stdio(
|
||||||
name: &'static str,
|
name: AgentServerName,
|
||||||
command: AgentServerCommand,
|
command: AgentServerCommand,
|
||||||
root_dir: &Path,
|
root_dir: &Path,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
|
@ -443,7 +443,7 @@ impl AgentConnection for AcpConnection {
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let session_id = acp::SessionId("acp-old-no-id".into());
|
let session_id = acp::SessionId("acp-old-no-id".into());
|
||||||
AcpThread::new(self.name, self.clone(), project, session_id, cx)
|
AcpThread::new(self.name.0.clone(), self.clone(), project, session_id, cx)
|
||||||
});
|
});
|
||||||
current_thread.replace(thread.downgrade());
|
current_thread.replace(thread.downgrade());
|
||||||
thread
|
thread
|
||||||
|
|
|
@ -13,10 +13,10 @@ use anyhow::{Context as _, Result};
|
||||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||||
|
|
||||||
use crate::{AgentServerCommand, acp::UnsupportedVersion};
|
use crate::{AgentServerCommand, acp::UnsupportedVersion};
|
||||||
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
|
use acp_thread::{AcpThread, AgentConnection, AgentServerName, AuthRequired};
|
||||||
|
|
||||||
pub struct AcpConnection {
|
pub struct AcpConnection {
|
||||||
server_name: &'static str,
|
server_name: AgentServerName,
|
||||||
connection: Rc<acp::ClientSideConnection>,
|
connection: Rc<acp::ClientSideConnection>,
|
||||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||||
auth_methods: Vec<acp::AuthMethod>,
|
auth_methods: Vec<acp::AuthMethod>,
|
||||||
|
@ -31,7 +31,7 @@ const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
|
||||||
|
|
||||||
impl AcpConnection {
|
impl AcpConnection {
|
||||||
pub async fn stdio(
|
pub async fn stdio(
|
||||||
server_name: &'static str,
|
server_name: AgentServerName,
|
||||||
command: AgentServerCommand,
|
command: AgentServerCommand,
|
||||||
root_dir: &Path,
|
root_dir: &Path,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
|
@ -150,7 +150,7 @@ impl AgentConnection for AcpConnection {
|
||||||
|
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
AcpThread::new(
|
AcpThread::new(
|
||||||
self.server_name,
|
self.server_name.0.clone(),
|
||||||
self.clone(),
|
self.clone(),
|
||||||
project,
|
project,
|
||||||
session_id.clone(),
|
session_id.clone(),
|
||||||
|
|
|
@ -10,7 +10,7 @@ pub use claude::*;
|
||||||
pub use gemini::*;
|
pub use gemini::*;
|
||||||
pub use settings::*;
|
pub use settings::*;
|
||||||
|
|
||||||
use acp_thread::AgentConnection;
|
use acp_thread::{AgentConnection, AgentServerName};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use gpui::{App, AsyncApp, Entity, SharedString, Task};
|
use gpui::{App, AsyncApp, Entity, SharedString, Task};
|
||||||
|
@ -30,7 +30,7 @@ pub fn init(cx: &mut App) {
|
||||||
|
|
||||||
pub trait AgentServer: Send {
|
pub trait AgentServer: Send {
|
||||||
fn logo(&self) -> ui::IconName;
|
fn logo(&self) -> ui::IconName;
|
||||||
fn name(&self) -> &'static str;
|
fn name(&self) -> AgentServerName;
|
||||||
fn empty_state_headline(&self) -> &'static str;
|
fn empty_state_headline(&self) -> &'static str;
|
||||||
fn empty_state_message(&self) -> &'static str;
|
fn empty_state_message(&self) -> &'static str;
|
||||||
|
|
||||||
|
|
|
@ -30,18 +30,18 @@ use util::{ResultExt, debug_panic};
|
||||||
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
|
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
|
||||||
use crate::claude::tools::ClaudeTool;
|
use crate::claude::tools::ClaudeTool;
|
||||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
|
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
|
||||||
use acp_thread::{AcpThread, AgentConnection};
|
use acp_thread::{AcpThread, AgentConnection, AgentServerName};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ClaudeCode;
|
pub struct ClaudeCode;
|
||||||
|
|
||||||
impl AgentServer for ClaudeCode {
|
impl AgentServer for ClaudeCode {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> AgentServerName {
|
||||||
"Claude Code"
|
AgentServerName("Claude Code".into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn empty_state_headline(&self) -> &'static str {
|
fn empty_state_headline(&self) -> &'static str {
|
||||||
self.name()
|
"Claude Code"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn empty_state_message(&self) -> &'static str {
|
fn empty_state_message(&self) -> &'static str {
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::path::Path;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
use crate::{AgentServer, AgentServerCommand};
|
use crate::{AgentServer, AgentServerCommand};
|
||||||
use acp_thread::{AgentConnection, LoadError};
|
use acp_thread::{AgentConnection, AgentServerName, LoadError};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use gpui::{Entity, Task};
|
use gpui::{Entity, Task};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
@ -17,8 +17,8 @@ pub struct Gemini;
|
||||||
const ACP_ARG: &str = "--experimental-acp";
|
const ACP_ARG: &str = "--experimental-acp";
|
||||||
|
|
||||||
impl AgentServer for Gemini {
|
impl AgentServer for Gemini {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> AgentServerName {
|
||||||
"Gemini"
|
AgentServerName("Gemini".into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn empty_state_headline(&self) -> &'static str {
|
fn empty_state_headline(&self) -> &'static str {
|
||||||
|
|
|
@ -3,8 +3,10 @@ mod entry_view_state;
|
||||||
mod message_editor;
|
mod message_editor;
|
||||||
mod model_selector;
|
mod model_selector;
|
||||||
mod model_selector_popover;
|
mod model_selector_popover;
|
||||||
|
mod thread_history;
|
||||||
mod thread_view;
|
mod thread_view;
|
||||||
|
|
||||||
pub use model_selector::AcpModelSelector;
|
pub use model_selector::AcpModelSelector;
|
||||||
pub use model_selector_popover::AcpModelSelectorPopover;
|
pub use model_selector_popover::AcpModelSelectorPopover;
|
||||||
|
pub use thread_history::{AcpThreadHistory, ThreadHistoryEvent};
|
||||||
pub use thread_view::AcpThreadView;
|
pub use thread_view::AcpThreadView;
|
||||||
|
|
796
crates/agent_ui/src/acp/thread_history.rs
Normal file
796
crates/agent_ui/src/acp/thread_history.rs
Normal file
|
@ -0,0 +1,796 @@
|
||||||
|
use crate::RemoveSelectedThread;
|
||||||
|
use agent_servers::AgentServer;
|
||||||
|
use agent2::{HistoryEntry, HistoryStore, NativeAgentServer};
|
||||||
|
use chrono::{Datelike as _, Local, NaiveDate, TimeDelta};
|
||||||
|
use editor::{Editor, EditorEvent};
|
||||||
|
use fuzzy::{StringMatch, StringMatchCandidate};
|
||||||
|
use gpui::{
|
||||||
|
App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ScrollStrategy, Stateful,
|
||||||
|
Task, UniformListScrollHandle, Window, uniform_list,
|
||||||
|
};
|
||||||
|
use project::Project;
|
||||||
|
use std::{fmt::Display, ops::Range, sync::Arc};
|
||||||
|
use time::{OffsetDateTime, UtcOffset};
|
||||||
|
use ui::{
|
||||||
|
HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState,
|
||||||
|
Tooltip, prelude::*,
|
||||||
|
};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
pub struct AcpThreadHistory {
|
||||||
|
pub(crate) history_store: Entity<HistoryStore>,
|
||||||
|
scroll_handle: UniformListScrollHandle,
|
||||||
|
selected_index: usize,
|
||||||
|
hovered_index: Option<usize>,
|
||||||
|
search_editor: Entity<Editor>,
|
||||||
|
all_entries: Arc<Vec<HistoryEntry>>,
|
||||||
|
// When the search is empty, we display date separators between history entries
|
||||||
|
// This vector contains an enum of either a separator or an actual entry
|
||||||
|
separated_items: Vec<ListItemType>,
|
||||||
|
// Maps entry indexes to list item indexes
|
||||||
|
separated_item_indexes: Vec<u32>,
|
||||||
|
_separated_items_task: Option<Task<()>>,
|
||||||
|
search_state: SearchState,
|
||||||
|
scrollbar_visibility: bool,
|
||||||
|
scrollbar_state: ScrollbarState,
|
||||||
|
local_timezone: UtcOffset,
|
||||||
|
_subscriptions: Vec<gpui::Subscription>,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum SearchState {
|
||||||
|
Empty,
|
||||||
|
Searching {
|
||||||
|
query: SharedString,
|
||||||
|
_task: Task<()>,
|
||||||
|
},
|
||||||
|
Searched {
|
||||||
|
query: SharedString,
|
||||||
|
matches: Vec<StringMatch>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ListItemType {
|
||||||
|
BucketSeparator(TimeBucket),
|
||||||
|
Entry {
|
||||||
|
index: usize,
|
||||||
|
format: EntryTimeFormat,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum ThreadHistoryEvent {
|
||||||
|
Open(HistoryEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventEmitter<ThreadHistoryEvent> for AcpThreadHistory {}
|
||||||
|
|
||||||
|
impl AcpThreadHistory {
|
||||||
|
pub(crate) fn new(
|
||||||
|
project: &Entity<Project>,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Self {
|
||||||
|
let search_editor = cx.new(|cx| {
|
||||||
|
let mut editor = Editor::single_line(window, cx);
|
||||||
|
editor.set_placeholder_text("Search threads...", cx);
|
||||||
|
editor
|
||||||
|
});
|
||||||
|
let history_store = HistoryStore::get_or_init(project, cx);
|
||||||
|
|
||||||
|
let search_editor_subscription =
|
||||||
|
cx.subscribe(&search_editor, |this, search_editor, event, cx| {
|
||||||
|
if let EditorEvent::BufferEdited = event {
|
||||||
|
let query = search_editor.read(cx).text(cx);
|
||||||
|
this.search(query.into(), cx);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let history_store_subscription = cx.observe(&history_store, |this, _, cx| {
|
||||||
|
this.update_all_entries(cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
let scroll_handle = UniformListScrollHandle::default();
|
||||||
|
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
|
||||||
|
|
||||||
|
let mut this = Self {
|
||||||
|
history_store,
|
||||||
|
scroll_handle,
|
||||||
|
selected_index: 0,
|
||||||
|
hovered_index: None,
|
||||||
|
search_state: SearchState::Empty,
|
||||||
|
all_entries: Default::default(),
|
||||||
|
separated_items: Default::default(),
|
||||||
|
separated_item_indexes: Default::default(),
|
||||||
|
search_editor,
|
||||||
|
scrollbar_visibility: true,
|
||||||
|
scrollbar_state,
|
||||||
|
local_timezone: UtcOffset::from_whole_seconds(
|
||||||
|
chrono::Local::now().offset().local_minus_utc(),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
_subscriptions: vec![search_editor_subscription, history_store_subscription],
|
||||||
|
_separated_items_task: None,
|
||||||
|
};
|
||||||
|
this.update_all_entries(cx);
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_all_entries(&mut self, cx: &mut Context<Self>) {
|
||||||
|
let new_entries: Arc<Vec<HistoryEntry>> = self
|
||||||
|
.history_store
|
||||||
|
.update(cx, |store, cx| store.entries(cx))
|
||||||
|
.into();
|
||||||
|
|
||||||
|
self._separated_items_task.take();
|
||||||
|
|
||||||
|
let mut items = Vec::with_capacity(new_entries.len() + 1);
|
||||||
|
let mut indexes = Vec::with_capacity(new_entries.len() + 1);
|
||||||
|
|
||||||
|
let bg_task = cx.background_spawn(async move {
|
||||||
|
let mut bucket = None;
|
||||||
|
let today = Local::now().naive_local().date();
|
||||||
|
|
||||||
|
for (index, entry) in new_entries.iter().enumerate() {
|
||||||
|
let entry_date = entry
|
||||||
|
.updated_at()
|
||||||
|
.with_timezone(&Local)
|
||||||
|
.naive_local()
|
||||||
|
.date();
|
||||||
|
let entry_bucket = TimeBucket::from_dates(today, entry_date);
|
||||||
|
|
||||||
|
if Some(entry_bucket) != bucket {
|
||||||
|
bucket = Some(entry_bucket);
|
||||||
|
items.push(ListItemType::BucketSeparator(entry_bucket));
|
||||||
|
}
|
||||||
|
|
||||||
|
indexes.push(items.len() as u32);
|
||||||
|
items.push(ListItemType::Entry {
|
||||||
|
index,
|
||||||
|
format: entry_bucket.into(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(new_entries, items, indexes)
|
||||||
|
});
|
||||||
|
|
||||||
|
let task = cx.spawn(async move |this, cx| {
|
||||||
|
let (new_entries, items, indexes) = bg_task.await;
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
let previously_selected_entry =
|
||||||
|
this.all_entries.get(this.selected_index).map(|e| e.id());
|
||||||
|
|
||||||
|
this.all_entries = new_entries;
|
||||||
|
this.separated_items = items;
|
||||||
|
this.separated_item_indexes = indexes;
|
||||||
|
|
||||||
|
match &this.search_state {
|
||||||
|
SearchState::Empty => {
|
||||||
|
if this.selected_index >= this.all_entries.len() {
|
||||||
|
this.set_selected_entry_index(
|
||||||
|
this.all_entries.len().saturating_sub(1),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
} else if let Some(prev_id) = previously_selected_entry {
|
||||||
|
if let Some(new_ix) = this
|
||||||
|
.all_entries
|
||||||
|
.iter()
|
||||||
|
.position(|probe| probe.id() == prev_id)
|
||||||
|
{
|
||||||
|
this.set_selected_entry_index(new_ix, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {
|
||||||
|
this.search(query.clone(), cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
.log_err();
|
||||||
|
});
|
||||||
|
self._separated_items_task = Some(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search(&mut self, query: SharedString, cx: &mut Context<Self>) {
|
||||||
|
if query.is_empty() {
|
||||||
|
self.search_state = SearchState::Empty;
|
||||||
|
cx.notify();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let all_entries = self.all_entries.clone();
|
||||||
|
|
||||||
|
let fuzzy_search_task = cx.background_spawn({
|
||||||
|
let query = query.clone();
|
||||||
|
let executor = cx.background_executor().clone();
|
||||||
|
async move {
|
||||||
|
let mut candidates = Vec::with_capacity(all_entries.len());
|
||||||
|
|
||||||
|
for (idx, entry) in all_entries.iter().enumerate() {
|
||||||
|
match entry {
|
||||||
|
HistoryEntry::AcpThread(thread) => {
|
||||||
|
candidates.push(StringMatchCandidate::new(idx, &thread.title));
|
||||||
|
}
|
||||||
|
HistoryEntry::TextThread(context) => {
|
||||||
|
candidates.push(StringMatchCandidate::new(idx, &context.title));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const MAX_MATCHES: usize = 100;
|
||||||
|
|
||||||
|
fuzzy::match_strings(
|
||||||
|
&candidates,
|
||||||
|
&query,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
MAX_MATCHES,
|
||||||
|
&Default::default(),
|
||||||
|
executor,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let task = cx.spawn({
|
||||||
|
let query = query.clone();
|
||||||
|
async move |this, cx| {
|
||||||
|
let matches = fuzzy_search_task.await;
|
||||||
|
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
let SearchState::Searching {
|
||||||
|
query: current_query,
|
||||||
|
_task,
|
||||||
|
} = &this.search_state
|
||||||
|
else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
if &query == current_query {
|
||||||
|
this.search_state = SearchState::Searched {
|
||||||
|
query: query.clone(),
|
||||||
|
matches,
|
||||||
|
};
|
||||||
|
|
||||||
|
this.set_selected_entry_index(0, cx);
|
||||||
|
cx.notify();
|
||||||
|
};
|
||||||
|
})
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
self.search_state = SearchState::Searching { query, _task: task };
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matched_count(&self) -> usize {
|
||||||
|
match &self.search_state {
|
||||||
|
SearchState::Empty => self.all_entries.len(),
|
||||||
|
SearchState::Searching { .. } => 0,
|
||||||
|
SearchState::Searched { matches, .. } => matches.len(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_item_count(&self) -> usize {
|
||||||
|
match &self.search_state {
|
||||||
|
SearchState::Empty => self.separated_items.len(),
|
||||||
|
SearchState::Searching { .. } => 0,
|
||||||
|
SearchState::Searched { matches, .. } => matches.len(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search_produced_no_matches(&self) -> bool {
|
||||||
|
match &self.search_state {
|
||||||
|
SearchState::Empty => false,
|
||||||
|
SearchState::Searching { .. } => false,
|
||||||
|
SearchState::Searched { matches, .. } => matches.is_empty(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_match(&self, ix: usize) -> Option<&HistoryEntry> {
|
||||||
|
match &self.search_state {
|
||||||
|
SearchState::Empty => self.all_entries.get(ix),
|
||||||
|
SearchState::Searching { .. } => None,
|
||||||
|
SearchState::Searched { matches, .. } => matches
|
||||||
|
.get(ix)
|
||||||
|
.and_then(|m| self.all_entries.get(m.candidate_id)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn select_previous(
|
||||||
|
&mut self,
|
||||||
|
_: &menu::SelectPrevious,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let count = self.matched_count();
|
||||||
|
if count > 0 {
|
||||||
|
if self.selected_index == 0 {
|
||||||
|
self.set_selected_entry_index(count - 1, cx);
|
||||||
|
} else {
|
||||||
|
self.set_selected_entry_index(self.selected_index - 1, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn select_next(
|
||||||
|
&mut self,
|
||||||
|
_: &menu::SelectNext,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let count = self.matched_count();
|
||||||
|
if count > 0 {
|
||||||
|
if self.selected_index == count - 1 {
|
||||||
|
self.set_selected_entry_index(0, cx);
|
||||||
|
} else {
|
||||||
|
self.set_selected_entry_index(self.selected_index + 1, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn select_first(
|
||||||
|
&mut self,
|
||||||
|
_: &menu::SelectFirst,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let count = self.matched_count();
|
||||||
|
if count > 0 {
|
||||||
|
self.set_selected_entry_index(0, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
|
||||||
|
let count = self.matched_count();
|
||||||
|
if count > 0 {
|
||||||
|
self.set_selected_entry_index(count - 1, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_selected_entry_index(&mut self, entry_index: usize, cx: &mut Context<Self>) {
|
||||||
|
self.selected_index = entry_index;
|
||||||
|
|
||||||
|
let scroll_ix = match self.search_state {
|
||||||
|
SearchState::Empty | SearchState::Searching { .. } => self
|
||||||
|
.separated_item_indexes
|
||||||
|
.get(entry_index)
|
||||||
|
.map(|ix| *ix as usize)
|
||||||
|
.unwrap_or(entry_index + 1),
|
||||||
|
SearchState::Searched { .. } => entry_index,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.scroll_handle
|
||||||
|
.scroll_to_item(scroll_ix, ScrollStrategy::Top);
|
||||||
|
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
|
||||||
|
if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(
|
||||||
|
div()
|
||||||
|
.occlude()
|
||||||
|
.id("thread-history-scroll")
|
||||||
|
.h_full()
|
||||||
|
.bg(cx.theme().colors().panel_background.opacity(0.8))
|
||||||
|
.border_l_1()
|
||||||
|
.border_color(cx.theme().colors().border_variant)
|
||||||
|
.absolute()
|
||||||
|
.right_1()
|
||||||
|
.top_0()
|
||||||
|
.bottom_0()
|
||||||
|
.w_4()
|
||||||
|
.pl_1()
|
||||||
|
.cursor_default()
|
||||||
|
.on_mouse_move(cx.listener(|_, _, _window, cx| {
|
||||||
|
cx.notify();
|
||||||
|
cx.stop_propagation()
|
||||||
|
}))
|
||||||
|
.on_hover(|_, _window, cx| {
|
||||||
|
cx.stop_propagation();
|
||||||
|
})
|
||||||
|
.on_any_mouse_down(|_, _window, cx| {
|
||||||
|
cx.stop_propagation();
|
||||||
|
})
|
||||||
|
.on_scroll_wheel(cx.listener(|_, _, _window, cx| {
|
||||||
|
cx.notify();
|
||||||
|
}))
|
||||||
|
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
|
||||||
|
self.confirm_entry(self.selected_index, cx);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn confirm_entry(&mut self, ix: usize, cx: &mut Context<Self>) {
|
||||||
|
let Some(entry) = self.get_match(ix) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
cx.emit(ThreadHistoryEvent::Open(entry.clone()));
|
||||||
|
// let task_result = match entry {
|
||||||
|
// HistoryEntry::Thread(thread) => {
|
||||||
|
// self.agent_panel.update(cx, move |agent_panel, cx| todo!())
|
||||||
|
// }
|
||||||
|
// HistoryEntry::Context(context) => {
|
||||||
|
// self.agent_panel.update(cx, move |agent_panel, cx| {
|
||||||
|
// agent_panel.open_saved_prompt_editor(context.path.clone(), window, cx)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
|
// if let Some(task) = task_result.log_err() {
|
||||||
|
// task.detach_and_log_err(cx);
|
||||||
|
// };
|
||||||
|
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_selected_thread(
|
||||||
|
&mut self,
|
||||||
|
_: &RemoveSelectedThread,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
self.remove_thread(self.selected_index, cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_thread(&mut self, ix: usize, cx: &mut Context<Self>) {
|
||||||
|
let Some(entry) = self.get_match(ix) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
todo!();
|
||||||
|
// let task_result = match entry {
|
||||||
|
// HistoryEntry::Thread(thread) => todo!(),
|
||||||
|
// HistoryEntry::Context(context) => self
|
||||||
|
// .agent_panel
|
||||||
|
// .update(cx, |this, cx| this.delete_context(context.path.clone(), cx)),
|
||||||
|
// };
|
||||||
|
|
||||||
|
// if let Some(task) = task_result.log_err() {
|
||||||
|
// task.detach_and_log_err(cx);
|
||||||
|
// };
|
||||||
|
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_items(
|
||||||
|
&mut self,
|
||||||
|
range: Range<usize>,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Vec<AnyElement> {
|
||||||
|
match &self.search_state {
|
||||||
|
SearchState::Empty => self
|
||||||
|
.separated_items
|
||||||
|
.get(range)
|
||||||
|
.iter()
|
||||||
|
.flat_map(|items| {
|
||||||
|
items
|
||||||
|
.iter()
|
||||||
|
.map(|item| self.render_list_item(item, vec![], cx))
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
SearchState::Searched { matches, .. } => matches[range]
|
||||||
|
.iter()
|
||||||
|
.filter_map(|m| {
|
||||||
|
let entry = self.all_entries.get(m.candidate_id)?;
|
||||||
|
Some(self.render_history_entry(
|
||||||
|
entry,
|
||||||
|
EntryTimeFormat::DateAndTime,
|
||||||
|
m.candidate_id,
|
||||||
|
m.positions.clone(),
|
||||||
|
cx,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
SearchState::Searching { .. } => {
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_list_item(
|
||||||
|
&self,
|
||||||
|
item: &ListItemType,
|
||||||
|
highlight_positions: Vec<usize>,
|
||||||
|
cx: &Context<Self>,
|
||||||
|
) -> AnyElement {
|
||||||
|
match item {
|
||||||
|
ListItemType::Entry { index, format } => match self.all_entries.get(*index) {
|
||||||
|
Some(entry) => self
|
||||||
|
.render_history_entry(entry, *format, *index, highlight_positions, cx)
|
||||||
|
.into_any(),
|
||||||
|
None => Empty.into_any_element(),
|
||||||
|
},
|
||||||
|
ListItemType::BucketSeparator(bucket) => div()
|
||||||
|
.px(DynamicSpacing::Base06.rems(cx))
|
||||||
|
.pt_2()
|
||||||
|
.pb_1()
|
||||||
|
.child(
|
||||||
|
Label::new(bucket.to_string())
|
||||||
|
.size(LabelSize::XSmall)
|
||||||
|
.color(Color::Muted),
|
||||||
|
)
|
||||||
|
.into_any_element(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_history_entry(
|
||||||
|
&self,
|
||||||
|
entry: &HistoryEntry,
|
||||||
|
format: EntryTimeFormat,
|
||||||
|
list_entry_ix: usize,
|
||||||
|
highlight_positions: Vec<usize>,
|
||||||
|
cx: &Context<Self>,
|
||||||
|
) -> AnyElement {
|
||||||
|
let selected = list_entry_ix == self.selected_index;
|
||||||
|
let hovered = Some(list_entry_ix) == self.hovered_index;
|
||||||
|
let timestamp = entry.updated_at().timestamp();
|
||||||
|
let thread_timestamp = format.format_timestamp(timestamp, self.local_timezone);
|
||||||
|
|
||||||
|
h_flex()
|
||||||
|
.w_full()
|
||||||
|
.pb_1()
|
||||||
|
.child(
|
||||||
|
ListItem::new(list_entry_ix)
|
||||||
|
.rounded()
|
||||||
|
.toggle_state(selected)
|
||||||
|
.spacing(ListItemSpacing::Sparse)
|
||||||
|
.start_slot(
|
||||||
|
h_flex()
|
||||||
|
.w_full()
|
||||||
|
.gap_2()
|
||||||
|
.justify_between()
|
||||||
|
.child(
|
||||||
|
HighlightedLabel::new(entry.title(), highlight_positions)
|
||||||
|
.size(LabelSize::Small)
|
||||||
|
.truncate(),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
Label::new(thread_timestamp)
|
||||||
|
.color(Color::Muted)
|
||||||
|
.size(LabelSize::XSmall),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.on_hover(cx.listener(move |this, is_hovered, _window, cx| {
|
||||||
|
if *is_hovered {
|
||||||
|
this.hovered_index = Some(list_entry_ix);
|
||||||
|
} else if this.hovered_index == Some(list_entry_ix) {
|
||||||
|
this.hovered_index = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
cx.notify();
|
||||||
|
}))
|
||||||
|
.end_slot::<IconButton>(if hovered || selected {
|
||||||
|
Some(
|
||||||
|
IconButton::new("delete", IconName::Trash)
|
||||||
|
.shape(IconButtonShape::Square)
|
||||||
|
.icon_size(IconSize::XSmall)
|
||||||
|
.icon_color(Color::Muted)
|
||||||
|
.tooltip(move |window, cx| {
|
||||||
|
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
|
||||||
|
})
|
||||||
|
.on_click(cx.listener(move |this, _, _, cx| {
|
||||||
|
this.remove_thread(list_entry_ix, cx)
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
})
|
||||||
|
.on_click(
|
||||||
|
cx.listener(move |this, _, _, cx| this.confirm_entry(list_entry_ix, cx)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_any_element()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Focusable for AcpThreadHistory {
|
||||||
|
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||||
|
self.search_editor.focus_handle(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Render for AcpThreadHistory {
|
||||||
|
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||||
|
v_flex()
|
||||||
|
.key_context("ThreadHistory")
|
||||||
|
.size_full()
|
||||||
|
.on_action(cx.listener(Self::select_previous))
|
||||||
|
.on_action(cx.listener(Self::select_next))
|
||||||
|
.on_action(cx.listener(Self::select_first))
|
||||||
|
.on_action(cx.listener(Self::select_last))
|
||||||
|
.on_action(cx.listener(Self::confirm))
|
||||||
|
.on_action(cx.listener(Self::remove_selected_thread))
|
||||||
|
.when(!self.all_entries.is_empty(), |parent| {
|
||||||
|
parent.child(
|
||||||
|
h_flex()
|
||||||
|
.h(px(41.)) // Match the toolbar perfectly
|
||||||
|
.w_full()
|
||||||
|
.py_1()
|
||||||
|
.px_2()
|
||||||
|
.gap_2()
|
||||||
|
.justify_between()
|
||||||
|
.border_b_1()
|
||||||
|
.border_color(cx.theme().colors().border)
|
||||||
|
.child(
|
||||||
|
Icon::new(IconName::MagnifyingGlass)
|
||||||
|
.color(Color::Muted)
|
||||||
|
.size(IconSize::Small),
|
||||||
|
)
|
||||||
|
.child(self.search_editor.clone()),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.child({
|
||||||
|
let view = v_flex()
|
||||||
|
.id("list-container")
|
||||||
|
.relative()
|
||||||
|
.overflow_hidden()
|
||||||
|
.flex_grow();
|
||||||
|
|
||||||
|
if self.all_entries.is_empty() {
|
||||||
|
view.justify_center()
|
||||||
|
.child(
|
||||||
|
h_flex().w_full().justify_center().child(
|
||||||
|
Label::new("You don't have any past threads yet.")
|
||||||
|
.size(LabelSize::Small),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
} else if self.search_produced_no_matches() {
|
||||||
|
view.justify_center().child(
|
||||||
|
h_flex().w_full().justify_center().child(
|
||||||
|
Label::new("No threads match your search.").size(LabelSize::Small),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
view.pr_5()
|
||||||
|
.child(
|
||||||
|
uniform_list(
|
||||||
|
"thread-history",
|
||||||
|
self.list_item_count(),
|
||||||
|
cx.processor(|this, range: Range<usize>, window, cx| {
|
||||||
|
this.list_items(range, window, cx)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.p_1()
|
||||||
|
.track_scroll(self.scroll_handle.clone())
|
||||||
|
.flex_grow(),
|
||||||
|
)
|
||||||
|
.when_some(self.render_scrollbar(cx), |div, scrollbar| {
|
||||||
|
div.child(scrollbar)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub enum EntryTimeFormat {
|
||||||
|
DateAndTime,
|
||||||
|
TimeOnly,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EntryTimeFormat {
|
||||||
|
fn format_timestamp(&self, timestamp: i64, timezone: UtcOffset) -> String {
|
||||||
|
let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap();
|
||||||
|
|
||||||
|
match self {
|
||||||
|
EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp(
|
||||||
|
timestamp,
|
||||||
|
OffsetDateTime::now_utc(),
|
||||||
|
timezone,
|
||||||
|
time_format::TimestampFormat::EnhancedAbsolute,
|
||||||
|
),
|
||||||
|
EntryTimeFormat::TimeOnly => time_format::format_time(timestamp),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<TimeBucket> for EntryTimeFormat {
|
||||||
|
fn from(bucket: TimeBucket) -> Self {
|
||||||
|
match bucket {
|
||||||
|
TimeBucket::Today => EntryTimeFormat::TimeOnly,
|
||||||
|
TimeBucket::Yesterday => EntryTimeFormat::TimeOnly,
|
||||||
|
TimeBucket::ThisWeek => EntryTimeFormat::DateAndTime,
|
||||||
|
TimeBucket::PastWeek => EntryTimeFormat::DateAndTime,
|
||||||
|
TimeBucket::All => EntryTimeFormat::DateAndTime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
|
||||||
|
enum TimeBucket {
|
||||||
|
Today,
|
||||||
|
Yesterday,
|
||||||
|
ThisWeek,
|
||||||
|
PastWeek,
|
||||||
|
All,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TimeBucket {
|
||||||
|
fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self {
|
||||||
|
if date == reference {
|
||||||
|
return TimeBucket::Today;
|
||||||
|
}
|
||||||
|
|
||||||
|
if date == reference - TimeDelta::days(1) {
|
||||||
|
return TimeBucket::Yesterday;
|
||||||
|
}
|
||||||
|
|
||||||
|
let week = date.iso_week();
|
||||||
|
|
||||||
|
if reference.iso_week() == week {
|
||||||
|
return TimeBucket::ThisWeek;
|
||||||
|
}
|
||||||
|
|
||||||
|
let last_week = (reference - TimeDelta::days(7)).iso_week();
|
||||||
|
|
||||||
|
if week == last_week {
|
||||||
|
return TimeBucket::PastWeek;
|
||||||
|
}
|
||||||
|
|
||||||
|
TimeBucket::All
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for TimeBucket {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TimeBucket::Today => write!(f, "Today"),
|
||||||
|
TimeBucket::Yesterday => write!(f, "Yesterday"),
|
||||||
|
TimeBucket::ThisWeek => write!(f, "This Week"),
|
||||||
|
TimeBucket::PastWeek => write!(f, "Past Week"),
|
||||||
|
TimeBucket::All => write!(f, "All"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use chrono::NaiveDate;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_time_bucket_from_dates() {
|
||||||
|
let today = NaiveDate::from_ymd_opt(2023, 1, 15).unwrap();
|
||||||
|
|
||||||
|
let date = today;
|
||||||
|
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Today);
|
||||||
|
|
||||||
|
let date = NaiveDate::from_ymd_opt(2023, 1, 14).unwrap();
|
||||||
|
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Yesterday);
|
||||||
|
|
||||||
|
let date = NaiveDate::from_ymd_opt(2023, 1, 13).unwrap();
|
||||||
|
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek);
|
||||||
|
|
||||||
|
let date = NaiveDate::from_ymd_opt(2023, 1, 11).unwrap();
|
||||||
|
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek);
|
||||||
|
|
||||||
|
let date = NaiveDate::from_ymd_opt(2023, 1, 8).unwrap();
|
||||||
|
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek);
|
||||||
|
|
||||||
|
let date = NaiveDate::from_ymd_opt(2023, 1, 5).unwrap();
|
||||||
|
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek);
|
||||||
|
|
||||||
|
// All: not in this week or last week
|
||||||
|
let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap();
|
||||||
|
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::All);
|
||||||
|
|
||||||
|
// Test year boundary cases
|
||||||
|
let new_year = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap();
|
||||||
|
|
||||||
|
let date = NaiveDate::from_ymd_opt(2022, 12, 31).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
TimeBucket::from_dates(new_year, date),
|
||||||
|
TimeBucket::Yesterday
|
||||||
|
);
|
||||||
|
|
||||||
|
let date = NaiveDate::from_ymd_opt(2022, 12, 28).unwrap();
|
||||||
|
assert_eq!(TimeBucket::from_dates(new_year, date), TimeBucket::ThisWeek);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
use acp_thread::{
|
use acp_thread::{
|
||||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
|
AcpThread, AcpThreadEvent, AcpThreadMetadata, AgentThreadEntry, AssistantMessage,
|
||||||
LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId,
|
AssistantMessageChunk, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent,
|
||||||
|
ToolCallStatus, UserMessageId,
|
||||||
};
|
};
|
||||||
use acp_thread::{AgentConnection, Plan};
|
use acp_thread::{AgentConnection, Plan};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
|
@ -17,6 +18,7 @@ use editor::scroll::Autoscroll;
|
||||||
use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
|
use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
|
||||||
use file_icons::FileIcons;
|
use file_icons::FileIcons;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
|
use futures::StreamExt;
|
||||||
use gpui::{
|
use gpui::{
|
||||||
Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement,
|
Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement,
|
||||||
Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
|
Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
|
||||||
|
@ -122,6 +124,7 @@ pub struct AcpThreadView {
|
||||||
editor_expanded: bool,
|
editor_expanded: bool,
|
||||||
terminal_expanded: bool,
|
terminal_expanded: bool,
|
||||||
editing_message: Option<usize>,
|
editing_message: Option<usize>,
|
||||||
|
history_store: Entity<agent2::HistoryStore>,
|
||||||
_cancel_task: Option<Task<()>>,
|
_cancel_task: Option<Task<()>>,
|
||||||
_subscriptions: [Subscription; 3],
|
_subscriptions: [Subscription; 3],
|
||||||
}
|
}
|
||||||
|
@ -133,6 +136,7 @@ enum ThreadState {
|
||||||
Ready {
|
Ready {
|
||||||
thread: Entity<AcpThread>,
|
thread: Entity<AcpThread>,
|
||||||
_subscription: [Subscription; 2],
|
_subscription: [Subscription; 2],
|
||||||
|
_history_task: Option<Task<()>>,
|
||||||
},
|
},
|
||||||
LoadError(LoadError),
|
LoadError(LoadError),
|
||||||
Unauthenticated {
|
Unauthenticated {
|
||||||
|
@ -148,8 +152,10 @@ impl AcpThreadView {
|
||||||
agent: Rc<dyn AgentServer>,
|
agent: Rc<dyn AgentServer>,
|
||||||
workspace: WeakEntity<Workspace>,
|
workspace: WeakEntity<Workspace>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
history_store: Entity<agent2::HistoryStore>,
|
||||||
thread_store: Entity<ThreadStore>,
|
thread_store: Entity<ThreadStore>,
|
||||||
text_thread_store: Entity<TextThreadStore>,
|
text_thread_store: Entity<TextThreadStore>,
|
||||||
|
restore_thread: Option<AcpThreadMetadata>,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -191,7 +197,16 @@ impl AcpThreadView {
|
||||||
workspace: workspace.clone(),
|
workspace: workspace.clone(),
|
||||||
project: project.clone(),
|
project: project.clone(),
|
||||||
entry_view_state,
|
entry_view_state,
|
||||||
thread_state: Self::initial_state(agent, workspace, project, window, cx),
|
thread_state: Self::initial_state(
|
||||||
|
agent,
|
||||||
|
restore_thread,
|
||||||
|
history_store.clone(),
|
||||||
|
workspace,
|
||||||
|
project,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
),
|
||||||
|
history_store,
|
||||||
message_editor,
|
message_editor,
|
||||||
model_selector: None,
|
model_selector: None,
|
||||||
profile_selector: None,
|
profile_selector: None,
|
||||||
|
@ -215,6 +230,8 @@ impl AcpThreadView {
|
||||||
|
|
||||||
fn initial_state(
|
fn initial_state(
|
||||||
agent: Rc<dyn AgentServer>,
|
agent: Rc<dyn AgentServer>,
|
||||||
|
restore_thread: Option<AcpThreadMetadata>,
|
||||||
|
history_store: Entity<agent2::HistoryStore>,
|
||||||
workspace: WeakEntity<Workspace>,
|
workspace: WeakEntity<Workspace>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
|
@ -241,6 +258,25 @@ impl AcpThreadView {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut history_task = None;
|
||||||
|
let history = connection.clone().history();
|
||||||
|
if let Some(history) = history.clone() {
|
||||||
|
if let Some(mut history) = cx.update(|_, cx| history.observe_history(cx)).ok() {
|
||||||
|
history_task = Some(cx.spawn(async move |cx| {
|
||||||
|
while let Some(update) = history.next().await {
|
||||||
|
if !history_store
|
||||||
|
.update(cx, |history_store, cx| {
|
||||||
|
history_store.update_history(update, cx)
|
||||||
|
})
|
||||||
|
.is_ok()
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// this.update_in(cx, |_this, _window, cx| {
|
// this.update_in(cx, |_this, _window, cx| {
|
||||||
// let status = connection.exit_status(cx);
|
// let status = connection.exit_status(cx);
|
||||||
// cx.spawn(async move |this, cx| {
|
// cx.spawn(async move |this, cx| {
|
||||||
|
@ -254,19 +290,24 @@ impl AcpThreadView {
|
||||||
// .detach();
|
// .detach();
|
||||||
// })
|
// })
|
||||||
// .ok();
|
// .ok();
|
||||||
|
let history = connection.clone().history();
|
||||||
let Some(result) = cx
|
let task = cx.update(|_, cx| {
|
||||||
.update(|_, cx| {
|
if let Some(restore_thread) = restore_thread
|
||||||
|
&& let Some(history) = history
|
||||||
|
{
|
||||||
|
history.load_thread(project.clone(), &root_dir, restore_thread.id, cx)
|
||||||
|
} else {
|
||||||
connection
|
connection
|
||||||
.clone()
|
.clone()
|
||||||
.new_thread(project.clone(), &root_dir, cx)
|
.new_thread(project.clone(), &root_dir, cx)
|
||||||
})
|
}
|
||||||
.log_err()
|
});
|
||||||
else {
|
|
||||||
|
let Ok(task) = task else {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = match result.await {
|
let result = match task.await {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let mut cx = cx.clone();
|
let mut cx = cx.clone();
|
||||||
if e.is::<acp_thread::AuthRequired>() {
|
if e.is::<acp_thread::AuthRequired>() {
|
||||||
|
@ -293,8 +334,13 @@ impl AcpThreadView {
|
||||||
let action_log_subscription =
|
let action_log_subscription =
|
||||||
cx.observe(&action_log, |_, _, cx| cx.notify());
|
cx.observe(&action_log, |_, _, cx| cx.notify());
|
||||||
|
|
||||||
this.list_state
|
let count = thread.read(cx).entries().len();
|
||||||
.splice(0..0, thread.read(cx).entries().len());
|
this.list_state.splice(0..0, count);
|
||||||
|
this.entry_view_state.update(cx, |view_state, cx| {
|
||||||
|
for ix in 0..count {
|
||||||
|
view_state.sync_entry(ix, &thread, window, cx);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
|
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
|
||||||
|
|
||||||
|
@ -319,6 +365,7 @@ impl AcpThreadView {
|
||||||
this.thread_state = ThreadState::Ready {
|
this.thread_state = ThreadState::Ready {
|
||||||
thread,
|
thread,
|
||||||
_subscription: [thread_subscription, action_log_subscription],
|
_subscription: [thread_subscription, action_log_subscription],
|
||||||
|
_history_task: history_task,
|
||||||
};
|
};
|
||||||
|
|
||||||
this.profile_selector = this.as_native_thread(cx).map(|thread| {
|
this.profile_selector = this.as_native_thread(cx).map(|thread| {
|
||||||
|
@ -698,6 +745,7 @@ impl AcpThreadView {
|
||||||
AcpThreadEvent::ServerExited(status) => {
|
AcpThreadEvent::ServerExited(status) => {
|
||||||
self.thread_state = ThreadState::ServerExited { status: *status };
|
self.thread_state = ThreadState::ServerExited { status: *status };
|
||||||
}
|
}
|
||||||
|
AcpThreadEvent::TitleUpdated => {}
|
||||||
}
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
@ -726,6 +774,8 @@ impl AcpThreadView {
|
||||||
} else {
|
} else {
|
||||||
this.thread_state = Self::initial_state(
|
this.thread_state = Self::initial_state(
|
||||||
agent,
|
agent,
|
||||||
|
None, // todo!()
|
||||||
|
this.history_store.clone(),
|
||||||
this.workspace.clone(),
|
this.workspace.clone(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
window,
|
window,
|
||||||
|
@ -2546,12 +2596,15 @@ impl AcpThreadView {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
let current_mode = thread.completion_mode();
|
let current_mode = thread.completion_mode();
|
||||||
thread.set_completion_mode(match current_mode {
|
thread.set_completion_mode(
|
||||||
CompletionMode::Burn => CompletionMode::Normal,
|
match current_mode {
|
||||||
CompletionMode::Normal => CompletionMode::Burn,
|
CompletionMode::Burn => CompletionMode::Normal,
|
||||||
});
|
CompletionMode::Normal => CompletionMode::Burn,
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3265,8 +3318,8 @@ impl AcpThreadView {
|
||||||
.tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
|
.tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
|
||||||
.on_click({
|
.on_click({
|
||||||
cx.listener(move |this, _, _window, cx| {
|
cx.listener(move |this, _, _window, cx| {
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(CompletionMode::Burn);
|
thread.set_completion_mode(CompletionMode::Burn, cx);
|
||||||
});
|
});
|
||||||
this.resume_chat(cx);
|
this.resume_chat(cx);
|
||||||
})
|
})
|
||||||
|
@ -3587,7 +3640,7 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) mod tests {
|
pub(crate) mod tests {
|
||||||
use acp_thread::StubAgentConnection;
|
use acp_thread::{AgentServerName, StubAgentConnection};
|
||||||
use agent::{TextThreadStore, ThreadStore};
|
use agent::{TextThreadStore, ThreadStore};
|
||||||
use agent_client_protocol::SessionId;
|
use agent_client_protocol::SessionId;
|
||||||
use editor::EditorSettings;
|
use editor::EditorSettings;
|
||||||
|
@ -3727,6 +3780,8 @@ pub(crate) mod tests {
|
||||||
cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx)));
|
cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx)));
|
||||||
let text_thread_store =
|
let text_thread_store =
|
||||||
cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx)));
|
cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx)));
|
||||||
|
let history_store =
|
||||||
|
cx.update(|_window, cx| cx.new(|cx| agent2::HistoryStore::get_or_init(cx)));
|
||||||
|
|
||||||
let thread_view = cx.update(|window, cx| {
|
let thread_view = cx.update(|window, cx| {
|
||||||
cx.new(|cx| {
|
cx.new(|cx| {
|
||||||
|
@ -3734,8 +3789,10 @@ pub(crate) mod tests {
|
||||||
Rc::new(agent),
|
Rc::new(agent),
|
||||||
workspace.downgrade(),
|
workspace.downgrade(),
|
||||||
project,
|
project,
|
||||||
|
history_store.clone(),
|
||||||
thread_store.clone(),
|
thread_store.clone(),
|
||||||
text_thread_store.clone(),
|
text_thread_store.clone(),
|
||||||
|
None,
|
||||||
window,
|
window,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -3817,8 +3874,8 @@ pub(crate) mod tests {
|
||||||
ui::IconName::Ai
|
ui::IconName::Ai
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> AgentServerName {
|
||||||
"Test"
|
AgentServerName("Test".into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn empty_state_headline(&self) -> &'static str {
|
fn empty_state_headline(&self) -> &'static str {
|
||||||
|
@ -3925,6 +3982,8 @@ pub(crate) mod tests {
|
||||||
cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx)));
|
cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx)));
|
||||||
let text_thread_store =
|
let text_thread_store =
|
||||||
cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx)));
|
cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx)));
|
||||||
|
let history_store =
|
||||||
|
cx.update(|_window, cx| cx.new(|cx| agent2::HistoryStore::get_or_init(cx)));
|
||||||
|
|
||||||
let connection = Rc::new(StubAgentConnection::new());
|
let connection = Rc::new(StubAgentConnection::new());
|
||||||
let thread_view = cx.update(|window, cx| {
|
let thread_view = cx.update(|window, cx| {
|
||||||
|
@ -3933,8 +3992,10 @@ pub(crate) mod tests {
|
||||||
Rc::new(StubAgentServer::new(connection.as_ref().clone())),
|
Rc::new(StubAgentServer::new(connection.as_ref().clone())),
|
||||||
workspace.downgrade(),
|
workspace.downgrade(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
|
history_store,
|
||||||
thread_store.clone(),
|
thread_store.clone(),
|
||||||
text_thread_store.clone(),
|
text_thread_store.clone(),
|
||||||
|
None,
|
||||||
window,
|
window,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
|
|
@ -199,24 +199,21 @@ impl AgentDiffPane {
|
||||||
let action_log = thread.action_log(cx).clone();
|
let action_log = thread.action_log(cx).clone();
|
||||||
|
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
_subscriptions: [
|
_subscriptions: vec![
|
||||||
Some(
|
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
|
||||||
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
|
this.update_excerpts(window, cx)
|
||||||
this.update_excerpts(window, cx)
|
}),
|
||||||
}),
|
|
||||||
),
|
|
||||||
match &thread {
|
match &thread {
|
||||||
AgentDiffThread::Native(thread) => {
|
AgentDiffThread::Native(thread) => cx
|
||||||
Some(cx.subscribe(&thread, |this, _thread, event, cx| {
|
.subscribe(&thread, |this, _thread, event, cx| {
|
||||||
this.handle_thread_event(event, cx)
|
this.handle_native_thread_event(event, cx)
|
||||||
}))
|
}),
|
||||||
}
|
AgentDiffThread::AcpThread(thread) => cx
|
||||||
AgentDiffThread::AcpThread(_) => None,
|
.subscribe(&thread, |this, _thread, event, cx| {
|
||||||
|
this.handle_acp_thread_event(event, cx)
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
.into_iter()
|
|
||||||
.flatten()
|
|
||||||
.collect(),
|
|
||||||
title: SharedString::default(),
|
title: SharedString::default(),
|
||||||
multibuffer,
|
multibuffer,
|
||||||
editor,
|
editor,
|
||||||
|
@ -324,13 +321,20 @@ impl AgentDiffPane {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
|
fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
|
||||||
match event {
|
match event {
|
||||||
ThreadEvent::SummaryGenerated => self.update_title(cx),
|
ThreadEvent::SummaryGenerated => self.update_title(cx),
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
|
||||||
|
match event {
|
||||||
|
AcpThreadEvent::TitleUpdated => self.update_title(cx),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
|
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
|
||||||
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
|
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
|
||||||
self.editor.update(cx, |editor, cx| {
|
self.editor.update(cx, |editor, cx| {
|
||||||
|
@ -1521,7 +1525,8 @@ impl AgentDiff {
|
||||||
self.update_reviewing_editors(workspace, window, cx);
|
self.update_reviewing_editors(workspace, window, cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AcpThreadEvent::EntriesRemoved(_)
|
AcpThreadEvent::TitleUpdated
|
||||||
|
| AcpThreadEvent::EntriesRemoved(_)
|
||||||
| AcpThreadEvent::Stopped
|
| AcpThreadEvent::Stopped
|
||||||
| AcpThreadEvent::ToolAuthorizationRequired
|
| AcpThreadEvent::ToolAuthorizationRequired
|
||||||
| AcpThreadEvent::Error
|
| AcpThreadEvent::Error
|
||||||
|
|
|
@ -4,11 +4,13 @@ use std::rc::Rc;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use acp_thread::AcpThreadMetadata;
|
||||||
use agent_servers::AgentServer;
|
use agent_servers::AgentServer;
|
||||||
|
use agent2::HistoryEntry;
|
||||||
use db::kvp::{Dismissable, KEY_VALUE_STORE};
|
use db::kvp::{Dismissable, KEY_VALUE_STORE};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::NewExternalAgentThread;
|
use crate::acp::{AcpThreadHistory, ThreadHistoryEvent};
|
||||||
use crate::agent_diff::AgentDiffThread;
|
use crate::agent_diff::AgentDiffThread;
|
||||||
use crate::{
|
use crate::{
|
||||||
AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode,
|
AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode,
|
||||||
|
@ -29,6 +31,7 @@ use crate::{
|
||||||
thread_history::{HistoryEntryElement, ThreadHistory},
|
thread_history::{HistoryEntryElement, ThreadHistory},
|
||||||
ui::{AgentOnboardingModal, EndTrialUpsell},
|
ui::{AgentOnboardingModal, EndTrialUpsell},
|
||||||
};
|
};
|
||||||
|
use crate::{ExternalAgent, NewExternalAgentThread};
|
||||||
use agent::{
|
use agent::{
|
||||||
Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio,
|
Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio,
|
||||||
context_store::ContextStore,
|
context_store::ContextStore,
|
||||||
|
@ -119,7 +122,7 @@ pub fn init(cx: &mut App) {
|
||||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||||
panel.update(cx, |panel, cx| {
|
panel.update(cx, |panel, cx| {
|
||||||
panel.new_external_thread(action.agent, window, cx)
|
panel.new_external_thread(action.agent, None, window, cx)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -478,6 +481,7 @@ pub struct AgentPanel {
|
||||||
previous_view: Option<ActiveView>,
|
previous_view: Option<ActiveView>,
|
||||||
history_store: Entity<HistoryStore>,
|
history_store: Entity<HistoryStore>,
|
||||||
history: Entity<ThreadHistory>,
|
history: Entity<ThreadHistory>,
|
||||||
|
acp_history: Entity<AcpThreadHistory>,
|
||||||
hovered_recent_history_item: Option<usize>,
|
hovered_recent_history_item: Option<usize>,
|
||||||
new_thread_menu_handle: PopoverMenuHandle<ContextMenu>,
|
new_thread_menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||||
agent_panel_menu_handle: PopoverMenuHandle<ContextMenu>,
|
agent_panel_menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||||
|
@ -744,6 +748,27 @@ impl AgentPanel {
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let acp_history = cx.new(|cx| AcpThreadHistory::new(&project, window, cx));
|
||||||
|
cx.subscribe_in(
|
||||||
|
&acp_history,
|
||||||
|
window,
|
||||||
|
|this, _, event, window, cx| match event {
|
||||||
|
ThreadHistoryEvent::Open(HistoryEntry::AcpThread(thread)) => {
|
||||||
|
let agent_choice = match thread.agent.0.as_ref() {
|
||||||
|
"Claude Code" => Some(ExternalAgent::ClaudeCode),
|
||||||
|
"Gemini" => Some(ExternalAgent::Gemini),
|
||||||
|
"Native Agent" => Some(ExternalAgent::NativeAgent),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
this.new_external_thread(agent_choice, Some(thread.clone()), window, cx);
|
||||||
|
}
|
||||||
|
ThreadHistoryEvent::Open(HistoryEntry::TextThread(thread)) => {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.detach();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
active_view,
|
active_view,
|
||||||
workspace,
|
workspace,
|
||||||
|
@ -765,6 +790,7 @@ impl AgentPanel {
|
||||||
previous_view: None,
|
previous_view: None,
|
||||||
history_store: history_store.clone(),
|
history_store: history_store.clone(),
|
||||||
history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)),
|
history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)),
|
||||||
|
acp_history,
|
||||||
hovered_recent_history_item: None,
|
hovered_recent_history_item: None,
|
||||||
new_thread_menu_handle: PopoverMenuHandle::default(),
|
new_thread_menu_handle: PopoverMenuHandle::default(),
|
||||||
agent_panel_menu_handle: PopoverMenuHandle::default(),
|
agent_panel_menu_handle: PopoverMenuHandle::default(),
|
||||||
|
@ -954,6 +980,7 @@ impl AgentPanel {
|
||||||
fn new_external_thread(
|
fn new_external_thread(
|
||||||
&mut self,
|
&mut self,
|
||||||
agent_choice: Option<crate::ExternalAgent>,
|
agent_choice: Option<crate::ExternalAgent>,
|
||||||
|
restore_thread: Option<AcpThreadMetadata>,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
|
@ -1004,13 +1031,16 @@ impl AgentPanel {
|
||||||
};
|
};
|
||||||
|
|
||||||
this.update_in(cx, |this, window, cx| {
|
this.update_in(cx, |this, window, cx| {
|
||||||
|
let acp_history_store = this.acp_history.read(cx).history_store.clone();
|
||||||
let thread_view = cx.new(|cx| {
|
let thread_view = cx.new(|cx| {
|
||||||
crate::acp::AcpThreadView::new(
|
crate::acp::AcpThreadView::new(
|
||||||
server,
|
server,
|
||||||
workspace.clone(),
|
workspace.clone(),
|
||||||
project,
|
project,
|
||||||
|
acp_history_store,
|
||||||
thread_store.clone(),
|
thread_store.clone(),
|
||||||
text_thread_store.clone(),
|
text_thread_store.clone(),
|
||||||
|
restore_thread,
|
||||||
window,
|
window,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -1669,13 +1699,13 @@ impl AgentPanel {
|
||||||
window.dispatch_action(NewTextThread.boxed_clone(), cx);
|
window.dispatch_action(NewTextThread.boxed_clone(), cx);
|
||||||
}
|
}
|
||||||
AgentType::NativeAgent => {
|
AgentType::NativeAgent => {
|
||||||
self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), window, cx)
|
self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), None, window, cx)
|
||||||
}
|
}
|
||||||
AgentType::Gemini => {
|
AgentType::Gemini => {
|
||||||
self.new_external_thread(Some(crate::ExternalAgent::Gemini), window, cx)
|
self.new_external_thread(Some(crate::ExternalAgent::Gemini), None, window, cx)
|
||||||
}
|
}
|
||||||
AgentType::ClaudeCode => {
|
AgentType::ClaudeCode => {
|
||||||
self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), window, cx)
|
self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), None, window, cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1686,7 +1716,14 @@ impl Focusable for AgentPanel {
|
||||||
match &self.active_view {
|
match &self.active_view {
|
||||||
ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx),
|
ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx),
|
||||||
ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx),
|
ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx),
|
||||||
ActiveView::History => self.history.focus_handle(cx),
|
ActiveView::History => {
|
||||||
|
if cx.has_flag::<feature_flags::AcpFeatureFlag>() {
|
||||||
|
self.acp_history.focus_handle(cx)
|
||||||
|
} else {
|
||||||
|
self.history.focus_handle(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
|
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
|
||||||
ActiveView::Configuration => {
|
ActiveView::Configuration => {
|
||||||
if let Some(configuration) = self.configuration.as_ref() {
|
if let Some(configuration) = self.configuration.as_ref() {
|
||||||
|
@ -3517,7 +3554,13 @@ impl Render for AgentPanel {
|
||||||
ActiveView::ExternalAgentThread { thread_view, .. } => parent
|
ActiveView::ExternalAgentThread { thread_view, .. } => parent
|
||||||
.child(thread_view.clone())
|
.child(thread_view.clone())
|
||||||
.child(self.render_drag_target(cx)),
|
.child(self.render_drag_target(cx)),
|
||||||
ActiveView::History => parent.child(self.history.clone()),
|
ActiveView::History => {
|
||||||
|
if cx.has_flag::<feature_flags::AcpFeatureFlag>() {
|
||||||
|
parent.child(self.acp_history.clone())
|
||||||
|
} else {
|
||||||
|
parent.child(self.history.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
ActiveView::TextThread {
|
ActiveView::TextThread {
|
||||||
context_editor,
|
context_editor,
|
||||||
buffer_search_bar,
|
buffer_search_bar,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue