Merge branch 'main' into ui-scrollbar-teardown

This commit is contained in:
MrSubidubi 2025-08-20 10:48:16 +02:00
commit 76842eed31
598 changed files with 18423 additions and 13109 deletions

View file

@ -56,7 +56,6 @@ runs:
$env:COMPlus_CreateDumpDiagnostics = "1" $env:COMPlus_CreateDumpDiagnostics = "1"
cargo nextest run --workspace --no-fail-fast cargo nextest run --workspace --no-fail-fast
continue-on-error: true
- name: Analyze crash dumps - name: Analyze crash dumps
if: always() if: always()

76
Cargo.lock generated
View file

@ -7,7 +7,6 @@ name = "acp_thread"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"action_log", "action_log",
"agent",
"agent-client-protocol", "agent-client-protocol",
"anyhow", "anyhow",
"buffer_diff", "buffer_diff",
@ -20,6 +19,7 @@ dependencies = [
"indoc", "indoc",
"itertools 0.14.0", "itertools 0.14.0",
"language", "language",
"language_model",
"markdown", "markdown",
"parking_lot", "parking_lot",
"project", "project",
@ -130,7 +130,6 @@ dependencies = [
"component", "component",
"context_server", "context_server",
"convert_case 0.8.0", "convert_case 0.8.0",
"feature_flags",
"fs", "fs",
"futures 0.3.31", "futures 0.3.31",
"git", "git",
@ -191,10 +190,12 @@ version = "0.1.0"
dependencies = [ dependencies = [
"acp_thread", "acp_thread",
"action_log", "action_log",
"agent",
"agent-client-protocol", "agent-client-protocol",
"agent_servers", "agent_servers",
"agent_settings", "agent_settings",
"anyhow", "anyhow",
"assistant_context",
"assistant_tool", "assistant_tool",
"assistant_tools", "assistant_tools",
"chrono", "chrono",
@ -204,10 +205,12 @@ dependencies = [
"collections", "collections",
"context_server", "context_server",
"ctor", "ctor",
"db",
"editor", "editor",
"env_logger 0.11.8", "env_logger 0.11.8",
"fs", "fs",
"futures 0.3.31", "futures 0.3.31",
"git",
"gpui", "gpui",
"gpui_tokio", "gpui_tokio",
"handlebars 4.5.0", "handlebars 4.5.0",
@ -221,6 +224,7 @@ dependencies = [
"log", "log",
"lsp", "lsp",
"open", "open",
"parking_lot",
"paths", "paths",
"portable-pty", "portable-pty",
"pretty_assertions", "pretty_assertions",
@ -233,6 +237,7 @@ dependencies = [
"serde_json", "serde_json",
"settings", "settings",
"smol", "smol",
"sqlez",
"task", "task",
"tempfile", "tempfile",
"terminal", "terminal",
@ -249,6 +254,7 @@ dependencies = [
"workspace-hack", "workspace-hack",
"worktree", "worktree",
"zlog", "zlog",
"zstd",
] ]
[[package]] [[package]]
@ -256,7 +262,9 @@ name = "agent_servers"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"acp_thread", "acp_thread",
"action_log",
"agent-client-protocol", "agent-client-protocol",
"agent_settings",
"agentic-coding-protocol", "agentic-coding-protocol",
"anyhow", "anyhow",
"collections", "collections",
@ -267,6 +275,8 @@ dependencies = [
"indoc", "indoc",
"itertools 0.14.0", "itertools 0.14.0",
"language", "language",
"language_model",
"language_models",
"libc", "libc",
"log", "log",
"nix 0.29.0", "nix 0.29.0",
@ -274,6 +284,7 @@ dependencies = [
"project", "project",
"rand 0.8.5", "rand 0.8.5",
"schemars", "schemars",
"semver",
"serde", "serde",
"serde_json", "serde_json",
"settings", "settings",
@ -3070,6 +3081,7 @@ dependencies = [
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"serde_urlencoded",
"settings", "settings",
"sha2", "sha2",
"smol", "smol",
@ -3861,7 +3873,7 @@ dependencies = [
"jni", "jni",
"js-sys", "js-sys",
"libc", "libc",
"mach2", "mach2 0.4.2",
"ndk", "ndk",
"ndk-context", "ndk-context",
"num-derive", "num-derive",
@ -4011,7 +4023,7 @@ checksum = "031ed29858d90cfdf27fe49fae28028a1f20466db97962fa2f4ea34809aeebf3"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"mach2", "mach2 0.4.2",
] ]
[[package]] [[package]]
@ -4023,7 +4035,7 @@ dependencies = [
"cfg-if", "cfg-if",
"crash-context", "crash-context",
"libc", "libc",
"mach2", "mach2 0.4.2",
"parking_lot", "parking_lot",
] ]
@ -4033,6 +4045,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"crash-handler", "crash-handler",
"log", "log",
"mach2 0.5.0",
"minidumper", "minidumper",
"paths", "paths",
"release_channel", "release_channel",
@ -7477,6 +7490,7 @@ dependencies = [
"slotmap", "slotmap",
"smallvec", "smallvec",
"smol", "smol",
"stacksafe",
"strum 0.27.1", "strum 0.27.1",
"sum_tree", "sum_tree",
"taffy", "taffy",
@ -9854,6 +9868,15 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "mach2"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a1b95cd5421ec55b445b5ae102f5ea0e768de1f82bd3001e11f426c269c3aea"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "malloc_buf" name = "malloc_buf"
version = "0.0.6" version = "0.0.6"
@ -10190,7 +10213,7 @@ dependencies = [
"goblin", "goblin",
"libc", "libc",
"log", "log",
"mach2", "mach2 0.4.2",
"memmap2", "memmap2",
"memoffset", "memoffset",
"minidump-common", "minidump-common",
@ -15536,6 +15559,40 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "stacker"
version = "0.1.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b"
dependencies = [
"cc",
"cfg-if",
"libc",
"psm",
"windows-sys 0.59.0",
]
[[package]]
name = "stacksafe"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d9c1172965d317e87ddb6d364a040d958b40a1db82b6ef97da26253a8b3d090"
dependencies = [
"stacker",
"stacksafe-macro",
]
[[package]]
name = "stacksafe-macro"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "172175341049678163e979d9107ca3508046d4d2a7c6682bee46ac541b17db69"
dependencies = [
"proc-macro-error2",
"quote",
"syn 2.0.101",
]
[[package]] [[package]]
name = "static_assertions" name = "static_assertions"
version = "1.1.0" version = "1.1.0"
@ -18247,7 +18304,7 @@ dependencies = [
"indexmap", "indexmap",
"libc", "libc",
"log", "log",
"mach2", "mach2 0.4.2",
"memfd", "memfd",
"object", "object",
"once_cell", "once_cell",
@ -20197,8 +20254,9 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]] [[package]]
name = "yawc" name = "yawc"
version = "0.2.4" version = "0.2.5"
source = "git+https://github.com/deviant-forks/yawc?rev=1899688f3e69ace4545aceb97b2a13881cf26142#1899688f3e69ace4545aceb97b2a13881cf26142" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19a5d82922135b4ae73a079a4ffb5501e9aadb4d785b8c660eaa0a8b899028c5"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"bytes 1.10.1", "bytes 1.10.1",

View file

@ -515,6 +515,7 @@ libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
linkify = "0.10.0" linkify = "0.10.0"
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" } lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" }
mach2 = "0.5"
markup5ever_rcdom = "0.3.0" markup5ever_rcdom = "0.3.0"
metal = "0.29" metal = "0.29"
minidumper = "0.8" minidumper = "0.8"
@ -582,6 +583,7 @@ serde_json_lenient = { version = "0.2", features = [
"raw_value", "raw_value",
] } ] }
serde_repr = "0.1" serde_repr = "0.1"
serde_urlencoded = "0.7"
sha2 = "0.10" sha2 = "0.10"
shellexpand = "2.1.0" shellexpand = "2.1.0"
shlex = "1.3.0" shlex = "1.3.0"
@ -589,6 +591,7 @@ simplelog = "0.12.2"
smallvec = { version = "1.6", features = ["union"] } smallvec = { version = "1.6", features = ["union"] }
smol = "2.0" smol = "2.0"
sqlformat = "0.2" sqlformat = "0.2"
stacksafe = "0.1"
streaming-iterator = "0.1" streaming-iterator = "0.1"
strsim = "0.11" strsim = "0.11"
strum = { version = "0.27.0", features = ["derive"] } strum = { version = "0.27.0", features = ["derive"] }
@ -659,9 +662,7 @@ which = "6.0.0"
windows-core = "0.61" windows-core = "0.61"
wit-component = "0.221" wit-component = "0.221"
workspace-hack = "0.1.0" workspace-hack = "0.1.0"
# We can switch back to the published version once https://github.com/infinitefield/yawc/pull/16 is merged and a new yawc = "0.2.5"
# version is released.
yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" }
zstd = "0.11" zstd = "0.11"
[workspace.dependencies.windows] [workspace.dependencies.windows]
@ -821,10 +822,30 @@ single_range_in_vec_init = "allow"
style = { level = "allow", priority = -1 } style = { level = "allow", priority = -1 }
# Temporary list of style lints that we've fixed so far. # Temporary list of style lints that we've fixed so far.
comparison_to_empty = "warn"
into_iter_on_ref = "warn"
iter_cloned_collect = "warn"
iter_next_slice = "warn"
iter_nth = "warn"
iter_nth_zero = "warn"
iter_skip_next = "warn"
let_and_return = "warn"
module_inception = { level = "deny" } module_inception = { level = "deny" }
question_mark = { level = "deny" } question_mark = { level = "deny" }
single_match = "warn"
redundant_closure = { level = "deny" } redundant_closure = { level = "deny" }
redundant_static_lifetimes = { level = "warn" }
redundant_pattern_matching = "warn"
redundant_field_names = "warn"
declare_interior_mutable_const = { level = "deny" } declare_interior_mutable_const = { level = "deny" }
collapsible_if = { level = "warn"}
collapsible_else_if = { level = "warn" }
needless_borrow = { level = "warn"}
needless_return = { level = "warn" }
unnecessary_mut_passed = {level = "warn"}
unnecessary_map_or = { level = "warn" }
unused_unit = "warn"
# Individual rules that have violations in the codebase: # Individual rules that have violations in the codebase:
type_complexity = "allow" type_complexity = "allow"
# We often return trait objects from `new` functions. # We often return trait objects from `new` functions.
@ -833,6 +854,8 @@ new_ret_no_self = { level = "allow" }
# compared to Iterator::next. Yet, clippy complains about those. # compared to Iterator::next. Yet, clippy complains about those.
should_implement_trait = { level = "allow" } should_implement_trait = { level = "allow" }
let_underscore_future = "allow" let_underscore_future = "allow"
# It doesn't make sense to implement `Default` unilaterally.
new_without_default = "allow"
# in Rust it can be very tedious to reduce argument count without # in Rust it can be very tedious to reduce argument count without
# running afoul of the borrow checker. # running afoul of the borrow checker.
@ -841,6 +864,10 @@ too_many_arguments = "allow"
# We often have large enum variants yet we rarely actually bother with splitting them up. # We often have large enum variants yet we rarely actually bother with splitting them up.
large_enum_variant = "allow" large_enum_variant = "allow"
# `enum_variant_names` fires for all enums, even when they derive serde traits.
# Adhering to this lint would be a breaking change.
enum_variant_names = "allow"
[workspace.metadata.cargo-machete] [workspace.metadata.cargo-machete]
ignored = [ ignored = [
"bindgen", "bindgen",

View file

@ -1 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="none"><path stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.2" d="M2.667 8h8M2.667 4h10.666M2.667 12H8"/></svg> <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M13.333 10H8M13.333 6H2.66701" stroke="black" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 210 B

After

Width:  |  Height:  |  Size: 227 B

Before After
Before After

View file

@ -0,0 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M13.333 10H8M13.333 6H2.66701" stroke="black" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 227 B

View file

@ -0,0 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 2C11.3137 2 14 4.68629 14 8C14 11.3137 11.3137 14 8 14C4.68629 14 2 11.3137 2 8C2 4.68629 4.68629 2 8 2ZM10.4238 5.57617C10.1895 5.34187 9.81049 5.3419 9.57617 5.57617L8 7.15234L6.42383 5.57617C6.18953 5.34187 5.81049 5.3419 5.57617 5.57617C5.34186 5.81049 5.34186 6.18951 5.57617 6.42383L7.15234 8L5.57617 9.57617C5.34186 9.81049 5.34186 10.1895 5.57617 10.4238C5.81049 10.6581 6.18954 10.6581 6.42383 10.4238L8 8.84766L9.57617 10.4238C9.81049 10.6581 10.1895 10.6581 10.4238 10.4238C10.6581 10.1895 10.658 9.81048 10.4238 9.57617L8.84766 8L10.4238 6.42383C10.6581 6.18954 10.658 5.81048 10.4238 5.57617Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 737 B

View file

@ -0,0 +1,27 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M11 8.75V10.5C8.93097 10.5 8.06903 10.5 6 10.5V10L11 6V5.5H6V7.25" stroke="black" stroke-width="1.2"/>
<path d="M2 8.5C2.27614 8.5 2.5 8.27614 2.5 8C2.5 7.72386 2.27614 7.5 2 7.5C1.72386 7.5 1.5 7.72386 1.5 8C1.5 8.27614 1.72386 8.5 2 8.5Z" fill="black"/>
<path d="M2.99976 6.33002C3.2759 6.33002 3.49976 6.10616 3.49976 5.83002C3.49976 5.55387 3.2759 5.33002 2.99976 5.33002C2.72361 5.33002 2.49976 5.55387 2.49976 5.83002C2.49976 6.10616 2.72361 6.33002 2.99976 6.33002Z" fill="black"/>
<path d="M2.99976 10.66C3.2759 10.66 3.49976 10.4361 3.49976 10.16C3.49976 9.88383 3.2759 9.65997 2.99976 9.65997C2.72361 9.65997 2.49976 9.88383 2.49976 10.16C2.49976 10.4361 2.72361 10.66 2.99976 10.66Z" fill="black"/>
<path d="M15 8.5C15.2761 8.5 15.5 8.27614 15.5 8C15.5 7.72386 15.2761 7.5 15 7.5C14.7239 7.5 14.5 7.72386 14.5 8C14.5 8.27614 14.7239 8.5 15 8.5Z" fill="black"/>
<path d="M14 6.33002C14.2761 6.33002 14.5 6.10616 14.5 5.83002C14.5 5.55387 14.2761 5.33002 14 5.33002C13.7239 5.33002 13.5 5.55387 13.5 5.83002C13.5 6.10616 13.7239 6.33002 14 6.33002Z" fill="black"/>
<path d="M14 10.66C14.2761 10.66 14.5 10.4361 14.5 10.16C14.5 9.88383 14.2761 9.65997 14 9.65997C13.7239 9.65997 13.5 9.88383 13.5 10.16C13.5 10.4361 13.7239 10.66 14 10.66Z" fill="black"/>
<path d="M8.49219 2C8.76833 2 8.99219 1.77614 8.99219 1.5C8.99219 1.22386 8.76833 1 8.49219 1C8.21605 1 7.99219 1.22386 7.99219 1.5C7.99219 1.77614 8.21605 2 8.49219 2Z" fill="black"/>
<path d="M6 3C6.27614 3 6.5 2.77614 6.5 2.5C6.5 2.22386 6.27614 2 6 2C5.72386 2 5.5 2.22386 5.5 2.5C5.5 2.77614 5.72386 3 6 3Z" fill="black"/>
<path d="M4 4C4.27614 4 4.5 3.77614 4.5 3.5C4.5 3.22386 4.27614 3 4 3C3.72386 3 3.5 3.22386 3.5 3.5C3.5 3.77614 3.72386 4 4 4Z" fill="black"/>
<path d="M3.99976 13C4.2759 13 4.49976 12.7761 4.49976 12.5C4.49976 12.2239 4.2759 12 3.99976 12C3.72361 12 3.49976 12.2239 3.49976 12.5C3.49976 12.7761 3.72361 13 3.99976 13Z" fill="black"/>
<path d="M2 12.5C2.27614 12.5 2.5 12.2761 2.5 12C2.5 11.7239 2.27614 11.5 2 11.5C1.72386 11.5 1.5 11.7239 1.5 12C1.5 12.2761 1.72386 12.5 2 12.5Z" fill="black"/>
<path d="M2 4.5C2.27614 4.5 2.5 4.27614 2.5 4C2.5 3.72386 2.27614 3.5 2 3.5C1.72386 3.5 1.5 3.72386 1.5 4C1.5 4.27614 1.72386 4.5 2 4.5Z" fill="black"/>
<path d="M15 12.5C15.2761 12.5 15.5 12.2761 15.5 12C15.5 11.7239 15.2761 11.5 15 11.5C14.7239 11.5 14.5 11.7239 14.5 12C14.5 12.2761 14.7239 12.5 15 12.5Z" fill="black"/>
<path d="M15 4.5C15.2761 4.5 15.5 4.27614 15.5 4C15.5 3.72386 15.2761 3.5 15 3.5C14.7239 3.5 14.5 3.72386 14.5 4C14.5 4.27614 14.7239 4.5 15 4.5Z" fill="black"/>
<path d="M3.99976 15C4.2759 15 4.49976 14.7761 4.49976 14.5C4.49976 14.2239 4.2759 14 3.99976 14C3.72361 14 3.49976 14.2239 3.49976 14.5C3.49976 14.7761 3.72361 15 3.99976 15Z" fill="black"/>
<path d="M4 2C4.27614 2 4.5 1.77614 4.5 1.5C4.5 1.22386 4.27614 1 4 1C3.72386 1 3.5 1.22386 3.5 1.5C3.5 1.77614 3.72386 2 4 2Z" fill="black"/>
<path d="M13 15C13.2761 15 13.5 14.7761 13.5 14.5C13.5 14.2239 13.2761 14 13 14C12.7239 14 12.5 14.2239 12.5 14.5C12.5 14.7761 12.7239 15 13 15Z" fill="black"/>
<path d="M13 2C13.2761 2 13.5 1.77614 13.5 1.5C13.5 1.22386 13.2761 1 13 1C12.7239 1 12.5 1.22386 12.5 1.5C12.5 1.77614 12.7239 2 13 2Z" fill="black"/>
<path d="M13 4C13.2761 4 13.5 3.77614 13.5 3.5C13.5 3.22386 13.2761 3 13 3C12.7239 3 12.5 3.22386 12.5 3.5C12.5 3.77614 12.7239 4 13 4Z" fill="black"/>
<path d="M13 13C13.2761 13 13.5 12.7761 13.5 12.5C13.5 12.2239 13.2761 12 13 12C12.7239 12 12.5 12.2239 12.5 12.5C12.5 12.7761 12.7239 13 13 13Z" fill="black"/>
<path d="M11 3C11.2761 3 11.5 2.77614 11.5 2.5C11.5 2.22386 11.2761 2 11 2C10.7239 2 10.5 2.22386 10.5 2.5C10.5 2.77614 10.7239 3 11 3Z" fill="black"/>
<path d="M8.5 15C8.77614 15 9 14.7761 9 14.5C9 14.2239 8.77614 14 8.5 14C8.22386 14 8 14.2239 8 14.5C8 14.7761 8.22386 15 8.5 15Z" fill="black"/>
<path d="M6 14C6.27614 14 6.5 13.7761 6.5 13.5C6.5 13.2239 6.27614 13 6 13C5.72386 13 5.5 13.2239 5.5 13.5C5.5 13.7761 5.72386 14 6 14Z" fill="black"/>
<path d="M11 14C11.2761 14 11.5 13.7761 11.5 13.5C11.5 13.2239 11.2761 13 11 13C10.7239 13 10.5 13.2239 10.5 13.5C10.5 13.7761 10.7239 14 11 14Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 4.2 KiB

View file

@ -327,7 +327,7 @@
} }
}, },
{ {
"context": "AcpThread > Editor", "context": "AcpThread > Editor && !use_modifier_to_send",
"use_key_equivalents": true, "use_key_equivalents": true,
"bindings": { "bindings": {
"enter": "agent::Chat", "enter": "agent::Chat",
@ -336,6 +336,16 @@
"ctrl-shift-n": "agent::RejectAll" "ctrl-shift-n": "agent::RejectAll"
} }
}, },
{
"context": "AcpThread > Editor && use_modifier_to_send",
"use_key_equivalents": true,
"bindings": {
"ctrl-enter": "agent::Chat",
"shift-ctrl-r": "agent::OpenAgentDiff",
"ctrl-shift-y": "agent::KeepAll",
"ctrl-shift-n": "agent::RejectAll"
}
},
{ {
"context": "ThreadHistory", "context": "ThreadHistory",
"bindings": { "bindings": {

View file

@ -379,7 +379,7 @@
} }
}, },
{ {
"context": "AcpThread > Editor", "context": "AcpThread > Editor && !use_modifier_to_send",
"use_key_equivalents": true, "use_key_equivalents": true,
"bindings": { "bindings": {
"enter": "agent::Chat", "enter": "agent::Chat",
@ -388,6 +388,16 @@
"cmd-shift-n": "agent::RejectAll" "cmd-shift-n": "agent::RejectAll"
} }
}, },
{
"context": "AcpThread > Editor && use_modifier_to_send",
"use_key_equivalents": true,
"bindings": {
"cmd-enter": "agent::Chat",
"shift-ctrl-r": "agent::OpenAgentDiff",
"cmd-shift-y": "agent::KeepAll",
"cmd-shift-n": "agent::RejectAll"
}
},
{ {
"context": "ThreadHistory", "context": "ThreadHistory",
"bindings": { "bindings": {

View file

@ -717,7 +717,7 @@
// Can be 'never', 'always', or 'when_in_call', // Can be 'never', 'always', or 'when_in_call',
// or a boolean (interpreted as 'never'/'always'). // or a boolean (interpreted as 'never'/'always').
"button": "when_in_call", "button": "when_in_call",
// Where to the chat panel. Can be 'left' or 'right'. // Where to dock the chat panel. Can be 'left' or 'right'.
"dock": "right", "dock": "right",
// Default width of the chat panel. // Default width of the chat panel.
"default_width": 240 "default_width": 240
@ -725,7 +725,7 @@
"git_panel": { "git_panel": {
// Whether to show the git panel button in the status bar. // Whether to show the git panel button in the status bar.
"button": true, "button": true,
// Where to show the git panel. Can be 'left' or 'right'. // Where to dock the git panel. Can be 'left' or 'right'.
"dock": "left", "dock": "left",
// Default width of the git panel. // Default width of the git panel.
"default_width": 360, "default_width": 360,

View file

@ -18,7 +18,6 @@ test-support = ["gpui/test-support", "project/test-support", "dep:parking_lot"]
[dependencies] [dependencies]
action_log.workspace = true action_log.workspace = true
agent-client-protocol.workspace = true agent-client-protocol.workspace = true
agent.workspace = true
anyhow.workspace = true anyhow.workspace = true
buffer_diff.workspace = true buffer_diff.workspace = true
collections.workspace = true collections.workspace = true
@ -28,6 +27,7 @@ futures.workspace = true
gpui.workspace = true gpui.workspace = true
itertools.workspace = true itertools.workspace = true
language.workspace = true language.workspace = true
language_model.workspace = true
markdown.workspace = true markdown.workspace = true
parking_lot = { workspace = true, optional = true } parking_lot = { workspace = true, optional = true }
project.workspace = true project.workspace = true

View file

@ -3,9 +3,13 @@ mod diff;
mod mention; mod mention;
mod terminal; mod terminal;
use collections::HashSet;
pub use connection::*; pub use connection::*;
pub use diff::*; pub use diff::*;
use language::language_settings::FormatOnSave;
pub use mention::*; pub use mention::*;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
use serde::{Deserialize, Serialize};
pub use terminal::*; pub use terminal::*;
use action_log::ActionLog; use action_log::ActionLog;
@ -24,6 +28,7 @@ use std::fmt::{Formatter, Write};
use std::ops::Range; use std::ops::Range;
use std::process::ExitStatus; use std::process::ExitStatus;
use std::rc::Rc; use std::rc::Rc;
use std::time::{Duration, Instant};
use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
use ui::App; use ui::App;
use util::ResultExt; use util::ResultExt;
@ -48,7 +53,7 @@ impl UserMessage {
if self if self
.checkpoint .checkpoint
.as_ref() .as_ref()
.map_or(false, |checkpoint| checkpoint.show) .is_some_and(|checkpoint| checkpoint.show)
{ {
writeln!(markdown, "## User (checkpoint)").unwrap(); writeln!(markdown, "## User (checkpoint)").unwrap();
} else { } else {
@ -248,14 +253,13 @@ impl ToolCall {
} }
if let Some(raw_output) = raw_output { if let Some(raw_output) = raw_output {
if self.content.is_empty() { if self.content.is_empty()
if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx) && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
{ {
self.content self.content
.push(ToolCallContent::ContentBlock(ContentBlock::Markdown { .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
markdown, markdown,
})); }));
}
} }
self.raw_output = Some(raw_output); self.raw_output = Some(raw_output);
} }
@ -429,11 +433,11 @@ impl ContentBlock {
language_registry: &Arc<LanguageRegistry>, language_registry: &Arc<LanguageRegistry>,
cx: &mut App, cx: &mut App,
) { ) {
if matches!(self, ContentBlock::Empty) { if matches!(self, ContentBlock::Empty)
if let acp::ContentBlock::ResourceLink(resource_link) = block { && let acp::ContentBlock::ResourceLink(resource_link) = block
*self = ContentBlock::ResourceLink { resource_link }; {
return; *self = ContentBlock::ResourceLink { resource_link };
} return;
} }
let new_content = self.block_string_contents(block); let new_content = self.block_string_contents(block);
@ -485,7 +489,7 @@ impl ContentBlock {
} }
fn resource_link_md(uri: &str) -> String { fn resource_link_md(uri: &str) -> String {
if let Some(uri) = MentionUri::parse(&uri).log_err() { if let Some(uri) = MentionUri::parse(uri).log_err() {
uri.as_link().to_string() uri.as_link().to_string()
} else { } else {
uri.to_string() uri.to_string()
@ -537,9 +541,15 @@ impl ToolCallContent {
acp::ToolCallContent::Content { content } => { acp::ToolCallContent::Content { content } => {
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx)) Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
} }
acp::ToolCallContent::Diff { diff } => { acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx))) Diff::finalized(
} diff.path,
diff.old_text,
diff.new_text,
language_registry,
cx,
)
})),
} }
} }
@ -658,6 +668,21 @@ impl PlanEntry {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenUsage {
pub max_tokens: u64,
pub used_tokens: u64,
}
#[derive(Debug, Clone)]
pub struct RetryStatus {
pub last_error: SharedString,
pub attempt: usize,
pub max_attempts: usize,
pub started_at: Instant,
pub duration: Duration,
}
pub struct AcpThread { pub struct AcpThread {
title: SharedString, title: SharedString,
entries: Vec<AgentThreadEntry>, entries: Vec<AgentThreadEntry>,
@ -668,16 +693,21 @@ pub struct AcpThread {
send_task: Option<Task<()>>, send_task: Option<Task<()>>,
connection: Rc<dyn AgentConnection>, connection: Rc<dyn AgentConnection>,
session_id: acp::SessionId, session_id: acp::SessionId,
token_usage: Option<TokenUsage>,
} }
#[derive(Debug)]
pub enum AcpThreadEvent { pub enum AcpThreadEvent {
NewEntry, NewEntry,
TitleUpdated,
TokenUsageUpdated,
EntryUpdated(usize), EntryUpdated(usize),
EntriesRemoved(Range<usize>), EntriesRemoved(Range<usize>),
ToolAuthorizationRequired, ToolAuthorizationRequired,
Retry(RetryStatus),
Stopped, Stopped,
Error, Error,
ServerExited(ExitStatus), LoadError(LoadError),
} }
impl EventEmitter<AcpThreadEvent> for AcpThread {} impl EventEmitter<AcpThreadEvent> for AcpThread {}
@ -691,20 +721,30 @@ pub enum ThreadStatus {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum LoadError { pub enum LoadError {
NotInstalled {
error_message: SharedString,
install_message: SharedString,
install_command: String,
},
Unsupported { Unsupported {
error_message: SharedString, error_message: SharedString,
upgrade_message: SharedString, upgrade_message: SharedString,
upgrade_command: String, upgrade_command: String,
}, },
Exited(i32), Exited {
status: ExitStatus,
},
Other(SharedString), Other(SharedString),
} }
impl Display for LoadError { impl Display for LoadError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self { match self {
LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message), LoadError::NotInstalled { error_message, .. }
LoadError::Exited(status) => write!(f, "Server exited with status {}", status), | LoadError::Unsupported { error_message, .. } => {
write!(f, "{error_message}")
}
LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
LoadError::Other(msg) => write!(f, "{}", msg), LoadError::Other(msg) => write!(f, "{}", msg),
} }
} }
@ -717,11 +757,9 @@ impl AcpThread {
title: impl Into<SharedString>, title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>, connection: Rc<dyn AgentConnection>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>,
session_id: acp::SessionId, session_id: acp::SessionId,
cx: &mut Context<Self>,
) -> Self { ) -> Self {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self { Self {
action_log, action_log,
shared_buffers: Default::default(), shared_buffers: Default::default(),
@ -732,6 +770,7 @@ impl AcpThread {
send_task: None, send_task: None,
connection, connection,
session_id, session_id,
token_usage: None,
} }
} }
@ -771,6 +810,10 @@ impl AcpThread {
} }
} }
pub fn token_usage(&self) -> Option<&TokenUsage> {
self.token_usage.as_ref()
}
pub fn has_pending_edit_tool_calls(&self) -> bool { pub fn has_pending_edit_tool_calls(&self) -> bool {
for entry in self.entries.iter().rev() { for entry in self.entries.iter().rev() {
match entry { match entry {
@ -915,6 +958,21 @@ impl AcpThread {
cx.emit(AcpThreadEvent::NewEntry); cx.emit(AcpThreadEvent::NewEntry);
} }
pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
self.title = title;
cx.emit(AcpThreadEvent::TitleUpdated);
Ok(())
}
pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
self.token_usage = usage;
cx.emit(AcpThreadEvent::TokenUsageUpdated);
}
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::Retry(status));
}
pub fn update_tool_call( pub fn update_tool_call(
&mut self, &mut self,
update: impl Into<ToolCallUpdate>, update: impl Into<ToolCallUpdate>,
@ -1006,6 +1064,22 @@ impl AcpThread {
}) })
} }
pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
self.entries
.iter()
.enumerate()
.rev()
.find_map(|(index, tool_call)| {
if let AgentThreadEntry::ToolCall(tool_call) = tool_call
&& &tool_call.id == id
{
Some((index, tool_call))
} else {
None
}
})
}
pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) { pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
let project = self.project.clone(); let project = self.project.clone();
let Some((_, tool_call)) = self.tool_call_mut(&id) else { let Some((_, tool_call)) = self.tool_call_mut(&id) else {
@ -1199,17 +1273,21 @@ impl AcpThread {
} else { } else {
None None
}; };
self.push_entry(
AgentThreadEntry::UserMessage(UserMessage {
id: message_id.clone(),
content: block,
chunks: message,
checkpoint: None,
}),
cx,
);
self.run_turn(cx, async move |this, cx| { self.run_turn(cx, async move |this, cx| {
this.update(cx, |this, cx| {
this.push_entry(
AgentThreadEntry::UserMessage(UserMessage {
id: message_id.clone(),
content: block,
chunks: message,
checkpoint: None,
}),
cx,
);
})
.ok();
let old_checkpoint = git_store let old_checkpoint = git_store
.update(cx, |git, cx| git.checkpoint(cx))? .update(cx, |git, cx| git.checkpoint(cx))?
.await .await
@ -1262,6 +1340,8 @@ impl AcpThread {
.await?; .await?;
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.project
.update(cx, |project, cx| project.set_agent_location(None, cx));
match response { match response {
Ok(Err(e)) => { Ok(Err(e)) => {
this.send_task.take(); this.send_task.take();
@ -1411,7 +1491,7 @@ impl AcpThread {
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> { fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
self.entries.iter().find_map(|entry| { self.entries.iter().find_map(|entry| {
if let AgentThreadEntry::UserMessage(message) = entry { if let AgentThreadEntry::UserMessage(message) = entry {
if message.id.as_ref() == Some(&id) { if message.id.as_ref() == Some(id) {
Some(message) Some(message)
} else { } else {
None None
@ -1425,7 +1505,7 @@ impl AcpThread {
fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> { fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| { self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
if let AgentThreadEntry::UserMessage(message) = entry { if let AgentThreadEntry::UserMessage(message) = entry {
if message.id.as_ref() == Some(&id) { if message.id.as_ref() == Some(id) {
Some((ix, message)) Some((ix, message))
} else { } else {
None None
@ -1550,30 +1630,59 @@ impl AcpThread {
.collect::<Vec<_>>() .collect::<Vec<_>>()
}) })
.await; .await;
cx.update(|cx| {
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: edits
.last()
.map(|(range, _)| range.end)
.unwrap_or(Anchor::MIN),
}),
cx,
);
});
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: edits
.last()
.map(|(range, _)| range.end)
.unwrap_or(Anchor::MIN),
}),
cx,
);
})?;
let format_on_save = cx.update(|cx| {
action_log.update(cx, |action_log, cx| { action_log.update(cx, |action_log, cx| {
action_log.buffer_read(buffer.clone(), cx); action_log.buffer_read(buffer.clone(), cx);
}); });
buffer.update(cx, |buffer, cx| {
let format_on_save = buffer.update(cx, |buffer, cx| {
buffer.edit(edits, None, cx); buffer.edit(edits, None, cx);
let settings = language::language_settings::language_settings(
buffer.language().map(|l| l.name()),
buffer.file(),
cx,
);
settings.format_on_save != FormatOnSave::Off
}); });
action_log.update(cx, |action_log, cx| { action_log.update(cx, |action_log, cx| {
action_log.buffer_edited(buffer.clone(), cx); action_log.buffer_edited(buffer.clone(), cx);
}); });
format_on_save
})?; })?;
if format_on_save {
let format_task = project.update(cx, |project, cx| {
project.format(
HashSet::from_iter([buffer.clone()]),
LspFormatTarget::Buffers,
false,
FormatTrigger::Save,
cx,
)
})?;
format_task.await.log_err();
action_log.update(cx, |action_log, cx| {
action_log.buffer_edited(buffer.clone(), cx);
})?;
}
project project
.update(cx, |project, cx| project.save_buffer(buffer, cx))? .update(cx, |project, cx| project.save_buffer(buffer, cx))?
.await .await
@ -1584,8 +1693,8 @@ impl AcpThread {
self.entries.iter().map(|e| e.to_markdown(cx)).collect() self.entries.iter().map(|e| e.to_markdown(cx)).collect()
} }
pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) { pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::ServerExited(status)); cx.emit(AcpThreadEvent::LoadError(error));
} }
} }
@ -1636,7 +1745,7 @@ mod tests {
use super::*; use super::*;
use anyhow::anyhow; use anyhow::anyhow;
use futures::{channel::mpsc, future::LocalBoxFuture, select}; use futures::{channel::mpsc, future::LocalBoxFuture, select};
use gpui::{AsyncApp, TestAppContext, WeakEntity}; use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc; use indoc::indoc;
use project::{FakeFs, Fs}; use project::{FakeFs, Fs};
use rand::Rng as _; use rand::Rng as _;
@ -2123,7 +2232,7 @@ mod tests {
"} "}
); );
}); });
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx))) cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
.await .await
@ -2153,7 +2262,10 @@ mod tests {
}); });
assert_eq!( assert_eq!(
fs.files(), fs.files(),
vec![Path::new("/test/file-0"), Path::new("/test/file-1")] vec![
Path::new(path!("/test/file-0")),
Path::new(path!("/test/file-1"))
]
); );
// Checkpoint isn't stored when there are no changes. // Checkpoint isn't stored when there are no changes.
@ -2194,7 +2306,10 @@ mod tests {
}); });
assert_eq!( assert_eq!(
fs.files(), fs.files(),
vec![Path::new("/test/file-0"), Path::new("/test/file-1")] vec![
Path::new(path!("/test/file-0")),
Path::new(path!("/test/file-1"))
]
); );
// Rewinding the conversation truncates the history and restores the checkpoint. // Rewinding the conversation truncates the history and restores the checkpoint.
@ -2222,7 +2337,7 @@ mod tests {
"} "}
); );
}); });
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
} }
async fn run_until_first_tool_call( async fn run_until_first_tool_call(
@ -2306,7 +2421,7 @@ mod tests {
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
_cwd: &Path, _cwd: &Path,
cx: &mut gpui::App, cx: &mut App,
) -> Task<gpui::Result<Entity<AcpThread>>> { ) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId( let session_id = acp::SessionId(
rand::thread_rng() rand::thread_rng()
@ -2316,8 +2431,16 @@ mod tests {
.collect::<String>() .collect::<String>()
.into(), .into(),
); );
let thread = let action_log = cx.new(|_| ActionLog::new(project.clone()));
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); let thread = cx.new(|_cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
)
});
self.sessions.lock().insert(session_id, thread.downgrade()); self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread)) Task::ready(Ok(thread))
} }
@ -2351,7 +2474,7 @@ mod tests {
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
let sessions = self.sessions.lock(); let sessions = self.sessions.lock();
let thread = sessions.get(&session_id).unwrap().clone(); let thread = sessions.get(session_id).unwrap().clone();
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
thread thread

View file

@ -3,12 +3,14 @@ use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;
use collections::IndexMap; use collections::IndexMap;
use gpui::{Entity, SharedString, Task}; use gpui::{Entity, SharedString, Task};
use language_model::LanguageModelProviderId;
use project::Project; use project::Project;
use serde::{Deserialize, Serialize};
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName}; use ui::{App, IconName};
use uuid::Uuid; use uuid::Uuid;
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct UserMessageId(Arc<str>); pub struct UserMessageId(Arc<str>);
impl UserMessageId { impl UserMessageId {
@ -80,12 +82,34 @@ pub trait AgentSessionResume {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct AuthRequired; pub struct AuthRequired {
pub description: Option<String>,
pub provider_id: Option<LanguageModelProviderId>,
}
impl AuthRequired {
pub fn new() -> Self {
Self {
description: None,
provider_id: None,
}
}
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
self.provider_id = Some(provider_id);
self
}
}
impl Error for AuthRequired {} impl Error for AuthRequired {}
impl fmt::Display for AuthRequired { impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired") write!(f, "Authentication required")
} }
} }
@ -185,8 +209,9 @@ impl AgentModelList {
mod test_support { mod test_support {
use std::sync::Arc; use std::sync::Arc;
use action_log::ActionLog;
use collections::HashMap; use collections::HashMap;
use futures::future::try_join_all; use futures::{channel::oneshot, future::try_join_all};
use gpui::{AppContext as _, WeakEntity}; use gpui::{AppContext as _, WeakEntity};
use parking_lot::Mutex; use parking_lot::Mutex;
@ -194,11 +219,16 @@ mod test_support {
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct StubAgentConnection { pub struct StubAgentConnection {
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>, sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>, permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>, next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
} }
struct Session {
thread: WeakEntity<AcpThread>,
response_tx: Option<oneshot::Sender<acp::StopReason>>,
}
impl StubAgentConnection { impl StubAgentConnection {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -226,15 +256,33 @@ mod test_support {
update: acp::SessionUpdate, update: acp::SessionUpdate,
cx: &mut App, cx: &mut App,
) { ) {
assert!(
self.next_prompt_updates.lock().is_empty(),
"Use either send_update or set_next_prompt_updates"
);
self.sessions self.sessions
.lock() .lock()
.get(&session_id) .get(&session_id)
.unwrap() .unwrap()
.thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap(); thread.handle_session_update(update, cx).unwrap();
}) })
.unwrap(); .unwrap();
} }
pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
self.sessions
.lock()
.get_mut(&session_id)
.unwrap()
.response_tx
.take()
.expect("No pending turn")
.send(stop_reason)
.unwrap();
}
} }
impl AgentConnection for StubAgentConnection { impl AgentConnection for StubAgentConnection {
@ -249,9 +297,23 @@ mod test_support {
cx: &mut gpui::App, cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> { ) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into()); let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let thread = let action_log = cx.new(|_| ActionLog::new(project.clone()));
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); let thread = cx.new(|_cx| {
self.sessions.lock().insert(session_id, thread.downgrade()); AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
)
});
self.sessions.lock().insert(
session_id,
Session {
thread: thread.downgrade(),
response_tx: None,
},
);
Task::ready(Ok(thread)) Task::ready(Ok(thread))
} }
@ -269,47 +331,70 @@ mod test_support {
params: acp::PromptRequest, params: acp::PromptRequest,
cx: &mut App, cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> { ) -> Task<gpui::Result<acp::PromptResponse>> {
let sessions = self.sessions.lock(); let mut sessions = self.sessions.lock();
let thread = sessions.get(&params.session_id).unwrap(); let Session {
thread,
response_tx,
} = sessions.get_mut(&params.session_id).unwrap();
let mut tasks = vec![]; let mut tasks = vec![];
for update in self.next_prompt_updates.lock().drain(..) { if self.next_prompt_updates.lock().is_empty() {
let thread = thread.clone(); let (tx, rx) = oneshot::channel();
let update = update.clone(); response_tx.replace(tx);
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update cx.spawn(async move |_| {
&& let Some(options) = self.permission_requests.get(&tool_call.id) let stop_reason = rx.await?;
{ Ok(acp::PromptResponse { stop_reason })
Some((tool_call.clone(), options.clone()))
} else {
None
};
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
tool_call.clone().into(),
options.clone(),
cx,
)
})?;
permission?.await?;
}
thread.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})?;
anyhow::Ok(())
});
tasks.push(task);
}
cx.spawn(async move |_| {
try_join_all(tasks).await?;
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
}) })
}) } else {
for update in self.next_prompt_updates.lock().drain(..) {
let thread = thread.clone();
let update = update.clone();
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
&update
&& let Some(options) = self.permission_requests.get(&tool_call.id)
{
Some((tool_call.clone(), options.clone()))
} else {
None
};
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
tool_call.clone().into(),
options.clone(),
cx,
)
})?;
permission?.await?;
}
thread.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})?;
anyhow::Ok(())
});
tasks.push(task);
}
cx.spawn(async move |_| {
try_join_all(tasks).await?;
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
} }
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
unimplemented!() if let Some(end_turn_tx) = self
.sessions
.lock()
.get_mut(session_id)
.unwrap()
.response_tx
.take()
{
end_turn_tx.send(acp::StopReason::Canceled).unwrap();
}
} }
fn session_editor( fn session_editor(

View file

@ -1,4 +1,3 @@
use agent_client_protocol as acp;
use anyhow::Result; use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey}; use editor::{MultiBuffer, PathKey};
@ -21,17 +20,13 @@ pub enum Diff {
} }
impl Diff { impl Diff {
pub fn from_acp( pub fn finalized(
diff: acp::Diff, path: PathBuf,
old_text: Option<String>,
new_text: String,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
@ -71,8 +66,8 @@ impl Diff {
let hunk_ranges = { let hunk_ranges = {
let buffer = new_buffer.read(cx); let buffer = new_buffer.read(cx);
let diff = buffer_diff.read(cx); let diff = buffer_diff.read(cx);
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) .map(|diff_hunk| diff_hunk.buffer_range.to_point(buffer))
.collect::<Vec<_>>() .collect::<Vec<_>>()
}; };
@ -306,13 +301,13 @@ impl PendingDiff {
let buffer = self.buffer.read(cx); let buffer = self.buffer.read(cx);
let diff = self.diff.read(cx); let diff = self.diff.read(cx);
let mut ranges = diff let mut ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) .map(|diff_hunk| diff_hunk.buffer_range.to_point(buffer))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
ranges.extend( ranges.extend(
self.revealed_ranges self.revealed_ranges
.iter() .iter()
.map(|range| range.to_point(&buffer)), .map(|range| range.to_point(buffer)),
); );
ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end))); ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));

View file

@ -1,7 +1,8 @@
use agent::ThreadId; use agent_client_protocol as acp;
use anyhow::{Context as _, Result, bail}; use anyhow::{Context as _, Result, bail};
use file_icons::FileIcons; use file_icons::FileIcons;
use prompt_store::{PromptId, UserPromptId}; use prompt_store::{PromptId, UserPromptId};
use serde::{Deserialize, Serialize};
use std::{ use std::{
fmt, fmt,
ops::Range, ops::Range,
@ -11,11 +12,13 @@ use std::{
use ui::{App, IconName, SharedString}; use ui::{App, IconName, SharedString};
use url::Url; use url::Url;
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum MentionUri { pub enum MentionUri {
File { File {
abs_path: PathBuf, abs_path: PathBuf,
is_directory: bool, },
Directory {
abs_path: PathBuf,
}, },
Symbol { Symbol {
path: PathBuf, path: PathBuf,
@ -23,7 +26,7 @@ pub enum MentionUri {
line_range: Range<u32>, line_range: Range<u32>,
}, },
Thread { Thread {
id: ThreadId, id: acp::SessionId,
name: String, name: String,
}, },
TextThread { TextThread {
@ -49,6 +52,7 @@ impl MentionUri {
let path = url.path(); let path = url.path();
match url.scheme() { match url.scheme() {
"file" => { "file" => {
let path = url.to_file_path().ok().context("Extracting file path")?;
if let Some(fragment) = url.fragment() { if let Some(fragment) = url.fragment() {
let range = fragment let range = fragment
.strip_prefix("L") .strip_prefix("L")
@ -69,31 +73,23 @@ impl MentionUri {
if let Some(name) = single_query_param(&url, "symbol")? { if let Some(name) = single_query_param(&url, "symbol")? {
Ok(Self::Symbol { Ok(Self::Symbol {
name, name,
path: path.into(), path,
line_range, line_range,
}) })
} else { } else {
Ok(Self::Selection { Ok(Self::Selection { path, line_range })
path: path.into(),
line_range,
})
} }
} else if input.ends_with("/") {
Ok(Self::Directory { abs_path: path })
} else { } else {
let file_path = Ok(Self::File { abs_path: path })
PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path));
let is_directory = input.ends_with("/");
Ok(Self::File {
abs_path: file_path,
is_directory,
})
} }
} }
"zed" => { "zed" => {
if let Some(thread_id) = path.strip_prefix("/agent/thread/") { if let Some(thread_id) = path.strip_prefix("/agent/thread/") {
let name = single_query_param(&url, "name")?.context("Missing thread name")?; let name = single_query_param(&url, "name")?.context("Missing thread name")?;
Ok(Self::Thread { Ok(Self::Thread {
id: thread_id.into(), id: acp::SessionId(thread_id.into()),
name, name,
}) })
} else if let Some(path) = path.strip_prefix("/agent/text-thread/") { } else if let Some(path) = path.strip_prefix("/agent/text-thread/") {
@ -120,7 +116,7 @@ impl MentionUri {
pub fn name(&self) -> String { pub fn name(&self) -> String {
match self { match self {
MentionUri::File { abs_path, .. } => abs_path MentionUri::File { abs_path, .. } | MentionUri::Directory { abs_path, .. } => abs_path
.file_name() .file_name()
.unwrap_or_default() .unwrap_or_default()
.to_string_lossy() .to_string_lossy()
@ -138,18 +134,11 @@ impl MentionUri {
pub fn icon_path(&self, cx: &mut App) -> SharedString { pub fn icon_path(&self, cx: &mut App) -> SharedString {
match self { match self {
MentionUri::File { MentionUri::File { abs_path } => {
abs_path, FileIcons::get_icon(abs_path, cx).unwrap_or_else(|| IconName::File.path().into())
is_directory,
} => {
if *is_directory {
FileIcons::get_folder_icon(false, cx)
.unwrap_or_else(|| IconName::Folder.path().into())
} else {
FileIcons::get_icon(&abs_path, cx)
.unwrap_or_else(|| IconName::File.path().into())
}
} }
MentionUri::Directory { .. } => FileIcons::get_folder_icon(false, cx)
.unwrap_or_else(|| IconName::Folder.path().into()),
MentionUri::Symbol { .. } => IconName::Code.path().into(), MentionUri::Symbol { .. } => IconName::Code.path().into(),
MentionUri::Thread { .. } => IconName::Thread.path().into(), MentionUri::Thread { .. } => IconName::Thread.path().into(),
MentionUri::TextThread { .. } => IconName::Thread.path().into(), MentionUri::TextThread { .. } => IconName::Thread.path().into(),
@ -165,25 +154,18 @@ impl MentionUri {
pub fn to_uri(&self) -> Url { pub fn to_uri(&self) -> Url {
match self { match self {
MentionUri::File { MentionUri::File { abs_path } => {
abs_path, Url::from_file_path(abs_path).expect("mention path should be absolute")
is_directory, }
} => { MentionUri::Directory { abs_path } => {
let mut url = Url::parse("file:///").unwrap(); Url::from_directory_path(abs_path).expect("mention path should be absolute")
let mut path = abs_path.to_string_lossy().to_string();
if *is_directory && !path.ends_with("/") {
path.push_str("/");
}
url.set_path(&path);
url
} }
MentionUri::Symbol { MentionUri::Symbol {
path, path,
name, name,
line_range, line_range,
} => { } => {
let mut url = Url::parse("file:///").unwrap(); let mut url = Url::from_file_path(path).expect("mention path should be absolute");
url.set_path(&path.to_string_lossy());
url.query_pairs_mut().append_pair("symbol", name); url.query_pairs_mut().append_pair("symbol", name);
url.set_fragment(Some(&format!( url.set_fragment(Some(&format!(
"L{}:{}", "L{}:{}",
@ -193,8 +175,7 @@ impl MentionUri {
url url
} }
MentionUri::Selection { path, line_range } => { MentionUri::Selection { path, line_range } => {
let mut url = Url::parse("file:///").unwrap(); let mut url = Url::from_file_path(path).expect("mention path should be absolute");
url.set_path(&path.to_string_lossy());
url.set_fragment(Some(&format!( url.set_fragment(Some(&format!(
"L{}:{}", "L{}:{}",
line_range.start + 1, line_range.start + 1,
@ -267,19 +248,17 @@ pub fn selection_name(path: &Path, line_range: &Range<u32>) -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use util::{path, uri};
use super::*; use super::*;
#[test] #[test]
fn test_parse_file_uri() { fn test_parse_file_uri() {
let file_uri = "file:///path/to/file.rs"; let file_uri = uri!("file:///path/to/file.rs");
let parsed = MentionUri::parse(file_uri).unwrap(); let parsed = MentionUri::parse(file_uri).unwrap();
match &parsed { match &parsed {
MentionUri::File { MentionUri::File { abs_path } => {
abs_path, assert_eq!(abs_path.to_str().unwrap(), path!("/path/to/file.rs"));
is_directory,
} => {
assert_eq!(abs_path.to_str().unwrap(), "/path/to/file.rs");
assert!(!is_directory);
} }
_ => panic!("Expected File variant"), _ => panic!("Expected File variant"),
} }
@ -288,42 +267,38 @@ mod tests {
#[test] #[test]
fn test_parse_directory_uri() { fn test_parse_directory_uri() {
let file_uri = "file:///path/to/dir/"; let file_uri = uri!("file:///path/to/dir/");
let parsed = MentionUri::parse(file_uri).unwrap(); let parsed = MentionUri::parse(file_uri).unwrap();
match &parsed { match &parsed {
MentionUri::File { MentionUri::Directory { abs_path } => {
abs_path, assert_eq!(abs_path.to_str().unwrap(), path!("/path/to/dir/"));
is_directory,
} => {
assert_eq!(abs_path.to_str().unwrap(), "/path/to/dir/");
assert!(is_directory);
} }
_ => panic!("Expected File variant"), _ => panic!("Expected Directory variant"),
} }
assert_eq!(parsed.to_uri().to_string(), file_uri); assert_eq!(parsed.to_uri().to_string(), file_uri);
} }
#[test] #[test]
fn test_to_directory_uri_with_slash() { fn test_to_directory_uri_with_slash() {
let uri = MentionUri::File { let uri = MentionUri::Directory {
abs_path: PathBuf::from("/path/to/dir/"), abs_path: PathBuf::from(path!("/path/to/dir/")),
is_directory: true,
}; };
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/"); let expected = uri!("file:///path/to/dir/");
assert_eq!(uri.to_uri().to_string(), expected);
} }
#[test] #[test]
fn test_to_directory_uri_without_slash() { fn test_to_directory_uri_without_slash() {
let uri = MentionUri::File { let uri = MentionUri::Directory {
abs_path: PathBuf::from("/path/to/dir"), abs_path: PathBuf::from(path!("/path/to/dir")),
is_directory: true,
}; };
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/"); let expected = uri!("file:///path/to/dir/");
assert_eq!(uri.to_uri().to_string(), expected);
} }
#[test] #[test]
fn test_parse_symbol_uri() { fn test_parse_symbol_uri() {
let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20"; let symbol_uri = uri!("file:///path/to/file.rs?symbol=MySymbol#L10:20");
let parsed = MentionUri::parse(symbol_uri).unwrap(); let parsed = MentionUri::parse(symbol_uri).unwrap();
match &parsed { match &parsed {
MentionUri::Symbol { MentionUri::Symbol {
@ -331,7 +306,7 @@ mod tests {
name, name,
line_range, line_range,
} => { } => {
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"); assert_eq!(path.to_str().unwrap(), path!("/path/to/file.rs"));
assert_eq!(name, "MySymbol"); assert_eq!(name, "MySymbol");
assert_eq!(line_range.start, 9); assert_eq!(line_range.start, 9);
assert_eq!(line_range.end, 19); assert_eq!(line_range.end, 19);
@ -343,11 +318,11 @@ mod tests {
#[test] #[test]
fn test_parse_selection_uri() { fn test_parse_selection_uri() {
let selection_uri = "file:///path/to/file.rs#L5:15"; let selection_uri = uri!("file:///path/to/file.rs#L5:15");
let parsed = MentionUri::parse(selection_uri).unwrap(); let parsed = MentionUri::parse(selection_uri).unwrap();
match &parsed { match &parsed {
MentionUri::Selection { path, line_range } => { MentionUri::Selection { path, line_range } => {
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"); assert_eq!(path.to_str().unwrap(), path!("/path/to/file.rs"));
assert_eq!(line_range.start, 4); assert_eq!(line_range.start, 4);
assert_eq!(line_range.end, 14); assert_eq!(line_range.end, 14);
} }
@ -429,32 +404,35 @@ mod tests {
#[test] #[test]
fn test_invalid_line_range_format() { fn test_invalid_line_range_format() {
// Missing L prefix // Missing L prefix
assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err()); assert!(MentionUri::parse(uri!("file:///path/to/file.rs#10:20")).is_err());
// Missing colon separator // Missing colon separator
assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err()); assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L1020")).is_err());
// Invalid numbers // Invalid numbers
assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err()); assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L10:abc")).is_err());
assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err()); assert!(MentionUri::parse(uri!("file:///path/to/file.rs#Labc:20")).is_err());
} }
#[test] #[test]
fn test_invalid_query_parameters() { fn test_invalid_query_parameters() {
// Invalid query parameter name // Invalid query parameter name
assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err()); assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L10:20?invalid=test")).is_err());
// Too many query parameters // Too many query parameters
assert!( assert!(
MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err() MentionUri::parse(uri!(
"file:///path/to/file.rs#L10:20?symbol=test&another=param"
))
.is_err()
); );
} }
#[test] #[test]
fn test_zero_based_line_numbers() { fn test_zero_based_line_numbers() {
// Test that 0-based line numbers are rejected (should be 1-based) // Test that 0-based line numbers are rejected (should be 1-based)
assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err()); assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L0:10")).is_err());
assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err()); assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L1:0")).is_err());
assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err()); assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L0:0")).is_err());
} }
} }

View file

@ -116,7 +116,7 @@ impl ActionLog {
} else if buffer } else if buffer
.read(cx) .read(cx)
.file() .file()
.map_or(false, |file| file.disk_state().exists()) .is_some_and(|file| file.disk_state().exists())
{ {
TrackedBufferStatus::Created { TrackedBufferStatus::Created {
existing_file_content: Some(buffer.read(cx).as_rope().clone()), existing_file_content: Some(buffer.read(cx).as_rope().clone()),
@ -215,7 +215,7 @@ impl ActionLog {
if buffer if buffer
.read(cx) .read(cx)
.file() .file()
.map_or(false, |file| file.disk_state() == DiskState::Deleted) .is_some_and(|file| file.disk_state() == DiskState::Deleted)
{ {
// If the buffer had been edited by a tool, but it got // If the buffer had been edited by a tool, but it got
// deleted externally, we want to stop tracking it. // deleted externally, we want to stop tracking it.
@ -227,7 +227,7 @@ impl ActionLog {
if buffer if buffer
.read(cx) .read(cx)
.file() .file()
.map_or(false, |file| file.disk_state() != DiskState::Deleted) .is_some_and(|file| file.disk_state() != DiskState::Deleted)
{ {
// If the buffer had been deleted by a tool, but it got // If the buffer had been deleted by a tool, but it got
// resurrected externally, we want to clear the edits we // resurrected externally, we want to clear the edits we
@ -264,15 +264,14 @@ impl ActionLog {
if let Some((git_diff, (buffer_repo, _))) = git_diff.as_ref().zip(buffer_repo) { if let Some((git_diff, (buffer_repo, _))) = git_diff.as_ref().zip(buffer_repo) {
cx.update(|cx| { cx.update(|cx| {
let mut old_head = buffer_repo.read(cx).head_commit.clone(); let mut old_head = buffer_repo.read(cx).head_commit.clone();
Some(cx.subscribe(git_diff, move |_, event, cx| match event { Some(cx.subscribe(git_diff, move |_, event, cx| {
buffer_diff::BufferDiffEvent::DiffChanged { .. } => { if let buffer_diff::BufferDiffEvent::DiffChanged { .. } = event {
let new_head = buffer_repo.read(cx).head_commit.clone(); let new_head = buffer_repo.read(cx).head_commit.clone();
if new_head != old_head { if new_head != old_head {
old_head = new_head; old_head = new_head;
git_diff_updates_tx.send(()).ok(); git_diff_updates_tx.send(()).ok();
} }
} }
_ => {}
})) }))
})? })?
} else { } else {
@ -290,7 +289,7 @@ impl ActionLog {
} }
_ = git_diff_updates_rx.changed().fuse() => { _ = git_diff_updates_rx.changed().fuse() => {
if let Some(git_diff) = git_diff.as_ref() { if let Some(git_diff) = git_diff.as_ref() {
Self::keep_committed_edits(&this, &buffer, &git_diff, cx).await?; Self::keep_committed_edits(&this, &buffer, git_diff, cx).await?;
} }
} }
} }
@ -498,7 +497,7 @@ impl ActionLog {
new: new_range, new: new_range,
}, },
&new_diff_base, &new_diff_base,
&buffer_snapshot.as_rope(), buffer_snapshot.as_rope(),
)); ));
} }
unreviewed_edits unreviewed_edits
@ -614,10 +613,10 @@ impl ActionLog {
false false
} }
}); });
if tracked_buffer.unreviewed_edits.is_empty() { if tracked_buffer.unreviewed_edits.is_empty()
if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status { && let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status
tracked_buffer.status = TrackedBufferStatus::Modified; {
} tracked_buffer.status = TrackedBufferStatus::Modified;
} }
tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx); tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx);
} }
@ -811,7 +810,7 @@ impl ActionLog {
tracked.version != buffer.version tracked.version != buffer.version
&& buffer && buffer
.file() .file()
.map_or(false, |file| file.disk_state() != DiskState::Deleted) .is_some_and(|file| file.disk_state() != DiskState::Deleted)
}) })
.map(|(buffer, _)| buffer) .map(|(buffer, _)| buffer)
} }
@ -847,7 +846,7 @@ fn apply_non_conflicting_edits(
conflict = true; conflict = true;
if new_edits if new_edits
.peek() .peek()
.map_or(false, |next_edit| next_edit.old.overlaps(&old_edit.new)) .is_some_and(|next_edit| next_edit.old.overlaps(&old_edit.new))
{ {
new_edit = new_edits.next().unwrap(); new_edit = new_edits.next().unwrap();
} else { } else {
@ -964,7 +963,7 @@ impl TrackedBuffer {
fn has_edits(&self, cx: &App) -> bool { fn has_edits(&self, cx: &App) -> bool {
self.diff self.diff
.read(cx) .read(cx)
.hunks(&self.buffer.read(cx), cx) .hunks(self.buffer.read(cx), cx)
.next() .next()
.is_some() .is_some()
} }
@ -2268,7 +2267,7 @@ mod tests {
log::info!("quiescing..."); log::info!("quiescing...");
cx.run_until_parked(); cx.run_until_parked();
action_log.update(cx, |log, cx| { action_log.update(cx, |log, cx| {
let tracked_buffer = log.tracked_buffers.get(&buffer).unwrap(); let tracked_buffer = log.tracked_buffers.get(buffer).unwrap();
let mut old_text = tracked_buffer.diff_base.clone(); let mut old_text = tracked_buffer.diff_base.clone();
let new_text = buffer.read(cx).as_rope(); let new_text = buffer.read(cx).as_rope();
for edit in tracked_buffer.unreviewed_edits.edits() { for edit in tracked_buffer.unreviewed_edits.edits() {

View file

@ -103,26 +103,21 @@ impl ActivityIndicator {
cx.subscribe_in( cx.subscribe_in(
&workspace_handle, &workspace_handle,
window, window,
|activity_indicator, _, event, window, cx| match event { |activity_indicator, _, event, window, cx| {
workspace::Event::ClearActivityIndicator { .. } => { if let workspace::Event::ClearActivityIndicator { .. } = event
if activity_indicator.statuses.pop().is_some() { && activity_indicator.statuses.pop().is_some()
activity_indicator.dismiss_error_message( {
&DismissErrorMessage, activity_indicator.dismiss_error_message(&DismissErrorMessage, window, cx);
window, cx.notify();
cx,
);
cx.notify();
}
} }
_ => {}
}, },
) )
.detach(); .detach();
cx.subscribe( cx.subscribe(
&project.read(cx).lsp_store(), &project.read(cx).lsp_store(),
|activity_indicator, _, event, cx| match event { |activity_indicator, _, event, cx| {
LspStoreEvent::LanguageServerUpdate { name, message, .. } => { if let LspStoreEvent::LanguageServerUpdate { name, message, .. } = event {
if let proto::update_language_server::Variant::StatusUpdate(status_update) = if let proto::update_language_server::Variant::StatusUpdate(status_update) =
message message
{ {
@ -191,7 +186,6 @@ impl ActivityIndicator {
} }
cx.notify() cx.notify()
} }
_ => {}
}, },
) )
.detach(); .detach();
@ -206,9 +200,10 @@ impl ActivityIndicator {
cx.subscribe( cx.subscribe(
&project.read(cx).git_store().clone(), &project.read(cx).git_store().clone(),
|_, _, event: &GitStoreEvent, cx| match event { |_, _, event: &GitStoreEvent, cx| {
project::git_store::GitStoreEvent::JobsUpdated => cx.notify(), if let project::git_store::GitStoreEvent::JobsUpdated = event {
_ => {} cx.notify()
}
}, },
) )
.detach(); .detach();
@ -458,26 +453,24 @@ impl ActivityIndicator {
.map(|r| r.read(cx)) .map(|r| r.read(cx))
.and_then(Repository::current_job); .and_then(Repository::current_job);
// Show any long-running git command // Show any long-running git command
if let Some(job_info) = current_job { if let Some(job_info) = current_job
if Instant::now() - job_info.start >= GIT_OPERATION_DELAY { && Instant::now() - job_info.start >= GIT_OPERATION_DELAY
return Some(Content { {
icon: Some( return Some(Content {
Icon::new(IconName::ArrowCircle) icon: Some(
.size(IconSize::Small) Icon::new(IconName::ArrowCircle)
.with_animation( .size(IconSize::Small)
"arrow-circle", .with_animation(
Animation::new(Duration::from_secs(2)).repeat(), "arrow-circle",
|icon, delta| { Animation::new(Duration::from_secs(2)).repeat(),
icon.transform(Transformation::rotate(percentage(delta))) |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
}, )
) .into_any_element(),
.into_any_element(), ),
), message: job_info.message.into(),
message: job_info.message.into(), on_click: None,
on_click: None, tooltip_message: None,
tooltip_message: None, });
});
}
} }
// Show any language server installation info. // Show any language server installation info.
@ -702,7 +695,7 @@ impl ActivityIndicator {
on_click: Some(Arc::new(|this, window, cx| { on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx) this.dismiss_error_message(&DismissErrorMessage, window, cx)
})), })),
tooltip_message: Some(Self::version_tooltip_message(&version)), tooltip_message: Some(Self::version_tooltip_message(version)),
}), }),
AutoUpdateStatus::Installing { version } => Some(Content { AutoUpdateStatus::Installing { version } => Some(Content {
icon: Some( icon: Some(
@ -714,13 +707,13 @@ impl ActivityIndicator {
on_click: Some(Arc::new(|this, window, cx| { on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx) this.dismiss_error_message(&DismissErrorMessage, window, cx)
})), })),
tooltip_message: Some(Self::version_tooltip_message(&version)), tooltip_message: Some(Self::version_tooltip_message(version)),
}), }),
AutoUpdateStatus::Updated { version } => Some(Content { AutoUpdateStatus::Updated { version } => Some(Content {
icon: None, icon: None,
message: "Click to restart and update Zed".to_string(), message: "Click to restart and update Zed".to_string(),
on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))), on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))),
tooltip_message: Some(Self::version_tooltip_message(&version)), tooltip_message: Some(Self::version_tooltip_message(version)),
}), }),
AutoUpdateStatus::Errored => Some(Content { AutoUpdateStatus::Errored => Some(Content {
icon: Some( icon: Some(
@ -740,21 +733,20 @@ impl ActivityIndicator {
if let Some(extension_store) = if let Some(extension_store) =
ExtensionStore::try_global(cx).map(|extension_store| extension_store.read(cx)) ExtensionStore::try_global(cx).map(|extension_store| extension_store.read(cx))
&& let Some(extension_id) = extension_store.outstanding_operations().keys().next()
{ {
if let Some(extension_id) = extension_store.outstanding_operations().keys().next() { return Some(Content {
return Some(Content { icon: Some(
icon: Some( Icon::new(IconName::Download)
Icon::new(IconName::Download) .size(IconSize::Small)
.size(IconSize::Small) .into_any_element(),
.into_any_element(), ),
), message: format!("Updating {extension_id} extension…"),
message: format!("Updating {extension_id} extension…"), on_click: Some(Arc::new(|this, window, cx| {
on_click: Some(Arc::new(|this, window, cx| { this.dismiss_error_message(&DismissErrorMessage, window, cx)
this.dismiss_error_message(&DismissErrorMessage, window, cx) })),
})), tooltip_message: None,
tooltip_message: None, });
});
}
} }
None None

View file

@ -31,7 +31,6 @@ collections.workspace = true
component.workspace = true component.workspace = true
context_server.workspace = true context_server.workspace = true
convert_case.workspace = true convert_case.workspace = true
feature_flags.workspace = true
fs.workspace = true fs.workspace = true
futures.workspace = true futures.workspace = true
git.workspace = true git.workspace = true

View file

@ -90,7 +90,7 @@ impl AgentProfile {
return false; return false;
}; };
return Self::is_enabled(settings, source, tool_name); Self::is_enabled(settings, source, tool_name)
} }
fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool { fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {

View file

@ -201,24 +201,24 @@ impl FileContextHandle {
parse_status.changed().await.log_err(); parse_status.changed().await.log_err();
} }
if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot()) { if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot())
if let Some(outline) = snapshot.outline(None) { && let Some(outline) = snapshot.outline(None)
let items = outline {
.items let items = outline
.into_iter() .items
.map(|item| item.to_point(&snapshot)); .into_iter()
.map(|item| item.to_point(&snapshot));
if let Ok(outline_text) = if let Ok(outline_text) =
outline::render_outline(items, None, 0, usize::MAX).await outline::render_outline(items, None, 0, usize::MAX).await
{ {
let context = AgentContext::File(FileContext { let context = AgentContext::File(FileContext {
handle: self, handle: self,
full_path, full_path,
text: outline_text.into(), text: outline_text.into(),
is_outline: true, is_outline: true,
}); });
return Some((context, vec![buffer])); return Some((context, vec![buffer]));
}
} }
} }
} }

View file

@ -338,11 +338,9 @@ impl ContextStore {
image_task, image_task,
context_id: self.next_context_id.post_inc(), context_id: self.next_context_id.post_inc(),
}); });
if self.has_context(&context) { if self.has_context(&context) && remove_if_exists {
if remove_if_exists { self.remove_context(&context, cx);
self.remove_context(&context, cx); return None;
return None;
}
} }
self.insert_context(context.clone(), cx); self.insert_context(context.clone(), cx);

View file

@ -9,14 +9,16 @@ use crate::{
tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}, tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
}; };
use action_log::ActionLog; use action_log::ActionLog;
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; use agent_settings::{
AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
SUMMARIZE_THREAD_PROMPT,
};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet}; use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage}; use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
use collections::HashMap; use collections::HashMap;
use feature_flags::{self, FeatureFlagAppExt};
use futures::{FutureExt, StreamExt as _, future::Shared}; use futures::{FutureExt, StreamExt as _, future::Shared};
use git::repository::DiffType; use git::repository::DiffType;
use gpui::{ use gpui::{
@ -108,7 +110,7 @@ impl std::fmt::Display for PromptId {
} }
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct MessageId(pub(crate) usize); pub struct MessageId(pub usize);
impl MessageId { impl MessageId {
fn post_inc(&mut self) -> Self { fn post_inc(&mut self) -> Self {
@ -388,7 +390,6 @@ pub struct Thread {
feedback: Option<ThreadFeedback>, feedback: Option<ThreadFeedback>,
retry_state: Option<RetryState>, retry_state: Option<RetryState>,
message_feedback: HashMap<MessageId, ThreadFeedback>, message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>,
last_received_chunk_at: Option<Instant>, last_received_chunk_at: Option<Instant>,
request_callback: Option< request_callback: Option<
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>, Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
@ -489,7 +490,6 @@ impl Thread {
feedback: None, feedback: None,
retry_state: None, retry_state: None,
message_feedback: HashMap::default(), message_feedback: HashMap::default(),
last_auto_capture_at: None,
last_error_context: None, last_error_context: None,
last_received_chunk_at: None, last_received_chunk_at: None,
request_callback: None, request_callback: None,
@ -614,7 +614,6 @@ impl Thread {
tool_use_limit_reached: serialized.tool_use_limit_reached, tool_use_limit_reached: serialized.tool_use_limit_reached,
feedback: None, feedback: None,
message_feedback: HashMap::default(), message_feedback: HashMap::default(),
last_auto_capture_at: None,
last_error_context: None, last_error_context: None,
last_received_chunk_at: None, last_received_chunk_at: None,
request_callback: None, request_callback: None,
@ -1033,8 +1032,6 @@ impl Thread {
}); });
} }
self.auto_capture_telemetry(cx);
message_id message_id
} }
@ -1651,15 +1648,13 @@ impl Thread {
self.tool_use self.tool_use
.request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx); .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
let pending_tool_use = self.tool_use.insert_tool_output( self.tool_use.insert_tool_output(
tool_use_id.clone(), tool_use_id.clone(),
tool_name, tool_name,
tool_output, tool_output,
self.configured_model.as_ref(), self.configured_model.as_ref(),
self.completion_mode, self.completion_mode,
); )
pending_tool_use
} }
pub fn stream_completion( pub fn stream_completion(
@ -1692,7 +1687,7 @@ impl Thread {
self.last_received_chunk_at = Some(Instant::now()); self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| { let task = cx.spawn(async move |thread, cx| {
let stream_completion_future = model.stream_completion(request, &cx); let stream_completion_future = model.stream_completion(request, cx);
let initial_token_usage = let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async { let stream_completion = async {
@ -1824,7 +1819,7 @@ impl Thread {
let streamed_input = if tool_use.is_input_complete { let streamed_input = if tool_use.is_input_complete {
None None
} else { } else {
Some((&tool_use.input).clone()) Some(tool_use.input.clone())
}; };
let ui_text = thread.tool_use.request_tool_use( let ui_text = thread.tool_use.request_tool_use(
@ -1906,7 +1901,6 @@ impl Thread {
cx.emit(ThreadEvent::StreamedCompletion); cx.emit(ThreadEvent::StreamedCompletion);
cx.notify(); cx.notify();
thread.auto_capture_telemetry(cx);
Ok(()) Ok(())
})??; })??;
@ -1974,11 +1968,9 @@ impl Thread {
if let Some(prev_message) = if let Some(prev_message) =
thread.messages.get(ix - 1) thread.messages.get(ix - 1)
{ && prev_message.role == Role::Assistant {
if prev_message.role == Role::Assistant {
break; break;
} }
}
} }
} }
@ -2051,7 +2043,7 @@ impl Thread {
retry_scheduled = thread retry_scheduled = thread
.handle_retryable_error_with_delay( .handle_retryable_error_with_delay(
&completion_error, completion_error,
Some(retry_strategy), Some(retry_strategy),
model.clone(), model.clone(),
intent, intent,
@ -2081,8 +2073,6 @@ impl Thread {
request_callback(request, response_events); request_callback(request, response_events);
} }
thread.auto_capture_telemetry(cx);
if let Ok(initial_usage) = initial_token_usage { if let Ok(initial_usage) = initial_token_usage {
let usage = thread.cumulative_token_usage - initial_usage; let usage = thread.cumulative_token_usage - initial_usage;
@ -2130,7 +2120,7 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| { self.pending_summary = cx.spawn(async move |this, cx| {
let result = async { let result = async {
let mut messages = model.model.stream_completion(request, &cx).await?; let mut messages = model.model.stream_completion(request, cx).await?;
let mut new_summary = String::new(); let mut new_summary = String::new();
while let Some(event) = messages.next().await { while let Some(event) = messages.next().await {
@ -2438,12 +2428,10 @@ impl Thread {
return; return;
} }
let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
let request = self.to_summarize_request( let request = self.to_summarize_request(
&model, &model,
CompletionIntent::ThreadContextSummarization, CompletionIntent::ThreadContextSummarization,
added_user_message.into(), SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
cx, cx,
); );
@ -2456,7 +2444,7 @@ impl Thread {
// which result to prefer (the old task could complete after the new one, resulting in a // which result to prefer (the old task could complete after the new one, resulting in a
// stale summary). // stale summary).
self.detailed_summary_task = cx.spawn(async move |thread, cx| { self.detailed_summary_task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion_text(request, &cx); let stream = model.stream_completion_text(request, cx);
let Some(mut messages) = stream.await.log_err() else { let Some(mut messages) = stream.await.log_err() else {
thread thread
.update(cx, |thread, _cx| { .update(cx, |thread, _cx| {
@ -2485,13 +2473,13 @@ impl Thread {
.ok()?; .ok()?;
// Save thread so its summary can be reused later // Save thread so its summary can be reused later
if let Some(thread) = thread.upgrade() { if let Some(thread) = thread.upgrade()
if let Ok(Ok(save_task)) = cx.update(|cx| { && let Ok(Ok(save_task)) = cx.update(|cx| {
thread_store thread_store
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
}) { })
save_task.await.log_err(); {
} save_task.await.log_err();
} }
Some(()) Some(())
@ -2536,7 +2524,6 @@ impl Thread {
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Vec<PendingToolUse> { ) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx);
let request = let request =
Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx)); Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
let pending_tool_uses = self let pending_tool_uses = self
@ -2740,13 +2727,11 @@ impl Thread {
window: Option<AnyWindowHandle>, window: Option<AnyWindowHandle>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if self.all_tools_finished() { if self.all_tools_finished()
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() { && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
if !canceled { && !canceled
self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx); {
} self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
self.auto_capture_telemetry(cx);
}
} }
cx.emit(ThreadEvent::ToolFinished { cx.emit(ThreadEvent::ToolFinished {
@ -2933,11 +2918,11 @@ impl Thread {
let buffer_store = project.read(app_cx).buffer_store(); let buffer_store = project.read(app_cx).buffer_store();
for buffer_handle in buffer_store.read(app_cx).buffers() { for buffer_handle in buffer_store.read(app_cx).buffers() {
let buffer = buffer_handle.read(app_cx); let buffer = buffer_handle.read(app_cx);
if buffer.is_dirty() { if buffer.is_dirty()
if let Some(file) = buffer.file() { && let Some(file) = buffer.file()
let path = file.path().to_string_lossy().to_string(); {
unsaved_buffers.push(path); let path = file.path().to_string_lossy().to_string();
} unsaved_buffers.push(path);
} }
} }
}) })
@ -3147,50 +3132,6 @@ impl Thread {
&self.project &self.project
} }
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
return;
}
let now = Instant::now();
if let Some(last) = self.last_auto_capture_at {
if now.duration_since(last).as_secs() < 10 {
return;
}
}
self.last_auto_capture_at = Some(now);
let thread_id = self.id().clone();
let github_login = self
.project
.read(cx)
.user_store()
.read(cx)
.current_user()
.map(|user| user.github_login.clone());
let client = self.project.read(cx).client();
let serialize_task = self.serialize(cx);
cx.background_executor()
.spawn(async move {
if let Ok(serialized_thread) = serialize_task.await {
if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
telemetry::event!(
"Agent Thread Auto-Captured",
thread_id = thread_id.to_string(),
thread_data = thread_data,
auto_capture_reason = "tracked_user",
github_login = github_login
);
client.telemetry().flush_events().await;
}
}
})
.detach();
}
pub fn cumulative_token_usage(&self) -> TokenUsage { pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage self.cumulative_token_usage
} }
@ -3233,13 +3174,13 @@ impl Thread {
.model .model
.max_token_count_for_mode(self.completion_mode().into()); .max_token_count_for_mode(self.completion_mode().into());
if let Some(exceeded_error) = &self.exceeded_window_error { if let Some(exceeded_error) = &self.exceeded_window_error
if model.model.id() == exceeded_error.model_id { && model.model.id() == exceeded_error.model_id
return Some(TotalTokenUsage { {
total: exceeded_error.token_count, return Some(TotalTokenUsage {
max, total: exceeded_error.token_count,
}); max,
} });
} }
let total = self let total = self
@ -4043,7 +3984,7 @@ fn main() {{
}); });
let fake_model = model.as_fake(); let fake_model = model.as_fake();
simulate_successful_response(&fake_model, cx); simulate_successful_response(fake_model, cx);
// Should start generating summary when there are >= 2 messages // Should start generating summary when there are >= 2 messages
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
@ -4138,7 +4079,7 @@ fn main() {{
}); });
let fake_model = model.as_fake(); let fake_model = model.as_fake();
simulate_successful_response(&fake_model, cx); simulate_successful_response(fake_model, cx);
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
// State is still Error, not Generating // State is still Error, not Generating
@ -5420,7 +5361,7 @@ fn main() {{
}); });
let fake_model = model.as_fake(); let fake_model = model.as_fake();
simulate_successful_response(&fake_model, cx); simulate_successful_response(fake_model, cx);
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Generating)); assert!(matches!(thread.summary(), ThreadSummary::Generating));

View file

@ -42,7 +42,7 @@ use std::{
use util::ResultExt as _; use util::ResultExt as _;
pub static ZED_STATELESS: std::sync::LazyLock<bool> = pub static ZED_STATELESS: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataType { pub enum DataType {
@ -74,7 +74,7 @@ impl Column for DataType {
} }
} }
const RULES_FILE_NAMES: [&'static str; 9] = [ const RULES_FILE_NAMES: [&str; 9] = [
".rules", ".rules",
".cursorrules", ".cursorrules",
".windsurfrules", ".windsurfrules",
@ -581,33 +581,32 @@ impl ThreadStore {
return; return;
}; };
if protocol.capable(context_server::protocol::ServerCapability::Tools) { if protocol.capable(context_server::protocol::ServerCapability::Tools)
if let Some(response) = protocol && let Some(response) = protocol
.request::<context_server::types::requests::ListTools>(()) .request::<context_server::types::requests::ListTools>(())
.await .await
.log_err() .log_err()
{ {
let tool_ids = tool_working_set let tool_ids = tool_working_set
.update(cx, |tool_working_set, cx| { .update(cx, |tool_working_set, cx| {
tool_working_set.extend( tool_working_set.extend(
response.tools.into_iter().map(|tool| { response.tools.into_iter().map(|tool| {
Arc::new(ContextServerTool::new( Arc::new(ContextServerTool::new(
context_server_store.clone(), context_server_store.clone(),
server.id(), server.id(),
tool, tool,
)) as Arc<dyn Tool> )) as Arc<dyn Tool>
}), }),
cx, cx,
) )
}) })
.log_err(); .log_err();
if let Some(tool_ids) = tool_ids { if let Some(tool_ids) = tool_ids {
this.update(cx, |this, _| { this.update(cx, |this, _| {
this.context_server_tool_ids.insert(server_id, tool_ids); this.context_server_tool_ids.insert(server_id, tool_ids);
}) })
.log_err(); .log_err();
}
} }
} }
}) })
@ -697,13 +696,14 @@ impl SerializedThreadV0_1_0 {
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len()); let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
for message in self.0.messages { for message in self.0.messages {
if message.role == Role::User && !message.tool_results.is_empty() { if message.role == Role::User
if let Some(last_message) = messages.last_mut() { && !message.tool_results.is_empty()
debug_assert!(last_message.role == Role::Assistant); && let Some(last_message) = messages.last_mut()
{
debug_assert!(last_message.role == Role::Assistant);
last_message.tool_results = message.tool_results; last_message.tool_results = message.tool_results;
continue; continue;
}
} }
messages.push(message); messages.push(message);
@ -893,7 +893,7 @@ impl ThreadsDatabase {
let needs_migration_from_heed = mdb_path.exists(); let needs_migration_from_heed = mdb_path.exists();
let connection = if *ZED_STATELESS { let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
Connection::open_memory(Some("THREAD_FALLBACK_DB")) Connection::open_memory(Some("THREAD_FALLBACK_DB"))
} else { } else {
Connection::open_file(&sqlite_path.to_string_lossy()) Connection::open_file(&sqlite_path.to_string_lossy())

View file

@ -112,19 +112,13 @@ impl ToolUseState {
}, },
); );
if let Some(window) = &mut window { if let Some(window) = &mut window
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) { && let Some(tool) = this.tools.read(cx).tool(tool_use, cx)
if let Some(output) = tool_result.output.clone() { && let Some(output) = tool_result.output.clone()
if let Some(card) = tool.deserialize_card( && let Some(card) =
output, tool.deserialize_card(output, project.clone(), window, cx)
project.clone(), {
window, this.tool_result_cards.insert(tool_use_id, card);
cx,
) {
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
} }
} }
} }
@ -281,7 +275,7 @@ impl ToolUseState {
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool { pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message self.tool_uses_by_assistant_message
.get(&assistant_message_id) .get(&assistant_message_id)
.map_or(false, |results| !results.is_empty()) .is_some_and(|results| !results.is_empty())
} }
pub fn tool_result( pub fn tool_result(

View file

@ -8,24 +8,31 @@ license = "GPL-3.0-or-later"
[lib] [lib]
path = "src/agent2.rs" path = "src/agent2.rs"
[features]
test-support = ["db/test-support"]
[lints] [lints]
workspace = true workspace = true
[dependencies] [dependencies]
acp_thread.workspace = true acp_thread.workspace = true
action_log.workspace = true action_log.workspace = true
agent.workspace = true
agent-client-protocol.workspace = true agent-client-protocol.workspace = true
agent_servers.workspace = true agent_servers.workspace = true
agent_settings.workspace = true agent_settings.workspace = true
anyhow.workspace = true anyhow.workspace = true
assistant_context.workspace = true
assistant_tool.workspace = true assistant_tool.workspace = true
assistant_tools.workspace = true assistant_tools.workspace = true
chrono.workspace = true chrono.workspace = true
cloud_llm_client.workspace = true cloud_llm_client.workspace = true
collections.workspace = true collections.workspace = true
context_server.workspace = true context_server.workspace = true
db.workspace = true
fs.workspace = true fs.workspace = true
futures.workspace = true futures.workspace = true
git.workspace = true
gpui.workspace = true gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] } handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true html_to_markdown.workspace = true
@ -37,6 +44,7 @@ language_model.workspace = true
language_models.workspace = true language_models.workspace = true
log.workspace = true log.workspace = true
open.workspace = true open.workspace = true
parking_lot.workspace = true
paths.workspace = true paths.workspace = true
portable-pty.workspace = true portable-pty.workspace = true
project.workspace = true project.workspace = true
@ -47,6 +55,7 @@ serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
settings.workspace = true settings.workspace = true
smol.workspace = true smol.workspace = true
sqlez.workspace = true
task.workspace = true task.workspace = true
terminal.workspace = true terminal.workspace = true
text.workspace = true text.workspace = true
@ -57,15 +66,20 @@ watch.workspace = true
web_search.workspace = true web_search.workspace = true
which.workspace = true which.workspace = true
workspace-hack.workspace = true workspace-hack.workspace = true
zstd.workspace = true
[dev-dependencies] [dev-dependencies]
agent = { workspace = true, "features" = ["test-support"] }
assistant_context = { workspace = true, "features" = ["test-support"] }
ctor.workspace = true ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] } client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] }
context_server = { workspace = true, "features" = ["test-support"] } context_server = { workspace = true, "features" = ["test-support"] }
db = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] } fs = { workspace = true, "features" = ["test-support"] }
git = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] } language = { workspace = true, "features" = ["test-support"] }

View file

@ -1,10 +1,10 @@
use crate::HistoryStore;
use crate::{ use crate::{
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, UserMessageContent, templates::Templates,
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
}; };
use acp_thread::AgentModelSelector; use acp_thread::{AcpThread, AgentModelSelector};
use action_log::ActionLog;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::AgentSettings; use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
@ -22,14 +22,13 @@ use prompt_store::{
}; };
use settings::update_settings_file; use settings::update_settings_file;
use std::any::Any; use std::any::Any;
use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use util::ResultExt; use util::ResultExt;
const RULES_FILE_NAMES: [&'static str; 9] = [ const RULES_FILE_NAMES: [&str; 9] = [
".rules", ".rules",
".cursorrules", ".cursorrules",
".windsurfrules", ".windsurfrules",
@ -51,7 +50,8 @@ struct Session {
thread: Entity<Thread>, thread: Entity<Thread>,
/// The ACP thread that handles protocol communication /// The ACP thread that handles protocol communication
acp_thread: WeakEntity<acp_thread::AcpThread>, acp_thread: WeakEntity<acp_thread::AcpThread>,
_subscription: Subscription, pending_save: Task<()>,
_subscriptions: Vec<Subscription>,
} }
pub struct LanguageModels { pub struct LanguageModels {
@ -91,7 +91,7 @@ impl LanguageModels {
for provider in &providers { for provider in &providers {
for model in provider.recommended_models(cx) { for model in provider.recommended_models(cx) {
recommended_models.insert(model.id()); recommended_models.insert(model.id());
recommended.push(Self::map_language_model_to_info(&model, &provider)); recommended.push(Self::map_language_model_to_info(&model, provider));
} }
} }
if !recommended.is_empty() { if !recommended.is_empty() {
@ -155,8 +155,9 @@ impl LanguageModels {
pub struct NativeAgent { pub struct NativeAgent {
/// Session ID -> Session mapping /// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>, sessions: HashMap<acp::SessionId, Session>,
history: Entity<HistoryStore>,
/// Shared project context for all threads /// Shared project context for all threads
project_context: Rc<RefCell<ProjectContext>>, project_context: Entity<ProjectContext>,
project_context_needs_refresh: watch::Sender<()>, project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task<Result<()>>, _maintain_project_context: Task<Result<()>>,
context_server_registry: Entity<ContextServerRegistry>, context_server_registry: Entity<ContextServerRegistry>,
@ -173,6 +174,7 @@ pub struct NativeAgent {
impl NativeAgent { impl NativeAgent {
pub async fn new( pub async fn new(
project: Entity<Project>, project: Entity<Project>,
history: Entity<HistoryStore>,
templates: Arc<Templates>, templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>, prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
@ -200,7 +202,8 @@ impl NativeAgent {
watch::channel(()); watch::channel(());
Self { Self {
sessions: HashMap::new(), sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)), history,
project_context: cx.new(|_| project_context),
project_context_needs_refresh: project_context_needs_refresh_tx, project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| { _maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
@ -218,6 +221,55 @@ impl NativeAgent {
}) })
} }
fn register_session(
&mut self,
thread_handle: Entity<Thread>,
cx: &mut Context<Self>,
) -> Entity<AcpThread> {
let connection = Rc::new(NativeAgentConnection(cx.entity()));
let registry = LanguageModelRegistry::read_global(cx);
let summarization_model = registry.thread_summary_model().map(|c| c.model);
thread_handle.update(cx, |thread, cx| {
thread.set_summarization_model(summarization_model, cx);
thread.add_default_tools(cx)
});
let thread = thread_handle.read(cx);
let session_id = thread.id().clone();
let title = thread.title();
let project = thread.project.clone();
let action_log = thread.action_log.clone();
let acp_thread = cx.new(|_cx| {
acp_thread::AcpThread::new(
title,
connection,
project.clone(),
action_log.clone(),
session_id.clone(),
)
});
let subscriptions = vec![
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
cx.observe(&thread_handle, move |this, thread, cx| {
this.save_thread(thread.clone(), cx)
}),
];
self.sessions.insert(
session_id,
Session {
thread: thread_handle,
acp_thread: acp_thread.downgrade(),
_subscriptions: subscriptions,
pending_save: Task::ready(()),
},
);
acp_thread
}
pub fn models(&self) -> &LanguageModels { pub fn models(&self) -> &LanguageModels {
&self.models &self.models
} }
@ -233,7 +285,9 @@ impl NativeAgent {
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx) Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
})? })?
.await; .await;
this.update(cx, |this, _| this.project_context.replace(project_context))?; this.update(cx, |this, cx| {
this.project_context = cx.new(|_| project_context);
})?;
} }
Ok(()) Ok(())
@ -426,21 +480,101 @@ impl NativeAgent {
) { ) {
self.models.refresh_list(cx); self.models.refresh_list(cx);
let default_model = LanguageModelRegistry::read_global(cx) let registry = LanguageModelRegistry::read_global(cx);
.default_model() let default_model = registry.default_model().map(|m| m.model.clone());
.map(|m| m.model.clone()); let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
for session in self.sessions.values_mut() { for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| { session.thread.update(cx, |thread, cx| {
if thread.model().is_none() if thread.model().is_none()
&& let Some(model) = default_model.clone() && let Some(model) = default_model.clone()
{ {
thread.set_model(model); thread.set_model(model, cx);
cx.notify(); cx.notify();
} }
thread.set_summarization_model(summarization_model.clone(), cx);
}); });
} }
} }
pub fn open_thread(
&mut self,
id: acp::SessionId,
cx: &mut Context<Self>,
) -> Task<Result<Entity<AcpThread>>> {
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
let db_thread = database
.load_thread(id.clone())
.await?
.with_context(|| format!("no thread found with ID: {id:?}"))?;
let thread = this.update(cx, |this, cx| {
let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
cx.new(|cx| {
Thread::from_db(
id.clone(),
db_thread,
this.project.clone(),
this.project_context.clone(),
this.context_server_registry.clone(),
action_log.clone(),
this.templates.clone(),
cx,
)
})
})?;
let acp_thread =
this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
cx.update(|cx| {
NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
})?
.await?;
Ok(acp_thread)
})
}
pub fn thread_summary(
&mut self,
id: acp::SessionId,
cx: &mut Context<Self>,
) -> Task<Result<SharedString>> {
let thread = self.open_thread(id.clone(), cx);
cx.spawn(async move |this, cx| {
let acp_thread = thread.await?;
let result = this
.update(cx, |this, cx| {
this.sessions
.get(&id)
.unwrap()
.thread
.update(cx, |thread, cx| thread.summary(cx))
})?
.await?;
drop(acp_thread);
Ok(result)
})
}
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
let database_future = ThreadsDatabase::connect(cx);
let (id, db_thread) =
thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
let Some(session) = self.sessions.get_mut(&id) else {
return;
};
let history = self.history.clone();
session.pending_save = cx.spawn(async move |_, cx| {
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
return;
};
let db_thread = db_thread.await;
database.save_thread(id, db_thread).await.log_err();
history.update(cx, |history, cx| history.reload(cx)).ok();
});
}
} }
/// Wrapper struct that implements the AgentConnection trait /// Wrapper struct that implements the AgentConnection trait
@ -461,10 +595,7 @@ impl NativeAgentConnection {
session_id: acp::SessionId, session_id: acp::SessionId,
cx: &mut App, cx: &mut App,
f: impl 'static f: impl 'static
+ FnOnce( + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
Entity<Thread>,
&mut App,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
) -> Task<Result<acp::PromptResponse>> { ) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent agent
@ -476,19 +607,38 @@ impl NativeAgentConnection {
}; };
log::debug!("Found session for: {}", session_id); log::debug!("Found session for: {}", session_id);
let mut response_stream = match f(thread, cx) { let response_stream = match f(thread, cx) {
Ok(stream) => stream, Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)), Err(err) => return Task::ready(Err(err)),
}; };
Self::handle_thread_events(response_stream, acp_thread, cx)
}
fn handle_thread_events(
mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
acp_thread: WeakEntity<AcpThread>,
cx: &App,
) -> Task<Result<acp::PromptResponse>> {
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread // Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await { while let Some(result) = events.next().await {
match result { match result {
Ok(event) => { Ok(event) => {
log::trace!("Received completion event: {:?}", event); log::trace!("Received completion event: {:?}", event);
match event { match event {
AgentResponseEvent::Text(text) => { ThreadEvent::UserMessage(message) => {
acp_thread.update(cx, |thread, cx| {
for content in message.content {
thread.push_user_content_block(
Some(message.id.clone()),
content.into(),
cx,
);
}
})?;
}
ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block( thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent { acp::ContentBlock::Text(acp::TextContent {
@ -500,7 +650,7 @@ impl NativeAgentConnection {
) )
})?; })?;
} }
AgentResponseEvent::Thinking(text) => { ThreadEvent::AgentThinking(text) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block( thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent { acp::ContentBlock::Text(acp::TextContent {
@ -512,7 +662,7 @@ impl NativeAgentConnection {
) )
})?; })?;
} }
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call, tool_call,
options, options,
response, response,
@ -535,17 +685,31 @@ impl NativeAgentConnection {
}) })
.detach(); .detach();
} }
AgentResponseEvent::ToolCall(tool_call) => { ThreadEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx) thread.upsert_tool_call(tool_call, cx)
})??; })??;
} }
AgentResponseEvent::ToolCallUpdate(update) => { ThreadEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx) thread.update_tool_call(update, cx)
})??; })??;
} }
AgentResponseEvent::Stop(stop_reason) => { ThreadEvent::TokenUsageUpdate(usage) => {
acp_thread.update(cx, |thread, cx| {
thread.update_token_usage(Some(usage), cx)
})?;
}
ThreadEvent::TitleUpdate(title) => {
acp_thread
.update(cx, |thread, cx| thread.update_title(title, cx))??;
}
ThreadEvent::Retry(status) => {
acp_thread.update(cx, |thread, cx| {
thread.update_retry_status(status, cx)
})?;
}
ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason); log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason }); return Ok(acp::PromptResponse { stop_reason });
} }
@ -598,8 +762,8 @@ impl AgentModelSelector for NativeAgentConnection {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
}; };
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, cx| {
thread.set_model(model.clone()); thread.set_model(model.clone(), cx);
}); });
update_settings_file::<AgentSettings>( update_settings_file::<AgentSettings>(
@ -659,31 +823,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context"); log::debug!("Starting thread creation in async context");
// Generate session ID let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
session_id.clone(),
cx,
)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
// Create Thread // Create Thread
let thread = agent.update( let thread = agent.update(
cx, cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> { |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings // Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx); let registry = LanguageModelRegistry::read_global(cx);
// Log available models for debugging // Log available models for debugging
let available_count = registry.available_models(cx).count(); let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count); log::debug!("Total available models: {}", available_count);
@ -695,7 +841,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}); });
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let mut thread = Thread::new( Thread::new(
project.clone(), project.clone(),
agent.project_context.clone(), agent.project_context.clone(),
agent.context_server_registry.clone(), agent.context_server_registry.clone(),
@ -703,45 +849,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
agent.templates.clone(), agent.templates.clone(),
default_model, default_model,
cx, cx,
); )
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(EditFileTool::new(cx.entity()));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));
thread.add_tool(ListDirectoryTool::new(project.clone()));
thread.add_tool(MovePathTool::new(project.clone()));
thread.add_tool(NowTool);
thread.add_tool(OpenTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(TerminalTool::new(project.clone(), cx));
thread.add_tool(ThinkingTool);
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
thread
}); });
Ok(thread) Ok(thread)
}, },
)??; )??;
agent.update(cx, |agent, cx| agent.register_session(thread, cx))
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
session_id,
Session {
thread,
acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
},
);
})?;
Ok(acp_thread)
}) })
} }
@ -797,7 +911,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Cancelling on session: {}", session_id); log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| { self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) { if let Some(agent) = agent.sessions.get(session_id) {
agent.thread.update(cx, |thread, _cx| thread.cancel()); agent.thread.update(cx, |thread, cx| thread.cancel(cx));
} }
}); });
} }
@ -808,10 +922,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx: &mut App, cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> { ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
self.0.update(cx, |agent, _cx| { self.0.update(cx, |agent, _cx| {
agent agent.sessions.get(session_id).map(|session| {
.sessions Rc::new(NativeAgentSessionEditor {
.get(session_id) thread: session.thread.clone(),
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _) acp_thread: session.acp_thread.clone(),
}) as _
})
}) })
} }
@ -820,11 +936,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
} }
} }
struct NativeAgentSessionEditor(Entity<Thread>); struct NativeAgentSessionEditor {
thread: Entity<Thread>,
acp_thread: WeakEntity<AcpThread>,
}
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> { fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) match self.thread.update(cx, |thread, cx| {
thread.truncate(message_id.clone(), cx)?;
Ok(thread.latest_token_usage())
}) {
Ok(usage) => {
self.acp_thread
.update(cx, |thread, cx| {
thread.update_token_usage(usage, cx);
})
.ok();
Task::ready(Ok(()))
}
Err(error) => Task::ready(Err(error)),
}
} }
} }
@ -863,8 +995,11 @@ mod tests {
) )
.await; .await;
let project = Project::test(fs.clone(), [], cx).await; let project = Project::test(fs.clone(), [], cx).await;
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
let agent = NativeAgent::new( let agent = NativeAgent::new(
project.clone(), project.clone(),
history_store,
Templates::new(), Templates::new(),
None, None,
fs.clone(), fs.clone(),
@ -872,8 +1007,8 @@ mod tests {
) )
.await .await
.unwrap(); .unwrap();
agent.read_with(cx, |agent, _| { agent.read_with(cx, |agent, cx| {
assert_eq!(agent.project_context.borrow().worktrees, vec![]) assert_eq!(agent.project_context.read(cx).worktrees, vec![])
}); });
let worktree = project let worktree = project
@ -881,9 +1016,9 @@ mod tests {
.await .await
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
agent.read_with(cx, |agent, _| { agent.read_with(cx, |agent, cx| {
assert_eq!( assert_eq!(
agent.project_context.borrow().worktrees, agent.project_context.read(cx).worktrees,
vec![WorktreeContext { vec![WorktreeContext {
root_name: "a".into(), root_name: "a".into(),
abs_path: Path::new("/a").into(), abs_path: Path::new("/a").into(),
@ -898,7 +1033,7 @@ mod tests {
agent.read_with(cx, |agent, cx| { agent.read_with(cx, |agent, cx| {
let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap(); let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
assert_eq!( assert_eq!(
agent.project_context.borrow().worktrees, agent.project_context.read(cx).worktrees,
vec![WorktreeContext { vec![WorktreeContext {
root_name: "a".into(), root_name: "a".into(),
abs_path: Path::new("/a").into(), abs_path: Path::new("/a").into(),
@ -918,9 +1053,12 @@ mod tests {
let fs = FakeFs::new(cx.executor()); let fs = FakeFs::new(cx.executor());
fs.insert_tree("/", json!({ "a": {} })).await; fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await; let project = Project::test(fs.clone(), [], cx).await;
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
let connection = NativeAgentConnection( let connection = NativeAgentConnection(
NativeAgent::new( NativeAgent::new(
project.clone(), project.clone(),
history_store,
Templates::new(), Templates::new(),
None, None,
fs.clone(), fs.clone(),
@ -971,9 +1109,13 @@ mod tests {
.await; .await;
let project = Project::test(fs.clone(), [], cx).await; let project = Project::test(fs.clone(), [], cx).await;
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
// Create the agent and connection // Create the agent and connection
let agent = NativeAgent::new( let agent = NativeAgent::new(
project.clone(), project.clone(),
history_store,
Templates::new(), Templates::new(),
None, None,
fs.clone(), fs.clone(),

View file

@ -1,13 +1,18 @@
mod agent; mod agent;
mod db;
mod history_store;
mod native_agent_server; mod native_agent_server;
mod templates; mod templates;
mod thread; mod thread;
mod tool_schema;
mod tools; mod tools;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
pub use agent::*; pub use agent::*;
pub use db::*;
pub use history_store::*;
pub use native_agent_server::NativeAgentServer; pub use native_agent_server::NativeAgentServer;
pub use templates::*; pub use templates::*;
pub use thread::*; pub use thread::*;

488
crates/agent2/src/db.rs Normal file
View file

@ -0,0 +1,488 @@
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
use acp_thread::UserMessageId;
use agent::{thread::DetailedSummaryState, thread_store};
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Result, anyhow};
use chrono::{DateTime, Utc};
use collections::{HashMap, IndexMap};
use futures::{FutureExt, future::Shared};
use gpui::{BackgroundExecutor, Global, Task};
use indoc::indoc;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use sqlez::{
bindable::{Bind, Column},
connection::Connection,
statement::Statement,
};
use std::sync::Arc;
use ui::{App, SharedString};
pub type DbMessage = crate::Message;
pub type DbSummary = DetailedSummaryState;
pub type DbLanguageModel = thread_store::SerializedLanguageModel;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbThreadMetadata {
pub id: acp::SessionId,
#[serde(alias = "summary")]
pub title: SharedString,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DbThread {
pub title: SharedString,
pub messages: Vec<DbMessage>,
pub updated_at: DateTime<Utc>,
#[serde(default)]
pub detailed_summary: Option<SharedString>,
#[serde(default)]
pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
#[serde(default)]
pub cumulative_token_usage: language_model::TokenUsage,
#[serde(default)]
pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
#[serde(default)]
pub model: Option<DbLanguageModel>,
#[serde(default)]
pub completion_mode: Option<CompletionMode>,
#[serde(default)]
pub profile: Option<AgentProfileId>,
}
impl DbThread {
pub const VERSION: &'static str = "0.3.0";
pub fn from_json(json: &[u8]) -> Result<Self> {
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
match saved_thread_json.get("version") {
Some(serde_json::Value::String(version)) => match version.as_str() {
Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
},
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
}
}
fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
let mut messages = Vec::new();
let mut request_token_usage = HashMap::default();
let mut last_user_message_id = None;
for (ix, msg) in thread.messages.into_iter().enumerate() {
let message = match msg.role {
language_model::Role::User => {
let mut content = Vec::new();
// Convert segments to content
for segment in msg.segments {
match segment {
thread_store::SerializedMessageSegment::Text { text } => {
content.push(UserMessageContent::Text(text));
}
thread_store::SerializedMessageSegment::Thinking { text, .. } => {
// User messages don't have thinking segments, but handle gracefully
content.push(UserMessageContent::Text(text));
}
thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
// User messages don't have redacted thinking, skip.
}
}
}
// If no content was added, add context as text if available
if content.is_empty() && !msg.context.is_empty() {
content.push(UserMessageContent::Text(msg.context));
}
let id = UserMessageId::new();
last_user_message_id = Some(id.clone());
crate::Message::User(UserMessage {
// MessageId from old format can't be meaningfully converted, so generate a new one
id,
content,
})
}
language_model::Role::Assistant => {
let mut content = Vec::new();
// Convert segments to content
for segment in msg.segments {
match segment {
thread_store::SerializedMessageSegment::Text { text } => {
content.push(AgentMessageContent::Text(text));
}
thread_store::SerializedMessageSegment::Thinking {
text,
signature,
} => {
content.push(AgentMessageContent::Thinking { text, signature });
}
thread_store::SerializedMessageSegment::RedactedThinking { data } => {
content.push(AgentMessageContent::RedactedThinking(data));
}
}
}
// Convert tool uses
let mut tool_names_by_id = HashMap::default();
for tool_use in msg.tool_uses {
tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
content.push(AgentMessageContent::ToolUse(
language_model::LanguageModelToolUse {
id: tool_use.id,
name: tool_use.name.into(),
raw_input: serde_json::to_string(&tool_use.input)
.unwrap_or_default(),
input: tool_use.input,
is_input_complete: true,
},
));
}
// Convert tool results
let mut tool_results = IndexMap::default();
for tool_result in msg.tool_results {
let name = tool_names_by_id
.remove(&tool_result.tool_use_id)
.unwrap_or_else(|| SharedString::from("unknown"));
tool_results.insert(
tool_result.tool_use_id.clone(),
language_model::LanguageModelToolResult {
tool_use_id: tool_result.tool_use_id,
tool_name: name.into(),
is_error: tool_result.is_error,
content: tool_result.content,
output: tool_result.output,
},
);
}
if let Some(last_user_message_id) = &last_user_message_id
&& let Some(token_usage) = thread.request_token_usage.get(ix).copied()
{
request_token_usage.insert(last_user_message_id.clone(), token_usage);
}
crate::Message::Agent(AgentMessage {
content,
tool_results,
})
}
language_model::Role::System => {
// Skip system messages as they're not supported in the new format
continue;
}
};
messages.push(message);
}
Ok(Self {
title: thread.summary,
messages,
updated_at: thread.updated_at,
detailed_summary: match thread.detailed_summary_state {
DetailedSummaryState::NotGenerated | DetailedSummaryState::Generating { .. } => {
None
}
DetailedSummaryState::Generated { text, .. } => Some(text),
},
initial_project_snapshot: thread.initial_project_snapshot,
cumulative_token_usage: thread.cumulative_token_usage,
request_token_usage,
model: thread.model,
completion_mode: thread.completion_mode,
profile: thread.profile,
})
}
}
pub static ZED_STATELESS: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataType {
#[serde(rename = "json")]
Json,
#[serde(rename = "zstd")]
Zstd,
}
impl Bind for DataType {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let value = match self {
DataType::Json => "json",
DataType::Zstd => "zstd",
};
value.bind(statement, start_index)
}
}
impl Column for DataType {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (value, next_index) = String::column(statement, start_index)?;
let data_type = match value.as_str() {
"json" => DataType::Json,
"zstd" => DataType::Zstd,
_ => anyhow::bail!("Unknown data type: {}", value),
};
Ok((data_type, next_index))
}
}
pub(crate) struct ThreadsDatabase {
executor: BackgroundExecutor,
connection: Arc<Mutex<Connection>>,
}
struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
impl Global for GlobalThreadsDatabase {}
impl ThreadsDatabase {
pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
if cx.has_global::<GlobalThreadsDatabase>() {
return cx.global::<GlobalThreadsDatabase>().0.clone();
}
let executor = cx.background_executor().clone();
let task = executor
.spawn({
let executor = executor.clone();
async move {
match ThreadsDatabase::new(executor) {
Ok(db) => Ok(Arc::new(db)),
Err(err) => Err(Arc::new(err)),
}
}
})
.shared();
cx.set_global(GlobalThreadsDatabase(task.clone()));
task
}
pub fn new(executor: BackgroundExecutor) -> Result<Self> {
let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
} else {
let threads_dir = paths::data_dir().join("threads");
std::fs::create_dir_all(&threads_dir)?;
let sqlite_path = threads_dir.join("threads.db");
Connection::open_file(&sqlite_path.to_string_lossy())
};
connection.exec(indoc! {"
CREATE TABLE IF NOT EXISTS threads (
id TEXT PRIMARY KEY,
summary TEXT NOT NULL,
updated_at TEXT NOT NULL,
data_type TEXT NOT NULL,
data BLOB NOT NULL
)
"})?()
.map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
let db = Self {
executor: executor.clone(),
connection: Arc::new(Mutex::new(connection)),
};
Ok(db)
}
fn save_thread_sync(
connection: &Arc<Mutex<Connection>>,
id: acp::SessionId,
thread: DbThread,
) -> Result<()> {
const COMPRESSION_LEVEL: i32 = 3;
#[derive(Serialize)]
struct SerializedThread {
#[serde(flatten)]
thread: DbThread,
version: &'static str,
}
let title = thread.title.to_string();
let updated_at = thread.updated_at.to_rfc3339();
let json_data = serde_json::to_string(&SerializedThread {
thread,
version: DbThread::VERSION,
})?;
let connection = connection.lock();
let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
let data_type = DataType::Zstd;
let data = compressed;
let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
"})?;
insert((id.0.clone(), title, updated_at, data_type, data))?;
Ok(())
}
pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
let connection = self.connection.clone();
self.executor.spawn(async move {
let connection = connection.lock();
let mut select =
connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
"})?;
let rows = select(())?;
let mut threads = Vec::new();
for (id, summary, updated_at) in rows {
threads.push(DbThreadMetadata {
id: acp::SessionId(id),
title: summary.into(),
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
});
}
Ok(threads)
})
}
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
let connection = self.connection.clone();
self.executor.spawn(async move {
let connection = connection.lock();
let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
"})?;
let rows = select(id.0)?;
if let Some((data_type, data)) = rows.into_iter().next() {
let json_data = match data_type {
DataType::Zstd => {
let decompressed = zstd::decode_all(&data[..])?;
String::from_utf8(decompressed)?
}
DataType::Json => String::from_utf8(data)?,
};
let thread = DbThread::from_json(json_data.as_bytes())?;
Ok(Some(thread))
} else {
Ok(None)
}
})
}
pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
let connection = self.connection.clone();
self.executor
.spawn(async move { Self::save_thread_sync(&connection, id, thread) })
}
pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
let connection = self.connection.clone();
self.executor.spawn(async move {
let connection = connection.lock();
let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
DELETE FROM threads WHERE id = ?
"})?;
delete(id.0)?;
Ok(())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use agent::MessageSegment;
use agent::context::LoadedContext;
use client::Client;
use fs::FakeFs;
use gpui::AppContext;
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use language_model::Role;
use project::Project;
use settings::SettingsStore;
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
language::init(cx);
let http_client = FakeHttpClient::with_404_response();
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
agent::init(cx);
agent_settings::init(cx);
language_model::init(client.clone(), cx);
});
}
#[gpui::test]
async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
// Save a thread using the old agent.
let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
thread.update(cx, |thread, cx| {
thread.insert_message(
Role::User,
vec![MessageSegment::Text("Hey!".into())],
LoadedContext::default(),
vec![],
false,
cx,
);
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text("How're you doing?".into())],
LoadedContext::default(),
vec![],
false,
cx,
)
});
thread_store
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
.await
.unwrap();
// Open that same thread using the new agent.
let db = cx.update(ThreadsDatabase::connect).await.unwrap();
let threads = db.list_threads().await.unwrap();
assert_eq!(threads.len(), 1);
let thread = db
.load_thread(threads[0].id.clone())
.await
.unwrap()
.unwrap();
assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
assert_eq!(
thread.messages[1].to_markdown(),
"## Assistant\n\nHow're you doing?\n"
);
}
}

View file

@ -0,0 +1,345 @@
use crate::{DbThreadMetadata, ThreadsDatabase};
use acp_thread::MentionUri;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result, anyhow};
use assistant_context::{AssistantContext, SavedContextMetadata};
use chrono::{DateTime, Utc};
use db::kvp::KEY_VALUE_STORE;
use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
use itertools::Itertools;
use paths::contexts_dir;
use serde::{Deserialize, Serialize};
use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration};
use util::ResultExt as _;
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
const RECENTLY_OPENED_THREADS_KEY: &str = "recent-agent-threads";
const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
#[derive(Clone, Debug)]
pub enum HistoryEntry {
AcpThread(DbThreadMetadata),
TextThread(SavedContextMetadata),
}
impl HistoryEntry {
pub fn updated_at(&self) -> DateTime<Utc> {
match self {
HistoryEntry::AcpThread(thread) => thread.updated_at,
HistoryEntry::TextThread(context) => context.mtime.to_utc(),
}
}
pub fn id(&self) -> HistoryEntryId {
match self {
HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()),
HistoryEntry::TextThread(context) => HistoryEntryId::TextThread(context.path.clone()),
}
}
pub fn mention_uri(&self) -> MentionUri {
match self {
HistoryEntry::AcpThread(thread) => MentionUri::Thread {
id: thread.id.clone(),
name: thread.title.to_string(),
},
HistoryEntry::TextThread(context) => MentionUri::TextThread {
path: context.path.as_ref().to_owned(),
name: context.title.to_string(),
},
}
}
pub fn title(&self) -> &SharedString {
match self {
HistoryEntry::AcpThread(thread) if thread.title.is_empty() => DEFAULT_TITLE,
HistoryEntry::AcpThread(thread) => &thread.title,
HistoryEntry::TextThread(context) => &context.title,
}
}
}
/// Generic identifier for a history entry.
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub enum HistoryEntryId {
AcpThread(acp::SessionId),
TextThread(Arc<Path>),
}
#[derive(Serialize, Deserialize, Debug)]
enum SerializedRecentOpen {
AcpThread(String),
TextThread(String),
}
pub struct HistoryStore {
threads: Vec<DbThreadMetadata>,
context_store: Entity<assistant_context::ContextStore>,
recently_opened_entries: VecDeque<HistoryEntryId>,
_subscriptions: Vec<gpui::Subscription>,
_save_recently_opened_entries_task: Task<()>,
}
impl HistoryStore {
pub fn new(
context_store: Entity<assistant_context::ContextStore>,
cx: &mut Context<Self>,
) -> Self {
let subscriptions = vec![cx.observe(&context_store, |_, _, cx| cx.notify())];
cx.spawn(async move |this, cx| {
let entries = Self::load_recently_opened_entries(cx).await;
this.update(cx, |this, cx| {
if let Some(entries) = entries.log_err() {
this.recently_opened_entries = entries;
}
this.reload(cx);
})
.ok();
})
.detach();
Self {
context_store,
recently_opened_entries: VecDeque::default(),
threads: Vec::default(),
_subscriptions: subscriptions,
_save_recently_opened_entries_task: Task::ready(()),
}
}
pub fn delete_thread(
&mut self,
id: acp::SessionId,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
database.delete_thread(id.clone()).await?;
this.update(cx, |this, cx| this.reload(cx))
})
}
pub fn delete_text_thread(
&mut self,
path: Arc<Path>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.context_store.update(cx, |context_store, cx| {
context_store.delete_local_context(path, cx)
})
}
pub fn load_text_thread(
&self,
path: Arc<Path>,
cx: &mut Context<Self>,
) -> Task<Result<Entity<AssistantContext>>> {
self.context_store.update(cx, |context_store, cx| {
context_store.open_local_context(path, cx)
})
}
pub fn reload(&self, cx: &mut Context<Self>) {
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let threads = database_future
.await
.map_err(|err| anyhow!(err))?
.list_threads()
.await?;
this.update(cx, |this, cx| {
if this.recently_opened_entries.len() < MAX_RECENTLY_OPENED_ENTRIES {
for thread in threads
.iter()
.take(MAX_RECENTLY_OPENED_ENTRIES - this.recently_opened_entries.len())
.rev()
{
this.push_recently_opened_entry(
HistoryEntryId::AcpThread(thread.id.clone()),
cx,
)
}
}
this.threads = threads;
cx.notify();
})
})
.detach_and_log_err(cx);
}
pub fn entries(&self, cx: &App) -> Vec<HistoryEntry> {
let mut history_entries = Vec::new();
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
return history_entries;
}
history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread));
history_entries.extend(
self.context_store
.read(cx)
.unordered_contexts()
.cloned()
.map(HistoryEntry::TextThread),
);
history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
history_entries
}
pub fn is_empty(&self, cx: &App) -> bool {
self.threads.is_empty()
&& self
.context_store
.read(cx)
.unordered_contexts()
.next()
.is_none()
}
pub fn recently_opened_entries(&self, cx: &App) -> Vec<HistoryEntry> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
return Vec::new();
}
let thread_entries = self.threads.iter().flat_map(|thread| {
self.recently_opened_entries
.iter()
.enumerate()
.flat_map(|(index, entry)| match entry {
HistoryEntryId::AcpThread(id) if &thread.id == id => {
Some((index, HistoryEntry::AcpThread(thread.clone())))
}
_ => None,
})
});
let context_entries =
self.context_store
.read(cx)
.unordered_contexts()
.flat_map(|context| {
self.recently_opened_entries
.iter()
.enumerate()
.flat_map(|(index, entry)| match entry {
HistoryEntryId::TextThread(path) if &context.path == path => {
Some((index, HistoryEntry::TextThread(context.clone())))
}
_ => None,
})
});
thread_entries
.chain(context_entries)
// optimization to halt iteration early
.take(self.recently_opened_entries.len())
.sorted_unstable_by_key(|(index, _)| *index)
.map(|(_, entry)| entry)
.collect()
}
fn save_recently_opened_entries(&mut self, cx: &mut Context<Self>) {
let serialized_entries = self
.recently_opened_entries
.iter()
.filter_map(|entry| match entry {
HistoryEntryId::TextThread(path) => path.file_name().map(|file| {
SerializedRecentOpen::TextThread(file.to_string_lossy().to_string())
}),
HistoryEntryId::AcpThread(id) => {
Some(SerializedRecentOpen::AcpThread(id.to_string()))
}
})
.collect::<Vec<_>>();
self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| {
let content = serde_json::to_string(&serialized_entries).unwrap();
cx.background_executor()
.timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE)
.await;
if cfg!(any(feature = "test-support", test)) {
return;
}
KEY_VALUE_STORE
.write_kvp(RECENTLY_OPENED_THREADS_KEY.to_owned(), content)
.await
.log_err();
});
}
fn load_recently_opened_entries(cx: &AsyncApp) -> Task<Result<VecDeque<HistoryEntryId>>> {
cx.background_spawn(async move {
if cfg!(any(feature = "test-support", test)) {
anyhow::bail!("history store does not persist in tests");
}
let json = KEY_VALUE_STORE
.read_kvp(RECENTLY_OPENED_THREADS_KEY)?
.unwrap_or("[]".to_string());
let entries = serde_json::from_str::<Vec<SerializedRecentOpen>>(&json)
.context("deserializing persisted agent panel navigation history")?
.into_iter()
.take(MAX_RECENTLY_OPENED_ENTRIES)
.flat_map(|entry| match entry {
SerializedRecentOpen::AcpThread(id) => Some(HistoryEntryId::AcpThread(
acp::SessionId(id.as_str().into()),
)),
SerializedRecentOpen::TextThread(file_name) => Some(
HistoryEntryId::TextThread(contexts_dir().join(file_name).into()),
),
})
.collect();
Ok(entries)
})
}
pub fn push_recently_opened_entry(&mut self, entry: HistoryEntryId, cx: &mut Context<Self>) {
self.recently_opened_entries
.retain(|old_entry| old_entry != &entry);
self.recently_opened_entries.push_front(entry);
self.recently_opened_entries
.truncate(MAX_RECENTLY_OPENED_ENTRIES);
self.save_recently_opened_entries(cx);
}
pub fn remove_recently_opened_thread(&mut self, id: acp::SessionId, cx: &mut Context<Self>) {
self.recently_opened_entries.retain(|entry| match entry {
HistoryEntryId::AcpThread(thread_id) if thread_id == &id => false,
_ => true,
});
self.save_recently_opened_entries(cx);
}
pub fn replace_recently_opened_text_thread(
&mut self,
old_path: &Path,
new_path: &Arc<Path>,
cx: &mut Context<Self>,
) {
for entry in &mut self.recently_opened_entries {
match entry {
HistoryEntryId::TextThread(path) if path.as_ref() == old_path => {
*entry = HistoryEntryId::TextThread(new_path.clone());
break;
}
_ => {}
}
}
self.save_recently_opened_entries(cx);
}
pub fn remove_recently_opened_entry(&mut self, entry: &HistoryEntryId, cx: &mut Context<Self>) {
self.recently_opened_entries
.retain(|old_entry| old_entry != entry);
self.save_recently_opened_entries(cx);
}
}

View file

@ -1,4 +1,4 @@
use std::{path::Path, rc::Rc, sync::Arc}; use std::{any::Any, path::Path, rc::Rc, sync::Arc};
use agent_servers::AgentServer; use agent_servers::AgentServer;
use anyhow::Result; use anyhow::Result;
@ -7,16 +7,17 @@ use gpui::{App, Entity, Task};
use project::Project; use project::Project;
use prompt_store::PromptStore; use prompt_store::PromptStore;
use crate::{NativeAgent, NativeAgentConnection, templates::Templates}; use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates};
#[derive(Clone)] #[derive(Clone)]
pub struct NativeAgentServer { pub struct NativeAgentServer {
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
history: Entity<HistoryStore>,
} }
impl NativeAgentServer { impl NativeAgentServer {
pub fn new(fs: Arc<dyn Fs>) -> Self { pub fn new(fs: Arc<dyn Fs>, history: Entity<HistoryStore>) -> Self {
Self { fs } Self { fs, history }
} }
} }
@ -26,16 +27,15 @@ impl AgentServer for NativeAgentServer {
} }
fn empty_state_headline(&self) -> &'static str { fn empty_state_headline(&self) -> &'static str {
"Native Agent" ""
} }
fn empty_state_message(&self) -> &'static str { fn empty_state_message(&self) -> &'static str {
"How can I help you today?" ""
} }
fn logo(&self) -> ui::IconName { fn logo(&self) -> ui::IconName {
// Using the ZedAssistant icon as it's the native built-in agent ui::IconName::ZedAgent
ui::IconName::ZedAssistant
} }
fn connect( fn connect(
@ -50,6 +50,7 @@ impl AgentServer for NativeAgentServer {
); );
let project = project.clone(); let project = project.clone();
let fs = self.fs.clone(); let fs = self.fs.clone();
let history = self.history.clone();
let prompt_store = PromptStore::global(cx); let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent"); log::debug!("Creating templates for native agent");
@ -57,7 +58,8 @@ impl AgentServer for NativeAgentServer {
let prompt_store = prompt_store.await?; let prompt_store = prompt_store.await?;
log::debug!("Creating native agent entity"); log::debug!("Creating native agent entity");
let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?; let agent =
NativeAgent::new(project, history, templates, Some(prompt_store), fs, cx).await?;
// Create the connection wrapper // Create the connection wrapper
let connection = NativeAgentConnection(agent); let connection = NativeAgentConnection(agent);
@ -66,4 +68,8 @@ impl AgentServer for NativeAgentServer {
Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>) Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
}) })
} }
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
} }

View file

@ -62,7 +62,7 @@ fn contains(
handlebars::RenderError::new("contains: missing or invalid query parameter") handlebars::RenderError::new("contains: missing or invalid query parameter")
})?; })?;
if list.contains(&query) { if list.contains(query) {
out.write("true")?; out.write("true")?;
} }

View file

@ -6,15 +6,16 @@ use agent_settings::AgentProfileId;
use anyhow::Result; use anyhow::Result;
use client::{Client, UserStore}; use client::{Client, UserStore};
use fs::{FakeFs, Fs}; use fs::{FakeFs, Fs};
use futures::channel::mpsc::UnboundedReceiver; use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
use gpui::{ use gpui::{
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient, App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
}; };
use indoc::indoc; use indoc::indoc;
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent, LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
Role, StopReason, fake_provider::FakeLanguageModel, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
fake_provider::FakeLanguageModel,
}; };
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use project::Project; use project::Project;
@ -24,8 +25,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use settings::SettingsStore; use settings::SettingsStore;
use smol::stream::StreamExt; use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path; use util::path;
mod test_tools; mod test_tools;
@ -101,7 +101,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
} = setup(cx, TestModel::Fake).await; } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake(); let fake_model = model.as_fake();
project_context.borrow_mut().shell = "test-shell".into(); project_context.update(cx, |project_context, _cx| {
project_context.shell = "test-shell".into()
});
thread.update(cx, |thread, _| thread.add_tool(EchoTool)); thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
@ -343,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let mut saw_partial_tool_use = false; let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message // Look for a tool use in the thread's last message
let message = thread.last_message().unwrap(); let message = thread.last_message().unwrap();
@ -733,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
); );
} }
async fn expect_tool_call( async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCall {
let event = events let event = events
.next() .next()
.await .await
.expect("no tool call authorization event received") .expect("no tool call authorization event received")
.unwrap(); .unwrap();
match event { match event {
AgentResponseEvent::ToolCall(tool_call) => return tool_call, ThreadEvent::ToolCall(tool_call) => tool_call,
event => { event => {
panic!("Unexpected event {event:?}"); panic!("Unexpected event {event:?}");
} }
@ -750,7 +750,7 @@ async fn expect_tool_call(
} }
async fn expect_tool_call_update_fields( async fn expect_tool_call_update_fields(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>, events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> acp::ToolCallUpdate { ) -> acp::ToolCallUpdate {
let event = events let event = events
.next() .next()
@ -758,9 +758,7 @@ async fn expect_tool_call_update_fields(
.expect("no tool call authorization event received") .expect("no tool call authorization event received")
.unwrap(); .unwrap();
match event { match event {
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
return update;
}
event => { event => {
panic!("Unexpected event {event:?}"); panic!("Unexpected event {event:?}");
} }
@ -768,7 +766,7 @@ async fn expect_tool_call_update_fields(
} }
async fn next_tool_call_authorization( async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>, events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> ToolCallAuthorization { ) -> ToolCallAuthorization {
loop { loop {
let event = events let event = events
@ -776,7 +774,7 @@ async fn next_tool_call_authorization(
.await .await
.expect("no tool call authorization event received") .expect("no tool call authorization event received")
.unwrap(); .unwrap();
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization let permission_kinds = tool_call_authorization
.options .options
.iter() .iter()
@ -943,13 +941,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
let mut echo_completed = false; let mut echo_completed = false;
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
match event.unwrap() { match event.unwrap() {
AgentResponseEvent::ToolCall(tool_call) => { ThreadEvent::ToolCall(tool_call) => {
assert_eq!(tool_call.title, expected_tools.remove(0)); assert_eq!(tool_call.title, expected_tools.remove(0));
if tool_call.title == "Echo" { if tool_call.title == "Echo" {
echo_id = Some(tool_call.id); echo_id = Some(tool_call.id);
} }
} }
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
acp::ToolCallUpdate { acp::ToolCallUpdate {
id, id,
fields: fields:
@ -971,13 +969,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even // Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running. // if one of the tools is still running.
thread.update(cx, |thread, _cx| thread.cancel()); thread.update(cx, |thread, cx| thread.cancel(cx));
let events = events.collect::<Vec<_>>().await; let events = events.collect::<Vec<_>>().await;
let last_event = events.last(); let last_event = events.last();
assert!( assert!(
matches!( matches!(
last_event, last_event,
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
), ),
"unexpected event {last_event:?}" "unexpected event {last_event:?}"
); );
@ -1119,7 +1117,7 @@ async fn test_refusal(cx: &mut TestAppContext) {
} }
#[gpui::test] #[gpui::test]
async fn test_truncate(cx: &mut TestAppContext) { async fn test_truncate_first_message(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake(); let fake_model = model.as_fake();
@ -1139,9 +1137,18 @@ async fn test_truncate(cx: &mut TestAppContext) {
Hello Hello
"} "}
); );
assert_eq!(thread.latest_token_usage(), None);
}); });
fake_model.send_last_completion_stream_text_chunk("Hey!"); fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 32_000,
output_tokens: 16_000,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
cx.run_until_parked(); cx.run_until_parked();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
assert_eq!( assert_eq!(
@ -1156,14 +1163,22 @@ async fn test_truncate(cx: &mut TestAppContext) {
Hey! Hey!
"} "}
); );
assert_eq!(
thread.latest_token_usage(),
Some(acp_thread::TokenUsage {
used_tokens: 32_000 + 16_000,
max_tokens: 1_000_000,
})
);
}); });
thread thread
.update(cx, |thread, _cx| thread.truncate(message_id)) .update(cx, |thread, cx| thread.truncate(message_id, cx))
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
assert_eq!(thread.to_markdown(), ""); assert_eq!(thread.to_markdown(), "");
assert_eq!(thread.latest_token_usage(), None);
}); });
// Ensure we can still send a new message after truncation. // Ensure we can still send a new message after truncation.
@ -1184,6 +1199,14 @@ async fn test_truncate(cx: &mut TestAppContext) {
}); });
cx.run_until_parked(); cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Ahoy!"); fake_model.send_last_completion_stream_text_chunk("Ahoy!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 40_000,
output_tokens: 20_000,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
cx.run_until_parked(); cx.run_until_parked();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
assert_eq!( assert_eq!(
@ -1198,9 +1221,171 @@ async fn test_truncate(cx: &mut TestAppContext) {
Ahoy! Ahoy!
"} "}
); );
assert_eq!(
thread.latest_token_usage(),
Some(acp_thread::TokenUsage {
used_tokens: 40_000 + 20_000,
max_tokens: 1_000_000,
})
);
}); });
} }
#[gpui::test]
async fn test_truncate_second_message(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Message 1"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 32_000,
output_tokens: 16_000,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let assert_first_message_state = |cx: &mut TestAppContext| {
thread.clone().read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Message 1
## Assistant
Message 1 response
"}
);
assert_eq!(
thread.latest_token_usage(),
Some(acp_thread::TokenUsage {
used_tokens: 32_000 + 16_000,
max_tokens: 1_000_000,
})
);
});
};
assert_first_message_state(cx);
let second_message_id = UserMessageId::new();
thread
.update(cx, |thread, cx| {
thread.send(second_message_id.clone(), ["Message 2"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 40_000,
output_tokens: 20_000,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Message 1
## Assistant
Message 1 response
## User
Message 2
## Assistant
Message 2 response
"}
);
assert_eq!(
thread.latest_token_usage(),
Some(acp_thread::TokenUsage {
used_tokens: 40_000 + 20_000,
max_tokens: 1_000_000,
})
);
});
thread
.update(cx, |thread, cx| thread.truncate(second_message_id, cx))
.unwrap();
cx.run_until_parked();
assert_first_message_state(cx);
}
#[gpui::test]
async fn test_title_generation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let summary_model = Arc::new(FakeLanguageModel::default());
thread.update(cx, |thread, cx| {
thread.set_summarization_model(Some(summary_model.clone()), cx)
});
let send = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
// Ensure the summary model has been invoked to generate a title.
summary_model.send_last_completion_stream_text_chunk("Hello ");
summary_model.send_last_completion_stream_text_chunk("world\nG");
summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
summary_model.end_last_completion_stream();
send.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
// Send another message, ensuring no title is generated this time.
let send = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello again"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey again!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
assert_eq!(summary_model.pending_completions(), Vec::new());
send.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
}
#[gpui::test] #[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) { async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init); cx.update(settings::init);
@ -1228,10 +1413,13 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
fake_fs.insert_tree(path!("/test"), json!({})).await; fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await; let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
let cwd = Path::new("/test"); let cwd = Path::new("/test");
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
// Create agent and connection // Create agent and connection
let agent = NativeAgent::new( let agent = NativeAgent::new(
project.clone(), project.clone(),
history_store,
templates.clone(), templates.clone(),
None, None,
fake_fs.clone(), fake_fs.clone(),
@ -1433,12 +1621,168 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
); );
} }
#[gpui::test]
async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
ThreadEvent::Stop(..) => break,
_ => {}
}
}
assert_eq!(retry_events.len(), 0);
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hello!
## Assistant
Hey!
"}
)
});
}
#[gpui::test]
async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
});
fake_model.end_last_completion_stream();
cx.executor().advance_clock(Duration::from_secs(3));
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
ThreadEvent::Stop(..) => break,
_ => {}
}
}
assert_eq!(retry_events.len(), 1);
assert!(matches!(
retry_events[0],
acp_thread::RetryStatus { attempt: 1, .. }
));
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hello!
## Assistant
Hey!
"}
)
});
}
#[gpui::test]
async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
cx.run_until_parked();
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
fake_model.send_last_completion_stream_error(
LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
},
);
fake_model.end_last_completion_stream();
cx.executor().advance_clock(Duration::from_secs(3));
cx.run_until_parked();
}
let mut errors = Vec::new();
let mut retry_events = Vec::new();
while let Some(event) = events.next().await {
match event {
Ok(ThreadEvent::Retry(retry_status)) => {
retry_events.push(retry_status);
}
Ok(ThreadEvent::Stop(..)) => break,
Err(error) => errors.push(error),
_ => {}
}
}
assert_eq!(
retry_events.len(),
crate::thread::MAX_RETRY_ATTEMPTS as usize
);
for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
assert_eq!(retry_events[i].attempt, i + 1);
}
assert_eq!(errors.len(), 1);
let error = errors[0]
.downcast_ref::<LanguageModelCompletionError>()
.unwrap();
assert!(matches!(
error,
LanguageModelCompletionError::ServerOverloaded { .. }
));
}
/// Filters out the stop events for asserting against in tests /// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> { fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
result_events result_events
.into_iter() .into_iter()
.filter_map(|event| match event.unwrap() { .filter_map(|event| match event.unwrap() {
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), ThreadEvent::Stop(stop_reason) => Some(stop_reason),
_ => None, _ => None,
}) })
.collect() .collect()
@ -1447,7 +1791,7 @@ fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopR
struct ThreadTest { struct ThreadTest {
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
thread: Entity<Thread>, thread: Entity<Thread>,
project_context: Rc<RefCell<ProjectContext>>, project_context: Entity<ProjectContext>,
fs: Arc<FakeFs>, fs: Arc<FakeFs>,
} }
@ -1543,7 +1887,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
}) })
.await; .await;
let project_context = Rc::new(RefCell::new(ProjectContext::default())); let project_context = cx.new(|_cx| ProjectContext::default());
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,43 @@
use language_model::LanguageModelToolSchemaFormat;
use schemars::{
JsonSchema, Schema,
generate::SchemaSettings,
transform::{Transform, transform_subschemas},
};
pub(crate) fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
let mut generator = match format {
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
.with(|settings| {
settings.meta_schema = None;
settings.inline_subschemas = true;
})
.with_transform(ToJsonSchemaSubsetTransform)
.into_generator(),
};
generator.root_schema_for::<T>()
}
#[derive(Debug, Clone)]
struct ToJsonSchemaSubsetTransform;
impl Transform for ToJsonSchemaSubsetTransform {
fn transform(&mut self, schema: &mut Schema) {
// Ensure that the type field is not an array, this happens when we use
// Option<T>, the type will be [T, "null"].
if let Some(type_field) = schema.get_mut("type")
&& let Some(types) = type_field.as_array()
&& let Some(first_type) = types.first()
{
*type_field = first_type.clone();
}
// oneOf is not supported, use anyOf instead
if let Some(one_of) = schema.remove("oneOf") {
schema.insert("anyOf".to_string(), one_of);
}
transform_subschemas(self, schema);
}
}

View file

@ -103,7 +103,7 @@ impl ContextServerRegistry {
self.reload_tools_for_server(server_id.clone(), cx); self.reload_tools_for_server(server_id.clone(), cx);
} }
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
self.registered_servers.remove(&server_id); self.registered_servers.remove(server_id);
cx.notify(); cx.notify();
} }
} }
@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
}) })
}) })
} }
fn replay(
&self,
_input: serde_json::Value,
_output: serde_json::Value,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
Ok(())
}
} }

View file

@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent; use cloud_llm_client::CompletionIntent;
use collections::HashSet; use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task}; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc; use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave}; use language::language_settings::{self, FormatOnSave};
use language::{LanguageRegistry, ToPoint};
use language_model::LanguageModelToolResultContent; use language_model::LanguageModelToolResultContent;
use paths; use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget}; use project::lsp_store::{FormatTrigger, LspFormatTarget};
@ -98,11 +98,13 @@ pub enum EditFileMode {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput { pub struct EditFileToolOutput {
#[serde(alias = "original_path")]
input_path: PathBuf, input_path: PathBuf,
project_path: PathBuf,
new_text: String, new_text: String,
old_text: Arc<String>, old_text: Arc<String>,
#[serde(default)]
diff: String, diff: String,
#[serde(alias = "raw_output")]
edit_agent_output: EditAgentOutput, edit_agent_output: EditAgentOutput,
} }
@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
} }
pub struct EditFileTool { pub struct EditFileTool {
thread: Entity<Thread>, thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
} }
impl EditFileTool { impl EditFileTool {
pub fn new(thread: Entity<Thread>) -> Self { pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
Self { thread } Self {
thread,
language_registry,
}
} }
fn authorize( fn authorize(
@ -156,19 +162,22 @@ impl EditFileTool {
// It's also possible that the global config dir is configured to be inside the project, // It's also possible that the global config dir is configured to be inside the project,
// so check for that edge case too. // so check for that edge case too.
if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
if canonical_path.starts_with(paths::config_dir()) { && canonical_path.starts_with(paths::config_dir())
return event_stream.authorize( {
format!("{} (global settings)", input.display_description), return event_stream.authorize(
cx, format!("{} (global settings)", input.display_description),
); cx,
} );
} }
// Check if path is inside the global config directory // Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize // First check if it's already inside project - if not, try to canonicalize
let thread = self.thread.read(cx); let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
let project_path = thread.project().read(cx).find_project_path(&input.path, cx); thread.project().read(cx).find_project_path(&input.path, cx)
}) else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
// If the path is inside the project, and it's not one of the above edge cases, // If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary. // then no confirmation is necessary. Otherwise, confirmation is necessary.
@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<Self::Output>> { ) -> Task<Result<Self::Output>> {
let project = self.thread.read(cx).project().clone(); let Ok(project) = self
.thread
.read_with(cx, |thread, _cx| thread.project().clone())
else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
let project_path = match resolve_path(&input, project.clone(), cx) { let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path, Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))), Err(err) => return Task::ready(Err(anyhow!(err))),
@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
}); });
} }
let Some(request) = self.thread.update(cx, |thread, cx| {
thread
.build_completion_request(CompletionIntent::ToolResults, cx)
.ok()
}) else {
return Task::ready(Err(anyhow!("Failed to build completion request")));
};
let thread = self.thread.read(cx);
let Some(model) = thread.model().cloned() else {
return Task::ready(Err(anyhow!("No language model configured")));
};
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx); let authorize = self.authorize(&input, &event_stream, cx);
cx.spawn(async move |cx: &mut AsyncApp| { cx.spawn(async move |cx: &mut AsyncApp| {
authorize.await?; authorize.await?;
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
(request, thread.model().cloned(), thread.action_log().clone())
})?;
let request = request?;
let model = model.context("No language model configured")?;
let edit_format = EditFormat::from_model(model.clone())?; let edit_format = EditFormat::from_model(model.clone())?;
let edit_agent = EditAgent::new( let edit_agent = EditAgent::new(
model, model,
@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
Ok(EditFileToolOutput { Ok(EditFileToolOutput {
input_path: input.path, input_path: input.path,
project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(), new_text: new_text.clone(),
old_text, old_text,
diff: unified_diff, diff: unified_diff,
@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
}) })
}) })
} }
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()> {
event_stream.update_diff(cx.new(|cx| {
Diff::finalized(
output.input_path,
Some(output.old_text.to_string()),
output.new_text,
self.language_registry.clone(),
cx,
)
}));
Ok(())
}
} }
/// Validate that the file path is valid, meaning: /// Validate that the file path is valid, meaning:
@ -471,7 +497,7 @@ fn resolve_path(
let parent_entry = parent_project_path let parent_entry = parent_project_path
.as_ref() .as_ref()
.and_then(|path| project.entry_for_path(&path, cx)) .and_then(|path| project.entry_for_path(path, cx))
.context("Can't create file: parent directory doesn't exist")?; .context("Can't create file: parent directory doesn't exist")?;
anyhow::ensure!( anyhow::ensure!(
@ -503,9 +529,9 @@ mod tests {
use fs::Fs; use fs::Fs;
use gpui::{TestAppContext, UpdateGlobal}; use gpui::{TestAppContext, UpdateGlobal};
use language_model::fake_provider::FakeLanguageModel; use language_model::fake_provider::FakeLanguageModel;
use prompt_store::ProjectContext;
use serde_json::json; use serde_json::json;
use settings::SettingsStore; use settings::SettingsStore;
use std::rc::Rc;
use util::path; use util::path;
#[gpui::test] #[gpui::test]
@ -515,6 +541,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor()); let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await; fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -522,7 +549,7 @@ mod tests {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project, project,
Rc::default(), cx.new(|_cx| ProjectContext::default()),
context_server_registry, context_server_registry,
action_log, action_log,
Templates::new(), Templates::new(),
@ -537,7 +564,11 @@ mod tests {
path: "root/nonexistent_file.txt".into(), path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit, mode: EditFileMode::Edit,
}; };
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
}) })
.await; .await;
assert_eq!( assert_eq!(
@ -624,8 +655,7 @@ mod tests {
mode: mode.clone(), mode: mode.clone(),
}; };
let result = cx.update(|cx| resolve_path(&input, project, cx)); cx.update(|cx| resolve_path(&input, project, cx))
result
} }
fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) { fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
@ -719,7 +749,7 @@ mod tests {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project, project,
Rc::default(), cx.new(|_cx| ProjectContext::default()),
context_server_registry, context_server_registry,
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
@ -750,9 +780,10 @@ mod tests {
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}; };
Arc::new(EditFileTool { Arc::new(EditFileTool::new(
thread: thread.clone(), thread.downgrade(),
}) language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx) .run(input, ToolCallEventStream::test().0, cx)
}); });
@ -806,7 +837,11 @@ mod tests {
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}; };
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
}); });
// Stream the unformatted content // Stream the unformatted content
@ -850,12 +885,13 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project, project,
Rc::default(), cx.new(|_cx| ProjectContext::default()),
context_server_registry, context_server_registry,
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
@ -887,9 +923,10 @@ mod tests {
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}; };
Arc::new(EditFileTool { Arc::new(EditFileTool::new(
thread: thread.clone(), thread.downgrade(),
}) language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx) .run(input, ToolCallEventStream::test().0, cx)
}); });
@ -938,10 +975,11 @@ mod tests {
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}; };
Arc::new(EditFileTool { Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
thread: thread.clone(), input,
}) ToolCallEventStream::test().0,
.run(input, ToolCallEventStream::test().0, cx) cx,
)
}); });
// Stream the content with trailing whitespace // Stream the content with trailing whitespace
@ -976,12 +1014,13 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project, project,
Rc::default(), cx.new(|_cx| ProjectContext::default()),
context_server_registry, context_server_registry,
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
@ -989,7 +1028,7 @@ mod tests {
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
fs.insert_tree("/root", json!({})).await; fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation // Test 1: Path with .zed component should require confirmation
@ -1111,6 +1150,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor()); let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await; fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
@ -1118,7 +1158,7 @@ mod tests {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project, project,
Rc::default(), cx.new(|_cx| ProjectContext::default()),
context_server_registry, context_server_registry,
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
@ -1126,7 +1166,7 @@ mod tests {
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test global config paths - these should require confirmation if they exist and are outside the project // Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![ let test_cases = vec![
@ -1220,7 +1260,7 @@ mod tests {
cx, cx,
) )
.await; .await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1228,7 +1268,7 @@ mod tests {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project.clone(), project.clone(),
Rc::default(), cx.new(|_cx| ProjectContext::default()),
context_server_registry.clone(), context_server_registry.clone(),
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
@ -1236,7 +1276,7 @@ mod tests {
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test files in different worktrees // Test files in different worktrees
let test_cases = vec![ let test_cases = vec![
@ -1302,6 +1342,7 @@ mod tests {
) )
.await; .await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1309,7 +1350,7 @@ mod tests {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project.clone(), project.clone(),
Rc::default(), cx.new(|_cx| ProjectContext::default()),
context_server_registry.clone(), context_server_registry.clone(),
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
@ -1317,7 +1358,7 @@ mod tests {
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test edge cases // Test edge cases
let test_cases = vec![ let test_cases = vec![
@ -1386,6 +1427,7 @@ mod tests {
) )
.await; .await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1393,7 +1435,7 @@ mod tests {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project.clone(), project.clone(),
Rc::default(), cx.new(|_cx| ProjectContext::default()),
context_server_registry.clone(), context_server_registry.clone(),
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
@ -1401,7 +1443,7 @@ mod tests {
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test different EditFileMode values // Test different EditFileMode values
let modes = vec![ let modes = vec![
@ -1467,6 +1509,7 @@ mod tests {
init_test(cx); init_test(cx);
let fs = project::FakeFs::new(cx.executor()); let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1474,7 +1517,7 @@ mod tests {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
Thread::new( Thread::new(
project.clone(), project.clone(),
Rc::default(), cx.new(|_cx| ProjectContext::default()),
context_server_registry, context_server_registry,
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
@ -1482,7 +1525,7 @@ mod tests {
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
assert_eq!( assert_eq!(
tool.initial_title(Err(json!({ tool.initial_title(Err(json!({

View file

@ -179,15 +179,14 @@ impl AgentTool for GrepTool {
// Check if this file should be excluded based on its worktree settings // Check if this file should be excluded based on its worktree settings
if let Ok(Some(project_path)) = project.read_with(cx, |project, cx| { if let Ok(Some(project_path)) = project.read_with(cx, |project, cx| {
project.find_project_path(&path, cx) project.find_project_path(&path, cx)
}) { })
if cx.update(|cx| { && cx.update(|cx| {
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
worktree_settings.is_path_excluded(&project_path.path) worktree_settings.is_path_excluded(&project_path.path)
|| worktree_settings.is_path_private(&project_path.path) || worktree_settings.is_path_private(&project_path.path)
}).unwrap_or(false) { }).unwrap_or(false) {
continue; continue;
} }
}
while *parse_status.borrow() != ParseStatus::Idle { while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?; parse_status.changed().await?;
@ -275,12 +274,11 @@ impl AgentTool for GrepTool {
output.extend(snapshot.text_for_range(range)); output.extend(snapshot.text_for_range(range));
output.push_str("\n```\n"); output.push_str("\n```\n");
if let Some(ancestor_range) = ancestor_range { if let Some(ancestor_range) = ancestor_range
if end_row < ancestor_range.end.row { && end_row < ancestor_range.end.row {
let remaining_lines = ancestor_range.end.row - end_row; let remaining_lines = ancestor_range.end.row - end_row;
writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?; writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
} }
}
matches_found += 1; matches_found += 1;
} }

View file

@ -175,7 +175,7 @@ impl AgentTool for ReadFileTool {
buffer buffer
.file() .file()
.as_ref() .as_ref()
.map_or(true, |file| !file.disk_state().exists()) .is_none_or(|file| !file.disk_state().exists())
})? { })? {
anyhow::bail!("{file_path} not found"); anyhow::bail!("{file_path} not found");
} }

View file

@ -47,12 +47,9 @@ impl TerminalTool {
} }
if which::which("bash").is_ok() { if which::which("bash").is_ok() {
log::info!("agent selected bash for terminal tool");
"bash".into() "bash".into()
} else { } else {
let shell = get_system_shell(); get_system_shell()
log::info!("agent selected {shell} for terminal tool");
shell
} }
}); });
Self { Self {
@ -80,7 +77,7 @@ impl AgentTool for TerminalTool {
let first_line = lines.next().unwrap_or_default(); let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count(); let remaining_line_count = lines.count();
match remaining_line_count { match remaining_line_count {
0 => MarkdownInlineCode(&first_line).to_string().into(), 0 => MarkdownInlineCode(first_line).to_string().into(),
1 => MarkdownInlineCode(&format!( 1 => MarkdownInlineCode(&format!(
"{} - {} more line", "{} - {} more line",
first_line, remaining_line_count first_line, remaining_line_count
@ -271,7 +268,7 @@ fn working_dir(
let project = project.read(cx); let project = project.read(cx);
let cd = &input.cd; let cd = &input.cd;
if cd == "." || cd == "" { if cd == "." || cd.is_empty() {
// Accept "." or "" as meaning "the one worktree" if we only have one worktree. // Accept "." or "" as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx); let mut worktrees = project.worktrees(cx);
@ -296,10 +293,8 @@ fn working_dir(
{ {
return Ok(Some(input_path.into())); return Ok(Some(input_path.into()));
} }
} else { } else if let Some(worktree) = project.worktree_for_root_name(cd, cx) {
if let Some(worktree) = project.worktree_for_root_name(cd, cx) { return Ok(Some(worktree.read(cx).abs_path().to_path_buf()));
return Ok(Some(worktree.read(cx).abs_path().to_path_buf()));
}
} }
anyhow::bail!("`cd` directory {cd:?} was not in any of the project's worktrees."); anyhow::bail!("`cd` directory {cd:?} was not in any of the project's worktrees.");
@ -319,7 +314,7 @@ mod tests {
use theme::ThemeSettings; use theme::ThemeSettings;
use util::test::TempTree; use util::test::TempTree;
use crate::AgentResponseEvent; use crate::ThreadEvent;
use super::*; use super::*;
@ -396,7 +391,7 @@ mod tests {
}); });
cx.run_until_parked(); cx.run_until_parked();
let event = stream_rx.try_next(); let event = stream_rx.try_next();
if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event { if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
auth.response.send(auth.options[0].id.clone()).unwrap(); auth.response.send(auth.options[0].id.clone()).unwrap();
} }

View file

@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
} }
}; };
let result_text = if response.results.len() == 1 { emit_update(&response, &event_stream);
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
Ok(WebSearchToolOutput(response)) Ok(WebSearchToolOutput(response))
}) })
} }
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
emit_update(&output.0, &event_stream);
Ok(())
}
}
fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
let result_text = if response.results.len() == 1 {
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
} }

View file

@ -18,7 +18,9 @@ doctest = false
[dependencies] [dependencies]
acp_thread.workspace = true acp_thread.workspace = true
action_log.workspace = true
agent-client-protocol.workspace = true agent-client-protocol.workspace = true
agent_settings.workspace = true
agentic-coding-protocol.workspace = true agentic-coding-protocol.workspace = true
anyhow.workspace = true anyhow.workspace = true
collections.workspace = true collections.workspace = true
@ -27,11 +29,15 @@ futures.workspace = true
gpui.workspace = true gpui.workspace = true
indoc.workspace = true indoc.workspace = true
itertools.workspace = true itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_models.workspace = true
log.workspace = true log.workspace = true
paths.workspace = true paths.workspace = true
project.workspace = true project.workspace = true
rand.workspace = true rand.workspace = true
schemars.workspace = true schemars.workspace = true
semver.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
settings.workspace = true settings.workspace = true

View file

@ -19,14 +19,14 @@ pub async fn connect(
root_dir: &Path, root_dir: &Path,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Result<Rc<dyn AgentConnection>> { ) -> Result<Rc<dyn AgentConnection>> {
let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await; let conn = v1::AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await;
match conn { match conn {
Ok(conn) => Ok(Rc::new(conn) as _), Ok(conn) => Ok(Rc::new(conn) as _),
Err(err) if err.is::<UnsupportedVersion>() => { Err(err) if err.is::<UnsupportedVersion>() => {
// Consider re-using initialize response and subprocess when adding another version here // Consider re-using initialize response and subprocess when adding another version here
let conn: Rc<dyn AgentConnection> = let conn: Rc<dyn AgentConnection> =
Rc::new(v0::AcpConnection::stdio(server_name, command, &root_dir, cx).await?); Rc::new(v0::AcpConnection::stdio(server_name, command, root_dir, cx).await?);
Ok(conn) Ok(conn)
} }
Err(err) => Err(err), Err(err) => Err(err),

View file

@ -1,4 +1,5 @@
// Translates old acp agents into the new schema // Translates old acp agents into the new schema
use action_log::ActionLog;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
@ -148,7 +149,7 @@ impl acp_old::Client for OldAcpClientDelegate {
Ok(acp_old::RequestToolCallConfirmationResponse { Ok(acp_old::RequestToolCallConfirmationResponse {
id: acp_old::ToolCallId(old_acp_id), id: acp_old::ToolCallId(old_acp_id),
outcome: outcome, outcome,
}) })
} }
@ -265,7 +266,7 @@ impl acp_old::Client for OldAcpClientDelegate {
fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall { fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall {
acp::ToolCall { acp::ToolCall {
id: id, id,
title: request.label, title: request.label,
kind: acp_kind_from_old_icon(request.icon), kind: acp_kind_from_old_icon(request.icon),
status: acp::ToolCallStatus::InProgress, status: acp::ToolCallStatus::InProgress,
@ -437,13 +438,14 @@ impl AgentConnection for AcpConnection {
let result = acp_old::InitializeParams::response_from_any(result)?; let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated { if !result.is_authenticated {
anyhow::bail!(AuthRequired) anyhow::bail!(AuthRequired::new())
} }
cx.update(|cx| { cx.update(|cx| {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into()); let session_id = acp::SessionId("acp-old-no-id".into());
AcpThread::new(self.name, self.clone(), project, session_id, cx) let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(self.name, self.clone(), project, action_log, session_id)
}); });
current_thread.replace(thread.downgrade()); current_thread.replace(thread.downgrade());
thread thread

View file

@ -1,3 +1,4 @@
use action_log::ActionLog;
use agent_client_protocol::{self as acp, Agent as _}; use agent_client_protocol::{self as acp, Agent as _};
use anyhow::anyhow; use anyhow::anyhow;
use collections::HashMap; use collections::HashMap;
@ -13,7 +14,7 @@ use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::{AgentServerCommand, acp::UnsupportedVersion}; use crate::{AgentServerCommand, acp::UnsupportedVersion};
use acp_thread::{AcpThread, AgentConnection, AuthRequired}; use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError};
pub struct AcpConnection { pub struct AcpConnection {
server_name: &'static str, server_name: &'static str,
@ -86,7 +87,9 @@ impl AcpConnection {
for session in sessions.borrow().values() { for session in sessions.borrow().values() {
session session
.thread .thread
.update(cx, |thread, cx| thread.emit_server_exited(status, cx)) .update(cx, |thread, cx| {
thread.emit_load_error(LoadError::Exited { status }, cx)
})
.ok(); .ok();
} }
@ -140,21 +143,27 @@ impl AgentConnection for AcpConnection {
.await .await
.map_err(|err| { .map_err(|err| {
if err.code == acp::ErrorCode::AUTH_REQUIRED.code { if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
anyhow!(AuthRequired) let mut error = AuthRequired::new();
if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
error = error.with_description(err.message);
}
anyhow!(error)
} else { } else {
anyhow!(err) anyhow!(err)
} }
})?; })?;
let session_id = response.session_id; let session_id = response.session_id;
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|cx| { let thread = cx.new(|_cx| {
AcpThread::new( AcpThread::new(
self.server_name, self.server_name,
self.clone(), self.clone(),
project, project,
action_log,
session_id.clone(), session_id.clone(),
cx,
) )
})?; })?;

View file

@ -18,6 +18,7 @@ use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
any::Any,
path::{Path, PathBuf}, path::{Path, PathBuf},
rc::Rc, rc::Rc,
sync::Arc, sync::Arc,
@ -40,6 +41,14 @@ pub trait AgentServer: Send {
project: &Entity<Project>, project: &Entity<Project>,
cx: &mut App, cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>>; ) -> Task<Result<Rc<dyn AgentConnection>>>;
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
}
impl dyn AgentServer {
pub fn downcast<T: 'static + AgentServer + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
self.into_any().downcast().ok()
}
} }
impl std::fmt::Debug for AgentServerCommand { impl std::fmt::Debug for AgentServerCommand {
@ -95,7 +104,7 @@ impl AgentServerCommand {
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Option<Self> { ) -> Option<Self> {
if let Some(agent_settings) = settings { if let Some(agent_settings) = settings {
return Some(Self { Some(Self {
path: agent_settings.command.path, path: agent_settings.command.path,
args: agent_settings args: agent_settings
.command .command
@ -104,7 +113,7 @@ impl AgentServerCommand {
.chain(extra_args.iter().map(|arg| arg.to_string())) .chain(extra_args.iter().map(|arg| arg.to_string()))
.collect(), .collect(),
env: agent_settings.command.env, env: agent_settings.command.env,
}); })
} else { } else {
match find_bin_in_path(path_bin_name, project, cx).await { match find_bin_in_path(path_bin_name, project, cx).await {
Some(path) => Some(Self { Some(path) => Some(Self {

View file

@ -1,16 +1,23 @@
mod edit_tool;
mod mcp_server; mod mcp_server;
mod permission_tool;
mod read_tool;
pub mod tools; pub mod tools;
mod write_tool;
use action_log::ActionLog;
use collections::HashMap; use collections::HashMap;
use context_server::listener::McpServerTool; use context_server::listener::McpServerTool;
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use smol::process::Child; use smol::process::Child;
use std::any::Any; use std::any::Any;
use std::cell::RefCell; use std::cell::RefCell;
use std::fmt::Display; use std::fmt::Display;
use std::path::Path; use std::path::{Path, PathBuf};
use std::rc::Rc; use std::rc::Rc;
use util::command::new_smol_command;
use uuid::Uuid; use uuid::Uuid;
use agent_client_protocol as acp; use agent_client_protocol as acp;
@ -30,7 +37,7 @@ use util::{ResultExt, debug_panic};
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
use crate::claude::tools::ClaudeTool; use crate::claude::tools::ClaudeTool;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
use acp_thread::{AcpThread, AgentConnection}; use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError, MentionUri};
#[derive(Clone)] #[derive(Clone)]
pub struct ClaudeCode; pub struct ClaudeCode;
@ -64,6 +71,10 @@ impl AgentServer for ClaudeCode {
Task::ready(Ok(Rc::new(connection) as _)) Task::ready(Ok(Rc::new(connection) as _))
} }
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
} }
struct ClaudeAgentConnection { struct ClaudeAgentConnection {
@ -79,8 +90,43 @@ impl AgentConnection for ClaudeAgentConnection {
) -> Task<Result<Entity<AcpThread>>> { ) -> Task<Result<Entity<AcpThread>>> {
let cwd = cwd.to_owned(); let cwd = cwd.to_owned();
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).claude.clone()
})?;
let Some(command) = AgentServerCommand::resolve(
"claude",
&[],
Some(&util::paths::home_dir().join(".claude/local/claude")),
settings,
&project,
cx,
)
.await
else {
return Err(LoadError::NotInstalled {
error_message: "Failed to find Claude Code binary".into(),
install_message: "Install Claude Code".into(),
install_command: "npm install -g @anthropic-ai/claude-code@latest".into(),
}.into());
};
let api_key =
cx.update(AnthropicLanguageModelProvider::api_key)?
.await
.map_err(|err| {
if err.is::<language_model::AuthenticateError>() {
anyhow!(AuthRequired::new().with_language_model_provider(
language_model::ANTHROPIC_PROVIDER_ID
))
} else {
anyhow!(err)
}
})?;
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
let mut mcp_servers = HashMap::default(); let mut mcp_servers = HashMap::default();
mcp_servers.insert( mcp_servers.insert(
@ -98,23 +144,6 @@ impl AgentConnection for ClaudeAgentConnection {
.await?; .await?;
mcp_config_file.flush().await?; mcp_config_file.flush().await?;
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).claude.clone()
})?;
let Some(command) = AgentServerCommand::resolve(
"claude",
&[],
Some(&util::paths::home_dir().join(".claude/local/claude")),
settings,
&project,
cx,
)
.await
else {
anyhow::bail!("Failed to find claude binary");
};
let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
@ -126,6 +155,7 @@ impl AgentConnection for ClaudeAgentConnection {
&command, &command,
ClaudeSessionMode::Start, ClaudeSessionMode::Start,
session_id.clone(), session_id.clone(),
api_key,
&mcp_config_path, &mcp_config_path,
&cwd, &cwd,
)?; )?;
@ -183,20 +213,50 @@ impl AgentConnection for ClaudeAgentConnection {
.await .await
} }
if let Some(status) = child.status().await.log_err() { if let Some(status) = child.status().await.log_err()
if let Some(thread) = thread_rx.recv().await.ok() { && let Some(thread) = thread_rx.recv().await.ok()
thread {
.update(cx, |thread, cx| { let version = claude_version(command.path.clone(), cx).await.log_err();
thread.emit_server_exited(status, cx); let help = claude_help(command.path.clone(), cx).await.log_err();
}) thread
.ok(); .update(cx, |thread, cx| {
} let error = if let Some(version) = version
&& let Some(help) = help
&& (!help.contains("--input-format")
|| !help.contains("--session-id"))
{
LoadError::Unsupported {
error_message: format!(
"Your installed version of Claude Code ({}, version {}) does not have required features for use with Zed.",
command.path.to_string_lossy(),
version,
)
.into(),
upgrade_message: "Upgrade Claude Code to latest".into(),
upgrade_command: format!(
"{} update",
command.path.to_string_lossy()
),
}
} else {
LoadError::Exited { status }
};
thread.emit_load_error(error, cx);
})
.ok();
} }
} }
}); });
let thread = cx.new(|cx| { let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) let thread = cx.new(|_cx| {
AcpThread::new(
"Claude Code",
self.clone(),
project,
action_log,
session_id.clone(),
)
})?; })?;
thread_tx.send(thread.downgrade())?; thread_tx.send(thread.downgrade())?;
@ -239,27 +299,12 @@ impl AgentConnection for ClaudeAgentConnection {
let (end_tx, end_rx) = oneshot::channel(); let (end_tx, end_rx) = oneshot::channel();
session.turn_state.replace(TurnState::InProgress { end_tx }); session.turn_state.replace(TurnState::InProgress { end_tx });
let mut content = String::new(); let content = acp_content_to_claude(params.prompt);
for chunk in params.prompt {
match chunk {
acp::ContentBlock::Text(text_content) => {
content.push_str(&text_content.text);
}
acp::ContentBlock::ResourceLink(resource_link) => {
content.push_str(&format!("@{}", resource_link.uri));
}
acp::ContentBlock::Audio(_)
| acp::ContentBlock::Image(_)
| acp::ContentBlock::Resource(_) => {
// TODO
}
}
}
if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User { if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
message: Message { message: Message {
role: Role::User, role: Role::User,
content: Content::UntaggedText(content), content: Content::Chunks(content),
id: None, id: None,
model: None, model: None,
stop_reason: None, stop_reason: None,
@ -276,7 +321,7 @@ impl AgentConnection for ClaudeAgentConnection {
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
let sessions = self.sessions.borrow(); let sessions = self.sessions.borrow();
let Some(session) = sessions.get(&session_id) else { let Some(session) = sessions.get(session_id) else {
log::warn!("Attempted to cancel nonexistent session {}", session_id); log::warn!("Attempted to cancel nonexistent session {}", session_id);
return; return;
}; };
@ -320,6 +365,7 @@ fn spawn_claude(
command: &AgentServerCommand, command: &AgentServerCommand,
mode: ClaudeSessionMode, mode: ClaudeSessionMode,
session_id: acp::SessionId, session_id: acp::SessionId,
api_key: language_models::provider::anthropic::ApiKey,
mcp_config_path: &Path, mcp_config_path: &Path,
root_dir: &Path, root_dir: &Path,
) -> Result<Child> { ) -> Result<Child> {
@ -337,24 +383,24 @@ fn spawn_claude(
&format!( &format!(
"mcp__{}__{}", "mcp__{}__{}",
mcp_server::SERVER_NAME, mcp_server::SERVER_NAME,
mcp_server::PermissionTool::NAME, permission_tool::PermissionTool::NAME,
), ),
"--allowedTools", "--allowedTools",
&format!( &format!(
"mcp__{}__{},mcp__{}__{}", "mcp__{}__{}",
mcp_server::SERVER_NAME, mcp_server::SERVER_NAME,
mcp_server::EditTool::NAME, read_tool::ReadTool::NAME
mcp_server::SERVER_NAME,
mcp_server::ReadTool::NAME
), ),
"--disallowedTools", "--disallowedTools",
"Read,Edit", "Read,Write,Edit,MultiEdit",
]) ])
.args(match mode { .args(match mode {
ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()], ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()], ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
}) })
.args(command.args.iter().map(|arg| arg.as_str())) .args(command.args.iter().map(|arg| arg.as_str()))
.envs(command.env.iter().flatten())
.env("ANTHROPIC_API_KEY", api_key.key)
.current_dir(root_dir) .current_dir(root_dir)
.stdin(std::process::Stdio::piped()) .stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped())
@ -365,6 +411,27 @@ fn spawn_claude(
Ok(child) Ok(child)
} }
fn claude_version(path: PathBuf, cx: &mut AsyncApp) -> Task<Result<semver::Version>> {
cx.background_spawn(async move {
let output = new_smol_command(path).arg("--version").output().await?;
let output = String::from_utf8(output.stdout)?;
let version = output
.trim()
.strip_suffix(" (Claude Code)")
.context("parsing Claude version")?;
let version = semver::Version::parse(version)?;
anyhow::Ok(version)
})
}
fn claude_help(path: PathBuf, cx: &mut AsyncApp) -> Task<Result<String>> {
cx.background_spawn(async move {
let output = new_smol_command(path).arg("--help").output().await?;
let output = String::from_utf8(output.stdout)?;
anyhow::Ok(output)
})
}
struct ClaudeAgentSession { struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>, outgoing_tx: UnboundedSender<SdkMessage>,
turn_state: Rc<RefCell<TurnState>>, turn_state: Rc<RefCell<TurnState>>,
@ -454,9 +521,16 @@ impl ClaudeAgentSession {
let content = content.to_string(); let content = content.to_string();
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
let id = acp::ToolCallId(tool_use_id.into());
let set_new_content = !content.is_empty()
&& thread.tool_call(&id).is_none_or(|(_, tool_call)| {
// preserve rich diff if we have one
tool_call.diffs().next().is_none()
});
thread.update_tool_call( thread.update_tool_call(
acp::ToolCallUpdate { acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()), id,
fields: acp::ToolCallUpdateFields { fields: acp::ToolCallUpdateFields {
status: if turn_state.borrow().is_canceled() { status: if turn_state.borrow().is_canceled() {
// Do not set to completed if turn was canceled // Do not set to completed if turn was canceled
@ -464,7 +538,7 @@ impl ClaudeAgentSession {
} else { } else {
Some(acp::ToolCallStatus::Completed) Some(acp::ToolCallStatus::Completed)
}, },
content: (!content.is_empty()) content: set_new_content
.then(|| vec![content.into()]), .then(|| vec![content.into()]),
..Default::default() ..Default::default()
}, },
@ -482,10 +556,17 @@ impl ClaudeAgentSession {
chunk chunk
); );
} }
ContentChunk::Image { source } => {
if !turn_state.borrow().is_canceled() {
thread
.update(cx, |thread, cx| {
thread.push_user_content_block(None, source.into(), cx)
})
.log_err();
}
}
ContentChunk::Image ContentChunk::Document | ContentChunk::WebSearchToolResult => {
| ContentChunk::Document
| ContentChunk::WebSearchToolResult => {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.push_assistant_content_block( thread.push_assistant_content_block(
@ -571,7 +652,14 @@ impl ClaudeAgentSession {
"Should not get tool results with role: assistant. should we handle this?" "Should not get tool results with role: assistant. should we handle this?"
); );
} }
ContentChunk::Image | ContentChunk::Document => { ContentChunk::Image { source } => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(source.into(), false, cx)
})
.log_err();
}
ContentChunk::Document => {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.push_assistant_content_block( thread.push_assistant_content_block(
@ -737,14 +825,44 @@ enum ContentChunk {
thinking: String, thinking: String,
}, },
RedactedThinking, RedactedThinking,
Image {
source: ImageSource,
},
// TODO // TODO
Image,
Document, Document,
WebSearchToolResult, WebSearchToolResult,
#[serde(untagged)] #[serde(untagged)]
UntaggedText(String), UntaggedText(String),
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ImageSource {
Base64 { data: String, media_type: String },
Url { url: String },
}
impl Into<acp::ContentBlock> for ImageSource {
fn into(self) -> acp::ContentBlock {
match self {
ImageSource::Base64 { data, media_type } => {
acp::ContentBlock::Image(acp::ImageContent {
annotations: None,
data,
mime_type: media_type,
uri: None,
})
}
ImageSource::Url { url } => acp::ContentBlock::Image(acp::ImageContent {
annotations: None,
data: "".to_string(),
mime_type: "".to_string(),
uri: Some(url),
}),
}
}
}
impl Display for ContentChunk { impl Display for ContentChunk {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
@ -753,7 +871,7 @@ impl Display for ContentChunk {
ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"), ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
ContentChunk::UntaggedText(text) => write!(f, "{}", text), ContentChunk::UntaggedText(text) => write!(f, "{}", text),
ContentChunk::ToolResult { content, .. } => write!(f, "{}", content), ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
ContentChunk::Image ContentChunk::Image { .. }
| ContentChunk::Document | ContentChunk::Document
| ContentChunk::ToolUse { .. } | ContentChunk::ToolUse { .. }
| ContentChunk::WebSearchToolResult => { | ContentChunk::WebSearchToolResult => {
@ -865,6 +983,75 @@ impl Display for ResultErrorType {
} }
} }
fn acp_content_to_claude(prompt: Vec<acp::ContentBlock>) -> Vec<ContentChunk> {
let mut content = Vec::with_capacity(prompt.len());
let mut context = Vec::with_capacity(prompt.len());
for chunk in prompt {
match chunk {
acp::ContentBlock::Text(text_content) => {
content.push(ContentChunk::Text {
text: text_content.text,
});
}
acp::ContentBlock::ResourceLink(resource_link) => {
match MentionUri::parse(&resource_link.uri) {
Ok(uri) => {
content.push(ContentChunk::Text {
text: format!("{}", uri.as_link()),
});
}
Err(_) => {
content.push(ContentChunk::Text {
text: resource_link.uri,
});
}
}
}
acp::ContentBlock::Resource(resource) => match resource.resource {
acp::EmbeddedResourceResource::TextResourceContents(resource) => {
match MentionUri::parse(&resource.uri) {
Ok(uri) => {
content.push(ContentChunk::Text {
text: format!("{}", uri.as_link()),
});
}
Err(_) => {
content.push(ContentChunk::Text {
text: resource.uri.clone(),
});
}
}
context.push(ContentChunk::Text {
text: format!(
"\n<context ref=\"{}\">\n{}\n</context>",
resource.uri, resource.text
),
});
}
acp::EmbeddedResourceResource::BlobResourceContents(_) => {
// Unsupported by SDK
}
},
acp::ContentBlock::Image(acp::ImageContent {
data, mime_type, ..
}) => content.push(ContentChunk::Image {
source: ImageSource::Base64 {
data,
media_type: mime_type,
},
}),
acp::ContentBlock::Audio(_) => {
// Unsupported by SDK
}
}
}
content.extend(context);
content
}
fn new_request_id() -> String { fn new_request_id() -> String {
use rand::Rng; use rand::Rng;
// In the Claude Code TS SDK they just generate a random 12 character string, // In the Claude Code TS SDK they just generate a random 12 character string,
@ -1081,4 +1268,100 @@ pub(crate) mod tests {
_ => panic!("Expected ToolResult variant"), _ => panic!("Expected ToolResult variant"),
} }
} }
#[test]
fn test_acp_content_to_claude() {
let acp_content = vec![
acp::ContentBlock::Text(acp::TextContent {
text: "Hello world".to_string(),
annotations: None,
}),
acp::ContentBlock::Image(acp::ImageContent {
data: "base64data".to_string(),
mime_type: "image/png".to_string(),
annotations: None,
uri: None,
}),
acp::ContentBlock::ResourceLink(acp::ResourceLink {
uri: "file:///path/to/example.rs".to_string(),
name: "example.rs".to_string(),
annotations: None,
description: None,
mime_type: None,
size: None,
title: None,
}),
acp::ContentBlock::Resource(acp::EmbeddedResource {
annotations: None,
resource: acp::EmbeddedResourceResource::TextResourceContents(
acp::TextResourceContents {
mime_type: None,
text: "fn main() { println!(\"Hello!\"); }".to_string(),
uri: "file:///path/to/code.rs".to_string(),
},
),
}),
acp::ContentBlock::ResourceLink(acp::ResourceLink {
uri: "invalid_uri_format".to_string(),
name: "invalid.txt".to_string(),
annotations: None,
description: None,
mime_type: None,
size: None,
title: None,
}),
];
let claude_content = acp_content_to_claude(acp_content);
assert_eq!(claude_content.len(), 6);
match &claude_content[0] {
ContentChunk::Text { text } => assert_eq!(text, "Hello world"),
_ => panic!("Expected Text chunk"),
}
match &claude_content[1] {
ContentChunk::Image { source } => match source {
ImageSource::Base64 { data, media_type } => {
assert_eq!(data, "base64data");
assert_eq!(media_type, "image/png");
}
_ => panic!("Expected Base64 image source"),
},
_ => panic!("Expected Image chunk"),
}
match &claude_content[2] {
ContentChunk::Text { text } => {
assert!(text.contains("example.rs"));
assert!(text.contains("file:///path/to/example.rs"));
}
_ => panic!("Expected Text chunk for ResourceLink"),
}
match &claude_content[3] {
ContentChunk::Text { text } => {
assert!(text.contains("code.rs"));
assert!(text.contains("file:///path/to/code.rs"));
}
_ => panic!("Expected Text chunk for Resource"),
}
match &claude_content[4] {
ContentChunk::Text { text } => {
assert_eq!(text, "invalid_uri_format");
}
_ => panic!("Expected Text chunk for invalid URI"),
}
match &claude_content[5] {
ContentChunk::Text { text } => {
assert!(text.contains("<context ref=\"file:///path/to/code.rs\">"));
assert!(text.contains("fn main() { println!(\"Hello!\"); }"));
assert!(text.contains("</context>"));
}
_ => panic!("Expected Text chunk for context"),
}
}
} }

View file

@ -0,0 +1,178 @@
use acp_thread::AcpThread;
use anyhow::Result;
use context_server::{
listener::{McpServerTool, ToolResponse},
types::{ToolAnnotations, ToolResponseContent},
};
use gpui::{AsyncApp, WeakEntity};
use language::unified_diff;
use util::markdown::MarkdownCodeBlock;
use crate::tools::EditToolParams;
#[derive(Clone)]
pub struct EditTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl EditTool {
pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
Self { thread_rx }
}
}
impl McpServerTool for EditTool {
type Input = EditToolParams;
type Output = ();
const NAME: &'static str = "Edit";
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Edit file".to_string()),
read_only_hint: Some(false),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: Some(false),
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
})?
.await?;
let (new_content, diff) = cx
.background_executor()
.spawn(async move {
let new_content = content.replace(&input.old_text, &input.new_text);
if new_content == content {
return Err(anyhow::anyhow!("Failed to find `old_text`",));
}
let diff = unified_diff(&content, &new_content);
Ok((new_content, diff))
})
.await?;
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.abs_path, new_content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: MarkdownCodeBlock {
tag: "diff",
text: diff.as_str().trim_end_matches('\n'),
}
.to_string(),
}],
structured_content: (),
})
}
}
#[cfg(test)]
mod tests {
use std::rc::Rc;
use acp_thread::{AgentConnection, StubAgentConnection};
use gpui::{Entity, TestAppContext};
use indoc::indoc;
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
use super::*;
#[gpui::test]
async fn old_text_not_found(cx: &mut TestAppContext) {
let (_thread, tool) = init_test(cx).await;
let result = tool
.run(
EditToolParams {
abs_path: path!("/root/file.txt").into(),
old_text: "hi".into(),
new_text: "bye".into(),
},
&mut cx.to_async(),
)
.await;
assert_eq!(result.unwrap_err().to_string(), "Failed to find `old_text`");
}
#[gpui::test]
async fn found_and_replaced(cx: &mut TestAppContext) {
let (_thread, tool) = init_test(cx).await;
let result = tool
.run(
EditToolParams {
abs_path: path!("/root/file.txt").into(),
old_text: "hello".into(),
new_text: "hi".into(),
},
&mut cx.to_async(),
)
.await;
assert_eq!(
result.unwrap().content[0].text().unwrap(),
indoc! {
r"
```diff
@@ -1,1 +1,1 @@
-hello
+hi
```
"
}
);
}
async fn init_test(cx: &mut TestAppContext) -> (Entity<AcpThread>, EditTool) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
let connection = Rc::new(StubAgentConnection::new());
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"file.txt": "hello"
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let thread = cx
.update(|cx| connection.new_thread(project, path!("/test").as_ref(), cx))
.await
.unwrap();
thread_tx.send(thread.downgrade()).unwrap();
(thread, EditTool::new(thread_rx))
}
}

View file

@ -1,18 +1,22 @@
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; use crate::claude::edit_tool::EditTool;
use crate::claude::permission_tool::PermissionTool;
use crate::claude::read_tool::ReadTool;
use crate::claude::write_tool::WriteTool;
use acp_thread::AcpThread; use acp_thread::AcpThread;
use agent_client_protocol as acp; #[cfg(not(test))]
use anyhow::{Context, Result}; use anyhow::Context as _;
use anyhow::Result;
use collections::HashMap; use collections::HashMap;
use context_server::listener::{McpServerTool, ToolResponse};
use context_server::types::{ use context_server::types::{
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, ToolsCapabilities, requests,
}; };
use gpui::{App, AsyncApp, Task, WeakEntity}; use gpui::{App, AsyncApp, Task, WeakEntity};
use schemars::JsonSchema; use project::Fs;
use serde::{Deserialize, Serialize}; use serde::Serialize;
pub struct ClaudeZedMcpServer { pub struct ClaudeZedMcpServer {
server: context_server::listener::McpServer, server: context_server::listener::McpServer,
@ -23,20 +27,16 @@ pub const SERVER_NAME: &str = "zed";
impl ClaudeZedMcpServer { impl ClaudeZedMcpServer {
pub async fn new( pub async fn new(
thread_rx: watch::Receiver<WeakEntity<AcpThread>>, thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
fs: Arc<dyn Fs>,
cx: &AsyncApp, cx: &AsyncApp,
) -> Result<Self> { ) -> Result<Self> {
let mut mcp_server = context_server::listener::McpServer::new(cx).await?; let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize); mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
mcp_server.add_tool(PermissionTool { mcp_server.add_tool(PermissionTool::new(fs.clone(), thread_rx.clone()));
thread_rx: thread_rx.clone(), mcp_server.add_tool(ReadTool::new(thread_rx.clone()));
}); mcp_server.add_tool(EditTool::new(thread_rx.clone()));
mcp_server.add_tool(ReadTool { mcp_server.add_tool(WriteTool::new(thread_rx.clone()));
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(EditTool {
thread_rx: thread_rx.clone(),
});
Ok(Self { server: mcp_server }) Ok(Self { server: mcp_server })
} }
@ -97,206 +97,3 @@ pub struct McpServerConfig {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>, pub env: Option<HashMap<String, String>>,
} }
// Tools
#[derive(Clone)]
pub struct PermissionTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
#[derive(Deserialize, JsonSchema, Debug)]
pub struct PermissionToolParams {
tool_name: String,
input: serde_json::Value,
tool_use_id: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PermissionToolResponse {
behavior: PermissionToolBehavior,
updated_input: serde_json::Value,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
enum PermissionToolBehavior {
Allow,
Deny,
}
impl McpServerTool for PermissionTool {
type Input = PermissionToolParams;
type Output = ();
const NAME: &'static str = "Confirmation";
fn description(&self) -> &'static str {
"Request permission for tool calls"
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
let allow_option_id = acp::PermissionOptionId("allow".into());
let reject_option_id = acp::PermissionOptionId("reject".into());
let chosen_option = thread
.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
claude_tool.as_acp(tool_call_id).into(),
vec![
acp::PermissionOption {
id: allow_option_id.clone(),
name: "Allow".into(),
kind: acp::PermissionOptionKind::AllowOnce,
},
acp::PermissionOption {
id: reject_option_id.clone(),
name: "Reject".into(),
kind: acp::PermissionOptionKind::RejectOnce,
},
],
cx,
)
})??
.await?;
let response = if chosen_option == allow_option_id {
PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
}
} else {
debug_assert_eq!(chosen_option, reject_option_id);
PermissionToolResponse {
behavior: PermissionToolBehavior::Deny,
updated_input: input.input,
}
};
Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&response)?,
}],
structured_content: (),
})
}
}
#[derive(Clone)]
pub struct ReadTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for ReadTool {
type Input = ReadToolParams;
type Output = ();
const NAME: &'static str = "Read";
fn description(&self) -> &'static str {
"Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents."
}
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Read file".to_string()),
read_only_hint: Some(true),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: None,
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![ToolResponseContent::Text { text: content }],
structured_content: (),
})
}
}
#[derive(Clone)]
pub struct EditTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for EditTool {
type Input = EditToolParams;
type Output = ();
const NAME: &'static str = "Edit";
fn description(&self) -> &'static str {
"Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better."
}
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Edit file".to_string()),
read_only_hint: Some(false),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: Some(false),
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
})?
.await?;
let new_content = content.replace(&input.old_text, &input.new_text);
if new_content == content {
return Err(anyhow::anyhow!("The old_text was not found in the content"));
}
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.abs_path, new_content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: (),
})
}
}

View file

@ -0,0 +1,158 @@
use std::sync::Arc;
use acp_thread::AcpThread;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use context_server::{
listener::{McpServerTool, ToolResponse},
types::ToolResponseContent,
};
use gpui::{AsyncApp, WeakEntity};
use project::Fs;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings as _, update_settings_file};
use util::debug_panic;
use crate::tools::ClaudeTool;
#[derive(Clone)]
pub struct PermissionTool {
fs: Arc<dyn Fs>,
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
/// Request permission for tool calls
#[derive(Deserialize, JsonSchema, Debug)]
pub struct PermissionToolParams {
tool_name: String,
input: serde_json::Value,
tool_use_id: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PermissionToolResponse {
behavior: PermissionToolBehavior,
updated_input: serde_json::Value,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
enum PermissionToolBehavior {
Allow,
Deny,
}
impl PermissionTool {
pub fn new(fs: Arc<dyn Fs>, thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
Self { fs, thread_rx }
}
}
impl McpServerTool for PermissionTool {
type Input = PermissionToolParams;
type Output = ();
const NAME: &'static str = "Confirmation";
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
if agent_settings::AgentSettings::try_read_global(cx, |settings| {
settings.always_allow_tool_actions
})
.unwrap_or(false)
{
let response = PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
};
return Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&response)?,
}],
structured_content: (),
});
}
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
const ALWAYS_ALLOW: &str = "always_allow";
const ALLOW: &str = "allow";
const REJECT: &str = "reject";
let chosen_option = thread
.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
claude_tool.as_acp(tool_call_id).into(),
vec![
acp::PermissionOption {
id: acp::PermissionOptionId(ALWAYS_ALLOW.into()),
name: "Always Allow".into(),
kind: acp::PermissionOptionKind::AllowAlways,
},
acp::PermissionOption {
id: acp::PermissionOptionId(ALLOW.into()),
name: "Allow".into(),
kind: acp::PermissionOptionKind::AllowOnce,
},
acp::PermissionOption {
id: acp::PermissionOptionId(REJECT.into()),
name: "Reject".into(),
kind: acp::PermissionOptionKind::RejectOnce,
},
],
cx,
)
})??
.await?;
let response = match chosen_option.0.as_ref() {
ALWAYS_ALLOW => {
cx.update(|cx| {
update_settings_file::<AgentSettings>(self.fs.clone(), cx, |settings, _| {
settings.set_always_allow_tool_actions(true);
});
})?;
PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
}
}
ALLOW => PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
},
REJECT => PermissionToolResponse {
behavior: PermissionToolBehavior::Deny,
updated_input: input.input,
},
opt => {
debug_panic!("Unexpected option: {}", opt);
PermissionToolResponse {
behavior: PermissionToolBehavior::Deny,
updated_input: input.input,
}
}
};
Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&response)?,
}],
structured_content: (),
})
}
}

View file

@ -0,0 +1,59 @@
use acp_thread::AcpThread;
use anyhow::Result;
use context_server::{
listener::{McpServerTool, ToolResponse},
types::{ToolAnnotations, ToolResponseContent},
};
use gpui::{AsyncApp, WeakEntity};
use crate::tools::ReadToolParams;
#[derive(Clone)]
pub struct ReadTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl ReadTool {
pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
Self { thread_rx }
}
}
impl McpServerTool for ReadTool {
type Input = ReadToolParams;
type Output = ();
const NAME: &'static str = "Read";
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Read file".to_string()),
read_only_hint: Some(true),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: None,
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![ToolResponseContent::Text { text: content }],
structured_content: (),
})
}
}

View file

@ -34,6 +34,7 @@ impl ClaudeTool {
// Known tools // Known tools
"mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()), "mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()),
"mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()), "mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()),
"mcp__zed__Write" => Self::Write(serde_json::from_value(input).log_err()),
"MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()), "MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()),
"Write" => Self::Write(serde_json::from_value(input).log_err()), "Write" => Self::Write(serde_json::from_value(input).log_err()),
"LS" => Self::Ls(serde_json::from_value(input).log_err()), "LS" => Self::Ls(serde_json::from_value(input).log_err()),
@ -93,7 +94,7 @@ impl ClaudeTool {
} }
Self::MultiEdit(None) => "Multi Edit".into(), Self::MultiEdit(None) => "Multi Edit".into(),
Self::Write(Some(params)) => { Self::Write(Some(params)) => {
format!("Write {}", params.file_path.display()) format!("Write {}", params.abs_path.display())
} }
Self::Write(None) => "Write".into(), Self::Write(None) => "Write".into(),
Self::Glob(Some(params)) => { Self::Glob(Some(params)) => {
@ -153,7 +154,7 @@ impl ClaudeTool {
}], }],
Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff {
diff: acp::Diff { diff: acp::Diff {
path: params.file_path.clone(), path: params.abs_path.clone(),
old_text: None, old_text: None,
new_text: params.content.clone(), new_text: params.content.clone(),
}, },
@ -229,7 +230,10 @@ impl ClaudeTool {
line: None, line: None,
}] }]
} }
Self::Write(Some(WriteToolParams { file_path, .. })) => { Self::Write(Some(WriteToolParams {
abs_path: file_path,
..
})) => {
vec![acp::ToolCallLocation { vec![acp::ToolCallLocation {
path: file_path.clone(), path: file_path.clone(),
line: None, line: None,
@ -302,6 +306,20 @@ impl ClaudeTool {
} }
} }
/// Edit a file.
///
/// In sessions with mcp__zed__Edit always use it instead of Edit as it will
/// allow the user to conveniently review changes.
///
/// File editing instructions:
/// - The `old_text` param must match existing file content, including indentation.
/// - The `old_text` param must come from the actual file, not an outline.
/// - The `old_text` section must not be empty.
/// - Be minimal with replacements:
/// - For unique lines, include only those lines.
/// - For non-unique lines, include enough context to identify them.
/// - Do not escape quotes, newlines, or other characters.
/// - Only edit the specified file.
#[derive(Deserialize, JsonSchema, Debug)] #[derive(Deserialize, JsonSchema, Debug)]
pub struct EditToolParams { pub struct EditToolParams {
/// The absolute path to the file to read. /// The absolute path to the file to read.
@ -312,6 +330,11 @@ pub struct EditToolParams {
pub new_text: String, pub new_text: String,
} }
/// Reads the content of the given file in the project.
///
/// Never attempt to read a path that hasn't been previously mentioned.
///
/// In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.
#[derive(Deserialize, JsonSchema, Debug)] #[derive(Deserialize, JsonSchema, Debug)]
pub struct ReadToolParams { pub struct ReadToolParams {
/// The absolute path to the file to read. /// The absolute path to the file to read.
@ -324,11 +347,15 @@ pub struct ReadToolParams {
pub limit: Option<u32>, pub limit: Option<u32>,
} }
/// Writes content to the specified file in the project.
///
/// In sessions with mcp__zed__Write always use it instead of Write as it will
/// allow the user to conveniently review changes.
#[derive(Deserialize, JsonSchema, Debug)] #[derive(Deserialize, JsonSchema, Debug)]
pub struct WriteToolParams { pub struct WriteToolParams {
/// Absolute path for new file /// The absolute path of the file to write.
pub file_path: PathBuf, pub abs_path: PathBuf,
/// File content /// The full content to write.
pub content: String, pub content: String,
} }

View file

@ -0,0 +1,59 @@
use acp_thread::AcpThread;
use anyhow::Result;
use context_server::{
listener::{McpServerTool, ToolResponse},
types::ToolAnnotations,
};
use gpui::{AsyncApp, WeakEntity};
use crate::tools::WriteToolParams;
#[derive(Clone)]
pub struct WriteTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl WriteTool {
pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
Self { thread_rx }
}
}
impl McpServerTool for WriteTool {
type Input = WriteToolParams;
type Output = ();
const NAME: &'static str = "Write";
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Write file".to_string()),
read_only_hint: Some(false),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: Some(false),
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.abs_path, input.content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: (),
})
}
}

View file

@ -428,12 +428,9 @@ pub async fn new_test_thread(
.await .await
.unwrap(); .unwrap();
let thread = cx cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
.await .await
.unwrap(); .unwrap()
thread
} }
pub async fn run_until_first_tool_call( pub async fn run_until_first_tool_call(
@ -471,7 +468,7 @@ pub fn get_zed_path() -> PathBuf {
while zed_path while zed_path
.file_name() .file_name()
.map_or(true, |name| name.to_string_lossy() != "debug") .is_none_or(|name| name.to_string_lossy() != "debug")
{ {
if !zed_path.pop() { if !zed_path.pop() {
panic!("Could not find target directory"); panic!("Could not find target directory");

View file

@ -1,5 +1,5 @@
use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
use std::{any::Any, path::Path};
use crate::{AgentServer, AgentServerCommand}; use crate::{AgentServer, AgentServerCommand};
use acp_thread::{AgentConnection, LoadError}; use acp_thread::{AgentConnection, LoadError};
@ -26,7 +26,7 @@ impl AgentServer for Gemini {
} }
fn empty_state_message(&self) -> &'static str { fn empty_state_message(&self) -> &'static str {
"Ask questions, edit files, run commands.\nBe specific for the best results." "Ask questions, edit files, run commands"
} }
fn logo(&self) -> ui::IconName { fn logo(&self) -> ui::IconName {
@ -50,7 +50,11 @@ impl AgentServer for Gemini {
let Some(command) = let Some(command) =
AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await
else { else {
anyhow::bail!("Failed to find gemini binary"); return Err(LoadError::NotInstalled {
error_message: "Failed to find Gemini CLI binary".into(),
install_message: "Install Gemini CLI".into(),
install_command: "npm install -g @google/gemini-cli@latest".into()
}.into());
}; };
let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await; let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await;
@ -75,10 +79,11 @@ impl AgentServer for Gemini {
if !supported { if !supported {
return Err(LoadError::Unsupported { return Err(LoadError::Unsupported {
error_message: format!( error_message: format!(
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", "Your installed version of Gemini CLI ({}, version {}) doesn't support the Agentic Coding Protocol (ACP).",
command.path.to_string_lossy(),
current_version current_version
).into(), ).into(),
upgrade_message: "Upgrade Gemini to Latest".into(), upgrade_message: "Upgrade Gemini CLI to latest".into(),
upgrade_command: "npm install -g @google/gemini-cli@latest".into(), upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
}.into()) }.into())
} }
@ -86,6 +91,10 @@ impl AgentServer for Gemini {
result result
}) })
} }
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -58,7 +58,7 @@ impl AgentProfileSettings {
|| self || self
.context_servers .context_servers
.get(server_id) .get(server_id)
.map_or(false, |preset| preset.tools.get(tool_name) == Some(&true)) .is_some_and(|preset| preset.tools.get(tool_name) == Some(&true))
} }
} }

View file

@ -15,6 +15,8 @@ pub use crate::agent_profile::*;
pub const SUMMARIZE_THREAD_PROMPT: &str = pub const SUMMARIZE_THREAD_PROMPT: &str =
include_str!("../../agent/src/prompts/summarize_thread_prompt.txt"); include_str!("../../agent/src/prompts/summarize_thread_prompt.txt");
pub const SUMMARIZE_THREAD_DETAILED_PROMPT: &str =
include_str!("../../agent/src/prompts/summarize_thread_detailed_prompt.txt");
pub fn init(cx: &mut App) { pub fn init(cx: &mut App) {
AgentSettings::register(cx); AgentSettings::register(cx);
@ -116,15 +118,15 @@ pub struct LanguageModelParameters {
impl LanguageModelParameters { impl LanguageModelParameters {
pub fn matches(&self, model: &Arc<dyn LanguageModel>) -> bool { pub fn matches(&self, model: &Arc<dyn LanguageModel>) -> bool {
if let Some(provider) = &self.provider { if let Some(provider) = &self.provider
if provider.0 != model.provider_id().0 { && provider.0 != model.provider_id().0
return false; {
} return false;
} }
if let Some(setting_model) = &self.model { if let Some(setting_model) = &self.model
if *setting_model != model.id().0 { && *setting_model != model.id().0
return false; {
} return false;
} }
true true
} }

View file

@ -104,9 +104,11 @@ zed_actions.workspace = true
[dev-dependencies] [dev-dependencies]
acp_thread = { workspace = true, features = ["test-support"] } acp_thread = { workspace = true, features = ["test-support"] }
agent = { workspace = true, features = ["test-support"] } agent = { workspace = true, features = ["test-support"] }
agent2 = { workspace = true, features = ["test-support"] }
assistant_context = { workspace = true, features = ["test-support"] } assistant_context = { workspace = true, features = ["test-support"] }
assistant_tools.workspace = true assistant_tools.workspace = true
buffer_diff = { workspace = true, features = ["test-support"] } buffer_diff = { workspace = true, features = ["test-support"] }
db = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true indoc.workspace = true

View file

@ -3,8 +3,10 @@ mod entry_view_state;
mod message_editor; mod message_editor;
mod model_selector; mod model_selector;
mod model_selector_popover; mod model_selector_popover;
mod thread_history;
mod thread_view; mod thread_view;
pub use model_selector::AcpModelSelector; pub use model_selector::AcpModelSelector;
pub use model_selector_popover::AcpModelSelectorPopover; pub use model_selector_popover::AcpModelSelectorPopover;
pub use thread_history::*;
pub use thread_view::AcpThreadView; pub use thread_view::AcpThreadView;

View file

@ -3,6 +3,7 @@ use std::sync::Arc;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use acp_thread::MentionUri; use acp_thread::MentionUri;
use agent2::{HistoryEntry, HistoryStore};
use anyhow::Result; use anyhow::Result;
use editor::{CompletionProvider, Editor, ExcerptId}; use editor::{CompletionProvider, Editor, ExcerptId};
use fuzzy::{StringMatch, StringMatchCandidate}; use fuzzy::{StringMatch, StringMatchCandidate};
@ -18,25 +19,21 @@ use text::{Anchor, ToPoint as _};
use ui::prelude::*; use ui::prelude::*;
use workspace::Workspace; use workspace::Workspace;
use agent::thread_store::{TextThreadStore, ThreadStore}; use crate::AgentPanel;
use crate::acp::message_editor::MessageEditor; use crate::acp::message_editor::MessageEditor;
use crate::context_picker::file_context_picker::{FileMatch, search_files}; use crate::context_picker::file_context_picker::{FileMatch, search_files};
use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules}; use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules};
use crate::context_picker::symbol_context_picker::SymbolMatch; use crate::context_picker::symbol_context_picker::SymbolMatch;
use crate::context_picker::symbol_context_picker::search_symbols; use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_picker::thread_context_picker::{
ThreadContextEntry, ThreadMatch, search_threads,
};
use crate::context_picker::{ use crate::context_picker::{
ContextPickerAction, ContextPickerEntry, ContextPickerMode, RecentEntry, ContextPickerAction, ContextPickerEntry, ContextPickerMode, selection_ranges,
available_context_picker_entries, recent_context_picker_entries, selection_ranges,
}; };
pub(crate) enum Match { pub(crate) enum Match {
File(FileMatch), File(FileMatch),
Symbol(SymbolMatch), Symbol(SymbolMatch),
Thread(ThreadMatch), Thread(HistoryEntry),
RecentThread(HistoryEntry),
Fetch(SharedString), Fetch(SharedString),
Rules(RulesContextEntry), Rules(RulesContextEntry),
Entry(EntryMatch), Entry(EntryMatch),
@ -53,6 +50,7 @@ impl Match {
Match::File(file) => file.mat.score, Match::File(file) => file.mat.score,
Match::Entry(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.), Match::Entry(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.),
Match::Thread(_) => 1., Match::Thread(_) => 1.,
Match::RecentThread(_) => 1.,
Match::Symbol(_) => 1., Match::Symbol(_) => 1.,
Match::Rules(_) => 1., Match::Rules(_) => 1.,
Match::Fetch(_) => 1., Match::Fetch(_) => 1.,
@ -60,209 +58,25 @@ impl Match {
} }
} }
fn search(
mode: Option<ContextPickerMode>,
query: String,
cancellation_flag: Arc<AtomicBool>,
recent_entries: Vec<RecentEntry>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>,
text_thread_context_store: WeakEntity<assistant_context::ContextStore>,
workspace: Entity<Workspace>,
cx: &mut App,
) -> Task<Vec<Match>> {
match mode {
Some(ContextPickerMode::File) => {
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_files_task
.await
.into_iter()
.map(Match::File)
.collect()
})
}
Some(ContextPickerMode::Symbol) => {
let search_symbols_task =
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_symbols_task
.await
.into_iter()
.map(Match::Symbol)
.collect()
})
}
Some(ContextPickerMode::Thread) => {
if let Some((thread_store, context_store)) = thread_store
.upgrade()
.zip(text_thread_context_store.upgrade())
{
let search_threads_task = search_threads(
query.clone(),
cancellation_flag.clone(),
thread_store,
context_store,
cx,
);
cx.background_spawn(async move {
search_threads_task
.await
.into_iter()
.map(Match::Thread)
.collect()
})
} else {
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Fetch) => {
if !query.is_empty() {
Task::ready(vec![Match::Fetch(query.into())])
} else {
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Rules) => {
if let Some(prompt_store) = prompt_store.as_ref() {
let search_rules_task =
search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
cx.background_spawn(async move {
search_rules_task
.await
.into_iter()
.map(Match::Rules)
.collect::<Vec<_>>()
})
} else {
Task::ready(Vec::new())
}
}
None => {
if query.is_empty() {
let mut matches = recent_entries
.into_iter()
.map(|entry| match entry {
RecentEntry::File {
project_path,
path_prefix,
} => Match::File(FileMatch {
mat: fuzzy::PathMatch {
score: 1.,
positions: Vec::new(),
worktree_id: project_path.worktree_id.to_usize(),
path: project_path.path,
path_prefix,
is_dir: false,
distance_to_relative_ancestor: 0,
},
is_recent: true,
}),
RecentEntry::Thread(thread_context_entry) => Match::Thread(ThreadMatch {
thread: thread_context_entry,
is_recent: true,
}),
})
.collect::<Vec<_>>();
matches.extend(
available_context_picker_entries(
&prompt_store,
&Some(thread_store.clone()),
&workspace,
cx,
)
.into_iter()
.map(|mode| {
Match::Entry(EntryMatch {
entry: mode,
mat: None,
})
}),
);
Task::ready(matches)
} else {
let executor = cx.background_executor().clone();
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
let entries = available_context_picker_entries(
&prompt_store,
&Some(thread_store.clone()),
&workspace,
cx,
);
let entry_candidates = entries
.iter()
.enumerate()
.map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword()))
.collect::<Vec<_>>();
cx.background_spawn(async move {
let mut matches = search_files_task
.await
.into_iter()
.map(Match::File)
.collect::<Vec<_>>();
let entry_matches = fuzzy::match_strings(
&entry_candidates,
&query,
false,
true,
100,
&Arc::new(AtomicBool::default()),
executor,
)
.await;
matches.extend(entry_matches.into_iter().map(|mat| {
Match::Entry(EntryMatch {
entry: entries[mat.candidate_id],
mat: Some(mat),
})
}));
matches.sort_by(|a, b| {
b.score()
.partial_cmp(&a.score())
.unwrap_or(std::cmp::Ordering::Equal)
});
matches
})
}
}
}
}
pub struct ContextPickerCompletionProvider { pub struct ContextPickerCompletionProvider {
workspace: WeakEntity<Workspace>,
thread_store: WeakEntity<ThreadStore>,
text_thread_store: WeakEntity<TextThreadStore>,
message_editor: WeakEntity<MessageEditor>, message_editor: WeakEntity<MessageEditor>,
workspace: WeakEntity<Workspace>,
history_store: Entity<HistoryStore>,
prompt_store: Option<Entity<PromptStore>>,
} }
impl ContextPickerCompletionProvider { impl ContextPickerCompletionProvider {
pub fn new( pub fn new(
workspace: WeakEntity<Workspace>,
thread_store: WeakEntity<ThreadStore>,
text_thread_store: WeakEntity<TextThreadStore>,
message_editor: WeakEntity<MessageEditor>, message_editor: WeakEntity<MessageEditor>,
workspace: WeakEntity<Workspace>,
history_store: Entity<HistoryStore>,
prompt_store: Option<Entity<PromptStore>>,
) -> Self { ) -> Self {
Self { Self {
workspace,
thread_store,
text_thread_store,
message_editor, message_editor,
workspace,
history_store,
prompt_store,
} }
} }
@ -349,22 +163,13 @@ impl ContextPickerCompletionProvider {
} }
fn completion_for_thread( fn completion_for_thread(
thread_entry: ThreadContextEntry, thread_entry: HistoryEntry,
source_range: Range<Anchor>, source_range: Range<Anchor>,
recent: bool, recent: bool,
editor: WeakEntity<MessageEditor>, editor: WeakEntity<MessageEditor>,
cx: &mut App, cx: &mut App,
) -> Completion { ) -> Completion {
let uri = match &thread_entry { let uri = thread_entry.mention_uri();
ThreadContextEntry::Thread { id, title } => MentionUri::Thread {
id: id.clone(),
name: title.to_string(),
},
ThreadContextEntry::Context { path, title } => MentionUri::TextThread {
path: path.to_path_buf(),
name: title.to_string(),
},
};
let icon_for_completion = if recent { let icon_for_completion = if recent {
IconName::HistoryRerun.path().into() IconName::HistoryRerun.path().into()
@ -445,19 +250,20 @@ impl ContextPickerCompletionProvider {
let abs_path = project.read(cx).absolute_path(&project_path, cx)?; let abs_path = project.read(cx).absolute_path(&project_path, cx)?;
let file_uri = MentionUri::File { let uri = if is_directory {
abs_path, MentionUri::Directory { abs_path }
is_directory, } else {
MentionUri::File { abs_path }
}; };
let crease_icon_path = file_uri.icon_path(cx); let crease_icon_path = uri.icon_path(cx);
let completion_icon_path = if is_recent { let completion_icon_path = if is_recent {
IconName::HistoryRerun.path().into() IconName::HistoryRerun.path().into()
} else { } else {
crease_icon_path.clone() crease_icon_path.clone()
}; };
let new_text = format!("{} ", file_uri.as_link()); let new_text = format!("{} ", uri.as_link());
let new_text_len = new_text.len(); let new_text_len = new_text.len();
Some(Completion { Some(Completion {
replace_range: source_range.clone(), replace_range: source_range.clone(),
@ -472,7 +278,7 @@ impl ContextPickerCompletionProvider {
source_range.start, source_range.start,
new_text_len - 1, new_text_len - 1,
message_editor, message_editor,
file_uri, uri,
)), )),
}) })
} }
@ -546,17 +352,262 @@ impl ContextPickerCompletionProvider {
)), )),
}) })
} }
fn search(
&self,
mode: Option<ContextPickerMode>,
query: String,
cancellation_flag: Arc<AtomicBool>,
cx: &mut App,
) -> Task<Vec<Match>> {
let Some(workspace) = self.workspace.upgrade() else {
return Task::ready(Vec::default());
};
match mode {
Some(ContextPickerMode::File) => {
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_files_task
.await
.into_iter()
.map(Match::File)
.collect()
})
}
Some(ContextPickerMode::Symbol) => {
let search_symbols_task =
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_symbols_task
.await
.into_iter()
.map(Match::Symbol)
.collect()
})
}
Some(ContextPickerMode::Thread) => {
let search_threads_task = search_threads(
query.clone(),
cancellation_flag.clone(),
&self.history_store,
cx,
);
cx.background_spawn(async move {
search_threads_task
.await
.into_iter()
.map(Match::Thread)
.collect()
})
}
Some(ContextPickerMode::Fetch) => {
if !query.is_empty() {
Task::ready(vec![Match::Fetch(query.into())])
} else {
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Rules) => {
if let Some(prompt_store) = self.prompt_store.as_ref() {
let search_rules_task =
search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
cx.background_spawn(async move {
search_rules_task
.await
.into_iter()
.map(Match::Rules)
.collect::<Vec<_>>()
})
} else {
Task::ready(Vec::new())
}
}
None if query.is_empty() => {
let mut matches = self.recent_context_picker_entries(&workspace, cx);
matches.extend(
self.available_context_picker_entries(&workspace, cx)
.into_iter()
.map(|mode| {
Match::Entry(EntryMatch {
entry: mode,
mat: None,
})
}),
);
Task::ready(matches)
}
None => {
let executor = cx.background_executor().clone();
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
let entries = self.available_context_picker_entries(&workspace, cx);
let entry_candidates = entries
.iter()
.enumerate()
.map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword()))
.collect::<Vec<_>>();
cx.background_spawn(async move {
let mut matches = search_files_task
.await
.into_iter()
.map(Match::File)
.collect::<Vec<_>>();
let entry_matches = fuzzy::match_strings(
&entry_candidates,
&query,
false,
true,
100,
&Arc::new(AtomicBool::default()),
executor,
)
.await;
matches.extend(entry_matches.into_iter().map(|mat| {
Match::Entry(EntryMatch {
entry: entries[mat.candidate_id],
mat: Some(mat),
})
}));
matches.sort_by(|a, b| {
b.score()
.partial_cmp(&a.score())
.unwrap_or(std::cmp::Ordering::Equal)
});
matches
})
}
}
}
fn recent_context_picker_entries(
&self,
workspace: &Entity<Workspace>,
cx: &mut App,
) -> Vec<Match> {
let mut recent = Vec::with_capacity(6);
let mut mentions = self
.message_editor
.read_with(cx, |message_editor, _cx| message_editor.mentions())
.unwrap_or_default();
let workspace = workspace.read(cx);
let project = workspace.project().read(cx);
if let Some(agent_panel) = workspace.panel::<AgentPanel>(cx)
&& let Some(thread) = agent_panel.read(cx).active_agent_thread(cx)
{
let thread = thread.read(cx);
mentions.insert(MentionUri::Thread {
id: thread.session_id().clone(),
name: thread.title().into(),
});
}
recent.extend(
workspace
.recent_navigation_history_iter(cx)
.filter(|(_, abs_path)| {
abs_path.as_ref().is_none_or(|path| {
!mentions.contains(&MentionUri::File {
abs_path: path.clone(),
})
})
})
.take(4)
.filter_map(|(project_path, _)| {
project
.worktree_for_id(project_path.worktree_id, cx)
.map(|worktree| {
let path_prefix = worktree.read(cx).root_name().into();
Match::File(FileMatch {
mat: fuzzy::PathMatch {
score: 1.,
positions: Vec::new(),
worktree_id: project_path.worktree_id.to_usize(),
path: project_path.path,
path_prefix,
is_dir: false,
distance_to_relative_ancestor: 0,
},
is_recent: true,
})
})
}),
);
const RECENT_COUNT: usize = 2;
let threads = self
.history_store
.read(cx)
.recently_opened_entries(cx)
.into_iter()
.filter(|thread| !mentions.contains(&thread.mention_uri()))
.take(RECENT_COUNT)
.collect::<Vec<_>>();
recent.extend(threads.into_iter().map(Match::RecentThread));
recent
}
fn available_context_picker_entries(
&self,
workspace: &Entity<Workspace>,
cx: &mut App,
) -> Vec<ContextPickerEntry> {
let mut entries = vec![
ContextPickerEntry::Mode(ContextPickerMode::File),
ContextPickerEntry::Mode(ContextPickerMode::Symbol),
ContextPickerEntry::Mode(ContextPickerMode::Thread),
];
let has_selection = workspace
.read(cx)
.active_item(cx)
.and_then(|item| item.downcast::<Editor>())
.is_some_and(|editor| {
editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx))
});
if has_selection {
entries.push(ContextPickerEntry::Action(
ContextPickerAction::AddSelections,
));
}
if self.prompt_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
}
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Fetch));
entries
}
} }
fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: &App) -> CodeLabel { fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: &App) -> CodeLabel {
let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId); let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId);
let mut label = CodeLabel::default(); let mut label = CodeLabel::default();
label.push_str(&file_name, None); label.push_str(file_name, None);
label.push_str(" ", None); label.push_str(" ", None);
if let Some(directory) = directory { if let Some(directory) = directory {
label.push_str(&directory, comment_id); label.push_str(directory, comment_id);
} }
label.filter_range = 0..label.text().len(); label.filter_range = 0..label.text().len();
@ -595,45 +646,12 @@ impl CompletionProvider for ContextPickerCompletionProvider {
let source_range = snapshot.anchor_before(state.source_range.start) let source_range = snapshot.anchor_before(state.source_range.start)
..snapshot.anchor_after(state.source_range.end); ..snapshot.anchor_after(state.source_range.end);
let thread_store = self.thread_store.clone();
let text_thread_store = self.text_thread_store.clone();
let editor = self.message_editor.clone(); let editor = self.message_editor.clone();
let Ok((exclude_paths, exclude_threads)) =
self.message_editor.update(cx, |message_editor, _cx| {
message_editor.mentioned_path_and_threads()
})
else {
return Task::ready(Ok(Vec::new()));
};
let MentionCompletion { mode, argument, .. } = state; let MentionCompletion { mode, argument, .. } = state;
let query = argument.unwrap_or_else(|| "".to_string()); let query = argument.unwrap_or_else(|| "".to_string());
let recent_entries = recent_context_picker_entries( let search_task = self.search(mode, query, Arc::<AtomicBool>::default(), cx);
Some(thread_store.clone()),
Some(text_thread_store.clone()),
workspace.clone(),
&exclude_paths,
&exclude_threads,
cx,
);
let prompt_store = thread_store
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
.ok()
.flatten();
let search_task = search(
mode,
query,
Arc::<AtomicBool>::default(),
recent_entries,
prompt_store,
thread_store.clone(),
text_thread_store.clone(),
workspace.clone(),
cx,
);
cx.spawn(async move |_, cx| { cx.spawn(async move |_, cx| {
let matches = search_task.await; let matches = search_task.await;
@ -668,12 +686,18 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx, cx,
), ),
Match::Thread(ThreadMatch { Match::Thread(thread) => Some(Self::completion_for_thread(
thread, is_recent, ..
}) => Some(Self::completion_for_thread(
thread, thread,
source_range.clone(), source_range.clone(),
is_recent, false,
editor.clone(),
cx,
)),
Match::RecentThread(thread) => Some(Self::completion_for_thread(
thread,
source_range.clone(),
true,
editor.clone(), editor.clone(),
cx, cx,
)), )),
@ -747,6 +771,42 @@ impl CompletionProvider for ContextPickerCompletionProvider {
} }
} }
pub(crate) fn search_threads(
query: String,
cancellation_flag: Arc<AtomicBool>,
history_store: &Entity<HistoryStore>,
cx: &mut App,
) -> Task<Vec<HistoryEntry>> {
let threads = history_store.read(cx).entries(cx);
if query.is_empty() {
return Task::ready(threads);
}
let executor = cx.background_executor().clone();
cx.background_spawn(async move {
let candidates = threads
.iter()
.enumerate()
.map(|(id, thread)| StringMatchCandidate::new(id, thread.title()))
.collect::<Vec<_>>();
let matches = fuzzy::match_strings(
&candidates,
&query,
false,
true,
100,
&cancellation_flag,
executor,
)
.await;
matches
.into_iter()
.map(|mat| threads[mat.candidate_id].clone())
.collect()
})
}
fn confirm_completion_callback( fn confirm_completion_callback(
crease_text: SharedString, crease_text: SharedString,
start: Anchor, start: Anchor,
@ -762,14 +822,16 @@ fn confirm_completion_callback(
message_editor message_editor
.clone() .clone()
.update(cx, |message_editor, cx| { .update(cx, |message_editor, cx| {
message_editor.confirm_completion( message_editor
crease_text, .confirm_completion(
start, crease_text,
content_len, start,
mention_uri, content_len,
window, mention_uri,
cx, window,
) cx,
)
.detach();
}) })
.ok(); .ok();
}); });
@ -794,7 +856,7 @@ impl MentionCompletion {
&& line && line
.chars() .chars()
.nth(last_mention_start - 1) .nth(last_mention_start - 1)
.map_or(false, |c| !c.is_whitespace()) .is_some_and(|c| !c.is_whitespace())
{ {
return None; return None;
} }

View file

@ -1,15 +1,16 @@
use std::ops::Range; use std::ops::Range;
use acp_thread::{AcpThread, AgentThreadEntry}; use acp_thread::{AcpThread, AgentThreadEntry};
use agent::{TextThreadStore, ThreadStore}; use agent2::HistoryStore;
use collections::HashMap; use collections::HashMap;
use editor::{Editor, EditorMode, MinimapVisibility}; use editor::{Editor, EditorMode, MinimapVisibility};
use gpui::{ use gpui::{
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement, AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
WeakEntity, Window, TextStyleRefinement, WeakEntity, Window,
}; };
use language::language_settings::SoftWrap; use language::language_settings::SoftWrap;
use project::Project; use project::Project;
use prompt_store::PromptStore;
use settings::Settings as _; use settings::Settings as _;
use terminal_view::TerminalView; use terminal_view::TerminalView;
use theme::ThemeSettings; use theme::ThemeSettings;
@ -21,24 +22,27 @@ use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
pub struct EntryViewState { pub struct EntryViewState {
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
project: Entity<Project>, project: Entity<Project>,
thread_store: Entity<ThreadStore>, history_store: Entity<HistoryStore>,
text_thread_store: Entity<TextThreadStore>, prompt_store: Option<Entity<PromptStore>>,
entries: Vec<Entry>, entries: Vec<Entry>,
prevent_slash_commands: bool,
} }
impl EntryViewState { impl EntryViewState {
pub fn new( pub fn new(
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
project: Entity<Project>, project: Entity<Project>,
thread_store: Entity<ThreadStore>, history_store: Entity<HistoryStore>,
text_thread_store: Entity<TextThreadStore>, prompt_store: Option<Entity<PromptStore>>,
prevent_slash_commands: bool,
) -> Self { ) -> Self {
Self { Self {
workspace, workspace,
project, project,
thread_store, history_store,
text_thread_store, prompt_store,
entries: Vec::new(), entries: Vec::new(),
prevent_slash_commands,
} }
} }
@ -61,33 +65,45 @@ impl EntryViewState {
AgentThreadEntry::UserMessage(message) => { AgentThreadEntry::UserMessage(message) => {
let has_id = message.id.is_some(); let has_id = message.id.is_some();
let chunks = message.chunks.clone(); let chunks = message.chunks.clone();
let message_editor = cx.new(|cx| { if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
let mut editor = MessageEditor::new( if !editor.focus_handle(cx).is_focused(window) {
self.workspace.clone(), // Only update if we are not editing.
self.project.clone(), // If we are, cancelling the edit will set the message to the newest content.
self.thread_store.clone(), editor.update(cx, |editor, cx| {
self.text_thread_store.clone(), editor.set_message(chunks, window, cx);
editor::EditorMode::AutoHeight { });
min_lines: 1,
max_lines: None,
},
window,
cx,
);
if !has_id {
editor.set_read_only(true, cx);
} }
editor.set_message(chunks, window, cx); } else {
editor let message_editor = cx.new(|cx| {
}); let mut editor = MessageEditor::new(
cx.subscribe(&message_editor, move |_, editor, event, cx| { self.workspace.clone(),
cx.emit(EntryViewEvent { self.project.clone(),
entry_index: index, self.history_store.clone(),
view_event: ViewEvent::MessageEditorEvent(editor, *event), self.prompt_store.clone(),
"Edit message @ to include context",
self.prevent_slash_commands,
editor::EditorMode::AutoHeight {
min_lines: 1,
max_lines: None,
},
window,
cx,
);
if !has_id {
editor.set_read_only(true, cx);
}
editor.set_message(chunks, window, cx);
editor
});
cx.subscribe(&message_editor, move |_, editor, event, cx| {
cx.emit(EntryViewEvent {
entry_index: index,
view_event: ViewEvent::MessageEditorEvent(editor, *event),
})
}) })
}) .detach();
.detach(); self.set_entry(index, Entry::UserMessage(message_editor));
self.set_entry(index, Entry::UserMessage(message_editor)); }
} }
AgentThreadEntry::ToolCall(tool_call) => { AgentThreadEntry::ToolCall(tool_call) => {
let terminals = tool_call.terminals().cloned().collect::<Vec<_>>(); let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
@ -174,6 +190,7 @@ pub enum ViewEvent {
MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent), MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
} }
#[derive(Debug)]
pub enum Entry { pub enum Entry {
UserMessage(Entity<MessageEditor>), UserMessage(Entity<MessageEditor>),
Content(HashMap<EntityId, AnyEntity>), Content(HashMap<EntityId, AnyEntity>),
@ -297,9 +314,10 @@ mod tests {
use std::{path::Path, rc::Rc}; use std::{path::Path, rc::Rc};
use acp_thread::{AgentConnection, StubAgentConnection}; use acp_thread::{AgentConnection, StubAgentConnection};
use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::AgentSettings; use agent_settings::AgentSettings;
use agent2::HistoryStore;
use assistant_context::ContextStore;
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind}; use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
use editor::{EditorSettings, RowInfo}; use editor::{EditorSettings, RowInfo};
use fs::FakeFs; use fs::FakeFs;
@ -362,15 +380,16 @@ mod tests {
connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx) connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
}); });
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx));
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
let view_state = cx.new(|_cx| { let view_state = cx.new(|_cx| {
EntryViewState::new( EntryViewState::new(
workspace.downgrade(), workspace.downgrade(),
project.clone(), project.clone(),
thread_store, history_store,
text_thread_store, None,
false,
) )
}); });

File diff suppressed because it is too large Load diff

View file

@ -330,7 +330,7 @@ async fn fuzzy_search(
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut matches = match_strings( let mut matches = match_strings(
&candidates, &candidates,
&query, query,
false, false,
true, true,
100, 100,

View file

@ -0,0 +1,721 @@
use crate::RemoveSelectedThread;
use agent2::{HistoryEntry, HistoryStore};
use chrono::{Datelike as _, Local, NaiveDate, TimeDelta};
use editor::{Editor, EditorEvent};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Task,
UniformListScrollHandle, Window, uniform_list,
};
use std::{fmt::Display, ops::Range, sync::Arc};
use time::{OffsetDateTime, UtcOffset};
use ui::{
HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Tooltip, WithScrollbar,
prelude::*,
};
use util::ResultExt;
pub struct AcpThreadHistory {
pub(crate) history_store: Entity<HistoryStore>,
scroll_handle: UniformListScrollHandle,
selected_index: usize,
hovered_index: Option<usize>,
search_editor: Entity<Editor>,
all_entries: Arc<Vec<HistoryEntry>>,
// When the search is empty, we display date separators between history entries
// This vector contains an enum of either a separator or an actual entry
separated_items: Vec<ListItemType>,
// Maps entry indexes to list item indexes
separated_item_indexes: Vec<u32>,
_separated_items_task: Option<Task<()>>,
search_state: SearchState,
local_timezone: UtcOffset,
_subscriptions: Vec<gpui::Subscription>,
}
enum SearchState {
Empty,
Searching {
query: SharedString,
_task: Task<()>,
},
Searched {
query: SharedString,
matches: Vec<StringMatch>,
},
}
enum ListItemType {
BucketSeparator(TimeBucket),
Entry {
index: usize,
format: EntryTimeFormat,
},
}
pub enum ThreadHistoryEvent {
Open(HistoryEntry),
}
impl EventEmitter<ThreadHistoryEvent> for AcpThreadHistory {}
impl AcpThreadHistory {
pub(crate) fn new(
history_store: Entity<agent2::HistoryStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let search_editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_placeholder_text("Search threads...", cx);
editor
});
let search_editor_subscription =
cx.subscribe(&search_editor, |this, search_editor, event, cx| {
if let EditorEvent::BufferEdited = event {
let query = search_editor.read(cx).text(cx);
this.search(query.into(), cx);
}
});
let history_store_subscription = cx.observe(&history_store, |this, _, cx| {
this.update_all_entries(cx);
});
let scroll_handle = UniformListScrollHandle::default();
let mut this = Self {
history_store,
scroll_handle,
selected_index: 0,
hovered_index: None,
search_state: SearchState::Empty,
all_entries: Default::default(),
separated_items: Default::default(),
separated_item_indexes: Default::default(),
search_editor,
local_timezone: UtcOffset::from_whole_seconds(
chrono::Local::now().offset().local_minus_utc(),
)
.unwrap(),
_subscriptions: vec![search_editor_subscription, history_store_subscription],
_separated_items_task: None,
};
this.update_all_entries(cx);
this
}
fn update_all_entries(&mut self, cx: &mut Context<Self>) {
let new_entries: Arc<Vec<HistoryEntry>> = self
.history_store
.update(cx, |store, cx| store.entries(cx))
.into();
self._separated_items_task.take();
let mut items = Vec::with_capacity(new_entries.len() + 1);
let mut indexes = Vec::with_capacity(new_entries.len() + 1);
let bg_task = cx.background_spawn(async move {
let mut bucket = None;
let today = Local::now().naive_local().date();
for (index, entry) in new_entries.iter().enumerate() {
let entry_date = entry
.updated_at()
.with_timezone(&Local)
.naive_local()
.date();
let entry_bucket = TimeBucket::from_dates(today, entry_date);
if Some(entry_bucket) != bucket {
bucket = Some(entry_bucket);
items.push(ListItemType::BucketSeparator(entry_bucket));
}
indexes.push(items.len() as u32);
items.push(ListItemType::Entry {
index,
format: entry_bucket.into(),
});
}
(new_entries, items, indexes)
});
let task = cx.spawn(async move |this, cx| {
let (new_entries, items, indexes) = bg_task.await;
this.update(cx, |this, cx| {
let previously_selected_entry =
this.all_entries.get(this.selected_index).map(|e| e.id());
this.all_entries = new_entries;
this.separated_items = items;
this.separated_item_indexes = indexes;
match &this.search_state {
SearchState::Empty => {
if this.selected_index >= this.all_entries.len() {
this.set_selected_entry_index(
this.all_entries.len().saturating_sub(1),
cx,
);
} else if let Some(prev_id) = previously_selected_entry
&& let Some(new_ix) = this
.all_entries
.iter()
.position(|probe| probe.id() == prev_id)
{
this.set_selected_entry_index(new_ix, cx);
}
}
SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {
this.search(query.clone(), cx);
}
}
cx.notify();
})
.log_err();
});
self._separated_items_task = Some(task);
}
fn search(&mut self, query: SharedString, cx: &mut Context<Self>) {
if query.is_empty() {
self.search_state = SearchState::Empty;
cx.notify();
return;
}
let all_entries = self.all_entries.clone();
let fuzzy_search_task = cx.background_spawn({
let query = query.clone();
let executor = cx.background_executor().clone();
async move {
let mut candidates = Vec::with_capacity(all_entries.len());
for (idx, entry) in all_entries.iter().enumerate() {
candidates.push(StringMatchCandidate::new(idx, entry.title()));
}
const MAX_MATCHES: usize = 100;
fuzzy::match_strings(
&candidates,
&query,
false,
true,
MAX_MATCHES,
&Default::default(),
executor,
)
.await
}
});
let task = cx.spawn({
let query = query.clone();
async move |this, cx| {
let matches = fuzzy_search_task.await;
this.update(cx, |this, cx| {
let SearchState::Searching {
query: current_query,
_task,
} = &this.search_state
else {
return;
};
if &query == current_query {
this.search_state = SearchState::Searched {
query: query.clone(),
matches,
};
this.set_selected_entry_index(0, cx);
cx.notify();
};
})
.log_err();
}
});
self.search_state = SearchState::Searching { query, _task: task };
cx.notify();
}
fn matched_count(&self) -> usize {
match &self.search_state {
SearchState::Empty => self.all_entries.len(),
SearchState::Searching { .. } => 0,
SearchState::Searched { matches, .. } => matches.len(),
}
}
fn list_item_count(&self) -> usize {
match &self.search_state {
SearchState::Empty => self.separated_items.len(),
SearchState::Searching { .. } => 0,
SearchState::Searched { matches, .. } => matches.len(),
}
}
fn search_produced_no_matches(&self) -> bool {
match &self.search_state {
SearchState::Empty => false,
SearchState::Searching { .. } => false,
SearchState::Searched { matches, .. } => matches.is_empty(),
}
}
fn get_match(&self, ix: usize) -> Option<&HistoryEntry> {
match &self.search_state {
SearchState::Empty => self.all_entries.get(ix),
SearchState::Searching { .. } => None,
SearchState::Searched { matches, .. } => matches
.get(ix)
.and_then(|m| self.all_entries.get(m.candidate_id)),
}
}
pub fn select_previous(
&mut self,
_: &menu::SelectPrevious,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let count = self.matched_count();
if count > 0 {
if self.selected_index == 0 {
self.set_selected_entry_index(count - 1, cx);
} else {
self.set_selected_entry_index(self.selected_index - 1, cx);
}
}
}
pub fn select_next(
&mut self,
_: &menu::SelectNext,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let count = self.matched_count();
if count > 0 {
if self.selected_index == count - 1 {
self.set_selected_entry_index(0, cx);
} else {
self.set_selected_entry_index(self.selected_index + 1, cx);
}
}
}
fn select_first(
&mut self,
_: &menu::SelectFirst,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let count = self.matched_count();
if count > 0 {
self.set_selected_entry_index(0, cx);
}
}
fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
let count = self.matched_count();
if count > 0 {
self.set_selected_entry_index(count - 1, cx);
}
}
fn set_selected_entry_index(&mut self, entry_index: usize, cx: &mut Context<Self>) {
self.selected_index = entry_index;
let scroll_ix = match self.search_state {
SearchState::Empty | SearchState::Searching { .. } => self
.separated_item_indexes
.get(entry_index)
.map(|ix| *ix as usize)
.unwrap_or(entry_index + 1),
SearchState::Searched { .. } => entry_index,
};
self.scroll_handle
.scroll_to_item(scroll_ix, ScrollStrategy::Top);
cx.notify();
}
fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
self.confirm_entry(self.selected_index, cx);
}
fn confirm_entry(&mut self, ix: usize, cx: &mut Context<Self>) {
let Some(entry) = self.get_match(ix) else {
return;
};
cx.emit(ThreadHistoryEvent::Open(entry.clone()));
}
fn remove_selected_thread(
&mut self,
_: &RemoveSelectedThread,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.remove_thread(self.selected_index, cx)
}
fn remove_thread(&mut self, ix: usize, cx: &mut Context<Self>) {
let Some(entry) = self.get_match(ix) else {
return;
};
let task = match entry {
HistoryEntry::AcpThread(thread) => self
.history_store
.update(cx, |this, cx| this.delete_thread(thread.id.clone(), cx)),
HistoryEntry::TextThread(context) => self.history_store.update(cx, |this, cx| {
this.delete_text_thread(context.path.clone(), cx)
}),
};
task.detach_and_log_err(cx);
}
fn list_items(
&mut self,
range: Range<usize>,
_window: &mut Window,
cx: &mut Context<Self>,
) -> Vec<AnyElement> {
match &self.search_state {
SearchState::Empty => self
.separated_items
.get(range)
.iter()
.flat_map(|items| {
items
.iter()
.map(|item| self.render_list_item(item, vec![], cx))
})
.collect(),
SearchState::Searched { matches, .. } => matches[range]
.iter()
.filter_map(|m| {
let entry = self.all_entries.get(m.candidate_id)?;
Some(self.render_history_entry(
entry,
EntryTimeFormat::DateAndTime,
m.candidate_id,
m.positions.clone(),
cx,
))
})
.collect(),
SearchState::Searching { .. } => {
vec![]
}
}
}
fn render_list_item(
&self,
item: &ListItemType,
highlight_positions: Vec<usize>,
cx: &Context<Self>,
) -> AnyElement {
match item {
ListItemType::Entry { index, format } => match self.all_entries.get(*index) {
Some(entry) => self
.render_history_entry(entry, *format, *index, highlight_positions, cx)
.into_any(),
None => Empty.into_any_element(),
},
ListItemType::BucketSeparator(bucket) => div()
.px(DynamicSpacing::Base06.rems(cx))
.pt_2()
.pb_1()
.child(
Label::new(bucket.to_string())
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.into_any_element(),
}
}
fn render_history_entry(
&self,
entry: &HistoryEntry,
format: EntryTimeFormat,
list_entry_ix: usize,
highlight_positions: Vec<usize>,
cx: &Context<Self>,
) -> AnyElement {
let selected = list_entry_ix == self.selected_index;
let hovered = Some(list_entry_ix) == self.hovered_index;
let timestamp = entry.updated_at().timestamp();
let thread_timestamp = format.format_timestamp(timestamp, self.local_timezone);
h_flex()
.w_full()
.pb_1()
.child(
ListItem::new(list_entry_ix)
.rounded()
.toggle_state(selected)
.spacing(ListItemSpacing::Sparse)
.start_slot(
h_flex()
.w_full()
.gap_2()
.justify_between()
.child(
HighlightedLabel::new(entry.title(), highlight_positions)
.size(LabelSize::Small)
.truncate(),
)
.child(
Label::new(thread_timestamp)
.color(Color::Muted)
.size(LabelSize::XSmall),
),
)
.on_hover(cx.listener(move |this, is_hovered, _window, cx| {
if *is_hovered {
this.hovered_index = Some(list_entry_ix);
} else if this.hovered_index == Some(list_entry_ix) {
this.hovered_index = None;
}
cx.notify();
}))
.end_slot::<IconButton>(if hovered || selected {
Some(
IconButton::new("delete", IconName::Trash)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(move |window, cx| {
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
})
.on_click(cx.listener(move |this, _, _, cx| {
this.remove_thread(list_entry_ix, cx)
})),
)
} else {
None
})
.on_click(
cx.listener(move |this, _, _, cx| this.confirm_entry(list_entry_ix, cx)),
),
)
.into_any_element()
}
}
impl Focusable for AcpThreadHistory {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.search_editor.focus_handle(cx)
}
}
impl Render for AcpThreadHistory {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.key_context("ThreadHistory")
.size_full()
.on_action(cx.listener(Self::select_previous))
.on_action(cx.listener(Self::select_next))
.on_action(cx.listener(Self::select_first))
.on_action(cx.listener(Self::select_last))
.on_action(cx.listener(Self::confirm))
.on_action(cx.listener(Self::remove_selected_thread))
.when(!self.all_entries.is_empty(), |parent| {
parent.child(
h_flex()
.h(px(41.)) // Match the toolbar perfectly
.w_full()
.py_1()
.px_2()
.gap_2()
.justify_between()
.border_b_1()
.border_color(cx.theme().colors().border)
.child(
Icon::new(IconName::MagnifyingGlass)
.color(Color::Muted)
.size(IconSize::Small),
)
.child(self.search_editor.clone()),
)
})
.child({
let view = v_flex()
.id("list-container")
.relative()
.overflow_hidden()
.flex_grow();
if self.all_entries.is_empty() {
view.justify_center()
.child(
h_flex().w_full().justify_center().child(
Label::new("You don't have any past threads yet.")
.size(LabelSize::Small),
),
)
} else if self.search_produced_no_matches() {
view.justify_center().child(
h_flex().w_full().justify_center().child(
Label::new("No threads match your search.").size(LabelSize::Small),
),
)
} else {
view.pr_5().child(
uniform_list(
"thread-history",
self.list_item_count(),
cx.processor(|this, range: Range<usize>, window, cx| {
this.list_items(range, window, cx)
}),
)
.p_1()
.track_scroll(self.scroll_handle.clone())
.vertical_scrollbar_for(self.scroll_handle.clone(), window, cx)
.flex_grow(),
)
}
})
}
}
#[derive(Clone, Copy)]
pub enum EntryTimeFormat {
DateAndTime,
TimeOnly,
}
impl EntryTimeFormat {
fn format_timestamp(&self, timestamp: i64, timezone: UtcOffset) -> String {
let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap();
match self {
EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp(
timestamp,
OffsetDateTime::now_utc(),
timezone,
time_format::TimestampFormat::EnhancedAbsolute,
),
EntryTimeFormat::TimeOnly => time_format::format_time(timestamp),
}
}
}
impl From<TimeBucket> for EntryTimeFormat {
fn from(bucket: TimeBucket) -> Self {
match bucket {
TimeBucket::Today => EntryTimeFormat::TimeOnly,
TimeBucket::Yesterday => EntryTimeFormat::TimeOnly,
TimeBucket::ThisWeek => EntryTimeFormat::DateAndTime,
TimeBucket::PastWeek => EntryTimeFormat::DateAndTime,
TimeBucket::All => EntryTimeFormat::DateAndTime,
}
}
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
enum TimeBucket {
Today,
Yesterday,
ThisWeek,
PastWeek,
All,
}
impl TimeBucket {
fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self {
if date == reference {
return TimeBucket::Today;
}
if date == reference - TimeDelta::days(1) {
return TimeBucket::Yesterday;
}
let week = date.iso_week();
if reference.iso_week() == week {
return TimeBucket::ThisWeek;
}
let last_week = (reference - TimeDelta::days(7)).iso_week();
if week == last_week {
return TimeBucket::PastWeek;
}
TimeBucket::All
}
}
impl Display for TimeBucket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TimeBucket::Today => write!(f, "Today"),
TimeBucket::Yesterday => write!(f, "Yesterday"),
TimeBucket::ThisWeek => write!(f, "This Week"),
TimeBucket::PastWeek => write!(f, "Past Week"),
TimeBucket::All => write!(f, "All"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::NaiveDate;
#[test]
fn test_time_bucket_from_dates() {
let today = NaiveDate::from_ymd_opt(2023, 1, 15).unwrap();
let date = today;
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Today);
let date = NaiveDate::from_ymd_opt(2023, 1, 14).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Yesterday);
let date = NaiveDate::from_ymd_opt(2023, 1, 13).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek);
let date = NaiveDate::from_ymd_opt(2023, 1, 11).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek);
let date = NaiveDate::from_ymd_opt(2023, 1, 8).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek);
let date = NaiveDate::from_ymd_opt(2023, 1, 5).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek);
// All: not in this week or last week
let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap();
assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::All);
// Test year boundary cases
let new_year = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap();
let date = NaiveDate::from_ymd_opt(2022, 12, 31).unwrap();
assert_eq!(
TimeBucket::from_dates(new_year, date),
TimeBucket::Yesterday
);
let date = NaiveDate::from_ymd_opt(2022, 12, 28).unwrap();
assert_eq!(TimeBucket::from_dates(new_year, date), TimeBucket::ThisWeek);
}
}

File diff suppressed because it is too large Load diff

View file

@ -1040,12 +1040,12 @@ impl ActiveThread {
); );
} }
ThreadEvent::StreamedAssistantText(message_id, text) => { ThreadEvent::StreamedAssistantText(message_id, text) => {
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) { if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(message_id) {
rendered_message.append_text(text, cx); rendered_message.append_text(text, cx);
} }
} }
ThreadEvent::StreamedAssistantThinking(message_id, text) => { ThreadEvent::StreamedAssistantThinking(message_id, text) => {
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) { if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(message_id) {
rendered_message.append_thinking(text, cx); rendered_message.append_thinking(text, cx);
} }
} }
@ -1068,8 +1068,8 @@ impl ActiveThread {
} }
ThreadEvent::MessageEdited(message_id) => { ThreadEvent::MessageEdited(message_id) => {
self.clear_last_error(); self.clear_last_error();
if let Some(index) = self.messages.iter().position(|id| id == message_id) { if let Some(index) = self.messages.iter().position(|id| id == message_id)
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| { && let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
thread.message(*message_id).map(|message| { thread.message(*message_id).map(|message| {
let mut rendered_message = RenderedMessage { let mut rendered_message = RenderedMessage {
language_registry: self.language_registry.clone(), language_registry: self.language_registry.clone(),
@ -1080,14 +1080,14 @@ impl ActiveThread {
} }
rendered_message rendered_message
}) })
}) { })
self.list_state.splice(index..index + 1, 1); {
self.rendered_messages_by_id self.list_state.splice(index..index + 1, 1);
.insert(*message_id, rendered_message); self.rendered_messages_by_id
self.scroll_to_bottom(cx); .insert(*message_id, rendered_message);
self.save_thread(cx); self.scroll_to_bottom(cx);
cx.notify(); self.save_thread(cx);
} cx.notify();
} }
} }
ThreadEvent::MessageDeleted(message_id) => { ThreadEvent::MessageDeleted(message_id) => {
@ -1268,62 +1268,61 @@ impl ActiveThread {
}) })
}) })
.log_err() .log_err()
&& let Some(pop_up) = screen_window.entity(cx).log_err()
{ {
if let Some(pop_up) = screen_window.entity(cx).log_err() { self.notification_subscriptions
self.notification_subscriptions .entry(screen_window)
.entry(screen_window) .or_insert_with(Vec::new)
.or_insert_with(Vec::new) .push(cx.subscribe_in(&pop_up, window, {
.push(cx.subscribe_in(&pop_up, window, { |this, _, event, window, cx| match event {
|this, _, event, window, cx| match event { AgentNotificationEvent::Accepted => {
AgentNotificationEvent::Accepted => { let handle = window.window_handle();
let handle = window.window_handle(); cx.activate(true);
cx.activate(true);
let workspace_handle = this.workspace.clone(); let workspace_handle = this.workspace.clone();
// If there are multiple Zed windows, activate the correct one. // If there are multiple Zed windows, activate the correct one.
cx.defer(move |cx| { cx.defer(move |cx| {
handle handle
.update(cx, |_view, window, _cx| { .update(cx, |_view, window, _cx| {
window.activate_window(); window.activate_window();
if let Some(workspace) = workspace_handle.upgrade() { if let Some(workspace) = workspace_handle.upgrade() {
workspace.update(_cx, |workspace, cx| { workspace.update(_cx, |workspace, cx| {
workspace.focus_panel::<AgentPanel>(window, cx); workspace.focus_panel::<AgentPanel>(window, cx);
}); });
} }
}) })
.log_err(); .log_err();
}); });
this.dismiss_notifications(cx); this.dismiss_notifications(cx);
}
AgentNotificationEvent::Dismissed => {
this.dismiss_notifications(cx);
}
} }
})); AgentNotificationEvent::Dismissed => {
this.dismiss_notifications(cx);
}
}
}));
self.notifications.push(screen_window); self.notifications.push(screen_window);
// If the user manually refocuses the original window, dismiss the popup. // If the user manually refocuses the original window, dismiss the popup.
self.notification_subscriptions self.notification_subscriptions
.entry(screen_window) .entry(screen_window)
.or_insert_with(Vec::new) .or_insert_with(Vec::new)
.push({ .push({
let pop_up_weak = pop_up.downgrade(); let pop_up_weak = pop_up.downgrade();
cx.observe_window_activation(window, move |_, window, cx| { cx.observe_window_activation(window, move |_, window, cx| {
if window.is_window_active() { if window.is_window_active()
if let Some(pop_up) = pop_up_weak.upgrade() { && let Some(pop_up) = pop_up_weak.upgrade()
pop_up.update(cx, |_, cx| { {
cx.emit(AgentNotificationEvent::Dismissed); pop_up.update(cx, |_, cx| {
}); cx.emit(AgentNotificationEvent::Dismissed);
} });
} }
}) })
}); });
}
} }
} }
@ -1370,12 +1369,12 @@ impl ActiveThread {
editor.focus_handle(cx).focus(window); editor.focus_handle(cx).focus(window);
editor.move_to_end(&editor::actions::MoveToEnd, window, cx); editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
}); });
let buffer_edited_subscription = cx.subscribe(&editor, |this, _, event, cx| match event { let buffer_edited_subscription =
EditorEvent::BufferEdited => { cx.subscribe(&editor, |this, _, event: &EditorEvent, cx| {
this.update_editing_message_token_count(true, cx); if event == &EditorEvent::BufferEdited {
} this.update_editing_message_token_count(true, cx);
_ => {} }
}); });
let context_picker_menu_handle = PopoverMenuHandle::default(); let context_picker_menu_handle = PopoverMenuHandle::default();
let context_strip = cx.new(|cx| { let context_strip = cx.new(|cx| {
@ -2243,9 +2242,7 @@ impl ActiveThread {
let after_editing_message = self let after_editing_message = self
.editing_message .editing_message
.as_ref() .as_ref()
.map_or(false, |(editing_message_id, _)| { .is_some_and(|(editing_message_id, _)| message_id > *editing_message_id);
message_id > *editing_message_id
});
let backdrop = div() let backdrop = div()
.id(("backdrop", ix)) .id(("backdrop", ix))
@ -2265,13 +2262,12 @@ impl ActiveThread {
let mut error = None; let mut error = None;
if let Some(last_restore_checkpoint) = if let Some(last_restore_checkpoint) =
self.thread.read(cx).last_restore_checkpoint() self.thread.read(cx).last_restore_checkpoint()
&& last_restore_checkpoint.message_id() == message_id
{ {
if last_restore_checkpoint.message_id() == message_id { match last_restore_checkpoint {
match last_restore_checkpoint { LastRestoreCheckpoint::Pending { .. } => is_pending = true,
LastRestoreCheckpoint::Pending { .. } => is_pending = true, LastRestoreCheckpoint::Error { error: err, .. } => {
LastRestoreCheckpoint::Error { error: err, .. } => { error = Some(err.clone());
error = Some(err.clone());
}
} }
} }
} }
@ -2469,7 +2465,7 @@ impl ActiveThread {
message_id, message_id,
index, index,
content.clone(), content.clone(),
&scroll_handle, scroll_handle,
Some(index) == pending_thinking_segment_index, Some(index) == pending_thinking_segment_index,
window, window,
cx, cx,
@ -2593,7 +2589,7 @@ impl ActiveThread {
.id(("message-container", ix)) .id(("message-container", ix))
.py_1() .py_1()
.px_2p5() .px_2p5()
.child(Banner::new().severity(ui::Severity::Warning).child(message)) .child(Banner::new().severity(Severity::Warning).child(message))
} }
fn render_message_thinking_segment( fn render_message_thinking_segment(

View file

@ -94,7 +94,7 @@ impl AgentConfiguration {
let mut expanded_provider_configurations = HashMap::default(); let mut expanded_provider_configurations = HashMap::default();
if LanguageModelRegistry::read_global(cx) if LanguageModelRegistry::read_global(cx)
.provider(&ZED_CLOUD_PROVIDER_ID) .provider(&ZED_CLOUD_PROVIDER_ID)
.map_or(false, |cloud_provider| cloud_provider.must_accept_terms(cx)) .is_some_and(|cloud_provider| cloud_provider.must_accept_terms(cx))
{ {
expanded_provider_configurations.insert(ZED_CLOUD_PROVIDER_ID, true); expanded_provider_configurations.insert(ZED_CLOUD_PROVIDER_ID, true);
} }
@ -134,7 +134,11 @@ impl AgentConfiguration {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let configuration_view = provider.configuration_view(window, cx); let configuration_view = provider.configuration_view(
language_model::ConfigurationViewTargetAgent::ZedAgent,
window,
cx,
);
self.configuration_views_by_provider self.configuration_views_by_provider
.insert(provider.id(), configuration_view); .insert(provider.id(), configuration_view);
} }
@ -951,7 +955,7 @@ impl AgentConfiguration {
} }
parent.child(v_flex().py_1p5().px_1().gap_1().children( parent.child(v_flex().py_1p5().px_1().gap_1().children(
tools.into_iter().enumerate().map(|(ix, tool)| { tools.iter().enumerate().map(|(ix, tool)| {
h_flex() h_flex()
.id(("tool-item", ix)) .id(("tool-item", ix))
.px_1() .px_1()

View file

@ -454,7 +454,7 @@ impl Render for AddLlmProviderModal {
this.section( this.section(
Section::new().child( Section::new().child(
Banner::new() Banner::new()
.severity(ui::Severity::Warning) .severity(Severity::Warning)
.child(div().text_xs().child(error)), .child(div().text_xs().child(error)),
), ),
) )

View file

@ -163,10 +163,10 @@ impl ConfigurationSource {
.read(cx) .read(cx)
.text(cx); .text(cx);
let settings = serde_json_lenient::from_str::<serde_json::Value>(&text)?; let settings = serde_json_lenient::from_str::<serde_json::Value>(&text)?;
if let Some(settings_validator) = settings_validator { if let Some(settings_validator) = settings_validator
if let Err(error) = settings_validator.validate(&settings) { && let Err(error) = settings_validator.validate(&settings)
return Err(anyhow::anyhow!(error.to_string())); {
} return Err(anyhow::anyhow!(error.to_string()));
} }
Ok(( Ok((
id.clone(), id.clone(),
@ -487,7 +487,7 @@ impl ConfigureContextServerModal {
} }
fn render_modal_description(&self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement { fn render_modal_description(&self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
const MODAL_DESCRIPTION: &'static str = "Visit the MCP server configuration docs to find all necessary arguments and environment variables."; const MODAL_DESCRIPTION: &str = "Visit the MCP server configuration docs to find all necessary arguments and environment variables.";
if let ConfigurationSource::Extension { if let ConfigurationSource::Extension {
installation_instructions: Some(installation_instructions), installation_instructions: Some(installation_instructions),
@ -716,24 +716,24 @@ fn wait_for_context_server(
project::context_server_store::Event::ServerStatusChanged { server_id, status } => { project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status { match status {
ContextServerStatus::Running => { ContextServerStatus::Running => {
if server_id == &context_server_id { if server_id == &context_server_id
if let Some(tx) = tx.lock().unwrap().take() { && let Some(tx) = tx.lock().unwrap().take()
let _ = tx.send(Ok(())); {
} let _ = tx.send(Ok(()));
} }
} }
ContextServerStatus::Stopped => { ContextServerStatus::Stopped => {
if server_id == &context_server_id { if server_id == &context_server_id
if let Some(tx) = tx.lock().unwrap().take() { && let Some(tx) = tx.lock().unwrap().take()
let _ = tx.send(Err("Context server stopped running".into())); {
} let _ = tx.send(Err("Context server stopped running".into()));
} }
} }
ContextServerStatus::Error(error) => { ContextServerStatus::Error(error) => {
if server_id == &context_server_id { if server_id == &context_server_id
if let Some(tx) = tx.lock().unwrap().take() { && let Some(tx) = tx.lock().unwrap().take()
let _ = tx.send(Err(error.clone())); {
} let _ = tx.send(Err(error.clone()));
} }
} }
_ => {} _ => {}

View file

@ -191,10 +191,10 @@ impl PickerDelegate for ToolPickerDelegate {
BTreeMap::default(); BTreeMap::default();
for item in all_items.iter() { for item in all_items.iter() {
if let PickerItem::Tool { server_id, name } = item.clone() { if let PickerItem::Tool { server_id, name } = item.clone()
if name.contains(&query) { && name.contains(&query)
tools_by_provider.entry(server_id).or_default().push(name); {
} tools_by_provider.entry(server_id).or_default().push(name);
} }
} }

View file

@ -199,24 +199,21 @@ impl AgentDiffPane {
let action_log = thread.action_log(cx).clone(); let action_log = thread.action_log(cx).clone();
let mut this = Self { let mut this = Self {
_subscriptions: [ _subscriptions: vec![
Some( cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
cx.observe_in(&action_log, window, |this, _action_log, window, cx| { this.update_excerpts(window, cx)
this.update_excerpts(window, cx) }),
}),
),
match &thread { match &thread {
AgentDiffThread::Native(thread) => { AgentDiffThread::Native(thread) => cx
Some(cx.subscribe(&thread, |this, _thread, event, cx| { .subscribe(thread, |this, _thread, event, cx| {
this.handle_thread_event(event, cx) this.handle_native_thread_event(event, cx)
})) }),
} AgentDiffThread::AcpThread(thread) => cx
AgentDiffThread::AcpThread(_) => None, .subscribe(thread, |this, _thread, event, cx| {
this.handle_acp_thread_event(event, cx)
}),
}, },
] ],
.into_iter()
.flatten()
.collect(),
title: SharedString::default(), title: SharedString::default(),
multibuffer, multibuffer,
editor, editor,
@ -288,7 +285,7 @@ impl AgentDiffPane {
&& buffer && buffer
.read(cx) .read(cx)
.file() .file()
.map_or(false, |file| file.disk_state() == DiskState::Deleted) .is_some_and(|file| file.disk_state() == DiskState::Deleted)
{ {
editor.fold_buffer(snapshot.text.remote_id(), cx) editor.fold_buffer(snapshot.text.remote_id(), cx)
} }
@ -324,10 +321,15 @@ impl AgentDiffPane {
} }
} }
fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) { fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
match event { if let ThreadEvent::SummaryGenerated = event {
ThreadEvent::SummaryGenerated => self.update_title(cx), self.update_title(cx)
_ => {} }
}
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
if let AcpThreadEvent::TitleUpdated = event {
self.update_title(cx)
} }
} }
@ -398,7 +400,7 @@ fn keep_edits_in_selection(
.disjoint_anchor_ranges() .disjoint_anchor_ranges()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
keep_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx) keep_edits_in_ranges(editor, buffer_snapshot, thread, ranges, window, cx)
} }
fn reject_edits_in_selection( fn reject_edits_in_selection(
@ -412,7 +414,7 @@ fn reject_edits_in_selection(
.selections .selections
.disjoint_anchor_ranges() .disjoint_anchor_ranges()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
reject_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx) reject_edits_in_ranges(editor, buffer_snapshot, thread, ranges, window, cx)
} }
fn keep_edits_in_ranges( fn keep_edits_in_ranges(
@ -503,8 +505,7 @@ fn update_editor_selection(
&[last_kept_hunk_end..editor::Anchor::max()], &[last_kept_hunk_end..editor::Anchor::max()],
buffer_snapshot, buffer_snapshot,
) )
.skip(1) .nth(1)
.next()
}) })
.or_else(|| { .or_else(|| {
let first_kept_hunk = diff_hunks.first()?; let first_kept_hunk = diff_hunks.first()?;
@ -1001,7 +1002,7 @@ impl AgentDiffToolbar {
return; return;
}; };
*state = agent_diff.read(cx).editor_state(&editor); *state = agent_diff.read(cx).editor_state(editor);
self.update_location(cx); self.update_location(cx);
cx.notify(); cx.notify();
} }
@ -1044,23 +1045,23 @@ impl ToolbarItemView for AgentDiffToolbar {
return self.location(cx); return self.location(cx);
} }
if let Some(editor) = item.act_as::<Editor>(cx) { if let Some(editor) = item.act_as::<Editor>(cx)
if editor.read(cx).mode().is_full() { && editor.read(cx).mode().is_full()
let agent_diff = AgentDiff::global(cx); {
let agent_diff = AgentDiff::global(cx);
self.active_item = Some(AgentDiffToolbarItem::Editor { self.active_item = Some(AgentDiffToolbarItem::Editor {
editor: editor.downgrade(), editor: editor.downgrade(),
state: agent_diff.read(cx).editor_state(&editor.downgrade()), state: agent_diff.read(cx).editor_state(&editor.downgrade()),
_diff_subscription: cx.observe(&agent_diff, Self::handle_diff_notify), _diff_subscription: cx.observe(&agent_diff, Self::handle_diff_notify),
}); });
return self.location(cx); return self.location(cx);
}
} }
} }
self.active_item = None; self.active_item = None;
return self.location(cx); self.location(cx)
} }
fn pane_focus_update( fn pane_focus_update(
@ -1343,13 +1344,13 @@ impl AgentDiff {
}); });
let thread_subscription = match &thread { let thread_subscription = match &thread {
AgentDiffThread::Native(thread) => cx.subscribe_in(&thread, window, { AgentDiffThread::Native(thread) => cx.subscribe_in(thread, window, {
let workspace = workspace.clone(); let workspace = workspace.clone();
move |this, _thread, event, window, cx| { move |this, _thread, event, window, cx| {
this.handle_native_thread_event(&workspace, event, window, cx) this.handle_native_thread_event(&workspace, event, window, cx)
} }
}), }),
AgentDiffThread::AcpThread(thread) => cx.subscribe_in(&thread, window, { AgentDiffThread::AcpThread(thread) => cx.subscribe_in(thread, window, {
let workspace = workspace.clone(); let workspace = workspace.clone();
move |this, thread, event, window, cx| { move |this, thread, event, window, cx| {
this.handle_acp_thread_event(&workspace, thread, event, window, cx) this.handle_acp_thread_event(&workspace, thread, event, window, cx)
@ -1357,11 +1358,11 @@ impl AgentDiff {
}), }),
}; };
if let Some(workspace_thread) = self.workspace_threads.get_mut(&workspace) { if let Some(workspace_thread) = self.workspace_threads.get_mut(workspace) {
// replace thread and action log subscription, but keep editors // replace thread and action log subscription, but keep editors
workspace_thread.thread = thread.downgrade(); workspace_thread.thread = thread.downgrade();
workspace_thread._thread_subscriptions = (action_log_subscription, thread_subscription); workspace_thread._thread_subscriptions = (action_log_subscription, thread_subscription);
self.update_reviewing_editors(&workspace, window, cx); self.update_reviewing_editors(workspace, window, cx);
return; return;
} }
@ -1506,7 +1507,7 @@ impl AgentDiff {
.read(cx) .read(cx)
.entries() .entries()
.last() .last()
.map_or(false, |entry| entry.diffs().next().is_some()) .is_some_and(|entry| entry.diffs().next().is_some())
{ {
self.update_reviewing_editors(workspace, window, cx); self.update_reviewing_editors(workspace, window, cx);
} }
@ -1516,16 +1517,19 @@ impl AgentDiff {
.read(cx) .read(cx)
.entries() .entries()
.get(*ix) .get(*ix)
.map_or(false, |entry| entry.diffs().next().is_some()) .is_some_and(|entry| entry.diffs().next().is_some())
{ {
self.update_reviewing_editors(workspace, window, cx); self.update_reviewing_editors(workspace, window, cx);
} }
} }
AcpThreadEvent::EntriesRemoved(_) AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::LoadError(_) => {
| AcpThreadEvent::Stopped self.update_reviewing_editors(workspace, window, cx);
}
AcpThreadEvent::TitleUpdated
| AcpThreadEvent::TokenUsageUpdated
| AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Error | AcpThreadEvent::Retry(_) => {}
| AcpThreadEvent::ServerExited(_) => {}
} }
} }
@ -1536,21 +1540,11 @@ impl AgentDiff {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
match event { if let workspace::Event::ItemAdded { item } = event
workspace::Event::ItemAdded { item } => { && let Some(editor) = item.downcast::<Editor>()
if let Some(editor) = item.downcast::<Editor>() { && let Some(buffer) = Self::full_editor_buffer(editor.read(cx), cx)
if let Some(buffer) = Self::full_editor_buffer(editor.read(cx), cx) { {
self.register_editor( self.register_editor(workspace.downgrade(), buffer.clone(), editor, window, cx);
workspace.downgrade(),
buffer.clone(),
editor,
window,
cx,
);
}
}
}
_ => {}
} }
} }
@ -1677,7 +1671,7 @@ impl AgentDiff {
editor.register_addon(EditorAgentDiffAddon); editor.register_addon(EditorAgentDiffAddon);
}); });
} else { } else {
unaffected.remove(&weak_editor); unaffected.remove(weak_editor);
} }
if new_state == EditorState::Reviewing && previous_state != Some(new_state) { if new_state == EditorState::Reviewing && previous_state != Some(new_state) {
@ -1710,7 +1704,7 @@ impl AgentDiff {
.read_with(cx, |editor, _cx| editor.workspace()) .read_with(cx, |editor, _cx| editor.workspace())
.ok() .ok()
.flatten() .flatten()
.map_or(false, |editor_workspace| { .is_some_and(|editor_workspace| {
editor_workspace.entity_id() == workspace.entity_id() editor_workspace.entity_id() == workspace.entity_id()
}); });
@ -1730,7 +1724,7 @@ impl AgentDiff {
fn editor_state(&self, editor: &WeakEntity<Editor>) -> EditorState { fn editor_state(&self, editor: &WeakEntity<Editor>) -> EditorState {
self.reviewing_editors self.reviewing_editors
.get(&editor) .get(editor)
.cloned() .cloned()
.unwrap_or(EditorState::Idle) .unwrap_or(EditorState::Idle)
} }
@ -1850,26 +1844,26 @@ impl AgentDiff {
let thread = thread.upgrade()?; let thread = thread.upgrade()?;
if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) { if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx)
if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() { && let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton()
let changed_buffers = thread.action_log(cx).read(cx).changed_buffers(cx); {
let changed_buffers = thread.action_log(cx).read(cx).changed_buffers(cx);
let mut keys = changed_buffers.keys().cycle(); let mut keys = changed_buffers.keys().cycle();
keys.find(|k| *k == &curr_buffer); keys.find(|k| *k == &curr_buffer);
let next_project_path = keys let next_project_path = keys
.next() .next()
.filter(|k| *k != &curr_buffer) .filter(|k| *k != &curr_buffer)
.and_then(|after| after.read(cx).project_path(cx)); .and_then(|after| after.read(cx).project_path(cx));
if let Some(path) = next_project_path { if let Some(path) = next_project_path {
let task = workspace.open_path(path, None, true, window, cx); let task = workspace.open_path(path, None, true, window, cx);
let task = cx.spawn(async move |_, _cx| task.await.map(|_| ())); let task = cx.spawn(async move |_, _cx| task.await.map(|_| ()));
return Some(task); return Some(task);
}
} }
} }
return Some(Task::ready(Ok(()))); Some(Task::ready(Ok(())))
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -156,11 +156,15 @@ enum ExternalAgent {
} }
impl ExternalAgent { impl ExternalAgent {
pub fn server(&self, fs: Arc<dyn fs::Fs>) -> Rc<dyn agent_servers::AgentServer> { pub fn server(
&self,
fs: Arc<dyn fs::Fs>,
history: Entity<agent2::HistoryStore>,
) -> Rc<dyn agent_servers::AgentServer> {
match self { match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)), ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs, history)),
} }
} }
} }
@ -320,7 +324,7 @@ fn init_language_model_settings(cx: &mut App) {
cx.subscribe( cx.subscribe(
&LanguageModelRegistry::global(cx), &LanguageModelRegistry::global(cx),
|_, event: &language_model::Event, cx| match event { |_, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_) | language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => { | language_model::Event::RemovedProvider(_) => {
update_active_language_model_from_settings(cx); update_active_language_model_from_settings(cx);

View file

@ -352,12 +352,12 @@ impl CodegenAlternative {
event: &multi_buffer::Event, event: &multi_buffer::Event,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let multi_buffer::Event::TransactionUndone { transaction_id } = event { if let multi_buffer::Event::TransactionUndone { transaction_id } = event
if self.transformation_transaction_id == Some(*transaction_id) { && self.transformation_transaction_id == Some(*transaction_id)
self.transformation_transaction_id = None; {
self.generation = Task::ready(()); self.transformation_transaction_id = None;
cx.emit(CodegenEvent::Undone); self.generation = Task::ready(());
} cx.emit(CodegenEvent::Undone);
} }
} }
@ -388,7 +388,7 @@ impl CodegenAlternative {
} else { } else {
let request = self.build_request(&model, user_prompt, cx)?; let request = self.build_request(&model, user_prompt, cx)?;
cx.spawn(async move |_, cx| { cx.spawn(async move |_, cx| {
Ok(model.stream_completion_text(request.await, &cx).await?) Ok(model.stream_completion_text(request.await, cx).await?)
}) })
.boxed_local() .boxed_local()
}; };
@ -447,7 +447,7 @@ impl CodegenAlternative {
} }
}); });
let temperature = AgentSettings::temperature_for_model(&model, cx); let temperature = AgentSettings::temperature_for_model(model, cx);
Ok(cx.spawn(async move |_cx| { Ok(cx.spawn(async move |_cx| {
let mut request_message = LanguageModelRequestMessage { let mut request_message = LanguageModelRequestMessage {
@ -576,38 +576,34 @@ impl CodegenAlternative {
let mut lines = chunk.split('\n').peekable(); let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() { while let Some(line) = lines.next() {
new_text.push_str(line); new_text.push_str(line);
if line_indent.is_none() { if line_indent.is_none()
if let Some(non_whitespace_ch_ix) = && let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace()) new_text.find(|ch: char| !ch.is_whitespace())
{ {
line_indent = Some(non_whitespace_ch_ix); line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent); base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap(); let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap(); let base_indent = base_indent.unwrap();
let indent_delta = let indent_delta = line_indent as i32 - base_indent as i32;
line_indent as i32 - base_indent as i32; let mut corrected_indent_len = cmp::max(
let mut corrected_indent_len = cmp::max( 0,
0, suggested_line_indent.len as i32 + indent_delta,
suggested_line_indent.len as i32 + indent_delta, )
) as usize;
as usize; if first_line {
if first_line { corrected_indent_len = corrected_indent_len
corrected_indent_len = corrected_indent_len .saturating_sub(selection_start.column as usize);
.saturating_sub(
selection_start.column as usize,
);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
} }
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
} }
if line_indent.is_some() { if line_indent.is_some() {
@ -1028,7 +1024,7 @@ where
chunk.push('\n'); chunk.push('\n');
} }
chunk.push_str(&line); chunk.push_str(line);
} }
consumed += line.len(); consumed += line.len();

View file

@ -385,12 +385,11 @@ impl ContextPicker {
} }
pub fn select_first(&mut self, window: &mut Window, cx: &mut Context<Self>) { pub fn select_first(&mut self, window: &mut Window, cx: &mut Context<Self>) {
match &self.mode { // Other variants already select their first entry on open automatically
ContextPickerState::Default(entity) => entity.update(cx, |entity, cx| { if let ContextPickerState::Default(entity) = &self.mode {
entity.update(cx, |entity, cx| {
entity.select_first(&Default::default(), window, cx) entity.select_first(&Default::default(), window, cx)
}), })
// Other variants already select their first entry on open automatically
_ => {}
} }
} }
@ -610,9 +609,7 @@ pub(crate) fn available_context_picker_entries(
.read(cx) .read(cx)
.active_item(cx) .active_item(cx)
.and_then(|item| item.downcast::<Editor>()) .and_then(|item| item.downcast::<Editor>())
.map_or(false, |editor| { .is_some_and(|editor| editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx)));
editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx))
});
if has_selection { if has_selection {
entries.push(ContextPickerEntry::Action( entries.push(ContextPickerEntry::Action(
ContextPickerAction::AddSelections, ContextPickerAction::AddSelections,
@ -680,7 +677,7 @@ pub(crate) fn recent_context_picker_entries(
.filter(|(_, abs_path)| { .filter(|(_, abs_path)| {
abs_path abs_path
.as_ref() .as_ref()
.map_or(true, |path| !exclude_paths.contains(path.as_path())) .is_none_or(|path| !exclude_paths.contains(path.as_path()))
}) })
.take(4) .take(4)
.filter_map(|(project_path, _)| { .filter_map(|(project_path, _)| {

View file

@ -728,11 +728,11 @@ fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx:
let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId); let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId);
let mut label = CodeLabel::default(); let mut label = CodeLabel::default();
label.push_str(&file_name, None); label.push_str(file_name, None);
label.push_str(" ", None); label.push_str(" ", None);
if let Some(directory) = directory { if let Some(directory) = directory {
label.push_str(&directory, comment_id); label.push_str(directory, comment_id);
} }
label.filter_range = 0..label.text().len(); label.filter_range = 0..label.text().len();
@ -1020,7 +1020,7 @@ impl MentionCompletion {
&& line && line
.chars() .chars()
.nth(last_mention_start - 1) .nth(last_mention_start - 1)
.map_or(false, |c| !c.is_whitespace()) .is_some_and(|c| !c.is_whitespace())
{ {
return None; return None;
} }

View file

@ -226,9 +226,10 @@ impl PickerDelegate for FetchContextPickerDelegate {
_window: &mut Window, _window: &mut Window,
cx: &mut Context<Picker<Self>>, cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> { ) -> Option<Self::ListItem> {
let added = self.context_store.upgrade().map_or(false, |context_store| { let added = self
context_store.read(cx).includes_url(&self.url) .context_store
}); .upgrade()
.is_some_and(|context_store| context_store.read(cx).includes_url(&self.url));
Some( Some(
ListItem::new(ix) ListItem::new(ix)

View file

@ -239,9 +239,7 @@ pub(crate) fn search_files(
PathMatchCandidateSet { PathMatchCandidateSet {
snapshot: worktree.snapshot(), snapshot: worktree.snapshot(),
include_ignored: worktree include_ignored: worktree.root_entry().is_some_and(|entry| entry.is_ignored),
.root_entry()
.map_or(false, |entry| entry.is_ignored),
include_root_name: true, include_root_name: true,
candidates: project::Candidates::Entries, candidates: project::Candidates::Entries,
} }
@ -315,7 +313,7 @@ pub fn render_file_context_entry(
context_store: WeakEntity<ContextStore>, context_store: WeakEntity<ContextStore>,
cx: &App, cx: &App,
) -> Stateful<Div> { ) -> Stateful<Div> {
let (file_name, directory) = extract_file_name_and_directory(&path, path_prefix); let (file_name, directory) = extract_file_name_and_directory(path, path_prefix);
let added = context_store.upgrade().and_then(|context_store| { let added = context_store.upgrade().and_then(|context_store| {
let project_path = ProjectPath { let project_path = ProjectPath {
@ -334,7 +332,7 @@ pub fn render_file_context_entry(
let file_icon = if is_directory { let file_icon = if is_directory {
FileIcons::get_folder_icon(false, cx) FileIcons::get_folder_icon(false, cx)
} else { } else {
FileIcons::get_icon(&path, cx) FileIcons::get_icon(path, cx)
} }
.map(Icon::from_path) .map(Icon::from_path)
.unwrap_or_else(|| Icon::new(IconName::File)); .unwrap_or_else(|| Icon::new(IconName::File));

View file

@ -159,7 +159,7 @@ pub fn render_thread_context_entry(
context_store: WeakEntity<ContextStore>, context_store: WeakEntity<ContextStore>,
cx: &mut App, cx: &mut App,
) -> Div { ) -> Div {
let added = context_store.upgrade().map_or(false, |context_store| { let added = context_store.upgrade().is_some_and(|context_store| {
context_store context_store
.read(cx) .read(cx)
.includes_user_rules(user_rules.prompt_id) .includes_user_rules(user_rules.prompt_id)

View file

@ -289,12 +289,12 @@ pub(crate) fn search_symbols(
.iter() .iter()
.enumerate() .enumerate()
.map(|(id, symbol)| { .map(|(id, symbol)| {
StringMatchCandidate::new(id, &symbol.label.filter_text()) StringMatchCandidate::new(id, symbol.label.filter_text())
}) })
.partition(|candidate| { .partition(|candidate| {
project project
.entry_for_path(&symbols[candidate.id].path, cx) .entry_for_path(&symbols[candidate.id].path, cx)
.map_or(false, |e| !e.is_ignored) .is_some_and(|e| !e.is_ignored)
}) })
}) })
.log_err() .log_err()

View file

@ -167,7 +167,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
return; return;
}; };
let open_thread_task = let open_thread_task =
thread_store.update(cx, |this, cx| this.open_thread(&id, window, cx)); thread_store.update(cx, |this, cx| this.open_thread(id, window, cx));
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let thread = open_thread_task.await?; let thread = open_thread_task.await?;
@ -236,12 +236,10 @@ pub fn render_thread_context_entry(
let is_added = match entry { let is_added = match entry {
ThreadContextEntry::Thread { id, .. } => context_store ThreadContextEntry::Thread { id, .. } => context_store
.upgrade() .upgrade()
.map_or(false, |ctx_store| ctx_store.read(cx).includes_thread(&id)), .is_some_and(|ctx_store| ctx_store.read(cx).includes_thread(id)),
ThreadContextEntry::Context { path, .. } => { ThreadContextEntry::Context { path, .. } => context_store
context_store.upgrade().map_or(false, |ctx_store| { .upgrade()
ctx_store.read(cx).includes_text_thread(path) .is_some_and(|ctx_store| ctx_store.read(cx).includes_text_thread(path)),
})
}
}; };
h_flex() h_flex()
@ -338,7 +336,7 @@ pub(crate) fn search_threads(
let candidates = threads let candidates = threads
.iter() .iter()
.enumerate() .enumerate()
.map(|(id, (_, thread))| StringMatchCandidate::new(id, &thread.title())) .map(|(id, (_, thread))| StringMatchCandidate::new(id, thread.title()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let matches = fuzzy::match_strings( let matches = fuzzy::match_strings(
&candidates, &candidates,

View file

@ -145,7 +145,7 @@ impl ContextStrip {
} }
let file_name = active_buffer.file()?.file_name(cx); let file_name = active_buffer.file()?.file_name(cx);
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx); let icon_path = FileIcons::get_icon(Path::new(&file_name), cx);
Some(SuggestedContext::File { Some(SuggestedContext::File {
name: file_name.to_string_lossy().into_owned().into(), name: file_name.to_string_lossy().into_owned().into(),
buffer: active_buffer_entity.downgrade(), buffer: active_buffer_entity.downgrade(),
@ -368,16 +368,16 @@ impl ContextStrip {
_window: &mut Window, _window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let Some(suggested) = self.suggested_context(cx) { if let Some(suggested) = self.suggested_context(cx)
if self.is_suggested_focused(&self.added_contexts(cx)) { && self.is_suggested_focused(&self.added_contexts(cx))
self.add_suggested_context(&suggested, cx); {
} self.add_suggested_context(&suggested, cx);
} }
} }
fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) { fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) {
self.context_store.update(cx, |context_store, cx| { self.context_store.update(cx, |context_store, cx| {
context_store.add_suggested_context(&suggested, cx) context_store.add_suggested_context(suggested, cx)
}); });
cx.notify(); cx.notify();
} }

View file

@ -182,13 +182,13 @@ impl InlineAssistant {
match event { match event {
workspace::Event::UserSavedItem { item, .. } => { workspace::Event::UserSavedItem { item, .. } => {
// When the user manually saves an editor, automatically accepts all finished transformations. // When the user manually saves an editor, automatically accepts all finished transformations.
if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) { if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx))
if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) { && let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade())
for assist_id in editor_assists.assist_ids.clone() { {
let assist = &self.assists[&assist_id]; for assist_id in editor_assists.assist_ids.clone() {
if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) { let assist = &self.assists[&assist_id];
self.finish_assist(assist_id, false, window, cx) if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
} self.finish_assist(assist_id, false, window, cx)
} }
} }
} }
@ -342,13 +342,11 @@ impl InlineAssistant {
) )
.await .await
.ok(); .ok();
if let Some(answer) = answer { if let Some(answer) = answer
if answer == 0 { && answer == 0
cx.update(|window, cx| { {
window.dispatch_action(Box::new(OpenSettings), cx) cx.update(|window, cx| window.dispatch_action(Box::new(OpenSettings), cx))
})
.ok(); .ok();
}
} }
anyhow::Ok(()) anyhow::Ok(())
}) })
@ -435,11 +433,11 @@ impl InlineAssistant {
} }
} }
if let Some(prev_selection) = selections.last_mut() { if let Some(prev_selection) = selections.last_mut()
if selection.start <= prev_selection.end { && selection.start <= prev_selection.end
prev_selection.end = selection.end; {
continue; prev_selection.end = selection.end;
} continue;
} }
let latest_selection = newest_selection.get_or_insert_with(|| selection.clone()); let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
@ -526,9 +524,9 @@ impl InlineAssistant {
if assist_to_focus.is_none() { if assist_to_focus.is_none() {
let focus_assist = if newest_selection.reversed { let focus_assist = if newest_selection.reversed {
range.start.to_point(&snapshot) == newest_selection.start range.start.to_point(snapshot) == newest_selection.start
} else { } else {
range.end.to_point(&snapshot) == newest_selection.end range.end.to_point(snapshot) == newest_selection.end
}; };
if focus_assist { if focus_assist {
assist_to_focus = Some(assist_id); assist_to_focus = Some(assist_id);
@ -550,7 +548,7 @@ impl InlineAssistant {
let editor_assists = self let editor_assists = self
.assists_by_editor .assists_by_editor
.entry(editor.downgrade()) .entry(editor.downgrade())
.or_insert_with(|| EditorInlineAssists::new(&editor, window, cx)); .or_insert_with(|| EditorInlineAssists::new(editor, window, cx));
let mut assist_group = InlineAssistGroup::new(); let mut assist_group = InlineAssistGroup::new();
for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists { for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
let codegen = prompt_editor.read(cx).codegen().clone(); let codegen = prompt_editor.read(cx).codegen().clone();
@ -649,7 +647,7 @@ impl InlineAssistant {
let editor_assists = self let editor_assists = self
.assists_by_editor .assists_by_editor
.entry(editor.downgrade()) .entry(editor.downgrade())
.or_insert_with(|| EditorInlineAssists::new(&editor, window, cx)); .or_insert_with(|| EditorInlineAssists::new(editor, window, cx));
let mut assist_group = InlineAssistGroup::new(); let mut assist_group = InlineAssistGroup::new();
self.assists.insert( self.assists.insert(
@ -985,14 +983,13 @@ impl InlineAssistant {
EditorEvent::SelectionsChanged { .. } => { EditorEvent::SelectionsChanged { .. } => {
for assist_id in editor_assists.assist_ids.clone() { for assist_id in editor_assists.assist_ids.clone() {
let assist = &self.assists[&assist_id]; let assist = &self.assists[&assist_id];
if let Some(decorations) = assist.decorations.as_ref() { if let Some(decorations) = assist.decorations.as_ref()
if decorations && decorations
.prompt_editor .prompt_editor
.focus_handle(cx) .focus_handle(cx)
.is_focused(window) .is_focused(window)
{ {
return; return;
}
} }
} }
@ -1123,7 +1120,7 @@ impl InlineAssistant {
if editor_assists if editor_assists
.scroll_lock .scroll_lock
.as_ref() .as_ref()
.map_or(false, |lock| lock.assist_id == assist_id) .is_some_and(|lock| lock.assist_id == assist_id)
{ {
editor_assists.scroll_lock = None; editor_assists.scroll_lock = None;
} }
@ -1503,20 +1500,18 @@ impl InlineAssistant {
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
) -> Option<InlineAssistTarget> { ) -> Option<InlineAssistTarget> {
if let Some(terminal_panel) = workspace.panel::<TerminalPanel>(cx) { if let Some(terminal_panel) = workspace.panel::<TerminalPanel>(cx)
if terminal_panel && terminal_panel
.read(cx) .read(cx)
.focus_handle(cx) .focus_handle(cx)
.contains_focused(window, cx) .contains_focused(window, cx)
{ && let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| {
if let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| { pane.read(cx)
pane.read(cx) .active_item()
.active_item() .and_then(|t| t.downcast::<TerminalView>())
.and_then(|t| t.downcast::<TerminalView>()) })
}) { {
return Some(InlineAssistTarget::Terminal(terminal_view)); return Some(InlineAssistTarget::Terminal(terminal_view));
}
}
} }
let context_editor = agent_panel let context_editor = agent_panel
@ -1741,22 +1736,20 @@ impl InlineAssist {
return; return;
}; };
if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) { if let CodegenStatus::Error(error) = codegen.read(cx).status(cx)
if assist.decorations.is_none() { && assist.decorations.is_none()
if let Some(workspace) = assist.workspace.upgrade() { && let Some(workspace) = assist.workspace.upgrade()
let error = format!("Inline assistant error: {}", error); {
workspace.update(cx, |workspace, cx| { let error = format!("Inline assistant error: {}", error);
struct InlineAssistantError; workspace.update(cx, |workspace, cx| {
struct InlineAssistantError;
let id = let id = NotificationId::composite::<InlineAssistantError>(
NotificationId::composite::<InlineAssistantError>( assist_id.0,
assist_id.0, );
);
workspace.show_toast(Toast::new(id, error), cx); workspace.show_toast(Toast::new(id, error), cx);
}) })
}
}
} }
if assist.decorations.is_none() { if assist.decorations.is_none() {
@ -1821,18 +1814,18 @@ impl CodeActionProvider for AssistantCodeActionProvider {
has_diagnostics = true; has_diagnostics = true;
} }
if has_diagnostics { if has_diagnostics {
if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) { if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None)
if let Some(symbol) = symbols_containing_start.last() { && let Some(symbol) = symbols_containing_start.last()
range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); {
range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
} range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
} }
if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) { if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None)
if let Some(symbol) = symbols_containing_end.last() { && let Some(symbol) = symbols_containing_end.last()
range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); {
range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
} range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
} }
Task::ready(Ok(vec![CodeAction { Task::ready(Ok(vec![CodeAction {

View file

@ -75,7 +75,7 @@ impl<T: 'static> Render for PromptEditor<T> {
let codegen = codegen.read(cx); let codegen = codegen.read(cx);
if codegen.alternative_count(cx) > 1 { if codegen.alternative_count(cx) > 1 {
buttons.push(self.render_cycle_controls(&codegen, cx)); buttons.push(self.render_cycle_controls(codegen, cx));
} }
let editor_margins = editor_margins.lock(); let editor_margins = editor_margins.lock();
@ -345,7 +345,7 @@ impl<T: 'static> PromptEditor<T> {
let prompt = self.editor.read(cx).text(cx); let prompt = self.editor.read(cx).text(cx);
if self if self
.prompt_history_ix .prompt_history_ix
.map_or(true, |ix| self.prompt_history[ix] != prompt) .is_none_or(|ix| self.prompt_history[ix] != prompt)
{ {
self.prompt_history_ix.take(); self.prompt_history_ix.take();
self.pending_prompt = prompt; self.pending_prompt = prompt;

View file

@ -104,7 +104,7 @@ impl LanguageModelPickerDelegate {
window, window,
|picker, _, event, window, cx| { |picker, _, event, window, cx| {
match event { match event {
language_model::Event::ProviderStateChanged language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_) | language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => { | language_model::Event::RemovedProvider(_) => {
let query = picker.query(cx); let query = picker.query(cx);
@ -296,7 +296,7 @@ impl ModelMatcher {
pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> { pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
let mut matches = self.bg_executor.block(match_strings( let mut matches = self.bg_executor.block(match_strings(
&self.candidates, &self.candidates,
&query, query,
false, false,
true, true,
100, 100,

View file

@ -117,7 +117,7 @@ pub(crate) fn create_editor(
let mut editor = Editor::new( let mut editor = Editor::new(
editor::EditorMode::AutoHeight { editor::EditorMode::AutoHeight {
min_lines, min_lines,
max_lines: max_lines, max_lines,
}, },
buffer, buffer,
None, None,
@ -156,7 +156,7 @@ impl ProfileProvider for Entity<Thread> {
fn profiles_supported(&self, cx: &App) -> bool { fn profiles_supported(&self, cx: &App) -> bool {
self.read(cx) self.read(cx)
.configured_model() .configured_model()
.map_or(false, |model| model.model.supports_tools()) .is_some_and(|model| model.model.supports_tools())
} }
fn profile_id(&self, cx: &App) -> AgentProfileId { fn profile_id(&self, cx: &App) -> AgentProfileId {
@ -215,9 +215,10 @@ impl MessageEditor {
let subscriptions = vec![ let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
cx.subscribe(&editor, |this, _, event, cx| match event { cx.subscribe(&editor, |this, _, event: &EditorEvent, cx| {
EditorEvent::BufferEdited => this.handle_message_changed(cx), if event == &EditorEvent::BufferEdited {
_ => {} this.handle_message_changed(cx)
}
}), }),
cx.observe(&context_store, |this, _, cx| { cx.observe(&context_store, |this, _, cx| {
// When context changes, reload it for token counting. // When context changes, reload it for token counting.
@ -690,11 +691,7 @@ impl MessageEditor {
.as_ref() .as_ref()
.map(|model| { .map(|model| {
self.incompatible_tools_state.update(cx, |state, cx| { self.incompatible_tools_state.update(cx, |state, cx| {
state state.incompatible_tools(&model.model, cx).to_vec()
.incompatible_tools(&model.model, cx)
.iter()
.cloned()
.collect::<Vec<_>>()
}) })
}) })
.unwrap_or_default(); .unwrap_or_default();
@ -1136,7 +1133,7 @@ impl MessageEditor {
) )
.when(is_edit_changes_expanded, |parent| { .when(is_edit_changes_expanded, |parent| {
parent.child( parent.child(
v_flex().children(changed_buffers.into_iter().enumerate().flat_map( v_flex().children(changed_buffers.iter().enumerate().flat_map(
|(index, (buffer, _diff))| { |(index, (buffer, _diff))| {
let file = buffer.read(cx).file()?; let file = buffer.read(cx).file()?;
let path = file.path(); let path = file.path();
@ -1166,7 +1163,7 @@ impl MessageEditor {
.buffer_font(cx) .buffer_font(cx)
}); });
let file_icon = FileIcons::get_icon(&path, cx) let file_icon = FileIcons::get_icon(path, cx)
.map(Icon::from_path) .map(Icon::from_path)
.map(|icon| icon.color(Color::Muted).size(IconSize::Small)) .map(|icon| icon.color(Color::Muted).size(IconSize::Small))
.unwrap_or_else(|| { .unwrap_or_else(|| {
@ -1293,7 +1290,7 @@ impl MessageEditor {
self.thread self.thread
.read(cx) .read(cx)
.configured_model() .configured_model()
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID) .is_some_and(|model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
} }
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> { fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {
@ -1323,14 +1320,10 @@ impl MessageEditor {
token_usage_ratio: TokenUsageRatio, token_usage_ratio: TokenUsageRatio,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Option<Div> { ) -> Option<Div> {
let icon = if token_usage_ratio == TokenUsageRatio::Exceeded { let (icon, severity) = if token_usage_ratio == TokenUsageRatio::Exceeded {
Icon::new(IconName::Close) (IconName::Close, Severity::Error)
.color(Color::Error)
.size(IconSize::XSmall)
} else { } else {
Icon::new(IconName::Warning) (IconName::Warning, Severity::Warning)
.color(Color::Warning)
.size(IconSize::XSmall)
}; };
let title = if token_usage_ratio == TokenUsageRatio::Exceeded { let title = if token_usage_ratio == TokenUsageRatio::Exceeded {
@ -1345,30 +1338,34 @@ impl MessageEditor {
"To continue, start a new thread from a summary." "To continue, start a new thread from a summary."
}; };
let mut callout = Callout::new() let callout = Callout::new()
.line_height(line_height) .line_height(line_height)
.severity(severity)
.icon(icon) .icon(icon)
.title(title) .title(title)
.description(description) .description(description)
.primary_action( .actions_slot(
Button::new("start-new-thread", "Start New Thread") h_flex()
.label_size(LabelSize::Small) .gap_0p5()
.on_click(cx.listener(|this, _, window, cx| { .when(self.is_using_zed_provider(cx), |this| {
let from_thread_id = Some(this.thread.read(cx).id().clone()); this.child(
window.dispatch_action(Box::new(NewThread { from_thread_id }), cx); IconButton::new("burn-mode-callout", IconName::ZedBurnMode)
})), .icon_size(IconSize::XSmall)
.on_click(cx.listener(|this, _event, window, cx| {
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
})),
)
})
.child(
Button::new("start-new-thread", "Start New Thread")
.label_size(LabelSize::Small)
.on_click(cx.listener(|this, _, window, cx| {
let from_thread_id = Some(this.thread.read(cx).id().clone());
window.dispatch_action(Box::new(NewThread { from_thread_id }), cx);
})),
),
); );
if self.is_using_zed_provider(cx) {
callout = callout.secondary_action(
IconButton::new("burn-mode-callout", IconName::ZedBurnMode)
.icon_size(IconSize::XSmall)
.on_click(cx.listener(|this, _event, window, cx| {
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
})),
);
}
Some( Some(
div() div()
.border_t_1() .border_t_1()
@ -1446,7 +1443,7 @@ impl MessageEditor {
let message_text = editor.read(cx).text(cx); let message_text = editor.read(cx).text(cx);
if message_text.is_empty() if message_text.is_empty()
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty()) && loaded_context.is_none_or(|loaded_context| loaded_context.is_empty())
{ {
return None; return None;
} }
@ -1559,9 +1556,8 @@ impl ContextCreasesAddon {
cx: &mut Context<Editor>, cx: &mut Context<Editor>,
) { ) {
self.creases.entry(key).or_default().extend(creases); self.creases.entry(key).or_default().extend(creases);
self._subscription = Some(cx.subscribe( self._subscription = Some(
&context_store, cx.subscribe(context_store, |editor, _, event, cx| match event {
|editor, _, event, cx| match event {
ContextStoreEvent::ContextRemoved(key) => { ContextStoreEvent::ContextRemoved(key) => {
let Some(this) = editor.addon_mut::<Self>() else { let Some(this) = editor.addon_mut::<Self>() else {
return; return;
@ -1581,8 +1577,8 @@ impl ContextCreasesAddon {
editor.edit(ranges.into_iter().zip(replacement_texts), cx); editor.edit(ranges.into_iter().zip(replacement_texts), cx);
cx.notify(); cx.notify();
} }
}, }),
)) )
} }
pub fn into_inner(self) -> HashMap<AgentContextKey, Vec<(CreaseId, SharedString)>> { pub fn into_inner(self) -> HashMap<AgentContextKey, Vec<(CreaseId, SharedString)>> {
@ -1610,7 +1606,8 @@ pub fn extract_message_creases(
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
// Filter the addon's list of creases based on what the editor reports, // Filter the addon's list of creases based on what the editor reports,
// since the addon might have removed creases in it. // since the addon might have removed creases in it.
let creases = editor.display_map.update(cx, |display_map, cx| {
editor.display_map.update(cx, |display_map, cx| {
display_map display_map
.snapshot(cx) .snapshot(cx)
.crease_snapshot .crease_snapshot
@ -1634,8 +1631,7 @@ pub fn extract_message_creases(
} }
}) })
.collect() .collect()
}); })
creases
} }
impl EventEmitter<MessageEditorEvent> for MessageEditor {} impl EventEmitter<MessageEditorEvent> for MessageEditor {}

View file

@ -140,12 +140,10 @@ impl PickerDelegate for SlashCommandDelegate {
); );
ret.push(index - 1); ret.push(index - 1);
} }
} else { } else if let SlashCommandEntry::Advert { .. } = command {
if let SlashCommandEntry::Advert { .. } = command { previous_is_advert = true;
previous_is_advert = true; if index != 0 {
if index != 0 { ret.push(index - 1);
ret.push(index - 1);
}
} }
} }
} }
@ -214,7 +212,7 @@ impl PickerDelegate for SlashCommandDelegate {
let mut label = format!("{}", info.name); let mut label = format!("{}", info.name);
if let Some(args) = info.args.as_ref().filter(|_| selected) if let Some(args) = info.args.as_ref().filter(|_| selected)
{ {
label.push_str(&args); label.push_str(args);
} }
Label::new(label) Label::new(label)
.single_line() .single_line()
@ -329,9 +327,7 @@ where
}; };
let picker_view = cx.new(|cx| { let picker_view = cx.new(|cx| {
let picker = Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()));
picker
}); });
let handle = self let handle = self

View file

@ -48,7 +48,7 @@ impl TerminalCodegen {
let prompt = prompt_task.await; let prompt = prompt_task.await;
let model_telemetry_id = model.telemetry_id(); let model_telemetry_id = model.telemetry_id();
let model_provider_id = model.provider_id(); let model_provider_id = model.provider_id();
let response = model.stream_completion_text(prompt, &cx).await; let response = model.stream_completion_text(prompt, cx).await;
let generate = async { let generate = async {
let message_id = response let message_id = response
.as_ref() .as_ref()

View file

@ -388,20 +388,20 @@ impl TerminalInlineAssistant {
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut App,
) { ) {
if let Some(assist) = self.assists.get_mut(&assist_id) { if let Some(assist) = self.assists.get_mut(&assist_id)
if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() { && let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned()
assist {
.terminal assist
.update(cx, |terminal, cx| { .terminal
terminal.clear_block_below_cursor(cx); .update(cx, |terminal, cx| {
let block = terminal_view::BlockProperties { terminal.clear_block_below_cursor(cx);
height, let block = terminal_view::BlockProperties {
render: Box::new(move |_| prompt_editor.clone().into_any_element()), height,
}; render: Box::new(move |_| prompt_editor.clone().into_any_element()),
terminal.set_block_below_cursor(block, window, cx); };
}) terminal.set_block_below_cursor(block, window, cx);
.log_err(); })
} .log_err();
} }
} }
} }
@ -450,23 +450,20 @@ impl TerminalInlineAssist {
return; return;
}; };
if let CodegenStatus::Error(error) = &codegen.read(cx).status { if let CodegenStatus::Error(error) = &codegen.read(cx).status
if assist.prompt_editor.is_none() { && assist.prompt_editor.is_none()
if let Some(workspace) = assist.workspace.upgrade() { && let Some(workspace) = assist.workspace.upgrade()
let error = {
format!("Terminal inline assistant error: {}", error); let error = format!("Terminal inline assistant error: {}", error);
workspace.update(cx, |workspace, cx| { workspace.update(cx, |workspace, cx| {
struct InlineAssistantError; struct InlineAssistantError;
let id = let id = NotificationId::composite::<InlineAssistantError>(
NotificationId::composite::<InlineAssistantError>( assist_id.0,
assist_id.0, );
);
workspace.show_toast(Toast::new(id, error), cx); workspace.show_toast(Toast::new(id, error), cx);
}) })
}
}
} }
if assist.prompt_editor.is_none() { if assist.prompt_editor.is_none() {

View file

@ -373,7 +373,7 @@ impl TextThreadEditor {
.map(|default| default.provider); .map(|default| default.provider);
if provider if provider
.as_ref() .as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx)) .is_some_and(|provider| provider.must_accept_terms(cx))
{ {
self.show_accept_terms = true; self.show_accept_terms = true;
cx.notify(); cx.notify();
@ -457,7 +457,7 @@ impl TextThreadEditor {
|| snapshot || snapshot
.chars_at(newest_cursor) .chars_at(newest_cursor)
.next() .next()
.map_or(false, |ch| ch != '\n') .is_some_and(|ch| ch != '\n')
{ {
editor.move_to_end_of_line( editor.move_to_end_of_line(
&MoveToEndOfLine { &MoveToEndOfLine {
@ -540,7 +540,7 @@ impl TextThreadEditor {
let context = self.context.read(cx); let context = self.context.read(cx);
let sections = context let sections = context
.slash_command_output_sections() .slash_command_output_sections()
.into_iter() .iter()
.filter(|section| section.is_valid(context.buffer().read(cx))) .filter(|section| section.is_valid(context.buffer().read(cx)))
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -745,32 +745,27 @@ impl TextThreadEditor {
) { ) {
if let Some(invoked_slash_command) = if let Some(invoked_slash_command) =
self.context.read(cx).invoked_slash_command(&command_id) self.context.read(cx).invoked_slash_command(&command_id)
&& let InvokedSlashCommandStatus::Finished = invoked_slash_command.status
{ {
if let InvokedSlashCommandStatus::Finished = invoked_slash_command.status { let run_commands_in_ranges = invoked_slash_command.run_commands_in_ranges.clone();
let run_commands_in_ranges = invoked_slash_command for range in run_commands_in_ranges {
.run_commands_in_ranges let commands = self.context.update(cx, |context, cx| {
.iter() context.reparse(cx);
.cloned() context
.collect::<Vec<_>>(); .pending_commands_for_range(range.clone(), cx)
for range in run_commands_in_ranges { .to_vec()
let commands = self.context.update(cx, |context, cx| { });
context.reparse(cx);
context
.pending_commands_for_range(range.clone(), cx)
.to_vec()
});
for command in commands { for command in commands {
self.run_command( self.run_command(
command.source_range, command.source_range,
&command.name, &command.name,
&command.arguments, &command.arguments,
false, false,
self.workspace.clone(), self.workspace.clone(),
window, window,
cx, cx,
); );
}
} }
} }
} }
@ -1242,7 +1237,7 @@ impl TextThreadEditor {
let mut new_blocks = vec![]; let mut new_blocks = vec![];
let mut block_index_to_message = vec![]; let mut block_index_to_message = vec![];
for message in self.context.read(cx).messages(cx) { for message in self.context.read(cx).messages(cx) {
if let Some(_) = blocks_to_remove.remove(&message.id) { if blocks_to_remove.remove(&message.id).is_some() {
// This is an old message that we might modify. // This is an old message that we might modify.
let Some((meta, block_id)) = old_blocks.get_mut(&message.id) else { let Some((meta, block_id)) = old_blocks.get_mut(&message.id) else {
debug_assert!( debug_assert!(
@ -1280,7 +1275,7 @@ impl TextThreadEditor {
context_editor_view: &Entity<TextThreadEditor>, context_editor_view: &Entity<TextThreadEditor>,
cx: &mut Context<Workspace>, cx: &mut Context<Workspace>,
) -> Option<(String, bool)> { ) -> Option<(String, bool)> {
const CODE_FENCE_DELIMITER: &'static str = "```"; const CODE_FENCE_DELIMITER: &str = "```";
let context_editor = context_editor_view.read(cx).editor.clone(); let context_editor = context_editor_view.read(cx).editor.clone();
context_editor.update(cx, |context_editor, cx| { context_editor.update(cx, |context_editor, cx| {
@ -2166,8 +2161,8 @@ impl TextThreadEditor {
/// Returns the contents of the *outermost* fenced code block that contains the given offset. /// Returns the contents of the *outermost* fenced code block that contains the given offset.
fn find_surrounding_code_block(snapshot: &BufferSnapshot, offset: usize) -> Option<Range<usize>> { fn find_surrounding_code_block(snapshot: &BufferSnapshot, offset: usize) -> Option<Range<usize>> {
const CODE_BLOCK_NODE: &'static str = "fenced_code_block"; const CODE_BLOCK_NODE: &str = "fenced_code_block";
const CODE_BLOCK_CONTENT: &'static str = "code_fence_content"; const CODE_BLOCK_CONTENT: &str = "code_fence_content";
let layer = snapshot.syntax_layers().next()?; let layer = snapshot.syntax_layers().next()?;
@ -3134,7 +3129,7 @@ mod tests {
let context_editor = window let context_editor = window
.update(&mut cx, |_, window, cx| { .update(&mut cx, |_, window, cx| {
cx.new(|cx| { cx.new(|cx| {
let editor = TextThreadEditor::for_context( TextThreadEditor::for_context(
context.clone(), context.clone(),
fs, fs,
workspace.downgrade(), workspace.downgrade(),
@ -3142,8 +3137,7 @@ mod tests {
None, None,
window, window,
cx, cx,
); )
editor
}) })
}) })
.unwrap(); .unwrap();

View file

@ -161,14 +161,13 @@ impl ThreadHistory {
this.all_entries.len().saturating_sub(1), this.all_entries.len().saturating_sub(1),
cx, cx,
); );
} else if let Some(prev_id) = previously_selected_entry { } else if let Some(prev_id) = previously_selected_entry
if let Some(new_ix) = this && let Some(new_ix) = this
.all_entries .all_entries
.iter() .iter()
.position(|probe| probe.id() == prev_id) .position(|probe| probe.id() == prev_id)
{ {
this.set_selected_entry_index(new_ix, cx); this.set_selected_entry_index(new_ix, cx);
}
} }
} }
SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => { SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {

View file

@ -14,13 +14,11 @@ pub struct IncompatibleToolsState {
impl IncompatibleToolsState { impl IncompatibleToolsState {
pub fn new(thread: Entity<Thread>, cx: &mut Context<Self>) -> Self { pub fn new(thread: Entity<Thread>, cx: &mut Context<Self>) -> Self {
let _tool_working_set_subscription = let _tool_working_set_subscription = cx.subscribe(&thread, |this, _, event, _| {
cx.subscribe(&thread, |this, _, event, _| match event { if let ThreadEvent::ProfileChanged = event {
ThreadEvent::ProfileChanged => { this.cache.clear();
this.cache.clear(); }
} });
_ => {}
});
Self { Self {
cache: HashMap::default(), cache: HashMap::default(),

View file

@ -353,7 +353,7 @@ impl AddedContext {
name, name,
parent, parent,
tooltip: Some(full_path_string), tooltip: Some(full_path_string),
icon_path: FileIcons::get_icon(&full_path, cx), icon_path: FileIcons::get_icon(full_path, cx),
status: ContextStatus::Ready, status: ContextStatus::Ready,
render_hover: None, render_hover: None,
handle: AgentContextHandle::File(handle), handle: AgentContextHandle::File(handle),
@ -615,7 +615,7 @@ impl AddedContext {
let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into(); let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into();
let (name, parent) = let (name, parent) =
extract_file_name_and_directory_from_full_path(full_path, &full_path_string); extract_file_name_and_directory_from_full_path(full_path, &full_path_string);
let icon_path = FileIcons::get_icon(&full_path, cx); let icon_path = FileIcons::get_icon(full_path, cx);
(name, parent, icon_path) (name, parent, icon_path)
} else { } else {
("Image".into(), None, None) ("Image".into(), None, None)
@ -706,7 +706,7 @@ impl ContextFileExcerpt {
.and_then(|p| p.file_name()) .and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into()); .map(|n| n.to_string_lossy().into_owned().into());
let icon_path = FileIcons::get_icon(&full_path, cx); let icon_path = FileIcons::get_icon(full_path, cx);
ContextFileExcerpt { ContextFileExcerpt {
file_name_and_range: file_name_and_range.into(), file_name_and_range: file_name_and_range.into(),

View file

@ -80,14 +80,10 @@ impl RenderOnce for UsageCallout {
} }
}; };
let icon = if is_limit_reached { let (icon, severity) = if is_limit_reached {
Icon::new(IconName::Close) (IconName::Close, Severity::Error)
.color(Color::Error)
.size(IconSize::XSmall)
} else { } else {
Icon::new(IconName::Warning) (IconName::Warning, Severity::Warning)
.color(Color::Warning)
.size(IconSize::XSmall)
}; };
div() div()
@ -95,10 +91,12 @@ impl RenderOnce for UsageCallout {
.border_color(cx.theme().colors().border) .border_color(cx.theme().colors().border)
.child( .child(
Callout::new() Callout::new()
.icon(icon)
.severity(severity)
.icon(icon) .icon(icon)
.title(title) .title(title)
.description(message) .description(message)
.primary_action( .actions_slot(
Button::new("upgrade", button_text) Button::new("upgrade", button_text)
.label_size(LabelSize::Small) .label_size(LabelSize::Small)
.on_click(move |_, _, cx| { .on_click(move |_, _, cx| {

View file

@ -11,7 +11,7 @@ impl ApiKeysWithProviders {
cx.subscribe( cx.subscribe(
&LanguageModelRegistry::global(cx), &LanguageModelRegistry::global(cx),
|this: &mut Self, _registry, event: &language_model::Event, cx| match event { |this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_) | language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => { | language_model::Event::RemovedProvider(_) => {
this.configured_providers = Self::compute_configured_providers(cx) this.configured_providers = Self::compute_configured_providers(cx)

View file

@ -25,7 +25,7 @@ impl AgentPanelOnboarding {
cx.subscribe( cx.subscribe(
&LanguageModelRegistry::global(cx), &LanguageModelRegistry::global(cx),
|this: &mut Self, _registry, event: &language_model::Event, cx| match event { |this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_) | language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => { | language_model::Event::RemovedProvider(_) => {
this.configured_providers = Self::compute_available_providers(cx) this.configured_providers = Self::compute_available_providers(cx)

View file

@ -332,17 +332,25 @@ impl ZedAiOnboarding {
.mb_2(), .mb_2(),
) )
.child(plan_definitions.pro_plan(false)) .child(plan_definitions.pro_plan(false))
.child( .when_some(
Button::new("pro", "Continue with Zed Pro") self.dismiss_onboarding.as_ref(),
.full_width() |this, dismiss_callback| {
.style(ButtonStyle::Outlined) let callback = dismiss_callback.clone();
.on_click({ this.child(
let callback = self.continue_with_zed_ai.clone(); h_flex().absolute().top_0().right_0().child(
move |_, window, cx| { IconButton::new("dismiss_onboarding", IconName::Close)
telemetry::event!("Banner Dismissed", source = "AI Onboarding"); .icon_size(IconSize::Small)
callback(window, cx) .tooltip(Tooltip::text("Dismiss"))
} .on_click(move |_, window, cx| {
}), telemetry::event!(
"Banner Dismissed",
source = "AI Onboarding",
);
callback(window, cx)
}),
),
)
},
) )
.into_any_element() .into_any_element()
} }

View file

@ -17,6 +17,6 @@ impl RenderOnce for YoungAccountBanner {
div() div()
.max_w_full() .max_w_full()
.my_1() .my_1()
.child(Banner::new().severity(ui::Severity::Warning).child(label)) .child(Banner::new().severity(Severity::Warning).child(label))
} }
} }

View file

@ -177,11 +177,11 @@ impl AskPassSession {
_ = askpass_opened_rx.fuse() => { _ = askpass_opened_rx.fuse() => {
// Note: this await can only resolve after we are dropped. // Note: this await can only resolve after we are dropped.
askpass_kill_master_rx.await.ok(); askpass_kill_master_rx.await.ok();
return AskPassResult::CancelledByUser AskPassResult::CancelledByUser
} }
_ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => { _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
return AskPassResult::Timedout AskPassResult::Timedout
} }
} }
} }
@ -215,7 +215,7 @@ pub fn main(socket: &str) {
} }
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
while buffer.last().map_or(false, |&b| b == b'\n' || b == b'\r') { while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
buffer.pop(); buffer.pop();
} }
if buffer.last() != Some(&b'\0') { if buffer.last() != Some(&b'\0') {

View file

@ -590,17 +590,16 @@ impl From<&Message> for MessageMetadata {
impl MessageMetadata { impl MessageMetadata {
pub fn is_cache_valid(&self, buffer: &BufferSnapshot, range: &Range<usize>) -> bool { pub fn is_cache_valid(&self, buffer: &BufferSnapshot, range: &Range<usize>) -> bool {
let result = match &self.cache { match &self.cache {
Some(MessageCacheMetadata { cached_at, .. }) => !buffer.has_edits_since_in_range( Some(MessageCacheMetadata { cached_at, .. }) => !buffer.has_edits_since_in_range(
&cached_at, cached_at,
Range { Range {
start: buffer.anchor_at(range.start, Bias::Right), start: buffer.anchor_at(range.start, Bias::Right),
end: buffer.anchor_at(range.end, Bias::Left), end: buffer.anchor_at(range.end, Bias::Left),
}, },
), ),
_ => false, _ => false,
}; }
result
} }
} }
@ -1023,9 +1022,11 @@ impl AssistantContext {
summary: new_summary, summary: new_summary,
.. ..
} => { } => {
if self.summary.timestamp().map_or(true, |current_timestamp| { if self
new_summary.timestamp > current_timestamp .summary
}) { .timestamp()
.is_none_or(|current_timestamp| new_summary.timestamp > current_timestamp)
{
self.summary = ContextSummary::Content(new_summary); self.summary = ContextSummary::Content(new_summary);
summary_generated = true; summary_generated = true;
} }
@ -1076,20 +1077,20 @@ impl AssistantContext {
timestamp, timestamp,
.. ..
} => { } => {
if let Some(slash_command) = self.invoked_slash_commands.get_mut(&id) { if let Some(slash_command) = self.invoked_slash_commands.get_mut(&id)
if timestamp > slash_command.timestamp { && timestamp > slash_command.timestamp
slash_command.timestamp = timestamp; {
match error_message { slash_command.timestamp = timestamp;
Some(message) => { match error_message {
slash_command.status = Some(message) => {
InvokedSlashCommandStatus::Error(message.into()); slash_command.status =
} InvokedSlashCommandStatus::Error(message.into());
None => { }
slash_command.status = InvokedSlashCommandStatus::Finished; None => {
} slash_command.status = InvokedSlashCommandStatus::Finished;
} }
cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id });
} }
cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id });
} }
} }
ContextOperation::BufferOperation(_) => unreachable!(), ContextOperation::BufferOperation(_) => unreachable!(),
@ -1339,7 +1340,7 @@ impl AssistantContext {
let is_invalid = self let is_invalid = self
.messages_metadata .messages_metadata
.get(&message_id) .get(&message_id)
.map_or(true, |metadata| { .is_none_or(|metadata| {
!metadata.is_cache_valid(&buffer, &message.offset_range) !metadata.is_cache_valid(&buffer, &message.offset_range)
|| *encountered_invalid || *encountered_invalid
}); });
@ -1368,10 +1369,10 @@ impl AssistantContext {
continue; continue;
} }
if let Some(last_anchor) = last_anchor { if let Some(last_anchor) = last_anchor
if message.id == last_anchor { && message.id == last_anchor
hit_last_anchor = true; {
} hit_last_anchor = true;
} }
new_anchor_needs_caching = new_anchor_needs_caching new_anchor_needs_caching = new_anchor_needs_caching
@ -1406,14 +1407,14 @@ impl AssistantContext {
if !self.pending_completions.is_empty() { if !self.pending_completions.is_empty() {
return; return;
} }
if let Some(cache_configuration) = cache_configuration { if let Some(cache_configuration) = cache_configuration
if !cache_configuration.should_speculate { && !cache_configuration.should_speculate
return; {
} return;
} }
let request = { let request = {
let mut req = self.to_completion_request(Some(&model), cx); let mut req = self.to_completion_request(Some(model), cx);
// Skip the last message because it's likely to change and // Skip the last message because it's likely to change and
// therefore would be a waste to cache. // therefore would be a waste to cache.
req.messages.pop(); req.messages.pop();
@ -1428,7 +1429,7 @@ impl AssistantContext {
let model = Arc::clone(model); let model = Arc::clone(model);
self.pending_cache_warming_task = cx.spawn(async move |this, cx| { self.pending_cache_warming_task = cx.spawn(async move |this, cx| {
async move { async move {
match model.stream_completion(request, &cx).await { match model.stream_completion(request, cx).await {
Ok(mut stream) => { Ok(mut stream) => {
stream.next().await; stream.next().await;
log::info!("Cache warming completed successfully"); log::info!("Cache warming completed successfully");
@ -1552,25 +1553,24 @@ impl AssistantContext {
}) })
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
.collect::<SmallVec<_>>(); .collect::<SmallVec<_>>();
if let Some(command) = self.slash_commands.command(name, cx) { if let Some(command) = self.slash_commands.command(name, cx)
if !command.requires_argument() || !arguments.is_empty() { && (!command.requires_argument() || !arguments.is_empty())
let start_ix = offset + command_line.name.start - 1; {
let end_ix = offset let start_ix = offset + command_line.name.start - 1;
+ command_line let end_ix = offset
.arguments + command_line
.last() .arguments
.map_or(command_line.name.end, |argument| argument.end); .last()
let source_range = .map_or(command_line.name.end, |argument| argument.end);
buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix); let source_range = buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
let pending_command = ParsedSlashCommand { let pending_command = ParsedSlashCommand {
name: name.to_string(), name: name.to_string(),
arguments, arguments,
source_range, source_range,
status: PendingSlashCommandStatus::Idle, status: PendingSlashCommandStatus::Idle,
}; };
updated.push(pending_command.clone()); updated.push(pending_command.clone());
new_commands.push(pending_command); new_commands.push(pending_command);
}
} }
} }
@ -1661,12 +1661,12 @@ impl AssistantContext {
) -> Range<usize> { ) -> Range<usize> {
let buffer = self.buffer.read(cx); let buffer = self.buffer.read(cx);
let start_ix = match all_annotations let start_ix = match all_annotations
.binary_search_by(|probe| probe.range().end.cmp(&range.start, &buffer)) .binary_search_by(|probe| probe.range().end.cmp(&range.start, buffer))
{ {
Ok(ix) | Err(ix) => ix, Ok(ix) | Err(ix) => ix,
}; };
let end_ix = match all_annotations let end_ix = match all_annotations
.binary_search_by(|probe| probe.range().start.cmp(&range.end, &buffer)) .binary_search_by(|probe| probe.range().start.cmp(&range.end, buffer))
{ {
Ok(ix) => ix + 1, Ok(ix) => ix + 1,
Err(ix) => ix, Err(ix) => ix,
@ -1799,14 +1799,13 @@ impl AssistantContext {
}); });
let end = this.buffer.read(cx).anchor_before(insert_position); let end = this.buffer.read(cx).anchor_before(insert_position);
if run_commands_in_text { if run_commands_in_text
if let Some(invoked_slash_command) = && let Some(invoked_slash_command) =
this.invoked_slash_commands.get_mut(&command_id) this.invoked_slash_commands.get_mut(&command_id)
{ {
invoked_slash_command invoked_slash_command
.run_commands_in_ranges .run_commands_in_ranges
.push(start..end); .push(start..end);
}
} }
} }
SlashCommandEvent::EndSection => { SlashCommandEvent::EndSection => {
@ -1862,7 +1861,7 @@ impl AssistantContext {
{ {
let newline_offset = insert_position.saturating_sub(1); let newline_offset = insert_position.saturating_sub(1);
if buffer.contains_str_at(newline_offset, "\n") if buffer.contains_str_at(newline_offset, "\n")
&& last_section_range.map_or(true, |last_section_range| { && last_section_range.is_none_or(|last_section_range| {
!last_section_range !last_section_range
.to_offset(buffer) .to_offset(buffer)
.contains(&newline_offset) .contains(&newline_offset)
@ -2045,7 +2044,7 @@ impl AssistantContext {
let task = cx.spawn({ let task = cx.spawn({
async move |this, cx| { async move |this, cx| {
let stream = model.stream_completion(request, &cx); let stream = model.stream_completion(request, cx);
let assistant_message_id = assistant_message.id; let assistant_message_id = assistant_message.id;
let mut response_latency = None; let mut response_latency = None;
let stream_completion = async { let stream_completion = async {
@ -2081,15 +2080,12 @@ impl AssistantContext {
match event { match event {
LanguageModelCompletionEvent::StatusUpdate(status_update) => { LanguageModelCompletionEvent::StatusUpdate(status_update) => {
match status_update { if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update {
CompletionRequestStatus::UsageUpdated { amount, limit } => { this.update_model_request_usage(
this.update_model_request_usage( amount as u32,
amount as u32, limit,
limit, cx,
cx, );
);
}
_ => {}
} }
} }
LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::StartMessage { .. } => {}
@ -2315,10 +2311,7 @@ impl AssistantContext {
let mut request_message = LanguageModelRequestMessage { let mut request_message = LanguageModelRequestMessage {
role: message.role, role: message.role,
content: Vec::new(), content: Vec::new(),
cache: message cache: message.cache.as_ref().is_some_and(|cache| cache.is_anchor),
.cache
.as_ref()
.map_or(false, |cache| cache.is_anchor),
}; };
while let Some(content) = contents.peek() { while let Some(content) = contents.peek() {
@ -2708,7 +2701,7 @@ impl AssistantContext {
self.summary_task = cx.spawn(async move |this, cx| { self.summary_task = cx.spawn(async move |this, cx| {
let result = async { let result = async {
let stream = model.model.stream_completion_text(request, &cx); let stream = model.model.stream_completion_text(request, cx);
let mut messages = stream.await?; let mut messages = stream.await?;
let mut replaced = !replace_old; let mut replaced = !replace_old;
@ -2741,10 +2734,10 @@ impl AssistantContext {
} }
this.read_with(cx, |this, _cx| { this.read_with(cx, |this, _cx| {
if let Some(summary) = this.summary.content() { if let Some(summary) = this.summary.content()
if summary.text.is_empty() { && summary.text.is_empty()
bail!("Model generated an empty summary"); {
} bail!("Model generated an empty summary");
} }
Ok(()) Ok(())
})??; })??;
@ -2799,7 +2792,7 @@ impl AssistantContext {
let mut current_message = messages.next(); let mut current_message = messages.next();
while let Some(offset) = offsets.next() { while let Some(offset) = offsets.next() {
// Locate the message that contains the offset. // Locate the message that contains the offset.
while current_message.as_ref().map_or(false, |message| { while current_message.as_ref().is_some_and(|message| {
!message.offset_range.contains(&offset) && messages.peek().is_some() !message.offset_range.contains(&offset) && messages.peek().is_some()
}) { }) {
current_message = messages.next(); current_message = messages.next();
@ -2809,7 +2802,7 @@ impl AssistantContext {
}; };
// Skip offsets that are in the same message. // Skip offsets that are in the same message.
while offsets.peek().map_or(false, |offset| { while offsets.peek().is_some_and(|offset| {
message.offset_range.contains(offset) || messages.peek().is_none() message.offset_range.contains(offset) || messages.peek().is_none()
}) { }) {
offsets.next(); offsets.next();
@ -2924,18 +2917,18 @@ impl AssistantContext {
fs.create_dir(contexts_dir().as_ref()).await?; fs.create_dir(contexts_dir().as_ref()).await?;
// rename before write ensures that only one file exists // rename before write ensures that only one file exists
if let Some(old_path) = old_path.as_ref() { if let Some(old_path) = old_path.as_ref()
if new_path.as_path() != old_path.as_ref() { && new_path.as_path() != old_path.as_ref()
fs.rename( {
&old_path, fs.rename(
&new_path, old_path,
RenameOptions { &new_path,
overwrite: true, RenameOptions {
ignore_if_exists: true, overwrite: true,
}, ignore_if_exists: true,
) },
.await?; )
} .await?;
} }
// update path before write in case it fails // update path before write in case it fails

View file

@ -1055,7 +1055,7 @@ fn test_mark_cache_anchors(cx: &mut App) {
assert_eq!( assert_eq!(
messages_cache(&context, cx) messages_cache(&context, cx)
.iter() .iter()
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
.count(), .count(),
0, 0,
"Empty messages should not have any cache anchors." "Empty messages should not have any cache anchors."
@ -1083,7 +1083,7 @@ fn test_mark_cache_anchors(cx: &mut App) {
assert_eq!( assert_eq!(
messages_cache(&context, cx) messages_cache(&context, cx)
.iter() .iter()
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
.count(), .count(),
0, 0,
"Messages should not be marked for cache before going over the token minimum." "Messages should not be marked for cache before going over the token minimum."
@ -1098,7 +1098,7 @@ fn test_mark_cache_anchors(cx: &mut App) {
assert_eq!( assert_eq!(
messages_cache(&context, cx) messages_cache(&context, cx)
.iter() .iter()
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
.collect::<Vec<bool>>(), .collect::<Vec<bool>>(),
vec![true, true, false], vec![true, true, false],
"Last message should not be an anchor on speculative request." "Last message should not be an anchor on speculative request."
@ -1116,7 +1116,7 @@ fn test_mark_cache_anchors(cx: &mut App) {
assert_eq!( assert_eq!(
messages_cache(&context, cx) messages_cache(&context, cx)
.iter() .iter()
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
.collect::<Vec<bool>>(), .collect::<Vec<bool>>(),
vec![false, true, true, false], vec![false, true, true, false],
"Most recent message should also be cached if not a speculative request." "Most recent message should also be cached if not a speculative request."
@ -1300,7 +1300,7 @@ fn test_summarize_error(
context.assist(cx); context.assist(cx);
}); });
simulate_successful_response(&model, cx); simulate_successful_response(model, cx);
context.read_with(cx, |context, _| { context.read_with(cx, |context, _| {
assert!(!context.summary().content().unwrap().done); assert!(!context.summary().content().unwrap().done);

Some files were not shown because too many files have changed in this diff Show more