diff --git a/.github/actions/run_tests_windows/action.yml b/.github/actions/run_tests_windows/action.yml
index e3e3b7142e..0a550c7d32 100644
--- a/.github/actions/run_tests_windows/action.yml
+++ b/.github/actions/run_tests_windows/action.yml
@@ -56,7 +56,6 @@ runs:
$env:COMPlus_CreateDumpDiagnostics = "1"
cargo nextest run --workspace --no-fail-fast
- continue-on-error: true
- name: Analyze crash dumps
if: always()
diff --git a/Cargo.lock b/Cargo.lock
index f44f513f78..6dbbccedaa 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -7,7 +7,6 @@ name = "acp_thread"
version = "0.1.0"
dependencies = [
"action_log",
- "agent",
"agent-client-protocol",
"anyhow",
"buffer_diff",
@@ -20,6 +19,7 @@ dependencies = [
"indoc",
"itertools 0.14.0",
"language",
+ "language_model",
"markdown",
"parking_lot",
"project",
@@ -130,7 +130,6 @@ dependencies = [
"component",
"context_server",
"convert_case 0.8.0",
- "feature_flags",
"fs",
"futures 0.3.31",
"git",
@@ -191,10 +190,12 @@ version = "0.1.0"
dependencies = [
"acp_thread",
"action_log",
+ "agent",
"agent-client-protocol",
"agent_servers",
"agent_settings",
"anyhow",
+ "assistant_context",
"assistant_tool",
"assistant_tools",
"chrono",
@@ -204,10 +205,12 @@ dependencies = [
"collections",
"context_server",
"ctor",
+ "db",
"editor",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
+ "git",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
@@ -221,6 +224,7 @@ dependencies = [
"log",
"lsp",
"open",
+ "parking_lot",
"paths",
"portable-pty",
"pretty_assertions",
@@ -233,6 +237,7 @@ dependencies = [
"serde_json",
"settings",
"smol",
+ "sqlez",
"task",
"tempfile",
"terminal",
@@ -249,6 +254,7 @@ dependencies = [
"workspace-hack",
"worktree",
"zlog",
+ "zstd",
]
[[package]]
@@ -256,7 +262,9 @@ name = "agent_servers"
version = "0.1.0"
dependencies = [
"acp_thread",
+ "action_log",
"agent-client-protocol",
+ "agent_settings",
"agentic-coding-protocol",
"anyhow",
"collections",
@@ -267,6 +275,8 @@ dependencies = [
"indoc",
"itertools 0.14.0",
"language",
+ "language_model",
+ "language_models",
"libc",
"log",
"nix 0.29.0",
@@ -274,6 +284,7 @@ dependencies = [
"project",
"rand 0.8.5",
"schemars",
+ "semver",
"serde",
"serde_json",
"settings",
@@ -3070,6 +3081,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
+ "serde_urlencoded",
"settings",
"sha2",
"smol",
@@ -3861,7 +3873,7 @@ dependencies = [
"jni",
"js-sys",
"libc",
- "mach2",
+ "mach2 0.4.2",
"ndk",
"ndk-context",
"num-derive",
@@ -4011,7 +4023,7 @@ checksum = "031ed29858d90cfdf27fe49fae28028a1f20466db97962fa2f4ea34809aeebf3"
dependencies = [
"cfg-if",
"libc",
- "mach2",
+ "mach2 0.4.2",
]
[[package]]
@@ -4023,7 +4035,7 @@ dependencies = [
"cfg-if",
"crash-context",
"libc",
- "mach2",
+ "mach2 0.4.2",
"parking_lot",
]
@@ -4033,6 +4045,7 @@ version = "0.1.0"
dependencies = [
"crash-handler",
"log",
+ "mach2 0.5.0",
"minidumper",
"paths",
"release_channel",
@@ -7477,6 +7490,7 @@ dependencies = [
"slotmap",
"smallvec",
"smol",
+ "stacksafe",
"strum 0.27.1",
"sum_tree",
"taffy",
@@ -9854,6 +9868,15 @@ dependencies = [
"libc",
]
+[[package]]
+name = "mach2"
+version = "0.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6a1b95cd5421ec55b445b5ae102f5ea0e768de1f82bd3001e11f426c269c3aea"
+dependencies = [
+ "libc",
+]
+
[[package]]
name = "malloc_buf"
version = "0.0.6"
@@ -10190,7 +10213,7 @@ dependencies = [
"goblin",
"libc",
"log",
- "mach2",
+ "mach2 0.4.2",
"memmap2",
"memoffset",
"minidump-common",
@@ -15536,6 +15559,40 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
+[[package]]
+name = "stacker"
+version = "0.1.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b"
+dependencies = [
+ "cc",
+ "cfg-if",
+ "libc",
+ "psm",
+ "windows-sys 0.59.0",
+]
+
+[[package]]
+name = "stacksafe"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1d9c1172965d317e87ddb6d364a040d958b40a1db82b6ef97da26253a8b3d090"
+dependencies = [
+ "stacker",
+ "stacksafe-macro",
+]
+
+[[package]]
+name = "stacksafe-macro"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "172175341049678163e979d9107ca3508046d4d2a7c6682bee46ac541b17db69"
+dependencies = [
+ "proc-macro-error2",
+ "quote",
+ "syn 2.0.101",
+]
+
[[package]]
name = "static_assertions"
version = "1.1.0"
@@ -18247,7 +18304,7 @@ dependencies = [
"indexmap",
"libc",
"log",
- "mach2",
+ "mach2 0.4.2",
"memfd",
"object",
"once_cell",
@@ -20197,8 +20254,9 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]]
name = "yawc"
-version = "0.2.4"
-source = "git+https://github.com/deviant-forks/yawc?rev=1899688f3e69ace4545aceb97b2a13881cf26142#1899688f3e69ace4545aceb97b2a13881cf26142"
+version = "0.2.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "19a5d82922135b4ae73a079a4ffb5501e9aadb4d785b8c660eaa0a8b899028c5"
dependencies = [
"base64 0.22.1",
"bytes 1.10.1",
diff --git a/Cargo.toml b/Cargo.toml
index 14691cf8a4..dc14c8ebd9 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -515,6 +515,7 @@ libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
linkify = "0.10.0"
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" }
+mach2 = "0.5"
markup5ever_rcdom = "0.3.0"
metal = "0.29"
minidumper = "0.8"
@@ -582,6 +583,7 @@ serde_json_lenient = { version = "0.2", features = [
"raw_value",
] }
serde_repr = "0.1"
+serde_urlencoded = "0.7"
sha2 = "0.10"
shellexpand = "2.1.0"
shlex = "1.3.0"
@@ -589,6 +591,7 @@ simplelog = "0.12.2"
smallvec = { version = "1.6", features = ["union"] }
smol = "2.0"
sqlformat = "0.2"
+stacksafe = "0.1"
streaming-iterator = "0.1"
strsim = "0.11"
strum = { version = "0.27.0", features = ["derive"] }
@@ -659,9 +662,7 @@ which = "6.0.0"
windows-core = "0.61"
wit-component = "0.221"
workspace-hack = "0.1.0"
-# We can switch back to the published version once https://github.com/infinitefield/yawc/pull/16 is merged and a new
-# version is released.
-yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" }
+yawc = "0.2.5"
zstd = "0.11"
[workspace.dependencies.windows]
@@ -821,10 +822,30 @@ single_range_in_vec_init = "allow"
style = { level = "allow", priority = -1 }
# Temporary list of style lints that we've fixed so far.
+comparison_to_empty = "warn"
+into_iter_on_ref = "warn"
+iter_cloned_collect = "warn"
+iter_next_slice = "warn"
+iter_nth = "warn"
+iter_nth_zero = "warn"
+iter_skip_next = "warn"
+let_and_return = "warn"
module_inception = { level = "deny" }
question_mark = { level = "deny" }
+single_match = "warn"
redundant_closure = { level = "deny" }
+redundant_static_lifetimes = { level = "warn" }
+redundant_pattern_matching = "warn"
+redundant_field_names = "warn"
declare_interior_mutable_const = { level = "deny" }
+collapsible_if = { level = "warn"}
+collapsible_else_if = { level = "warn" }
+needless_borrow = { level = "warn"}
+needless_return = { level = "warn" }
+unnecessary_mut_passed = {level = "warn"}
+unnecessary_map_or = { level = "warn" }
+unused_unit = "warn"
+
# Individual rules that have violations in the codebase:
type_complexity = "allow"
# We often return trait objects from `new` functions.
@@ -833,6 +854,8 @@ new_ret_no_self = { level = "allow" }
# compared to Iterator::next. Yet, clippy complains about those.
should_implement_trait = { level = "allow" }
let_underscore_future = "allow"
+# It doesn't make sense to implement `Default` unilaterally.
+new_without_default = "allow"
# in Rust it can be very tedious to reduce argument count without
# running afoul of the borrow checker.
@@ -841,6 +864,10 @@ too_many_arguments = "allow"
# We often have large enum variants yet we rarely actually bother with splitting them up.
large_enum_variant = "allow"
+# `enum_variant_names` fires for all enums, even when they derive serde traits.
+# Adhering to this lint would be a breaking change.
+enum_variant_names = "allow"
+
[workspace.metadata.cargo-machete]
ignored = [
"bindgen",
diff --git a/assets/icons/menu_alt.svg b/assets/icons/menu_alt.svg
index f73102e286..87add13216 100644
--- a/assets/icons/menu_alt.svg
+++ b/assets/icons/menu_alt.svg
@@ -1 +1,3 @@
-
+
diff --git a/assets/icons/menu_alt_temp.svg b/assets/icons/menu_alt_temp.svg
new file mode 100644
index 0000000000..87add13216
--- /dev/null
+++ b/assets/icons/menu_alt_temp.svg
@@ -0,0 +1,3 @@
+
diff --git a/assets/icons/x_circle_filled.svg b/assets/icons/x_circle_filled.svg
new file mode 100644
index 0000000000..52215acda8
--- /dev/null
+++ b/assets/icons/x_circle_filled.svg
@@ -0,0 +1,3 @@
+
diff --git a/assets/icons/zed_agent.svg b/assets/icons/zed_agent.svg
new file mode 100644
index 0000000000..b6e120a0b6
--- /dev/null
+++ b/assets/icons/zed_agent.svg
@@ -0,0 +1,27 @@
+
diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json
index 01c0b4e969..b4efa70572 100644
--- a/assets/keymaps/default-linux.json
+++ b/assets/keymaps/default-linux.json
@@ -327,7 +327,7 @@
}
},
{
- "context": "AcpThread > Editor",
+ "context": "AcpThread > Editor && !use_modifier_to_send",
"use_key_equivalents": true,
"bindings": {
"enter": "agent::Chat",
@@ -336,6 +336,16 @@
"ctrl-shift-n": "agent::RejectAll"
}
},
+ {
+ "context": "AcpThread > Editor && use_modifier_to_send",
+ "use_key_equivalents": true,
+ "bindings": {
+ "ctrl-enter": "agent::Chat",
+ "shift-ctrl-r": "agent::OpenAgentDiff",
+ "ctrl-shift-y": "agent::KeepAll",
+ "ctrl-shift-n": "agent::RejectAll"
+ }
+ },
{
"context": "ThreadHistory",
"bindings": {
diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json
index e5b7fff9e1..ad2ab2ba89 100644
--- a/assets/keymaps/default-macos.json
+++ b/assets/keymaps/default-macos.json
@@ -379,7 +379,7 @@
}
},
{
- "context": "AcpThread > Editor",
+ "context": "AcpThread > Editor && !use_modifier_to_send",
"use_key_equivalents": true,
"bindings": {
"enter": "agent::Chat",
@@ -388,6 +388,16 @@
"cmd-shift-n": "agent::RejectAll"
}
},
+ {
+ "context": "AcpThread > Editor && use_modifier_to_send",
+ "use_key_equivalents": true,
+ "bindings": {
+ "cmd-enter": "agent::Chat",
+ "shift-ctrl-r": "agent::OpenAgentDiff",
+ "cmd-shift-y": "agent::KeepAll",
+ "cmd-shift-n": "agent::RejectAll"
+ }
+ },
{
"context": "ThreadHistory",
"bindings": {
diff --git a/assets/settings/default.json b/assets/settings/default.json
index 72e4dcbf4f..c290baf003 100644
--- a/assets/settings/default.json
+++ b/assets/settings/default.json
@@ -717,7 +717,7 @@
// Can be 'never', 'always', or 'when_in_call',
// or a boolean (interpreted as 'never'/'always').
"button": "when_in_call",
- // Where to the chat panel. Can be 'left' or 'right'.
+ // Where to dock the chat panel. Can be 'left' or 'right'.
"dock": "right",
// Default width of the chat panel.
"default_width": 240
@@ -725,7 +725,7 @@
"git_panel": {
// Whether to show the git panel button in the status bar.
"button": true,
- // Where to show the git panel. Can be 'left' or 'right'.
+ // Where to dock the git panel. Can be 'left' or 'right'.
"dock": "left",
// Default width of the git panel.
"default_width": 360,
diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml
index 2b9a6513c8..eab756db51 100644
--- a/crates/acp_thread/Cargo.toml
+++ b/crates/acp_thread/Cargo.toml
@@ -18,7 +18,6 @@ test-support = ["gpui/test-support", "project/test-support", "dep:parking_lot"]
[dependencies]
action_log.workspace = true
agent-client-protocol.workspace = true
-agent.workspace = true
anyhow.workspace = true
buffer_diff.workspace = true
collections.workspace = true
@@ -28,6 +27,7 @@ futures.workspace = true
gpui.workspace = true
itertools.workspace = true
language.workspace = true
+language_model.workspace = true
markdown.workspace = true
parking_lot = { workspace = true, optional = true }
project.workspace = true
diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs
index fb31265326..5d3b35d018 100644
--- a/crates/acp_thread/src/acp_thread.rs
+++ b/crates/acp_thread/src/acp_thread.rs
@@ -3,9 +3,13 @@ mod diff;
mod mention;
mod terminal;
+use collections::HashSet;
pub use connection::*;
pub use diff::*;
+use language::language_settings::FormatOnSave;
pub use mention::*;
+use project::lsp_store::{FormatTrigger, LspFormatTarget};
+use serde::{Deserialize, Serialize};
pub use terminal::*;
use action_log::ActionLog;
@@ -24,6 +28,7 @@ use std::fmt::{Formatter, Write};
use std::ops::Range;
use std::process::ExitStatus;
use std::rc::Rc;
+use std::time::{Duration, Instant};
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
use ui::App;
use util::ResultExt;
@@ -48,7 +53,7 @@ impl UserMessage {
if self
.checkpoint
.as_ref()
- .map_or(false, |checkpoint| checkpoint.show)
+ .is_some_and(|checkpoint| checkpoint.show)
{
writeln!(markdown, "## User (checkpoint)").unwrap();
} else {
@@ -248,14 +253,13 @@ impl ToolCall {
}
if let Some(raw_output) = raw_output {
- if self.content.is_empty() {
- if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
- {
- self.content
- .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
- markdown,
- }));
- }
+ if self.content.is_empty()
+ && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
+ {
+ self.content
+ .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
+ markdown,
+ }));
}
self.raw_output = Some(raw_output);
}
@@ -429,11 +433,11 @@ impl ContentBlock {
language_registry: &Arc,
cx: &mut App,
) {
- if matches!(self, ContentBlock::Empty) {
- if let acp::ContentBlock::ResourceLink(resource_link) = block {
- *self = ContentBlock::ResourceLink { resource_link };
- return;
- }
+ if matches!(self, ContentBlock::Empty)
+ && let acp::ContentBlock::ResourceLink(resource_link) = block
+ {
+ *self = ContentBlock::ResourceLink { resource_link };
+ return;
}
let new_content = self.block_string_contents(block);
@@ -485,7 +489,7 @@ impl ContentBlock {
}
fn resource_link_md(uri: &str) -> String {
- if let Some(uri) = MentionUri::parse(&uri).log_err() {
+ if let Some(uri) = MentionUri::parse(uri).log_err() {
uri.as_link().to_string()
} else {
uri.to_string()
@@ -537,9 +541,15 @@ impl ToolCallContent {
acp::ToolCallContent::Content { content } => {
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
}
- acp::ToolCallContent::Diff { diff } => {
- Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
- }
+ acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
+ Diff::finalized(
+ diff.path,
+ diff.old_text,
+ diff.new_text,
+ language_registry,
+ cx,
+ )
+ })),
}
}
@@ -658,6 +668,21 @@ impl PlanEntry {
}
}
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+pub struct TokenUsage {
+ pub max_tokens: u64,
+ pub used_tokens: u64,
+}
+
+#[derive(Debug, Clone)]
+pub struct RetryStatus {
+ pub last_error: SharedString,
+ pub attempt: usize,
+ pub max_attempts: usize,
+ pub started_at: Instant,
+ pub duration: Duration,
+}
+
pub struct AcpThread {
title: SharedString,
entries: Vec,
@@ -668,16 +693,21 @@ pub struct AcpThread {
send_task: Option>,
connection: Rc,
session_id: acp::SessionId,
+ token_usage: Option,
}
+#[derive(Debug)]
pub enum AcpThreadEvent {
NewEntry,
+ TitleUpdated,
+ TokenUsageUpdated,
EntryUpdated(usize),
EntriesRemoved(Range),
ToolAuthorizationRequired,
+ Retry(RetryStatus),
Stopped,
Error,
- ServerExited(ExitStatus),
+ LoadError(LoadError),
}
impl EventEmitter for AcpThread {}
@@ -691,20 +721,30 @@ pub enum ThreadStatus {
#[derive(Debug, Clone)]
pub enum LoadError {
+ NotInstalled {
+ error_message: SharedString,
+ install_message: SharedString,
+ install_command: String,
+ },
Unsupported {
error_message: SharedString,
upgrade_message: SharedString,
upgrade_command: String,
},
- Exited(i32),
+ Exited {
+ status: ExitStatus,
+ },
Other(SharedString),
}
impl Display for LoadError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
- LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
- LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
+ LoadError::NotInstalled { error_message, .. }
+ | LoadError::Unsupported { error_message, .. } => {
+ write!(f, "{error_message}")
+ }
+ LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
LoadError::Other(msg) => write!(f, "{}", msg),
}
}
@@ -717,11 +757,9 @@ impl AcpThread {
title: impl Into,
connection: Rc,
project: Entity,
+ action_log: Entity,
session_id: acp::SessionId,
- cx: &mut Context,
) -> Self {
- let action_log = cx.new(|_| ActionLog::new(project.clone()));
-
Self {
action_log,
shared_buffers: Default::default(),
@@ -732,6 +770,7 @@ impl AcpThread {
send_task: None,
connection,
session_id,
+ token_usage: None,
}
}
@@ -771,6 +810,10 @@ impl AcpThread {
}
}
+ pub fn token_usage(&self) -> Option<&TokenUsage> {
+ self.token_usage.as_ref()
+ }
+
pub fn has_pending_edit_tool_calls(&self) -> bool {
for entry in self.entries.iter().rev() {
match entry {
@@ -915,6 +958,21 @@ impl AcpThread {
cx.emit(AcpThreadEvent::NewEntry);
}
+ pub fn update_title(&mut self, title: SharedString, cx: &mut Context) -> Result<()> {
+ self.title = title;
+ cx.emit(AcpThreadEvent::TitleUpdated);
+ Ok(())
+ }
+
+ pub fn update_token_usage(&mut self, usage: Option, cx: &mut Context) {
+ self.token_usage = usage;
+ cx.emit(AcpThreadEvent::TokenUsageUpdated);
+ }
+
+ pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context) {
+ cx.emit(AcpThreadEvent::Retry(status));
+ }
+
pub fn update_tool_call(
&mut self,
update: impl Into,
@@ -1006,6 +1064,22 @@ impl AcpThread {
})
}
+ pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
+ self.entries
+ .iter()
+ .enumerate()
+ .rev()
+ .find_map(|(index, tool_call)| {
+ if let AgentThreadEntry::ToolCall(tool_call) = tool_call
+ && &tool_call.id == id
+ {
+ Some((index, tool_call))
+ } else {
+ None
+ }
+ })
+ }
+
pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context) {
let project = self.project.clone();
let Some((_, tool_call)) = self.tool_call_mut(&id) else {
@@ -1199,17 +1273,21 @@ impl AcpThread {
} else {
None
};
- self.push_entry(
- AgentThreadEntry::UserMessage(UserMessage {
- id: message_id.clone(),
- content: block,
- chunks: message,
- checkpoint: None,
- }),
- cx,
- );
self.run_turn(cx, async move |this, cx| {
+ this.update(cx, |this, cx| {
+ this.push_entry(
+ AgentThreadEntry::UserMessage(UserMessage {
+ id: message_id.clone(),
+ content: block,
+ chunks: message,
+ checkpoint: None,
+ }),
+ cx,
+ );
+ })
+ .ok();
+
let old_checkpoint = git_store
.update(cx, |git, cx| git.checkpoint(cx))?
.await
@@ -1262,6 +1340,8 @@ impl AcpThread {
.await?;
this.update(cx, |this, cx| {
+ this.project
+ .update(cx, |project, cx| project.set_agent_location(None, cx));
match response {
Ok(Err(e)) => {
this.send_task.take();
@@ -1411,7 +1491,7 @@ impl AcpThread {
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
self.entries.iter().find_map(|entry| {
if let AgentThreadEntry::UserMessage(message) = entry {
- if message.id.as_ref() == Some(&id) {
+ if message.id.as_ref() == Some(id) {
Some(message)
} else {
None
@@ -1425,7 +1505,7 @@ impl AcpThread {
fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
if let AgentThreadEntry::UserMessage(message) = entry {
- if message.id.as_ref() == Some(&id) {
+ if message.id.as_ref() == Some(id) {
Some((ix, message))
} else {
None
@@ -1550,30 +1630,59 @@ impl AcpThread {
.collect::>()
})
.await;
- cx.update(|cx| {
- project.update(cx, |project, cx| {
- project.set_agent_location(
- Some(AgentLocation {
- buffer: buffer.downgrade(),
- position: edits
- .last()
- .map(|(range, _)| range.end)
- .unwrap_or(Anchor::MIN),
- }),
- cx,
- );
- });
+ project.update(cx, |project, cx| {
+ project.set_agent_location(
+ Some(AgentLocation {
+ buffer: buffer.downgrade(),
+ position: edits
+ .last()
+ .map(|(range, _)| range.end)
+ .unwrap_or(Anchor::MIN),
+ }),
+ cx,
+ );
+ })?;
+
+ let format_on_save = cx.update(|cx| {
action_log.update(cx, |action_log, cx| {
action_log.buffer_read(buffer.clone(), cx);
});
- buffer.update(cx, |buffer, cx| {
+
+ let format_on_save = buffer.update(cx, |buffer, cx| {
buffer.edit(edits, None, cx);
+
+ let settings = language::language_settings::language_settings(
+ buffer.language().map(|l| l.name()),
+ buffer.file(),
+ cx,
+ );
+
+ settings.format_on_save != FormatOnSave::Off
});
action_log.update(cx, |action_log, cx| {
action_log.buffer_edited(buffer.clone(), cx);
});
+ format_on_save
})?;
+
+ if format_on_save {
+ let format_task = project.update(cx, |project, cx| {
+ project.format(
+ HashSet::from_iter([buffer.clone()]),
+ LspFormatTarget::Buffers,
+ false,
+ FormatTrigger::Save,
+ cx,
+ )
+ })?;
+ format_task.await.log_err();
+
+ action_log.update(cx, |action_log, cx| {
+ action_log.buffer_edited(buffer.clone(), cx);
+ })?;
+ }
+
project
.update(cx, |project, cx| project.save_buffer(buffer, cx))?
.await
@@ -1584,8 +1693,8 @@ impl AcpThread {
self.entries.iter().map(|e| e.to_markdown(cx)).collect()
}
- pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context) {
- cx.emit(AcpThreadEvent::ServerExited(status));
+ pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context) {
+ cx.emit(AcpThreadEvent::LoadError(error));
}
}
@@ -1636,7 +1745,7 @@ mod tests {
use super::*;
use anyhow::anyhow;
use futures::{channel::mpsc, future::LocalBoxFuture, select};
- use gpui::{AsyncApp, TestAppContext, WeakEntity};
+ use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
use project::{FakeFs, Fs};
use rand::Rng as _;
@@ -2123,7 +2232,7 @@ mod tests {
"}
);
});
- assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
+ assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
.await
@@ -2153,7 +2262,10 @@ mod tests {
});
assert_eq!(
fs.files(),
- vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
+ vec![
+ Path::new(path!("/test/file-0")),
+ Path::new(path!("/test/file-1"))
+ ]
);
// Checkpoint isn't stored when there are no changes.
@@ -2194,7 +2306,10 @@ mod tests {
});
assert_eq!(
fs.files(),
- vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
+ vec![
+ Path::new(path!("/test/file-0")),
+ Path::new(path!("/test/file-1"))
+ ]
);
// Rewinding the conversation truncates the history and restores the checkpoint.
@@ -2222,7 +2337,7 @@ mod tests {
"}
);
});
- assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
+ assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
}
async fn run_until_first_tool_call(
@@ -2306,7 +2421,7 @@ mod tests {
self: Rc,
project: Entity,
_cwd: &Path,
- cx: &mut gpui::App,
+ cx: &mut App,
) -> Task>> {
let session_id = acp::SessionId(
rand::thread_rng()
@@ -2316,8 +2431,16 @@ mod tests {
.collect::()
.into(),
);
- let thread =
- cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let thread = cx.new(|_cx| {
+ AcpThread::new(
+ "Test",
+ self.clone(),
+ project,
+ action_log,
+ session_id.clone(),
+ )
+ });
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}
@@ -2351,7 +2474,7 @@ mod tests {
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
let sessions = self.sessions.lock();
- let thread = sessions.get(&session_id).unwrap().clone();
+ let thread = sessions.get(session_id).unwrap().clone();
cx.spawn(async move |cx| {
thread
diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs
index 7497d2309f..8cae975ce5 100644
--- a/crates/acp_thread/src/connection.rs
+++ b/crates/acp_thread/src/connection.rs
@@ -3,12 +3,14 @@ use agent_client_protocol::{self as acp};
use anyhow::Result;
use collections::IndexMap;
use gpui::{Entity, SharedString, Task};
+use language_model::LanguageModelProviderId;
use project::Project;
+use serde::{Deserialize, Serialize};
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct UserMessageId(Arc);
impl UserMessageId {
@@ -80,12 +82,34 @@ pub trait AgentSessionResume {
}
#[derive(Debug)]
-pub struct AuthRequired;
+pub struct AuthRequired {
+ pub description: Option,
+ pub provider_id: Option,
+}
+
+impl AuthRequired {
+ pub fn new() -> Self {
+ Self {
+ description: None,
+ provider_id: None,
+ }
+ }
+
+ pub fn with_description(mut self, description: String) -> Self {
+ self.description = Some(description);
+ self
+ }
+
+ pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
+ self.provider_id = Some(provider_id);
+ self
+ }
+}
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "AuthRequired")
+ write!(f, "Authentication required")
}
}
@@ -185,8 +209,9 @@ impl AgentModelList {
mod test_support {
use std::sync::Arc;
+ use action_log::ActionLog;
use collections::HashMap;
- use futures::future::try_join_all;
+ use futures::{channel::oneshot, future::try_join_all};
use gpui::{AppContext as _, WeakEntity};
use parking_lot::Mutex;
@@ -194,11 +219,16 @@ mod test_support {
#[derive(Clone, Default)]
pub struct StubAgentConnection {
- sessions: Arc>>>,
+ sessions: Arc>>,
permission_requests: HashMap>,
next_prompt_updates: Arc>>,
}
+ struct Session {
+ thread: WeakEntity,
+ response_tx: Option>,
+ }
+
impl StubAgentConnection {
pub fn new() -> Self {
Self {
@@ -226,15 +256,33 @@ mod test_support {
update: acp::SessionUpdate,
cx: &mut App,
) {
+ assert!(
+ self.next_prompt_updates.lock().is_empty(),
+ "Use either send_update or set_next_prompt_updates"
+ );
+
self.sessions
.lock()
.get(&session_id)
.unwrap()
+ .thread
.update(cx, |thread, cx| {
- thread.handle_session_update(update.clone(), cx).unwrap();
+ thread.handle_session_update(update, cx).unwrap();
})
.unwrap();
}
+
+ pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
+ self.sessions
+ .lock()
+ .get_mut(&session_id)
+ .unwrap()
+ .response_tx
+ .take()
+ .expect("No pending turn")
+ .send(stop_reason)
+ .unwrap();
+ }
}
impl AgentConnection for StubAgentConnection {
@@ -249,9 +297,23 @@ mod test_support {
cx: &mut gpui::App,
) -> Task>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
- let thread =
- cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
- self.sessions.lock().insert(session_id, thread.downgrade());
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let thread = cx.new(|_cx| {
+ AcpThread::new(
+ "Test",
+ self.clone(),
+ project,
+ action_log,
+ session_id.clone(),
+ )
+ });
+ self.sessions.lock().insert(
+ session_id,
+ Session {
+ thread: thread.downgrade(),
+ response_tx: None,
+ },
+ );
Task::ready(Ok(thread))
}
@@ -269,47 +331,70 @@ mod test_support {
params: acp::PromptRequest,
cx: &mut App,
) -> Task> {
- let sessions = self.sessions.lock();
- let thread = sessions.get(¶ms.session_id).unwrap();
+ let mut sessions = self.sessions.lock();
+ let Session {
+ thread,
+ response_tx,
+ } = sessions.get_mut(¶ms.session_id).unwrap();
let mut tasks = vec![];
- for update in self.next_prompt_updates.lock().drain(..) {
- let thread = thread.clone();
- let update = update.clone();
- let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
- && let Some(options) = self.permission_requests.get(&tool_call.id)
- {
- Some((tool_call.clone(), options.clone()))
- } else {
- None
- };
- let task = cx.spawn(async move |cx| {
- if let Some((tool_call, options)) = permission_request {
- let permission = thread.update(cx, |thread, cx| {
- thread.request_tool_call_authorization(
- tool_call.clone().into(),
- options.clone(),
- cx,
- )
- })?;
- permission?.await?;
- }
- thread.update(cx, |thread, cx| {
- thread.handle_session_update(update.clone(), cx).unwrap();
- })?;
- anyhow::Ok(())
- });
- tasks.push(task);
- }
- cx.spawn(async move |_| {
- try_join_all(tasks).await?;
- Ok(acp::PromptResponse {
- stop_reason: acp::StopReason::EndTurn,
+ if self.next_prompt_updates.lock().is_empty() {
+ let (tx, rx) = oneshot::channel();
+ response_tx.replace(tx);
+ cx.spawn(async move |_| {
+ let stop_reason = rx.await?;
+ Ok(acp::PromptResponse { stop_reason })
})
- })
+ } else {
+ for update in self.next_prompt_updates.lock().drain(..) {
+ let thread = thread.clone();
+ let update = update.clone();
+ let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
+ &update
+ && let Some(options) = self.permission_requests.get(&tool_call.id)
+ {
+ Some((tool_call.clone(), options.clone()))
+ } else {
+ None
+ };
+ let task = cx.spawn(async move |cx| {
+ if let Some((tool_call, options)) = permission_request {
+ let permission = thread.update(cx, |thread, cx| {
+ thread.request_tool_call_authorization(
+ tool_call.clone().into(),
+ options.clone(),
+ cx,
+ )
+ })?;
+ permission?.await?;
+ }
+ thread.update(cx, |thread, cx| {
+ thread.handle_session_update(update.clone(), cx).unwrap();
+ })?;
+ anyhow::Ok(())
+ });
+ tasks.push(task);
+ }
+
+ cx.spawn(async move |_| {
+ try_join_all(tasks).await?;
+ Ok(acp::PromptResponse {
+ stop_reason: acp::StopReason::EndTurn,
+ })
+ })
+ }
}
- fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
- unimplemented!()
+ fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
+ if let Some(end_turn_tx) = self
+ .sessions
+ .lock()
+ .get_mut(session_id)
+ .unwrap()
+ .response_tx
+ .take()
+ {
+ end_turn_tx.send(acp::StopReason::Canceled).unwrap();
+ }
}
fn session_editor(
diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs
index a2c2d6c322..4b779931c5 100644
--- a/crates/acp_thread/src/diff.rs
+++ b/crates/acp_thread/src/diff.rs
@@ -1,4 +1,3 @@
-use agent_client_protocol as acp;
use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey};
@@ -21,17 +20,13 @@ pub enum Diff {
}
impl Diff {
- pub fn from_acp(
- diff: acp::Diff,
+ pub fn finalized(
+ path: PathBuf,
+ old_text: Option,
+ new_text: String,
language_registry: Arc,
cx: &mut Context,
) -> Self {
- let acp::Diff {
- path,
- old_text,
- new_text,
- } = diff;
-
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
@@ -71,8 +66,8 @@ impl Diff {
let hunk_ranges = {
let buffer = new_buffer.read(cx);
let diff = buffer_diff.read(cx);
- diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
- .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
+ diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer, cx)
+ .map(|diff_hunk| diff_hunk.buffer_range.to_point(buffer))
.collect::>()
};
@@ -306,13 +301,13 @@ impl PendingDiff {
let buffer = self.buffer.read(cx);
let diff = self.diff.read(cx);
let mut ranges = diff
- .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
- .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
+ .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer, cx)
+ .map(|diff_hunk| diff_hunk.buffer_range.to_point(buffer))
.collect::>();
ranges.extend(
self.revealed_ranges
.iter()
- .map(|range| range.to_point(&buffer)),
+ .map(|range| range.to_point(buffer)),
);
ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
diff --git a/crates/acp_thread/src/mention.rs b/crates/acp_thread/src/mention.rs
index b9b021c4ca..a1e713cffa 100644
--- a/crates/acp_thread/src/mention.rs
+++ b/crates/acp_thread/src/mention.rs
@@ -1,7 +1,8 @@
-use agent::ThreadId;
+use agent_client_protocol as acp;
use anyhow::{Context as _, Result, bail};
use file_icons::FileIcons;
use prompt_store::{PromptId, UserPromptId};
+use serde::{Deserialize, Serialize};
use std::{
fmt,
ops::Range,
@@ -11,11 +12,13 @@ use std::{
use ui::{App, IconName, SharedString};
use url::Url;
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum MentionUri {
File {
abs_path: PathBuf,
- is_directory: bool,
+ },
+ Directory {
+ abs_path: PathBuf,
},
Symbol {
path: PathBuf,
@@ -23,7 +26,7 @@ pub enum MentionUri {
line_range: Range,
},
Thread {
- id: ThreadId,
+ id: acp::SessionId,
name: String,
},
TextThread {
@@ -49,6 +52,7 @@ impl MentionUri {
let path = url.path();
match url.scheme() {
"file" => {
+ let path = url.to_file_path().ok().context("Extracting file path")?;
if let Some(fragment) = url.fragment() {
let range = fragment
.strip_prefix("L")
@@ -69,31 +73,23 @@ impl MentionUri {
if let Some(name) = single_query_param(&url, "symbol")? {
Ok(Self::Symbol {
name,
- path: path.into(),
+ path,
line_range,
})
} else {
- Ok(Self::Selection {
- path: path.into(),
- line_range,
- })
+ Ok(Self::Selection { path, line_range })
}
+ } else if input.ends_with("/") {
+ Ok(Self::Directory { abs_path: path })
} else {
- let file_path =
- PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path));
- let is_directory = input.ends_with("/");
-
- Ok(Self::File {
- abs_path: file_path,
- is_directory,
- })
+ Ok(Self::File { abs_path: path })
}
}
"zed" => {
if let Some(thread_id) = path.strip_prefix("/agent/thread/") {
let name = single_query_param(&url, "name")?.context("Missing thread name")?;
Ok(Self::Thread {
- id: thread_id.into(),
+ id: acp::SessionId(thread_id.into()),
name,
})
} else if let Some(path) = path.strip_prefix("/agent/text-thread/") {
@@ -120,7 +116,7 @@ impl MentionUri {
pub fn name(&self) -> String {
match self {
- MentionUri::File { abs_path, .. } => abs_path
+ MentionUri::File { abs_path, .. } | MentionUri::Directory { abs_path, .. } => abs_path
.file_name()
.unwrap_or_default()
.to_string_lossy()
@@ -138,18 +134,11 @@ impl MentionUri {
pub fn icon_path(&self, cx: &mut App) -> SharedString {
match self {
- MentionUri::File {
- abs_path,
- is_directory,
- } => {
- if *is_directory {
- FileIcons::get_folder_icon(false, cx)
- .unwrap_or_else(|| IconName::Folder.path().into())
- } else {
- FileIcons::get_icon(&abs_path, cx)
- .unwrap_or_else(|| IconName::File.path().into())
- }
+ MentionUri::File { abs_path } => {
+ FileIcons::get_icon(abs_path, cx).unwrap_or_else(|| IconName::File.path().into())
}
+ MentionUri::Directory { .. } => FileIcons::get_folder_icon(false, cx)
+ .unwrap_or_else(|| IconName::Folder.path().into()),
MentionUri::Symbol { .. } => IconName::Code.path().into(),
MentionUri::Thread { .. } => IconName::Thread.path().into(),
MentionUri::TextThread { .. } => IconName::Thread.path().into(),
@@ -165,25 +154,18 @@ impl MentionUri {
pub fn to_uri(&self) -> Url {
match self {
- MentionUri::File {
- abs_path,
- is_directory,
- } => {
- let mut url = Url::parse("file:///").unwrap();
- let mut path = abs_path.to_string_lossy().to_string();
- if *is_directory && !path.ends_with("/") {
- path.push_str("/");
- }
- url.set_path(&path);
- url
+ MentionUri::File { abs_path } => {
+ Url::from_file_path(abs_path).expect("mention path should be absolute")
+ }
+ MentionUri::Directory { abs_path } => {
+ Url::from_directory_path(abs_path).expect("mention path should be absolute")
}
MentionUri::Symbol {
path,
name,
line_range,
} => {
- let mut url = Url::parse("file:///").unwrap();
- url.set_path(&path.to_string_lossy());
+ let mut url = Url::from_file_path(path).expect("mention path should be absolute");
url.query_pairs_mut().append_pair("symbol", name);
url.set_fragment(Some(&format!(
"L{}:{}",
@@ -193,8 +175,7 @@ impl MentionUri {
url
}
MentionUri::Selection { path, line_range } => {
- let mut url = Url::parse("file:///").unwrap();
- url.set_path(&path.to_string_lossy());
+ let mut url = Url::from_file_path(path).expect("mention path should be absolute");
url.set_fragment(Some(&format!(
"L{}:{}",
line_range.start + 1,
@@ -267,19 +248,17 @@ pub fn selection_name(path: &Path, line_range: &Range) -> String {
#[cfg(test)]
mod tests {
+ use util::{path, uri};
+
use super::*;
#[test]
fn test_parse_file_uri() {
- let file_uri = "file:///path/to/file.rs";
+ let file_uri = uri!("file:///path/to/file.rs");
let parsed = MentionUri::parse(file_uri).unwrap();
match &parsed {
- MentionUri::File {
- abs_path,
- is_directory,
- } => {
- assert_eq!(abs_path.to_str().unwrap(), "/path/to/file.rs");
- assert!(!is_directory);
+ MentionUri::File { abs_path } => {
+ assert_eq!(abs_path.to_str().unwrap(), path!("/path/to/file.rs"));
}
_ => panic!("Expected File variant"),
}
@@ -288,42 +267,38 @@ mod tests {
#[test]
fn test_parse_directory_uri() {
- let file_uri = "file:///path/to/dir/";
+ let file_uri = uri!("file:///path/to/dir/");
let parsed = MentionUri::parse(file_uri).unwrap();
match &parsed {
- MentionUri::File {
- abs_path,
- is_directory,
- } => {
- assert_eq!(abs_path.to_str().unwrap(), "/path/to/dir/");
- assert!(is_directory);
+ MentionUri::Directory { abs_path } => {
+ assert_eq!(abs_path.to_str().unwrap(), path!("/path/to/dir/"));
}
- _ => panic!("Expected File variant"),
+ _ => panic!("Expected Directory variant"),
}
assert_eq!(parsed.to_uri().to_string(), file_uri);
}
#[test]
fn test_to_directory_uri_with_slash() {
- let uri = MentionUri::File {
- abs_path: PathBuf::from("/path/to/dir/"),
- is_directory: true,
+ let uri = MentionUri::Directory {
+ abs_path: PathBuf::from(path!("/path/to/dir/")),
};
- assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
+ let expected = uri!("file:///path/to/dir/");
+ assert_eq!(uri.to_uri().to_string(), expected);
}
#[test]
fn test_to_directory_uri_without_slash() {
- let uri = MentionUri::File {
- abs_path: PathBuf::from("/path/to/dir"),
- is_directory: true,
+ let uri = MentionUri::Directory {
+ abs_path: PathBuf::from(path!("/path/to/dir")),
};
- assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
+ let expected = uri!("file:///path/to/dir/");
+ assert_eq!(uri.to_uri().to_string(), expected);
}
#[test]
fn test_parse_symbol_uri() {
- let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20";
+ let symbol_uri = uri!("file:///path/to/file.rs?symbol=MySymbol#L10:20");
let parsed = MentionUri::parse(symbol_uri).unwrap();
match &parsed {
MentionUri::Symbol {
@@ -331,7 +306,7 @@ mod tests {
name,
line_range,
} => {
- assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
+ assert_eq!(path.to_str().unwrap(), path!("/path/to/file.rs"));
assert_eq!(name, "MySymbol");
assert_eq!(line_range.start, 9);
assert_eq!(line_range.end, 19);
@@ -343,11 +318,11 @@ mod tests {
#[test]
fn test_parse_selection_uri() {
- let selection_uri = "file:///path/to/file.rs#L5:15";
+ let selection_uri = uri!("file:///path/to/file.rs#L5:15");
let parsed = MentionUri::parse(selection_uri).unwrap();
match &parsed {
MentionUri::Selection { path, line_range } => {
- assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
+ assert_eq!(path.to_str().unwrap(), path!("/path/to/file.rs"));
assert_eq!(line_range.start, 4);
assert_eq!(line_range.end, 14);
}
@@ -429,32 +404,35 @@ mod tests {
#[test]
fn test_invalid_line_range_format() {
// Missing L prefix
- assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err());
+ assert!(MentionUri::parse(uri!("file:///path/to/file.rs#10:20")).is_err());
// Missing colon separator
- assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err());
+ assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L1020")).is_err());
// Invalid numbers
- assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err());
- assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err());
+ assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L10:abc")).is_err());
+ assert!(MentionUri::parse(uri!("file:///path/to/file.rs#Labc:20")).is_err());
}
#[test]
fn test_invalid_query_parameters() {
// Invalid query parameter name
- assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err());
+ assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L10:20?invalid=test")).is_err());
// Too many query parameters
assert!(
- MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err()
+ MentionUri::parse(uri!(
+ "file:///path/to/file.rs#L10:20?symbol=test&another=param"
+ ))
+ .is_err()
);
}
#[test]
fn test_zero_based_line_numbers() {
// Test that 0-based line numbers are rejected (should be 1-based)
- assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err());
- assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err());
- assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err());
+ assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L0:10")).is_err());
+ assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L1:0")).is_err());
+ assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L0:0")).is_err());
}
}
diff --git a/crates/action_log/src/action_log.rs b/crates/action_log/src/action_log.rs
index c4eaffc228..1c3cad386d 100644
--- a/crates/action_log/src/action_log.rs
+++ b/crates/action_log/src/action_log.rs
@@ -116,7 +116,7 @@ impl ActionLog {
} else if buffer
.read(cx)
.file()
- .map_or(false, |file| file.disk_state().exists())
+ .is_some_and(|file| file.disk_state().exists())
{
TrackedBufferStatus::Created {
existing_file_content: Some(buffer.read(cx).as_rope().clone()),
@@ -215,7 +215,7 @@ impl ActionLog {
if buffer
.read(cx)
.file()
- .map_or(false, |file| file.disk_state() == DiskState::Deleted)
+ .is_some_and(|file| file.disk_state() == DiskState::Deleted)
{
// If the buffer had been edited by a tool, but it got
// deleted externally, we want to stop tracking it.
@@ -227,7 +227,7 @@ impl ActionLog {
if buffer
.read(cx)
.file()
- .map_or(false, |file| file.disk_state() != DiskState::Deleted)
+ .is_some_and(|file| file.disk_state() != DiskState::Deleted)
{
// If the buffer had been deleted by a tool, but it got
// resurrected externally, we want to clear the edits we
@@ -264,15 +264,14 @@ impl ActionLog {
if let Some((git_diff, (buffer_repo, _))) = git_diff.as_ref().zip(buffer_repo) {
cx.update(|cx| {
let mut old_head = buffer_repo.read(cx).head_commit.clone();
- Some(cx.subscribe(git_diff, move |_, event, cx| match event {
- buffer_diff::BufferDiffEvent::DiffChanged { .. } => {
+ Some(cx.subscribe(git_diff, move |_, event, cx| {
+ if let buffer_diff::BufferDiffEvent::DiffChanged { .. } = event {
let new_head = buffer_repo.read(cx).head_commit.clone();
if new_head != old_head {
old_head = new_head;
git_diff_updates_tx.send(()).ok();
}
}
- _ => {}
}))
})?
} else {
@@ -290,7 +289,7 @@ impl ActionLog {
}
_ = git_diff_updates_rx.changed().fuse() => {
if let Some(git_diff) = git_diff.as_ref() {
- Self::keep_committed_edits(&this, &buffer, &git_diff, cx).await?;
+ Self::keep_committed_edits(&this, &buffer, git_diff, cx).await?;
}
}
}
@@ -498,7 +497,7 @@ impl ActionLog {
new: new_range,
},
&new_diff_base,
- &buffer_snapshot.as_rope(),
+ buffer_snapshot.as_rope(),
));
}
unreviewed_edits
@@ -614,10 +613,10 @@ impl ActionLog {
false
}
});
- if tracked_buffer.unreviewed_edits.is_empty() {
- if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status {
- tracked_buffer.status = TrackedBufferStatus::Modified;
- }
+ if tracked_buffer.unreviewed_edits.is_empty()
+ && let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status
+ {
+ tracked_buffer.status = TrackedBufferStatus::Modified;
}
tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx);
}
@@ -811,7 +810,7 @@ impl ActionLog {
tracked.version != buffer.version
&& buffer
.file()
- .map_or(false, |file| file.disk_state() != DiskState::Deleted)
+ .is_some_and(|file| file.disk_state() != DiskState::Deleted)
})
.map(|(buffer, _)| buffer)
}
@@ -847,7 +846,7 @@ fn apply_non_conflicting_edits(
conflict = true;
if new_edits
.peek()
- .map_or(false, |next_edit| next_edit.old.overlaps(&old_edit.new))
+ .is_some_and(|next_edit| next_edit.old.overlaps(&old_edit.new))
{
new_edit = new_edits.next().unwrap();
} else {
@@ -964,7 +963,7 @@ impl TrackedBuffer {
fn has_edits(&self, cx: &App) -> bool {
self.diff
.read(cx)
- .hunks(&self.buffer.read(cx), cx)
+ .hunks(self.buffer.read(cx), cx)
.next()
.is_some()
}
@@ -2268,7 +2267,7 @@ mod tests {
log::info!("quiescing...");
cx.run_until_parked();
action_log.update(cx, |log, cx| {
- let tracked_buffer = log.tracked_buffers.get(&buffer).unwrap();
+ let tracked_buffer = log.tracked_buffers.get(buffer).unwrap();
let mut old_text = tracked_buffer.diff_base.clone();
let new_text = buffer.read(cx).as_rope();
for edit in tracked_buffer.unreviewed_edits.edits() {
diff --git a/crates/activity_indicator/src/activity_indicator.rs b/crates/activity_indicator/src/activity_indicator.rs
index 7c562aaba4..324480f5b4 100644
--- a/crates/activity_indicator/src/activity_indicator.rs
+++ b/crates/activity_indicator/src/activity_indicator.rs
@@ -103,26 +103,21 @@ impl ActivityIndicator {
cx.subscribe_in(
&workspace_handle,
window,
- |activity_indicator, _, event, window, cx| match event {
- workspace::Event::ClearActivityIndicator { .. } => {
- if activity_indicator.statuses.pop().is_some() {
- activity_indicator.dismiss_error_message(
- &DismissErrorMessage,
- window,
- cx,
- );
- cx.notify();
- }
+ |activity_indicator, _, event, window, cx| {
+ if let workspace::Event::ClearActivityIndicator { .. } = event
+ && activity_indicator.statuses.pop().is_some()
+ {
+ activity_indicator.dismiss_error_message(&DismissErrorMessage, window, cx);
+ cx.notify();
}
- _ => {}
},
)
.detach();
cx.subscribe(
&project.read(cx).lsp_store(),
- |activity_indicator, _, event, cx| match event {
- LspStoreEvent::LanguageServerUpdate { name, message, .. } => {
+ |activity_indicator, _, event, cx| {
+ if let LspStoreEvent::LanguageServerUpdate { name, message, .. } = event {
if let proto::update_language_server::Variant::StatusUpdate(status_update) =
message
{
@@ -191,7 +186,6 @@ impl ActivityIndicator {
}
cx.notify()
}
- _ => {}
},
)
.detach();
@@ -206,9 +200,10 @@ impl ActivityIndicator {
cx.subscribe(
&project.read(cx).git_store().clone(),
- |_, _, event: &GitStoreEvent, cx| match event {
- project::git_store::GitStoreEvent::JobsUpdated => cx.notify(),
- _ => {}
+ |_, _, event: &GitStoreEvent, cx| {
+ if let project::git_store::GitStoreEvent::JobsUpdated = event {
+ cx.notify()
+ }
},
)
.detach();
@@ -458,26 +453,24 @@ impl ActivityIndicator {
.map(|r| r.read(cx))
.and_then(Repository::current_job);
// Show any long-running git command
- if let Some(job_info) = current_job {
- if Instant::now() - job_info.start >= GIT_OPERATION_DELAY {
- return Some(Content {
- icon: Some(
- Icon::new(IconName::ArrowCircle)
- .size(IconSize::Small)
- .with_animation(
- "arrow-circle",
- Animation::new(Duration::from_secs(2)).repeat(),
- |icon, delta| {
- icon.transform(Transformation::rotate(percentage(delta)))
- },
- )
- .into_any_element(),
- ),
- message: job_info.message.into(),
- on_click: None,
- tooltip_message: None,
- });
- }
+ if let Some(job_info) = current_job
+ && Instant::now() - job_info.start >= GIT_OPERATION_DELAY
+ {
+ return Some(Content {
+ icon: Some(
+ Icon::new(IconName::ArrowCircle)
+ .size(IconSize::Small)
+ .with_animation(
+ "arrow-circle",
+ Animation::new(Duration::from_secs(2)).repeat(),
+ |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
+ )
+ .into_any_element(),
+ ),
+ message: job_info.message.into(),
+ on_click: None,
+ tooltip_message: None,
+ });
}
// Show any language server installation info.
@@ -702,7 +695,7 @@ impl ActivityIndicator {
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
})),
- tooltip_message: Some(Self::version_tooltip_message(&version)),
+ tooltip_message: Some(Self::version_tooltip_message(version)),
}),
AutoUpdateStatus::Installing { version } => Some(Content {
icon: Some(
@@ -714,13 +707,13 @@ impl ActivityIndicator {
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
})),
- tooltip_message: Some(Self::version_tooltip_message(&version)),
+ tooltip_message: Some(Self::version_tooltip_message(version)),
}),
AutoUpdateStatus::Updated { version } => Some(Content {
icon: None,
message: "Click to restart and update Zed".to_string(),
on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))),
- tooltip_message: Some(Self::version_tooltip_message(&version)),
+ tooltip_message: Some(Self::version_tooltip_message(version)),
}),
AutoUpdateStatus::Errored => Some(Content {
icon: Some(
@@ -740,21 +733,20 @@ impl ActivityIndicator {
if let Some(extension_store) =
ExtensionStore::try_global(cx).map(|extension_store| extension_store.read(cx))
+ && let Some(extension_id) = extension_store.outstanding_operations().keys().next()
{
- if let Some(extension_id) = extension_store.outstanding_operations().keys().next() {
- return Some(Content {
- icon: Some(
- Icon::new(IconName::Download)
- .size(IconSize::Small)
- .into_any_element(),
- ),
- message: format!("Updating {extension_id} extension…"),
- on_click: Some(Arc::new(|this, window, cx| {
- this.dismiss_error_message(&DismissErrorMessage, window, cx)
- })),
- tooltip_message: None,
- });
- }
+ return Some(Content {
+ icon: Some(
+ Icon::new(IconName::Download)
+ .size(IconSize::Small)
+ .into_any_element(),
+ ),
+ message: format!("Updating {extension_id} extension…"),
+ on_click: Some(Arc::new(|this, window, cx| {
+ this.dismiss_error_message(&DismissErrorMessage, window, cx)
+ })),
+ tooltip_message: None,
+ });
}
None
diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml
index 53ad2f4967..391abb38fe 100644
--- a/crates/agent/Cargo.toml
+++ b/crates/agent/Cargo.toml
@@ -31,7 +31,6 @@ collections.workspace = true
component.workspace = true
context_server.workspace = true
convert_case.workspace = true
-feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
git.workspace = true
diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs
index 38e697dd9b..1636508df6 100644
--- a/crates/agent/src/agent_profile.rs
+++ b/crates/agent/src/agent_profile.rs
@@ -90,7 +90,7 @@ impl AgentProfile {
return false;
};
- return Self::is_enabled(settings, source, tool_name);
+ Self::is_enabled(settings, source, tool_name)
}
fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs
index 8cdb87ef8d..9bb8fc0eae 100644
--- a/crates/agent/src/context.rs
+++ b/crates/agent/src/context.rs
@@ -201,24 +201,24 @@ impl FileContextHandle {
parse_status.changed().await.log_err();
}
- if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot()) {
- if let Some(outline) = snapshot.outline(None) {
- let items = outline
- .items
- .into_iter()
- .map(|item| item.to_point(&snapshot));
+ if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot())
+ && let Some(outline) = snapshot.outline(None)
+ {
+ let items = outline
+ .items
+ .into_iter()
+ .map(|item| item.to_point(&snapshot));
- if let Ok(outline_text) =
- outline::render_outline(items, None, 0, usize::MAX).await
- {
- let context = AgentContext::File(FileContext {
- handle: self,
- full_path,
- text: outline_text.into(),
- is_outline: true,
- });
- return Some((context, vec![buffer]));
- }
+ if let Ok(outline_text) =
+ outline::render_outline(items, None, 0, usize::MAX).await
+ {
+ let context = AgentContext::File(FileContext {
+ handle: self,
+ full_path,
+ text: outline_text.into(),
+ is_outline: true,
+ });
+ return Some((context, vec![buffer]));
}
}
}
diff --git a/crates/agent/src/context_store.rs b/crates/agent/src/context_store.rs
index 60ba5527dc..b531852a18 100644
--- a/crates/agent/src/context_store.rs
+++ b/crates/agent/src/context_store.rs
@@ -338,11 +338,9 @@ impl ContextStore {
image_task,
context_id: self.next_context_id.post_inc(),
});
- if self.has_context(&context) {
- if remove_if_exists {
- self.remove_context(&context, cx);
- return None;
- }
+ if self.has_context(&context) && remove_if_exists {
+ self.remove_context(&context, cx);
+ return None;
}
self.insert_context(context.clone(), cx);
diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs
index 5491842185..fc91e1bb62 100644
--- a/crates/agent/src/thread.rs
+++ b/crates/agent/src/thread.rs
@@ -9,14 +9,16 @@ use crate::{
tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
};
use action_log::ActionLog;
-use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
+use agent_settings::{
+ AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
+ SUMMARIZE_THREAD_PROMPT,
+};
use anyhow::{Result, anyhow};
use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
use collections::HashMap;
-use feature_flags::{self, FeatureFlagAppExt};
use futures::{FutureExt, StreamExt as _, future::Shared};
use git::repository::DiffType;
use gpui::{
@@ -108,7 +110,7 @@ impl std::fmt::Display for PromptId {
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
-pub struct MessageId(pub(crate) usize);
+pub struct MessageId(pub usize);
impl MessageId {
fn post_inc(&mut self) -> Self {
@@ -388,7 +390,6 @@ pub struct Thread {
feedback: Option,
retry_state: Option,
message_feedback: HashMap,
- last_auto_capture_at: Option,
last_received_chunk_at: Option,
request_callback: Option<
Box])>,
@@ -489,7 +490,6 @@ impl Thread {
feedback: None,
retry_state: None,
message_feedback: HashMap::default(),
- last_auto_capture_at: None,
last_error_context: None,
last_received_chunk_at: None,
request_callback: None,
@@ -614,7 +614,6 @@ impl Thread {
tool_use_limit_reached: serialized.tool_use_limit_reached,
feedback: None,
message_feedback: HashMap::default(),
- last_auto_capture_at: None,
last_error_context: None,
last_received_chunk_at: None,
request_callback: None,
@@ -1033,8 +1032,6 @@ impl Thread {
});
}
- self.auto_capture_telemetry(cx);
-
message_id
}
@@ -1651,15 +1648,13 @@ impl Thread {
self.tool_use
.request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
- let pending_tool_use = self.tool_use.insert_tool_output(
+ self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
tool_output,
self.configured_model.as_ref(),
self.completion_mode,
- );
-
- pending_tool_use
+ )
}
pub fn stream_completion(
@@ -1692,7 +1687,7 @@ impl Thread {
self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| {
- let stream_completion_future = model.stream_completion(request, &cx);
+ let stream_completion_future = model.stream_completion(request, cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
@@ -1824,7 +1819,7 @@ impl Thread {
let streamed_input = if tool_use.is_input_complete {
None
} else {
- Some((&tool_use.input).clone())
+ Some(tool_use.input.clone())
};
let ui_text = thread.tool_use.request_tool_use(
@@ -1906,7 +1901,6 @@ impl Thread {
cx.emit(ThreadEvent::StreamedCompletion);
cx.notify();
- thread.auto_capture_telemetry(cx);
Ok(())
})??;
@@ -1974,11 +1968,9 @@ impl Thread {
if let Some(prev_message) =
thread.messages.get(ix - 1)
- {
- if prev_message.role == Role::Assistant {
+ && prev_message.role == Role::Assistant {
break;
}
- }
}
}
@@ -2051,7 +2043,7 @@ impl Thread {
retry_scheduled = thread
.handle_retryable_error_with_delay(
- &completion_error,
+ completion_error,
Some(retry_strategy),
model.clone(),
intent,
@@ -2081,8 +2073,6 @@ impl Thread {
request_callback(request, response_events);
}
- thread.auto_capture_telemetry(cx);
-
if let Ok(initial_usage) = initial_token_usage {
let usage = thread.cumulative_token_usage - initial_usage;
@@ -2130,7 +2120,7 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
let result = async {
- let mut messages = model.model.stream_completion(request, &cx).await?;
+ let mut messages = model.model.stream_completion(request, cx).await?;
let mut new_summary = String::new();
while let Some(event) = messages.next().await {
@@ -2438,12 +2428,10 @@ impl Thread {
return;
}
- let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
-
let request = self.to_summarize_request(
&model,
CompletionIntent::ThreadContextSummarization,
- added_user_message.into(),
+ SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
cx,
);
@@ -2456,7 +2444,7 @@ impl Thread {
// which result to prefer (the old task could complete after the new one, resulting in a
// stale summary).
self.detailed_summary_task = cx.spawn(async move |thread, cx| {
- let stream = model.stream_completion_text(request, &cx);
+ let stream = model.stream_completion_text(request, cx);
let Some(mut messages) = stream.await.log_err() else {
thread
.update(cx, |thread, _cx| {
@@ -2485,13 +2473,13 @@ impl Thread {
.ok()?;
// Save thread so its summary can be reused later
- if let Some(thread) = thread.upgrade() {
- if let Ok(Ok(save_task)) = cx.update(|cx| {
+ if let Some(thread) = thread.upgrade()
+ && let Ok(Ok(save_task)) = cx.update(|cx| {
thread_store
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
- }) {
- save_task.await.log_err();
- }
+ })
+ {
+ save_task.await.log_err();
}
Some(())
@@ -2536,7 +2524,6 @@ impl Thread {
model: Arc,
cx: &mut Context,
) -> Vec {
- self.auto_capture_telemetry(cx);
let request =
Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
let pending_tool_uses = self
@@ -2740,13 +2727,11 @@ impl Thread {
window: Option,
cx: &mut Context,
) {
- if self.all_tools_finished() {
- if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
- if !canceled {
- self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
- }
- self.auto_capture_telemetry(cx);
- }
+ if self.all_tools_finished()
+ && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
+ && !canceled
+ {
+ self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
}
cx.emit(ThreadEvent::ToolFinished {
@@ -2933,11 +2918,11 @@ impl Thread {
let buffer_store = project.read(app_cx).buffer_store();
for buffer_handle in buffer_store.read(app_cx).buffers() {
let buffer = buffer_handle.read(app_cx);
- if buffer.is_dirty() {
- if let Some(file) = buffer.file() {
- let path = file.path().to_string_lossy().to_string();
- unsaved_buffers.push(path);
- }
+ if buffer.is_dirty()
+ && let Some(file) = buffer.file()
+ {
+ let path = file.path().to_string_lossy().to_string();
+ unsaved_buffers.push(path);
}
}
})
@@ -3147,50 +3132,6 @@ impl Thread {
&self.project
}
- pub fn auto_capture_telemetry(&mut self, cx: &mut Context) {
- if !cx.has_flag::() {
- return;
- }
-
- let now = Instant::now();
- if let Some(last) = self.last_auto_capture_at {
- if now.duration_since(last).as_secs() < 10 {
- return;
- }
- }
-
- self.last_auto_capture_at = Some(now);
-
- let thread_id = self.id().clone();
- let github_login = self
- .project
- .read(cx)
- .user_store()
- .read(cx)
- .current_user()
- .map(|user| user.github_login.clone());
- let client = self.project.read(cx).client();
- let serialize_task = self.serialize(cx);
-
- cx.background_executor()
- .spawn(async move {
- if let Ok(serialized_thread) = serialize_task.await {
- if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
- telemetry::event!(
- "Agent Thread Auto-Captured",
- thread_id = thread_id.to_string(),
- thread_data = thread_data,
- auto_capture_reason = "tracked_user",
- github_login = github_login
- );
-
- client.telemetry().flush_events().await;
- }
- }
- })
- .detach();
- }
-
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage
}
@@ -3233,13 +3174,13 @@ impl Thread {
.model
.max_token_count_for_mode(self.completion_mode().into());
- if let Some(exceeded_error) = &self.exceeded_window_error {
- if model.model.id() == exceeded_error.model_id {
- return Some(TotalTokenUsage {
- total: exceeded_error.token_count,
- max,
- });
- }
+ if let Some(exceeded_error) = &self.exceeded_window_error
+ && model.model.id() == exceeded_error.model_id
+ {
+ return Some(TotalTokenUsage {
+ total: exceeded_error.token_count,
+ max,
+ });
}
let total = self
@@ -4043,7 +3984,7 @@ fn main() {{
});
let fake_model = model.as_fake();
- simulate_successful_response(&fake_model, cx);
+ simulate_successful_response(fake_model, cx);
// Should start generating summary when there are >= 2 messages
thread.read_with(cx, |thread, _| {
@@ -4138,7 +4079,7 @@ fn main() {{
});
let fake_model = model.as_fake();
- simulate_successful_response(&fake_model, cx);
+ simulate_successful_response(fake_model, cx);
thread.read_with(cx, |thread, _| {
// State is still Error, not Generating
@@ -5420,7 +5361,7 @@ fn main() {{
});
let fake_model = model.as_fake();
- simulate_successful_response(&fake_model, cx);
+ simulate_successful_response(fake_model, cx);
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Generating));
diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs
index 12c94a522d..45e551dbdf 100644
--- a/crates/agent/src/thread_store.rs
+++ b/crates/agent/src/thread_store.rs
@@ -42,7 +42,7 @@ use std::{
use util::ResultExt as _;
pub static ZED_STATELESS: std::sync::LazyLock =
- std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
+ std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataType {
@@ -74,7 +74,7 @@ impl Column for DataType {
}
}
-const RULES_FILE_NAMES: [&'static str; 9] = [
+const RULES_FILE_NAMES: [&str; 9] = [
".rules",
".cursorrules",
".windsurfrules",
@@ -581,33 +581,32 @@ impl ThreadStore {
return;
};
- if protocol.capable(context_server::protocol::ServerCapability::Tools) {
- if let Some(response) = protocol
+ if protocol.capable(context_server::protocol::ServerCapability::Tools)
+ && let Some(response) = protocol
.request::(())
.await
.log_err()
- {
- let tool_ids = tool_working_set
- .update(cx, |tool_working_set, cx| {
- tool_working_set.extend(
- response.tools.into_iter().map(|tool| {
- Arc::new(ContextServerTool::new(
- context_server_store.clone(),
- server.id(),
- tool,
- )) as Arc
- }),
- cx,
- )
- })
- .log_err();
+ {
+ let tool_ids = tool_working_set
+ .update(cx, |tool_working_set, cx| {
+ tool_working_set.extend(
+ response.tools.into_iter().map(|tool| {
+ Arc::new(ContextServerTool::new(
+ context_server_store.clone(),
+ server.id(),
+ tool,
+ )) as Arc
+ }),
+ cx,
+ )
+ })
+ .log_err();
- if let Some(tool_ids) = tool_ids {
- this.update(cx, |this, _| {
- this.context_server_tool_ids.insert(server_id, tool_ids);
- })
- .log_err();
- }
+ if let Some(tool_ids) = tool_ids {
+ this.update(cx, |this, _| {
+ this.context_server_tool_ids.insert(server_id, tool_ids);
+ })
+ .log_err();
}
}
})
@@ -697,13 +696,14 @@ impl SerializedThreadV0_1_0 {
let mut messages: Vec = Vec::with_capacity(self.0.messages.len());
for message in self.0.messages {
- if message.role == Role::User && !message.tool_results.is_empty() {
- if let Some(last_message) = messages.last_mut() {
- debug_assert!(last_message.role == Role::Assistant);
+ if message.role == Role::User
+ && !message.tool_results.is_empty()
+ && let Some(last_message) = messages.last_mut()
+ {
+ debug_assert!(last_message.role == Role::Assistant);
- last_message.tool_results = message.tool_results;
- continue;
- }
+ last_message.tool_results = message.tool_results;
+ continue;
}
messages.push(message);
@@ -893,7 +893,7 @@ impl ThreadsDatabase {
let needs_migration_from_heed = mdb_path.exists();
- let connection = if *ZED_STATELESS {
+ let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
} else {
Connection::open_file(&sqlite_path.to_string_lossy())
diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs
index 74dfaf9a85..962dca591f 100644
--- a/crates/agent/src/tool_use.rs
+++ b/crates/agent/src/tool_use.rs
@@ -112,19 +112,13 @@ impl ToolUseState {
},
);
- if let Some(window) = &mut window {
- if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
- if let Some(output) = tool_result.output.clone() {
- if let Some(card) = tool.deserialize_card(
- output,
- project.clone(),
- window,
- cx,
- ) {
- this.tool_result_cards.insert(tool_use_id, card);
- }
- }
- }
+ if let Some(window) = &mut window
+ && let Some(tool) = this.tools.read(cx).tool(tool_use, cx)
+ && let Some(output) = tool_result.output.clone()
+ && let Some(card) =
+ tool.deserialize_card(output, project.clone(), window, cx)
+ {
+ this.tool_result_cards.insert(tool_use_id, card);
}
}
}
@@ -281,7 +275,7 @@ impl ToolUseState {
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
- .map_or(false, |results| !results.is_empty())
+ .is_some_and(|results| !results.is_empty())
}
pub fn tool_result(
diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml
index ac1840e5e5..2a39440af8 100644
--- a/crates/agent2/Cargo.toml
+++ b/crates/agent2/Cargo.toml
@@ -8,24 +8,31 @@ license = "GPL-3.0-or-later"
[lib]
path = "src/agent2.rs"
+[features]
+test-support = ["db/test-support"]
+
[lints]
workspace = true
[dependencies]
acp_thread.workspace = true
action_log.workspace = true
+agent.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
agent_settings.workspace = true
anyhow.workspace = true
+assistant_context.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
chrono.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
context_server.workspace = true
+db.workspace = true
fs.workspace = true
futures.workspace = true
+git.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
@@ -37,6 +44,7 @@ language_model.workspace = true
language_models.workspace = true
log.workspace = true
open.workspace = true
+parking_lot.workspace = true
paths.workspace = true
portable-pty.workspace = true
project.workspace = true
@@ -47,6 +55,7 @@ serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
+sqlez.workspace = true
task.workspace = true
terminal.workspace = true
text.workspace = true
@@ -57,15 +66,20 @@ watch.workspace = true
web_search.workspace = true
which.workspace = true
workspace-hack.workspace = true
+zstd.workspace = true
[dev-dependencies]
+agent = { workspace = true, "features" = ["test-support"] }
+assistant_context = { workspace = true, "features" = ["test-support"] }
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
context_server = { workspace = true, "features" = ["test-support"] }
+db = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
+git = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }
diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs
index af740d9901..3c605de803 100644
--- a/crates/agent2/src/agent.rs
+++ b/crates/agent2/src/agent.rs
@@ -1,10 +1,10 @@
+use crate::HistoryStore;
use crate::{
- AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
- DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
- MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
- ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
+ ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
+ UserMessageContent, templates::Templates,
};
-use acp_thread::AgentModelSelector;
+use acp_thread::{AcpThread, AgentModelSelector};
+use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
@@ -22,14 +22,13 @@ use prompt_store::{
};
use settings::update_settings_file;
use std::any::Any;
-use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use util::ResultExt;
-const RULES_FILE_NAMES: [&'static str; 9] = [
+const RULES_FILE_NAMES: [&str; 9] = [
".rules",
".cursorrules",
".windsurfrules",
@@ -51,7 +50,8 @@ struct Session {
thread: Entity,
/// The ACP thread that handles protocol communication
acp_thread: WeakEntity,
- _subscription: Subscription,
+ pending_save: Task<()>,
+ _subscriptions: Vec,
}
pub struct LanguageModels {
@@ -91,7 +91,7 @@ impl LanguageModels {
for provider in &providers {
for model in provider.recommended_models(cx) {
recommended_models.insert(model.id());
- recommended.push(Self::map_language_model_to_info(&model, &provider));
+ recommended.push(Self::map_language_model_to_info(&model, provider));
}
}
if !recommended.is_empty() {
@@ -155,8 +155,9 @@ impl LanguageModels {
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap,
+ history: Entity,
/// Shared project context for all threads
- project_context: Rc>,
+ project_context: Entity,
project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task>,
context_server_registry: Entity,
@@ -173,6 +174,7 @@ pub struct NativeAgent {
impl NativeAgent {
pub async fn new(
project: Entity,
+ history: Entity,
templates: Arc,
prompt_store: Option>,
fs: Arc,
@@ -200,7 +202,8 @@ impl NativeAgent {
watch::channel(());
Self {
sessions: HashMap::new(),
- project_context: Rc::new(RefCell::new(project_context)),
+ history,
+ project_context: cx.new(|_| project_context),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
@@ -218,6 +221,55 @@ impl NativeAgent {
})
}
+ fn register_session(
+ &mut self,
+ thread_handle: Entity,
+ cx: &mut Context,
+ ) -> Entity {
+ let connection = Rc::new(NativeAgentConnection(cx.entity()));
+ let registry = LanguageModelRegistry::read_global(cx);
+ let summarization_model = registry.thread_summary_model().map(|c| c.model);
+
+ thread_handle.update(cx, |thread, cx| {
+ thread.set_summarization_model(summarization_model, cx);
+ thread.add_default_tools(cx)
+ });
+
+ let thread = thread_handle.read(cx);
+ let session_id = thread.id().clone();
+ let title = thread.title();
+ let project = thread.project.clone();
+ let action_log = thread.action_log.clone();
+ let acp_thread = cx.new(|_cx| {
+ acp_thread::AcpThread::new(
+ title,
+ connection,
+ project.clone(),
+ action_log.clone(),
+ session_id.clone(),
+ )
+ });
+ let subscriptions = vec![
+ cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
+ this.sessions.remove(acp_thread.session_id());
+ }),
+ cx.observe(&thread_handle, move |this, thread, cx| {
+ this.save_thread(thread.clone(), cx)
+ }),
+ ];
+
+ self.sessions.insert(
+ session_id,
+ Session {
+ thread: thread_handle,
+ acp_thread: acp_thread.downgrade(),
+ _subscriptions: subscriptions,
+ pending_save: Task::ready(()),
+ },
+ );
+ acp_thread
+ }
+
pub fn models(&self) -> &LanguageModels {
&self.models
}
@@ -233,7 +285,9 @@ impl NativeAgent {
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
})?
.await;
- this.update(cx, |this, _| this.project_context.replace(project_context))?;
+ this.update(cx, |this, cx| {
+ this.project_context = cx.new(|_| project_context);
+ })?;
}
Ok(())
@@ -426,21 +480,101 @@ impl NativeAgent {
) {
self.models.refresh_list(cx);
- let default_model = LanguageModelRegistry::read_global(cx)
- .default_model()
- .map(|m| m.model.clone());
+ let registry = LanguageModelRegistry::read_global(cx);
+ let default_model = registry.default_model().map(|m| m.model.clone());
+ let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| {
if thread.model().is_none()
&& let Some(model) = default_model.clone()
{
- thread.set_model(model);
+ thread.set_model(model, cx);
cx.notify();
}
+ thread.set_summarization_model(summarization_model.clone(), cx);
});
}
}
+
+ pub fn open_thread(
+ &mut self,
+ id: acp::SessionId,
+ cx: &mut Context,
+ ) -> Task>> {
+ let database_future = ThreadsDatabase::connect(cx);
+ cx.spawn(async move |this, cx| {
+ let database = database_future.await.map_err(|err| anyhow!(err))?;
+ let db_thread = database
+ .load_thread(id.clone())
+ .await?
+ .with_context(|| format!("no thread found with ID: {id:?}"))?;
+
+ let thread = this.update(cx, |this, cx| {
+ let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
+ cx.new(|cx| {
+ Thread::from_db(
+ id.clone(),
+ db_thread,
+ this.project.clone(),
+ this.project_context.clone(),
+ this.context_server_registry.clone(),
+ action_log.clone(),
+ this.templates.clone(),
+ cx,
+ )
+ })
+ })?;
+ let acp_thread =
+ this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
+ let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
+ cx.update(|cx| {
+ NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
+ })?
+ .await?;
+ Ok(acp_thread)
+ })
+ }
+
+ pub fn thread_summary(
+ &mut self,
+ id: acp::SessionId,
+ cx: &mut Context,
+ ) -> Task> {
+ let thread = self.open_thread(id.clone(), cx);
+ cx.spawn(async move |this, cx| {
+ let acp_thread = thread.await?;
+ let result = this
+ .update(cx, |this, cx| {
+ this.sessions
+ .get(&id)
+ .unwrap()
+ .thread
+ .update(cx, |thread, cx| thread.summary(cx))
+ })?
+ .await?;
+ drop(acp_thread);
+ Ok(result)
+ })
+ }
+
+ fn save_thread(&mut self, thread: Entity, cx: &mut Context) {
+ let database_future = ThreadsDatabase::connect(cx);
+ let (id, db_thread) =
+ thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
+ let Some(session) = self.sessions.get_mut(&id) else {
+ return;
+ };
+ let history = self.history.clone();
+ session.pending_save = cx.spawn(async move |_, cx| {
+ let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
+ return;
+ };
+ let db_thread = db_thread.await;
+ database.save_thread(id, db_thread).await.log_err();
+ history.update(cx, |history, cx| history.reload(cx)).ok();
+ });
+ }
}
/// Wrapper struct that implements the AgentConnection trait
@@ -461,10 +595,7 @@ impl NativeAgentConnection {
session_id: acp::SessionId,
cx: &mut App,
f: impl 'static
- + FnOnce(
- Entity,
- &mut App,
- ) -> Result>>,
+ + FnOnce(Entity, &mut App) -> Result>>,
) -> Task> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent
@@ -476,19 +607,38 @@ impl NativeAgentConnection {
};
log::debug!("Found session for: {}", session_id);
- let mut response_stream = match f(thread, cx) {
+ let response_stream = match f(thread, cx) {
Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)),
};
+ Self::handle_thread_events(response_stream, acp_thread, cx)
+ }
+
+ fn handle_thread_events(
+ mut events: mpsc::UnboundedReceiver>,
+ acp_thread: WeakEntity,
+ cx: &App,
+ ) -> Task> {
cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread
- while let Some(result) = response_stream.next().await {
+ while let Some(result) = events.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
- AgentResponseEvent::Text(text) => {
+ ThreadEvent::UserMessage(message) => {
+ acp_thread.update(cx, |thread, cx| {
+ for content in message.content {
+ thread.push_user_content_block(
+ Some(message.id.clone()),
+ content.into(),
+ cx,
+ );
+ }
+ })?;
+ }
+ ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@@ -500,7 +650,7 @@ impl NativeAgentConnection {
)
})?;
}
- AgentResponseEvent::Thinking(text) => {
+ ThreadEvent::AgentThinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@@ -512,7 +662,7 @@ impl NativeAgentConnection {
)
})?;
}
- AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
+ ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
@@ -535,17 +685,31 @@ impl NativeAgentConnection {
})
.detach();
}
- AgentResponseEvent::ToolCall(tool_call) => {
+ ThreadEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})??;
}
- AgentResponseEvent::ToolCallUpdate(update) => {
+ ThreadEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
- AgentResponseEvent::Stop(stop_reason) => {
+ ThreadEvent::TokenUsageUpdate(usage) => {
+ acp_thread.update(cx, |thread, cx| {
+ thread.update_token_usage(Some(usage), cx)
+ })?;
+ }
+ ThreadEvent::TitleUpdate(title) => {
+ acp_thread
+ .update(cx, |thread, cx| thread.update_title(title, cx))??;
+ }
+ ThreadEvent::Retry(status) => {
+ acp_thread.update(cx, |thread, cx| {
+ thread.update_retry_status(status, cx)
+ })?;
+ }
+ ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
@@ -598,8 +762,8 @@ impl AgentModelSelector for NativeAgentConnection {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
- thread.update(cx, |thread, _cx| {
- thread.set_model(model.clone());
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model.clone(), cx);
});
update_settings_file::(
@@ -659,31 +823,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context");
- // Generate session ID
- let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
- log::info!("Created session with ID: {}", session_id);
-
- // Create AcpThread
- let acp_thread = cx.update(|cx| {
- cx.new(|cx| {
- acp_thread::AcpThread::new(
- "agent2",
- self.clone(),
- project.clone(),
- session_id.clone(),
- cx,
- )
- })
- })?;
- let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
-
+ let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
// Create Thread
let thread = agent.update(
cx,
|agent, cx: &mut gpui::Context| -> Result<_> {
// Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx);
-
// Log available models for debugging
let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count);
@@ -695,7 +841,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
});
let thread = cx.new(|cx| {
- let mut thread = Thread::new(
+ Thread::new(
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
@@ -703,45 +849,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
agent.templates.clone(),
default_model,
cx,
- );
- thread.add_tool(CopyPathTool::new(project.clone()));
- thread.add_tool(CreateDirectoryTool::new(project.clone()));
- thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
- thread.add_tool(DiagnosticsTool::new(project.clone()));
- thread.add_tool(EditFileTool::new(cx.entity()));
- thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
- thread.add_tool(FindPathTool::new(project.clone()));
- thread.add_tool(GrepTool::new(project.clone()));
- thread.add_tool(ListDirectoryTool::new(project.clone()));
- thread.add_tool(MovePathTool::new(project.clone()));
- thread.add_tool(NowTool);
- thread.add_tool(OpenTool::new(project.clone()));
- thread.add_tool(ReadFileTool::new(project.clone(), action_log));
- thread.add_tool(TerminalTool::new(project.clone(), cx));
- thread.add_tool(ThinkingTool);
- thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
- thread
+ )
});
Ok(thread)
},
)??;
-
- // Store the session
- agent.update(cx, |agent, cx| {
- agent.sessions.insert(
- session_id,
- Session {
- thread,
- acp_thread: acp_thread.downgrade(),
- _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
- this.sessions.remove(acp_thread.session_id());
- }),
- },
- );
- })?;
-
- Ok(acp_thread)
+ agent.update(cx, |agent, cx| agent.register_session(thread, cx))
})
}
@@ -797,7 +911,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) {
- agent.thread.update(cx, |thread, _cx| thread.cancel());
+ agent.thread.update(cx, |thread, cx| thread.cancel(cx));
}
});
}
@@ -808,10 +922,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx: &mut App,
) -> Option> {
self.0.update(cx, |agent, _cx| {
- agent
- .sessions
- .get(session_id)
- .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
+ agent.sessions.get(session_id).map(|session| {
+ Rc::new(NativeAgentSessionEditor {
+ thread: session.thread.clone(),
+ acp_thread: session.acp_thread.clone(),
+ }) as _
+ })
})
}
@@ -820,11 +936,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}
}
-struct NativeAgentSessionEditor(Entity);
+struct NativeAgentSessionEditor {
+ thread: Entity,
+ acp_thread: WeakEntity,
+}
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> {
- Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
+ match self.thread.update(cx, |thread, cx| {
+ thread.truncate(message_id.clone(), cx)?;
+ Ok(thread.latest_token_usage())
+ }) {
+ Ok(usage) => {
+ self.acp_thread
+ .update(cx, |thread, cx| {
+ thread.update_token_usage(usage, cx);
+ })
+ .ok();
+ Task::ready(Ok(()))
+ }
+ Err(error) => Task::ready(Err(error)),
+ }
}
}
@@ -863,8 +995,11 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
+ let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
let agent = NativeAgent::new(
project.clone(),
+ history_store,
Templates::new(),
None,
fs.clone(),
@@ -872,8 +1007,8 @@ mod tests {
)
.await
.unwrap();
- agent.read_with(cx, |agent, _| {
- assert_eq!(agent.project_context.borrow().worktrees, vec![])
+ agent.read_with(cx, |agent, cx| {
+ assert_eq!(agent.project_context.read(cx).worktrees, vec![])
});
let worktree = project
@@ -881,9 +1016,9 @@ mod tests {
.await
.unwrap();
cx.run_until_parked();
- agent.read_with(cx, |agent, _| {
+ agent.read_with(cx, |agent, cx| {
assert_eq!(
- agent.project_context.borrow().worktrees,
+ agent.project_context.read(cx).worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
@@ -898,7 +1033,7 @@ mod tests {
agent.read_with(cx, |agent, cx| {
let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
assert_eq!(
- agent.project_context.borrow().worktrees,
+ agent.project_context.read(cx).worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
@@ -918,9 +1053,12 @@ mod tests {
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
+ let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
let connection = NativeAgentConnection(
NativeAgent::new(
project.clone(),
+ history_store,
Templates::new(),
None,
fs.clone(),
@@ -971,9 +1109,13 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [], cx).await;
+ let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
+
// Create the agent and connection
let agent = NativeAgent::new(
project.clone(),
+ history_store,
Templates::new(),
None,
fs.clone(),
diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs
index f13cd1bd67..1fc9c1cb95 100644
--- a/crates/agent2/src/agent2.rs
+++ b/crates/agent2/src/agent2.rs
@@ -1,13 +1,18 @@
mod agent;
+mod db;
+mod history_store;
mod native_agent_server;
mod templates;
mod thread;
+mod tool_schema;
mod tools;
#[cfg(test)]
mod tests;
pub use agent::*;
+pub use db::*;
+pub use history_store::*;
pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*;
diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs
new file mode 100644
index 0000000000..c6a6c38201
--- /dev/null
+++ b/crates/agent2/src/db.rs
@@ -0,0 +1,488 @@
+use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
+use acp_thread::UserMessageId;
+use agent::{thread::DetailedSummaryState, thread_store};
+use agent_client_protocol as acp;
+use agent_settings::{AgentProfileId, CompletionMode};
+use anyhow::{Result, anyhow};
+use chrono::{DateTime, Utc};
+use collections::{HashMap, IndexMap};
+use futures::{FutureExt, future::Shared};
+use gpui::{BackgroundExecutor, Global, Task};
+use indoc::indoc;
+use parking_lot::Mutex;
+use serde::{Deserialize, Serialize};
+use sqlez::{
+ bindable::{Bind, Column},
+ connection::Connection,
+ statement::Statement,
+};
+use std::sync::Arc;
+use ui::{App, SharedString};
+
+pub type DbMessage = crate::Message;
+pub type DbSummary = DetailedSummaryState;
+pub type DbLanguageModel = thread_store::SerializedLanguageModel;
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DbThreadMetadata {
+ pub id: acp::SessionId,
+ #[serde(alias = "summary")]
+ pub title: SharedString,
+ pub updated_at: DateTime,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct DbThread {
+ pub title: SharedString,
+ pub messages: Vec,
+ pub updated_at: DateTime,
+ #[serde(default)]
+ pub detailed_summary: Option,
+ #[serde(default)]
+ pub initial_project_snapshot: Option>,
+ #[serde(default)]
+ pub cumulative_token_usage: language_model::TokenUsage,
+ #[serde(default)]
+ pub request_token_usage: HashMap,
+ #[serde(default)]
+ pub model: Option,
+ #[serde(default)]
+ pub completion_mode: Option,
+ #[serde(default)]
+ pub profile: Option,
+}
+
+impl DbThread {
+ pub const VERSION: &'static str = "0.3.0";
+
+ pub fn from_json(json: &[u8]) -> Result {
+ let saved_thread_json = serde_json::from_slice::(json)?;
+ match saved_thread_json.get("version") {
+ Some(serde_json::Value::String(version)) => match version.as_str() {
+ Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
+ _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
+ },
+ _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
+ }
+ }
+
+ fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result {
+ let mut messages = Vec::new();
+ let mut request_token_usage = HashMap::default();
+
+ let mut last_user_message_id = None;
+ for (ix, msg) in thread.messages.into_iter().enumerate() {
+ let message = match msg.role {
+ language_model::Role::User => {
+ let mut content = Vec::new();
+
+ // Convert segments to content
+ for segment in msg.segments {
+ match segment {
+ thread_store::SerializedMessageSegment::Text { text } => {
+ content.push(UserMessageContent::Text(text));
+ }
+ thread_store::SerializedMessageSegment::Thinking { text, .. } => {
+ // User messages don't have thinking segments, but handle gracefully
+ content.push(UserMessageContent::Text(text));
+ }
+ thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
+ // User messages don't have redacted thinking, skip.
+ }
+ }
+ }
+
+ // If no content was added, add context as text if available
+ if content.is_empty() && !msg.context.is_empty() {
+ content.push(UserMessageContent::Text(msg.context));
+ }
+
+ let id = UserMessageId::new();
+ last_user_message_id = Some(id.clone());
+
+ crate::Message::User(UserMessage {
+ // MessageId from old format can't be meaningfully converted, so generate a new one
+ id,
+ content,
+ })
+ }
+ language_model::Role::Assistant => {
+ let mut content = Vec::new();
+
+ // Convert segments to content
+ for segment in msg.segments {
+ match segment {
+ thread_store::SerializedMessageSegment::Text { text } => {
+ content.push(AgentMessageContent::Text(text));
+ }
+ thread_store::SerializedMessageSegment::Thinking {
+ text,
+ signature,
+ } => {
+ content.push(AgentMessageContent::Thinking { text, signature });
+ }
+ thread_store::SerializedMessageSegment::RedactedThinking { data } => {
+ content.push(AgentMessageContent::RedactedThinking(data));
+ }
+ }
+ }
+
+ // Convert tool uses
+ let mut tool_names_by_id = HashMap::default();
+ for tool_use in msg.tool_uses {
+ tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
+ content.push(AgentMessageContent::ToolUse(
+ language_model::LanguageModelToolUse {
+ id: tool_use.id,
+ name: tool_use.name.into(),
+ raw_input: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ input: tool_use.input,
+ is_input_complete: true,
+ },
+ ));
+ }
+
+ // Convert tool results
+ let mut tool_results = IndexMap::default();
+ for tool_result in msg.tool_results {
+ let name = tool_names_by_id
+ .remove(&tool_result.tool_use_id)
+ .unwrap_or_else(|| SharedString::from("unknown"));
+ tool_results.insert(
+ tool_result.tool_use_id.clone(),
+ language_model::LanguageModelToolResult {
+ tool_use_id: tool_result.tool_use_id,
+ tool_name: name.into(),
+ is_error: tool_result.is_error,
+ content: tool_result.content,
+ output: tool_result.output,
+ },
+ );
+ }
+
+ if let Some(last_user_message_id) = &last_user_message_id
+ && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
+ {
+ request_token_usage.insert(last_user_message_id.clone(), token_usage);
+ }
+
+ crate::Message::Agent(AgentMessage {
+ content,
+ tool_results,
+ })
+ }
+ language_model::Role::System => {
+ // Skip system messages as they're not supported in the new format
+ continue;
+ }
+ };
+
+ messages.push(message);
+ }
+
+ Ok(Self {
+ title: thread.summary,
+ messages,
+ updated_at: thread.updated_at,
+ detailed_summary: match thread.detailed_summary_state {
+ DetailedSummaryState::NotGenerated | DetailedSummaryState::Generating { .. } => {
+ None
+ }
+ DetailedSummaryState::Generated { text, .. } => Some(text),
+ },
+ initial_project_snapshot: thread.initial_project_snapshot,
+ cumulative_token_usage: thread.cumulative_token_usage,
+ request_token_usage,
+ model: thread.model,
+ completion_mode: thread.completion_mode,
+ profile: thread.profile,
+ })
+ }
+}
+
+pub static ZED_STATELESS: std::sync::LazyLock =
+ std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
+
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+pub enum DataType {
+ #[serde(rename = "json")]
+ Json,
+ #[serde(rename = "zstd")]
+ Zstd,
+}
+
+impl Bind for DataType {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result {
+ let value = match self {
+ DataType::Json => "json",
+ DataType::Zstd => "zstd",
+ };
+ value.bind(statement, start_index)
+ }
+}
+
+impl Column for DataType {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let (value, next_index) = String::column(statement, start_index)?;
+ let data_type = match value.as_str() {
+ "json" => DataType::Json,
+ "zstd" => DataType::Zstd,
+ _ => anyhow::bail!("Unknown data type: {}", value),
+ };
+ Ok((data_type, next_index))
+ }
+}
+
+pub(crate) struct ThreadsDatabase {
+ executor: BackgroundExecutor,
+ connection: Arc>,
+}
+
+struct GlobalThreadsDatabase(Shared, Arc>>>);
+
+impl Global for GlobalThreadsDatabase {}
+
+impl ThreadsDatabase {
+ pub fn connect(cx: &mut App) -> Shared, Arc>>> {
+ if cx.has_global::() {
+ return cx.global::().0.clone();
+ }
+ let executor = cx.background_executor().clone();
+ let task = executor
+ .spawn({
+ let executor = executor.clone();
+ async move {
+ match ThreadsDatabase::new(executor) {
+ Ok(db) => Ok(Arc::new(db)),
+ Err(err) => Err(Arc::new(err)),
+ }
+ }
+ })
+ .shared();
+
+ cx.set_global(GlobalThreadsDatabase(task.clone()));
+ task
+ }
+
+ pub fn new(executor: BackgroundExecutor) -> Result {
+ let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
+ Connection::open_memory(Some("THREAD_FALLBACK_DB"))
+ } else {
+ let threads_dir = paths::data_dir().join("threads");
+ std::fs::create_dir_all(&threads_dir)?;
+ let sqlite_path = threads_dir.join("threads.db");
+ Connection::open_file(&sqlite_path.to_string_lossy())
+ };
+
+ connection.exec(indoc! {"
+ CREATE TABLE IF NOT EXISTS threads (
+ id TEXT PRIMARY KEY,
+ summary TEXT NOT NULL,
+ updated_at TEXT NOT NULL,
+ data_type TEXT NOT NULL,
+ data BLOB NOT NULL
+ )
+ "})?()
+ .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
+
+ let db = Self {
+ executor: executor.clone(),
+ connection: Arc::new(Mutex::new(connection)),
+ };
+
+ Ok(db)
+ }
+
+ fn save_thread_sync(
+ connection: &Arc>,
+ id: acp::SessionId,
+ thread: DbThread,
+ ) -> Result<()> {
+ const COMPRESSION_LEVEL: i32 = 3;
+
+ #[derive(Serialize)]
+ struct SerializedThread {
+ #[serde(flatten)]
+ thread: DbThread,
+ version: &'static str,
+ }
+
+ let title = thread.title.to_string();
+ let updated_at = thread.updated_at.to_rfc3339();
+ let json_data = serde_json::to_string(&SerializedThread {
+ thread,
+ version: DbThread::VERSION,
+ })?;
+
+ let connection = connection.lock();
+
+ let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
+ let data_type = DataType::Zstd;
+ let data = compressed;
+
+ let mut insert = connection.exec_bound::<(Arc, String, String, DataType, Vec)>(indoc! {"
+ INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
+ "})?;
+
+ insert((id.0.clone(), title, updated_at, data_type, data))?;
+
+ Ok(())
+ }
+
+ pub fn list_threads(&self) -> Task>> {
+ let connection = self.connection.clone();
+
+ self.executor.spawn(async move {
+ let connection = connection.lock();
+
+ let mut select =
+ connection.select_bound::<(), (Arc, String, String)>(indoc! {"
+ SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
+ "})?;
+
+ let rows = select(())?;
+ let mut threads = Vec::new();
+
+ for (id, summary, updated_at) in rows {
+ threads.push(DbThreadMetadata {
+ id: acp::SessionId(id),
+ title: summary.into(),
+ updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
+ });
+ }
+
+ Ok(threads)
+ })
+ }
+
+ pub fn load_thread(&self, id: acp::SessionId) -> Task>> {
+ let connection = self.connection.clone();
+
+ self.executor.spawn(async move {
+ let connection = connection.lock();
+ let mut select = connection.select_bound::, (DataType, Vec)>(indoc! {"
+ SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
+ "})?;
+
+ let rows = select(id.0)?;
+ if let Some((data_type, data)) = rows.into_iter().next() {
+ let json_data = match data_type {
+ DataType::Zstd => {
+ let decompressed = zstd::decode_all(&data[..])?;
+ String::from_utf8(decompressed)?
+ }
+ DataType::Json => String::from_utf8(data)?,
+ };
+ let thread = DbThread::from_json(json_data.as_bytes())?;
+ Ok(Some(thread))
+ } else {
+ Ok(None)
+ }
+ })
+ }
+
+ pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task> {
+ let connection = self.connection.clone();
+
+ self.executor
+ .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
+ }
+
+ pub fn delete_thread(&self, id: acp::SessionId) -> Task> {
+ let connection = self.connection.clone();
+
+ self.executor.spawn(async move {
+ let connection = connection.lock();
+
+ let mut delete = connection.exec_bound::>(indoc! {"
+ DELETE FROM threads WHERE id = ?
+ "})?;
+
+ delete(id.0)?;
+
+ Ok(())
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+ use agent::MessageSegment;
+ use agent::context::LoadedContext;
+ use client::Client;
+ use fs::FakeFs;
+ use gpui::AppContext;
+ use gpui::TestAppContext;
+ use http_client::FakeHttpClient;
+ use language_model::Role;
+ use project::Project;
+ use settings::SettingsStore;
+
+ fn init_test(cx: &mut TestAppContext) {
+ env_logger::try_init().ok();
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ language::init(cx);
+
+ let http_client = FakeHttpClient::with_404_response();
+ let clock = Arc::new(clock::FakeSystemClock::new());
+ let client = Client::new(clock, http_client, cx);
+ agent::init(cx);
+ agent_settings::init(cx);
+ language_model::init(client.clone(), cx);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [], cx).await;
+
+ // Save a thread using the old agent.
+ let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
+ let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
+ thread.update(cx, |thread, cx| {
+ thread.insert_message(
+ Role::User,
+ vec![MessageSegment::Text("Hey!".into())],
+ LoadedContext::default(),
+ vec![],
+ false,
+ cx,
+ );
+ thread.insert_message(
+ Role::Assistant,
+ vec![MessageSegment::Text("How're you doing?".into())],
+ LoadedContext::default(),
+ vec![],
+ false,
+ cx,
+ )
+ });
+ thread_store
+ .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
+ .await
+ .unwrap();
+
+ // Open that same thread using the new agent.
+ let db = cx.update(ThreadsDatabase::connect).await.unwrap();
+ let threads = db.list_threads().await.unwrap();
+ assert_eq!(threads.len(), 1);
+ let thread = db
+ .load_thread(threads[0].id.clone())
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
+ assert_eq!(
+ thread.messages[1].to_markdown(),
+ "## Assistant\n\nHow're you doing?\n"
+ );
+ }
+}
diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs
new file mode 100644
index 0000000000..7eb7da94ba
--- /dev/null
+++ b/crates/agent2/src/history_store.rs
@@ -0,0 +1,345 @@
+use crate::{DbThreadMetadata, ThreadsDatabase};
+use acp_thread::MentionUri;
+use agent_client_protocol as acp;
+use anyhow::{Context as _, Result, anyhow};
+use assistant_context::{AssistantContext, SavedContextMetadata};
+use chrono::{DateTime, Utc};
+use db::kvp::KEY_VALUE_STORE;
+use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
+use itertools::Itertools;
+use paths::contexts_dir;
+use serde::{Deserialize, Serialize};
+use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration};
+use util::ResultExt as _;
+
+const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
+const RECENTLY_OPENED_THREADS_KEY: &str = "recent-agent-threads";
+const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
+
+const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
+
+#[derive(Clone, Debug)]
+pub enum HistoryEntry {
+ AcpThread(DbThreadMetadata),
+ TextThread(SavedContextMetadata),
+}
+
+impl HistoryEntry {
+ pub fn updated_at(&self) -> DateTime {
+ match self {
+ HistoryEntry::AcpThread(thread) => thread.updated_at,
+ HistoryEntry::TextThread(context) => context.mtime.to_utc(),
+ }
+ }
+
+ pub fn id(&self) -> HistoryEntryId {
+ match self {
+ HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()),
+ HistoryEntry::TextThread(context) => HistoryEntryId::TextThread(context.path.clone()),
+ }
+ }
+
+ pub fn mention_uri(&self) -> MentionUri {
+ match self {
+ HistoryEntry::AcpThread(thread) => MentionUri::Thread {
+ id: thread.id.clone(),
+ name: thread.title.to_string(),
+ },
+ HistoryEntry::TextThread(context) => MentionUri::TextThread {
+ path: context.path.as_ref().to_owned(),
+ name: context.title.to_string(),
+ },
+ }
+ }
+
+ pub fn title(&self) -> &SharedString {
+ match self {
+ HistoryEntry::AcpThread(thread) if thread.title.is_empty() => DEFAULT_TITLE,
+ HistoryEntry::AcpThread(thread) => &thread.title,
+ HistoryEntry::TextThread(context) => &context.title,
+ }
+ }
+}
+
+/// Generic identifier for a history entry.
+#[derive(Clone, PartialEq, Eq, Debug, Hash)]
+pub enum HistoryEntryId {
+ AcpThread(acp::SessionId),
+ TextThread(Arc),
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+enum SerializedRecentOpen {
+ AcpThread(String),
+ TextThread(String),
+}
+
+pub struct HistoryStore {
+ threads: Vec,
+ context_store: Entity,
+ recently_opened_entries: VecDeque,
+ _subscriptions: Vec,
+ _save_recently_opened_entries_task: Task<()>,
+}
+
+impl HistoryStore {
+ pub fn new(
+ context_store: Entity,
+ cx: &mut Context,
+ ) -> Self {
+ let subscriptions = vec![cx.observe(&context_store, |_, _, cx| cx.notify())];
+
+ cx.spawn(async move |this, cx| {
+ let entries = Self::load_recently_opened_entries(cx).await;
+ this.update(cx, |this, cx| {
+ if let Some(entries) = entries.log_err() {
+ this.recently_opened_entries = entries;
+ }
+
+ this.reload(cx);
+ })
+ .ok();
+ })
+ .detach();
+
+ Self {
+ context_store,
+ recently_opened_entries: VecDeque::default(),
+ threads: Vec::default(),
+ _subscriptions: subscriptions,
+ _save_recently_opened_entries_task: Task::ready(()),
+ }
+ }
+
+ pub fn delete_thread(
+ &mut self,
+ id: acp::SessionId,
+ cx: &mut Context,
+ ) -> Task> {
+ let database_future = ThreadsDatabase::connect(cx);
+ cx.spawn(async move |this, cx| {
+ let database = database_future.await.map_err(|err| anyhow!(err))?;
+ database.delete_thread(id.clone()).await?;
+ this.update(cx, |this, cx| this.reload(cx))
+ })
+ }
+
+ pub fn delete_text_thread(
+ &mut self,
+ path: Arc,
+ cx: &mut Context,
+ ) -> Task> {
+ self.context_store.update(cx, |context_store, cx| {
+ context_store.delete_local_context(path, cx)
+ })
+ }
+
+ pub fn load_text_thread(
+ &self,
+ path: Arc,
+ cx: &mut Context,
+ ) -> Task>> {
+ self.context_store.update(cx, |context_store, cx| {
+ context_store.open_local_context(path, cx)
+ })
+ }
+
+ pub fn reload(&self, cx: &mut Context) {
+ let database_future = ThreadsDatabase::connect(cx);
+ cx.spawn(async move |this, cx| {
+ let threads = database_future
+ .await
+ .map_err(|err| anyhow!(err))?
+ .list_threads()
+ .await?;
+
+ this.update(cx, |this, cx| {
+ if this.recently_opened_entries.len() < MAX_RECENTLY_OPENED_ENTRIES {
+ for thread in threads
+ .iter()
+ .take(MAX_RECENTLY_OPENED_ENTRIES - this.recently_opened_entries.len())
+ .rev()
+ {
+ this.push_recently_opened_entry(
+ HistoryEntryId::AcpThread(thread.id.clone()),
+ cx,
+ )
+ }
+ }
+ this.threads = threads;
+ cx.notify();
+ })
+ })
+ .detach_and_log_err(cx);
+ }
+
+ pub fn entries(&self, cx: &App) -> Vec {
+ let mut history_entries = Vec::new();
+
+ #[cfg(debug_assertions)]
+ if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
+ return history_entries;
+ }
+
+ history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread));
+ history_entries.extend(
+ self.context_store
+ .read(cx)
+ .unordered_contexts()
+ .cloned()
+ .map(HistoryEntry::TextThread),
+ );
+
+ history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
+ history_entries
+ }
+
+ pub fn is_empty(&self, cx: &App) -> bool {
+ self.threads.is_empty()
+ && self
+ .context_store
+ .read(cx)
+ .unordered_contexts()
+ .next()
+ .is_none()
+ }
+
+ pub fn recently_opened_entries(&self, cx: &App) -> Vec {
+ #[cfg(debug_assertions)]
+ if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
+ return Vec::new();
+ }
+
+ let thread_entries = self.threads.iter().flat_map(|thread| {
+ self.recently_opened_entries
+ .iter()
+ .enumerate()
+ .flat_map(|(index, entry)| match entry {
+ HistoryEntryId::AcpThread(id) if &thread.id == id => {
+ Some((index, HistoryEntry::AcpThread(thread.clone())))
+ }
+ _ => None,
+ })
+ });
+
+ let context_entries =
+ self.context_store
+ .read(cx)
+ .unordered_contexts()
+ .flat_map(|context| {
+ self.recently_opened_entries
+ .iter()
+ .enumerate()
+ .flat_map(|(index, entry)| match entry {
+ HistoryEntryId::TextThread(path) if &context.path == path => {
+ Some((index, HistoryEntry::TextThread(context.clone())))
+ }
+ _ => None,
+ })
+ });
+
+ thread_entries
+ .chain(context_entries)
+ // optimization to halt iteration early
+ .take(self.recently_opened_entries.len())
+ .sorted_unstable_by_key(|(index, _)| *index)
+ .map(|(_, entry)| entry)
+ .collect()
+ }
+
+ fn save_recently_opened_entries(&mut self, cx: &mut Context) {
+ let serialized_entries = self
+ .recently_opened_entries
+ .iter()
+ .filter_map(|entry| match entry {
+ HistoryEntryId::TextThread(path) => path.file_name().map(|file| {
+ SerializedRecentOpen::TextThread(file.to_string_lossy().to_string())
+ }),
+ HistoryEntryId::AcpThread(id) => {
+ Some(SerializedRecentOpen::AcpThread(id.to_string()))
+ }
+ })
+ .collect::>();
+
+ self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| {
+ let content = serde_json::to_string(&serialized_entries).unwrap();
+ cx.background_executor()
+ .timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE)
+ .await;
+
+ if cfg!(any(feature = "test-support", test)) {
+ return;
+ }
+ KEY_VALUE_STORE
+ .write_kvp(RECENTLY_OPENED_THREADS_KEY.to_owned(), content)
+ .await
+ .log_err();
+ });
+ }
+
+ fn load_recently_opened_entries(cx: &AsyncApp) -> Task>> {
+ cx.background_spawn(async move {
+ if cfg!(any(feature = "test-support", test)) {
+ anyhow::bail!("history store does not persist in tests");
+ }
+ let json = KEY_VALUE_STORE
+ .read_kvp(RECENTLY_OPENED_THREADS_KEY)?
+ .unwrap_or("[]".to_string());
+ let entries = serde_json::from_str::>(&json)
+ .context("deserializing persisted agent panel navigation history")?
+ .into_iter()
+ .take(MAX_RECENTLY_OPENED_ENTRIES)
+ .flat_map(|entry| match entry {
+ SerializedRecentOpen::AcpThread(id) => Some(HistoryEntryId::AcpThread(
+ acp::SessionId(id.as_str().into()),
+ )),
+ SerializedRecentOpen::TextThread(file_name) => Some(
+ HistoryEntryId::TextThread(contexts_dir().join(file_name).into()),
+ ),
+ })
+ .collect();
+ Ok(entries)
+ })
+ }
+
+ pub fn push_recently_opened_entry(&mut self, entry: HistoryEntryId, cx: &mut Context) {
+ self.recently_opened_entries
+ .retain(|old_entry| old_entry != &entry);
+ self.recently_opened_entries.push_front(entry);
+ self.recently_opened_entries
+ .truncate(MAX_RECENTLY_OPENED_ENTRIES);
+ self.save_recently_opened_entries(cx);
+ }
+
+ pub fn remove_recently_opened_thread(&mut self, id: acp::SessionId, cx: &mut Context) {
+ self.recently_opened_entries.retain(|entry| match entry {
+ HistoryEntryId::AcpThread(thread_id) if thread_id == &id => false,
+ _ => true,
+ });
+ self.save_recently_opened_entries(cx);
+ }
+
+ pub fn replace_recently_opened_text_thread(
+ &mut self,
+ old_path: &Path,
+ new_path: &Arc,
+ cx: &mut Context,
+ ) {
+ for entry in &mut self.recently_opened_entries {
+ match entry {
+ HistoryEntryId::TextThread(path) if path.as_ref() == old_path => {
+ *entry = HistoryEntryId::TextThread(new_path.clone());
+ break;
+ }
+ _ => {}
+ }
+ }
+ self.save_recently_opened_entries(cx);
+ }
+
+ pub fn remove_recently_opened_entry(&mut self, entry: &HistoryEntryId, cx: &mut Context) {
+ self.recently_opened_entries
+ .retain(|old_entry| old_entry != entry);
+ self.save_recently_opened_entries(cx);
+ }
+}
diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs
index cadd88a846..74d24efb13 100644
--- a/crates/agent2/src/native_agent_server.rs
+++ b/crates/agent2/src/native_agent_server.rs
@@ -1,4 +1,4 @@
-use std::{path::Path, rc::Rc, sync::Arc};
+use std::{any::Any, path::Path, rc::Rc, sync::Arc};
use agent_servers::AgentServer;
use anyhow::Result;
@@ -7,16 +7,17 @@ use gpui::{App, Entity, Task};
use project::Project;
use prompt_store::PromptStore;
-use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
+use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates};
#[derive(Clone)]
pub struct NativeAgentServer {
fs: Arc,
+ history: Entity,
}
impl NativeAgentServer {
- pub fn new(fs: Arc) -> Self {
- Self { fs }
+ pub fn new(fs: Arc, history: Entity) -> Self {
+ Self { fs, history }
}
}
@@ -26,16 +27,15 @@ impl AgentServer for NativeAgentServer {
}
fn empty_state_headline(&self) -> &'static str {
- "Native Agent"
+ ""
}
fn empty_state_message(&self) -> &'static str {
- "How can I help you today?"
+ ""
}
fn logo(&self) -> ui::IconName {
- // Using the ZedAssistant icon as it's the native built-in agent
- ui::IconName::ZedAssistant
+ ui::IconName::ZedAgent
}
fn connect(
@@ -50,6 +50,7 @@ impl AgentServer for NativeAgentServer {
);
let project = project.clone();
let fs = self.fs.clone();
+ let history = self.history.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
@@ -57,7 +58,8 @@ impl AgentServer for NativeAgentServer {
let prompt_store = prompt_store.await?;
log::debug!("Creating native agent entity");
- let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
+ let agent =
+ NativeAgent::new(project, history, templates, Some(prompt_store), fs, cx).await?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);
@@ -66,4 +68,8 @@ impl AgentServer for NativeAgentServer {
Ok(Rc::new(connection) as Rc)
})
}
+
+ fn into_any(self: Rc) -> Rc {
+ self
+ }
}
diff --git a/crates/agent2/src/templates.rs b/crates/agent2/src/templates.rs
index a63f0ad206..72a8f6633c 100644
--- a/crates/agent2/src/templates.rs
+++ b/crates/agent2/src/templates.rs
@@ -62,7 +62,7 @@ fn contains(
handlebars::RenderError::new("contains: missing or invalid query parameter")
})?;
- if list.contains(&query) {
+ if list.contains(query) {
out.write("true")?;
}
diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs
index e3e3050d49..55bfa6f0b5 100644
--- a/crates/agent2/src/tests/mod.rs
+++ b/crates/agent2/src/tests/mod.rs
@@ -6,15 +6,16 @@ use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
use fs::{FakeFs, Fs};
-use futures::channel::mpsc::UnboundedReceiver;
+use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
use gpui::{
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
};
use indoc::indoc;
use language_model::{
- LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
- LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
- Role, StopReason, fake_provider::FakeLanguageModel,
+ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
+ LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
+ LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
+ fake_provider::FakeLanguageModel,
};
use pretty_assertions::assert_eq;
use project::Project;
@@ -24,8 +25,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use settings::SettingsStore;
-use smol::stream::StreamExt;
-use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
+use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
mod test_tools;
@@ -101,7 +101,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- project_context.borrow_mut().shell = "test-shell".into();
+ project_context.update(cx, |project_context, _cx| {
+ project_context.shell = "test-shell".into()
+ });
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread
.update(cx, |thread, cx| {
@@ -343,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await {
- if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
+ if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message
let message = thread.last_message().unwrap();
@@ -733,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
);
}
-async fn expect_tool_call(
- events: &mut UnboundedReceiver>,
-) -> acp::ToolCall {
+async fn expect_tool_call(events: &mut UnboundedReceiver>) -> acp::ToolCall {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
match event {
- AgentResponseEvent::ToolCall(tool_call) => return tool_call,
+ ThreadEvent::ToolCall(tool_call) => tool_call,
event => {
panic!("Unexpected event {event:?}");
}
@@ -750,7 +750,7 @@ async fn expect_tool_call(
}
async fn expect_tool_call_update_fields(
- events: &mut UnboundedReceiver>,
+ events: &mut UnboundedReceiver>,
) -> acp::ToolCallUpdate {
let event = events
.next()
@@ -758,9 +758,7 @@ async fn expect_tool_call_update_fields(
.expect("no tool call authorization event received")
.unwrap();
match event {
- AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
- return update;
- }
+ ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
event => {
panic!("Unexpected event {event:?}");
}
@@ -768,7 +766,7 @@ async fn expect_tool_call_update_fields(
}
async fn next_tool_call_authorization(
- events: &mut UnboundedReceiver>,
+ events: &mut UnboundedReceiver>,
) -> ToolCallAuthorization {
loop {
let event = events
@@ -776,7 +774,7 @@ async fn next_tool_call_authorization(
.await
.expect("no tool call authorization event received")
.unwrap();
- if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
+ if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization
.options
.iter()
@@ -943,13 +941,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
let mut echo_completed = false;
while let Some(event) = events.next().await {
match event.unwrap() {
- AgentResponseEvent::ToolCall(tool_call) => {
+ ThreadEvent::ToolCall(tool_call) => {
assert_eq!(tool_call.title, expected_tools.remove(0));
if tool_call.title == "Echo" {
echo_id = Some(tool_call.id);
}
}
- AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
+ ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
acp::ToolCallUpdate {
id,
fields:
@@ -971,13 +969,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running.
- thread.update(cx, |thread, _cx| thread.cancel());
+ thread.update(cx, |thread, cx| thread.cancel(cx));
let events = events.collect::>().await;
let last_event = events.last();
assert!(
matches!(
last_event,
- Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
+ Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
),
"unexpected event {last_event:?}"
);
@@ -1119,7 +1117,7 @@ async fn test_refusal(cx: &mut TestAppContext) {
}
#[gpui::test]
-async fn test_truncate(cx: &mut TestAppContext) {
+async fn test_truncate_first_message(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
@@ -1139,9 +1137,18 @@ async fn test_truncate(cx: &mut TestAppContext) {
Hello
"}
);
+ assert_eq!(thread.latest_token_usage(), None);
});
fake_model.send_last_completion_stream_text_chunk("Hey!");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 32_000,
+ output_tokens: 16_000,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@@ -1156,14 +1163,22 @@ async fn test_truncate(cx: &mut TestAppContext) {
Hey!
"}
);
+ assert_eq!(
+ thread.latest_token_usage(),
+ Some(acp_thread::TokenUsage {
+ used_tokens: 32_000 + 16_000,
+ max_tokens: 1_000_000,
+ })
+ );
});
thread
- .update(cx, |thread, _cx| thread.truncate(message_id))
+ .update(cx, |thread, cx| thread.truncate(message_id, cx))
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(thread.to_markdown(), "");
+ assert_eq!(thread.latest_token_usage(), None);
});
// Ensure we can still send a new message after truncation.
@@ -1184,6 +1199,14 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 40_000,
+ output_tokens: 20_000,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@@ -1198,9 +1221,171 @@ async fn test_truncate(cx: &mut TestAppContext) {
Ahoy!
"}
);
+
+ assert_eq!(
+ thread.latest_token_usage(),
+ Some(acp_thread::TokenUsage {
+ used_tokens: 40_000 + 20_000,
+ max_tokens: 1_000_000,
+ })
+ );
});
}
+#[gpui::test]
+async fn test_truncate_second_message(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Message 1"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Message 1 response");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 32_000,
+ output_tokens: 16_000,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let assert_first_message_state = |cx: &mut TestAppContext| {
+ thread.clone().read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Message 1
+
+ ## Assistant
+
+ Message 1 response
+ "}
+ );
+
+ assert_eq!(
+ thread.latest_token_usage(),
+ Some(acp_thread::TokenUsage {
+ used_tokens: 32_000 + 16_000,
+ max_tokens: 1_000_000,
+ })
+ );
+ });
+ };
+
+ assert_first_message_state(cx);
+
+ let second_message_id = UserMessageId::new();
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(second_message_id.clone(), ["Message 2"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("Message 2 response");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 40_000,
+ output_tokens: 20_000,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Message 1
+
+ ## Assistant
+
+ Message 1 response
+
+ ## User
+
+ Message 2
+
+ ## Assistant
+
+ Message 2 response
+ "}
+ );
+
+ assert_eq!(
+ thread.latest_token_usage(),
+ Some(acp_thread::TokenUsage {
+ used_tokens: 40_000 + 20_000,
+ max_tokens: 1_000_000,
+ })
+ );
+ });
+
+ thread
+ .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
+ .unwrap();
+ cx.run_until_parked();
+
+ assert_first_message_state(cx);
+}
+
+#[gpui::test]
+async fn test_title_generation(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let summary_model = Arc::new(FakeLanguageModel::default());
+ thread.update(cx, |thread, cx| {
+ thread.set_summarization_model(Some(summary_model.clone()), cx)
+ });
+
+ let send = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("Hey!");
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+ thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
+
+ // Ensure the summary model has been invoked to generate a title.
+ summary_model.send_last_completion_stream_text_chunk("Hello ");
+ summary_model.send_last_completion_stream_text_chunk("world\nG");
+ summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
+ summary_model.end_last_completion_stream();
+ send.collect::>().await;
+ thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+
+ // Send another message, ensuring no title is generated this time.
+ let send = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello again"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Hey again!");
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+ assert_eq!(summary_model.pending_completions(), Vec::new());
+ send.collect::>().await;
+ thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+}
+
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init);
@@ -1228,10 +1413,13 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
+ let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
// Create agent and connection
let agent = NativeAgent::new(
project.clone(),
+ history_store,
templates.clone(),
None,
fake_fs.clone(),
@@ -1433,12 +1621,168 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
);
}
+#[gpui::test]
+async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
+ let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
+ thread.send(UserMessageId::new(), ["Hello!"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("Hey!");
+ fake_model.end_last_completion_stream();
+
+ let mut retry_events = Vec::new();
+ while let Some(Ok(event)) = events.next().await {
+ match event {
+ ThreadEvent::Retry(retry_status) => {
+ retry_events.push(retry_status);
+ }
+ ThreadEvent::Stop(..) => break,
+ _ => {}
+ }
+ }
+
+ assert_eq!(retry_events.len(), 0);
+ thread.read_with(cx, |thread, _cx| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Hello!
+
+ ## Assistant
+
+ Hey!
+ "}
+ )
+ });
+}
+
+#[gpui::test]
+async fn test_send_retry_on_error(cx: &mut TestAppContext) {
+ let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
+ thread.send(UserMessageId::new(), ["Hello!"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
+ provider: LanguageModelProviderName::new("Anthropic"),
+ retry_after: Some(Duration::from_secs(3)),
+ });
+ fake_model.end_last_completion_stream();
+
+ cx.executor().advance_clock(Duration::from_secs(3));
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("Hey!");
+ fake_model.end_last_completion_stream();
+
+ let mut retry_events = Vec::new();
+ while let Some(Ok(event)) = events.next().await {
+ match event {
+ ThreadEvent::Retry(retry_status) => {
+ retry_events.push(retry_status);
+ }
+ ThreadEvent::Stop(..) => break,
+ _ => {}
+ }
+ }
+
+ assert_eq!(retry_events.len(), 1);
+ assert!(matches!(
+ retry_events[0],
+ acp_thread::RetryStatus { attempt: 1, .. }
+ ));
+ thread.read_with(cx, |thread, _cx| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Hello!
+
+ ## Assistant
+
+ Hey!
+ "}
+ )
+ });
+}
+
+#[gpui::test]
+async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
+ let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
+ thread.send(UserMessageId::new(), ["Hello!"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
+ fake_model.send_last_completion_stream_error(
+ LanguageModelCompletionError::ServerOverloaded {
+ provider: LanguageModelProviderName::new("Anthropic"),
+ retry_after: Some(Duration::from_secs(3)),
+ },
+ );
+ fake_model.end_last_completion_stream();
+ cx.executor().advance_clock(Duration::from_secs(3));
+ cx.run_until_parked();
+ }
+
+ let mut errors = Vec::new();
+ let mut retry_events = Vec::new();
+ while let Some(event) = events.next().await {
+ match event {
+ Ok(ThreadEvent::Retry(retry_status)) => {
+ retry_events.push(retry_status);
+ }
+ Ok(ThreadEvent::Stop(..)) => break,
+ Err(error) => errors.push(error),
+ _ => {}
+ }
+ }
+
+ assert_eq!(
+ retry_events.len(),
+ crate::thread::MAX_RETRY_ATTEMPTS as usize
+ );
+ for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
+ assert_eq!(retry_events[i].attempt, i + 1);
+ }
+ assert_eq!(errors.len(), 1);
+ let error = errors[0]
+ .downcast_ref::()
+ .unwrap();
+ assert!(matches!(
+ error,
+ LanguageModelCompletionError::ServerOverloaded { .. }
+ ));
+}
+
/// Filters out the stop events for asserting against in tests
-fn stop_events(result_events: Vec>) -> Vec {
+fn stop_events(result_events: Vec>) -> Vec {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {
- AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
+ ThreadEvent::Stop(stop_reason) => Some(stop_reason),
_ => None,
})
.collect()
@@ -1447,7 +1791,7 @@ fn stop_events(result_events: Vec>) -> Vec,
thread: Entity,
- project_context: Rc>,
+ project_context: Entity,
fs: Arc,
}
@@ -1543,7 +1887,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
})
.await;
- let project_context = Rc::new(RefCell::new(ProjectContext::default()));
+ let project_context = cx.new(|_cx| ProjectContext::default());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs
index 429832010b..c1778bf38b 100644
--- a/crates/agent2/src/thread.rs
+++ b/crates/agent2/src/thread.rs
@@ -1,57 +1,58 @@
-use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
+use crate::{
+ ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
+ DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
+ ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
+ Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
+};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
+use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
use agent_client_protocol as acp;
-use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
+use agent_settings::{
+ AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
+ SUMMARIZE_THREAD_PROMPT,
+};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
+use chrono::{DateTime, Utc};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
-use collections::IndexMap;
+use collections::{HashMap, IndexMap};
use fs::Fs;
use futures::{
+ FutureExt,
channel::{mpsc, oneshot},
+ future::Shared,
stream::FuturesUnordered,
};
-use gpui::{App, Context, Entity, SharedString, Task};
+use git::repository::DiffType;
+use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language_model::{
- LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
- LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
- LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
- LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
+ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
+ LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
+ LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
+ LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
+};
+use project::{
+ Project,
+ git_store::{GitStore, RepositoryState},
};
-use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
-use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
+use std::{
+ collections::BTreeMap,
+ path::Path,
+ sync::Arc,
+ time::{Duration, Instant},
+};
use std::{fmt::Write, ops::Range};
use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid;
-#[derive(
- Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
-)]
-pub struct ThreadId(Arc);
-
-impl ThreadId {
- pub fn new() -> Self {
- Self(Uuid::new_v4().to_string().into())
- }
-}
-
-impl std::fmt::Display for ThreadId {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
-}
-
-impl From<&str> for ThreadId {
- fn from(value: &str) -> Self {
- Self(value.into())
- }
-}
+const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
/// The ID of the user prompt that initiated a request.
///
@@ -71,7 +72,22 @@ impl std::fmt::Display for PromptId {
}
}
-#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
+pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
+
+#[derive(Debug, Clone)]
+enum RetryStrategy {
+ ExponentialBackoff {
+ initial_delay: Duration,
+ max_attempts: u8,
+ },
+ Fixed {
+ delay: Duration,
+ max_attempts: u8,
+ },
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Message {
User(UserMessage),
Agent(AgentMessage),
@@ -86,6 +102,18 @@ impl Message {
}
}
+ pub fn to_request(&self) -> Vec {
+ match self {
+ Message::User(message) => vec![message.to_request()],
+ Message::Agent(message) => message.to_request(),
+ Message::Resume => vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Continue where you left off".into()],
+ cache: false,
+ }],
+ }
+ }
+
pub fn to_markdown(&self) -> String {
match self {
Message::User(message) => message.to_markdown(),
@@ -93,15 +121,22 @@ impl Message {
Message::Resume => "[resumed after tool use limit was reached]".into(),
}
}
+
+ pub fn role(&self) -> Role {
+ match self {
+ Message::User(_) | Message::Resume => Role::User,
+ Message::Agent(_) => Role::Assistant,
+ }
+ }
}
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserMessage {
pub id: UserMessageId,
pub content: Vec,
}
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum UserMessageContent {
Text(String),
Mention { uri: MentionUri, content: String },
@@ -146,6 +181,7 @@ impl UserMessage {
They are up-to-date and don't need to be re-read.\n\n";
const OPEN_FILES_TAG: &str = "";
+ const OPEN_DIRECTORIES_TAG: &str = "";
const OPEN_SYMBOLS_TAG: &str = "";
const OPEN_THREADS_TAG: &str = "";
const OPEN_FETCH_TAG: &str = "";
@@ -153,6 +189,7 @@ impl UserMessage {
"\nThe user has specified the following rules that should be applied:\n";
let mut file_context = OPEN_FILES_TAG.to_string();
+ let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
let mut thread_context = OPEN_THREADS_TAG.to_string();
let mut fetch_context = OPEN_FETCH_TAG.to_string();
@@ -168,17 +205,20 @@ impl UserMessage {
}
UserMessageContent::Mention { uri, content } => {
match uri {
- MentionUri::File { abs_path, .. } => {
+ MentionUri::File { abs_path } => {
write!(
&mut symbol_context,
"\n{}",
MarkdownCodeBlock {
- tag: &codeblock_tag(&abs_path, None),
+ tag: &codeblock_tag(abs_path, None),
text: &content.to_string(),
}
)
.ok();
}
+ MentionUri::Directory { .. } => {
+ write!(&mut directory_context, "\n{}\n", content).ok();
+ }
MentionUri::Symbol {
path, line_range, ..
}
@@ -189,8 +229,8 @@ impl UserMessage {
&mut rules_context,
"\n{}",
MarkdownCodeBlock {
- tag: &codeblock_tag(&path, Some(line_range)),
- text: &content
+ tag: &codeblock_tag(path, Some(line_range)),
+ text: content
}
)
.ok();
@@ -207,7 +247,7 @@ impl UserMessage {
"\n{}",
MarkdownCodeBlock {
tag: "",
- text: &content
+ text: content
}
)
.ok();
@@ -233,6 +273,13 @@ impl UserMessage {
.push(language_model::MessageContent::Text(file_context));
}
+ if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
+ directory_context.push_str("\n");
+ message
+ .content
+ .push(language_model::MessageContent::Text(directory_context));
+ }
+
if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
symbol_context.push_str("\n");
message
@@ -313,9 +360,6 @@ impl AgentMessage {
AgentMessageContent::RedactedThinking(_) => {
markdown.push_str("\n")
}
- AgentMessageContent::Image(_) => {
- markdown.push_str("\n");
- }
AgentMessageContent::ToolUse(tool_use) => {
markdown.push_str(&format!(
"**Tool Use**: {} (ID: {})\n",
@@ -386,9 +430,6 @@ impl AgentMessage {
AgentMessageContent::ToolUse(value) => {
language_model::MessageContent::ToolUse(value.clone())
}
- AgentMessageContent::Image(value) => {
- language_model::MessageContent::Image(value.clone())
- }
};
assistant_message.content.push(chunk);
}
@@ -418,13 +459,13 @@ impl AgentMessage {
}
}
-#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AgentMessage {
pub content: Vec,
pub tool_results: IndexMap,
}
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AgentMessageContent {
Text(String),
Thinking {
@@ -432,17 +473,20 @@ pub enum AgentMessageContent {
signature: Option,
},
RedactedThinking(String),
- Image(LanguageModelImage),
ToolUse(LanguageModelToolUse),
}
#[derive(Debug)]
-pub enum AgentResponseEvent {
- Text(String),
- Thinking(String),
+pub enum ThreadEvent {
+ UserMessage(UserMessage),
+ AgentText(String),
+ AgentThinking(String),
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
+ TokenUsageUpdate(acp_thread::TokenUsage),
+ TitleUpdate(SharedString),
+ Retry(acp_thread::RetryStatus),
Stop(acp::StopReason),
}
@@ -454,8 +498,11 @@ pub struct ToolCallAuthorization {
}
pub struct Thread {
- id: ThreadId,
+ id: acp::SessionId,
prompt_id: PromptId,
+ updated_at: DateTime,
+ title: Option,
+ summary: Option,
messages: Vec,
completion_mode: CompletionMode,
/// Holds the task that handles agent interaction until the end of the turn.
@@ -465,19 +512,25 @@ pub struct Thread {
pending_message: Option,
tools: BTreeMap>,
tool_use_limit_reached: bool,
+ request_token_usage: HashMap,
+ #[allow(unused)]
+ cumulative_token_usage: TokenUsage,
+ #[allow(unused)]
+ initial_project_snapshot: Shared>>>,
context_server_registry: Entity,
profile_id: AgentProfileId,
- project_context: Rc>,
+ project_context: Entity,
templates: Arc,
model: Option>,
- project: Entity,
- action_log: Entity,
+ summarization_model: Option>,
+ pub(crate) project: Entity,
+ pub(crate) action_log: Entity,
}
impl Thread {
pub fn new(
project: Entity,
- project_context: Rc>,
+ project_context: Entity,
context_server_registry: Entity,
action_log: Entity,
templates: Arc,
@@ -486,24 +539,327 @@ impl Thread {
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
Self {
- id: ThreadId::new(),
+ id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
prompt_id: PromptId::new(),
+ updated_at: Utc::now(),
+ title: None,
+ summary: None,
messages: Vec::new(),
- completion_mode: CompletionMode::Normal,
+ completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
running_turn: None,
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
+ request_token_usage: HashMap::default(),
+ cumulative_token_usage: TokenUsage::default(),
+ initial_project_snapshot: {
+ let project_snapshot = Self::project_snapshot(project.clone(), cx);
+ cx.foreground_executor()
+ .spawn(async move { Some(project_snapshot.await) })
+ .shared()
+ },
context_server_registry,
profile_id,
project_context,
templates,
model,
+ summarization_model: None,
project,
action_log,
}
}
+ pub fn id(&self) -> &acp::SessionId {
+ &self.id
+ }
+
+ pub fn replay(
+ &mut self,
+ cx: &mut Context,
+ ) -> mpsc::UnboundedReceiver> {
+ let (tx, rx) = mpsc::unbounded();
+ let stream = ThreadEventStream(tx);
+ for message in &self.messages {
+ match message {
+ Message::User(user_message) => stream.send_user_message(user_message),
+ Message::Agent(assistant_message) => {
+ for content in &assistant_message.content {
+ match content {
+ AgentMessageContent::Text(text) => stream.send_text(text),
+ AgentMessageContent::Thinking { text, .. } => {
+ stream.send_thinking(text)
+ }
+ AgentMessageContent::RedactedThinking(_) => {}
+ AgentMessageContent::ToolUse(tool_use) => {
+ self.replay_tool_call(
+ tool_use,
+ assistant_message.tool_results.get(&tool_use.id),
+ &stream,
+ cx,
+ );
+ }
+ }
+ }
+ }
+ Message::Resume => {}
+ }
+ }
+ rx
+ }
+
+ fn replay_tool_call(
+ &self,
+ tool_use: &LanguageModelToolUse,
+ tool_result: Option<&LanguageModelToolResult>,
+ stream: &ThreadEventStream,
+ cx: &mut Context,
+ ) {
+ let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
+ stream
+ .0
+ .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
+ id: acp::ToolCallId(tool_use.id.to_string().into()),
+ title: tool_use.name.to_string(),
+ kind: acp::ToolKind::Other,
+ status: acp::ToolCallStatus::Failed,
+ content: Vec::new(),
+ locations: Vec::new(),
+ raw_input: Some(tool_use.input.clone()),
+ raw_output: None,
+ })))
+ .ok();
+ return;
+ };
+
+ let title = tool.initial_title(tool_use.input.clone());
+ let kind = tool.kind();
+ stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
+
+ let output = tool_result
+ .as_ref()
+ .and_then(|result| result.output.clone());
+ if let Some(output) = output.clone() {
+ let tool_event_stream = ToolCallEventStream::new(
+ tool_use.id.clone(),
+ stream.clone(),
+ Some(self.project.read(cx).fs().clone()),
+ );
+ tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
+ .log_err();
+ }
+
+ stream.update_tool_call_fields(
+ &tool_use.id,
+ acp::ToolCallUpdateFields {
+ status: Some(acp::ToolCallStatus::Completed),
+ raw_output: output,
+ ..Default::default()
+ },
+ );
+ }
+
+ pub fn from_db(
+ id: acp::SessionId,
+ db_thread: DbThread,
+ project: Entity,
+ project_context: Entity,
+ context_server_registry: Entity,
+ action_log: Entity,
+ templates: Arc,
+ cx: &mut Context,
+ ) -> Self {
+ let profile_id = db_thread
+ .profile
+ .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
+ let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ db_thread
+ .model
+ .and_then(|model| {
+ let model = SelectedModel {
+ provider: model.provider.clone().into(),
+ model: model.model.clone().into(),
+ };
+ registry.select_model(&model, cx)
+ })
+ .or_else(|| registry.default_model())
+ .map(|model| model.model)
+ });
+
+ Self {
+ id,
+ prompt_id: PromptId::new(),
+ title: if db_thread.title.is_empty() {
+ None
+ } else {
+ Some(db_thread.title.clone())
+ },
+ summary: db_thread.detailed_summary,
+ messages: db_thread.messages,
+ completion_mode: db_thread.completion_mode.unwrap_or_default(),
+ running_turn: None,
+ pending_message: None,
+ tools: BTreeMap::default(),
+ tool_use_limit_reached: false,
+ request_token_usage: db_thread.request_token_usage.clone(),
+ cumulative_token_usage: db_thread.cumulative_token_usage,
+ initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
+ context_server_registry,
+ profile_id,
+ project_context,
+ templates,
+ model,
+ summarization_model: None,
+ project,
+ action_log,
+ updated_at: db_thread.updated_at,
+ }
+ }
+
+ pub fn to_db(&self, cx: &App) -> Task {
+ let initial_project_snapshot = self.initial_project_snapshot.clone();
+ let mut thread = DbThread {
+ title: self.title.clone().unwrap_or_default(),
+ messages: self.messages.clone(),
+ updated_at: self.updated_at,
+ detailed_summary: self.summary.clone(),
+ initial_project_snapshot: None,
+ cumulative_token_usage: self.cumulative_token_usage,
+ request_token_usage: self.request_token_usage.clone(),
+ model: self.model.as_ref().map(|model| DbLanguageModel {
+ provider: model.provider_id().to_string(),
+ model: model.name().0.to_string(),
+ }),
+ completion_mode: Some(self.completion_mode),
+ profile: Some(self.profile_id.clone()),
+ };
+
+ cx.background_spawn(async move {
+ let initial_project_snapshot = initial_project_snapshot.await;
+ thread.initial_project_snapshot = initial_project_snapshot;
+ thread
+ })
+ }
+
+ /// Create a snapshot of the current project state including git information and unsaved buffers.
+ fn project_snapshot(
+ project: Entity,
+ cx: &mut Context,
+ ) -> Task> {
+ let git_store = project.read(cx).git_store().clone();
+ let worktree_snapshots: Vec<_> = project
+ .read(cx)
+ .visible_worktrees(cx)
+ .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
+ .collect();
+
+ cx.spawn(async move |_, cx| {
+ let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
+
+ let mut unsaved_buffers = Vec::new();
+ cx.update(|app_cx| {
+ let buffer_store = project.read(app_cx).buffer_store();
+ for buffer_handle in buffer_store.read(app_cx).buffers() {
+ let buffer = buffer_handle.read(app_cx);
+ if buffer.is_dirty()
+ && let Some(file) = buffer.file()
+ {
+ let path = file.path().to_string_lossy().to_string();
+ unsaved_buffers.push(path);
+ }
+ }
+ })
+ .ok();
+
+ Arc::new(ProjectSnapshot {
+ worktree_snapshots,
+ unsaved_buffer_paths: unsaved_buffers,
+ timestamp: Utc::now(),
+ })
+ })
+ }
+
+ fn worktree_snapshot(
+ worktree: Entity,
+ git_store: Entity,
+ cx: &App,
+ ) -> Task {
+ cx.spawn(async move |cx| {
+ // Get worktree path and snapshot
+ let worktree_info = cx.update(|app_cx| {
+ let worktree = worktree.read(app_cx);
+ let path = worktree.abs_path().to_string_lossy().to_string();
+ let snapshot = worktree.snapshot();
+ (path, snapshot)
+ });
+
+ let Ok((worktree_path, _snapshot)) = worktree_info else {
+ return WorktreeSnapshot {
+ worktree_path: String::new(),
+ git_state: None,
+ };
+ };
+
+ let git_state = git_store
+ .update(cx, |git_store, cx| {
+ git_store
+ .repositories()
+ .values()
+ .find(|repo| {
+ repo.read(cx)
+ .abs_path_to_repo_path(&worktree.read(cx).abs_path())
+ .is_some()
+ })
+ .cloned()
+ })
+ .ok()
+ .flatten()
+ .map(|repo| {
+ repo.update(cx, |repo, _| {
+ let current_branch =
+ repo.branch.as_ref().map(|branch| branch.name().to_owned());
+ repo.send_job(None, |state, _| async move {
+ let RepositoryState::Local { backend, .. } = state else {
+ return GitState {
+ remote_url: None,
+ head_sha: None,
+ current_branch,
+ diff: None,
+ };
+ };
+
+ let remote_url = backend.remote_url("origin");
+ let head_sha = backend.head_sha().await;
+ let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
+
+ GitState {
+ remote_url,
+ head_sha,
+ current_branch,
+ diff,
+ }
+ })
+ })
+ });
+
+ let git_state = match git_state {
+ Some(git_state) => match git_state.ok() {
+ Some(git_state) => git_state.await.ok(),
+ None => None,
+ },
+ None => None,
+ };
+
+ WorktreeSnapshot {
+ worktree_path,
+ git_state,
+ }
+ })
+ }
+
+ pub fn project_context(&self) -> &Entity {
+ &self.project_context
+ }
+
pub fn project(&self) -> &Entity {
&self.project
}
@@ -516,16 +872,27 @@ impl Thread {
self.model.as_ref()
}
- pub fn set_model(&mut self, model: Arc) {
+ pub fn set_model(&mut self, model: Arc, cx: &mut Context) {
self.model = Some(model);
+ cx.notify()
+ }
+
+ pub fn set_summarization_model(
+ &mut self,
+ model: Option>,
+ cx: &mut Context,
+ ) {
+ self.summarization_model = model;
+ cx.notify()
}
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
- pub fn set_completion_mode(&mut self, mode: CompletionMode) {
+ pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context) {
self.completion_mode = mode;
+ cx.notify()
}
#[cfg(any(test, feature = "test-support"))]
@@ -537,6 +904,32 @@ impl Thread {
}
}
+ pub fn add_default_tools(&mut self, cx: &mut Context) {
+ let language_registry = self.project.read(cx).languages().clone();
+ self.add_tool(CopyPathTool::new(self.project.clone()));
+ self.add_tool(CreateDirectoryTool::new(self.project.clone()));
+ self.add_tool(DeletePathTool::new(
+ self.project.clone(),
+ self.action_log.clone(),
+ ));
+ self.add_tool(DiagnosticsTool::new(self.project.clone()));
+ self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
+ self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
+ self.add_tool(FindPathTool::new(self.project.clone()));
+ self.add_tool(GrepTool::new(self.project.clone()));
+ self.add_tool(ListDirectoryTool::new(self.project.clone()));
+ self.add_tool(MovePathTool::new(self.project.clone()));
+ self.add_tool(NowTool);
+ self.add_tool(OpenTool::new(self.project.clone()));
+ self.add_tool(ReadFileTool::new(
+ self.project.clone(),
+ self.action_log.clone(),
+ ));
+ self.add_tool(TerminalTool::new(self.project.clone(), cx));
+ self.add_tool(ThinkingTool);
+ self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
+ }
+
pub fn add_tool(&mut self, tool: impl AgentTool) {
self.tools.insert(tool.name(), tool.erase());
}
@@ -553,29 +946,58 @@ impl Thread {
self.profile_id = profile_id;
}
- pub fn cancel(&mut self) {
+ pub fn cancel(&mut self, cx: &mut Context) {
if let Some(running_turn) = self.running_turn.take() {
running_turn.cancel();
}
- self.flush_pending_message();
+ self.flush_pending_message(cx);
}
- pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
- self.cancel();
+ pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
+ let Some(last_user_message) = self.last_user_message() else {
+ return;
+ };
+
+ self.request_token_usage
+ .insert(last_user_message.id.clone(), update);
+ }
+
+ pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> {
+ self.cancel(cx);
let Some(position) = self.messages.iter().position(
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
) else {
return Err(anyhow!("Message not found"));
};
- self.messages.truncate(position);
+
+ for message in self.messages.drain(position..) {
+ match message {
+ Message::User(message) => {
+ self.request_token_usage.remove(&message.id);
+ }
+ Message::Agent(_) | Message::Resume => {}
+ }
+ }
+ self.summary = None;
+ cx.notify();
Ok(())
}
+ pub fn latest_token_usage(&self) -> Option {
+ let last_user_message = self.last_user_message()?;
+ let tokens = self.request_token_usage.get(&last_user_message.id)?;
+ let model = self.model.clone()?;
+
+ Some(acp_thread::TokenUsage {
+ max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
+ used_tokens: tokens.total_tokens(),
+ })
+ }
+
pub fn resume(
&mut self,
cx: &mut Context,
- ) -> Result>> {
- anyhow::ensure!(self.model.is_some(), "Model not set");
+ ) -> Result>> {
anyhow::ensure!(
self.tool_use_limit_reached,
"can only resume after tool use limit is reached"
@@ -596,7 +1018,7 @@ impl Thread {
id: UserMessageId,
content: impl IntoIterator- ,
cx: &mut Context,
- ) -> Result>>
+ ) -> Result>>
where
T: Into