Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Finn Evers
89c7e544d5 sibling of 397b5f9301 2025-07-29 16:04:51 +00:00
190 changed files with 4918 additions and 11219 deletions

View file

@ -269,10 +269,6 @@ jobs:
mkdir -p ./../.cargo mkdir -p ./../.cargo
cp ./.cargo/ci-config.toml ./../.cargo/config.toml cp ./.cargo/ci-config.toml ./../.cargo/config.toml
- name: Check that Cargo.lock is up to date
run: |
cargo update --locked --workspace
- name: cargo clippy - name: cargo clippy
run: ./script/clippy run: ./script/clippy
@ -771,7 +767,7 @@ jobs:
timeout-minutes: 120 timeout-minutes: 120
name: Create a Windows installer name: Create a Windows installer
runs-on: [self-hosted, Windows, X64] runs-on: [self-hosted, Windows, X64]
if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) if: false && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
needs: [windows_tests] needs: [windows_tests]
env: env:
AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }} AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }}

View file

@ -111,11 +111,6 @@ jobs:
echo "Publishing version: ${version} on release channel nightly" echo "Publishing version: ${version} on release channel nightly"
echo "nightly" > crates/zed/RELEASE_CHANNEL echo "nightly" > crates/zed/RELEASE_CHANNEL
- name: Setup Sentry CLI
uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2
with:
token: ${{ SECRETS.SENTRY_AUTH_TOKEN }}
- name: Create macOS app bundle - name: Create macOS app bundle
run: script/bundle-mac run: script/bundle-mac
@ -141,11 +136,6 @@ jobs:
- name: Install Linux dependencies - name: Install Linux dependencies
run: ./script/linux && ./script/install-mold 2.34.0 run: ./script/linux && ./script/install-mold 2.34.0
- name: Setup Sentry CLI
uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2
with:
token: ${{ SECRETS.SENTRY_AUTH_TOKEN }}
- name: Limit target directory size - name: Limit target directory size
run: script/clear-target-dir-if-larger-than 100 run: script/clear-target-dir-if-larger-than 100
@ -178,11 +168,6 @@ jobs:
- name: Install Linux dependencies - name: Install Linux dependencies
run: ./script/linux run: ./script/linux
- name: Setup Sentry CLI
uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2
with:
token: ${{ SECRETS.SENTRY_AUTH_TOKEN }}
- name: Limit target directory size - name: Limit target directory size
run: script/clear-target-dir-if-larger-than 100 run: script/clear-target-dir-if-larger-than 100
@ -277,11 +262,6 @@ jobs:
Write-Host "Publishing version: $version on release channel nightly" Write-Host "Publishing version: $version on release channel nightly"
"nightly" | Set-Content -Path "crates/zed/RELEASE_CHANNEL" "nightly" | Set-Content -Path "crates/zed/RELEASE_CHANNEL"
- name: Setup Sentry CLI
uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2
with:
token: ${{ SECRETS.SENTRY_AUTH_TOKEN }}
- name: Build Zed installer - name: Build Zed installer
working-directory: ${{ env.ZED_WORKSPACE }} working-directory: ${{ env.ZED_WORKSPACE }}
run: script/bundle-windows.ps1 run: script/bundle-windows.ps1

57
Cargo.lock generated
View file

@ -6,7 +6,6 @@ version = 4
name = "acp_thread" name = "acp_thread"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"agent-client-protocol",
"agentic-coding-protocol", "agentic-coding-protocol",
"anyhow", "anyhow",
"assistant_tool", "assistant_tool",
@ -136,23 +135,11 @@ dependencies = [
"zstd", "zstd",
] ]
[[package]]
name = "agent-client-protocol"
version = "0.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b"
dependencies = [
"schemars",
"serde",
"serde_json",
]
[[package]] [[package]]
name = "agent_servers" name = "agent_servers"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"acp_thread", "acp_thread",
"agent-client-protocol",
"agentic-coding-protocol", "agentic-coding-protocol",
"anyhow", "anyhow",
"collections", "collections",
@ -168,7 +155,6 @@ dependencies = [
"nix 0.29.0", "nix 0.29.0",
"paths", "paths",
"project", "project",
"rand 0.8.5",
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
@ -209,9 +195,9 @@ version = "0.1.0"
dependencies = [ dependencies = [
"acp_thread", "acp_thread",
"agent", "agent",
"agent-client-protocol",
"agent_servers", "agent_servers",
"agent_settings", "agent_settings",
"agentic-coding-protocol",
"ai_onboarding", "ai_onboarding",
"anyhow", "anyhow",
"assistant_context", "assistant_context",
@ -4258,7 +4244,7 @@ dependencies = [
[[package]] [[package]]
name = "dap-types" name = "dap-types"
version = "0.0.1" version = "0.0.1"
source = "git+https://github.com/zed-industries/dap-types?rev=1b461b310481d01e02b2603c16d7144b926339f8#1b461b310481d01e02b2603c16d7144b926339f8" source = "git+https://github.com/zed-industries/dap-types?rev=7f39295b441614ca9dbf44293e53c32f666897f9#7f39295b441614ca9dbf44293e53c32f666897f9"
dependencies = [ dependencies = [
"schemars", "schemars",
"serde", "serde",
@ -4980,7 +4966,6 @@ dependencies = [
"text", "text",
"theme", "theme",
"time", "time",
"tree-sitter-bash",
"tree-sitter-html", "tree-sitter-html",
"tree-sitter-python", "tree-sitter-python",
"tree-sitter-rust", "tree-sitter-rust",
@ -5386,13 +5371,11 @@ dependencies = [
"log", "log",
"lsp", "lsp",
"parking_lot", "parking_lot",
"pretty_assertions",
"semantic_version", "semantic_version",
"serde", "serde",
"serde_json", "serde_json",
"task", "task",
"toml 0.8.20", "toml 0.8.20",
"url",
"util", "util",
"wasm-encoder 0.221.3", "wasm-encoder 0.221.3",
"wasmparser 0.221.3", "wasmparser 0.221.3",
@ -7419,9 +7402,9 @@ dependencies = [
[[package]] [[package]]
name = "grid" name = "grid"
version = "0.17.0" version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71b01d27060ad58be4663b9e4ac9e2d4806918e8876af8912afbddd1a91d5eaa" checksum = "be136d9dacc2a13cc70bb6c8f902b414fb2641f8db1314637c6b7933411a8f82"
[[package]] [[package]]
name = "group" name = "group"
@ -7692,12 +7675,6 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hex-literal"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcaaec4551594c969335c98c903c1397853d4198408ea609190f420500f6be71"
[[package]] [[package]]
name = "hexf-parse" name = "hexf-parse"
version = "0.2.1" version = "0.2.1"
@ -9226,7 +9203,6 @@ dependencies = [
"chrono", "chrono",
"collections", "collections",
"dap", "dap",
"feature_flags",
"futures 0.3.31", "futures 0.3.31",
"gpui", "gpui",
"http_client", "http_client",
@ -9419,7 +9395,7 @@ dependencies = [
[[package]] [[package]]
name = "libwebrtc" name = "libwebrtc"
version = "0.3.10" version = "0.3.10"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
dependencies = [ dependencies = [
"cxx", "cxx",
"jni", "jni",
@ -9499,7 +9475,7 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856"
[[package]] [[package]]
name = "livekit" name = "livekit"
version = "0.7.8" version = "0.7.8"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
dependencies = [ dependencies = [
"chrono", "chrono",
"futures-util", "futures-util",
@ -9522,7 +9498,7 @@ dependencies = [
[[package]] [[package]]
name = "livekit-api" name = "livekit-api"
version = "0.4.2" version = "0.4.2"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
dependencies = [ dependencies = [
"futures-util", "futures-util",
"http 0.2.12", "http 0.2.12",
@ -9546,7 +9522,7 @@ dependencies = [
[[package]] [[package]]
name = "livekit-protocol" name = "livekit-protocol"
version = "0.3.9" version = "0.3.9"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
dependencies = [ dependencies = [
"futures-util", "futures-util",
"livekit-runtime", "livekit-runtime",
@ -9563,7 +9539,7 @@ dependencies = [
[[package]] [[package]]
name = "livekit-runtime" name = "livekit-runtime"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
dependencies = [ dependencies = [
"tokio", "tokio",
"tokio-stream", "tokio-stream",
@ -11029,7 +11005,6 @@ dependencies = [
"ui", "ui",
"workspace", "workspace",
"workspace-hack", "workspace-hack",
"zed_actions",
] ]
[[package]] [[package]]
@ -15986,12 +15961,13 @@ dependencies = [
[[package]] [[package]]
name = "taffy" name = "taffy"
version = "0.8.3" version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aaef0ac998e6527d6d0d5582f7e43953bb17221ac75bb8eb2fcc2db3396db1c" checksum = "e8b61630cba2afd2c851821add2e1bb1b7851a2436e839ab73b56558b009035e"
dependencies = [ dependencies = [
"arrayvec", "arrayvec",
"grid", "grid",
"num-traits",
"serde", "serde",
"slotmap", "slotmap",
] ]
@ -18551,7 +18527,7 @@ dependencies = [
[[package]] [[package]]
name = "webrtc-sys" name = "webrtc-sys"
version = "0.3.7" version = "0.3.7"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
dependencies = [ dependencies = [
"cc", "cc",
"cxx", "cxx",
@ -18564,15 +18540,13 @@ dependencies = [
[[package]] [[package]]
name = "webrtc-sys-build" name = "webrtc-sys-build"
version = "0.3.6" version = "0.3.6"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
dependencies = [ dependencies = [
"fs2", "fs2",
"hex-literal",
"regex", "regex",
"reqwest 0.11.27", "reqwest 0.11.27",
"scratch", "scratch",
"semver", "semver",
"sha2",
"zip", "zip",
] ]
@ -20196,7 +20170,7 @@ dependencies = [
[[package]] [[package]]
name = "zed" name = "zed"
version = "0.198.0" version = "0.197.3"
dependencies = [ dependencies = [
"activity_indicator", "activity_indicator",
"agent", "agent",
@ -20237,7 +20211,6 @@ dependencies = [
"extension", "extension",
"extension_host", "extension_host",
"extensions_ui", "extensions_ui",
"feature_flags",
"feedback", "feedback",
"file_finder", "file_finder",
"fs", "fs",

View file

@ -413,7 +413,6 @@ zlog_settings = { path = "crates/zlog_settings" }
# #
agentic-coding-protocol = "0.0.10" agentic-coding-protocol = "0.0.10"
agent-client-protocol = "0.0.11"
aho-corasick = "1.1" aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14" any_vec = "0.14"
@ -460,7 +459,7 @@ core-video = { version = "0.4.3", features = ["metal"] }
cpal = "0.16" cpal = "0.16"
criterion = { version = "0.5", features = ["html_reports"] } criterion = { version = "0.5", features = ["html_reports"] }
ctor = "0.4.0" ctor = "0.4.0"
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "1b461b310481d01e02b2603c16d7144b926339f8" } dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "7f39295b441614ca9dbf44293e53c32f666897f9" }
dashmap = "6.0" dashmap = "6.0"
derive_more = "0.99.17" derive_more = "0.99.17"
dirs = "4.0" dirs = "4.0"
@ -720,11 +719,6 @@ workspace-hack = { path = "tooling/workspace-hack" }
split-debuginfo = "unpacked" split-debuginfo = "unpacked"
codegen-units = 16 codegen-units = 16
# mirror configuration for crates compiled for the build platform
# (without this cargo will compile ~400 crates twice)
[profile.dev.build-override]
codegen-units = 16
[profile.dev.package] [profile.dev.package]
taffy = { opt-level = 3 } taffy = { opt-level = 3 }
cranelift-codegen = { opt-level = 3 } cranelift-codegen = { opt-level = 3 }

View file

@ -1,7 +1 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-volume-off"><path d="M16 9a5 5 0 0 1 .95 2.293"/><path d="M19.364 5.636a9 9 0 0 1 1.889 9.96"/><path d="m2 2 20 20"/><path d="m7 7-.587.587A1.4 1.4 0 0 1 5.416 8H3a1 1 0 0 0-1 1v6a1 1 0 0 0 1 1h2.416a1.4 1.4 0 0 1 .997.413l3.383 3.384A.705.705 0 0 0 11 19.298V11"/><path d="M9.828 4.172A.686.686 0 0 1 11 4.657v.686"/></svg>
<path d="M10.6667 6C11.003 6.44823 11.2208 6.97398 11.3001 7.52867" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M12.9094 3.75732C13.7621 4.6095 14.3383 5.69876 14.5629 6.88315C14.7875 8.06754 14.6502 9.29213 14.1688 10.3973" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2.66675 2L13.6667 13" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M5.33333 4.66669L4.942 5.05802C4.85494 5.1456 4.75136 5.21504 4.63726 5.2623C4.52317 5.30957 4.40083 5.33372 4.27733 5.33335H2.66667C2.48986 5.33335 2.32029 5.40359 2.19526 5.52862C2.07024 5.65364 2 5.82321 2 6.00002V10C2 10.1768 2.07024 10.3464 2.19526 10.4714C2.32029 10.5964 2.48986 10.6667 2.66667 10.6667H4.27733C4.40083 10.6663 4.52317 10.6905 4.63726 10.7377C4.75136 10.785 4.85494 10.8544 4.942 10.942L7.19733 13.198C7.26307 13.2639 7.34687 13.3088 7.43813 13.3269C7.52939 13.3451 7.62399 13.3358 7.70995 13.3002C7.79591 13.2646 7.86936 13.2042 7.921 13.1268C7.97263 13.0494 8.00013 12.9584 8 12.8654V7.33335" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M7.21875 2.78136C7.28267 2.71719 7.36421 2.67345 7.45303 2.65568C7.54184 2.63791 7.63393 2.64691 7.71762 2.68154C7.80132 2.71618 7.87284 2.77488 7.92312 2.85022C7.97341 2.92555 8.0002 3.01412 8.00008 3.10469V3.56202" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.6 KiB

After

Width:  |  Height:  |  Size: 527 B

Before After
Before After

View file

@ -1,5 +1 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-volume-2"><path d="M11 4.702a.705.705 0 0 0-1.203-.498L6.413 7.587A1.4 1.4 0 0 1 5.416 8H3a1 1 0 0 0-1 1v6a1 1 0 0 0 1 1h2.416a1.4 1.4 0 0 1 .997.413l3.383 3.384A.705.705 0 0 0 11 19.298z"/><path d="M16 9a5 5 0 0 1 0 6"/><path d="M19.364 18.364a9 9 0 0 0 0-12.728"/></svg>
<path d="M8 3.13467C7.99987 3.04181 7.97223 2.95107 7.92057 2.8739C7.86892 2.79674 7.79557 2.7366 7.70977 2.70108C7.62397 2.66557 7.52958 2.65626 7.43849 2.67434C7.34741 2.69242 7.26373 2.73707 7.198 2.80266L4.942 5.058C4.85494 5.14558 4.75136 5.21502 4.63726 5.26228C4.52317 5.30954 4.40083 5.33369 4.27733 5.33333H2.66667C2.48986 5.33333 2.32029 5.40357 2.19526 5.52859C2.07024 5.65362 2 5.82319 2 6V10C2 10.1768 2.07024 10.3464 2.19526 10.4714C2.32029 10.5964 2.48986 10.6667 2.66667 10.6667H4.27733C4.40083 10.6663 4.52317 10.6905 4.63726 10.7377C4.75136 10.785 4.85494 10.8544 4.942 10.942L7.19733 13.198C7.26307 13.2639 7.34687 13.3087 7.43813 13.3269C7.52939 13.3451 7.62399 13.3358 7.70995 13.3002C7.79591 13.2645 7.86936 13.2042 7.921 13.1268C7.97263 13.0494 8.00013 12.9584 8 12.8653V3.13467Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M10.6667 6C11.0995 6.57699 11.3334 7.27877 11.3334 8C11.3334 8.72123 11.0995 9.42301 10.6667 10" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M12.9094 12.2427C13.4666 11.6855 13.9085 11.0241 14.2101 10.2961C14.5116 9.56815 14.6668 8.78793 14.6668 7.99999C14.6668 7.21205 14.5116 6.43183 14.2101 5.70387C13.9085 4.97591 13.4666 4.31448 12.9094 3.75732" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.4 KiB

After

Width:  |  Height:  |  Size: 475 B

Before After
Before After

View file

@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-cloud-download-icon lucide-cloud-download"><path d="M12 13v8l-4-4"/><path d="m12 21 4-4"/><path d="M4.393 15.269A7 7 0 1 1 15.71 8h1.79a4.5 4.5 0 0 1 2.436 8.284"/></svg>

Before

Width:  |  Height:  |  Size: 372 B

View file

@ -1,5 +1,8 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M10.437 11.0461L13.4831 8L10.437 4.95392" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/> <path
<path d="M13 8L8 8" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/> fill-rule="evenodd"
<path d="M6.6553 13.4659H4.21843C3.89528 13.4659 3.58537 13.3375 3.35687 13.109C3.12837 12.8805 3 12.5706 3 12.2475V3.71843C3 3.39528 3.12837 3.08537 3.35687 2.85687C3.58537 2.62837 3.89528 2.5 4.21843 2.5H6.6553" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/> clip-rule="evenodd"
d="M3 1C2.44771 1 2 1.44772 2 2V13C2 13.5523 2.44772 14 3 14H10.5C10.7761 14 11 13.7761 11 13.5C11 13.2239 10.7761 13 10.5 13H3V2L10.5 2C10.7761 2 11 1.77614 11 1.5C11 1.22386 10.7761 1 10.5 1H3ZM12.6036 4.89645C12.4083 4.70118 12.0917 4.70118 11.8964 4.89645C11.7012 5.09171 11.7012 5.40829 11.8964 5.60355L13.2929 7H6.5C6.22386 7 6 7.22386 6 7.5C6 7.77614 6.22386 8 6.5 8H13.2929L11.8964 9.39645C11.7012 9.59171 11.7012 9.90829 11.8964 10.1036C12.0917 10.2988 12.4083 10.2988 12.6036 10.1036L14.8536 7.85355C15.0488 7.65829 15.0488 7.34171 14.8536 7.14645L12.6036 4.89645Z"
fill="currentColor"
/>
</svg> </svg>

Before

Width:  |  Height:  |  Size: 637 B

After

Width:  |  Height:  |  Size: 768 B

Before After
Before After

View file

@ -1,3 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M9.628 11.0743V10.4575H8.45562L8.65084 10.2445C8.75911 10.1264 8.96952 9.79454 9.11862 9.50789C9.52153 8.73047 9.51798 7.25107 9.11862 6.43992C8.58614 5.35722 7.49453 4.56381 6.24942 4.35703C4.59252 4.08192 2.86196 5.00312 2.14045 6.54287C1.77038 7.33182 1.77038 8.64437 2.14045 9.43333C2.45905 10.1122 3.11309 10.8204 3.73609 11.1595C4.51439 11.5828 5.18264 11.676 7.51312 11.6848L9.62627 11.6928L9.628 11.0743ZM5.30605 10.169C4.24109 10.0111 3.45215 9.07124 3.45659 7.96813C3.45659 7.33004 3.70064 6.80022 4.18697 6.36182C4.67685 5.91986 5.1312 5.77344 5.86602 5.82048C7.00287 5.89236 7.82382 6.79845 7.82382 7.98056C7.82382 8.61332 7.71996 8.91682 7.33036 9.42534C6.90172 9.98444 6.08345 10.2853 5.30692 10.1699M15.1374 10.9802V10.2684H11.8138V4.47509H10.1986V11.6928H15.1374V10.9802Z" fill="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 916 B

View file

@ -1,5 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 12.2028V14.3042" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/> <path fill-rule="evenodd" clip-rule="evenodd" d="M3.72742 8.83338C3.63539 8.57302 3.34973 8.43656 3.08937 8.52858C2.82901 8.6206 2.69255 8.90626 2.78458 9.16662C2.86101 9.38288 2.95188 9.59228 3.056 9.79364C3.81427 11.2601 5.27842 12.3044 7.00014 12.4753L7.00014 14L5.50014 14C5.22399 14 5.00014 14.2239 5.00014 14.5C5.00014 14.7761 5.22399 15 5.50014 15L7.50014 15L9.50014 15C9.77628 15 10.0001 14.7761 10.0001 14.5C10.0001 14.2239 9.77628 14 9.50014 14L8.00014 14L8.00014 12.4753C9.72168 12.3043 11.1857 11.26 11.9439 9.79364C12.048 9.59228 12.1389 9.38288 12.2153 9.16662C12.3073 8.90626 12.1709 8.6206 11.9105 8.52858C11.6501 8.43656 11.3645 8.57302 11.2725 8.83338C11.2114 9.00607 11.1388 9.17337 11.0556 9.33433C10.3899 10.6218 9.04706 11.5 7.49994 11.5C5.95282 11.5 4.60997 10.6218 3.94428 9.33433C3.86104 9.17337 3.78845 9.00607 3.72742 8.83338ZM5.5 3.5L5.5 7.5C5.5 8.60457 6.39543 9.5 7.5 9.5C8.60457 9.5 9.5 8.60457 9.5 7.5L9.5 3.5C9.5 2.39543 8.60457 1.5 7.5 1.5C6.39543 1.5 5.5 2.39543 5.5 3.5ZM4.5 7.5C4.5 9.15685 5.84315 10.5 7.5 10.5C9.15685 10.5 10.5 9.15685 10.5 7.5L10.5 3.5C10.5 1.84315 9.15685 0.5 7.5 0.5C5.84315 0.5 4.5 1.84315 4.5 3.5L4.5 7.5Z" fill="black"/>
<path d="M12.2027 6.94928V8.11672C12.2027 9.20041 11.7599 10.2397 10.9717 11.006C10.1836 11.7723 9.11457 12.2028 7.99992 12.2028C6.88527 12.2028 5.81627 11.7723 5.02809 11.006C4.23991 10.2397 3.79712 9.20041 3.79712 8.11672V6.94928" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M10.1015 3.63555C10.1015 2.56426 9.16065 1.6958 8.00008 1.6958C6.83951 1.6958 5.89868 2.56426 5.89868 3.63555V8.16165C5.89868 9.23294 6.83951 10.1014 8.00008 10.1014C9.16065 10.1014 10.1015 9.23294 10.1015 8.16165V3.63555Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg> </svg>

Before

Width:  |  Height:  |  Size: 847 B

After

Width:  |  Height:  |  Size: 1.3 KiB

Before After
Before After

View file

@ -1,8 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3 3L13 13" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/> <path fill-rule="evenodd" clip-rule="evenodd" d="M12.87 1.83637C13.0557 1.63204 13.0407 1.31581 12.8363 1.13006C12.632 0.944307 12.3158 0.959365 12.13 1.16369L10.4589 3.00199C10.2216 1.58215 8.98719 0.5 7.5 0.5C5.84315 0.5 4.5 1.84315 4.5 3.5L4.5 7.5C4.5 8.0754 4.66199 8.61297 4.94286 9.06958L4.24966 9.8321C4.1363 9.6744 4.03412 9.5081 3.94428 9.33433C3.86104 9.17337 3.78845 9.00607 3.72742 8.83338C3.63539 8.57302 3.34973 8.43656 3.08937 8.52858C2.82901 8.6206 2.69255 8.90626 2.78458 9.16662C2.86101 9.38288 2.95188 9.59228 3.056 9.79364C3.20094 10.074 3.37167 10.3388 3.56506 10.5852L2.13003 12.1637C1.94428 12.368 1.95933 12.6842 2.16366 12.87C2.36799 13.0558 2.68422 13.0407 2.86997 12.8364L4.25951 11.3079C5.01297 11.9497 5.95951 12.372 7.00014 12.4753L7.00014 14L5.50014 14C5.22399 14 5.00014 14.2239 5.00014 14.5C5.00014 14.7761 5.22399 15 5.50014 15L7.50014 15L9.50014 15C9.77628 15 10.0001 14.7761 10.0001 14.5C10.0001 14.2239 9.77628 14 9.50014 14L8.00014 14L8.00014 12.4753C9.72168 12.3043 11.1857 11.26 11.9439 9.79364C12.048 9.59228 12.1389 9.38288 12.2153 9.16662C12.3073 8.90626 12.1709 8.6206 11.9105 8.52858C11.6501 8.43656 11.3645 8.57302 11.2725 8.83338C11.2114 9.00607 11.1388 9.17337 11.0556 9.33433C10.3899 10.6218 9.04706 11.5 7.49994 11.5C6.523 11.5 5.62751 11.1498 4.93254 10.5675L5.60604 9.82669C6.12251 10.2476 6.78178 10.5 7.5 10.5C9.15685 10.5 10.5 9.15685 10.5 7.5L10.5 4.44333L12.87 1.83637ZM9.5 4.05673L9.5 3.5C9.5 2.39543 8.60457 1.5 7.5 1.5C6.39543 1.5 5.5 2.39543 5.5 3.5L5.5 7.5C5.5 7.77755 5.55653 8.04189 5.65872 8.28214L9.5 4.05673ZM6.28022 9.08509L9.5 5.54333L9.5 7.5C9.5 8.60457 8.60457 9.5 7.5 9.5C7.04083 9.5 6.6178 9.34527 6.28022 9.08509Z" fill="black"/>
<path d="M12 9C12 8.74858 12 8.49375 12 8.23839V7" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M4.00043 7V8.09869C3.98856 8.86731 4.22157 9.62164 4.66938 10.2643C5.11718 10.907 5.75924 11.4085 6.51267 11.7042C7.2661 11.9999 8.09632 12.0761 8.89619 11.923C9.47851 11.8115 10.0253 11.5823 10.5 11.2539" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M10 6V3.62904C9.99714 3.26103 9.8347 2.90448 9.53885 2.6168C9.24299 2.32913 8.83093 2.12707 8.36903 2.04316C7.90713 1.95926 7.42226 1.9984 6.99252 2.15427C6.56278 2.31015 6.21317 2.57369 6 2.90245" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M6 6V8.00088C6.00031 8.39636 6.10356 8.78287 6.29674 9.11159C6.48991 9.44031 6.76433 9.69649 7.08534 9.84779C7.40634 9.99909 7.75954 10.0387 8.10032 9.96165C8.4411 9.88459 8.75417 9.69431 9 9.41483" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M8 12V14" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg> </svg>

Before

Width:  |  Height:  |  Size: 1.3 KiB

After

Width:  |  Height:  |  Size: 1.8 KiB

Before After
Before After

View file

@ -1,5 +1,8 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12.8 3H3.2C2.53726 3 2 3.51167 2 4.14286V9.85714C2 10.4883 2.53726 11 3.2 11H12.8C13.4627 11 14 10.4883 14 9.85714V4.14286C14 3.51167 13.4627 3 12.8 3Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/> <path
<path d="M5.33325 14H10.6666" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/> fill-rule="evenodd"
<path d="M8 11.3333V14" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/> clip-rule="evenodd"
d="M1 3.25C1 3.11193 1.11193 3 1.25 3H13.75C13.8881 3 14 3.11193 14 3.25V10.75C14 10.8881 13.8881 11 13.75 11H1.25C1.11193 11 1 10.8881 1 10.75V3.25ZM1.25 2C0.559643 2 0 2.55964 0 3.25V10.75C0 11.4404 0.559644 12 1.25 12H5.07341L4.82991 13.2986C4.76645 13.6371 5.02612 13.95 5.37049 13.95H9.62951C9.97389 13.95 10.2336 13.6371 10.1701 13.2986L9.92659 12H13.75C14.4404 12 15 11.4404 15 10.75V3.25C15 2.55964 14.4404 2 13.75 2H1.25ZM9.01091 12H5.98909L5.79222 13.05H9.20778L9.01091 12Z"
fill="currentColor"
/>
</svg> </svg>

Before

Width:  |  Height:  |  Size: 569 B

After

Width:  |  Height:  |  Size: 677 B

Before After
Before After

View file

@ -495,7 +495,7 @@
"shift-f12": "editor::GoToImplementation", "shift-f12": "editor::GoToImplementation",
"alt-ctrl-f12": "editor::GoToTypeDefinitionSplit", "alt-ctrl-f12": "editor::GoToTypeDefinitionSplit",
"alt-shift-f12": "editor::FindAllReferences", "alt-shift-f12": "editor::FindAllReferences",
"ctrl-m": "editor::MoveToEnclosingBracket", // from jetbrains "ctrl-m": "editor::MoveToEnclosingBracket",
"ctrl-|": "editor::MoveToEnclosingBracket", "ctrl-|": "editor::MoveToEnclosingBracket",
"ctrl-{": "editor::Fold", "ctrl-{": "editor::Fold",
"ctrl-}": "editor::UnfoldLines", "ctrl-}": "editor::UnfoldLines",

View file

@ -549,7 +549,7 @@
"alt-cmd-f12": "editor::GoToTypeDefinitionSplit", "alt-cmd-f12": "editor::GoToTypeDefinitionSplit",
"alt-shift-f12": "editor::FindAllReferences", "alt-shift-f12": "editor::FindAllReferences",
"cmd-|": "editor::MoveToEnclosingBracket", "cmd-|": "editor::MoveToEnclosingBracket",
"ctrl-m": "editor::MoveToEnclosingBracket", // From Jetbrains "ctrl-m": "editor::MoveToEnclosingBracket",
"alt-cmd-[": "editor::Fold", "alt-cmd-[": "editor::Fold",
"alt-cmd-]": "editor::UnfoldLines", "alt-cmd-]": "editor::UnfoldLines",
"cmd-k cmd-l": "editor::ToggleFold", "cmd-k cmd-l": "editor::ToggleFold",

View file

@ -4,7 +4,6 @@
"ctrl-alt-s": "zed::OpenSettings", "ctrl-alt-s": "zed::OpenSettings",
"ctrl-{": "pane::ActivatePreviousItem", "ctrl-{": "pane::ActivatePreviousItem",
"ctrl-}": "pane::ActivateNextItem", "ctrl-}": "pane::ActivateNextItem",
"shift-escape": null, // Unmap workspace::zoom
"ctrl-f2": "debugger::Stop", "ctrl-f2": "debugger::Stop",
"f6": "debugger::Pause", "f6": "debugger::Pause",
"f7": "debugger::StepInto", "f7": "debugger::StepInto",
@ -45,8 +44,8 @@
"ctrl-alt-right": "pane::GoForward", "ctrl-alt-right": "pane::GoForward",
"alt-f7": "editor::FindAllReferences", "alt-f7": "editor::FindAllReferences",
"ctrl-alt-f7": "editor::FindAllReferences", "ctrl-alt-f7": "editor::FindAllReferences",
"ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock // "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock
"ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleRightDock // "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleLeftDock
"ctrl-shift-b": "editor::GoToTypeDefinition", "ctrl-shift-b": "editor::GoToTypeDefinition",
"ctrl-alt-shift-b": "editor::GoToTypeDefinitionSplit", "ctrl-alt-shift-b": "editor::GoToTypeDefinitionSplit",
"f2": "editor::GoToDiagnostic", "f2": "editor::GoToDiagnostic",
@ -101,27 +100,12 @@
"shift shift": "command_palette::Toggle", "shift shift": "command_palette::Toggle",
"ctrl-alt-shift-n": "project_symbols::Toggle", "ctrl-alt-shift-n": "project_symbols::Toggle",
"alt-0": "git_panel::ToggleFocus", "alt-0": "git_panel::ToggleFocus",
"alt-1": "project_panel::ToggleFocus", "alt-1": "workspace::ToggleLeftDock",
"alt-5": "debug_panel::ToggleFocus", "alt-5": "debug_panel::ToggleFocus",
"alt-6": "diagnostics::Deploy", "alt-6": "diagnostics::Deploy",
"alt-7": "outline_panel::ToggleFocus" "alt-7": "outline_panel::ToggleFocus"
} }
}, },
{
"context": "Pane", // this is to override the default Pane mappings to switch tabs
"bindings": {
"alt-1": "project_panel::ToggleFocus",
"alt-2": null, // Bookmarks (left dock)
"alt-3": null, // Find Panel (bottom dock)
"alt-4": null, // Run Panel (bottom dock)
"alt-5": "debug_panel::ToggleFocus",
"alt-6": "diagnostics::Deploy",
"alt-7": "outline_panel::ToggleFocus",
"alt-8": null, // Services (bottom dock)
"alt-9": null, // Git History (bottom dock)
"alt-0": "git_panel::ToggleFocus"
}
},
{ {
"context": "Workspace || Editor", "context": "Workspace || Editor",
"bindings": { "bindings": {
@ -167,9 +151,6 @@
{ "context": "OutlinePanel", "bindings": { "alt-7": "workspace::CloseActiveDock" } }, { "context": "OutlinePanel", "bindings": { "alt-7": "workspace::CloseActiveDock" } },
{ {
"context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", "context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)",
"bindings": { "bindings": { "escape": "editor::ToggleFocus" }
"escape": "editor::ToggleFocus",
"shift-escape": "workspace::CloseActiveDock"
}
} }
] ]

View file

@ -4,7 +4,6 @@
"cmd-{": "pane::ActivatePreviousItem", "cmd-{": "pane::ActivatePreviousItem",
"cmd-}": "pane::ActivateNextItem", "cmd-}": "pane::ActivateNextItem",
"cmd-0": "git_panel::ToggleFocus", // overrides `cmd-0` zoom reset "cmd-0": "git_panel::ToggleFocus", // overrides `cmd-0` zoom reset
"shift-escape": null, // Unmap workspace::zoom
"ctrl-f2": "debugger::Stop", "ctrl-f2": "debugger::Stop",
"f6": "debugger::Pause", "f6": "debugger::Pause",
"f7": "debugger::StepInto", "f7": "debugger::StepInto",
@ -109,21 +108,6 @@
"cmd-7": "outline_panel::ToggleFocus" "cmd-7": "outline_panel::ToggleFocus"
} }
}, },
{
"context": "Pane", // this is to override the default Pane mappings to switch tabs
"bindings": {
"cmd-1": "project_panel::ToggleFocus",
"cmd-2": null, // Bookmarks (left dock)
"cmd-3": null, // Find Panel (bottom dock)
"cmd-4": null, // Run Panel (bottom dock)
"cmd-5": "debug_panel::ToggleFocus",
"cmd-6": "diagnostics::Deploy",
"cmd-7": "outline_panel::ToggleFocus",
"cmd-8": null, // Services (bottom dock)
"cmd-9": null, // Git History (bottom dock)
"cmd-0": "git_panel::ToggleFocus"
}
},
{ {
"context": "Workspace || Editor", "context": "Workspace || Editor",
"bindings": { "bindings": {
@ -162,15 +146,11 @@
} }
}, },
{ "context": "GitPanel", "bindings": { "cmd-0": "workspace::CloseActiveDock" } }, { "context": "GitPanel", "bindings": { "cmd-0": "workspace::CloseActiveDock" } },
{ "context": "ProjectPanel", "bindings": { "cmd-1": "workspace::CloseActiveDock" } },
{ "context": "DebugPanel", "bindings": { "cmd-5": "workspace::CloseActiveDock" } }, { "context": "DebugPanel", "bindings": { "cmd-5": "workspace::CloseActiveDock" } },
{ "context": "Diagnostics > Editor", "bindings": { "cmd-6": "pane::CloseActiveItem" } }, { "context": "Diagnostics > Editor", "bindings": { "cmd-6": "pane::CloseActiveItem" } },
{ "context": "OutlinePanel", "bindings": { "cmd-7": "workspace::CloseActiveDock" } }, { "context": "OutlinePanel", "bindings": { "cmd-7": "workspace::CloseActiveDock" } },
{ {
"context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", "context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)",
"bindings": { "bindings": { "escape": "editor::ToggleFocus" }
"escape": "editor::ToggleFocus",
"shift-escape": "workspace::CloseActiveDock"
}
} }
] ]

View file

@ -220,8 +220,6 @@
{ {
"context": "vim_mode == normal", "context": "vim_mode == normal",
"bindings": { "bindings": {
"i": "vim::InsertBefore",
"a": "vim::InsertAfter",
"ctrl-[": "editor::Cancel", "ctrl-[": "editor::Cancel",
":": "command_palette::Toggle", ":": "command_palette::Toggle",
"c": "vim::PushChange", "c": "vim::PushChange",
@ -355,7 +353,9 @@
"shift-d": "vim::DeleteToEndOfLine", "shift-d": "vim::DeleteToEndOfLine",
"shift-j": "vim::JoinLines", "shift-j": "vim::JoinLines",
"shift-y": "vim::YankLine", "shift-y": "vim::YankLine",
"i": "vim::InsertBefore",
"shift-i": "vim::InsertFirstNonWhitespace", "shift-i": "vim::InsertFirstNonWhitespace",
"a": "vim::InsertAfter",
"shift-a": "vim::InsertEndOfLine", "shift-a": "vim::InsertEndOfLine",
"o": "vim::InsertLineBelow", "o": "vim::InsertLineBelow",
"shift-o": "vim::InsertLineAbove", "shift-o": "vim::InsertLineAbove",
@ -377,8 +377,6 @@
{ {
"context": "vim_mode == helix_normal && !menu", "context": "vim_mode == helix_normal && !menu",
"bindings": { "bindings": {
"i": "vim::HelixInsert",
"a": "vim::HelixAppend",
"ctrl-[": "editor::Cancel", "ctrl-[": "editor::Cancel",
";": "vim::HelixCollapseSelection", ";": "vim::HelixCollapseSelection",
":": "command_palette::Toggle", ":": "command_palette::Toggle",

View file

@ -691,10 +691,7 @@
// 5. Never show the scrollbar: // 5. Never show the scrollbar:
// "never" // "never"
"show": null "show": null
}, }
// Default depth to expand outline items in the current file.
// Set to 0 to collapse all items that have children, 1 or higher to collapse items at that depth or deeper.
"expand_outlines_with_depth": 100
}, },
"collaboration_panel": { "collaboration_panel": {
// Whether to show the collaboration panel button in the status bar. // Whether to show the collaboration panel button in the status bar.

View file

@ -16,7 +16,6 @@ doctest = false
test-support = ["gpui/test-support", "project/test-support"] test-support = ["gpui/test-support", "project/test-support"]
[dependencies] [dependencies]
agent-client-protocol.workspace = true
agentic-coding-protocol.workspace = true agentic-coding-protocol.workspace = true
anyhow.workspace = true anyhow.workspace = true
assistant_tool.workspace = true assistant_tool.workspace = true

File diff suppressed because it is too large Load diff

View file

@ -1,26 +1,20 @@
use std::{path::Path, rc::Rc}; use agentic_coding_protocol as acp;
use agent_client_protocol as acp;
use anyhow::Result; use anyhow::Result;
use gpui::{AsyncApp, Entity, Task}; use futures::future::{FutureExt as _, LocalBoxFuture};
use project::Project;
use ui::App;
use crate::AcpThread;
pub trait AgentConnection { pub trait AgentConnection {
fn name(&self) -> &'static str; fn request_any(
&self,
fn new_thread( params: acp::AnyAgentRequest,
self: Rc<Self>, ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>>;
project: Entity<Project>, }
cwd: &Path,
cx: &mut AsyncApp, impl AgentConnection for acp::AgentConnection {
) -> Task<Result<Entity<AcpThread>>>; fn request_any(
&self,
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>; params: acp::AnyAgentRequest,
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>; let task = self.request_any(params);
async move { Ok(task.await?) }.boxed_local()
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); }
} }

View file

@ -1,453 +0,0 @@
// Translates old acp agents into the new schema
use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
use anyhow::{Context as _, Result};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
use ui::App;
use util::ResultExt as _;
use crate::{AcpThread, AgentConnection};
#[derive(Clone)]
pub struct OldAcpClientDelegate {
thread: Rc<RefCell<WeakEntity<AcpThread>>>,
cx: AsyncApp,
next_tool_call_id: Rc<RefCell<u64>>,
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
}
impl OldAcpClientDelegate {
pub fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
Self {
thread,
cx,
next_tool_call_id: Rc::new(RefCell::new(0)),
}
}
}
impl acp_old::Client for OldAcpClientDelegate {
async fn stream_assistant_message_chunk(
&self,
params: acp_old::StreamAssistantMessageChunkParams,
) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread
.borrow()
.update(cx, |thread, cx| match params.chunk {
acp_old::AssistantMessageChunk::Text { text } => {
thread.push_assistant_content_block(text.into(), false, cx)
}
acp_old::AssistantMessageChunk::Thought { thought } => {
thread.push_assistant_content_block(thought.into(), true, cx)
}
})
.log_err();
})?;
Ok(())
}
async fn request_tool_call_confirmation(
&self,
request: acp_old::RequestToolCallConfirmationParams,
) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
let cx = &mut self.cx.clone();
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
self.next_tool_call_id.replace(old_acp_id);
let tool_call = into_new_tool_call(
acp::ToolCallId(old_acp_id.to_string().into()),
request.tool_call,
);
let mut options = match request.confirmation {
acp_old::ToolCallConfirmation::Edit { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow Edits".to_string(),
)],
acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", root_command),
)],
acp_old::ToolCallConfirmation::Mcp {
server_name,
tool_name,
..
} => vec![
(
acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", server_name),
),
(
acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", tool_name),
),
],
acp_old::ToolCallConfirmation::Fetch { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow".to_string(),
)],
acp_old::ToolCallConfirmation::Other { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow".to_string(),
)],
};
options.extend([
(
acp_old::ToolCallConfirmationOutcome::Allow,
acp::PermissionOptionKind::AllowOnce,
"Allow".to_string(),
),
(
acp_old::ToolCallConfirmationOutcome::Reject,
acp::PermissionOptionKind::RejectOnce,
"Reject".to_string(),
),
]);
let mut outcomes = Vec::with_capacity(options.len());
let mut acp_options = Vec::with_capacity(options.len());
for (index, (outcome, kind, label)) in options.into_iter().enumerate() {
outcomes.push(outcome);
acp_options.push(acp::PermissionOption {
id: acp::PermissionOptionId(index.to_string().into()),
label,
kind,
})
}
let response = cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.request_tool_call_permission(tool_call, acp_options, cx)
})
})?
.context("Failed to update thread")?
.await;
let outcome = match response {
Ok(option_id) => outcomes[option_id.0.parse::<usize>().unwrap_or(0)],
Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel,
};
Ok(acp_old::RequestToolCallConfirmationResponse {
id: acp_old::ToolCallId(old_acp_id),
outcome: outcome,
})
}
async fn push_tool_call(
&self,
request: acp_old::PushToolCallParams,
) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
let cx = &mut self.cx.clone();
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
self.next_tool_call_id.replace(old_acp_id);
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.upsert_tool_call(
into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request),
cx,
)
})
})?
.context("Failed to update thread")?;
Ok(acp_old::PushToolCallResponse {
id: acp_old::ToolCallId(old_acp_id),
})
}
async fn update_tool_call(
&self,
request: acp_old::UpdateToolCallParams,
) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(request.tool_call_id.0.to_string().into()),
fields: acp::ToolCallUpdateFields {
status: Some(into_new_tool_call_status(request.status)),
content: Some(
request
.content
.into_iter()
.map(into_new_tool_call_content)
.collect::<Vec<_>>(),
),
..Default::default()
},
},
cx,
)
})
})?
.context("Failed to update thread")??;
Ok(())
}
async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.update_plan(
acp::Plan {
entries: request
.entries
.into_iter()
.map(into_new_plan_entry)
.collect(),
},
cx,
)
})
})?
.context("Failed to update thread")?;
Ok(())
}
async fn read_text_file(
&self,
acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams,
) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
let content = self
.cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.read_text_file(path, line, limit, false, cx)
})
})?
.context("Failed to update thread")?
.await?;
Ok(acp_old::ReadTextFileResponse { content })
}
async fn write_text_file(
&self,
acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams,
) -> Result<(), acp_old::Error> {
self.cx
.update(|cx| {
self.thread
.borrow()
.update(cx, |thread, cx| thread.write_text_file(path, content, cx))
})?
.context("Failed to update thread")?
.await?;
Ok(())
}
}
fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall {
acp::ToolCall {
id: id,
label: request.label,
kind: acp_kind_from_old_icon(request.icon),
status: acp::ToolCallStatus::InProgress,
content: request
.content
.into_iter()
.map(into_new_tool_call_content)
.collect(),
locations: request
.locations
.into_iter()
.map(into_new_tool_call_location)
.collect(),
raw_input: None,
}
}
fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind {
match icon {
acp_old::Icon::FileSearch => acp::ToolKind::Search,
acp_old::Icon::Folder => acp::ToolKind::Search,
acp_old::Icon::Globe => acp::ToolKind::Search,
acp_old::Icon::Hammer => acp::ToolKind::Other,
acp_old::Icon::LightBulb => acp::ToolKind::Think,
acp_old::Icon::Pencil => acp::ToolKind::Edit,
acp_old::Icon::Regex => acp::ToolKind::Search,
acp_old::Icon::Terminal => acp::ToolKind::Execute,
}
}
fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus {
match status {
acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress,
acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed,
acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed,
}
}
fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent {
match content {
acp_old::ToolCallContent::Markdown { markdown } => markdown.into(),
acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff {
diff: into_new_diff(diff),
},
}
}
fn into_new_diff(diff: acp_old::Diff) -> acp::Diff {
acp::Diff {
path: diff.path,
old_text: diff.old_text,
new_text: diff.new_text,
}
}
fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation {
acp::ToolCallLocation {
path: location.path,
line: location.line,
}
}
fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry {
acp::PlanEntry {
content: entry.content,
priority: into_new_plan_priority(entry.priority),
status: into_new_plan_status(entry.status),
}
}
fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority {
match priority {
acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low,
acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium,
acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High,
}
}
fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus {
match status {
acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending,
acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress,
acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed,
}
}
#[derive(Debug)]
pub struct Unauthenticated;
impl Error for Unauthenticated {}
impl fmt::Display for Unauthenticated {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unauthenticated")
}
}
pub struct OldAcpAgentConnection {
pub name: &'static str,
pub connection: acp_old::AgentConnection,
pub child_status: Task<Result<()>>,
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
}
impl AgentConnection for OldAcpAgentConnection {
fn name(&self) -> &'static str {
self.name
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let task = self.connection.request_any(
acp_old::InitializeParams {
protocol_version: acp_old::ProtocolVersion::latest(),
}
.into_any(),
);
let current_thread = self.current_thread.clone();
cx.spawn(async move |cx| {
let result = task.await?;
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
anyhow::bail!(Unauthenticated)
}
cx.update(|cx| {
let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into());
AcpThread::new(self.clone(), project, session_id, cx)
});
current_thread.replace(thread.downgrade());
thread
})
})
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
let task = self
.connection
.request_any(acp_old::AuthenticateParams.into_any());
cx.foreground_executor().spawn(async move {
task.await?;
Ok(())
})
}
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
let chunks = params
.prompt
.into_iter()
.filter_map(|block| match block {
acp::ContentBlock::Text(text) => {
Some(acp_old::UserMessageChunk::Text { text: text.text })
}
acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path {
path: link.uri.into(),
}),
_ => None,
})
.collect();
let task = self
.connection
.request_any(acp_old::SendUserMessageParams { chunks }.into_any());
cx.foreground_executor().spawn(async move {
task.await?;
anyhow::Ok(())
})
}
fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) {
let task = self
.connection
.request_any(acp_old::CancelSendMessageParams.into_any());
cx.foreground_executor()
.spawn(async move {
task.await?;
anyhow::Ok(())
})
.detach_and_log_err(cx)
}
}

View file

@ -18,7 +18,6 @@ doctest = false
[dependencies] [dependencies]
acp_thread.workspace = true acp_thread.workspace = true
agent-client-protocol.workspace = true
agentic-coding-protocol.workspace = true agentic-coding-protocol.workspace = true
anyhow.workspace = true anyhow.workspace = true
collections.workspace = true collections.workspace = true
@ -29,7 +28,6 @@ itertools.workspace = true
log.workspace = true log.workspace = true
paths.workspace = true paths.workspace = true
project.workspace = true project.workspace = true
rand.workspace = true
schemars.workspace = true schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
@ -41,7 +39,6 @@ ui.workspace = true
util.workspace = true util.workspace = true
uuid.workspace = true uuid.workspace = true
watch.workspace = true watch.workspace = true
indoc.workspace = true
which.workspace = true which.workspace = true
workspace-hack.workspace = true workspace-hack.workspace = true

View file

@ -1,18 +1,17 @@
mod claude; mod claude;
mod codex;
mod gemini; mod gemini;
mod mcp_server;
mod settings; mod settings;
mod stdio_agent_server;
#[cfg(test)] #[cfg(test)]
mod e2e_tests; mod e2e_tests;
pub use claude::*; pub use claude::*;
pub use codex::*;
pub use gemini::*; pub use gemini::*;
pub use settings::*; pub use settings::*;
pub use stdio_agent_server::*;
use acp_thread::AgentConnection; use acp_thread::AcpThread;
use anyhow::Result; use anyhow::Result;
use collections::HashMap; use collections::HashMap;
use gpui::{App, AsyncApp, Entity, SharedString, Task}; use gpui::{App, AsyncApp, Entity, SharedString, Task};
@ -21,7 +20,6 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
rc::Rc,
sync::Arc, sync::Arc,
}; };
use util::ResultExt as _; use util::ResultExt as _;
@ -35,14 +33,14 @@ pub trait AgentServer: Send {
fn name(&self) -> &'static str; fn name(&self) -> &'static str;
fn empty_state_headline(&self) -> &'static str; fn empty_state_headline(&self) -> &'static str;
fn empty_state_message(&self) -> &'static str; fn empty_state_message(&self) -> &'static str;
fn supports_always_allow(&self) -> bool;
fn connect( fn new_thread(
&self, &self,
// these will go away when old_acp is fully removed
root_dir: &Path, root_dir: &Path,
project: &Entity<Project>, project: &Entity<Project>,
cx: &mut App, cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>>; ) -> Task<Result<Entity<AcpThread>>>;
} }
impl std::fmt::Debug for AgentServerCommand { impl std::fmt::Debug for AgentServerCommand {

View file

@ -1,35 +1,39 @@
mod mcp_server; mod mcp_server;
pub mod tools; mod tools;
use collections::HashMap; use collections::HashMap;
use context_server::listener::McpServerTool;
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use smol::process::Child; use smol::process::Child;
use std::cell::RefCell; use std::cell::RefCell;
use std::fmt::Display; use std::fmt::Display;
use std::path::Path; use std::path::Path;
use std::pin::pin;
use std::rc::Rc; use std::rc::Rc;
use uuid::Uuid; use uuid::Uuid;
use agent_client_protocol as acp; use agentic_coding_protocol::{
self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion,
StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams,
};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use futures::channel::oneshot; use futures::channel::oneshot;
use futures::{AsyncBufReadExt, AsyncWriteExt}; use futures::future::LocalBoxFuture;
use futures::{AsyncBufReadExt, AsyncWriteExt, SinkExt};
use futures::{ use futures::{
AsyncRead, AsyncWrite, FutureExt, StreamExt, AsyncRead, AsyncWrite, FutureExt, StreamExt,
channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
io::BufReader, io::BufReader,
select_biased, select_biased,
}; };
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use gpui::{App, AppContext, Entity, Task};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::ResultExt; use util::ResultExt;
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::mcp_server::ClaudeMcpServer;
use crate::claude::tools::ClaudeTool; use crate::claude::tools::ClaudeTool;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
use acp_thread::{AcpThread, AgentConnection}; use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
#[derive(Clone)] #[derive(Clone)]
pub struct ClaudeCode; pub struct ClaudeCode;
@ -44,51 +48,36 @@ impl AgentServer for ClaudeCode {
} }
fn empty_state_message(&self) -> &'static str { fn empty_state_message(&self) -> &'static str {
"How can I help you today?" ""
} }
fn logo(&self) -> ui::IconName { fn logo(&self) -> ui::IconName {
ui::IconName::AiClaude ui::IconName::AiClaude
} }
fn connect( fn supports_always_allow(&self) -> bool {
&self, false
_root_dir: &Path,
_project: &Entity<Project>,
_cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let connection = ClaudeAgentConnection {
sessions: Default::default(),
};
Task::ready(Ok(Rc::new(connection) as _))
}
}
struct ClaudeAgentConnection {
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
}
impl AgentConnection for ClaudeAgentConnection {
fn name(&self) -> &'static str {
ClaudeCode.name()
} }
fn new_thread( fn new_thread(
self: Rc<Self>, &self,
project: Entity<Project>, root_dir: &Path,
cwd: &Path, project: &Entity<Project>,
cx: &mut AsyncApp, cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> { ) -> Task<Result<Entity<AcpThread>>> {
let cwd = cwd.to_owned(); let project = project.clone();
let root_dir = root_dir.to_path_buf();
let title = self.name().into();
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); let (mut delegate_tx, delegate_rx) = watch::channel(None);
let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
let mut mcp_servers = HashMap::default(); let mut mcp_servers = HashMap::default();
mcp_servers.insert( mcp_servers.insert(
mcp_server::SERVER_NAME.to_string(), mcp_server::SERVER_NAME.to_string(),
permission_mcp_server.server_config()?, mcp_server.server_config()?,
); );
let mcp_config = McpConfig { mcp_servers }; let mcp_config = McpConfig { mcp_servers };
@ -113,158 +102,192 @@ impl AgentConnection for ClaudeAgentConnection {
let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
let (cancel_tx, mut cancel_rx) = mpsc::unbounded::<oneshot::Sender<Result<()>>>();
let session_id = acp::SessionId(Uuid::new_v4().to_string().into()); let session_id = Uuid::new_v4();
log::trace!("Starting session with id: {}", session_id); log::trace!("Starting session with id: {}", session_id);
cx.background_spawn({ cx.background_spawn(async move {
let session_id = session_id.clone(); let mut outgoing_rx = Some(outgoing_rx);
async move { let mut mode = ClaudeSessionMode::Start;
let mut outgoing_rx = Some(outgoing_rx);
let mut child = spawn_claude( loop {
&command, let mut child =
ClaudeSessionMode::Start, spawn_claude(&command, mode, session_id, &mcp_config_path, &root_dir)
session_id.clone(), .await?;
&mcp_config_path, mode = ClaudeSessionMode::Resume;
&cwd,
)
.await?;
let pid = child.id(); let pid = child.id();
log::trace!("Spawned (pid: {})", pid); log::trace!("Spawned (pid: {})", pid);
ClaudeAgentSession::handle_io( let mut io_fut = pin!(
outgoing_rx.take().unwrap(), ClaudeAgentConnection::handle_io(
incoming_message_tx.clone(), outgoing_rx.take().unwrap(),
child.stdin.take().unwrap(), incoming_message_tx.clone(),
child.stdout.take().unwrap(), child.stdin.take().unwrap(),
) child.stdout.take().unwrap(),
.await?; )
.fuse()
);
select_biased! {
done_tx = cancel_rx.next() => {
if let Some(done_tx) = done_tx {
log::trace!("Interrupted (pid: {})", pid);
let result = send_interrupt(pid as i32);
outgoing_rx.replace(io_fut.await?);
done_tx.send(result).log_err();
continue;
}
}
result = io_fut => {
result?;
}
}
log::trace!("Stopped (pid: {})", pid); log::trace!("Stopped (pid: {})", pid);
break;
drop(mcp_config_path);
anyhow::Ok(())
} }
drop(mcp_config_path);
anyhow::Ok(())
}) })
.detach(); .detach();
let end_turn_tx = Rc::new(RefCell::new(None)); cx.new(|cx| {
let handler_task = cx.spawn({ let end_turn_tx = Rc::new(RefCell::new(None));
let end_turn_tx = end_turn_tx.clone(); let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
let thread_rx = thread_rx.clone(); delegate_tx.send(Some(delegate.clone())).log_err();
async move |cx| {
while let Some(message) = incoming_message_rx.next().await { let handler_task = cx.foreground_executor().spawn({
ClaudeAgentSession::handle_message( let end_turn_tx = end_turn_tx.clone();
thread_rx.clone(), let tool_id_map = tool_id_map.clone();
message, let delegate = delegate.clone();
end_turn_tx.clone(), async move {
cx, while let Some(message) = incoming_message_rx.next().await {
) ClaudeAgentConnection::handle_message(
.await delegate.clone(),
message,
end_turn_tx.clone(),
tool_id_map.clone(),
)
.await
}
} }
} });
});
let thread = let mut connection = ClaudeAgentConnection {
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?; delegate,
outgoing_tx,
end_turn_tx,
cancel_tx,
session_id,
_handler_task: handler_task,
_mcp_server: None,
};
thread_tx.send(thread.downgrade())?; connection._mcp_server = Some(mcp_server);
acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
let session = ClaudeAgentSession { })
outgoing_tx,
end_turn_tx,
_handler_task: handler_task,
_mcp_server: Some(permission_mcp_server),
};
self.sessions.borrow_mut().insert(session_id, session);
Ok(thread)
}) })
} }
}
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> { #[cfg(unix)]
Task::ready(Err(anyhow!("Authentication not supported"))) fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> {
} let pid = nix::unistd::Pid::from_raw(pid);
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> { nix::sys::signal::kill(pid, nix::sys::signal::SIGINT)
let sessions = self.sessions.borrow(); .map_err(|e| anyhow!("Failed to interrupt process: {}", e))
let Some(session) = sessions.get(&params.session_id) else { }
return Task::ready(Err(anyhow!(
"Attempted to send message to nonexistent session {}",
params.session_id
)));
};
let (tx, rx) = oneshot::channel(); #[cfg(windows)]
session.end_turn_tx.borrow_mut().replace(tx); fn send_interrupt(_pid: i32) -> anyhow::Result<()> {
panic!("Cancel not implemented on Windows")
}
let mut content = String::new(); impl AgentConnection for ClaudeAgentConnection {
for chunk in params.prompt { /// Send a request to the agent and wait for a response.
match chunk { fn request_any(
acp::ContentBlock::Text(text_content) => { &self,
content.push_str(&text_content.text); params: AnyAgentRequest,
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
let delegate = self.delegate.clone();
let end_turn_tx = self.end_turn_tx.clone();
let outgoing_tx = self.outgoing_tx.clone();
let mut cancel_tx = self.cancel_tx.clone();
let session_id = self.session_id;
async move {
match params {
// todo: consider sending an empty request so we get the init response?
AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
acp::InitializeResponse {
is_authenticated: true,
protocol_version: ProtocolVersion::latest(),
},
)),
AnyAgentRequest::AuthenticateParams(_) => {
Err(anyhow!("Authentication not supported"))
} }
acp::ContentBlock::ResourceLink(resource_link) => { AnyAgentRequest::SendUserMessageParams(message) => {
content.push_str(&format!("@{}", resource_link.uri)); delegate.clear_completed_plan_entries().await?;
let (tx, rx) = oneshot::channel();
end_turn_tx.borrow_mut().replace(tx);
let mut content = String::new();
for chunk in message.chunks {
match chunk {
agentic_coding_protocol::UserMessageChunk::Text { text } => {
content.push_str(&text)
}
agentic_coding_protocol::UserMessageChunk::Path { path } => {
content.push_str(&format!("@{path:?}"))
}
}
}
outgoing_tx.unbounded_send(SdkMessage::User {
message: Message {
role: Role::User,
content: Content::UntaggedText(content),
id: None,
model: None,
stop_reason: None,
stop_sequence: None,
usage: None,
},
session_id: Some(session_id),
})?;
rx.await??;
Ok(AnyAgentResult::SendUserMessageResponse(
acp::SendUserMessageResponse,
))
} }
acp::ContentBlock::Audio(_) AnyAgentRequest::CancelSendMessageParams(_) => {
| acp::ContentBlock::Image(_) let (done_tx, done_rx) = oneshot::channel();
| acp::ContentBlock::Resource(_) => { cancel_tx.send(done_tx).await?;
// TODO done_rx.await??;
Ok(AnyAgentResult::CancelSendMessageResponse(
acp::CancelSendMessageResponse,
))
} }
} }
} }
.boxed_local()
if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
message: Message {
role: Role::User,
content: Content::UntaggedText(content),
id: None,
model: None,
stop_reason: None,
stop_sequence: None,
usage: None,
},
session_id: Some(params.session_id.to_string()),
}) {
return Task::ready(Err(anyhow!(err)));
}
cx.foreground_executor().spawn(async move {
rx.await??;
Ok(())
})
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
let sessions = self.sessions.borrow();
let Some(session) = sessions.get(&session_id) else {
log::warn!("Attempted to cancel nonexistent session {}", session_id);
return;
};
session
.outgoing_tx
.unbounded_send(SdkMessage::new_interrupt_message())
.log_err();
} }
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
enum ClaudeSessionMode { enum ClaudeSessionMode {
Start, Start,
#[expect(dead_code)]
Resume, Resume,
} }
async fn spawn_claude( async fn spawn_claude(
command: &AgentServerCommand, command: &AgentServerCommand,
mode: ClaudeSessionMode, mode: ClaudeSessionMode,
session_id: acp::SessionId, session_id: Uuid,
mcp_config_path: &Path, mcp_config_path: &Path,
root_dir: &Path, root_dir: &Path,
) -> Result<Child> { ) -> Result<Child> {
@ -282,16 +305,10 @@ async fn spawn_claude(
&format!( &format!(
"mcp__{}__{}", "mcp__{}__{}",
mcp_server::SERVER_NAME, mcp_server::SERVER_NAME,
mcp_server::PermissionTool::NAME, mcp_server::PERMISSION_TOOL
), ),
"--allowedTools", "--allowedTools",
&format!( "mcp__zed__Read,mcp__zed__Edit",
"mcp__{}__{},mcp__{}__{}",
mcp_server::SERVER_NAME,
mcp_server::EditTool::NAME,
mcp_server::SERVER_NAME,
mcp_server::ReadTool::NAME
),
"--disallowedTools", "--disallowedTools",
"Read,Edit", "Read,Edit",
]) ])
@ -310,135 +327,105 @@ async fn spawn_claude(
Ok(child) Ok(child)
} }
struct ClaudeAgentSession { struct ClaudeAgentConnection {
delegate: AcpClientDelegate,
session_id: Uuid,
outgoing_tx: UnboundedSender<SdkMessage>, outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>, end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
_mcp_server: Option<ClaudeZedMcpServer>, cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>,
_mcp_server: Option<ClaudeMcpServer>,
_handler_task: Task<()>, _handler_task: Task<()>,
} }
impl ClaudeAgentSession { impl ClaudeAgentConnection {
async fn handle_message( async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>, delegate: AcpClientDelegate,
message: SdkMessage, message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>, end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
cx: &mut AsyncApp, tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
) { ) {
match message { match message {
// we should only be sending these out, they don't need to be in the thread SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => {
SdkMessage::ControlRequest { .. } => {}
SdkMessage::Assistant {
message,
session_id: _,
}
| SdkMessage::User {
message,
session_id: _,
} => {
let Some(thread) = thread_rx
.recv()
.await
.log_err()
.and_then(|entity| entity.upgrade())
else {
log::error!("Received an SDK message but thread is gone");
return;
};
for chunk in message.content.chunks() { for chunk in message.content.chunks() {
match chunk { match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
thread delegate
.update(cx, |thread, cx| { .stream_assistant_message_chunk(StreamAssistantMessageChunkParams {
thread.push_assistant_content_block(text.into(), false, cx) chunk: acp::AssistantMessageChunk::Text { text },
}) })
.await
.log_err(); .log_err();
} }
ContentChunk::ToolUse { id, name, input } => { ContentChunk::ToolUse { id, name, input } => {
let claude_tool = ClaudeTool::infer(&name, input); let claude_tool = ClaudeTool::infer(&name, input);
thread if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
.update(cx, |thread, cx| { delegate
if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { .update_plan(acp::UpdatePlanParams {
thread.update_plan( entries: params.todos.into_iter().map(Into::into).collect(),
acp::Plan { })
entries: params .await
.todos .log_err();
.into_iter() } else if let Some(resp) = delegate
.map(Into::into) .push_tool_call(claude_tool.as_acp())
.collect(), .await
}, .log_err()
cx, {
) tool_id_map.borrow_mut().insert(id, resp.id);
} else { }
thread.upsert_tool_call(
claude_tool.as_acp(acp::ToolCallId(id.into())),
cx,
);
}
})
.log_err();
} }
ContentChunk::ToolResult { ContentChunk::ToolResult {
content, content,
tool_use_id, tool_use_id,
} => { } => {
let content = content.to_string(); let id = tool_id_map.borrow_mut().remove(&tool_use_id);
thread if let Some(id) = id {
.update(cx, |thread, cx| { let content = content.to_string();
thread.update_tool_call( delegate
acp::ToolCallUpdate { .update_tool_call(UpdateToolCallParams {
id: acp::ToolCallId(tool_use_id.into()), tool_call_id: id,
fields: acp::ToolCallUpdateFields { status: acp::ToolCallStatus::Finished,
status: Some(acp::ToolCallStatus::Completed), // Don't unset existing content
content: (!content.is_empty()) content: (!content.is_empty()).then_some(
.then(|| vec![content.into()]), ToolCallContent::Markdown {
..Default::default() // For now we only include text content
markdown: content,
}, },
}, ),
cx, })
) .await
}) .log_err();
.log_err(); }
} }
ContentChunk::Image ContentChunk::Image
| ContentChunk::Document | ContentChunk::Document
| ContentChunk::Thinking | ContentChunk::Thinking
| ContentChunk::RedactedThinking | ContentChunk::RedactedThinking
| ContentChunk::WebSearchToolResult => { | ContentChunk::WebSearchToolResult => {
thread delegate
.update(cx, |thread, cx| { .stream_assistant_message_chunk(StreamAssistantMessageChunkParams {
thread.push_assistant_content_block( chunk: acp::AssistantMessageChunk::Text {
format!("Unsupported content: {:?}", chunk).into(), text: format!("Unsupported content: {:?}", chunk),
false, },
cx,
)
}) })
.await
.log_err(); .log_err();
} }
} }
} }
} }
SdkMessage::Result { SdkMessage::Result {
is_error, is_error, subtype, ..
subtype,
result,
..
} => { } => {
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
if is_error { if is_error {
end_turn_tx end_turn_tx.send(Err(anyhow!("Error: {subtype}"))).ok();
.send(Err(anyhow!(
"Error: {}",
result.unwrap_or_else(|| subtype.to_string())
)))
.ok();
} else { } else {
end_turn_tx.send(Ok(())).ok(); end_turn_tx.send(Ok(())).ok();
} }
} }
} }
SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {} SdkMessage::System { .. } => {}
} }
} }
@ -605,14 +592,16 @@ enum SdkMessage {
Assistant { Assistant {
message: Message, // from Anthropic SDK message: Message, // from Anthropic SDK
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<String>, session_id: Option<Uuid>,
}, },
// A user message // A user message
User { User {
message: Message, // from Anthropic SDK message: Message, // from Anthropic SDK
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<String>, session_id: Option<Uuid>,
}, },
// Emitted as the last message in a conversation // Emitted as the last message in a conversation
Result { Result {
subtype: ResultErrorType, subtype: ResultErrorType,
@ -637,26 +626,6 @@ enum SdkMessage {
#[serde(rename = "permissionMode")] #[serde(rename = "permissionMode")]
permission_mode: PermissionMode, permission_mode: PermissionMode,
}, },
/// Messages used to control the conversation, outside of chat messages to the model
ControlRequest {
request_id: String,
request: ControlRequest,
},
/// Response to a control request
ControlResponse { response: ControlResponse },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "subtype", rename_all = "snake_case")]
enum ControlRequest {
/// Cancel the current conversation
Interrupt,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ControlResponse {
request_id: String,
subtype: ResultErrorType,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -677,24 +646,6 @@ impl Display for ResultErrorType {
} }
} }
impl SdkMessage {
fn new_interrupt_message() -> Self {
use rand::Rng;
// In the Claude Code TS SDK they just generate a random 12 character string,
// `Math.random().toString(36).substring(2, 15)`
let request_id = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(12)
.map(char::from)
.collect();
Self::ControlRequest {
request_id,
request: ControlRequest::Interrupt,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct McpServer { struct McpServer {
name: String, name: String,
@ -710,12 +661,27 @@ enum PermissionMode {
Plan, Plan,
} }
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct McpConfig {
mcp_servers: HashMap<String, McpServerConfig>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct McpServerConfig {
command: String,
args: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
env: Option<HashMap<String, String>>,
}
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use super::*; use super::*;
use serde_json::json; use serde_json::json;
crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow"); crate::common_e2e_tests!(ClaudeCode);
pub fn local_command() -> AgentServerCommand { pub fn local_command() -> AgentServerCommand {
AgentServerCommand { AgentServerCommand {

View file

@ -1,53 +1,78 @@
use std::path::PathBuf; use std::{cell::RefCell, rc::Rc};
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; use acp_thread::AcpClientDelegate;
use acp_thread::AcpThread; use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams};
use agent_client_protocol as acp;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use collections::HashMap; use collections::HashMap;
use context_server::listener::{McpServerTool, ToolResponse}; use context_server::{
use context_server::types::{ listener::McpServer,
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, types::{
ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
ToolResponseContent, ToolsCapabilities, requests,
},
}; };
use gpui::{App, AsyncApp, Task, WeakEntity}; use gpui::{App, AsyncApp, Task};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::debug_panic;
pub struct ClaudeZedMcpServer { use crate::claude::{
server: context_server::listener::McpServer, McpServerConfig,
tools::{ClaudeTool, EditToolParams, ReadToolParams},
};
pub struct ClaudeMcpServer {
server: McpServer,
} }
pub const SERVER_NAME: &str = "zed"; pub const SERVER_NAME: &str = "zed";
pub const READ_TOOL: &str = "Read";
pub const EDIT_TOOL: &str = "Edit";
pub const PERMISSION_TOOL: &str = "Confirmation";
impl ClaudeZedMcpServer { #[derive(Deserialize, JsonSchema, Debug)]
struct PermissionToolParams {
tool_name: String,
input: serde_json::Value,
tool_use_id: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct PermissionToolResponse {
behavior: PermissionToolBehavior,
updated_input: serde_json::Value,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
enum PermissionToolBehavior {
Allow,
Deny,
}
impl ClaudeMcpServer {
pub async fn new( pub async fn new(
thread_rx: watch::Receiver<WeakEntity<AcpThread>>, delegate: watch::Receiver<Option<AcpClientDelegate>>,
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
cx: &AsyncApp, cx: &AsyncApp,
) -> Result<Self> { ) -> Result<Self> {
let mut mcp_server = context_server::listener::McpServer::new(cx).await?; let mut mcp_server = McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize); mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
mcp_server.add_tool(PermissionTool { mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
thread_rx: thread_rx.clone(), Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx)
});
mcp_server.add_tool(ReadTool {
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(EditTool {
thread_rx: thread_rx.clone(),
}); });
Ok(Self { server: mcp_server }) Ok(Self { server: mcp_server })
} }
pub fn server_config(&self) -> Result<McpServerConfig> { pub fn server_config(&self) -> Result<McpServerConfig> {
#[cfg(not(test))]
let zed_path = std::env::current_exe() let zed_path = std::env::current_exe()
.context("finding current executable path for use in mcp_server")?; .context("finding current executable path for use in mcp_server")?
.to_string_lossy()
#[cfg(test)] .to_string();
let zed_path = crate::e2e_tests::get_zed_path();
Ok(McpServerConfig { Ok(McpServerConfig {
command: zed_path, command: zed_path,
@ -81,222 +106,191 @@ impl ClaudeZedMcpServer {
}) })
}) })
} }
}
#[derive(Serialize)] fn handle_list_tools(_: (), cx: &App) -> Task<Result<ListToolsResponse>> {
#[serde(rename_all = "camelCase")] cx.foreground_executor().spawn(async move {
pub struct McpConfig { Ok(ListToolsResponse {
pub mcp_servers: HashMap<String, McpServerConfig>, tools: vec![
} Tool {
name: PERMISSION_TOOL.into(),
#[derive(Serialize, Clone)] input_schema: schemars::schema_for!(PermissionToolParams).into(),
#[serde(rename_all = "camelCase")] description: None,
pub struct McpServerConfig { annotations: None,
pub command: PathBuf, },
pub args: Vec<String>, Tool {
#[serde(skip_serializing_if = "Option::is_none")] name: READ_TOOL.into(),
pub env: Option<HashMap<String, String>>, input_schema: schemars::schema_for!(ReadToolParams).into(),
} description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()),
annotations: Some(ToolAnnotations {
// Tools title: Some("Read file".to_string()),
read_only_hint: Some(true),
#[derive(Clone)] destructive_hint: Some(false),
pub struct PermissionTool { open_world_hint: Some(false),
thread_rx: watch::Receiver<WeakEntity<AcpThread>>, // if time passes the contents might change, but it's not going to do anything different
} // true or false seem too strong, let's try a none.
idempotent_hint: None,
#[derive(Deserialize, JsonSchema, Debug)] }),
pub struct PermissionToolParams { },
tool_name: String, Tool {
input: serde_json::Value, name: EDIT_TOOL.into(),
tool_use_id: Option<String>, input_schema: schemars::schema_for!(EditToolParams).into(),
} description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()),
annotations: Some(ToolAnnotations {
#[derive(Serialize)] title: Some("Edit file".to_string()),
#[serde(rename_all = "camelCase")] read_only_hint: Some(false),
pub struct PermissionToolResponse { destructive_hint: Some(false),
behavior: PermissionToolBehavior, open_world_hint: Some(false),
updated_input: serde_json::Value, idempotent_hint: Some(false),
} }),
},
#[derive(Serialize)] ],
#[serde(rename_all = "snake_case")] next_cursor: None,
enum PermissionToolBehavior { meta: None,
Allow, })
Deny, })
}
impl McpServerTool for PermissionTool {
type Input = PermissionToolParams;
type Output = ();
const NAME: &'static str = "Confirmation";
fn description(&self) -> &'static str {
"Request permission for tool calls"
} }
async fn run( fn handle_call_tool(
&self, request: CallToolParams,
input: Self::Input, mut delegate_watch: watch::Receiver<Option<AcpClientDelegate>>,
cx: &mut AsyncApp, tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
) -> Result<ToolResponse<Self::Output>> { cx: &App,
let mut thread_rx = self.thread_rx.clone(); ) -> Task<Result<CallToolResponse>> {
let Some(thread) = thread_rx.recv().await?.upgrade() else { cx.spawn(async move |cx| {
anyhow::bail!("Thread closed"); let Some(delegate) = delegate_watch.recv().await? else {
}; debug_panic!("Sent None delegate");
anyhow::bail!("Server not available");
};
let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone()); if request.name.as_str() == PERMISSION_TOOL {
let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into()); let input =
let allow_option_id = acp::PermissionOptionId("allow".into()); serde_json::from_value(request.arguments.context("Arguments required")?)?;
let reject_option_id = acp::PermissionOptionId("reject".into());
let chosen_option = thread let result =
.update(cx, |thread, cx| { Self::handle_permissions_tool_call(input, delegate, tool_id_map, cx).await?;
thread.request_tool_call_permission( Ok(CallToolResponse {
claude_tool.as_acp(tool_call_id), content: vec![ToolResponseContent::Text {
vec![ text: serde_json::to_string(&result)?,
acp::PermissionOption { }],
id: allow_option_id.clone(), is_error: None,
label: "Allow".into(), meta: None,
kind: acp::PermissionOptionKind::AllowOnce, })
}, } else if request.name.as_str() == READ_TOOL {
acp::PermissionOption { let input =
id: reject_option_id.clone(), serde_json::from_value(request.arguments.context("Arguments required")?)?;
label: "Reject".into(),
kind: acp::PermissionOptionKind::RejectOnce, let content = Self::handle_read_tool_call(input, delegate, cx).await?;
}, Ok(CallToolResponse {
], content,
cx, is_error: None,
meta: None,
})
} else if request.name.as_str() == EDIT_TOOL {
let input =
serde_json::from_value(request.arguments.context("Arguments required")?)?;
Self::handle_edit_tool_call(input, delegate, cx).await?;
Ok(CallToolResponse {
content: vec![],
is_error: None,
meta: None,
})
} else {
anyhow::bail!("Unsupported tool");
}
})
}
fn handle_read_tool_call(
params: ReadToolParams,
delegate: AcpClientDelegate,
cx: &AsyncApp,
) -> Task<Result<Vec<ToolResponseContent>>> {
cx.foreground_executor().spawn(async move {
let response = delegate
.read_text_file(ReadTextFileParams {
path: params.abs_path,
line: params.offset,
limit: params.limit,
})
.await?;
Ok(vec![ToolResponseContent::Text {
text: response.content,
}])
})
}
fn handle_edit_tool_call(
params: EditToolParams,
delegate: AcpClientDelegate,
cx: &AsyncApp,
) -> Task<Result<()>> {
cx.foreground_executor().spawn(async move {
let response = delegate
.read_text_file_reusing_snapshot(ReadTextFileParams {
path: params.abs_path.clone(),
line: None,
limit: None,
})
.await?;
let new_content = response.content.replace(&params.old_text, &params.new_text);
if new_content == response.content {
return Err(anyhow::anyhow!("The old_text was not found in the content"));
}
delegate
.write_text_file(WriteTextFileParams {
path: params.abs_path,
content: new_content,
})
.await?;
Ok(())
})
}
fn handle_permissions_tool_call(
params: PermissionToolParams,
delegate: AcpClientDelegate,
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
cx: &AsyncApp,
) -> Task<Result<PermissionToolResponse>> {
cx.foreground_executor().spawn(async move {
let claude_tool = ClaudeTool::infer(&params.tool_name, params.input.clone());
let tool_call_id = match params.tool_use_id {
Some(tool_use_id) => tool_id_map
.borrow()
.get(&tool_use_id)
.cloned()
.context("Tool call ID not found")?,
None => delegate.push_tool_call(claude_tool.as_acp()).await?.id,
};
let outcome = delegate
.request_existing_tool_call_confirmation(
tool_call_id,
claude_tool.confirmation(None),
) )
})? .await?;
.await?;
let response = if chosen_option == allow_option_id { match outcome {
PermissionToolResponse { acp::ToolCallConfirmationOutcome::Allow
behavior: PermissionToolBehavior::Allow, | acp::ToolCallConfirmationOutcome::AlwaysAllow
updated_input: input.input, | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
| acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: params.input,
}),
acp::ToolCallConfirmationOutcome::Reject
| acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse {
behavior: PermissionToolBehavior::Deny,
updated_input: params.input,
}),
} }
} else {
debug_assert_eq!(chosen_option, reject_option_id);
PermissionToolResponse {
behavior: PermissionToolBehavior::Deny,
updated_input: input.input,
}
};
Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&response)?,
}],
structured_content: (),
})
}
}
#[derive(Clone)]
pub struct ReadTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for ReadTool {
type Input = ReadToolParams;
type Output = ();
const NAME: &'static str = "Read";
fn description(&self) -> &'static str {
"Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents."
}
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Read file".to_string()),
read_only_hint: Some(true),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: None,
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![ToolResponseContent::Text { text: content }],
structured_content: (),
})
}
}
#[derive(Clone)]
pub struct EditTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for EditTool {
type Input = EditToolParams;
type Output = ();
const NAME: &'static str = "Edit";
fn description(&self) -> &'static str {
"Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better."
}
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: Some("Edit file".to_string()),
read_only_hint: Some(false),
destructive_hint: Some(false),
open_world_hint: Some(false),
idempotent_hint: Some(false),
}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
})?
.await?;
let new_content = content.replace(&input.old_text, &input.new_text);
if new_content == content {
return Err(anyhow::anyhow!("The old_text was not found in the content"));
}
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.abs_path, new_content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: (),
}) })
} }
} }

View file

@ -1,6 +1,6 @@
use std::path::PathBuf; use std::path::PathBuf;
use agent_client_protocol as acp; use agentic_coding_protocol::{self as acp, PushToolCallParams, ToolCallLocation};
use itertools::Itertools; use itertools::Itertools;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -115,36 +115,51 @@ impl ClaudeTool {
Self::Other { name, .. } => name.clone(), Self::Other { name, .. } => name.clone(),
} }
} }
pub fn content(&self) -> Vec<acp::ToolCallContent> {
pub fn content(&self) -> Option<acp::ToolCallContent> {
match &self { match &self {
Self::Other { input, .. } => vec![ Self::Other { input, .. } => Some(acp::ToolCallContent::Markdown {
format!( markdown: format!(
"```json\n{}```", "```json\n{}```",
serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) serde_json::to_string_pretty(&input).unwrap_or("{}".to_string())
) ),
.into(), }),
], Self::Task(Some(params)) => Some(acp::ToolCallContent::Markdown {
Self::Task(Some(params)) => vec![params.prompt.clone().into()], markdown: params.prompt.clone(),
Self::NotebookRead(Some(params)) => { }),
vec![params.notebook_path.display().to_string().into()] Self::NotebookRead(Some(params)) => Some(acp::ToolCallContent::Markdown {
} markdown: params.notebook_path.display().to_string(),
Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()], }),
Self::Terminal(Some(params)) => vec![ Self::NotebookEdit(Some(params)) => Some(acp::ToolCallContent::Markdown {
format!( markdown: params.new_source.clone(),
}),
Self::Terminal(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: format!(
"`{}`\n\n{}", "`{}`\n\n{}",
params.command, params.command,
params.description.as_deref().unwrap_or_default() params.description.as_deref().unwrap_or_default()
) ),
.into(), }),
], Self::ReadFile(Some(params)) => Some(acp::ToolCallContent::Markdown {
Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()], markdown: params.abs_path.display().to_string(),
Self::Ls(Some(params)) => vec![params.path.display().to_string().into()], }),
Self::Glob(Some(params)) => vec![params.to_string().into()], Self::Ls(Some(params)) => Some(acp::ToolCallContent::Markdown {
Self::Grep(Some(params)) => vec![format!("`{params}`").into()], markdown: params.path.display().to_string(),
Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()], }),
Self::WebSearch(Some(params)) => vec![params.to_string().into()], Self::Glob(Some(params)) => Some(acp::ToolCallContent::Markdown {
Self::TodoWrite(Some(params)) => vec![ markdown: params.to_string(),
params }),
Self::Grep(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: format!("`{params}`"),
}),
Self::WebFetch(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: params.prompt.clone(),
}),
Self::WebSearch(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: params.to_string(),
}),
Self::TodoWrite(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: params
.todos .todos
.iter() .iter()
.map(|todo| { .map(|todo| {
@ -159,39 +174,34 @@ impl ClaudeTool {
todo.content todo.content
) )
}) })
.join("\n") .join("\n"),
.into(), }),
], Self::ExitPlanMode(Some(params)) => Some(acp::ToolCallContent::Markdown {
Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()], markdown: params.plan.clone(),
Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff { }),
Self::Edit(Some(params)) => Some(acp::ToolCallContent::Diff {
diff: acp::Diff { diff: acp::Diff {
path: params.abs_path.clone(), path: params.abs_path.clone(),
old_text: Some(params.old_text.clone()), old_text: Some(params.old_text.clone()),
new_text: params.new_text.clone(), new_text: params.new_text.clone(),
}, },
}], }),
Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { Self::Write(Some(params)) => Some(acp::ToolCallContent::Diff {
diff: acp::Diff { diff: acp::Diff {
path: params.file_path.clone(), path: params.file_path.clone(),
old_text: None, old_text: None,
new_text: params.content.clone(), new_text: params.content.clone(),
}, },
}], }),
Self::MultiEdit(Some(params)) => { Self::MultiEdit(Some(params)) => {
// todo: show multiple edits in a multibuffer? // todo: show multiple edits in a multibuffer?
params params.edits.first().map(|edit| acp::ToolCallContent::Diff {
.edits diff: acp::Diff {
.first() path: params.file_path.clone(),
.map(|edit| { old_text: Some(edit.old_string.clone()),
vec![acp::ToolCallContent::Diff { new_text: edit.new_string.clone(),
diff: acp::Diff { },
path: params.file_path.clone(), })
old_text: Some(edit.old_string.clone()),
new_text: edit.new_string.clone(),
},
}]
})
.unwrap_or_default()
} }
Self::Task(None) Self::Task(None)
| Self::NotebookRead(None) | Self::NotebookRead(None)
@ -207,80 +217,181 @@ impl ClaudeTool {
| Self::ExitPlanMode(None) | Self::ExitPlanMode(None)
| Self::Edit(None) | Self::Edit(None)
| Self::Write(None) | Self::Write(None)
| Self::MultiEdit(None) => vec![], | Self::MultiEdit(None) => None,
} }
} }
pub fn kind(&self) -> acp::ToolKind { pub fn icon(&self) -> acp::Icon {
match self { match self {
Self::Task(_) => acp::ToolKind::Think, Self::Task(_) => acp::Icon::Hammer,
Self::NotebookRead(_) => acp::ToolKind::Read, Self::NotebookRead(_) => acp::Icon::FileSearch,
Self::NotebookEdit(_) => acp::ToolKind::Edit, Self::NotebookEdit(_) => acp::Icon::Pencil,
Self::Edit(_) => acp::ToolKind::Edit, Self::Edit(_) => acp::Icon::Pencil,
Self::MultiEdit(_) => acp::ToolKind::Edit, Self::MultiEdit(_) => acp::Icon::Pencil,
Self::Write(_) => acp::ToolKind::Edit, Self::Write(_) => acp::Icon::Pencil,
Self::ReadFile(_) => acp::ToolKind::Read, Self::ReadFile(_) => acp::Icon::FileSearch,
Self::Ls(_) => acp::ToolKind::Search, Self::Ls(_) => acp::Icon::Folder,
Self::Glob(_) => acp::ToolKind::Search, Self::Glob(_) => acp::Icon::FileSearch,
Self::Grep(_) => acp::ToolKind::Search, Self::Grep(_) => acp::Icon::Regex,
Self::Terminal(_) => acp::ToolKind::Execute, Self::Terminal(_) => acp::Icon::Terminal,
Self::WebSearch(_) => acp::ToolKind::Search, Self::WebSearch(_) => acp::Icon::Globe,
Self::WebFetch(_) => acp::ToolKind::Fetch, Self::WebFetch(_) => acp::Icon::Globe,
Self::TodoWrite(_) => acp::ToolKind::Think, Self::TodoWrite(_) => acp::Icon::LightBulb,
Self::ExitPlanMode(_) => acp::ToolKind::Think, Self::ExitPlanMode(_) => acp::Icon::Hammer,
Self::Other { .. } => acp::ToolKind::Other, Self::Other { .. } => acp::Icon::Hammer,
}
}
pub fn confirmation(&self, description: Option<String>) -> acp::ToolCallConfirmation {
match &self {
Self::Edit(_) | Self::Write(_) | Self::NotebookEdit(_) | Self::MultiEdit(_) => {
acp::ToolCallConfirmation::Edit { description }
}
Self::WebFetch(params) => acp::ToolCallConfirmation::Fetch {
urls: params
.as_ref()
.map(|p| vec![p.url.clone()])
.unwrap_or_default(),
description,
},
Self::Terminal(Some(BashToolParams {
description,
command,
..
})) => acp::ToolCallConfirmation::Execute {
command: command.clone(),
root_command: command.clone(),
description: description.clone(),
},
Self::ExitPlanMode(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {}", params.plan)
} else {
params.plan.clone()
},
},
Self::Task(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {}", params.description)
} else {
params.description.clone()
},
},
Self::Ls(Some(LsToolParams { path, .. }))
| Self::ReadFile(Some(ReadToolParams { abs_path: path, .. })) => {
let path = path.display();
acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {path}")
} else {
path.to_string()
},
}
}
Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
let path = notebook_path.display();
acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {path}")
} else {
path.to_string()
},
}
}
Self::Glob(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {params}")
} else {
params.to_string()
},
},
Self::Grep(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {params}")
} else {
params.to_string()
},
},
Self::WebSearch(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {params}")
} else {
params.to_string()
},
},
Self::TodoWrite(Some(params)) => {
let params = params.todos.iter().map(|todo| &todo.content).join(", ");
acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {params}")
} else {
params
},
}
}
Self::Terminal(None)
| Self::Task(None)
| Self::NotebookRead(None)
| Self::ExitPlanMode(None)
| Self::Ls(None)
| Self::Glob(None)
| Self::Grep(None)
| Self::ReadFile(None)
| Self::WebSearch(None)
| Self::TodoWrite(None)
| Self::Other { .. } => acp::ToolCallConfirmation::Other {
description: description.unwrap_or("".to_string()),
},
} }
} }
pub fn locations(&self) -> Vec<acp::ToolCallLocation> { pub fn locations(&self) -> Vec<acp::ToolCallLocation> {
match &self { match &self {
Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation { Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![ToolCallLocation {
path: abs_path.clone(), path: abs_path.clone(),
line: None, line: None,
}], }],
Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => {
vec![acp::ToolCallLocation { vec![ToolCallLocation {
path: file_path.clone(),
line: None,
}]
}
Self::Write(Some(WriteToolParams { file_path, .. })) => {
vec![acp::ToolCallLocation {
path: file_path.clone(), path: file_path.clone(),
line: None, line: None,
}] }]
} }
Self::Write(Some(WriteToolParams { file_path, .. })) => vec![ToolCallLocation {
path: file_path.clone(),
line: None,
}],
Self::ReadFile(Some(ReadToolParams { Self::ReadFile(Some(ReadToolParams {
abs_path, offset, .. abs_path, offset, ..
})) => vec![acp::ToolCallLocation { })) => vec![ToolCallLocation {
path: abs_path.clone(), path: abs_path.clone(),
line: *offset, line: *offset,
}], }],
Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
vec![acp::ToolCallLocation { vec![ToolCallLocation {
path: notebook_path.clone(), path: notebook_path.clone(),
line: None, line: None,
}] }]
} }
Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => {
vec![acp::ToolCallLocation { vec![ToolCallLocation {
path: notebook_path.clone(), path: notebook_path.clone(),
line: None, line: None,
}] }]
} }
Self::Glob(Some(GlobToolParams { Self::Glob(Some(GlobToolParams {
path: Some(path), .. path: Some(path), ..
})) => vec![acp::ToolCallLocation { })) => vec![ToolCallLocation {
path: path.clone(), path: path.clone(),
line: None, line: None,
}], }],
Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation { Self::Ls(Some(LsToolParams { path, .. })) => vec![ToolCallLocation {
path: path.clone(), path: path.clone(),
line: None, line: None,
}], }],
Self::Grep(Some(GrepToolParams { Self::Grep(Some(GrepToolParams {
path: Some(path), .. path: Some(path), ..
})) => vec![acp::ToolCallLocation { })) => vec![ToolCallLocation {
path: PathBuf::from(path), path: PathBuf::from(path),
line: None, line: None,
}], }],
@ -303,15 +414,12 @@ impl ClaudeTool {
} }
} }
pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall { pub fn as_acp(&self) -> PushToolCallParams {
acp::ToolCall { PushToolCallParams {
id,
kind: self.kind(),
status: acp::ToolCallStatus::InProgress,
label: self.label(), label: self.label(),
content: self.content(), content: self.content(),
icon: self.icon(),
locations: self.locations(), locations: self.locations(),
raw_input: None,
} }
} }
} }

View file

@ -1,317 +0,0 @@
use agent_client_protocol as acp;
use anyhow::anyhow;
use collections::HashMap;
use context_server::listener::McpServerTool;
use context_server::types::requests;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use project::Project;
use settings::SettingsStore;
use smol::stream::StreamExt as _;
use std::cell::RefCell;
use std::rc::Rc;
use std::{path::Path, sync::Arc};
use util::ResultExt;
use anyhow::{Context, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::mcp_server::ZedMcpServer;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
use acp_thread::{AcpThread, AgentConnection};
#[derive(Clone)]
pub struct Codex;
impl AgentServer for Codex {
fn name(&self) -> &'static str {
"Codex"
}
fn empty_state_headline(&self) -> &'static str {
"Welcome to Codex"
}
fn empty_state_message(&self) -> &'static str {
"What can I help with?"
}
fn logo(&self) -> ui::IconName {
ui::IconName::AiOpenAi
}
fn connect(
&self,
_root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let project = project.clone();
cx.spawn(async move |cx| {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).codex.clone()
})?;
let Some(command) =
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
else {
anyhow::bail!("Failed to find codex binary");
};
let client: Arc<ContextServer> = ContextServer::stdio(
ContextServerId("codex-mcp-server".into()),
ContextServerCommand {
path: command.path,
args: command.args,
env: command.env,
},
)
.into();
ContextServer::start(client.clone(), cx).await?;
let (notification_tx, mut notification_rx) = mpsc::unbounded();
client
.client()
.context("Failed to subscribe")?
.on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
move |notification, _cx| {
let notification_tx = notification_tx.clone();
log::trace!(
"ACP Notification: {}",
serde_json::to_string_pretty(&notification).unwrap()
);
if let Some(notification) =
serde_json::from_value::<acp::SessionNotification>(notification)
.log_err()
{
notification_tx.unbounded_send(notification).ok();
}
}
});
let sessions = Rc::new(RefCell::new(HashMap::default()));
let notification_handler_task = cx.spawn({
let sessions = sessions.clone();
async move |cx| {
while let Some(notification) = notification_rx.next().await {
CodexConnection::handle_session_notification(
notification,
sessions.clone(),
cx,
)
}
}
});
let connection = CodexConnection {
client,
sessions,
_notification_handler_task: notification_handler_task,
};
Ok(Rc::new(connection) as _)
})
}
}
struct CodexConnection {
client: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
_notification_handler_task: Task<()>,
}
struct CodexSession {
thread: WeakEntity<AcpThread>,
cancel_tx: Option<oneshot::Sender<()>>,
_mcp_server: ZedMcpServer,
}
impl AgentConnection for CodexConnection {
fn name(&self) -> &'static str {
"Codex"
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let client = self.client.client();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let client = client.context("MCP server is not initialized yet")?;
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
let response = client
.request::<requests::CallTool>(context_server::types::CallToolParams {
name: acp::NEW_SESSION_TOOL_NAME.into(),
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
mcp_servers: [(
mcp_server::SERVER_NAME.to_string(),
mcp_server.server_config()?,
)]
.into(),
client_tools: acp::ClientTools {
request_permission: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
}),
read_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
}),
write_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
}),
},
cwd,
})?),
meta: None,
})
.await?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
let result = serde_json::from_value::<acp::NewSessionOutput>(
response.structured_content.context("Empty response")?,
)?;
let thread =
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
thread_tx.send(thread.downgrade())?;
let session = CodexSession {
thread: thread.downgrade(),
cancel_tx: None,
_mcp_server: mcp_server,
};
sessions.borrow_mut().insert(result.session_id, session);
Ok(thread)
})
}
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Err(anyhow!("Authentication not supported")))
}
fn prompt(
&self,
params: agent_client_protocol::PromptArguments,
cx: &mut App,
) -> Task<Result<()>> {
let client = self.client.client();
let sessions = self.sessions.clone();
cx.foreground_executor().spawn(async move {
let client = client.context("MCP server is not initialized yet")?;
let (new_cancel_tx, cancel_rx) = oneshot::channel();
{
let mut sessions = sessions.borrow_mut();
let session = sessions
.get_mut(&params.session_id)
.context("Session not found")?;
session.cancel_tx.replace(new_cancel_tx);
}
let result = client
.request_with::<requests::CallTool>(
context_server::types::CallToolParams {
name: acp::PROMPT_TOOL_NAME.into(),
arguments: Some(serde_json::to_value(params)?),
meta: None,
},
Some(cancel_rx),
None,
)
.await;
if let Err(err) = &result
&& err.is::<context_server::client::RequestCanceled>()
{
return Ok(());
}
let response = result?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
Ok(())
})
}
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
let mut sessions = self.sessions.borrow_mut();
if let Some(cancel_tx) = sessions
.get_mut(session_id)
.and_then(|session| session.cancel_tx.take())
{
cancel_tx.send(()).ok();
}
}
}
impl CodexConnection {
pub fn handle_session_notification(
notification: acp::SessionNotification,
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
cx: &mut AsyncApp,
) {
let threads = threads.borrow();
let Some(thread) = threads
.get(&notification.session_id)
.and_then(|session| session.thread.upgrade())
else {
log::error!(
"Thread not found for session ID: {}",
notification.session_id
);
return;
};
thread
.update(cx, |thread, cx| {
thread.handle_session_update(notification.update, cx)
})
.log_err();
}
}
impl Drop for CodexConnection {
fn drop(&mut self) {
self.client.stop().log_err();
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::AgentServerCommand;
use std::path::Path;
crate::common_e2e_tests!(Codex, allow_option_id = "approve");
pub fn local_command() -> AgentServerCommand {
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../../../codex/codex-rs/target/debug/codex");
AgentServerCommand {
path: cli_path,
args: vec![],
env: None,
}
}
}

View file

@ -1,17 +1,15 @@
use std::{ use std::{path::Path, sync::Arc, time::Duration};
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; use acp_thread::{
use agent_client_protocol as acp; AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
};
use agentic_coding_protocol as acp;
use futures::{FutureExt, StreamExt, channel::mpsc, select}; use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::{Entity, TestAppContext}; use gpui::{Entity, TestAppContext};
use indoc::indoc; use indoc::indoc;
use project::{FakeFs, Project}; use project::{FakeFs, Project};
use serde_json::json;
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use util::path; use util::path;
@ -26,11 +24,7 @@ pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppCont
.unwrap(); .unwrap();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
assert!( assert_eq!(thread.entries().len(), 2);
thread.entries().len() >= 2,
"Expected at least 2 entries. Got: {:?}",
thread.entries()
);
assert!(matches!( assert!(matches!(
thread.entries()[0], thread.entries()[0],
AgentThreadEntry::UserMessage(_) AgentThreadEntry::UserMessage(_)
@ -60,25 +54,19 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send( thread.send(
vec![ acp::SendUserMessageParams {
acp::ContentBlock::Text(acp::TextContent { chunks: vec![
text: "Read the file ".into(), acp::UserMessageChunk::Text {
annotations: None, text: "Read the file ".into(),
}), },
acp::ContentBlock::ResourceLink(acp::ResourceLink { acp::UserMessageChunk::Path {
uri: "foo.rs".into(), path: Path::new("foo.rs").into(),
name: "foo.rs".into(), },
annotations: None, acp::UserMessageChunk::Text {
description: None, text: " and tell me what the content of the println! is".into(),
mime_type: None, },
size: None, ],
title: None, },
}),
acp::ContentBlock::Text(acp::TextContent {
text: " and tell me what the content of the println! is".into(),
annotations: None,
}),
],
cx, cx,
) )
}) })
@ -86,44 +74,37 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes
.unwrap(); .unwrap();
thread.read_with(cx, |thread, cx| { thread.read_with(cx, |thread, cx| {
assert_eq!(thread.entries().len(), 3);
assert!(matches!( assert!(matches!(
thread.entries()[0], thread.entries()[0],
AgentThreadEntry::UserMessage(_) AgentThreadEntry::UserMessage(_)
)); ));
let assistant_message = &thread assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
.entries() let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
.iter() panic!("Expected AssistantMessage")
.rev() };
.find_map(|entry| match entry {
AgentThreadEntry::AssistantMessage(msg) => Some(msg),
_ => None,
})
.unwrap();
assert!( assert!(
assistant_message.to_markdown(cx).contains("Hello, world!"), assistant_message.to_markdown(cx).contains("Hello, world!"),
"unexpected assistant message: {:?}", "unexpected assistant message: {:?}",
assistant_message.to_markdown(cx) assistant_message.to_markdown(cx)
); );
}); });
drop(tempdir);
} }
pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) { pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
let _fs = init_test(cx).await; let fs = init_test(cx).await;
fs.insert_tree(
let tempdir = tempfile::tempdir().unwrap(); path!("/private/tmp"),
let foo_path = tempdir.path().join("foo"); json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file"); )
.await;
let project = Project::example([tempdir.path()], &mut cx.to_async()).await; let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send_raw( thread.send_raw(
&format!("Read {} and tell me what you see.", foo_path.display()), "Read the '/private/tmp/foo' file and tell me what you see.",
cx, cx,
) )
}) })
@ -146,13 +127,10 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp
.any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) }) .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
); );
}); });
drop(tempdir);
} }
pub async fn test_tool_call_with_confirmation( pub async fn test_tool_call_with_confirmation(
server: impl AgentServer + 'static, server: impl AgentServer + 'static,
allow_option_id: acp::PermissionOptionId,
cx: &mut TestAppContext, cx: &mut TestAppContext,
) { ) {
let fs = init_test(cx).await; let fs = init_test(cx).await;
@ -160,7 +138,7 @@ pub async fn test_tool_call_with_confirmation(
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| { let full_turn = thread.update(cx, |thread, cx| {
thread.send_raw( thread.send_raw(
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
cx, cx,
) )
}); });
@ -180,11 +158,14 @@ pub async fn test_tool_call_with_confirmation(
) )
.await; .await;
let tool_call_id = thread.read_with(cx, |thread, cx| { let tool_call_id = thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
id, id,
label, status:
status: ToolCallStatus::WaitingForConfirmation { .. }, ToolCallStatus::WaitingForConfirmation {
confirmation: ToolCallConfirmation::Execute { root_command, .. },
..
},
.. ..
}) = &thread }) = &thread
.entries() .entries()
@ -195,19 +176,13 @@ pub async fn test_tool_call_with_confirmation(
panic!(); panic!();
}; };
let label = label.read(cx).source(); assert!(root_command.contains("touch"));
assert!(label.contains("touch"), "Got: {}", label);
id.clone() *id
}); });
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
thread.authorize_tool_call( thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
tool_call_id,
allow_option_id,
acp::PermissionOptionKind::AllowOnce,
cx,
);
assert!(thread.entries().iter().any(|entry| matches!( assert!(thread.entries().iter().any(|entry| matches!(
entry, entry,
@ -222,7 +197,7 @@ pub async fn test_tool_call_with_confirmation(
thread.read_with(cx, |thread, cx| { thread.read_with(cx, |thread, cx| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
content, content: Some(ToolCallContent::Markdown { markdown }),
status: ToolCallStatus::Allowed { .. }, status: ToolCallStatus::Allowed { .. },
.. ..
}) = thread }) = thread
@ -234,10 +209,13 @@ pub async fn test_tool_call_with_confirmation(
panic!(); panic!();
}; };
assert!( markdown.read_with(cx, |md, _cx| {
content.iter().any(|c| c.to_markdown(cx).contains("Hello")), assert!(
"Expected content to contain 'Hello'" md.source().contains("Hello"),
); r#"Expected '{}' to contain "Hello""#,
md.source()
);
});
}); });
} }
@ -248,7 +226,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| { let full_turn = thread.update(cx, |thread, cx| {
thread.send_raw( thread.send_raw(
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
cx, cx,
) )
}); });
@ -268,24 +246,29 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
) )
.await; .await;
thread.read_with(cx, |thread, cx| { thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
id, id,
label, status:
status: ToolCallStatus::WaitingForConfirmation { .. }, ToolCallStatus::WaitingForConfirmation {
confirmation: ToolCallConfirmation::Execute { root_command, .. },
..
},
.. ..
}) = &thread.entries()[first_tool_call_ix] }) = &thread.entries()[first_tool_call_ix]
else { else {
panic!("{:?}", thread.entries()[1]); panic!("{:?}", thread.entries()[1]);
}; };
let label = label.read(cx).source(); assert!(root_command.contains("touch"));
assert!(label.contains("touch"), "Got: {}", label);
id.clone() *id
}); });
let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); thread
.update(cx, |thread, cx| thread.cancel(cx))
.await
.unwrap();
full_turn.await.unwrap(); full_turn.await.unwrap();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
@ -313,7 +296,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
#[macro_export] #[macro_export]
macro_rules! common_e2e_tests { macro_rules! common_e2e_tests {
($server:expr, allow_option_id = $allow_option_id:expr) => { ($server:expr) => {
mod common_e2e { mod common_e2e {
use super::*; use super::*;
@ -338,12 +321,7 @@ macro_rules! common_e2e_tests {
#[::gpui::test] #[::gpui::test]
#[cfg_attr(not(feature = "e2e"), ignore)] #[cfg_attr(not(feature = "e2e"), ignore)]
async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) { async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
$crate::e2e_tests::test_tool_call_with_confirmation( $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
$server,
::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
cx,
)
.await;
} }
#[::gpui::test] #[::gpui::test]
@ -375,9 +353,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
gemini: Some(AgentServerSettings { gemini: Some(AgentServerSettings {
command: crate::gemini::tests::local_command(), command: crate::gemini::tests::local_command(),
}), }),
codex: Some(AgentServerSettings {
command: crate::codex::tests::local_command(),
}),
}, },
cx, cx,
); );
@ -394,16 +369,15 @@ pub async fn new_test_thread(
current_dir: impl AsRef<Path>, current_dir: impl AsRef<Path>,
cx: &mut TestAppContext, cx: &mut TestAppContext,
) -> Entity<AcpThread> { ) -> Entity<AcpThread> {
let connection = cx let thread = cx
.update(|cx| server.connect(current_dir.as_ref(), &project, cx)) .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
.await .await
.unwrap(); .unwrap();
let thread = connection thread
.new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async()) .update(cx, |thread, _| thread.initialize())
.await .await
.unwrap(); .unwrap();
thread thread
} }
@ -436,24 +410,3 @@ pub async fn run_until_first_tool_call(
} }
} }
} }
pub fn get_zed_path() -> PathBuf {
let mut zed_path = std::env::current_exe().unwrap();
while zed_path
.file_name()
.map_or(true, |name| name.to_string_lossy() != "debug")
{
if !zed_path.pop() {
panic!("Could not find target directory");
}
}
zed_path.push("zed");
if !zed_path.exists() {
panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
}
zed_path
}

View file

@ -1,17 +1,9 @@
use anyhow::anyhow; use crate::stdio_agent_server::StdioAgentServer;
use std::cell::RefCell; use crate::{AgentServerCommand, AgentServerVersion};
use std::path::Path;
use std::rc::Rc;
use util::ResultExt as _;
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
use agentic_coding_protocol as acp_old;
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use gpui::{AsyncApp, Entity};
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use ui::App;
use crate::AllAgentServersSettings; use crate::AllAgentServersSettings;
@ -20,7 +12,7 @@ pub struct Gemini;
const ACP_ARG: &str = "--experimental-acp"; const ACP_ARG: &str = "--experimental-acp";
impl AgentServer for Gemini { impl StdioAgentServer for Gemini {
fn name(&self) -> &'static str { fn name(&self) -> &'static str {
"Gemini" "Gemini"
} }
@ -33,89 +25,14 @@ impl AgentServer for Gemini {
"Ask questions, edit files, run commands.\nBe specific for the best results." "Ask questions, edit files, run commands.\nBe specific for the best results."
} }
fn supports_always_allow(&self) -> bool {
true
}
fn logo(&self) -> ui::IconName { fn logo(&self) -> ui::IconName {
ui::IconName::AiGemini ui::IconName::AiGemini
} }
fn connect(
&self,
root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let root_dir = root_dir.to_path_buf();
let project = project.clone();
let this = self.clone();
let name = self.name();
cx.spawn(async move |cx| {
let command = this.command(&project, cx).await?;
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()?;
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let foreground_executor = cx.foreground_executor().clone();
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
stdin,
stdout,
move |fut| foreground_executor.spawn(fut).detach(),
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
});
let child_status = cx.background_spawn(async move {
let result = match child.status().await {
Err(e) => Err(anyhow!(e)),
Ok(result) if result.success() => Ok(()),
Ok(result) => {
if let Some(AgentServerVersion::Unsupported {
error_message,
upgrade_message,
upgrade_command,
}) = this.version(&command).await.log_err()
{
Err(anyhow!(LoadError::Unsupported {
error_message,
upgrade_message,
upgrade_command
}))
} else {
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
}
}
};
drop(io_task);
result
});
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
name,
connection,
child_status,
current_thread: thread_rc,
});
Ok(connection)
})
}
}
impl Gemini {
async fn command( async fn command(
&self, &self,
project: &Entity<Project>, project: &Entity<Project>,
@ -189,7 +106,7 @@ pub(crate) mod tests {
use crate::AgentServerCommand; use crate::AgentServerCommand;
use std::path::Path; use std::path::Path;
crate::common_e2e_tests!(Gemini, allow_option_id = "0"); crate::common_e2e_tests!(Gemini);
pub fn local_command() -> AgentServerCommand { pub fn local_command() -> AgentServerCommand {
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))

View file

@ -1,207 +0,0 @@
use acp_thread::AcpThread;
use agent_client_protocol as acp;
use anyhow::Result;
use context_server::listener::{McpServerTool, ToolResponse};
use context_server::types::{
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
ToolsCapabilities, requests,
};
use futures::channel::oneshot;
use gpui::{App, AsyncApp, Task, WeakEntity};
use indoc::indoc;
pub struct ZedMcpServer {
server: context_server::listener::McpServer,
}
pub const SERVER_NAME: &str = "zed";
impl ZedMcpServer {
pub async fn new(
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
cx: &AsyncApp,
) -> Result<Self> {
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
mcp_server.add_tool(RequestPermissionTool {
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(ReadTextFileTool {
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(WriteTextFileTool {
thread_rx: thread_rx.clone(),
});
Ok(Self { server: mcp_server })
}
pub fn server_config(&self) -> Result<acp::McpServerConfig> {
#[cfg(not(test))]
let zed_path = anyhow::Context::context(
std::env::current_exe(),
"finding current executable path for use in mcp_server",
)?;
#[cfg(test)]
let zed_path = crate::e2e_tests::get_zed_path();
Ok(acp::McpServerConfig {
command: zed_path,
args: vec![
"--nc".into(),
self.server.socket_path().display().to_string(),
],
env: None,
})
}
fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
cx.foreground_executor().spawn(async move {
Ok(InitializeResponse {
protocol_version: ProtocolVersion("2025-06-18".into()),
capabilities: ServerCapabilities {
experimental: None,
logging: None,
completions: None,
prompts: None,
resources: None,
tools: Some(ToolsCapabilities {
list_changed: Some(false),
}),
},
server_info: Implementation {
name: SERVER_NAME.into(),
version: "0.1.0".into(),
},
meta: None,
})
})
}
}
// Tools
#[derive(Clone)]
pub struct RequestPermissionTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for RequestPermissionTool {
type Input = acp::RequestPermissionArguments;
type Output = acp::RequestPermissionOutput;
const NAME: &'static str = "Confirmation";
fn description(&self) -> &'static str {
indoc! {"
Request permission for tool calls.
This tool is meant to be called programmatically by the agent loop, not the LLM.
"}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let result = thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(input.tool_call, input.options, cx)
})?
.await;
let outcome = match result {
Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id },
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
};
Ok(ToolResponse {
content: vec![],
structured_content: acp::RequestPermissionOutput { outcome },
})
}
}
#[derive(Clone)]
pub struct ReadTextFileTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for ReadTextFileTool {
type Input = acp::ReadTextFileArguments;
type Output = acp::ReadTextFileOutput;
const NAME: &'static str = "Read";
fn description(&self) -> &'static str {
"Reads the content of the given file in the project including unsaved changes."
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.path, input.line, input.limit, false, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: acp::ReadTextFileOutput { content },
})
}
}
#[derive(Clone)]
pub struct WriteTextFileTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for WriteTextFileTool {
type Input = acp::WriteTextFileArguments;
type Output = ();
const NAME: &'static str = "Write";
fn description(&self) -> &'static str {
"Write to a file replacing its contents"
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.path, input.content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: (),
})
}
}

View file

@ -13,7 +13,6 @@ pub fn init(cx: &mut App) {
pub struct AllAgentServersSettings { pub struct AllAgentServersSettings {
pub gemini: Option<AgentServerSettings>, pub gemini: Option<AgentServerSettings>,
pub claude: Option<AgentServerSettings>, pub claude: Option<AgentServerSettings>,
pub codex: Option<AgentServerSettings>,
} }
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] #[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings {
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> { fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
let mut settings = AllAgentServersSettings::default(); let mut settings = AllAgentServersSettings::default();
for AllAgentServersSettings { for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
gemini,
claude,
codex,
} in sources.defaults_and_customizations()
{
if gemini.is_some() { if gemini.is_some() {
settings.gemini = gemini.clone(); settings.gemini = gemini.clone();
} }
if claude.is_some() { if claude.is_some() {
settings.claude = claude.clone(); settings.claude = claude.clone();
} }
if codex.is_some() {
settings.codex = codex.clone();
}
} }
Ok(settings) Ok(settings)

View file

@ -0,0 +1,119 @@
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
use acp_thread::{AcpClientDelegate, AcpThread, LoadError};
use agentic_coding_protocol as acp;
use anyhow::{Result, anyhow};
use gpui::{App, AsyncApp, Entity, Task, prelude::*};
use project::Project;
use std::path::Path;
use util::ResultExt;
pub trait StdioAgentServer: Send + Clone {
fn logo(&self) -> ui::IconName;
fn name(&self) -> &'static str;
fn empty_state_headline(&self) -> &'static str;
fn empty_state_message(&self) -> &'static str;
fn supports_always_allow(&self) -> bool;
fn command(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> impl Future<Output = Result<AgentServerCommand>>;
fn version(
&self,
command: &AgentServerCommand,
) -> impl Future<Output = Result<AgentServerVersion>> + Send;
}
impl<T: StdioAgentServer + 'static> AgentServer for T {
fn name(&self) -> &'static str {
self.name()
}
fn empty_state_headline(&self) -> &'static str {
self.empty_state_headline()
}
fn empty_state_message(&self) -> &'static str {
self.empty_state_message()
}
fn logo(&self) -> ui::IconName {
self.logo()
}
fn supports_always_allow(&self) -> bool {
self.supports_always_allow()
}
fn new_thread(
&self,
root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let root_dir = root_dir.to_path_buf();
let project = project.clone();
let this = self.clone();
let title = self.name().into();
cx.spawn(async move |cx| {
let command = this.command(&project, cx).await?;
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()?;
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
cx.new(|cx| {
let foreground_executor = cx.foreground_executor().clone();
let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
stdin,
stdout,
move |fut| foreground_executor.spawn(fut).detach(),
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
});
let child_status = cx.background_spawn(async move {
let result = match child.status().await {
Err(e) => Err(anyhow!(e)),
Ok(result) if result.success() => Ok(()),
Ok(result) => {
if let Some(AgentServerVersion::Unsupported {
error_message,
upgrade_message,
upgrade_command,
}) = this.version(&command).await.log_err()
{
Err(anyhow!(LoadError::Unsupported {
error_message,
upgrade_message,
upgrade_command
}))
} else {
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
}
}
};
drop(io_task);
result
});
AcpThread::new(connection, title, Some(child_status), project.clone(), cx)
})
})
}
}

View file

@ -17,10 +17,10 @@ test-support = ["gpui/test-support", "language/test-support"]
[dependencies] [dependencies]
acp_thread.workspace = true acp_thread.workspace = true
agent-client-protocol.workspace = true
agent.workspace = true agent.workspace = true
agent_servers.workspace = true agentic-coding-protocol.workspace = true
agent_settings.workspace = true agent_settings.workspace = true
agent_servers.workspace = true
ai_onboarding.workspace = true ai_onboarding.workspace = true
anyhow.workspace = true anyhow.workspace = true
assistant_context.workspace = true assistant_context.workspace = true

View file

@ -1,4 +1,4 @@
use acp_thread::{AgentConnection, Plan}; use acp_thread::Plan;
use agent_servers::AgentServer; use agent_servers::AgentServer;
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::BTreeMap; use std::collections::BTreeMap;
@ -7,7 +7,7 @@ use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use agent_client_protocol as acp; use agentic_coding_protocol::{self as acp};
use assistant_tool::ActionLog; use assistant_tool::ActionLog;
use buffer_diff::BufferDiff; use buffer_diff::BufferDiff;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
@ -16,6 +16,7 @@ use editor::{
EditorStyle, MinimapVisibility, MultiBuffer, PathKey, EditorStyle, MinimapVisibility, MultiBuffer, PathKey,
}; };
use file_icons::FileIcons; use file_icons::FileIcons;
use futures::channel::oneshot;
use gpui::{ use gpui::{
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement,
@ -38,7 +39,8 @@ use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
use ::acp_thread::{ use ::acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff,
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent,
ToolCallId, ToolCallStatus,
}; };
use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet};
@ -62,13 +64,12 @@ pub struct AcpThreadView {
last_error: Option<Entity<Markdown>>, last_error: Option<Entity<Markdown>>,
list_state: ListState, list_state: ListState,
auth_task: Option<Task<()>>, auth_task: Option<Task<()>>,
expanded_tool_calls: HashSet<acp::ToolCallId>, expanded_tool_calls: HashSet<ToolCallId>,
expanded_thinking_blocks: HashSet<(usize, usize)>, expanded_thinking_blocks: HashSet<(usize, usize)>,
edits_expanded: bool, edits_expanded: bool,
plan_expanded: bool, plan_expanded: bool,
editor_expanded: bool, editor_expanded: bool,
message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>, message_history: Rc<RefCell<MessageHistory<acp::SendUserMessageParams>>>,
_cancel_task: Option<Task<()>>,
} }
enum ThreadState { enum ThreadState {
@ -81,16 +82,22 @@ enum ThreadState {
}, },
LoadError(LoadError), LoadError(LoadError),
Unauthenticated { Unauthenticated {
connection: Rc<dyn AgentConnection>, thread: Entity<AcpThread>,
}, },
} }
struct AlwaysAllowOption {
id: &'static str,
label: SharedString,
outcome: acp::ToolCallConfirmationOutcome,
}
impl AcpThreadView { impl AcpThreadView {
pub fn new( pub fn new(
agent: Rc<dyn AgentServer>, agent: Rc<dyn AgentServer>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
project: Entity<Project>, project: Entity<Project>,
message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>, message_history: Rc<RefCell<MessageHistory<acp::SendUserMessageParams>>>,
min_lines: usize, min_lines: usize,
max_lines: Option<usize>, max_lines: Option<usize>,
window: &mut Window, window: &mut Window,
@ -184,7 +191,6 @@ impl AcpThreadView {
plan_expanded: false, plan_expanded: false,
editor_expanded: false, editor_expanded: false,
message_history, message_history,
_cancel_task: None,
} }
} }
@ -202,9 +208,9 @@ impl AcpThreadView {
.map(|worktree| worktree.read(cx).abs_path()) .map(|worktree| worktree.read(cx).abs_path())
.unwrap_or_else(|| paths::home_dir().as_path().into()); .unwrap_or_else(|| paths::home_dir().as_path().into());
let connect_task = agent.connect(&root_dir, &project, cx); let task = agent.new_thread(&root_dir, &project, cx);
let load_task = cx.spawn_in(window, async move |this, cx| { let load_task = cx.spawn_in(window, async move |this, cx| {
let connection = match connect_task.await { let thread = match task.await {
Ok(thread) => thread, Ok(thread) => thread,
Err(err) => { Err(err) => {
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
@ -216,30 +222,48 @@ impl AcpThreadView {
} }
}; };
let result = match connection let init_response = async {
.clone() let resp = thread
.new_thread(project.clone(), &root_dir, cx) .read_with(cx, |thread, _cx| thread.initialize())?
.await .await?;
{ anyhow::Ok(resp)
};
let result = match init_response.await {
Err(e) => { Err(e) => {
let mut cx = cx.clone(); let mut cx = cx.clone();
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() { if e.downcast_ref::<oneshot::Canceled>().is_some() {
this.update(&mut cx, |this, cx| { let child_status = thread
this.thread_state = ThreadState::Unauthenticated { connection }; .update(&mut cx, |thread, _| thread.child_status())
cx.notify(); .ok()
}) .flatten();
.ok(); if let Some(child_status) = child_status {
return; match child_status.await {
Ok(_) => Err(e),
Err(e) => Err(e),
}
} else {
Err(e)
}
} else { } else {
Err(e) Err(e)
} }
} }
Ok(session_id) => Ok(session_id), Ok(response) => {
if !response.is_authenticated {
this.update(cx, |this, _| {
this.thread_state = ThreadState::Unauthenticated { thread };
})
.ok();
return;
};
Ok(())
}
}; };
this.update_in(cx, |this, window, cx| { this.update_in(cx, |this, window, cx| {
match result { match result {
Ok(thread) => { Ok(()) => {
let thread_subscription = let thread_subscription =
cx.subscribe_in(&thread, window, Self::handle_thread_event); cx.subscribe_in(&thread, window, Self::handle_thread_event);
@ -281,10 +305,10 @@ impl AcpThreadView {
pub fn thread(&self) -> Option<&Entity<AcpThread>> { pub fn thread(&self) -> Option<&Entity<AcpThread>> {
match &self.thread_state { match &self.thread_state {
ThreadState::Ready { thread, .. } => Some(thread), ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => {
ThreadState::Unauthenticated { .. } Some(thread)
| ThreadState::Loading { .. } }
| ThreadState::LoadError(..) => None, ThreadState::Loading { .. } | ThreadState::LoadError(..) => None,
} }
} }
@ -301,7 +325,7 @@ impl AcpThreadView {
self.last_error.take(); self.last_error.take();
if let Some(thread) = self.thread() { if let Some(thread) = self.thread() {
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
} }
} }
@ -338,7 +362,7 @@ impl AcpThreadView {
self.last_error.take(); self.last_error.take();
let mut ix = 0; let mut ix = 0;
let mut chunks: Vec<acp::ContentBlock> = Vec::new(); let mut chunks: Vec<acp::UserMessageChunk> = Vec::new();
let project = self.project.clone(); let project = self.project.clone();
self.message_editor.update(cx, |editor, cx| { self.message_editor.update(cx, |editor, cx| {
let text = editor.text(cx); let text = editor.text(cx);
@ -350,19 +374,12 @@ impl AcpThreadView {
{ {
let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot);
if crease_range.start > ix { if crease_range.start > ix {
chunks.push(text[ix..crease_range.start].into()); chunks.push(acp::UserMessageChunk::Text {
text: text[ix..crease_range.start].to_string(),
});
} }
if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) {
let path_str = abs_path.display().to_string(); chunks.push(acp::UserMessageChunk::Path { path: abs_path });
chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink {
uri: path_str.clone(),
name: path_str,
annotations: None,
description: None,
mime_type: None,
size: None,
title: None,
}));
} }
ix = crease_range.end; ix = crease_range.end;
} }
@ -371,7 +388,9 @@ impl AcpThreadView {
if ix < text.len() { if ix < text.len() {
let last_chunk = text[ix..].trim(); let last_chunk = text[ix..].trim();
if !last_chunk.is_empty() { if !last_chunk.is_empty() {
chunks.push(last_chunk.into()); chunks.push(acp::UserMessageChunk::Text {
text: last_chunk.into(),
});
} }
} }
}) })
@ -382,7 +401,8 @@ impl AcpThreadView {
} }
let Some(thread) = self.thread() else { return }; let Some(thread) = self.thread() else { return };
let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); let message = acp::SendUserMessageParams { chunks };
let task = thread.update(cx, |thread, cx| thread.send(message.clone(), cx));
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let result = task.await; let result = task.await;
@ -404,7 +424,7 @@ impl AcpThreadView {
editor.remove_creases(mention_set.lock().drain(), cx) editor.remove_creases(mention_set.lock().drain(), cx)
}); });
self.message_history.borrow_mut().push(chunks); self.message_history.borrow_mut().push(message);
} }
fn previous_history_message( fn previous_history_message(
@ -470,7 +490,7 @@ impl AcpThreadView {
message_editor: Entity<Editor>, message_editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>, mention_set: Arc<Mutex<MentionSet>>,
project: Entity<Project>, project: Entity<Project>,
message: Option<&Vec<acp::ContentBlock>>, message: Option<&acp::SendUserMessageParams>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> bool { ) -> bool {
@ -483,19 +503,18 @@ impl AcpThreadView {
let mut text = String::new(); let mut text = String::new();
let mut mentions = Vec::new(); let mut mentions = Vec::new();
for chunk in message { for chunk in &message.chunks {
match chunk { match chunk {
acp::ContentBlock::Text(text_content) => { acp::UserMessageChunk::Text { text: chunk } => {
text.push_str(&text_content.text); text.push_str(&chunk);
} }
acp::ContentBlock::ResourceLink(resource_link) => { acp::UserMessageChunk::Path { path } => {
let path = Path::new(&resource_link.uri);
let start = text.len(); let start = text.len();
let content = MentionPath::new(&path).to_string(); let content = MentionPath::new(path).to_string();
text.push_str(&content); text.push_str(&content);
let end = text.len(); let end = text.len();
if let Some(project_path) = if let Some(project_path) =
project.read(cx).project_path_for_absolute_path(&path, cx) project.read(cx).project_path_for_absolute_path(path, cx)
{ {
let filename: SharedString = path let filename: SharedString = path
.file_name() .file_name()
@ -506,9 +525,6 @@ impl AcpThreadView {
mentions.push((start..end, project_path, filename)); mentions.push((start..end, project_path, filename));
} }
} }
acp::ContentBlock::Image(_)
| acp::ContentBlock::Audio(_)
| acp::ContentBlock::Resource(_) => {}
} }
} }
@ -574,79 +590,71 @@ impl AcpThreadView {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else { let Some(multibuffer) = self.entry_diff_multibuffer(entry_ix, cx) else {
return; return;
}; };
let multibuffers = multibuffers.collect::<Vec<_>>(); if self.diff_editors.contains_key(&multibuffer.entity_id()) {
return;
for multibuffer in multibuffers {
if self.diff_editors.contains_key(&multibuffer.entity_id()) {
return;
}
let editor = cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
multibuffer.clone(),
None,
window,
cx,
);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_expand_excerpt_buttons(cx);
editor.set_show_vertical_scrollbar(false, cx);
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_indent_guides(false, cx);
editor.set_read_only(true);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(TextStyleRefinement {
font_size: Some(
TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
});
editor
});
let entity_id = multibuffer.entity_id();
cx.observe_release(&multibuffer, move |this, _, _| {
this.diff_editors.remove(&entity_id);
})
.detach();
self.diff_editors.insert(entity_id, editor);
} }
let editor = cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
multibuffer.clone(),
None,
window,
cx,
);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_expand_excerpt_buttons(cx);
editor.set_show_vertical_scrollbar(false, cx);
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_indent_guides(false, cx);
editor.set_read_only(true);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(TextStyleRefinement {
font_size: Some(
TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
});
editor
});
let entity_id = multibuffer.entity_id();
cx.observe_release(&multibuffer, move |this, _, _| {
this.diff_editors.remove(&entity_id);
})
.detach();
self.diff_editors.insert(entity_id, editor);
} }
fn entry_diff_multibuffers( fn entry_diff_multibuffer(&self, entry_ix: usize, cx: &App) -> Option<Entity<MultiBuffer>> {
&self,
entry_ix: usize,
cx: &App,
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
let entry = self.thread()?.read(cx).entries().get(entry_ix)?; let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
Some(entry.diffs().map(|diff| diff.multibuffer.clone())) entry.diff().map(|diff| diff.multibuffer.clone())
} }
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let ThreadState::Unauthenticated { ref connection } = self.thread_state else { let Some(thread) = self.thread().cloned() else {
return; return;
}; };
self.last_error.take(); self.last_error.take();
let authenticate = connection.authenticate(cx); let authenticate = thread.read(cx).authenticate();
self.auth_task = Some(cx.spawn_in(window, { self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone(); let project = self.project.clone();
let agent = self.agent.clone(); let agent = self.agent.clone();
@ -676,16 +684,15 @@ impl AcpThreadView {
fn authorize_tool_call( fn authorize_tool_call(
&mut self, &mut self,
tool_call_id: acp::ToolCallId, id: ToolCallId,
option_id: acp::PermissionOptionId, outcome: acp::ToolCallConfirmationOutcome,
option_kind: acp::PermissionOptionKind,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let Some(thread) = self.thread() else { let Some(thread) = self.thread() else {
return; return;
}; };
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx); thread.authorize_tool_call(id, outcome, cx);
}); });
cx.notify(); cx.notify();
} }
@ -712,12 +719,10 @@ impl AcpThreadView {
.border_1() .border_1()
.border_color(cx.theme().colors().border) .border_color(cx.theme().colors().border)
.text_xs() .text_xs()
.children(message.content.markdown().map(|md| { .child(self.render_markdown(
self.render_markdown( message.content.clone(),
md.clone(), user_message_markdown_style(window, cx),
user_message_markdown_style(window, cx), )),
)
})),
) )
.into_any(), .into_any(),
AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => { AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => {
@ -725,28 +730,20 @@ impl AcpThreadView {
let message_body = v_flex() let message_body = v_flex()
.w_full() .w_full()
.gap_2p5() .gap_2p5()
.children(chunks.iter().enumerate().filter_map( .children(chunks.iter().enumerate().map(|(chunk_ix, chunk)| {
|(chunk_ix, chunk)| match chunk { match chunk {
AssistantMessageChunk::Message { block } => { AssistantMessageChunk::Text { chunk } => self
block.markdown().map(|md| { .render_markdown(chunk.clone(), style.clone())
self.render_markdown(md.clone(), style.clone()) .into_any_element(),
.into_any_element() AssistantMessageChunk::Thought { chunk } => self.render_thinking_block(
}) index,
} chunk_ix,
AssistantMessageChunk::Thought { block } => { chunk.clone(),
block.markdown().map(|md| { window,
self.render_thinking_block( cx,
index, ),
chunk_ix, }
md.clone(), }))
window,
cx,
)
.into_any_element()
})
}
},
))
.into_any(); .into_any();
v_flex() v_flex()
@ -872,12 +869,9 @@ impl AcpThreadView {
let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix)); let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix));
let status_icon = match &tool_call.status { let status_icon = match &tool_call.status {
ToolCallStatus::WaitingForConfirmation { .. } => None,
ToolCallStatus::Allowed { ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Pending, status: acp::ToolCallStatus::Running,
}
| ToolCallStatus::WaitingForConfirmation { .. } => None,
ToolCallStatus::Allowed {
status: acp::ToolCallStatus::InProgress,
.. ..
} => Some( } => Some(
Icon::new(IconName::ArrowCircle) Icon::new(IconName::ArrowCircle)
@ -891,13 +885,13 @@ impl AcpThreadView {
.into_any(), .into_any(),
), ),
ToolCallStatus::Allowed { ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Completed, status: acp::ToolCallStatus::Finished,
.. ..
} => None, } => None,
ToolCallStatus::Rejected ToolCallStatus::Rejected
| ToolCallStatus::Canceled | ToolCallStatus::Canceled
| ToolCallStatus::Allowed { | ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Failed, status: acp::ToolCallStatus::Error,
.. ..
} => Some( } => Some(
Icon::new(IconName::X) Icon::new(IconName::X)
@ -915,9 +909,34 @@ impl AcpThreadView {
.any(|content| matches!(content, ToolCallContent::Diff { .. })), .any(|content| matches!(content, ToolCallContent::Diff { .. })),
}; };
let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; let is_collapsible = tool_call.content.is_some() && !needs_confirmation;
let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id); let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id);
let content = if is_open {
match &tool_call.status {
ToolCallStatus::WaitingForConfirmation { confirmation, .. } => {
Some(self.render_tool_call_confirmation(
tool_call.id,
confirmation,
tool_call.content.as_ref(),
window,
cx,
))
}
ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {
tool_call.content.as_ref().map(|content| {
div()
.py_1p5()
.child(self.render_tool_call_content(content, window, cx))
.into_any_element()
})
}
ToolCallStatus::Rejected => None,
}
} else {
None
};
v_flex() v_flex()
.when(needs_confirmation, |this| { .when(needs_confirmation, |this| {
this.rounded_lg() this.rounded_lg()
@ -957,19 +976,9 @@ impl AcpThreadView {
}) })
.gap_1p5() .gap_1p5()
.child( .child(
Icon::new(match tool_call.kind { Icon::new(tool_call.icon)
acp::ToolKind::Read => IconName::ToolRead, .size(IconSize::Small)
acp::ToolKind::Edit => IconName::ToolPencil, .color(Color::Muted),
acp::ToolKind::Delete => IconName::ToolDeleteFile,
acp::ToolKind::Move => IconName::ArrowRightLeft,
acp::ToolKind::Search => IconName::ToolSearch,
acp::ToolKind::Execute => IconName::ToolTerminal,
acp::ToolKind::Think => IconName::ToolBulb,
acp::ToolKind::Fetch => IconName::ToolWeb,
acp::ToolKind::Other => IconName::ToolHammer,
})
.size(IconSize::Small)
.color(Color::Muted),
) )
.child(if tool_call.locations.len() == 1 { .child(if tool_call.locations.len() == 1 {
let name = tool_call.locations[0] let name = tool_call.locations[0]
@ -1014,16 +1023,16 @@ impl AcpThreadView {
.gap_0p5() .gap_0p5()
.when(is_collapsible, |this| { .when(is_collapsible, |this| {
this.child( this.child(
Disclosure::new(("expand", entry_ix), is_open) Disclosure::new(("expand", tool_call.id.0), is_open)
.opened_icon(IconName::ChevronUp) .opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown) .closed_icon(IconName::ChevronDown)
.on_click(cx.listener({ .on_click(cx.listener({
let id = tool_call.id.clone(); let id = tool_call.id;
move |this: &mut Self, _, _, cx: &mut Context<Self>| { move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open { if is_open {
this.expanded_tool_calls.remove(&id); this.expanded_tool_calls.remove(&id);
} else { } else {
this.expanded_tool_calls.insert(id.clone()); this.expanded_tool_calls.insert(id);
} }
cx.notify(); cx.notify();
} }
@ -1033,12 +1042,12 @@ impl AcpThreadView {
.children(status_icon), .children(status_icon),
) )
.on_click(cx.listener({ .on_click(cx.listener({
let id = tool_call.id.clone(); let id = tool_call.id;
move |this: &mut Self, _, _, cx: &mut Context<Self>| { move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open { if is_open {
this.expanded_tool_calls.remove(&id); this.expanded_tool_calls.remove(&id);
} else { } else {
this.expanded_tool_calls.insert(id.clone()); this.expanded_tool_calls.insert(id);
} }
cx.notify(); cx.notify();
} }
@ -1046,7 +1055,7 @@ impl AcpThreadView {
) )
.when(is_open, |this| { .when(is_open, |this| {
this.child( this.child(
v_flex() div()
.text_xs() .text_xs()
.when(is_collapsible, |this| { .when(is_collapsible, |this| {
this.mt_1() this.mt_1()
@ -1055,45 +1064,7 @@ impl AcpThreadView {
.bg(cx.theme().colors().editor_background) .bg(cx.theme().colors().editor_background)
.rounded_lg() .rounded_lg()
}) })
.map(|this| { .children(content),
if is_open {
match &tool_call.status {
ToolCallStatus::WaitingForConfirmation { options, .. } => this
.children(tool_call.content.iter().map(|content| {
div()
.py_1p5()
.child(
self.render_tool_call_content(
content, window, cx,
),
)
.into_any_element()
}))
.child(self.render_permission_buttons(
options,
entry_ix,
tool_call.id.clone(),
tool_call.content.is_empty(),
cx,
)),
ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {
this.children(tool_call.content.iter().map(|content| {
div()
.py_1p5()
.child(
self.render_tool_call_content(
content, window, cx,
),
)
.into_any_element()
}))
}
ToolCallStatus::Rejected => this,
}
} else {
this
}
}),
) )
}) })
} }
@ -1105,20 +1076,14 @@ impl AcpThreadView {
cx: &Context<Self>, cx: &Context<Self>,
) -> AnyElement { ) -> AnyElement {
match content { match content {
ToolCallContent::ContentBlock { content } => { ToolCallContent::Markdown { markdown } => {
if let Some(md) = content.markdown() { div()
div() .p_2()
.p_2() .child(self.render_markdown(
.child( markdown.clone(),
self.render_markdown( default_markdown_style(false, window, cx),
md.clone(), ))
default_markdown_style(false, window, cx), .into_any_element()
),
)
.into_any_element()
} else {
Empty.into_any_element()
}
} }
ToolCallContent::Diff { ToolCallContent::Diff {
diff: Diff { multibuffer, .. }, diff: Diff { multibuffer, .. },
@ -1127,56 +1092,223 @@ impl AcpThreadView {
} }
} }
fn render_permission_buttons( fn render_tool_call_confirmation(
&self, &self,
options: &[acp::PermissionOption], tool_call_id: ToolCallId,
entry_ix: usize, confirmation: &ToolCallConfirmation,
tool_call_id: acp::ToolCallId, content: Option<&ToolCallContent>,
empty_content: bool, window: &Window,
cx: &Context<Self>,
) -> AnyElement {
let confirmation_container = v_flex().mt_1().py_1p5();
match confirmation {
ToolCallConfirmation::Edit { description } => confirmation_container
.child(
div()
.px_2()
.children(description.clone().map(|description| {
self.render_markdown(
description,
default_markdown_style(false, window, cx),
)
})),
)
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[AlwaysAllowOption {
id: "always_allow",
label: "Always Allow Edits".into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
}],
tool_call_id,
cx,
))
.into_any(),
ToolCallConfirmation::Execute {
command,
root_command,
description,
} => confirmation_container
.child(v_flex().px_2().pb_1p5().child(command.clone()).children(
description.clone().map(|description| {
self.render_markdown(description, default_markdown_style(false, window, cx))
.on_url_click({
let workspace = self.workspace.clone();
move |text, window, cx| {
Self::open_link(text, &workspace, window, cx);
}
})
}),
))
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[AlwaysAllowOption {
id: "always_allow",
label: format!("Always Allow {root_command}").into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
}],
tool_call_id,
cx,
))
.into_any(),
ToolCallConfirmation::Mcp {
server_name,
tool_name: _,
tool_display_name,
description,
} => confirmation_container
.child(
v_flex()
.px_2()
.pb_1p5()
.child(format!("{server_name} - {tool_display_name}"))
.children(description.clone().map(|description| {
self.render_markdown(
description,
default_markdown_style(false, window, cx),
)
})),
)
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[
AlwaysAllowOption {
id: "always_allow_server",
label: format!("Always Allow {server_name}").into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
},
AlwaysAllowOption {
id: "always_allow_tool",
label: format!("Always Allow {tool_display_name}").into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowTool,
},
],
tool_call_id,
cx,
))
.into_any(),
ToolCallConfirmation::Fetch { description, urls } => confirmation_container
.child(
v_flex()
.px_2()
.pb_1p5()
.gap_1()
.children(urls.iter().map(|url| {
h_flex().child(
Button::new(url.clone(), url)
.icon(IconName::ArrowUpRight)
.icon_color(Color::Muted)
.icon_size(IconSize::XSmall)
.on_click({
let url = url.clone();
move |_, _, cx| cx.open_url(&url)
}),
)
}))
.children(description.clone().map(|description| {
self.render_markdown(
description,
default_markdown_style(false, window, cx),
)
})),
)
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[AlwaysAllowOption {
id: "always_allow",
label: "Always Allow".into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
}],
tool_call_id,
cx,
))
.into_any(),
ToolCallConfirmation::Other { description } => confirmation_container
.child(v_flex().px_2().pb_1p5().child(self.render_markdown(
description.clone(),
default_markdown_style(false, window, cx),
)))
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[AlwaysAllowOption {
id: "always_allow",
label: "Always Allow".into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
}],
tool_call_id,
cx,
))
.into_any(),
}
}
fn render_confirmation_buttons(
&self,
always_allow_options: &[AlwaysAllowOption],
tool_call_id: ToolCallId,
cx: &Context<Self>, cx: &Context<Self>,
) -> Div { ) -> Div {
h_flex() h_flex()
.py_1p5() .pt_1p5()
.px_1p5() .px_1p5()
.gap_1() .gap_1()
.justify_end() .justify_end()
.when(!empty_content, |this| { .border_t_1()
this.border_t_1() .border_color(self.tool_card_border_color(cx))
.border_color(self.tool_card_border_color(cx)) .when(self.agent.supports_always_allow(), |this| {
}) this.children(always_allow_options.into_iter().map(|always_allow_option| {
.children(options.iter().map(|option| { let outcome = always_allow_option.outcome;
let option_id = SharedString::from(option.id.0.clone()); Button::new(
Button::new((option_id, entry_ix), option.label.clone()) (always_allow_option.id, tool_call_id.0),
.map(|this| match option.kind { always_allow_option.label.clone(),
acp::PermissionOptionKind::AllowOnce => { )
this.icon(IconName::Check).icon_color(Color::Success) .icon(IconName::CheckDouble)
}
acp::PermissionOptionKind::AllowAlways => {
this.icon(IconName::CheckDouble).icon_color(Color::Success)
}
acp::PermissionOptionKind::RejectOnce => {
this.icon(IconName::X).icon_color(Color::Error)
}
acp::PermissionOptionKind::RejectAlways => {
this.icon(IconName::X).icon_color(Color::Error)
}
})
.icon_position(IconPosition::Start) .icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall) .icon_size(IconSize::XSmall)
.icon_color(Color::Success)
.on_click(cx.listener({ .on_click(cx.listener({
let tool_call_id = tool_call_id.clone(); let id = tool_call_id;
let option_id = option.id.clone(); move |this, _, _, cx| {
let option_kind = option.kind; this.authorize_tool_call(id, outcome, cx);
}
}))
}))
})
.child(
Button::new(("allow", tool_call_id.0), "Allow")
.icon(IconName::Check)
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| { move |this, _, _, cx| {
this.authorize_tool_call( this.authorize_tool_call(
tool_call_id.clone(), id,
option_id.clone(), acp::ToolCallConfirmationOutcome::Allow,
option_kind,
cx, cx,
); );
} }
})) })),
})) )
.child(
Button::new(("reject", tool_call_id.0), "Reject")
.icon(IconName::X)
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon_color(Color::Error)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Reject,
cx,
);
}
})),
)
} }
fn render_diff_editor(&self, multibuffer: &Entity<MultiBuffer>) -> AnyElement { fn render_diff_editor(&self, multibuffer: &Entity<MultiBuffer>) -> AnyElement {
@ -2113,11 +2245,12 @@ impl AcpThreadView {
.languages .languages
.language_for_name("Markdown"); .language_for_name("Markdown");
let (thread_summary, markdown) = if let Some(thread) = self.thread() { let (thread_summary, markdown) = match &self.thread_state {
let thread = thread.read(cx); ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => {
(thread.title().to_string(), thread.to_markdown(cx)) let thread = thread.read(cx);
} else { (thread.title().to_string(), thread.to_markdown(cx))
return Task::ready(Ok(())); }
ThreadState::Loading { .. } | ThreadState::LoadError(..) => return Task::ready(Ok(())),
}; };
window.spawn(cx, async move |cx| { window.spawn(cx, async move |cx| {

View file

@ -193,7 +193,6 @@ impl AgentConfiguration {
.unwrap_or(false); .unwrap_or(false);
v_flex() v_flex()
.w_full()
.when(is_expanded, |this| this.mb_2()) .when(is_expanded, |this| this.mb_2())
.child( .child(
div() div()
@ -224,7 +223,6 @@ impl AgentConfiguration {
.hover(|hover| hover.bg(cx.theme().colors().element_hover)) .hover(|hover| hover.bg(cx.theme().colors().element_hover))
.child( .child(
h_flex() h_flex()
.w_full()
.gap_2() .gap_2()
.child( .child(
Icon::new(provider.icon()) Icon::new(provider.icon())
@ -233,7 +231,6 @@ impl AgentConfiguration {
) )
.child( .child(
h_flex() h_flex()
.w_full()
.gap_1() .gap_1()
.child( .child(
Label::new(provider_name.clone()) Label::new(provider_name.clone())
@ -317,7 +314,6 @@ impl AgentConfiguration {
let providers = LanguageModelRegistry::read_global(cx).providers(); let providers = LanguageModelRegistry::read_global(cx).providers();
v_flex() v_flex()
.w_full()
.child( .child(
h_flex() h_flex()
.p(DynamicSpacing::Base16.rems(cx)) .p(DynamicSpacing::Base16.rems(cx))
@ -328,67 +324,50 @@ impl AgentConfiguration {
.justify_between() .justify_between()
.child( .child(
v_flex() v_flex()
.w_full()
.gap_0p5() .gap_0p5()
.child( .child(Headline::new("LLM Providers"))
h_flex()
.w_full()
.gap_2()
.justify_between()
.child(Headline::new("LLM Providers"))
.child(
PopoverMenu::new("add-provider-popover")
.trigger(
Button::new("add-provider", "Add Provider")
.icon_position(IconPosition::Start)
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.label_size(LabelSize::Small),
)
.anchor(gpui::Corner::TopRight)
.menu({
let workspace = self.workspace.clone();
move |window, cx| {
Some(ContextMenu::build(
window,
cx,
|menu, _window, _cx| {
menu.header("Compatible APIs").entry(
"OpenAI",
None,
{
let workspace =
workspace.clone();
move |window, cx| {
workspace
.update(cx, |workspace, cx| {
AddLlmProviderModal::toggle(
LlmCompatibleProvider::OpenAi,
workspace,
window,
cx,
);
})
.log_err();
}
},
)
},
))
}
}),
),
)
.child( .child(
Label::new("Add at least one provider to use AI-powered features.") Label::new("Add at least one provider to use AI-powered features.")
.color(Color::Muted), .color(Color::Muted),
), ),
)
.child(
PopoverMenu::new("add-provider-popover")
.trigger(
Button::new("add-provider", "Add Provider")
.icon_position(IconPosition::Start)
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.label_size(LabelSize::Small),
)
.anchor(gpui::Corner::TopRight)
.menu({
let workspace = self.workspace.clone();
move |window, cx| {
Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
menu.header("Compatible APIs").entry("OpenAI", None, {
let workspace = workspace.clone();
move |window, cx| {
workspace
.update(cx, |workspace, cx| {
AddLlmProviderModal::toggle(
LlmCompatibleProvider::OpenAi,
workspace,
window,
cx,
);
})
.log_err();
}
})
}))
}
}),
), ),
) )
.child( .child(
div() div()
.w_full()
.pl(DynamicSpacing::Base08.rems(cx)) .pl(DynamicSpacing::Base08.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx)) .pr(DynamicSpacing::Base20.rems(cx))
.children( .children(
@ -404,9 +383,9 @@ impl AgentConfiguration {
let fs = self.fs.clone(); let fs = self.fs.clone();
SwitchField::new( SwitchField::new(
"always-allow-tool-actions-switch", "single-file-review",
"Allow running commands without asking for confirmation", "Enable single-file agent reviews",
"The agent can perform potentially destructive actions without asking for your confirmation.", "Agent edits are also displayed in single-file editors for review.",
always_allow_tool_actions, always_allow_tool_actions,
move |state, _window, cx| { move |state, _window, cx| {
let allow = state == &ToggleState::Selected; let allow = state == &ToggleState::Selected;

View file

@ -1506,7 +1506,8 @@ impl AgentDiff {
.read(cx) .read(cx)
.entries() .entries()
.last() .last()
.map_or(false, |entry| entry.diffs().next().is_some()) .and_then(|entry| entry.diff())
.is_some()
{ {
self.update_reviewing_editors(workspace, window, cx); self.update_reviewing_editors(workspace, window, cx);
} }
@ -1516,7 +1517,8 @@ impl AgentDiff {
.read(cx) .read(cx)
.entries() .entries()
.get(*ix) .get(*ix)
.map_or(false, |entry| entry.diffs().next().is_some()) .and_then(|entry| entry.diff())
.is_some()
{ {
self.update_reviewing_editors(workspace, window, cx); self.update_reviewing_editors(workspace, window, cx);
} }

View file

@ -440,7 +440,7 @@ pub struct AgentPanel {
local_timezone: UtcOffset, local_timezone: UtcOffset,
active_view: ActiveView, active_view: ActiveView,
acp_message_history: acp_message_history:
Rc<RefCell<crate::acp::MessageHistory<Vec<agent_client_protocol::ContentBlock>>>>, Rc<RefCell<crate::acp::MessageHistory<agentic_coding_protocol::SendUserMessageParams>>>,
previous_view: Option<ActiveView>, previous_view: Option<ActiveView>,
history_store: Entity<HistoryStore>, history_store: Entity<HistoryStore>,
history: Entity<ThreadHistory>, history: Entity<ThreadHistory>,
@ -1991,20 +1991,6 @@ impl AgentPanel {
); );
}), }),
) )
.item(
ContextMenuEntry::new("New Codex Thread")
.icon(IconName::AiOpenAi)
.icon_color(Color::Muted)
.handler(move |window, cx| {
window.dispatch_action(
NewExternalAgentThread {
agent: Some(crate::ExternalAgent::Codex),
}
.boxed_clone(),
cx,
);
}),
)
}); });
menu menu
})) }))
@ -2030,69 +2016,65 @@ impl AgentPanel {
) )
.anchor(Corner::TopRight) .anchor(Corner::TopRight)
.with_handle(self.agent_panel_menu_handle.clone()) .with_handle(self.agent_panel_menu_handle.clone())
.menu({ .menu(move |window, cx| {
let focus_handle = focus_handle.clone(); Some(ContextMenu::build(window, cx, |mut menu, _window, _| {
move |window, cx| { if let Some(usage) = usage {
Some(ContextMenu::build(window, cx, |mut menu, _window, _| {
menu = menu.context(focus_handle.clone());
if let Some(usage) = usage {
menu = menu
.header_with_link("Prompt Usage", "Manage", account_url.clone())
.custom_entry(
move |_window, cx| {
let used_percentage = match usage.limit {
UsageLimit::Limited(limit) => {
Some((usage.amount as f32 / limit as f32) * 100.)
}
UsageLimit::Unlimited => None,
};
h_flex()
.flex_1()
.gap_1p5()
.children(used_percentage.map(|percent| {
ProgressBar::new("usage", percent, 100., cx)
}))
.child(
Label::new(match usage.limit {
UsageLimit::Limited(limit) => {
format!("{} / {limit}", usage.amount)
}
UsageLimit::Unlimited => {
format!("{} / ∞", usage.amount)
}
})
.size(LabelSize::Small)
.color(Color::Muted),
)
.into_any_element()
},
move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
)
.separator()
}
menu = menu menu = menu
.header("MCP Servers") .header_with_link("Prompt Usage", "Manage", account_url.clone())
.action( .custom_entry(
"View Server Extensions", move |_window, cx| {
Box::new(zed_actions::Extensions { let used_percentage = match usage.limit {
category_filter: Some( UsageLimit::Limited(limit) => {
zed_actions::ExtensionCategoryFilter::ContextServers, Some((usage.amount as f32 / limit as f32) * 100.)
), }
id: None, UsageLimit::Unlimited => None,
}), };
h_flex()
.flex_1()
.gap_1p5()
.children(used_percentage.map(|percent| {
ProgressBar::new("usage", percent, 100., cx)
}))
.child(
Label::new(match usage.limit {
UsageLimit::Limited(limit) => {
format!("{} / {limit}", usage.amount)
}
UsageLimit::Unlimited => {
format!("{} / ∞", usage.amount)
}
})
.size(LabelSize::Small)
.color(Color::Muted),
)
.into_any_element()
},
move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
) )
.action("Add Custom Server…", Box::new(AddContextServer)) .separator()
.separator(); }
menu = menu menu = menu
.action("Rules…", Box::new(OpenRulesLibrary::default())) .header("MCP Servers")
.action("Settings", Box::new(OpenConfiguration)) .action(
.action(zoom_in_label, Box::new(ToggleZoom)); "View Server Extensions",
menu Box::new(zed_actions::Extensions {
})) category_filter: Some(
} zed_actions::ExtensionCategoryFilter::ContextServers,
),
id: None,
}),
)
.action("Add Custom Server…", Box::new(AddContextServer))
.separator();
menu = menu
.action("Rules…", Box::new(OpenRulesLibrary::default()))
.action("Settings", Box::new(OpenConfiguration))
.action(zoom_in_label, Box::new(ToggleZoom));
menu
}))
}); });
h_flex() h_flex()
@ -2666,25 +2648,6 @@ impl AgentPanel {
) )
}, },
), ),
)
.child(
NewThreadButton::new(
"new-codex-thread-btn",
"New Codex Thread",
IconName::AiOpenAi,
)
.on_click(
|window, cx| {
window.dispatch_action(
Box::new(NewExternalAgentThread {
agent: Some(
crate::ExternalAgent::Codex,
),
}),
cx,
)
},
),
), ),
) )
}), }),

View file

@ -150,7 +150,6 @@ enum ExternalAgent {
#[default] #[default]
Gemini, Gemini,
ClaudeCode, ClaudeCode,
Codex,
} }
impl ExternalAgent { impl ExternalAgent {
@ -158,7 +157,6 @@ impl ExternalAgent {
match self { match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
ExternalAgent::Codex => Rc::new(agent_servers::Codex),
} }
} }
} }
@ -264,9 +262,7 @@ fn update_command_palette_filter(cx: &mut App) {
if disable_ai { if disable_ai {
filter.hide_namespace("agent"); filter.hide_namespace("agent");
filter.hide_namespace("assistant"); filter.hide_namespace("assistant");
filter.hide_namespace("copilot");
filter.hide_namespace("zed_predict_onboarding"); filter.hide_namespace("zed_predict_onboarding");
filter.hide_namespace("edit_prediction"); filter.hide_namespace("edit_prediction");
use editor::actions::{ use editor::actions::{
@ -286,7 +282,6 @@ fn update_command_palette_filter(cx: &mut App) {
} else { } else {
filter.show_namespace("agent"); filter.show_namespace("agent");
filter.show_namespace("assistant"); filter.show_namespace("assistant");
filter.show_namespace("copilot");
filter.show_namespace("zed_predict_onboarding"); filter.show_namespace("zed_predict_onboarding");
filter.show_namespace("edit_prediction"); filter.show_namespace("edit_prediction");

View file

@ -1,14 +1,12 @@
mod agent_api_keys_onboarding; mod agent_api_keys_onboarding;
mod agent_panel_onboarding_card; mod agent_panel_onboarding_card;
mod agent_panel_onboarding_content; mod agent_panel_onboarding_content;
mod ai_upsell_card;
mod edit_prediction_onboarding_content; mod edit_prediction_onboarding_content;
mod young_account_banner; mod young_account_banner;
pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders}; pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders};
pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; pub use agent_panel_onboarding_card::AgentPanelOnboardingCard;
pub use agent_panel_onboarding_content::AgentPanelOnboarding; pub use agent_panel_onboarding_content::AgentPanelOnboarding;
pub use ai_upsell_card::AiUpsellCard;
pub use edit_prediction_onboarding_content::EditPredictionOnboarding; pub use edit_prediction_onboarding_content::EditPredictionOnboarding;
pub use young_account_banner::YoungAccountBanner; pub use young_account_banner::YoungAccountBanner;
@ -56,7 +54,6 @@ impl RenderOnce for BulletItem {
} }
} }
#[derive(PartialEq)]
pub enum SignInStatus { pub enum SignInStatus {
SignedIn, SignedIn,
SigningIn, SigningIn,

View file

@ -1,201 +0,0 @@
use std::sync::Arc;
use client::{Client, zed_urls};
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
use ui::{Divider, List, Vector, VectorName, prelude::*};
use crate::{BulletItem, SignInStatus};
#[derive(IntoElement, RegisterComponent)]
pub struct AiUpsellCard {
pub sign_in_status: SignInStatus,
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
}
impl AiUpsellCard {
pub fn new(client: Arc<Client>) -> Self {
let status = *client.status().borrow();
Self {
sign_in_status: status.into(),
sign_in: Arc::new(move |_window, cx| {
cx.spawn({
let client = client.clone();
async move |cx| {
client.authenticate_and_connect(true, cx).await;
}
})
.detach();
}),
}
}
}
impl RenderOnce for AiUpsellCard {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let pro_section = v_flex()
.w_full()
.gap_1()
.child(
h_flex()
.gap_2()
.child(
Label::new("Pro")
.size(LabelSize::Small)
.color(Color::Accent)
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(
List::new()
.child(BulletItem::new("500 prompts with Claude models"))
.child(BulletItem::new(
"Unlimited edit predictions with Zeta, our open-source model",
)),
);
let free_section = v_flex()
.w_full()
.gap_1()
.child(
h_flex()
.gap_2()
.child(
Label::new("Free")
.size(LabelSize::Small)
.color(Color::Muted)
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(
List::new()
.child(BulletItem::new("50 prompts with the Claude models"))
.child(BulletItem::new("2,000 accepted edit predictions")),
);
let grid_bg = h_flex().absolute().inset_0().w_full().h(px(240.)).child(
Vector::new(VectorName::Grid, rems_from_px(500.), rems_from_px(240.))
.color(Color::Custom(cx.theme().colors().border.opacity(0.05))),
);
let gradient_bg = div()
.absolute()
.inset_0()
.size_full()
.bg(gpui::linear_gradient(
180.,
gpui::linear_color_stop(
cx.theme().colors().elevated_surface_background.opacity(0.8),
0.,
),
gpui::linear_color_stop(
cx.theme().colors().elevated_surface_background.opacity(0.),
0.8,
),
));
const DESCRIPTION: &str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI.";
let footer_buttons = match self.sign_in_status {
SignInStatus::SignedIn => v_flex()
.items_center()
.gap_1()
.child(
Button::new("sign_in", "Start 14-day Free Pro Trial")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(move |_, _window, cx| {
telemetry::event!("Start Trial Clicked", state = "post-sign-in");
cx.open_url(&zed_urls::start_trial_url(cx))
}),
)
.child(
Label::new("No credit card required")
.size(LabelSize::Small)
.color(Color::Muted),
)
.into_any_element(),
_ => Button::new("sign_in", "Sign In")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click({
let callback = self.sign_in.clone();
move |_, window, cx| {
telemetry::event!("Start Trial Clicked", state = "pre-sign-in");
callback(window, cx)
}
})
.into_any_element(),
};
v_flex()
.relative()
.p_6()
.pt_4()
.border_1()
.border_color(cx.theme().colors().border)
.rounded_lg()
.overflow_hidden()
.child(grid_bg)
.child(gradient_bg)
.child(Headline::new("Try Zed AI"))
.child(Label::new(DESCRIPTION).color(Color::Muted).mb_2())
.child(
h_flex()
.mt_1p5()
.mb_2p5()
.items_start()
.gap_12()
.child(free_section)
.child(pro_section),
)
.child(footer_buttons)
}
}
impl Component for AiUpsellCard {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
fn name() -> &'static str {
"AI Upsell Card"
}
fn sort_name() -> &'static str {
"AI Upsell Card"
}
fn description() -> Option<&'static str> {
Some("A card presenting the Zed AI product during user's first-open onboarding flow.")
}
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
Some(
v_flex()
.p_4()
.gap_4()
.children(vec![example_group(vec![
single_example(
"Signed Out State",
AiUpsellCard {
sign_in_status: SignInStatus::SignedOut,
sign_in: Arc::new(|_, _| {}),
}
.into_any_element(),
),
single_example(
"Signed In State",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
}
.into_any_element(),
),
])])
.into_any_element(),
)
}
}

View file

@ -1138,7 +1138,7 @@ impl Client {
.to_str() .to_str()
.map_err(EstablishConnectionError::other)? .map_err(EstablishConnectionError::other)?
.to_string(); .to_string();
Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}")) Url::parse(&collab_url).with_context(|| format!("parsing colab rpc url {collab_url}"))
} }
} }

View file

@ -358,13 +358,13 @@ impl Telemetry {
worktree_id: WorktreeId, worktree_id: WorktreeId,
updated_entries_set: &UpdatedEntriesSet, updated_entries_set: &UpdatedEntriesSet,
) { ) {
let Some(project_types) = self.detect_project_types(worktree_id, updated_entries_set) let Some(project_type_names) = self.detect_project_types(worktree_id, updated_entries_set)
else { else {
return; return;
}; };
for project_type in project_types { for project_type_name in project_type_names {
telemetry::event!("Project Opened", project_type = project_type); telemetry::event!("Project Opened", project_type = project_type_name);
} }
} }

View file

@ -106,6 +106,7 @@ pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
.route("/users/:id/update_plan", post(update_plan)) .route("/users/:id/update_plan", post(update_plan))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
.merge(billing::router())
.merge(contributors::router()) .merge(contributors::router())
.layer( .layer(
ServiceBuilder::new() ServiceBuilder::new()

View file

@ -1,13 +1,23 @@
use anyhow::{Context as _, bail}; use anyhow::{Context as _, bail};
use axum::{Extension, Json, Router, extract, routing::post};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use reqwest::StatusCode;
use sea_orm::ActiveValue; use sea_orm::ActiveValue;
use std::{sync::Arc, time::Duration}; use serde::{Deserialize, Serialize};
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus}; use std::{str::FromStr, sync::Arc, time::Duration};
use stripe::{
BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::{ResultExt, maybe}; use util::{ResultExt, maybe};
use zed_llm_client::LanguageModelProvider; use zed_llm_client::LanguageModelProvider;
use crate::AppState;
use crate::db::billing_subscription::{ use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
}; };
@ -17,16 +27,331 @@ use crate::stripe_client::{
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription, StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
StripeSubscriptionId, StripeSubscriptionId,
}; };
use crate::{AppState, Error, Result};
use crate::{db::UserId, llm::db::LlmDatabase}; use crate::{db::UserId, llm::db::LlmDatabase};
use crate::{ use crate::{
db::{ db::{
CreateBillingCustomerParams, CreateBillingSubscriptionParams, BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
CreateProcessedStripeEventParams, UpdateBillingCustomerParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
UpdateBillingSubscriptionParams, billing_customer, UpdateBillingSubscriptionParams, billing_customer,
}, },
stripe_billing::StripeBilling, stripe_billing::StripeBilling,
}; };
pub fn router() -> Router {
Router::new()
.route(
"/billing/subscriptions/manage",
post(manage_billing_subscription),
)
.route(
"/billing/subscriptions/sync",
post(sync_billing_subscription),
)
}
#[derive(Debug, PartialEq, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ManageSubscriptionIntent {
/// The user intends to manage their subscription.
///
/// This will open the Stripe billing portal without putting the user in a specific flow.
ManageSubscription,
/// The user intends to update their payment method.
UpdatePaymentMethod,
/// The user intends to upgrade to Zed Pro.
UpgradeToPro,
/// The user intends to cancel their subscription.
Cancel,
/// The user intends to stop the cancellation of their subscription.
StopCancellation,
}
#[derive(Debug, Deserialize)]
struct ManageBillingSubscriptionBody {
github_user_id: i32,
intent: ManageSubscriptionIntent,
/// The ID of the subscription to manage.
subscription_id: BillingSubscriptionId,
redirect_to: Option<String>,
}
#[derive(Debug, Serialize)]
struct ManageBillingSubscriptionResponse {
billing_portal_session_url: Option<String>,
}
/// Initiates a Stripe customer portal session for managing a billing subscription.
async fn manage_billing_subscription(
Extension(app): Extension<Arc<AppState>>,
extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
) -> Result<Json<ManageBillingSubscriptionResponse>> {
let user = app
.db
.get_user_by_github_user_id(body.github_user_id)
.await?
.context("user not found")?;
let Some(stripe_client) = app.real_stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::error!("failed to retrieve Stripe billing object");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let customer = app
.db
.get_billing_customer_by_user_id(user.id)
.await?
.context("billing customer not found")?;
let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
.context("failed to parse customer ID")?;
let subscription = app
.db
.get_billing_subscription_by_id(body.subscription_id)
.await?
.context("subscription not found")?;
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;
if body.intent == ManageSubscriptionIntent::StopCancellation {
let updated_stripe_subscription = Subscription::update(
&stripe_client,
&subscription_id,
stripe::UpdateSubscription {
cancel_at_period_end: Some(false),
..Default::default()
},
)
.await?;
app.db
.update_billing_subscription(
subscription.id,
&UpdateBillingSubscriptionParams {
stripe_cancel_at: ActiveValue::set(
updated_stripe_subscription
.cancel_at
.and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
.map(|time| time.naive_utc()),
),
..Default::default()
},
)
.await?;
return Ok(Json(ManageBillingSubscriptionResponse {
billing_portal_session_url: None,
}));
}
let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None,
ManageSubscriptionIntent::UpgradeToPro => {
let zed_pro_price_id: stripe::PriceId =
stripe_billing.zed_pro_price_id().await?.try_into()?;
let zed_free_price_id: stripe::PriceId =
stripe_billing.zed_free_price_id().await?.try_into()?;
let stripe_subscription =
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
&& stripe_subscription.items.data.iter().any(|item| {
item.price
.as_ref()
.map_or(false, |price| price.id == zed_pro_price_id)
});
if is_on_zed_pro_trial {
let payment_methods = PaymentMethod::list(
&stripe_client,
&stripe::ListPaymentMethods {
customer: Some(stripe_subscription.customer.id()),
..Default::default()
},
)
.await?;
let has_payment_method = !payment_methods.data.is_empty();
if !has_payment_method {
return Err(Error::http(
StatusCode::BAD_REQUEST,
"missing payment method".into(),
));
}
// If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
Subscription::update(
&stripe_client,
&stripe_subscription.id,
stripe::UpdateSubscription {
trial_end: Some(stripe::Scheduled::now()),
..Default::default()
},
)
.await?;
return Ok(Json(ManageBillingSubscriptionResponse {
billing_portal_session_url: None,
}));
}
let subscription_item_to_update = stripe_subscription
.items
.data
.iter()
.find_map(|item| {
let price = item.price.as_ref()?;
if price.id == zed_free_price_id {
Some(item.id.clone())
} else {
None
}
})
.context("No subscription item to update")?;
Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
subscription_update_confirm: Some(
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
subscription: subscription.stripe_subscription_id,
items: vec![
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
id: subscription_item_to_update.to_string(),
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
},
],
discounts: None,
},
),
..Default::default()
})
}
ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
return_url: format!(
"{}{path}",
app.config.zed_dot_dev_url(),
path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
),
}),
..Default::default()
}),
..Default::default()
}),
ManageSubscriptionIntent::Cancel => {
if subscription.kind == Some(SubscriptionKind::ZedFree) {
return Err(Error::http(
StatusCode::BAD_REQUEST,
"free subscription cannot be canceled".into(),
));
}
Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
return_url: format!("{}/account", app.config.zed_dot_dev_url()),
}),
..Default::default()
}),
subscription_cancel: Some(
stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
subscription: subscription.stripe_subscription_id,
retention: None,
},
),
..Default::default()
})
}
ManageSubscriptionIntent::StopCancellation => unreachable!(),
};
let mut params = CreateBillingPortalSession::new(customer_id);
params.flow_data = flow;
let return_url = format!("{}/account", app.config.zed_dot_dev_url());
params.return_url = Some(&return_url);
let session = BillingPortalSession::create(&stripe_client, params).await?;
Ok(Json(ManageBillingSubscriptionResponse {
billing_portal_session_url: Some(session.url),
}))
}
#[derive(Debug, Deserialize)]
struct SyncBillingSubscriptionBody {
github_user_id: i32,
}
#[derive(Debug, Serialize)]
struct SyncBillingSubscriptionResponse {
stripe_customer_id: String,
}
async fn sync_billing_subscription(
Extension(app): Extension<Arc<AppState>>,
extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
) -> Result<Json<SyncBillingSubscriptionResponse>> {
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let user = app
.db
.get_user_by_github_user_id(body.github_user_id)
.await?
.context("user not found")?;
let billing_customer = app
.db
.get_billing_customer_by_user_id(user.id)
.await?
.context("billing customer not found")?;
let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let subscriptions = stripe_client
.list_subscriptions_for_customer(&stripe_customer_id)
.await?;
for subscription in subscriptions {
let subscription_id = subscription.id.clone();
sync_subscription(&app, &stripe_client, subscription)
.await
.with_context(|| {
format!(
"failed to sync subscription {subscription_id} for user {}",
user.id,
)
})?;
}
Ok(Json(SyncBillingSubscriptionResponse {
stripe_customer_id: billing_customer.stripe_customer_id.clone(),
}))
}
/// The amount of time we wait in between each poll of Stripe events. /// The amount of time we wait in between each poll of Stripe events.
/// ///
/// This value should strike a balance between: /// This value should strike a balance between:

View file

@ -433,8 +433,6 @@ impl Server {
.add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>) .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
.add_request_handler(forward_mutating_project_request::<proto::Stage>) .add_request_handler(forward_mutating_project_request::<proto::Stage>)
.add_request_handler(forward_mutating_project_request::<proto::Unstage>) .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
.add_request_handler(forward_mutating_project_request::<proto::Stash>)
.add_request_handler(forward_mutating_project_request::<proto::StashPop>)
.add_request_handler(forward_mutating_project_request::<proto::Commit>) .add_request_handler(forward_mutating_project_request::<proto::Commit>)
.add_request_handler(forward_mutating_project_request::<proto::GitInit>) .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
.add_request_handler(forward_read_only_project_request::<proto::GetRemotes>) .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
@ -831,7 +829,7 @@ impl Server {
// This arrangement ensures we will attempt to process earlier messages first, but fall // This arrangement ensures we will attempt to process earlier messages first, but fall
// back to processing messages arrived later in the spirit of making progress. // back to processing messages arrived later in the spirit of making progress.
let mut foreground_message_handlers = FuturesUnordered::new(); let mut foreground_message_handlers = FuturesUnordered::new();
let concurrent_handlers = Arc::new(Semaphore::new(512)); let concurrent_handlers = Arc::new(Semaphore::new(256));
loop { loop {
let next_message = async { let next_message = async {
let permit = concurrent_handlers.clone().acquire_owned().await.unwrap(); let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();

View file

@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use collections::HashMap; use collections::HashMap;
use futures::{FutureExt, StreamExt, channel::oneshot, future, select}; use futures::{FutureExt, StreamExt, channel::oneshot, select};
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task}; use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
use parking_lot::Mutex; use parking_lot::Mutex;
use postage::barrier; use postage::barrier;
@ -10,19 +10,15 @@ use smol::channel;
use std::{ use std::{
fmt, fmt,
path::PathBuf, path::PathBuf,
pin::pin,
sync::{ sync::{
Arc, Arc,
atomic::{AtomicI32, Ordering::SeqCst}, atomic::{AtomicI32, Ordering::SeqCst},
}, },
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use util::{ResultExt, TryFutureExt}; use util::TryFutureExt;
use crate::{ use crate::transport::{StdioTransport, Transport};
transport::{StdioTransport, Transport},
types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
};
const JSON_RPC_VERSION: &str = "2.0"; const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
@ -36,7 +32,6 @@ pub const INTERNAL_ERROR: i32 = -32603;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>; type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>; type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
@ -83,15 +78,6 @@ pub struct Request<'a, T> {
pub params: T, pub params: T,
} }
#[derive(Serialize, Deserialize)]
pub struct AnyRequest<'a> {
pub jsonrpc: &'a str,
pub id: RequestId,
pub method: &'a str,
#[serde(skip_serializing_if = "is_null_value")]
pub params: Option<&'a RawValue>,
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct AnyResponse<'a> { struct AnyResponse<'a> {
jsonrpc: &'a str, jsonrpc: &'a str,
@ -190,23 +176,15 @@ impl Client {
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers = let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
let receive_input_task = cx.spawn({ let receive_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone(); let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone(); let response_handlers = response_handlers.clone();
let request_handlers = request_handlers.clone();
let transport = transport.clone(); let transport = transport.clone();
async move |cx| { async move |cx| {
Self::handle_input( Self::handle_input(transport, notification_handlers, response_handlers, cx)
transport, .log_err()
notification_handlers, .await
request_handlers,
response_handlers,
cx,
)
.log_err()
.await
} }
}); });
let receive_err_task = cx.spawn({ let receive_err_task = cx.spawn({
@ -252,24 +230,13 @@ impl Client {
async fn handle_input( async fn handle_input(
transport: Arc<dyn Transport>, transport: Arc<dyn Transport>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>, response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut receiver = transport.receive(); let mut receiver = transport.receive();
while let Some(message) = receiver.next().await { while let Some(message) = receiver.next().await {
log::trace!("recv: {}", &message); if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
if let Ok(request) = serde_json::from_str::<AnyRequest>(&message) {
let mut request_handlers = request_handlers.lock();
if let Some(handler) = request_handlers.get_mut(request.method) {
handler(
request.id,
request.params.unwrap_or(RawValue::NULL),
cx.clone(),
);
}
} else if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
if let Some(handlers) = response_handlers.lock().as_mut() { if let Some(handlers) = response_handlers.lock().as_mut() {
if let Some(handler) = handlers.remove(&response.id) { if let Some(handler) = handlers.remove(&response.id) {
handler(Ok(message.to_string())); handler(Ok(message.to_string()));
@ -280,8 +247,6 @@ impl Client {
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
handler(notification.params.unwrap_or(Value::Null), cx.clone()); handler(notification.params.unwrap_or(Value::Null), cx.clone());
} }
} else {
log::error!("Unhandled JSON from context_server: {}", message);
} }
} }
@ -329,17 +294,6 @@ impl Client {
&self, &self,
method: &str, method: &str,
params: impl Serialize, params: impl Serialize,
) -> Result<T> {
self.request_with(method, params, None, Some(REQUEST_TIMEOUT))
.await
}
pub async fn request_with<T: DeserializeOwned>(
&self,
method: &str,
params: impl Serialize,
cancel_rx: Option<oneshot::Receiver<()>>,
timeout: Option<Duration>,
) -> Result<T> { ) -> Result<T> {
let id = self.next_id.fetch_add(1, SeqCst); let id = self.next_id.fetch_add(1, SeqCst);
let request = serde_json::to_string(&Request { let request = serde_json::to_string(&Request {
@ -375,23 +329,7 @@ impl Client {
handle_response?; handle_response?;
send?; send?;
let mut timeout_fut = pin!( let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
match timeout {
Some(timeout) => future::Either::Left(executor.timer(timeout)),
None => future::Either::Right(future::pending()),
}
.fuse()
);
let mut cancel_fut = pin!(
match cancel_rx {
Some(rx) => future::Either::Left(async {
rx.await.log_err();
}),
None => future::Either::Right(future::pending()),
}
.fuse()
);
select! { select! {
response = rx.fuse() => { response = rx.fuse() => {
let elapsed = started.elapsed(); let elapsed = started.elapsed();
@ -410,18 +348,8 @@ impl Client {
Err(_) => anyhow::bail!("cancelled") Err(_) => anyhow::bail!("cancelled")
} }
} }
_ = cancel_fut => { _ = timeout => {
self.notify( log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
Cancelled::METHOD,
ClientNotification::Cancelled(CancelledParams {
request_id: RequestId::Int(id),
reason: None
})
).log_err();
anyhow::bail!(RequestCanceled)
}
_ = timeout_fut => {
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout.unwrap());
anyhow::bail!("Context server request timeout"); anyhow::bail!("Context server request timeout");
} }
} }
@ -451,17 +379,6 @@ impl Client {
} }
} }
#[derive(Debug)]
pub struct RequestCanceled;
impl std::error::Error for RequestCanceled {}
impl std::fmt::Display for RequestCanceled {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Context server request was canceled")
}
}
impl fmt::Display for ContextServerId { impl fmt::Display for ContextServerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f) self.0.fmt(f)

View file

@ -9,8 +9,6 @@ use futures::{
}; };
use gpui::{App, AppContext, AsyncApp, Task}; use gpui::{App, AppContext, AsyncApp, Task};
use net::async_net::{UnixListener, UnixStream}; use net::async_net::{UnixListener, UnixStream};
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde_json::{json, value::RawValue}; use serde_json::{json, value::RawValue};
use smol::stream::StreamExt; use smol::stream::StreamExt;
use std::{ use std::{
@ -22,32 +20,16 @@ use util::ResultExt;
use crate::{ use crate::{
client::{CspResult, RequestId, Response}, client::{CspResult, RequestId, Response},
types::{ types::Request,
CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
ToolResponseContent,
requests::{CallTool, ListTools},
},
}; };
pub struct McpServer { pub struct McpServer {
socket_path: PathBuf, socket_path: PathBuf,
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>, handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
_server_task: Task<()>, _server_task: Task<()>,
} }
struct RegisteredTool { type McpHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
tool: Tool,
handler: ToolHandler,
}
type ToolHandler = Box<
dyn Fn(
Option<serde_json::Value>,
&mut AsyncApp,
) -> Task<Result<ToolResponse<serde_json::Value>>>,
>;
type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
impl McpServer { impl McpServer {
pub fn new(cx: &AsyncApp) -> Task<Result<Self>> { pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
@ -61,14 +43,12 @@ impl McpServer {
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let (temp_dir, socket_path, listener) = task.await?; let (temp_dir, socket_path, listener) = task.await?;
let tools = Rc::new(RefCell::new(HashMap::default()));
let handlers = Rc::new(RefCell::new(HashMap::default())); let handlers = Rc::new(RefCell::new(HashMap::default()));
let server_task = cx.spawn({ let server_task = cx.spawn({
let tools = tools.clone();
let handlers = handlers.clone(); let handlers = handlers.clone();
async move |cx| { async move |cx| {
while let Ok((stream, _)) = listener.accept().await { while let Ok((stream, _)) = listener.accept().await {
Self::serve_connection(stream, tools.clone(), handlers.clone(), cx); Self::serve_connection(stream, handlers.clone(), cx);
} }
drop(temp_dir) drop(temp_dir)
} }
@ -76,56 +56,11 @@ impl McpServer {
Ok(Self { Ok(Self {
socket_path, socket_path,
_server_task: server_task, _server_task: server_task,
tools, handlers: handlers.clone(),
handlers: handlers,
}) })
}) })
} }
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
let output_schema = schemars::schema_for!(T::Output);
let unit_schema = schemars::schema_for!(());
let registered_tool = RegisteredTool {
tool: Tool {
name: T::NAME.into(),
description: Some(tool.description().into()),
input_schema: schemars::schema_for!(T::Input).into(),
output_schema: if output_schema == unit_schema {
None
} else {
Some(output_schema.into())
},
annotations: Some(tool.annotations()),
},
handler: Box::new({
let tool = tool.clone();
move |input_value, cx| {
let input = match input_value {
Some(input) => serde_json::from_value(input),
None => serde_json::from_value(serde_json::Value::Null),
};
let tool = tool.clone();
match input {
Ok(input) => cx.spawn(async move |cx| {
let output = tool.run(input, cx).await?;
Ok(ToolResponse {
content: output.content,
structured_content: serde_json::to_value(output.structured_content)
.unwrap_or_default(),
})
}),
Err(err) => Task::ready(Err(err.into())),
}
}
}),
};
self.tools.borrow_mut().insert(T::NAME, registered_tool);
}
pub fn handle_request<R: Request>( pub fn handle_request<R: Request>(
&mut self, &mut self,
f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static, f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
@ -185,8 +120,7 @@ impl McpServer {
fn serve_connection( fn serve_connection(
stream: UnixStream, stream: UnixStream,
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>, handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) { ) {
let (read, write) = smol::io::split(stream); let (read, write) = smol::io::split(stream);
@ -201,13 +135,7 @@ impl McpServer {
let Some(request_id) = request.id.clone() else { let Some(request_id) = request.id.clone() else {
continue; continue;
}; };
if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
if request.method == CallTool::METHOD {
Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
.await;
} else if request.method == ListTools::METHOD {
Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
} else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
let outgoing_tx = outgoing_tx.clone(); let outgoing_tx = outgoing_tx.clone();
if let Some(task) = cx if let Some(task) = cx
@ -221,126 +149,25 @@ impl McpServer {
.detach(); .detach();
} }
} else { } else {
Self::send_err( outgoing_tx
request_id, .unbounded_send(
format!("unhandled method {}", request.method), serde_json::to_string(&Response::<()> {
&outgoing_tx, jsonrpc: "2.0",
); id: request.id.unwrap(),
value: CspResult::Error(Some(crate::client::Error {
message: format!("unhandled method {}", request.method),
code: -32601,
})),
})
.unwrap(),
)
.ok();
} }
} }
}) })
.detach(); .detach();
} }
fn handle_list_tools(
request_id: RequestId,
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
outgoing_tx: &UnboundedSender<String>,
) {
let response = ListToolsResponse {
tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
next_cursor: None,
meta: None,
};
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response {
jsonrpc: "2.0",
id: request_id,
value: CspResult::Ok(Some(response)),
})
.unwrap_or_default(),
)
.ok();
}
async fn handle_call_tool(
request_id: RequestId,
params: Option<Box<RawValue>>,
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
outgoing_tx: &UnboundedSender<String>,
cx: &mut AsyncApp,
) {
let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
Some(params) => serde_json::from_str(params.get()),
None => serde_json::from_value(serde_json::Value::Null),
};
match result {
Ok(params) => {
if let Some(tool) = tools.borrow().get(&params.name.as_ref()) {
let outgoing_tx = outgoing_tx.clone();
let task = (tool.handler)(params.arguments, cx);
cx.spawn(async move |_| {
let response = match task.await {
Ok(result) => CallToolResponse {
content: result.content,
is_error: Some(false),
meta: None,
structured_content: if result.structured_content.is_null() {
None
} else {
Some(result.structured_content)
},
},
Err(err) => CallToolResponse {
content: vec![ToolResponseContent::Text {
text: err.to_string(),
}],
is_error: Some(true),
meta: None,
structured_content: None,
},
};
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response {
jsonrpc: "2.0",
id: request_id,
value: CspResult::Ok(Some(response)),
})
.unwrap_or_default(),
)
.ok();
})
.detach();
} else {
Self::send_err(
request_id,
format!("Tool not found: {}", params.name),
&outgoing_tx,
);
}
}
Err(err) => {
Self::send_err(request_id, err.to_string(), &outgoing_tx);
}
}
}
fn send_err(
request_id: RequestId,
message: impl Into<String>,
outgoing_tx: &UnboundedSender<String>,
) {
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response::<()> {
jsonrpc: "2.0",
id: request_id,
value: CspResult::Error(Some(crate::client::Error {
message: message.into(),
code: -32601,
})),
})
.unwrap(),
)
.ok();
}
async fn handle_io( async fn handle_io(
mut outgoing_rx: UnboundedReceiver<String>, mut outgoing_rx: UnboundedReceiver<String>,
incoming_tx: UnboundedSender<RawRequest>, incoming_tx: UnboundedSender<RawRequest>,
@ -389,37 +216,7 @@ impl McpServer {
} }
} }
pub trait McpServerTool { #[derive(Serialize, Deserialize)]
type Input: DeserializeOwned + JsonSchema;
type Output: Serialize + JsonSchema;
const NAME: &'static str;
fn description(&self) -> &'static str;
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: None,
read_only_hint: None,
destructive_hint: None,
idempotent_hint: None,
open_world_hint: None,
}
}
fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
}
pub struct ToolResponse<T> {
pub content: Vec<ToolResponseContent>,
pub structured_content: T,
}
#[derive(Debug, Serialize, Deserialize)]
struct RawRequest { struct RawRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
id: Option<RequestId>, id: Option<RequestId>,

View file

@ -5,12 +5,7 @@
//! read/write messages and the types from types.rs for serialization/deserialization //! read/write messages and the types from types.rs for serialization/deserialization
//! of messages. //! of messages.
use std::time::Duration;
use anyhow::Result; use anyhow::Result;
use futures::channel::oneshot;
use gpui::AsyncApp;
use serde_json::Value;
use crate::client::Client; use crate::client::Client;
use crate::types::{self, Notification, Request}; use crate::types::{self, Notification, Request};
@ -100,25 +95,7 @@ impl InitializedContextServerProtocol {
self.inner.request(T::METHOD, params).await self.inner.request(T::METHOD, params).await
} }
pub async fn request_with<T: Request>(
&self,
params: T::Params,
cancel_rx: Option<oneshot::Receiver<()>>,
timeout: Option<Duration>,
) -> Result<T::Response> {
self.inner
.request_with(T::METHOD, params, cancel_rx, timeout)
.await
}
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> { pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
self.inner.notify(T::METHOD, params) self.inner.notify(T::METHOD, params)
} }
pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncApp),
{
self.inner.on_notification(method, f);
}
} }

View file

@ -3,8 +3,6 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
use crate::client::RequestId;
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
pub const VERSION_2024_11_05: &str = "2024-11-05"; pub const VERSION_2024_11_05: &str = "2024-11-05";
@ -102,7 +100,6 @@ pub mod notifications {
notification!("notifications/initialized", Initialized, ()); notification!("notifications/initialized", Initialized, ());
notification!("notifications/progress", Progress, ProgressParams); notification!("notifications/progress", Progress, ProgressParams);
notification!("notifications/message", Message, MessageParams); notification!("notifications/message", Message, MessageParams);
notification!("notifications/cancelled", Cancelled, CancelledParams);
notification!( notification!(
"notifications/resources/updated", "notifications/resources/updated",
ResourcesUpdated, ResourcesUpdated,
@ -495,20 +492,18 @@ pub struct RootsCapabilities {
pub list_changed: Option<bool>, pub list_changed: Option<bool>,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Tool { pub struct Tool {
pub name: String, pub name: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>, pub description: Option<String>,
pub input_schema: serde_json::Value, pub input_schema: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_schema: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub annotations: Option<ToolAnnotations>, pub annotations: Option<ToolAnnotations>,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ToolAnnotations { pub struct ToolAnnotations {
/// A human-readable title for the tool. /// A human-readable title for the tool.
@ -622,15 +617,11 @@ pub enum ClientNotification {
Initialized, Initialized,
Progress(ProgressParams), Progress(ProgressParams),
RootsListChanged, RootsListChanged,
Cancelled(CancelledParams), Cancelled {
} request_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[derive(Debug, Serialize, Deserialize)] reason: Option<String>,
#[serde(rename_all = "camelCase")] },
pub struct CancelledParams {
pub request_id: RequestId,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -682,20 +673,6 @@ pub struct CallToolResponse {
pub is_error: Option<bool>, pub is_error: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub structured_content: Option<serde_json::Value>,
}
impl CallToolResponse {
pub fn text_contents(&self) -> String {
let mut text = String::new();
for chunk in &self.content {
if let ToolResponseContent::Text { text: chunk } = chunk {
text.push_str(&chunk)
};
}
text
}
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]

View file

@ -918,7 +918,7 @@ async fn test_debug_panel_item_thread_status_reset_on_failure(
.unwrap(); .unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap()); let client = session.update(cx, |session, _| session.adapter_client().unwrap());
const THREAD_ID_NUM: i64 = 1; const THREAD_ID_NUM: u64 = 1;
client.on_request::<dap::requests::Threads, _>(move |_, _| { client.on_request::<dap::requests::Threads, _>(move |_, _| {
Ok(dap::ThreadsResponse { Ok(dap::ThreadsResponse {

View file

@ -110,7 +110,6 @@ tree-sitter-html.workspace = true
tree-sitter-rust.workspace = true tree-sitter-rust.workspace = true
tree-sitter-typescript.workspace = true tree-sitter-typescript.workspace = true
tree-sitter-yaml.workspace = true tree-sitter-yaml.workspace = true
tree-sitter-bash.workspace = true
unindent.workspace = true unindent.workspace = true
util = { workspace = true, features = ["test-support"] } util = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] }

View file

@ -365,8 +365,6 @@ actions!(
ConvertToLowerCase, ConvertToLowerCase,
/// Toggles the case of selected text. /// Toggles the case of selected text.
ConvertToOppositeCase, ConvertToOppositeCase,
/// Converts selected text to sentence case.
ConvertToSentenceCase,
/// Converts selected text to snake_case. /// Converts selected text to snake_case.
ConvertToSnakeCase, ConvertToSnakeCase,
/// Converts selected text to Title Case. /// Converts selected text to Title Case.

View file

@ -94,7 +94,7 @@ async fn test_fuzzy_score(cx: &mut TestAppContext) {
filter_and_sort_matches("set_text", &completions, SnippetSortOrder::Top, cx).await; filter_and_sort_matches("set_text", &completions, SnippetSortOrder::Top, cx).await;
assert_eq!(matches[0].string, "set_text"); assert_eq!(matches[0].string, "set_text");
assert_eq!(matches[1].string, "set_text_style_refinement"); assert_eq!(matches[1].string, "set_text_style_refinement");
assert_eq!(matches[2].string, "set_placeholder_text"); assert_eq!(matches[2].string, "set_context_menu_options");
} }
// fuzzy filter text over label, sort_text and sort_kind // fuzzy filter text over label, sort_text and sort_kind
@ -216,28 +216,6 @@ async fn test_sort_positions(cx: &mut TestAppContext) {
assert_eq!(matches[0].string, "rounded-full"); assert_eq!(matches[0].string, "rounded-full");
} }
#[gpui::test]
async fn test_fuzzy_over_sort_positions(cx: &mut TestAppContext) {
let completions = vec![
CompletionBuilder::variable("lsp_document_colors", None, "7fffffff"), // 0.29 fuzzy score
CompletionBuilder::function(
"language_servers_running_disk_based_diagnostics",
None,
"7fffffff",
), // 0.168 fuzzy score
CompletionBuilder::function("code_lens", None, "7fffffff"), // 3.2 fuzzy score
CompletionBuilder::variable("lsp_code_lens", None, "7fffffff"), // 3.2 fuzzy score
CompletionBuilder::function("fetch_code_lens", None, "7fffffff"), // 3.2 fuzzy score
];
let matches =
filter_and_sort_matches("lens", &completions, SnippetSortOrder::default(), cx).await;
assert_eq!(matches[0].string, "code_lens");
assert_eq!(matches[1].string, "lsp_code_lens");
assert_eq!(matches[2].string, "fetch_code_lens");
}
async fn test_for_each_prefix<F>( async fn test_for_each_prefix<F>(
target: &str, target: &str,
completions: &Vec<Completion>, completions: &Vec<Completion>,

View file

@ -1057,9 +1057,9 @@ impl CompletionsMenu {
enum MatchTier<'a> { enum MatchTier<'a> {
WordStartMatch { WordStartMatch {
sort_exact: Reverse<i32>, sort_exact: Reverse<i32>,
sort_positions: Vec<usize>,
sort_snippet: Reverse<i32>, sort_snippet: Reverse<i32>,
sort_score: Reverse<OrderedFloat<f64>>, sort_score: Reverse<OrderedFloat<f64>>,
sort_positions: Vec<usize>,
sort_text: Option<&'a str>, sort_text: Option<&'a str>,
sort_kind: usize, sort_kind: usize,
sort_label: &'a str, sort_label: &'a str,
@ -1137,9 +1137,9 @@ impl CompletionsMenu {
MatchTier::WordStartMatch { MatchTier::WordStartMatch {
sort_exact, sort_exact,
sort_positions,
sort_snippet, sort_snippet,
sort_score, sort_score,
sort_positions,
sort_text, sort_text,
sort_kind, sort_kind,
sort_label, sort_label,

View file

@ -10877,6 +10877,17 @@ impl Editor {
}); });
} }
pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context<Self>) {
self.manipulate_text(window, cx, |text| {
let has_upper_case_characters = text.chars().any(|c| c.is_uppercase());
if has_upper_case_characters {
text.to_lowercase()
} else {
text.to_uppercase()
}
})
}
fn manipulate_immutable_lines<Fn>( fn manipulate_immutable_lines<Fn>(
&mut self, &mut self,
window: &mut Window, window: &mut Window,
@ -11132,26 +11143,6 @@ impl Editor {
}) })
} }
pub fn convert_to_sentence_case(
&mut self,
_: &ConvertToSentenceCase,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.manipulate_text(window, cx, |text| text.to_case(Case::Sentence))
}
pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context<Self>) {
self.manipulate_text(window, cx, |text| {
let has_upper_case_characters = text.chars().any(|c| c.is_uppercase());
if has_upper_case_characters {
text.to_lowercase()
} else {
text.to_uppercase()
}
})
}
pub fn convert_to_rot13( pub fn convert_to_rot13(
&mut self, &mut self,
_: &ConvertToRot13, _: &ConvertToRot13,
@ -16976,7 +16967,7 @@ impl Editor {
now: Instant, now: Instant,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Option<TransactionId> { ) {
self.end_selection(window, cx); self.end_selection(window, cx);
if let Some(tx_id) = self if let Some(tx_id) = self
.buffer .buffer
@ -16986,10 +16977,7 @@ impl Editor {
.insert_transaction(tx_id, self.selections.disjoint_anchors()); .insert_transaction(tx_id, self.selections.disjoint_anchors());
cx.emit(EditorEvent::TransactionBegun { cx.emit(EditorEvent::TransactionBegun {
transaction_id: tx_id, transaction_id: tx_id,
}); })
Some(tx_id)
} else {
None
} }
} }
@ -17017,17 +17005,6 @@ impl Editor {
} }
} }
pub fn modify_transaction_selection_history(
&mut self,
transaction_id: TransactionId,
modify: impl FnOnce(&mut (Arc<[Selection<Anchor>]>, Option<Arc<[Selection<Anchor>]>>)),
) -> bool {
self.selection_history
.transaction_mut(transaction_id)
.map(modify)
.is_some()
}
pub fn set_mark(&mut self, _: &actions::SetMark, window: &mut Window, cx: &mut Context<Self>) { pub fn set_mark(&mut self, _: &actions::SetMark, window: &mut Window, cx: &mut Context<Self>) {
if self.selection_mark_mode { if self.selection_mark_mode {
self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
@ -22280,7 +22257,7 @@ fn consume_contiguous_rows(
selections: &mut Peekable<std::slice::Iter<Selection<Point>>>, selections: &mut Peekable<std::slice::Iter<Selection<Point>>>,
) -> (MultiBufferRow, MultiBufferRow) { ) -> (MultiBufferRow, MultiBufferRow) {
contiguous_row_selections.push(selection.clone()); contiguous_row_selections.push(selection.clone());
let start_row = starting_row(selection, display_map); let start_row = MultiBufferRow(selection.start.row);
let mut end_row = ending_row(selection, display_map); let mut end_row = ending_row(selection, display_map);
while let Some(next_selection) = selections.peek() { while let Some(next_selection) = selections.peek() {
@ -22294,14 +22271,6 @@ fn consume_contiguous_rows(
(start_row, end_row) (start_row, end_row)
} }
fn starting_row(selection: &Selection<Point>, display_map: &DisplaySnapshot) -> MultiBufferRow {
if selection.start.column > 0 {
MultiBufferRow(display_map.prev_line_boundary(selection.start).0.row)
} else {
MultiBufferRow(selection.start.row)
}
}
fn ending_row(next_selection: &Selection<Point>, display_map: &DisplaySnapshot) -> MultiBufferRow { fn ending_row(next_selection: &Selection<Point>, display_map: &DisplaySnapshot) -> MultiBufferRow {
if next_selection.end.column > 0 || next_selection.is_empty() { if next_selection.end.column > 0 || next_selection.is_empty() {
MultiBufferRow(display_map.next_line_boundary(next_selection.end).0.row + 1) MultiBufferRow(display_map.next_line_boundary(next_selection.end).0.row + 1)

View file

@ -4724,23 +4724,6 @@ async fn test_toggle_case(cx: &mut TestAppContext) {
"}); "});
} }
#[gpui::test]
async fn test_convert_to_sentence_case(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
cx.set_state(indoc! {"
«implement-windows-supportˇ»
"});
cx.update_editor(|e, window, cx| {
e.convert_to_sentence_case(&ConvertToSentenceCase, window, cx)
});
cx.assert_editor_state(indoc! {"
«Implement windows supportˇ»
"});
}
#[gpui::test] #[gpui::test]
async fn test_manipulate_text(cx: &mut TestAppContext) { async fn test_manipulate_text(cx: &mut TestAppContext) {
init_test(cx, |_| {}); init_test(cx, |_| {});
@ -5086,33 +5069,6 @@ fn test_move_line_up_down(cx: &mut TestAppContext) {
}); });
} }
#[gpui::test]
fn test_move_line_up_selection_at_end_of_fold(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let editor = cx.add_window(|window, cx| {
let buffer = MultiBuffer::build_simple("\n\n\n\n\n\naaaa\nbbbb\ncccc", cx);
build_editor(buffer, window, cx)
});
_ = editor.update(cx, |editor, window, cx| {
editor.fold_creases(
vec![Crease::simple(
Point::new(6, 4)..Point::new(7, 4),
FoldPlaceholder::test(),
)],
true,
window,
cx,
);
editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
s.select_ranges([Point::new(7, 4)..Point::new(7, 4)])
});
assert_eq!(editor.display_text(cx), "\n\n\n\n\n\naaaa⋯\ncccc");
editor.move_line_up(&MoveLineUp, window, cx);
let buffer_text = editor.buffer.read(cx).snapshot(cx).text();
assert_eq!(buffer_text, "\n\n\n\n\naaaa\nbbbb\n\ncccc");
});
}
#[gpui::test] #[gpui::test]
fn test_move_line_up_down_with_blocks(cx: &mut TestAppContext) { fn test_move_line_up_down_with_blocks(cx: &mut TestAppContext) {
init_test(cx, |_| {}); init_test(cx, |_| {});
@ -22663,435 +22619,6 @@ async fn test_indent_on_newline_for_python(cx: &mut TestAppContext) {
"}); "});
} }
#[gpui::test]
async fn test_tab_in_leading_whitespace_auto_indents_for_bash(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into());
cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx));
// test cursor move to start of each line on tab
// for `if`, `elif`, `else`, `while`, `for`, `case` and `function`
cx.set_state(indoc! {"
function main() {
ˇ for item in $items; do
ˇ while [ -n \"$item\" ]; do
ˇ if [ \"$value\" -gt 10 ]; then
ˇ continue
ˇ elif [ \"$value\" -lt 0 ]; then
ˇ break
ˇ else
ˇ echo \"$item\"
ˇ fi
ˇ done
ˇ done
ˇ}
"});
cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx));
cx.assert_editor_state(indoc! {"
function main() {
ˇfor item in $items; do
ˇwhile [ -n \"$item\" ]; do
ˇif [ \"$value\" -gt 10 ]; then
ˇcontinue
ˇelif [ \"$value\" -lt 0 ]; then
ˇbreak
ˇelse
ˇecho \"$item\"
ˇfi
ˇdone
ˇdone
ˇ}
"});
// test relative indent is preserved when tab
cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx));
cx.assert_editor_state(indoc! {"
function main() {
ˇfor item in $items; do
ˇwhile [ -n \"$item\" ]; do
ˇif [ \"$value\" -gt 10 ]; then
ˇcontinue
ˇelif [ \"$value\" -lt 0 ]; then
ˇbreak
ˇelse
ˇecho \"$item\"
ˇfi
ˇdone
ˇdone
ˇ}
"});
// test cursor move to start of each line on tab
// for `case` statement with patterns
cx.set_state(indoc! {"
function handle() {
ˇ case \"$1\" in
ˇ start)
ˇ echo \"a\"
ˇ ;;
ˇ stop)
ˇ echo \"b\"
ˇ ;;
ˇ *)
ˇ echo \"c\"
ˇ ;;
ˇ esac
ˇ}
"});
cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx));
cx.assert_editor_state(indoc! {"
function handle() {
ˇcase \"$1\" in
ˇstart)
ˇecho \"a\"
ˇ;;
ˇstop)
ˇecho \"b\"
ˇ;;
ˇ*)
ˇecho \"c\"
ˇ;;
ˇesac
ˇ}
"});
}
#[gpui::test]
async fn test_indent_after_input_for_bash(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into());
cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx));
// test indents on comment insert
cx.set_state(indoc! {"
function main() {
ˇ for item in $items; do
ˇ while [ -n \"$item\" ]; do
ˇ if [ \"$value\" -gt 10 ]; then
ˇ continue
ˇ elif [ \"$value\" -lt 0 ]; then
ˇ break
ˇ else
ˇ echo \"$item\"
ˇ fi
ˇ done
ˇ done
ˇ}
"});
cx.update_editor(|e, window, cx| e.handle_input("#", window, cx));
cx.assert_editor_state(indoc! {"
function main() {
#ˇ for item in $items; do
#ˇ while [ -n \"$item\" ]; do
#ˇ if [ \"$value\" -gt 10 ]; then
#ˇ continue
#ˇ elif [ \"$value\" -lt 0 ]; then
#ˇ break
#ˇ else
#ˇ echo \"$item\"
#ˇ fi
#ˇ done
#ˇ done
#ˇ}
"});
}
#[gpui::test]
async fn test_outdent_after_input_for_bash(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into());
cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx));
// test `else` auto outdents when typed inside `if` block
cx.set_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
echo \"foo bar\"
ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("else", window, cx);
});
cx.assert_editor_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
echo \"foo bar\"
elseˇ
"});
// test `elif` auto outdents when typed inside `if` block
cx.set_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
echo \"foo bar\"
ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("elif", window, cx);
});
cx.assert_editor_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
echo \"foo bar\"
elifˇ
"});
// test `fi` auto outdents when typed inside `else` block
cx.set_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
echo \"foo bar\"
else
echo \"bar baz\"
ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("fi", window, cx);
});
cx.assert_editor_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
echo \"foo bar\"
else
echo \"bar baz\"
fiˇ
"});
// test `done` auto outdents when typed inside `while` block
cx.set_state(indoc! {"
while read line; do
echo \"$line\"
ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("done", window, cx);
});
cx.assert_editor_state(indoc! {"
while read line; do
echo \"$line\"
doneˇ
"});
// test `done` auto outdents when typed inside `for` block
cx.set_state(indoc! {"
for file in *.txt; do
cat \"$file\"
ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("done", window, cx);
});
cx.assert_editor_state(indoc! {"
for file in *.txt; do
cat \"$file\"
doneˇ
"});
// test `esac` auto outdents when typed inside `case` block
cx.set_state(indoc! {"
case \"$1\" in
start)
echo \"foo bar\"
;;
stop)
echo \"bar baz\"
;;
ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("esac", window, cx);
});
cx.assert_editor_state(indoc! {"
case \"$1\" in
start)
echo \"foo bar\"
;;
stop)
echo \"bar baz\"
;;
esacˇ
"});
// test `*)` auto outdents when typed inside `case` block
cx.set_state(indoc! {"
case \"$1\" in
start)
echo \"foo bar\"
;;
ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("*)", window, cx);
});
cx.assert_editor_state(indoc! {"
case \"$1\" in
start)
echo \"foo bar\"
;;
*)ˇ
"});
// test `fi` outdents to correct level with nested if blocks
cx.set_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
echo \"outer if\"
if [ \"$2\" = \"debug\" ]; then
echo \"inner if\"
ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("fi", window, cx);
});
cx.assert_editor_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
echo \"outer if\"
if [ \"$2\" = \"debug\" ]; then
echo \"inner if\"
fiˇ
"});
}
#[gpui::test]
async fn test_indent_on_newline_for_bash(cx: &mut TestAppContext) {
init_test(cx, |_| {});
update_test_language_settings(cx, |settings| {
settings.defaults.extend_comment_on_newline = Some(false);
});
let mut cx = EditorTestContext::new(cx).await;
let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into());
cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx));
// test correct indent after newline on comment
cx.set_state(indoc! {"
# COMMENT:ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.assert_editor_state(indoc! {"
# COMMENT:
ˇ
"});
// test correct indent after newline after `then`
cx.set_state(indoc! {"
if [ \"$1\" = \"test\" ]; thenˇ
"});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
ˇ
"});
// test correct indent after newline after `else`
cx.set_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
elseˇ
"});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
else
ˇ
"});
// test correct indent after newline after `elif`
cx.set_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
elifˇ
"});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
if [ \"$1\" = \"test\" ]; then
elif
ˇ
"});
// test correct indent after newline after `do`
cx.set_state(indoc! {"
for file in *.txt; doˇ
"});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
for file in *.txt; do
ˇ
"});
// test correct indent after newline after case pattern
cx.set_state(indoc! {"
case \"$1\" in
start)ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
case \"$1\" in
start)
ˇ
"});
// test correct indent after newline after case pattern
cx.set_state(indoc! {"
case \"$1\" in
start)
;;
*)ˇ
"});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
case \"$1\" in
start)
;;
*)
ˇ
"});
// test correct indent after newline after function opening brace
cx.set_state(indoc! {"
function test() {ˇ}
"});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
function test() {
ˇ
}
"});
// test no extra indent after semicolon on same line
cx.set_state(indoc! {"
echo \"test\"
"});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
echo \"test\";
ˇ
"});
}
fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> { fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> {
let point = DisplayPoint::new(DisplayRow(row as u32), column as u32); let point = DisplayPoint::new(DisplayRow(row as u32), column as u32);
point..point point..point

View file

@ -230,6 +230,7 @@ impl EditorElement {
register_action(editor, window, Editor::sort_lines_case_insensitive); register_action(editor, window, Editor::sort_lines_case_insensitive);
register_action(editor, window, Editor::reverse_lines); register_action(editor, window, Editor::reverse_lines);
register_action(editor, window, Editor::shuffle_lines); register_action(editor, window, Editor::shuffle_lines);
register_action(editor, window, Editor::toggle_case);
register_action(editor, window, Editor::convert_indentation_to_spaces); register_action(editor, window, Editor::convert_indentation_to_spaces);
register_action(editor, window, Editor::convert_indentation_to_tabs); register_action(editor, window, Editor::convert_indentation_to_tabs);
register_action(editor, window, Editor::convert_to_upper_case); register_action(editor, window, Editor::convert_to_upper_case);
@ -240,8 +241,6 @@ impl EditorElement {
register_action(editor, window, Editor::convert_to_upper_camel_case); register_action(editor, window, Editor::convert_to_upper_camel_case);
register_action(editor, window, Editor::convert_to_lower_camel_case); register_action(editor, window, Editor::convert_to_lower_camel_case);
register_action(editor, window, Editor::convert_to_opposite_case); register_action(editor, window, Editor::convert_to_opposite_case);
register_action(editor, window, Editor::convert_to_sentence_case);
register_action(editor, window, Editor::toggle_case);
register_action(editor, window, Editor::convert_to_rot13); register_action(editor, window, Editor::convert_to_rot13);
register_action(editor, window, Editor::convert_to_rot47); register_action(editor, window, Editor::convert_to_rot47);
register_action(editor, window, Editor::delete_to_previous_word_start); register_action(editor, window, Editor::delete_to_previous_word_start);
@ -4011,7 +4010,6 @@ impl EditorElement {
let available_width = hitbox.bounds.size.width - right_margin; let available_width = hitbox.bounds.size.width - right_margin;
let mut header = v_flex() let mut header = v_flex()
.w_full()
.relative() .relative()
.child( .child(
div() div()
@ -7944,11 +7942,17 @@ impl Element for EditorElement {
right: right_margin, right: right_margin,
}; };
// Offset the content_bounds from the text_bounds by the gutter margin (which
// is roughly half a character wide) to make hit testing work more like how we want.
let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO);
let editor_content_width = editor_width - content_offset.x;
snapshot = self.editor.update(cx, |editor, cx| { snapshot = self.editor.update(cx, |editor, cx| {
editor.last_bounds = Some(bounds); editor.last_bounds = Some(bounds);
editor.gutter_dimensions = gutter_dimensions; editor.gutter_dimensions = gutter_dimensions;
editor.set_visible_line_count(bounds.size.height / line_height, window, cx); editor.set_visible_line_count(bounds.size.height / line_height, window, cx);
editor.set_visible_column_count(editor_width / em_advance); editor.set_visible_column_count(editor_content_width / em_advance);
if matches!( if matches!(
editor.mode, editor.mode,
@ -7960,10 +7964,10 @@ impl Element for EditorElement {
let wrap_width = match editor.soft_wrap_mode(cx) { let wrap_width = match editor.soft_wrap_mode(cx) {
SoftWrap::GitDiff => None, SoftWrap::GitDiff => None,
SoftWrap::None => Some(wrap_width_for(MAX_LINE_LEN as u32 / 2)), SoftWrap::None => Some(wrap_width_for(MAX_LINE_LEN as u32 / 2)),
SoftWrap::EditorWidth => Some(editor_width), SoftWrap::EditorWidth => Some(editor_content_width),
SoftWrap::Column(column) => Some(wrap_width_for(column)), SoftWrap::Column(column) => Some(wrap_width_for(column)),
SoftWrap::Bounded(column) => { SoftWrap::Bounded(column) => {
Some(editor_width.min(wrap_width_for(column))) Some(editor_content_width.min(wrap_width_for(column)))
} }
}; };
@ -7988,12 +7992,13 @@ impl Element for EditorElement {
HitboxBehavior::Normal, HitboxBehavior::Normal,
); );
// Offset the content_bounds from the text_bounds by the gutter margin (which
// is roughly half a character wide) to make hit testing work more like how we want.
let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO);
let content_origin = text_hitbox.origin + content_offset; let content_origin = text_hitbox.origin + content_offset;
let height_in_lines = bounds.size.height / line_height; let editor_text_bounds =
Bounds::from_corners(content_origin, bounds.bottom_right());
let height_in_lines = editor_text_bounds.size.height / line_height;
let max_row = snapshot.max_point().row().as_f32(); let max_row = snapshot.max_point().row().as_f32();
// The max scroll position for the top of the window // The max scroll position for the top of the window
@ -8377,6 +8382,7 @@ impl Element for EditorElement {
glyph_grid_cell, glyph_grid_cell,
size(longest_line_width, max_row.as_f32() * line_height), size(longest_line_width, max_row.as_f32() * line_height),
longest_line_blame_width, longest_line_blame_width,
editor_width,
EditorSettings::get_global(cx), EditorSettings::get_global(cx),
); );
@ -8448,7 +8454,7 @@ impl Element for EditorElement {
MultiBufferRow(end_anchor.to_point(&snapshot.buffer_snapshot).row); MultiBufferRow(end_anchor.to_point(&snapshot.buffer_snapshot).row);
let scroll_max = point( let scroll_max = point(
((scroll_width - editor_width) / em_advance).max(0.0), ((scroll_width - editor_content_width) / em_advance).max(0.0),
max_scroll_top, max_scroll_top,
); );
@ -8460,7 +8466,7 @@ impl Element for EditorElement {
if needs_horizontal_autoscroll.0 if needs_horizontal_autoscroll.0
&& let Some(new_scroll_position) = editor.autoscroll_horizontally( && let Some(new_scroll_position) = editor.autoscroll_horizontally(
start_row, start_row,
editor_width, editor_content_width,
scroll_width, scroll_width,
em_advance, em_advance,
&line_layouts, &line_layouts,
@ -9041,6 +9047,7 @@ impl ScrollbarLayoutInformation {
glyph_grid_cell: Size<Pixels>, glyph_grid_cell: Size<Pixels>,
document_size: Size<Pixels>, document_size: Size<Pixels>,
longest_line_blame_width: Pixels, longest_line_blame_width: Pixels,
editor_width: Pixels,
settings: &EditorSettings, settings: &EditorSettings,
) -> Self { ) -> Self {
let vertical_overscroll = match settings.scroll_beyond_last_line { let vertical_overscroll = match settings.scroll_beyond_last_line {
@ -9051,11 +9058,19 @@ impl ScrollbarLayoutInformation {
} }
}; };
let overscroll = size(longest_line_blame_width, vertical_overscroll); let right_margin = if document_size.width + longest_line_blame_width >= editor_width {
glyph_grid_cell.width
} else {
px(0.0)
};
let overscroll = size(right_margin + longest_line_blame_width, vertical_overscroll);
let scroll_range = document_size + overscroll;
ScrollbarLayoutInformation { ScrollbarLayoutInformation {
editor_bounds, editor_bounds,
scroll_range: document_size + overscroll, scroll_range,
glyph_grid_cell, glyph_grid_cell,
} }
} }
@ -9160,7 +9175,7 @@ struct EditorScrollbars {
impl EditorScrollbars { impl EditorScrollbars {
pub fn from_scrollbar_axes( pub fn from_scrollbar_axes(
show_scrollbar: ScrollbarAxes, settings_visibility: ScrollbarAxes,
layout_information: &ScrollbarLayoutInformation, layout_information: &ScrollbarLayoutInformation,
content_offset: gpui::Point<Pixels>, content_offset: gpui::Point<Pixels>,
scroll_position: gpui::Point<f32>, scroll_position: gpui::Point<f32>,
@ -9198,13 +9213,22 @@ impl EditorScrollbars {
}; };
let mut create_scrollbar_layout = |axis| { let mut create_scrollbar_layout = |axis| {
let viewport_size = viewport_size.along(axis); settings_visibility
let scroll_range = scroll_range.along(axis); .along(axis)
// We always want a vertical scrollbar track for scrollbar diagnostic visibility.
(show_scrollbar.along(axis)
&& (axis == ScrollbarAxis::Vertical || scroll_range > viewport_size))
.then(|| { .then(|| {
(
viewport_size.along(axis) - content_offset.along(axis),
scroll_range.along(axis),
)
})
.filter(|(viewport_size, scroll_range)| {
// The scrollbar should only be rendered if the content does
// not entirely fit into the editor
// However, this only applies to the horizontal scrollbar, as information about the
// vertical scrollbar layout is always needed for scrollbar diagnostics.
axis != ScrollbarAxis::Horizontal || viewport_size < scroll_range
})
.map(|(viewport_size, scroll_range)| {
ScrollbarLayout::new( ScrollbarLayout::new(
window.insert_hitbox(scrollbar_bounds_for(axis), HitboxBehavior::Normal), window.insert_hitbox(scrollbar_bounds_for(axis), HitboxBehavior::Normal),
viewport_size, viewport_size,

View file

@ -32,11 +32,7 @@ serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
task.workspace = true task.workspace = true
toml.workspace = true toml.workspace = true
url.workspace = true
util.workspace = true util.workspace = true
wasm-encoder.workspace = true wasm-encoder.workspace = true
wasmparser.workspace = true wasmparser.workspace = true
workspace-hack.workspace = true workspace-hack.workspace = true
[dev-dependencies]
pretty_assertions.workspace = true

View file

@ -1,20 +0,0 @@
mod download_file_capability;
mod npm_install_package_capability;
mod process_exec_capability;
pub use download_file_capability::*;
pub use npm_install_package_capability::*;
pub use process_exec_capability::*;
use serde::{Deserialize, Serialize};
/// A capability for an extension.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum ExtensionCapability {
#[serde(rename = "process:exec")]
ProcessExec(ProcessExecCapability),
DownloadFile(DownloadFileCapability),
#[serde(rename = "npm:install")]
NpmInstallPackage(NpmInstallPackageCapability),
}

View file

@ -1,121 +0,0 @@
use serde::{Deserialize, Serialize};
use url::Url;
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct DownloadFileCapability {
pub host: String,
pub path: Vec<String>,
}
impl DownloadFileCapability {
/// Returns whether the capability allows downloading a file from the given URL.
pub fn allows(&self, url: &Url) -> bool {
let Some(desired_host) = url.host_str() else {
return false;
};
let Some(desired_path) = url.path_segments() else {
return false;
};
let desired_path = desired_path.collect::<Vec<_>>();
if self.host != desired_host && self.host != "*" {
return false;
}
for (ix, path_segment) in self.path.iter().enumerate() {
if path_segment == "**" {
return true;
}
if ix >= desired_path.len() {
return false;
}
if path_segment != "*" && path_segment != desired_path[ix] {
return false;
}
}
if self.path.len() < desired_path.len() {
return false;
}
true
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn test_allows() {
let capability = DownloadFileCapability {
host: "*".to_string(),
path: vec!["**".to_string()],
};
assert_eq!(
capability.allows(&"https://example.com/some/path".parse().unwrap()),
true
);
let capability = DownloadFileCapability {
host: "github.com".to_string(),
path: vec!["**".to_string()],
};
assert_eq!(
capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()),
true
);
assert_eq!(
capability.allows(
&"https://fake-github.com/some-owner/some-repo"
.parse()
.unwrap()
),
false
);
let capability = DownloadFileCapability {
host: "github.com".to_string(),
path: vec!["specific-owner".to_string(), "*".to_string()],
};
assert_eq!(
capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()),
false
);
assert_eq!(
capability.allows(
&"https://github.com/specific-owner/some-repo"
.parse()
.unwrap()
),
true
);
let capability = DownloadFileCapability {
host: "github.com".to_string(),
path: vec!["specific-owner".to_string(), "*".to_string()],
};
assert_eq!(
capability.allows(
&"https://github.com/some-owner/some-repo/extra"
.parse()
.unwrap()
),
false
);
assert_eq!(
capability.allows(
&"https://github.com/specific-owner/some-repo/extra"
.parse()
.unwrap()
),
false
);
}
}

View file

@ -1,39 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct NpmInstallPackageCapability {
pub package: String,
}
impl NpmInstallPackageCapability {
/// Returns whether the capability allows installing the given NPM package.
pub fn allows(&self, package: &str) -> bool {
self.package == "*" || self.package == package
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn test_allows() {
let capability = NpmInstallPackageCapability {
package: "*".to_string(),
};
assert_eq!(capability.allows("package"), true);
let capability = NpmInstallPackageCapability {
package: "react".to_string(),
};
assert_eq!(capability.allows("react"), true);
let capability = NpmInstallPackageCapability {
package: "react".to_string(),
};
assert_eq!(capability.allows("malicious-package"), false);
}
}

View file

@ -1,116 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct ProcessExecCapability {
/// The command to execute.
pub command: String,
/// The arguments to pass to the command. Use `*` for a single wildcard argument.
/// If the last element is `**`, then any trailing arguments are allowed.
pub args: Vec<String>,
}
impl ProcessExecCapability {
/// Returns whether the capability allows the given command and arguments.
pub fn allows(
&self,
desired_command: &str,
desired_args: &[impl AsRef<str> + std::fmt::Debug],
) -> bool {
if self.command != desired_command && self.command != "*" {
return false;
}
for (ix, arg) in self.args.iter().enumerate() {
if arg == "**" {
return true;
}
if ix >= desired_args.len() {
return false;
}
if arg != "*" && arg != desired_args[ix].as_ref() {
return false;
}
}
if self.args.len() < desired_args.len() {
return false;
}
true
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn test_allows_with_exact_match() {
let capability = ProcessExecCapability {
command: "ls".to_string(),
args: vec!["-la".to_string()],
};
assert_eq!(capability.allows("ls", &["-la"]), true);
assert_eq!(capability.allows("ls", &["-l"]), false);
assert_eq!(capability.allows("pwd", &[] as &[&str]), false);
}
#[test]
fn test_allows_with_wildcard_arg() {
let capability = ProcessExecCapability {
command: "git".to_string(),
args: vec!["*".to_string()],
};
assert_eq!(capability.allows("git", &["status"]), true);
assert_eq!(capability.allows("git", &["commit"]), true);
// Too many args.
assert_eq!(capability.allows("git", &["status", "-s"]), false);
// Wrong command.
assert_eq!(capability.allows("npm", &["install"]), false);
}
#[test]
fn test_allows_with_double_wildcard() {
let capability = ProcessExecCapability {
command: "cargo".to_string(),
args: vec!["test".to_string(), "**".to_string()],
};
assert_eq!(capability.allows("cargo", &["test"]), true);
assert_eq!(capability.allows("cargo", &["test", "--all"]), true);
assert_eq!(
capability.allows("cargo", &["test", "--all", "--no-fail-fast"]),
true
);
// Wrong first arg.
assert_eq!(capability.allows("cargo", &["build"]), false);
}
#[test]
fn test_allows_with_mixed_wildcards() {
let capability = ProcessExecCapability {
command: "docker".to_string(),
args: vec!["run".to_string(), "*".to_string(), "**".to_string()],
};
assert_eq!(capability.allows("docker", &["run", "nginx"]), true);
assert_eq!(capability.allows("docker", &["run"]), false);
assert_eq!(
capability.allows("docker", &["run", "ubuntu", "bash"]),
true
);
assert_eq!(
capability.allows("docker", &["run", "alpine", "sh", "-c", "echo hello"]),
true
);
// Wrong first arg.
assert_eq!(capability.allows("docker", &["ps"]), false);
}
}

View file

@ -1,4 +1,3 @@
mod capabilities;
pub mod extension_builder; pub mod extension_builder;
mod extension_events; mod extension_events;
mod extension_host_proxy; mod extension_host_proxy;
@ -17,7 +16,6 @@ use language::LanguageName;
use semantic_version::SemanticVersion; use semantic_version::SemanticVersion;
use task::{SpawnInTerminal, ZedDebugConfig}; use task::{SpawnInTerminal, ZedDebugConfig};
pub use crate::capabilities::*;
pub use crate::extension_events::*; pub use crate::extension_events::*;
pub use crate::extension_host_proxy::*; pub use crate::extension_host_proxy::*;
pub use crate::extension_manifest::*; pub use crate::extension_manifest::*;

View file

@ -12,8 +12,6 @@ use std::{
sync::Arc, sync::Arc,
}; };
use crate::ExtensionCapability;
/// This is the old version of the extension manifest, from when it was `extension.json`. /// This is the old version of the extension manifest, from when it was `extension.json`.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct OldExtensionManifest { pub struct OldExtensionManifest {
@ -102,8 +100,24 @@ impl ExtensionManifest {
desired_args: &[impl AsRef<str> + std::fmt::Debug], desired_args: &[impl AsRef<str> + std::fmt::Debug],
) -> Result<()> { ) -> Result<()> {
let is_allowed = self.capabilities.iter().any(|capability| match capability { let is_allowed = self.capabilities.iter().any(|capability| match capability {
ExtensionCapability::ProcessExec(capability) => { ExtensionCapability::ProcessExec { command, args } if command == desired_command => {
capability.allows(desired_command, desired_args) for (ix, arg) in args.iter().enumerate() {
if arg == "**" {
return true;
}
if ix >= desired_args.len() {
return false;
}
if arg != "*" && arg != desired_args[ix].as_ref() {
return false;
}
}
if args.len() < desired_args.len() {
return false;
}
true
} }
_ => false, _ => false,
}); });
@ -134,6 +148,20 @@ pub fn build_debug_adapter_schema_path(
}) })
} }
/// A capability for an extension.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum ExtensionCapability {
#[serde(rename = "process:exec")]
ProcessExec {
/// The command to execute.
command: String,
/// The arguments to pass to the command. Use `*` for a single wildcard argument.
/// If the last element is `**`, then any trailing arguments are allowed.
args: Vec<String>,
},
}
#[derive(Clone, Default, PartialEq, Eq, Debug, Deserialize, Serialize)] #[derive(Clone, Default, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct LibManifestEntry { pub struct LibManifestEntry {
pub kind: Option<ExtensionLibraryKind>, pub kind: Option<ExtensionLibraryKind>,
@ -281,10 +309,6 @@ fn manifest_from_old_manifest(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use pretty_assertions::assert_eq;
use crate::ProcessExecCapability;
use super::*; use super::*;
fn extension_manifest() -> ExtensionManifest { fn extension_manifest() -> ExtensionManifest {
@ -336,12 +360,12 @@ mod tests {
} }
#[test] #[test]
fn test_allow_exec_exact_match() { fn test_allow_exact_match() {
let manifest = ExtensionManifest { let manifest = ExtensionManifest {
capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { capabilities: vec![ExtensionCapability::ProcessExec {
command: "ls".to_string(), command: "ls".to_string(),
args: vec!["-la".to_string()], args: vec!["-la".to_string()],
})], }],
..extension_manifest() ..extension_manifest()
}; };
@ -351,12 +375,12 @@ mod tests {
} }
#[test] #[test]
fn test_allow_exec_wildcard_arg() { fn test_allow_wildcard_arg() {
let manifest = ExtensionManifest { let manifest = ExtensionManifest {
capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { capabilities: vec![ExtensionCapability::ProcessExec {
command: "git".to_string(), command: "git".to_string(),
args: vec!["*".to_string()], args: vec!["*".to_string()],
})], }],
..extension_manifest() ..extension_manifest()
}; };
@ -367,12 +391,12 @@ mod tests {
} }
#[test] #[test]
fn test_allow_exec_double_wildcard() { fn test_allow_double_wildcard() {
let manifest = ExtensionManifest { let manifest = ExtensionManifest {
capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { capabilities: vec![ExtensionCapability::ProcessExec {
command: "cargo".to_string(), command: "cargo".to_string(),
args: vec!["test".to_string(), "**".to_string()], args: vec!["test".to_string(), "**".to_string()],
})], }],
..extension_manifest() ..extension_manifest()
}; };
@ -387,12 +411,12 @@ mod tests {
} }
#[test] #[test]
fn test_allow_exec_mixed_wildcards() { fn test_allow_mixed_wildcards() {
let manifest = ExtensionManifest { let manifest = ExtensionManifest {
capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { capabilities: vec![ExtensionCapability::ProcessExec {
command: "docker".to_string(), command: "docker".to_string(),
args: vec!["run".to_string(), "*".to_string(), "**".to_string()], args: vec!["run".to_string(), "*".to_string(), "**".to_string()],
})], }],
..extension_manifest() ..extension_manifest()
}; };

View file

@ -134,12 +134,10 @@ fn manifest() -> ExtensionManifest {
slash_commands: BTreeMap::default(), slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(), indexed_docs_providers: BTreeMap::default(),
snippets: None, snippets: None,
capabilities: vec![ExtensionCapability::ProcessExec( capabilities: vec![ExtensionCapability::ProcessExec {
extension::ProcessExecCapability { command: "echo".into(),
command: "echo".into(), args: vec!["hello!".into()],
args: vec!["hello!".into()], }],
},
)],
debug_adapters: Default::default(), debug_adapters: Default::default(),
debug_locators: Default::default(), debug_locators: Default::default(),
} }

View file

@ -1,153 +0,0 @@
use std::sync::Arc;
use anyhow::{Result, bail};
use extension::{ExtensionCapability, ExtensionManifest};
use url::Url;
pub struct CapabilityGranter {
granted_capabilities: Vec<ExtensionCapability>,
manifest: Arc<ExtensionManifest>,
}
impl CapabilityGranter {
pub fn new(
granted_capabilities: Vec<ExtensionCapability>,
manifest: Arc<ExtensionManifest>,
) -> Self {
Self {
granted_capabilities,
manifest,
}
}
pub fn grant_exec(
&self,
desired_command: &str,
desired_args: &[impl AsRef<str> + std::fmt::Debug],
) -> Result<()> {
self.manifest.allow_exec(desired_command, desired_args)?;
let is_allowed = self
.granted_capabilities
.iter()
.any(|capability| match capability {
ExtensionCapability::ProcessExec(capability) => {
capability.allows(desired_command, desired_args)
}
_ => false,
});
if !is_allowed {
bail!(
"capability for process:exec {desired_command} {desired_args:?} is not granted by the extension host",
);
}
Ok(())
}
pub fn grant_download_file(&self, desired_url: &Url) -> Result<()> {
let is_allowed = self
.granted_capabilities
.iter()
.any(|capability| match capability {
ExtensionCapability::DownloadFile(capability) => capability.allows(desired_url),
_ => false,
});
if !is_allowed {
bail!(
"capability for download_file {desired_url} is not granted by the extension host",
);
}
Ok(())
}
pub fn grant_npm_install_package(&self, package_name: &str) -> Result<()> {
let is_allowed = self
.granted_capabilities
.iter()
.any(|capability| match capability {
ExtensionCapability::NpmInstallPackage(capability) => {
capability.allows(package_name)
}
_ => false,
});
if !is_allowed {
bail!("capability for npm:install {package_name} is not granted by the extension host",);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use extension::{ProcessExecCapability, SchemaVersion};
use super::*;
fn extension_manifest() -> ExtensionManifest {
ExtensionManifest {
id: "test".into(),
name: "Test".to_string(),
version: "1.0.0".into(),
schema_version: SchemaVersion::ZERO,
description: None,
repository: None,
authors: vec![],
lib: Default::default(),
themes: vec![],
icon_themes: vec![],
languages: vec![],
grammars: BTreeMap::default(),
language_servers: BTreeMap::default(),
context_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(),
snippets: None,
capabilities: vec![],
debug_adapters: Default::default(),
debug_locators: Default::default(),
}
}
#[test]
fn test_grant_exec() {
let manifest = Arc::new(ExtensionManifest {
capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
command: "ls".to_string(),
args: vec!["-la".to_string()],
})],
..extension_manifest()
});
// It returns an error when the extension host has no granted capabilities.
let granter = CapabilityGranter::new(Vec::new(), manifest.clone());
assert!(granter.grant_exec("ls", &["-la"]).is_err());
// It succeeds when the extension host has the exact capability.
let granter = CapabilityGranter::new(
vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
command: "ls".to_string(),
args: vec!["-la".to_string()],
})],
manifest.clone(),
);
assert!(granter.grant_exec("ls", &["-la"]).is_ok());
// It succeeds when the extension host has a wildcard capability.
let granter = CapabilityGranter::new(
vec![ExtensionCapability::ProcessExec(ProcessExecCapability {
command: "*".to_string(),
args: vec!["**".to_string()],
})],
manifest.clone(),
);
assert!(granter.grant_exec("ls", &["-la"]).is_ok());
}
}

View file

@ -1,4 +1,3 @@
mod capability_granter;
pub mod extension_settings; pub mod extension_settings;
pub mod headless_host; pub mod headless_host;
pub mod wasm_host; pub mod wasm_host;

View file

@ -1,15 +1,13 @@
pub mod wit; pub mod wit;
use crate::ExtensionManifest; use crate::ExtensionManifest;
use crate::capability_granter::CapabilityGranter;
use anyhow::{Context as _, Result, anyhow, bail}; use anyhow::{Context as _, Result, anyhow, bail};
use async_trait::async_trait; use async_trait::async_trait;
use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest}; use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
use extension::{ use extension::{
CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary, CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
DebugTaskDefinition, DownloadFileCapability, ExtensionCapability, ExtensionHostProxy, DebugTaskDefinition, ExtensionHostProxy, KeyValueStoreDelegate, ProjectDelegate, SlashCommand,
KeyValueStoreDelegate, NpmInstallPackageCapability, ProcessExecCapability, ProjectDelegate, SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate,
SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate,
}; };
use fs::{Fs, normalize_path}; use fs::{Fs, normalize_path};
use futures::future::LocalBoxFuture; use futures::future::LocalBoxFuture;
@ -52,8 +50,6 @@ pub struct WasmHost {
pub(crate) proxy: Arc<ExtensionHostProxy>, pub(crate) proxy: Arc<ExtensionHostProxy>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
pub work_dir: PathBuf, pub work_dir: PathBuf,
/// The capabilities granted to extensions running on the host.
pub(crate) granted_capabilities: Vec<ExtensionCapability>,
_main_thread_message_task: Task<()>, _main_thread_message_task: Task<()>,
main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>, main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>,
} }
@ -490,7 +486,6 @@ pub struct WasmState {
pub table: ResourceTable, pub table: ResourceTable,
ctx: wasi::WasiCtx, ctx: wasi::WasiCtx,
pub host: Arc<WasmHost>, pub host: Arc<WasmHost>,
pub(crate) capability_granter: CapabilityGranter,
} }
type MainThreadCall = Box<dyn Send + for<'a> FnOnce(&'a mut AsyncApp) -> LocalBoxFuture<'a, ()>>; type MainThreadCall = Box<dyn Send + for<'a> FnOnce(&'a mut AsyncApp) -> LocalBoxFuture<'a, ()>>;
@ -576,19 +571,6 @@ impl WasmHost {
node_runtime, node_runtime,
proxy, proxy,
release_channel: ReleaseChannel::global(cx), release_channel: ReleaseChannel::global(cx),
granted_capabilities: vec![
ExtensionCapability::ProcessExec(ProcessExecCapability {
command: "*".to_string(),
args: vec!["**".to_string()],
}),
ExtensionCapability::DownloadFile(DownloadFileCapability {
host: "*".to_string(),
path: vec!["**".to_string()],
}),
ExtensionCapability::NpmInstallPackage(NpmInstallPackageCapability {
package: "*".to_string(),
}),
],
_main_thread_message_task: task, _main_thread_message_task: task,
main_thread_message_tx: tx, main_thread_message_tx: tx,
}) })
@ -615,10 +597,6 @@ impl WasmHost {
manifest: manifest.clone(), manifest: manifest.clone(),
table: ResourceTable::new(), table: ResourceTable::new(),
host: this.clone(), host: this.clone(),
capability_granter: CapabilityGranter::new(
this.granted_capabilities.clone(),
manifest.clone(),
),
}, },
); );
// Store will yield after 1 tick, and get a new deadline of 1 tick after each yield. // Store will yield after 1 tick, and get a new deadline of 1 tick after each yield.
@ -777,18 +755,8 @@ impl WasmExtension {
} }
.boxed() .boxed()
})) }))
.unwrap_or_else(|_| { .expect("wasm extension channel should not be closed yet");
panic!( return_rx.await.expect("wasm extension channel")
"wasm extension channel should not be closed yet, extension {} (id {})",
self.manifest.name, self.manifest.id,
)
});
return_rx.await.unwrap_or_else(|_| {
panic!(
"wasm extension channel, extension {} (id {})",
self.manifest.name, self.manifest.id,
)
})
} }
} }
@ -809,19 +777,8 @@ impl WasmState {
} }
.boxed_local() .boxed_local()
})) }))
.unwrap_or_else(|_| { .expect("main thread message channel should not be closed yet");
panic!( async move { return_rx.await.expect("main thread message channel") }
"main thread message channel should not be closed yet, extension {} (id {})",
self.manifest.name, self.manifest.id,
)
});
let name = self.manifest.name.clone();
let id = self.manifest.id.clone();
async move {
return_rx.await.unwrap_or_else(|_| {
panic!("main thread message channel, extension {name} (id {id})")
})
}
} }
fn work_dir(&self) -> PathBuf { fn work_dir(&self) -> PathBuf {

View file

@ -30,7 +30,6 @@ use std::{
sync::{Arc, OnceLock}, sync::{Arc, OnceLock},
}; };
use task::{SpawnInTerminal, ZedDebugConfig}; use task::{SpawnInTerminal, ZedDebugConfig};
use url::Url;
use util::{archive::extract_zip, fs::make_file_executable, maybe}; use util::{archive::extract_zip, fs::make_file_executable, maybe};
use wasmtime::component::{Linker, Resource}; use wasmtime::component::{Linker, Resource};
@ -745,9 +744,6 @@ impl nodejs::Host for WasmState {
package_name: String, package_name: String,
version: String, version: String,
) -> wasmtime::Result<Result<(), String>> { ) -> wasmtime::Result<Result<(), String>> {
self.capability_granter
.grant_npm_install_package(&package_name)?;
self.host self.host
.node_runtime .node_runtime
.npm_install_packages(&self.work_dir(), &[(&package_name, &version)]) .npm_install_packages(&self.work_dir(), &[(&package_name, &version)])
@ -851,8 +847,7 @@ impl process::Host for WasmState {
command: process::Command, command: process::Command,
) -> wasmtime::Result<Result<process::Output, String>> { ) -> wasmtime::Result<Result<process::Output, String>> {
maybe!(async { maybe!(async {
self.capability_granter self.manifest.allow_exec(&command.command, &command.args)?;
.grant_exec(&command.command, &command.args)?;
let output = util::command::new_smol_command(command.command.as_str()) let output = util::command::new_smol_command(command.command.as_str())
.args(&command.args) .args(&command.args)
@ -1015,9 +1010,6 @@ impl ExtensionImports for WasmState {
file_type: DownloadedFileType, file_type: DownloadedFileType,
) -> wasmtime::Result<Result<(), String>> { ) -> wasmtime::Result<Result<(), String>> {
maybe!(async { maybe!(async {
let parsed_url = Url::parse(&url)?;
self.capability_granter.grant_download_file(&parsed_url)?;
let path = PathBuf::from(path); let path = PathBuf::from(path);
let extension_work_dir = self.host.work_dir.join(self.manifest.id.as_ref()); let extension_work_dir = self.host.work_dir.join(self.manifest.id.as_ref());

View file

@ -85,11 +85,6 @@ impl FeatureFlag for ThreadAutoCaptureFeatureFlag {
false false
} }
} }
pub struct PanicFeatureFlag;
impl FeatureFlag for PanicFeatureFlag {
const NAME: &'static str = "panic";
}
pub struct JjUiFeatureFlag {} pub struct JjUiFeatureFlag {}

View file

@ -1404,21 +1404,14 @@ impl PickerDelegate for FileFinderDelegate {
} else { } else {
let path_position = PathWithPosition::parse_str(&raw_query); let path_position = PathWithPosition::parse_str(&raw_query);
#[cfg(windows)]
let raw_query = raw_query.trim().to_owned().replace("/", "\\");
#[cfg(not(windows))]
let raw_query = raw_query.trim().to_owned();
let file_query_end = if path_position.path.to_str().unwrap_or(&raw_query) == raw_query {
None
} else {
// Safe to unwrap as we won't get here when the unwrap in if fails
Some(path_position.path.to_str().unwrap().len())
};
let query = FileSearchQuery { let query = FileSearchQuery {
raw_query, raw_query: raw_query.trim().to_owned(),
file_query_end, file_query_end: if path_position.path.to_str().unwrap_or(raw_query) == raw_query {
None
} else {
// Safe to unwrap as we won't get here when the unwrap in if fails
Some(path_position.path.to_str().unwrap().len())
},
path_position, path_position,
}; };

View file

@ -398,18 +398,6 @@ impl GitRepository for FakeGitRepository {
}) })
} }
fn stash_paths(
&self,
_paths: Vec<RepoPath>,
_env: Arc<HashMap<String, String>>,
) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn stash_pop(&self, _env: Arc<HashMap<String, String>>) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn commit( fn commit(
&self, &self,
_message: gpui::SharedString, _message: gpui::SharedString,

View file

@ -55,10 +55,6 @@ actions!(
StageAll, StageAll,
/// Unstages all changes in the repository. /// Unstages all changes in the repository.
UnstageAll, UnstageAll,
/// Stashes all changes in the repository, including untracked files.
StashAll,
/// Pops the most recent stash.
StashPop,
/// Restores all tracked files to their last committed state. /// Restores all tracked files to their last committed state.
RestoreTrackedFiles, RestoreTrackedFiles,
/// Moves all untracked files to trash. /// Moves all untracked files to trash.

View file

@ -395,14 +395,6 @@ pub trait GitRepository: Send + Sync {
env: Arc<HashMap<String, String>>, env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>>; ) -> BoxFuture<'_, Result<()>>;
fn stash_paths(
&self,
paths: Vec<RepoPath>,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<Result<()>>;
fn stash_pop(&self, env: Arc<HashMap<String, String>>) -> BoxFuture<Result<()>>;
fn push( fn push(
&self, &self,
branch_name: String, branch_name: String,
@ -1197,55 +1189,6 @@ impl GitRepository for RealGitRepository {
.boxed() .boxed()
} }
fn stash_paths(
&self,
paths: Vec<RepoPath>,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<Result<()>> {
let working_directory = self.working_directory();
self.executor
.spawn(async move {
let mut cmd = new_smol_command("git");
cmd.current_dir(&working_directory?)
.envs(env.iter())
.args(["stash", "push", "--quiet"])
.arg("--include-untracked");
cmd.args(paths.iter().map(|p| p.as_ref()));
let output = cmd.output().await?;
anyhow::ensure!(
output.status.success(),
"Failed to stash:\n{}",
String::from_utf8_lossy(&output.stderr)
);
Ok(())
})
.boxed()
}
fn stash_pop(&self, env: Arc<HashMap<String, String>>) -> BoxFuture<Result<()>> {
let working_directory = self.working_directory();
self.executor
.spawn(async move {
let mut cmd = new_smol_command("git");
cmd.current_dir(&working_directory?)
.envs(env.iter())
.args(["stash", "pop"]);
let output = cmd.output().await?;
anyhow::ensure!(
output.status.success(),
"Failed to stash pop:\n{}",
String::from_utf8_lossy(&output.stderr)
);
Ok(())
})
.boxed()
}
fn commit( fn commit(
&self, &self,
message: SharedString, message: SharedString,

View file

@ -159,11 +159,7 @@ impl GitHostingProvider for Github {
} }
let mut path_segments = url.path_segments()?; let mut path_segments = url.path_segments()?;
let mut owner = path_segments.next()?; let owner = path_segments.next()?;
if owner.is_empty() {
owner = path_segments.next()?;
}
let repo = path_segments.next()?.trim_end_matches(".git"); let repo = path_segments.next()?.trim_end_matches(".git");
Some(ParsedGitRemote { Some(ParsedGitRemote {
@ -248,22 +244,6 @@ mod tests {
use super::*; use super::*;
#[test]
fn test_remote_url_with_root_slash() {
let remote_url = "git@github.com:/zed-industries/zed";
let parsed_remote = Github::public_instance()
.parse_remote_url(remote_url)
.unwrap();
assert_eq!(
parsed_remote,
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
}
);
}
#[test] #[test]
fn test_invalid_self_hosted_remote_url() { fn test_invalid_self_hosted_remote_url() {
let remote_url = "git@github.com:zed-industries/zed.git"; let remote_url = "git@github.com:zed-industries/zed.git";

View file

@ -27,10 +27,7 @@ use git::repository::{
}; };
use git::status::StageStatus; use git::status::StageStatus;
use git::{Amend, Signoff, ToggleStaged, repository::RepoPath, status::FileStatus}; use git::{Amend, Signoff, ToggleStaged, repository::RepoPath, status::FileStatus};
use git::{ use git::{ExpandCommitEditor, RestoreTrackedFiles, StageAll, TrashUntrackedFiles, UnstageAll};
ExpandCommitEditor, RestoreTrackedFiles, StageAll, StashAll, StashPop, TrashUntrackedFiles,
UnstageAll,
};
use gpui::{ use gpui::{
Action, Animation, AnimationExt as _, AsyncApp, AsyncWindowContext, Axis, ClickEvent, Corner, Action, Animation, AnimationExt as _, AsyncApp, AsyncWindowContext, Axis, ClickEvent, Corner,
DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, KeyContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, KeyContext,
@ -143,13 +140,6 @@ fn git_panel_context_menu(
UnstageAll.boxed_clone(), UnstageAll.boxed_clone(),
) )
.separator() .separator()
.action_disabled_when(
!(state.has_new_changes || state.has_tracked_changes),
"Stash All",
StashAll.boxed_clone(),
)
.action("Stash Pop", StashPop.boxed_clone())
.separator()
.action("Open Diff", project_diff::Diff.boxed_clone()) .action("Open Diff", project_diff::Diff.boxed_clone())
.separator() .separator()
.action_disabled_when( .action_disabled_when(
@ -390,9 +380,6 @@ pub(crate) fn commit_message_editor(
window: &mut Window, window: &mut Window,
cx: &mut Context<Editor>, cx: &mut Context<Editor>,
) -> Editor { ) -> Editor {
project.update(cx, |this, cx| {
this.mark_buffer_as_non_searchable(commit_message_buffer.read(cx).remote_id(), cx);
});
let buffer = cx.new(|cx| MultiBuffer::singleton(commit_message_buffer, cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(commit_message_buffer, cx));
let max_lines = if in_panel { MAX_PANEL_EDITOR_LINES } else { 18 }; let max_lines = if in_panel { MAX_PANEL_EDITOR_LINES } else { 18 };
let mut commit_editor = Editor::new( let mut commit_editor = Editor::new(
@ -1425,52 +1412,6 @@ impl GitPanel {
self.tracked_staged_count + self.new_staged_count + self.conflicted_staged_count self.tracked_staged_count + self.new_staged_count + self.conflicted_staged_count
} }
pub fn stash_pop(&mut self, _: &StashPop, _window: &mut Window, cx: &mut Context<Self>) {
let Some(active_repository) = self.active_repository.clone() else {
return;
};
cx.spawn({
async move |this, cx| {
let stash_task = active_repository
.update(cx, |repo, cx| repo.stash_pop(cx))?
.await;
this.update(cx, |this, cx| {
stash_task
.map_err(|e| {
this.show_error_toast("stash pop", e, cx);
})
.ok();
cx.notify();
})
}
})
.detach();
}
pub fn stash_all(&mut self, _: &StashAll, _window: &mut Window, cx: &mut Context<Self>) {
let Some(active_repository) = self.active_repository.clone() else {
return;
};
cx.spawn({
async move |this, cx| {
let stash_task = active_repository
.update(cx, |repo, cx| repo.stash_all(cx))?
.await;
this.update(cx, |this, cx| {
stash_task
.map_err(|e| {
this.show_error_toast("stash", e, cx);
})
.ok();
cx.notify();
})
}
})
.detach();
}
pub fn commit_message_buffer(&self, cx: &App) -> Entity<Buffer> { pub fn commit_message_buffer(&self, cx: &App) -> Entity<Buffer> {
self.commit_editor self.commit_editor
.read(cx) .read(cx)
@ -4430,8 +4371,6 @@ impl Render for GitPanel {
.on_action(cx.listener(Self::revert_selected)) .on_action(cx.listener(Self::revert_selected))
.on_action(cx.listener(Self::clean_all)) .on_action(cx.listener(Self::clean_all))
.on_action(cx.listener(Self::generate_commit_message_action)) .on_action(cx.listener(Self::generate_commit_message_action))
.on_action(cx.listener(Self::stash_all))
.on_action(cx.listener(Self::stash_pop))
}) })
.on_action(cx.listener(Self::select_first)) .on_action(cx.listener(Self::select_first))
.on_action(cx.listener(Self::select_next)) .on_action(cx.listener(Self::select_next))

View file

@ -114,22 +114,6 @@ pub fn init(cx: &mut App) {
}); });
}); });
} }
workspace.register_action(|workspace, action: &git::StashAll, window, cx| {
let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else {
return;
};
panel.update(cx, |panel, cx| {
panel.stash_all(action, window, cx);
});
});
workspace.register_action(|workspace, action: &git::StashPop, window, cx| {
let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else {
return;
};
panel.update(cx, |panel, cx| {
panel.stash_pop(action, window, cx);
});
});
workspace.register_action(|workspace, action: &git::StageAll, window, cx| { workspace.register_action(|workspace, action: &git::StageAll, window, cx| {
let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else {
return; return;

View file

@ -121,7 +121,7 @@ smallvec.workspace = true
smol.workspace = true smol.workspace = true
strum.workspace = true strum.workspace = true
sum_tree.workspace = true sum_tree.workspace = true
taffy = "=0.8.3" taffy = "=0.5.1"
thiserror.workspace = true thiserror.workspace = true
util.workspace = true util.workspace = true
uuid.workspace = true uuid.workspace = true
@ -287,10 +287,6 @@ path = "examples/shadow.rs"
name = "svg" name = "svg"
path = "examples/svg/svg.rs" path = "examples/svg/svg.rs"
[[example]]
name = "tab_stop"
path = "examples/tab_stop.rs"
[[example]] [[example]]
name = "text" name = "text"
path = "examples/text.rs" path = "examples/text.rs"

View file

@ -128,7 +128,6 @@ mod macos {
"AtlasTile".into(), "AtlasTile".into(),
"PathRasterizationInputIndex".into(), "PathRasterizationInputIndex".into(),
"PathVertex_ScaledPixels".into(), "PathVertex_ScaledPixels".into(),
"PathRasterizationVertex".into(),
"ShadowInputIndex".into(), "ShadowInputIndex".into(),
"Shadow".into(), "Shadow".into(),
"QuadInputIndex".into(), "QuadInputIndex".into(),

View file

@ -1,12 +1,11 @@
use gpui::{ use gpui::{
Application, Background, Bounds, ColorSpace, Context, MouseDownEvent, Path, PathBuilder, Application, Background, Bounds, ColorSpace, Context, MouseDownEvent, Path, PathBuilder,
PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowOptions, canvas, PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowOptions, canvas,
div, linear_color_stop, linear_gradient, point, prelude::*, px, quad, rgb, size, div, linear_color_stop, linear_gradient, point, prelude::*, px, rgb, size,
}; };
struct PaintingViewer { struct PaintingViewer {
default_lines: Vec<(Path<Pixels>, Background)>, default_lines: Vec<(Path<Pixels>, Background)>,
background_quads: Vec<(Bounds<Pixels>, Background)>,
lines: Vec<Vec<Point<Pixels>>>, lines: Vec<Vec<Point<Pixels>>>,
start: Point<Pixels>, start: Point<Pixels>,
dashed: bool, dashed: bool,
@ -17,148 +16,12 @@ impl PaintingViewer {
fn new(_window: &mut Window, _cx: &mut Context<Self>) -> Self { fn new(_window: &mut Window, _cx: &mut Context<Self>) -> Self {
let mut lines = vec![]; let mut lines = vec![];
// Black squares beneath transparent paths.
let background_quads = vec![
(
Bounds {
origin: point(px(70.), px(70.)),
size: size(px(40.), px(40.)),
},
gpui::black().into(),
),
(
Bounds {
origin: point(px(170.), px(70.)),
size: size(px(40.), px(40.)),
},
gpui::black().into(),
),
(
Bounds {
origin: point(px(270.), px(70.)),
size: size(px(40.), px(40.)),
},
gpui::black().into(),
),
(
Bounds {
origin: point(px(370.), px(70.)),
size: size(px(40.), px(40.)),
},
gpui::black().into(),
),
(
Bounds {
origin: point(px(450.), px(50.)),
size: size(px(80.), px(80.)),
},
gpui::black().into(),
),
];
// 50% opaque red path that extends across black quad.
let mut builder = PathBuilder::fill();
builder.move_to(point(px(50.), px(50.)));
builder.line_to(point(px(130.), px(50.)));
builder.line_to(point(px(130.), px(130.)));
builder.line_to(point(px(50.), px(130.)));
builder.close();
let path = builder.build().unwrap();
let mut red = rgb(0xFF0000);
red.a = 0.5;
lines.push((path, red.into()));
// 50% opaque blue path that extends across black quad.
let mut builder = PathBuilder::fill();
builder.move_to(point(px(150.), px(50.)));
builder.line_to(point(px(230.), px(50.)));
builder.line_to(point(px(230.), px(130.)));
builder.line_to(point(px(150.), px(130.)));
builder.close();
let path = builder.build().unwrap();
let mut blue = rgb(0x0000FF);
blue.a = 0.5;
lines.push((path, blue.into()));
// 50% opaque green path that extends across black quad.
let mut builder = PathBuilder::fill();
builder.move_to(point(px(250.), px(50.)));
builder.line_to(point(px(330.), px(50.)));
builder.line_to(point(px(330.), px(130.)));
builder.line_to(point(px(250.), px(130.)));
builder.close();
let path = builder.build().unwrap();
let mut green = rgb(0x00FF00);
green.a = 0.5;
lines.push((path, green.into()));
// 50% opaque black path that extends across black quad.
let mut builder = PathBuilder::fill();
builder.move_to(point(px(350.), px(50.)));
builder.line_to(point(px(430.), px(50.)));
builder.line_to(point(px(430.), px(130.)));
builder.line_to(point(px(350.), px(130.)));
builder.close();
let path = builder.build().unwrap();
let mut black = rgb(0x000000);
black.a = 0.5;
lines.push((path, black.into()));
// Two 50% opaque red circles overlapping - center should be darker red
let mut builder = PathBuilder::fill();
let center = point(px(530.), px(85.));
let radius = px(30.);
builder.move_to(point(center.x + radius, center.y));
builder.arc_to(
point(radius, radius),
px(0.),
false,
false,
point(center.x - radius, center.y),
);
builder.arc_to(
point(radius, radius),
px(0.),
false,
false,
point(center.x + radius, center.y),
);
builder.close();
let path = builder.build().unwrap();
let mut red1 = rgb(0xFF0000);
red1.a = 0.5;
lines.push((path, red1.into()));
let mut builder = PathBuilder::fill();
let center = point(px(570.), px(85.));
let radius = px(30.);
builder.move_to(point(center.x + radius, center.y));
builder.arc_to(
point(radius, radius),
px(0.),
false,
false,
point(center.x - radius, center.y),
);
builder.arc_to(
point(radius, radius),
px(0.),
false,
false,
point(center.x + radius, center.y),
);
builder.close();
let path = builder.build().unwrap();
let mut red2 = rgb(0xFF0000);
red2.a = 0.5;
lines.push((path, red2.into()));
// draw a Rust logo // draw a Rust logo
let mut builder = lyon::path::Path::svg_builder(); let mut builder = lyon::path::Path::svg_builder();
lyon::extra::rust_logo::build_logo_path(&mut builder); lyon::extra::rust_logo::build_logo_path(&mut builder);
// move down the Path // move down the Path
let mut builder: PathBuilder = builder.into(); let mut builder: PathBuilder = builder.into();
builder.translate(point(px(10.), px(200.))); builder.translate(point(px(10.), px(100.)));
builder.scale(0.9); builder.scale(0.9);
let path = builder.build().unwrap(); let path = builder.build().unwrap();
lines.push((path, gpui::black().into())); lines.push((path, gpui::black().into()));
@ -167,10 +30,10 @@ impl PaintingViewer {
let mut builder = PathBuilder::fill(); let mut builder = PathBuilder::fill();
builder.add_polygon( builder.add_polygon(
&[ &[
point(px(150.), px(300.)), point(px(150.), px(200.)),
point(px(200.), px(225.)), point(px(200.), px(125.)),
point(px(200.), px(275.)), point(px(200.), px(175.)),
point(px(250.), px(200.)), point(px(250.), px(100.)),
], ],
false, false,
); );
@ -179,17 +42,17 @@ impl PaintingViewer {
// draw a ⭐ // draw a ⭐
let mut builder = PathBuilder::fill(); let mut builder = PathBuilder::fill();
builder.move_to(point(px(350.), px(200.))); builder.move_to(point(px(350.), px(100.)));
builder.line_to(point(px(370.), px(260.))); builder.line_to(point(px(370.), px(160.)));
builder.line_to(point(px(430.), px(260.))); builder.line_to(point(px(430.), px(160.)));
builder.line_to(point(px(380.), px(300.))); builder.line_to(point(px(380.), px(200.)));
builder.line_to(point(px(400.), px(360.))); builder.line_to(point(px(400.), px(260.)));
builder.line_to(point(px(350.), px(320.))); builder.line_to(point(px(350.), px(220.)));
builder.line_to(point(px(300.), px(360.))); builder.line_to(point(px(300.), px(260.)));
builder.line_to(point(px(320.), px(300.))); builder.line_to(point(px(320.), px(200.)));
builder.line_to(point(px(270.), px(260.))); builder.line_to(point(px(270.), px(160.)));
builder.line_to(point(px(330.), px(260.))); builder.line_to(point(px(330.), px(160.)));
builder.line_to(point(px(350.), px(200.))); builder.line_to(point(px(350.), px(100.)));
let path = builder.build().unwrap(); let path = builder.build().unwrap();
lines.push(( lines.push((
path, path,
@ -203,7 +66,7 @@ impl PaintingViewer {
// draw linear gradient // draw linear gradient
let square_bounds = Bounds { let square_bounds = Bounds {
origin: point(px(450.), px(200.)), origin: point(px(450.), px(100.)),
size: size(px(200.), px(80.)), size: size(px(200.), px(80.)),
}; };
let height = square_bounds.size.height; let height = square_bounds.size.height;
@ -233,31 +96,31 @@ impl PaintingViewer {
// draw a pie chart // draw a pie chart
let center = point(px(96.), px(96.)); let center = point(px(96.), px(96.));
let pie_center = point(px(775.), px(255.)); let pie_center = point(px(775.), px(155.));
let segments = [ let segments = [
( (
point(px(871.), px(255.)), point(px(871.), px(155.)),
point(px(747.), px(163.)), point(px(747.), px(63.)),
rgb(0x1374e9), rgb(0x1374e9),
), ),
( (
point(px(747.), px(163.)), point(px(747.), px(63.)),
point(px(679.), px(263.)), point(px(679.), px(163.)),
rgb(0xe13527), rgb(0xe13527),
), ),
( (
point(px(679.), px(263.)), point(px(679.), px(163.)),
point(px(754.), px(349.)), point(px(754.), px(249.)),
rgb(0x0751ce), rgb(0x0751ce),
), ),
( (
point(px(754.), px(349.)), point(px(754.), px(249.)),
point(px(854.), px(310.)), point(px(854.), px(210.)),
rgb(0x209742), rgb(0x209742),
), ),
( (
point(px(854.), px(310.)), point(px(854.), px(210.)),
point(px(871.), px(255.)), point(px(871.), px(155.)),
rgb(0xfbc10a), rgb(0xfbc10a),
), ),
]; ];
@ -277,11 +140,11 @@ impl PaintingViewer {
.with_line_width(1.) .with_line_width(1.)
.with_line_join(lyon::path::LineJoin::Bevel); .with_line_join(lyon::path::LineJoin::Bevel);
let mut builder = PathBuilder::stroke(px(1.)).with_style(PathStyle::Stroke(options)); let mut builder = PathBuilder::stroke(px(1.)).with_style(PathStyle::Stroke(options));
builder.move_to(point(px(40.), px(420.))); builder.move_to(point(px(40.), px(320.)));
for i in 1..50 { for i in 1..50 {
builder.line_to(point( builder.line_to(point(
px(40.0 + i as f32 * 10.0), px(40.0 + i as f32 * 10.0),
px(420.0 + (i as f32 * 10.0).sin() * 40.0), px(320.0 + (i as f32 * 10.0).sin() * 40.0),
)); ));
} }
let path = builder.build().unwrap(); let path = builder.build().unwrap();
@ -289,7 +152,6 @@ impl PaintingViewer {
Self { Self {
default_lines: lines.clone(), default_lines: lines.clone(),
background_quads,
lines: vec![], lines: vec![],
start: point(px(0.), px(0.)), start: point(px(0.), px(0.)),
dashed: false, dashed: false,
@ -323,7 +185,6 @@ fn button(
impl Render for PaintingViewer { impl Render for PaintingViewer {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let default_lines = self.default_lines.clone(); let default_lines = self.default_lines.clone();
let background_quads = self.background_quads.clone();
let lines = self.lines.clone(); let lines = self.lines.clone();
let dashed = self.dashed; let dashed = self.dashed;
@ -360,19 +221,6 @@ impl Render for PaintingViewer {
canvas( canvas(
move |_, _, _| {}, move |_, _, _| {},
move |_, _, window, _| { move |_, _, window, _| {
// First draw background quads
for (bounds, color) in background_quads.iter() {
window.paint_quad(quad(
*bounds,
px(0.),
*color,
px(0.),
gpui::transparent_black(),
Default::default(),
));
}
// Then draw the default paths on top
for (path, color) in default_lines { for (path, color) in default_lines {
window.paint_path(path, color); window.paint_path(path, color);
} }
@ -455,10 +303,6 @@ fn main() {
|window, cx| cx.new(|cx| PaintingViewer::new(window, cx)), |window, cx| cx.new(|cx| PaintingViewer::new(window, cx)),
) )
.unwrap(); .unwrap();
cx.on_window_closed(|cx| {
cx.quit();
})
.detach();
cx.activate(true); cx.activate(true);
}); });
} }

View file

@ -1,92 +0,0 @@
use gpui::{
Application, Background, Bounds, ColorSpace, Context, Path, PathBuilder, Pixels, Render,
TitlebarOptions, Window, WindowBounds, WindowOptions, canvas, div, linear_color_stop,
linear_gradient, point, prelude::*, px, rgb, size,
};
const DEFAULT_WINDOW_WIDTH: Pixels = px(1024.0);
const DEFAULT_WINDOW_HEIGHT: Pixels = px(768.0);
struct PaintingViewer {
default_lines: Vec<(Path<Pixels>, Background)>,
_painting: bool,
}
impl PaintingViewer {
fn new(_window: &mut Window, _cx: &mut Context<Self>) -> Self {
let mut lines = vec![];
// draw a lightening bolt ⚡
for _ in 0..2000 {
// draw a ⭐
let mut builder = PathBuilder::fill();
builder.move_to(point(px(350.), px(100.)));
builder.line_to(point(px(370.), px(160.)));
builder.line_to(point(px(430.), px(160.)));
builder.line_to(point(px(380.), px(200.)));
builder.line_to(point(px(400.), px(260.)));
builder.line_to(point(px(350.), px(220.)));
builder.line_to(point(px(300.), px(260.)));
builder.line_to(point(px(320.), px(200.)));
builder.line_to(point(px(270.), px(160.)));
builder.line_to(point(px(330.), px(160.)));
builder.line_to(point(px(350.), px(100.)));
let path = builder.build().unwrap();
lines.push((
path,
linear_gradient(
180.,
linear_color_stop(rgb(0xFACC15), 0.7),
linear_color_stop(rgb(0xD56D0C), 1.),
)
.color_space(ColorSpace::Oklab),
));
}
Self {
default_lines: lines,
_painting: false,
}
}
}
impl Render for PaintingViewer {
fn render(&mut self, window: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
window.request_animation_frame();
let lines = self.default_lines.clone();
div().size_full().child(
canvas(
move |_, _, _| {},
move |_, _, window, _| {
for (path, color) in lines {
window.paint_path(path, color);
}
},
)
.size_full(),
)
}
}
fn main() {
Application::new().run(|cx| {
cx.open_window(
WindowOptions {
titlebar: Some(TitlebarOptions {
title: Some("Vulkan".into()),
..Default::default()
}),
focus: true,
window_bounds: Some(WindowBounds::Windowed(Bounds::centered(
None,
size(DEFAULT_WINDOW_WIDTH, DEFAULT_WINDOW_HEIGHT),
cx,
))),
..Default::default()
},
|window, cx| cx.new(|cx| PaintingViewer::new(window, cx)),
)
.unwrap();
cx.activate(true);
});
}

View file

@ -6,7 +6,6 @@ use gpui::{
actions!(example, [Tab, TabPrev]); actions!(example, [Tab, TabPrev]);
struct Example { struct Example {
focus_handle: FocusHandle,
items: Vec<FocusHandle>, items: Vec<FocusHandle>,
message: SharedString, message: SharedString,
} }
@ -21,11 +20,8 @@ impl Example {
cx.focus_handle().tab_index(2).tab_stop(true), cx.focus_handle().tab_index(2).tab_stop(true),
]; ];
let focus_handle = cx.focus_handle(); window.focus(items.first().unwrap());
window.focus(&focus_handle);
Self { Self {
focus_handle,
items, items,
message: SharedString::from("Press `Tab`, `Shift-Tab` to switch focus."), message: SharedString::from("Press `Tab`, `Shift-Tab` to switch focus."),
} }
@ -44,10 +40,6 @@ impl Example {
impl Render for Example { impl Render for Example {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn tab_stop_style<T: Styled>(this: T) -> T {
this.border_3().border_color(gpui::blue())
}
fn button(id: impl Into<ElementId>) -> Stateful<Div> { fn button(id: impl Into<ElementId>) -> Stateful<Div> {
div() div()
.id(id) .id(id)
@ -60,13 +52,12 @@ impl Render for Example {
.border_color(gpui::black()) .border_color(gpui::black())
.bg(gpui::black()) .bg(gpui::black())
.text_color(gpui::white()) .text_color(gpui::white())
.focus(tab_stop_style) .focus(|this| this.border_color(gpui::blue()))
.shadow_sm() .shadow_sm()
} }
div() div()
.id("app") .id("app")
.track_focus(&self.focus_handle)
.on_action(cx.listener(Self::on_tab)) .on_action(cx.listener(Self::on_tab))
.on_action(cx.listener(Self::on_tab_prev)) .on_action(cx.listener(Self::on_tab_prev))
.size_full() .size_full()
@ -95,7 +86,7 @@ impl Render for Example {
.border_color(gpui::black()) .border_color(gpui::black())
.when( .when(
item_handle.tab_stop && item_handle.is_focused(window), item_handle.tab_stop && item_handle.is_focused(window),
tab_stop_style, |this| this.border_color(gpui::blue()),
) )
.map(|this| match item_handle.tab_stop { .map(|this| match item_handle.tab_stop {
true => this true => this

View file

@ -1334,6 +1334,7 @@ impl Element for Div {
} else if let Some(scroll_handle) = self.interactivity.tracked_scroll_handle.as_ref() { } else if let Some(scroll_handle) = self.interactivity.tracked_scroll_handle.as_ref() {
let mut state = scroll_handle.0.borrow_mut(); let mut state = scroll_handle.0.borrow_mut();
state.child_bounds = Vec::with_capacity(request_layout.child_layout_ids.len()); state.child_bounds = Vec::with_capacity(request_layout.child_layout_ids.len());
state.bounds = bounds;
for child_layout_id in &request_layout.child_layout_ids { for child_layout_id in &request_layout.child_layout_ids {
let child_bounds = window.layout_bounds(*child_layout_id); let child_bounds = window.layout_bounds(*child_layout_id);
child_min = child_min.min(&child_bounds.origin); child_min = child_min.min(&child_bounds.origin);
@ -1705,7 +1706,6 @@ impl Interactivity {
if let Some(mut scroll_handle_state) = tracked_scroll_handle { if let Some(mut scroll_handle_state) = tracked_scroll_handle {
scroll_handle_state.max_offset = scroll_max; scroll_handle_state.max_offset = scroll_max;
scroll_handle_state.bounds = bounds;
} }
*scroll_offset *scroll_offset
@ -3007,6 +3007,11 @@ impl ScrollHandle {
self.0.borrow().bounds self.0.borrow().bounds
} }
/// Set the bounds into which this child is painted
pub(super) fn set_bounds(&self, bounds: Bounds<Pixels>) {
self.0.borrow_mut().bounds = bounds;
}
/// Get the bounds for a specific child. /// Get the bounds for a specific child.
pub fn bounds_for_item(&self, ix: usize) -> Option<Bounds<Pixels>> { pub fn bounds_for_item(&self, ix: usize) -> Option<Bounds<Pixels>> {
self.0.borrow().child_bounds.get(ix).cloned() self.0.borrow().child_bounds.get(ix).cloned()

View file

@ -88,24 +88,15 @@ pub enum ScrollStrategy {
/// May not be possible if there's not enough list items above the item scrolled to: /// May not be possible if there's not enough list items above the item scrolled to:
/// in this case, the element will be placed at the closest possible position. /// in this case, the element will be placed at the closest possible position.
Center, Center,
} /// Scrolls the element to be at the given item index from the top of the viewport.
ToPosition(usize),
#[derive(Clone, Copy, Debug)]
#[allow(missing_docs)]
pub struct DeferredScrollToItem {
/// The item index to scroll to
pub item_index: usize,
/// The scroll strategy to use
pub strategy: ScrollStrategy,
/// The offset in number of items
pub offset: usize,
} }
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
#[allow(missing_docs)] #[allow(missing_docs)]
pub struct UniformListScrollState { pub struct UniformListScrollState {
pub base_handle: ScrollHandle, pub base_handle: ScrollHandle,
pub deferred_scroll_to_item: Option<DeferredScrollToItem>, pub deferred_scroll_to_item: Option<(usize, ScrollStrategy)>,
/// Size of the item, captured during last layout. /// Size of the item, captured during last layout.
pub last_item_size: Option<ItemSize>, pub last_item_size: Option<ItemSize>,
/// Whether the list was vertically flipped during last layout. /// Whether the list was vertically flipped during last layout.
@ -135,24 +126,7 @@ impl UniformListScrollHandle {
/// Scroll the list to the given item index. /// Scroll the list to the given item index.
pub fn scroll_to_item(&self, ix: usize, strategy: ScrollStrategy) { pub fn scroll_to_item(&self, ix: usize, strategy: ScrollStrategy) {
self.0.borrow_mut().deferred_scroll_to_item = Some(DeferredScrollToItem { self.0.borrow_mut().deferred_scroll_to_item = Some((ix, strategy));
item_index: ix,
strategy,
offset: 0,
});
}
/// Scroll the list to the given item index with an offset.
///
/// For ScrollStrategy::Top, the item will be placed at the offset position from the top.
///
/// For ScrollStrategy::Center, the item will be centered between offset and the last visible item.
pub fn scroll_to_item_with_offset(&self, ix: usize, strategy: ScrollStrategy, offset: usize) {
self.0.borrow_mut().deferred_scroll_to_item = Some(DeferredScrollToItem {
item_index: ix,
strategy,
offset,
});
} }
/// Check if the list is flipped vertically. /// Check if the list is flipped vertically.
@ -165,8 +139,7 @@ impl UniformListScrollHandle {
pub fn logical_scroll_top_index(&self) -> usize { pub fn logical_scroll_top_index(&self) -> usize {
let this = self.0.borrow(); let this = self.0.borrow();
this.deferred_scroll_to_item this.deferred_scroll_to_item
.as_ref() .map(|(ix, _)| ix)
.map(|deferred| deferred.item_index)
.unwrap_or_else(|| this.base_handle.logical_scroll_top().0) .unwrap_or_else(|| this.base_handle.logical_scroll_top().0)
} }
@ -322,8 +295,9 @@ impl Element for UniformList {
bounds.bottom_right() - point(border.right + padding.right, border.bottom), bounds.bottom_right() - point(border.right + padding.right, border.bottom),
); );
let y_flipped = if let Some(scroll_handle) = &self.scroll_handle { let y_flipped = if let Some(scroll_handle) = self.scroll_handle.as_mut() {
let scroll_state = scroll_handle.0.borrow(); let mut scroll_state = scroll_handle.0.borrow_mut();
scroll_state.base_handle.set_bounds(bounds);
scroll_state.y_flipped scroll_state.y_flipped
} else { } else {
false false
@ -347,8 +321,7 @@ impl Element for UniformList {
scroll_offset.x = Pixels::ZERO; scroll_offset.x = Pixels::ZERO;
} }
if let Some(deferred_scroll) = shared_scroll_to_item { if let Some((mut ix, scroll_strategy)) = shared_scroll_to_item {
let mut ix = deferred_scroll.item_index;
if y_flipped { if y_flipped {
ix = self.item_count.saturating_sub(ix + 1); ix = self.item_count.saturating_sub(ix + 1);
} }
@ -357,28 +330,23 @@ impl Element for UniformList {
let item_top = item_height * ix + padding.top; let item_top = item_height * ix + padding.top;
let item_bottom = item_top + item_height; let item_bottom = item_top + item_height;
let scroll_top = -updated_scroll_offset.y; let scroll_top = -updated_scroll_offset.y;
let offset_pixels = item_height * deferred_scroll.offset;
let mut scrolled_to_top = false; let mut scrolled_to_top = false;
if item_top < scroll_top + padding.top {
if item_top < scroll_top + padding.top + offset_pixels {
scrolled_to_top = true; scrolled_to_top = true;
updated_scroll_offset.y = -(item_top) + padding.top + offset_pixels; updated_scroll_offset.y = -(item_top) + padding.top;
} else if item_bottom > scroll_top + list_height - padding.bottom { } else if item_bottom > scroll_top + list_height - padding.bottom {
scrolled_to_top = true; scrolled_to_top = true;
updated_scroll_offset.y = -(item_bottom - list_height) - padding.bottom; updated_scroll_offset.y = -(item_bottom - list_height) - padding.bottom;
} }
match deferred_scroll.strategy { match scroll_strategy {
ScrollStrategy::Top => {} ScrollStrategy::Top => {}
ScrollStrategy::Center => { ScrollStrategy::Center => {
if scrolled_to_top { if scrolled_to_top {
let item_center = item_top + item_height / 2.0; let item_center = item_top + item_height / 2.0;
let target_scroll_top = item_center - list_height / 2.0;
let viewport_height = list_height - offset_pixels; if item_top < scroll_top
let viewport_center = offset_pixels + viewport_height / 2.0;
let target_scroll_top = item_center - viewport_center;
if item_top < scroll_top + offset_pixels
|| item_bottom > scroll_top + list_height || item_bottom > scroll_top + list_height
{ {
updated_scroll_offset.y = -target_scroll_top updated_scroll_offset.y = -target_scroll_top
@ -388,6 +356,15 @@ impl Element for UniformList {
} }
} }
} }
ScrollStrategy::ToPosition(sticky_index) => {
let target_y_in_viewport = item_height * sticky_index;
let target_scroll_top = item_top - target_y_in_viewport;
let max_scroll_top =
(content_height - list_height).max(Pixels::ZERO);
let new_scroll_top =
target_scroll_top.clamp(Pixels::ZERO, max_scroll_top);
updated_scroll_offset.y = -new_scroll_top;
}
} }
scroll_offset = *updated_scroll_offset scroll_offset = *updated_scroll_offset
} }

View file

@ -809,6 +809,7 @@ pub(crate) struct AtlasTextureId {
pub(crate) enum AtlasTextureKind { pub(crate) enum AtlasTextureKind {
Monochrome = 0, Monochrome = 0,
Polychrome = 1, Polychrome = 1,
Path = 2,
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]

View file

@ -10,6 +10,8 @@ use etagere::BucketedAtlasAllocator;
use parking_lot::Mutex; use parking_lot::Mutex;
use std::{borrow::Cow, ops, sync::Arc}; use std::{borrow::Cow, ops, sync::Arc};
pub(crate) const PATH_TEXTURE_FORMAT: gpu::TextureFormat = gpu::TextureFormat::R16Float;
pub(crate) struct BladeAtlas(Mutex<BladeAtlasState>); pub(crate) struct BladeAtlas(Mutex<BladeAtlasState>);
struct PendingUpload { struct PendingUpload {
@ -25,6 +27,7 @@ struct BladeAtlasState {
tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, tiles_by_key: FxHashMap<AtlasKey, AtlasTile>,
initializations: Vec<AtlasTextureId>, initializations: Vec<AtlasTextureId>,
uploads: Vec<PendingUpload>, uploads: Vec<PendingUpload>,
path_sample_count: u32,
} }
#[cfg(gles)] #[cfg(gles)]
@ -38,11 +41,13 @@ impl BladeAtlasState {
} }
pub struct BladeTextureInfo { pub struct BladeTextureInfo {
pub size: gpu::Extent,
pub raw_view: gpu::TextureView, pub raw_view: gpu::TextureView,
pub msaa_view: Option<gpu::TextureView>,
} }
impl BladeAtlas { impl BladeAtlas {
pub(crate) fn new(gpu: &Arc<gpu::Context>) -> Self { pub(crate) fn new(gpu: &Arc<gpu::Context>, path_sample_count: u32) -> Self {
BladeAtlas(Mutex::new(BladeAtlasState { BladeAtlas(Mutex::new(BladeAtlasState {
gpu: Arc::clone(gpu), gpu: Arc::clone(gpu),
upload_belt: BufferBelt::new(BufferBeltDescriptor { upload_belt: BufferBelt::new(BufferBeltDescriptor {
@ -54,6 +59,7 @@ impl BladeAtlas {
tiles_by_key: Default::default(), tiles_by_key: Default::default(),
initializations: Vec::new(), initializations: Vec::new(),
uploads: Vec::new(), uploads: Vec::new(),
path_sample_count,
})) }))
} }
@ -61,6 +67,27 @@ impl BladeAtlas {
self.0.lock().destroy(); self.0.lock().destroy();
} }
pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) {
let mut lock = self.0.lock();
let textures = &mut lock.storage[texture_kind];
for texture in textures.iter_mut() {
texture.clear();
}
}
/// Allocate a rectangle and make it available for rendering immediately (without waiting for `before_frame`)
pub fn allocate_for_rendering(
&self,
size: Size<DevicePixels>,
texture_kind: AtlasTextureKind,
gpu_encoder: &mut gpu::CommandEncoder,
) -> AtlasTile {
let mut lock = self.0.lock();
let tile = lock.allocate(size, texture_kind);
lock.flush_initializations(gpu_encoder);
tile
}
pub fn before_frame(&self, gpu_encoder: &mut gpu::CommandEncoder) { pub fn before_frame(&self, gpu_encoder: &mut gpu::CommandEncoder) {
let mut lock = self.0.lock(); let mut lock = self.0.lock();
lock.flush(gpu_encoder); lock.flush(gpu_encoder);
@ -74,8 +101,15 @@ impl BladeAtlas {
pub fn get_texture_info(&self, id: AtlasTextureId) -> BladeTextureInfo { pub fn get_texture_info(&self, id: AtlasTextureId) -> BladeTextureInfo {
let lock = self.0.lock(); let lock = self.0.lock();
let texture = &lock.storage[id]; let texture = &lock.storage[id];
let size = texture.allocator.size();
BladeTextureInfo { BladeTextureInfo {
size: gpu::Extent {
width: size.width as u32,
height: size.height as u32,
depth: 1,
},
raw_view: texture.raw_view, raw_view: texture.raw_view,
msaa_view: texture.msaa_view,
} }
} }
} }
@ -166,8 +200,48 @@ impl BladeAtlasState {
format = gpu::TextureFormat::Bgra8UnormSrgb; format = gpu::TextureFormat::Bgra8UnormSrgb;
usage = gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE; usage = gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE;
} }
AtlasTextureKind::Path => {
format = PATH_TEXTURE_FORMAT;
usage = gpu::TextureUsage::COPY
| gpu::TextureUsage::RESOURCE
| gpu::TextureUsage::TARGET;
}
} }
// We currently only enable MSAA for path textures.
let (msaa, msaa_view) = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path {
let msaa = self.gpu.create_texture(gpu::TextureDesc {
name: "msaa path texture",
format,
size: gpu::Extent {
width: size.width.into(),
height: size.height.into(),
depth: 1,
},
array_layer_count: 1,
mip_level_count: 1,
sample_count: self.path_sample_count,
dimension: gpu::TextureDimension::D2,
usage: gpu::TextureUsage::TARGET,
external: None,
});
(
Some(msaa),
Some(self.gpu.create_texture_view(
msaa,
gpu::TextureViewDesc {
name: "msaa texture view",
format,
dimension: gpu::ViewDimension::D2,
subresources: &Default::default(),
},
)),
)
} else {
(None, None)
};
let raw = self.gpu.create_texture(gpu::TextureDesc { let raw = self.gpu.create_texture(gpu::TextureDesc {
name: "atlas", name: "atlas",
format, format,
@ -205,6 +279,8 @@ impl BladeAtlasState {
format, format,
raw, raw,
raw_view, raw_view,
msaa,
msaa_view,
live_atlas_keys: 0, live_atlas_keys: 0,
}; };
@ -264,6 +340,7 @@ impl BladeAtlasState {
struct BladeAtlasStorage { struct BladeAtlasStorage {
monochrome_textures: AtlasTextureList<BladeAtlasTexture>, monochrome_textures: AtlasTextureList<BladeAtlasTexture>,
polychrome_textures: AtlasTextureList<BladeAtlasTexture>, polychrome_textures: AtlasTextureList<BladeAtlasTexture>,
path_textures: AtlasTextureList<BladeAtlasTexture>,
} }
impl ops::Index<AtlasTextureKind> for BladeAtlasStorage { impl ops::Index<AtlasTextureKind> for BladeAtlasStorage {
@ -272,6 +349,7 @@ impl ops::Index<AtlasTextureKind> for BladeAtlasStorage {
match kind { match kind {
crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Monochrome => &self.monochrome_textures,
crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures,
crate::AtlasTextureKind::Path => &self.path_textures,
} }
} }
} }
@ -281,6 +359,7 @@ impl ops::IndexMut<AtlasTextureKind> for BladeAtlasStorage {
match kind { match kind {
crate::AtlasTextureKind::Monochrome => &mut self.monochrome_textures, crate::AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
crate::AtlasTextureKind::Polychrome => &mut self.polychrome_textures, crate::AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
crate::AtlasTextureKind::Path => &mut self.path_textures,
} }
} }
} }
@ -291,6 +370,7 @@ impl ops::Index<AtlasTextureId> for BladeAtlasStorage {
let textures = match id.kind { let textures = match id.kind {
crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Monochrome => &self.monochrome_textures,
crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures,
crate::AtlasTextureKind::Path => &self.path_textures,
}; };
textures[id.index as usize].as_ref().unwrap() textures[id.index as usize].as_ref().unwrap()
} }
@ -304,6 +384,9 @@ impl BladeAtlasStorage {
for mut texture in self.polychrome_textures.drain().flatten() { for mut texture in self.polychrome_textures.drain().flatten() {
texture.destroy(gpu); texture.destroy(gpu);
} }
for mut texture in self.path_textures.drain().flatten() {
texture.destroy(gpu);
}
} }
} }
@ -312,11 +395,17 @@ struct BladeAtlasTexture {
allocator: BucketedAtlasAllocator, allocator: BucketedAtlasAllocator,
raw: gpu::Texture, raw: gpu::Texture,
raw_view: gpu::TextureView, raw_view: gpu::TextureView,
msaa: Option<gpu::Texture>,
msaa_view: Option<gpu::TextureView>,
format: gpu::TextureFormat, format: gpu::TextureFormat,
live_atlas_keys: u32, live_atlas_keys: u32,
} }
impl BladeAtlasTexture { impl BladeAtlasTexture {
fn clear(&mut self) {
self.allocator.clear();
}
fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> {
let allocation = self.allocator.allocate(size.into())?; let allocation = self.allocator.allocate(size.into())?;
let tile = AtlasTile { let tile = AtlasTile {
@ -335,6 +424,12 @@ impl BladeAtlasTexture {
fn destroy(&mut self, gpu: &gpu::Context) { fn destroy(&mut self, gpu: &gpu::Context) {
gpu.destroy_texture(self.raw); gpu.destroy_texture(self.raw);
gpu.destroy_texture_view(self.raw_view); gpu.destroy_texture_view(self.raw_view);
if let Some(msaa) = self.msaa {
gpu.destroy_texture(msaa);
}
if let Some(msaa_view) = self.msaa_view {
gpu.destroy_texture_view(msaa_view);
}
} }
fn bytes_per_pixel(&self) -> u8 { fn bytes_per_pixel(&self) -> u8 {

View file

@ -1,19 +1,24 @@
// Doing `if let` gives you nice scoping with passes/encoders // Doing `if let` gives you nice scoping with passes/encoders
#![allow(irrefutable_let_patterns)] #![allow(irrefutable_let_patterns)]
use super::{BladeAtlas, BladeContext}; use super::{BladeAtlas, BladeContext, PATH_TEXTURE_FORMAT};
use crate::{ use crate::{
Background, Bounds, DevicePixels, GpuSpecs, MonochromeSprite, Path, Point, PolychromeSprite, AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels, GpuSpecs,
PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Underline, MonochromeSprite, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch, Quad,
ScaledPixels, Scene, Shadow, Size, Underline,
}; };
use blade_graphics as gpu; use blade_graphics as gpu;
use blade_util::{BufferBelt, BufferBeltDescriptor}; use blade_util::{BufferBelt, BufferBeltDescriptor};
use bytemuck::{Pod, Zeroable}; use bytemuck::{Pod, Zeroable};
use collections::HashMap;
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
use media::core_video::CVMetalTextureCache; use media::core_video::CVMetalTextureCache;
use std::sync::Arc; use std::{mem, sync::Arc};
const MAX_FRAME_TIME_MS: u32 = 10000; const MAX_FRAME_TIME_MS: u32 = 10000;
// Use 4x MSAA, all devices support it.
// https://developer.apple.com/documentation/metal/mtldevice/1433355-supportstexturesamplecount
const DEFAULT_PATH_SAMPLE_COUNT: u32 = 4;
#[repr(C)] #[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)] #[derive(Clone, Copy, Pod, Zeroable)]
@ -109,15 +114,8 @@ struct ShaderSurfacesData {
#[repr(C)] #[repr(C)]
struct PathSprite { struct PathSprite {
bounds: Bounds<ScaledPixels>, bounds: Bounds<ScaledPixels>,
}
#[derive(Clone, Debug)]
#[repr(C)]
struct PathRasterizationVertex {
xy_position: Point<ScaledPixels>,
st_position: Point<f32>,
color: Background, color: Background,
bounds: Bounds<ScaledPixels>, tile: AtlasTile,
} }
struct BladePipelines { struct BladePipelines {
@ -146,7 +144,10 @@ impl BladePipelines {
shader.check_struct_size::<SurfaceParams>(); shader.check_struct_size::<SurfaceParams>();
shader.check_struct_size::<Quad>(); shader.check_struct_size::<Quad>();
shader.check_struct_size::<Shadow>(); shader.check_struct_size::<Shadow>();
shader.check_struct_size::<PathRasterizationVertex>(); assert_eq!(
mem::size_of::<PathVertex<ScaledPixels>>(),
shader.get_struct_size("PathVertex") as usize,
);
shader.check_struct_size::<PathSprite>(); shader.check_struct_size::<PathSprite>();
shader.check_struct_size::<Underline>(); shader.check_struct_size::<Underline>();
shader.check_struct_size::<MonochromeSprite>(); shader.check_struct_size::<MonochromeSprite>();
@ -204,16 +205,9 @@ impl BladePipelines {
}, },
depth_stencil: None, depth_stencil: None,
fragment: Some(shader.at("fs_path_rasterization")), fragment: Some(shader.at("fs_path_rasterization")),
// The original implementation was using ADDITIVE blende mode,
// I don't know why
// color_targets: &[gpu::ColorTargetState {
// format: PATH_TEXTURE_FORMAT,
// blend: Some(gpu::BlendState::ADDITIVE),
// write_mask: gpu::ColorWrites::default(),
// }],
color_targets: &[gpu::ColorTargetState { color_targets: &[gpu::ColorTargetState {
format: surface_info.format, format: PATH_TEXTURE_FORMAT,
blend: Some(gpu::BlendState::PREMULTIPLIED_ALPHA_BLENDING), blend: Some(gpu::BlendState::ADDITIVE),
write_mask: gpu::ColorWrites::default(), write_mask: gpu::ColorWrites::default(),
}], }],
multisample_state: gpu::MultisampleState { multisample_state: gpu::MultisampleState {
@ -232,14 +226,7 @@ impl BladePipelines {
}, },
depth_stencil: None, depth_stencil: None,
fragment: Some(shader.at("fs_path")), fragment: Some(shader.at("fs_path")),
color_targets: &[gpu::ColorTargetState { color_targets,
format: surface_info.format,
blend: Some(gpu::BlendState {
color: gpu::BlendComponent::OVER,
alpha: gpu::BlendComponent::ADDITIVE,
}),
write_mask: gpu::ColorWrites::default(),
}],
multisample_state: gpu::MultisampleState::default(), multisample_state: gpu::MultisampleState::default(),
}), }),
underlines: gpu.create_render_pipeline(gpu::RenderPipelineDesc { underlines: gpu.create_render_pipeline(gpu::RenderPipelineDesc {
@ -330,15 +317,12 @@ pub struct BladeRenderer {
last_sync_point: Option<gpu::SyncPoint>, last_sync_point: Option<gpu::SyncPoint>,
pipelines: BladePipelines, pipelines: BladePipelines,
instance_belt: BufferBelt, instance_belt: BufferBelt,
path_tiles: HashMap<PathId, AtlasTile>,
atlas: Arc<BladeAtlas>, atlas: Arc<BladeAtlas>,
atlas_sampler: gpu::Sampler, atlas_sampler: gpu::Sampler,
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
core_video_texture_cache: CVMetalTextureCache, core_video_texture_cache: CVMetalTextureCache,
path_sample_count: u32, path_sample_count: u32,
path_intermediate_texture: gpu::Texture,
path_intermediate_texture_view: gpu::TextureView,
path_intermediate_msaa_texture: Option<gpu::Texture>,
path_intermediate_msaa_texture_view: Option<gpu::TextureView>,
} }
impl BladeRenderer { impl BladeRenderer {
@ -368,43 +352,21 @@ impl BladeRenderer {
let path_sample_count = std::env::var("ZED_PATH_SAMPLE_COUNT") let path_sample_count = std::env::var("ZED_PATH_SAMPLE_COUNT")
.ok() .ok()
.and_then(|v| v.parse().ok()) .and_then(|v| v.parse().ok())
.or_else(|| { .unwrap_or(DEFAULT_PATH_SAMPLE_COUNT);
[4, 2, 1]
.into_iter()
.find(|count| context.gpu.supports_texture_sample_count(*count))
})
.unwrap_or(1);
let pipelines = BladePipelines::new(&context.gpu, surface.info(), path_sample_count); let pipelines = BladePipelines::new(&context.gpu, surface.info(), path_sample_count);
let instance_belt = BufferBelt::new(BufferBeltDescriptor { let instance_belt = BufferBelt::new(BufferBeltDescriptor {
memory: gpu::Memory::Shared, memory: gpu::Memory::Shared,
min_chunk_size: 0x1000, min_chunk_size: 0x1000,
alignment: 0x40, // Vulkan `minStorageBufferOffsetAlignment` on Intel Xe alignment: 0x40, // Vulkan `minStorageBufferOffsetAlignment` on Intel Xe
}); });
let atlas = Arc::new(BladeAtlas::new(&context.gpu)); let atlas = Arc::new(BladeAtlas::new(&context.gpu, path_sample_count));
let atlas_sampler = context.gpu.create_sampler(gpu::SamplerDesc { let atlas_sampler = context.gpu.create_sampler(gpu::SamplerDesc {
name: "path rasterization sampler", name: "atlas",
mag_filter: gpu::FilterMode::Linear, mag_filter: gpu::FilterMode::Linear,
min_filter: gpu::FilterMode::Linear, min_filter: gpu::FilterMode::Linear,
..Default::default() ..Default::default()
}); });
let (path_intermediate_texture, path_intermediate_texture_view) =
create_path_intermediate_texture(
&context.gpu,
surface.info().format,
config.size.width,
config.size.height,
);
let (path_intermediate_msaa_texture, path_intermediate_msaa_texture_view) =
create_msaa_texture_if_needed(
&context.gpu,
surface.info().format,
config.size.width,
config.size.height,
path_sample_count,
)
.unzip();
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
let core_video_texture_cache = unsafe { let core_video_texture_cache = unsafe {
CVMetalTextureCache::new( CVMetalTextureCache::new(
@ -421,15 +383,12 @@ impl BladeRenderer {
last_sync_point: None, last_sync_point: None,
pipelines, pipelines,
instance_belt, instance_belt,
path_tiles: HashMap::default(),
atlas, atlas,
atlas_sampler, atlas_sampler,
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
core_video_texture_cache, core_video_texture_cache,
path_sample_count, path_sample_count,
path_intermediate_texture,
path_intermediate_texture_view,
path_intermediate_msaa_texture,
path_intermediate_msaa_texture_view,
}) })
} }
@ -482,35 +441,6 @@ impl BladeRenderer {
self.surface_config.size = gpu_size; self.surface_config.size = gpu_size;
self.gpu self.gpu
.reconfigure_surface(&mut self.surface, self.surface_config); .reconfigure_surface(&mut self.surface, self.surface_config);
self.gpu.destroy_texture(self.path_intermediate_texture);
self.gpu
.destroy_texture_view(self.path_intermediate_texture_view);
if let Some(msaa_texture) = self.path_intermediate_msaa_texture {
self.gpu.destroy_texture(msaa_texture);
}
if let Some(msaa_view) = self.path_intermediate_msaa_texture_view {
self.gpu.destroy_texture_view(msaa_view);
}
let (path_intermediate_texture, path_intermediate_texture_view) =
create_path_intermediate_texture(
&self.gpu,
self.surface.info().format,
gpu_size.width,
gpu_size.height,
);
self.path_intermediate_texture = path_intermediate_texture;
self.path_intermediate_texture_view = path_intermediate_texture_view;
let (path_intermediate_msaa_texture, path_intermediate_msaa_texture_view) =
create_msaa_texture_if_needed(
&self.gpu,
self.surface.info().format,
gpu_size.width,
gpu_size.height,
self.path_sample_count,
)
.unzip();
self.path_intermediate_msaa_texture = path_intermediate_msaa_texture;
self.path_intermediate_msaa_texture_view = path_intermediate_msaa_texture_view;
} }
} }
@ -561,63 +491,76 @@ impl BladeRenderer {
} }
#[profiling::function] #[profiling::function]
fn draw_paths_to_intermediate( fn rasterize_paths(&mut self, paths: &[Path<ScaledPixels>]) {
&mut self, self.path_tiles.clear();
paths: &[Path<ScaledPixels>], let mut vertices_by_texture_id = HashMap::default();
width: f32,
height: f32, for path in paths {
) { let clipped_bounds = path
self.command_encoder .bounds
.init_texture(self.path_intermediate_texture); .intersect(&path.content_mask.bounds)
if let Some(msaa_texture) = self.path_intermediate_msaa_texture { .map_origin(|origin| origin.floor())
self.command_encoder.init_texture(msaa_texture); .map_size(|size| size.ceil());
let tile = self.atlas.allocate_for_rendering(
clipped_bounds.size.map(Into::into),
AtlasTextureKind::Path,
&mut self.command_encoder,
);
vertices_by_texture_id
.entry(tile.texture_id)
.or_insert(Vec::new())
.extend(path.vertices.iter().map(|vertex| PathVertex {
xy_position: vertex.xy_position - clipped_bounds.origin
+ tile.bounds.origin.map(Into::into),
st_position: vertex.st_position,
content_mask: ContentMask {
bounds: tile.bounds.map(Into::into),
},
}));
self.path_tiles.insert(path.id, tile);
} }
let target = if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { for (texture_id, vertices) in vertices_by_texture_id {
gpu::RenderTarget { let tex_info = self.atlas.get_texture_info(texture_id);
view: msaa_view,
init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack),
finish_op: gpu::FinishOp::ResolveTo(self.path_intermediate_texture_view),
}
} else {
gpu::RenderTarget {
view: self.path_intermediate_texture_view,
init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack),
finish_op: gpu::FinishOp::Store,
}
};
if let mut pass = self.command_encoder.render(
"rasterize paths",
gpu::RenderTargetSet {
colors: &[target],
depth_stencil: None,
},
) {
let globals = GlobalParams { let globals = GlobalParams {
viewport_size: [width, height], viewport_size: [tex_info.size.width as f32, tex_info.size.height as f32],
premultiplied_alpha: 0, premultiplied_alpha: 0,
pad: 0, pad: 0,
}; };
let mut encoder = pass.with(&self.pipelines.path_rasterization);
let mut vertices = Vec::new();
for path in paths {
vertices.extend(path.vertices.iter().map(|v| PathRasterizationVertex {
xy_position: v.xy_position,
st_position: v.st_position,
color: path.color,
bounds: path.bounds.intersect(&path.content_mask.bounds),
}));
}
let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) }; let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) };
encoder.bind( let frame_view = tex_info.raw_view;
0, let color_target = if let Some(msaa_view) = tex_info.msaa_view {
&ShaderPathRasterizationData { gpu::RenderTarget {
globals, view: msaa_view,
b_path_vertices: vertex_buf, init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack),
finish_op: gpu::FinishOp::ResolveTo(frame_view),
}
} else {
gpu::RenderTarget {
view: frame_view,
init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack),
finish_op: gpu::FinishOp::Store,
}
};
if let mut pass = self.command_encoder.render(
"paths",
gpu::RenderTargetSet {
colors: &[color_target],
depth_stencil: None,
}, },
); ) {
encoder.draw(0, vertices.len() as u32, 0, 1); let mut encoder = pass.with(&self.pipelines.path_rasterization);
encoder.bind(
0,
&ShaderPathRasterizationData {
globals,
b_path_vertices: vertex_buf,
},
);
encoder.draw(0, vertices.len() as u32, 0, 1);
}
} }
} }
@ -629,20 +572,12 @@ impl BladeRenderer {
self.gpu.destroy_command_encoder(&mut self.command_encoder); self.gpu.destroy_command_encoder(&mut self.command_encoder);
self.pipelines.destroy(&self.gpu); self.pipelines.destroy(&self.gpu);
self.gpu.destroy_surface(&mut self.surface); self.gpu.destroy_surface(&mut self.surface);
self.gpu.destroy_texture(self.path_intermediate_texture);
self.gpu
.destroy_texture_view(self.path_intermediate_texture_view);
if let Some(msaa_texture) = self.path_intermediate_msaa_texture {
self.gpu.destroy_texture(msaa_texture);
}
if let Some(msaa_view) = self.path_intermediate_msaa_texture_view {
self.gpu.destroy_texture_view(msaa_view);
}
} }
pub fn draw(&mut self, scene: &Scene) { pub fn draw(&mut self, scene: &Scene) {
self.command_encoder.start(); self.command_encoder.start();
self.atlas.before_frame(&mut self.command_encoder); self.atlas.before_frame(&mut self.command_encoder);
self.rasterize_paths(scene.paths());
let frame = { let frame = {
profiling::scope!("acquire frame"); profiling::scope!("acquire frame");
@ -662,7 +597,7 @@ impl BladeRenderer {
pad: 0, pad: 0,
}; };
let mut pass = self.command_encoder.render( if let mut pass = self.command_encoder.render(
"main", "main",
gpu::RenderTargetSet { gpu::RenderTargetSet {
colors: &[gpu::RenderTarget { colors: &[gpu::RenderTarget {
@ -672,235 +607,209 @@ impl BladeRenderer {
}], }],
depth_stencil: None, depth_stencil: None,
}, },
); ) {
profiling::scope!("render pass");
for batch in scene.batches() {
match batch {
PrimitiveBatch::Quads(quads) => {
let instance_buf =
unsafe { self.instance_belt.alloc_typed(quads, &self.gpu) };
let mut encoder = pass.with(&self.pipelines.quads);
encoder.bind(
0,
&ShaderQuadsData {
globals,
b_quads: instance_buf,
},
);
encoder.draw(0, 4, 0, quads.len() as u32);
}
PrimitiveBatch::Shadows(shadows) => {
let instance_buf =
unsafe { self.instance_belt.alloc_typed(shadows, &self.gpu) };
let mut encoder = pass.with(&self.pipelines.shadows);
encoder.bind(
0,
&ShaderShadowsData {
globals,
b_shadows: instance_buf,
},
);
encoder.draw(0, 4, 0, shadows.len() as u32);
}
PrimitiveBatch::Paths(paths) => {
let mut encoder = pass.with(&self.pipelines.paths);
// todo(linux): group by texture ID
for path in paths {
let tile = &self.path_tiles[&path.id];
let tex_info = self.atlas.get_texture_info(tile.texture_id);
let origin = path.bounds.intersect(&path.content_mask.bounds).origin;
let sprites = [PathSprite {
bounds: Bounds {
origin: origin.map(|p| p.floor()),
size: tile.bounds.size.map(Into::into),
},
color: path.color,
tile: (*tile).clone(),
}];
profiling::scope!("render pass"); let instance_buf =
for batch in scene.batches() { unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) };
match batch { encoder.bind(
PrimitiveBatch::Quads(quads) => { 0,
let instance_buf = unsafe { self.instance_belt.alloc_typed(quads, &self.gpu) }; &ShaderPathsData {
let mut encoder = pass.with(&self.pipelines.quads); globals,
encoder.bind( t_sprite: tex_info.raw_view,
0, s_sprite: self.atlas_sampler,
&ShaderQuadsData { b_path_sprites: instance_buf,
globals, },
b_quads: instance_buf, );
}, encoder.draw(0, 4, 0, sprites.len() as u32);
);
encoder.draw(0, 4, 0, quads.len() as u32);
}
PrimitiveBatch::Shadows(shadows) => {
let instance_buf =
unsafe { self.instance_belt.alloc_typed(shadows, &self.gpu) };
let mut encoder = pass.with(&self.pipelines.shadows);
encoder.bind(
0,
&ShaderShadowsData {
globals,
b_shadows: instance_buf,
},
);
encoder.draw(0, 4, 0, shadows.len() as u32);
}
PrimitiveBatch::Paths(paths) => {
let Some(first_path) = paths.first() else {
continue;
};
drop(pass);
self.draw_paths_to_intermediate(
paths,
self.surface_config.size.width as f32,
self.surface_config.size.height as f32,
);
pass = self.command_encoder.render(
"main",
gpu::RenderTargetSet {
colors: &[gpu::RenderTarget {
view: frame.texture_view(),
init_op: gpu::InitOp::Load,
finish_op: gpu::FinishOp::Store,
}],
depth_stencil: None,
},
);
let mut encoder = pass.with(&self.pipelines.paths);
// When copying paths from the intermediate texture to the drawable,
// each pixel must only be copied once, in case of transparent paths.
//
// If all paths have the same draw order, then their bounds are all
// disjoint, so we can copy each path's bounds individually. If this
// batch combines different draw orders, we perform a single copy
// for a minimal spanning rect.
let sprites = if paths.last().unwrap().order == first_path.order {
paths
.iter()
.map(|path| PathSprite {
bounds: path.bounds,
})
.collect()
} else {
let mut bounds = first_path.bounds;
for path in paths.iter().skip(1) {
bounds = bounds.union(&path.bounds);
} }
vec![PathSprite { bounds }] }
}; PrimitiveBatch::Underlines(underlines) => {
let instance_buf = let instance_buf =
unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) }; unsafe { self.instance_belt.alloc_typed(underlines, &self.gpu) };
encoder.bind( let mut encoder = pass.with(&self.pipelines.underlines);
0, encoder.bind(
&ShaderPathsData { 0,
globals, &ShaderUnderlinesData {
t_sprite: self.path_intermediate_texture_view, globals,
s_sprite: self.atlas_sampler, b_underlines: instance_buf,
b_path_sprites: instance_buf, },
}, );
); encoder.draw(0, 4, 0, underlines.len() as u32);
encoder.draw(0, 4, 0, sprites.len() as u32); }
} PrimitiveBatch::MonochromeSprites {
PrimitiveBatch::Underlines(underlines) => { texture_id,
let instance_buf = sprites,
unsafe { self.instance_belt.alloc_typed(underlines, &self.gpu) }; } => {
let mut encoder = pass.with(&self.pipelines.underlines); let tex_info = self.atlas.get_texture_info(texture_id);
encoder.bind( let instance_buf =
0, unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) };
&ShaderUnderlinesData { let mut encoder = pass.with(&self.pipelines.mono_sprites);
globals, encoder.bind(
b_underlines: instance_buf, 0,
}, &ShaderMonoSpritesData {
); globals,
encoder.draw(0, 4, 0, underlines.len() as u32); t_sprite: tex_info.raw_view,
} s_sprite: self.atlas_sampler,
PrimitiveBatch::MonochromeSprites { b_mono_sprites: instance_buf,
texture_id, },
sprites, );
} => { encoder.draw(0, 4, 0, sprites.len() as u32);
let tex_info = self.atlas.get_texture_info(texture_id); }
let instance_buf = PrimitiveBatch::PolychromeSprites {
unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; texture_id,
let mut encoder = pass.with(&self.pipelines.mono_sprites); sprites,
encoder.bind( } => {
0, let tex_info = self.atlas.get_texture_info(texture_id);
&ShaderMonoSpritesData { let instance_buf =
globals, unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) };
t_sprite: tex_info.raw_view, let mut encoder = pass.with(&self.pipelines.poly_sprites);
s_sprite: self.atlas_sampler, encoder.bind(
b_mono_sprites: instance_buf, 0,
}, &ShaderPolySpritesData {
); globals,
encoder.draw(0, 4, 0, sprites.len() as u32); t_sprite: tex_info.raw_view,
} s_sprite: self.atlas_sampler,
PrimitiveBatch::PolychromeSprites { b_poly_sprites: instance_buf,
texture_id, },
sprites, );
} => { encoder.draw(0, 4, 0, sprites.len() as u32);
let tex_info = self.atlas.get_texture_info(texture_id); }
let instance_buf = PrimitiveBatch::Surfaces(surfaces) => {
unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; let mut _encoder = pass.with(&self.pipelines.surfaces);
let mut encoder = pass.with(&self.pipelines.poly_sprites);
encoder.bind(
0,
&ShaderPolySpritesData {
globals,
t_sprite: tex_info.raw_view,
s_sprite: self.atlas_sampler,
b_poly_sprites: instance_buf,
},
);
encoder.draw(0, 4, 0, sprites.len() as u32);
}
PrimitiveBatch::Surfaces(surfaces) => {
let mut _encoder = pass.with(&self.pipelines.surfaces);
for surface in surfaces { for surface in surfaces {
#[cfg(not(target_os = "macos"))] #[cfg(not(target_os = "macos"))]
{ {
let _ = surface; let _ = surface;
continue; continue;
}; };
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
{ {
let (t_y, t_cb_cr) = unsafe { let (t_y, t_cb_cr) = unsafe {
use core_foundation::base::TCFType as _; use core_foundation::base::TCFType as _;
use std::ptr; use std::ptr;
assert_eq!( assert_eq!(
surface.image_buffer.get_pixel_format(), surface.image_buffer.get_pixel_format(),
core_video::pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange core_video::pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange
); );
let y_texture = self let y_texture = self
.core_video_texture_cache .core_video_texture_cache
.create_texture_from_image( .create_texture_from_image(
surface.image_buffer.as_concrete_TypeRef(), surface.image_buffer.as_concrete_TypeRef(),
ptr::null(), ptr::null(),
metal::MTLPixelFormat::R8Unorm, metal::MTLPixelFormat::R8Unorm,
surface.image_buffer.get_width_of_plane(0), surface.image_buffer.get_width_of_plane(0),
surface.image_buffer.get_height_of_plane(0), surface.image_buffer.get_height_of_plane(0),
0, 0,
)
.unwrap();
let cb_cr_texture = self
.core_video_texture_cache
.create_texture_from_image(
surface.image_buffer.as_concrete_TypeRef(),
ptr::null(),
metal::MTLPixelFormat::RG8Unorm,
surface.image_buffer.get_width_of_plane(1),
surface.image_buffer.get_height_of_plane(1),
1,
)
.unwrap();
(
gpu::TextureView::from_metal_texture(
&objc2::rc::Retained::retain(
foreign_types::ForeignTypeRef::as_ptr(
y_texture.as_texture_ref(),
)
as *mut objc2::runtime::ProtocolObject<
dyn objc2_metal::MTLTexture,
>,
) )
.unwrap(), .unwrap();
gpu::TexelAspects::COLOR, let cb_cr_texture = self
), .core_video_texture_cache
gpu::TextureView::from_metal_texture( .create_texture_from_image(
&objc2::rc::Retained::retain( surface.image_buffer.as_concrete_TypeRef(),
foreign_types::ForeignTypeRef::as_ptr( ptr::null(),
cb_cr_texture.as_texture_ref(), metal::MTLPixelFormat::RG8Unorm,
) surface.image_buffer.get_width_of_plane(1),
as *mut objc2::runtime::ProtocolObject< surface.image_buffer.get_height_of_plane(1),
dyn objc2_metal::MTLTexture, 1,
>,
) )
.unwrap(), .unwrap();
gpu::TexelAspects::COLOR, (
), gpu::TextureView::from_metal_texture(
) &objc2::rc::Retained::retain(
}; foreign_types::ForeignTypeRef::as_ptr(
y_texture.as_texture_ref(),
)
as *mut objc2::runtime::ProtocolObject<
dyn objc2_metal::MTLTexture,
>,
)
.unwrap(),
gpu::TexelAspects::COLOR,
),
gpu::TextureView::from_metal_texture(
&objc2::rc::Retained::retain(
foreign_types::ForeignTypeRef::as_ptr(
cb_cr_texture.as_texture_ref(),
)
as *mut objc2::runtime::ProtocolObject<
dyn objc2_metal::MTLTexture,
>,
)
.unwrap(),
gpu::TexelAspects::COLOR,
),
)
};
_encoder.bind( _encoder.bind(
0, 0,
&ShaderSurfacesData { &ShaderSurfacesData {
globals, globals,
surface_locals: SurfaceParams { surface_locals: SurfaceParams {
bounds: surface.bounds.into(), bounds: surface.bounds.into(),
content_mask: surface.content_mask.bounds.into(), content_mask: surface.content_mask.bounds.into(),
},
t_y,
t_cb_cr,
s_surface: self.atlas_sampler,
}, },
t_y, );
t_cb_cr,
s_surface: self.atlas_sampler,
},
);
_encoder.draw(0, 4, 0, 1); _encoder.draw(0, 4, 0, 1);
}
} }
} }
} }
} }
} }
drop(pass);
self.command_encoder.present(frame); self.command_encoder.present(frame);
let sync_point = self.gpu.submit(&mut self.command_encoder); let sync_point = self.gpu.submit(&mut self.command_encoder);
@ -908,79 +817,9 @@ impl BladeRenderer {
profiling::scope!("finish"); profiling::scope!("finish");
self.instance_belt.flush(&sync_point); self.instance_belt.flush(&sync_point);
self.atlas.after_frame(&sync_point); self.atlas.after_frame(&sync_point);
self.atlas.clear_textures(AtlasTextureKind::Path);
self.wait_for_gpu(); self.wait_for_gpu();
self.last_sync_point = Some(sync_point); self.last_sync_point = Some(sync_point);
} }
} }
fn create_path_intermediate_texture(
gpu: &gpu::Context,
format: gpu::TextureFormat,
width: u32,
height: u32,
) -> (gpu::Texture, gpu::TextureView) {
let texture = gpu.create_texture(gpu::TextureDesc {
name: "path intermediate",
format,
size: gpu::Extent {
width,
height,
depth: 1,
},
array_layer_count: 1,
mip_level_count: 1,
sample_count: 1,
dimension: gpu::TextureDimension::D2,
usage: gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE | gpu::TextureUsage::TARGET,
external: None,
});
let texture_view = gpu.create_texture_view(
texture,
gpu::TextureViewDesc {
name: "path intermediate view",
format,
dimension: gpu::ViewDimension::D2,
subresources: &Default::default(),
},
);
(texture, texture_view)
}
fn create_msaa_texture_if_needed(
gpu: &gpu::Context,
format: gpu::TextureFormat,
width: u32,
height: u32,
sample_count: u32,
) -> Option<(gpu::Texture, gpu::TextureView)> {
if sample_count <= 1 {
return None;
}
let texture_msaa = gpu.create_texture(gpu::TextureDesc {
name: "path intermediate msaa",
format,
size: gpu::Extent {
width,
height,
depth: 1,
},
array_layer_count: 1,
mip_level_count: 1,
sample_count,
dimension: gpu::TextureDimension::D2,
usage: gpu::TextureUsage::TARGET,
external: None,
});
let texture_view_msaa = gpu.create_texture_view(
texture_msaa,
gpu::TextureViewDesc {
name: "path intermediate msaa view",
format,
dimension: gpu::ViewDimension::D2,
subresources: &Default::default(),
},
);
Some((texture_msaa, texture_view_msaa))
}

View file

@ -924,19 +924,16 @@ fn fs_shadow(input: ShadowVarying) -> @location(0) vec4<f32> {
// --- path rasterization --- // // --- path rasterization --- //
struct PathRasterizationVertex { struct PathVertex {
xy_position: vec2<f32>, xy_position: vec2<f32>,
st_position: vec2<f32>, st_position: vec2<f32>,
color: Background, content_mask: Bounds,
bounds: Bounds,
} }
var<storage, read> b_path_vertices: array<PathVertex>;
var<storage, read> b_path_vertices: array<PathRasterizationVertex>;
struct PathRasterizationVarying { struct PathRasterizationVarying {
@builtin(position) position: vec4<f32>, @builtin(position) position: vec4<f32>,
@location(0) st_position: vec2<f32>, @location(0) st_position: vec2<f32>,
@location(1) vertex_id: u32,
//TODO: use `clip_distance` once Naga supports it //TODO: use `clip_distance` once Naga supports it
@location(3) clip_distances: vec4<f32>, @location(3) clip_distances: vec4<f32>,
} }
@ -948,54 +945,40 @@ fn vs_path_rasterization(@builtin(vertex_index) vertex_id: u32) -> PathRasteriza
var out = PathRasterizationVarying(); var out = PathRasterizationVarying();
out.position = to_device_position_impl(v.xy_position); out.position = to_device_position_impl(v.xy_position);
out.st_position = v.st_position; out.st_position = v.st_position;
out.vertex_id = vertex_id; out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.content_mask);
out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.bounds);
return out; return out;
} }
@fragment @fragment
fn fs_path_rasterization(input: PathRasterizationVarying) -> @location(0) vec4<f32> { fn fs_path_rasterization(input: PathRasterizationVarying) -> @location(0) f32 {
let dx = dpdx(input.st_position); let dx = dpdx(input.st_position);
let dy = dpdy(input.st_position); let dy = dpdy(input.st_position);
if (any(input.clip_distances < vec4<f32>(0.0))) { if (any(input.clip_distances < vec4<f32>(0.0))) {
return vec4<f32>(0.0); return 0.0;
} }
let v = b_path_vertices[input.vertex_id]; let gradient = 2.0 * input.st_position.xx * vec2<f32>(dx.x, dy.x) - vec2<f32>(dx.y, dy.y);
let background = v.color; let f = input.st_position.x * input.st_position.x - input.st_position.y;
let bounds = v.bounds; let distance = f / length(gradient);
return saturate(0.5 - distance);
var alpha: f32;
if (length(vec2<f32>(dx.x, dy.x)) < 0.001) {
// If the gradient is too small, return a solid color.
alpha = 1.0;
} else {
let gradient = 2.0 * input.st_position.xx * vec2<f32>(dx.x, dy.x) - vec2<f32>(dx.y, dy.y);
let f = input.st_position.x * input.st_position.x - input.st_position.y;
let distance = f / length(gradient);
alpha = saturate(0.5 - distance);
}
let gradient_color = prepare_gradient_color(
background.tag,
background.color_space,
background.solid,
background.colors,
);
let color = gradient_color(background, input.position.xy, bounds,
gradient_color.solid, gradient_color.color0, gradient_color.color1);
return vec4<f32>(color.rgb * color.a * alpha, color.a * alpha);
} }
// --- paths --- // // --- paths --- //
struct PathSprite { struct PathSprite {
bounds: Bounds, bounds: Bounds,
color: Background,
tile: AtlasTile,
} }
var<storage, read> b_path_sprites: array<PathSprite>; var<storage, read> b_path_sprites: array<PathSprite>;
struct PathVarying { struct PathVarying {
@builtin(position) position: vec4<f32>, @builtin(position) position: vec4<f32>,
@location(0) texture_coords: vec2<f32>, @location(0) tile_position: vec2<f32>,
@location(1) @interpolate(flat) instance_id: u32,
@location(2) @interpolate(flat) color_solid: vec4<f32>,
@location(3) @interpolate(flat) color0: vec4<f32>,
@location(4) @interpolate(flat) color1: vec4<f32>,
} }
@vertex @vertex
@ -1003,22 +986,33 @@ fn vs_path(@builtin(vertex_index) vertex_id: u32, @builtin(instance_index) insta
let unit_vertex = vec2<f32>(f32(vertex_id & 1u), 0.5 * f32(vertex_id & 2u)); let unit_vertex = vec2<f32>(f32(vertex_id & 1u), 0.5 * f32(vertex_id & 2u));
let sprite = b_path_sprites[instance_id]; let sprite = b_path_sprites[instance_id];
// Don't apply content mask because it was already accounted for when rasterizing the path. // Don't apply content mask because it was already accounted for when rasterizing the path.
let device_position = to_device_position(unit_vertex, sprite.bounds);
// For screen-space intermediate texture, convert screen position to texture coordinates
let screen_position = sprite.bounds.origin + unit_vertex * sprite.bounds.size;
let texture_coords = screen_position / globals.viewport_size;
var out = PathVarying(); var out = PathVarying();
out.position = device_position; out.position = to_device_position(unit_vertex, sprite.bounds);
out.texture_coords = texture_coords; out.tile_position = to_tile_position(unit_vertex, sprite.tile);
out.instance_id = instance_id;
let gradient = prepare_gradient_color(
sprite.color.tag,
sprite.color.color_space,
sprite.color.solid,
sprite.color.colors
);
out.color_solid = gradient.solid;
out.color0 = gradient.color0;
out.color1 = gradient.color1;
return out; return out;
} }
@fragment @fragment
fn fs_path(input: PathVarying) -> @location(0) vec4<f32> { fn fs_path(input: PathVarying) -> @location(0) vec4<f32> {
let sample = textureSample(t_sprite, s_sprite, input.texture_coords); let sample = textureSample(t_sprite, s_sprite, input.tile_position).r;
return sample; let mask = 1.0 - abs(1.0 - sample % 2.0);
let sprite = b_path_sprites[input.instance_id];
let background = sprite.color;
let color = gradient_color(background, input.position.xy, sprite.bounds,
input.color_solid, input.color0, input.color1);
return blend_color(color, mask);
} }
// --- underlines --- // // --- underlines --- //

View file

@ -111,7 +111,7 @@ pub struct WaylandWindowState {
resize_throttle: bool, resize_throttle: bool,
in_progress_window_controls: Option<WindowControls>, in_progress_window_controls: Option<WindowControls>,
window_controls: WindowControls, window_controls: WindowControls,
client_inset: Option<Pixels>, inset: Option<Pixels>,
} }
#[derive(Clone)] #[derive(Clone)]
@ -186,7 +186,7 @@ impl WaylandWindowState {
hovered: false, hovered: false,
in_progress_window_controls: None, in_progress_window_controls: None,
window_controls: WindowControls::default(), window_controls: WindowControls::default(),
client_inset: None, inset: None,
}) })
} }
@ -211,13 +211,6 @@ impl WaylandWindowState {
self.display = current_output; self.display = current_output;
scale scale
} }
pub fn inset(&self) -> Pixels {
match self.decorations {
WindowDecorations::Server => px(0.0),
WindowDecorations::Client => self.client_inset.unwrap_or(px(0.0)),
}
}
} }
pub(crate) struct WaylandWindow(pub WaylandWindowStatePtr); pub(crate) struct WaylandWindow(pub WaylandWindowStatePtr);
@ -387,7 +380,7 @@ impl WaylandWindowStatePtr {
configure.size = if got_unmaximized { configure.size = if got_unmaximized {
Some(state.window_bounds.size) Some(state.window_bounds.size)
} else { } else {
compute_outer_size(state.inset(), configure.size, state.tiling) compute_outer_size(state.inset, configure.size, state.tiling)
}; };
if let Some(size) = configure.size { if let Some(size) = configure.size {
state.window_bounds = Bounds { state.window_bounds = Bounds {
@ -407,7 +400,7 @@ impl WaylandWindowStatePtr {
let window_geometry = inset_by_tiling( let window_geometry = inset_by_tiling(
state.bounds.map_origin(|_| px(0.0)), state.bounds.map_origin(|_| px(0.0)),
state.inset(), state.inset.unwrap_or(px(0.0)),
state.tiling, state.tiling,
) )
.map(|v| v.0 as i32) .map(|v| v.0 as i32)
@ -825,7 +818,7 @@ impl PlatformWindow for WaylandWindow {
} else if state.maximized { } else if state.maximized {
WindowBounds::Maximized(state.window_bounds) WindowBounds::Maximized(state.window_bounds)
} else { } else {
let inset = state.inset(); let inset = state.inset.unwrap_or(px(0.));
drop(state); drop(state);
WindowBounds::Windowed(self.bounds().inset(inset)) WindowBounds::Windowed(self.bounds().inset(inset))
} }
@ -1080,8 +1073,8 @@ impl PlatformWindow for WaylandWindow {
fn set_client_inset(&self, inset: Pixels) { fn set_client_inset(&self, inset: Pixels) {
let mut state = self.borrow_mut(); let mut state = self.borrow_mut();
if Some(inset) != state.client_inset { if Some(inset) != state.inset {
state.client_inset = Some(inset); state.inset = Some(inset);
update_window(state); update_window(state);
} }
} }
@ -1101,7 +1094,9 @@ fn update_window(mut state: RefMut<WaylandWindowState>) {
state.renderer.update_transparency(!opaque); state.renderer.update_transparency(!opaque);
let mut opaque_area = state.window_bounds.map(|v| v.0 as i32); let mut opaque_area = state.window_bounds.map(|v| v.0 as i32);
opaque_area.inset(state.inset().0 as i32); if let Some(inset) = state.inset {
opaque_area.inset(inset.0 as i32);
}
let region = state let region = state
.globals .globals
@ -1174,10 +1169,12 @@ impl ResizeEdge {
/// updating to account for the client decorations. But that's not the area we want to render /// updating to account for the client decorations. But that's not the area we want to render
/// to, due to our intrusize CSD. So, here we calculate the 'actual' size, by adding back in the insets /// to, due to our intrusize CSD. So, here we calculate the 'actual' size, by adding back in the insets
fn compute_outer_size( fn compute_outer_size(
inset: Pixels, inset: Option<Pixels>,
new_size: Option<Size<Pixels>>, new_size: Option<Size<Pixels>>,
tiling: Tiling, tiling: Tiling,
) -> Option<Size<Pixels>> { ) -> Option<Size<Pixels>> {
let Some(inset) = inset else { return new_size };
new_size.map(|mut new_size| { new_size.map(|mut new_size| {
if !tiling.top { if !tiling.top {
new_size.height += inset; new_size.height += inset;

View file

@ -13,25 +13,53 @@ use std::borrow::Cow;
pub(crate) struct MetalAtlas(Mutex<MetalAtlasState>); pub(crate) struct MetalAtlas(Mutex<MetalAtlasState>);
impl MetalAtlas { impl MetalAtlas {
pub(crate) fn new(device: Device) -> Self { pub(crate) fn new(device: Device, path_sample_count: u32) -> Self {
MetalAtlas(Mutex::new(MetalAtlasState { MetalAtlas(Mutex::new(MetalAtlasState {
device: AssertSend(device), device: AssertSend(device),
monochrome_textures: Default::default(), monochrome_textures: Default::default(),
polychrome_textures: Default::default(), polychrome_textures: Default::default(),
path_textures: Default::default(),
tiles_by_key: Default::default(), tiles_by_key: Default::default(),
path_sample_count,
})) }))
} }
pub(crate) fn metal_texture(&self, id: AtlasTextureId) -> metal::Texture { pub(crate) fn metal_texture(&self, id: AtlasTextureId) -> metal::Texture {
self.0.lock().texture(id).metal_texture.clone() self.0.lock().texture(id).metal_texture.clone()
} }
pub(crate) fn msaa_texture(&self, id: AtlasTextureId) -> Option<metal::Texture> {
self.0.lock().texture(id).msaa_texture.clone()
}
pub(crate) fn allocate(
&self,
size: Size<DevicePixels>,
texture_kind: AtlasTextureKind,
) -> Option<AtlasTile> {
self.0.lock().allocate(size, texture_kind)
}
pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) {
let mut lock = self.0.lock();
let textures = match texture_kind {
AtlasTextureKind::Monochrome => &mut lock.monochrome_textures,
AtlasTextureKind::Polychrome => &mut lock.polychrome_textures,
AtlasTextureKind::Path => &mut lock.path_textures,
};
for texture in textures.iter_mut() {
texture.clear();
}
}
} }
struct MetalAtlasState { struct MetalAtlasState {
device: AssertSend<Device>, device: AssertSend<Device>,
monochrome_textures: AtlasTextureList<MetalAtlasTexture>, monochrome_textures: AtlasTextureList<MetalAtlasTexture>,
polychrome_textures: AtlasTextureList<MetalAtlasTexture>, polychrome_textures: AtlasTextureList<MetalAtlasTexture>,
path_textures: AtlasTextureList<MetalAtlasTexture>,
tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, tiles_by_key: FxHashMap<AtlasKey, AtlasTile>,
path_sample_count: u32,
} }
impl PlatformAtlas for MetalAtlas { impl PlatformAtlas for MetalAtlas {
@ -66,6 +94,7 @@ impl PlatformAtlas for MetalAtlas {
let textures = match id.kind { let textures = match id.kind {
AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, AtlasTextureKind::Monochrome => &mut lock.monochrome_textures,
AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, AtlasTextureKind::Polychrome => &mut lock.polychrome_textures,
AtlasTextureKind::Path => &mut lock.polychrome_textures,
}; };
let Some(texture_slot) = textures let Some(texture_slot) = textures
@ -99,6 +128,7 @@ impl MetalAtlasState {
let textures = match texture_kind { let textures = match texture_kind {
AtlasTextureKind::Monochrome => &mut self.monochrome_textures, AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
AtlasTextureKind::Polychrome => &mut self.polychrome_textures, AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
AtlasTextureKind::Path => &mut self.path_textures,
}; };
if let Some(tile) = textures if let Some(tile) = textures
@ -143,14 +173,31 @@ impl MetalAtlasState {
pixel_format = metal::MTLPixelFormat::BGRA8Unorm; pixel_format = metal::MTLPixelFormat::BGRA8Unorm;
usage = metal::MTLTextureUsage::ShaderRead; usage = metal::MTLTextureUsage::ShaderRead;
} }
AtlasTextureKind::Path => {
pixel_format = metal::MTLPixelFormat::R16Float;
usage = metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead;
}
} }
texture_descriptor.set_pixel_format(pixel_format); texture_descriptor.set_pixel_format(pixel_format);
texture_descriptor.set_usage(usage); texture_descriptor.set_usage(usage);
let metal_texture = self.device.new_texture(&texture_descriptor); let metal_texture = self.device.new_texture(&texture_descriptor);
// We currently only enable MSAA for path textures.
let msaa_texture = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path {
let mut descriptor = texture_descriptor.clone();
descriptor.set_texture_type(metal::MTLTextureType::D2Multisample);
descriptor.set_storage_mode(metal::MTLStorageMode::Private);
descriptor.set_sample_count(self.path_sample_count as _);
let msaa_texture = self.device.new_texture(&descriptor);
Some(msaa_texture)
} else {
None
};
let texture_list = match kind { let texture_list = match kind {
AtlasTextureKind::Monochrome => &mut self.monochrome_textures, AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
AtlasTextureKind::Polychrome => &mut self.polychrome_textures, AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
AtlasTextureKind::Path => &mut self.path_textures,
}; };
let index = texture_list.free_list.pop(); let index = texture_list.free_list.pop();
@ -162,6 +209,7 @@ impl MetalAtlasState {
}, },
allocator: etagere::BucketedAtlasAllocator::new(size.into()), allocator: etagere::BucketedAtlasAllocator::new(size.into()),
metal_texture: AssertSend(metal_texture), metal_texture: AssertSend(metal_texture),
msaa_texture: AssertSend(msaa_texture),
live_atlas_keys: 0, live_atlas_keys: 0,
}; };
@ -178,6 +226,7 @@ impl MetalAtlasState {
let textures = match id.kind { let textures = match id.kind {
crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Monochrome => &self.monochrome_textures,
crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures,
crate::AtlasTextureKind::Path => &self.path_textures,
}; };
textures[id.index as usize].as_ref().unwrap() textures[id.index as usize].as_ref().unwrap()
} }
@ -187,10 +236,15 @@ struct MetalAtlasTexture {
id: AtlasTextureId, id: AtlasTextureId,
allocator: BucketedAtlasAllocator, allocator: BucketedAtlasAllocator,
metal_texture: AssertSend<metal::Texture>, metal_texture: AssertSend<metal::Texture>,
msaa_texture: AssertSend<Option<metal::Texture>>,
live_atlas_keys: u32, live_atlas_keys: u32,
} }
impl MetalAtlasTexture { impl MetalAtlasTexture {
fn clear(&mut self) {
self.allocator.clear();
}
fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> {
let allocation = self.allocator.allocate(size.into())?; let allocation = self.allocator.allocate(size.into())?;
let tile = AtlasTile { let tile = AtlasTile {

View file

@ -1,30 +1,27 @@
use super::metal_atlas::MetalAtlas; use super::metal_atlas::MetalAtlas;
use crate::{ use crate::{
AtlasTextureId, Background, Bounds, ContentMask, DevicePixels, MonochromeSprite, PaintSurface, AtlasTextureId, AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels,
Path, Point, PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, MonochromeSprite, PaintSurface, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch,
Surface, Underline, point, size, Quad, ScaledPixels, Scene, Shadow, Size, Surface, Underline, point, size,
}; };
use anyhow::Result; use anyhow::{Context as _, Result};
use block::ConcreteBlock; use block::ConcreteBlock;
use cocoa::{ use cocoa::{
base::{NO, YES}, base::{NO, YES},
foundation::{NSSize, NSUInteger}, foundation::{NSSize, NSUInteger},
quartzcore::AutoresizingMask, quartzcore::AutoresizingMask,
}; };
use collections::HashMap;
use core_foundation::base::TCFType; use core_foundation::base::TCFType;
use core_video::{ use core_video::{
metal_texture::CVMetalTextureGetTexture, metal_texture_cache::CVMetalTextureCache, metal_texture::CVMetalTextureGetTexture, metal_texture_cache::CVMetalTextureCache,
pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange, pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange,
}; };
use foreign_types::{ForeignType, ForeignTypeRef}; use foreign_types::{ForeignType, ForeignTypeRef};
use metal::{ use metal::{CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange};
CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange,
RenderPassColorAttachmentDescriptorRef,
};
use objc::{self, msg_send, sel, sel_impl}; use objc::{self, msg_send, sel, sel_impl};
use parking_lot::Mutex; use parking_lot::Mutex;
use smallvec::SmallVec;
use std::{cell::Cell, ffi::c_void, mem, ptr, sync::Arc}; use std::{cell::Cell, ffi::c_void, mem, ptr, sync::Arc};
// Exported to metal // Exported to metal
@ -114,17 +111,6 @@ pub(crate) struct MetalRenderer {
instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>, instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>,
sprite_atlas: Arc<MetalAtlas>, sprite_atlas: Arc<MetalAtlas>,
core_video_texture_cache: core_video::metal_texture_cache::CVMetalTextureCache, core_video_texture_cache: core_video::metal_texture_cache::CVMetalTextureCache,
path_intermediate_texture: Option<metal::Texture>,
path_intermediate_msaa_texture: Option<metal::Texture>,
path_sample_count: u32,
}
#[repr(C)]
pub struct PathRasterizationVertex {
pub xy_position: Point<ScaledPixels>,
pub st_position: Point<f32>,
pub color: Background,
pub bounds: Bounds<ScaledPixels>,
} }
impl MetalRenderer { impl MetalRenderer {
@ -189,10 +175,10 @@ impl MetalRenderer {
"paths_rasterization", "paths_rasterization",
"path_rasterization_vertex", "path_rasterization_vertex",
"path_rasterization_fragment", "path_rasterization_fragment",
MTLPixelFormat::BGRA8Unorm, MTLPixelFormat::R16Float,
PATH_SAMPLE_COUNT, PATH_SAMPLE_COUNT,
); );
let path_sprites_pipeline_state = build_path_sprite_pipeline_state( let path_sprites_pipeline_state = build_pipeline_state(
&device, &device,
&library, &library,
"path_sprites", "path_sprites",
@ -250,7 +236,7 @@ impl MetalRenderer {
); );
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let sprite_atlas = Arc::new(MetalAtlas::new(device.clone())); let sprite_atlas = Arc::new(MetalAtlas::new(device.clone(), PATH_SAMPLE_COUNT));
let core_video_texture_cache = let core_video_texture_cache =
CVMetalTextureCache::new(None, device.clone(), None).unwrap(); CVMetalTextureCache::new(None, device.clone(), None).unwrap();
@ -271,9 +257,6 @@ impl MetalRenderer {
instance_buffer_pool, instance_buffer_pool,
sprite_atlas, sprite_atlas,
core_video_texture_cache, core_video_texture_cache,
path_intermediate_texture: None,
path_intermediate_msaa_texture: None,
path_sample_count: PATH_SAMPLE_COUNT,
} }
} }
@ -306,31 +289,6 @@ impl MetalRenderer {
setDrawableSize: size setDrawableSize: size
]; ];
} }
let device_pixels_size = Size {
width: DevicePixels(size.width as i32),
height: DevicePixels(size.height as i32),
};
self.update_path_intermediate_textures(device_pixels_size);
}
fn update_path_intermediate_textures(&mut self, size: Size<DevicePixels>) {
let texture_descriptor = metal::TextureDescriptor::new();
texture_descriptor.set_width(size.width.0 as u64);
texture_descriptor.set_height(size.height.0 as u64);
texture_descriptor.set_pixel_format(metal::MTLPixelFormat::BGRA8Unorm);
texture_descriptor
.set_usage(metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead);
self.path_intermediate_texture = Some(self.device.new_texture(&texture_descriptor));
if self.path_sample_count > 1 {
let mut msaa_descriptor = texture_descriptor.clone();
msaa_descriptor.set_texture_type(metal::MTLTextureType::D2Multisample);
msaa_descriptor.set_storage_mode(metal::MTLStorageMode::Private);
msaa_descriptor.set_sample_count(self.path_sample_count as _);
self.path_intermediate_msaa_texture = Some(self.device.new_texture(&msaa_descriptor));
} else {
self.path_intermediate_msaa_texture = None;
}
} }
pub fn update_transparency(&self, _transparent: bool) { pub fn update_transparency(&self, _transparent: bool) {
@ -416,18 +374,38 @@ impl MetalRenderer {
) -> Result<metal::CommandBuffer> { ) -> Result<metal::CommandBuffer> {
let command_queue = self.command_queue.clone(); let command_queue = self.command_queue.clone();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let alpha = if self.layer.is_opaque() { 1. } else { 0. };
let mut instance_offset = 0; let mut instance_offset = 0;
let mut command_encoder = new_command_encoder( let path_tiles = self
command_buffer, .rasterize_paths(
drawable, scene.paths(),
viewport_size, instance_buffer,
|color_attachment| { &mut instance_offset,
color_attachment.set_load_action(metal::MTLLoadAction::Clear); command_buffer,
color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha)); )
}, .with_context(|| format!("rasterizing {} paths", scene.paths().len()))?;
);
let render_pass_descriptor = metal::RenderPassDescriptor::new();
let color_attachment = render_pass_descriptor
.color_attachments()
.object_at(0)
.unwrap();
color_attachment.set_texture(Some(drawable.texture()));
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
color_attachment.set_store_action(metal::MTLStoreAction::Store);
let alpha = if self.layer.is_opaque() { 1. } else { 0. };
color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha));
let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor);
command_encoder.set_viewport(metal::MTLViewport {
originX: 0.0,
originY: 0.0,
width: i32::from(viewport_size.width) as f64,
height: i32::from(viewport_size.height) as f64,
znear: 0.0,
zfar: 1.0,
});
for batch in scene.batches() { for batch in scene.batches() {
let ok = match batch { let ok = match batch {
@ -436,53 +414,29 @@ impl MetalRenderer {
instance_buffer, instance_buffer,
&mut instance_offset, &mut instance_offset,
viewport_size, viewport_size,
&command_encoder, command_encoder,
), ),
PrimitiveBatch::Quads(quads) => self.draw_quads( PrimitiveBatch::Quads(quads) => self.draw_quads(
quads, quads,
instance_buffer, instance_buffer,
&mut instance_offset, &mut instance_offset,
viewport_size, viewport_size,
&command_encoder, command_encoder,
),
PrimitiveBatch::Paths(paths) => self.draw_paths(
paths,
&path_tiles,
instance_buffer,
&mut instance_offset,
viewport_size,
command_encoder,
), ),
PrimitiveBatch::Paths(paths) => {
command_encoder.end_encoding();
let did_draw = self.draw_paths_to_intermediate(
paths,
instance_buffer,
&mut instance_offset,
viewport_size,
command_buffer,
);
command_encoder = new_command_encoder(
command_buffer,
drawable,
viewport_size,
|color_attachment| {
color_attachment.set_load_action(metal::MTLLoadAction::Load);
},
);
if did_draw {
self.draw_paths_from_intermediate(
paths,
instance_buffer,
&mut instance_offset,
viewport_size,
&command_encoder,
)
} else {
false
}
}
PrimitiveBatch::Underlines(underlines) => self.draw_underlines( PrimitiveBatch::Underlines(underlines) => self.draw_underlines(
underlines, underlines,
instance_buffer, instance_buffer,
&mut instance_offset, &mut instance_offset,
viewport_size, viewport_size,
&command_encoder, command_encoder,
), ),
PrimitiveBatch::MonochromeSprites { PrimitiveBatch::MonochromeSprites {
texture_id, texture_id,
@ -493,7 +447,7 @@ impl MetalRenderer {
instance_buffer, instance_buffer,
&mut instance_offset, &mut instance_offset,
viewport_size, viewport_size,
&command_encoder, command_encoder,
), ),
PrimitiveBatch::PolychromeSprites { PrimitiveBatch::PolychromeSprites {
texture_id, texture_id,
@ -504,16 +458,17 @@ impl MetalRenderer {
instance_buffer, instance_buffer,
&mut instance_offset, &mut instance_offset,
viewport_size, viewport_size,
&command_encoder, command_encoder,
), ),
PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces( PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces(
surfaces, surfaces,
instance_buffer, instance_buffer,
&mut instance_offset, &mut instance_offset,
viewport_size, viewport_size,
&command_encoder, command_encoder,
), ),
}; };
if !ok { if !ok {
command_encoder.end_encoding(); command_encoder.end_encoding();
anyhow::bail!( anyhow::bail!(
@ -538,90 +493,104 @@ impl MetalRenderer {
Ok(command_buffer.to_owned()) Ok(command_buffer.to_owned())
} }
fn draw_paths_to_intermediate( fn rasterize_paths(
&self, &self,
paths: &[Path<ScaledPixels>], paths: &[Path<ScaledPixels>],
instance_buffer: &mut InstanceBuffer, instance_buffer: &mut InstanceBuffer,
instance_offset: &mut usize, instance_offset: &mut usize,
viewport_size: Size<DevicePixels>,
command_buffer: &metal::CommandBufferRef, command_buffer: &metal::CommandBufferRef,
) -> bool { ) -> Option<HashMap<PathId, AtlasTile>> {
if paths.is_empty() { self.sprite_atlas.clear_textures(AtlasTextureKind::Path);
return true;
}
let Some(intermediate_texture) = &self.path_intermediate_texture else {
return false;
};
let render_pass_descriptor = metal::RenderPassDescriptor::new(); let mut tiles = HashMap::default();
let color_attachment = render_pass_descriptor let mut vertices_by_texture_id = HashMap::default();
.color_attachments()
.object_at(0)
.unwrap();
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., 0.));
if let Some(msaa_texture) = &self.path_intermediate_msaa_texture {
color_attachment.set_texture(Some(msaa_texture));
color_attachment.set_resolve_texture(Some(intermediate_texture));
color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve);
} else {
color_attachment.set_texture(Some(intermediate_texture));
color_attachment.set_store_action(metal::MTLStoreAction::Store);
}
let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor);
command_encoder.set_render_pipeline_state(&self.paths_rasterization_pipeline_state);
align_offset(instance_offset);
let mut vertices = Vec::new();
for path in paths { for path in paths {
vertices.extend(path.vertices.iter().map(|v| PathRasterizationVertex { let clipped_bounds = path.bounds.intersect(&path.content_mask.bounds);
xy_position: v.xy_position,
st_position: v.st_position,
color: path.color,
bounds: path.bounds.intersect(&path.content_mask.bounds),
}));
}
let vertices_bytes_len = mem::size_of_val(vertices.as_slice());
let next_offset = *instance_offset + vertices_bytes_len;
if next_offset > instance_buffer.size {
command_encoder.end_encoding();
return false;
}
command_encoder.set_vertex_buffer(
PathRasterizationInputIndex::Vertices as u64,
Some(&instance_buffer.metal_buffer),
*instance_offset as u64,
);
command_encoder.set_vertex_bytes(
PathRasterizationInputIndex::ViewportSize as u64,
mem::size_of_val(&viewport_size) as u64,
&viewport_size as *const Size<DevicePixels> as *const _,
);
command_encoder.set_fragment_buffer(
PathRasterizationInputIndex::Vertices as u64,
Some(&instance_buffer.metal_buffer),
*instance_offset as u64,
);
let buffer_contents =
unsafe { (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) };
unsafe {
ptr::copy_nonoverlapping(
vertices.as_ptr() as *const u8,
buffer_contents,
vertices_bytes_len,
);
}
command_encoder.draw_primitives(
metal::MTLPrimitiveType::Triangle,
0,
vertices.len() as u64,
);
*instance_offset = next_offset;
command_encoder.end_encoding(); let tile = self
true .sprite_atlas
.allocate(clipped_bounds.size.map(Into::into), AtlasTextureKind::Path)?;
vertices_by_texture_id
.entry(tile.texture_id)
.or_insert(Vec::new())
.extend(path.vertices.iter().map(|vertex| PathVertex {
xy_position: vertex.xy_position - clipped_bounds.origin
+ tile.bounds.origin.map(Into::into),
st_position: vertex.st_position,
content_mask: ContentMask {
bounds: tile.bounds.map(Into::into),
},
}));
tiles.insert(path.id, tile);
}
for (texture_id, vertices) in vertices_by_texture_id {
align_offset(instance_offset);
let vertices_bytes_len = mem::size_of_val(vertices.as_slice());
let next_offset = *instance_offset + vertices_bytes_len;
if next_offset > instance_buffer.size {
return None;
}
let render_pass_descriptor = metal::RenderPassDescriptor::new();
let color_attachment = render_pass_descriptor
.color_attachments()
.object_at(0)
.unwrap();
let texture = self.sprite_atlas.metal_texture(texture_id);
let msaa_texture = self.sprite_atlas.msaa_texture(texture_id);
if let Some(msaa_texture) = msaa_texture {
color_attachment.set_texture(Some(&msaa_texture));
color_attachment.set_resolve_texture(Some(&texture));
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve);
} else {
color_attachment.set_texture(Some(&texture));
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
color_attachment.set_store_action(metal::MTLStoreAction::Store);
}
color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., 1.));
let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor);
command_encoder.set_render_pipeline_state(&self.paths_rasterization_pipeline_state);
command_encoder.set_vertex_buffer(
PathRasterizationInputIndex::Vertices as u64,
Some(&instance_buffer.metal_buffer),
*instance_offset as u64,
);
let texture_size = Size {
width: DevicePixels::from(texture.width()),
height: DevicePixels::from(texture.height()),
};
command_encoder.set_vertex_bytes(
PathRasterizationInputIndex::AtlasTextureSize as u64,
mem::size_of_val(&texture_size) as u64,
&texture_size as *const Size<DevicePixels> as *const _,
);
let buffer_contents = unsafe {
(instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset)
};
unsafe {
ptr::copy_nonoverlapping(
vertices.as_ptr() as *const u8,
buffer_contents,
vertices_bytes_len,
);
}
command_encoder.draw_primitives(
metal::MTLPrimitiveType::Triangle,
0,
vertices.len() as u64,
);
command_encoder.end_encoding();
*instance_offset = next_offset;
}
Some(tiles)
} }
fn draw_shadows( fn draw_shadows(
@ -746,21 +715,18 @@ impl MetalRenderer {
true true
} }
fn draw_paths_from_intermediate( fn draw_paths(
&self, &self,
paths: &[Path<ScaledPixels>], paths: &[Path<ScaledPixels>],
tiles_by_path_id: &HashMap<PathId, AtlasTile>,
instance_buffer: &mut InstanceBuffer, instance_buffer: &mut InstanceBuffer,
instance_offset: &mut usize, instance_offset: &mut usize,
viewport_size: Size<DevicePixels>, viewport_size: Size<DevicePixels>,
command_encoder: &metal::RenderCommandEncoderRef, command_encoder: &metal::RenderCommandEncoderRef,
) -> bool { ) -> bool {
let Some(ref first_path) = paths.first() else { if paths.is_empty() {
return true; return true;
}; }
let Some(ref intermediate_texture) = self.path_intermediate_texture else {
return false;
};
command_encoder.set_render_pipeline_state(&self.path_sprites_pipeline_state); command_encoder.set_render_pipeline_state(&self.path_sprites_pipeline_state);
command_encoder.set_vertex_buffer( command_encoder.set_vertex_buffer(
@ -774,65 +740,88 @@ impl MetalRenderer {
&viewport_size as *const Size<DevicePixels> as *const _, &viewport_size as *const Size<DevicePixels> as *const _,
); );
command_encoder.set_fragment_texture( let mut prev_texture_id = None;
SpriteInputIndex::AtlasTexture as u64, let mut sprites = SmallVec::<[_; 1]>::new();
Some(intermediate_texture), let mut paths_and_tiles = paths
); .iter()
.map(|path| (path, tiles_by_path_id.get(&path.id).unwrap()))
.peekable();
// When copying paths from the intermediate texture to the drawable, loop {
// each pixel must only be copied once, in case of transparent paths. if let Some((path, tile)) = paths_and_tiles.peek() {
// if prev_texture_id.map_or(true, |texture_id| texture_id == tile.texture_id) {
// If all paths have the same draw order, then their bounds are all prev_texture_id = Some(tile.texture_id);
// disjoint, so we can copy each path's bounds individually. If this let origin = path.bounds.intersect(&path.content_mask.bounds).origin;
// batch combines different draw orders, we perform a single copy sprites.push(PathSprite {
// for a minimal spanning rect. bounds: Bounds {
let sprites; origin: origin.map(|p| p.floor()),
if paths.last().unwrap().order == first_path.order { size: tile.bounds.size.map(Into::into),
sprites = paths },
.iter() color: path.color,
.map(|path| PathSprite { tile: (*tile).clone(),
bounds: path.bounds, });
}) paths_and_tiles.next();
.collect(); continue;
} else { }
let mut bounds = first_path.bounds; }
for path in paths.iter().skip(1) {
bounds = bounds.union(&path.bounds); if sprites.is_empty() {
break;
} else {
align_offset(instance_offset);
let texture_id = prev_texture_id.take().unwrap();
let texture: metal::Texture = self.sprite_atlas.metal_texture(texture_id);
let texture_size = size(
DevicePixels(texture.width() as i32),
DevicePixels(texture.height() as i32),
);
command_encoder.set_vertex_buffer(
SpriteInputIndex::Sprites as u64,
Some(&instance_buffer.metal_buffer),
*instance_offset as u64,
);
command_encoder.set_vertex_bytes(
SpriteInputIndex::AtlasTextureSize as u64,
mem::size_of_val(&texture_size) as u64,
&texture_size as *const Size<DevicePixels> as *const _,
);
command_encoder.set_fragment_buffer(
SpriteInputIndex::Sprites as u64,
Some(&instance_buffer.metal_buffer),
*instance_offset as u64,
);
command_encoder
.set_fragment_texture(SpriteInputIndex::AtlasTexture as u64, Some(&texture));
let sprite_bytes_len = mem::size_of_val(sprites.as_slice());
let next_offset = *instance_offset + sprite_bytes_len;
if next_offset > instance_buffer.size {
return false;
}
let buffer_contents = unsafe {
(instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset)
};
unsafe {
ptr::copy_nonoverlapping(
sprites.as_ptr() as *const u8,
buffer_contents,
sprite_bytes_len,
);
}
command_encoder.draw_primitives_instanced(
metal::MTLPrimitiveType::Triangle,
0,
6,
sprites.len() as u64,
);
*instance_offset = next_offset;
sprites.clear();
} }
sprites = vec![PathSprite { bounds }];
} }
align_offset(instance_offset);
let sprite_bytes_len = mem::size_of_val(sprites.as_slice());
let next_offset = *instance_offset + sprite_bytes_len;
if next_offset > instance_buffer.size {
return false;
}
command_encoder.set_vertex_buffer(
SpriteInputIndex::Sprites as u64,
Some(&instance_buffer.metal_buffer),
*instance_offset as u64,
);
let buffer_contents =
unsafe { (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) };
unsafe {
ptr::copy_nonoverlapping(
sprites.as_ptr() as *const u8,
buffer_contents,
sprite_bytes_len,
);
}
command_encoder.draw_primitives_instanced(
metal::MTLPrimitiveType::Triangle,
0,
6,
sprites.len() as u64,
);
*instance_offset = next_offset;
true true
} }
@ -1147,33 +1136,6 @@ impl MetalRenderer {
} }
} }
fn new_command_encoder<'a>(
command_buffer: &'a metal::CommandBufferRef,
drawable: &'a metal::MetalDrawableRef,
viewport_size: Size<DevicePixels>,
configure_color_attachment: impl Fn(&RenderPassColorAttachmentDescriptorRef),
) -> &'a metal::RenderCommandEncoderRef {
let render_pass_descriptor = metal::RenderPassDescriptor::new();
let color_attachment = render_pass_descriptor
.color_attachments()
.object_at(0)
.unwrap();
color_attachment.set_texture(Some(drawable.texture()));
color_attachment.set_store_action(metal::MTLStoreAction::Store);
configure_color_attachment(color_attachment);
let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor);
command_encoder.set_viewport(metal::MTLViewport {
originX: 0.0,
originY: 0.0,
width: i32::from(viewport_size.width) as f64,
height: i32::from(viewport_size.height) as f64,
znear: 0.0,
zfar: 1.0,
});
command_encoder
}
fn build_pipeline_state( fn build_pipeline_state(
device: &metal::DeviceRef, device: &metal::DeviceRef,
library: &metal::LibraryRef, library: &metal::LibraryRef,
@ -1208,40 +1170,6 @@ fn build_pipeline_state(
.expect("could not create render pipeline state") .expect("could not create render pipeline state")
} }
fn build_path_sprite_pipeline_state(
device: &metal::DeviceRef,
library: &metal::LibraryRef,
label: &str,
vertex_fn_name: &str,
fragment_fn_name: &str,
pixel_format: metal::MTLPixelFormat,
) -> metal::RenderPipelineState {
let vertex_fn = library
.get_function(vertex_fn_name, None)
.expect("error locating vertex function");
let fragment_fn = library
.get_function(fragment_fn_name, None)
.expect("error locating fragment function");
let descriptor = metal::RenderPipelineDescriptor::new();
descriptor.set_label(label);
descriptor.set_vertex_function(Some(vertex_fn.as_ref()));
descriptor.set_fragment_function(Some(fragment_fn.as_ref()));
let color_attachment = descriptor.color_attachments().object_at(0).unwrap();
color_attachment.set_pixel_format(pixel_format);
color_attachment.set_blending_enabled(true);
color_attachment.set_rgb_blend_operation(metal::MTLBlendOperation::Add);
color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add);
color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One);
color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One);
color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha);
color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::One);
device
.new_render_pipeline_state(&descriptor)
.expect("could not create render pipeline state")
}
fn build_path_rasterization_pipeline_state( fn build_path_rasterization_pipeline_state(
device: &metal::DeviceRef, device: &metal::DeviceRef,
library: &metal::LibraryRef, library: &metal::LibraryRef,
@ -1264,7 +1192,7 @@ fn build_path_rasterization_pipeline_state(
descriptor.set_fragment_function(Some(fragment_fn.as_ref())); descriptor.set_fragment_function(Some(fragment_fn.as_ref()));
if path_sample_count > 1 { if path_sample_count > 1 {
descriptor.set_raster_sample_count(path_sample_count as _); descriptor.set_raster_sample_count(path_sample_count as _);
descriptor.set_alpha_to_coverage_enabled(false); descriptor.set_alpha_to_coverage_enabled(true);
} }
let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); let color_attachment = descriptor.color_attachments().object_at(0).unwrap();
color_attachment.set_pixel_format(pixel_format); color_attachment.set_pixel_format(pixel_format);
@ -1273,8 +1201,8 @@ fn build_path_rasterization_pipeline_state(
color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add); color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add);
color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One); color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One);
color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One); color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One);
color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::One);
color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::One);
device device
.new_render_pipeline_state(&descriptor) .new_render_pipeline_state(&descriptor)
@ -1329,13 +1257,15 @@ enum SurfaceInputIndex {
#[repr(C)] #[repr(C)]
enum PathRasterizationInputIndex { enum PathRasterizationInputIndex {
Vertices = 0, Vertices = 0,
ViewportSize = 1, AtlasTextureSize = 1,
} }
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[repr(C)] #[repr(C)]
pub struct PathSprite { pub struct PathSprite {
pub bounds: Bounds<ScaledPixels>, pub bounds: Bounds<ScaledPixels>,
pub color: Background,
pub tile: AtlasTile,
} }
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]

View file

@ -701,117 +701,107 @@ fragment float4 polychrome_sprite_fragment(
struct PathRasterizationVertexOutput { struct PathRasterizationVertexOutput {
float4 position [[position]]; float4 position [[position]];
float2 st_position; float2 st_position;
uint vertex_id [[flat]];
float clip_rect_distance [[clip_distance]][4]; float clip_rect_distance [[clip_distance]][4];
}; };
struct PathRasterizationFragmentInput { struct PathRasterizationFragmentInput {
float4 position [[position]]; float4 position [[position]];
float2 st_position; float2 st_position;
uint vertex_id [[flat]];
}; };
vertex PathRasterizationVertexOutput path_rasterization_vertex( vertex PathRasterizationVertexOutput path_rasterization_vertex(
uint vertex_id [[vertex_id]], uint vertex_id [[vertex_id]],
constant PathRasterizationVertex *vertices [[buffer(PathRasterizationInputIndex_Vertices)]], constant PathVertex_ScaledPixels *vertices
constant Size_DevicePixels *atlas_size [[buffer(PathRasterizationInputIndex_ViewportSize)]] [[buffer(PathRasterizationInputIndex_Vertices)]],
) { constant Size_DevicePixels *atlas_size
PathRasterizationVertex v = vertices[vertex_id]; [[buffer(PathRasterizationInputIndex_AtlasTextureSize)]]) {
PathVertex_ScaledPixels v = vertices[vertex_id];
float2 vertex_position = float2(v.xy_position.x, v.xy_position.y); float2 vertex_position = float2(v.xy_position.x, v.xy_position.y);
float4 position = float4( float2 viewport_size = float2(atlas_size->width, atlas_size->height);
vertex_position * float2(2. / atlas_size->width, -2. / atlas_size->height) + float2(-1., 1.),
0.,
1.
);
return PathRasterizationVertexOutput{ return PathRasterizationVertexOutput{
position, float4(vertex_position / viewport_size * float2(2., -2.) +
float2(-1., 1.),
0., 1.),
float2(v.st_position.x, v.st_position.y), float2(v.st_position.x, v.st_position.y),
vertex_id, {v.xy_position.x - v.content_mask.bounds.origin.x,
{ v.content_mask.bounds.origin.x + v.content_mask.bounds.size.width -
v.xy_position.x - v.bounds.origin.x, v.xy_position.x,
v.bounds.origin.x + v.bounds.size.width - v.xy_position.x, v.xy_position.y - v.content_mask.bounds.origin.y,
v.xy_position.y - v.bounds.origin.y, v.content_mask.bounds.origin.y + v.content_mask.bounds.size.height -
v.bounds.origin.y + v.bounds.size.height - v.xy_position.y v.xy_position.y}};
}
};
} }
fragment float4 path_rasterization_fragment( fragment float4 path_rasterization_fragment(PathRasterizationFragmentInput input
PathRasterizationFragmentInput input [[stage_in]], [[stage_in]]) {
constant PathRasterizationVertex *vertices [[buffer(PathRasterizationInputIndex_Vertices)]]
) {
float2 dx = dfdx(input.st_position); float2 dx = dfdx(input.st_position);
float2 dy = dfdy(input.st_position); float2 dy = dfdy(input.st_position);
float2 gradient = float2((2. * input.st_position.x) * dx.x - dx.y,
PathRasterizationVertex v = vertices[input.vertex_id]; (2. * input.st_position.x) * dy.x - dy.y);
Background background = v.color; float f = (input.st_position.x * input.st_position.x) - input.st_position.y;
Bounds_ScaledPixels path_bounds = v.bounds; float distance = f / length(gradient);
float alpha; float alpha = saturate(0.5 - distance);
if (length(float2(dx.x, dy.x)) < 0.001) { return float4(alpha, 0., 0., 1.);
alpha = 1.0;
} else {
float2 gradient = float2(
(2. * input.st_position.x) * dx.x - dx.y,
(2. * input.st_position.x) * dy.x - dy.y
);
float f = (input.st_position.x * input.st_position.x) - input.st_position.y;
float distance = f / length(gradient);
alpha = saturate(0.5 - distance);
}
GradientColor gradient_color = prepare_fill_color(
background.tag,
background.color_space,
background.solid,
background.colors[0].color,
background.colors[1].color
);
float4 color = fill_color(
background,
input.position.xy,
path_bounds,
gradient_color.solid,
gradient_color.color0,
gradient_color.color1
);
return float4(color.rgb * color.a * alpha, alpha * color.a);
} }
struct PathSpriteVertexOutput { struct PathSpriteVertexOutput {
float4 position [[position]]; float4 position [[position]];
float2 texture_coords; float2 tile_position;
uint sprite_id [[flat]];
float4 solid_color [[flat]];
float4 color0 [[flat]];
float4 color1 [[flat]];
}; };
vertex PathSpriteVertexOutput path_sprite_vertex( vertex PathSpriteVertexOutput path_sprite_vertex(
uint unit_vertex_id [[vertex_id]], uint unit_vertex_id [[vertex_id]], uint sprite_id [[instance_id]],
uint sprite_id [[instance_id]], constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]],
constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]], constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]],
constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], constant Size_DevicePixels *viewport_size
constant Size_DevicePixels *viewport_size [[buffer(SpriteInputIndex_ViewportSize)]] [[buffer(SpriteInputIndex_ViewportSize)]],
) { constant Size_DevicePixels *atlas_size
[[buffer(SpriteInputIndex_AtlasTextureSize)]]) {
float2 unit_vertex = unit_vertices[unit_vertex_id]; float2 unit_vertex = unit_vertices[unit_vertex_id];
PathSprite sprite = sprites[sprite_id]; PathSprite sprite = sprites[sprite_id];
// Don't apply content mask because it was already accounted for when // Don't apply content mask because it was already accounted for when
// rasterizing the path. // rasterizing the path.
float4 device_position = float4 device_position =
to_device_position(unit_vertex, sprite.bounds, viewport_size); to_device_position(unit_vertex, sprite.bounds, viewport_size);
float2 tile_position = to_tile_position(unit_vertex, sprite.tile, atlas_size);
float2 screen_position = float2(sprite.bounds.origin.x, sprite.bounds.origin.y) + unit_vertex * float2(sprite.bounds.size.width, sprite.bounds.size.height); GradientColor gradient = prepare_fill_color(
float2 texture_coords = screen_position / float2(viewport_size->width, viewport_size->height); sprite.color.tag,
sprite.color.color_space,
sprite.color.solid,
sprite.color.colors[0].color,
sprite.color.colors[1].color
);
return PathSpriteVertexOutput{ return PathSpriteVertexOutput{
device_position, device_position,
texture_coords tile_position,
sprite_id,
gradient.solid,
gradient.color0,
gradient.color1
}; };
} }
fragment float4 path_sprite_fragment( fragment float4 path_sprite_fragment(
PathSpriteVertexOutput input [[stage_in]], PathSpriteVertexOutput input [[stage_in]],
texture2d<float> intermediate_texture [[texture(SpriteInputIndex_AtlasTexture)]] constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]],
) { texture2d<float> atlas_texture [[texture(SpriteInputIndex_AtlasTexture)]]) {
constexpr sampler intermediate_texture_sampler(mag_filter::linear, min_filter::linear); constexpr sampler atlas_texture_sampler(mag_filter::linear,
return intermediate_texture.sample(intermediate_texture_sampler, input.texture_coords); min_filter::linear);
float4 sample =
atlas_texture.sample(atlas_texture_sampler, input.tile_position);
float mask = 1. - abs(1. - fmod(sample.r, 2.));
PathSprite sprite = sprites[input.sprite_id];
Background background = sprite.color;
float4 color = fill_color(background, input.position.xy, sprite.bounds,
input.solid_color, input.color0, input.color1);
color.a *= mask;
return color;
} }
struct SurfaceVertexOutput { struct SurfaceVertexOutput {

View file

@ -341,7 +341,7 @@ impl PlatformAtlas for TestAtlas {
crate::AtlasTile { crate::AtlasTile {
texture_id: AtlasTextureId { texture_id: AtlasTextureId {
index: texture_id, index: texture_id,
kind: crate::AtlasTextureKind::Monochrome, kind: crate::AtlasTextureKind::Path,
}, },
tile_id: TileId(tile_id), tile_id: TileId(tile_id),
padding: 0, padding: 0,

View file

@ -43,6 +43,17 @@ impl Scene {
self.surfaces.clear(); self.surfaces.clear();
} }
#[cfg_attr(
all(
any(target_os = "linux", target_os = "freebsd"),
not(any(feature = "x11", feature = "wayland"))
),
allow(dead_code)
)]
pub fn paths(&self) -> &[Path<ScaledPixels>] {
&self.paths
}
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.paint_operations.len() self.paint_operations.len()
} }
@ -670,7 +681,7 @@ pub(crate) struct PathId(pub(crate) usize);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Path<P: Clone + Debug + Default + PartialEq> { pub struct Path<P: Clone + Debug + Default + PartialEq> {
pub(crate) id: PathId, pub(crate) id: PathId,
pub(crate) order: DrawOrder, order: DrawOrder,
pub(crate) bounds: Bounds<P>, pub(crate) bounds: Bounds<P>,
pub(crate) content_mask: ContentMask<P>, pub(crate) content_mask: ContentMask<P>,
pub(crate) vertices: Vec<PathVertex<P>>, pub(crate) vertices: Vec<PathVertex<P>>,

View file

@ -5,7 +5,7 @@ use crate::{FocusHandle, FocusId};
/// Used to manage the `Tab` event to switch between focus handles. /// Used to manage the `Tab` event to switch between focus handles.
#[derive(Default)] #[derive(Default)]
pub(crate) struct TabHandles { pub(crate) struct TabHandles {
pub(crate) handles: Vec<FocusHandle>, handles: Vec<FocusHandle>,
} }
impl TabHandles { impl TabHandles {
@ -32,18 +32,20 @@ impl TabHandles {
self.handles.clear(); self.handles.clear();
} }
fn current_index(&self, focused_id: Option<&FocusId>) -> Option<usize> { fn current_index(&self, focused_id: Option<&FocusId>) -> usize {
self.handles.iter().position(|h| Some(&h.id) == focused_id) self.handles
.iter()
.position(|h| Some(&h.id) == focused_id)
.unwrap_or_default()
} }
pub(crate) fn next(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { pub(crate) fn next(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> {
let next_ix = self let ix = self.current_index(focused_id);
.current_index(focused_id)
.and_then(|ix| { let mut next_ix = ix + 1;
let next_ix = ix + 1; if next_ix + 1 > self.handles.len() {
(next_ix < self.handles.len()).then_some(next_ix) next_ix = 0;
}) }
.unwrap_or_default();
if let Some(next_handle) = self.handles.get(next_ix) { if let Some(next_handle) = self.handles.get(next_ix) {
Some(next_handle.clone()) Some(next_handle.clone())
@ -53,7 +55,7 @@ impl TabHandles {
} }
pub(crate) fn prev(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { pub(crate) fn prev(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> {
let ix = self.current_index(focused_id).unwrap_or_default(); let ix = self.current_index(focused_id);
let prev_ix; let prev_ix;
if ix == 0 { if ix == 0 {
prev_ix = self.handles.len().saturating_sub(1); prev_ix = self.handles.len().saturating_sub(1);
@ -106,14 +108,8 @@ mod tests {
] ]
); );
// Select first tab index if no handle is currently focused. // next
assert_eq!(tab.next(None), Some(tab.handles[0].clone())); assert_eq!(tab.next(None), Some(tab.handles[1].clone()));
// Select last tab index if no handle is currently focused.
assert_eq!(
tab.prev(None),
Some(tab.handles[tab.handles.len() - 1].clone())
);
assert_eq!( assert_eq!(
tab.next(Some(&tab.handles[0].id)), tab.next(Some(&tab.handles[0].id)),
Some(tab.handles[1].clone()) Some(tab.handles[1].clone())

View file

@ -283,7 +283,7 @@ impl ToTaffy<taffy::style::LengthPercentageAuto> for Length {
fn to_taffy(&self, rem_size: Pixels) -> taffy::prelude::LengthPercentageAuto { fn to_taffy(&self, rem_size: Pixels) -> taffy::prelude::LengthPercentageAuto {
match self { match self {
Length::Definite(length) => length.to_taffy(rem_size), Length::Definite(length) => length.to_taffy(rem_size),
Length::Auto => taffy::prelude::LengthPercentageAuto::auto(), Length::Auto => taffy::prelude::LengthPercentageAuto::Auto,
} }
} }
} }
@ -292,7 +292,7 @@ impl ToTaffy<taffy::style::Dimension> for Length {
fn to_taffy(&self, rem_size: Pixels) -> taffy::prelude::Dimension { fn to_taffy(&self, rem_size: Pixels) -> taffy::prelude::Dimension {
match self { match self {
Length::Definite(length) => length.to_taffy(rem_size), Length::Definite(length) => length.to_taffy(rem_size),
Length::Auto => taffy::prelude::Dimension::auto(), Length::Auto => taffy::prelude::Dimension::Auto,
} }
} }
} }
@ -302,14 +302,14 @@ impl ToTaffy<taffy::style::LengthPercentage> for DefiniteLength {
match self { match self {
DefiniteLength::Absolute(length) => match length { DefiniteLength::Absolute(length) => match length {
AbsoluteLength::Pixels(pixels) => { AbsoluteLength::Pixels(pixels) => {
taffy::style::LengthPercentage::length(pixels.into()) taffy::style::LengthPercentage::Length(pixels.into())
} }
AbsoluteLength::Rems(rems) => { AbsoluteLength::Rems(rems) => {
taffy::style::LengthPercentage::length((*rems * rem_size).into()) taffy::style::LengthPercentage::Length((*rems * rem_size).into())
} }
}, },
DefiniteLength::Fraction(fraction) => { DefiniteLength::Fraction(fraction) => {
taffy::style::LengthPercentage::percent(*fraction) taffy::style::LengthPercentage::Percent(*fraction)
} }
} }
} }
@ -320,14 +320,14 @@ impl ToTaffy<taffy::style::LengthPercentageAuto> for DefiniteLength {
match self { match self {
DefiniteLength::Absolute(length) => match length { DefiniteLength::Absolute(length) => match length {
AbsoluteLength::Pixels(pixels) => { AbsoluteLength::Pixels(pixels) => {
taffy::style::LengthPercentageAuto::length(pixels.into()) taffy::style::LengthPercentageAuto::Length(pixels.into())
} }
AbsoluteLength::Rems(rems) => { AbsoluteLength::Rems(rems) => {
taffy::style::LengthPercentageAuto::length((*rems * rem_size).into()) taffy::style::LengthPercentageAuto::Length((*rems * rem_size).into())
} }
}, },
DefiniteLength::Fraction(fraction) => { DefiniteLength::Fraction(fraction) => {
taffy::style::LengthPercentageAuto::percent(*fraction) taffy::style::LengthPercentageAuto::Percent(*fraction)
} }
} }
} }
@ -337,12 +337,12 @@ impl ToTaffy<taffy::style::Dimension> for DefiniteLength {
fn to_taffy(&self, rem_size: Pixels) -> taffy::style::Dimension { fn to_taffy(&self, rem_size: Pixels) -> taffy::style::Dimension {
match self { match self {
DefiniteLength::Absolute(length) => match length { DefiniteLength::Absolute(length) => match length {
AbsoluteLength::Pixels(pixels) => taffy::style::Dimension::length(pixels.into()), AbsoluteLength::Pixels(pixels) => taffy::style::Dimension::Length(pixels.into()),
AbsoluteLength::Rems(rems) => { AbsoluteLength::Rems(rems) => {
taffy::style::Dimension::length((*rems * rem_size).into()) taffy::style::Dimension::Length((*rems * rem_size).into())
} }
}, },
DefiniteLength::Fraction(fraction) => taffy::style::Dimension::percent(*fraction), DefiniteLength::Fraction(fraction) => taffy::style::Dimension::Percent(*fraction),
} }
} }
} }
@ -350,9 +350,9 @@ impl ToTaffy<taffy::style::Dimension> for DefiniteLength {
impl ToTaffy<taffy::style::LengthPercentage> for AbsoluteLength { impl ToTaffy<taffy::style::LengthPercentage> for AbsoluteLength {
fn to_taffy(&self, rem_size: Pixels) -> taffy::style::LengthPercentage { fn to_taffy(&self, rem_size: Pixels) -> taffy::style::LengthPercentage {
match self { match self {
AbsoluteLength::Pixels(pixels) => taffy::style::LengthPercentage::length(pixels.into()), AbsoluteLength::Pixels(pixels) => taffy::style::LengthPercentage::Length(pixels.into()),
AbsoluteLength::Rems(rems) => { AbsoluteLength::Rems(rems) => {
taffy::style::LengthPercentage::length((*rems * rem_size).into()) taffy::style::LengthPercentage::Length((*rems * rem_size).into())
} }
} }
} }

View file

@ -702,7 +702,6 @@ pub(crate) struct PaintIndex {
input_handlers_index: usize, input_handlers_index: usize,
cursor_styles_index: usize, cursor_styles_index: usize,
accessed_element_states_index: usize, accessed_element_states_index: usize,
tab_handle_index: usize,
line_layout_index: LineLayoutIndex, line_layout_index: LineLayoutIndex,
} }
@ -2209,7 +2208,6 @@ impl Window {
input_handlers_index: self.next_frame.input_handlers.len(), input_handlers_index: self.next_frame.input_handlers.len(),
cursor_styles_index: self.next_frame.cursor_styles.len(), cursor_styles_index: self.next_frame.cursor_styles.len(),
accessed_element_states_index: self.next_frame.accessed_element_states.len(), accessed_element_states_index: self.next_frame.accessed_element_states.len(),
tab_handle_index: self.next_frame.tab_handles.handles.len(),
line_layout_index: self.text_system.layout_index(), line_layout_index: self.text_system.layout_index(),
} }
} }
@ -2239,12 +2237,6 @@ impl Window {
.iter() .iter()
.map(|(id, type_id)| (GlobalElementId(id.0.clone()), *type_id)), .map(|(id, type_id)| (GlobalElementId(id.0.clone()), *type_id)),
); );
self.next_frame.tab_handles.handles.extend(
self.rendered_frame.tab_handles.handles
[range.start.tab_handle_index..range.end.tab_handle_index]
.iter()
.cloned(),
);
self.text_system self.text_system
.reuse_layouts(range.start.line_layout_index..range.end.line_layout_index); .reuse_layouts(range.start.line_layout_index..range.end.line_layout_index);

View file

@ -71,7 +71,6 @@ pub enum IconName {
CircleHelp, CircleHelp,
Close, Close,
Cloud, Cloud,
CloudDownload,
Code, Code,
Cog, Cog,
Command, Command,

View file

@ -166,6 +166,7 @@ pub struct CachedLspAdapter {
pub reinstall_attempt_count: AtomicU64, pub reinstall_attempt_count: AtomicU64,
cached_binary: futures::lock::Mutex<Option<LanguageServerBinary>>, cached_binary: futures::lock::Mutex<Option<LanguageServerBinary>>,
manifest_name: OnceLock<Option<ManifestName>>, manifest_name: OnceLock<Option<ManifestName>>,
attach_kind: OnceLock<Attach>,
} }
impl Debug for CachedLspAdapter { impl Debug for CachedLspAdapter {
@ -201,6 +202,7 @@ impl CachedLspAdapter {
adapter, adapter,
cached_binary: Default::default(), cached_binary: Default::default(),
reinstall_attempt_count: AtomicU64::new(0), reinstall_attempt_count: AtomicU64::new(0),
attach_kind: Default::default(),
manifest_name: Default::default(), manifest_name: Default::default(),
}) })
} }
@ -286,15 +288,29 @@ impl CachedLspAdapter {
.get_or_init(|| self.adapter.manifest_name()) .get_or_init(|| self.adapter.manifest_name())
.clone() .clone()
} }
pub fn attach_kind(&self) -> Attach {
*self.attach_kind.get_or_init(|| self.adapter.attach_kind())
}
} }
/// Determines what gets sent out as a workspace folders content
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
pub enum WorkspaceFoldersContent { pub enum Attach {
/// Send out a single entry with the root of the workspace. /// Create a single language server instance per subproject root.
WorktreeRoot, InstancePerRoot,
/// Send out a list of subproject roots. /// Use one shared language server instance for all subprojects within a project.
SubprojectRoots, Shared,
}
impl Attach {
pub fn root_path(
&self,
root_subproject_path: (WorktreeId, Arc<Path>),
) -> (WorktreeId, Arc<Path>) {
match self {
Attach::InstancePerRoot => root_subproject_path,
Attach::Shared => (root_subproject_path.0, Arc::from(Path::new(""))),
}
}
} }
/// [`LspAdapterDelegate`] allows [`LspAdapter]` implementations to interface with the application /// [`LspAdapterDelegate`] allows [`LspAdapter]` implementations to interface with the application
@ -586,11 +602,8 @@ pub trait LspAdapter: 'static + Send + Sync {
Ok(original) Ok(original)
} }
/// Determines whether a language server supports workspace folders. fn attach_kind(&self) -> Attach {
/// Attach::Shared
/// And does not trip over itself in the process.
fn workspace_folders_content(&self) -> WorkspaceFoldersContent {
WorkspaceFoldersContent::SubprojectRoots
} }
fn manifest_name(&self) -> Option<ManifestName> { fn manifest_name(&self) -> Option<ManifestName> {

Some files were not shown because too many files have changed in this diff Show more