diff --git a/.github/actions/build_docs/action.yml b/.github/actions/build_docs/action.yml index a7effad247..9a2d7e1ec7 100644 --- a/.github/actions/build_docs/action.yml +++ b/.github/actions/build_docs/action.yml @@ -19,7 +19,7 @@ runs: shell: bash -euxo pipefail {0} run: ./script/linux - - name: Check for broken links (in MD) + - name: Check for broken links uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 with: args: --no-progress --exclude '^http' './docs/src/**/*' @@ -30,9 +30,3 @@ runs: run: | mkdir -p target/deploy mdbook build ./docs --dest-dir=../target/deploy/docs/ - - - name: Check for broken links (in HTML) - uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 - with: - args: --no-progress --exclude '^http' 'target/deploy/docs/' - fail: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7dfc33e0d2..a4da5e99ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -269,10 +269,6 @@ jobs: mkdir -p ./../.cargo cp ./.cargo/ci-config.toml ./../.cargo/config.toml - - name: Check that Cargo.lock is up to date - run: | - cargo update --locked --workspace - - name: cargo clippy run: ./script/clippy @@ -771,8 +767,7 @@ jobs: timeout-minutes: 120 name: Create a Windows installer runs-on: [self-hosted, Windows, X64] - if: contains(github.event.pull_request.labels.*.name, 'run-bundling') - # if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) + if: false && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) needs: [windows_tests] env: AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }} diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index 4f7506967b..f799133ea7 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -111,11 +111,6 @@ jobs: echo "Publishing version: ${version} on release channel nightly" echo "nightly" > crates/zed/RELEASE_CHANNEL - - name: Setup Sentry CLI - uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 - with: - token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} - - name: Create macOS app bundle run: script/bundle-mac @@ -141,11 +136,6 @@ jobs: - name: Install Linux dependencies run: ./script/linux && ./script/install-mold 2.34.0 - - name: Setup Sentry CLI - uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 - with: - token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} - - name: Limit target directory size run: script/clear-target-dir-if-larger-than 100 @@ -178,11 +168,6 @@ jobs: - name: Install Linux dependencies run: ./script/linux - - name: Setup Sentry CLI - uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 - with: - token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} - - name: Limit target directory size run: script/clear-target-dir-if-larger-than 100 @@ -277,11 +262,6 @@ jobs: Write-Host "Publishing version: $version on release channel nightly" "nightly" | Set-Content -Path "crates/zed/RELEASE_CHANNEL" - - name: Setup Sentry CLI - uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 - with: - token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} - - name: Build Zed installer working-directory: ${{ env.ZED_WORKSPACE }} run: script/bundle-windows.ps1 diff --git a/Cargo.lock b/Cargo.lock index 56210557d2..c5ab86ceb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6,9 +6,10 @@ version = 4 name = "acp_thread" version = "0.1.0" dependencies = [ - "agent-client-protocol", + "agentic-coding-protocol", "anyhow", "assistant_tool", + "async-pipe", "buffer_diff", "editor", "env_logger 0.11.8", @@ -18,9 +19,7 @@ dependencies = [ "itertools 0.14.0", "language", "markdown", - "parking_lot", "project", - "rand 0.8.5", "serde", "serde_json", "settings", @@ -90,7 +89,6 @@ dependencies = [ "assistant_tools", "chrono", "client", - "cloud_llm_client", "collections", "component", "context_server", @@ -114,6 +112,7 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", + "proto", "rand 0.8.5", "ref-cast", "rope", @@ -132,30 +131,15 @@ dependencies = [ "uuid", "workspace", "workspace-hack", + "zed_llm_client", "zstd", ] -[[package]] -name = "agent-client-protocol" -version = "0.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22c5180e40d31a9998ffa5f8eb067667f0870908a4aeed65a6a299e2d1d95443" -dependencies = [ - "anyhow", - "futures 0.3.31", - "log", - "parking_lot", - "schemars", - "serde", - "serde_json", -] - [[package]] name = "agent_servers" version = "0.1.0" dependencies = [ "acp_thread", - "agent-client-protocol", "agentic-coding-protocol", "anyhow", "collections", @@ -171,7 +155,6 @@ dependencies = [ "nix 0.29.0", "paths", "project", - "rand 0.8.5", "schemars", "serde", "serde_json", @@ -179,7 +162,6 @@ dependencies = [ "smol", "strum 0.27.1", "tempfile", - "thiserror 2.0.12", "ui", "util", "uuid", @@ -193,7 +175,6 @@ name = "agent_settings" version = "0.1.0" dependencies = [ "anyhow", - "cloud_llm_client", "collections", "fs", "gpui", @@ -205,6 +186,7 @@ dependencies = [ "serde_json_lenient", "settings", "workspace-hack", + "zed_llm_client", ] [[package]] @@ -213,9 +195,9 @@ version = "0.1.0" dependencies = [ "acp_thread", "agent", - "agent-client-protocol", "agent_servers", "agent_settings", + "agentic-coding-protocol", "ai_onboarding", "anyhow", "assistant_context", @@ -227,7 +209,6 @@ dependencies = [ "buffer_diff", "chrono", "client", - "cloud_llm_client", "collections", "command_palette_hooks", "component", @@ -299,6 +280,7 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", + "zed_llm_client", ] [[package]] @@ -359,10 +341,10 @@ name = "ai_onboarding" version = "0.1.0" dependencies = [ "client", - "cloud_llm_client", "component", "gpui", "language_model", + "proto", "serde", "smallvec", "telemetry", @@ -691,7 +673,6 @@ dependencies = [ "chrono", "client", "clock", - "cloud_llm_client", "collections", "context_server", "fs", @@ -725,6 +706,7 @@ dependencies = [ "uuid", "workspace", "workspace-hack", + "zed_llm_client", ] [[package]] @@ -832,7 +814,6 @@ dependencies = [ "chrono", "client", "clock", - "cloud_llm_client", "collections", "component", "derive_more 0.99.19", @@ -886,6 +867,7 @@ dependencies = [ "which 6.0.3", "workspace", "workspace-hack", + "zed_llm_client", "zlog", ] @@ -1079,6 +1061,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "async-recursion" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "async-recursion" version = "1.1.1" @@ -2964,12 +2957,11 @@ name = "client" version = "0.1.0" dependencies = [ "anyhow", + "async-recursion 0.3.2", "async-tungstenite", "base64 0.22.1", "chrono", "clock", - "cloud_api_client", - "cloud_llm_client", "cocoa 0.26.0", "collections", "credentials_provider", @@ -3012,6 +3004,7 @@ dependencies = [ "windows 0.61.1", "workspace-hack", "worktree", + "zed_llm_client", ] [[package]] @@ -3024,44 +3017,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "cloud_api_client" -version = "0.1.0" -dependencies = [ - "anyhow", - "cloud_api_types", - "futures 0.3.31", - "http_client", - "parking_lot", - "serde_json", - "workspace-hack", -] - -[[package]] -name = "cloud_api_types" -version = "0.1.0" -dependencies = [ - "chrono", - "cloud_llm_client", - "pretty_assertions", - "serde", - "serde_json", - "workspace-hack", -] - -[[package]] -name = "cloud_llm_client" -version = "0.1.0" -dependencies = [ - "anyhow", - "pretty_assertions", - "serde", - "serde_json", - "strum 0.27.1", - "uuid", - "workspace-hack", -] - [[package]] name = "clru" version = "0.6.2" @@ -3188,7 +3143,6 @@ dependencies = [ "chrono", "client", "clock", - "cloud_llm_client", "collab_ui", "collections", "command_palette_hooks", @@ -3275,6 +3229,7 @@ dependencies = [ "workspace", "workspace-hack", "worktree", + "zed_llm_client", "zlog", ] @@ -3715,6 +3670,17 @@ dependencies = [ "libm", ] +[[package]] +name = "coreaudio-rs" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "321077172d79c662f64f5071a03120748d5bb652f5231570141be24cfcd2bace" +dependencies = [ + "bitflags 1.3.2", + "core-foundation-sys", + "coreaudio-sys", +] + [[package]] name = "coreaudio-rs" version = "0.12.1" @@ -3772,6 +3738,29 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "cpal" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "873dab07c8f743075e57f524c583985fbaf745602acbe916a01539364369a779" +dependencies = [ + "alsa", + "core-foundation-sys", + "coreaudio-rs 0.11.3", + "dasp_sample", + "jni", + "js-sys", + "libc", + "mach2", + "ndk 0.8.0", + "ndk-context", + "oboe", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "windows 0.54.0", +] + [[package]] name = "cpal" version = "0.16.0" @@ -3785,7 +3774,7 @@ dependencies = [ "js-sys", "libc", "mach2", - "ndk", + "ndk 0.9.0", "ndk-context", "num-derive", "num-traits", @@ -4255,7 +4244,7 @@ dependencies = [ [[package]] name = "dap-types" version = "0.0.1" -source = "git+https://github.com/zed-industries/dap-types?rev=1b461b310481d01e02b2603c16d7144b926339f8#1b461b310481d01e02b2603c16d7144b926339f8" +source = "git+https://github.com/zed-industries/dap-types?rev=7f39295b441614ca9dbf44293e53c32f666897f9#7f39295b441614ca9dbf44293e53c32f666897f9" dependencies = [ "schemars", "serde", @@ -4723,6 +4712,7 @@ name = "docs_preprocessor" version = "0.1.0" dependencies = [ "anyhow", + "clap", "command_palette", "gpui", "mdbook", @@ -4733,7 +4723,6 @@ dependencies = [ "util", "workspace-hack", "zed", - "zlog", ] [[package]] @@ -4911,7 +4900,6 @@ dependencies = [ "text", "theme", "time", - "tree-sitter-bash", "tree-sitter-c", "tree-sitter-html", "tree-sitter-python", @@ -5195,7 +5183,6 @@ dependencies = [ "chrono", "clap", "client", - "cloud_llm_client", "collections", "debug_adapter_extension", "dirs 4.0.0", @@ -5235,6 +5222,7 @@ dependencies = [ "uuid", "watch", "workspace-hack", + "zed_llm_client", ] [[package]] @@ -5299,12 +5287,6 @@ dependencies = [ "zune-inflate", ] -[[package]] -name = "extended" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af9673d8203fcb076b19dfd17e38b3d4ae9f44959416ea532ce72415a6020365" - [[package]] name = "extension" version = "0.1.0" @@ -5324,13 +5306,11 @@ dependencies = [ "log", "lsp", "parking_lot", - "pretty_assertions", "semantic_version", "serde", "serde_json", "task", "toml 0.8.20", - "url", "util", "wasm-encoder 0.221.3", "wasmparser 0.221.3", @@ -6316,7 +6296,6 @@ dependencies = [ "call", "chrono", "client", - "cloud_llm_client", "collections", "command_palette_hooks", "component", @@ -6359,6 +6338,7 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", + "zed_llm_client", "zlog", ] @@ -7357,9 +7337,9 @@ dependencies = [ [[package]] name = "grid" -version = "0.17.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71b01d27060ad58be4663b9e4ac9e2d4806918e8876af8912afbddd1a91d5eaa" +checksum = "be136d9dacc2a13cc70bb6c8f902b414fb2641f8db1314637c6b7933411a8f82" [[package]] name = "group" @@ -7672,6 +7652,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "hound" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" + [[package]] name = "html5ever" version = "0.27.0" @@ -7805,7 +7791,6 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "log", - "parking_lot", "serde", "serde_json", "url", @@ -8305,7 +8290,6 @@ version = "0.1.0" dependencies = [ "anyhow", "client", - "cloud_llm_client", "copilot", "editor", "feature_flags", @@ -8328,6 +8312,7 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", + "zed_llm_client", "zeta", ] @@ -9020,7 +9005,6 @@ dependencies = [ "anyhow", "base64 0.22.1", "client", - "cloud_llm_client", "collections", "futures 0.3.31", "gpui", @@ -9038,6 +9022,7 @@ dependencies = [ "thiserror 2.0.12", "util", "workspace-hack", + "zed_llm_client", ] [[package]] @@ -9053,7 +9038,6 @@ dependencies = [ "bedrock", "chrono", "client", - "cloud_llm_client", "collections", "component", "convert_case 0.8.0", @@ -9077,6 +9061,7 @@ dependencies = [ "open_router", "partial-json-fixer", "project", + "proto", "release_channel", "schemars", "serde", @@ -9094,6 +9079,7 @@ dependencies = [ "vercel", "workspace-hack", "x_ai", + "zed_llm_client", ] [[package]] @@ -9155,7 +9141,6 @@ dependencies = [ "chrono", "collections", "dap", - "feature_flags", "futures 0.3.31", "gpui", "http_client", @@ -9348,7 +9333,7 @@ dependencies = [ [[package]] name = "libwebrtc" version = "0.3.10" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" dependencies = [ "cxx", "jni", @@ -9428,7 +9413,7 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "livekit" version = "0.7.8" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" dependencies = [ "chrono", "futures-util", @@ -9451,7 +9436,7 @@ dependencies = [ [[package]] name = "livekit-api" version = "0.4.2" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" dependencies = [ "futures-util", "http 0.2.12", @@ -9475,7 +9460,7 @@ dependencies = [ [[package]] name = "livekit-protocol" version = "0.3.9" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" dependencies = [ "futures-util", "livekit-runtime", @@ -9492,7 +9477,7 @@ dependencies = [ [[package]] name = "livekit-runtime" version = "0.4.0" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" dependencies = [ "tokio", "tokio-stream", @@ -9524,7 +9509,7 @@ dependencies = [ "core-foundation 0.10.0", "core-video", "coreaudio-rs 0.12.1", - "cpal", + "cpal 0.16.0", "futures 0.3.31", "gpui", "gpui_tokio", @@ -9575,9 +9560,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.13" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" dependencies = [ "autocfg", "scopeguard", @@ -9814,7 +9799,7 @@ name = "markdown_preview" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion", + "async-recursion 1.1.1", "collections", "editor", "fs", @@ -10305,6 +10290,20 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "ndk" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7" +dependencies = [ + "bitflags 2.9.0", + "jni-sys", + "log", + "ndk-sys 0.5.0+25.2.9519653", + "num_enum", + "thiserror 1.0.69", +] + [[package]] name = "ndk" version = "0.9.0" @@ -10314,7 +10313,7 @@ dependencies = [ "bitflags 2.9.0", "jni-sys", "log", - "ndk-sys", + "ndk-sys 0.6.0+11769913", "num_enum", "thiserror 1.0.69", ] @@ -10325,6 +10324,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" +[[package]] +name = "ndk-sys" +version = "0.5.0+25.2.9519653" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691" +dependencies = [ + "jni-sys", +] + [[package]] name = "ndk-sys" version = "0.6.0+11769913" @@ -10897,6 +10905,29 @@ dependencies = [ "memchr", ] +[[package]] +name = "oboe" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8b61bebd49e5d43f5f8cc7ee2891c16e0f41ec7954d36bcb6c14c5e0de867fb" +dependencies = [ + "jni", + "ndk 0.8.0", + "ndk-context", + "num-derive", + "num-traits", + "oboe-sys", +] + +[[package]] +name = "oboe-sys" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8bb09a4a2b1d668170cfe0a7d5bc103f8999fb316c98099b6a9939c9f2e79d" +dependencies = [ + "cc", +] + [[package]] name = "ollama" version = "0.1.0" @@ -10914,33 +10945,17 @@ dependencies = [ name = "onboarding" version = "0.1.0" dependencies = [ - "ai_onboarding", "anyhow", - "client", "command_palette_hooks", - "component", "db", - "documented", - "editor", "feature_flags", "fs", "gpui", - "itertools 0.14.0", - "language", - "language_model", - "menu", - "project", - "schemars", - "serde", "settings", "theme", "ui", - "util", - "vim_mode_setting", "workspace", "workspace-hack", - "zed_actions", - "zlog", ] [[package]] @@ -11291,9 +11306,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.4" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -11301,9 +11316,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.11" +version = "0.9.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", @@ -13691,15 +13706,12 @@ dependencies = [ [[package]] name = "rodio" -version = "0.21.1" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183" +checksum = "e7ceb6607dd738c99bc8cb28eff249b7cd5c8ec88b9db96c0608c1480d140fb1" dependencies = [ - "cpal", - "dasp_sample", - "num-rational", - "symphonia", - "tracing", + "cpal 0.15.3", + "hound", ] [[package]] @@ -14704,27 +14716,6 @@ dependencies = [ "zlog", ] -[[package]] -name = "settings_profile_selector" -version = "0.1.0" -dependencies = [ - "client", - "editor", - "fuzzy", - "gpui", - "language", - "menu", - "picker", - "project", - "serde_json", - "settings", - "theme", - "ui", - "workspace", - "workspace-hack", - "zed_actions", -] - [[package]] name = "settings_ui" version = "0.1.0" @@ -14747,6 +14738,7 @@ dependencies = [ "notifications", "paths", "project", + "schemars", "search", "serde", "serde_json", @@ -15740,66 +15732,6 @@ dependencies = [ "zeno", ] -[[package]] -name = "symphonia" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9" -dependencies = [ - "lazy_static", - "symphonia-codec-pcm", - "symphonia-core", - "symphonia-format-riff", - "symphonia-metadata", -] - -[[package]] -name = "symphonia-codec-pcm" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f395a67057c2ebc5e84d7bb1be71cce1a7ba99f64e0f0f0e303a03f79116f89b" -dependencies = [ - "log", - "symphonia-core", -] - -[[package]] -name = "symphonia-core" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "798306779e3dc7d5231bd5691f5a813496dc79d3f56bf82e25789f2094e022c3" -dependencies = [ - "arrayvec", - "bitflags 1.3.2", - "bytemuck", - "lazy_static", - "log", -] - -[[package]] -name = "symphonia-format-riff" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f7be232f962f937f4b7115cbe62c330929345434c834359425e043bfd15f50" -dependencies = [ - "extended", - "log", - "symphonia-core", - "symphonia-metadata", -] - -[[package]] -name = "symphonia-metadata" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc622b9841a10089c5b18e99eb904f4341615d5aa55bbf4eedde1be721a4023c" -dependencies = [ - "encoding_rs", - "lazy_static", - "log", - "symphonia-core", -] - [[package]] name = "syn" version = "1.0.109" @@ -15980,12 +15912,13 @@ dependencies = [ [[package]] name = "taffy" -version = "0.8.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7aaef0ac998e6527d6d0d5582f7e43953bb17221ac75bb8eb2fcc2db3396db1c" +checksum = "e8b61630cba2afd2c851821add2e1bb1b7851a2436e839ab73b56558b009035e" dependencies = [ "arrayvec", "grid", + "num-traits", "serde", "slotmap", ] @@ -16183,7 +16116,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_slash_command", - "async-recursion", + "async-recursion 1.1.1", "breadcrumbs", "client", "collections", @@ -16532,7 +16465,6 @@ dependencies = [ "call", "chrono", "client", - "cloud_llm_client", "collections", "db", "gpui", @@ -18501,11 +18433,11 @@ name = "web_search" version = "0.1.0" dependencies = [ "anyhow", - "cloud_llm_client", "collections", "gpui", "serde", "workspace-hack", + "zed_llm_client", ] [[package]] @@ -18514,7 +18446,6 @@ version = "0.1.0" dependencies = [ "anyhow", "client", - "cloud_llm_client", "futures 0.3.31", "gpui", "http_client", @@ -18523,6 +18454,7 @@ dependencies = [ "serde_json", "web_search", "workspace-hack", + "zed_llm_client", ] [[package]] @@ -18546,7 +18478,7 @@ dependencies = [ [[package]] name = "webrtc-sys" version = "0.3.7" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" dependencies = [ "cc", "cxx", @@ -18559,7 +18491,7 @@ dependencies = [ [[package]] name = "webrtc-sys-build" version = "0.3.6" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" dependencies = [ "fs2", "regex", @@ -18594,6 +18526,7 @@ dependencies = [ "serde", "settings", "telemetry", + "theme", "ui", "util", "vim_mode_setting", @@ -19608,7 +19541,7 @@ version = "0.1.0" dependencies = [ "any_vec", "anyhow", - "async-recursion", + "async-recursion 1.1.1", "bincode", "call", "client", @@ -19685,12 +19618,14 @@ dependencies = [ "cc", "chrono", "cipher", + "clang-sys", "clap", "clap_builder", "codespan-reporting 0.12.0", "concurrent-queue", "core-foundation 0.9.4", "core-foundation-sys", + "coreaudio-sys", "cranelift-codegen", "crc32fast", "crossbeam-epoch", @@ -20133,7 +20068,7 @@ dependencies = [ "async-io", "async-lock", "async-process", - "async-recursion", + "async-recursion 1.1.1", "async-task", "async-trait", "blocking", @@ -20186,7 +20121,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.199.0" +version = "0.197.5" dependencies = [ "activity_indicator", "agent", @@ -20227,7 +20162,6 @@ dependencies = [ "extension", "extension_host", "extensions_ui", - "feature_flags", "feedback", "file_finder", "fs", @@ -20289,7 +20223,6 @@ dependencies = [ "serde_json", "session", "settings", - "settings_profile_selector", "settings_ui", "shellexpand 2.1.2", "smol", @@ -20348,7 +20281,7 @@ dependencies = [ [[package]] name = "zed_emmet" -version = "0.0.4" +version = "0.0.3" dependencies = [ "zed_extension_api 0.1.0", ] @@ -20387,6 +20320,19 @@ dependencies = [ "zed_extension_api 0.1.0", ] +[[package]] +name = "zed_llm_client" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc" +dependencies = [ + "anyhow", + "serde", + "serde_json", + "strum 0.27.1", + "uuid", +] + [[package]] name = "zed_proto" version = "0.2.2" @@ -20566,8 +20512,6 @@ dependencies = [ "call", "client", "clock", - "cloud_api_types", - "cloud_llm_client", "collections", "command_palette_hooks", "copilot", @@ -20587,6 +20531,7 @@ dependencies = [ "menu", "postage", "project", + "proto", "regex", "release_channel", "reqwest_client", @@ -20608,45 +20553,10 @@ dependencies = [ "workspace-hack", "worktree", "zed_actions", + "zed_llm_client", "zlog", ] -[[package]] -name = "zeta_cli" -version = "0.1.0" -dependencies = [ - "anyhow", - "clap", - "client", - "debug_adapter_extension", - "extension", - "fs", - "futures 0.3.31", - "gpui", - "gpui_tokio", - "language", - "language_extension", - "language_model", - "language_models", - "languages", - "node_runtime", - "paths", - "project", - "prompt_store", - "release_channel", - "reqwest_client", - "serde", - "serde_json", - "settings", - "shellexpand 2.1.2", - "smol", - "terminal_view", - "util", - "watch", - "workspace-hack", - "zeta", -] - [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index 5d852f8842..ec793a7429 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [workspace] resolver = "2" members = [ - "crates/acp_thread", "crates/activity_indicator", - "crates/agent", - "crates/agent_servers", - "crates/agent_settings", + "crates/acp_thread", "crates/agent_ui", + "crates/agent", + "crates/agent_settings", "crates/ai_onboarding", + "crates/agent_servers", "crates/anthropic", "crates/askpass", "crates/assets", @@ -29,9 +29,6 @@ members = [ "crates/cli", "crates/client", "crates/clock", - "crates/cloud_api_client", - "crates/cloud_api_types", - "crates/cloud_llm_client", "crates/collab", "crates/collab_ui", "crates/collections", @@ -51,8 +48,8 @@ members = [ "crates/diagnostics", "crates/docs_preprocessor", "crates/editor", - "crates/eval", "crates/explorer_command_injector", + "crates/eval", "crates/extension", "crates/extension_api", "crates/extension_cli", @@ -73,6 +70,7 @@ members = [ "crates/gpui", "crates/gpui_macros", "crates/gpui_tokio", + "crates/html_to_markdown", "crates/http_client", "crates/http_client_tls", @@ -101,6 +99,7 @@ members = [ "crates/markdown_preview", "crates/media", "crates/menu", + "crates/svg_preview", "crates/migrator", "crates/mistral", "crates/multi_buffer", @@ -141,7 +140,6 @@ members = [ "crates/semantic_version", "crates/session", "crates/settings", - "crates/settings_profile_selector", "crates/settings_ui", "crates/snippet", "crates/snippet_provider", @@ -154,7 +152,6 @@ members = [ "crates/sum_tree", "crates/supermaven", "crates/supermaven_api", - "crates/svg_preview", "crates/tab_switcher", "crates/task", "crates/tasks_ui", @@ -189,7 +186,6 @@ members = [ "crates/zed", "crates/zed_actions", "crates/zeta", - "crates/zeta_cli", "crates/zlog", "crates/zlog_settings", @@ -255,9 +251,6 @@ channel = { path = "crates/channel" } cli = { path = "crates/cli" } client = { path = "crates/client" } clock = { path = "crates/clock" } -cloud_api_client = { path = "crates/cloud_api_client" } -cloud_api_types = { path = "crates/cloud_api_types" } -cloud_llm_client = { path = "crates/cloud_llm_client" } collab = { path = "crates/collab" } collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } @@ -344,7 +337,6 @@ picker = { path = "crates/picker" } plugin = { path = "crates/plugin" } plugin_macros = { path = "crates/plugin_macros" } prettier = { path = "crates/prettier" } -settings_profile_selector = { path = "crates/settings_profile_selector" } project = { path = "crates/project" } project_panel = { path = "crates/project_panel" } project_symbols = { path = "crates/project_symbols" } @@ -421,7 +413,6 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.17" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" @@ -468,7 +459,7 @@ core-video = { version = "0.4.3", features = ["metal"] } cpal = "0.16" criterion = { version = "0.5", features = ["html_reports"] } ctor = "0.4.0" -dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "1b461b310481d01e02b2603c16d7144b926339f8" } +dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "7f39295b441614ca9dbf44293e53c32f666897f9" } dashmap = "6.0" derive_more = "0.99.17" dirs = "4.0" @@ -653,6 +644,7 @@ which = "6.0.0" windows-core = "0.61" wit-component = "0.221" workspace-hack = "0.1.0" +zed_llm_client = "= 0.8.6" zstd = "0.11" [workspace.dependencies.async-stripe] @@ -679,16 +671,14 @@ features = [ "UI_ViewManagement", "Wdk_System_SystemServices", "Win32_Globalization", - "Win32_Graphics_Direct3D", - "Win32_Graphics_Direct3D11", - "Win32_Graphics_Direct3D_Fxc", - "Win32_Graphics_DirectComposition", + "Win32_Graphics_Direct2D", + "Win32_Graphics_Direct2D_Common", "Win32_Graphics_DirectWrite", "Win32_Graphics_Dwm", - "Win32_Graphics_Dxgi", "Win32_Graphics_Dxgi_Common", "Win32_Graphics_Gdi", "Win32_Graphics_Imaging", + "Win32_Graphics_Imaging_D2D", "Win32_Networking_WinSock", "Win32_Security", "Win32_Security_Credentials", @@ -729,11 +719,6 @@ workspace-hack = { path = "tooling/workspace-hack" } split-debuginfo = "unpacked" codegen-units = 16 -# mirror configuration for crates compiled for the build platform -# (without this cargo will compile ~400 crates twice) -[profile.dev.build-override] -codegen-units = 16 - [profile.dev.package] taffy = { opt-level = 3 } cranelift-codegen = { opt-level = 3 } diff --git a/README.md b/README.md index 38547c1ca4..4c794efc3d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ # Zed -[![Zed](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/zed-industries/zed/main/assets/badge/v0.json)](https://zed.dev) [![CI](https://github.com/zed-industries/zed/actions/workflows/ci.yml/badge.svg)](https://github.com/zed-industries/zed/actions/workflows/ci.yml) Welcome to Zed, a high-performance, multiplayer code editor from the creators of [Atom](https://github.com/atom/atom) and [Tree-sitter](https://github.com/tree-sitter/tree-sitter). diff --git a/assets/badge/v0.json b/assets/badge/v0.json deleted file mode 100644 index c7d18bb42b..0000000000 --- a/assets/badge/v0.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "label": "", - "message": "Zed", - "logoSvg": "", - "logoWidth": 16, - "labelColor": "black", - "color": "white" -} diff --git a/assets/icons/ai_bedrock.svg b/assets/icons/ai_bedrock.svg index c9bbcc82e1..2b672c364e 100644 --- a/assets/icons/ai_bedrock.svg +++ b/assets/icons/ai_bedrock.svg @@ -1,8 +1,4 @@ - - - - - - - + + + diff --git a/assets/icons/ai_deep_seek.svg b/assets/icons/ai_deep_seek.svg index c8e5483fb3..cf480c834c 100644 --- a/assets/icons/ai_deep_seek.svg +++ b/assets/icons/ai_deep_seek.svg @@ -1,3 +1 @@ - - - +DeepSeek diff --git a/assets/icons/ai_lm_studio.svg b/assets/icons/ai_lm_studio.svg index 5cfdeb5578..0b455f48a7 100644 --- a/assets/icons/ai_lm_studio.svg +++ b/assets/icons/ai_lm_studio.svg @@ -1,15 +1,33 @@ - - - - - - - - - - - - - - + + + Artboard + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/icons/ai_mistral.svg b/assets/icons/ai_mistral.svg index f11c177e2f..23b8f2ef6c 100644 --- a/assets/icons/ai_mistral.svg +++ b/assets/icons/ai_mistral.svg @@ -1,8 +1 @@ - - - - - - - - +Mistral \ No newline at end of file diff --git a/assets/icons/ai_ollama.svg b/assets/icons/ai_ollama.svg index 36a88c1ad6..d433df3981 100644 --- a/assets/icons/ai_ollama.svg +++ b/assets/icons/ai_ollama.svg @@ -1,7 +1,14 @@ - - - - - + + + + + + + + + + + + diff --git a/assets/icons/ai_open_ai.svg b/assets/icons/ai_open_ai.svg index e45ac315a0..e659a472d8 100644 --- a/assets/icons/ai_open_ai.svg +++ b/assets/icons/ai_open_ai.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/ai_open_router.svg b/assets/icons/ai_open_router.svg index b6f5164e0b..94f2849146 100644 --- a/assets/icons/ai_open_router.svg +++ b/assets/icons/ai_open_router.svg @@ -1,8 +1,8 @@ - - - - - - - + + + + + + + diff --git a/assets/icons/ai_x_ai.svg b/assets/icons/ai_x_ai.svg index d3400fbe9c..289525c8ef 100644 --- a/assets/icons/ai_x_ai.svg +++ b/assets/icons/ai_x_ai.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/ai_zed.svg b/assets/icons/ai_zed.svg index 6d78efacd5..1c6bb8ad63 100644 --- a/assets/icons/ai_zed.svg +++ b/assets/icons/ai_zed.svg @@ -1,3 +1,10 @@ - + + + + + + + + diff --git a/assets/icons/cloud_download.svg b/assets/icons/at_sign.svg similarity index 51% rename from assets/icons/cloud_download.svg rename to assets/icons/at_sign.svg index bc7a8376d1..4cf8cd468f 100644 --- a/assets/icons/cloud_download.svg +++ b/assets/icons/at_sign.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/assets/icons/audio_off.svg b/assets/icons/audio_off.svg index dfb5a1c458..93b98471ca 100644 --- a/assets/icons/audio_off.svg +++ b/assets/icons/audio_off.svg @@ -1,7 +1 @@ - - - - - - - + diff --git a/assets/icons/audio_on.svg b/assets/icons/audio_on.svg index d1bef0d337..42310ea32c 100644 --- a/assets/icons/audio_on.svg +++ b/assets/icons/audio_on.svg @@ -1,5 +1 @@ - - - - - + diff --git a/assets/icons/bolt.svg b/assets/icons/bolt.svg new file mode 100644 index 0000000000..2688ede2a5 --- /dev/null +++ b/assets/icons/bolt.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/bolt_filled.svg b/assets/icons/bolt_filled.svg index 14d8f53e02..543e72adf8 100644 --- a/assets/icons/bolt_filled.svg +++ b/assets/icons/bolt_filled.svg @@ -1,3 +1,3 @@ - - + + diff --git a/assets/icons/bolt_filled_alt.svg b/assets/icons/bolt_filled_alt.svg new file mode 100644 index 0000000000..141e1c5f57 --- /dev/null +++ b/assets/icons/bolt_filled_alt.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/bolt_outlined.svg b/assets/icons/bolt_outlined.svg deleted file mode 100644 index 58fccf7788..0000000000 --- a/assets/icons/bolt_outlined.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/book_plus.svg b/assets/icons/book_plus.svg new file mode 100644 index 0000000000..2868f07cd0 --- /dev/null +++ b/assets/icons/book_plus.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/brain.svg b/assets/icons/brain.svg new file mode 100644 index 0000000000..80c93814f7 --- /dev/null +++ b/assets/icons/brain.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/chat.svg b/assets/icons/chat.svg deleted file mode 100644 index a0548c3d3e..0000000000 --- a/assets/icons/chat.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - - diff --git a/assets/icons/editor_atom.svg b/assets/icons/editor_atom.svg deleted file mode 100644 index cc5fa83843..0000000000 --- a/assets/icons/editor_atom.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/editor_cursor.svg b/assets/icons/editor_cursor.svg deleted file mode 100644 index 338697be8a..0000000000 --- a/assets/icons/editor_cursor.svg +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - - - - diff --git a/assets/icons/editor_emacs.svg b/assets/icons/editor_emacs.svg deleted file mode 100644 index 951d7b2be1..0000000000 --- a/assets/icons/editor_emacs.svg +++ /dev/null @@ -1,10 +0,0 @@ - - - - - - - - - - diff --git a/assets/icons/editor_jet_brains.svg b/assets/icons/editor_jet_brains.svg deleted file mode 100644 index 7d9cf0c65c..0000000000 --- a/assets/icons/editor_jet_brains.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/editor_sublime.svg b/assets/icons/editor_sublime.svg deleted file mode 100644 index 95a04f6b54..0000000000 --- a/assets/icons/editor_sublime.svg +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/assets/icons/editor_vs_code.svg b/assets/icons/editor_vs_code.svg deleted file mode 100644 index 2a71ad52af..0000000000 --- a/assets/icons/editor_vs_code.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/exit.svg b/assets/icons/exit.svg index 1ff9d78824..2cc6ce120d 100644 --- a/assets/icons/exit.svg +++ b/assets/icons/exit.svg @@ -1,5 +1,8 @@ - - - - + + diff --git a/assets/icons/file_icons/kdl.svg b/assets/icons/file_icons/kdl.svg deleted file mode 100644 index 92d9f28428..0000000000 --- a/assets/icons/file_icons/kdl.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/assets/icons/file_icons/surrealql.svg b/assets/icons/file_icons/surrealql.svg deleted file mode 100644 index 076f93e808..0000000000 --- a/assets/icons/file_icons/surrealql.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/file_text.svg b/assets/icons/file_text.svg index a9b8f971e0..7c602f2ac7 100644 --- a/assets/icons/file_text.svg +++ b/assets/icons/file_text.svg @@ -1,6 +1 @@ - - - - - - + diff --git a/assets/icons/git_onboarding_bg.svg b/assets/icons/git_onboarding_bg.svg new file mode 100644 index 0000000000..18da0230a2 --- /dev/null +++ b/assets/icons/git_onboarding_bg.svg @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/icons/message_bubbles.svg b/assets/icons/message_bubbles.svg new file mode 100644 index 0000000000..03a6c7760c --- /dev/null +++ b/assets/icons/message_bubbles.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/assets/icons/mic.svg b/assets/icons/mic.svg index 1d9c5bc9ed..01f4c9bf66 100644 --- a/assets/icons/mic.svg +++ b/assets/icons/mic.svg @@ -1,5 +1,3 @@ - - - - + + diff --git a/assets/icons/mic_mute.svg b/assets/icons/mic_mute.svg index 8c61ae2f1c..fe5f8201cc 100644 --- a/assets/icons/mic_mute.svg +++ b/assets/icons/mic_mute.svg @@ -1,8 +1,3 @@ - - - - - - - + + diff --git a/assets/icons/microscope.svg b/assets/icons/microscope.svg new file mode 100644 index 0000000000..2b3009a28b --- /dev/null +++ b/assets/icons/microscope.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/new_from_summary.svg b/assets/icons/new_from_summary.svg new file mode 100644 index 0000000000..3b61ca51a0 --- /dev/null +++ b/assets/icons/new_from_summary.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/assets/icons/text_thread.svg b/assets/icons/new_text_thread.svg similarity index 100% rename from assets/icons/text_thread.svg rename to assets/icons/new_text_thread.svg diff --git a/assets/icons/thread.svg b/assets/icons/new_thread.svg similarity index 100% rename from assets/icons/thread.svg rename to assets/icons/new_thread.svg diff --git a/assets/icons/play.svg b/assets/icons/play.svg new file mode 100644 index 0000000000..2481bda7d6 --- /dev/null +++ b/assets/icons/play.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/play_outlined.svg b/assets/icons/play_alt.svg similarity index 70% rename from assets/icons/play_outlined.svg rename to assets/icons/play_alt.svg index 7e1cacd5af..b327ab07b5 100644 --- a/assets/icons/play_outlined.svg +++ b/assets/icons/play_alt.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/play_bug.svg b/assets/icons/play_bug.svg new file mode 100644 index 0000000000..7d265dd42a --- /dev/null +++ b/assets/icons/play_bug.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/assets/icons/play_filled.svg b/assets/icons/play_filled.svg index c632434305..387304ef04 100644 --- a/assets/icons/play_filled.svg +++ b/assets/icons/play_filled.svg @@ -1,3 +1,3 @@ - - + + diff --git a/assets/icons/reveal.svg b/assets/icons/reveal.svg new file mode 100644 index 0000000000..ff5444d8f8 --- /dev/null +++ b/assets/icons/reveal.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/screen.svg b/assets/icons/screen.svg index 4b686b58f9..ad252e64cf 100644 --- a/assets/icons/screen.svg +++ b/assets/icons/screen.svg @@ -1,5 +1,8 @@ - - - - + + diff --git a/assets/icons/shield_check.svg b/assets/icons/shield_check.svg deleted file mode 100644 index 6e58c31468..0000000000 --- a/assets/icons/shield_check.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - - diff --git a/assets/icons/spinner.svg b/assets/icons/spinner.svg new file mode 100644 index 0000000000..4f4034ae89 --- /dev/null +++ b/assets/icons/spinner.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/assets/icons/strikethrough.svg b/assets/icons/strikethrough.svg new file mode 100644 index 0000000000..d7d0905912 --- /dev/null +++ b/assets/icons/strikethrough.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/thread_from_summary.svg b/assets/icons/thread_from_summary.svg deleted file mode 100644 index 7519935aff..0000000000 --- a/assets/icons/thread_from_summary.svg +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - diff --git a/assets/icons/trash.svg b/assets/icons/trash.svg index 1322e90f9f..b71035b99c 100644 --- a/assets/icons/trash.svg +++ b/assets/icons/trash.svg @@ -1,5 +1 @@ - - - - - + diff --git a/assets/icons/trash_alt.svg b/assets/icons/trash_alt.svg new file mode 100644 index 0000000000..6867b42147 --- /dev/null +++ b/assets/icons/trash_alt.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/zed_predict_bg.svg b/assets/icons/zed_predict_bg.svg new file mode 100644 index 0000000000..1dccbb51af --- /dev/null +++ b/assets/icons/zed_predict_bg.svg @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index ef5354e82d..a4f812b2fc 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -232,7 +232,7 @@ "ctrl-n": "agent::NewThread", "ctrl-alt-n": "agent::NewTextThread", "ctrl-shift-h": "agent::OpenHistory", - "ctrl-alt-c": "agent::OpenSettings", + "ctrl-alt-c": "agent::OpenConfiguration", "ctrl-alt-p": "agent::OpenRulesLibrary", "ctrl-i": "agent::ToggleProfileSelector", "ctrl-alt-/": "agent::ToggleModelSelector", @@ -495,7 +495,7 @@ "shift-f12": "editor::GoToImplementation", "alt-ctrl-f12": "editor::GoToTypeDefinitionSplit", "alt-shift-f12": "editor::FindAllReferences", - "ctrl-m": "editor::MoveToEnclosingBracket", // from jetbrains + "ctrl-m": "editor::MoveToEnclosingBracket", "ctrl-|": "editor::MoveToEnclosingBracket", "ctrl-{": "editor::Fold", "ctrl-}": "editor::UnfoldLines", @@ -598,7 +598,6 @@ "ctrl-shift-t": "pane::ReopenClosedItem", "ctrl-k ctrl-s": "zed::OpenKeymapEditor", "ctrl-k ctrl-t": "theme_selector::Toggle", - "ctrl-alt-super-p": "settings_profile_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -1168,14 +1167,5 @@ "up": "menu::SelectPrevious", "down": "menu::SelectNext" } - }, - { - "context": "Onboarding", - "use_key_equivalents": true, - "bindings": { - "ctrl-1": "onboarding::ActivateBasicsPage", - "ctrl-2": "onboarding::ActivateEditingPage", - "ctrl-3": "onboarding::ActivateAISetupPage" - } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 3287e50acb..eded8c73e6 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -272,7 +272,7 @@ "cmd-n": "agent::NewThread", "cmd-alt-n": "agent::NewTextThread", "cmd-shift-h": "agent::OpenHistory", - "cmd-alt-c": "agent::OpenSettings", + "cmd-alt-c": "agent::OpenConfiguration", "cmd-alt-p": "agent::OpenRulesLibrary", "cmd-i": "agent::ToggleProfileSelector", "cmd-alt-/": "agent::ToggleModelSelector", @@ -549,7 +549,7 @@ "alt-cmd-f12": "editor::GoToTypeDefinitionSplit", "alt-shift-f12": "editor::FindAllReferences", "cmd-|": "editor::MoveToEnclosingBracket", - "ctrl-m": "editor::MoveToEnclosingBracket", // From Jetbrains + "ctrl-m": "editor::MoveToEnclosingBracket", "alt-cmd-[": "editor::Fold", "alt-cmd-]": "editor::UnfoldLines", "cmd-k cmd-l": "editor::ToggleFold", @@ -665,7 +665,6 @@ "cmd-shift-t": "pane::ReopenClosedItem", "cmd-k cmd-s": "zed::OpenKeymapEditor", "cmd-k cmd-t": "theme_selector::Toggle", - "ctrl-alt-cmd-p": "settings_profile_selector::Toggle", "cmd-t": "project_symbols::Toggle", "cmd-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -1270,14 +1269,5 @@ "up": "menu::SelectPrevious", "down": "menu::SelectNext" } - }, - { - "context": "Onboarding", - "use_key_equivalents": true, - "bindings": { - "cmd-1": "onboarding::ActivateBasicsPage", - "cmd-2": "onboarding::ActivateEditingPage", - "cmd-3": "onboarding::ActivateAISetupPage" - } } ] diff --git a/assets/keymaps/linux/cursor.json b/assets/keymaps/linux/cursor.json index 1c381b0cf0..347b7885fc 100644 --- a/assets/keymaps/linux/cursor.json +++ b/assets/keymaps/linux/cursor.json @@ -8,7 +8,7 @@ "ctrl-shift-i": "agent::ToggleFocus", "ctrl-l": "agent::ToggleFocus", "ctrl-shift-l": "agent::ToggleFocus", - "ctrl-shift-j": "agent::OpenSettings" + "ctrl-shift-j": "agent::OpenConfiguration" } }, { diff --git a/assets/keymaps/linux/jetbrains.json b/assets/keymaps/linux/jetbrains.json index 3df1243fed..c1d8bbebe6 100644 --- a/assets/keymaps/linux/jetbrains.json +++ b/assets/keymaps/linux/jetbrains.json @@ -4,7 +4,6 @@ "ctrl-alt-s": "zed::OpenSettings", "ctrl-{": "pane::ActivatePreviousItem", "ctrl-}": "pane::ActivateNextItem", - "shift-escape": null, // Unmap workspace::zoom "ctrl-f2": "debugger::Stop", "f6": "debugger::Pause", "f7": "debugger::StepInto", @@ -45,8 +44,8 @@ "ctrl-alt-right": "pane::GoForward", "alt-f7": "editor::FindAllReferences", "ctrl-alt-f7": "editor::FindAllReferences", - "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock - "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleRightDock + // "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock + // "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleLeftDock "ctrl-shift-b": "editor::GoToTypeDefinition", "ctrl-alt-shift-b": "editor::GoToTypeDefinitionSplit", "f2": "editor::GoToDiagnostic", @@ -101,27 +100,12 @@ "shift shift": "command_palette::Toggle", "ctrl-alt-shift-n": "project_symbols::Toggle", "alt-0": "git_panel::ToggleFocus", - "alt-1": "project_panel::ToggleFocus", + "alt-1": "workspace::ToggleLeftDock", "alt-5": "debug_panel::ToggleFocus", "alt-6": "diagnostics::Deploy", "alt-7": "outline_panel::ToggleFocus" } }, - { - "context": "Pane", // this is to override the default Pane mappings to switch tabs - "bindings": { - "alt-1": "project_panel::ToggleFocus", - "alt-2": null, // Bookmarks (left dock) - "alt-3": null, // Find Panel (bottom dock) - "alt-4": null, // Run Panel (bottom dock) - "alt-5": "debug_panel::ToggleFocus", - "alt-6": "diagnostics::Deploy", - "alt-7": "outline_panel::ToggleFocus", - "alt-8": null, // Services (bottom dock) - "alt-9": null, // Git History (bottom dock) - "alt-0": "git_panel::ToggleFocus" - } - }, { "context": "Workspace || Editor", "bindings": { @@ -167,9 +151,6 @@ { "context": "OutlinePanel", "bindings": { "alt-7": "workspace::CloseActiveDock" } }, { "context": "Dock || Workspace || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", - "bindings": { - "escape": "editor::ToggleFocus", - "shift-escape": "workspace::CloseActiveDock" - } + "bindings": { "escape": "editor::ToggleFocus" } } ] diff --git a/assets/keymaps/macos/cursor.json b/assets/keymaps/macos/cursor.json index fdf9c437cf..b1d39bef9e 100644 --- a/assets/keymaps/macos/cursor.json +++ b/assets/keymaps/macos/cursor.json @@ -8,7 +8,7 @@ "cmd-shift-i": "agent::ToggleFocus", "cmd-l": "agent::ToggleFocus", "cmd-shift-l": "agent::ToggleFocus", - "cmd-shift-j": "agent::OpenSettings" + "cmd-shift-j": "agent::OpenConfiguration" } }, { diff --git a/assets/keymaps/macos/jetbrains.json b/assets/keymaps/macos/jetbrains.json index 66962811f4..a8d11835e6 100644 --- a/assets/keymaps/macos/jetbrains.json +++ b/assets/keymaps/macos/jetbrains.json @@ -4,7 +4,6 @@ "cmd-{": "pane::ActivatePreviousItem", "cmd-}": "pane::ActivateNextItem", "cmd-0": "git_panel::ToggleFocus", // overrides `cmd-0` zoom reset - "shift-escape": null, // Unmap workspace::zoom "ctrl-f2": "debugger::Stop", "f6": "debugger::Pause", "f7": "debugger::StepInto", @@ -109,21 +108,6 @@ "cmd-7": "outline_panel::ToggleFocus" } }, - { - "context": "Pane", // this is to override the default Pane mappings to switch tabs - "bindings": { - "cmd-1": "project_panel::ToggleFocus", - "cmd-2": null, // Bookmarks (left dock) - "cmd-3": null, // Find Panel (bottom dock) - "cmd-4": null, // Run Panel (bottom dock) - "cmd-5": "debug_panel::ToggleFocus", - "cmd-6": "diagnostics::Deploy", - "cmd-7": "outline_panel::ToggleFocus", - "cmd-8": null, // Services (bottom dock) - "cmd-9": null, // Git History (bottom dock) - "cmd-0": "git_panel::ToggleFocus" - } - }, { "context": "Workspace || Editor", "bindings": { @@ -162,15 +146,11 @@ } }, { "context": "GitPanel", "bindings": { "cmd-0": "workspace::CloseActiveDock" } }, - { "context": "ProjectPanel", "bindings": { "cmd-1": "workspace::CloseActiveDock" } }, { "context": "DebugPanel", "bindings": { "cmd-5": "workspace::CloseActiveDock" } }, { "context": "Diagnostics > Editor", "bindings": { "cmd-6": "pane::CloseActiveItem" } }, { "context": "OutlinePanel", "bindings": { "cmd-7": "workspace::CloseActiveDock" } }, { "context": "Dock || Workspace || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", - "bindings": { - "escape": "editor::ToggleFocus", - "shift-escape": "workspace::CloseActiveDock" - } + "bindings": { "escape": "editor::ToggleFocus" } } ] diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index 6458ac1510..d0cf4621a5 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -220,8 +220,6 @@ { "context": "vim_mode == normal", "bindings": { - "i": "vim::InsertBefore", - "a": "vim::InsertAfter", "ctrl-[": "editor::Cancel", ":": "command_palette::Toggle", "c": "vim::PushChange", @@ -355,7 +353,9 @@ "shift-d": "vim::DeleteToEndOfLine", "shift-j": "vim::JoinLines", "shift-y": "vim::YankLine", + "i": "vim::InsertBefore", "shift-i": "vim::InsertFirstNonWhitespace", + "a": "vim::InsertAfter", "shift-a": "vim::InsertEndOfLine", "o": "vim::InsertLineBelow", "shift-o": "vim::InsertLineAbove", @@ -377,8 +377,6 @@ { "context": "vim_mode == helix_normal && !menu", "bindings": { - "i": "vim::HelixInsert", - "a": "vim::HelixAppend", "ctrl-[": "editor::Cancel", ";": "vim::HelixCollapseSelection", ":": "command_palette::Toggle", diff --git a/assets/settings/default.json b/assets/settings/default.json index 4734b5d118..dab1684aef 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -691,10 +691,7 @@ // 5. Never show the scrollbar: // "never" "show": null - }, - // Default depth to expand outline items in the current file. - // Set to 0 to collapse all items that have children, 1 or higher to collapse items at that depth or deeper. - "expand_outlines_with_depth": 100 + } }, "collaboration_panel": { // Whether to show the collaboration panel button in the status bar. @@ -1877,25 +1874,5 @@ "save_breakpoints": true, "dock": "bottom", "button": true - }, - // Configures any number of settings profiles that are temporarily applied on - // top of your existing user settings when selected from - // `settings profile selector: toggle`. - // Examples: - // "profiles": { - // "Presenting": { - // "agent_font_size": 20.0, - // "buffer_font_size": 20.0, - // "theme": "One Light", - // "ui_font_size": 20.0 - // }, - // "Python (ty)": { - // "languages": { - // "Python": { - // "language_servers": ["ty"] - // } - // } - // } - // } - "profiles": [] + } } diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 225597415c..b44c25ccc9 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -16,7 +16,7 @@ doctest = false test-support = ["gpui/test-support", "project/test-support"] [dependencies] -agent-client-protocol.workspace = true +agentic-coding-protocol.workspace = true anyhow.workspace = true assistant_tool.workspace = true buffer_diff.workspace = true @@ -36,12 +36,11 @@ util.workspace = true workspace-hack.workspace = true [dev-dependencies] +async-pipe.workspace = true env_logger.workspace = true gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true -parking_lot.workspace = true project = { workspace = true, "features" = ["test-support"] } -rand.workspace = true tempfile.workspace = true util.workspace = true settings.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 079a207358..9af1eeb187 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,13 +1,17 @@ mod connection; pub use connection::*; -use agent_client_protocol as acp; +pub use acp::ToolCallId; +use agentic_coding_protocol::{ + self as acp, AgentRequest, ProtocolVersion, ToolCallConfirmationOutcome, ToolCallLocation, + UserMessageChunk, +}; use anyhow::{Context as _, Result}; use assistant_tool::ActionLog; use buffer_diff::BufferDiff; use editor::{Bias, MultiBuffer, PathKey}; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; -use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; +use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use itertools::Itertools; use language::{ Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point, @@ -17,37 +21,46 @@ use markdown::Markdown; use project::{AgentLocation, Project}; use std::collections::HashMap; use std::error::Error; -use std::fmt::Formatter; -use std::rc::Rc; +use std::fmt::{Formatter, Write}; use std::{ fmt::Display, mem, path::{Path, PathBuf}, sync::Arc, }; -use ui::App; +use ui::{App, IconName}; use util::ResultExt; -#[derive(Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct UserMessage { - pub content: ContentBlock, + pub content: Entity, } impl UserMessage { pub fn from_acp( - message: impl IntoIterator, + message: &acp::SendUserMessageParams, language_registry: Arc, cx: &mut App, ) -> Self { - let mut content = ContentBlock::Empty; - for chunk in message { - content.append(chunk, &language_registry, cx) + let mut md_source = String::new(); + + for chunk in &message.chunks { + match chunk { + UserMessageChunk::Text { text } => md_source.push_str(&text), + UserMessageChunk::Path { path } => { + write!(&mut md_source, "{}", MentionPath(&path)).unwrap() + } + } + } + + Self { + content: cx + .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)), } - Self { content: content } } fn to_markdown(&self, cx: &App) -> String { - format!("## User\n\n{}\n\n", self.content.to_markdown(cx)) + format!("## User\n\n{}\n\n", self.content.read(cx).source()) } } @@ -83,7 +96,7 @@ impl Display for MentionPath<'_> { } } -#[derive(Debug, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct AssistantMessage { pub chunks: Vec, } @@ -100,24 +113,42 @@ impl AssistantMessage { } } -#[derive(Debug, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum AssistantMessageChunk { - Message { block: ContentBlock }, - Thought { block: ContentBlock }, + Text { chunk: Entity }, + Thought { chunk: Entity }, } impl AssistantMessageChunk { - pub fn from_str(chunk: &str, language_registry: &Arc, cx: &mut App) -> Self { - Self::Message { - block: ContentBlock::new(chunk.into(), language_registry, cx), + pub fn from_acp( + chunk: acp::AssistantMessageChunk, + language_registry: Arc, + cx: &mut App, + ) -> Self { + match chunk { + acp::AssistantMessageChunk::Text { text } => Self::Text { + chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)), + }, + acp::AssistantMessageChunk::Thought { thought } => Self::Thought { + chunk: cx + .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)), + }, + } + } + + pub fn from_str(chunk: &str, language_registry: Arc, cx: &mut App) -> Self { + Self::Text { + chunk: cx.new(|cx| { + Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx) + }), } } fn to_markdown(&self, cx: &App) -> String { match self { - Self::Message { block } => block.to_markdown(cx).to_string(), - Self::Thought { block } => { - format!("\n{}\n", block.to_markdown(cx)) + Self::Text { chunk } => chunk.read(cx).source().to_string(), + Self::Thought { chunk } => { + format!("\n{}\n", chunk.read(cx).source()) } } } @@ -135,15 +166,19 @@ impl AgentThreadEntry { match self { Self::UserMessage(message) => message.to_markdown(cx), Self::AssistantMessage(message) => message.to_markdown(cx), - Self::ToolCall(tool_call) => tool_call.to_markdown(cx), + Self::ToolCall(too_call) => too_call.to_markdown(cx), } } - pub fn diffs(&self) -> impl Iterator { - if let AgentThreadEntry::ToolCall(call) = self { - itertools::Either::Left(call.diffs()) + pub fn diff(&self) -> Option<&Diff> { + if let AgentThreadEntry::ToolCall(ToolCall { + content: Some(ToolCallContent::Diff { diff }), + .. + }) = self + { + Some(&diff) } else { - itertools::Either::Right(std::iter::empty()) + None } } @@ -160,99 +195,20 @@ impl AgentThreadEntry { pub struct ToolCall { pub id: acp::ToolCallId, pub label: Entity, - pub kind: acp::ToolKind, - pub content: Vec, + pub icon: IconName, + pub content: Option, pub status: ToolCallStatus, pub locations: Vec, - pub raw_input: Option, } impl ToolCall { - fn from_acp( - tool_call: acp::ToolCall, - status: ToolCallStatus, - language_registry: Arc, - cx: &mut App, - ) -> Self { - Self { - id: tool_call.id, - label: cx.new(|cx| { - Markdown::new( - tool_call.label.into(), - Some(language_registry.clone()), - None, - cx, - ) - }), - kind: tool_call.kind, - content: tool_call - .content - .into_iter() - .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx)) - .collect(), - locations: tool_call.locations, - status, - raw_input: tool_call.raw_input, - } - } - - fn update( - &mut self, - fields: acp::ToolCallUpdateFields, - language_registry: Arc, - cx: &mut App, - ) { - let acp::ToolCallUpdateFields { - kind, - status, - label, - content, - locations, - raw_input, - } = fields; - - if let Some(kind) = kind { - self.kind = kind; - } - - if let Some(status) = status { - self.status = ToolCallStatus::Allowed { status }; - } - - if let Some(label) = label { - self.label = cx.new(|cx| Markdown::new_text(label.into(), cx)); - } - - if let Some(content) = content { - self.content = content - .into_iter() - .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx)) - .collect(); - } - - if let Some(locations) = locations { - self.locations = locations; - } - - if let Some(raw_input) = raw_input { - self.raw_input = Some(raw_input); - } - } - - pub fn diffs(&self) -> impl Iterator { - self.content.iter().filter_map(|content| match content { - ToolCallContent::ContentBlock { .. } => None, - ToolCallContent::Diff { diff } => Some(diff), - }) - } - fn to_markdown(&self, cx: &App) -> String { let mut markdown = format!( "**Tool Call: {}**\nStatus: {}\n\n", self.label.read(cx).source(), self.status ); - for content in &self.content { + if let Some(content) = &self.content { markdown.push_str(content.to_markdown(cx).as_str()); markdown.push_str("\n\n"); } @@ -263,8 +219,8 @@ impl ToolCall { #[derive(Debug)] pub enum ToolCallStatus { WaitingForConfirmation { - options: Vec, - respond_tx: oneshot::Sender, + confirmation: ToolCallConfirmation, + respond_tx: oneshot::Sender, }, Allowed { status: acp::ToolCallStatus, @@ -281,10 +237,9 @@ impl Display for ToolCallStatus { match self { ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation", ToolCallStatus::Allowed { status } => match status { - acp::ToolCallStatus::Pending => "Pending", - acp::ToolCallStatus::InProgress => "In Progress", - acp::ToolCallStatus::Completed => "Completed", - acp::ToolCallStatus::Failed => "Failed", + acp::ToolCallStatus::Running => "Running", + acp::ToolCallStatus::Finished => "Finished", + acp::ToolCallStatus::Error => "Error", }, ToolCallStatus::Rejected => "Rejected", ToolCallStatus::Canceled => "Canceled", @@ -293,92 +248,86 @@ impl Display for ToolCallStatus { } } -#[derive(Debug, PartialEq, Clone)] -pub enum ContentBlock { - Empty, - Markdown { markdown: Entity }, +#[derive(Debug)] +pub enum ToolCallConfirmation { + Edit { + description: Option>, + }, + Execute { + command: String, + root_command: String, + description: Option>, + }, + Mcp { + server_name: String, + tool_name: String, + tool_display_name: String, + description: Option>, + }, + Fetch { + urls: Vec, + description: Option>, + }, + Other { + description: Entity, + }, } -impl ContentBlock { - pub fn new( - block: acp::ContentBlock, - language_registry: &Arc, - cx: &mut App, - ) -> Self { - let mut this = Self::Empty; - this.append(block, language_registry, cx); - this - } - - pub fn new_combined( - blocks: impl IntoIterator, +impl ToolCallConfirmation { + pub fn from_acp( + confirmation: acp::ToolCallConfirmation, language_registry: Arc, cx: &mut App, ) -> Self { - let mut this = Self::Empty; - for block in blocks { - this.append(block, &language_registry, cx); - } - this - } - - pub fn append( - &mut self, - block: acp::ContentBlock, - language_registry: &Arc, - cx: &mut App, - ) { - let new_content = match block { - acp::ContentBlock::Text(text_content) => text_content.text.clone(), - acp::ContentBlock::ResourceLink(resource_link) => { - if let Some(path) = resource_link.uri.strip_prefix("file://") { - format!("{}", MentionPath(path.as_ref())) - } else { - resource_link.uri.clone() - } - } - acp::ContentBlock::Image(_) - | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) => String::new(), + let to_md = |description: String, cx: &mut App| -> Entity { + cx.new(|cx| { + Markdown::new( + description.into(), + Some(language_registry.clone()), + None, + cx, + ) + }) }; - match self { - ContentBlock::Empty => { - *self = ContentBlock::Markdown { - markdown: cx.new(|cx| { - Markdown::new( - new_content.into(), - Some(language_registry.clone()), - None, - cx, - ) - }), - }; - } - ContentBlock::Markdown { markdown } => { - markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx)); - } - } - } - - fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str { - match self { - ContentBlock::Empty => "", - ContentBlock::Markdown { markdown } => markdown.read(cx).source(), - } - } - - pub fn markdown(&self) -> Option<&Entity> { - match self { - ContentBlock::Empty => None, - ContentBlock::Markdown { markdown } => Some(markdown), + match confirmation { + acp::ToolCallConfirmation::Edit { description } => Self::Edit { + description: description.map(|description| to_md(description, cx)), + }, + acp::ToolCallConfirmation::Execute { + command, + root_command, + description, + } => Self::Execute { + command, + root_command, + description: description.map(|description| to_md(description, cx)), + }, + acp::ToolCallConfirmation::Mcp { + server_name, + tool_name, + tool_display_name, + description, + } => Self::Mcp { + server_name, + tool_name, + tool_display_name, + description: description.map(|description| to_md(description, cx)), + }, + acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch { + urls: urls.iter().map(|url| url.into()).collect(), + description: description.map(|description| to_md(description, cx)), + }, + acp::ToolCallConfirmation::Other { description } => Self::Other { + description: to_md(description, cx), + }, } } } #[derive(Debug)] pub enum ToolCallContent { - ContentBlock { content: ContentBlock }, + Markdown { markdown: Entity }, Diff { diff: Diff }, } @@ -389,8 +338,8 @@ impl ToolCallContent { cx: &mut App, ) -> Self { match content { - acp::ToolCallContent::Content { content } => Self::ContentBlock { - content: ContentBlock::new(content, &language_registry, cx), + acp::ToolCallContent::Markdown { markdown } => Self::Markdown { + markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)), }, acp::ToolCallContent::Diff { diff } => Self::Diff { diff: Diff::from_acp(diff, language_registry, cx), @@ -398,9 +347,9 @@ impl ToolCallContent { } } - pub fn to_markdown(&self, cx: &App) -> String { + fn to_markdown(&self, cx: &App) -> String { match self { - Self::ContentBlock { content } => content.to_markdown(cx).to_string(), + Self::Markdown { markdown } => markdown.read(cx).source().to_string(), Self::Diff { diff } => diff.to_markdown(cx), } } @@ -571,16 +520,13 @@ pub struct AcpThread { action_log: Entity, shared_buffers: HashMap, BufferSnapshot>, send_task: Option>, - connection: Rc, - session_id: acp::SessionId, + connection: Arc, + child_status: Option>>, } pub enum AcpThreadEvent { NewEntry, EntryUpdated(usize), - ToolAuthorizationRequired, - Stopped, - Error, } impl EventEmitter for AcpThread {} @@ -617,10 +563,10 @@ impl Error for LoadError {} impl AcpThread { pub fn new( - title: impl Into, - connection: Rc, + connection: impl AgentConnection + 'static, + title: SharedString, + child_status: Option>>, project: Entity, - session_id: acp::SessionId, cx: &mut Context, ) -> Self { let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -630,11 +576,24 @@ impl AcpThread { shared_buffers: Default::default(), entries: Default::default(), plan: Default::default(), - title: title.into(), + title, project, send_task: None, - connection, - session_id, + connection: Arc::new(connection), + child_status, + } + } + + /// Send a request to the agent and wait for a response. + pub fn request( + &self, + params: R, + ) -> impl use + Future> { + let params = params.into_any(); + let result = self.connection.request_any(params); + async move { + let result = result.await?; + Ok(R::response_from_any(result)?) } } @@ -670,18 +629,15 @@ impl AcpThread { for entry in self.entries.iter().rev() { match entry { AgentThreadEntry::UserMessage(_) => return false, - AgentThreadEntry::ToolCall( - call @ ToolCall { - status: - ToolCallStatus::Allowed { - status: - acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending, - }, - .. - }, - ) if call.diffs().next().is_some() => { - return true; - } + AgentThreadEntry::ToolCall(ToolCall { + status: + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Running, + .. + }, + content: Some(ToolCallContent::Diff { .. }), + .. + }) => return true, AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {} } } @@ -689,94 +645,49 @@ impl AcpThread { false } - pub fn used_tools_since_last_user_message(&self) -> bool { - for entry in self.entries.iter().rev() { - match entry { - AgentThreadEntry::UserMessage(..) => return false, - AgentThreadEntry::AssistantMessage(..) => continue, - AgentThreadEntry::ToolCall(..) => return true, - } - } - - false + pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { + self.entries.push(entry); + cx.emit(AcpThreadEvent::NewEntry); } - pub fn handle_session_update( + pub fn push_assistant_chunk( &mut self, - update: acp::SessionUpdate, - cx: &mut Context, - ) -> Result<()> { - match update { - acp::SessionUpdate::UserMessageChunk { content } => { - self.push_user_content_block(content, cx); - } - acp::SessionUpdate::AgentMessageChunk { content } => { - self.push_assistant_content_block(content, false, cx); - } - acp::SessionUpdate::AgentThoughtChunk { content } => { - self.push_assistant_content_block(content, true, cx); - } - acp::SessionUpdate::ToolCall(tool_call) => { - self.upsert_tool_call(tool_call, cx); - } - acp::SessionUpdate::ToolCallUpdate(tool_call_update) => { - self.update_tool_call(tool_call_update, cx)?; - } - acp::SessionUpdate::Plan(plan) => { - self.update_plan(plan, cx); - } - } - Ok(()) - } - - pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context) { - let language_registry = self.project.read(cx).languages().clone(); - let entries_len = self.entries.len(); - - if let Some(last_entry) = self.entries.last_mut() - && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry - { - content.append(chunk, &language_registry, cx); - cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); - } else { - let content = ContentBlock::new(chunk, &language_registry, cx); - self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx); - } - } - - pub fn push_assistant_content_block( - &mut self, - chunk: acp::ContentBlock, - is_thought: bool, + chunk: acp::AssistantMessageChunk, cx: &mut Context, ) { - let language_registry = self.project.read(cx).languages().clone(); let entries_len = self.entries.len(); if let Some(last_entry) = self.entries.last_mut() && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry { cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); - match (chunks.last_mut(), is_thought) { - (Some(AssistantMessageChunk::Message { block }), false) - | (Some(AssistantMessageChunk::Thought { block }), true) => { - block.append(chunk, &language_registry, cx) + + match (chunks.last_mut(), &chunk) { + ( + Some(AssistantMessageChunk::Text { chunk: old_chunk }), + acp::AssistantMessageChunk::Text { text: new_chunk }, + ) + | ( + Some(AssistantMessageChunk::Thought { chunk: old_chunk }), + acp::AssistantMessageChunk::Thought { thought: new_chunk }, + ) => { + old_chunk.update(cx, |old_chunk, cx| { + old_chunk.append(&new_chunk, cx); + }); } _ => { - let block = ContentBlock::new(chunk, &language_registry, cx); - if is_thought { - chunks.push(AssistantMessageChunk::Thought { block }) - } else { - chunks.push(AssistantMessageChunk::Message { block }) - } + chunks.push(AssistantMessageChunk::from_acp( + chunk, + self.project.read(cx).languages().clone(), + cx, + )); } } } else { - let block = ContentBlock::new(chunk, &language_registry, cx); - let chunk = if is_thought { - AssistantMessageChunk::Thought { block } - } else { - AssistantMessageChunk::Message { block } - }; + let chunk = AssistantMessageChunk::from_acp( + chunk, + self.project.read(cx).languages().clone(), + cx, + ); self.push_entry( AgentThreadEntry::AssistantMessage(AssistantMessage { @@ -787,79 +698,212 @@ impl AcpThread { } } - fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { - self.entries.push(entry); - cx.emit(AcpThreadEvent::NewEntry); + pub fn request_new_tool_call( + &mut self, + tool_call: acp::RequestToolCallConfirmationParams, + cx: &mut Context, + ) -> ToolCallRequest { + let (tx, rx) = oneshot::channel(); + + let status = ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::from_acp( + tool_call.confirmation, + self.project.read(cx).languages().clone(), + cx, + ), + respond_tx: tx, + }; + + let id = self.insert_tool_call(tool_call.tool_call, status, cx); + ToolCallRequest { id, outcome: rx } + } + + pub fn request_tool_call_confirmation( + &mut self, + tool_call_id: ToolCallId, + confirmation: acp::ToolCallConfirmation, + cx: &mut Context, + ) -> Result { + let project = self.project.read(cx).languages().clone(); + let Some((idx, call)) = self.tool_call_mut(tool_call_id) else { + anyhow::bail!("Tool call not found"); + }; + + let (tx, rx) = oneshot::channel(); + + call.status = ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx), + respond_tx: tx, + }; + + cx.emit(AcpThreadEvent::EntryUpdated(idx)); + + Ok(ToolCallRequest { + id: tool_call_id, + outcome: rx, + }) + } + + pub fn push_tool_call( + &mut self, + request: acp::PushToolCallParams, + cx: &mut Context, + ) -> acp::ToolCallId { + let status = ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Running, + }; + + self.insert_tool_call(request, status, cx) + } + + fn insert_tool_call( + &mut self, + tool_call: acp::PushToolCallParams, + status: ToolCallStatus, + cx: &mut Context, + ) -> acp::ToolCallId { + let language_registry = self.project.read(cx).languages().clone(); + let id = acp::ToolCallId(self.entries.len() as u64); + let call = ToolCall { + id, + label: cx.new(|cx| { + Markdown::new( + tool_call.label.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + icon: acp_icon_to_ui_icon(tool_call.icon), + content: tool_call + .content + .map(|content| ToolCallContent::from_acp(content, language_registry, cx)), + locations: tool_call.locations, + status, + }; + + let location = call.locations.last().cloned(); + if let Some(location) = location { + self.set_project_location(location, cx) + } + + self.push_entry(AgentThreadEntry::ToolCall(call), cx); + + id + } + + pub fn authorize_tool_call( + &mut self, + id: acp::ToolCallId, + outcome: acp::ToolCallConfirmationOutcome, + cx: &mut Context, + ) { + let Some((ix, call)) = self.tool_call_mut(id) else { + return; + }; + + let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject { + ToolCallStatus::Rejected + } else { + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Running, + } + }; + + let curr_status = mem::replace(&mut call.status, new_status); + + if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { + respond_tx.send(outcome).log_err(); + } else if cfg!(debug_assertions) { + panic!("tried to authorize an already authorized tool call"); + } + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); } pub fn update_tool_call( &mut self, - update: acp::ToolCallUpdate, + id: acp::ToolCallId, + new_status: acp::ToolCallStatus, + new_content: Option, cx: &mut Context, ) -> Result<()> { - let languages = self.project.read(cx).languages().clone(); - - let (ix, current_call) = self - .tool_call_mut(&update.id) - .context("Tool call not found")?; - current_call.update(update.fields, languages, cx); - - cx.emit(AcpThreadEvent::EntryUpdated(ix)); - - Ok(()) - } - - /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. - pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { - let status = ToolCallStatus::Allowed { - status: tool_call.status, - }; - self.upsert_tool_call_inner(tool_call, status, cx) - } - - pub fn upsert_tool_call_inner( - &mut self, - tool_call: acp::ToolCall, - status: ToolCallStatus, - cx: &mut Context, - ) { let language_registry = self.project.read(cx).languages().clone(); - let call = ToolCall::from_acp(tool_call, status, language_registry, cx); + let (ix, call) = self.tool_call_mut(id).context("Entry not found")?; - let location = call.locations.last().cloned(); - - if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { - *current_call = call; - - cx.emit(AcpThreadEvent::EntryUpdated(ix)); - } else { - self.push_entry(AgentThreadEntry::ToolCall(call), cx); + if let Some(new_content) = new_content { + call.content = Some(ToolCallContent::from_acp( + new_content, + language_registry, + cx, + )); } + match &mut call.status { + ToolCallStatus::Allowed { status } => { + *status = new_status; + } + ToolCallStatus::WaitingForConfirmation { .. } => { + anyhow::bail!("Tool call hasn't been authorized yet") + } + ToolCallStatus::Rejected => { + anyhow::bail!("Tool call was rejected and therefore can't be updated") + } + ToolCallStatus::Canceled => { + call.status = ToolCallStatus::Allowed { status: new_status }; + } + } + + let location = call.locations.last().cloned(); if let Some(location) = location { self.set_project_location(location, cx) } + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + Ok(()) } - fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { - // The tool call we are looking for is typically the last one, or very close to the end. - // At the moment, it doesn't seem like a hashmap would be a good fit for this use case. - self.entries - .iter_mut() - .enumerate() - .rev() - .find_map(|(index, tool_call)| { - if let AgentThreadEntry::ToolCall(tool_call) = tool_call - && &tool_call.id == id - { - Some((index, tool_call)) - } else { - None + fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { + let entry = self.entries.get_mut(id.0 as usize); + debug_assert!( + entry.is_some(), + "We shouldn't give out ids to entries that don't exist" + ); + match entry { + Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)), + _ => { + if cfg!(debug_assertions) { + panic!("entry is not a tool call"); } - }) + None + } + } } - pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context) { + pub fn plan(&self) -> &Plan { + &self.plan + } + + pub fn update_plan(&mut self, request: acp::UpdatePlanParams, cx: &mut Context) { + self.plan = Plan { + entries: request + .entries + .into_iter() + .map(|entry| PlanEntry::from_acp(entry, cx)) + .collect(), + }; + + cx.notify(); + } + + pub fn clear_completed_plan_entries(&mut self, cx: &mut Context) { + self.plan + .entries + .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed)); + cx.notify(); + } + + pub fn set_project_location(&self, location: ToolCallLocation, cx: &mut Context) { self.project.update(cx, |project, cx| { let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else { return; @@ -890,57 +934,6 @@ impl AcpThread { }); } - pub fn request_tool_call_permission( - &mut self, - tool_call: acp::ToolCall, - options: Vec, - cx: &mut Context, - ) -> oneshot::Receiver { - let (tx, rx) = oneshot::channel(); - - let status = ToolCallStatus::WaitingForConfirmation { - options, - respond_tx: tx, - }; - - self.upsert_tool_call_inner(tool_call, status, cx); - cx.emit(AcpThreadEvent::ToolAuthorizationRequired); - rx - } - - pub fn authorize_tool_call( - &mut self, - id: acp::ToolCallId, - option_id: acp::PermissionOptionId, - option_kind: acp::PermissionOptionKind, - cx: &mut Context, - ) { - let Some((ix, call)) = self.tool_call_mut(&id) else { - return; - }; - - let new_status = match option_kind { - acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => { - ToolCallStatus::Rejected - } - acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => { - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::InProgress, - } - } - }; - - let curr_status = mem::replace(&mut call.status, new_status); - - if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { - respond_tx.send(option_id).log_err(); - } else if cfg!(debug_assertions) { - panic!("tried to authorize an already authorized tool call"); - } - - cx.emit(AcpThreadEvent::EntryUpdated(ix)); - } - /// Returns true if the last turn is awaiting tool authorization pub fn waiting_for_tool_confirmation(&self) -> bool { for entry in self.entries.iter().rev() { @@ -960,27 +953,14 @@ impl AcpThread { false } - pub fn plan(&self) -> &Plan { - &self.plan + pub fn initialize(&self) -> impl use<> + Future> { + self.request(acp::InitializeParams { + protocol_version: ProtocolVersion::latest(), + }) } - pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context) { - self.plan = Plan { - entries: request - .entries - .into_iter() - .map(|entry| PlanEntry::from_acp(entry, cx)) - .collect(), - }; - - cx.notify(); - } - - fn clear_completed_plan_entries(&mut self, cx: &mut Context) { - self.plan - .entries - .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed)); - cx.notify(); + pub fn authenticate(&self) -> impl use<> + Future> { + self.request(acp::AuthenticateParams) } #[cfg(any(test, feature = "test-support"))] @@ -988,50 +968,39 @@ impl AcpThread { &mut self, message: &str, cx: &mut Context, - ) -> BoxFuture<'static, Result<()>> { + ) -> BoxFuture<'static, Result<(), acp::Error>> { self.send( - vec![acp::ContentBlock::Text(acp::TextContent { - text: message.to_string(), - annotations: None, - })], + acp::SendUserMessageParams { + chunks: vec![acp::UserMessageChunk::Text { + text: message.to_string(), + }], + }, cx, ) } pub fn send( &mut self, - message: Vec, + message: acp::SendUserMessageParams, cx: &mut Context, - ) -> BoxFuture<'static, Result<()>> { - let block = ContentBlock::new_combined( - message.clone(), - self.project.read(cx).languages().clone(), - cx, - ); + ) -> BoxFuture<'static, Result<(), acp::Error>> { self.push_entry( - AgentThreadEntry::UserMessage(UserMessage { content: block }), + AgentThreadEntry::UserMessage(UserMessage::from_acp( + &message, + self.project.read(cx).languages().clone(), + cx, + )), cx, ); - self.clear_completed_plan_entries(cx); let (tx, rx) = oneshot::channel(); - let cancel_task = self.cancel(cx); + let cancel = self.cancel(cx); self.send_task = Some(cx.spawn(async move |this, cx| { async { - cancel_task.await; + cancel.await.log_err(); - let result = this - .update(cx, |this, cx| { - this.connection.prompt( - acp::PromptRequest { - prompt: message, - session_id: this.session_id.clone(), - }, - cx, - ) - })? - .await; + let result = this.update(cx, |this, _| this.request(message))?.await; tx.send(result).log_err(); this.update(cx, |this, _cx| this.send_task.take())?; anyhow::Ok(()) @@ -1040,53 +1009,57 @@ impl AcpThread { .log_err(); })); - cx.spawn(async move |this, cx| match rx.await { - Ok(Err(e)) => { - this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error)) - .log_err(); - Err(e)? + async move { + match rx.await { + Ok(Err(e)) => Err(e)?, + _ => Ok(()), } - _ => { - this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped)) - .log_err(); - Ok(()) - } - }) + } .boxed() } - pub fn cancel(&mut self, cx: &mut Context) -> Task<()> { - let Some(send_task) = self.send_task.take() else { - return Task::ready(()); - }; + pub fn cancel(&mut self, cx: &mut Context) -> Task> { + if self.send_task.take().is_some() { + let request = self.request(acp::CancelSendMessageParams); + cx.spawn(async move |this, cx| { + request.await?; + this.update(cx, |this, _cx| { + for entry in this.entries.iter_mut() { + if let AgentThreadEntry::ToolCall(call) = entry { + let cancel = matches!( + call.status, + ToolCallStatus::WaitingForConfirmation { .. } + | ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Running + } + ); - for entry in self.entries.iter_mut() { - if let AgentThreadEntry::ToolCall(call) = entry { - let cancel = matches!( - call.status, - ToolCallStatus::WaitingForConfirmation { .. } - | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::InProgress + if cancel { + let curr_status = + mem::replace(&mut call.status, ToolCallStatus::Canceled); + + if let ToolCallStatus::WaitingForConfirmation { + respond_tx, .. + } = curr_status + { + respond_tx + .send(acp::ToolCallConfirmationOutcome::Cancel) + .ok(); + } + } } - ); - - if cancel { - call.status = ToolCallStatus::Canceled; - } - } + } + })?; + Ok(()) + }) + } else { + Task::ready(Ok(())) } - - self.connection.cancel(&self.session_id, cx); - - // Wait for the send task to complete - cx.foreground_executor().spawn(send_task) } pub fn read_text_file( &self, - path: PathBuf, - line: Option, - limit: Option, + request: acp::ReadTextFileParams, reuse_shared_snapshot: bool, cx: &mut Context, ) -> Task> { @@ -1095,7 +1068,7 @@ impl AcpThread { cx.spawn(async move |this, cx| { let load = project.update(cx, |project, cx| { let path = project - .project_path_for_absolute_path(&path, cx) + .project_path_for_absolute_path(&request.path, cx) .context("invalid path")?; anyhow::Ok(project.open_buffer(path, cx)) }); @@ -1121,7 +1094,7 @@ impl AcpThread { let position = buffer .read(cx) .snapshot() - .anchor_before(Point::new(line.unwrap_or_default(), 0)); + .anchor_before(Point::new(request.line.unwrap_or_default(), 0)); project.set_agent_location( Some(AgentLocation { buffer: buffer.downgrade(), @@ -1137,11 +1110,11 @@ impl AcpThread { this.update(cx, |this, _| { let text = snapshot.text(); this.shared_buffers.insert(buffer.clone(), snapshot); - if line.is_none() && limit.is_none() { + if request.line.is_none() && request.limit.is_none() { return Ok(text); } - let limit = limit.unwrap_or(u32::MAX) as usize; - let Some(line) = line else { + let limit = request.limit.unwrap_or(u32::MAX) as usize; + let Some(line) = request.line else { return Ok(text.lines().take(limit).collect::()); }; @@ -1226,25 +1199,207 @@ impl AcpThread { }) } + pub fn child_status(&mut self) -> Option>> { + self.child_status.take() + } + pub fn to_markdown(&self, cx: &App) -> String { self.entries.iter().map(|e| e.to_markdown(cx)).collect() } } +#[derive(Clone)] +pub struct AcpClientDelegate { + thread: WeakEntity, + cx: AsyncApp, + // sent_buffer_versions: HashMap, HashMap>, +} + +impl AcpClientDelegate { + pub fn new(thread: WeakEntity, cx: AsyncApp) -> Self { + Self { thread, cx } + } + + pub async fn clear_completed_plan_entries(&self) -> Result<()> { + let cx = &mut self.cx.clone(); + cx.update(|cx| { + self.thread + .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx)) + })? + .context("Failed to update thread")?; + + Ok(()) + } + + pub async fn request_existing_tool_call_confirmation( + &self, + tool_call_id: ToolCallId, + confirmation: acp::ToolCallConfirmation, + ) -> Result { + let cx = &mut self.cx.clone(); + let ToolCallRequest { outcome, .. } = cx + .update(|cx| { + self.thread.update(cx, |thread, cx| { + thread.request_tool_call_confirmation(tool_call_id, confirmation, cx) + }) + })? + .context("Failed to update thread")??; + + Ok(outcome.await?) + } + + pub async fn read_text_file_reusing_snapshot( + &self, + request: acp::ReadTextFileParams, + ) -> Result { + let content = self + .cx + .update(|cx| { + self.thread + .update(cx, |thread, cx| thread.read_text_file(request, true, cx)) + })? + .context("Failed to update thread")? + .await?; + Ok(acp::ReadTextFileResponse { content }) + } +} + +impl acp::Client for AcpClientDelegate { + async fn stream_assistant_message_chunk( + &self, + params: acp::StreamAssistantMessageChunkParams, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread + .update(cx, |thread, cx| { + thread.push_assistant_chunk(params.chunk, cx) + }) + .ok(); + })?; + + Ok(()) + } + + async fn request_tool_call_confirmation( + &self, + request: acp::RequestToolCallConfirmationParams, + ) -> Result { + let cx = &mut self.cx.clone(); + let ToolCallRequest { id, outcome } = cx + .update(|cx| { + self.thread + .update(cx, |thread, cx| thread.request_new_tool_call(request, cx)) + })? + .context("Failed to update thread")?; + + Ok(acp::RequestToolCallConfirmationResponse { + id, + outcome: outcome.await.map_err(acp::Error::into_internal_error)?, + }) + } + + async fn push_tool_call( + &self, + request: acp::PushToolCallParams, + ) -> Result { + let cx = &mut self.cx.clone(); + let id = cx + .update(|cx| { + self.thread + .update(cx, |thread, cx| thread.push_tool_call(request, cx)) + })? + .context("Failed to update thread")?; + + Ok(acp::PushToolCallResponse { id }) + } + + async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread.update(cx, |thread, cx| { + thread.update_tool_call(request.tool_call_id, request.status, request.content, cx) + }) + })? + .context("Failed to update thread")??; + + Ok(()) + } + + async fn update_plan(&self, request: acp::UpdatePlanParams) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread + .update(cx, |thread, cx| thread.update_plan(request, cx)) + })? + .context("Failed to update thread")?; + + Ok(()) + } + + async fn read_text_file( + &self, + request: acp::ReadTextFileParams, + ) -> Result { + let content = self + .cx + .update(|cx| { + self.thread + .update(cx, |thread, cx| thread.read_text_file(request, false, cx)) + })? + .context("Failed to update thread")? + .await?; + Ok(acp::ReadTextFileResponse { content }) + } + + async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> { + self.cx + .update(|cx| { + self.thread.update(cx, |thread, cx| { + thread.write_text_file(request.path, request.content, cx) + }) + })? + .context("Failed to update thread")? + .await?; + + Ok(()) + } +} + +fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName { + match icon { + acp::Icon::FileSearch => IconName::ToolSearch, + acp::Icon::Folder => IconName::ToolFolder, + acp::Icon::Globe => IconName::ToolWeb, + acp::Icon::Hammer => IconName::ToolHammer, + acp::Icon::LightBulb => IconName::ToolBulb, + acp::Icon::Pencil => IconName::ToolPencil, + acp::Icon::Regex => IconName::ToolRegex, + acp::Icon::Terminal => IconName::ToolTerminal, + } +} + +pub struct ToolCallRequest { + pub id: acp::ToolCallId, + pub outcome: oneshot::Receiver, +} + #[cfg(test)] mod tests { use super::*; use anyhow::anyhow; + use async_pipe::{PipeReader, PipeWriter}; use futures::{channel::mpsc, future::LocalBoxFuture, select}; - use gpui::{AsyncApp, TestAppContext, WeakEntity}; + use gpui::{AsyncApp, TestAppContext}; use indoc::indoc; use project::FakeFs; - use rand::Rng as _; use serde_json::json; use settings::SettingsStore; - use smol::stream::StreamExt as _; + use smol::{future::BoxedLocal, stream::StreamExt as _}; use std::{cell::RefCell, rc::Rc, time::Duration}; - use util::path; fn init_test(cx: &mut TestAppContext) { @@ -1258,133 +1413,39 @@ mod tests { } #[gpui::test] - async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) { + async fn test_thinking_concatenation(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; - let connection = Rc::new(FakeAgentConnection::new()); - let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) + let (thread, fake_server) = fake_acp_thread(project, cx); + + fake_server.update(cx, |fake_server, _| { + fake_server.on_user_message(move |_, server, mut cx| async move { + server + .update(&mut cx, |server, _| { + server.send_to_zed(acp::StreamAssistantMessageChunkParams { + chunk: acp::AssistantMessageChunk::Thought { + thought: "Thinking ".into(), + }, + }) + })? .await - }) - .await - .unwrap(); - - // Test creating a new user message - thread.update(cx, |thread, cx| { - thread.push_user_content_block( - acp::ContentBlock::Text(acp::TextContent { - annotations: None, - text: "Hello, ".to_string(), - }), - cx, - ); - }); - - thread.update(cx, |thread, cx| { - assert_eq!(thread.entries.len(), 1); - if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { - assert_eq!(user_msg.content.to_markdown(cx), "Hello, "); - } else { - panic!("Expected UserMessage"); - } - }); - - // Test appending to existing user message - thread.update(cx, |thread, cx| { - thread.push_user_content_block( - acp::ContentBlock::Text(acp::TextContent { - annotations: None, - text: "world!".to_string(), - }), - cx, - ); - }); - - thread.update(cx, |thread, cx| { - assert_eq!(thread.entries.len(), 1); - if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { - assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!"); - } else { - panic!("Expected UserMessage"); - } - }); - - // Test creating new user message after assistant message - thread.update(cx, |thread, cx| { - thread.push_assistant_content_block( - acp::ContentBlock::Text(acp::TextContent { - annotations: None, - text: "Assistant response".to_string(), - }), - false, - cx, - ); - }); - - thread.update(cx, |thread, cx| { - thread.push_user_content_block( - acp::ContentBlock::Text(acp::TextContent { - annotations: None, - text: "New user message".to_string(), - }), - cx, - ); - }); - - thread.update(cx, |thread, cx| { - assert_eq!(thread.entries.len(), 3); - if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] { - assert_eq!(user_msg.content.to_markdown(cx), "New user message"); - } else { - panic!("Expected UserMessage at index 2"); - } - }); - } - - #[gpui::test] - async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [], cx).await; - let connection = Rc::new(FakeAgentConnection::new().on_user_message( - |_, thread, mut cx| { - async move { - thread.update(&mut cx, |thread, cx| { - thread - .handle_session_update( - acp::SessionUpdate::AgentThoughtChunk { - content: "Thinking ".into(), - }, - cx, - ) - .unwrap(); - thread - .handle_session_update( - acp::SessionUpdate::AgentThoughtChunk { - content: "hard!".into(), - }, - cx, - ) - .unwrap(); - }) - } - .boxed_local() - }, - )); - - let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) + .unwrap(); + server + .update(&mut cx, |server, _| { + server.send_to_zed(acp::StreamAssistantMessageChunkParams { + chunk: acp::AssistantMessageChunk::Thought { + thought: "hard!".into(), + }, + }) + })? .await + .unwrap(); + + Ok(()) }) - .await - .unwrap(); + }); thread .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) @@ -1417,38 +1478,7 @@ mod tests { fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"})) .await; let project = Project::test(fs.clone(), [], cx).await; - let (read_file_tx, read_file_rx) = oneshot::channel::<()>(); - let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx))); - let connection = Rc::new(FakeAgentConnection::new().on_user_message( - move |_, thread, mut cx| { - let read_file_tx = read_file_tx.clone(); - async move { - let content = thread - .update(&mut cx, |thread, cx| { - thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx) - }) - .unwrap() - .await - .unwrap(); - assert_eq!(content, "one\ntwo\nthree\n"); - read_file_tx.take().unwrap().send(()).unwrap(); - thread - .update(&mut cx, |thread, cx| { - thread.write_text_file( - path!("/tmp/foo").into(), - "one\ntwo\nthree\nfour\nfive\n".to_string(), - cx, - ) - }) - .unwrap() - .await - .unwrap(); - Ok(()) - } - .boxed_local() - }, - )); - + let (thread, fake_server) = fake_acp_thread(project.clone(), cx); let (worktree, pathbuf) = project .update(cx, |project, cx| { project.find_or_create_worktree(path!("/tmp/foo"), true, cx) @@ -1462,10 +1492,38 @@ mod tests { .await .unwrap(); - let thread = cx - .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx)) - .await - .unwrap(); + let (read_file_tx, read_file_rx) = oneshot::channel::<()>(); + let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx))); + + fake_server.update(cx, |fake_server, _| { + fake_server.on_user_message(move |_, server, mut cx| { + let read_file_tx = read_file_tx.clone(); + async move { + let content = server + .update(&mut cx, |server, _| { + server.send_to_zed(acp::ReadTextFileParams { + path: path!("/tmp/foo").into(), + line: None, + limit: None, + }) + })? + .await + .unwrap(); + assert_eq!(content.content, "one\ntwo\nthree\n"); + read_file_tx.take().unwrap().send(()).unwrap(); + server + .update(&mut cx, |server, _| { + server.send_to_zed(acp::WriteTextFileParams { + path: path!("/tmp/foo").into(), + content: "one\ntwo\nthree\nfour\nfive\n".to_string(), + }) + })? + .await + .unwrap(); + Ok(()) + } + }) + }); let request = thread.update(cx, |thread, cx| { thread.send_raw("Extend the count in /tmp/foo", cx) @@ -1492,44 +1550,36 @@ mod tests { let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; - let id = acp::ToolCallId("test".into()); + let (thread, fake_server) = fake_acp_thread(project, cx); - let connection = Rc::new(FakeAgentConnection::new().on_user_message({ - let id = id.clone(); - move |_, thread, mut cx| { - let id = id.clone(); + let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>(); + + let tool_call_id = Rc::new(RefCell::new(None)); + let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx))); + fake_server.update(cx, |fake_server, _| { + let tool_call_id = tool_call_id.clone(); + fake_server.on_user_message(move |_, server, mut cx| { + let end_turn_rx = end_turn_rx.clone(); + let tool_call_id = tool_call_id.clone(); async move { - thread - .update(&mut cx, |thread, cx| { - thread.handle_session_update( - acp::SessionUpdate::ToolCall(acp::ToolCall { - id: id.clone(), - label: "Label".into(), - kind: acp::ToolKind::Fetch, - status: acp::ToolCallStatus::InProgress, - content: vec![], - locations: vec![], - raw_input: None, - }), - cx, - ) - }) - .unwrap() + let tool_call_result = server + .update(&mut cx, |server, _| { + server.send_to_zed(acp::PushToolCallParams { + label: "Fetch".to_string(), + icon: acp::Icon::Globe, + content: None, + locations: vec![], + }) + })? + .await .unwrap(); + *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id); + end_turn_rx.take().unwrap().await.ok(); + Ok(()) } - .boxed_local() - } - })); - - let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await }) - .await - .unwrap(); + }); let request = thread.update(cx, |thread, cx| { thread.send_raw("Fetch https://example.com", cx) @@ -1542,7 +1592,7 @@ mod tests { thread.entries[1], AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::InProgress, + status: acp::ToolCallStatus::Running, .. }, .. @@ -1550,7 +1600,12 @@ mod tests { )); }); - thread.update(cx, |thread, cx| thread.cancel(cx)).await; + cx.run_until_parked(); + + thread + .update(cx, |thread, cx| thread.cancel(cx)) + .await + .unwrap(); thread.read_with(cx, |thread, _| { assert!(matches!( @@ -1562,21 +1617,18 @@ mod tests { )); }); - thread - .update(cx, |thread, cx| { - thread.handle_session_update( - acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate { - id, - fields: acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::Completed), - ..Default::default() - }, - }), - cx, - ) + fake_server + .update(cx, |fake_server, _| { + fake_server.send_to_zed(acp::UpdateToolCallParams { + tool_call_id: tool_call_id.borrow().unwrap(), + status: acp::ToolCallStatus::Finished, + content: None, + }) }) + .await .unwrap(); + drop(end_turn_tx); request.await.unwrap(); thread.read_with(cx, |thread, _| { @@ -1584,7 +1636,7 @@ mod tests { thread.entries[1], AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Completed, + status: acp::ToolCallStatus::Finished, .. }, .. @@ -1593,56 +1645,6 @@ mod tests { }); } - #[gpui::test] - async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) { - init_test(cx); - let fs = FakeFs::new(cx.background_executor.clone()); - fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; - - let connection = Rc::new(FakeAgentConnection::new().on_user_message({ - move |_, thread, mut cx| { - async move { - thread - .update(&mut cx, |thread, cx| { - thread.handle_session_update( - acp::SessionUpdate::ToolCall(acp::ToolCall { - id: acp::ToolCallId("test".into()), - label: "Label".into(), - kind: acp::ToolKind::Edit, - status: acp::ToolCallStatus::Completed, - content: vec![acp::ToolCallContent::Diff { - diff: acp::Diff { - path: "/test/test.txt".into(), - old_text: None, - new_text: "foo".into(), - }, - }], - locations: vec![], - raw_input: None, - }), - cx, - ) - }) - .unwrap() - .unwrap(); - Ok(()) - } - .boxed_local() - } - })); - - let thread = connection - .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) - .await - .unwrap(); - cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) - .await - .unwrap(); - - assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); - } - async fn run_until_first_tool_call( thread: &Entity, cx: &mut TestAppContext, @@ -1670,108 +1672,140 @@ mod tests { } } - #[derive(Clone, Default)] - struct FakeAgentConnection { - auth_methods: Vec, - sessions: Arc>>>, + pub fn fake_acp_thread( + project: Entity, + cx: &mut TestAppContext, + ) -> (Entity, Entity) { + let (stdin_tx, stdin_rx) = async_pipe::pipe(); + let (stdout_tx, stdout_rx) = async_pipe::pipe(); + + let thread = cx.new(|cx| { + let foreground_executor = cx.foreground_executor().clone(); + let (connection, io_fut) = acp::AgentConnection::connect_to_agent( + AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), + stdin_tx, + stdout_rx, + move |fut| { + foreground_executor.spawn(fut).detach(); + }, + ); + + let io_task = cx.background_spawn({ + async move { + io_fut.await.log_err(); + Ok(()) + } + }); + AcpThread::new(connection, "Test".into(), Some(io_task), project, cx) + }); + let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx))); + (thread, agent) + } + + pub struct FakeAcpServer { + connection: acp::ClientConnection, + + _io_task: Task<()>, on_user_message: Option< Rc< dyn Fn( - acp::PromptRequest, - WeakEntity, - AsyncApp, - ) -> LocalBoxFuture<'static, Result<()>> - + 'static, + acp::SendUserMessageParams, + Entity, + AsyncApp, + ) -> LocalBoxFuture<'static, Result<(), acp::Error>>, >, >, } - impl FakeAgentConnection { - fn new() -> Self { - Self { - auth_methods: Vec::new(), - on_user_message: None, - sessions: Arc::default(), + #[derive(Clone)] + struct FakeAgent { + server: Entity, + cx: AsyncApp, + } + + impl acp::Agent for FakeAgent { + async fn initialize( + &self, + params: acp::InitializeParams, + ) -> Result { + Ok(acp::InitializeResponse { + protocol_version: params.protocol_version, + is_authenticated: true, + }) + } + + async fn authenticate(&self) -> Result<(), acp::Error> { + Ok(()) + } + + async fn cancel_send_message(&self) -> Result<(), acp::Error> { + Ok(()) + } + + async fn send_user_message( + &self, + request: acp::SendUserMessageParams, + ) -> Result<(), acp::Error> { + let mut cx = self.cx.clone(); + let handler = self + .server + .update(&mut cx, |server, _| server.on_user_message.clone()) + .ok() + .flatten(); + if let Some(handler) = handler { + handler(request, self.server.clone(), self.cx.clone()).await + } else { + Err(anyhow::anyhow!("No handler for on_user_message").into()) } } - - #[expect(unused)] - fn with_auth_methods(mut self, auth_methods: Vec) -> Self { - self.auth_methods = auth_methods; - self - } - - fn on_user_message( - mut self, - handler: impl Fn( - acp::PromptRequest, - WeakEntity, - AsyncApp, - ) -> LocalBoxFuture<'static, Result<()>> - + 'static, - ) -> Self { - self.on_user_message.replace(Rc::new(handler)); - self - } } - impl AgentConnection for FakeAgentConnection { - fn auth_methods(&self) -> &[acp::AuthMethod] { - &self.auth_methods - } + impl FakeAcpServer { + fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context) -> Self { + let agent = FakeAgent { + server: cx.entity(), + cx: cx.to_async(), + }; + let foreground_executor = cx.foreground_executor().clone(); - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut gpui::AsyncApp, - ) -> Task>> { - let session_id = acp::SessionId( - rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(7) - .map(char::from) - .collect::() - .into(), + let (connection, io_fut) = acp::ClientConnection::connect_to_client( + agent.clone(), + stdout, + stdin, + move |fut| { + foreground_executor.spawn(fut).detach(); + }, ); - let thread = cx - .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) - .unwrap(); - self.sessions.lock().insert(session_id, thread.downgrade()); - Task::ready(Ok(thread)) - } - - fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task> { - if self.auth_methods().iter().any(|m| m.id == method) { - Task::ready(Ok(())) - } else { - Task::ready(Err(anyhow!("Invalid Auth Method"))) + FakeAcpServer { + connection: connection, + on_user_message: None, + _io_task: cx.background_spawn(async move { + io_fut.await.log_err(); + }), } } - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { - let sessions = self.sessions.lock(); - let thread = sessions.get(¶ms.session_id).unwrap(); - if let Some(handler) = &self.on_user_message { - let handler = handler.clone(); - let thread = thread.clone(); - cx.spawn(async move |cx| handler(params, thread, cx.clone()).await) - } else { - Task::ready(Ok(())) - } + fn on_user_message( + &mut self, + handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity, AsyncApp) -> F + + 'static, + ) where + F: Future> + 'static, + { + self.on_user_message + .replace(Rc::new(move |request, server, cx| { + handler(request, server, cx).boxed_local() + })); } - fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { - let sessions = self.sessions.lock(); - let thread = sessions.get(&session_id).unwrap().clone(); - - cx.spawn(async move |cx| { - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .unwrap() - .await - }) - .detach(); + fn send_to_zed( + &self, + message: T, + ) -> BoxedLocal> { + self.connection + .request(message) + .map(|f| f.map_err(|err| anyhow!(err))) + .boxed_local() } } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 929500a67b..7c0ba4f41c 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,36 +1,20 @@ -use std::{error::Error, fmt, path::Path, rc::Rc}; - -use agent_client_protocol::{self as acp}; +use agentic_coding_protocol as acp; use anyhow::Result; -use gpui::{AsyncApp, Entity, Task}; -use project::Project; -use ui::App; - -use crate::AcpThread; +use futures::future::{FutureExt as _, LocalBoxFuture}; pub trait AgentConnection { - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>>; - - fn auth_methods(&self) -> &[acp::AuthMethod]; - - fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; - - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task>; - - fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + fn request_any( + &self, + params: acp::AnyAgentRequest, + ) -> LocalBoxFuture<'static, Result>; } -#[derive(Debug)] -pub struct AuthRequired; - -impl Error for AuthRequired {} -impl fmt::Display for AuthRequired { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "AuthRequired") +impl AgentConnection for acp::AgentConnection { + fn request_any( + &self, + params: acp::AnyAgentRequest, + ) -> LocalBoxFuture<'static, Result> { + let task = self.request_any(params); + async move { Ok(task.await?) }.boxed_local() } } diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 7bc0e82cad..135363ab65 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -25,7 +25,6 @@ assistant_context.workspace = true assistant_tool.workspace = true chrono.workspace = true client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true component.workspace = true context_server.workspace = true @@ -36,9 +35,9 @@ futures.workspace = true git.workspace = true gpui.workspace = true heed.workspace = true -http_client.workspace = true icons.workspace = true indoc.workspace = true +http_client.workspace = true itertools.workspace = true language.workspace = true language_model.workspace = true @@ -47,6 +46,7 @@ paths.workspace = true postage.workspace = true project.workspace = true prompt_store.workspace = true +proto.workspace = true ref-cast.workspace = true rope.workspace = true schemars.workspace = true @@ -63,6 +63,7 @@ time.workspace = true util.workspace = true uuid.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true zstd.workspace = true [dev-dependencies] diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index cd366b8308..ddd13de491 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -42,8 +42,8 @@ impl ContextKind { ContextKind::Symbol => IconName::Code, ContextKind::Selection => IconName::Context, ContextKind::FetchedUrl => IconName::Globe, - ContextKind::Thread => IconName::Thread, - ContextKind::TextThread => IconName::TextThread, + ContextKind::Thread => IconName::MessageBubbles, + ContextKind::TextThread => IconName::MessageBubbles, ContextKind::Rules => RULES_ICON, ContextKind::Image => IconName::Image, } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 8558dd528d..1b8aa012a1 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -13,7 +13,6 @@ use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; -use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; use collections::HashMap; use feature_flags::{self, FeatureFlagAppExt}; use futures::{FutureExt, StreamExt as _, future::Shared}; @@ -37,6 +36,7 @@ use project::{ git_store::{GitStore, GitStoreCheckpoint, RepositoryState}, }; use prompt_store::{ModelContext, PromptBuilder}; +use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -49,6 +49,7 @@ use std::{ use thiserror::Error; use util::{ResultExt as _, post_inc}; use uuid::Uuid; +use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; const MAX_RETRY_ATTEMPTS: u8 = 4; const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); @@ -1680,7 +1681,7 @@ impl Thread { let completion_mode = request .mode - .unwrap_or(cloud_llm_client::CompletionMode::Normal); + .unwrap_or(zed_llm_client::CompletionMode::Normal); self.last_received_chunk_at = Some(Instant::now()); @@ -3254,10 +3255,8 @@ impl Thread { } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { - self.project - .read(cx) - .user_store() - .update(cx, |user_store, cx| { + self.project.update(cx, |project, cx| { + project.user_store().update(cx, |user_store, cx| { user_store.update_model_request_usage( ModelRequestUsage(RequestUsage { amount: amount as i32, @@ -3265,7 +3264,8 @@ impl Thread { }), cx, ) - }); + }) + }); } pub fn deny_tool_use( diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 81c97c8aa6..4714245b94 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -18,19 +18,16 @@ doctest = false [dependencies] acp_thread.workspace = true -agent-client-protocol.workspace = true agentic-coding-protocol.workspace = true anyhow.workspace = true collections.workspace = true context_server.workspace = true futures.workspace = true gpui.workspace = true -indoc.workspace = true itertools.workspace = true log.workspace = true paths.workspace = true project.workspace = true -rand.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true @@ -38,7 +35,6 @@ settings.workspace = true smol.workspace = true strum.workspace = true tempfile.workspace = true -thiserror.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs deleted file mode 100644 index 00e3e3df50..0000000000 --- a/crates/agent_servers/src/acp.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::{path::Path, rc::Rc}; - -use crate::AgentServerCommand; -use acp_thread::AgentConnection; -use anyhow::Result; -use gpui::AsyncApp; -use thiserror::Error; - -mod v0; -mod v1; - -#[derive(Debug, Error)] -#[error("Unsupported version")] -pub struct UnsupportedVersion; - -pub async fn connect( - server_name: &'static str, - command: AgentServerCommand, - root_dir: &Path, - cx: &mut AsyncApp, -) -> Result> { - let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await; - - match conn { - Ok(conn) => Ok(Rc::new(conn) as _), - Err(err) if err.is::() => { - // Consider re-using initialize response and subprocess when adding another version here - let conn: Rc = - Rc::new(v0::AcpConnection::stdio(server_name, command, &root_dir, cx).await?); - Ok(conn) - } - Err(err) => Err(err), - } -} diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs deleted file mode 100644 index 6839ff2462..0000000000 --- a/crates/agent_servers/src/acp/v0.rs +++ /dev/null @@ -1,501 +0,0 @@ -// Translates old acp agents into the new schema -use agent_client_protocol as acp; -use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; -use anyhow::{Context as _, Result, anyhow}; -use futures::channel::oneshot; -use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; -use project::Project; -use std::{cell::RefCell, path::Path, rc::Rc}; -use ui::App; -use util::ResultExt as _; - -use crate::AgentServerCommand; -use acp_thread::{AcpThread, AgentConnection, AuthRequired}; - -#[derive(Clone)] -struct OldAcpClientDelegate { - thread: Rc>>, - cx: AsyncApp, - next_tool_call_id: Rc>, - // sent_buffer_versions: HashMap, HashMap>, -} - -impl OldAcpClientDelegate { - fn new(thread: Rc>>, cx: AsyncApp) -> Self { - Self { - thread, - cx, - next_tool_call_id: Rc::new(RefCell::new(0)), - } - } -} - -impl acp_old::Client for OldAcpClientDelegate { - async fn stream_assistant_message_chunk( - &self, - params: acp_old::StreamAssistantMessageChunkParams, - ) -> Result<(), acp_old::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread - .borrow() - .update(cx, |thread, cx| match params.chunk { - acp_old::AssistantMessageChunk::Text { text } => { - thread.push_assistant_content_block(text.into(), false, cx) - } - acp_old::AssistantMessageChunk::Thought { thought } => { - thread.push_assistant_content_block(thought.into(), true, cx) - } - }) - .log_err(); - })?; - - Ok(()) - } - - async fn request_tool_call_confirmation( - &self, - request: acp_old::RequestToolCallConfirmationParams, - ) -> Result { - let cx = &mut self.cx.clone(); - - let old_acp_id = *self.next_tool_call_id.borrow() + 1; - self.next_tool_call_id.replace(old_acp_id); - - let tool_call = into_new_tool_call( - acp::ToolCallId(old_acp_id.to_string().into()), - request.tool_call, - ); - - let mut options = match request.confirmation { - acp_old::ToolCallConfirmation::Edit { .. } => vec![( - acp_old::ToolCallConfirmationOutcome::AlwaysAllow, - acp::PermissionOptionKind::AllowAlways, - "Always Allow Edits".to_string(), - )], - acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![( - acp_old::ToolCallConfirmationOutcome::AlwaysAllow, - acp::PermissionOptionKind::AllowAlways, - format!("Always Allow {}", root_command), - )], - acp_old::ToolCallConfirmation::Mcp { - server_name, - tool_name, - .. - } => vec![ - ( - acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, - acp::PermissionOptionKind::AllowAlways, - format!("Always Allow {}", server_name), - ), - ( - acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool, - acp::PermissionOptionKind::AllowAlways, - format!("Always Allow {}", tool_name), - ), - ], - acp_old::ToolCallConfirmation::Fetch { .. } => vec![( - acp_old::ToolCallConfirmationOutcome::AlwaysAllow, - acp::PermissionOptionKind::AllowAlways, - "Always Allow".to_string(), - )], - acp_old::ToolCallConfirmation::Other { .. } => vec![( - acp_old::ToolCallConfirmationOutcome::AlwaysAllow, - acp::PermissionOptionKind::AllowAlways, - "Always Allow".to_string(), - )], - }; - - options.extend([ - ( - acp_old::ToolCallConfirmationOutcome::Allow, - acp::PermissionOptionKind::AllowOnce, - "Allow".to_string(), - ), - ( - acp_old::ToolCallConfirmationOutcome::Reject, - acp::PermissionOptionKind::RejectOnce, - "Reject".to_string(), - ), - ]); - - let mut outcomes = Vec::with_capacity(options.len()); - let mut acp_options = Vec::with_capacity(options.len()); - - for (index, (outcome, kind, label)) in options.into_iter().enumerate() { - outcomes.push(outcome); - acp_options.push(acp::PermissionOption { - id: acp::PermissionOptionId(index.to_string().into()), - label, - kind, - }) - } - - let response = cx - .update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.request_tool_call_permission(tool_call, acp_options, cx) - }) - })? - .context("Failed to update thread")? - .await; - - let outcome = match response { - Ok(option_id) => outcomes[option_id.0.parse::().unwrap_or(0)], - Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel, - }; - - Ok(acp_old::RequestToolCallConfirmationResponse { - id: acp_old::ToolCallId(old_acp_id), - outcome: outcome, - }) - } - - async fn push_tool_call( - &self, - request: acp_old::PushToolCallParams, - ) -> Result { - let cx = &mut self.cx.clone(); - - let old_acp_id = *self.next_tool_call_id.borrow() + 1; - self.next_tool_call_id.replace(old_acp_id); - - cx.update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.upsert_tool_call( - into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request), - cx, - ) - }) - })? - .context("Failed to update thread")?; - - Ok(acp_old::PushToolCallResponse { - id: acp_old::ToolCallId(old_acp_id), - }) - } - - async fn update_tool_call( - &self, - request: acp_old::UpdateToolCallParams, - ) -> Result<(), acp_old::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.update_tool_call( - acp::ToolCallUpdate { - id: acp::ToolCallId(request.tool_call_id.0.to_string().into()), - fields: acp::ToolCallUpdateFields { - status: Some(into_new_tool_call_status(request.status)), - content: Some( - request - .content - .into_iter() - .map(into_new_tool_call_content) - .collect::>(), - ), - ..Default::default() - }, - }, - cx, - ) - }) - })? - .context("Failed to update thread")??; - - Ok(()) - } - - async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.update_plan( - acp::Plan { - entries: request - .entries - .into_iter() - .map(into_new_plan_entry) - .collect(), - }, - cx, - ) - }) - })? - .context("Failed to update thread")?; - - Ok(()) - } - - async fn read_text_file( - &self, - acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams, - ) -> Result { - let content = self - .cx - .update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.read_text_file(path, line, limit, false, cx) - }) - })? - .context("Failed to update thread")? - .await?; - Ok(acp_old::ReadTextFileResponse { content }) - } - - async fn write_text_file( - &self, - acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams, - ) -> Result<(), acp_old::Error> { - self.cx - .update(|cx| { - self.thread - .borrow() - .update(cx, |thread, cx| thread.write_text_file(path, content, cx)) - })? - .context("Failed to update thread")? - .await?; - - Ok(()) - } -} - -fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall { - acp::ToolCall { - id: id, - label: request.label, - kind: acp_kind_from_old_icon(request.icon), - status: acp::ToolCallStatus::InProgress, - content: request - .content - .into_iter() - .map(into_new_tool_call_content) - .collect(), - locations: request - .locations - .into_iter() - .map(into_new_tool_call_location) - .collect(), - raw_input: None, - } -} - -fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind { - match icon { - acp_old::Icon::FileSearch => acp::ToolKind::Search, - acp_old::Icon::Folder => acp::ToolKind::Search, - acp_old::Icon::Globe => acp::ToolKind::Search, - acp_old::Icon::Hammer => acp::ToolKind::Other, - acp_old::Icon::LightBulb => acp::ToolKind::Think, - acp_old::Icon::Pencil => acp::ToolKind::Edit, - acp_old::Icon::Regex => acp::ToolKind::Search, - acp_old::Icon::Terminal => acp::ToolKind::Execute, - } -} - -fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus { - match status { - acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress, - acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed, - acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed, - } -} - -fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent { - match content { - acp_old::ToolCallContent::Markdown { markdown } => markdown.into(), - acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff { - diff: into_new_diff(diff), - }, - } -} - -fn into_new_diff(diff: acp_old::Diff) -> acp::Diff { - acp::Diff { - path: diff.path, - old_text: diff.old_text, - new_text: diff.new_text, - } -} - -fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation { - acp::ToolCallLocation { - path: location.path, - line: location.line, - } -} - -fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry { - acp::PlanEntry { - content: entry.content, - priority: into_new_plan_priority(entry.priority), - status: into_new_plan_status(entry.status), - } -} - -fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority { - match priority { - acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low, - acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium, - acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High, - } -} - -fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus { - match status { - acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending, - acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress, - acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed, - } -} - -pub struct AcpConnection { - pub name: &'static str, - pub connection: acp_old::AgentConnection, - pub _child_status: Task>, - pub current_thread: Rc>>, -} - -impl AcpConnection { - pub fn stdio( - name: &'static str, - command: AgentServerCommand, - root_dir: &Path, - cx: &mut AsyncApp, - ) -> Task> { - let root_dir = root_dir.to_path_buf(); - - cx.spawn(async move |cx| { - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - let foreground_executor = cx.foreground_executor().clone(); - - let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); - - let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => Err(anyhow!(result)), - }; - drop(io_task); - result - }); - - Ok(Self { - name, - connection, - _child_status: child_status, - current_thread: thread_rc, - }) - }) - } -} - -impl AgentConnection for AcpConnection { - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let task = self.connection.request_any( - acp_old::InitializeParams { - protocol_version: acp_old::ProtocolVersion::latest(), - } - .into_any(), - ); - let current_thread = self.current_thread.clone(); - cx.spawn(async move |cx| { - let result = task.await?; - let result = acp_old::InitializeParams::response_from_any(result)?; - - if !result.is_authenticated { - anyhow::bail!(AuthRequired) - } - - cx.update(|cx| { - let thread = cx.new(|cx| { - let session_id = acp::SessionId("acp-old-no-id".into()); - AcpThread::new(self.name, self.clone(), project, session_id, cx) - }); - current_thread.replace(thread.downgrade()); - thread - }) - }) - } - - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } - - fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { - let task = self - .connection - .request_any(acp_old::AuthenticateParams.into_any()); - cx.foreground_executor().spawn(async move { - task.await?; - Ok(()) - }) - } - - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { - let chunks = params - .prompt - .into_iter() - .filter_map(|block| match block { - acp::ContentBlock::Text(text) => { - Some(acp_old::UserMessageChunk::Text { text: text.text }) - } - acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path { - path: link.uri.into(), - }), - _ => None, - }) - .collect(); - - let task = self - .connection - .request_any(acp_old::SendUserMessageParams { chunks }.into_any()); - cx.foreground_executor().spawn(async move { - task.await?; - anyhow::Ok(()) - }) - } - - fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) { - let task = self - .connection - .request_any(acp_old::CancelSendMessageParams.into_any()); - cx.foreground_executor() - .spawn(async move { - task.await?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx) - } -} diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs deleted file mode 100644 index 9e2193ce18..0000000000 --- a/crates/agent_servers/src/acp/v1.rs +++ /dev/null @@ -1,254 +0,0 @@ -use agent_client_protocol::{self as acp, Agent as _}; -use collections::HashMap; -use futures::channel::oneshot; -use project::Project; -use std::cell::RefCell; -use std::path::Path; -use std::rc::Rc; - -use anyhow::{Context as _, Result}; -use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; - -use crate::{AgentServerCommand, acp::UnsupportedVersion}; -use acp_thread::{AcpThread, AgentConnection, AuthRequired}; - -pub struct AcpConnection { - server_name: &'static str, - connection: Rc, - sessions: Rc>>, - auth_methods: Vec, - _io_task: Task>, - _child: smol::process::Child, -} - -pub struct AcpSession { - thread: WeakEntity, -} - -const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1; - -impl AcpConnection { - pub async fn stdio( - server_name: &'static str, - command: AgentServerCommand, - root_dir: &Path, - cx: &mut AsyncApp, - ) -> Result { - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter().map(|arg| arg.as_str())) - .envs(command.env.iter().flatten()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - let stdout = child.stdout.take().expect("Failed to take stdout"); - let stdin = child.stdin.take().expect("Failed to take stdin"); - - let sessions = Rc::new(RefCell::new(HashMap::default())); - - let client = ClientDelegate { - sessions: sessions.clone(), - cx: cx.clone(), - }; - let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, { - let foreground_executor = cx.foreground_executor().clone(); - move |fut| { - foreground_executor.spawn(fut).detach(); - } - }); - - let io_task = cx.background_spawn(io_task); - - let response = connection - .initialize(acp::InitializeRequest { - protocol_version: acp::VERSION, - client_capabilities: acp::ClientCapabilities { - fs: acp::FileSystemCapability { - read_text_file: true, - write_text_file: true, - }, - }, - }) - .await?; - - if response.protocol_version < MINIMUM_SUPPORTED_VERSION { - return Err(UnsupportedVersion.into()); - } - - Ok(Self { - auth_methods: response.auth_methods, - connection: connection.into(), - server_name, - sessions, - _child: child, - _io_task: io_task, - }) - } -} - -impl AgentConnection for AcpConnection { - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let conn = self.connection.clone(); - let sessions = self.sessions.clone(); - let cwd = cwd.to_path_buf(); - cx.spawn(async move |cx| { - let response = conn - .new_session(acp::NewSessionRequest { - mcp_servers: vec![], - cwd, - }) - .await?; - - let Some(session_id) = response.session_id else { - anyhow::bail!(AuthRequired); - }; - - let thread = cx.new(|cx| { - AcpThread::new( - self.server_name, - self.clone(), - project, - session_id.clone(), - cx, - ) - })?; - - let session = AcpSession { - thread: thread.downgrade(), - }; - sessions.borrow_mut().insert(session_id, session); - - Ok(thread) - }) - } - - fn auth_methods(&self) -> &[acp::AuthMethod] { - &self.auth_methods - } - - fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { - let conn = self.connection.clone(); - cx.foreground_executor().spawn(async move { - let result = conn - .authenticate(acp::AuthenticateRequest { - method_id: method_id.clone(), - }) - .await?; - - Ok(result) - }) - } - - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { - let conn = self.connection.clone(); - cx.foreground_executor() - .spawn(async move { Ok(conn.prompt(params).await?) }) - } - - fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { - let conn = self.connection.clone(); - let params = acp::CancelledNotification { - session_id: session_id.clone(), - }; - cx.foreground_executor() - .spawn(async move { conn.cancelled(params).await }) - .detach(); - } -} - -struct ClientDelegate { - sessions: Rc>>, - cx: AsyncApp, -} - -impl acp::Client for ClientDelegate { - async fn request_permission( - &self, - arguments: acp::RequestPermissionRequest, - ) -> Result { - let cx = &mut self.cx.clone(); - let rx = self - .sessions - .borrow() - .get(&arguments.session_id) - .context("Failed to get session")? - .thread - .update(cx, |thread, cx| { - thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx) - })?; - - let result = rx.await; - - let outcome = match result { - Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, - Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled, - }; - - Ok(acp::RequestPermissionResponse { outcome }) - } - - async fn write_text_file( - &self, - arguments: acp::WriteTextFileRequest, - ) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - let task = self - .sessions - .borrow() - .get(&arguments.session_id) - .context("Failed to get session")? - .thread - .update(cx, |thread, cx| { - thread.write_text_file(arguments.path, arguments.content, cx) - })?; - - task.await?; - - Ok(()) - } - - async fn read_text_file( - &self, - arguments: acp::ReadTextFileRequest, - ) -> Result { - let cx = &mut self.cx.clone(); - let task = self - .sessions - .borrow() - .get(&arguments.session_id) - .context("Failed to get session")? - .thread - .update(cx, |thread, cx| { - thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx) - })?; - - let content = task.await?; - - Ok(acp::ReadTextFileResponse { content }) - } - - async fn session_notification( - &self, - notification: acp::SessionNotification, - ) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - let sessions = self.sessions.borrow(); - let session = sessions - .get(¬ification.session_id) - .context("Failed to get session")?; - - session.thread.update(cx, |thread, cx| { - thread.handle_session_update(notification.update, cx) - })??; - - Ok(()) - } -} diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index ec69290206..6d9c77f296 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,7 +1,7 @@ -mod acp; mod claude; mod gemini; mod settings; +mod stdio_agent_server; #[cfg(test)] mod e2e_tests; @@ -9,8 +9,9 @@ mod e2e_tests; pub use claude::*; pub use gemini::*; pub use settings::*; +pub use stdio_agent_server::*; -use acp_thread::AgentConnection; +use acp_thread::AcpThread; use anyhow::Result; use collections::HashMap; use gpui::{App, AsyncApp, Entity, SharedString, Task}; @@ -19,7 +20,6 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ path::{Path, PathBuf}, - rc::Rc, sync::Arc, }; use util::ResultExt as _; @@ -33,13 +33,14 @@ pub trait AgentServer: Send { fn name(&self) -> &'static str; fn empty_state_headline(&self) -> &'static str; fn empty_state_message(&self) -> &'static str; + fn supports_always_allow(&self) -> bool; - fn connect( + fn new_thread( &self, root_dir: &Path, project: &Entity, cx: &mut App, - ) -> Task>>; + ) -> Task>>; } impl std::fmt::Debug for AgentServerCommand { diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 9040b83085..835efbd655 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -1,35 +1,39 @@ mod mcp_server; -pub mod tools; +mod tools; use collections::HashMap; -use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; use std::cell::RefCell; use std::fmt::Display; use std::path::Path; +use std::pin::pin; use std::rc::Rc; use uuid::Uuid; -use agent_client_protocol as acp; +use agentic_coding_protocol::{ + self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion, + StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams, +}; use anyhow::{Result, anyhow}; use futures::channel::oneshot; -use futures::{AsyncBufReadExt, AsyncWriteExt}; +use futures::future::LocalBoxFuture; +use futures::{AsyncBufReadExt, AsyncWriteExt, SinkExt}; use futures::{ AsyncRead, AsyncWrite, FutureExt, StreamExt, channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, io::BufReader, select_biased, }; -use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; +use gpui::{App, AppContext, Entity, Task}; use serde::{Deserialize, Serialize}; use util::ResultExt; -use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; +use crate::claude::mcp_server::ClaudeMcpServer; use crate::claude::tools::ClaudeTool; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::{AcpThread, AgentConnection}; +use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection}; #[derive(Clone)] pub struct ClaudeCode; @@ -44,47 +48,36 @@ impl AgentServer for ClaudeCode { } fn empty_state_message(&self) -> &'static str { - "How can I help you today?" + "" } fn logo(&self) -> ui::IconName { ui::IconName::AiClaude } - fn connect( - &self, - _root_dir: &Path, - _project: &Entity, - _cx: &mut App, - ) -> Task>> { - let connection = ClaudeAgentConnection { - sessions: Default::default(), - }; - - Task::ready(Ok(Rc::new(connection) as _)) + fn supports_always_allow(&self) -> bool { + false } -} -struct ClaudeAgentConnection { - sessions: Rc>>, -} - -impl AgentConnection for ClaudeAgentConnection { fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, + &self, + root_dir: &Path, + project: &Entity, + cx: &mut App, ) -> Task>> { - let cwd = cwd.to_owned(); + let project = project.clone(); + let root_dir = root_dir.to_path_buf(); + let title = self.name().into(); cx.spawn(async move |cx| { - let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); - let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; + let (mut delegate_tx, delegate_rx) = watch::channel(None); + let tool_id_map = Rc::new(RefCell::new(HashMap::default())); + + let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?; let mut mcp_servers = HashMap::default(); mcp_servers.insert( mcp_server::SERVER_NAME.to_string(), - permission_mcp_server.server_config()?, + mcp_server.server_config()?, ); let mcp_config = McpConfig { mcp_servers }; @@ -109,163 +102,192 @@ impl AgentConnection for ClaudeAgentConnection { let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + let (cancel_tx, mut cancel_rx) = mpsc::unbounded::>>(); - let session_id = acp::SessionId(Uuid::new_v4().to_string().into()); + let session_id = Uuid::new_v4(); log::trace!("Starting session with id: {}", session_id); - cx.background_spawn({ - let session_id = session_id.clone(); - async move { - let mut outgoing_rx = Some(outgoing_rx); + cx.background_spawn(async move { + let mut outgoing_rx = Some(outgoing_rx); + let mut mode = ClaudeSessionMode::Start; - let mut child = spawn_claude( - &command, - ClaudeSessionMode::Start, - session_id.clone(), - &mcp_config_path, - &cwd, - ) - .await?; + loop { + let mut child = + spawn_claude(&command, mode, session_id, &mcp_config_path, &root_dir) + .await?; + mode = ClaudeSessionMode::Resume; let pid = child.id(); log::trace!("Spawned (pid: {})", pid); - ClaudeAgentSession::handle_io( - outgoing_rx.take().unwrap(), - incoming_message_tx.clone(), - child.stdin.take().unwrap(), - child.stdout.take().unwrap(), - ) - .await?; + let mut io_fut = pin!( + ClaudeAgentConnection::handle_io( + outgoing_rx.take().unwrap(), + incoming_message_tx.clone(), + child.stdin.take().unwrap(), + child.stdout.take().unwrap(), + ) + .fuse() + ); + + select_biased! { + done_tx = cancel_rx.next() => { + if let Some(done_tx) = done_tx { + log::trace!("Interrupted (pid: {})", pid); + let result = send_interrupt(pid as i32); + outgoing_rx.replace(io_fut.await?); + done_tx.send(result).log_err(); + continue; + } + } + result = io_fut => { + result?; + } + } log::trace!("Stopped (pid: {})", pid); - - drop(mcp_config_path); - anyhow::Ok(()) + break; } + + drop(mcp_config_path); + anyhow::Ok(()) }) .detach(); - let end_turn_tx = Rc::new(RefCell::new(None)); - let handler_task = cx.spawn({ - let end_turn_tx = end_turn_tx.clone(); - let thread_rx = thread_rx.clone(); - async move |cx| { - while let Some(message) = incoming_message_rx.next().await { - ClaudeAgentSession::handle_message( - thread_rx.clone(), - message, - end_turn_tx.clone(), - cx, - ) - .await + cx.new(|cx| { + let end_turn_tx = Rc::new(RefCell::new(None)); + let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()); + delegate_tx.send(Some(delegate.clone())).log_err(); + + let handler_task = cx.foreground_executor().spawn({ + let end_turn_tx = end_turn_tx.clone(); + let tool_id_map = tool_id_map.clone(); + let delegate = delegate.clone(); + async move { + while let Some(message) = incoming_message_rx.next().await { + ClaudeAgentConnection::handle_message( + delegate.clone(), + message, + end_turn_tx.clone(), + tool_id_map.clone(), + ) + .await + } } - } - }); + }); - let thread = cx.new(|cx| { - AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) - })?; + let mut connection = ClaudeAgentConnection { + delegate, + outgoing_tx, + end_turn_tx, + cancel_tx, + session_id, + _handler_task: handler_task, + _mcp_server: None, + }; - thread_tx.send(thread.downgrade())?; - - let session = ClaudeAgentSession { - outgoing_tx, - end_turn_tx, - _handler_task: handler_task, - _mcp_server: Some(permission_mcp_server), - }; - - self.sessions.borrow_mut().insert(session_id, session); - - Ok(thread) + connection._mcp_server = Some(mcp_server); + acp_thread::AcpThread::new(connection, title, None, project.clone(), cx) + }) }) } +} - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } +#[cfg(unix)] +fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> { + let pid = nix::unistd::Pid::from_raw(pid); - fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { - Task::ready(Err(anyhow!("Authentication not supported"))) - } + nix::sys::signal::kill(pid, nix::sys::signal::SIGINT) + .map_err(|e| anyhow!("Failed to interrupt process: {}", e)) +} - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { - let sessions = self.sessions.borrow(); - let Some(session) = sessions.get(¶ms.session_id) else { - return Task::ready(Err(anyhow!( - "Attempted to send message to nonexistent session {}", - params.session_id - ))); - }; +#[cfg(windows)] +fn send_interrupt(_pid: i32) -> anyhow::Result<()> { + panic!("Cancel not implemented on Windows") +} - let (tx, rx) = oneshot::channel(); - session.end_turn_tx.borrow_mut().replace(tx); - - let mut content = String::new(); - for chunk in params.prompt { - match chunk { - acp::ContentBlock::Text(text_content) => { - content.push_str(&text_content.text); +impl AgentConnection for ClaudeAgentConnection { + /// Send a request to the agent and wait for a response. + fn request_any( + &self, + params: AnyAgentRequest, + ) -> LocalBoxFuture<'static, Result> { + let delegate = self.delegate.clone(); + let end_turn_tx = self.end_turn_tx.clone(); + let outgoing_tx = self.outgoing_tx.clone(); + let mut cancel_tx = self.cancel_tx.clone(); + let session_id = self.session_id; + async move { + match params { + // todo: consider sending an empty request so we get the init response? + AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse( + acp::InitializeResponse { + is_authenticated: true, + protocol_version: ProtocolVersion::latest(), + }, + )), + AnyAgentRequest::AuthenticateParams(_) => { + Err(anyhow!("Authentication not supported")) } - acp::ContentBlock::ResourceLink(resource_link) => { - content.push_str(&format!("@{}", resource_link.uri)); + AnyAgentRequest::SendUserMessageParams(message) => { + delegate.clear_completed_plan_entries().await?; + + let (tx, rx) = oneshot::channel(); + end_turn_tx.borrow_mut().replace(tx); + let mut content = String::new(); + for chunk in message.chunks { + match chunk { + agentic_coding_protocol::UserMessageChunk::Text { text } => { + content.push_str(&text) + } + agentic_coding_protocol::UserMessageChunk::Path { path } => { + content.push_str(&format!("@{path:?}")) + } + } + } + outgoing_tx.unbounded_send(SdkMessage::User { + message: Message { + role: Role::User, + content: Content::UntaggedText(content), + id: None, + model: None, + stop_reason: None, + stop_sequence: None, + usage: None, + }, + session_id: Some(session_id), + })?; + rx.await??; + Ok(AnyAgentResult::SendUserMessageResponse( + acp::SendUserMessageResponse, + )) } - acp::ContentBlock::Audio(_) - | acp::ContentBlock::Image(_) - | acp::ContentBlock::Resource(_) => { - // TODO + AnyAgentRequest::CancelSendMessageParams(_) => { + let (done_tx, done_rx) = oneshot::channel(); + cancel_tx.send(done_tx).await?; + done_rx.await??; + + Ok(AnyAgentResult::CancelSendMessageResponse( + acp::CancelSendMessageResponse, + )) } } } - - if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User { - message: Message { - role: Role::User, - content: Content::UntaggedText(content), - id: None, - model: None, - stop_reason: None, - stop_sequence: None, - usage: None, - }, - session_id: Some(params.session_id.to_string()), - }) { - return Task::ready(Err(anyhow!(err))); - } - - cx.foreground_executor().spawn(async move { - rx.await??; - Ok(()) - }) - } - - fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { - let sessions = self.sessions.borrow(); - let Some(session) = sessions.get(&session_id) else { - log::warn!("Attempted to cancel nonexistent session {}", session_id); - return; - }; - - session - .outgoing_tx - .unbounded_send(SdkMessage::new_interrupt_message()) - .log_err(); + .boxed_local() } } #[derive(Clone, Copy)] enum ClaudeSessionMode { Start, - #[expect(dead_code)] Resume, } async fn spawn_claude( command: &AgentServerCommand, mode: ClaudeSessionMode, - session_id: acp::SessionId, + session_id: Uuid, mcp_config_path: &Path, root_dir: &Path, ) -> Result { @@ -283,16 +305,10 @@ async fn spawn_claude( &format!( "mcp__{}__{}", mcp_server::SERVER_NAME, - mcp_server::PermissionTool::NAME, + mcp_server::PERMISSION_TOOL ), "--allowedTools", - &format!( - "mcp__{}__{},mcp__{}__{}", - mcp_server::SERVER_NAME, - mcp_server::EditTool::NAME, - mcp_server::SERVER_NAME, - mcp_server::ReadTool::NAME - ), + "mcp__zed__Read,mcp__zed__Edit", "--disallowedTools", "Read,Edit", ]) @@ -311,135 +327,105 @@ async fn spawn_claude( Ok(child) } -struct ClaudeAgentSession { +struct ClaudeAgentConnection { + delegate: AcpClientDelegate, + session_id: Uuid, outgoing_tx: UnboundedSender, end_turn_tx: Rc>>>>, - _mcp_server: Option, + cancel_tx: UnboundedSender>>, + _mcp_server: Option, _handler_task: Task<()>, } -impl ClaudeAgentSession { +impl ClaudeAgentConnection { async fn handle_message( - mut thread_rx: watch::Receiver>, + delegate: AcpClientDelegate, message: SdkMessage, end_turn_tx: Rc>>>>, - cx: &mut AsyncApp, + tool_id_map: Rc>>, ) { match message { - // we should only be sending these out, they don't need to be in the thread - SdkMessage::ControlRequest { .. } => {} - SdkMessage::Assistant { - message, - session_id: _, - } - | SdkMessage::User { - message, - session_id: _, - } => { - let Some(thread) = thread_rx - .recv() - .await - .log_err() - .and_then(|entity| entity.upgrade()) - else { - log::error!("Received an SDK message but thread is gone"); - return; - }; - + SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => { for chunk in message.content.chunks() { match chunk { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - thread - .update(cx, |thread, cx| { - thread.push_assistant_content_block(text.into(), false, cx) + delegate + .stream_assistant_message_chunk(StreamAssistantMessageChunkParams { + chunk: acp::AssistantMessageChunk::Text { text }, }) + .await .log_err(); } ContentChunk::ToolUse { id, name, input } => { let claude_tool = ClaudeTool::infer(&name, input); - thread - .update(cx, |thread, cx| { - if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { - thread.update_plan( - acp::Plan { - entries: params - .todos - .into_iter() - .map(Into::into) - .collect(), - }, - cx, - ) - } else { - thread.upsert_tool_call( - claude_tool.as_acp(acp::ToolCallId(id.into())), - cx, - ); - } - }) - .log_err(); + if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { + delegate + .update_plan(acp::UpdatePlanParams { + entries: params.todos.into_iter().map(Into::into).collect(), + }) + .await + .log_err(); + } else if let Some(resp) = delegate + .push_tool_call(claude_tool.as_acp()) + .await + .log_err() + { + tool_id_map.borrow_mut().insert(id, resp.id); + } } ContentChunk::ToolResult { content, tool_use_id, } => { - let content = content.to_string(); - thread - .update(cx, |thread, cx| { - thread.update_tool_call( - acp::ToolCallUpdate { - id: acp::ToolCallId(tool_use_id.into()), - fields: acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::Completed), - content: (!content.is_empty()) - .then(|| vec![content.into()]), - ..Default::default() + let id = tool_id_map.borrow_mut().remove(&tool_use_id); + if let Some(id) = id { + let content = content.to_string(); + delegate + .update_tool_call(UpdateToolCallParams { + tool_call_id: id, + status: acp::ToolCallStatus::Finished, + // Don't unset existing content + content: (!content.is_empty()).then_some( + ToolCallContent::Markdown { + // For now we only include text content + markdown: content, }, - }, - cx, - ) - }) - .log_err(); + ), + }) + .await + .log_err(); + } } ContentChunk::Image | ContentChunk::Document | ContentChunk::Thinking | ContentChunk::RedactedThinking | ContentChunk::WebSearchToolResult => { - thread - .update(cx, |thread, cx| { - thread.push_assistant_content_block( - format!("Unsupported content: {:?}", chunk).into(), - false, - cx, - ) + delegate + .stream_assistant_message_chunk(StreamAssistantMessageChunkParams { + chunk: acp::AssistantMessageChunk::Text { + text: format!("Unsupported content: {:?}", chunk), + }, }) + .await .log_err(); } } } } SdkMessage::Result { - is_error, - subtype, - result, - .. + is_error, subtype, .. } => { if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { if is_error { - end_turn_tx - .send(Err(anyhow!( - "Error: {}", - result.unwrap_or_else(|| subtype.to_string()) - ))) - .ok(); + end_turn_tx.send(Err(anyhow!("Error: {subtype}"))).ok(); } else { end_turn_tx.send(Ok(())).ok(); } } } - SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {} + SdkMessage::System { .. } => {} } } @@ -606,14 +592,16 @@ enum SdkMessage { Assistant { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, + // A user message User { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, + // Emitted as the last message in a conversation Result { subtype: ResultErrorType, @@ -638,26 +626,6 @@ enum SdkMessage { #[serde(rename = "permissionMode")] permission_mode: PermissionMode, }, - /// Messages used to control the conversation, outside of chat messages to the model - ControlRequest { - request_id: String, - request: ControlRequest, - }, - /// Response to a control request - ControlResponse { response: ControlResponse }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "subtype", rename_all = "snake_case")] -enum ControlRequest { - /// Cancel the current conversation - Interrupt, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ControlResponse { - request_id: String, - subtype: ResultErrorType, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -678,24 +646,6 @@ impl Display for ResultErrorType { } } -impl SdkMessage { - fn new_interrupt_message() -> Self { - use rand::Rng; - // In the Claude Code TS SDK they just generate a random 12 character string, - // `Math.random().toString(36).substring(2, 15)` - let request_id = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(12) - .map(char::from) - .collect(); - - Self::ControlRequest { - request_id, - request: ControlRequest::Interrupt, - } - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] struct McpServer { name: String, @@ -711,12 +661,27 @@ enum PermissionMode { Plan, } +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct McpConfig { + mcp_servers: HashMap, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct McpServerConfig { + command: String, + args: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + env: Option>, +} + #[cfg(test)] pub(crate) mod tests { use super::*; use serde_json::json; - crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow"); + crate::common_e2e_tests!(ClaudeCode); pub fn local_command() -> AgentServerCommand { AgentServerCommand { diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index cc303016f1..2405603550 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -1,53 +1,78 @@ -use std::path::PathBuf; +use std::{cell::RefCell, rc::Rc}; -use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; -use acp_thread::AcpThread; -use agent_client_protocol as acp; +use acp_thread::AcpClientDelegate; +use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams}; use anyhow::{Context, Result}; use collections::HashMap; -use context_server::listener::{McpServerTool, ToolResponse}; -use context_server::types::{ - Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, - ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, +use context_server::{ + listener::McpServer, + types::{ + CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse, + ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, + ToolResponseContent, ToolsCapabilities, requests, + }, }; -use gpui::{App, AsyncApp, Task, WeakEntity}; +use gpui::{App, AsyncApp, Task}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use util::debug_panic; -pub struct ClaudeZedMcpServer { - server: context_server::listener::McpServer, +use crate::claude::{ + McpServerConfig, + tools::{ClaudeTool, EditToolParams, ReadToolParams}, +}; + +pub struct ClaudeMcpServer { + server: McpServer, } pub const SERVER_NAME: &str = "zed"; +pub const READ_TOOL: &str = "Read"; +pub const EDIT_TOOL: &str = "Edit"; +pub const PERMISSION_TOOL: &str = "Confirmation"; -impl ClaudeZedMcpServer { +#[derive(Deserialize, JsonSchema, Debug)] +struct PermissionToolParams { + tool_name: String, + input: serde_json::Value, + tool_use_id: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct PermissionToolResponse { + behavior: PermissionToolBehavior, + updated_input: serde_json::Value, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +enum PermissionToolBehavior { + Allow, + Deny, +} + +impl ClaudeMcpServer { pub async fn new( - thread_rx: watch::Receiver>, + delegate: watch::Receiver>, + tool_id_map: Rc>>, cx: &AsyncApp, ) -> Result { - let mut mcp_server = context_server::listener::McpServer::new(cx).await?; + let mut mcp_server = McpServer::new(cx).await?; mcp_server.handle_request::(Self::handle_initialize); - - mcp_server.add_tool(PermissionTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(ReadTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(EditTool { - thread_rx: thread_rx.clone(), + mcp_server.handle_request::(Self::handle_list_tools); + mcp_server.handle_request::(move |request, cx| { + Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx) }); Ok(Self { server: mcp_server }) } pub fn server_config(&self) -> Result { - #[cfg(not(test))] let zed_path = std::env::current_exe() - .context("finding current executable path for use in mcp_server")?; - - #[cfg(test)] - let zed_path = crate::e2e_tests::get_zed_path(); + .context("finding current executable path for use in mcp_server")? + .to_string_lossy() + .to_string(); Ok(McpServerConfig { command: zed_path, @@ -81,222 +106,191 @@ impl ClaudeZedMcpServer { }) }) } -} -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct McpConfig { - pub mcp_servers: HashMap, -} - -#[derive(Serialize, Clone)] -#[serde(rename_all = "camelCase")] -pub struct McpServerConfig { - pub command: PathBuf, - pub args: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub env: Option>, -} - -// Tools - -#[derive(Clone)] -pub struct PermissionTool { - thread_rx: watch::Receiver>, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct PermissionToolParams { - tool_name: String, - input: serde_json::Value, - tool_use_id: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct PermissionToolResponse { - behavior: PermissionToolBehavior, - updated_input: serde_json::Value, -} - -#[derive(Serialize)] -#[serde(rename_all = "snake_case")] -enum PermissionToolBehavior { - Allow, - Deny, -} - -impl McpServerTool for PermissionTool { - type Input = PermissionToolParams; - type Output = (); - - const NAME: &'static str = "Confirmation"; - - fn description(&self) -> &'static str { - "Request permission for tool calls" + fn handle_list_tools(_: (), cx: &App) -> Task> { + cx.foreground_executor().spawn(async move { + Ok(ListToolsResponse { + tools: vec![ + Tool { + name: PERMISSION_TOOL.into(), + input_schema: schemars::schema_for!(PermissionToolParams).into(), + description: None, + annotations: None, + }, + Tool { + name: READ_TOOL.into(), + input_schema: schemars::schema_for!(ReadToolParams).into(), + description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()), + annotations: Some(ToolAnnotations { + title: Some("Read file".to_string()), + read_only_hint: Some(true), + destructive_hint: Some(false), + open_world_hint: Some(false), + // if time passes the contents might change, but it's not going to do anything different + // true or false seem too strong, let's try a none. + idempotent_hint: None, + }), + }, + Tool { + name: EDIT_TOOL.into(), + input_schema: schemars::schema_for!(EditToolParams).into(), + description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()), + annotations: Some(ToolAnnotations { + title: Some("Edit file".to_string()), + read_only_hint: Some(false), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: Some(false), + }), + }, + ], + next_cursor: None, + meta: None, + }) + }) } - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; + fn handle_call_tool( + request: CallToolParams, + mut delegate_watch: watch::Receiver>, + tool_id_map: Rc>>, + cx: &App, + ) -> Task> { + cx.spawn(async move |cx| { + let Some(delegate) = delegate_watch.recv().await? else { + debug_panic!("Sent None delegate"); + anyhow::bail!("Server not available"); + }; - let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone()); - let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into()); - let allow_option_id = acp::PermissionOptionId("allow".into()); - let reject_option_id = acp::PermissionOptionId("reject".into()); + if request.name.as_str() == PERMISSION_TOOL { + let input = + serde_json::from_value(request.arguments.context("Arguments required")?)?; - let chosen_option = thread - .update(cx, |thread, cx| { - thread.request_tool_call_permission( - claude_tool.as_acp(tool_call_id), - vec![ - acp::PermissionOption { - id: allow_option_id.clone(), - label: "Allow".into(), - kind: acp::PermissionOptionKind::AllowOnce, - }, - acp::PermissionOption { - id: reject_option_id.clone(), - label: "Reject".into(), - kind: acp::PermissionOptionKind::RejectOnce, - }, - ], - cx, + let result = + Self::handle_permissions_tool_call(input, delegate, tool_id_map, cx).await?; + Ok(CallToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&result)?, + }], + is_error: None, + meta: None, + }) + } else if request.name.as_str() == READ_TOOL { + let input = + serde_json::from_value(request.arguments.context("Arguments required")?)?; + + let content = Self::handle_read_tool_call(input, delegate, cx).await?; + Ok(CallToolResponse { + content, + is_error: None, + meta: None, + }) + } else if request.name.as_str() == EDIT_TOOL { + let input = + serde_json::from_value(request.arguments.context("Arguments required")?)?; + + Self::handle_edit_tool_call(input, delegate, cx).await?; + Ok(CallToolResponse { + content: vec![], + is_error: None, + meta: None, + }) + } else { + anyhow::bail!("Unsupported tool"); + } + }) + } + + fn handle_read_tool_call( + params: ReadToolParams, + delegate: AcpClientDelegate, + cx: &AsyncApp, + ) -> Task>> { + cx.foreground_executor().spawn(async move { + let response = delegate + .read_text_file(ReadTextFileParams { + path: params.abs_path, + line: params.offset, + limit: params.limit, + }) + .await?; + + Ok(vec![ToolResponseContent::Text { + text: response.content, + }]) + }) + } + + fn handle_edit_tool_call( + params: EditToolParams, + delegate: AcpClientDelegate, + cx: &AsyncApp, + ) -> Task> { + cx.foreground_executor().spawn(async move { + let response = delegate + .read_text_file_reusing_snapshot(ReadTextFileParams { + path: params.abs_path.clone(), + line: None, + limit: None, + }) + .await?; + + let new_content = response.content.replace(¶ms.old_text, ¶ms.new_text); + if new_content == response.content { + return Err(anyhow::anyhow!("The old_text was not found in the content")); + } + + delegate + .write_text_file(WriteTextFileParams { + path: params.abs_path, + content: new_content, + }) + .await?; + + Ok(()) + }) + } + + fn handle_permissions_tool_call( + params: PermissionToolParams, + delegate: AcpClientDelegate, + tool_id_map: Rc>>, + cx: &AsyncApp, + ) -> Task> { + cx.foreground_executor().spawn(async move { + let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone()); + + let tool_call_id = match params.tool_use_id { + Some(tool_use_id) => tool_id_map + .borrow() + .get(&tool_use_id) + .cloned() + .context("Tool call ID not found")?, + + None => delegate.push_tool_call(claude_tool.as_acp()).await?.id, + }; + + let outcome = delegate + .request_existing_tool_call_confirmation( + tool_call_id, + claude_tool.confirmation(None), ) - })? - .await?; + .await?; - let response = if chosen_option == allow_option_id { - PermissionToolResponse { - behavior: PermissionToolBehavior::Allow, - updated_input: input.input, + match outcome { + acp::ToolCallConfirmationOutcome::Allow + | acp::ToolCallConfirmationOutcome::AlwaysAllow + | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer + | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse { + behavior: PermissionToolBehavior::Allow, + updated_input: params.input, + }), + acp::ToolCallConfirmationOutcome::Reject + | acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse { + behavior: PermissionToolBehavior::Deny, + updated_input: params.input, + }), } - } else { - debug_assert_eq!(chosen_option, reject_option_id); - PermissionToolResponse { - behavior: PermissionToolBehavior::Deny, - updated_input: input.input, - } - }; - - Ok(ToolResponse { - content: vec![ToolResponseContent::Text { - text: serde_json::to_string(&response)?, - }], - structured_content: (), - }) - } -} - -#[derive(Clone)] -pub struct ReadTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for ReadTool { - type Input = ReadToolParams; - type Output = (); - - const NAME: &'static str = "Read"; - - fn description(&self) -> &'static str { - "Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents." - } - - fn annotations(&self) -> ToolAnnotations { - ToolAnnotations { - title: Some("Read file".to_string()), - read_only_hint: Some(true), - destructive_hint: Some(false), - open_world_hint: Some(false), - idempotent_hint: None, - } - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![ToolResponseContent::Text { text: content }], - structured_content: (), - }) - } -} - -#[derive(Clone)] -pub struct EditTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for EditTool { - type Input = EditToolParams; - type Output = (); - - const NAME: &'static str = "Edit"; - - fn description(&self) -> &'static str { - "Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better." - } - - fn annotations(&self) -> ToolAnnotations { - ToolAnnotations { - title: Some("Edit file".to_string()), - read_only_hint: Some(false), - destructive_hint: Some(false), - open_world_hint: Some(false), - idempotent_hint: Some(false), - } - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(input.abs_path.clone(), None, None, true, cx) - })? - .await?; - - let new_content = content.replace(&input.old_text, &input.new_text); - if new_content == content { - return Err(anyhow::anyhow!("The old_text was not found in the content")); - } - - thread - .update(cx, |thread, cx| { - thread.write_text_file(input.abs_path, new_content, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![], - structured_content: (), }) } } diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index 6acb6355aa..75a26ee230 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use agent_client_protocol as acp; +use agentic_coding_protocol::{self as acp, PushToolCallParams, ToolCallLocation}; use itertools::Itertools; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -115,36 +115,51 @@ impl ClaudeTool { Self::Other { name, .. } => name.clone(), } } - pub fn content(&self) -> Vec { + + pub fn content(&self) -> Option { match &self { - Self::Other { input, .. } => vec![ - format!( + Self::Other { input, .. } => Some(acp::ToolCallContent::Markdown { + markdown: format!( "```json\n{}```", serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) - ) - .into(), - ], - Self::Task(Some(params)) => vec![params.prompt.clone().into()], - Self::NotebookRead(Some(params)) => { - vec![params.notebook_path.display().to_string().into()] - } - Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()], - Self::Terminal(Some(params)) => vec![ - format!( + ), + }), + Self::Task(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.prompt.clone(), + }), + Self::NotebookRead(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.notebook_path.display().to_string(), + }), + Self::NotebookEdit(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.new_source.clone(), + }), + Self::Terminal(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: format!( "`{}`\n\n{}", params.command, params.description.as_deref().unwrap_or_default() - ) - .into(), - ], - Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()], - Self::Ls(Some(params)) => vec![params.path.display().to_string().into()], - Self::Glob(Some(params)) => vec![params.to_string().into()], - Self::Grep(Some(params)) => vec![format!("`{params}`").into()], - Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()], - Self::WebSearch(Some(params)) => vec![params.to_string().into()], - Self::TodoWrite(Some(params)) => vec![ - params + ), + }), + Self::ReadFile(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.abs_path.display().to_string(), + }), + Self::Ls(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.path.display().to_string(), + }), + Self::Glob(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.to_string(), + }), + Self::Grep(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: format!("`{params}`"), + }), + Self::WebFetch(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.prompt.clone(), + }), + Self::WebSearch(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.to_string(), + }), + Self::TodoWrite(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params .todos .iter() .map(|todo| { @@ -159,39 +174,34 @@ impl ClaudeTool { todo.content ) }) - .join("\n") - .into(), - ], - Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()], - Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff { + .join("\n"), + }), + Self::ExitPlanMode(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.plan.clone(), + }), + Self::Edit(Some(params)) => Some(acp::ToolCallContent::Diff { diff: acp::Diff { path: params.abs_path.clone(), old_text: Some(params.old_text.clone()), new_text: params.new_text.clone(), }, - }], - Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { + }), + Self::Write(Some(params)) => Some(acp::ToolCallContent::Diff { diff: acp::Diff { path: params.file_path.clone(), old_text: None, new_text: params.content.clone(), }, - }], + }), Self::MultiEdit(Some(params)) => { // todo: show multiple edits in a multibuffer? - params - .edits - .first() - .map(|edit| { - vec![acp::ToolCallContent::Diff { - diff: acp::Diff { - path: params.file_path.clone(), - old_text: Some(edit.old_string.clone()), - new_text: edit.new_string.clone(), - }, - }] - }) - .unwrap_or_default() + params.edits.first().map(|edit| acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: Some(edit.old_string.clone()), + new_text: edit.new_string.clone(), + }, + }) } Self::Task(None) | Self::NotebookRead(None) @@ -207,80 +217,181 @@ impl ClaudeTool { | Self::ExitPlanMode(None) | Self::Edit(None) | Self::Write(None) - | Self::MultiEdit(None) => vec![], + | Self::MultiEdit(None) => None, } } - pub fn kind(&self) -> acp::ToolKind { + pub fn icon(&self) -> acp::Icon { match self { - Self::Task(_) => acp::ToolKind::Think, - Self::NotebookRead(_) => acp::ToolKind::Read, - Self::NotebookEdit(_) => acp::ToolKind::Edit, - Self::Edit(_) => acp::ToolKind::Edit, - Self::MultiEdit(_) => acp::ToolKind::Edit, - Self::Write(_) => acp::ToolKind::Edit, - Self::ReadFile(_) => acp::ToolKind::Read, - Self::Ls(_) => acp::ToolKind::Search, - Self::Glob(_) => acp::ToolKind::Search, - Self::Grep(_) => acp::ToolKind::Search, - Self::Terminal(_) => acp::ToolKind::Execute, - Self::WebSearch(_) => acp::ToolKind::Search, - Self::WebFetch(_) => acp::ToolKind::Fetch, - Self::TodoWrite(_) => acp::ToolKind::Think, - Self::ExitPlanMode(_) => acp::ToolKind::Think, - Self::Other { .. } => acp::ToolKind::Other, + Self::Task(_) => acp::Icon::Hammer, + Self::NotebookRead(_) => acp::Icon::FileSearch, + Self::NotebookEdit(_) => acp::Icon::Pencil, + Self::Edit(_) => acp::Icon::Pencil, + Self::MultiEdit(_) => acp::Icon::Pencil, + Self::Write(_) => acp::Icon::Pencil, + Self::ReadFile(_) => acp::Icon::FileSearch, + Self::Ls(_) => acp::Icon::Folder, + Self::Glob(_) => acp::Icon::FileSearch, + Self::Grep(_) => acp::Icon::Regex, + Self::Terminal(_) => acp::Icon::Terminal, + Self::WebSearch(_) => acp::Icon::Globe, + Self::WebFetch(_) => acp::Icon::Globe, + Self::TodoWrite(_) => acp::Icon::LightBulb, + Self::ExitPlanMode(_) => acp::Icon::Hammer, + Self::Other { .. } => acp::Icon::Hammer, + } + } + + pub fn confirmation(&self, description: Option) -> acp::ToolCallConfirmation { + match &self { + Self::Edit(_) | Self::Write(_) | Self::NotebookEdit(_) | Self::MultiEdit(_) => { + acp::ToolCallConfirmation::Edit { description } + } + Self::WebFetch(params) => acp::ToolCallConfirmation::Fetch { + urls: params + .as_ref() + .map(|p| vec![p.url.clone()]) + .unwrap_or_default(), + description, + }, + Self::Terminal(Some(BashToolParams { + description, + command, + .. + })) => acp::ToolCallConfirmation::Execute { + command: command.clone(), + root_command: command.clone(), + description: description.clone(), + }, + Self::ExitPlanMode(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {}", params.plan) + } else { + params.plan.clone() + }, + }, + Self::Task(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {}", params.description) + } else { + params.description.clone() + }, + }, + Self::Ls(Some(LsToolParams { path, .. })) + | Self::ReadFile(Some(ReadToolParams { abs_path: path, .. })) => { + let path = path.display(); + acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {path}") + } else { + path.to_string() + }, + } + } + Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { + let path = notebook_path.display(); + acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {path}") + } else { + path.to_string() + }, + } + } + Self::Glob(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {params}") + } else { + params.to_string() + }, + }, + Self::Grep(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {params}") + } else { + params.to_string() + }, + }, + Self::WebSearch(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {params}") + } else { + params.to_string() + }, + }, + Self::TodoWrite(Some(params)) => { + let params = params.todos.iter().map(|todo| &todo.content).join(", "); + acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {params}") + } else { + params + }, + } + } + Self::Terminal(None) + | Self::Task(None) + | Self::NotebookRead(None) + | Self::ExitPlanMode(None) + | Self::Ls(None) + | Self::Glob(None) + | Self::Grep(None) + | Self::ReadFile(None) + | Self::WebSearch(None) + | Self::TodoWrite(None) + | Self::Other { .. } => acp::ToolCallConfirmation::Other { + description: description.unwrap_or("".to_string()), + }, } } pub fn locations(&self) -> Vec { match &self { - Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation { + Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![ToolCallLocation { path: abs_path.clone(), line: None, }], Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { - vec![acp::ToolCallLocation { - path: file_path.clone(), - line: None, - }] - } - Self::Write(Some(WriteToolParams { file_path, .. })) => { - vec![acp::ToolCallLocation { + vec![ToolCallLocation { path: file_path.clone(), line: None, }] } + Self::Write(Some(WriteToolParams { file_path, .. })) => vec![ToolCallLocation { + path: file_path.clone(), + line: None, + }], Self::ReadFile(Some(ReadToolParams { abs_path, offset, .. - })) => vec![acp::ToolCallLocation { + })) => vec![ToolCallLocation { path: abs_path.clone(), line: *offset, }], Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { - vec![acp::ToolCallLocation { + vec![ToolCallLocation { path: notebook_path.clone(), line: None, }] } Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { - vec![acp::ToolCallLocation { + vec![ToolCallLocation { path: notebook_path.clone(), line: None, }] } Self::Glob(Some(GlobToolParams { path: Some(path), .. - })) => vec![acp::ToolCallLocation { + })) => vec![ToolCallLocation { path: path.clone(), line: None, }], - Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation { + Self::Ls(Some(LsToolParams { path, .. })) => vec![ToolCallLocation { path: path.clone(), line: None, }], Self::Grep(Some(GrepToolParams { path: Some(path), .. - })) => vec![acp::ToolCallLocation { + })) => vec![ToolCallLocation { path: PathBuf::from(path), line: None, }], @@ -303,15 +414,12 @@ impl ClaudeTool { } } - pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall { - acp::ToolCall { - id, - kind: self.kind(), - status: acp::ToolCallStatus::InProgress, + pub fn as_acp(&self) -> PushToolCallParams { + PushToolCallParams { label: self.label(), content: self.content(), + icon: self.icon(), locations: self.locations(), - raw_input: None, } } } diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index a60aefb7b9..12f74cb13e 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -1,17 +1,15 @@ -use std::{ - path::{Path, PathBuf}, - sync::Arc, - time::Duration, -}; +use std::{path::Path, sync::Arc, time::Duration}; use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; -use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; -use agent_client_protocol as acp; - +use acp_thread::{ + AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus, +}; +use agentic_coding_protocol as acp; use futures::{FutureExt, StreamExt, channel::mpsc, select}; use gpui::{Entity, TestAppContext}; use indoc::indoc; use project::{FakeFs, Project}; +use serde_json::json; use settings::{Settings, SettingsStore}; use util::path; @@ -26,11 +24,7 @@ pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppCont .unwrap(); thread.read_with(cx, |thread, _| { - assert!( - thread.entries().len() >= 2, - "Expected at least 2 entries. Got: {:?}", - thread.entries() - ); + assert_eq!(thread.entries().len(), 2); assert!(matches!( thread.entries()[0], AgentThreadEntry::UserMessage(_) @@ -60,25 +54,19 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes thread .update(cx, |thread, cx| { thread.send( - vec![ - acp::ContentBlock::Text(acp::TextContent { - text: "Read the file ".into(), - annotations: None, - }), - acp::ContentBlock::ResourceLink(acp::ResourceLink { - uri: "foo.rs".into(), - name: "foo.rs".into(), - annotations: None, - description: None, - mime_type: None, - size: None, - title: None, - }), - acp::ContentBlock::Text(acp::TextContent { - text: " and tell me what the content of the println! is".into(), - annotations: None, - }), - ], + acp::SendUserMessageParams { + chunks: vec![ + acp::UserMessageChunk::Text { + text: "Read the file ".into(), + }, + acp::UserMessageChunk::Path { + path: Path::new("foo.rs").into(), + }, + acp::UserMessageChunk::Text { + text: " and tell me what the content of the println! is".into(), + }, + ], + }, cx, ) }) @@ -86,44 +74,37 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes .unwrap(); thread.read_with(cx, |thread, cx| { + assert_eq!(thread.entries().len(), 3); assert!(matches!( thread.entries()[0], AgentThreadEntry::UserMessage(_) )); - let assistant_message = &thread - .entries() - .iter() - .rev() - .find_map(|entry| match entry { - AgentThreadEntry::AssistantMessage(msg) => Some(msg), - _ => None, - }) - .unwrap(); - + assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_))); + let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else { + panic!("Expected AssistantMessage") + }; assert!( assistant_message.to_markdown(cx).contains("Hello, world!"), "unexpected assistant message: {:?}", assistant_message.to_markdown(cx) ); }); - - drop(tempdir); } pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) { - let _fs = init_test(cx).await; - - let tempdir = tempfile::tempdir().unwrap(); - let foo_path = tempdir.path().join("foo"); - std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file"); - - let project = Project::example([tempdir.path()], &mut cx.to_async()).await; + let fs = init_test(cx).await; + fs.insert_tree( + path!("/private/tmp"), + json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), + ) + .await; + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; thread .update(cx, |thread, cx| { thread.send_raw( - &format!("Read {} and tell me what you see.", foo_path.display()), + "Read the '/private/tmp/foo' file and tell me what you see.", cx, ) }) @@ -146,13 +127,10 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) }) ); }); - - drop(tempdir); } -pub async fn test_tool_call_with_permission( +pub async fn test_tool_call_with_confirmation( server: impl AgentServer + 'static, - allow_option_id: acp::PermissionOptionId, cx: &mut TestAppContext, ) { let fs = init_test(cx).await; @@ -160,7 +138,7 @@ pub async fn test_tool_call_with_permission( let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let full_turn = thread.update(cx, |thread, cx| { thread.send_raw( - r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, + r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#, cx, ) }); @@ -180,11 +158,14 @@ pub async fn test_tool_call_with_permission( ) .await; - let tool_call_id = thread.read_with(cx, |thread, cx| { + let tool_call_id = thread.read_with(cx, |thread, _cx| { let AgentThreadEntry::ToolCall(ToolCall { id, - label, - status: ToolCallStatus::WaitingForConfirmation { .. }, + status: + ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::Execute { root_command, .. }, + .. + }, .. }) = &thread .entries() @@ -195,19 +176,13 @@ pub async fn test_tool_call_with_permission( panic!(); }; - let label = label.read(cx).source(); - assert!(label.contains("touch"), "Got: {}", label); + assert!(root_command.contains("touch")); - id.clone() + *id }); thread.update(cx, |thread, cx| { - thread.authorize_tool_call( - tool_call_id, - allow_option_id, - acp::PermissionOptionKind::AllowOnce, - cx, - ); + thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); assert!(thread.entries().iter().any(|entry| matches!( entry, @@ -222,7 +197,7 @@ pub async fn test_tool_call_with_permission( thread.read_with(cx, |thread, cx| { let AgentThreadEntry::ToolCall(ToolCall { - content, + content: Some(ToolCallContent::Markdown { markdown }), status: ToolCallStatus::Allowed { .. }, .. }) = thread @@ -234,10 +209,13 @@ pub async fn test_tool_call_with_permission( panic!(); }; - assert!( - content.iter().any(|c| c.to_markdown(cx).contains("Hello")), - "Expected content to contain 'Hello'" - ); + markdown.read_with(cx, |md, _cx| { + assert!( + md.source().contains("Hello"), + r#"Expected '{}' to contain "Hello""#, + md.source() + ); + }); }); } @@ -248,7 +226,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let full_turn = thread.update(cx, |thread, cx| { thread.send_raw( - r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, + r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#, cx, ) }); @@ -268,24 +246,29 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon ) .await; - thread.read_with(cx, |thread, cx| { + thread.read_with(cx, |thread, _cx| { let AgentThreadEntry::ToolCall(ToolCall { id, - label, - status: ToolCallStatus::WaitingForConfirmation { .. }, + status: + ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::Execute { root_command, .. }, + .. + }, .. }) = &thread.entries()[first_tool_call_ix] else { panic!("{:?}", thread.entries()[1]); }; - let label = label.read(cx).source(); - assert!(label.contains("touch"), "Got: {}", label); + assert!(root_command.contains("touch")); - id.clone() + *id }); - let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); + thread + .update(cx, |thread, cx| thread.cancel(cx)) + .await + .unwrap(); full_turn.await.unwrap(); thread.read_with(cx, |thread, _| { let AgentThreadEntry::ToolCall(ToolCall { @@ -313,7 +296,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon #[macro_export] macro_rules! common_e2e_tests { - ($server:expr, allow_option_id = $allow_option_id:expr) => { + ($server:expr) => { mod common_e2e { use super::*; @@ -337,13 +320,8 @@ macro_rules! common_e2e_tests { #[::gpui::test] #[cfg_attr(not(feature = "e2e"), ignore)] - async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) { - $crate::e2e_tests::test_tool_call_with_permission( - $server, - ::agent_client_protocol::PermissionOptionId($allow_option_id.into()), - cx, - ) - .await; + async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await; } #[::gpui::test] @@ -391,16 +369,15 @@ pub async fn new_test_thread( current_dir: impl AsRef, cx: &mut TestAppContext, ) -> Entity { - let connection = cx - .update(|cx| server.connect(current_dir.as_ref(), &project, cx)) + let thread = cx + .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx)) .await .unwrap(); - let thread = connection - .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async()) + thread + .update(cx, |thread, _| thread.initialize()) .await .unwrap(); - thread } @@ -433,24 +410,3 @@ pub async fn run_until_first_tool_call( } } } - -pub fn get_zed_path() -> PathBuf { - let mut zed_path = std::env::current_exe().unwrap(); - - while zed_path - .file_name() - .map_or(true, |name| name.to_string_lossy() != "debug") - { - if !zed_path.pop() { - panic!("Could not find target directory"); - } - } - - zed_path.push("zed"); - - if !zed_path.exists() { - panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n"); - } - - zed_path -} diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 2366783d22..8ad147cbff 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,13 +1,9 @@ -use std::path::Path; -use std::rc::Rc; - -use crate::{AgentServer, AgentServerCommand}; -use acp_thread::AgentConnection; -use anyhow::Result; -use gpui::{Entity, Task}; +use crate::stdio_agent_server::StdioAgentServer; +use crate::{AgentServerCommand, AgentServerVersion}; +use anyhow::{Context as _, Result}; +use gpui::{AsyncApp, Entity}; use project::Project; use settings::SettingsStore; -use ui::App; use crate::AllAgentServersSettings; @@ -16,7 +12,7 @@ pub struct Gemini; const ACP_ARG: &str = "--experimental-acp"; -impl AgentServer for Gemini { +impl StdioAgentServer for Gemini { fn name(&self) -> &'static str { "Gemini" } @@ -29,33 +25,79 @@ impl AgentServer for Gemini { "Ask questions, edit files, run commands.\nBe specific for the best results." } + fn supports_always_allow(&self) -> bool { + true + } + fn logo(&self) -> ui::IconName { ui::IconName::AiGemini } - fn connect( + async fn command( &self, - root_dir: &Path, project: &Entity, - cx: &mut App, - ) -> Task>> { - let project = project.clone(); - let root_dir = root_dir.to_path_buf(); - let server_name = self.name(); - cx.spawn(async move |cx| { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).gemini.clone() - })?; + cx: &mut AsyncApp, + ) -> Result { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() + })?; - let Some(command) = - AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await - else { - anyhow::bail!("Failed to find gemini binary"); - }; + if let Some(command) = + AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await + { + return Ok(command); + }; - crate::acp::connect(server_name, command, &root_dir, cx).await + let (fs, node_runtime) = project.update(cx, |project, _| { + (project.fs().clone(), project.node_runtime().cloned()) + })?; + let node_runtime = node_runtime.context("gemini not found on path")?; + + let directory = ::paths::agent_servers_dir().join("gemini"); + fs.create_dir(&directory).await?; + node_runtime + .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) + .await?; + let path = directory.join("node_modules/.bin/gemini"); + + Ok(AgentServerCommand { + path, + args: vec![ACP_ARG.into()], + env: None, }) } + + async fn version(&self, command: &AgentServerCommand) -> Result { + let version_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--version") + .kill_on_drop(true) + .output(); + + let help_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--help") + .kill_on_drop(true) + .output(); + + let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; + + let current_version = String::from_utf8(version_output?.stdout)?; + let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); + + if supported { + Ok(AgentServerVersion::Supported) + } else { + Ok(AgentServerVersion::Unsupported { + error_message: format!( + "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", + current_version + ).into(), + upgrade_message: "Upgrade Gemini to Latest".into(), + upgrade_command: "npm install -g @google/gemini-cli@latest".into(), + }) + } + } } #[cfg(test)] @@ -64,7 +106,7 @@ pub(crate) mod tests { use crate::AgentServerCommand; use std::path::Path; - crate::common_e2e_tests!(Gemini, allow_option_id = "proceed_once"); + crate::common_e2e_tests!(Gemini); pub fn local_command() -> AgentServerCommand { let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) @@ -74,7 +116,7 @@ pub(crate) mod tests { AgentServerCommand { path: "node".into(), - args: vec![cli_path], + args: vec![cli_path, ACP_ARG.into()], env: None, } } diff --git a/crates/agent_servers/src/stdio_agent_server.rs b/crates/agent_servers/src/stdio_agent_server.rs new file mode 100644 index 0000000000..e60dd39de4 --- /dev/null +++ b/crates/agent_servers/src/stdio_agent_server.rs @@ -0,0 +1,119 @@ +use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; +use acp_thread::{AcpClientDelegate, AcpThread, LoadError}; +use agentic_coding_protocol as acp; +use anyhow::{Result, anyhow}; +use gpui::{App, AsyncApp, Entity, Task, prelude::*}; +use project::Project; +use std::path::Path; +use util::ResultExt; + +pub trait StdioAgentServer: Send + Clone { + fn logo(&self) -> ui::IconName; + fn name(&self) -> &'static str; + fn empty_state_headline(&self) -> &'static str; + fn empty_state_message(&self) -> &'static str; + fn supports_always_allow(&self) -> bool; + + fn command( + &self, + project: &Entity, + cx: &mut AsyncApp, + ) -> impl Future>; + + fn version( + &self, + command: &AgentServerCommand, + ) -> impl Future> + Send; +} + +impl AgentServer for T { + fn name(&self) -> &'static str { + self.name() + } + + fn empty_state_headline(&self) -> &'static str { + self.empty_state_headline() + } + + fn empty_state_message(&self) -> &'static str { + self.empty_state_message() + } + + fn logo(&self) -> ui::IconName { + self.logo() + } + + fn supports_always_allow(&self) -> bool { + self.supports_always_allow() + } + + fn new_thread( + &self, + root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + let root_dir = root_dir.to_path_buf(); + let project = project.clone(); + let this = self.clone(); + let title = self.name().into(); + + cx.spawn(async move |cx| { + let command = this.command(&project, cx).await?; + + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + cx.new(|cx| { + let foreground_executor = cx.foreground_executor().clone(); + + let (connection, io_fut) = acp::AgentConnection::connect_to_agent( + AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => { + if let Some(AgentServerVersion::Unsupported { + error_message, + upgrade_message, + upgrade_command, + }) = this.version(&command).await.log_err() + { + Err(anyhow!(LoadError::Unsupported { + error_message, + upgrade_message, + upgrade_command + })) + } else { + Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) + } + } + }; + drop(io_task); + result + }); + + AcpThread::new(connection, title, Some(child_status), project.clone(), cx) + }) + }) + } +} diff --git a/crates/agent_settings/Cargo.toml b/crates/agent_settings/Cargo.toml index d34396a5d3..3afe5ae547 100644 --- a/crates/agent_settings/Cargo.toml +++ b/crates/agent_settings/Cargo.toml @@ -13,7 +13,6 @@ path = "src/agent_settings.rs" [dependencies] anyhow.workspace = true -cloud_llm_client.workspace = true collections.workspace = true gpui.workspace = true language_model.workspace = true @@ -21,6 +20,7 @@ schemars.workspace = true serde.workspace = true settings.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] fs.workspace = true diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 4e872c78d7..13b966608c 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -321,11 +321,11 @@ pub enum CompletionMode { Burn, } -impl From for cloud_llm_client::CompletionMode { +impl From for zed_llm_client::CompletionMode { fn from(value: CompletionMode) -> Self { match value { - CompletionMode::Normal => cloud_llm_client::CompletionMode::Normal, - CompletionMode::Burn => cloud_llm_client::CompletionMode::Max, + CompletionMode::Normal => zed_llm_client::CompletionMode::Normal, + CompletionMode::Burn => zed_llm_client::CompletionMode::Max, } } } diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 95fd2b1757..7d3b84e42e 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -17,10 +17,10 @@ test-support = ["gpui/test-support", "language/test-support"] [dependencies] acp_thread.workspace = true -agent-client-protocol.workspace = true agent.workspace = true -agent_servers.workspace = true +agentic-coding-protocol.workspace = true agent_settings.workspace = true +agent_servers.workspace = true ai_onboarding.workspace = true anyhow.workspace = true assistant_context.workspace = true @@ -31,7 +31,6 @@ audio.workspace = true buffer_diff.workspace = true chrono.workspace = true client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true component.workspace = true @@ -47,9 +46,9 @@ futures.workspace = true fuzzy.workspace = true gpui.workspace = true html_to_markdown.workspace = true +indoc.workspace = true http_client.workspace = true indexed_docs.workspace = true -indoc.workspace = true inventory.workspace = true itertools.workspace = true jsonschema.workspace = true @@ -98,6 +97,7 @@ watch.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true +zed_llm_client.workspace = true [dev-dependencies] assistant_tools.workspace = true diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 24d8b73396..95f4f81205 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,7 +1,5 @@ -use acp_thread::{AgentConnection, Plan}; +use acp_thread::Plan; use agent_servers::AgentServer; -use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; -use audio::{Audio, Sound}; use std::cell::RefCell; use std::collections::BTreeMap; use std::path::Path; @@ -9,7 +7,7 @@ use std::rc::Rc; use std::sync::Arc; use std::time::Duration; -use agent_client_protocol as acp; +use agentic_coding_protocol::{self as acp}; use assistant_tool::ActionLog; use buffer_diff::BufferDiff; use collections::{HashMap, HashSet}; @@ -18,12 +16,13 @@ use editor::{ EditorStyle, MinimapVisibility, MultiBuffer, PathKey, }; use file_icons::FileIcons; +use futures::channel::oneshot; use gpui::{ Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, - FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString, - StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation, - UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient, - list, percentage, point, prelude::*, pulsating_between, + FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, + Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, + Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*, + pulsating_between, }; use language::language_settings::SoftWrap; use language::{Buffer, Language}; @@ -31,7 +30,7 @@ use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::Project; use settings::Settings as _; -use text::{Anchor, BufferSnapshot}; +use text::Anchor; use theme::ThemeSettings; use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*}; use util::ResultExt; @@ -40,17 +39,15 @@ use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; use ::acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, - LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, + LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent, + ToolCallId, ToolCallStatus, }; use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; use crate::acp::message_history::MessageHistory; use crate::agent_diff::AgentDiff; use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; -use crate::ui::{AgentNotification, AgentNotificationEvent}; -use crate::{ - AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll, -}; +use crate::{AgentDiffPane, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll}; const RESPONSE_PADDING_X: Pixels = px(19.); @@ -61,21 +58,18 @@ pub struct AcpThreadView { thread_state: ThreadState, diff_editors: HashMap>, message_editor: Entity, - message_set_from_history: Option, + message_set_from_history: bool, _message_editor_subscription: Subscription, mention_set: Arc>, - notifications: Vec>, - notification_subscriptions: HashMap, Vec>, last_error: Option>, list_state: ListState, auth_task: Option>, - expanded_tool_calls: HashSet, + expanded_tool_calls: HashSet, expanded_thinking_blocks: HashSet<(usize, usize)>, edits_expanded: bool, plan_expanded: bool, editor_expanded: bool, - message_history: Rc>>>, - _cancel_task: Option>, + message_history: Rc>>, } enum ThreadState { @@ -88,16 +82,22 @@ enum ThreadState { }, LoadError(LoadError), Unauthenticated { - connection: Rc, + thread: Entity, }, } +struct AlwaysAllowOption { + id: &'static str, + label: SharedString, + outcome: acp::ToolCallConfirmationOutcome, +} + impl AcpThreadView { pub fn new( agent: Rc, workspace: WeakEntity, project: Entity, - message_history: Rc>>>, + message_history: Rc>>, min_lines: usize, max_lines: Option, window: &mut Window, @@ -144,28 +144,14 @@ impl AcpThreadView { editor }); - let message_editor_subscription = - cx.subscribe(&message_editor, |this, editor, event, cx| { - if let editor::EditorEvent::BufferEdited = &event { - let buffer = editor - .read(cx) - .buffer() - .read(cx) - .as_singleton() - .unwrap() - .read(cx) - .snapshot(); - if let Some(message) = this.message_set_from_history.clone() - && message.version() != buffer.version() - { - this.message_set_from_history = None; - } - - if this.message_set_from_history.is_none() { - this.message_history.borrow_mut().reset_position(); - } + let message_editor_subscription = cx.subscribe(&message_editor, |this, _, event, _| { + if let editor::EditorEvent::BufferEdited = &event { + if !this.message_set_from_history { + this.message_history.borrow_mut().reset_position(); } - }); + this.message_set_from_history = false; + } + }); let mention_set = mention_set.clone(); @@ -192,11 +178,9 @@ impl AcpThreadView { project: project.clone(), thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, - message_set_from_history: None, + message_set_from_history: false, _message_editor_subscription: message_editor_subscription, mention_set, - notifications: Vec::new(), - notification_subscriptions: HashMap::default(), diff_editors: Default::default(), list_state: list_state, last_error: None, @@ -207,7 +191,6 @@ impl AcpThreadView { plan_expanded: false, editor_expanded: false, message_history, - _cancel_task: None, } } @@ -225,9 +208,9 @@ impl AcpThreadView { .map(|worktree| worktree.read(cx).abs_path()) .unwrap_or_else(|| paths::home_dir().as_path().into()); - let connect_task = agent.connect(&root_dir, &project, cx); + let task = agent.new_thread(&root_dir, &project, cx); let load_task = cx.spawn_in(window, async move |this, cx| { - let connection = match connect_task.await { + let thread = match task.await { Ok(thread) => thread, Err(err) => { this.update(cx, |this, cx| { @@ -239,30 +222,48 @@ impl AcpThreadView { } }; - let result = match connection - .clone() - .new_thread(project.clone(), &root_dir, cx) - .await - { + let init_response = async { + let resp = thread + .read_with(cx, |thread, _cx| thread.initialize())? + .await?; + anyhow::Ok(resp) + }; + + let result = match init_response.await { Err(e) => { let mut cx = cx.clone(); - if e.is::() { - this.update(&mut cx, |this, cx| { - this.thread_state = ThreadState::Unauthenticated { connection }; - cx.notify(); - }) - .ok(); - return; + if e.downcast_ref::().is_some() { + let child_status = thread + .update(&mut cx, |thread, _| thread.child_status()) + .ok() + .flatten(); + if let Some(child_status) = child_status { + match child_status.await { + Ok(_) => Err(e), + Err(e) => Err(e), + } + } else { + Err(e) + } } else { Err(e) } } - Ok(session_id) => Ok(session_id), + Ok(response) => { + if !response.is_authenticated { + this.update(cx, |this, _| { + this.thread_state = ThreadState::Unauthenticated { thread }; + }) + .ok(); + return; + }; + Ok(()) + } }; this.update_in(cx, |this, window, cx| { match result { - Ok(thread) => { + Ok(()) => { let thread_subscription = cx.subscribe_in(&thread, window, Self::handle_thread_event); @@ -304,10 +305,10 @@ impl AcpThreadView { pub fn thread(&self) -> Option<&Entity> { match &self.thread_state { - ThreadState::Ready { thread, .. } => Some(thread), - ThreadState::Unauthenticated { .. } - | ThreadState::Loading { .. } - | ThreadState::LoadError(..) => None, + ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { + Some(thread) + } + ThreadState::Loading { .. } | ThreadState::LoadError(..) => None, } } @@ -324,7 +325,7 @@ impl AcpThreadView { self.last_error.take(); if let Some(thread) = self.thread() { - self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); + thread.update(cx, |thread, cx| thread.cancel(cx)).detach(); } } @@ -361,7 +362,7 @@ impl AcpThreadView { self.last_error.take(); let mut ix = 0; - let mut chunks: Vec = Vec::new(); + let mut chunks: Vec = Vec::new(); let project = self.project.clone(); self.message_editor.update(cx, |editor, cx| { let text = editor.text(cx); @@ -373,19 +374,12 @@ impl AcpThreadView { { let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); if crease_range.start > ix { - chunks.push(text[ix..crease_range.start].into()); + chunks.push(acp::UserMessageChunk::Text { + text: text[ix..crease_range.start].to_string(), + }); } if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { - let path_str = abs_path.display().to_string(); - chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink { - uri: path_str.clone(), - name: path_str, - annotations: None, - description: None, - mime_type: None, - size: None, - title: None, - })); + chunks.push(acp::UserMessageChunk::Path { path: abs_path }); } ix = crease_range.end; } @@ -394,7 +388,9 @@ impl AcpThreadView { if ix < text.len() { let last_chunk = text[ix..].trim(); if !last_chunk.is_empty() { - chunks.push(last_chunk.into()); + chunks.push(acp::UserMessageChunk::Text { + text: last_chunk.into(), + }); } } }) @@ -404,10 +400,9 @@ impl AcpThreadView { return; } - let Some(thread) = self.thread() else { - return; - }; - let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); + let Some(thread) = self.thread() else { return }; + let message = acp::SendUserMessageParams { chunks }; + let task = thread.update(cx, |thread, cx| thread.send(message.clone(), cx)); cx.spawn(async move |this, cx| { let result = task.await; @@ -424,15 +419,12 @@ impl AcpThreadView { let mention_set = self.mention_set.clone(); self.set_editor_is_expanded(false, cx); - self.message_editor.update(cx, |editor, cx| { editor.clear(window, cx); editor.remove_creases(mention_set.lock().drain(), cx) }); - self.scroll_to_bottom(cx); - - self.message_history.borrow_mut().push(chunks); + self.message_history.borrow_mut().push(message); } fn previous_history_message( @@ -441,21 +433,11 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - if self.message_set_from_history.is_none() && !self.message_editor.read(cx).is_empty(cx) { - self.message_editor.update(cx, |editor, cx| { - editor.move_up(&Default::default(), window, cx); - }); - return; - } - self.message_set_from_history = Self::set_draft_message( self.message_editor.clone(), self.mention_set.clone(), self.project.clone(), - self.message_history - .borrow_mut() - .prev() - .map(|blocks| blocks.as_slice()), + self.message_history.borrow_mut().prev(), window, cx, ); @@ -467,35 +449,14 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - if self.message_set_from_history.is_none() { - self.message_editor.update(cx, |editor, cx| { - editor.move_down(&Default::default(), window, cx); - }); - return; - } - - let mut message_history = self.message_history.borrow_mut(); - let next_history = message_history.next(); - - let set_draft_message = Self::set_draft_message( + self.message_set_from_history = Self::set_draft_message( self.message_editor.clone(), self.mention_set.clone(), self.project.clone(), - Some( - next_history - .map(|blocks| blocks.as_slice()) - .unwrap_or_else(|| &[]), - ), + self.message_history.borrow_mut().next(), window, cx, ); - // If we reset the text to an empty string because we ran out of history, - // we don't want to mark it as coming from the history - self.message_set_from_history = if next_history.is_some() { - set_draft_message - } else { - None - }; } fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { @@ -529,30 +490,31 @@ impl AcpThreadView { message_editor: Entity, mention_set: Arc>, project: Entity, - message: Option<&[acp::ContentBlock]>, + message: Option<&acp::SendUserMessageParams>, window: &mut Window, cx: &mut Context, - ) -> Option { + ) -> bool { cx.notify(); - let message = message?; + let Some(message) = message else { + return false; + }; let mut text = String::new(); let mut mentions = Vec::new(); - for chunk in message { + for chunk in &message.chunks { match chunk { - acp::ContentBlock::Text(text_content) => { - text.push_str(&text_content.text); + acp::UserMessageChunk::Text { text: chunk } => { + text.push_str(&chunk); } - acp::ContentBlock::ResourceLink(resource_link) => { - let path = Path::new(&resource_link.uri); + acp::UserMessageChunk::Path { path } => { let start = text.len(); - let content = MentionPath::new(&path).to_string(); + let content = MentionPath::new(path).to_string(); text.push_str(&content); let end = text.len(); if let Some(project_path) = - project.read(cx).project_path_for_absolute_path(&path, cx) + project.read(cx).project_path_for_absolute_path(path, cx) { let filename: SharedString = path .file_name() @@ -563,9 +525,6 @@ impl AcpThreadView { mentions.push((start..end, project_path, filename)); } } - acp::ContentBlock::Image(_) - | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) => {} } } @@ -599,8 +558,7 @@ impl AcpThreadView { } } - let snapshot = snapshot.as_singleton().unwrap().2.clone(); - Some(snapshot.text) + true } fn handle_thread_event( @@ -622,30 +580,6 @@ impl AcpThreadView { self.sync_thread_entry_view(index, window, cx); self.list_state.splice(index..index + 1, 1); } - AcpThreadEvent::ToolAuthorizationRequired => { - self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); - } - AcpThreadEvent::Stopped => { - let used_tools = thread.read(cx).used_tools_since_last_user_message(); - self.notify_with_sound( - if used_tools { - "Finished running tools" - } else { - "New message" - }, - IconName::ZedAssistant, - window, - cx, - ); - } - AcpThreadEvent::Error => { - self.notify_with_sound( - "Agent stopped due to an error", - IconName::Warning, - window, - cx, - ); - } } cx.notify(); } @@ -656,84 +590,71 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else { + let Some(multibuffer) = self.entry_diff_multibuffer(entry_ix, cx) else { return; }; - let multibuffers = multibuffers.collect::>(); - - for multibuffer in multibuffers { - if self.diff_editors.contains_key(&multibuffer.entity_id()) { - return; - } - - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - None, - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(TextStyleRefinement { - font_size: Some( - TextSize::Small - .rems(cx) - .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) - .into(), - ), - ..Default::default() - }); - editor - }); - let entity_id = multibuffer.entity_id(); - cx.observe_release(&multibuffer, move |this, _, _| { - this.diff_editors.remove(&entity_id); - }) - .detach(); - - self.diff_editors.insert(entity_id, editor); + if self.diff_editors.contains_key(&multibuffer.entity_id()) { + return; } + + let editor = cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + multibuffer.clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(TextStyleRefinement { + font_size: Some( + TextSize::Small + .rems(cx) + .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) + .into(), + ), + ..Default::default() + }); + editor + }); + let entity_id = multibuffer.entity_id(); + cx.observe_release(&multibuffer, move |this, _, _| { + this.diff_editors.remove(&entity_id); + }) + .detach(); + + self.diff_editors.insert(entity_id, editor); } - fn entry_diff_multibuffers( - &self, - entry_ix: usize, - cx: &App, - ) -> Option>> { + fn entry_diff_multibuffer(&self, entry_ix: usize, cx: &App) -> Option> { let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - Some(entry.diffs().map(|diff| diff.multibuffer.clone())) + entry.diff().map(|diff| diff.multibuffer.clone()) } - fn authenticate( - &mut self, - method: acp::AuthMethodId, - window: &mut Window, - cx: &mut Context, - ) { - let ThreadState::Unauthenticated { ref connection } = self.thread_state else { + fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { + let Some(thread) = self.thread().cloned() else { return; }; self.last_error.take(); - let authenticate = connection.authenticate(method, cx); + let authenticate = thread.read(cx).authenticate(); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); let agent = self.agent.clone(); @@ -763,16 +684,15 @@ impl AcpThreadView { fn authorize_tool_call( &mut self, - tool_call_id: acp::ToolCallId, - option_id: acp::PermissionOptionId, - option_kind: acp::PermissionOptionKind, + id: ToolCallId, + outcome: acp::ToolCallConfirmationOutcome, cx: &mut Context, ) { let Some(thread) = self.thread() else { return; }; thread.update(cx, |thread, cx| { - thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx); + thread.authorize_tool_call(id, outcome, cx); }); cx.notify(); } @@ -799,12 +719,10 @@ impl AcpThreadView { .border_1() .border_color(cx.theme().colors().border) .text_xs() - .children(message.content.markdown().map(|md| { - self.render_markdown( - md.clone(), - user_message_markdown_style(window, cx), - ) - })), + .child(self.render_markdown( + message.content.clone(), + user_message_markdown_style(window, cx), + )), ) .into_any(), AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => { @@ -812,28 +730,20 @@ impl AcpThreadView { let message_body = v_flex() .w_full() .gap_2p5() - .children(chunks.iter().enumerate().filter_map( - |(chunk_ix, chunk)| match chunk { - AssistantMessageChunk::Message { block } => { - block.markdown().map(|md| { - self.render_markdown(md.clone(), style.clone()) - .into_any_element() - }) - } - AssistantMessageChunk::Thought { block } => { - block.markdown().map(|md| { - self.render_thinking_block( - index, - chunk_ix, - md.clone(), - window, - cx, - ) - .into_any_element() - }) - } - }, - )) + .children(chunks.iter().enumerate().map(|(chunk_ix, chunk)| { + match chunk { + AssistantMessageChunk::Text { chunk } => self + .render_markdown(chunk.clone(), style.clone()) + .into_any_element(), + AssistantMessageChunk::Thought { chunk } => self.render_thinking_block( + index, + chunk_ix, + chunk.clone(), + window, + cx, + ), + } + })) .into_any(); v_flex() @@ -959,12 +869,9 @@ impl AcpThreadView { let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix)); let status_icon = match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { .. } => None, ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Pending, - } - | ToolCallStatus::WaitingForConfirmation { .. } => None, - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::InProgress, + status: acp::ToolCallStatus::Running, .. } => Some( Icon::new(IconName::ArrowCircle) @@ -978,13 +885,13 @@ impl AcpThreadView { .into_any(), ), ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Completed, + status: acp::ToolCallStatus::Finished, .. } => None, ToolCallStatus::Rejected | ToolCallStatus::Canceled | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Failed, + status: acp::ToolCallStatus::Error, .. } => Some( Icon::new(IconName::X) @@ -1002,9 +909,34 @@ impl AcpThreadView { .any(|content| matches!(content, ToolCallContent::Diff { .. })), }; - let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; + let is_collapsible = tool_call.content.is_some() && !needs_confirmation; let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id); + let content = if is_open { + match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { confirmation, .. } => { + Some(self.render_tool_call_confirmation( + tool_call.id, + confirmation, + tool_call.content.as_ref(), + window, + cx, + )) + } + ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { + tool_call.content.as_ref().map(|content| { + div() + .py_1p5() + .child(self.render_tool_call_content(content, window, cx)) + .into_any_element() + }) + } + ToolCallStatus::Rejected => None, + } + } else { + None + }; + v_flex() .when(needs_confirmation, |this| { this.rounded_lg() @@ -1044,19 +976,9 @@ impl AcpThreadView { }) .gap_1p5() .child( - Icon::new(match tool_call.kind { - acp::ToolKind::Read => IconName::ToolRead, - acp::ToolKind::Edit => IconName::ToolPencil, - acp::ToolKind::Delete => IconName::ToolDeleteFile, - acp::ToolKind::Move => IconName::ArrowRightLeft, - acp::ToolKind::Search => IconName::ToolSearch, - acp::ToolKind::Execute => IconName::ToolTerminal, - acp::ToolKind::Think => IconName::ToolBulb, - acp::ToolKind::Fetch => IconName::ToolWeb, - acp::ToolKind::Other => IconName::ToolHammer, - }) - .size(IconSize::Small) - .color(Color::Muted), + Icon::new(tool_call.icon) + .size(IconSize::Small) + .color(Color::Muted), ) .child(if tool_call.locations.len() == 1 { let name = tool_call.locations[0] @@ -1101,16 +1023,16 @@ impl AcpThreadView { .gap_0p5() .when(is_collapsible, |this| { this.child( - Disclosure::new(("expand", entry_ix), is_open) + Disclosure::new(("expand", tool_call.id.0), is_open) .opened_icon(IconName::ChevronUp) .closed_icon(IconName::ChevronDown) .on_click(cx.listener({ - let id = tool_call.id.clone(); + let id = tool_call.id; move |this: &mut Self, _, _, cx: &mut Context| { if is_open { this.expanded_tool_calls.remove(&id); } else { - this.expanded_tool_calls.insert(id.clone()); + this.expanded_tool_calls.insert(id); } cx.notify(); } @@ -1120,12 +1042,12 @@ impl AcpThreadView { .children(status_icon), ) .on_click(cx.listener({ - let id = tool_call.id.clone(); + let id = tool_call.id; move |this: &mut Self, _, _, cx: &mut Context| { if is_open { this.expanded_tool_calls.remove(&id); } else { - this.expanded_tool_calls.insert(id.clone()); + this.expanded_tool_calls.insert(id); } cx.notify(); } @@ -1133,7 +1055,7 @@ impl AcpThreadView { ) .when(is_open, |this| { this.child( - v_flex() + div() .text_xs() .when(is_collapsible, |this| { this.mt_1() @@ -1142,45 +1064,7 @@ impl AcpThreadView { .bg(cx.theme().colors().editor_background) .rounded_lg() }) - .map(|this| { - if is_open { - match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { options, .. } => this - .children(tool_call.content.iter().map(|content| { - div() - .py_1p5() - .child( - self.render_tool_call_content( - content, window, cx, - ), - ) - .into_any_element() - })) - .child(self.render_permission_buttons( - options, - entry_ix, - tool_call.id.clone(), - tool_call.content.is_empty(), - cx, - )), - ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { - this.children(tool_call.content.iter().map(|content| { - div() - .py_1p5() - .child( - self.render_tool_call_content( - content, window, cx, - ), - ) - .into_any_element() - })) - } - ToolCallStatus::Rejected => this, - } - } else { - this - } - }), + .children(content), ) }) } @@ -1192,20 +1076,14 @@ impl AcpThreadView { cx: &Context, ) -> AnyElement { match content { - ToolCallContent::ContentBlock { content } => { - if let Some(md) = content.markdown() { - div() - .p_2() - .child( - self.render_markdown( - md.clone(), - default_markdown_style(false, window, cx), - ), - ) - .into_any_element() - } else { - Empty.into_any_element() - } + ToolCallContent::Markdown { markdown } => { + div() + .p_2() + .child(self.render_markdown( + markdown.clone(), + default_markdown_style(false, window, cx), + )) + .into_any_element() } ToolCallContent::Diff { diff: Diff { multibuffer, .. }, @@ -1214,56 +1092,223 @@ impl AcpThreadView { } } - fn render_permission_buttons( + fn render_tool_call_confirmation( &self, - options: &[acp::PermissionOption], - entry_ix: usize, - tool_call_id: acp::ToolCallId, - empty_content: bool, + tool_call_id: ToolCallId, + confirmation: &ToolCallConfirmation, + content: Option<&ToolCallContent>, + window: &Window, + cx: &Context, + ) -> AnyElement { + let confirmation_container = v_flex().mt_1().py_1p5(); + + match confirmation { + ToolCallConfirmation::Edit { description } => confirmation_container + .child( + div() + .px_2() + .children(description.clone().map(|description| { + self.render_markdown( + description, + default_markdown_style(false, window, cx), + ) + })), + ) + .children(content.map(|content| self.render_tool_call_content(content, window, cx))) + .child(self.render_confirmation_buttons( + &[AlwaysAllowOption { + id: "always_allow", + label: "Always Allow Edits".into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, + }], + tool_call_id, + cx, + )) + .into_any(), + ToolCallConfirmation::Execute { + command, + root_command, + description, + } => confirmation_container + .child(v_flex().px_2().pb_1p5().child(command.clone()).children( + description.clone().map(|description| { + self.render_markdown(description, default_markdown_style(false, window, cx)) + .on_url_click({ + let workspace = self.workspace.clone(); + move |text, window, cx| { + Self::open_link(text, &workspace, window, cx); + } + }) + }), + )) + .children(content.map(|content| self.render_tool_call_content(content, window, cx))) + .child(self.render_confirmation_buttons( + &[AlwaysAllowOption { + id: "always_allow", + label: format!("Always Allow {root_command}").into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, + }], + tool_call_id, + cx, + )) + .into_any(), + ToolCallConfirmation::Mcp { + server_name, + tool_name: _, + tool_display_name, + description, + } => confirmation_container + .child( + v_flex() + .px_2() + .pb_1p5() + .child(format!("{server_name} - {tool_display_name}")) + .children(description.clone().map(|description| { + self.render_markdown( + description, + default_markdown_style(false, window, cx), + ) + })), + ) + .children(content.map(|content| self.render_tool_call_content(content, window, cx))) + .child(self.render_confirmation_buttons( + &[ + AlwaysAllowOption { + id: "always_allow_server", + label: format!("Always Allow {server_name}").into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, + }, + AlwaysAllowOption { + id: "always_allow_tool", + label: format!("Always Allow {tool_display_name}").into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowTool, + }, + ], + tool_call_id, + cx, + )) + .into_any(), + ToolCallConfirmation::Fetch { description, urls } => confirmation_container + .child( + v_flex() + .px_2() + .pb_1p5() + .gap_1() + .children(urls.iter().map(|url| { + h_flex().child( + Button::new(url.clone(), url) + .icon(IconName::ArrowUpRight) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click({ + let url = url.clone(); + move |_, _, cx| cx.open_url(&url) + }), + ) + })) + .children(description.clone().map(|description| { + self.render_markdown( + description, + default_markdown_style(false, window, cx), + ) + })), + ) + .children(content.map(|content| self.render_tool_call_content(content, window, cx))) + .child(self.render_confirmation_buttons( + &[AlwaysAllowOption { + id: "always_allow", + label: "Always Allow".into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, + }], + tool_call_id, + cx, + )) + .into_any(), + ToolCallConfirmation::Other { description } => confirmation_container + .child(v_flex().px_2().pb_1p5().child(self.render_markdown( + description.clone(), + default_markdown_style(false, window, cx), + ))) + .children(content.map(|content| self.render_tool_call_content(content, window, cx))) + .child(self.render_confirmation_buttons( + &[AlwaysAllowOption { + id: "always_allow", + label: "Always Allow".into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, + }], + tool_call_id, + cx, + )) + .into_any(), + } + } + + fn render_confirmation_buttons( + &self, + always_allow_options: &[AlwaysAllowOption], + tool_call_id: ToolCallId, cx: &Context, ) -> Div { h_flex() - .py_1p5() + .pt_1p5() .px_1p5() .gap_1() .justify_end() - .when(!empty_content, |this| { - this.border_t_1() - .border_color(self.tool_card_border_color(cx)) - }) - .children(options.iter().map(|option| { - let option_id = SharedString::from(option.id.0.clone()); - Button::new((option_id, entry_ix), option.label.clone()) - .map(|this| match option.kind { - acp::PermissionOptionKind::AllowOnce => { - this.icon(IconName::Check).icon_color(Color::Success) - } - acp::PermissionOptionKind::AllowAlways => { - this.icon(IconName::CheckDouble).icon_color(Color::Success) - } - acp::PermissionOptionKind::RejectOnce => { - this.icon(IconName::X).icon_color(Color::Error) - } - acp::PermissionOptionKind::RejectAlways => { - this.icon(IconName::X).icon_color(Color::Error) - } - }) + .border_t_1() + .border_color(self.tool_card_border_color(cx)) + .when(self.agent.supports_always_allow(), |this| { + this.children(always_allow_options.into_iter().map(|always_allow_option| { + let outcome = always_allow_option.outcome; + Button::new( + (always_allow_option.id, tool_call_id.0), + always_allow_option.label.clone(), + ) + .icon(IconName::CheckDouble) .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) + .icon_color(Color::Success) .on_click(cx.listener({ - let tool_call_id = tool_call_id.clone(); - let option_id = option.id.clone(); - let option_kind = option.kind; + let id = tool_call_id; + move |this, _, _, cx| { + this.authorize_tool_call(id, outcome, cx); + } + })) + })) + }) + .child( + Button::new(("allow", tool_call_id.0), "Allow") + .icon(IconName::Check) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Success) + .on_click(cx.listener({ + let id = tool_call_id; move |this, _, _, cx| { this.authorize_tool_call( - tool_call_id.clone(), - option_id.clone(), - option_kind, + id, + acp::ToolCallConfirmationOutcome::Allow, cx, ); } - })) - })) + })), + ) + .child( + Button::new(("reject", tool_call_id.0), "Reject") + .icon(IconName::X) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Error) + .on_click(cx.listener({ + let id = tool_call_id; + move |this, _, _, cx| { + this.authorize_tool_call( + id, + acp::ToolCallConfirmationOutcome::Reject, + cx, + ); + } + })), + ) } fn render_diff_editor(&self, multibuffer: &Entity) -> AnyElement { @@ -2025,15 +2070,15 @@ impl AcpThreadView { .icon_color(Color::Accent) .style(ButtonStyle::Filled) .disabled(self.thread().is_none() || is_editor_empty) + .on_click(cx.listener(|this, _, window, cx| { + this.chat(&Chat, window, cx); + })) .when(!is_editor_empty, |button| { button.tooltip(move |window, cx| Tooltip::for_action("Send", &Chat, window, cx)) }) .when(is_editor_empty, |button| { button.tooltip(Tooltip::text("Type a message to submit")) }) - .on_click(cx.listener(|this, _, window, cx| { - this.chat(&Chat, window, cx); - })) .into_any_element() } else { IconButton::new("stop-generation", IconName::StopFilled) @@ -2200,11 +2245,12 @@ impl AcpThreadView { .languages .language_for_name("Markdown"); - let (thread_summary, markdown) = if let Some(thread) = self.thread() { - let thread = thread.read(cx); - (thread.title().to_string(), thread.to_markdown(cx)) - } else { - return Task::ready(Ok(())); + let (thread_summary, markdown) = match &self.thread_state { + ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { + let thread = thread.read(cx); + (thread.title().to_string(), thread.to_markdown(cx)) + } + ThreadState::Loading { .. } | ThreadState::LoadError(..) => return Task::ready(Ok(())), }; window.spawn(cx, async move |cx| { @@ -2247,165 +2293,17 @@ impl AcpThreadView { self.list_state.scroll_to(ListOffset::default()); cx.notify(); } +} - pub fn scroll_to_bottom(&mut self, cx: &mut Context) { - if let Some(thread) = self.thread() { - let entry_count = thread.read(cx).entries().len(); - self.list_state.reset(entry_count); - cx.notify(); - } +impl Focusable for AcpThreadView { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.message_editor.focus_handle(cx) } +} - fn notify_with_sound( - &mut self, - caption: impl Into, - icon: IconName, - window: &mut Window, - cx: &mut Context, - ) { - self.play_notification_sound(window, cx); - self.show_notification(caption, icon, window, cx); - } - - fn play_notification_sound(&self, window: &Window, cx: &mut App) { - let settings = AgentSettings::get_global(cx); - if settings.play_sound_when_agent_done && !window.is_window_active() { - Audio::play_sound(Sound::AgentDone, cx); - } - } - - fn show_notification( - &mut self, - caption: impl Into, - icon: IconName, - window: &mut Window, - cx: &mut Context, - ) { - if window.is_window_active() || !self.notifications.is_empty() { - return; - } - - let title = self.title(cx); - - match AgentSettings::get_global(cx).notify_when_agent_waiting { - NotifyWhenAgentWaiting::PrimaryScreen => { - if let Some(primary) = cx.primary_display() { - self.pop_up(icon, caption.into(), title, window, primary, cx); - } - } - NotifyWhenAgentWaiting::AllScreens => { - let caption = caption.into(); - for screen in cx.displays() { - self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx); - } - } - NotifyWhenAgentWaiting::Never => { - // Don't show anything - } - } - } - - fn pop_up( - &mut self, - icon: IconName, - caption: SharedString, - title: SharedString, - window: &mut Window, - screen: Rc, - cx: &mut Context, - ) { - let options = AgentNotification::window_options(screen, cx); - - let project_name = self.workspace.upgrade().and_then(|workspace| { - workspace - .read(cx) - .project() - .read(cx) - .visible_worktrees(cx) - .next() - .map(|worktree| worktree.read(cx).root_name().to_string()) - }); - - if let Some(screen_window) = cx - .open_window(options, |_, cx| { - cx.new(|_| { - AgentNotification::new(title.clone(), caption.clone(), icon, project_name) - }) - }) - .log_err() - { - if let Some(pop_up) = screen_window.entity(cx).log_err() { - self.notification_subscriptions - .entry(screen_window) - .or_insert_with(Vec::new) - .push(cx.subscribe_in(&pop_up, window, { - |this, _, event, window, cx| match event { - AgentNotificationEvent::Accepted => { - let handle = window.window_handle(); - cx.activate(true); - - let workspace_handle = this.workspace.clone(); - - // If there are multiple Zed windows, activate the correct one. - cx.defer(move |cx| { - handle - .update(cx, |_view, window, _cx| { - window.activate_window(); - - if let Some(workspace) = workspace_handle.upgrade() { - workspace.update(_cx, |workspace, cx| { - workspace.focus_panel::(window, cx); - }); - } - }) - .log_err(); - }); - - this.dismiss_notifications(cx); - } - AgentNotificationEvent::Dismissed => { - this.dismiss_notifications(cx); - } - } - })); - - self.notifications.push(screen_window); - - // If the user manually refocuses the original window, dismiss the popup. - self.notification_subscriptions - .entry(screen_window) - .or_insert_with(Vec::new) - .push({ - let pop_up_weak = pop_up.downgrade(); - - cx.observe_window_activation(window, move |_, window, cx| { - if window.is_window_active() { - if let Some(pop_up) = pop_up_weak.upgrade() { - pop_up.update(cx, |_, cx| { - cx.emit(AgentNotificationEvent::Dismissed); - }); - } - } - }) - }); - } - } - } - - fn dismiss_notifications(&mut self, cx: &mut Context) { - for window in self.notifications.drain(..) { - window - .update(cx, |_, window, _| { - window.remove_window(); - }) - .ok(); - - self.notification_subscriptions.remove(&window); - } - } - - fn render_thread_controls(&mut self, cx: &mut Context) -> impl IntoElement { - let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileText) +impl Render for AcpThreadView { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let open_as_markdown = IconButton::new("open-as-markdown", IconName::DocumentText) .icon_size(IconSize::XSmall) .icon_color(Color::Ignored) .tooltip(Tooltip::text("Open Thread as Markdown")) @@ -2424,28 +2322,6 @@ impl AcpThreadView { this.scroll_to_top(cx); })); - h_flex() - .mt_1() - .mr_1() - .py_2() - .px(RESPONSE_PADDING_X) - .opacity(0.4) - .hover(|style| style.opacity(1.)) - .flex_wrap() - .justify_end() - .child(open_as_markdown) - .child(scroll_to_top) - } -} - -impl Focusable for AcpThreadView { - fn focus_handle(&self, cx: &App) -> FocusHandle { - self.message_editor.focus_handle(cx) - } -} - -impl Render for AcpThreadView { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() .size_full() .key_context("AcpThread") @@ -2454,26 +2330,22 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::open_agent_diff)) .child(match &self.thread_state { - ThreadState::Unauthenticated { connection } => v_flex() - .p_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .child(h_flex().mt_1p5().justify_center().children( - connection.auth_methods().into_iter().map(|method| { - Button::new( - SharedString::from(method.id.0.clone()), - method.label.clone(), - ) - .on_click({ - let method_id = method.id.clone(); - cx.listener(move |this, _, window, cx| { - this.authenticate(method_id.clone(), window, cx) - }) - }) - }), - )), + ThreadState::Unauthenticated { .. } => { + v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_pending_auth_state()) + .child( + h_flex().mt_1p5().justify_center().child( + Button::new("sign-in", format!("Sign in to {}", self.agent.name())) + .on_click(cx.listener(|this, _, window, cx| { + this.authenticate(window, cx) + })), + ), + ) + } ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2() @@ -2481,39 +2353,42 @@ impl Render for AcpThreadView { .items_center() .justify_center() .child(self.render_error_state(e, cx)), - ThreadState::Ready { thread, .. } => { - let thread_clone = thread.clone(); - - v_flex().flex_1().map(|this| { - if self.list_state.item_count() > 0 { - let is_generating = - matches!(thread_clone.read(cx).status(), ThreadStatus::Generating); - - this.child( - list(self.list_state.clone()) - .with_sizing_behavior(gpui::ListSizingBehavior::Auto) - .flex_grow() - .into_any(), - ) - .when(!is_generating, |this| { - this.child(self.render_thread_controls(cx)) - }) - .children(match thread_clone.read(cx).status() { - ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => { - None - } - ThreadStatus::Generating => div() - .px_5() - .py_2() - .child(LoadingLabel::new("").size(LabelSize::Small)) - .into(), - }) - .children(self.render_activity_bar(&thread_clone, window, cx)) - } else { - this.child(self.render_empty_state(cx)) - } - }) - } + ThreadState::Ready { thread, .. } => v_flex().flex_1().map(|this| { + if self.list_state.item_count() > 0 { + this.child( + list(self.list_state.clone()) + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .flex_grow() + .into_any(), + ) + .child( + h_flex() + .group("controls") + .mt_1() + .mr_1() + .py_2() + .px(RESPONSE_PADDING_X) + .opacity(0.4) + .hover(|style| style.opacity(1.)) + .flex_wrap() + .justify_end() + .child(open_as_markdown) + .child(scroll_to_top) + .into_any_element(), + ) + .children(match thread.read(cx).status() { + ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => None, + ThreadStatus::Generating => div() + .px_5() + .py_2() + .child(LoadingLabel::new("").size(LabelSize::Small)) + .into(), + }) + .children(self.render_activity_bar(&thread, window, cx)) + } else { + this.child(self.render_empty_state(cx)) + } + }), }) .when_some(self.last_error.clone(), |el, error| { el.child( @@ -2699,347 +2574,3 @@ fn plan_label_markdown_style( ..default_md_style } } - -#[cfg(test)] -mod tests { - use agent_client_protocol::SessionId; - use editor::EditorSettings; - use fs::FakeFs; - use futures::future::try_join_all; - use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; - use rand::Rng; - use settings::SettingsStore; - - use super::*; - - #[gpui::test] - async fn test_notification_for_stop_event(cx: &mut TestAppContext) { - init_test(cx); - - let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await; - - let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); - message_editor.update_in(cx, |editor, window, cx| { - editor.set_text("Hello", window, cx); - }); - - cx.deactivate_window(); - - thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); - }); - - cx.run_until_parked(); - - assert!( - cx.windows() - .iter() - .any(|window| window.downcast::().is_some()) - ); - } - - #[gpui::test] - async fn test_notification_for_error(cx: &mut TestAppContext) { - init_test(cx); - - let (thread_view, cx) = - setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await; - - let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); - message_editor.update_in(cx, |editor, window, cx| { - editor.set_text("Hello", window, cx); - }); - - cx.deactivate_window(); - - thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); - }); - - cx.run_until_parked(); - - assert!( - cx.windows() - .iter() - .any(|window| window.downcast::().is_some()) - ); - } - - #[gpui::test] - async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { - init_test(cx); - - let tool_call_id = acp::ToolCallId("1".into()); - let tool_call = acp::ToolCall { - id: tool_call_id.clone(), - label: "Label".into(), - kind: acp::ToolKind::Edit, - status: acp::ToolCallStatus::Pending, - content: vec!["hi".into()], - locations: vec![], - raw_input: None, - }; - let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) - .with_permission_requests(HashMap::from_iter([( - tool_call_id, - vec![acp::PermissionOption { - id: acp::PermissionOptionId("1".into()), - label: "Allow".into(), - kind: acp::PermissionOptionKind::AllowOnce, - }], - )])); - let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; - - let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); - message_editor.update_in(cx, |editor, window, cx| { - editor.set_text("Hello", window, cx); - }); - - cx.deactivate_window(); - - thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); - }); - - cx.run_until_parked(); - - assert!( - cx.windows() - .iter() - .any(|window| window.downcast::().is_some()) - ); - } - - async fn setup_thread_view( - agent: impl AgentServer + 'static, - cx: &mut TestAppContext, - ) -> (Entity, &mut VisualTestContext) { - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [], cx).await; - let (workspace, cx) = - cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - - let thread_view = cx.update(|window, cx| { - cx.new(|cx| { - AcpThreadView::new( - Rc::new(agent), - workspace.downgrade(), - project, - Rc::new(RefCell::new(MessageHistory::default())), - 1, - None, - window, - cx, - ) - }) - }); - cx.run_until_parked(); - (thread_view, cx) - } - - struct StubAgentServer { - connection: C, - } - - impl StubAgentServer { - fn new(connection: C) -> Self { - Self { connection } - } - } - - impl StubAgentServer { - fn default() -> Self { - Self::new(StubAgentConnection::default()) - } - } - - impl AgentServer for StubAgentServer - where - C: 'static + AgentConnection + Send + Clone, - { - fn logo(&self) -> ui::IconName { - unimplemented!() - } - - fn name(&self) -> &'static str { - unimplemented!() - } - - fn empty_state_headline(&self) -> &'static str { - unimplemented!() - } - - fn empty_state_message(&self) -> &'static str { - unimplemented!() - } - - fn connect( - &self, - _root_dir: &Path, - _project: &Entity, - _cx: &mut App, - ) -> Task>> { - Task::ready(Ok(Rc::new(self.connection.clone()))) - } - } - - #[derive(Clone, Default)] - struct StubAgentConnection { - sessions: Arc>>>, - permission_requests: HashMap>, - updates: Vec, - } - - impl StubAgentConnection { - fn new(updates: Vec) -> Self { - Self { - updates, - permission_requests: HashMap::default(), - sessions: Arc::default(), - } - } - - fn with_permission_requests( - mut self, - permission_requests: HashMap>, - ) -> Self { - self.permission_requests = permission_requests; - self - } - } - - impl AgentConnection for StubAgentConnection { - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } - - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut gpui::AsyncApp, - ) -> Task>> { - let session_id = SessionId( - rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(7) - .map(char::from) - .collect::() - .into(), - ); - let thread = cx - .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) - .unwrap(); - self.sessions.lock().insert(session_id, thread.downgrade()); - Task::ready(Ok(thread)) - } - - fn authenticate( - &self, - _method_id: acp::AuthMethodId, - _cx: &mut App, - ) -> Task> { - unimplemented!() - } - - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { - let sessions = self.sessions.lock(); - let thread = sessions.get(¶ms.session_id).unwrap(); - let mut tasks = vec![]; - for update in &self.updates { - let thread = thread.clone(); - let update = update.clone(); - let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update - && let Some(options) = self.permission_requests.get(&tool_call.id) - { - Some((tool_call.clone(), options.clone())) - } else { - None - }; - let task = cx.spawn(async move |cx| { - if let Some((tool_call, options)) = permission_request { - let permission = thread.update(cx, |thread, cx| { - thread.request_tool_call_permission( - tool_call.clone(), - options.clone(), - cx, - ) - })?; - permission.await?; - } - thread.update(cx, |thread, cx| { - thread.handle_session_update(update.clone(), cx).unwrap(); - })?; - anyhow::Ok(()) - }); - tasks.push(task); - } - cx.spawn(async move |_| { - try_join_all(tasks).await?; - Ok(()) - }) - } - - fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { - unimplemented!() - } - } - - #[derive(Clone)] - struct SaboteurAgentConnection; - - impl AgentConnection for SaboteurAgentConnection { - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut gpui::AsyncApp, - ) -> Task>> { - Task::ready(Ok(cx - .new(|cx| { - AcpThread::new( - "SaboteurAgentConnection", - self, - project, - SessionId("test".into()), - cx, - ) - }) - .unwrap())) - } - - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } - - fn authenticate( - &self, - _method_id: acp::AuthMethodId, - _cx: &mut App, - ) -> Task> { - unimplemented!() - } - - fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task> { - Task::ready(Err(anyhow::anyhow!("Error prompting"))) - } - - fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { - unimplemented!() - } - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - AgentSettings::register(cx); - workspace::init_settings(cx); - ThemeSettings::register(cx); - release_channel::init(SemanticVersion::default(), cx); - EditorSettings::register(cx); - }); - } -} diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 04a093c7d0..e27c318221 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -14,7 +14,6 @@ use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; use anyhow::Context as _; use assistant_tool::ToolUseStatus; use audio::{Audio, Sound}; -use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; use editor::scroll::Autoscroll; @@ -53,6 +52,7 @@ use util::ResultExt as _; use util::markdown::MarkdownCodeBlock; use workspace::{CollaboratorId, Workspace}; use zed_actions::assistant::OpenRulesLibrary; +use zed_llm_client::CompletionIntent; const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container"; const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1; diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index dad930be9e..fabeee2bce 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -7,7 +7,6 @@ use std::{sync::Arc, time::Duration}; use agent_settings::AgentSettings; use assistant_tool::{ToolSource, ToolWorkingSet}; -use cloud_llm_client::Plan; use collections::HashMap; use context_server::ContextServerId; use extension::ExtensionManifest; @@ -26,6 +25,7 @@ use project::{ context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore}, project_settings::{ContextServerSettings, ProjectSettings}, }; +use proto::Plan; use settings::{Settings, update_settings_file}; use ui::{ Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, @@ -180,7 +180,7 @@ impl AgentConfiguration { let current_plan = if is_zed_provider { self.workspace .upgrade() - .and_then(|workspace| workspace.read(cx).user_store().read(cx).plan()) + .and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan()) } else { None }; @@ -193,7 +193,6 @@ impl AgentConfiguration { .unwrap_or(false); v_flex() - .w_full() .when(is_expanded, |this| this.mb_2()) .child( div() @@ -224,7 +223,6 @@ impl AgentConfiguration { .hover(|hover| hover.bg(cx.theme().colors().element_hover)) .child( h_flex() - .w_full() .gap_2() .child( Icon::new(provider.icon()) @@ -233,7 +231,6 @@ impl AgentConfiguration { ) .child( h_flex() - .w_full() .gap_1() .child( Label::new(provider_name.clone()) @@ -317,7 +314,6 @@ impl AgentConfiguration { let providers = LanguageModelRegistry::read_global(cx).providers(); v_flex() - .w_full() .child( h_flex() .p(DynamicSpacing::Base16.rems(cx)) @@ -328,67 +324,50 @@ impl AgentConfiguration { .justify_between() .child( v_flex() - .w_full() .gap_0p5() - .child( - h_flex() - .w_full() - .gap_2() - .justify_between() - .child(Headline::new("LLM Providers")) - .child( - PopoverMenu::new("add-provider-popover") - .trigger( - Button::new("add-provider", "Add Provider") - .icon_position(IconPosition::Start) - .icon(IconName::Plus) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .label_size(LabelSize::Small), - ) - .anchor(gpui::Corner::TopRight) - .menu({ - let workspace = self.workspace.clone(); - move |window, cx| { - Some(ContextMenu::build( - window, - cx, - |menu, _window, _cx| { - menu.header("Compatible APIs").entry( - "OpenAI", - None, - { - let workspace = - workspace.clone(); - move |window, cx| { - workspace - .update(cx, |workspace, cx| { - AddLlmProviderModal::toggle( - LlmCompatibleProvider::OpenAi, - workspace, - window, - cx, - ); - }) - .log_err(); - } - }, - ) - }, - )) - } - }), - ), - ) + .child(Headline::new("LLM Providers")) .child( Label::new("Add at least one provider to use AI-powered features.") .color(Color::Muted), ), + ) + .child( + PopoverMenu::new("add-provider-popover") + .trigger( + Button::new("add-provider", "Add Provider") + .icon_position(IconPosition::Start) + .icon(IconName::Plus) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .label_size(LabelSize::Small), + ) + .anchor(gpui::Corner::TopRight) + .menu({ + let workspace = self.workspace.clone(); + move |window, cx| { + Some(ContextMenu::build(window, cx, |menu, _window, _cx| { + menu.header("Compatible APIs").entry("OpenAI", None, { + let workspace = workspace.clone(); + move |window, cx| { + workspace + .update(cx, |workspace, cx| { + AddLlmProviderModal::toggle( + LlmCompatibleProvider::OpenAi, + workspace, + window, + cx, + ); + }) + .log_err(); + } + }) + })) + } + }), ), ) .child( div() - .w_full() .pl(DynamicSpacing::Base08.rems(cx)) .pr(DynamicSpacing::Base20.rems(cx)) .children( @@ -404,11 +383,9 @@ impl AgentConfiguration { let fs = self.fs.clone(); SwitchField::new( - "always-allow-tool-actions-switch", - "Allow running commands without asking for confirmation", - Some( - "The agent can perform potentially destructive actions without asking for your confirmation.".into(), - ), + "single-file-review", + "Enable single-file agent reviews", + "Agent edits are also displayed in single-file editors for review.", always_allow_tool_actions, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -426,7 +403,7 @@ impl AgentConfiguration { SwitchField::new( "single-file-review", "Enable single-file agent reviews", - Some("Agent edits are also displayed in single-file editors for review.".into()), + "Agent edits are also displayed in single-file editors for review.", single_file_review, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -444,9 +421,7 @@ impl AgentConfiguration { SwitchField::new( "sound-notification", "Play sound when finished generating", - Some( - "Hear a notification sound when the agent is done generating changes or needs your input.".into(), - ), + "Hear a notification sound when the agent is done generating changes or needs your input.", play_sound_when_agent_done, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -464,9 +439,7 @@ impl AgentConfiguration { SwitchField::new( "modifier-send", "Use modifier to submit a message", - Some( - "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(), - ), + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.", use_modifier_to_send, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -508,7 +481,7 @@ impl AgentConfiguration { .blend(cx.theme().colors().text_accent.opacity(0.2)); let (plan_name, label_color, bg_color) = match plan { - Plan::ZedFree => ("Free", Color::Default, free_chip_bg), + Plan::Free => ("Free", Color::Default, free_chip_bg), Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg), Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg), }; diff --git a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs index 5d44bb2d92..45536ff13b 100644 --- a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs @@ -483,7 +483,7 @@ impl ManageProfilesModal { let icon = match mode.profile_id.as_str() { "write" => IconName::Pencil, - "ask" => IconName::Chat, + "ask" => IconName::MessageBubbles, _ => IconName::UserRoundPen, }; diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index c4dc359093..e69664ce88 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1506,7 +1506,8 @@ impl AgentDiff { .read(cx) .entries() .last() - .map_or(false, |entry| entry.diffs().next().is_some()) + .and_then(|entry| entry.diff()) + .is_some() { self.update_reviewing_editors(workspace, window, cx); } @@ -1516,14 +1517,12 @@ impl AgentDiff { .read(cx) .entries() .get(*ix) - .map_or(false, |entry| entry.diffs().next().is_some()) + .and_then(|entry| entry.diff()) + .is_some() { self.update_reviewing_editors(workspace, window, cx); } } - AcpThreadEvent::Stopped - | AcpThreadEvent::ToolAuthorizationRequired - | AcpThreadEvent::Error => {} } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index b552a701f0..a0250816a0 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -44,7 +44,6 @@ use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; use client::{DisableAiSettings, UserStore, zed_urls}; -use cloud_llm_client::{CompletionIntent, Plan, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; use feature_flags::{self, FeatureFlagAppExt}; use fs::Fs; @@ -60,6 +59,7 @@ use language_model::{ }; use project::{Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; +use proto::Plan; use rules_library::{RulesLibrary, open_rules_library}; use search::{BufferSearchBar, buffer_search}; use settings::{Settings, update_settings_file}; @@ -77,9 +77,10 @@ use workspace::{ }; use zed_actions::{ DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize, - agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector}, + agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding, ToggleModelSelector}, assistant::{OpenRulesLibrary, ToggleFocus}, }; +use zed_llm_client::{CompletionIntent, UsageLimit}; const AGENT_PANEL_KEY: &str = "agent_panel"; @@ -104,7 +105,7 @@ pub fn init(cx: &mut App) { panel.update(cx, |panel, cx| panel.open_history(window, cx)); } }) - .register_action(|workspace, _: &OpenSettings, window, cx| { + .register_action(|workspace, _: &OpenConfiguration, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| panel.open_configuration(window, cx)); @@ -439,7 +440,7 @@ pub struct AgentPanel { local_timezone: UtcOffset, active_view: ActiveView, acp_message_history: - Rc>>>, + Rc>>, previous_view: Option, history_store: Entity, history: Entity, @@ -578,6 +579,7 @@ impl AgentPanel { MessageEditor::new( fs.clone(), workspace.clone(), + user_store.clone(), message_editor_context_store.clone(), prompt_store.clone(), thread_store.downgrade(), @@ -846,6 +848,7 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), + self.user_store.clone(), context_store.clone(), self.prompt_store.clone(), self.thread_store.downgrade(), @@ -1119,6 +1122,7 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), + self.user_store.clone(), context_store, self.prompt_store.clone(), self.thread_store.downgrade(), @@ -1911,6 +1915,27 @@ impl AgentPanel { .when(cx.has_flag::(), |this| { this.header("Zed Agent") }) + .item( + ContextMenuEntry::new("New Thread") + .icon(IconName::NewThread) + .icon_color(Color::Muted) + .action(NewThread::default().boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Text Thread") + .icon(IconName::NewTextThread) + .icon_color(Color::Muted) + .action(NewTextThread.boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action(NewTextThread.boxed_clone(), cx); + }), + ) .when_some(active_thread, |this, active_thread| { let thread = active_thread.read(cx); @@ -1918,7 +1943,7 @@ impl AgentPanel { let thread_id = thread.id().clone(); this.item( ContextMenuEntry::new("New From Summary") - .icon(IconName::ThreadFromSummary) + .icon(IconName::NewFromSummary) .icon_color(Color::Muted) .handler(move |window, cx| { window.dispatch_action( @@ -1933,27 +1958,6 @@ impl AgentPanel { this } }) - .item( - ContextMenuEntry::new("New Thread") - .icon(IconName::Thread) - .icon_color(Color::Muted) - .action(NewThread::default().boxed_clone()) - .handler(move |window, cx| { - window.dispatch_action( - NewThread::default().boxed_clone(), - cx, - ); - }), - ) - .item( - ContextMenuEntry::new("New Text Thread") - .icon(IconName::TextThread) - .icon_color(Color::Muted) - .action(NewTextThread.boxed_clone()) - .handler(move |window, cx| { - window.dispatch_action(NewTextThread.boxed_clone(), cx); - }), - ) .when(cx.has_flag::(), |this| { this.separator() .header("External Agents") @@ -2012,69 +2016,65 @@ impl AgentPanel { ) .anchor(Corner::TopRight) .with_handle(self.agent_panel_menu_handle.clone()) - .menu({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Some(ContextMenu::build(window, cx, |mut menu, _window, _| { - menu = menu.context(focus_handle.clone()); - if let Some(usage) = usage { - menu = menu - .header_with_link("Prompt Usage", "Manage", account_url.clone()) - .custom_entry( - move |_window, cx| { - let used_percentage = match usage.limit { - UsageLimit::Limited(limit) => { - Some((usage.amount as f32 / limit as f32) * 100.) - } - UsageLimit::Unlimited => None, - }; - - h_flex() - .flex_1() - .gap_1p5() - .children(used_percentage.map(|percent| { - ProgressBar::new("usage", percent, 100., cx) - })) - .child( - Label::new(match usage.limit { - UsageLimit::Limited(limit) => { - format!("{} / {limit}", usage.amount) - } - UsageLimit::Unlimited => { - format!("{} / ∞", usage.amount) - } - }) - .size(LabelSize::Small) - .color(Color::Muted), - ) - .into_any_element() - }, - move |_, cx| cx.open_url(&zed_urls::account_url(cx)), - ) - .separator() - } - + .menu(move |window, cx| { + Some(ContextMenu::build(window, cx, |mut menu, _window, _| { + if let Some(usage) = usage { menu = menu - .header("MCP Servers") - .action( - "View Server Extensions", - Box::new(zed_actions::Extensions { - category_filter: Some( - zed_actions::ExtensionCategoryFilter::ContextServers, - ), - id: None, - }), + .header_with_link("Prompt Usage", "Manage", account_url.clone()) + .custom_entry( + move |_window, cx| { + let used_percentage = match usage.limit { + UsageLimit::Limited(limit) => { + Some((usage.amount as f32 / limit as f32) * 100.) + } + UsageLimit::Unlimited => None, + }; + + h_flex() + .flex_1() + .gap_1p5() + .children(used_percentage.map(|percent| { + ProgressBar::new("usage", percent, 100., cx) + })) + .child( + Label::new(match usage.limit { + UsageLimit::Limited(limit) => { + format!("{} / {limit}", usage.amount) + } + UsageLimit::Unlimited => { + format!("{} / ∞", usage.amount) + } + }) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any_element() + }, + move |_, cx| cx.open_url(&zed_urls::account_url(cx)), ) - .action("Add Custom Server…", Box::new(AddContextServer)) - .separator(); + .separator() + } - menu = menu - .action("Rules…", Box::new(OpenRulesLibrary::default())) - .action("Settings", Box::new(OpenSettings)) - .action(zoom_in_label, Box::new(ToggleZoom)); - menu - })) - } + menu = menu + .header("MCP Servers") + .action( + "View Server Extensions", + Box::new(zed_actions::Extensions { + category_filter: Some( + zed_actions::ExtensionCategoryFilter::ContextServers, + ), + id: None, + }), + ) + .action("Add Custom Server…", Box::new(AddContextServer)) + .separator(); + + menu = menu + .action("Rules…", Box::new(OpenRulesLibrary::default())) + .action("Settings", Box::new(OpenConfiguration)) + .action(zoom_in_label, Box::new(ToggleZoom)); + menu + })) }); h_flex() @@ -2275,10 +2275,10 @@ impl AgentPanel { | ActiveView::Configuration => return false, } - let plan = self.user_store.read(cx).plan(); + let plan = self.user_store.read(cx).current_plan(); let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); - matches!(plan, Some(Plan::ZedFree)) && has_previous_trial + matches!(plan, Some(Plan::Free)) && has_previous_trial } fn should_render_onboarding(&self, cx: &mut Context) -> bool { @@ -2464,14 +2464,14 @@ impl AgentPanel { .icon_color(Color::Muted) .full_width() .key_binding(KeyBinding::for_action_in( - &OpenSettings, + &OpenConfiguration, &focus_handle, window, cx, )) .on_click(|_event, window, cx| { window.dispatch_action( - OpenSettings.boxed_clone(), + OpenConfiguration.boxed_clone(), cx, ) }), @@ -2558,7 +2558,7 @@ impl AgentPanel { NewThreadButton::new( "new-thread-btn", "New Thread", - IconName::Thread, + IconName::NewThread, ) .keybinding(KeyBinding::for_action_in( &NewThread::default(), @@ -2579,7 +2579,7 @@ impl AgentPanel { NewThreadButton::new( "new-text-thread-btn", "New Text Thread", - IconName::TextThread, + IconName::NewTextThread, ) .keybinding(KeyBinding::for_action_in( &NewTextThread, @@ -2676,11 +2676,16 @@ impl AgentPanel { .style(ButtonStyle::Tinted(ui::TintColor::Warning)) .label_size(LabelSize::Small) .key_binding( - KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx) - .map(|kb| kb.size(rems_from_px(12.))), + KeyBinding::for_action_in( + &OpenConfiguration, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), ) .on_click(|_event, window, cx| { - window.dispatch_action(OpenSettings.boxed_clone(), cx) + window.dispatch_action(OpenConfiguration.boxed_clone(), cx) }), ), ConfigurationError::ProviderPendingTermsAcceptance(provider) => { @@ -2874,7 +2879,7 @@ impl AgentPanel { ) -> AnyElement { let error_message = match plan { Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", - Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.", + Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.", }; let icon = Icon::new(IconName::XCircle) @@ -3184,7 +3189,7 @@ impl Render for AgentPanel { .on_action(cx.listener(|this, _: &OpenHistory, window, cx| { this.open_history(window, cx); })) - .on_action(cx.listener(|this, _: &OpenSettings, window, cx| { + .on_action(cx.listener(|this, _: &OpenConfiguration, window, cx| { this.open_configuration(window, cx); })) .on_action(cx.listener(Self::open_active_thread_as_markdown)) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index c5574c2371..22f1f92e90 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -284,7 +284,6 @@ fn update_command_palette_filter(cx: &mut App) { } else { filter.show_namespace("agent"); filter.show_namespace("assistant"); - filter.show_namespace("copilot"); filter.show_namespace("zed_predict_onboarding"); filter.show_namespace("edit_prediction"); diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 615142b73d..64498e9281 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -6,7 +6,6 @@ use agent::{ use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; -use cloud_llm_client::CompletionIntent; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; use futures::{ @@ -36,6 +35,7 @@ use std::{ }; use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; +use zed_llm_client::CompletionIntent; pub struct BufferCodegen { alternatives: Vec>, diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 32f9a096d9..5cc56b014e 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -148,7 +148,7 @@ impl ContextPickerMode { Self::File => IconName::File, Self::Symbol => IconName::Code, Self::Fetch => IconName::Globe, - Self::Thread => IconName::Thread, + Self::Thread => IconName::MessageBubbles, Self::Rules => RULES_ICON, } } diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index 5ca0913be7..b377e40b19 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -423,7 +423,7 @@ impl ContextPickerCompletionProvider { let icon_for_completion = if recent { IconName::HistoryRerun } else { - IconName::Thread + IconName::MessageBubbles }; let new_text = format!("{} ", MentionLink::for_thread(&thread_entry)); let new_text_len = new_text.len(); @@ -436,7 +436,7 @@ impl ContextPickerCompletionProvider { source: project::CompletionSource::Custom, icon_path: Some(icon_for_completion.path().into()), confirm: Some(confirm_completion_callback( - IconName::Thread.path().into(), + IconName::MessageBubbles.path().into(), thread_entry.title().clone(), excerpt_id, source_range.start, diff --git a/crates/agent_ui/src/context_picker/thread_context_picker.rs b/crates/agent_ui/src/context_picker/thread_context_picker.rs index 15cc731f8f..cb2e97a493 100644 --- a/crates/agent_ui/src/context_picker/thread_context_picker.rs +++ b/crates/agent_ui/src/context_picker/thread_context_picker.rs @@ -253,7 +253,7 @@ pub fn render_thread_context_entry( .gap_1p5() .max_w_72() .child( - Icon::new(IconName::Thread) + Icon::new(IconName::MessageBubbles) .size(IconSize::XSmall) .color(Color::Muted), ) diff --git a/crates/agent_ui/src/debug.rs b/crates/agent_ui/src/debug.rs index bd34659210..ff6538dc85 100644 --- a/crates/agent_ui/src/debug.rs +++ b/crates/agent_ui/src/debug.rs @@ -1,10 +1,10 @@ #![allow(unused, dead_code)] use client::{ModelRequestUsage, RequestUsage}; -use cloud_llm_client::{Plan, UsageLimit}; use gpui::Global; use std::ops::{Deref, DerefMut}; use ui::prelude::*; +use zed_llm_client::{Plan, UsageLimit}; /// Debug only: Used for testing various account states /// diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index ffa654d12b..44ec050ae2 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -48,7 +48,7 @@ use text::{OffsetRangeExt, ToPoint as _}; use ui::prelude::*; use util::{RangeExt, ResultExt, maybe}; use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId}; -use zed_actions::agent::OpenSettings; +use zed_actions::agent::OpenConfiguration; pub fn init( fs: Arc, @@ -345,7 +345,7 @@ impl InlineAssistant { if let Some(answer) = answer { if answer == 0 { cx.update(|window, cx| { - window.dispatch_action(Box::new(OpenSettings), cx) + window.dispatch_action(Box::new(OpenConfiguration), cx) }) .ok(); } diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index a5f90edb57..ade7a5e13d 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -541,7 +541,7 @@ impl PromptEditor { match &self.mode { PromptEditorMode::Terminal { .. } => vec![ accept, - IconButton::new("confirm", IconName::PlayOutlined) + IconButton::new("confirm", IconName::Play) .icon_color(Color::Info) .shape(IconButtonShape::Square) .tooltip(|window, cx| { diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 7121624c87..655e87d7cd 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -576,7 +576,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { .icon_position(IconPosition::Start) .on_click(|_, window, cx| { window.dispatch_action( - zed_actions::agent::OpenSettings.boxed_clone(), + zed_actions::agent::OpenConfiguration.boxed_clone(), cx, ); }), diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 2185885347..c160f1de04 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -17,7 +17,7 @@ use agent::{ use agent_settings::{AgentSettings, CompletionMode}; use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; -use cloud_llm_client::CompletionIntent; +use client::UserStore; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; use editor::display_map::CreaseId; @@ -42,6 +42,7 @@ use language_model::{ use multi_buffer; use project::Project; use prompt_store::PromptStore; +use proto::Plan; use settings::Settings; use std::time::Duration; use theme::ThemeSettings; @@ -52,6 +53,7 @@ use util::ResultExt as _; use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::Chat; use zed_actions::agent::ToggleModelSelector; +use zed_llm_client::CompletionIntent; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; @@ -77,6 +79,7 @@ pub struct MessageEditor { editor: Entity, workspace: WeakEntity, project: Entity, + user_store: Entity, context_store: Entity, prompt_store: Option>, history_store: Option>, @@ -156,6 +159,7 @@ impl MessageEditor { pub fn new( fs: Arc, workspace: WeakEntity, + user_store: Entity, context_store: Entity, prompt_store: Option>, thread_store: WeakEntity, @@ -227,6 +231,7 @@ impl MessageEditor { Self { editor: editor.clone(), project: thread.read(cx).project().clone(), + user_store, thread, incompatible_tools_state: incompatible_tools.clone(), workspace, @@ -1282,12 +1287,24 @@ impl MessageEditor { return None; } - let user_store = self.project.read(cx).user_store().read(cx); - if user_store.is_usage_based_billing_enabled() { + let user_store = self.user_store.read(cx); + + let ubb_enable = user_store + .usage_based_billing_enabled() + .map_or(false, |enabled| enabled); + + if ubb_enable { return None; } - let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree); + let plan = user_store + .current_plan() + .map(|plan| match plan { + Plan::Free => zed_llm_client::Plan::ZedFree, + Plan::ZedPro => zed_llm_client::Plan::ZedPro, + Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + }) + .unwrap_or(zed_llm_client::Plan::ZedFree); let usage = user_store.model_request_usage()?; @@ -1752,6 +1769,7 @@ impl AgentPreview for MessageEditor { ) -> Option { if let Some(workspace) = workspace.upgrade() { let fs = workspace.read(cx).app_state().fs.clone(); + let user_store = workspace.read(cx).app_state().user_store.clone(); let project = workspace.read(cx).project().clone(); let weak_project = project.downgrade(); let context_store = cx.new(|_cx| ContextStore::new(weak_project, None)); @@ -1764,6 +1782,7 @@ impl AgentPreview for MessageEditor { MessageEditor::new( fs, workspace.downgrade(), + user_store, context_store, None, thread_store.downgrade(), diff --git a/crates/agent_ui/src/terminal_inline_assistant.rs b/crates/agent_ui/src/terminal_inline_assistant.rs index bcbc308c99..91867957cd 100644 --- a/crates/agent_ui/src/terminal_inline_assistant.rs +++ b/crates/agent_ui/src/terminal_inline_assistant.rs @@ -10,7 +10,6 @@ use agent::{ use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; -use cloud_llm_client::CompletionIntent; use collections::{HashMap, VecDeque}; use editor::{MultiBuffer, actions::SelectAll}; use fs::Fs; @@ -28,6 +27,7 @@ use terminal_view::TerminalView; use ui::prelude::*; use util::ResultExt; use workspace::{Toast, Workspace, notifications::NotificationId}; +use zed_llm_client::CompletionIntent; pub fn init( fs: Arc, diff --git a/crates/agent_ui/src/thread_history.rs b/crates/agent_ui/src/thread_history.rs index b8d1db88d6..a2ee816f73 100644 --- a/crates/agent_ui/src/thread_history.rs +++ b/crates/agent_ui/src/thread_history.rs @@ -701,7 +701,7 @@ impl RenderOnce for HistoryEntryElement { .on_hover(self.on_hover) .end_slot::(if self.hovered || self.selected { Some( - IconButton::new("delete", IconName::Trash) + IconButton::new("delete", IconName::TrashAlt) .shape(IconButtonShape::Square) .icon_size(IconSize::XSmall) .icon_color(Color::Muted) diff --git a/crates/agent_ui/src/ui/preview/usage_callouts.rs b/crates/agent_ui/src/ui/preview/usage_callouts.rs index 64869a6ec7..45af41395b 100644 --- a/crates/agent_ui/src/ui/preview/usage_callouts.rs +++ b/crates/agent_ui/src/ui/preview/usage_callouts.rs @@ -1,8 +1,8 @@ use client::{ModelRequestUsage, RequestUsage, zed_urls}; -use cloud_llm_client::{Plan, UsageLimit}; use component::{empty_example, example_group_with_title, single_example}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; use ui::{Callout, prelude::*}; +use zed_llm_client::{Plan, UsageLimit}; #[derive(IntoElement, RegisterComponent)] pub struct UsageCallout { diff --git a/crates/ai_onboarding/Cargo.toml b/crates/ai_onboarding/Cargo.toml index 95a45b1a6f..9031e14e29 100644 --- a/crates/ai_onboarding/Cargo.toml +++ b/crates/ai_onboarding/Cargo.toml @@ -16,10 +16,10 @@ default = [] [dependencies] client.workspace = true -cloud_llm_client.workspace = true component.workspace = true gpui.workspace = true language_model.workspace = true +proto.workspace = true serde.workspace = true smallvec.workspace = true telemetry.workspace = true diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index e86568fe7a..5f56e4d26e 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -136,7 +136,10 @@ impl RenderOnce for ApiKeysWithoutProviders { .full_width() .style(ButtonStyle::Outlined) .on_click(move |_, window, cx| { - window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx); + window.dispatch_action( + zed_actions::agent::OpenConfiguration.boxed_clone(), + cx, + ); }), ) } diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index f1629eeff8..e8a62f7ff2 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use client::{Client, UserStore}; -use cloud_llm_client::Plan; use gpui::{Entity, IntoElement, ParentElement}; use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; use ui::prelude::*; @@ -57,8 +56,15 @@ impl AgentPanelOnboarding { impl Render for AgentPanelOnboarding { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial); - let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro); + let enrolled_in_trial = matches!( + self.user_store.read(cx).current_plan(), + Some(proto::Plan::ZedProTrial) + ); + + let is_pro_user = matches!( + self.user_store.read(cx).current_plan(), + Some(proto::Plan::ZedPro) + ); AgentPanelOnboardingCard::new() .child( diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs index c252b65f20..7fffb60ecc 100644 --- a/crates/ai_onboarding/src/ai_onboarding.rs +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -1,15 +1,12 @@ mod agent_api_keys_onboarding; mod agent_panel_onboarding_card; mod agent_panel_onboarding_content; -mod ai_upsell_card; mod edit_prediction_onboarding_content; mod young_account_banner; pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders}; pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; pub use agent_panel_onboarding_content::AgentPanelOnboarding; -pub use ai_upsell_card::AiUpsellCard; -use cloud_llm_client::Plan; pub use edit_prediction_onboarding_content::EditPredictionOnboarding; pub use young_account_banner::YoungAccountBanner; @@ -57,7 +54,6 @@ impl RenderOnce for BulletItem { } } -#[derive(PartialEq)] pub enum SignInStatus { SignedIn, SigningIn, @@ -80,7 +76,7 @@ impl From for SignInStatus { pub struct ZedAiOnboarding { pub sign_in_status: SignInStatus, pub has_accepted_terms_of_service: bool, - pub plan: Option, + pub plan: Option, pub account_too_young: bool, pub continue_with_zed_ai: Arc, pub sign_in: Arc, @@ -100,8 +96,8 @@ impl ZedAiOnboarding { Self { sign_in_status: status.into(), - has_accepted_terms_of_service: store.has_accepted_terms_of_service(), - plan: store.plan(), + has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false), + plan: store.current_plan(), account_too_young: store.account_too_young(), continue_with_zed_ai, accept_terms_of_service: Arc::new({ @@ -114,9 +110,11 @@ impl ZedAiOnboarding { sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); - async move |cx| client.sign_in_with_optional_connect(true, cx).await + async move |cx| { + client.authenticate_and_connect(true, cx).await; + } }) - .detach_and_log_err(cx); + .detach(); }), dismiss_onboarding: None, } @@ -410,9 +408,9 @@ impl RenderOnce for ZedAiOnboarding { if matches!(self.sign_in_status, SignInStatus::SignedIn) { if self.has_accepted_terms_of_service { match self.plan { - None | Some(Plan::ZedFree) => self.render_free_plan_state(cx), - Some(Plan::ZedProTrial) => self.render_trial_state(cx), - Some(Plan::ZedPro) => self.render_pro_plan_state(cx), + None | Some(proto::Plan::Free) => self.render_free_plan_state(cx), + Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx), + Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx), } } else { self.render_accept_terms_of_service() @@ -432,7 +430,7 @@ impl Component for ZedAiOnboarding { fn onboarding( sign_in_status: SignInStatus, has_accepted_terms_of_service: bool, - plan: Option, + plan: Option, account_too_young: bool, ) -> AnyElement { ZedAiOnboarding { @@ -467,15 +465,25 @@ impl Component for ZedAiOnboarding { ), single_example( "Free Plan", - onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false), + onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false), ), single_example( "Pro Trial", - onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false), + onboarding( + SignInStatus::SignedIn, + true, + Some(proto::Plan::ZedProTrial), + false, + ), ), single_example( "Pro Plan", - onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false), + onboarding( + SignInStatus::SignedIn, + true, + Some(proto::Plan::ZedPro), + false, + ), ), ]) .into_any_element(), diff --git a/crates/ai_onboarding/src/ai_upsell_card.rs b/crates/ai_onboarding/src/ai_upsell_card.rs deleted file mode 100644 index 2408b6aa37..0000000000 --- a/crates/ai_onboarding/src/ai_upsell_card.rs +++ /dev/null @@ -1,212 +0,0 @@ -use std::sync::Arc; - -use client::{Client, zed_urls}; -use cloud_llm_client::Plan; -use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; -use ui::{Divider, List, Vector, VectorName, prelude::*}; - -use crate::{BulletItem, SignInStatus}; - -#[derive(IntoElement, RegisterComponent)] -pub struct AiUpsellCard { - pub sign_in_status: SignInStatus, - pub sign_in: Arc, - pub user_plan: Option, -} - -impl AiUpsellCard { - pub fn new(client: Arc, user_plan: Option) -> Self { - let status = *client.status().borrow(); - - Self { - user_plan, - sign_in_status: status.into(), - sign_in: Arc::new(move |_window, cx| { - cx.spawn({ - let client = client.clone(); - async move |cx| client.sign_in_with_optional_connect(true, cx).await - }) - .detach_and_log_err(cx); - }), - } - } -} - -impl RenderOnce for AiUpsellCard { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - let pro_section = v_flex() - .flex_grow() - .w_full() - .gap_1() - .child( - h_flex() - .gap_2() - .child( - Label::new("Pro") - .size(LabelSize::Small) - .color(Color::Accent) - .buffer_font(cx), - ) - .child(Divider::horizontal()), - ) - .child( - List::new() - .child(BulletItem::new("500 prompts with Claude models")) - .child(BulletItem::new( - "Unlimited edit predictions with Zeta, our open-source model", - )), - ); - - let free_section = v_flex() - .flex_grow() - .w_full() - .gap_1() - .child( - h_flex() - .gap_2() - .child( - Label::new("Free") - .size(LabelSize::Small) - .color(Color::Muted) - .buffer_font(cx), - ) - .child(Divider::horizontal()), - ) - .child( - List::new() - .child(BulletItem::new("50 prompts with Claude models")) - .child(BulletItem::new("2,000 accepted edit predictions")), - ); - - let grid_bg = h_flex().absolute().inset_0().w_full().h(px(240.)).child( - Vector::new(VectorName::Grid, rems_from_px(500.), rems_from_px(240.)) - .color(Color::Custom(cx.theme().colors().border.opacity(0.05))), - ); - - let gradient_bg = div() - .absolute() - .inset_0() - .size_full() - .bg(gpui::linear_gradient( - 180., - gpui::linear_color_stop( - cx.theme().colors().elevated_surface_background.opacity(0.8), - 0., - ), - gpui::linear_color_stop( - cx.theme().colors().elevated_surface_background.opacity(0.), - 0.8, - ), - )); - - const DESCRIPTION: &str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI."; - - let footer_buttons = match self.sign_in_status { - SignInStatus::SignedIn => v_flex() - .items_center() - .gap_1() - .child( - Button::new("sign_in", "Start 14-day Free Pro Trial") - .full_width() - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .on_click(move |_, _window, cx| { - telemetry::event!("Start Trial Clicked", state = "post-sign-in"); - cx.open_url(&zed_urls::start_trial_url(cx)) - }), - ) - .child( - Label::new("No credit card required") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .into_any_element(), - _ => Button::new("sign_in", "Sign In") - .full_width() - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .on_click({ - let callback = self.sign_in.clone(); - move |_, window, cx| { - telemetry::event!("Start Trial Clicked", state = "pre-sign-in"); - callback(window, cx) - } - }) - .into_any_element(), - }; - - v_flex() - .relative() - .p_4() - .pt_3() - .border_1() - .border_color(cx.theme().colors().border) - .rounded_lg() - .overflow_hidden() - .child(grid_bg) - .child(gradient_bg) - .child(Label::new("Try Zed AI").size(LabelSize::Large)) - .child( - div() - .max_w_3_4() - .mb_2() - .child(Label::new(DESCRIPTION).color(Color::Muted)), - ) - .child( - h_flex() - .w_full() - .mt_1p5() - .mb_2p5() - .items_start() - .gap_6() - .child(free_section) - .child(pro_section), - ) - .child(footer_buttons) - } -} - -impl Component for AiUpsellCard { - fn scope() -> ComponentScope { - ComponentScope::Agent - } - - fn name() -> &'static str { - "AI Upsell Card" - } - - fn sort_name() -> &'static str { - "AI Upsell Card" - } - - fn description() -> Option<&'static str> { - Some("A card presenting the Zed AI product during user's first-open onboarding flow.") - } - - fn preview(_window: &mut Window, _cx: &mut App) -> Option { - Some( - v_flex() - .p_4() - .gap_4() - .children(vec![example_group(vec![ - single_example( - "Signed Out State", - AiUpsellCard { - sign_in_status: SignInStatus::SignedOut, - sign_in: Arc::new(|_, _| {}), - user_plan: None, - } - .into_any_element(), - ), - single_example( - "Signed In State", - AiUpsellCard { - sign_in_status: SignInStatus::SignedIn, - sign_in: Arc::new(|_, _| {}), - user_plan: None, - } - .into_any_element(), - ), - ])]) - .into_any_element(), - ) - } -} diff --git a/crates/assistant_context/Cargo.toml b/crates/assistant_context/Cargo.toml index 8f5ff98790..f35dc43340 100644 --- a/crates/assistant_context/Cargo.toml +++ b/crates/assistant_context/Cargo.toml @@ -19,7 +19,6 @@ assistant_slash_commands.workspace = true chrono.workspace = true client.workspace = true clock.workspace = true -cloud_llm_client.workspace = true collections.workspace = true context_server.workspace = true fs.workspace = true @@ -49,6 +48,7 @@ util.workspace = true uuid.workspace = true workspace-hack.workspace = true workspace.workspace = true +zed_llm_client.workspace = true [dev-dependencies] indoc.workspace = true diff --git a/crates/assistant_context/src/assistant_context.rs b/crates/assistant_context/src/assistant_context.rs index 4518bbff79..136468e084 100644 --- a/crates/assistant_context/src/assistant_context.rs +++ b/crates/assistant_context/src/assistant_context.rs @@ -11,7 +11,6 @@ use assistant_slash_command::{ use assistant_slash_commands::FileCommandMetadata; use client::{self, Client, proto, telemetry::Telemetry}; use clock::ReplicaId; -use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use fs::{Fs, RenameOptions}; use futures::{FutureExt, StreamExt, future::Shared}; @@ -47,6 +46,7 @@ use text::{BufferSnapshot, ToPoint}; use ui::IconName; use util::{ResultExt, TryFutureExt, post_inc}; use uuid::Uuid; +use zed_llm_client::CompletionIntent; pub use crate::context_store::*; diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index d4b8fa3afc..146800e094 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -21,11 +21,9 @@ assistant_tool.workspace = true buffer_diff.workspace = true chrono.workspace = true client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true component.workspace = true derive_more.workspace = true -diffy = "0.4.2" editor.workspace = true feature_flags.workspace = true futures.workspace = true @@ -65,6 +63,8 @@ web_search.workspace = true which.workspace = true workspace-hack.workspace = true workspace.workspace = true +zed_llm_client.workspace = true +diffy = "0.4.2" [dev-dependencies] lsp = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index fed79434bb..0184dff36c 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -7,7 +7,6 @@ mod streaming_fuzzy_matcher; use crate::{Template, Templates}; use anyhow::Result; use assistant_tool::ActionLog; -use cloud_llm_client::CompletionIntent; use create_file_parser::{CreateFileParser, CreateFileParserEvent}; pub use edit_parser::EditFormat; use edit_parser::{EditParser, EditParserEvent, EditParserMetrics}; @@ -30,6 +29,7 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task:: use streaming_diff::{CharOperation, StreamingDiff}; use streaming_fuzzy_matcher::StreamingFuzzyMatcher; use util::debug_panic; +use zed_llm_client::CompletionIntent; #[derive(Serialize)] struct CreateFilePromptTemplate { diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index 9a8e762455..eda7eee0e3 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -1658,24 +1658,23 @@ impl EditAgentTest { } async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> Result { - const MAX_RETRIES: usize = 20; let mut attempt = 0; - loop { attempt += 1; - let response = request().await; - - if attempt >= MAX_RETRIES { - return response; - } - - let retry_delay = match &response { - Ok(_) => None, - Err(err) => match err.downcast_ref::() { - Some(err) => match &err { + match request().await { + Ok(result) => return Ok(result), + Err(err) => match err.downcast::() { + Ok(err) => match &err { LanguageModelCompletionError::RateLimitExceeded { retry_after, .. } | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => { - Some(retry_after.unwrap_or(Duration::from_secs(5))) + let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); + // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time. + let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); + eprintln!( + "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" + ); + Timer::after(retry_after + jitter).await; + continue; } LanguageModelCompletionError::UpstreamProviderError { status, @@ -1688,31 +1687,23 @@ async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE ) || status.as_u16() == 529; - if should_retry { - // Use server-provided retry_after if available, otherwise use default - Some(retry_after.unwrap_or(Duration::from_secs(5))) - } else { - None + if !should_retry { + return Err(err.into()); } - } - LanguageModelCompletionError::ApiReadResponseError { .. } - | LanguageModelCompletionError::ApiInternalServerError { .. } - | LanguageModelCompletionError::HttpSend { .. } => { - // Exponential backoff for transient I/O and internal server errors - Some(Duration::from_secs(2_u64.pow((attempt - 1) as u32).min(30))) - } - _ => None, - }, - _ => None, - }, - }; - if let Some(retry_after) = retry_delay { - let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); - eprintln!("Attempt #{attempt}: Retry after {retry_after:?} + jitter of {jitter:?}"); - Timer::after(retry_after + jitter).await; - } else { - return response; + // Use server-provided retry_after if available, otherwise use default + let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); + let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); + eprintln!( + "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" + ); + Timer::after(retry_after + jitter).await; + continue; + } + _ => return Err(err.into()), + }, + Err(err) => return Err(err), + }, } } } diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index d4a12f22c5..5eeca9c2c4 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -6,7 +6,6 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, }; -use cloud_llm_client::{WebSearchResponse, WebSearchResult}; use futures::{Future, FutureExt, TryFutureExt}; use gpui::{ AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, @@ -18,6 +17,7 @@ use serde::{Deserialize, Serialize}; use ui::{IconName, Tooltip, prelude::*}; use web_search::WebSearchRegistry; use workspace::Workspace; +use zed_llm_client::{WebSearchResponse, WebSearchResult}; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct WebSearchToolInput { diff --git a/crates/audio/Cargo.toml b/crates/audio/Cargo.toml index d857a3eb2f..960aaf8e08 100644 --- a/crates/audio/Cargo.toml +++ b/crates/audio/Cargo.toml @@ -18,6 +18,6 @@ collections.workspace = true derive_more.workspace = true gpui.workspace = true parking_lot.workspace = true -rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] } +rodio = { version = "0.20.0", default-features = false, features = ["wav"] } util.workspace = true workspace-hack.workspace = true diff --git a/crates/audio/src/assets.rs b/crates/audio/src/assets.rs index fd5c935d87..02da79dc24 100644 --- a/crates/audio/src/assets.rs +++ b/crates/audio/src/assets.rs @@ -3,9 +3,12 @@ use std::{io::Cursor, sync::Arc}; use anyhow::{Context as _, Result}; use collections::HashMap; use gpui::{App, AssetSource, Global}; -use rodio::{Decoder, Source, source::Buffered}; +use rodio::{ + Decoder, Source, + source::{Buffered, SamplesConverter}, +}; -type Sound = Buffered>>>; +type Sound = Buffered>>, f32>>; pub struct SoundRegistry { cache: Arc>>, @@ -45,7 +48,7 @@ impl SoundRegistry { .with_context(|| format!("No asset available for path {path}"))?? .into_owned(); let cursor = Cursor::new(bytes); - let source = Decoder::new(cursor)?.buffered(); + let source = Decoder::new(cursor)?.convert_samples::().buffered(); self.cache.lock().insert(name.to_string(), source.clone()); diff --git a/crates/audio/src/audio.rs b/crates/audio/src/audio.rs index 44baa16aa2..e7b9a59e8f 100644 --- a/crates/audio/src/audio.rs +++ b/crates/audio/src/audio.rs @@ -1,7 +1,7 @@ use assets::SoundRegistry; use derive_more::{Deref, DerefMut}; use gpui::{App, AssetSource, BorrowAppContext, Global}; -use rodio::{OutputStream, OutputStreamBuilder}; +use rodio::{OutputStream, OutputStreamHandle}; use util::ResultExt; mod assets; @@ -37,7 +37,8 @@ impl Sound { #[derive(Default)] pub struct Audio { - output_handle: Option, + _output_stream: Option, + output_handle: Option, } #[derive(Deref, DerefMut)] @@ -50,9 +51,11 @@ impl Audio { Self::default() } - fn ensure_output_exists(&mut self) -> Option<&OutputStream> { + fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> { if self.output_handle.is_none() { - self.output_handle = OutputStreamBuilder::open_default_stream().log_err(); + let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip(); + self.output_handle = output_handle; + self._output_stream = _output_stream; } self.output_handle.as_ref() @@ -66,7 +69,7 @@ impl Audio { cx.update_global::(|this, cx| { let output_handle = this.ensure_output_exists()?; let source = SoundRegistry::global(cx).get(sound.file()).log_err()?; - output_handle.mixer().add(source); + output_handle.play_raw(source).log_err()?; Some(()) }); } @@ -77,6 +80,7 @@ impl Audio { } cx.update_global::(|this, _| { + this._output_stream.take(); this.output_handle.take(); }); } diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 4ad156b9fb..b7ba811421 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -126,7 +126,7 @@ impl ChannelMembership { proto::channel_member::Kind::Member => 0, proto::channel_member::Kind::Invitee => 1, }, - username_order: &self.user.github_login, + username_order: self.user.github_login.as_str(), } } } diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index c92226eeeb..f8f5de3c39 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -259,6 +259,20 @@ async fn test_channel_messages(cx: &mut TestAppContext) { assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx); }); + let get_users = server.receive::().await.unwrap(); + assert_eq!(get_users.payload.user_ids, vec![5]); + server.respond( + get_users.receipt(), + proto::UsersResponse { + users: vec![proto::User { + id: 5, + github_login: "nathansobo".into(), + avatar_url: "http://avatar.com/nathansobo".into(), + name: None, + }], + }, + ); + // Join a channel and populate its existing messages. let channel = channel_store.update(cx, |store, cx| { let channel_id = store.ordered_channels().next().unwrap().1.id; @@ -320,7 +334,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("user-5".into(), "a".into()), + ("nathansobo".into(), "a".into()), ("maxbrunsfeld".into(), "b".into()) ] ); @@ -423,7 +437,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("user-5".into(), "y".into()), + ("nathansobo".into(), "y".into()), ("maxbrunsfeld".into(), "z".into()) ] ); diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 365625b445..b741f515fd 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -17,12 +17,11 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup [dependencies] anyhow.workspace = true +async-recursion = "0.3" async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] } base64.workspace = true chrono = { workspace = true, features = ["serde"] } clock.workspace = true -cloud_api_client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true credentials_provider.workspace = true derive_more.workspace = true @@ -34,8 +33,8 @@ http_client.workspace = true http_client_tls.workspace = true httparse = "1.10" log.workspace = true -parking_lot.workspace = true paths.workspace = true +parking_lot.workspace = true postage.workspace = true rand.workspace = true regex.workspace = true @@ -47,18 +46,19 @@ serde_json.workspace = true settings.workspace = true sha2.workspace = true smol.workspace = true -telemetry.workspace = true telemetry_events.workspace = true text.workspace = true thiserror.workspace = true time.workspace = true tiny_http.workspace = true tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] } -tokio.workspace = true url.workspace = true util.workspace = true -workspace-hack.workspace = true worktree.workspace = true +telemetry.workspace = true +tokio.workspace = true +workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index e6d8f10d12..8aafbf383f 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -6,21 +6,22 @@ pub mod telemetry; pub mod user; pub mod zed_urls; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result, anyhow, bail}; +use async_recursion::async_recursion; use async_tungstenite::tungstenite::{ client::IntoClientRequest, error::Error as WebsocketError, http::{HeaderValue, Request, StatusCode}, }; +use chrono::{DateTime, Utc}; use clock::SystemClock; -use cloud_api_client::CloudApiClient; use credentials_provider::CredentialsProvider; use futures::{ AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, channel::oneshot, future::BoxFuture, }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; -use http_client::{HttpClient, HttpClientWithUrl, http}; +use http_client::{AsyncBody, HttpClient, HttpClientWithUrl, http}; use parking_lot::RwLock; use postage::watch; use proxy::connect_proxy_stream; @@ -161,8 +162,20 @@ pub fn init(client: &Arc, cx: &mut App) { let client = client.clone(); move |_: &SignIn, cx| { if let Some(client) = client.upgrade() { - cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await) - .detach_and_log_err(cx); + cx.spawn( + async move |cx| match client.authenticate_and_connect(true, &cx).await { + ConnectionResult::Timeout => { + log::error!("Initial authentication timed out"); + } + ConnectionResult::ConnectionReset => { + log::error!("Initial authentication connection reset"); + } + ConnectionResult::Result(r) => { + r.log_err(); + } + }, + ) + .detach(); } } }); @@ -200,7 +213,6 @@ pub struct Client { id: AtomicU64, peer: Arc, http: Arc, - cloud_client: Arc, telemetry: Arc, credentials_provider: ClientCredentialsProvider, state: RwLock, @@ -271,8 +283,6 @@ pub enum Status { SignedOut, UpgradeRequired, Authenticating, - Authenticated, - AuthenticationError, Connecting, ConnectionError, Connected { @@ -576,7 +586,6 @@ impl Client { id: AtomicU64::new(0), peer: Peer::new(0), telemetry: Telemetry::new(clock, http.clone(), cx), - cloud_client: Arc::new(CloudApiClient::new(http.clone())), http, credentials_provider: ClientCredentialsProvider::new(cx), state: Default::default(), @@ -609,10 +618,6 @@ impl Client { self.http.clone() } - pub fn cloud_client(&self) -> Arc { - self.cloud_client.clone() - } - pub fn set_id(&self, id: u64) -> &Self { self.id.store(id, Ordering::SeqCst); self @@ -699,7 +704,7 @@ impl Client { let mut delay = INITIAL_RECONNECTION_DELAY; loop { - match client.connect(true, &cx).await { + match client.authenticate_and_connect(true, &cx).await { ConnectionResult::Timeout => { log::error!("client connect attempt timed out") } @@ -869,123 +874,17 @@ impl Client { .is_some() } - pub async fn sign_in( - self: &Arc, - try_provider: bool, - cx: &AsyncApp, - ) -> Result { - if self.status().borrow().is_signed_out() { - self.set_status(Status::Authenticating, cx); - } else { - self.set_status(Status::Reauthenticating, cx); - } - - let mut credentials = None; - - let old_credentials = self.state.read().credentials.clone(); - if let Some(old_credentials) = old_credentials { - if self - .cloud_client - .validate_credentials( - old_credentials.user_id as u32, - &old_credentials.access_token, - ) - .await? - { - credentials = Some(old_credentials); - } - } - - if credentials.is_none() && try_provider { - if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await { - if self - .cloud_client - .validate_credentials( - stored_credentials.user_id as u32, - &stored_credentials.access_token, - ) - .await? - { - credentials = Some(stored_credentials); - } else { - self.credentials_provider - .delete_credentials(cx) - .await - .log_err(); - } - } - } - - if credentials.is_none() { - let mut status_rx = self.status(); - let _ = status_rx.next().await; - futures::select_biased! { - authenticate = self.authenticate(cx).fuse() => { - match authenticate { - Ok(creds) => { - if IMPERSONATE_LOGIN.is_none() { - self.credentials_provider - .write_credentials(creds.user_id, creds.access_token.clone(), cx) - .await - .log_err(); - } - - credentials = Some(creds); - }, - Err(err) => { - self.set_status(Status::AuthenticationError, cx); - return Err(err); - } - } - } - _ = status_rx.next().fuse() => { - return Err(anyhow!("authentication canceled")); - } - } - } - - let credentials = credentials.unwrap(); - self.set_id(credentials.user_id); - self.cloud_client - .set_credentials(credentials.user_id as u32, credentials.access_token.clone()); - self.state.write().credentials = Some(credentials.clone()); - self.set_status(Status::Authenticated, cx); - - Ok(credentials) - } - - /// Performs a sign-in and also connects to Collab. - /// - /// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls - /// to `sign_in` when we're ready to remove auto-connection to Collab. - pub async fn sign_in_with_optional_connect( - self: &Arc, - try_provider: bool, - cx: &AsyncApp, - ) -> Result<()> { - let credentials = self.sign_in(try_provider, cx).await?; - - let connect_result = match self.connect_with_credentials(credentials, cx).await { - ConnectionResult::Timeout => Err(anyhow!("connection timed out")), - ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")), - ConnectionResult::Result(result) => result.context("client auth and connect"), - }; - connect_result.log_err(); - - Ok(()) - } - - pub async fn connect( + #[async_recursion(?Send)] + pub async fn authenticate_and_connect( self: &Arc, try_provider: bool, cx: &AsyncApp, ) -> ConnectionResult<()> { let was_disconnected = match *self.status().borrow() { - Status::SignedOut | Status::Authenticated => true, + Status::SignedOut => true, Status::ConnectionError | Status::ConnectionLost | Status::Authenticating { .. } - | Status::AuthenticationError | Status::Reauthenticating { .. } | Status::ReconnectionError { .. } => false, Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { @@ -998,10 +897,39 @@ impl Client { ); } }; - let credentials = match self.sign_in(try_provider, cx).await { - Ok(credentials) => credentials, - Err(err) => return ConnectionResult::Result(Err(err)), - }; + if was_disconnected { + self.set_status(Status::Authenticating, cx); + } else { + self.set_status(Status::Reauthenticating, cx) + } + + let mut read_from_provider = false; + let mut credentials = self.state.read().credentials.clone(); + if credentials.is_none() && try_provider { + credentials = self.credentials_provider.read_credentials(cx).await; + read_from_provider = credentials.is_some(); + } + + if credentials.is_none() { + let mut status_rx = self.status(); + let _ = status_rx.next().await; + futures::select_biased! { + authenticate = self.authenticate(cx).fuse() => { + match authenticate { + Ok(creds) => credentials = Some(creds), + Err(err) => { + self.set_status(Status::ConnectionError, cx); + return ConnectionResult::Result(Err(err)); + } + } + } + _ = status_rx.next().fuse() => { + return ConnectionResult::Result(Err(anyhow!("authentication canceled"))); + } + } + } + let credentials = credentials.unwrap(); + self.set_id(credentials.user_id); if was_disconnected { self.set_status(Status::Connecting, cx); @@ -1009,20 +937,17 @@ impl Client { self.set_status(Status::Reconnecting, cx); } - self.connect_with_credentials(credentials, cx).await - } - - async fn connect_with_credentials( - self: &Arc, - credentials: Credentials, - cx: &AsyncApp, - ) -> ConnectionResult<()> { let mut timeout = futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT)); futures::select_biased! { connection = self.establish_connection(&credentials, cx).fuse() => { match connection { Ok(conn) => { + self.state.write().credentials = Some(credentials.clone()); + if !read_from_provider && IMPERSONATE_LOGIN.is_none() { + self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err(); + } + futures::select_biased! { result = self.set_connection(conn, cx).fuse() => { match result.context("client auth and connect") { @@ -1040,8 +965,15 @@ impl Client { } } Err(EstablishConnectionError::Unauthorized) => { - self.set_status(Status::ConnectionError, cx); - ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) + self.state.write().credentials.take(); + if read_from_provider { + self.credentials_provider.delete_credentials(cx).await.log_err(); + self.set_status(Status::SignedOut, cx); + self.authenticate_and_connect(false, cx).await + } else { + self.set_status(Status::ConnectionError, cx); + ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) + } } Err(EstablishConnectionError::UpgradeRequired) => { self.set_status(Status::UpgradeRequired, cx); @@ -1205,7 +1137,7 @@ impl Client { .to_str() .map_err(EstablishConnectionError::other)? .to_string(); - Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}")) + Url::parse(&collab_url).with_context(|| format!("parsing colab rpc url {collab_url}")) } } @@ -1436,31 +1368,96 @@ impl Client { self: &Arc, http: Arc, login: String, - api_token: String, + mut api_token: String, ) -> Result { - #[derive(Serialize)] - struct ImpersonateUserBody { - github_login: String, + #[derive(Deserialize)] + struct AuthenticatedUserResponse { + user: User, } #[derive(Deserialize)] - struct ImpersonateUserResponse { - user_id: u64, - access_token: String, + struct User { + id: u64, } - let url = self - .http - .build_zed_cloud_url("/internal/users/impersonate", &[])?; - let request = Request::post(url.as_str()) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {api_token}")) - .body( - serde_json::to_string(&ImpersonateUserBody { - github_login: login, - })? - .into(), - )?; + let github_user = { + #[derive(Deserialize)] + struct GithubUser { + id: i32, + login: String, + created_at: DateTime, + } + + let request = { + let mut request_builder = + Request::get(&format!("https://api.github.com/users/{login}")); + if let Ok(github_token) = std::env::var("GITHUB_TOKEN") { + request_builder = + request_builder.header("Authorization", format!("Bearer {}", github_token)); + } + + request_builder.body(AsyncBody::empty())? + }; + + let mut response = http + .send(request) + .await + .context("error fetching GitHub user")?; + + let mut body = Vec::new(); + response + .body_mut() + .read_to_end(&mut body) + .await + .context("error reading GitHub user")?; + + if !response.status().is_success() { + let text = String::from_utf8_lossy(body.as_slice()); + bail!( + "status error {}, response: {text:?}", + response.status().as_u16() + ); + } + + serde_json::from_slice::(body.as_slice()).map_err(|err| { + log::error!("Error deserializing: {:?}", err); + log::error!( + "GitHub API response text: {:?}", + String::from_utf8_lossy(body.as_slice()) + ); + anyhow!("error deserializing GitHub user") + })? + }; + + let query_params = [ + ("github_login", &github_user.login), + ("github_user_id", &github_user.id.to_string()), + ( + "github_user_created_at", + &github_user.created_at.to_rfc3339(), + ), + ]; + + // Use the collab server's admin API to retrieve the ID + // of the impersonated user. + let mut url = self.rpc_url(http.clone(), None).await?; + url.set_path("/user"); + url.set_query(Some( + &query_params + .iter() + .map(|(key, value)| { + format!( + "{}={}", + key, + url::form_urlencoded::byte_serialize(value.as_bytes()).collect::() + ) + }) + .collect::>() + .join("&"), + )); + let request: http_client::Request = Request::get(url.as_str()) + .header("Authorization", format!("token {api_token}")) + .body("".into())?; let mut response = http.send(request).await?; let mut body = String::new(); @@ -1471,17 +1468,18 @@ impl Client { response.status().as_u16(), body, ); - let response: ImpersonateUserResponse = serde_json::from_str(&body)?; + let response: AuthenticatedUserResponse = serde_json::from_str(&body)?; + // Use the admin API token to authenticate as the impersonated user. + api_token.insert_str(0, "ADMIN_TOKEN:"); Ok(Credentials { - user_id: response.user_id, - access_token: response.access_token, + user_id: response.user.id, + access_token: api_token, }) } pub async fn sign_out(self: &Arc, cx: &AsyncApp) { self.state.write().credentials = None; - self.cloud_client.clear_credentials(); self.disconnect(cx); if self.has_credentials(cx).await { @@ -1710,7 +1708,7 @@ pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> { #[cfg(test)] mod tests { use super::*; - use crate::test::{FakeServer, parse_authorization_header}; + use crate::test::FakeServer; use clock::FakeSystemClock; use gpui::{AppContext as _, BackgroundExecutor, TestAppContext}; @@ -1791,7 +1789,7 @@ mod tests { }); let auth_and_connect = cx.spawn({ let client = client.clone(); - |cx| async move { client.connect(false, &cx).await } + |cx| async move { client.authenticate_and_connect(false, &cx).await } }); executor.run_until_parked(); assert!(matches!(status.next().await, Some(Status::Connecting))); @@ -1836,75 +1834,6 @@ mod tests { )); } - #[gpui::test(iterations = 10)] - async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) { - init_test(cx); - let auth_count = Arc::new(Mutex::new(0)); - let http_client = FakeHttpClient::create(|_request| async move { - Ok(http_client::Response::builder() - .status(200) - .body("".into()) - .unwrap()) - }); - let client = - cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx)); - client.override_authenticate({ - let auth_count = auth_count.clone(); - move |cx| { - let auth_count = auth_count.clone(); - cx.background_spawn(async move { - *auth_count.lock() += 1; - Ok(Credentials { - user_id: 1, - access_token: auth_count.lock().to_string(), - }) - }) - } - }); - - let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); - assert_eq!(*auth_count.lock(), 1); - assert_eq!(credentials.access_token, "1"); - - // If credentials are still valid, signing in doesn't trigger authentication. - let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); - assert_eq!(*auth_count.lock(), 1); - assert_eq!(credentials.access_token, "1"); - - // If the server is unavailable, signing in doesn't trigger authentication. - http_client - .as_fake() - .replace_handler(|_, _request| async move { - Ok(http_client::Response::builder() - .status(503) - .body("".into()) - .unwrap()) - }); - client.sign_in(false, &cx.to_async()).await.unwrap_err(); - assert_eq!(*auth_count.lock(), 1); - - // If credentials became invalid, signing in triggers authentication. - http_client - .as_fake() - .replace_handler(|_, request| async move { - let credentials = parse_authorization_header(&request).unwrap(); - if credentials.access_token == "2" { - Ok(http_client::Response::builder() - .status(200) - .body("".into()) - .unwrap()) - } else { - Ok(http_client::Response::builder() - .status(401) - .body("".into()) - .unwrap()) - } - }); - let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); - assert_eq!(*auth_count.lock(), 2); - assert_eq!(credentials.access_token, "2"); - } - #[gpui::test(iterations = 10)] async fn test_authenticating_more_than_once( cx: &mut TestAppContext, @@ -1937,7 +1866,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - move |cx| async move { client.connect(false, &cx).await } + move |cx| async move { client.authenticate_and_connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 1); @@ -1945,7 +1874,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - |cx| async move { client.connect(false, &cx).await } + |cx| async move { client.authenticate_and_connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 2); diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 7d39464e4a..4983fda5ef 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -358,13 +358,13 @@ impl Telemetry { worktree_id: WorktreeId, updated_entries_set: &UpdatedEntriesSet, ) { - let Some(project_types) = self.detect_project_types(worktree_id, updated_entries_set) + let Some(project_type_names) = self.detect_project_types(worktree_id, updated_entries_set) else { return; }; - for project_type in project_types { - telemetry::event!("Project Opened", project_type = project_type); + for project_type_name in project_type_names { + telemetry::event!("Project Opened", project_type = project_type_name); } } diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 439fb100d2..6ce79fa9c5 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,11 +1,8 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; use chrono::Duration; -use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; -use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; -use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; use rpc::{ ConnectionId, Peer, Receipt, TypedEnvelope, @@ -42,44 +39,6 @@ impl FakeServer { executor: cx.executor(), }; - client.http_client().as_fake().replace_handler({ - let state = server.state.clone(); - move |old_handler, req| { - let state = state.clone(); - let old_handler = old_handler.clone(); - async move { - match (req.method(), req.uri().path()) { - (&Method::GET, "/client/users/me") => { - let credentials = parse_authorization_header(&req); - if credentials - != Some(Credentials { - user_id: client_user_id, - access_token: state.lock().access_token.to_string(), - }) - { - return Ok(http_client::Response::builder() - .status(401) - .body("Unauthorized".into()) - .unwrap()); - } - - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&make_get_authenticated_user_response( - client_user_id as i32, - format!("user-{client_user_id}"), - )) - .unwrap() - .into(), - ) - .unwrap()) - } - _ => old_handler(req).await, - } - } - } - }); client .override_authenticate({ let state = Arc::downgrade(&server.state); @@ -146,7 +105,7 @@ impl FakeServer { }); client - .connect(false, &cx.to_async()) + .authenticate_and_connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -264,54 +223,3 @@ impl Drop for FakeServer { self.disconnect(); } } - -pub fn parse_authorization_header(req: &Request) -> Option { - let mut auth_header = req - .headers() - .get(http::header::AUTHORIZATION)? - .to_str() - .ok()? - .split_whitespace(); - let user_id = auth_header.next()?.parse().ok()?; - let access_token = auth_header.next()?; - Some(Credentials { - user_id, - access_token: access_token.to_string(), - }) -} - -pub fn make_get_authenticated_user_response( - user_id: i32, - github_login: String, -) -> GetAuthenticatedUserResponse { - GetAuthenticatedUserResponse { - user: AuthenticatedUser { - id: user_id, - metrics_id: format!("metrics-id-{user_id}"), - avatar_url: "".to_string(), - github_login, - name: None, - is_staff: false, - accepted_tos_at: None, - }, - feature_flags: vec![], - plan: PlanInfo { - plan: Plan::ZedPro, - subscription_period: None, - usage: CurrentUsage { - model_requests: UsageData { - used: 0, - limit: UsageLimit::Limited(500), - }, - edit_predictions: UsageData { - used: 250, - limit: UsageLimit::Unlimited, - }, - }, - trial_started_at: None, - is_usage_based_billing_enabled: false, - is_account_too_young: false, - has_overdue_invoices: false, - }, - } -} diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 3c125a0882..5ed258aa8e 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,11 +1,6 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; -use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo}; -use cloud_llm_client::{ - EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, - MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; use collections::{HashMap, HashSet, hash_map::Entry}; use derive_more::Deref; use feature_flags::FeatureFlagAppExt; @@ -21,7 +16,11 @@ use std::{ sync::{Arc, Weak}, }; use text::ReplicaId; -use util::{ResultExt, TryFutureExt as _}; +use util::{TryFutureExt as _, maybe}; +use zed_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; pub type UserId = u64; @@ -56,7 +55,7 @@ pub struct ParticipantIndex(pub u32); #[derive(Default, Debug)] pub struct User { pub id: UserId, - pub github_login: SharedString, + pub github_login: String, pub avatar_uri: SharedUri, pub name: Option, } @@ -108,14 +107,19 @@ pub enum ContactRequestStatus { pub struct UserStore { users: HashMap>, - by_github_login: HashMap, + by_github_login: HashMap, participant_indices: HashMap, update_contacts_tx: mpsc::UnboundedSender, + current_plan: Option, + subscription_period: Option<(DateTime, DateTime)>, + trial_started_at: Option>, model_request_usage: Option, edit_prediction_usage: Option, - plan_info: Option, + is_usage_based_billing_enabled: Option, + account_too_young: Option, + has_overdue_invoices: Option, current_user: watch::Receiver>>, - accepted_tos_at: Option>, + accepted_tos_at: Option>>, contacts: Vec>, incoming_contact_requests: Vec>, outgoing_contact_requests: Vec>, @@ -141,7 +145,6 @@ pub enum Event { ShowContacts, ParticipantIndicesChanged, PrivateUserInfoUpdated, - PlanUpdated, } #[derive(Clone, Copy)] @@ -185,9 +188,14 @@ impl UserStore { users: Default::default(), by_github_login: Default::default(), current_user: current_user_rx, - plan_info: None, + current_plan: None, + subscription_period: None, + trial_started_at: None, model_request_usage: None, edit_prediction_usage: None, + is_usage_based_billing_enabled: None, + account_too_young: None, + has_overdue_invoices: None, accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), @@ -217,30 +225,53 @@ impl UserStore { return Ok(()); }; match status { - Status::Authenticated | Status::Connected { .. } => { + Status::Connected { .. } => { if let Some(user_id) = client.user_id() { - let response = client.cloud_client().get_authenticated_user().await; - let mut current_user = None; + let fetch_user = if let Ok(fetch_user) = + this.update(cx, |this, cx| this.get_user(user_id, cx).log_err()) + { + fetch_user + } else { + break; + }; + let fetch_private_user_info = + client.request(proto::GetPrivateUserInfo {}).log_err(); + let (user, info) = + futures::join!(fetch_user, fetch_private_user_info); + cx.update(|cx| { - if let Some(response) = response.log_err() { - let user = Arc::new(User { - id: user_id, - github_login: response.user.github_login.clone().into(), - avatar_uri: response.user.avatar_url.clone().into(), - name: response.user.name.clone(), - }); - current_user = Some(user.clone()); + if let Some(info) = info { + let staff = + info.staff && !*feature_flags::ZED_DISABLE_STAFF; + cx.update_flags(staff, info.flags); + client.telemetry.set_authenticated_user_info( + Some(info.metrics_id.clone()), + staff, + ); + this.update(cx, |this, cx| { - this.by_github_login - .insert(user.github_login.clone(), user_id); - this.users.insert(user_id, user); - this.update_authenticated_user(response, cx) + let accepted_tos_at = { + #[cfg(debug_assertions)] + if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() + { + None + } else { + info.accepted_tos_at + } + + #[cfg(not(debug_assertions))] + info.accepted_tos_at + }; + + this.set_current_user_accepted_tos_at(accepted_tos_at); + cx.emit(Event::PrivateUserInfoUpdated); }) } else { anyhow::Ok(()) } })??; - current_user_tx.send(current_user).await.ok(); + + current_user_tx.send(user).await.ok(); this.update(cx, |_, cx| cx.notify())?; } @@ -321,22 +352,59 @@ impl UserStore { async fn handle_update_plan( this: Entity, - _message: TypedEnvelope, + message: TypedEnvelope, mut cx: AsyncApp, ) -> Result<()> { - let client = this - .read_with(&cx, |this, _| this.client.upgrade())? - .context("client was dropped")?; - - let response = client - .cloud_client() - .get_authenticated_user() - .await - .context("failed to fetch authenticated user")?; - this.update(&mut cx, |this, cx| { - this.update_authenticated_user(response, cx); - }) + this.current_plan = Some(message.payload.plan()); + this.subscription_period = maybe!({ + let period = message.payload.subscription_period?; + let started_at = DateTime::from_timestamp(period.started_at as i64, 0)?; + let ended_at = DateTime::from_timestamp(period.ended_at as i64, 0)?; + + Some((started_at, ended_at)) + }); + this.trial_started_at = message + .payload + .trial_started_at + .and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0)); + this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled; + this.account_too_young = message.payload.account_too_young; + this.has_overdue_invoices = message.payload.has_overdue_invoices; + + if let Some(usage) = message.payload.usage { + // limits are always present even though they are wrapped in Option + this.model_request_usage = usage + .model_requests_usage_limit + .and_then(|limit| { + RequestUsage::from_proto(usage.model_requests_usage_amount, limit) + }) + .map(ModelRequestUsage); + this.edit_prediction_usage = usage + .edit_predictions_usage_limit + .and_then(|limit| { + RequestUsage::from_proto(usage.model_requests_usage_amount, limit) + }) + .map(EditPredictionUsage); + } + + cx.notify(); + })?; + Ok(()) + } + + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); } fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { @@ -695,131 +763,59 @@ impl UserStore { self.current_user.borrow().clone() } - pub fn plan(&self) -> Option { + pub fn current_plan(&self) -> Option { #[cfg(debug_assertions)] if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() { return match plan.as_str() { - "free" => Some(cloud_llm_client::Plan::ZedFree), - "trial" => Some(cloud_llm_client::Plan::ZedProTrial), - "pro" => Some(cloud_llm_client::Plan::ZedPro), + "free" => Some(proto::Plan::Free), + "trial" => Some(proto::Plan::ZedProTrial), + "pro" => Some(proto::Plan::ZedPro), _ => { panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'"); } }; } - self.plan_info.as_ref().map(|info| info.plan) + self.current_plan } pub fn subscription_period(&self) -> Option<(DateTime, DateTime)> { - self.plan_info - .as_ref() - .and_then(|plan| plan.subscription_period) - .map(|subscription_period| { - ( - subscription_period.started_at.0, - subscription_period.ended_at.0, - ) - }) + self.subscription_period } pub fn trial_started_at(&self) -> Option> { - self.plan_info - .as_ref() - .and_then(|plan| plan.trial_started_at) - .map(|trial_started_at| trial_started_at.0) + self.trial_started_at } - /// Returns whether the user's account is too new to use the service. - pub fn account_too_young(&self) -> bool { - self.plan_info - .as_ref() - .map(|plan| plan.is_account_too_young) - .unwrap_or_default() - } - - /// Returns whether the current user has overdue invoices and usage should be blocked. - pub fn has_overdue_invoices(&self) -> bool { - self.plan_info - .as_ref() - .map(|plan| plan.has_overdue_invoices) - .unwrap_or_default() - } - - pub fn is_usage_based_billing_enabled(&self) -> bool { - self.plan_info - .as_ref() - .map(|plan| plan.is_usage_based_billing_enabled) - .unwrap_or_default() + pub fn usage_based_billing_enabled(&self) -> Option { + self.is_usage_based_billing_enabled } pub fn model_request_usage(&self) -> Option { self.model_request_usage } - pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { - self.model_request_usage = Some(usage); - cx.notify(); - } - pub fn edit_prediction_usage(&self) -> Option { self.edit_prediction_usage } - pub fn update_edit_prediction_usage( - &mut self, - usage: EditPredictionUsage, - cx: &mut Context, - ) { - self.edit_prediction_usage = Some(usage); - cx.notify(); - } - - fn update_authenticated_user( - &mut self, - response: GetAuthenticatedUserResponse, - cx: &mut Context, - ) { - let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF; - cx.update_flags(staff, response.feature_flags); - if let Some(client) = self.client.upgrade() { - client - .telemetry - .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff); - } - - let accepted_tos_at = { - #[cfg(debug_assertions)] - if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() { - None - } else { - response.user.accepted_tos_at - } - - #[cfg(not(debug_assertions))] - response.user.accepted_tos_at - }; - - self.accepted_tos_at = Some(accepted_tos_at); - self.model_request_usage = Some(ModelRequestUsage(RequestUsage { - limit: response.plan.usage.model_requests.limit, - amount: response.plan.usage.model_requests.used as i32, - })); - self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { - limit: response.plan.usage.edit_predictions.limit, - amount: response.plan.usage.edit_predictions.used as i32, - })); - self.plan_info = Some(response.plan); - cx.emit(Event::PrivateUserInfoUpdated); - } - pub fn watch_current_user(&self) -> watch::Receiver>> { self.current_user.clone() } - pub fn has_accepted_terms_of_service(&self) -> bool { + /// Returns whether the user's account is too new to use the service. + pub fn account_too_young(&self) -> bool { + self.account_too_young.unwrap_or(false) + } + + /// Returns whether the current user has overdue invoices and usage should be blocked. + pub fn has_overdue_invoices(&self) -> bool { + self.has_overdue_invoices.unwrap_or(false) + } + + pub fn current_user_has_accepted_terms(&self) -> Option { self.accepted_tos_at - .map_or(false, |accepted_tos_at| accepted_tos_at.is_some()) + .map(|accepted_tos_at| accepted_tos_at.is_some()) } pub fn accept_terms_of_service(&self, cx: &Context) -> Task> { @@ -831,18 +827,23 @@ impl UserStore { cx.spawn(async move |this, cx| -> anyhow::Result<()> { let client = client.upgrade().context("client not found")?; let response = client - .cloud_client() - .accept_terms_of_service() + .request(proto::AcceptTermsOfService {}) .await .context("error accepting tos")?; this.update(cx, |this, cx| { - this.accepted_tos_at = Some(response.user.accepted_tos_at); + this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at)); cx.emit(Event::PrivateUserInfoUpdated); })?; Ok(()) }) } + fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option) { + self.accepted_tos_at = Some( + accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)), + ); + } + fn load_users( &self, request: impl RequestMessage, @@ -901,7 +902,7 @@ impl UserStore { let mut missing_user_ids = Vec::new(); for id in user_ids { if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) { - ret.insert(id, github_login); + ret.insert(id, github_login.into()); } else { missing_user_ids.push(id) } @@ -922,7 +923,7 @@ impl User { fn new(message: proto::User) -> Arc { Arc::new(User { id: message.id, - github_login: message.github_login.into(), + github_login: message.github_login, avatar_uri: message.avatar_url.into(), name: message.name, }) diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml deleted file mode 100644 index d56aa94c6e..0000000000 --- a/crates/cloud_api_client/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "cloud_api_client" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "Apache-2.0" - -[lints] -workspace = true - -[lib] -path = "src/cloud_api_client.rs" - -[dependencies] -anyhow.workspace = true -cloud_api_types.workspace = true -futures.workspace = true -http_client.workspace = true -parking_lot.workspace = true -serde_json.workspace = true -workspace-hack.workspace = true diff --git a/crates/cloud_api_client/LICENSE-APACHE b/crates/cloud_api_client/LICENSE-APACHE deleted file mode 120000 index 1cd601d0a3..0000000000 --- a/crates/cloud_api_client/LICENSE-APACHE +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs deleted file mode 100644 index edac051a0e..0000000000 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ /dev/null @@ -1,188 +0,0 @@ -use std::sync::Arc; - -use anyhow::{Context, Result, anyhow}; -pub use cloud_api_types::*; -use futures::AsyncReadExt as _; -use http_client::http::request; -use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode}; -use parking_lot::RwLock; - -struct Credentials { - user_id: u32, - access_token: String, -} - -pub struct CloudApiClient { - credentials: RwLock>, - http_client: Arc, -} - -impl CloudApiClient { - pub fn new(http_client: Arc) -> Self { - Self { - credentials: RwLock::new(None), - http_client, - } - } - - pub fn has_credentials(&self) -> bool { - self.credentials.read().is_some() - } - - pub fn set_credentials(&self, user_id: u32, access_token: String) { - *self.credentials.write() = Some(Credentials { - user_id, - access_token, - }); - } - - pub fn clear_credentials(&self) { - *self.credentials.write() = None; - } - - fn build_request( - &self, - req: request::Builder, - body: impl Into, - ) -> Result> { - let credentials = self.credentials.read(); - let credentials = credentials.as_ref().context("no credentials provided")?; - build_request(req, body, credentials) - } - - pub async fn get_authenticated_user(&self) -> Result { - let request = self.build_request( - Request::builder().method(Method::GET).uri( - self.http_client - .build_zed_cloud_url("/client/users/me", &[])? - .as_ref(), - ), - AsyncBody::default(), - )?; - - let mut response = self.http_client.send(request).await?; - - if !response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - anyhow::bail!( - "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", - response.status() - ) - } - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - Ok(serde_json::from_str(&body)?) - } - - pub async fn accept_terms_of_service(&self) -> Result { - let request = self.build_request( - Request::builder().method(Method::POST).uri( - self.http_client - .build_zed_cloud_url("/client/terms_of_service/accept", &[])? - .as_ref(), - ), - AsyncBody::default(), - )?; - - let mut response = self.http_client.send(request).await?; - - if !response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - anyhow::bail!( - "Failed to accept terms of service.\nStatus: {:?}\nBody: {body}", - response.status() - ) - } - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - Ok(serde_json::from_str(&body)?) - } - - pub async fn create_llm_token( - &self, - system_id: Option, - ) -> Result { - let mut request_builder = Request::builder().method(Method::POST).uri( - self.http_client - .build_zed_cloud_url("/client/llm_tokens", &[])? - .as_ref(), - ); - - if let Some(system_id) = system_id { - request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id); - } - - let request = self.build_request(request_builder, AsyncBody::default())?; - - let mut response = self.http_client.send(request).await?; - - if !response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - anyhow::bail!( - "Failed to create LLM token.\nStatus: {:?}\nBody: {body}", - response.status() - ) - } - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - Ok(serde_json::from_str(&body)?) - } - - pub async fn validate_credentials(&self, user_id: u32, access_token: &str) -> Result { - let request = build_request( - Request::builder().method(Method::GET).uri( - self.http_client - .build_zed_cloud_url("/client/users/me", &[])? - .as_ref(), - ), - AsyncBody::default(), - &Credentials { - user_id, - access_token: access_token.into(), - }, - )?; - - let mut response = self.http_client.send(request).await?; - - if response.status().is_success() { - Ok(true) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - if response.status() == StatusCode::UNAUTHORIZED { - return Ok(false); - } else { - return Err(anyhow!( - "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", - response.status() - )); - } - } - } -} - -fn build_request( - req: request::Builder, - body: impl Into, - credentials: &Credentials, -) -> Result> { - Ok(req - .header("Content-Type", "application/json") - .header( - "Authorization", - format!("{} {}", credentials.user_id, credentials.access_token), - ) - .body(body.into())?) -} diff --git a/crates/cloud_api_types/Cargo.toml b/crates/cloud_api_types/Cargo.toml deleted file mode 100644 index 868797df3b..0000000000 --- a/crates/cloud_api_types/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -name = "cloud_api_types" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "Apache-2.0" - -[lints] -workspace = true - -[lib] -path = "src/cloud_api_types.rs" - -[dependencies] -chrono.workspace = true -cloud_llm_client.workspace = true -serde.workspace = true -workspace-hack.workspace = true - -[dev-dependencies] -pretty_assertions.workspace = true -serde_json.workspace = true diff --git a/crates/cloud_api_types/LICENSE-APACHE b/crates/cloud_api_types/LICENSE-APACHE deleted file mode 120000 index 1cd601d0a3..0000000000 --- a/crates/cloud_api_types/LICENSE-APACHE +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_types/src/cloud_api_types.rs b/crates/cloud_api_types/src/cloud_api_types.rs deleted file mode 100644 index b38b38cde1..0000000000 --- a/crates/cloud_api_types/src/cloud_api_types.rs +++ /dev/null @@ -1,55 +0,0 @@ -mod timestamp; - -use serde::{Deserialize, Serialize}; - -pub use crate::timestamp::Timestamp; - -pub const ZED_SYSTEM_ID_HEADER_NAME: &str = "x-zed-system-id"; - -#[derive(Debug, PartialEq, Serialize, Deserialize)] -pub struct GetAuthenticatedUserResponse { - pub user: AuthenticatedUser, - pub feature_flags: Vec, - pub plan: PlanInfo, -} - -#[derive(Debug, PartialEq, Serialize, Deserialize)] -pub struct AuthenticatedUser { - pub id: i32, - pub metrics_id: String, - pub avatar_url: String, - pub github_login: String, - pub name: Option, - pub is_staff: bool, - pub accepted_tos_at: Option, -} - -#[derive(Debug, PartialEq, Serialize, Deserialize)] -pub struct PlanInfo { - pub plan: cloud_llm_client::Plan, - pub subscription_period: Option, - pub usage: cloud_llm_client::CurrentUsage, - pub trial_started_at: Option, - pub is_usage_based_billing_enabled: bool, - pub is_account_too_young: bool, - pub has_overdue_invoices: bool, -} - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] -pub struct SubscriptionPeriod { - pub started_at: Timestamp, - pub ended_at: Timestamp, -} - -#[derive(Debug, PartialEq, Serialize, Deserialize)] -pub struct AcceptTermsOfServiceResponse { - pub user: AuthenticatedUser, -} - -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] -pub struct LlmToken(pub String); - -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] -pub struct CreateLlmTokenResponse { - pub token: LlmToken, -} diff --git a/crates/cloud_api_types/src/timestamp.rs b/crates/cloud_api_types/src/timestamp.rs deleted file mode 100644 index 1f055d58ef..0000000000 --- a/crates/cloud_api_types/src/timestamp.rs +++ /dev/null @@ -1,166 +0,0 @@ -use chrono::{DateTime, NaiveDateTime, SecondsFormat, Utc}; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -/// A timestamp with a serialized representation in RFC 3339 format. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] -pub struct Timestamp(pub DateTime); - -impl Timestamp { - pub fn new(datetime: DateTime) -> Self { - Self(datetime) - } -} - -impl From> for Timestamp { - fn from(value: DateTime) -> Self { - Self(value) - } -} - -impl From for Timestamp { - fn from(value: NaiveDateTime) -> Self { - Self(value.and_utc()) - } -} - -impl Serialize for Timestamp { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let rfc3339_string = self.0.to_rfc3339_opts(SecondsFormat::Millis, true); - serializer.serialize_str(&rfc3339_string) - } -} - -impl<'de> Deserialize<'de> for Timestamp { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let value = String::deserialize(deserializer)?; - let datetime = DateTime::parse_from_rfc3339(&value) - .map_err(serde::de::Error::custom)? - .to_utc(); - Ok(Self(datetime)) - } -} - -#[cfg(test)] -mod tests { - use chrono::NaiveDate; - use pretty_assertions::assert_eq; - - use super::*; - - #[test] - fn test_timestamp_serialization() { - let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") - .unwrap() - .to_utc(); - let timestamp = Timestamp::new(datetime); - - let json = serde_json::to_string(×tamp).unwrap(); - assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); - } - - #[test] - fn test_timestamp_deserialization() { - let json = "\"2023-12-25T14:30:45.123Z\""; - let timestamp: Timestamp = serde_json::from_str(json).unwrap(); - - let expected = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") - .unwrap() - .to_utc(); - - assert_eq!(timestamp.0, expected); - } - - #[test] - fn test_timestamp_roundtrip() { - let original = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") - .unwrap() - .to_utc(); - - let timestamp = Timestamp::new(original); - let json = serde_json::to_string(×tamp).unwrap(); - let deserialized: Timestamp = serde_json::from_str(&json).unwrap(); - - assert_eq!(deserialized.0, original); - } - - #[test] - fn test_timestamp_from_datetime_utc() { - let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") - .unwrap() - .to_utc(); - - let timestamp = Timestamp::from(datetime); - assert_eq!(timestamp.0, datetime); - } - - #[test] - fn test_timestamp_from_naive_datetime() { - let naive_dt = NaiveDate::from_ymd_opt(2023, 12, 25) - .unwrap() - .and_hms_milli_opt(14, 30, 45, 123) - .unwrap(); - - let timestamp = Timestamp::from(naive_dt); - let expected = naive_dt.and_utc(); - - assert_eq!(timestamp.0, expected); - } - - #[test] - fn test_timestamp_serialization_with_microseconds() { - // Test that microseconds are truncated to milliseconds - let datetime = NaiveDate::from_ymd_opt(2023, 12, 25) - .unwrap() - .and_hms_micro_opt(14, 30, 45, 123456) - .unwrap() - .and_utc(); - - let timestamp = Timestamp::new(datetime); - let json = serde_json::to_string(×tamp).unwrap(); - - // Should be truncated to milliseconds - assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); - } - - #[test] - fn test_timestamp_deserialization_without_milliseconds() { - let json = "\"2023-12-25T14:30:45Z\""; - let timestamp: Timestamp = serde_json::from_str(json).unwrap(); - - let expected = NaiveDate::from_ymd_opt(2023, 12, 25) - .unwrap() - .and_hms_opt(14, 30, 45) - .unwrap() - .and_utc(); - - assert_eq!(timestamp.0, expected); - } - - #[test] - fn test_timestamp_deserialization_with_timezone() { - let json = "\"2023-12-25T14:30:45.123+05:30\""; - let timestamp: Timestamp = serde_json::from_str(json).unwrap(); - - // Should be converted to UTC - let expected = NaiveDate::from_ymd_opt(2023, 12, 25) - .unwrap() - .and_hms_milli_opt(9, 0, 45, 123) // 14:30:45 + 5:30 = 20:00:45, but we want UTC so subtract 5:30 - .unwrap() - .and_utc(); - - assert_eq!(timestamp.0, expected); - } - - #[test] - fn test_timestamp_deserialization_with_invalid_format() { - let json = "\"invalid-date\""; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - } -} diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml deleted file mode 100644 index 6f090d3c6e..0000000000 --- a/crates/cloud_llm_client/Cargo.toml +++ /dev/null @@ -1,23 +0,0 @@ -[package] -name = "cloud_llm_client" -version = "0.1.0" -publish.workspace = true -edition.workspace = true -license = "Apache-2.0" - -[lints] -workspace = true - -[lib] -path = "src/cloud_llm_client.rs" - -[dependencies] -anyhow.workspace = true -serde = { workspace = true, features = ["derive", "rc"] } -serde_json.workspace = true -strum = { workspace = true, features = ["derive"] } -uuid = { workspace = true, features = ["serde"] } -workspace-hack.workspace = true - -[dev-dependencies] -pretty_assertions.workspace = true diff --git a/crates/cloud_llm_client/LICENSE-APACHE b/crates/cloud_llm_client/LICENSE-APACHE deleted file mode 120000 index 1cd601d0a3..0000000000 --- a/crates/cloud_llm_client/LICENSE-APACHE +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs deleted file mode 100644 index 171c923154..0000000000 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ /dev/null @@ -1,370 +0,0 @@ -use std::str::FromStr; -use std::sync::Arc; - -use anyhow::Context as _; -use serde::{Deserialize, Serialize}; -use strum::{Display, EnumIter, EnumString}; -use uuid::Uuid; - -/// The name of the header used to indicate which version of Zed the client is running. -pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version"; - -/// The name of the header used to indicate when a request failed due to an -/// expired LLM token. -/// -/// The client may use this as a signal to refresh the token. -pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; - -/// The name of the header used to indicate what plan the user is currently on. -pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan"; - -/// The name of the header used to indicate the usage limit for model requests. -pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit"; - -/// The name of the header used to indicate the usage amount for model requests. -pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount"; - -/// The name of the header used to indicate the usage limit for edit predictions. -pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit"; - -/// The name of the header used to indicate the usage amount for edit predictions. -pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount"; - -/// The name of the header used to indicate the resource for which the subscription limit has been reached. -pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource"; - -pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests"; -pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions"; - -/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached. -pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached"; - -/// The name of the header used to indicate the the minimum required Zed version. -/// -/// This can be used to force a Zed upgrade in order to continue communicating -/// with the LLM service. -pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version"; - -/// The name of the header used by the client to indicate to the server that it supports receiving status messages. -pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = - "x-zed-client-supports-status-messages"; - -/// The name of the header used by the server to indicate to the client that it supports sending status messages. -pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = - "x-zed-server-supports-status-messages"; - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum UsageLimit { - Limited(i32), - Unlimited, -} - -impl FromStr for UsageLimit { - type Err = anyhow::Error; - - fn from_str(value: &str) -> Result { - match value { - "unlimited" => Ok(Self::Unlimited), - limit => limit - .parse::() - .map(Self::Limited) - .context("failed to parse limit"), - } - } -} - -#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum Plan { - #[default] - #[serde(alias = "Free")] - ZedFree, - #[serde(alias = "ZedPro")] - ZedPro, - #[serde(alias = "ZedProTrial")] - ZedProTrial, -} - -impl Plan { - pub fn as_str(&self) -> &'static str { - match self { - Plan::ZedFree => "zed_free", - Plan::ZedPro => "zed_pro", - Plan::ZedProTrial => "zed_pro_trial", - } - } - - pub fn model_requests_limit(&self) -> UsageLimit { - match self { - Plan::ZedPro => UsageLimit::Limited(500), - Plan::ZedProTrial => UsageLimit::Limited(150), - Plan::ZedFree => UsageLimit::Limited(50), - } - } - - pub fn edit_predictions_limit(&self) -> UsageLimit { - match self { - Plan::ZedPro => UsageLimit::Unlimited, - Plan::ZedProTrial => UsageLimit::Unlimited, - Plan::ZedFree => UsageLimit::Limited(2_000), - } - } -} - -impl FromStr for Plan { - type Err = anyhow::Error; - - fn from_str(value: &str) -> Result { - match value { - "zed_free" => Ok(Plan::ZedFree), - "zed_pro" => Ok(Plan::ZedPro), - "zed_pro_trial" => Ok(Plan::ZedProTrial), - plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")), - } - } -} - -#[derive( - Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display, -)] -#[serde(rename_all = "snake_case")] -#[strum(serialize_all = "snake_case")] -pub enum LanguageModelProvider { - Anthropic, - OpenAi, - Google, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PredictEditsBody { - #[serde(skip_serializing_if = "Option::is_none", default)] - pub outline: Option, - pub input_events: String, - pub input_excerpt: String, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub speculated_output: Option, - /// Whether the user provided consent for sampling this interaction. - #[serde(default, alias = "data_collection_permission")] - pub can_collect_data: bool, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub diagnostic_groups: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PredictEditsResponse { - pub request_id: Uuid, - pub output_excerpt: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AcceptEditPredictionBody { - pub request_id: Uuid, -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CompletionMode { - Normal, - Max, -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CompletionIntent { - UserPrompt, - ToolResults, - ThreadSummarization, - ThreadContextSummarization, - CreateFile, - EditFile, - InlineAssist, - TerminalInlineAssist, - GenerateGitCommitMessage, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct CompletionBody { - #[serde(skip_serializing_if = "Option::is_none", default)] - pub thread_id: Option, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub prompt_id: Option, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub intent: Option, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub mode: Option, - pub provider: LanguageModelProvider, - pub model: String, - pub provider_request: serde_json::Value, -} - -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CompletionRequestStatus { - Queued { - position: usize, - }, - Started, - Failed { - code: String, - message: String, - request_id: Uuid, - /// Retry duration in seconds. - retry_after: Option, - }, - UsageUpdated { - amount: usize, - limit: UsageLimit, - }, - ToolUseLimitReached, -} - -#[derive(Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CompletionEvent { - Status(CompletionRequestStatus), - Event(T), -} - -impl CompletionEvent { - pub fn into_status(self) -> Option { - match self { - Self::Status(status) => Some(status), - Self::Event(_) => None, - } - } - - pub fn into_event(self) -> Option { - match self { - Self::Event(event) => Some(event), - Self::Status(_) => None, - } - } -} - -#[derive(Serialize, Deserialize)] -pub struct WebSearchBody { - pub query: String, -} - -#[derive(Serialize, Deserialize, Clone)] -pub struct WebSearchResponse { - pub results: Vec, -} - -#[derive(Serialize, Deserialize, Clone)] -pub struct WebSearchResult { - pub title: String, - pub url: String, - pub text: String, -} - -#[derive(Serialize, Deserialize)] -pub struct CountTokensBody { - pub provider: LanguageModelProvider, - pub model: String, - pub provider_request: serde_json::Value, -} - -#[derive(Serialize, Deserialize)] -pub struct CountTokensResponse { - pub tokens: usize, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelId(pub Arc); - -impl std::fmt::Display for LanguageModelId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct LanguageModel { - pub provider: LanguageModelProvider, - pub id: LanguageModelId, - pub display_name: String, - pub max_token_count: usize, - pub max_token_count_in_max_mode: Option, - pub max_output_tokens: usize, - pub supports_tools: bool, - pub supports_images: bool, - pub supports_thinking: bool, - pub supports_max_mode: bool, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ListModelsResponse { - pub models: Vec, - pub default_model: LanguageModelId, - pub default_fast_model: LanguageModelId, - pub recommended_models: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct GetSubscriptionResponse { - pub plan: Plan, - pub usage: Option, -} - -#[derive(Debug, PartialEq, Serialize, Deserialize)] -pub struct CurrentUsage { - pub model_requests: UsageData, - pub edit_predictions: UsageData, -} - -#[derive(Debug, PartialEq, Serialize, Deserialize)] -pub struct UsageData { - pub used: u32, - pub limit: UsageLimit, -} - -#[cfg(test)] -mod tests { - use pretty_assertions::assert_eq; - use serde_json::json; - - use super::*; - - #[test] - fn test_plan_deserialize_snake_case() { - let plan = serde_json::from_value::(json!("zed_free")).unwrap(); - assert_eq!(plan, Plan::ZedFree); - - let plan = serde_json::from_value::(json!("zed_pro")).unwrap(); - assert_eq!(plan, Plan::ZedPro); - - let plan = serde_json::from_value::(json!("zed_pro_trial")).unwrap(); - assert_eq!(plan, Plan::ZedProTrial); - } - - #[test] - fn test_plan_deserialize_aliases() { - let plan = serde_json::from_value::(json!("Free")).unwrap(); - assert_eq!(plan, Plan::ZedFree); - - let plan = serde_json::from_value::(json!("ZedPro")).unwrap(); - assert_eq!(plan, Plan::ZedPro); - - let plan = serde_json::from_value::(json!("ZedProTrial")).unwrap(); - assert_eq!(plan, Plan::ZedProTrial); - } - - #[test] - fn test_usage_limit_from_str() { - let limit = UsageLimit::from_str("unlimited").unwrap(); - assert!(matches!(limit, UsageLimit::Unlimited)); - - let limit = UsageLimit::from_str(&0.to_string()).unwrap(); - assert!(matches!(limit, UsageLimit::Limited(0))); - - let limit = UsageLimit::from_str(&50.to_string()).unwrap(); - assert!(matches!(limit, UsageLimit::Limited(50))); - - for value in ["not_a_number", "50xyz"] { - let limit = UsageLimit::from_str(value); - assert!(limit.is_err()); - } - } -} diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 9af95317e6..d3b5048283 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -23,14 +23,13 @@ async-stripe.workspace = true async-trait.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } -aws-sdk-kinesis = "1.51.0" aws-sdk-s3 = { version = "1.15.0" } +aws-sdk-kinesis = "1.51.0" axum = { version = "0.6", features = ["json", "headers", "ws"] } axum-extra = { version = "0.4", features = ["erased-json"] } base64.workspace = true chrono.workspace = true clock.workspace = true -cloud_llm_client.workspace = true collections.workspace = true dashmap.workspace = true derive_more.workspace = true @@ -76,6 +75,7 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re util.workspace = true uuid.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] agent_settings.workspace = true diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 6cf3f68f54..3b0f5396a7 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -100,11 +100,13 @@ impl std::fmt::Display for SystemIdHeader { pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() + .route("/user", get(update_or_create_authenticated_user)) .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) .route("/users/:id/update_plan", post(update_plan)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) + .merge(billing::router()) .merge(contributors::router()) .layer( ServiceBuilder::new() @@ -144,6 +146,48 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR Ok::<_, Error>(next.run(req).await) } +#[derive(Debug, Deserialize)] +struct AuthenticatedUserParams { + github_user_id: i32, + github_login: String, + github_email: Option, + github_name: Option, + github_user_created_at: chrono::DateTime, +} + +#[derive(Debug, Serialize)] +struct AuthenticatedUserResponse { + user: User, + metrics_id: String, + feature_flags: Vec, +} + +async fn update_or_create_authenticated_user( + Query(params): Query, + Extension(app): Extension>, +) -> Result> { + let initial_channel_id = app.config.auto_join_channel_id; + + let user = app + .db + .update_or_create_user_by_github_account( + ¶ms.github_login, + params.github_user_id, + params.github_email.as_deref(), + params.github_name.as_deref(), + params.github_user_created_at, + initial_channel_id, + ) + .await?; + let metrics_id = app.db.get_user_metrics_id(user.id).await?; + let feature_flags = app.db.get_user_flags(user.id).await?; + Ok(Json(AuthenticatedUserResponse { + user, + metrics_id, + feature_flags, + })) +} + #[derive(Debug, Deserialize)] struct LookUpUserParams { identifier: String, @@ -310,9 +354,9 @@ async fn refresh_llm_tokens( #[derive(Debug, Serialize, Deserialize)] struct UpdatePlanBody { - pub plan: cloud_llm_client::Plan, + pub plan: zed_llm_client::Plan, pub subscription_period: SubscriptionPeriod, - pub usage: cloud_llm_client::CurrentUsage, + pub usage: zed_llm_client::CurrentUsage, pub trial_started_at: Option>, pub is_usage_based_billing_enabled: bool, pub is_account_too_young: bool, @@ -334,9 +378,9 @@ async fn update_plan( extract::Json(body): extract::Json, ) -> Result> { let plan = match body.plan { - cloud_llm_client::Plan::ZedFree => proto::Plan::Free, - cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro, - cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, + zed_llm_client::Plan::ZedFree => proto::Plan::Free, + zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro, + zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, }; let update_user_plan = proto::UpdateUserPlan { @@ -368,15 +412,15 @@ async fn update_plan( Ok(Json(UpdatePlanResponse {})) } -fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit { +fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit { proto::UsageLimit { variant: Some(match limit { - cloud_llm_client::UsageLimit::Limited(limit) => { + zed_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - cloud_llm_client::UsageLimit::Unlimited => { + zed_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 0e15308ffe..bd7b99b3eb 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,13 +1,23 @@ use anyhow::{Context as _, bail}; +use axum::{Extension, Json, Router, extract, routing::post}; use chrono::{DateTime, Utc}; -use cloud_llm_client::LanguageModelProvider; use collections::{HashMap, HashSet}; +use reqwest::StatusCode; use sea_orm::ActiveValue; -use std::{sync::Arc, time::Duration}; -use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus}; +use serde::{Deserialize, Serialize}; +use std::{str::FromStr, sync::Arc, time::Duration}; +use stripe::{ + BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession, + CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, + CreateBillingPortalSessionFlowDataAfterCompletionRedirect, + CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm, + CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems, + CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents, + PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus, +}; use util::{ResultExt, maybe}; +use zed_llm_client::LanguageModelProvider; -use crate::AppState; use crate::db::billing_subscription::{ StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, }; @@ -17,16 +27,331 @@ use crate::stripe_client::{ StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription, StripeSubscriptionId, }; +use crate::{AppState, Error, Result}; use crate::{db::UserId, llm::db::LlmDatabase}; use crate::{ db::{ - CreateBillingCustomerParams, CreateBillingSubscriptionParams, + BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams, UpdateBillingSubscriptionParams, billing_customer, }, stripe_billing::StripeBilling, }; +pub fn router() -> Router { + Router::new() + .route( + "/billing/subscriptions/manage", + post(manage_billing_subscription), + ) + .route( + "/billing/subscriptions/sync", + post(sync_billing_subscription), + ) +} + +#[derive(Debug, PartialEq, Deserialize)] +#[serde(rename_all = "snake_case")] +enum ManageSubscriptionIntent { + /// The user intends to manage their subscription. + /// + /// This will open the Stripe billing portal without putting the user in a specific flow. + ManageSubscription, + /// The user intends to update their payment method. + UpdatePaymentMethod, + /// The user intends to upgrade to Zed Pro. + UpgradeToPro, + /// The user intends to cancel their subscription. + Cancel, + /// The user intends to stop the cancellation of their subscription. + StopCancellation, +} + +#[derive(Debug, Deserialize)] +struct ManageBillingSubscriptionBody { + github_user_id: i32, + intent: ManageSubscriptionIntent, + /// The ID of the subscription to manage. + subscription_id: BillingSubscriptionId, + redirect_to: Option, +} + +#[derive(Debug, Serialize)] +struct ManageBillingSubscriptionResponse { + billing_portal_session_url: Option, +} + +/// Initiates a Stripe customer portal session for managing a billing subscription. +async fn manage_billing_subscription( + Extension(app): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let user = app + .db + .get_user_by_github_user_id(body.github_user_id) + .await? + .context("user not found")?; + + let Some(stripe_client) = app.real_stripe_client.clone() else { + log::error!("failed to retrieve Stripe client"); + Err(Error::http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; + + let Some(stripe_billing) = app.stripe_billing.clone() else { + log::error!("failed to retrieve Stripe billing object"); + Err(Error::http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; + + let customer = app + .db + .get_billing_customer_by_user_id(user.id) + .await? + .context("billing customer not found")?; + let customer_id = CustomerId::from_str(&customer.stripe_customer_id) + .context("failed to parse customer ID")?; + + let subscription = app + .db + .get_billing_subscription_by_id(body.subscription_id) + .await? + .context("subscription not found")?; + let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id) + .context("failed to parse subscription ID")?; + + if body.intent == ManageSubscriptionIntent::StopCancellation { + let updated_stripe_subscription = Subscription::update( + &stripe_client, + &subscription_id, + stripe::UpdateSubscription { + cancel_at_period_end: Some(false), + ..Default::default() + }, + ) + .await?; + + app.db + .update_billing_subscription( + subscription.id, + &UpdateBillingSubscriptionParams { + stripe_cancel_at: ActiveValue::set( + updated_stripe_subscription + .cancel_at + .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0)) + .map(|time| time.naive_utc()), + ), + ..Default::default() + }, + ) + .await?; + + return Ok(Json(ManageBillingSubscriptionResponse { + billing_portal_session_url: None, + })); + } + + let flow = match body.intent { + ManageSubscriptionIntent::ManageSubscription => None, + ManageSubscriptionIntent::UpgradeToPro => { + let zed_pro_price_id: stripe::PriceId = + stripe_billing.zed_pro_price_id().await?.try_into()?; + let zed_free_price_id: stripe::PriceId = + stripe_billing.zed_free_price_id().await?.try_into()?; + + let stripe_subscription = + Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?; + + let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing + && stripe_subscription.items.data.iter().any(|item| { + item.price + .as_ref() + .map_or(false, |price| price.id == zed_pro_price_id) + }); + if is_on_zed_pro_trial { + let payment_methods = PaymentMethod::list( + &stripe_client, + &stripe::ListPaymentMethods { + customer: Some(stripe_subscription.customer.id()), + ..Default::default() + }, + ) + .await?; + + let has_payment_method = !payment_methods.data.is_empty(); + if !has_payment_method { + return Err(Error::http( + StatusCode::BAD_REQUEST, + "missing payment method".into(), + )); + } + + // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early. + Subscription::update( + &stripe_client, + &stripe_subscription.id, + stripe::UpdateSubscription { + trial_end: Some(stripe::Scheduled::now()), + ..Default::default() + }, + ) + .await?; + + return Ok(Json(ManageBillingSubscriptionResponse { + billing_portal_session_url: None, + })); + } + + let subscription_item_to_update = stripe_subscription + .items + .data + .iter() + .find_map(|item| { + let price = item.price.as_ref()?; + + if price.id == zed_free_price_id { + Some(item.id.clone()) + } else { + None + } + }) + .context("No subscription item to update")?; + + Some(CreateBillingPortalSessionFlowData { + type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm, + subscription_update_confirm: Some( + CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm { + subscription: subscription.stripe_subscription_id, + items: vec![ + CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems { + id: subscription_item_to_update.to_string(), + price: Some(zed_pro_price_id.to_string()), + quantity: Some(1), + }, + ], + discounts: None, + }, + ), + ..Default::default() + }) + } + ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData { + type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate, + after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { + type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect, + redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect { + return_url: format!( + "{}{path}", + app.config.zed_dot_dev_url(), + path = body.redirect_to.unwrap_or_else(|| "/account".to_string()) + ), + }), + ..Default::default() + }), + ..Default::default() + }), + ManageSubscriptionIntent::Cancel => { + if subscription.kind == Some(SubscriptionKind::ZedFree) { + return Err(Error::http( + StatusCode::BAD_REQUEST, + "free subscription cannot be canceled".into(), + )); + } + + Some(CreateBillingPortalSessionFlowData { + type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel, + after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { + type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect, + redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect { + return_url: format!("{}/account", app.config.zed_dot_dev_url()), + }), + ..Default::default() + }), + subscription_cancel: Some( + stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel { + subscription: subscription.stripe_subscription_id, + retention: None, + }, + ), + ..Default::default() + }) + } + ManageSubscriptionIntent::StopCancellation => unreachable!(), + }; + + let mut params = CreateBillingPortalSession::new(customer_id); + params.flow_data = flow; + let return_url = format!("{}/account", app.config.zed_dot_dev_url()); + params.return_url = Some(&return_url); + + let session = BillingPortalSession::create(&stripe_client, params).await?; + + Ok(Json(ManageBillingSubscriptionResponse { + billing_portal_session_url: Some(session.url), + })) +} + +#[derive(Debug, Deserialize)] +struct SyncBillingSubscriptionBody { + github_user_id: i32, +} + +#[derive(Debug, Serialize)] +struct SyncBillingSubscriptionResponse { + stripe_customer_id: String, +} + +async fn sync_billing_subscription( + Extension(app): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let Some(stripe_client) = app.stripe_client.clone() else { + log::error!("failed to retrieve Stripe client"); + Err(Error::http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; + + let user = app + .db + .get_user_by_github_user_id(body.github_user_id) + .await? + .context("user not found")?; + + let billing_customer = app + .db + .get_billing_customer_by_user_id(user.id) + .await? + .context("billing customer not found")?; + let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); + + let subscriptions = stripe_client + .list_subscriptions_for_customer(&stripe_customer_id) + .await?; + + for subscription in subscriptions { + let subscription_id = subscription.id.clone(); + + sync_subscription(&app, &stripe_client, subscription) + .await + .with_context(|| { + format!( + "failed to sync subscription {subscription_id} for user {}", + user.id, + ) + })?; + } + + Ok(Json(SyncBillingSubscriptionResponse { + stripe_customer_id: billing_customer.stripe_customer_id.clone(), + })) +} + /// The amount of time we wait in between each poll of Stripe events. /// /// This value should strike a balance between: @@ -87,14 +412,6 @@ async fn poll_stripe_events( stripe_client: &Arc, real_stripe_client: &stripe::Client, ) -> anyhow::Result<()> { - let feature_flags = app.db.list_feature_flags().await?; - let sync_events_using_cloud = feature_flags - .iter() - .any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all); - if sync_events_using_cloud { - return Ok(()); - } - fn event_type_to_string(event_type: EventType) -> String { // Calling `to_string` on `stripe::EventType` members gives us a quoted string, // so we need to unquote it. @@ -577,14 +894,6 @@ async fn sync_model_request_usage_with_stripe( llm_db: &Arc, stripe_billing: &Arc, ) -> anyhow::Result<()> { - let feature_flags = app.db.list_feature_flags().await?; - let sync_model_request_usage_using_cloud = feature_flags - .iter() - .any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all); - if sync_model_request_usage_using_cloud { - return Ok(()); - } - log::info!("Stripe usage sync: Starting"); let started_at = Utc::now(); diff --git a/crates/collab/src/api/contributors.rs b/crates/collab/src/api/contributors.rs index 8cfef0ad7e..9296c1d428 100644 --- a/crates/collab/src/api/contributors.rs +++ b/crates/collab/src/api/contributors.rs @@ -8,6 +8,7 @@ use axum::{ use chrono::{NaiveDateTime, SecondsFormat}; use serde::{Deserialize, Serialize}; +use crate::api::AuthenticatedUserParams; use crate::db::ContributorSelector; use crate::{AppState, Result}; @@ -103,18 +104,9 @@ impl RenovateBot { } } -#[derive(Debug, Deserialize)] -struct AddContributorBody { - github_user_id: i32, - github_login: String, - github_email: Option, - github_name: Option, - github_user_created_at: chrono::DateTime, -} - async fn add_contributor( Extension(app): Extension>, - extract::Json(params): extract::Json, + extract::Json(params): extract::Json, ) -> Result<()> { let initial_channel_id = app.config.auto_join_channel_id; app.db diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 522973dbc9..43198f9859 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -95,7 +95,7 @@ pub enum SubscriptionKind { ZedFree, } -impl From for cloud_llm_client::Plan { +impl From for zed_llm_client::Plan { fn from(value: SubscriptionKind) -> Self { match value { SubscriptionKind::ZedPro => Self::ZedPro, diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index 18ad624dab..6a6efca0de 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -6,11 +6,11 @@ mod tables; #[cfg(test)] mod tests; -use cloud_llm_client::LanguageModelProvider; use collections::HashMap; pub use ids::*; pub use seed::*; pub use tables::*; +use zed_llm_client::LanguageModelProvider; #[cfg(test)] pub use tests::TestLlmDb; diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs index f4e1de40ec..7d52964b93 100644 --- a/crates/collab/src/llm/db/tests/provider_tests.rs +++ b/crates/collab/src/llm/db/tests/provider_tests.rs @@ -1,5 +1,5 @@ -use cloud_llm_client::LanguageModelProvider; use pretty_assertions::assert_eq; +use zed_llm_client::LanguageModelProvider; use crate::llm::db::LlmDatabase; use crate::test_llm_db; diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index da01c7f3be..d4566ffcb4 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -4,12 +4,12 @@ use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEA use crate::{Config, db::billing_preference}; use anyhow::{Context as _, Result}; use chrono::{NaiveDateTime, Utc}; -use cloud_llm_client::Plan; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use std::time::Duration; use thiserror::Error; use uuid::Uuid; +use zed_llm_client::Plan; #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index e648617fe1..0735b08e89 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -23,7 +23,6 @@ use anyhow::{Context as _, anyhow, bail}; use async_tungstenite::tungstenite::{ Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame, }; -use axum::headers::UserAgent; use axum::{ Extension, Router, TypedHeader, body::Body, @@ -42,7 +41,7 @@ use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; use reqwest_client::ReqwestClient; -use rpc::proto::{MultiLspQuery, split_repository_update}; +use rpc::proto::split_repository_update; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; use futures::{ @@ -374,7 +373,7 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) - .add_request_handler(multi_lsp_query) + .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) @@ -434,8 +433,6 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_read_only_project_request::) @@ -751,7 +748,6 @@ impl Server { address: String, principal: Principal, zed_version: ZedVersion, - user_agent: Option, geoip_country_code: Option, system_id: Option, send_connection_id: Option>, @@ -764,14 +760,9 @@ impl Server { user_id=field::Empty, login=field::Empty, impersonator=field::Empty, - user_agent=field::Empty, geoip_country_code=field::Empty ); principal.update_span(&span); - if let Some(user_agent) = user_agent { - span.record("user_agent", user_agent); - } - if let Some(country_code) = geoip_country_code.as_ref() { span.record("geoip_country_code", country_code); } @@ -865,7 +856,6 @@ impl Server { user_id=field::Empty, login=field::Empty, impersonator=field::Empty, - multi_lsp_query_request=field::Empty, ); principal.update_span(&span); let span_enter = span.enter(); @@ -1180,7 +1170,6 @@ pub async fn handle_websocket_request( ConnectInfo(socket_address): ConnectInfo, Extension(server): Extension>, Extension(principal): Extension, - user_agent: Option>, country_code_header: Option>, system_id_header: Option>, ws: WebSocketUpgrade, @@ -1236,7 +1225,6 @@ pub async fn handle_websocket_request( socket_address, principal, version, - user_agent.map(|header| header.to_string()), country_code_header.map(|header| header.to_string()), system_id_header.map(|header| header.to_string()), None, @@ -2330,15 +2318,6 @@ where Ok(()) } -async fn multi_lsp_query( - request: MultiLspQuery, - response: Response, - session: Session, -) -> Result<()> { - tracing::Span::current().record("multi_lsp_query_request", request.request_str()); - forward_mutating_project_request(request, response, session).await -} - /// Notify other participants that a new buffer has been created async fn create_buffer_for_peer( request: proto::CreateBufferForPeer, @@ -2878,12 +2857,12 @@ async fn make_update_user_plan_message( } fn model_requests_limit( - plan: cloud_llm_client::Plan, + plan: zed_llm_client::Plan, feature_flags: &Vec, -) -> cloud_llm_client::UsageLimit { +) -> zed_llm_client::UsageLimit { match plan.model_requests_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == cloud_llm_client::Plan::ZedProTrial + zed_llm_client::UsageLimit::Limited(limit) => { + let limit = if plan == zed_llm_client::Plan::ZedProTrial && feature_flags .iter() .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) @@ -2893,9 +2872,9 @@ fn model_requests_limit( limit }; - cloud_llm_client::UsageLimit::Limited(limit) + zed_llm_client::UsageLimit::Limited(limit) } - cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited, + zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited, } } @@ -2905,21 +2884,21 @@ fn subscription_usage_to_proto( feature_flags: &Vec, ) -> proto::SubscriptionUsage { let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, + proto::Plan::Free => zed_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, }; proto::SubscriptionUsage { model_requests_usage_amount: usage.model_requests as u32, model_requests_usage_limit: Some(proto::UsageLimit { variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { + zed_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - cloud_llm_client::UsageLimit::Unlimited => { + zed_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2927,12 +2906,12 @@ fn subscription_usage_to_proto( edit_predictions_usage_amount: usage.edit_predictions as u32, edit_predictions_usage_limit: Some(proto::UsageLimit { variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { + zed_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - cloud_llm_client::UsageLimit::Unlimited => { + zed_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2945,21 +2924,21 @@ fn make_default_subscription_usage( feature_flags: &Vec, ) -> proto::SubscriptionUsage { let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, + proto::Plan::Free => zed_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, }; proto::SubscriptionUsage { model_requests_usage_amount: 0, model_requests_usage_limit: Some(proto::UsageLimit { variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { + zed_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - cloud_llm_client::UsageLimit::Unlimited => { + zed_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2967,12 +2946,12 @@ fn make_default_subscription_usage( edit_predictions_usage_amount: 0, edit_predictions_usage_limit: Some(proto::UsageLimit { variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { + zed_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - cloud_llm_client::UsageLimit::Unlimited => { + zed_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 8d5d076780..19e410de5b 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -38,12 +38,12 @@ fn room_participants(room: &Entity, cx: &mut TestAppContext) -> RoomPartic let mut remote = room .remote_participants() .values() - .map(|participant| participant.user.github_login.clone().to_string()) + .map(|participant| participant.user.github_login.clone()) .collect::>(); let mut pending = room .pending_participants() .iter() - .map(|user| user.github_login.clone().to_string()) + .map(|user| user.github_login.clone()) .collect::>(); remote.sort(); pending.sort(); diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index aea359d75b..f1cc2bf24a 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -1286,7 +1286,7 @@ async fn test_calls_on_multiple_connections( client_b1.disconnect(&cx_b1.to_async()); executor.advance_clock(RECEIVE_TIMEOUT); client_b1 - .connect(false, &cx_b1.to_async()) + .authenticate_and_connect(false, &cx_b1.to_async()) .await .into_response() .unwrap(); @@ -1667,7 +1667,7 @@ async fn test_project_reconnect( // Client A reconnects. Their project is re-shared, and client B re-joins it. server.allow_connections(); client_a - .connect(false, &cx_a.to_async()) + .authenticate_and_connect(false, &cx_a.to_async()) .await .into_response() .unwrap(); @@ -1796,7 +1796,7 @@ async fn test_project_reconnect( // Client B reconnects. They re-join the room and the remaining shared project. server.allow_connections(); client_b - .connect(false, &cx_b.to_async()) + .authenticate_and_connect(false, &cx_b.to_async()) .await .into_response() .unwrap(); @@ -1881,7 +1881,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_a.user_id().unwrap(), - github_login: "user_a".into(), + github_login: "user_a".to_string(), avatar_uri: "avatar_a".into(), name: None, }), @@ -1900,7 +1900,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_b.user_id().unwrap(), - github_login: "user_b".into(), + github_login: "user_b".to_string(), avatar_uri: "avatar_b".into(), name: None, }), @@ -5738,7 +5738,7 @@ async fn test_contacts( server.allow_connections(); client_c - .connect(false, &cx_c.to_async()) + .authenticate_and_connect(false, &cx_c.to_async()) .await .into_response() .unwrap(); @@ -6079,7 +6079,7 @@ async fn test_contacts( .iter() .map(|contact| { ( - contact.user.github_login.clone().to_string(), + contact.user.github_login.clone(), if contact.online { "online" } else { "offline" }, if contact.busy { "busy" } else { "free" }, ) @@ -6269,7 +6269,7 @@ async fn test_contact_requests( client.disconnect(&cx.to_async()); client.clear_contacts(cx).await; client - .connect(false, &cx.to_async()) + .authenticate_and_connect(false, &cx.to_async()) .await .into_response() .unwrap(); diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs index 9bf906694e..4e64b5526b 100644 --- a/crates/collab/src/tests/notification_tests.rs +++ b/crates/collab/src/tests/notification_tests.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use gpui::{BackgroundExecutor, TestAppContext}; use notifications::NotificationEvent; use parking_lot::Mutex; -use pretty_assertions::assert_eq; use rpc::{Notification, proto}; use crate::tests::TestServer; @@ -18,9 +17,6 @@ async fn test_notifications( let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; - // Wait for authentication/connection to Collab to be established. - executor.run_until_parked(); - let notification_events_a = Arc::new(Mutex::new(Vec::new())); let notification_events_b = Arc::new(Mutex::new(Vec::new())); client_a.notification_store().update(cx_a, |_, cx| { diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 5fcc622fc1..ab84e02b19 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -8,7 +8,6 @@ use crate::{ use anyhow::anyhow; use call::ActiveCall; use channel::{ChannelBuffer, ChannelStore}; -use client::test::{make_get_authenticated_user_response, parse_authorization_header}; use client::{ self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore, proto::PeerId, @@ -21,7 +20,7 @@ use fs::FakeFs; use futures::{StreamExt as _, channel::oneshot}; use git::GitHostingProviderRegistry; use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext}; -use http_client::{FakeHttpClient, Method}; +use http_client::FakeHttpClient; use language::LanguageRegistry; use node_runtime::NodeRuntime; use notifications::NotificationStore; @@ -162,8 +161,6 @@ impl TestServer { } pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { - const ACCESS_TOKEN: &str = "the-token"; - let fs = FakeFs::new(cx.executor()); cx.update(|cx| { @@ -178,7 +175,7 @@ impl TestServer { }); let clock = Arc::new(FakeSystemClock::new()); - + let http = FakeHttpClient::with_404_response(); let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await { user.id @@ -200,47 +197,6 @@ impl TestServer { .expect("creating user failed") .user_id }; - - let http = FakeHttpClient::create({ - let name = name.to_string(); - move |req| { - let name = name.clone(); - async move { - match (req.method(), req.uri().path()) { - (&Method::GET, "/client/users/me") => { - let credentials = parse_authorization_header(&req); - if credentials - != Some(Credentials { - user_id: user_id.to_proto(), - access_token: ACCESS_TOKEN.into(), - }) - { - return Ok(http_client::Response::builder() - .status(401) - .body("Unauthorized".into()) - .unwrap()); - } - - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&make_get_authenticated_user_response( - user_id.0, name, - )) - .unwrap() - .into(), - ) - .unwrap()) - } - _ => Ok(http_client::Response::builder() - .status(404) - .body("Not Found".into()) - .unwrap()), - } - } - } - }); - let client_name = name.to_string(); let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx)); let server = self.server.clone(); @@ -252,10 +208,11 @@ impl TestServer { .unwrap() .set_id(user_id.to_proto()) .override_authenticate(move |cx| { + let access_token = "the-token".to_string(); cx.spawn(async move |_| { Ok(Credentials { user_id: user_id.to_proto(), - access_token: ACCESS_TOKEN.into(), + access_token, }) }) }) @@ -264,7 +221,7 @@ impl TestServer { credentials, &Credentials { user_id: user_id.0 as u64, - access_token: ACCESS_TOKEN.into(), + access_token: "the-token".into() } ); @@ -299,7 +256,6 @@ impl TestServer { ZedVersion(SemanticVersion::new(1, 0, 0)), None, None, - None, Some(connection_id_tx), Executor::Deterministic(cx.background_executor().clone()), None, @@ -362,7 +318,7 @@ impl TestServer { }); client - .connect(false, &cx.to_async()) + .authenticate_and_connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -735,17 +691,17 @@ impl TestClient { current: store .contacts() .iter() - .map(|contact| contact.user.github_login.clone().to_string()) + .map(|contact| contact.user.github_login.clone()) .collect(), outgoing_requests: store .outgoing_contact_requests() .iter() - .map(|user| user.github_login.clone().to_string()) + .map(|user| user.github_login.clone()) .collect(), incoming_requests: store .incoming_contact_requests() .iter() - .map(|user| user.github_login.clone().to_string()) + .map(|user| user.github_login.clone()) .collect(), }) } diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 3a9b568264..3e2d813f1b 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -1162,7 +1162,7 @@ impl Panel for ChatPanel { } fn icon(&self, _window: &Window, cx: &App) -> Option { - self.enabled(cx).then(|| ui::IconName::Chat) + self.enabled(cx).then(|| ui::IconName::MessageBubbles) } fn icon_tooltip(&self, _: &Window, _: &App) -> Option<&'static str> { diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 689591df12..4d5973481e 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -940,7 +940,7 @@ impl CollabPanel { room.read(cx).local_participant().role == proto::ChannelRole::Admin }); - ListItem::new(user.github_login.clone()) + ListItem::new(SharedString::from(user.github_login.clone())) .start_slot(Avatar::new(user.avatar_uri.clone())) .child(Label::new(user.github_login.clone())) .toggle_state(is_selected) @@ -1124,7 +1124,7 @@ impl CollabPanel { .relative() .gap_1() .child(render_tree_branch(false, false, window, cx)) - .child(IconButton::new(0, IconName::Chat)) + .child(IconButton::new(0, IconName::MessageBubbles)) .children(has_messages_notification.then(|| { div() .w_1p5() @@ -2331,7 +2331,7 @@ impl CollabPanel { let client = this.client.clone(); cx.spawn_in(window, async move |_, cx| { client - .connect(true, &cx) + .authenticate_and_connect(true, &cx) .await .into_response() .notify_async_err(cx); @@ -2583,7 +2583,7 @@ impl CollabPanel { ) -> impl IntoElement { let online = contact.online; let busy = contact.busy || calling; - let github_login = contact.user.github_login.clone(); + let github_login = SharedString::from(contact.user.github_login.clone()); let item = ListItem::new(github_login.clone()) .indent_level(1) .indent_step_size(px(20.)) @@ -2662,7 +2662,7 @@ impl CollabPanel { is_selected: bool, cx: &mut Context, ) -> impl IntoElement { - let github_login = user.github_login.clone(); + let github_login = SharedString::from(user.github_login.clone()); let user_id = user.id; let is_response_pending = self.user_store.read(cx).is_contact_request_pending(user); let color = if is_response_pending { @@ -2923,7 +2923,7 @@ impl CollabPanel { .gap_1() .px_1() .child( - IconButton::new("channel_chat", IconName::Chat) + IconButton::new("channel_chat", IconName::MessageBubbles) .style(ButtonStyle::Filled) .shape(ui::IconButtonShape::Square) .icon_size(IconSize::Small) @@ -2939,7 +2939,7 @@ impl CollabPanel { .visible_on_hover(""), ) .child( - IconButton::new("channel_notes", IconName::FileText) + IconButton::new("channel_notes", IconName::File) .style(ButtonStyle::Filled) .shape(ui::IconButtonShape::Square) .icon_size(IconSize::Small) diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index c3e834b645..fba8f66c2d 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -634,13 +634,13 @@ impl Render for NotificationPanel { .child(Icon::new(IconName::Envelope)), ) .map(|this| { - if !self.client.status().borrow().is_connected() { + if self.client.user_id().is_none() { this.child( v_flex() .gap_2() .p_4() .child( - Button::new("connect_prompt_button", "Connect") + Button::new("sign_in_prompt_button", "Sign in") .icon_color(Color::Muted) .icon(IconName::Github) .icon_position(IconPosition::Start) @@ -652,7 +652,10 @@ impl Render for NotificationPanel { let client = client.clone(); window .spawn(cx, async move |cx| { - match client.connect(true, &cx).await { + match client + .authenticate_and_connect(true, &cx) + .await + { util::ConnectionResult::Timeout => { log::error!("Connection timeout"); } @@ -670,7 +673,7 @@ impl Render for NotificationPanel { ) .child( div().flex().w_full().items_center().child( - Label::new("Connect to view notifications.") + Label::new("Sign in to view notifications.") .color(Color::Muted) .size(LabelSize::Small), ), diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 65283afa87..a1facb817d 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result, anyhow}; use collections::HashMap; -use futures::{FutureExt, StreamExt, channel::oneshot, future, select}; +use futures::{FutureExt, StreamExt, channel::oneshot, select}; use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task}; use parking_lot::Mutex; use postage::barrier; @@ -10,19 +10,15 @@ use smol::channel; use std::{ fmt, path::PathBuf, - pin::pin, sync::{ Arc, atomic::{AtomicI32, Ordering::SeqCst}, }, time::{Duration, Instant}, }; -use util::{ResultExt, TryFutureExt}; +use util::TryFutureExt; -use crate::{ - transport::{StdioTransport, Transport}, - types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled}, -}; +use crate::transport::{StdioTransport, Transport}; const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); @@ -36,7 +32,6 @@ pub const INTERNAL_ERROR: i32 = -32603; type ResponseHandler = Box)>; type NotificationHandler = Box; -type RequestHandler = Box; #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[serde(untagged)] @@ -83,15 +78,6 @@ pub struct Request<'a, T> { pub params: T, } -#[derive(Serialize, Deserialize)] -pub struct AnyRequest<'a> { - pub jsonrpc: &'a str, - pub id: RequestId, - pub method: &'a str, - #[serde(skip_serializing_if = "is_null_value")] - pub params: Option<&'a RawValue>, -} - #[derive(Serialize, Deserialize)] struct AnyResponse<'a> { jsonrpc: &'a str, @@ -191,23 +177,15 @@ impl Client { Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); - let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default())); let receive_input_task = cx.spawn({ let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); - let request_handlers = request_handlers.clone(); let transport = transport.clone(); async move |cx| { - Self::handle_input( - transport, - notification_handlers, - request_handlers, - response_handlers, - cx, - ) - .log_err() - .await + Self::handle_input(transport, notification_handlers, response_handlers, cx) + .log_err() + .await } }); let receive_err_task = cx.spawn({ @@ -253,24 +231,13 @@ impl Client { async fn handle_input( transport: Arc, notification_handlers: Arc>>, - request_handlers: Arc>>, response_handlers: Arc>>>, cx: &mut AsyncApp, ) -> anyhow::Result<()> { let mut receiver = transport.receive(); while let Some(message) = receiver.next().await { - log::trace!("recv: {}", &message); - if let Ok(request) = serde_json::from_str::(&message) { - let mut request_handlers = request_handlers.lock(); - if let Some(handler) = request_handlers.get_mut(request.method) { - handler( - request.id, - request.params.unwrap_or(RawValue::NULL), - cx.clone(), - ); - } - } else if let Ok(response) = serde_json::from_str::(&message) { + if let Ok(response) = serde_json::from_str::(&message) { if let Some(handlers) = response_handlers.lock().as_mut() { if let Some(handler) = handlers.remove(&response.id) { handler(Ok(message.to_string())); @@ -281,8 +248,6 @@ impl Client { if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { handler(notification.params.unwrap_or(Value::Null), cx.clone()); } - } else { - log::error!("Unhandled JSON from context_server: {}", message); } } @@ -330,17 +295,6 @@ impl Client { &self, method: &str, params: impl Serialize, - ) -> Result { - self.request_with(method, params, None, Some(REQUEST_TIMEOUT)) - .await - } - - pub async fn request_with( - &self, - method: &str, - params: impl Serialize, - cancel_rx: Option>, - timeout: Option, ) -> Result { let id = self.next_id.fetch_add(1, SeqCst); let request = serde_json::to_string(&Request { @@ -376,23 +330,7 @@ impl Client { handle_response?; send?; - let mut timeout_fut = pin!( - match timeout { - Some(timeout) => future::Either::Left(executor.timer(timeout)), - None => future::Either::Right(future::pending()), - } - .fuse() - ); - let mut cancel_fut = pin!( - match cancel_rx { - Some(rx) => future::Either::Left(async { - rx.await.log_err(); - }), - None => future::Either::Right(future::pending()), - } - .fuse() - ); - + let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse(); select! { response = rx.fuse() => { let elapsed = started.elapsed(); @@ -411,18 +349,8 @@ impl Client { Err(_) => anyhow::bail!("cancelled") } } - _ = cancel_fut => { - self.notify( - Cancelled::METHOD, - ClientNotification::Cancelled(CancelledParams { - request_id: RequestId::Int(id), - reason: None - }) - ).log_err(); - anyhow::bail!(RequestCanceled) - } - _ = timeout_fut => { - log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout.unwrap()); + _ = timeout => { + log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT); anyhow::bail!("Context server request timeout"); } } @@ -441,23 +369,14 @@ impl Client { Ok(()) } - pub fn on_notification( - &self, - method: &'static str, - f: Box, - ) { - self.notification_handlers.lock().insert(method, f); - } -} - -#[derive(Debug)] -pub struct RequestCanceled; - -impl std::error::Error for RequestCanceled {} - -impl std::fmt::Display for RequestCanceled { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("Context server request was canceled") + #[allow(unused)] + pub fn on_notification(&self, method: &'static str, f: F) + where + F: 'static + Send + FnMut(Value, AsyncApp), + { + self.notification_handlers + .lock() + .insert(method, Box::new(f)); } } diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 34fa29678d..e76e7972f7 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -95,28 +95,8 @@ impl ContextServer { self.client.read().clone() } - pub async fn start(&self, cx: &AsyncApp) -> Result<()> { - self.initialize(self.new_client(cx)?).await - } - - /// Starts the context server, making sure handlers are registered before initialization happens - pub async fn start_with_handlers( - &self, - notification_handlers: Vec<( - &'static str, - Box, - )>, - cx: &AsyncApp, - ) -> Result<()> { - let client = self.new_client(cx)?; - for (method, handler) in notification_handlers { - client.on_notification(method, handler); - } - self.initialize(client).await - } - - fn new_client(&self, cx: &AsyncApp) -> Result { - Ok(match &self.configuration { + pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { + let client = match &self.configuration { ContextServerTransport::Stdio(command, working_directory) => Client::stdio( client::ContextServerId(self.id.0.clone()), client::ModelContextServerBinary { @@ -133,7 +113,8 @@ impl ContextServer { transport.clone(), cx.clone(), )?, - }) + }; + self.initialize(client).await } async fn initialize(&self, client: Client) -> Result<()> { diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs index 0e85fb2129..9295ad979c 100644 --- a/crates/context_server/src/listener.rs +++ b/crates/context_server/src/listener.rs @@ -9,8 +9,6 @@ use futures::{ }; use gpui::{App, AppContext, AsyncApp, Task}; use net::async_net::{UnixListener, UnixStream}; -use schemars::JsonSchema; -use serde::de::DeserializeOwned; use serde_json::{json, value::RawValue}; use smol::stream::StreamExt; use std::{ @@ -22,32 +20,16 @@ use util::ResultExt; use crate::{ client::{CspResult, RequestId, Response}, - types::{ - CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations, - ToolResponseContent, - requests::{CallTool, ListTools}, - }, + types::Request, }; pub struct McpServer { socket_path: PathBuf, - tools: Rc>>, - handlers: Rc>>, + handlers: Rc>>, _server_task: Task<()>, } -struct RegisteredTool { - tool: Tool, - handler: ToolHandler, -} - -type ToolHandler = Box< - dyn Fn( - Option, - &mut AsyncApp, - ) -> Task>>, ->; -type RequestHandler = Box>, &App) -> Task>; +type McpHandler = Box>, &App) -> Task>; impl McpServer { pub fn new(cx: &AsyncApp) -> Task> { @@ -61,14 +43,12 @@ impl McpServer { cx.spawn(async move |cx| { let (temp_dir, socket_path, listener) = task.await?; - let tools = Rc::new(RefCell::new(HashMap::default())); let handlers = Rc::new(RefCell::new(HashMap::default())); let server_task = cx.spawn({ - let tools = tools.clone(); let handlers = handlers.clone(); async move |cx| { while let Ok((stream, _)) = listener.accept().await { - Self::serve_connection(stream, tools.clone(), handlers.clone(), cx); + Self::serve_connection(stream, handlers.clone(), cx); } drop(temp_dir) } @@ -76,60 +56,11 @@ impl McpServer { Ok(Self { socket_path, _server_task: server_task, - tools, - handlers: handlers, + handlers: handlers.clone(), }) }) } - pub fn add_tool(&mut self, tool: T) { - let mut settings = schemars::generate::SchemaSettings::draft07(); - settings.inline_subschemas = true; - let mut generator = settings.into_generator(); - - let output_schema = generator.root_schema_for::(); - let unit_schema = generator.root_schema_for::(); - - let registered_tool = RegisteredTool { - tool: Tool { - name: T::NAME.into(), - description: Some(tool.description().into()), - input_schema: generator.root_schema_for::().into(), - output_schema: if output_schema == unit_schema { - None - } else { - Some(output_schema.into()) - }, - annotations: Some(tool.annotations()), - }, - handler: Box::new({ - let tool = tool.clone(); - move |input_value, cx| { - let input = match input_value { - Some(input) => serde_json::from_value(input), - None => serde_json::from_value(serde_json::Value::Null), - }; - - let tool = tool.clone(); - match input { - Ok(input) => cx.spawn(async move |cx| { - let output = tool.run(input, cx).await?; - - Ok(ToolResponse { - content: output.content, - structured_content: serde_json::to_value(output.structured_content) - .unwrap_or_default(), - }) - }), - Err(err) => Task::ready(Err(err.into())), - } - } - }), - }; - - self.tools.borrow_mut().insert(T::NAME, registered_tool); - } - pub fn handle_request( &mut self, f: impl Fn(R::Params, &App) -> Task> + 'static, @@ -189,8 +120,7 @@ impl McpServer { fn serve_connection( stream: UnixStream, - tools: Rc>>, - handlers: Rc>>, + handlers: Rc>>, cx: &mut AsyncApp, ) { let (read, write) = smol::io::split(stream); @@ -205,13 +135,7 @@ impl McpServer { let Some(request_id) = request.id.clone() else { continue; }; - - if request.method == CallTool::METHOD { - Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx) - .await; - } else if request.method == ListTools::METHOD { - Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx); - } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) { + if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) { let outgoing_tx = outgoing_tx.clone(); if let Some(task) = cx @@ -225,126 +149,25 @@ impl McpServer { .detach(); } } else { - Self::send_err( - request_id, - format!("unhandled method {}", request.method), - &outgoing_tx, - ); + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response::<()> { + jsonrpc: "2.0", + id: request.id.unwrap(), + value: CspResult::Error(Some(crate::client::Error { + message: format!("unhandled method {}", request.method), + code: -32601, + })), + }) + .unwrap(), + ) + .ok(); } } }) .detach(); } - fn handle_list_tools( - request_id: RequestId, - tools: &Rc>>, - outgoing_tx: &UnboundedSender, - ) { - let response = ListToolsResponse { - tools: tools.borrow().values().map(|t| t.tool.clone()).collect(), - next_cursor: None, - meta: None, - }; - - outgoing_tx - .unbounded_send( - serde_json::to_string(&Response { - jsonrpc: "2.0", - id: request_id, - value: CspResult::Ok(Some(response)), - }) - .unwrap_or_default(), - ) - .ok(); - } - - async fn handle_call_tool( - request_id: RequestId, - params: Option>, - tools: &Rc>>, - outgoing_tx: &UnboundedSender, - cx: &mut AsyncApp, - ) { - let result: Result = match params.as_ref() { - Some(params) => serde_json::from_str(params.get()), - None => serde_json::from_value(serde_json::Value::Null), - }; - - match result { - Ok(params) => { - if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) { - let outgoing_tx = outgoing_tx.clone(); - - let task = (tool.handler)(params.arguments, cx); - cx.spawn(async move |_| { - let response = match task.await { - Ok(result) => CallToolResponse { - content: result.content, - is_error: Some(false), - meta: None, - structured_content: if result.structured_content.is_null() { - None - } else { - Some(result.structured_content) - }, - }, - Err(err) => CallToolResponse { - content: vec![ToolResponseContent::Text { - text: err.to_string(), - }], - is_error: Some(true), - meta: None, - structured_content: None, - }, - }; - - outgoing_tx - .unbounded_send( - serde_json::to_string(&Response { - jsonrpc: "2.0", - id: request_id, - value: CspResult::Ok(Some(response)), - }) - .unwrap_or_default(), - ) - .ok(); - }) - .detach(); - } else { - Self::send_err( - request_id, - format!("Tool not found: {}", params.name), - &outgoing_tx, - ); - } - } - Err(err) => { - Self::send_err(request_id, err.to_string(), &outgoing_tx); - } - } - } - - fn send_err( - request_id: RequestId, - message: impl Into, - outgoing_tx: &UnboundedSender, - ) { - outgoing_tx - .unbounded_send( - serde_json::to_string(&Response::<()> { - jsonrpc: "2.0", - id: request_id, - value: CspResult::Error(Some(crate::client::Error { - message: message.into(), - code: -32601, - })), - }) - .unwrap(), - ) - .ok(); - } - async fn handle_io( mut outgoing_rx: UnboundedReceiver, incoming_tx: UnboundedSender, @@ -393,37 +216,7 @@ impl McpServer { } } -pub trait McpServerTool { - type Input: DeserializeOwned + JsonSchema; - type Output: Serialize + JsonSchema; - - const NAME: &'static str; - - fn description(&self) -> &'static str; - - fn annotations(&self) -> ToolAnnotations { - ToolAnnotations { - title: None, - read_only_hint: None, - destructive_hint: None, - idempotent_hint: None, - open_world_hint: None, - } - } - - fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> impl Future>>; -} - -pub struct ToolResponse { - pub content: Vec, - pub structured_content: T, -} - -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] struct RawRequest { #[serde(skip_serializing_if = "Option::is_none")] id: Option, diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 5355f20f62..d8bbac60d6 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -5,12 +5,7 @@ //! read/write messages and the types from types.rs for serialization/deserialization //! of messages. -use std::time::Duration; - use anyhow::Result; -use futures::channel::oneshot; -use gpui::AsyncApp; -use serde_json::Value; use crate::client::Client; use crate::types::{self, Notification, Request}; @@ -100,26 +95,7 @@ impl InitializedContextServerProtocol { self.inner.request(T::METHOD, params).await } - pub async fn request_with( - &self, - params: T::Params, - cancel_rx: Option>, - timeout: Option, - ) -> Result { - self.inner - .request_with(T::METHOD, params, cancel_rx, timeout) - .await - } - pub fn notify(&self, params: T::Params) -> Result<()> { self.inner.notify(T::METHOD, params) } - - pub fn on_notification( - &self, - method: &'static str, - f: Box, - ) { - self.inner.on_notification(method, f); - } } diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 5fa2420a3d..4a6fdcabd3 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -3,8 +3,6 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::Url; -use crate::client::RequestId; - pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const VERSION_2024_11_05: &str = "2024-11-05"; @@ -102,7 +100,6 @@ pub mod notifications { notification!("notifications/initialized", Initialized, ()); notification!("notifications/progress", Progress, ProgressParams); notification!("notifications/message", Message, MessageParams); - notification!("notifications/cancelled", Cancelled, CancelledParams); notification!( "notifications/resources/updated", ResourcesUpdated, @@ -495,20 +492,18 @@ pub struct RootsCapabilities { pub list_changed: Option, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Tool { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, pub input_schema: serde_json::Value, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub output_schema: Option, #[serde(skip_serializing_if = "Option::is_none")] pub annotations: Option, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ToolAnnotations { /// A human-readable title for the tool. @@ -622,15 +617,11 @@ pub enum ClientNotification { Initialized, Progress(ProgressParams), RootsListChanged, - Cancelled(CancelledParams), -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CancelledParams { - pub request_id: RequestId, - #[serde(skip_serializing_if = "Option::is_none")] - pub reason: Option, + Cancelled { + request_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + reason: Option, + }, } #[derive(Debug, Serialize, Deserialize)] @@ -682,20 +673,6 @@ pub struct CallToolResponse { pub is_error: Option, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub structured_content: Option, -} - -impl CallToolResponse { - pub fn text_contents(&self) -> String { - let mut text = String::new(); - for chunk in &self.content { - if let ToolResponseContent::Text { text: chunk } = chunk { - text.push_str(&chunk) - }; - } - text - } } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index cacf834e0d..e11242cb15 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -85,13 +85,45 @@ pub fn init( move |cx| Copilot::start(new_server_id, fs, node_runtime, cx) }); Copilot::set_global(copilot.clone(), cx); - cx.observe(&copilot, |copilot, cx| { - copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); - }) - .detach(); - cx.observe_global::(|cx| { - if let Some(copilot) = Copilot::global(cx) { - copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); + cx.observe(&copilot, |handle, cx| { + let copilot_action_types = [ + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + let copilot_auth_action_types = [TypeId::of::()]; + let copilot_no_auth_action_types = [TypeId::of::()]; + let status = handle.read(cx).status(); + + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let filter = CommandPaletteFilter::global_mut(cx); + + if is_ai_disabled { + filter.hide_action_types(&copilot_action_types); + filter.hide_action_types(&copilot_auth_action_types); + filter.hide_action_types(&copilot_no_auth_action_types); + } else { + match status { + Status::Disabled => { + filter.hide_action_types(&copilot_action_types); + filter.hide_action_types(&copilot_auth_action_types); + filter.hide_action_types(&copilot_no_auth_action_types); + } + Status::Authorized => { + filter.hide_action_types(&copilot_no_auth_action_types); + filter.show_action_types( + copilot_action_types + .iter() + .chain(&copilot_auth_action_types), + ); + } + _ => { + filter.hide_action_types(&copilot_action_types); + filter.hide_action_types(&copilot_auth_action_types); + filter.show_action_types(copilot_no_auth_action_types.iter()); + } + } } }) .detach(); @@ -1099,44 +1131,6 @@ impl Copilot { cx.notify(); } } - - fn update_action_visibilities(&self, cx: &mut App) { - let signed_in_actions = [ - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - ]; - let auth_actions = [TypeId::of::()]; - let no_auth_actions = [TypeId::of::()]; - let status = self.status(); - - let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; - let filter = CommandPaletteFilter::global_mut(cx); - - if is_ai_disabled { - filter.hide_action_types(&signed_in_actions); - filter.hide_action_types(&auth_actions); - filter.hide_action_types(&no_auth_actions); - } else { - match status { - Status::Disabled => { - filter.hide_action_types(&signed_in_actions); - filter.hide_action_types(&auth_actions); - filter.hide_action_types(&no_auth_actions); - } - Status::Authorized => { - filter.hide_action_types(&no_auth_actions); - filter.show_action_types(signed_in_actions.iter().chain(&auth_actions)); - } - _ => { - filter.hide_action_types(&signed_in_actions); - filter.hide_action_types(&auth_actions); - filter.show_action_types(no_auth_actions.iter()); - } - } - } - } } fn id_for_language(language: Option<&Arc>) -> String { diff --git a/crates/debugger_ui/src/tests/debugger_panel.rs b/crates/debugger_ui/src/tests/debugger_panel.rs index 6180831ea9..505df09cfb 100644 --- a/crates/debugger_ui/src/tests/debugger_panel.rs +++ b/crates/debugger_ui/src/tests/debugger_panel.rs @@ -918,7 +918,7 @@ async fn test_debug_panel_item_thread_status_reset_on_failure( .unwrap(); let client = session.update(cx, |session, _| session.adapter_client().unwrap()); - const THREAD_ID_NUM: i64 = 1; + const THREAD_ID_NUM: u64 = 1; client.on_request::(move |_, _| { Ok(dap::ThreadsResponse { diff --git a/crates/docs_preprocessor/Cargo.toml b/crates/docs_preprocessor/Cargo.toml index e46ceb18db..a0df669abe 100644 --- a/crates/docs_preprocessor/Cargo.toml +++ b/crates/docs_preprocessor/Cargo.toml @@ -7,19 +7,17 @@ license = "GPL-3.0-or-later" [dependencies] anyhow.workspace = true -command_palette.workspace = true -gpui.workspace = true -# We are specifically pinning this version of mdbook, as later versions introduce issues with double-nested subdirectories. -# Ask @maxdeviant about this before bumping. -mdbook = "= 0.4.40" -regex.workspace = true +clap.workspace = true +mdbook = "0.4.40" serde.workspace = true serde_json.workspace = true settings.workspace = true +regex.workspace = true util.workspace = true workspace-hack.workspace = true zed.workspace = true -zlog.workspace = true +gpui.workspace = true +command_palette.workspace = true [lints] workspace = true diff --git a/crates/docs_preprocessor/src/main.rs b/crates/docs_preprocessor/src/main.rs index 1448f4cb52..8eeeb6f0c5 100644 --- a/crates/docs_preprocessor/src/main.rs +++ b/crates/docs_preprocessor/src/main.rs @@ -1,15 +1,14 @@ -use anyhow::{Context, Result}; +use anyhow::Result; +use clap::{Arg, ArgMatches, Command}; use mdbook::BookItem; use mdbook::book::{Book, Chapter}; use mdbook::preprocess::CmdPreprocessor; use regex::Regex; use settings::KeymapFile; -use std::borrow::Cow; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::io::{self, Read}; use std::process; use std::sync::LazyLock; -use util::paths::PathExt; static KEYMAP_MACOS: LazyLock = LazyLock::new(|| { load_keymap("keymaps/default-macos.json").expect("Failed to load MacOS keymap") @@ -21,68 +20,60 @@ static KEYMAP_LINUX: LazyLock = LazyLock::new(|| { static ALL_ACTIONS: LazyLock> = LazyLock::new(dump_all_gpui_actions); -const FRONT_MATTER_COMMENT: &'static str = ""; +pub fn make_app() -> Command { + Command::new("zed-docs-preprocessor") + .about("Preprocesses Zed Docs content to provide rich action & keybinding support and more") + .subcommand( + Command::new("supports") + .arg(Arg::new("renderer").required(true)) + .about("Check whether a renderer is supported by this preprocessor"), + ) +} fn main() -> Result<()> { - zlog::init(); - zlog::init_output_stderr(); + let matches = make_app().get_matches(); // call a zed:: function so everything in `zed` crate is linked and // all actions in the actual app are registered zed::stdout_is_a_pty(); - let args = std::env::args().skip(1).collect::>(); - match args.get(0).map(String::as_str) { - Some("supports") => { - let renderer = args.get(1).expect("Required argument"); - let supported = renderer != "not-supported"; - if supported { - process::exit(0); - } else { - process::exit(1); - } - } - Some("postprocess") => handle_postprocessing()?, - _ => handle_preprocessing()?, + if let Some(sub_args) = matches.subcommand_matches("supports") { + handle_supports(sub_args); + } else { + handle_preprocessing()?; } Ok(()) } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -enum PreprocessorError { +enum Error { ActionNotFound { action_name: String }, DeprecatedActionUsed { used: String, should_be: String }, - InvalidFrontmatterLine(String), } -impl PreprocessorError { +impl Error { fn new_for_not_found_action(action_name: String) -> Self { for action in &*ALL_ACTIONS { for alias in action.deprecated_aliases { if alias == &action_name { - return PreprocessorError::DeprecatedActionUsed { + return Error::DeprecatedActionUsed { used: action_name.clone(), should_be: action.name.to_string(), }; } } } - PreprocessorError::ActionNotFound { + Error::ActionNotFound { action_name: action_name.to_string(), } } } -impl std::fmt::Display for PreprocessorError { +impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - PreprocessorError::InvalidFrontmatterLine(line) => { - write!(f, "Invalid frontmatter line: {}", line) - } - PreprocessorError::ActionNotFound { action_name } => { - write!(f, "Action not found: {}", action_name) - } - PreprocessorError::DeprecatedActionUsed { used, should_be } => write!( + Error::ActionNotFound { action_name } => write!(f, "Action not found: {}", action_name), + Error::DeprecatedActionUsed { used, should_be } => write!( f, "Deprecated action used: {} should be {}", used, should_be @@ -98,9 +89,8 @@ fn handle_preprocessing() -> Result<()> { let (_ctx, mut book) = CmdPreprocessor::parse_input(input.as_bytes())?; - let mut errors = HashSet::::new(); + let mut errors = HashSet::::new(); - handle_frontmatter(&mut book, &mut errors); template_and_validate_keybindings(&mut book, &mut errors); template_and_validate_actions(&mut book, &mut errors); @@ -118,41 +108,19 @@ fn handle_preprocessing() -> Result<()> { Ok(()) } -fn handle_frontmatter(book: &mut Book, errors: &mut HashSet) { - let frontmatter_regex = Regex::new(r"(?s)^\s*---(.*?)---").unwrap(); - for_each_chapter_mut(book, |chapter| { - let new_content = frontmatter_regex.replace(&chapter.content, |caps: ®ex::Captures| { - let frontmatter = caps[1].trim(); - let frontmatter = frontmatter.trim_matches(&[' ', '-', '\n']); - let mut metadata = HashMap::::default(); - for line in frontmatter.lines() { - let Some((name, value)) = line.split_once(':') else { - errors.insert(PreprocessorError::InvalidFrontmatterLine(format!( - "{}: {}", - chapter_breadcrumbs(&chapter), - line - ))); - continue; - }; - let name = name.trim(); - let value = value.trim(); - metadata.insert(name.to_string(), value.to_string()); - } - FRONT_MATTER_COMMENT.replace( - "{}", - &serde_json::to_string(&metadata).expect("Failed to serialize metadata"), - ) - }); - match new_content { - Cow::Owned(content) => { - chapter.content = content; - } - Cow::Borrowed(_) => {} - } - }); +fn handle_supports(sub_args: &ArgMatches) -> ! { + let renderer = sub_args + .get_one::("renderer") + .expect("Required argument"); + let supported = renderer != "not-supported"; + if supported { + process::exit(0); + } else { + process::exit(1); + } } -fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { +fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { let regex = Regex::new(r"\{#kb (.*?)\}").unwrap(); for_each_chapter_mut(book, |chapter| { @@ -160,9 +128,7 @@ fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { +fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet) { let regex = Regex::new(r"\{#action (.*?)\}").unwrap(); for_each_chapter_mut(book, |chapter| { @@ -186,9 +152,7 @@ fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet{}", &action.human_name) @@ -253,13 +217,6 @@ fn name_for_action(action_as_str: String) -> String { .unwrap_or(action_as_str) } -fn chapter_breadcrumbs(chapter: &Chapter) -> String { - let mut breadcrumbs = Vec::with_capacity(chapter.parent_names.len() + 1); - breadcrumbs.extend(chapter.parent_names.iter().map(String::as_str)); - breadcrumbs.push(chapter.name.as_str()); - format!("[{:?}] {}", chapter.source_path, breadcrumbs.join(" > ")) -} - fn load_keymap(asset_path: &str) -> Result { let content = util::asset_str::(asset_path); KeymapFile::parse(content.as_ref()) @@ -297,126 +254,3 @@ fn dump_all_gpui_actions() -> Vec { return actions; } - -fn handle_postprocessing() -> Result<()> { - let logger = zlog::scoped!("render"); - let mut ctx = mdbook::renderer::RenderContext::from_json(io::stdin())?; - let output = ctx - .config - .get_mut("output") - .expect("has output") - .as_table_mut() - .expect("output is table"); - let zed_html = output.remove("zed-html").expect("zed-html output defined"); - let default_description = zed_html - .get("default-description") - .expect("Default description not found") - .as_str() - .expect("Default description not a string") - .to_string(); - let default_title = zed_html - .get("default-title") - .expect("Default title not found") - .as_str() - .expect("Default title not a string") - .to_string(); - - output.insert("html".to_string(), zed_html); - mdbook::Renderer::render(&mdbook::renderer::HtmlHandlebars::new(), &ctx)?; - let ignore_list = ["toc.html"]; - - let root_dir = ctx.destination.clone(); - let mut files = Vec::with_capacity(128); - let mut queue = Vec::with_capacity(64); - queue.push(root_dir.clone()); - while let Some(dir) = queue.pop() { - for entry in std::fs::read_dir(&dir).context(dir.to_sanitized_string())? { - let Ok(entry) = entry else { - continue; - }; - let file_type = entry.file_type().context("Failed to determine file type")?; - if file_type.is_dir() { - queue.push(entry.path()); - } - if file_type.is_file() - && matches!( - entry.path().extension().and_then(std::ffi::OsStr::to_str), - Some("html") - ) - { - if ignore_list.contains(&&*entry.file_name().to_string_lossy()) { - zlog::info!(logger => "Ignoring {}", entry.path().to_string_lossy()); - } else { - files.push(entry.path()); - } - } - } - } - - zlog::info!(logger => "Processing {} `.html` files", files.len()); - let meta_regex = Regex::new(&FRONT_MATTER_COMMENT.replace("{}", "(.*)")).unwrap(); - for file in files { - let contents = std::fs::read_to_string(&file)?; - let mut meta_description = None; - let mut meta_title = None; - let contents = meta_regex.replace(&contents, |caps: ®ex::Captures| { - let metadata: HashMap = serde_json::from_str(&caps[1]).with_context(|| format!("JSON Metadata: {:?}", &caps[1])).expect("Failed to deserialize metadata"); - for (kind, content) in metadata { - match kind.as_str() { - "description" => { - meta_description = Some(content); - } - "title" => { - meta_title = Some(content); - } - _ => { - zlog::warn!(logger => "Unrecognized frontmatter key: {} in {:?}", kind, pretty_path(&file, &root_dir)); - } - } - } - String::new() - }); - let meta_description = meta_description.as_ref().unwrap_or_else(|| { - zlog::warn!(logger => "No meta description found for {:?}", pretty_path(&file, &root_dir)); - &default_description - }); - let page_title = extract_title_from_page(&contents, pretty_path(&file, &root_dir)); - let meta_title = meta_title.as_ref().unwrap_or_else(|| { - zlog::debug!(logger => "No meta title found for {:?}", pretty_path(&file, &root_dir)); - &default_title - }); - let meta_title = format!("{} | {}", page_title, meta_title); - zlog::trace!(logger => "Updating {:?}", pretty_path(&file, &root_dir)); - let contents = contents.replace("#description#", meta_description); - let contents = TITLE_REGEX - .replace(&contents, |_: ®ex::Captures| { - format!("{}", meta_title) - }) - .to_string(); - // let contents = contents.replace("#title#", &meta_title); - std::fs::write(file, contents)?; - } - return Ok(()); - - fn pretty_path<'a>( - path: &'a std::path::PathBuf, - root: &'a std::path::PathBuf, - ) -> &'a std::path::Path { - &path.strip_prefix(&root).unwrap_or(&path) - } - const TITLE_REGEX: std::cell::LazyCell = - std::cell::LazyCell::new(|| Regex::new(r"\s*(.*?)\s*").unwrap()); - fn extract_title_from_page(contents: &str, pretty_path: &std::path::Path) -> String { - let title_tag_contents = &TITLE_REGEX - .captures(&contents) - .with_context(|| format!("Failed to find title in {:?}", pretty_path)) - .expect("Page has element")[1]; - let title = title_tag_contents - .trim() - .strip_suffix("- Zed") - .unwrap_or(title_tag_contents) - .trim() - .to_string(); - title - } -} diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index ab2d1c8ecb..41022b3d3c 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -113,7 +113,6 @@ tree-sitter-html.workspace = true tree-sitter-rust.workspace = true tree-sitter-typescript.workspace = true tree-sitter-yaml.workspace = true -tree-sitter-bash.workspace = true unindent.workspace = true util = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index 3a3a57ca64..f80a6afbbb 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -315,8 +315,9 @@ actions!( [ /// Accepts the full edit prediction. AcceptEditPrediction, + /// Accepts a partial Copilot suggestion. + AcceptPartialCopilotSuggestion, /// Accepts a partial edit prediction. - #[action(deprecated_aliases = ["editor::AcceptPartialCopilotSuggestion"])] AcceptPartialEditPrediction, /// Adds a cursor above the current selection. AddSelectionAbove, @@ -364,8 +365,6 @@ actions!( ConvertToLowerCase, /// Toggles the case of selected text. ConvertToOppositeCase, - /// Converts selected text to sentence case. - ConvertToSentenceCase, /// Converts selected text to snake_case. ConvertToSnakeCase, /// Converts selected text to Title Case. diff --git a/crates/editor/src/code_completion_tests.rs b/crates/editor/src/code_completion_tests.rs index fd8db29584..4f9822b597 100644 --- a/crates/editor/src/code_completion_tests.rs +++ b/crates/editor/src/code_completion_tests.rs @@ -94,7 +94,7 @@ async fn test_fuzzy_score(cx: &mut TestAppContext) { filter_and_sort_matches("set_text", &completions, SnippetSortOrder::Top, cx).await; assert_eq!(matches[0].string, "set_text"); assert_eq!(matches[1].string, "set_text_style_refinement"); - assert_eq!(matches[2].string, "set_placeholder_text"); + assert_eq!(matches[2].string, "set_context_menu_options"); } // fuzzy filter text over label, sort_text and sort_kind @@ -216,28 +216,6 @@ async fn test_sort_positions(cx: &mut TestAppContext) { assert_eq!(matches[0].string, "rounded-full"); } -#[gpui::test] -async fn test_fuzzy_over_sort_positions(cx: &mut TestAppContext) { - let completions = vec![ - CompletionBuilder::variable("lsp_document_colors", None, "7fffffff"), // 0.29 fuzzy score - CompletionBuilder::function( - "language_servers_running_disk_based_diagnostics", - None, - "7fffffff", - ), // 0.168 fuzzy score - CompletionBuilder::function("code_lens", None, "7fffffff"), // 3.2 fuzzy score - CompletionBuilder::variable("lsp_code_lens", None, "7fffffff"), // 3.2 fuzzy score - CompletionBuilder::function("fetch_code_lens", None, "7fffffff"), // 3.2 fuzzy score - ]; - - let matches = - filter_and_sort_matches("lens", &completions, SnippetSortOrder::default(), cx).await; - - assert_eq!(matches[0].string, "code_lens"); - assert_eq!(matches[1].string, "lsp_code_lens"); - assert_eq!(matches[2].string, "fetch_code_lens"); -} - async fn test_for_each_prefix<F>( target: &str, completions: &Vec<Completion>, diff --git a/crates/editor/src/code_context_menus.rs b/crates/editor/src/code_context_menus.rs index 4ae2a14ca7..9f842836ed 100644 --- a/crates/editor/src/code_context_menus.rs +++ b/crates/editor/src/code_context_menus.rs @@ -1057,9 +1057,9 @@ impl CompletionsMenu { enum MatchTier<'a> { WordStartMatch { sort_exact: Reverse<i32>, + sort_positions: Vec<usize>, sort_snippet: Reverse<i32>, sort_score: Reverse<OrderedFloat<f64>>, - sort_positions: Vec<usize>, sort_text: Option<&'a str>, sort_kind: usize, sort_label: &'a str, @@ -1137,9 +1137,9 @@ impl CompletionsMenu { MatchTier::WordStartMatch { sort_exact, + sort_positions, sort_snippet, sort_score, - sort_positions, sort_text, sort_kind, sort_label, diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index e4628b43aa..eccc8d3e25 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -51,56 +51,42 @@ mod signature_help; pub mod test; pub(crate) use actions::*; -pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; -pub use editor_settings::{ - CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, - ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap, ShowScrollbar, -}; -pub use editor_settings_controls::*; -pub use element::{ - CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition, -}; -pub use git::blame::BlameRenderer; -pub use hover_popover::hover_markdown_style; -pub use inline_completion::Direction; -pub use items::MAX_TAB_TITLE_LEN; -pub use lsp::CompletionContext; -pub use lsp_ext::lsp_tasks; -pub use multi_buffer::{ - Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, PathKey, - RowInfo, ToOffset, ToPoint, -}; -pub use proposed_changes_editor::{ - ProposedChangeLocation, ProposedChangesEditor, ProposedChangesEditorToolbar, -}; -pub use text::Bias; - -use ::git::{ - Restore, - blame::{BlameEntry, ParsedCommitMessage}, -}; +pub use actions::{AcceptEditPrediction, OpenExcerpts, OpenExcerptsSplit}; use aho_corasick::AhoCorasick; use anyhow::{Context as _, Result, anyhow}; use blink_manager::BlinkManager; use buffer_diff::DiffHunkStatus; use client::{Collaborator, DisableAiSettings, ParticipantIndex}; use clock::{AGENT_REPLICA_ID, ReplicaId}; -use code_context_menus::{ - AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu, - CompletionsMenu, ContextMenuOrigin, -}; use collections::{BTreeMap, HashMap, HashSet, VecDeque}; use convert_case::{Case, Casing}; use dap::TelemetrySpawnLocation; use display_map::*; +pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; +pub use editor_settings::{ + CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, + ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowScrollbar, +}; use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings}; +pub use editor_settings_controls::*; use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layout_line}; +pub use element::{ + CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition, +}; use futures::{ FutureExt, StreamExt as _, future::{self, Shared, join}, stream::FuturesUnordered, }; use fuzzy::{StringMatch, StringMatchCandidate}; +use lsp_colors::LspColorData; + +use ::git::blame::BlameEntry; +use ::git::{Restore, blame::ParsedCommitMessage}; +use code_context_menus::{ + AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu, + CompletionsMenu, ContextMenuOrigin, +}; use git::blame::{GitBlame, GlobalBlameRenderer}; use gpui::{ Action, Animation, AnimationExt, AnyElement, App, AppContext, AsyncWindowContext, @@ -114,43 +100,32 @@ use gpui::{ }; use highlight_matching_bracket::refresh_matching_bracket_highlights; use hover_links::{HoverLink, HoveredLinkState, InlayHighlight, find_file}; +pub use hover_popover::hover_markdown_style; use hover_popover::{HoverState, hide_hover}; use indent_guides::ActiveIndentGuidesState; use inlay_hint_cache::{InlayHintCache, InlaySplice, InvalidationStrategy}; +pub use inline_completion::Direction; use inline_completion::{EditPredictionProvider, InlineCompletionProviderHandle}; +pub use items::MAX_TAB_TITLE_LEN; use itertools::Itertools; use language::{ - AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, BufferRow, - BufferSnapshot, Capability, CharClassifier, CharKind, CodeLabel, CursorShape, DiagnosticEntry, - DiffOptions, EditPredictionsMode, EditPreview, HighlightedText, IndentKind, IndentSize, - Language, OffsetRangeExt, Point, Runnable, RunnableRange, Selection, SelectionGoal, TextObject, - TransactionId, TreeSitterOptions, WordsQuery, + AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, Capability, CharKind, + CodeLabel, CursorShape, DiagnosticEntry, DiffOptions, EditPredictionsMode, EditPreview, + HighlightedText, IndentKind, IndentSize, Language, OffsetRangeExt, Point, Selection, + SelectionGoal, TextObject, TransactionId, TreeSitterOptions, WordsQuery, language_settings::{ self, InlayHintSettings, LspInsertMode, RewrapBehavior, WordsCompletionMode, all_language_settings, language_settings, }, - point_from_lsp, point_to_lsp, text_diff_with_options, + point_from_lsp, text_diff_with_options, }; +use language::{BufferRow, CharClassifier, Runnable, RunnableRange, point_to_lsp}; use linked_editing_ranges::refresh_linked_ranges; -use lsp::{ - CodeActionKind, CompletionItemKind, CompletionTriggerKind, InsertTextFormat, InsertTextMode, - LanguageServerId, LanguageServerName, -}; -use lsp_colors::LspColorData; use markdown::Markdown; use mouse_context_menu::MouseContextMenu; -use movement::TextLayoutDetails; -use multi_buffer::{ - ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow, - MultiOrSingleBufferOffsetRange, ToOffsetUtf16, -}; -use parking_lot::Mutex; use persistence::DB; use project::{ - BreakpointWithPosition, CodeAction, Completion, CompletionIntent, CompletionResponse, - CompletionSource, DocumentHighlight, InlayHint, Location, LocationLink, PrepareRenameResponse, - Project, ProjectItem, ProjectPath, ProjectTransaction, TaskSourceKind, - debugger::breakpoint_store::Breakpoint, + BreakpointWithPosition, CompletionResponse, ProjectPath, debugger::{ breakpoint_store::{ BreakpointEditAction, BreakpointSessionState, BreakpointState, BreakpointStore, @@ -159,12 +134,44 @@ use project::{ session::{Session, SessionEvent}, }, git_store::{GitStoreEvent, RepositoryEvent}, - lsp_store::{CompletionDocumentation, FormatTrigger, LspFormatTarget, OpenLspBufferHandle}, project_settings::{DiagnosticSeverity, GoToDiagnosticSeverityFilter}, +}; + +pub use git::blame::BlameRenderer; +pub use proposed_changes_editor::{ + ProposedChangeLocation, ProposedChangesEditor, ProposedChangesEditorToolbar, +}; +use std::{cell::OnceCell, iter::Peekable, ops::Not}; +use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables}; + +pub use lsp::CompletionContext; +use lsp::{ + CodeActionKind, CompletionItemKind, CompletionTriggerKind, InsertTextFormat, InsertTextMode, + LanguageServerId, LanguageServerName, +}; + +use language::BufferSnapshot; +pub use lsp_ext::lsp_tasks; +use movement::TextLayoutDetails; +pub use multi_buffer::{ + Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, PathKey, + RowInfo, ToOffset, ToPoint, +}; +use multi_buffer::{ + ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow, + MultiOrSingleBufferOffsetRange, ToOffsetUtf16, +}; +use parking_lot::Mutex; +use project::{ + CodeAction, Completion, CompletionIntent, CompletionSource, DocumentHighlight, InlayHint, + Location, LocationLink, PrepareRenameResponse, Project, ProjectItem, ProjectTransaction, + TaskSourceKind, + debugger::breakpoint_store::Breakpoint, + lsp_store::{CompletionDocumentation, FormatTrigger, LspFormatTarget, OpenLspBufferHandle}, project_settings::{GitGutterSetting, ProjectSettings}, }; -use rand::{seq::SliceRandom, thread_rng}; -use rpc::{ErrorCode, ErrorExt, proto::PeerId}; +use rand::prelude::*; +use rpc::{ErrorExt, proto::*}; use scroll::{Autoscroll, OngoingScroll, ScrollAnchor, ScrollManager, ScrollbarAutoHide}; use selections_collection::{ MutableSelectionsCollection, SelectionsCollection, resolve_selections, @@ -173,24 +180,21 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsLocation, SettingsStore, update_settings_file}; use smallvec::{SmallVec, smallvec}; use snippet::Snippet; +use std::sync::Arc; use std::{ any::TypeId, borrow::Cow, - cell::OnceCell, cell::RefCell, cmp::{self, Ordering, Reverse}, - iter::Peekable, mem, num::NonZeroU32, - ops::Not, ops::{ControlFlow, Deref, DerefMut, Range, RangeInclusive}, path::{Path, PathBuf}, rc::Rc, - sync::Arc, time::{Duration, Instant}, }; +pub use sum_tree::Bias; use sum_tree::TreeMap; -use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables}; use text::{BufferId, FromAnchor, OffsetUtf16, Rope}; use theme::{ ActiveTheme, PlayerColor, StatusColors, SyntaxTheme, Theme, ThemeSettings, @@ -209,11 +213,14 @@ use workspace::{ notifications::{DetachAndPromptErr, NotificationId, NotifyTaskExt}, searchable::SearchEvent, }; +use zed_actions; use crate::{ code_context_menus::CompletionsMenuSource, - editor_settings::MultiCursorModifier, hover_links::{find_url, find_url_from_range}, +}; +use crate::{ + editor_settings::MultiCursorModifier, signature_help::{SignatureHelpHiddenBy, SignatureHelpState}, }; @@ -6403,6 +6410,7 @@ impl Editor { IconButton::new("inline_code_actions", ui::IconName::BoltFilled) .icon_size(icon_size) .shape(ui::IconButtonShape::Square) + .style(ButtonStyle::Transparent) .icon_color(ui::Color::Hidden) .toggle_state(is_active) .when(show_tooltip, |this| { @@ -8337,29 +8345,26 @@ impl Editor { let color = Color::Muted; let position = breakpoint.as_ref().map(|(anchor, _, _)| *anchor); - IconButton::new( - ("run_indicator", row.0 as usize), - ui::IconName::PlayOutlined, - ) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(color) - .toggle_state(is_active) - .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| { - let quick_launch = e.down.button == MouseButton::Left; - window.focus(&editor.focus_handle(cx)); - editor.toggle_code_actions( - &ToggleCodeActions { - deployed_from: Some(CodeActionSource::RunMenu(row)), - quick_launch, - }, - window, - cx, - ); - })) - .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| { - editor.set_breakpoint_context_menu(row, position, event.down.position, window, cx); - })) + IconButton::new(("run_indicator", row.0 as usize), ui::IconName::Play) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(color) + .toggle_state(is_active) + .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| { + let quick_launch = e.down.button == MouseButton::Left; + window.focus(&editor.focus_handle(cx)); + editor.toggle_code_actions( + &ToggleCodeActions { + deployed_from: Some(CodeActionSource::RunMenu(row)), + quick_launch, + }, + window, + cx, + ); + })) + .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| { + editor.set_breakpoint_context_menu(row, position, event.down.position, window, cx); + })) } pub fn context_menu_visible(&self) -> bool { @@ -10901,6 +10906,17 @@ impl Editor { }); } + pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context<Self>) { + self.manipulate_text(window, cx, |text| { + let has_upper_case_characters = text.chars().any(|c| c.is_uppercase()); + if has_upper_case_characters { + text.to_lowercase() + } else { + text.to_uppercase() + } + }) + } + fn manipulate_immutable_lines<Fn>( &mut self, window: &mut Window, @@ -11156,26 +11172,6 @@ impl Editor { }) } - pub fn convert_to_sentence_case( - &mut self, - _: &ConvertToSentenceCase, - window: &mut Window, - cx: &mut Context<Self>, - ) { - self.manipulate_text(window, cx, |text| text.to_case(Case::Sentence)) - } - - pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context<Self>) { - self.manipulate_text(window, cx, |text| { - let has_upper_case_characters = text.chars().any(|c| c.is_uppercase()); - if has_upper_case_characters { - text.to_lowercase() - } else { - text.to_uppercase() - } - }) - } - pub fn convert_to_rot13( &mut self, _: &ConvertToRot13, @@ -17000,7 +16996,7 @@ impl Editor { now: Instant, window: &mut Window, cx: &mut Context<Self>, - ) -> Option<TransactionId> { + ) { self.end_selection(window, cx); if let Some(tx_id) = self .buffer @@ -17010,10 +17006,7 @@ impl Editor { .insert_transaction(tx_id, self.selections.disjoint_anchors()); cx.emit(EditorEvent::TransactionBegun { transaction_id: tx_id, - }); - Some(tx_id) - } else { - None + }) } } @@ -17041,17 +17034,6 @@ impl Editor { } } - pub fn modify_transaction_selection_history( - &mut self, - transaction_id: TransactionId, - modify: impl FnOnce(&mut (Arc<[Selection<Anchor>]>, Option<Arc<[Selection<Anchor>]>>)), - ) -> bool { - self.selection_history - .transaction_mut(transaction_id) - .map(modify) - .is_some() - } - pub fn set_mark(&mut self, _: &actions::SetMark, window: &mut Window, cx: &mut Context<Self>) { if self.selection_mark_mode { self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { @@ -21124,6 +21106,13 @@ fn process_completion_for_edit( .is_le(), "replace_range should start before or at cursor position" ); + debug_assert!( + insert_range + .end + .cmp(&cursor_position, &buffer_snapshot) + .is_le(), + "insert_range should end before or at cursor position" + ); let should_replace = match intent { CompletionIntent::CompleteWithInsert => false, @@ -22297,7 +22286,7 @@ fn consume_contiguous_rows( selections: &mut Peekable<std::slice::Iter<Selection<Point>>>, ) -> (MultiBufferRow, MultiBufferRow) { contiguous_row_selections.push(selection.clone()); - let start_row = starting_row(selection, display_map); + let start_row = MultiBufferRow(selection.start.row); let mut end_row = ending_row(selection, display_map); while let Some(next_selection) = selections.peek() { @@ -22311,14 +22300,6 @@ fn consume_contiguous_rows( (start_row, end_row) } -fn starting_row(selection: &Selection<Point>, display_map: &DisplaySnapshot) -> MultiBufferRow { - if selection.start.column > 0 { - MultiBufferRow(display_map.prev_line_boundary(selection.start).0.row) - } else { - MultiBufferRow(selection.start.row) - } -} - fn ending_row(next_selection: &Selection<Point>, display_map: &DisplaySnapshot) -> MultiBufferRow { if next_selection.end.column > 0 || next_selection.is_empty() { MultiBufferRow(display_map.next_line_boundary(next_selection.end).0.row + 1) diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 1a4f444275..42daf14615 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -4724,23 +4724,6 @@ async fn test_toggle_case(cx: &mut TestAppContext) { "}); } -#[gpui::test] -async fn test_convert_to_sentence_case(cx: &mut TestAppContext) { - init_test(cx, |_| {}); - - let mut cx = EditorTestContext::new(cx).await; - - cx.set_state(indoc! {" - «implement-windows-supportˇ» - "}); - cx.update_editor(|e, window, cx| { - e.convert_to_sentence_case(&ConvertToSentenceCase, window, cx) - }); - cx.assert_editor_state(indoc! {" - «Implement windows supportˇ» - "}); -} - #[gpui::test] async fn test_manipulate_text(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -5086,33 +5069,6 @@ fn test_move_line_up_down(cx: &mut TestAppContext) { }); } -#[gpui::test] -fn test_move_line_up_selection_at_end_of_fold(cx: &mut TestAppContext) { - init_test(cx, |_| {}); - let editor = cx.add_window(|window, cx| { - let buffer = MultiBuffer::build_simple("\n\n\n\n\n\naaaa\nbbbb\ncccc", cx); - build_editor(buffer, window, cx) - }); - _ = editor.update(cx, |editor, window, cx| { - editor.fold_creases( - vec![Crease::simple( - Point::new(6, 4)..Point::new(7, 4), - FoldPlaceholder::test(), - )], - true, - window, - cx, - ); - editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_ranges([Point::new(7, 4)..Point::new(7, 4)]) - }); - assert_eq!(editor.display_text(cx), "\n\n\n\n\n\naaaa⋯\ncccc"); - editor.move_line_up(&MoveLineUp, window, cx); - let buffer_text = editor.buffer.read(cx).snapshot(cx).text(); - assert_eq!(buffer_text, "\n\n\n\n\naaaa\nbbbb\n\ncccc"); - }); -} - #[gpui::test] fn test_move_line_up_down_with_blocks(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -8612,7 +8568,6 @@ async fn test_autoclose_with_embedded_language(cx: &mut TestAppContext) { cx.language_registry().add(html_language.clone()); cx.language_registry().add(javascript_language.clone()); - cx.executor().run_until_parked(); cx.update_buffer(|buffer, cx| { buffer.set_language(Some(html_language), cx); @@ -22836,435 +22791,6 @@ async fn test_indent_on_newline_for_python(cx: &mut TestAppContext) { "}); } -#[gpui::test] -async fn test_tab_in_leading_whitespace_auto_indents_for_bash(cx: &mut TestAppContext) { - init_test(cx, |_| {}); - - let mut cx = EditorTestContext::new(cx).await; - let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); - cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); - - // test cursor move to start of each line on tab - // for `if`, `elif`, `else`, `while`, `for`, `case` and `function` - cx.set_state(indoc! {" - function main() { - ˇ for item in $items; do - ˇ while [ -n \"$item\" ]; do - ˇ if [ \"$value\" -gt 10 ]; then - ˇ continue - ˇ elif [ \"$value\" -lt 0 ]; then - ˇ break - ˇ else - ˇ echo \"$item\" - ˇ fi - ˇ done - ˇ done - ˇ} - "}); - cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx)); - cx.assert_editor_state(indoc! {" - function main() { - ˇfor item in $items; do - ˇwhile [ -n \"$item\" ]; do - ˇif [ \"$value\" -gt 10 ]; then - ˇcontinue - ˇelif [ \"$value\" -lt 0 ]; then - ˇbreak - ˇelse - ˇecho \"$item\" - ˇfi - ˇdone - ˇdone - ˇ} - "}); - // test relative indent is preserved when tab - cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx)); - cx.assert_editor_state(indoc! {" - function main() { - ˇfor item in $items; do - ˇwhile [ -n \"$item\" ]; do - ˇif [ \"$value\" -gt 10 ]; then - ˇcontinue - ˇelif [ \"$value\" -lt 0 ]; then - ˇbreak - ˇelse - ˇecho \"$item\" - ˇfi - ˇdone - ˇdone - ˇ} - "}); - - // test cursor move to start of each line on tab - // for `case` statement with patterns - cx.set_state(indoc! {" - function handle() { - ˇ case \"$1\" in - ˇ start) - ˇ echo \"a\" - ˇ ;; - ˇ stop) - ˇ echo \"b\" - ˇ ;; - ˇ *) - ˇ echo \"c\" - ˇ ;; - ˇ esac - ˇ} - "}); - cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx)); - cx.assert_editor_state(indoc! {" - function handle() { - ˇcase \"$1\" in - ˇstart) - ˇecho \"a\" - ˇ;; - ˇstop) - ˇecho \"b\" - ˇ;; - ˇ*) - ˇecho \"c\" - ˇ;; - ˇesac - ˇ} - "}); -} - -#[gpui::test] -async fn test_indent_after_input_for_bash(cx: &mut TestAppContext) { - init_test(cx, |_| {}); - - let mut cx = EditorTestContext::new(cx).await; - let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); - cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); - - // test indents on comment insert - cx.set_state(indoc! {" - function main() { - ˇ for item in $items; do - ˇ while [ -n \"$item\" ]; do - ˇ if [ \"$value\" -gt 10 ]; then - ˇ continue - ˇ elif [ \"$value\" -lt 0 ]; then - ˇ break - ˇ else - ˇ echo \"$item\" - ˇ fi - ˇ done - ˇ done - ˇ} - "}); - cx.update_editor(|e, window, cx| e.handle_input("#", window, cx)); - cx.assert_editor_state(indoc! {" - function main() { - #ˇ for item in $items; do - #ˇ while [ -n \"$item\" ]; do - #ˇ if [ \"$value\" -gt 10 ]; then - #ˇ continue - #ˇ elif [ \"$value\" -lt 0 ]; then - #ˇ break - #ˇ else - #ˇ echo \"$item\" - #ˇ fi - #ˇ done - #ˇ done - #ˇ} - "}); -} - -#[gpui::test] -async fn test_outdent_after_input_for_bash(cx: &mut TestAppContext) { - init_test(cx, |_| {}); - - let mut cx = EditorTestContext::new(cx).await; - let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); - cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); - - // test `else` auto outdents when typed inside `if` block - cx.set_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - echo \"foo bar\" - ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.handle_input("else", window, cx); - }); - cx.assert_editor_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - echo \"foo bar\" - elseˇ - "}); - - // test `elif` auto outdents when typed inside `if` block - cx.set_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - echo \"foo bar\" - ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.handle_input("elif", window, cx); - }); - cx.assert_editor_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - echo \"foo bar\" - elifˇ - "}); - - // test `fi` auto outdents when typed inside `else` block - cx.set_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - echo \"foo bar\" - else - echo \"bar baz\" - ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.handle_input("fi", window, cx); - }); - cx.assert_editor_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - echo \"foo bar\" - else - echo \"bar baz\" - fiˇ - "}); - - // test `done` auto outdents when typed inside `while` block - cx.set_state(indoc! {" - while read line; do - echo \"$line\" - ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.handle_input("done", window, cx); - }); - cx.assert_editor_state(indoc! {" - while read line; do - echo \"$line\" - doneˇ - "}); - - // test `done` auto outdents when typed inside `for` block - cx.set_state(indoc! {" - for file in *.txt; do - cat \"$file\" - ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.handle_input("done", window, cx); - }); - cx.assert_editor_state(indoc! {" - for file in *.txt; do - cat \"$file\" - doneˇ - "}); - - // test `esac` auto outdents when typed inside `case` block - cx.set_state(indoc! {" - case \"$1\" in - start) - echo \"foo bar\" - ;; - stop) - echo \"bar baz\" - ;; - ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.handle_input("esac", window, cx); - }); - cx.assert_editor_state(indoc! {" - case \"$1\" in - start) - echo \"foo bar\" - ;; - stop) - echo \"bar baz\" - ;; - esacˇ - "}); - - // test `*)` auto outdents when typed inside `case` block - cx.set_state(indoc! {" - case \"$1\" in - start) - echo \"foo bar\" - ;; - ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.handle_input("*)", window, cx); - }); - cx.assert_editor_state(indoc! {" - case \"$1\" in - start) - echo \"foo bar\" - ;; - *)ˇ - "}); - - // test `fi` outdents to correct level with nested if blocks - cx.set_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - echo \"outer if\" - if [ \"$2\" = \"debug\" ]; then - echo \"inner if\" - ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.handle_input("fi", window, cx); - }); - cx.assert_editor_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - echo \"outer if\" - if [ \"$2\" = \"debug\" ]; then - echo \"inner if\" - fiˇ - "}); -} - -#[gpui::test] -async fn test_indent_on_newline_for_bash(cx: &mut TestAppContext) { - init_test(cx, |_| {}); - update_test_language_settings(cx, |settings| { - settings.defaults.extend_comment_on_newline = Some(false); - }); - let mut cx = EditorTestContext::new(cx).await; - let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); - cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); - - // test correct indent after newline on comment - cx.set_state(indoc! {" - # COMMENT:ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.newline(&Newline, window, cx); - }); - cx.assert_editor_state(indoc! {" - # COMMENT: - ˇ - "}); - - // test correct indent after newline after `then` - cx.set_state(indoc! {" - - if [ \"$1\" = \"test\" ]; thenˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.newline(&Newline, window, cx); - }); - cx.run_until_parked(); - cx.assert_editor_state(indoc! {" - - if [ \"$1\" = \"test\" ]; then - ˇ - "}); - - // test correct indent after newline after `else` - cx.set_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - elseˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.newline(&Newline, window, cx); - }); - cx.run_until_parked(); - cx.assert_editor_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - else - ˇ - "}); - - // test correct indent after newline after `elif` - cx.set_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - elifˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.newline(&Newline, window, cx); - }); - cx.run_until_parked(); - cx.assert_editor_state(indoc! {" - if [ \"$1\" = \"test\" ]; then - elif - ˇ - "}); - - // test correct indent after newline after `do` - cx.set_state(indoc! {" - for file in *.txt; doˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.newline(&Newline, window, cx); - }); - cx.run_until_parked(); - cx.assert_editor_state(indoc! {" - for file in *.txt; do - ˇ - "}); - - // test correct indent after newline after case pattern - cx.set_state(indoc! {" - case \"$1\" in - start)ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.newline(&Newline, window, cx); - }); - cx.run_until_parked(); - cx.assert_editor_state(indoc! {" - case \"$1\" in - start) - ˇ - "}); - - // test correct indent after newline after case pattern - cx.set_state(indoc! {" - case \"$1\" in - start) - ;; - *)ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.newline(&Newline, window, cx); - }); - cx.run_until_parked(); - cx.assert_editor_state(indoc! {" - case \"$1\" in - start) - ;; - *) - ˇ - "}); - - // test correct indent after newline after function opening brace - cx.set_state(indoc! {" - function test() {ˇ} - "}); - cx.update_editor(|editor, window, cx| { - editor.newline(&Newline, window, cx); - }); - cx.run_until_parked(); - cx.assert_editor_state(indoc! {" - function test() { - ˇ - } - "}); - - // test no extra indent after semicolon on same line - cx.set_state(indoc! {" - echo \"test\";ˇ - "}); - cx.update_editor(|editor, window, cx| { - editor.newline(&Newline, window, cx); - }); - cx.run_until_parked(); - cx.assert_editor_state(indoc! {" - echo \"test\"; - ˇ - "}); -} - fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> { let point = DisplayPoint::new(DisplayRow(row as u32), column as u32); point..point diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 7e77f113ac..1b372a7d53 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -230,6 +230,7 @@ impl EditorElement { register_action(editor, window, Editor::sort_lines_case_insensitive); register_action(editor, window, Editor::reverse_lines); register_action(editor, window, Editor::shuffle_lines); + register_action(editor, window, Editor::toggle_case); register_action(editor, window, Editor::convert_indentation_to_spaces); register_action(editor, window, Editor::convert_indentation_to_tabs); register_action(editor, window, Editor::convert_to_upper_case); @@ -240,8 +241,6 @@ impl EditorElement { register_action(editor, window, Editor::convert_to_upper_camel_case); register_action(editor, window, Editor::convert_to_lower_camel_case); register_action(editor, window, Editor::convert_to_opposite_case); - register_action(editor, window, Editor::convert_to_sentence_case); - register_action(editor, window, Editor::toggle_case); register_action(editor, window, Editor::convert_to_rot13); register_action(editor, window, Editor::convert_to_rot47); register_action(editor, window, Editor::delete_to_previous_word_start); @@ -4011,7 +4010,6 @@ impl EditorElement { let available_width = hitbox.bounds.size.width - right_margin; let mut header = v_flex() - .w_full() .relative() .child( div() @@ -7944,11 +7942,17 @@ impl Element for EditorElement { right: right_margin, }; + // Offset the content_bounds from the text_bounds by the gutter margin (which + // is roughly half a character wide) to make hit testing work more like how we want. + let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO); + + let editor_content_width = editor_width - content_offset.x; + snapshot = self.editor.update(cx, |editor, cx| { editor.last_bounds = Some(bounds); editor.gutter_dimensions = gutter_dimensions; editor.set_visible_line_count(bounds.size.height / line_height, window, cx); - editor.set_visible_column_count(editor_width / em_advance); + editor.set_visible_column_count(editor_content_width / em_advance); if matches!( editor.mode, @@ -7960,10 +7964,10 @@ impl Element for EditorElement { let wrap_width = match editor.soft_wrap_mode(cx) { SoftWrap::GitDiff => None, SoftWrap::None => Some(wrap_width_for(MAX_LINE_LEN as u32 / 2)), - SoftWrap::EditorWidth => Some(editor_width), + SoftWrap::EditorWidth => Some(editor_content_width), SoftWrap::Column(column) => Some(wrap_width_for(column)), SoftWrap::Bounded(column) => { - Some(editor_width.min(wrap_width_for(column))) + Some(editor_content_width.min(wrap_width_for(column))) } }; @@ -7988,12 +7992,13 @@ impl Element for EditorElement { HitboxBehavior::Normal, ); - // Offset the content_bounds from the text_bounds by the gutter margin (which - // is roughly half a character wide) to make hit testing work more like how we want. - let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO); let content_origin = text_hitbox.origin + content_offset; - let height_in_lines = bounds.size.height / line_height; + let editor_text_bounds = + Bounds::from_corners(content_origin, bounds.bottom_right()); + + let height_in_lines = editor_text_bounds.size.height / line_height; + let max_row = snapshot.max_point().row().as_f32(); // The max scroll position for the top of the window @@ -8377,6 +8382,7 @@ impl Element for EditorElement { glyph_grid_cell, size(longest_line_width, max_row.as_f32() * line_height), longest_line_blame_width, + editor_width, EditorSettings::get_global(cx), ); @@ -8448,7 +8454,7 @@ impl Element for EditorElement { MultiBufferRow(end_anchor.to_point(&snapshot.buffer_snapshot).row); let scroll_max = point( - ((scroll_width - editor_width) / em_advance).max(0.0), + ((scroll_width - editor_content_width) / em_advance).max(0.0), max_scroll_top, ); @@ -8460,7 +8466,7 @@ impl Element for EditorElement { if needs_horizontal_autoscroll.0 && let Some(new_scroll_position) = editor.autoscroll_horizontally( start_row, - editor_width, + editor_content_width, scroll_width, em_advance, &line_layouts, @@ -9041,6 +9047,7 @@ impl ScrollbarLayoutInformation { glyph_grid_cell: Size<Pixels>, document_size: Size<Pixels>, longest_line_blame_width: Pixels, + editor_width: Pixels, settings: &EditorSettings, ) -> Self { let vertical_overscroll = match settings.scroll_beyond_last_line { @@ -9051,11 +9058,19 @@ impl ScrollbarLayoutInformation { } }; - let overscroll = size(longest_line_blame_width, vertical_overscroll); + let right_margin = if document_size.width + longest_line_blame_width >= editor_width { + glyph_grid_cell.width + } else { + px(0.0) + }; + + let overscroll = size(right_margin + longest_line_blame_width, vertical_overscroll); + + let scroll_range = document_size + overscroll; ScrollbarLayoutInformation { editor_bounds, - scroll_range: document_size + overscroll, + scroll_range, glyph_grid_cell, } } @@ -9160,7 +9175,7 @@ struct EditorScrollbars { impl EditorScrollbars { pub fn from_scrollbar_axes( - show_scrollbar: ScrollbarAxes, + settings_visibility: ScrollbarAxes, layout_information: &ScrollbarLayoutInformation, content_offset: gpui::Point<Pixels>, scroll_position: gpui::Point<f32>, @@ -9198,13 +9213,22 @@ impl EditorScrollbars { }; let mut create_scrollbar_layout = |axis| { - let viewport_size = viewport_size.along(axis); - let scroll_range = scroll_range.along(axis); - - // We always want a vertical scrollbar track for scrollbar diagnostic visibility. - (show_scrollbar.along(axis) - && (axis == ScrollbarAxis::Vertical || scroll_range > viewport_size)) + settings_visibility + .along(axis) .then(|| { + ( + viewport_size.along(axis) - content_offset.along(axis), + scroll_range.along(axis), + ) + }) + .filter(|(viewport_size, scroll_range)| { + // The scrollbar should only be rendered if the content does + // not entirely fit into the editor + // However, this only applies to the horizontal scrollbar, as information about the + // vertical scrollbar layout is always needed for scrollbar diagnostics. + axis != ScrollbarAxis::Horizontal || viewport_size < scroll_range + }) + .map(|(viewport_size, scroll_range)| { ScrollbarLayout::new( window.insert_hitbox(scrollbar_bounds_for(axis), HitboxBehavior::Normal), viewport_size, diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index a0214c76a1..d5db7f71a4 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -19,8 +19,8 @@ path = "src/explorer.rs" [dependencies] agent.workspace = true -agent_settings.workspace = true agent_ui.workspace = true +agent_settings.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true @@ -29,7 +29,6 @@ buffer_diff.workspace = true chrono.workspace = true clap.workspace = true client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true debug_adapter_extension.workspace = true dirs.workspace = true @@ -69,3 +68,4 @@ util.workspace = true uuid.workspace = true watch.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index d638ac171f..a02b4a7f0b 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -18,7 +18,7 @@ use collections::{HashMap, HashSet}; use extension::ExtensionHostProxy; use futures::future; use gpui::http_client::read_proxy_from_env; -use gpui::{App, AppContext, Application, AsyncApp, Entity, UpdateGlobal}; +use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal}; use gpui_tokio::Tokio; use language::LanguageRegistry; use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel}; @@ -337,8 +337,7 @@ pub struct AgentAppState { } pub fn init(cx: &mut App) -> Arc<AgentAppState> { - let app_version = AppVersion::global(cx); - release_channel::init(app_version, cx); + release_channel::init(SemanticVersion::default(), cx); gpui_tokio::init(cx); let mut settings_store = SettingsStore::new(cx); @@ -351,7 +350,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> { // Set User-Agent so we can download language servers from GitHub let user_agent = format!( "Zed/{} ({}; {})", - app_version, + AppVersion::global(cx), std::env::consts::OS, std::env::consts::ARCH ); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 23c8814916..7ce3b1fdf1 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -15,11 +15,11 @@ use agent_settings::AgentProfileId; use anyhow::{Result, anyhow}; use async_trait::async_trait; use buffer_diff::DiffHunkStatus; -use cloud_llm_client::CompletionIntent; use collections::HashMap; use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased}; use gpui::{App, AppContext, AsyncApp, Entity}; use language_model::{LanguageModel, Role, StopReason}; +use zed_llm_client::CompletionIntent; pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); diff --git a/crates/extension/Cargo.toml b/crates/extension/Cargo.toml index 42189f20b3..4fc7da2dca 100644 --- a/crates/extension/Cargo.toml +++ b/crates/extension/Cargo.toml @@ -32,11 +32,7 @@ serde.workspace = true serde_json.workspace = true task.workspace = true toml.workspace = true -url.workspace = true util.workspace = true wasm-encoder.workspace = true wasmparser.workspace = true workspace-hack.workspace = true - -[dev-dependencies] -pretty_assertions.workspace = true diff --git a/crates/extension/src/capabilities.rs b/crates/extension/src/capabilities.rs deleted file mode 100644 index b8afc4ec06..0000000000 --- a/crates/extension/src/capabilities.rs +++ /dev/null @@ -1,20 +0,0 @@ -mod download_file_capability; -mod npm_install_package_capability; -mod process_exec_capability; - -pub use download_file_capability::*; -pub use npm_install_package_capability::*; -pub use process_exec_capability::*; - -use serde::{Deserialize, Serialize}; - -/// A capability for an extension. -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -#[serde(tag = "kind", rename_all = "snake_case")] -pub enum ExtensionCapability { - #[serde(rename = "process:exec")] - ProcessExec(ProcessExecCapability), - DownloadFile(DownloadFileCapability), - #[serde(rename = "npm:install")] - NpmInstallPackage(NpmInstallPackageCapability), -} diff --git a/crates/extension/src/capabilities/download_file_capability.rs b/crates/extension/src/capabilities/download_file_capability.rs deleted file mode 100644 index a76755b593..0000000000 --- a/crates/extension/src/capabilities/download_file_capability.rs +++ /dev/null @@ -1,121 +0,0 @@ -use serde::{Deserialize, Serialize}; -use url::Url; - -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub struct DownloadFileCapability { - pub host: String, - pub path: Vec<String>, -} - -impl DownloadFileCapability { - /// Returns whether the capability allows downloading a file from the given URL. - pub fn allows(&self, url: &Url) -> bool { - let Some(desired_host) = url.host_str() else { - return false; - }; - - let Some(desired_path) = url.path_segments() else { - return false; - }; - let desired_path = desired_path.collect::<Vec<_>>(); - - if self.host != desired_host && self.host != "*" { - return false; - } - - for (ix, path_segment) in self.path.iter().enumerate() { - if path_segment == "**" { - return true; - } - - if ix >= desired_path.len() { - return false; - } - - if path_segment != "*" && path_segment != desired_path[ix] { - return false; - } - } - - if self.path.len() < desired_path.len() { - return false; - } - - true - } -} - -#[cfg(test)] -mod tests { - use pretty_assertions::assert_eq; - - use super::*; - - #[test] - fn test_allows() { - let capability = DownloadFileCapability { - host: "*".to_string(), - path: vec!["**".to_string()], - }; - assert_eq!( - capability.allows(&"https://example.com/some/path".parse().unwrap()), - true - ); - - let capability = DownloadFileCapability { - host: "github.com".to_string(), - path: vec!["**".to_string()], - }; - assert_eq!( - capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()), - true - ); - assert_eq!( - capability.allows( - &"https://fake-github.com/some-owner/some-repo" - .parse() - .unwrap() - ), - false - ); - - let capability = DownloadFileCapability { - host: "github.com".to_string(), - path: vec!["specific-owner".to_string(), "*".to_string()], - }; - assert_eq!( - capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()), - false - ); - assert_eq!( - capability.allows( - &"https://github.com/specific-owner/some-repo" - .parse() - .unwrap() - ), - true - ); - - let capability = DownloadFileCapability { - host: "github.com".to_string(), - path: vec!["specific-owner".to_string(), "*".to_string()], - }; - assert_eq!( - capability.allows( - &"https://github.com/some-owner/some-repo/extra" - .parse() - .unwrap() - ), - false - ); - assert_eq!( - capability.allows( - &"https://github.com/specific-owner/some-repo/extra" - .parse() - .unwrap() - ), - false - ); - } -} diff --git a/crates/extension/src/capabilities/npm_install_package_capability.rs b/crates/extension/src/capabilities/npm_install_package_capability.rs deleted file mode 100644 index 287645fc75..0000000000 --- a/crates/extension/src/capabilities/npm_install_package_capability.rs +++ /dev/null @@ -1,39 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub struct NpmInstallPackageCapability { - pub package: String, -} - -impl NpmInstallPackageCapability { - /// Returns whether the capability allows installing the given NPM package. - pub fn allows(&self, package: &str) -> bool { - self.package == "*" || self.package == package - } -} - -#[cfg(test)] -mod tests { - use pretty_assertions::assert_eq; - - use super::*; - - #[test] - fn test_allows() { - let capability = NpmInstallPackageCapability { - package: "*".to_string(), - }; - assert_eq!(capability.allows("package"), true); - - let capability = NpmInstallPackageCapability { - package: "react".to_string(), - }; - assert_eq!(capability.allows("react"), true); - - let capability = NpmInstallPackageCapability { - package: "react".to_string(), - }; - assert_eq!(capability.allows("malicious-package"), false); - } -} diff --git a/crates/extension/src/capabilities/process_exec_capability.rs b/crates/extension/src/capabilities/process_exec_capability.rs deleted file mode 100644 index 053a7b212b..0000000000 --- a/crates/extension/src/capabilities/process_exec_capability.rs +++ /dev/null @@ -1,116 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub struct ProcessExecCapability { - /// The command to execute. - pub command: String, - /// The arguments to pass to the command. Use `*` for a single wildcard argument. - /// If the last element is `**`, then any trailing arguments are allowed. - pub args: Vec<String>, -} - -impl ProcessExecCapability { - /// Returns whether the capability allows the given command and arguments. - pub fn allows( - &self, - desired_command: &str, - desired_args: &[impl AsRef<str> + std::fmt::Debug], - ) -> bool { - if self.command != desired_command && self.command != "*" { - return false; - } - - for (ix, arg) in self.args.iter().enumerate() { - if arg == "**" { - return true; - } - - if ix >= desired_args.len() { - return false; - } - - if arg != "*" && arg != desired_args[ix].as_ref() { - return false; - } - } - - if self.args.len() < desired_args.len() { - return false; - } - - true - } -} - -#[cfg(test)] -mod tests { - use pretty_assertions::assert_eq; - - use super::*; - - #[test] - fn test_allows_with_exact_match() { - let capability = ProcessExecCapability { - command: "ls".to_string(), - args: vec!["-la".to_string()], - }; - - assert_eq!(capability.allows("ls", &["-la"]), true); - assert_eq!(capability.allows("ls", &["-l"]), false); - assert_eq!(capability.allows("pwd", &[] as &[&str]), false); - } - - #[test] - fn test_allows_with_wildcard_arg() { - let capability = ProcessExecCapability { - command: "git".to_string(), - args: vec!["*".to_string()], - }; - - assert_eq!(capability.allows("git", &["status"]), true); - assert_eq!(capability.allows("git", &["commit"]), true); - // Too many args. - assert_eq!(capability.allows("git", &["status", "-s"]), false); - // Wrong command. - assert_eq!(capability.allows("npm", &["install"]), false); - } - - #[test] - fn test_allows_with_double_wildcard() { - let capability = ProcessExecCapability { - command: "cargo".to_string(), - args: vec!["test".to_string(), "**".to_string()], - }; - - assert_eq!(capability.allows("cargo", &["test"]), true); - assert_eq!(capability.allows("cargo", &["test", "--all"]), true); - assert_eq!( - capability.allows("cargo", &["test", "--all", "--no-fail-fast"]), - true - ); - // Wrong first arg. - assert_eq!(capability.allows("cargo", &["build"]), false); - } - - #[test] - fn test_allows_with_mixed_wildcards() { - let capability = ProcessExecCapability { - command: "docker".to_string(), - args: vec!["run".to_string(), "*".to_string(), "**".to_string()], - }; - - assert_eq!(capability.allows("docker", &["run", "nginx"]), true); - assert_eq!(capability.allows("docker", &["run"]), false); - assert_eq!( - capability.allows("docker", &["run", "ubuntu", "bash"]), - true - ); - assert_eq!( - capability.allows("docker", &["run", "alpine", "sh", "-c", "echo hello"]), - true - ); - // Wrong first arg. - assert_eq!(capability.allows("docker", &["ps"]), false); - } -} diff --git a/crates/extension/src/extension.rs b/crates/extension/src/extension.rs index 35f7f41938..8b150e19b9 100644 --- a/crates/extension/src/extension.rs +++ b/crates/extension/src/extension.rs @@ -1,4 +1,3 @@ -mod capabilities; pub mod extension_builder; mod extension_events; mod extension_host_proxy; @@ -17,7 +16,6 @@ use language::LanguageName; use semantic_version::SemanticVersion; use task::{SpawnInTerminal, ZedDebugConfig}; -pub use crate::capabilities::*; pub use crate::extension_events::*; pub use crate::extension_host_proxy::*; pub use crate::extension_manifest::*; diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index 5852b3e3fc..0a14923c0c 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -12,8 +12,6 @@ use std::{ sync::Arc, }; -use crate::ExtensionCapability; - /// This is the old version of the extension manifest, from when it was `extension.json`. #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct OldExtensionManifest { @@ -102,8 +100,24 @@ impl ExtensionManifest { desired_args: &[impl AsRef<str> + std::fmt::Debug], ) -> Result<()> { let is_allowed = self.capabilities.iter().any(|capability| match capability { - ExtensionCapability::ProcessExec(capability) => { - capability.allows(desired_command, desired_args) + ExtensionCapability::ProcessExec { command, args } if command == desired_command => { + for (ix, arg) in args.iter().enumerate() { + if arg == "**" { + return true; + } + + if ix >= desired_args.len() { + return false; + } + + if arg != "*" && arg != desired_args[ix].as_ref() { + return false; + } + } + if args.len() < desired_args.len() { + return false; + } + true } _ => false, }); @@ -134,6 +148,20 @@ pub fn build_debug_adapter_schema_path( }) } +/// A capability for an extension. +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[serde(tag = "kind")] +pub enum ExtensionCapability { + #[serde(rename = "process:exec")] + ProcessExec { + /// The command to execute. + command: String, + /// The arguments to pass to the command. Use `*` for a single wildcard argument. + /// If the last element is `**`, then any trailing arguments are allowed. + args: Vec<String>, + }, +} + #[derive(Clone, Default, PartialEq, Eq, Debug, Deserialize, Serialize)] pub struct LibManifestEntry { pub kind: Option<ExtensionLibraryKind>, @@ -163,7 +191,7 @@ pub struct LanguageServerManifestEntry { #[serde(default)] languages: Vec<LanguageName>, #[serde(default)] - pub language_ids: HashMap<LanguageName, String>, + pub language_ids: HashMap<String, String>, #[serde(default)] pub code_action_kinds: Option<Vec<lsp::CodeActionKind>>, } @@ -281,10 +309,6 @@ fn manifest_from_old_manifest( #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; - - use crate::ProcessExecCapability; - use super::*; fn extension_manifest() -> ExtensionManifest { @@ -336,12 +360,12 @@ mod tests { } #[test] - fn test_allow_exec_exact_match() { + fn test_allow_exact_match() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + capabilities: vec![ExtensionCapability::ProcessExec { command: "ls".to_string(), args: vec!["-la".to_string()], - })], + }], ..extension_manifest() }; @@ -351,12 +375,12 @@ mod tests { } #[test] - fn test_allow_exec_wildcard_arg() { + fn test_allow_wildcard_arg() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + capabilities: vec![ExtensionCapability::ProcessExec { command: "git".to_string(), args: vec!["*".to_string()], - })], + }], ..extension_manifest() }; @@ -367,12 +391,12 @@ mod tests { } #[test] - fn test_allow_exec_double_wildcard() { + fn test_allow_double_wildcard() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + capabilities: vec![ExtensionCapability::ProcessExec { command: "cargo".to_string(), args: vec!["test".to_string(), "**".to_string()], - })], + }], ..extension_manifest() }; @@ -387,12 +411,12 @@ mod tests { } #[test] - fn test_allow_exec_mixed_wildcards() { + fn test_allow_mixed_wildcards() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + capabilities: vec![ExtensionCapability::ProcessExec { command: "docker".to_string(), args: vec!["run".to_string(), "*".to_string(), "**".to_string()], - })], + }], ..extension_manifest() }; diff --git a/crates/extension_host/benches/extension_compilation_benchmark.rs b/crates/extension_host/benches/extension_compilation_benchmark.rs index a4fa9bfeff..9d867af041 100644 --- a/crates/extension_host/benches/extension_compilation_benchmark.rs +++ b/crates/extension_host/benches/extension_compilation_benchmark.rs @@ -134,12 +134,10 @@ fn manifest() -> ExtensionManifest { slash_commands: BTreeMap::default(), indexed_docs_providers: BTreeMap::default(), snippets: None, - capabilities: vec![ExtensionCapability::ProcessExec( - extension::ProcessExecCapability { - command: "echo".into(), - args: vec!["hello!".into()], - }, - )], + capabilities: vec![ExtensionCapability::ProcessExec { + command: "echo".into(), + args: vec!["hello!".into()], + }], debug_adapters: Default::default(), debug_locators: Default::default(), } diff --git a/crates/extension_host/src/capability_granter.rs b/crates/extension_host/src/capability_granter.rs deleted file mode 100644 index c77e5ecba1..0000000000 --- a/crates/extension_host/src/capability_granter.rs +++ /dev/null @@ -1,153 +0,0 @@ -use std::sync::Arc; - -use anyhow::{Result, bail}; -use extension::{ExtensionCapability, ExtensionManifest}; -use url::Url; - -pub struct CapabilityGranter { - granted_capabilities: Vec<ExtensionCapability>, - manifest: Arc<ExtensionManifest>, -} - -impl CapabilityGranter { - pub fn new( - granted_capabilities: Vec<ExtensionCapability>, - manifest: Arc<ExtensionManifest>, - ) -> Self { - Self { - granted_capabilities, - manifest, - } - } - - pub fn grant_exec( - &self, - desired_command: &str, - desired_args: &[impl AsRef<str> + std::fmt::Debug], - ) -> Result<()> { - self.manifest.allow_exec(desired_command, desired_args)?; - - let is_allowed = self - .granted_capabilities - .iter() - .any(|capability| match capability { - ExtensionCapability::ProcessExec(capability) => { - capability.allows(desired_command, desired_args) - } - _ => false, - }); - - if !is_allowed { - bail!( - "capability for process:exec {desired_command} {desired_args:?} is not granted by the extension host", - ); - } - - Ok(()) - } - - pub fn grant_download_file(&self, desired_url: &Url) -> Result<()> { - let is_allowed = self - .granted_capabilities - .iter() - .any(|capability| match capability { - ExtensionCapability::DownloadFile(capability) => capability.allows(desired_url), - _ => false, - }); - - if !is_allowed { - bail!( - "capability for download_file {desired_url} is not granted by the extension host", - ); - } - - Ok(()) - } - - pub fn grant_npm_install_package(&self, package_name: &str) -> Result<()> { - let is_allowed = self - .granted_capabilities - .iter() - .any(|capability| match capability { - ExtensionCapability::NpmInstallPackage(capability) => { - capability.allows(package_name) - } - _ => false, - }); - - if !is_allowed { - bail!("capability for npm:install {package_name} is not granted by the extension host",); - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeMap; - - use extension::{ProcessExecCapability, SchemaVersion}; - - use super::*; - - fn extension_manifest() -> ExtensionManifest { - ExtensionManifest { - id: "test".into(), - name: "Test".to_string(), - version: "1.0.0".into(), - schema_version: SchemaVersion::ZERO, - description: None, - repository: None, - authors: vec![], - lib: Default::default(), - themes: vec![], - icon_themes: vec![], - languages: vec![], - grammars: BTreeMap::default(), - language_servers: BTreeMap::default(), - context_servers: BTreeMap::default(), - slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), - snippets: None, - capabilities: vec![], - debug_adapters: Default::default(), - debug_locators: Default::default(), - } - } - - #[test] - fn test_grant_exec() { - let manifest = Arc::new(ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { - command: "ls".to_string(), - args: vec!["-la".to_string()], - })], - ..extension_manifest() - }); - - // It returns an error when the extension host has no granted capabilities. - let granter = CapabilityGranter::new(Vec::new(), manifest.clone()); - assert!(granter.grant_exec("ls", &["-la"]).is_err()); - - // It succeeds when the extension host has the exact capability. - let granter = CapabilityGranter::new( - vec![ExtensionCapability::ProcessExec(ProcessExecCapability { - command: "ls".to_string(), - args: vec!["-la".to_string()], - })], - manifest.clone(), - ); - assert!(granter.grant_exec("ls", &["-la"]).is_ok()); - - // It succeeds when the extension host has a wildcard capability. - let granter = CapabilityGranter::new( - vec![ExtensionCapability::ProcessExec(ProcessExecCapability { - command: "*".to_string(), - args: vec!["**".to_string()], - })], - manifest.clone(), - ); - assert!(granter.grant_exec("ls", &["-la"]).is_ok()); - } -} diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index dc38c244f1..fd64d3fa59 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -1,4 +1,3 @@ -mod capability_granter; pub mod extension_settings; pub mod headless_host; pub mod wasm_host; diff --git a/crates/extension_host/src/wasm_host.rs b/crates/extension_host/src/wasm_host.rs index d990b670f4..dcd52d0d02 100644 --- a/crates/extension_host/src/wasm_host.rs +++ b/crates/extension_host/src/wasm_host.rs @@ -1,15 +1,13 @@ pub mod wit; use crate::ExtensionManifest; -use crate::capability_granter::CapabilityGranter; use anyhow::{Context as _, Result, anyhow, bail}; use async_trait::async_trait; use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest}; use extension::{ CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary, - DebugTaskDefinition, DownloadFileCapability, ExtensionCapability, ExtensionHostProxy, - KeyValueStoreDelegate, NpmInstallPackageCapability, ProcessExecCapability, ProjectDelegate, - SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate, + DebugTaskDefinition, ExtensionHostProxy, KeyValueStoreDelegate, ProjectDelegate, SlashCommand, + SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate, }; use fs::{Fs, normalize_path}; use futures::future::LocalBoxFuture; @@ -52,8 +50,6 @@ pub struct WasmHost { pub(crate) proxy: Arc<ExtensionHostProxy>, fs: Arc<dyn Fs>, pub work_dir: PathBuf, - /// The capabilities granted to extensions running on the host. - pub(crate) granted_capabilities: Vec<ExtensionCapability>, _main_thread_message_task: Task<()>, main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>, } @@ -490,7 +486,6 @@ pub struct WasmState { pub table: ResourceTable, ctx: wasi::WasiCtx, pub host: Arc<WasmHost>, - pub(crate) capability_granter: CapabilityGranter, } type MainThreadCall = Box<dyn Send + for<'a> FnOnce(&'a mut AsyncApp) -> LocalBoxFuture<'a, ()>>; @@ -576,19 +571,6 @@ impl WasmHost { node_runtime, proxy, release_channel: ReleaseChannel::global(cx), - granted_capabilities: vec![ - ExtensionCapability::ProcessExec(ProcessExecCapability { - command: "*".to_string(), - args: vec!["**".to_string()], - }), - ExtensionCapability::DownloadFile(DownloadFileCapability { - host: "*".to_string(), - path: vec!["**".to_string()], - }), - ExtensionCapability::NpmInstallPackage(NpmInstallPackageCapability { - package: "*".to_string(), - }), - ], _main_thread_message_task: task, main_thread_message_tx: tx, }) @@ -615,10 +597,6 @@ impl WasmHost { manifest: manifest.clone(), table: ResourceTable::new(), host: this.clone(), - capability_granter: CapabilityGranter::new( - this.granted_capabilities.clone(), - manifest.clone(), - ), }, ); // Store will yield after 1 tick, and get a new deadline of 1 tick after each yield. diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs index 767b9033ad..d25328ee7f 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs @@ -30,7 +30,6 @@ use std::{ sync::{Arc, OnceLock}, }; use task::{SpawnInTerminal, ZedDebugConfig}; -use url::Url; use util::{archive::extract_zip, fs::make_file_executable, maybe}; use wasmtime::component::{Linker, Resource}; @@ -745,9 +744,6 @@ impl nodejs::Host for WasmState { package_name: String, version: String, ) -> wasmtime::Result<Result<(), String>> { - self.capability_granter - .grant_npm_install_package(&package_name)?; - self.host .node_runtime .npm_install_packages(&self.work_dir(), &[(&package_name, &version)]) @@ -851,8 +847,7 @@ impl process::Host for WasmState { command: process::Command, ) -> wasmtime::Result<Result<process::Output, String>> { maybe!(async { - self.capability_granter - .grant_exec(&command.command, &command.args)?; + self.manifest.allow_exec(&command.command, &command.args)?; let output = util::command::new_smol_command(command.command.as_str()) .args(&command.args) @@ -1015,9 +1010,6 @@ impl ExtensionImports for WasmState { file_type: DownloadedFileType, ) -> wasmtime::Result<Result<(), String>> { maybe!(async { - let parsed_url = Url::parse(&url)?; - self.capability_granter.grant_download_file(&parsed_url)?; - let path = PathBuf::from(path); let extension_work_dir = self.host.work_dir.join(self.manifest.id.as_ref()); diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 631bafc841..da85133bb9 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -85,11 +85,6 @@ impl FeatureFlag for ThreadAutoCaptureFeatureFlag { false } } -pub struct PanicFeatureFlag; - -impl FeatureFlag for PanicFeatureFlag { - const NAME: &'static str = "panic"; -} pub struct JjUiFeatureFlag {} diff --git a/crates/file_finder/src/file_finder.rs b/crates/file_finder/src/file_finder.rs index e5ac70bb58..a4d61dd56f 100644 --- a/crates/file_finder/src/file_finder.rs +++ b/crates/file_finder/src/file_finder.rs @@ -1404,21 +1404,14 @@ impl PickerDelegate for FileFinderDelegate { } else { let path_position = PathWithPosition::parse_str(&raw_query); - #[cfg(windows)] - let raw_query = raw_query.trim().to_owned().replace("/", "\\"); - #[cfg(not(windows))] - let raw_query = raw_query.trim().to_owned(); - - let file_query_end = if path_position.path.to_str().unwrap_or(&raw_query) == raw_query { - None - } else { - // Safe to unwrap as we won't get here when the unwrap in if fails - Some(path_position.path.to_str().unwrap().len()) - }; - let query = FileSearchQuery { - raw_query, - file_query_end, + raw_query: raw_query.trim().to_owned(), + file_query_end: if path_position.path.to_str().unwrap_or(raw_query) == raw_query { + None + } else { + // Safe to unwrap as we won't get here when the unwrap in if fails + Some(path_position.path.to_str().unwrap().len()) + }, path_position, }; diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 378a8fb7df..8a4f7c03bb 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -398,18 +398,6 @@ impl GitRepository for FakeGitRepository { }) } - fn stash_paths( - &self, - _paths: Vec<RepoPath>, - _env: Arc<HashMap<String, String>>, - ) -> BoxFuture<Result<()>> { - unimplemented!() - } - - fn stash_pop(&self, _env: Arc<HashMap<String, String>>) -> BoxFuture<Result<()>> { - unimplemented!() - } - fn commit( &self, _message: gpui::SharedString, diff --git a/crates/git/src/git.rs b/crates/git/src/git.rs index 553361e673..3714086dd0 100644 --- a/crates/git/src/git.rs +++ b/crates/git/src/git.rs @@ -55,10 +55,6 @@ actions!( StageAll, /// Unstages all changes in the repository. UnstageAll, - /// Stashes all changes in the repository, including untracked files. - StashAll, - /// Pops the most recent stash. - StashPop, /// Restores all tracked files to their last committed state. RestoreTrackedFiles, /// Moves all untracked files to trash. diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index a63315e69e..9cc3442392 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -395,14 +395,6 @@ pub trait GitRepository: Send + Sync { env: Arc<HashMap<String, String>>, ) -> BoxFuture<'_, Result<()>>; - fn stash_paths( - &self, - paths: Vec<RepoPath>, - env: Arc<HashMap<String, String>>, - ) -> BoxFuture<Result<()>>; - - fn stash_pop(&self, env: Arc<HashMap<String, String>>) -> BoxFuture<Result<()>>; - fn push( &self, branch_name: String, @@ -1197,55 +1189,6 @@ impl GitRepository for RealGitRepository { .boxed() } - fn stash_paths( - &self, - paths: Vec<RepoPath>, - env: Arc<HashMap<String, String>>, - ) -> BoxFuture<Result<()>> { - let working_directory = self.working_directory(); - self.executor - .spawn(async move { - let mut cmd = new_smol_command("git"); - cmd.current_dir(&working_directory?) - .envs(env.iter()) - .args(["stash", "push", "--quiet"]) - .arg("--include-untracked"); - - cmd.args(paths.iter().map(|p| p.as_ref())); - - let output = cmd.output().await?; - - anyhow::ensure!( - output.status.success(), - "Failed to stash:\n{}", - String::from_utf8_lossy(&output.stderr) - ); - Ok(()) - }) - .boxed() - } - - fn stash_pop(&self, env: Arc<HashMap<String, String>>) -> BoxFuture<Result<()>> { - let working_directory = self.working_directory(); - self.executor - .spawn(async move { - let mut cmd = new_smol_command("git"); - cmd.current_dir(&working_directory?) - .envs(env.iter()) - .args(["stash", "pop"]); - - let output = cmd.output().await?; - - anyhow::ensure!( - output.status.success(), - "Failed to stash pop:\n{}", - String::from_utf8_lossy(&output.stderr) - ); - Ok(()) - }) - .boxed() - } - fn commit( &self, message: SharedString, diff --git a/crates/git_hosting_providers/src/providers/github.rs b/crates/git_hosting_providers/src/providers/github.rs index 30f8d058a7..649b2f30ae 100644 --- a/crates/git_hosting_providers/src/providers/github.rs +++ b/crates/git_hosting_providers/src/providers/github.rs @@ -159,11 +159,7 @@ impl GitHostingProvider for Github { } let mut path_segments = url.path_segments()?; - let mut owner = path_segments.next()?; - if owner.is_empty() { - owner = path_segments.next()?; - } - + let owner = path_segments.next()?; let repo = path_segments.next()?.trim_end_matches(".git"); Some(ParsedGitRemote { @@ -248,22 +244,6 @@ mod tests { use super::*; - #[test] - fn test_remote_url_with_root_slash() { - let remote_url = "git@github.com:/zed-industries/zed"; - let parsed_remote = Github::public_instance() - .parse_remote_url(remote_url) - .unwrap(); - - assert_eq!( - parsed_remote, - ParsedGitRemote { - owner: "zed-industries".into(), - repo: "zed".into(), - } - ); - } - #[test] fn test_invalid_self_hosted_remote_url() { let remote_url = "git@github.com:zed-industries/zed.git"; diff --git a/crates/git_ui/Cargo.toml b/crates/git_ui/Cargo.toml index 4c919249ee..2fb80b7e73 100644 --- a/crates/git_ui/Cargo.toml +++ b/crates/git_ui/Cargo.toml @@ -24,7 +24,6 @@ buffer_diff.workspace = true call.workspace = true chrono.workspace = true client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true component.workspace = true @@ -63,6 +62,7 @@ watch.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true +zed_llm_client.workspace = true [target.'cfg(windows)'.dependencies] windows.workspace = true diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index ee74ac4d54..19e2712d7c 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -27,10 +27,7 @@ use git::repository::{ }; use git::status::StageStatus; use git::{Amend, Signoff, ToggleStaged, repository::RepoPath, status::FileStatus}; -use git::{ - ExpandCommitEditor, RestoreTrackedFiles, StageAll, StashAll, StashPop, TrashUntrackedFiles, - UnstageAll, -}; +use git::{ExpandCommitEditor, RestoreTrackedFiles, StageAll, TrashUntrackedFiles, UnstageAll}; use gpui::{ Action, Animation, AnimationExt as _, AsyncApp, AsyncWindowContext, Axis, ClickEvent, Corner, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, KeyContext, @@ -71,12 +68,12 @@ use ui::{ use util::{ResultExt, TryFutureExt, maybe}; use workspace::SERIALIZATION_THROTTLE_TIME; -use cloud_llm_client::CompletionIntent; use workspace::{ Workspace, dock::{DockPosition, Panel, PanelEvent}, notifications::{DetachAndPromptErr, ErrorMessagePrompt, NotificationId}, }; +use zed_llm_client::CompletionIntent; actions!( git_panel, @@ -143,13 +140,6 @@ fn git_panel_context_menu( UnstageAll.boxed_clone(), ) .separator() - .action_disabled_when( - !(state.has_new_changes || state.has_tracked_changes), - "Stash All", - StashAll.boxed_clone(), - ) - .action("Stash Pop", StashPop.boxed_clone()) - .separator() .action("Open Diff", project_diff::Diff.boxed_clone()) .separator() .action_disabled_when( @@ -390,9 +380,6 @@ pub(crate) fn commit_message_editor( window: &mut Window, cx: &mut Context<Editor>, ) -> Editor { - project.update(cx, |this, cx| { - this.mark_buffer_as_non_searchable(commit_message_buffer.read(cx).remote_id(), cx); - }); let buffer = cx.new(|cx| MultiBuffer::singleton(commit_message_buffer, cx)); let max_lines = if in_panel { MAX_PANEL_EDITOR_LINES } else { 18 }; let mut commit_editor = Editor::new( @@ -1425,52 +1412,6 @@ impl GitPanel { self.tracked_staged_count + self.new_staged_count + self.conflicted_staged_count } - pub fn stash_pop(&mut self, _: &StashPop, _window: &mut Window, cx: &mut Context<Self>) { - let Some(active_repository) = self.active_repository.clone() else { - return; - }; - - cx.spawn({ - async move |this, cx| { - let stash_task = active_repository - .update(cx, |repo, cx| repo.stash_pop(cx))? - .await; - this.update(cx, |this, cx| { - stash_task - .map_err(|e| { - this.show_error_toast("stash pop", e, cx); - }) - .ok(); - cx.notify(); - }) - } - }) - .detach(); - } - - pub fn stash_all(&mut self, _: &StashAll, _window: &mut Window, cx: &mut Context<Self>) { - let Some(active_repository) = self.active_repository.clone() else { - return; - }; - - cx.spawn({ - async move |this, cx| { - let stash_task = active_repository - .update(cx, |repo, cx| repo.stash_all(cx))? - .await; - this.update(cx, |this, cx| { - stash_task - .map_err(|e| { - this.show_error_toast("stash", e, cx); - }) - .ok(); - cx.notify(); - }) - } - }) - .detach(); - } - pub fn commit_message_buffer(&self, cx: &App) -> Entity<Buffer> { self.commit_editor .read(cx) @@ -2416,7 +2357,7 @@ impl GitPanel { .committer_name .clone() .or_else(|| participant.user.name.clone()) - .unwrap_or_else(|| participant.user.github_login.clone().to_string()); + .unwrap_or_else(|| participant.user.github_login.clone()); new_co_authors.push((name.clone(), email.clone())) } } @@ -2436,7 +2377,7 @@ impl GitPanel { .name .clone() .or_else(|| user.name.clone()) - .unwrap_or_else(|| user.github_login.clone().to_string()); + .unwrap_or_else(|| user.github_login.clone()); Some((name, email)) } @@ -4430,8 +4371,6 @@ impl Render for GitPanel { .on_action(cx.listener(Self::revert_selected)) .on_action(cx.listener(Self::clean_all)) .on_action(cx.listener(Self::generate_commit_message_action)) - .on_action(cx.listener(Self::stash_all)) - .on_action(cx.listener(Self::stash_pop)) }) .on_action(cx.listener(Self::select_first)) .on_action(cx.listener(Self::select_next)) diff --git a/crates/git_ui/src/git_ui.rs b/crates/git_ui/src/git_ui.rs index 0163175eda..2d7fba13c5 100644 --- a/crates/git_ui/src/git_ui.rs +++ b/crates/git_ui/src/git_ui.rs @@ -114,22 +114,6 @@ pub fn init(cx: &mut App) { }); }); } - workspace.register_action(|workspace, action: &git::StashAll, window, cx| { - let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { - return; - }; - panel.update(cx, |panel, cx| { - panel.stash_all(action, window, cx); - }); - }); - workspace.register_action(|workspace, action: &git::StashPop, window, cx| { - let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { - return; - }; - panel.update(cx, |panel, cx| { - panel.stash_pop(action, window, cx); - }); - }); workspace.register_action(|workspace, action: &git::StageAll, window, cx| { let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { return; diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 2bf49fa7d8..b446ea8bd8 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -121,7 +121,7 @@ smallvec.workspace = true smol.workspace = true strum.workspace = true sum_tree.workspace = true -taffy = "=0.8.3" +taffy = "=0.5.1" thiserror.workspace = true util.workspace = true uuid.workspace = true @@ -216,6 +216,10 @@ xim = { git = "https://github.com/XDeme1/xim-rs", rev = "d50d461764c2213655cd9cf x11-clipboard = { version = "0.9.3", optional = true } [target.'cfg(target_os = "windows")'.dependencies] +blade-util.workspace = true +bytemuck = "1" +blade-graphics.workspace = true +blade-macros.workspace = true flume = "0.11" rand.workspace = true windows.workspace = true @@ -236,6 +240,7 @@ util = { workspace = true, features = ["test-support"] } [target.'cfg(target_os = "windows")'.build-dependencies] embed-resource = "3.0" +naga.workspace = true [target.'cfg(target_os = "macos")'.build-dependencies] bindgen = "0.71" @@ -282,10 +287,6 @@ path = "examples/shadow.rs" name = "svg" path = "examples/svg/svg.rs" -[[example]] -name = "tab_stop" -path = "examples/tab_stop.rs" - [[example]] name = "text" path = "examples/text.rs" diff --git a/crates/gpui/build.rs b/crates/gpui/build.rs index 93a1c15c41..aed4397440 100644 --- a/crates/gpui/build.rs +++ b/crates/gpui/build.rs @@ -9,10 +9,7 @@ fn main() { let target = env::var("CARGO_CFG_TARGET_OS"); println!("cargo::rustc-check-cfg=cfg(gles)"); - #[cfg(any( - not(any(target_os = "macos", target_os = "windows")), - all(target_os = "macos", feature = "macos-blade") - ))] + #[cfg(any(not(target_os = "macos"), feature = "macos-blade"))] check_wgsl_shaders(); match target.as_deref() { @@ -20,18 +17,21 @@ fn main() { #[cfg(target_os = "macos")] macos::build(); } + #[cfg(all(target_os = "windows", feature = "windows-manifest"))] Ok("windows") => { - #[cfg(target_os = "windows")] - windows::build(); + let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); + let rc_file = std::path::Path::new("resources/windows/gpui.rc"); + println!("cargo:rerun-if-changed={}", manifest.display()); + println!("cargo:rerun-if-changed={}", rc_file.display()); + embed_resource::compile(rc_file, embed_resource::NONE) + .manifest_required() + .unwrap(); } _ => (), }; } -#[cfg(any( - not(any(target_os = "macos", target_os = "windows")), - all(target_os = "macos", feature = "macos-blade") -))] +#[allow(dead_code)] fn check_wgsl_shaders() { use std::path::PathBuf; use std::process; @@ -128,7 +128,6 @@ mod macos { "AtlasTile".into(), "PathRasterizationInputIndex".into(), "PathVertex_ScaledPixels".into(), - "PathRasterizationVertex".into(), "ShadowInputIndex".into(), "Shadow".into(), "QuadInputIndex".into(), @@ -243,215 +242,3 @@ mod macos { } } } - -#[cfg(target_os = "windows")] -mod windows { - use std::{ - fs, - io::Write, - path::{Path, PathBuf}, - process::{self, Command}, - }; - - pub(super) fn build() { - // Compile HLSL shaders - #[cfg(not(debug_assertions))] - compile_shaders(); - - // Embed the Windows manifest and resource file - #[cfg(feature = "windows-manifest")] - embed_resource(); - } - - #[cfg(feature = "windows-manifest")] - fn embed_resource() { - let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); - let rc_file = std::path::Path::new("resources/windows/gpui.rc"); - println!("cargo:rerun-if-changed={}", manifest.display()); - println!("cargo:rerun-if-changed={}", rc_file.display()); - embed_resource::compile(rc_file, embed_resource::NONE) - .manifest_required() - .unwrap(); - } - - /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. - fn compile_shaders() { - let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) - .join("src/platform/windows/shaders.hlsl"); - let out_dir = std::env::var("OUT_DIR").unwrap(); - - println!("cargo:rerun-if-changed={}", shader_path.display()); - - // Check if fxc.exe is available - let fxc_path = find_fxc_compiler(); - - // Define all modules - let modules = [ - "quad", - "shadow", - "path_rasterization", - "path_sprite", - "underline", - "monochrome_sprite", - "polychrome_sprite", - ]; - - let rust_binding_path = format!("{}/shaders_bytes.rs", out_dir); - if Path::new(&rust_binding_path).exists() { - fs::remove_file(&rust_binding_path) - .expect("Failed to remove existing Rust binding file"); - } - for module in modules { - compile_shader_for_module( - module, - &out_dir, - &fxc_path, - shader_path.to_str().unwrap(), - &rust_binding_path, - ); - } - - { - let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) - .join("src/platform/windows/color_text_raster.hlsl"); - compile_shader_for_module( - "emoji_rasterization", - &out_dir, - &fxc_path, - shader_path.to_str().unwrap(), - &rust_binding_path, - ); - } - } - - /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. - fn find_fxc_compiler() -> String { - // Check environment variable - if let Ok(path) = std::env::var("GPUI_FXC_PATH") { - if Path::new(&path).exists() { - return path; - } - } - - // Try to find in PATH - // NOTE: This has to be `where.exe` on Windows, not `where`, it must be ended with `.exe` - if let Ok(output) = std::process::Command::new("where.exe") - .arg("fxc.exe") - .output() - { - if output.status.success() { - let path = String::from_utf8_lossy(&output.stdout); - return path.trim().to_string(); - } - } - - // Check the default path - if Path::new(r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe") - .exists() - { - return r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe" - .to_string(); - } - - panic!("Failed to find fxc.exe"); - } - - fn compile_shader_for_module( - module: &str, - out_dir: &str, - fxc_path: &str, - shader_path: &str, - rust_binding_path: &str, - ) { - // Compile vertex shader - let output_file = format!("{}/{}_vs.h", out_dir, module); - let const_name = format!("{}_VERTEX_BYTES", module.to_uppercase()); - compile_shader_impl( - fxc_path, - &format!("{module}_vertex"), - &output_file, - &const_name, - shader_path, - "vs_4_1", - ); - generate_rust_binding(&const_name, &output_file, &rust_binding_path); - - // Compile fragment shader - let output_file = format!("{}/{}_ps.h", out_dir, module); - let const_name = format!("{}_FRAGMENT_BYTES", module.to_uppercase()); - compile_shader_impl( - fxc_path, - &format!("{module}_fragment"), - &output_file, - &const_name, - shader_path, - "ps_4_1", - ); - generate_rust_binding(&const_name, &output_file, &rust_binding_path); - } - - fn compile_shader_impl( - fxc_path: &str, - entry_point: &str, - output_path: &str, - var_name: &str, - shader_path: &str, - target: &str, - ) { - let output = Command::new(fxc_path) - .args([ - "/T", - target, - "/E", - entry_point, - "/Fh", - output_path, - "/Vn", - var_name, - "/O3", - shader_path, - ]) - .output(); - - match output { - Ok(result) => { - if result.status.success() { - return; - } - eprintln!( - "Shader compilation failed for {}:\n{}", - entry_point, - String::from_utf8_lossy(&result.stderr) - ); - process::exit(1); - } - Err(e) => { - eprintln!("Failed to run fxc for {}: {}", entry_point, e); - process::exit(1); - } - } - } - - fn generate_rust_binding(const_name: &str, head_file: &str, output_path: &str) { - let header_content = fs::read_to_string(head_file).expect("Failed to read header file"); - let const_definition = { - let global_var_start = header_content.find("const BYTE").unwrap(); - let global_var = &header_content[global_var_start..]; - let equal = global_var.find('=').unwrap(); - global_var[equal + 1..].trim() - }; - let rust_binding = format!( - "const {}: &[u8] = &{}\n", - const_name, - const_definition.replace('{', "[").replace('}', "]") - ); - let mut options = fs::OpenOptions::new() - .create(true) - .append(true) - .open(output_path) - .expect("Failed to open Rust binding file"); - options - .write_all(rust_binding.as_bytes()) - .expect("Failed to write Rust binding file"); - } -} diff --git a/crates/gpui/examples/painting.rs b/crates/gpui/examples/painting.rs index 668aed2377..ff4b64cbda 100644 --- a/crates/gpui/examples/painting.rs +++ b/crates/gpui/examples/painting.rs @@ -1,12 +1,11 @@ use gpui::{ Application, Background, Bounds, ColorSpace, Context, MouseDownEvent, Path, PathBuilder, PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowOptions, canvas, - div, linear_color_stop, linear_gradient, point, prelude::*, px, quad, rgb, size, + div, linear_color_stop, linear_gradient, point, prelude::*, px, rgb, size, }; struct PaintingViewer { default_lines: Vec<(Path<Pixels>, Background)>, - background_quads: Vec<(Bounds<Pixels>, Background)>, lines: Vec<Vec<Point<Pixels>>>, start: Point<Pixels>, dashed: bool, @@ -17,148 +16,12 @@ impl PaintingViewer { fn new(_window: &mut Window, _cx: &mut Context<Self>) -> Self { let mut lines = vec![]; - // Black squares beneath transparent paths. - let background_quads = vec![ - ( - Bounds { - origin: point(px(70.), px(70.)), - size: size(px(40.), px(40.)), - }, - gpui::black().into(), - ), - ( - Bounds { - origin: point(px(170.), px(70.)), - size: size(px(40.), px(40.)), - }, - gpui::black().into(), - ), - ( - Bounds { - origin: point(px(270.), px(70.)), - size: size(px(40.), px(40.)), - }, - gpui::black().into(), - ), - ( - Bounds { - origin: point(px(370.), px(70.)), - size: size(px(40.), px(40.)), - }, - gpui::black().into(), - ), - ( - Bounds { - origin: point(px(450.), px(50.)), - size: size(px(80.), px(80.)), - }, - gpui::black().into(), - ), - ]; - - // 50% opaque red path that extends across black quad. - let mut builder = PathBuilder::fill(); - builder.move_to(point(px(50.), px(50.))); - builder.line_to(point(px(130.), px(50.))); - builder.line_to(point(px(130.), px(130.))); - builder.line_to(point(px(50.), px(130.))); - builder.close(); - let path = builder.build().unwrap(); - let mut red = rgb(0xFF0000); - red.a = 0.5; - lines.push((path, red.into())); - - // 50% opaque blue path that extends across black quad. - let mut builder = PathBuilder::fill(); - builder.move_to(point(px(150.), px(50.))); - builder.line_to(point(px(230.), px(50.))); - builder.line_to(point(px(230.), px(130.))); - builder.line_to(point(px(150.), px(130.))); - builder.close(); - let path = builder.build().unwrap(); - let mut blue = rgb(0x0000FF); - blue.a = 0.5; - lines.push((path, blue.into())); - - // 50% opaque green path that extends across black quad. - let mut builder = PathBuilder::fill(); - builder.move_to(point(px(250.), px(50.))); - builder.line_to(point(px(330.), px(50.))); - builder.line_to(point(px(330.), px(130.))); - builder.line_to(point(px(250.), px(130.))); - builder.close(); - let path = builder.build().unwrap(); - let mut green = rgb(0x00FF00); - green.a = 0.5; - lines.push((path, green.into())); - - // 50% opaque black path that extends across black quad. - let mut builder = PathBuilder::fill(); - builder.move_to(point(px(350.), px(50.))); - builder.line_to(point(px(430.), px(50.))); - builder.line_to(point(px(430.), px(130.))); - builder.line_to(point(px(350.), px(130.))); - builder.close(); - let path = builder.build().unwrap(); - let mut black = rgb(0x000000); - black.a = 0.5; - lines.push((path, black.into())); - - // Two 50% opaque red circles overlapping - center should be darker red - let mut builder = PathBuilder::fill(); - let center = point(px(530.), px(85.)); - let radius = px(30.); - builder.move_to(point(center.x + radius, center.y)); - builder.arc_to( - point(radius, radius), - px(0.), - false, - false, - point(center.x - radius, center.y), - ); - builder.arc_to( - point(radius, radius), - px(0.), - false, - false, - point(center.x + radius, center.y), - ); - builder.close(); - let path = builder.build().unwrap(); - let mut red1 = rgb(0xFF0000); - red1.a = 0.5; - lines.push((path, red1.into())); - - let mut builder = PathBuilder::fill(); - let center = point(px(570.), px(85.)); - let radius = px(30.); - builder.move_to(point(center.x + radius, center.y)); - builder.arc_to( - point(radius, radius), - px(0.), - false, - false, - point(center.x - radius, center.y), - ); - builder.arc_to( - point(radius, radius), - px(0.), - false, - false, - point(center.x + radius, center.y), - ); - builder.close(); - let path = builder.build().unwrap(); - let mut red2 = rgb(0xFF0000); - red2.a = 0.5; - lines.push((path, red2.into())); - // draw a Rust logo let mut builder = lyon::path::Path::svg_builder(); lyon::extra::rust_logo::build_logo_path(&mut builder); // move down the Path let mut builder: PathBuilder = builder.into(); - builder.translate(point(px(10.), px(200.))); + builder.translate(point(px(10.), px(100.))); builder.scale(0.9); let path = builder.build().unwrap(); lines.push((path, gpui::black().into())); @@ -167,10 +30,10 @@ impl PaintingViewer { let mut builder = PathBuilder::fill(); builder.add_polygon( &[ - point(px(150.), px(300.)), - point(px(200.), px(225.)), - point(px(200.), px(275.)), - point(px(250.), px(200.)), + point(px(150.), px(200.)), + point(px(200.), px(125.)), + point(px(200.), px(175.)), + point(px(250.), px(100.)), ], false, ); @@ -179,17 +42,17 @@ impl PaintingViewer { // draw a ⭐ let mut builder = PathBuilder::fill(); - builder.move_to(point(px(350.), px(200.))); - builder.line_to(point(px(370.), px(260.))); - builder.line_to(point(px(430.), px(260.))); - builder.line_to(point(px(380.), px(300.))); - builder.line_to(point(px(400.), px(360.))); - builder.line_to(point(px(350.), px(320.))); - builder.line_to(point(px(300.), px(360.))); - builder.line_to(point(px(320.), px(300.))); - builder.line_to(point(px(270.), px(260.))); - builder.line_to(point(px(330.), px(260.))); - builder.line_to(point(px(350.), px(200.))); + builder.move_to(point(px(350.), px(100.))); + builder.line_to(point(px(370.), px(160.))); + builder.line_to(point(px(430.), px(160.))); + builder.line_to(point(px(380.), px(200.))); + builder.line_to(point(px(400.), px(260.))); + builder.line_to(point(px(350.), px(220.))); + builder.line_to(point(px(300.), px(260.))); + builder.line_to(point(px(320.), px(200.))); + builder.line_to(point(px(270.), px(160.))); + builder.line_to(point(px(330.), px(160.))); + builder.line_to(point(px(350.), px(100.))); let path = builder.build().unwrap(); lines.push(( path, @@ -203,7 +66,7 @@ impl PaintingViewer { // draw linear gradient let square_bounds = Bounds { - origin: point(px(450.), px(200.)), + origin: point(px(450.), px(100.)), size: size(px(200.), px(80.)), }; let height = square_bounds.size.height; @@ -233,31 +96,31 @@ impl PaintingViewer { // draw a pie chart let center = point(px(96.), px(96.)); - let pie_center = point(px(775.), px(255.)); + let pie_center = point(px(775.), px(155.)); let segments = [ ( - point(px(871.), px(255.)), - point(px(747.), px(163.)), + point(px(871.), px(155.)), + point(px(747.), px(63.)), rgb(0x1374e9), ), ( - point(px(747.), px(163.)), - point(px(679.), px(263.)), + point(px(747.), px(63.)), + point(px(679.), px(163.)), rgb(0xe13527), ), ( - point(px(679.), px(263.)), - point(px(754.), px(349.)), + point(px(679.), px(163.)), + point(px(754.), px(249.)), rgb(0x0751ce), ), ( - point(px(754.), px(349.)), - point(px(854.), px(310.)), + point(px(754.), px(249.)), + point(px(854.), px(210.)), rgb(0x209742), ), ( - point(px(854.), px(310.)), - point(px(871.), px(255.)), + point(px(854.), px(210.)), + point(px(871.), px(155.)), rgb(0xfbc10a), ), ]; @@ -277,11 +140,11 @@ impl PaintingViewer { .with_line_width(1.) .with_line_join(lyon::path::LineJoin::Bevel); let mut builder = PathBuilder::stroke(px(1.)).with_style(PathStyle::Stroke(options)); - builder.move_to(point(px(40.), px(420.))); + builder.move_to(point(px(40.), px(320.))); for i in 1..50 { builder.line_to(point( px(40.0 + i as f32 * 10.0), - px(420.0 + (i as f32 * 10.0).sin() * 40.0), + px(320.0 + (i as f32 * 10.0).sin() * 40.0), )); } let path = builder.build().unwrap(); @@ -289,7 +152,6 @@ impl PaintingViewer { Self { default_lines: lines.clone(), - background_quads, lines: vec![], start: point(px(0.), px(0.)), dashed: false, @@ -323,7 +185,6 @@ fn button( impl Render for PaintingViewer { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let default_lines = self.default_lines.clone(); - let background_quads = self.background_quads.clone(); let lines = self.lines.clone(); let dashed = self.dashed; @@ -360,19 +221,6 @@ impl Render for PaintingViewer { canvas( move |_, _, _| {}, move |_, _, window, _| { - // First draw background quads - for (bounds, color) in background_quads.iter() { - window.paint_quad(quad( - *bounds, - px(0.), - *color, - px(0.), - gpui::transparent_black(), - Default::default(), - )); - } - - // Then draw the default paths on top for (path, color) in default_lines { window.paint_path(path, color); } @@ -455,10 +303,6 @@ fn main() { |window, cx| cx.new(|cx| PaintingViewer::new(window, cx)), ) .unwrap(); - cx.on_window_closed(|cx| { - cx.quit(); - }) - .detach(); cx.activate(true); }); } diff --git a/crates/gpui/examples/paths_bench.rs b/crates/gpui/examples/paths_bench.rs deleted file mode 100644 index a801889ae8..0000000000 --- a/crates/gpui/examples/paths_bench.rs +++ /dev/null @@ -1,92 +0,0 @@ -use gpui::{ - Application, Background, Bounds, ColorSpace, Context, Path, PathBuilder, Pixels, Render, - TitlebarOptions, Window, WindowBounds, WindowOptions, canvas, div, linear_color_stop, - linear_gradient, point, prelude::*, px, rgb, size, -}; - -const DEFAULT_WINDOW_WIDTH: Pixels = px(1024.0); -const DEFAULT_WINDOW_HEIGHT: Pixels = px(768.0); - -struct PaintingViewer { - default_lines: Vec<(Path<Pixels>, Background)>, - _painting: bool, -} - -impl PaintingViewer { - fn new(_window: &mut Window, _cx: &mut Context<Self>) -> Self { - let mut lines = vec![]; - - // draw a lightening bolt ⚡ - for _ in 0..2000 { - // draw a ⭐ - let mut builder = PathBuilder::fill(); - builder.move_to(point(px(350.), px(100.))); - builder.line_to(point(px(370.), px(160.))); - builder.line_to(point(px(430.), px(160.))); - builder.line_to(point(px(380.), px(200.))); - builder.line_to(point(px(400.), px(260.))); - builder.line_to(point(px(350.), px(220.))); - builder.line_to(point(px(300.), px(260.))); - builder.line_to(point(px(320.), px(200.))); - builder.line_to(point(px(270.), px(160.))); - builder.line_to(point(px(330.), px(160.))); - builder.line_to(point(px(350.), px(100.))); - let path = builder.build().unwrap(); - lines.push(( - path, - linear_gradient( - 180., - linear_color_stop(rgb(0xFACC15), 0.7), - linear_color_stop(rgb(0xD56D0C), 1.), - ) - .color_space(ColorSpace::Oklab), - )); - } - - Self { - default_lines: lines, - _painting: false, - } - } -} - -impl Render for PaintingViewer { - fn render(&mut self, window: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - window.request_animation_frame(); - let lines = self.default_lines.clone(); - div().size_full().child( - canvas( - move |_, _, _| {}, - move |_, _, window, _| { - for (path, color) in lines { - window.paint_path(path, color); - } - }, - ) - .size_full(), - ) - } -} - -fn main() { - Application::new().run(|cx| { - cx.open_window( - WindowOptions { - titlebar: Some(TitlebarOptions { - title: Some("Vulkan".into()), - ..Default::default() - }), - focus: true, - window_bounds: Some(WindowBounds::Windowed(Bounds::centered( - None, - size(DEFAULT_WINDOW_WIDTH, DEFAULT_WINDOW_HEIGHT), - cx, - ))), - ..Default::default() - }, - |window, cx| cx.new(|cx| PaintingViewer::new(window, cx)), - ) - .unwrap(); - cx.activate(true); - }); -} diff --git a/crates/gpui/examples/tab_stop.rs b/crates/gpui/examples/tab_stop.rs index 1f6500f3e6..9c58b52a5e 100644 --- a/crates/gpui/examples/tab_stop.rs +++ b/crates/gpui/examples/tab_stop.rs @@ -6,7 +6,6 @@ use gpui::{ actions!(example, [Tab, TabPrev]); struct Example { - focus_handle: FocusHandle, items: Vec<FocusHandle>, message: SharedString, } @@ -21,11 +20,8 @@ impl Example { cx.focus_handle().tab_index(2).tab_stop(true), ]; - let focus_handle = cx.focus_handle(); - window.focus(&focus_handle); - + window.focus(items.first().unwrap()); Self { - focus_handle, items, message: SharedString::from("Press `Tab`, `Shift-Tab` to switch focus."), } @@ -44,10 +40,6 @@ impl Example { impl Render for Example { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - fn tab_stop_style<T: Styled>(this: T) -> T { - this.border_3().border_color(gpui::blue()) - } - fn button(id: impl Into<ElementId>) -> Stateful<Div> { div() .id(id) @@ -60,13 +52,12 @@ impl Render for Example { .border_color(gpui::black()) .bg(gpui::black()) .text_color(gpui::white()) - .focus(tab_stop_style) + .focus(|this| this.border_color(gpui::blue())) .shadow_sm() } div() .id("app") - .track_focus(&self.focus_handle) .on_action(cx.listener(Self::on_tab)) .on_action(cx.listener(Self::on_tab_prev)) .size_full() @@ -95,7 +86,7 @@ impl Render for Example { .border_color(gpui::black()) .when( item_handle.tab_stop && item_handle.is_focused(window), - tab_stop_style, + |this| this.border_color(gpui::blue()), ) .map(|this| match item_handle.tab_stop { true => this diff --git a/crates/gpui/examples/text.rs b/crates/gpui/examples/text.rs index 1166bb2795..19214aebde 100644 --- a/crates/gpui/examples/text.rs +++ b/crates/gpui/examples/text.rs @@ -198,7 +198,7 @@ impl RenderOnce for CharacterGrid { "χ", "ψ", "∂", "а", "в", "Ж", "ж", "З", "з", "К", "к", "л", "м", "Н", "н", "Р", "р", "У", "у", "ф", "ч", "ь", "ы", "Э", "э", "Я", "я", "ij", "öẋ", ".,", "⣝⣑", "~", "*", "_", "^", "`", "'", "(", "{", "«", "#", "&", "@", "$", "¢", "%", "|", "?", "¶", "µ", - "❮", "<=", "!=", "==", "--", "++", "=>", "->", "🏀", "🎊", "😍", "❤️", "👍", "👎", + "❮", "<=", "!=", "==", "--", "++", "=>", "->", ]; let columns = 11; diff --git a/crates/gpui/src/color.rs b/crates/gpui/src/color.rs index 639c84c101..a16c8f46be 100644 --- a/crates/gpui/src/color.rs +++ b/crates/gpui/src/color.rs @@ -35,7 +35,6 @@ pub(crate) fn swap_rgba_pa_to_bgra(color: &mut [u8]) { /// An RGBA color #[derive(PartialEq, Clone, Copy, Default)] -#[repr(C)] pub struct Rgba { /// The red component of the color, in the range 0.0 to 1.0 pub r: f32, diff --git a/crates/gpui/src/elements/div.rs b/crates/gpui/src/elements/div.rs index fa47758581..4655c92409 100644 --- a/crates/gpui/src/elements/div.rs +++ b/crates/gpui/src/elements/div.rs @@ -1334,6 +1334,7 @@ impl Element for Div { } else if let Some(scroll_handle) = self.interactivity.tracked_scroll_handle.as_ref() { let mut state = scroll_handle.0.borrow_mut(); state.child_bounds = Vec::with_capacity(request_layout.child_layout_ids.len()); + state.bounds = bounds; for child_layout_id in &request_layout.child_layout_ids { let child_bounds = window.layout_bounds(*child_layout_id); child_min = child_min.min(&child_bounds.origin); @@ -1705,7 +1706,6 @@ impl Interactivity { if let Some(mut scroll_handle_state) = tracked_scroll_handle { scroll_handle_state.max_offset = scroll_max; - scroll_handle_state.bounds = bounds; } *scroll_offset @@ -3007,6 +3007,11 @@ impl ScrollHandle { self.0.borrow().bounds } + /// Set the bounds into which this child is painted + pub(super) fn set_bounds(&self, bounds: Bounds<Pixels>) { + self.0.borrow_mut().bounds = bounds; + } + /// Get the bounds for a specific child. pub fn bounds_for_item(&self, ix: usize) -> Option<Bounds<Pixels>> { self.0.borrow().child_bounds.get(ix).cloned() diff --git a/crates/gpui/src/elements/uniform_list.rs b/crates/gpui/src/elements/uniform_list.rs index cdf90d4eb8..2ee6e9827d 100644 --- a/crates/gpui/src/elements/uniform_list.rs +++ b/crates/gpui/src/elements/uniform_list.rs @@ -322,8 +322,9 @@ impl Element for UniformList { bounds.bottom_right() - point(border.right + padding.right, border.bottom), ); - let y_flipped = if let Some(scroll_handle) = &self.scroll_handle { - let scroll_state = scroll_handle.0.borrow(); + let y_flipped = if let Some(scroll_handle) = self.scroll_handle.as_mut() { + let mut scroll_state = scroll_handle.0.borrow_mut(); + scroll_state.base_handle.set_bounds(bounds); scroll_state.y_flipped } else { false diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index b495d70dfd..6f227f1d07 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -13,7 +13,8 @@ mod mac; any(target_os = "linux", target_os = "freebsd"), any(feature = "x11", feature = "wayland") ), - all(target_os = "macos", feature = "macos-blade") + target_os = "windows", + feature = "macos-blade" ))] mod blade; @@ -447,8 +448,6 @@ impl Tiling { #[derive(Debug, Copy, Clone, Eq, PartialEq, Default)] pub(crate) struct RequestFrameOptions { pub(crate) require_presentation: bool, - /// Force refresh of all rendering states when true - pub(crate) force_render: bool, } pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle { @@ -810,6 +809,7 @@ pub(crate) struct AtlasTextureId { pub(crate) enum AtlasTextureKind { Monochrome = 0, Polychrome = 1, + Path = 2, } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] diff --git a/crates/gpui/src/platform/blade/blade_atlas.rs b/crates/gpui/src/platform/blade/blade_atlas.rs index 74500ebf83..78ba52056a 100644 --- a/crates/gpui/src/platform/blade/blade_atlas.rs +++ b/crates/gpui/src/platform/blade/blade_atlas.rs @@ -10,6 +10,8 @@ use etagere::BucketedAtlasAllocator; use parking_lot::Mutex; use std::{borrow::Cow, ops, sync::Arc}; +pub(crate) const PATH_TEXTURE_FORMAT: gpu::TextureFormat = gpu::TextureFormat::R16Float; + pub(crate) struct BladeAtlas(Mutex<BladeAtlasState>); struct PendingUpload { @@ -25,6 +27,7 @@ struct BladeAtlasState { tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, initializations: Vec<AtlasTextureId>, uploads: Vec<PendingUpload>, + path_sample_count: u32, } #[cfg(gles)] @@ -38,11 +41,13 @@ impl BladeAtlasState { } pub struct BladeTextureInfo { + pub size: gpu::Extent, pub raw_view: gpu::TextureView, + pub msaa_view: Option<gpu::TextureView>, } impl BladeAtlas { - pub(crate) fn new(gpu: &Arc<gpu::Context>) -> Self { + pub(crate) fn new(gpu: &Arc<gpu::Context>, path_sample_count: u32) -> Self { BladeAtlas(Mutex::new(BladeAtlasState { gpu: Arc::clone(gpu), upload_belt: BufferBelt::new(BufferBeltDescriptor { @@ -54,6 +59,7 @@ impl BladeAtlas { tiles_by_key: Default::default(), initializations: Vec::new(), uploads: Vec::new(), + path_sample_count, })) } @@ -61,6 +67,27 @@ impl BladeAtlas { self.0.lock().destroy(); } + pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) { + let mut lock = self.0.lock(); + let textures = &mut lock.storage[texture_kind]; + for texture in textures.iter_mut() { + texture.clear(); + } + } + + /// Allocate a rectangle and make it available for rendering immediately (without waiting for `before_frame`) + pub fn allocate_for_rendering( + &self, + size: Size<DevicePixels>, + texture_kind: AtlasTextureKind, + gpu_encoder: &mut gpu::CommandEncoder, + ) -> AtlasTile { + let mut lock = self.0.lock(); + let tile = lock.allocate(size, texture_kind); + lock.flush_initializations(gpu_encoder); + tile + } + pub fn before_frame(&self, gpu_encoder: &mut gpu::CommandEncoder) { let mut lock = self.0.lock(); lock.flush(gpu_encoder); @@ -74,8 +101,15 @@ impl BladeAtlas { pub fn get_texture_info(&self, id: AtlasTextureId) -> BladeTextureInfo { let lock = self.0.lock(); let texture = &lock.storage[id]; + let size = texture.allocator.size(); BladeTextureInfo { + size: gpu::Extent { + width: size.width as u32, + height: size.height as u32, + depth: 1, + }, raw_view: texture.raw_view, + msaa_view: texture.msaa_view, } } } @@ -166,8 +200,48 @@ impl BladeAtlasState { format = gpu::TextureFormat::Bgra8UnormSrgb; usage = gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE; } + AtlasTextureKind::Path => { + format = PATH_TEXTURE_FORMAT; + usage = gpu::TextureUsage::COPY + | gpu::TextureUsage::RESOURCE + | gpu::TextureUsage::TARGET; + } } + // We currently only enable MSAA for path textures. + let (msaa, msaa_view) = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path { + let msaa = self.gpu.create_texture(gpu::TextureDesc { + name: "msaa path texture", + format, + size: gpu::Extent { + width: size.width.into(), + height: size.height.into(), + depth: 1, + }, + array_layer_count: 1, + mip_level_count: 1, + sample_count: self.path_sample_count, + dimension: gpu::TextureDimension::D2, + usage: gpu::TextureUsage::TARGET, + external: None, + }); + + ( + Some(msaa), + Some(self.gpu.create_texture_view( + msaa, + gpu::TextureViewDesc { + name: "msaa texture view", + format, + dimension: gpu::ViewDimension::D2, + subresources: &Default::default(), + }, + )), + ) + } else { + (None, None) + }; + let raw = self.gpu.create_texture(gpu::TextureDesc { name: "atlas", format, @@ -205,6 +279,8 @@ impl BladeAtlasState { format, raw, raw_view, + msaa, + msaa_view, live_atlas_keys: 0, }; @@ -264,6 +340,7 @@ impl BladeAtlasState { struct BladeAtlasStorage { monochrome_textures: AtlasTextureList<BladeAtlasTexture>, polychrome_textures: AtlasTextureList<BladeAtlasTexture>, + path_textures: AtlasTextureList<BladeAtlasTexture>, } impl ops::Index<AtlasTextureKind> for BladeAtlasStorage { @@ -272,6 +349,7 @@ impl ops::Index<AtlasTextureKind> for BladeAtlasStorage { match kind { crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + crate::AtlasTextureKind::Path => &self.path_textures, } } } @@ -281,6 +359,7 @@ impl ops::IndexMut<AtlasTextureKind> for BladeAtlasStorage { match kind { crate::AtlasTextureKind::Monochrome => &mut self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + crate::AtlasTextureKind::Path => &mut self.path_textures, } } } @@ -291,6 +370,7 @@ impl ops::Index<AtlasTextureId> for BladeAtlasStorage { let textures = match id.kind { crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + crate::AtlasTextureKind::Path => &self.path_textures, }; textures[id.index as usize].as_ref().unwrap() } @@ -304,6 +384,9 @@ impl BladeAtlasStorage { for mut texture in self.polychrome_textures.drain().flatten() { texture.destroy(gpu); } + for mut texture in self.path_textures.drain().flatten() { + texture.destroy(gpu); + } } } @@ -312,11 +395,17 @@ struct BladeAtlasTexture { allocator: BucketedAtlasAllocator, raw: gpu::Texture, raw_view: gpu::TextureView, + msaa: Option<gpu::Texture>, + msaa_view: Option<gpu::TextureView>, format: gpu::TextureFormat, live_atlas_keys: u32, } impl BladeAtlasTexture { + fn clear(&mut self) { + self.allocator.clear(); + } + fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { let allocation = self.allocator.allocate(size.into())?; let tile = AtlasTile { @@ -335,6 +424,12 @@ impl BladeAtlasTexture { fn destroy(&mut self, gpu: &gpu::Context) { gpu.destroy_texture(self.raw); gpu.destroy_texture_view(self.raw_view); + if let Some(msaa) = self.msaa { + gpu.destroy_texture(msaa); + } + if let Some(msaa_view) = self.msaa_view { + gpu.destroy_texture_view(msaa_view); + } } fn bytes_per_pixel(&self) -> u8 { diff --git a/crates/gpui/src/platform/blade/blade_renderer.rs b/crates/gpui/src/platform/blade/blade_renderer.rs index 2e18d2be22..cac47434ae 100644 --- a/crates/gpui/src/platform/blade/blade_renderer.rs +++ b/crates/gpui/src/platform/blade/blade_renderer.rs @@ -1,19 +1,24 @@ // Doing `if let` gives you nice scoping with passes/encoders #![allow(irrefutable_let_patterns)] -use super::{BladeAtlas, BladeContext}; +use super::{BladeAtlas, BladeContext, PATH_TEXTURE_FORMAT}; use crate::{ - Background, Bounds, DevicePixels, GpuSpecs, MonochromeSprite, Path, Point, PolychromeSprite, - PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Underline, + AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels, GpuSpecs, + MonochromeSprite, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch, Quad, + ScaledPixels, Scene, Shadow, Size, Underline, }; use blade_graphics as gpu; use blade_util::{BufferBelt, BufferBeltDescriptor}; use bytemuck::{Pod, Zeroable}; +use collections::HashMap; #[cfg(target_os = "macos")] use media::core_video::CVMetalTextureCache; -use std::sync::Arc; +use std::{mem, sync::Arc}; const MAX_FRAME_TIME_MS: u32 = 10000; +// Use 4x MSAA, all devices support it. +// https://developer.apple.com/documentation/metal/mtldevice/1433355-supportstexturesamplecount +const DEFAULT_PATH_SAMPLE_COUNT: u32 = 4; #[repr(C)] #[derive(Clone, Copy, Pod, Zeroable)] @@ -109,15 +114,8 @@ struct ShaderSurfacesData { #[repr(C)] struct PathSprite { bounds: Bounds<ScaledPixels>, -} - -#[derive(Clone, Debug)] -#[repr(C)] -struct PathRasterizationVertex { - xy_position: Point<ScaledPixels>, - st_position: Point<f32>, color: Background, - bounds: Bounds<ScaledPixels>, + tile: AtlasTile, } struct BladePipelines { @@ -146,7 +144,10 @@ impl BladePipelines { shader.check_struct_size::<SurfaceParams>(); shader.check_struct_size::<Quad>(); shader.check_struct_size::<Shadow>(); - shader.check_struct_size::<PathRasterizationVertex>(); + assert_eq!( + mem::size_of::<PathVertex<ScaledPixels>>(), + shader.get_struct_size("PathVertex") as usize, + ); shader.check_struct_size::<PathSprite>(); shader.check_struct_size::<Underline>(); shader.check_struct_size::<MonochromeSprite>(); @@ -204,16 +205,9 @@ impl BladePipelines { }, depth_stencil: None, fragment: Some(shader.at("fs_path_rasterization")), - // The original implementation was using ADDITIVE blende mode, - // I don't know why - // color_targets: &[gpu::ColorTargetState { - // format: PATH_TEXTURE_FORMAT, - // blend: Some(gpu::BlendState::ADDITIVE), - // write_mask: gpu::ColorWrites::default(), - // }], color_targets: &[gpu::ColorTargetState { - format: surface_info.format, - blend: Some(gpu::BlendState::PREMULTIPLIED_ALPHA_BLENDING), + format: PATH_TEXTURE_FORMAT, + blend: Some(gpu::BlendState::ADDITIVE), write_mask: gpu::ColorWrites::default(), }], multisample_state: gpu::MultisampleState { @@ -232,14 +226,7 @@ impl BladePipelines { }, depth_stencil: None, fragment: Some(shader.at("fs_path")), - color_targets: &[gpu::ColorTargetState { - format: surface_info.format, - blend: Some(gpu::BlendState { - color: gpu::BlendComponent::OVER, - alpha: gpu::BlendComponent::ADDITIVE, - }), - write_mask: gpu::ColorWrites::default(), - }], + color_targets, multisample_state: gpu::MultisampleState::default(), }), underlines: gpu.create_render_pipeline(gpu::RenderPipelineDesc { @@ -330,15 +317,12 @@ pub struct BladeRenderer { last_sync_point: Option<gpu::SyncPoint>, pipelines: BladePipelines, instance_belt: BufferBelt, + path_tiles: HashMap<PathId, AtlasTile>, atlas: Arc<BladeAtlas>, atlas_sampler: gpu::Sampler, #[cfg(target_os = "macos")] core_video_texture_cache: CVMetalTextureCache, path_sample_count: u32, - path_intermediate_texture: gpu::Texture, - path_intermediate_texture_view: gpu::TextureView, - path_intermediate_msaa_texture: Option<gpu::Texture>, - path_intermediate_msaa_texture_view: Option<gpu::TextureView>, } impl BladeRenderer { @@ -368,43 +352,21 @@ impl BladeRenderer { let path_sample_count = std::env::var("ZED_PATH_SAMPLE_COUNT") .ok() .and_then(|v| v.parse().ok()) - .or_else(|| { - [4, 2, 1] - .into_iter() - .find(|count| context.gpu.supports_texture_sample_count(*count)) - }) - .unwrap_or(1); + .unwrap_or(DEFAULT_PATH_SAMPLE_COUNT); let pipelines = BladePipelines::new(&context.gpu, surface.info(), path_sample_count); let instance_belt = BufferBelt::new(BufferBeltDescriptor { memory: gpu::Memory::Shared, min_chunk_size: 0x1000, alignment: 0x40, // Vulkan `minStorageBufferOffsetAlignment` on Intel Xe }); - let atlas = Arc::new(BladeAtlas::new(&context.gpu)); + let atlas = Arc::new(BladeAtlas::new(&context.gpu, path_sample_count)); let atlas_sampler = context.gpu.create_sampler(gpu::SamplerDesc { - name: "path rasterization sampler", + name: "atlas", mag_filter: gpu::FilterMode::Linear, min_filter: gpu::FilterMode::Linear, ..Default::default() }); - let (path_intermediate_texture, path_intermediate_texture_view) = - create_path_intermediate_texture( - &context.gpu, - surface.info().format, - config.size.width, - config.size.height, - ); - let (path_intermediate_msaa_texture, path_intermediate_msaa_texture_view) = - create_msaa_texture_if_needed( - &context.gpu, - surface.info().format, - config.size.width, - config.size.height, - path_sample_count, - ) - .unzip(); - #[cfg(target_os = "macos")] let core_video_texture_cache = unsafe { CVMetalTextureCache::new( @@ -421,15 +383,12 @@ impl BladeRenderer { last_sync_point: None, pipelines, instance_belt, + path_tiles: HashMap::default(), atlas, atlas_sampler, #[cfg(target_os = "macos")] core_video_texture_cache, path_sample_count, - path_intermediate_texture, - path_intermediate_texture_view, - path_intermediate_msaa_texture, - path_intermediate_msaa_texture_view, }) } @@ -482,35 +441,6 @@ impl BladeRenderer { self.surface_config.size = gpu_size; self.gpu .reconfigure_surface(&mut self.surface, self.surface_config); - self.gpu.destroy_texture(self.path_intermediate_texture); - self.gpu - .destroy_texture_view(self.path_intermediate_texture_view); - if let Some(msaa_texture) = self.path_intermediate_msaa_texture { - self.gpu.destroy_texture(msaa_texture); - } - if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { - self.gpu.destroy_texture_view(msaa_view); - } - let (path_intermediate_texture, path_intermediate_texture_view) = - create_path_intermediate_texture( - &self.gpu, - self.surface.info().format, - gpu_size.width, - gpu_size.height, - ); - self.path_intermediate_texture = path_intermediate_texture; - self.path_intermediate_texture_view = path_intermediate_texture_view; - let (path_intermediate_msaa_texture, path_intermediate_msaa_texture_view) = - create_msaa_texture_if_needed( - &self.gpu, - self.surface.info().format, - gpu_size.width, - gpu_size.height, - self.path_sample_count, - ) - .unzip(); - self.path_intermediate_msaa_texture = path_intermediate_msaa_texture; - self.path_intermediate_msaa_texture_view = path_intermediate_msaa_texture_view; } } @@ -561,63 +491,76 @@ impl BladeRenderer { } #[profiling::function] - fn draw_paths_to_intermediate( - &mut self, - paths: &[Path<ScaledPixels>], - width: f32, - height: f32, - ) { - self.command_encoder - .init_texture(self.path_intermediate_texture); - if let Some(msaa_texture) = self.path_intermediate_msaa_texture { - self.command_encoder.init_texture(msaa_texture); + fn rasterize_paths(&mut self, paths: &[Path<ScaledPixels>]) { + self.path_tiles.clear(); + let mut vertices_by_texture_id = HashMap::default(); + + for path in paths { + let clipped_bounds = path + .bounds + .intersect(&path.content_mask.bounds) + .map_origin(|origin| origin.floor()) + .map_size(|size| size.ceil()); + let tile = self.atlas.allocate_for_rendering( + clipped_bounds.size.map(Into::into), + AtlasTextureKind::Path, + &mut self.command_encoder, + ); + vertices_by_texture_id + .entry(tile.texture_id) + .or_insert(Vec::new()) + .extend(path.vertices.iter().map(|vertex| PathVertex { + xy_position: vertex.xy_position - clipped_bounds.origin + + tile.bounds.origin.map(Into::into), + st_position: vertex.st_position, + content_mask: ContentMask { + bounds: tile.bounds.map(Into::into), + }, + })); + self.path_tiles.insert(path.id, tile); } - let target = if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { - gpu::RenderTarget { - view: msaa_view, - init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), - finish_op: gpu::FinishOp::ResolveTo(self.path_intermediate_texture_view), - } - } else { - gpu::RenderTarget { - view: self.path_intermediate_texture_view, - init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), - finish_op: gpu::FinishOp::Store, - } - }; - if let mut pass = self.command_encoder.render( - "rasterize paths", - gpu::RenderTargetSet { - colors: &[target], - depth_stencil: None, - }, - ) { + for (texture_id, vertices) in vertices_by_texture_id { + let tex_info = self.atlas.get_texture_info(texture_id); let globals = GlobalParams { - viewport_size: [width, height], + viewport_size: [tex_info.size.width as f32, tex_info.size.height as f32], premultiplied_alpha: 0, pad: 0, }; - let mut encoder = pass.with(&self.pipelines.path_rasterization); - let mut vertices = Vec::new(); - for path in paths { - vertices.extend(path.vertices.iter().map(|v| PathRasterizationVertex { - xy_position: v.xy_position, - st_position: v.st_position, - color: path.color, - bounds: path.bounds.intersect(&path.content_mask.bounds), - })); - } let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) }; - encoder.bind( - 0, - &ShaderPathRasterizationData { - globals, - b_path_vertices: vertex_buf, + let frame_view = tex_info.raw_view; + let color_target = if let Some(msaa_view) = tex_info.msaa_view { + gpu::RenderTarget { + view: msaa_view, + init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack), + finish_op: gpu::FinishOp::ResolveTo(frame_view), + } + } else { + gpu::RenderTarget { + view: frame_view, + init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack), + finish_op: gpu::FinishOp::Store, + } + }; + + if let mut pass = self.command_encoder.render( + "paths", + gpu::RenderTargetSet { + colors: &[color_target], + depth_stencil: None, }, - ); - encoder.draw(0, vertices.len() as u32, 0, 1); + ) { + let mut encoder = pass.with(&self.pipelines.path_rasterization); + encoder.bind( + 0, + &ShaderPathRasterizationData { + globals, + b_path_vertices: vertex_buf, + }, + ); + encoder.draw(0, vertices.len() as u32, 0, 1); + } } } @@ -629,20 +572,12 @@ impl BladeRenderer { self.gpu.destroy_command_encoder(&mut self.command_encoder); self.pipelines.destroy(&self.gpu); self.gpu.destroy_surface(&mut self.surface); - self.gpu.destroy_texture(self.path_intermediate_texture); - self.gpu - .destroy_texture_view(self.path_intermediate_texture_view); - if let Some(msaa_texture) = self.path_intermediate_msaa_texture { - self.gpu.destroy_texture(msaa_texture); - } - if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { - self.gpu.destroy_texture_view(msaa_view); - } } pub fn draw(&mut self, scene: &Scene) { self.command_encoder.start(); self.atlas.before_frame(&mut self.command_encoder); + self.rasterize_paths(scene.paths()); let frame = { profiling::scope!("acquire frame"); @@ -662,7 +597,7 @@ impl BladeRenderer { pad: 0, }; - let mut pass = self.command_encoder.render( + if let mut pass = self.command_encoder.render( "main", gpu::RenderTargetSet { colors: &[gpu::RenderTarget { @@ -672,235 +607,209 @@ impl BladeRenderer { }], depth_stencil: None, }, - ); + ) { + profiling::scope!("render pass"); + for batch in scene.batches() { + match batch { + PrimitiveBatch::Quads(quads) => { + let instance_buf = + unsafe { self.instance_belt.alloc_typed(quads, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.quads); + encoder.bind( + 0, + &ShaderQuadsData { + globals, + b_quads: instance_buf, + }, + ); + encoder.draw(0, 4, 0, quads.len() as u32); + } + PrimitiveBatch::Shadows(shadows) => { + let instance_buf = + unsafe { self.instance_belt.alloc_typed(shadows, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.shadows); + encoder.bind( + 0, + &ShaderShadowsData { + globals, + b_shadows: instance_buf, + }, + ); + encoder.draw(0, 4, 0, shadows.len() as u32); + } + PrimitiveBatch::Paths(paths) => { + let mut encoder = pass.with(&self.pipelines.paths); + // todo(linux): group by texture ID + for path in paths { + let tile = &self.path_tiles[&path.id]; + let tex_info = self.atlas.get_texture_info(tile.texture_id); + let origin = path.bounds.intersect(&path.content_mask.bounds).origin; + let sprites = [PathSprite { + bounds: Bounds { + origin: origin.map(|p| p.floor()), + size: tile.bounds.size.map(Into::into), + }, + color: path.color, + tile: (*tile).clone(), + }]; - profiling::scope!("render pass"); - for batch in scene.batches() { - match batch { - PrimitiveBatch::Quads(quads) => { - let instance_buf = unsafe { self.instance_belt.alloc_typed(quads, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.quads); - encoder.bind( - 0, - &ShaderQuadsData { - globals, - b_quads: instance_buf, - }, - ); - encoder.draw(0, 4, 0, quads.len() as u32); - } - PrimitiveBatch::Shadows(shadows) => { - let instance_buf = - unsafe { self.instance_belt.alloc_typed(shadows, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.shadows); - encoder.bind( - 0, - &ShaderShadowsData { - globals, - b_shadows: instance_buf, - }, - ); - encoder.draw(0, 4, 0, shadows.len() as u32); - } - PrimitiveBatch::Paths(paths) => { - let Some(first_path) = paths.first() else { - continue; - }; - drop(pass); - self.draw_paths_to_intermediate( - paths, - self.surface_config.size.width as f32, - self.surface_config.size.height as f32, - ); - pass = self.command_encoder.render( - "main", - gpu::RenderTargetSet { - colors: &[gpu::RenderTarget { - view: frame.texture_view(), - init_op: gpu::InitOp::Load, - finish_op: gpu::FinishOp::Store, - }], - depth_stencil: None, - }, - ); - let mut encoder = pass.with(&self.pipelines.paths); - // When copying paths from the intermediate texture to the drawable, - // each pixel must only be copied once, in case of transparent paths. - // - // If all paths have the same draw order, then their bounds are all - // disjoint, so we can copy each path's bounds individually. If this - // batch combines different draw orders, we perform a single copy - // for a minimal spanning rect. - let sprites = if paths.last().unwrap().order == first_path.order { - paths - .iter() - .map(|path| PathSprite { - bounds: path.bounds, - }) - .collect() - } else { - let mut bounds = first_path.bounds; - for path in paths.iter().skip(1) { - bounds = bounds.union(&path.bounds); + let instance_buf = + unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) }; + encoder.bind( + 0, + &ShaderPathsData { + globals, + t_sprite: tex_info.raw_view, + s_sprite: self.atlas_sampler, + b_path_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); } - vec![PathSprite { bounds }] - }; - let instance_buf = - unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) }; - encoder.bind( - 0, - &ShaderPathsData { - globals, - t_sprite: self.path_intermediate_texture_view, - s_sprite: self.atlas_sampler, - b_path_sprites: instance_buf, - }, - ); - encoder.draw(0, 4, 0, sprites.len() as u32); - } - PrimitiveBatch::Underlines(underlines) => { - let instance_buf = - unsafe { self.instance_belt.alloc_typed(underlines, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.underlines); - encoder.bind( - 0, - &ShaderUnderlinesData { - globals, - b_underlines: instance_buf, - }, - ); - encoder.draw(0, 4, 0, underlines.len() as u32); - } - PrimitiveBatch::MonochromeSprites { - texture_id, - sprites, - } => { - let tex_info = self.atlas.get_texture_info(texture_id); - let instance_buf = - unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.mono_sprites); - encoder.bind( - 0, - &ShaderMonoSpritesData { - globals, - t_sprite: tex_info.raw_view, - s_sprite: self.atlas_sampler, - b_mono_sprites: instance_buf, - }, - ); - encoder.draw(0, 4, 0, sprites.len() as u32); - } - PrimitiveBatch::PolychromeSprites { - texture_id, - sprites, - } => { - let tex_info = self.atlas.get_texture_info(texture_id); - let instance_buf = - unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.poly_sprites); - encoder.bind( - 0, - &ShaderPolySpritesData { - globals, - t_sprite: tex_info.raw_view, - s_sprite: self.atlas_sampler, - b_poly_sprites: instance_buf, - }, - ); - encoder.draw(0, 4, 0, sprites.len() as u32); - } - PrimitiveBatch::Surfaces(surfaces) => { - let mut _encoder = pass.with(&self.pipelines.surfaces); + } + PrimitiveBatch::Underlines(underlines) => { + let instance_buf = + unsafe { self.instance_belt.alloc_typed(underlines, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.underlines); + encoder.bind( + 0, + &ShaderUnderlinesData { + globals, + b_underlines: instance_buf, + }, + ); + encoder.draw(0, 4, 0, underlines.len() as u32); + } + PrimitiveBatch::MonochromeSprites { + texture_id, + sprites, + } => { + let tex_info = self.atlas.get_texture_info(texture_id); + let instance_buf = + unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.mono_sprites); + encoder.bind( + 0, + &ShaderMonoSpritesData { + globals, + t_sprite: tex_info.raw_view, + s_sprite: self.atlas_sampler, + b_mono_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); + } + PrimitiveBatch::PolychromeSprites { + texture_id, + sprites, + } => { + let tex_info = self.atlas.get_texture_info(texture_id); + let instance_buf = + unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.poly_sprites); + encoder.bind( + 0, + &ShaderPolySpritesData { + globals, + t_sprite: tex_info.raw_view, + s_sprite: self.atlas_sampler, + b_poly_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); + } + PrimitiveBatch::Surfaces(surfaces) => { + let mut _encoder = pass.with(&self.pipelines.surfaces); - for surface in surfaces { - #[cfg(not(target_os = "macos"))] - { - let _ = surface; - continue; - }; + for surface in surfaces { + #[cfg(not(target_os = "macos"))] + { + let _ = surface; + continue; + }; - #[cfg(target_os = "macos")] - { - let (t_y, t_cb_cr) = unsafe { - use core_foundation::base::TCFType as _; - use std::ptr; + #[cfg(target_os = "macos")] + { + let (t_y, t_cb_cr) = unsafe { + use core_foundation::base::TCFType as _; + use std::ptr; - assert_eq!( + assert_eq!( surface.image_buffer.get_pixel_format(), core_video::pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange ); - let y_texture = self - .core_video_texture_cache - .create_texture_from_image( - surface.image_buffer.as_concrete_TypeRef(), - ptr::null(), - metal::MTLPixelFormat::R8Unorm, - surface.image_buffer.get_width_of_plane(0), - surface.image_buffer.get_height_of_plane(0), - 0, - ) - .unwrap(); - let cb_cr_texture = self - .core_video_texture_cache - .create_texture_from_image( - surface.image_buffer.as_concrete_TypeRef(), - ptr::null(), - metal::MTLPixelFormat::RG8Unorm, - surface.image_buffer.get_width_of_plane(1), - surface.image_buffer.get_height_of_plane(1), - 1, - ) - .unwrap(); - ( - gpu::TextureView::from_metal_texture( - &objc2::rc::Retained::retain( - foreign_types::ForeignTypeRef::as_ptr( - y_texture.as_texture_ref(), - ) - as *mut objc2::runtime::ProtocolObject< - dyn objc2_metal::MTLTexture, - >, + let y_texture = self + .core_video_texture_cache + .create_texture_from_image( + surface.image_buffer.as_concrete_TypeRef(), + ptr::null(), + metal::MTLPixelFormat::R8Unorm, + surface.image_buffer.get_width_of_plane(0), + surface.image_buffer.get_height_of_plane(0), + 0, ) - .unwrap(), - gpu::TexelAspects::COLOR, - ), - gpu::TextureView::from_metal_texture( - &objc2::rc::Retained::retain( - foreign_types::ForeignTypeRef::as_ptr( - cb_cr_texture.as_texture_ref(), - ) - as *mut objc2::runtime::ProtocolObject< - dyn objc2_metal::MTLTexture, - >, + .unwrap(); + let cb_cr_texture = self + .core_video_texture_cache + .create_texture_from_image( + surface.image_buffer.as_concrete_TypeRef(), + ptr::null(), + metal::MTLPixelFormat::RG8Unorm, + surface.image_buffer.get_width_of_plane(1), + surface.image_buffer.get_height_of_plane(1), + 1, ) - .unwrap(), - gpu::TexelAspects::COLOR, - ), - ) - }; + .unwrap(); + ( + gpu::TextureView::from_metal_texture( + &objc2::rc::Retained::retain( + foreign_types::ForeignTypeRef::as_ptr( + y_texture.as_texture_ref(), + ) + as *mut objc2::runtime::ProtocolObject< + dyn objc2_metal::MTLTexture, + >, + ) + .unwrap(), + gpu::TexelAspects::COLOR, + ), + gpu::TextureView::from_metal_texture( + &objc2::rc::Retained::retain( + foreign_types::ForeignTypeRef::as_ptr( + cb_cr_texture.as_texture_ref(), + ) + as *mut objc2::runtime::ProtocolObject< + dyn objc2_metal::MTLTexture, + >, + ) + .unwrap(), + gpu::TexelAspects::COLOR, + ), + ) + }; - _encoder.bind( - 0, - &ShaderSurfacesData { - globals, - surface_locals: SurfaceParams { - bounds: surface.bounds.into(), - content_mask: surface.content_mask.bounds.into(), + _encoder.bind( + 0, + &ShaderSurfacesData { + globals, + surface_locals: SurfaceParams { + bounds: surface.bounds.into(), + content_mask: surface.content_mask.bounds.into(), + }, + t_y, + t_cb_cr, + s_surface: self.atlas_sampler, }, - t_y, - t_cb_cr, - s_surface: self.atlas_sampler, - }, - ); + ); - _encoder.draw(0, 4, 0, 1); + _encoder.draw(0, 4, 0, 1); + } } } } } } - drop(pass); self.command_encoder.present(frame); let sync_point = self.gpu.submit(&mut self.command_encoder); @@ -908,79 +817,9 @@ impl BladeRenderer { profiling::scope!("finish"); self.instance_belt.flush(&sync_point); self.atlas.after_frame(&sync_point); + self.atlas.clear_textures(AtlasTextureKind::Path); self.wait_for_gpu(); self.last_sync_point = Some(sync_point); } } - -fn create_path_intermediate_texture( - gpu: &gpu::Context, - format: gpu::TextureFormat, - width: u32, - height: u32, -) -> (gpu::Texture, gpu::TextureView) { - let texture = gpu.create_texture(gpu::TextureDesc { - name: "path intermediate", - format, - size: gpu::Extent { - width, - height, - depth: 1, - }, - array_layer_count: 1, - mip_level_count: 1, - sample_count: 1, - dimension: gpu::TextureDimension::D2, - usage: gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE | gpu::TextureUsage::TARGET, - external: None, - }); - let texture_view = gpu.create_texture_view( - texture, - gpu::TextureViewDesc { - name: "path intermediate view", - format, - dimension: gpu::ViewDimension::D2, - subresources: &Default::default(), - }, - ); - (texture, texture_view) -} - -fn create_msaa_texture_if_needed( - gpu: &gpu::Context, - format: gpu::TextureFormat, - width: u32, - height: u32, - sample_count: u32, -) -> Option<(gpu::Texture, gpu::TextureView)> { - if sample_count <= 1 { - return None; - } - let texture_msaa = gpu.create_texture(gpu::TextureDesc { - name: "path intermediate msaa", - format, - size: gpu::Extent { - width, - height, - depth: 1, - }, - array_layer_count: 1, - mip_level_count: 1, - sample_count, - dimension: gpu::TextureDimension::D2, - usage: gpu::TextureUsage::TARGET, - external: None, - }); - let texture_view_msaa = gpu.create_texture_view( - texture_msaa, - gpu::TextureViewDesc { - name: "path intermediate msaa view", - format, - dimension: gpu::ViewDimension::D2, - subresources: &Default::default(), - }, - ); - - Some((texture_msaa, texture_view_msaa)) -} diff --git a/crates/gpui/src/platform/blade/shaders.wgsl b/crates/gpui/src/platform/blade/shaders.wgsl index b1ffb1812e..0b34a0eea3 100644 --- a/crates/gpui/src/platform/blade/shaders.wgsl +++ b/crates/gpui/src/platform/blade/shaders.wgsl @@ -924,19 +924,16 @@ fn fs_shadow(input: ShadowVarying) -> @location(0) vec4<f32> { // --- path rasterization --- // -struct PathRasterizationVertex { +struct PathVertex { xy_position: vec2<f32>, st_position: vec2<f32>, - color: Background, - bounds: Bounds, + content_mask: Bounds, } - -var<storage, read> b_path_vertices: array<PathRasterizationVertex>; +var<storage, read> b_path_vertices: array<PathVertex>; struct PathRasterizationVarying { @builtin(position) position: vec4<f32>, @location(0) st_position: vec2<f32>, - @location(1) vertex_id: u32, //TODO: use `clip_distance` once Naga supports it @location(3) clip_distances: vec4<f32>, } @@ -948,54 +945,40 @@ fn vs_path_rasterization(@builtin(vertex_index) vertex_id: u32) -> PathRasteriza var out = PathRasterizationVarying(); out.position = to_device_position_impl(v.xy_position); out.st_position = v.st_position; - out.vertex_id = vertex_id; - out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.bounds); + out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.content_mask); return out; } @fragment -fn fs_path_rasterization(input: PathRasterizationVarying) -> @location(0) vec4<f32> { +fn fs_path_rasterization(input: PathRasterizationVarying) -> @location(0) f32 { let dx = dpdx(input.st_position); let dy = dpdy(input.st_position); if (any(input.clip_distances < vec4<f32>(0.0))) { - return vec4<f32>(0.0); + return 0.0; } - let v = b_path_vertices[input.vertex_id]; - let background = v.color; - let bounds = v.bounds; - - var alpha: f32; - if (length(vec2<f32>(dx.x, dy.x)) < 0.001) { - // If the gradient is too small, return a solid color. - alpha = 1.0; - } else { - let gradient = 2.0 * input.st_position.xx * vec2<f32>(dx.x, dy.x) - vec2<f32>(dx.y, dy.y); - let f = input.st_position.x * input.st_position.x - input.st_position.y; - let distance = f / length(gradient); - alpha = saturate(0.5 - distance); - } - let gradient_color = prepare_gradient_color( - background.tag, - background.color_space, - background.solid, - background.colors, - ); - let color = gradient_color(background, input.position.xy, bounds, - gradient_color.solid, gradient_color.color0, gradient_color.color1); - return vec4<f32>(color.rgb * color.a * alpha, color.a * alpha); + let gradient = 2.0 * input.st_position.xx * vec2<f32>(dx.x, dy.x) - vec2<f32>(dx.y, dy.y); + let f = input.st_position.x * input.st_position.x - input.st_position.y; + let distance = f / length(gradient); + return saturate(0.5 - distance); } // --- paths --- // struct PathSprite { bounds: Bounds, + color: Background, + tile: AtlasTile, } var<storage, read> b_path_sprites: array<PathSprite>; struct PathVarying { @builtin(position) position: vec4<f32>, - @location(0) texture_coords: vec2<f32>, + @location(0) tile_position: vec2<f32>, + @location(1) @interpolate(flat) instance_id: u32, + @location(2) @interpolate(flat) color_solid: vec4<f32>, + @location(3) @interpolate(flat) color0: vec4<f32>, + @location(4) @interpolate(flat) color1: vec4<f32>, } @vertex @@ -1003,22 +986,33 @@ fn vs_path(@builtin(vertex_index) vertex_id: u32, @builtin(instance_index) insta let unit_vertex = vec2<f32>(f32(vertex_id & 1u), 0.5 * f32(vertex_id & 2u)); let sprite = b_path_sprites[instance_id]; // Don't apply content mask because it was already accounted for when rasterizing the path. - let device_position = to_device_position(unit_vertex, sprite.bounds); - // For screen-space intermediate texture, convert screen position to texture coordinates - let screen_position = sprite.bounds.origin + unit_vertex * sprite.bounds.size; - let texture_coords = screen_position / globals.viewport_size; var out = PathVarying(); - out.position = device_position; - out.texture_coords = texture_coords; + out.position = to_device_position(unit_vertex, sprite.bounds); + out.tile_position = to_tile_position(unit_vertex, sprite.tile); + out.instance_id = instance_id; + let gradient = prepare_gradient_color( + sprite.color.tag, + sprite.color.color_space, + sprite.color.solid, + sprite.color.colors + ); + out.color_solid = gradient.solid; + out.color0 = gradient.color0; + out.color1 = gradient.color1; return out; } @fragment fn fs_path(input: PathVarying) -> @location(0) vec4<f32> { - let sample = textureSample(t_sprite, s_sprite, input.texture_coords); - return sample; + let sample = textureSample(t_sprite, s_sprite, input.tile_position).r; + let mask = 1.0 - abs(1.0 - sample % 2.0); + let sprite = b_path_sprites[input.instance_id]; + let background = sprite.color; + let color = gradient_color(background, input.position.xy, sprite.bounds, + input.color_solid, input.color0, input.color1); + return blend_color(color, mask); } // --- underlines --- // diff --git a/crates/gpui/src/platform/linux/wayland/window.rs b/crates/gpui/src/platform/linux/wayland/window.rs index 2b2207e22c..255ae9c372 100644 --- a/crates/gpui/src/platform/linux/wayland/window.rs +++ b/crates/gpui/src/platform/linux/wayland/window.rs @@ -111,7 +111,7 @@ pub struct WaylandWindowState { resize_throttle: bool, in_progress_window_controls: Option<WindowControls>, window_controls: WindowControls, - client_inset: Option<Pixels>, + inset: Option<Pixels>, } #[derive(Clone)] @@ -186,7 +186,7 @@ impl WaylandWindowState { hovered: false, in_progress_window_controls: None, window_controls: WindowControls::default(), - client_inset: None, + inset: None, }) } @@ -211,13 +211,6 @@ impl WaylandWindowState { self.display = current_output; scale } - - pub fn inset(&self) -> Pixels { - match self.decorations { - WindowDecorations::Server => px(0.0), - WindowDecorations::Client => self.client_inset.unwrap_or(px(0.0)), - } - } } pub(crate) struct WaylandWindow(pub WaylandWindowStatePtr); @@ -387,7 +380,7 @@ impl WaylandWindowStatePtr { configure.size = if got_unmaximized { Some(state.window_bounds.size) } else { - compute_outer_size(state.inset(), configure.size, state.tiling) + compute_outer_size(state.inset, configure.size, state.tiling) }; if let Some(size) = configure.size { state.window_bounds = Bounds { @@ -407,7 +400,7 @@ impl WaylandWindowStatePtr { let window_geometry = inset_by_tiling( state.bounds.map_origin(|_| px(0.0)), - state.inset(), + state.inset.unwrap_or(px(0.0)), state.tiling, ) .map(|v| v.0 as i32) @@ -825,7 +818,7 @@ impl PlatformWindow for WaylandWindow { } else if state.maximized { WindowBounds::Maximized(state.window_bounds) } else { - let inset = state.inset(); + let inset = state.inset.unwrap_or(px(0.)); drop(state); WindowBounds::Windowed(self.bounds().inset(inset)) } @@ -1080,8 +1073,8 @@ impl PlatformWindow for WaylandWindow { fn set_client_inset(&self, inset: Pixels) { let mut state = self.borrow_mut(); - if Some(inset) != state.client_inset { - state.client_inset = Some(inset); + if Some(inset) != state.inset { + state.inset = Some(inset); update_window(state); } } @@ -1101,7 +1094,9 @@ fn update_window(mut state: RefMut<WaylandWindowState>) { state.renderer.update_transparency(!opaque); let mut opaque_area = state.window_bounds.map(|v| v.0 as i32); - opaque_area.inset(state.inset().0 as i32); + if let Some(inset) = state.inset { + opaque_area.inset(inset.0 as i32); + } let region = state .globals @@ -1174,10 +1169,12 @@ impl ResizeEdge { /// updating to account for the client decorations. But that's not the area we want to render /// to, due to our intrusize CSD. So, here we calculate the 'actual' size, by adding back in the insets fn compute_outer_size( - inset: Pixels, + inset: Option<Pixels>, new_size: Option<Size<Pixels>>, tiling: Tiling, ) -> Option<Size<Pixels>> { + let Some(inset) = inset else { return new_size }; + new_size.map(|mut new_size| { if !tiling.top { new_size.height += inset; diff --git a/crates/gpui/src/platform/linux/x11/client.rs b/crates/gpui/src/platform/linux/x11/client.rs index 573e4addf7..16a7a768e2 100644 --- a/crates/gpui/src/platform/linux/x11/client.rs +++ b/crates/gpui/src/platform/linux/x11/client.rs @@ -1795,7 +1795,6 @@ impl X11ClientState { drop(state); window.refresh(RequestFrameOptions { require_presentation: expose_event_received, - force_render: false, }); } xcb_connection diff --git a/crates/gpui/src/platform/mac/metal_atlas.rs b/crates/gpui/src/platform/mac/metal_atlas.rs index 5d2d8e63e0..366f2dcc3c 100644 --- a/crates/gpui/src/platform/mac/metal_atlas.rs +++ b/crates/gpui/src/platform/mac/metal_atlas.rs @@ -13,25 +13,53 @@ use std::borrow::Cow; pub(crate) struct MetalAtlas(Mutex<MetalAtlasState>); impl MetalAtlas { - pub(crate) fn new(device: Device) -> Self { + pub(crate) fn new(device: Device, path_sample_count: u32) -> Self { MetalAtlas(Mutex::new(MetalAtlasState { device: AssertSend(device), monochrome_textures: Default::default(), polychrome_textures: Default::default(), + path_textures: Default::default(), tiles_by_key: Default::default(), + path_sample_count, })) } pub(crate) fn metal_texture(&self, id: AtlasTextureId) -> metal::Texture { self.0.lock().texture(id).metal_texture.clone() } + + pub(crate) fn msaa_texture(&self, id: AtlasTextureId) -> Option<metal::Texture> { + self.0.lock().texture(id).msaa_texture.clone() + } + + pub(crate) fn allocate( + &self, + size: Size<DevicePixels>, + texture_kind: AtlasTextureKind, + ) -> Option<AtlasTile> { + self.0.lock().allocate(size, texture_kind) + } + + pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) { + let mut lock = self.0.lock(); + let textures = match texture_kind { + AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, + AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, + AtlasTextureKind::Path => &mut lock.path_textures, + }; + for texture in textures.iter_mut() { + texture.clear(); + } + } } struct MetalAtlasState { device: AssertSend<Device>, monochrome_textures: AtlasTextureList<MetalAtlasTexture>, polychrome_textures: AtlasTextureList<MetalAtlasTexture>, + path_textures: AtlasTextureList<MetalAtlasTexture>, tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, + path_sample_count: u32, } impl PlatformAtlas for MetalAtlas { @@ -66,6 +94,7 @@ impl PlatformAtlas for MetalAtlas { let textures = match id.kind { AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, + AtlasTextureKind::Path => &mut lock.polychrome_textures, }; let Some(texture_slot) = textures @@ -99,6 +128,7 @@ impl MetalAtlasState { let textures = match texture_kind { AtlasTextureKind::Monochrome => &mut self.monochrome_textures, AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + AtlasTextureKind::Path => &mut self.path_textures, }; if let Some(tile) = textures @@ -143,14 +173,31 @@ impl MetalAtlasState { pixel_format = metal::MTLPixelFormat::BGRA8Unorm; usage = metal::MTLTextureUsage::ShaderRead; } + AtlasTextureKind::Path => { + pixel_format = metal::MTLPixelFormat::R16Float; + usage = metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead; + } } texture_descriptor.set_pixel_format(pixel_format); texture_descriptor.set_usage(usage); let metal_texture = self.device.new_texture(&texture_descriptor); + // We currently only enable MSAA for path textures. + let msaa_texture = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path { + let mut descriptor = texture_descriptor.clone(); + descriptor.set_texture_type(metal::MTLTextureType::D2Multisample); + descriptor.set_storage_mode(metal::MTLStorageMode::Private); + descriptor.set_sample_count(self.path_sample_count as _); + let msaa_texture = self.device.new_texture(&descriptor); + Some(msaa_texture) + } else { + None + }; + let texture_list = match kind { AtlasTextureKind::Monochrome => &mut self.monochrome_textures, AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + AtlasTextureKind::Path => &mut self.path_textures, }; let index = texture_list.free_list.pop(); @@ -162,6 +209,7 @@ impl MetalAtlasState { }, allocator: etagere::BucketedAtlasAllocator::new(size.into()), metal_texture: AssertSend(metal_texture), + msaa_texture: AssertSend(msaa_texture), live_atlas_keys: 0, }; @@ -178,6 +226,7 @@ impl MetalAtlasState { let textures = match id.kind { crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + crate::AtlasTextureKind::Path => &self.path_textures, }; textures[id.index as usize].as_ref().unwrap() } @@ -187,10 +236,15 @@ struct MetalAtlasTexture { id: AtlasTextureId, allocator: BucketedAtlasAllocator, metal_texture: AssertSend<metal::Texture>, + msaa_texture: AssertSend<Option<metal::Texture>>, live_atlas_keys: u32, } impl MetalAtlasTexture { + fn clear(&mut self) { + self.allocator.clear(); + } + fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { let allocation = self.allocator.allocate(size.into())?; let tile = AtlasTile { diff --git a/crates/gpui/src/platform/mac/metal_renderer.rs b/crates/gpui/src/platform/mac/metal_renderer.rs index fb5cb852d6..3cdc2dd2cf 100644 --- a/crates/gpui/src/platform/mac/metal_renderer.rs +++ b/crates/gpui/src/platform/mac/metal_renderer.rs @@ -1,30 +1,27 @@ use super::metal_atlas::MetalAtlas; use crate::{ - AtlasTextureId, Background, Bounds, ContentMask, DevicePixels, MonochromeSprite, PaintSurface, - Path, Point, PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, - Surface, Underline, point, size, + AtlasTextureId, AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels, + MonochromeSprite, PaintSurface, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch, + Quad, ScaledPixels, Scene, Shadow, Size, Surface, Underline, point, size, }; -use anyhow::Result; +use anyhow::{Context as _, Result}; use block::ConcreteBlock; use cocoa::{ base::{NO, YES}, foundation::{NSSize, NSUInteger}, quartzcore::AutoresizingMask, }; - +use collections::HashMap; use core_foundation::base::TCFType; use core_video::{ metal_texture::CVMetalTextureGetTexture, metal_texture_cache::CVMetalTextureCache, pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange, }; use foreign_types::{ForeignType, ForeignTypeRef}; -use metal::{ - CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange, - RenderPassColorAttachmentDescriptorRef, -}; +use metal::{CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange}; use objc::{self, msg_send, sel, sel_impl}; use parking_lot::Mutex; - +use smallvec::SmallVec; use std::{cell::Cell, ffi::c_void, mem, ptr, sync::Arc}; // Exported to metal @@ -114,17 +111,6 @@ pub(crate) struct MetalRenderer { instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>, sprite_atlas: Arc<MetalAtlas>, core_video_texture_cache: core_video::metal_texture_cache::CVMetalTextureCache, - path_intermediate_texture: Option<metal::Texture>, - path_intermediate_msaa_texture: Option<metal::Texture>, - path_sample_count: u32, -} - -#[repr(C)] -pub struct PathRasterizationVertex { - pub xy_position: Point<ScaledPixels>, - pub st_position: Point<f32>, - pub color: Background, - pub bounds: Bounds<ScaledPixels>, } impl MetalRenderer { @@ -189,10 +175,10 @@ impl MetalRenderer { "paths_rasterization", "path_rasterization_vertex", "path_rasterization_fragment", - MTLPixelFormat::BGRA8Unorm, + MTLPixelFormat::R16Float, PATH_SAMPLE_COUNT, ); - let path_sprites_pipeline_state = build_path_sprite_pipeline_state( + let path_sprites_pipeline_state = build_pipeline_state( &device, &library, "path_sprites", @@ -250,7 +236,7 @@ impl MetalRenderer { ); let command_queue = device.new_command_queue(); - let sprite_atlas = Arc::new(MetalAtlas::new(device.clone())); + let sprite_atlas = Arc::new(MetalAtlas::new(device.clone(), PATH_SAMPLE_COUNT)); let core_video_texture_cache = CVMetalTextureCache::new(None, device.clone(), None).unwrap(); @@ -271,9 +257,6 @@ impl MetalRenderer { instance_buffer_pool, sprite_atlas, core_video_texture_cache, - path_intermediate_texture: None, - path_intermediate_msaa_texture: None, - path_sample_count: PATH_SAMPLE_COUNT, } } @@ -306,31 +289,6 @@ impl MetalRenderer { setDrawableSize: size ]; } - let device_pixels_size = Size { - width: DevicePixels(size.width as i32), - height: DevicePixels(size.height as i32), - }; - self.update_path_intermediate_textures(device_pixels_size); - } - - fn update_path_intermediate_textures(&mut self, size: Size<DevicePixels>) { - let texture_descriptor = metal::TextureDescriptor::new(); - texture_descriptor.set_width(size.width.0 as u64); - texture_descriptor.set_height(size.height.0 as u64); - texture_descriptor.set_pixel_format(metal::MTLPixelFormat::BGRA8Unorm); - texture_descriptor - .set_usage(metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead); - self.path_intermediate_texture = Some(self.device.new_texture(&texture_descriptor)); - - if self.path_sample_count > 1 { - let mut msaa_descriptor = texture_descriptor.clone(); - msaa_descriptor.set_texture_type(metal::MTLTextureType::D2Multisample); - msaa_descriptor.set_storage_mode(metal::MTLStorageMode::Private); - msaa_descriptor.set_sample_count(self.path_sample_count as _); - self.path_intermediate_msaa_texture = Some(self.device.new_texture(&msaa_descriptor)); - } else { - self.path_intermediate_msaa_texture = None; - } } pub fn update_transparency(&self, _transparent: bool) { @@ -416,18 +374,38 @@ impl MetalRenderer { ) -> Result<metal::CommandBuffer> { let command_queue = self.command_queue.clone(); let command_buffer = command_queue.new_command_buffer(); - let alpha = if self.layer.is_opaque() { 1. } else { 0. }; let mut instance_offset = 0; - let mut command_encoder = new_command_encoder( - command_buffer, - drawable, - viewport_size, - |color_attachment| { - color_attachment.set_load_action(metal::MTLLoadAction::Clear); - color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha)); - }, - ); + let path_tiles = self + .rasterize_paths( + scene.paths(), + instance_buffer, + &mut instance_offset, + command_buffer, + ) + .with_context(|| format!("rasterizing {} paths", scene.paths().len()))?; + + let render_pass_descriptor = metal::RenderPassDescriptor::new(); + let color_attachment = render_pass_descriptor + .color_attachments() + .object_at(0) + .unwrap(); + + color_attachment.set_texture(Some(drawable.texture())); + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_store_action(metal::MTLStoreAction::Store); + let alpha = if self.layer.is_opaque() { 1. } else { 0. }; + color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha)); + let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); + + command_encoder.set_viewport(metal::MTLViewport { + originX: 0.0, + originY: 0.0, + width: i32::from(viewport_size.width) as f64, + height: i32::from(viewport_size.height) as f64, + znear: 0.0, + zfar: 1.0, + }); for batch in scene.batches() { let ok = match batch { @@ -436,53 +414,29 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), PrimitiveBatch::Quads(quads) => self.draw_quads( quads, instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, + ), + PrimitiveBatch::Paths(paths) => self.draw_paths( + paths, + &path_tiles, + instance_buffer, + &mut instance_offset, + viewport_size, + command_encoder, ), - PrimitiveBatch::Paths(paths) => { - command_encoder.end_encoding(); - - let did_draw = self.draw_paths_to_intermediate( - paths, - instance_buffer, - &mut instance_offset, - viewport_size, - command_buffer, - ); - - command_encoder = new_command_encoder( - command_buffer, - drawable, - viewport_size, - |color_attachment| { - color_attachment.set_load_action(metal::MTLLoadAction::Load); - }, - ); - - if did_draw { - self.draw_paths_from_intermediate( - paths, - instance_buffer, - &mut instance_offset, - viewport_size, - &command_encoder, - ) - } else { - false - } - } PrimitiveBatch::Underlines(underlines) => self.draw_underlines( underlines, instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), PrimitiveBatch::MonochromeSprites { texture_id, @@ -493,7 +447,7 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), PrimitiveBatch::PolychromeSprites { texture_id, @@ -504,16 +458,17 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces( surfaces, instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), }; + if !ok { command_encoder.end_encoding(); anyhow::bail!( @@ -538,90 +493,104 @@ impl MetalRenderer { Ok(command_buffer.to_owned()) } - fn draw_paths_to_intermediate( + fn rasterize_paths( &self, paths: &[Path<ScaledPixels>], instance_buffer: &mut InstanceBuffer, instance_offset: &mut usize, - viewport_size: Size<DevicePixels>, command_buffer: &metal::CommandBufferRef, - ) -> bool { - if paths.is_empty() { - return true; - } - let Some(intermediate_texture) = &self.path_intermediate_texture else { - return false; - }; + ) -> Option<HashMap<PathId, AtlasTile>> { + self.sprite_atlas.clear_textures(AtlasTextureKind::Path); - let render_pass_descriptor = metal::RenderPassDescriptor::new(); - let color_attachment = render_pass_descriptor - .color_attachments() - .object_at(0) - .unwrap(); - color_attachment.set_load_action(metal::MTLLoadAction::Clear); - color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., 0.)); - - if let Some(msaa_texture) = &self.path_intermediate_msaa_texture { - color_attachment.set_texture(Some(msaa_texture)); - color_attachment.set_resolve_texture(Some(intermediate_texture)); - color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve); - } else { - color_attachment.set_texture(Some(intermediate_texture)); - color_attachment.set_store_action(metal::MTLStoreAction::Store); - } - - let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); - command_encoder.set_render_pipeline_state(&self.paths_rasterization_pipeline_state); - - align_offset(instance_offset); - let mut vertices = Vec::new(); + let mut tiles = HashMap::default(); + let mut vertices_by_texture_id = HashMap::default(); for path in paths { - vertices.extend(path.vertices.iter().map(|v| PathRasterizationVertex { - xy_position: v.xy_position, - st_position: v.st_position, - color: path.color, - bounds: path.bounds.intersect(&path.content_mask.bounds), - })); - } - let vertices_bytes_len = mem::size_of_val(vertices.as_slice()); - let next_offset = *instance_offset + vertices_bytes_len; - if next_offset > instance_buffer.size { - command_encoder.end_encoding(); - return false; - } - command_encoder.set_vertex_buffer( - PathRasterizationInputIndex::Vertices as u64, - Some(&instance_buffer.metal_buffer), - *instance_offset as u64, - ); - command_encoder.set_vertex_bytes( - PathRasterizationInputIndex::ViewportSize as u64, - mem::size_of_val(&viewport_size) as u64, - &viewport_size as *const Size<DevicePixels> as *const _, - ); - command_encoder.set_fragment_buffer( - PathRasterizationInputIndex::Vertices as u64, - Some(&instance_buffer.metal_buffer), - *instance_offset as u64, - ); - let buffer_contents = - unsafe { (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) }; - unsafe { - ptr::copy_nonoverlapping( - vertices.as_ptr() as *const u8, - buffer_contents, - vertices_bytes_len, - ); - } - command_encoder.draw_primitives( - metal::MTLPrimitiveType::Triangle, - 0, - vertices.len() as u64, - ); - *instance_offset = next_offset; + let clipped_bounds = path.bounds.intersect(&path.content_mask.bounds); - command_encoder.end_encoding(); - true + let tile = self + .sprite_atlas + .allocate(clipped_bounds.size.map(Into::into), AtlasTextureKind::Path)?; + vertices_by_texture_id + .entry(tile.texture_id) + .or_insert(Vec::new()) + .extend(path.vertices.iter().map(|vertex| PathVertex { + xy_position: vertex.xy_position - clipped_bounds.origin + + tile.bounds.origin.map(Into::into), + st_position: vertex.st_position, + content_mask: ContentMask { + bounds: tile.bounds.map(Into::into), + }, + })); + tiles.insert(path.id, tile); + } + + for (texture_id, vertices) in vertices_by_texture_id { + align_offset(instance_offset); + let vertices_bytes_len = mem::size_of_val(vertices.as_slice()); + let next_offset = *instance_offset + vertices_bytes_len; + if next_offset > instance_buffer.size { + return None; + } + + let render_pass_descriptor = metal::RenderPassDescriptor::new(); + let color_attachment = render_pass_descriptor + .color_attachments() + .object_at(0) + .unwrap(); + + let texture = self.sprite_atlas.metal_texture(texture_id); + let msaa_texture = self.sprite_atlas.msaa_texture(texture_id); + + if let Some(msaa_texture) = msaa_texture { + color_attachment.set_texture(Some(&msaa_texture)); + color_attachment.set_resolve_texture(Some(&texture)); + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve); + } else { + color_attachment.set_texture(Some(&texture)); + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_store_action(metal::MTLStoreAction::Store); + } + color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., 1.)); + + let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); + command_encoder.set_render_pipeline_state(&self.paths_rasterization_pipeline_state); + command_encoder.set_vertex_buffer( + PathRasterizationInputIndex::Vertices as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + let texture_size = Size { + width: DevicePixels::from(texture.width()), + height: DevicePixels::from(texture.height()), + }; + command_encoder.set_vertex_bytes( + PathRasterizationInputIndex::AtlasTextureSize as u64, + mem::size_of_val(&texture_size) as u64, + &texture_size as *const Size<DevicePixels> as *const _, + ); + + let buffer_contents = unsafe { + (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) + }; + unsafe { + ptr::copy_nonoverlapping( + vertices.as_ptr() as *const u8, + buffer_contents, + vertices_bytes_len, + ); + } + + command_encoder.draw_primitives( + metal::MTLPrimitiveType::Triangle, + 0, + vertices.len() as u64, + ); + command_encoder.end_encoding(); + *instance_offset = next_offset; + } + + Some(tiles) } fn draw_shadows( @@ -746,21 +715,18 @@ impl MetalRenderer { true } - fn draw_paths_from_intermediate( + fn draw_paths( &self, paths: &[Path<ScaledPixels>], + tiles_by_path_id: &HashMap<PathId, AtlasTile>, instance_buffer: &mut InstanceBuffer, instance_offset: &mut usize, viewport_size: Size<DevicePixels>, command_encoder: &metal::RenderCommandEncoderRef, ) -> bool { - let Some(ref first_path) = paths.first() else { + if paths.is_empty() { return true; - }; - - let Some(ref intermediate_texture) = self.path_intermediate_texture else { - return false; - }; + } command_encoder.set_render_pipeline_state(&self.path_sprites_pipeline_state); command_encoder.set_vertex_buffer( @@ -774,65 +740,88 @@ impl MetalRenderer { &viewport_size as *const Size<DevicePixels> as *const _, ); - command_encoder.set_fragment_texture( - SpriteInputIndex::AtlasTexture as u64, - Some(intermediate_texture), - ); + let mut prev_texture_id = None; + let mut sprites = SmallVec::<[_; 1]>::new(); + let mut paths_and_tiles = paths + .iter() + .map(|path| (path, tiles_by_path_id.get(&path.id).unwrap())) + .peekable(); - // When copying paths from the intermediate texture to the drawable, - // each pixel must only be copied once, in case of transparent paths. - // - // If all paths have the same draw order, then their bounds are all - // disjoint, so we can copy each path's bounds individually. If this - // batch combines different draw orders, we perform a single copy - // for a minimal spanning rect. - let sprites; - if paths.last().unwrap().order == first_path.order { - sprites = paths - .iter() - .map(|path| PathSprite { - bounds: path.bounds, - }) - .collect(); - } else { - let mut bounds = first_path.bounds; - for path in paths.iter().skip(1) { - bounds = bounds.union(&path.bounds); + loop { + if let Some((path, tile)) = paths_and_tiles.peek() { + if prev_texture_id.map_or(true, |texture_id| texture_id == tile.texture_id) { + prev_texture_id = Some(tile.texture_id); + let origin = path.bounds.intersect(&path.content_mask.bounds).origin; + sprites.push(PathSprite { + bounds: Bounds { + origin: origin.map(|p| p.floor()), + size: tile.bounds.size.map(Into::into), + }, + color: path.color, + tile: (*tile).clone(), + }); + paths_and_tiles.next(); + continue; + } + } + + if sprites.is_empty() { + break; + } else { + align_offset(instance_offset); + let texture_id = prev_texture_id.take().unwrap(); + let texture: metal::Texture = self.sprite_atlas.metal_texture(texture_id); + let texture_size = size( + DevicePixels(texture.width() as i32), + DevicePixels(texture.height() as i32), + ); + + command_encoder.set_vertex_buffer( + SpriteInputIndex::Sprites as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + command_encoder.set_vertex_bytes( + SpriteInputIndex::AtlasTextureSize as u64, + mem::size_of_val(&texture_size) as u64, + &texture_size as *const Size<DevicePixels> as *const _, + ); + command_encoder.set_fragment_buffer( + SpriteInputIndex::Sprites as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + command_encoder + .set_fragment_texture(SpriteInputIndex::AtlasTexture as u64, Some(&texture)); + + let sprite_bytes_len = mem::size_of_val(sprites.as_slice()); + let next_offset = *instance_offset + sprite_bytes_len; + if next_offset > instance_buffer.size { + return false; + } + + let buffer_contents = unsafe { + (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) + }; + + unsafe { + ptr::copy_nonoverlapping( + sprites.as_ptr() as *const u8, + buffer_contents, + sprite_bytes_len, + ); + } + + command_encoder.draw_primitives_instanced( + metal::MTLPrimitiveType::Triangle, + 0, + 6, + sprites.len() as u64, + ); + *instance_offset = next_offset; + sprites.clear(); } - sprites = vec![PathSprite { bounds }]; } - - align_offset(instance_offset); - let sprite_bytes_len = mem::size_of_val(sprites.as_slice()); - let next_offset = *instance_offset + sprite_bytes_len; - if next_offset > instance_buffer.size { - return false; - } - - command_encoder.set_vertex_buffer( - SpriteInputIndex::Sprites as u64, - Some(&instance_buffer.metal_buffer), - *instance_offset as u64, - ); - - let buffer_contents = - unsafe { (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) }; - unsafe { - ptr::copy_nonoverlapping( - sprites.as_ptr() as *const u8, - buffer_contents, - sprite_bytes_len, - ); - } - - command_encoder.draw_primitives_instanced( - metal::MTLPrimitiveType::Triangle, - 0, - 6, - sprites.len() as u64, - ); - *instance_offset = next_offset; - true } @@ -1147,33 +1136,6 @@ impl MetalRenderer { } } -fn new_command_encoder<'a>( - command_buffer: &'a metal::CommandBufferRef, - drawable: &'a metal::MetalDrawableRef, - viewport_size: Size<DevicePixels>, - configure_color_attachment: impl Fn(&RenderPassColorAttachmentDescriptorRef), -) -> &'a metal::RenderCommandEncoderRef { - let render_pass_descriptor = metal::RenderPassDescriptor::new(); - let color_attachment = render_pass_descriptor - .color_attachments() - .object_at(0) - .unwrap(); - color_attachment.set_texture(Some(drawable.texture())); - color_attachment.set_store_action(metal::MTLStoreAction::Store); - configure_color_attachment(color_attachment); - - let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); - command_encoder.set_viewport(metal::MTLViewport { - originX: 0.0, - originY: 0.0, - width: i32::from(viewport_size.width) as f64, - height: i32::from(viewport_size.height) as f64, - znear: 0.0, - zfar: 1.0, - }); - command_encoder -} - fn build_pipeline_state( device: &metal::DeviceRef, library: &metal::LibraryRef, @@ -1208,40 +1170,6 @@ fn build_pipeline_state( .expect("could not create render pipeline state") } -fn build_path_sprite_pipeline_state( - device: &metal::DeviceRef, - library: &metal::LibraryRef, - label: &str, - vertex_fn_name: &str, - fragment_fn_name: &str, - pixel_format: metal::MTLPixelFormat, -) -> metal::RenderPipelineState { - let vertex_fn = library - .get_function(vertex_fn_name, None) - .expect("error locating vertex function"); - let fragment_fn = library - .get_function(fragment_fn_name, None) - .expect("error locating fragment function"); - - let descriptor = metal::RenderPipelineDescriptor::new(); - descriptor.set_label(label); - descriptor.set_vertex_function(Some(vertex_fn.as_ref())); - descriptor.set_fragment_function(Some(fragment_fn.as_ref())); - let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); - color_attachment.set_pixel_format(pixel_format); - color_attachment.set_blending_enabled(true); - color_attachment.set_rgb_blend_operation(metal::MTLBlendOperation::Add); - color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add); - color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One); - color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One); - color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); - color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::One); - - device - .new_render_pipeline_state(&descriptor) - .expect("could not create render pipeline state") -} - fn build_path_rasterization_pipeline_state( device: &metal::DeviceRef, library: &metal::LibraryRef, @@ -1264,7 +1192,7 @@ fn build_path_rasterization_pipeline_state( descriptor.set_fragment_function(Some(fragment_fn.as_ref())); if path_sample_count > 1 { descriptor.set_raster_sample_count(path_sample_count as _); - descriptor.set_alpha_to_coverage_enabled(false); + descriptor.set_alpha_to_coverage_enabled(true); } let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); color_attachment.set_pixel_format(pixel_format); @@ -1273,8 +1201,8 @@ fn build_path_rasterization_pipeline_state( color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add); color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One); color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One); - color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); - color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); + color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::One); device .new_render_pipeline_state(&descriptor) @@ -1329,13 +1257,15 @@ enum SurfaceInputIndex { #[repr(C)] enum PathRasterizationInputIndex { Vertices = 0, - ViewportSize = 1, + AtlasTextureSize = 1, } #[derive(Clone, Debug, Eq, PartialEq)] #[repr(C)] pub struct PathSprite { pub bounds: Bounds<ScaledPixels>, + pub color: Background, + pub tile: AtlasTile, } #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/crates/gpui/src/platform/mac/shaders.metal b/crates/gpui/src/platform/mac/shaders.metal index f9d5bdbf4c..64ebb1e22b 100644 --- a/crates/gpui/src/platform/mac/shaders.metal +++ b/crates/gpui/src/platform/mac/shaders.metal @@ -701,117 +701,107 @@ fragment float4 polychrome_sprite_fragment( struct PathRasterizationVertexOutput { float4 position [[position]]; float2 st_position; - uint vertex_id [[flat]]; float clip_rect_distance [[clip_distance]][4]; }; struct PathRasterizationFragmentInput { float4 position [[position]]; float2 st_position; - uint vertex_id [[flat]]; }; vertex PathRasterizationVertexOutput path_rasterization_vertex( - uint vertex_id [[vertex_id]], - constant PathRasterizationVertex *vertices [[buffer(PathRasterizationInputIndex_Vertices)]], - constant Size_DevicePixels *atlas_size [[buffer(PathRasterizationInputIndex_ViewportSize)]] -) { - PathRasterizationVertex v = vertices[vertex_id]; + uint vertex_id [[vertex_id]], + constant PathVertex_ScaledPixels *vertices + [[buffer(PathRasterizationInputIndex_Vertices)]], + constant Size_DevicePixels *atlas_size + [[buffer(PathRasterizationInputIndex_AtlasTextureSize)]]) { + PathVertex_ScaledPixels v = vertices[vertex_id]; float2 vertex_position = float2(v.xy_position.x, v.xy_position.y); - float4 position = float4( - vertex_position * float2(2. / atlas_size->width, -2. / atlas_size->height) + float2(-1., 1.), - 0., - 1. - ); + float2 viewport_size = float2(atlas_size->width, atlas_size->height); return PathRasterizationVertexOutput{ - position, + float4(vertex_position / viewport_size * float2(2., -2.) + + float2(-1., 1.), + 0., 1.), float2(v.st_position.x, v.st_position.y), - vertex_id, - { - v.xy_position.x - v.bounds.origin.x, - v.bounds.origin.x + v.bounds.size.width - v.xy_position.x, - v.xy_position.y - v.bounds.origin.y, - v.bounds.origin.y + v.bounds.size.height - v.xy_position.y - } - }; + {v.xy_position.x - v.content_mask.bounds.origin.x, + v.content_mask.bounds.origin.x + v.content_mask.bounds.size.width - + v.xy_position.x, + v.xy_position.y - v.content_mask.bounds.origin.y, + v.content_mask.bounds.origin.y + v.content_mask.bounds.size.height - + v.xy_position.y}}; } -fragment float4 path_rasterization_fragment( - PathRasterizationFragmentInput input [[stage_in]], - constant PathRasterizationVertex *vertices [[buffer(PathRasterizationInputIndex_Vertices)]] -) { +fragment float4 path_rasterization_fragment(PathRasterizationFragmentInput input + [[stage_in]]) { float2 dx = dfdx(input.st_position); float2 dy = dfdy(input.st_position); - - PathRasterizationVertex v = vertices[input.vertex_id]; - Background background = v.color; - Bounds_ScaledPixels path_bounds = v.bounds; - float alpha; - if (length(float2(dx.x, dy.x)) < 0.001) { - alpha = 1.0; - } else { - float2 gradient = float2( - (2. * input.st_position.x) * dx.x - dx.y, - (2. * input.st_position.x) * dy.x - dy.y - ); - float f = (input.st_position.x * input.st_position.x) - input.st_position.y; - float distance = f / length(gradient); - alpha = saturate(0.5 - distance); - } - - GradientColor gradient_color = prepare_fill_color( - background.tag, - background.color_space, - background.solid, - background.colors[0].color, - background.colors[1].color - ); - - float4 color = fill_color( - background, - input.position.xy, - path_bounds, - gradient_color.solid, - gradient_color.color0, - gradient_color.color1 - ); - return float4(color.rgb * color.a * alpha, alpha * color.a); + float2 gradient = float2((2. * input.st_position.x) * dx.x - dx.y, + (2. * input.st_position.x) * dy.x - dy.y); + float f = (input.st_position.x * input.st_position.x) - input.st_position.y; + float distance = f / length(gradient); + float alpha = saturate(0.5 - distance); + return float4(alpha, 0., 0., 1.); } struct PathSpriteVertexOutput { float4 position [[position]]; - float2 texture_coords; + float2 tile_position; + uint sprite_id [[flat]]; + float4 solid_color [[flat]]; + float4 color0 [[flat]]; + float4 color1 [[flat]]; }; vertex PathSpriteVertexOutput path_sprite_vertex( - uint unit_vertex_id [[vertex_id]], - uint sprite_id [[instance_id]], - constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]], - constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], - constant Size_DevicePixels *viewport_size [[buffer(SpriteInputIndex_ViewportSize)]] -) { + uint unit_vertex_id [[vertex_id]], uint sprite_id [[instance_id]], + constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]], + constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], + constant Size_DevicePixels *viewport_size + [[buffer(SpriteInputIndex_ViewportSize)]], + constant Size_DevicePixels *atlas_size + [[buffer(SpriteInputIndex_AtlasTextureSize)]]) { + float2 unit_vertex = unit_vertices[unit_vertex_id]; PathSprite sprite = sprites[sprite_id]; // Don't apply content mask because it was already accounted for when // rasterizing the path. float4 device_position = to_device_position(unit_vertex, sprite.bounds, viewport_size); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile, atlas_size); - float2 screen_position = float2(sprite.bounds.origin.x, sprite.bounds.origin.y) + unit_vertex * float2(sprite.bounds.size.width, sprite.bounds.size.height); - float2 texture_coords = screen_position / float2(viewport_size->width, viewport_size->height); + GradientColor gradient = prepare_fill_color( + sprite.color.tag, + sprite.color.color_space, + sprite.color.solid, + sprite.color.colors[0].color, + sprite.color.colors[1].color + ); return PathSpriteVertexOutput{ device_position, - texture_coords + tile_position, + sprite_id, + gradient.solid, + gradient.color0, + gradient.color1 }; } fragment float4 path_sprite_fragment( - PathSpriteVertexOutput input [[stage_in]], - texture2d<float> intermediate_texture [[texture(SpriteInputIndex_AtlasTexture)]] -) { - constexpr sampler intermediate_texture_sampler(mag_filter::linear, min_filter::linear); - return intermediate_texture.sample(intermediate_texture_sampler, input.texture_coords); + PathSpriteVertexOutput input [[stage_in]], + constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], + texture2d<float> atlas_texture [[texture(SpriteInputIndex_AtlasTexture)]]) { + constexpr sampler atlas_texture_sampler(mag_filter::linear, + min_filter::linear); + float4 sample = + atlas_texture.sample(atlas_texture_sampler, input.tile_position); + float mask = 1. - abs(1. - fmod(sample.r, 2.)); + PathSprite sprite = sprites[input.sprite_id]; + Background background = sprite.color; + float4 color = fill_color(background, input.position.xy, sprite.bounds, + input.solid_color, input.color0, input.color1); + color.a *= mask; + return color; } struct SurfaceVertexOutput { diff --git a/crates/gpui/src/platform/test/window.rs b/crates/gpui/src/platform/test/window.rs index e15bd7aeec..1b88415d3b 100644 --- a/crates/gpui/src/platform/test/window.rs +++ b/crates/gpui/src/platform/test/window.rs @@ -341,7 +341,7 @@ impl PlatformAtlas for TestAtlas { crate::AtlasTile { texture_id: AtlasTextureId { index: texture_id, - kind: crate::AtlasTextureKind::Monochrome, + kind: crate::AtlasTextureKind::Path, }, tile_id: TileId(tile_id), padding: 0, diff --git a/crates/gpui/src/platform/windows.rs b/crates/gpui/src/platform/windows.rs index 5268d3ccba..4bdf42080d 100644 --- a/crates/gpui/src/platform/windows.rs +++ b/crates/gpui/src/platform/windows.rs @@ -1,8 +1,6 @@ mod clipboard; mod destination_list; mod direct_write; -mod directx_atlas; -mod directx_renderer; mod dispatcher; mod display; mod events; @@ -16,8 +14,6 @@ mod wrapper; pub(crate) use clipboard::*; pub(crate) use destination_list::*; pub(crate) use direct_write::*; -pub(crate) use directx_atlas::*; -pub(crate) use directx_renderer::*; pub(crate) use dispatcher::*; pub(crate) use display::*; pub(crate) use events::*; diff --git a/crates/gpui/src/platform/windows/color_text_raster.hlsl b/crates/gpui/src/platform/windows/color_text_raster.hlsl deleted file mode 100644 index ccc5fa26f0..0000000000 --- a/crates/gpui/src/platform/windows/color_text_raster.hlsl +++ /dev/null @@ -1,39 +0,0 @@ -struct RasterVertexOutput { - float4 position : SV_Position; - float2 texcoord : TEXCOORD0; -}; - -RasterVertexOutput emoji_rasterization_vertex(uint vertexID : SV_VERTEXID) -{ - RasterVertexOutput output; - output.texcoord = float2((vertexID << 1) & 2, vertexID & 2); - output.position = float4(output.texcoord * 2.0f - 1.0f, 0.0f, 1.0f); - output.position.y = -output.position.y; - - return output; -} - -struct PixelInput { - float4 position: SV_Position; - float2 texcoord : TEXCOORD0; -}; - -struct Bounds { - int2 origin; - int2 size; -}; - -Texture2D<float4> t_layer : register(t0); -SamplerState s_layer : register(s0); - -cbuffer GlyphLayerTextureParams : register(b0) { - Bounds bounds; - float4 run_color; -}; - -float4 emoji_rasterization_fragment(PixelInput input): SV_Target { - float3 sampled = t_layer.Sample(s_layer, input.texcoord.xy).rgb; - float alpha = (sampled.r + sampled.g + sampled.b) / 3; - - return float4(run_color.rgb, alpha); -} diff --git a/crates/gpui/src/platform/windows/direct_write.rs b/crates/gpui/src/platform/windows/direct_write.rs index 587cb7b4a6..ada306c15c 100644 --- a/crates/gpui/src/platform/windows/direct_write.rs +++ b/crates/gpui/src/platform/windows/direct_write.rs @@ -10,11 +10,10 @@ use windows::{ Foundation::*, Globalization::GetUserDefaultLocaleName, Graphics::{ - Direct3D::D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, - Direct3D11::*, + Direct2D::{Common::*, *}, DirectWrite::*, Dxgi::Common::*, - Gdi::{IsRectEmpty, LOGFONTW}, + Gdi::LOGFONTW, Imaging::*, }, System::SystemServices::LOCALE_NAME_MAX_LENGTH, @@ -41,21 +40,16 @@ struct DirectWriteComponent { locale: String, factory: IDWriteFactory5, bitmap_factory: AgileReference<IWICImagingFactory>, + d2d1_factory: ID2D1Factory, in_memory_loader: IDWriteInMemoryFontFileLoader, builder: IDWriteFontSetBuilder1, text_renderer: Arc<TextRendererWrapper>, - - render_params: IDWriteRenderingParams3, - gpu_state: GPUState, + render_context: GlyphRenderContext, } -struct GPUState { - device: ID3D11Device, - device_context: ID3D11DeviceContext, - sampler: [Option<ID3D11SamplerState>; 1], - blend_state: ID3D11BlendState, - vertex_shader: ID3D11VertexShader, - pixel_shader: ID3D11PixelShader, +struct GlyphRenderContext { + params: IDWriteRenderingParams3, + dc_target: ID2D1DeviceContext4, } struct DirectWriteState { @@ -76,11 +70,12 @@ struct FontIdentifier { } impl DirectWriteComponent { - pub fn new(bitmap_factory: &IWICImagingFactory, gpu_context: &DirectXDevices) -> Result<Self> { - // todo: ideally this would not be a large unsafe block but smaller isolated ones for easier auditing + pub fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> { unsafe { let factory: IDWriteFactory5 = DWriteCreateFactory(DWRITE_FACTORY_TYPE_SHARED)?; let bitmap_factory = AgileReference::new(bitmap_factory)?; + let d2d1_factory: ID2D1Factory = + D2D1CreateFactory(D2D1_FACTORY_TYPE_MULTI_THREADED, None)?; // The `IDWriteInMemoryFontFileLoader` here is supported starting from // Windows 10 Creators Update, which consequently requires the entire // `DirectWriteTextSystem` to run on `win10 1703`+. @@ -91,132 +86,60 @@ impl DirectWriteComponent { GetUserDefaultLocaleName(&mut locale_vec); let locale = String::from_utf16_lossy(&locale_vec); let text_renderer = Arc::new(TextRendererWrapper::new(&locale)); - - let render_params = { - let default_params: IDWriteRenderingParams3 = - factory.CreateRenderingParams()?.cast()?; - let gamma = default_params.GetGamma(); - let enhanced_contrast = default_params.GetEnhancedContrast(); - let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); - let cleartype_level = default_params.GetClearTypeLevel(); - let grid_fit_mode = default_params.GetGridFitMode(); - - factory.CreateCustomRenderingParams( - gamma, - enhanced_contrast, - gray_contrast, - cleartype_level, - DWRITE_PIXEL_GEOMETRY_RGB, - DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, - grid_fit_mode, - )? - }; - - let gpu_state = GPUState::new(gpu_context)?; + let render_context = GlyphRenderContext::new(&factory, &d2d1_factory)?; Ok(DirectWriteComponent { locale, factory, bitmap_factory, + d2d1_factory, in_memory_loader, builder, text_renderer, - render_params, - gpu_state, + render_context, }) } } } -impl GPUState { - fn new(gpu_context: &DirectXDevices) -> Result<Self> { - let device = gpu_context.device.clone(); - let device_context = gpu_context.device_context.clone(); +impl GlyphRenderContext { + pub fn new(factory: &IDWriteFactory5, d2d1_factory: &ID2D1Factory) -> Result<Self> { + unsafe { + let default_params: IDWriteRenderingParams3 = + factory.CreateRenderingParams()?.cast()?; + let gamma = default_params.GetGamma(); + let enhanced_contrast = default_params.GetEnhancedContrast(); + let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); + let cleartype_level = default_params.GetClearTypeLevel(); + let grid_fit_mode = default_params.GetGridFitMode(); - let blend_state = { - let mut blend_state = None; - let desc = D3D11_BLEND_DESC { - AlphaToCoverageEnable: false.into(), - IndependentBlendEnable: false.into(), - RenderTarget: [ - D3D11_RENDER_TARGET_BLEND_DESC { - BlendEnable: true.into(), - SrcBlend: D3D11_BLEND_SRC_ALPHA, - DestBlend: D3D11_BLEND_INV_SRC_ALPHA, - BlendOp: D3D11_BLEND_OP_ADD, - SrcBlendAlpha: D3D11_BLEND_SRC_ALPHA, - DestBlendAlpha: D3D11_BLEND_INV_SRC_ALPHA, - BlendOpAlpha: D3D11_BLEND_OP_ADD, - RenderTargetWriteMask: D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8, - }, - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - Default::default(), - ], - }; - unsafe { device.CreateBlendState(&desc, Some(&mut blend_state)) }?; - blend_state.unwrap() - }; - - let sampler = { - let mut sampler = None; - let desc = D3D11_SAMPLER_DESC { - Filter: D3D11_FILTER_MIN_MAG_MIP_POINT, - AddressU: D3D11_TEXTURE_ADDRESS_BORDER, - AddressV: D3D11_TEXTURE_ADDRESS_BORDER, - AddressW: D3D11_TEXTURE_ADDRESS_BORDER, - MipLODBias: 0.0, - MaxAnisotropy: 1, - ComparisonFunc: D3D11_COMPARISON_ALWAYS, - BorderColor: [0.0, 0.0, 0.0, 0.0], - MinLOD: 0.0, - MaxLOD: 0.0, - }; - unsafe { device.CreateSamplerState(&desc, Some(&mut sampler)) }?; - [sampler] - }; - - let vertex_shader = { - let source = shader_resources::RawShaderBytes::new( - shader_resources::ShaderModule::EmojiRasterization, - shader_resources::ShaderTarget::Vertex, + let params = factory.CreateCustomRenderingParams( + gamma, + enhanced_contrast, + gray_contrast, + cleartype_level, + DWRITE_PIXEL_GEOMETRY_RGB, + DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + grid_fit_mode, )?; - let mut shader = None; - unsafe { device.CreateVertexShader(source.as_bytes(), None, Some(&mut shader)) }?; - shader.unwrap() - }; + let dc_target = { + let target = d2d1_factory.CreateDCRenderTarget(&get_render_target_property( + DXGI_FORMAT_B8G8R8A8_UNORM, + D2D1_ALPHA_MODE_PREMULTIPLIED, + ))?; + let target = target.cast::<ID2D1DeviceContext4>()?; + target.SetTextRenderingParams(¶ms); + target + }; - let pixel_shader = { - let source = shader_resources::RawShaderBytes::new( - shader_resources::ShaderModule::EmojiRasterization, - shader_resources::ShaderTarget::Fragment, - )?; - let mut shader = None; - unsafe { device.CreatePixelShader(source.as_bytes(), None, Some(&mut shader)) }?; - shader.unwrap() - }; - - Ok(Self { - device, - device_context, - sampler, - blend_state, - vertex_shader, - pixel_shader, - }) + Ok(Self { params, dc_target }) + } } } impl DirectWriteTextSystem { - pub(crate) fn new( - gpu_context: &DirectXDevices, - bitmap_factory: &IWICImagingFactory, - ) -> Result<Self> { - let components = DirectWriteComponent::new(bitmap_factory, gpu_context)?; + pub(crate) fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> { + let components = DirectWriteComponent::new(bitmap_factory)?; let system_font_collection = unsafe { let mut result = std::mem::zeroed(); components @@ -725,13 +648,15 @@ impl DirectWriteState { } } - fn create_glyph_run_analysis( - &self, - params: &RenderGlyphParams, - ) -> Result<IDWriteGlyphRunAnalysis> { + fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> { + let render_target = &self.components.render_context.dc_target; + unsafe { + render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); + render_target.SetDpi(96.0 * params.scale_factor, 96.0 * params.scale_factor); + } let font = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; - let advance = [0.0]; + let advance = [0.0f32]; let offset = [DWRITE_GLYPH_OFFSET::default()]; let glyph_run = DWRITE_GLYPH_RUN { fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, @@ -743,87 +668,44 @@ impl DirectWriteState { isSideways: BOOL(0), bidiLevel: 0, }; - let transform = DWRITE_MATRIX { - m11: params.scale_factor, - m12: 0.0, - m21: 0.0, - m22: params.scale_factor, - dx: 0.0, - dy: 0.0, - }; - let subpixel_shift = params - .subpixel_variant - .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); - let baseline_origin_x = subpixel_shift.x / params.scale_factor; - let baseline_origin_y = subpixel_shift.y / params.scale_factor; - - let mut rendering_mode = DWRITE_RENDERING_MODE1::default(); - let mut grid_fit_mode = DWRITE_GRID_FIT_MODE::default(); - unsafe { - font.font_face.GetRecommendedRenderingMode( - params.font_size.0, - // The dpi here seems that it has the same effect with `Some(&transform)` - 1.0, - 1.0, - Some(&transform), - false, - DWRITE_OUTLINE_THRESHOLD_ANTIALIASED, + let bounds = unsafe { + render_target.GetGlyphRunWorldBounds( + Vector2 { X: 0.0, Y: 0.0 }, + &glyph_run, DWRITE_MEASURING_MODE_NATURAL, - &self.components.render_params, - &mut rendering_mode, - &mut grid_fit_mode, - )?; + )? + }; + // todo(windows) + // This is a walkaround, deleted when figured out. + let y_offset; + let extra_height; + if params.is_emoji { + y_offset = 0; + extra_height = 0; + } else { + // make some room for scaler. + y_offset = -1; + extra_height = 2; } - let glyph_analysis = unsafe { - self.components.factory.CreateGlyphRunAnalysis( - &glyph_run, - Some(&transform), - rendering_mode, - DWRITE_MEASURING_MODE_NATURAL, - grid_fit_mode, - // We're using cleartype not grayscale for monochrome is because it provides better quality - DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, - baseline_origin_x, - baseline_origin_y, - ) - }?; - Ok(glyph_analysis) - } - - fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> { - let glyph_analysis = self.create_glyph_run_analysis(params)?; - - let bounds = unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1)? }; - // Some glyphs cannot be drawn with ClearType, such as bitmap fonts. In that case - // GetAlphaTextureBounds() supposedly returns an empty RECT, but I haven't tested that yet. - if !unsafe { IsRectEmpty(&bounds) }.as_bool() { + if bounds.right < bounds.left { Ok(Bounds { - origin: point(bounds.left.into(), bounds.top.into()), - size: size( - (bounds.right - bounds.left).into(), - (bounds.bottom - bounds.top).into(), - ), + origin: point(0.into(), 0.into()), + size: size(0.into(), 0.into()), }) } else { - // If it's empty, retry with grayscale AA. - let bounds = - unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_ALIASED_1x1)? }; - - if bounds.right < bounds.left { - Ok(Bounds { - origin: point(0.into(), 0.into()), - size: size(0.into(), 0.into()), - }) - } else { - Ok(Bounds { - origin: point(bounds.left.into(), bounds.top.into()), - size: size( - (bounds.right - bounds.left).into(), - (bounds.bottom - bounds.top).into(), - ), - }) - } + Ok(Bounds { + origin: point( + ((bounds.left * params.scale_factor).ceil() as i32).into(), + ((bounds.top * params.scale_factor).ceil() as i32 + y_offset).into(), + ), + size: size( + (((bounds.right - bounds.left) * params.scale_factor).ceil() as i32).into(), + (((bounds.bottom - bounds.top) * params.scale_factor).ceil() as i32 + + extra_height) + .into(), + ), + }) } } @@ -849,95 +731,7 @@ impl DirectWriteState { anyhow::bail!("glyph bounds are empty"); } - let bitmap_data = if params.is_emoji { - if let Ok(color) = self.rasterize_color(¶ms, glyph_bounds) { - color - } else { - let monochrome = self.rasterize_monochrome(params, glyph_bounds)?; - monochrome - .into_iter() - .flat_map(|pixel| [0, 0, 0, pixel]) - .collect::<Vec<_>>() - } - } else { - self.rasterize_monochrome(params, glyph_bounds)? - }; - - Ok((glyph_bounds.size, bitmap_data)) - } - - fn rasterize_monochrome( - &self, - params: &RenderGlyphParams, - glyph_bounds: Bounds<DevicePixels>, - ) -> Result<Vec<u8>> { - let mut bitmap_data = - vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize * 3]; - - let glyph_analysis = self.create_glyph_run_analysis(params)?; - unsafe { - glyph_analysis.CreateAlphaTexture( - // We're using cleartype not grayscale for monochrome is because it provides better quality - DWRITE_TEXTURE_CLEARTYPE_3x1, - &RECT { - left: glyph_bounds.origin.x.0, - top: glyph_bounds.origin.y.0, - right: glyph_bounds.size.width.0 + glyph_bounds.origin.x.0, - bottom: glyph_bounds.size.height.0 + glyph_bounds.origin.y.0, - }, - &mut bitmap_data, - )?; - } - - let bitmap_factory = self.components.bitmap_factory.resolve()?; - let bitmap = unsafe { - bitmap_factory.CreateBitmapFromMemory( - glyph_bounds.size.width.0 as u32, - glyph_bounds.size.height.0 as u32, - &GUID_WICPixelFormat24bppRGB, - glyph_bounds.size.width.0 as u32 * 3, - &bitmap_data, - ) - }?; - - let grayscale_bitmap = - unsafe { WICConvertBitmapSource(&GUID_WICPixelFormat8bppGray, &bitmap) }?; - - let mut bitmap_data = - vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize]; - unsafe { - grayscale_bitmap.CopyPixels( - std::ptr::null() as _, - glyph_bounds.size.width.0 as u32, - &mut bitmap_data, - ) - }?; - - Ok(bitmap_data) - } - - fn rasterize_color( - &self, - params: &RenderGlyphParams, - glyph_bounds: Bounds<DevicePixels>, - ) -> Result<Vec<u8>> { - let bitmap_size = glyph_bounds.size; - let subpixel_shift = params - .subpixel_variant - .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); - let baseline_origin_x = subpixel_shift.x / params.scale_factor; - let baseline_origin_y = subpixel_shift.y / params.scale_factor; - - let transform = DWRITE_MATRIX { - m11: params.scale_factor, - m12: 0.0, - m21: 0.0, - m22: params.scale_factor, - dx: 0.0, - dy: 0.0, - }; - - let font = &self.fonts[params.font_id.0]; + let font_info = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; let advance = [glyph_bounds.size.width.0 as f32]; let offset = [DWRITE_GLYPH_OFFSET { @@ -945,7 +739,7 @@ impl DirectWriteState { ascenderOffset: glyph_bounds.origin.y.0 as f32 / params.scale_factor, }]; let glyph_run = DWRITE_GLYPH_RUN { - fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, + fontFace: unsafe { std::mem::transmute_copy(&font_info.font_face) }, fontEmSize: params.font_size.0, glyphCount: 1, glyphIndices: glyph_id.as_ptr(), @@ -955,254 +749,160 @@ impl DirectWriteState { bidiLevel: 0, }; - // todo: support formats other than COLR - let color_enumerator = unsafe { - self.components.factory.TranslateColorGlyphRun( - Vector2::new(baseline_origin_x, baseline_origin_y), - &glyph_run, - None, - DWRITE_GLYPH_IMAGE_FORMATS_COLR, - DWRITE_MEASURING_MODE_NATURAL, - Some(&transform), - 0, - ) - }?; + // Add an extra pixel when the subpixel variant isn't zero to make room for anti-aliasing. + let mut bitmap_size = glyph_bounds.size; + if params.subpixel_variant.x > 0 { + bitmap_size.width += DevicePixels(1); + } + if params.subpixel_variant.y > 0 { + bitmap_size.height += DevicePixels(1); + } + let bitmap_size = bitmap_size; - let mut glyph_layers = Vec::new(); - loop { - let color_run = unsafe { color_enumerator.GetCurrentRun() }?; - let color_run = unsafe { &*color_run }; - let image_format = color_run.glyphImageFormat & !DWRITE_GLYPH_IMAGE_FORMATS_TRUETYPE; - if image_format == DWRITE_GLYPH_IMAGE_FORMATS_COLR { - let color_analysis = unsafe { - self.components.factory.CreateGlyphRunAnalysis( - &color_run.Base.glyphRun as *const _, - Some(&transform), - DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, - DWRITE_MEASURING_MODE_NATURAL, - DWRITE_GRID_FIT_MODE_DEFAULT, - DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, - baseline_origin_x, - baseline_origin_y, - ) - }?; - - let color_bounds = - unsafe { color_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1) }?; - - let color_size = size( - color_bounds.right - color_bounds.left, - color_bounds.bottom - color_bounds.top, - ); - if color_size.width > 0 && color_size.height > 0 { - let mut alpha_data = - vec![0u8; (color_size.width * color_size.height * 3) as usize]; - unsafe { - color_analysis.CreateAlphaTexture( - DWRITE_TEXTURE_CLEARTYPE_3x1, - &color_bounds, - &mut alpha_data, - ) - }?; - - let run_color = { - let run_color = color_run.Base.runColor; - Rgba { - r: run_color.r, - g: run_color.g, - b: run_color.b, - a: run_color.a, - } - }; - let bounds = bounds(point(color_bounds.left, color_bounds.top), color_size); - let alpha_data = alpha_data - .chunks_exact(3) - .flat_map(|chunk| [chunk[0], chunk[1], chunk[2], 255]) - .collect::<Vec<_>>(); - glyph_layers.push(GlyphLayerTexture::new( - &self.components.gpu_state, - run_color, - bounds, - &alpha_data, - )?); - } - } - - let has_next = unsafe { color_enumerator.MoveNext() } - .map(|e| e.as_bool()) - .unwrap_or(false); - if !has_next { - break; - } + let total_bytes; + let bitmap_format; + let render_target_property; + let bitmap_width; + let bitmap_height; + let bitmap_stride; + let bitmap_dpi; + if params.is_emoji { + total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize * 4; + bitmap_format = &GUID_WICPixelFormat32bppPBGRA; + render_target_property = get_render_target_property( + DXGI_FORMAT_B8G8R8A8_UNORM, + D2D1_ALPHA_MODE_PREMULTIPLIED, + ); + bitmap_width = bitmap_size.width.0 as u32; + bitmap_height = bitmap_size.height.0 as u32; + bitmap_stride = bitmap_size.width.0 as u32 * 4; + bitmap_dpi = 96.0; + } else { + total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize; + bitmap_format = &GUID_WICPixelFormat8bppAlpha; + render_target_property = + get_render_target_property(DXGI_FORMAT_A8_UNORM, D2D1_ALPHA_MODE_STRAIGHT); + bitmap_width = bitmap_size.width.0 as u32 * 2; + bitmap_height = bitmap_size.height.0 as u32 * 2; + bitmap_stride = bitmap_size.width.0 as u32; + bitmap_dpi = 192.0; } - let gpu_state = &self.components.gpu_state; - let params_buffer = { - let desc = D3D11_BUFFER_DESC { - ByteWidth: std::mem::size_of::<GlyphLayerTextureParams>() as u32, - Usage: D3D11_USAGE_DYNAMIC, - BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, - CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, - MiscFlags: 0, - StructureByteStride: 0, + let bitmap_factory = self.components.bitmap_factory.resolve()?; + unsafe { + let bitmap = bitmap_factory.CreateBitmap( + bitmap_width, + bitmap_height, + bitmap_format, + WICBitmapCacheOnLoad, + )?; + let render_target = self + .components + .d2d1_factory + .CreateWicBitmapRenderTarget(&bitmap, &render_target_property)?; + let brush = render_target.CreateSolidColorBrush(&BRUSH_COLOR, None)?; + let subpixel_shift = params + .subpixel_variant + .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + let baseline_origin = Vector2 { + X: subpixel_shift.x / params.scale_factor, + Y: subpixel_shift.y / params.scale_factor, }; - let mut buffer = None; - unsafe { - gpu_state - .device - .CreateBuffer(&desc, None, Some(&mut buffer)) - }?; - [buffer] - }; + // This `cast()` action here should never fail since we are running on Win10+, and + // ID2D1DeviceContext4 requires Win8+ + let render_target = render_target.cast::<ID2D1DeviceContext4>().unwrap(); + render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); + render_target.SetDpi( + bitmap_dpi * params.scale_factor, + bitmap_dpi * params.scale_factor, + ); + render_target.SetTextRenderingParams(&self.components.render_context.params); + render_target.BeginDraw(); - let render_target_texture = { - let mut texture = None; - let desc = D3D11_TEXTURE2D_DESC { - Width: bitmap_size.width.0 as u32, - Height: bitmap_size.height.0 as u32, - MipLevels: 1, - ArraySize: 1, - Format: DXGI_FORMAT_B8G8R8A8_UNORM, - SampleDesc: DXGI_SAMPLE_DESC { - Count: 1, - Quality: 0, - }, - Usage: D3D11_USAGE_DEFAULT, - BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, - CPUAccessFlags: 0, - MiscFlags: 0, - }; - unsafe { - gpu_state - .device - .CreateTexture2D(&desc, None, Some(&mut texture)) - }?; - texture.unwrap() - }; - - let render_target_view = { - let desc = D3D11_RENDER_TARGET_VIEW_DESC { - Format: DXGI_FORMAT_B8G8R8A8_UNORM, - ViewDimension: D3D11_RTV_DIMENSION_TEXTURE2D, - Anonymous: D3D11_RENDER_TARGET_VIEW_DESC_0 { - Texture2D: D3D11_TEX2D_RTV { MipSlice: 0 }, - }, - }; - let mut rtv = None; - unsafe { - gpu_state.device.CreateRenderTargetView( - &render_target_texture, - Some(&desc), - Some(&mut rtv), - ) - }?; - [rtv] - }; - - let staging_texture = { - let mut texture = None; - let desc = D3D11_TEXTURE2D_DESC { - Width: bitmap_size.width.0 as u32, - Height: bitmap_size.height.0 as u32, - MipLevels: 1, - ArraySize: 1, - Format: DXGI_FORMAT_B8G8R8A8_UNORM, - SampleDesc: DXGI_SAMPLE_DESC { - Count: 1, - Quality: 0, - }, - Usage: D3D11_USAGE_STAGING, - BindFlags: 0, - CPUAccessFlags: D3D11_CPU_ACCESS_READ.0 as u32, - MiscFlags: 0, - }; - unsafe { - gpu_state - .device - .CreateTexture2D(&desc, None, Some(&mut texture)) - }?; - texture.unwrap() - }; - - let device_context = &gpu_state.device_context; - unsafe { device_context.IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP) }; - unsafe { device_context.VSSetShader(&gpu_state.vertex_shader, None) }; - unsafe { device_context.PSSetShader(&gpu_state.pixel_shader, None) }; - unsafe { device_context.VSSetConstantBuffers(0, Some(¶ms_buffer)) }; - unsafe { device_context.PSSetConstantBuffers(0, Some(¶ms_buffer)) }; - unsafe { device_context.OMSetRenderTargets(Some(&render_target_view), None) }; - unsafe { device_context.PSSetSamplers(0, Some(&gpu_state.sampler)) }; - unsafe { device_context.OMSetBlendState(&gpu_state.blend_state, None, 0xffffffff) }; - - for layer in glyph_layers { - let params = GlyphLayerTextureParams { - run_color: layer.run_color, - bounds: layer.bounds, - }; - unsafe { - let mut dest = std::mem::zeroed(); - gpu_state.device_context.Map( - params_buffer[0].as_ref().unwrap(), + if params.is_emoji { + // WARN: only DWRITE_GLYPH_IMAGE_FORMATS_COLR has been tested + let enumerator = self.components.factory.TranslateColorGlyphRun( + baseline_origin, + &glyph_run as _, + None, + DWRITE_GLYPH_IMAGE_FORMATS_COLR + | DWRITE_GLYPH_IMAGE_FORMATS_SVG + | DWRITE_GLYPH_IMAGE_FORMATS_PNG + | DWRITE_GLYPH_IMAGE_FORMATS_JPEG + | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8, + DWRITE_MEASURING_MODE_NATURAL, + None, 0, - D3D11_MAP_WRITE_DISCARD, - 0, - Some(&mut dest), )?; - std::ptr::copy_nonoverlapping(¶ms as *const _, dest.pData as *mut _, 1); - gpu_state - .device_context - .Unmap(params_buffer[0].as_ref().unwrap(), 0); - }; + while enumerator.MoveNext().is_ok() { + let Ok(color_glyph) = enumerator.GetCurrentRun() else { + break; + }; + let color_glyph = &*color_glyph; + let brush_color = translate_color(&color_glyph.Base.runColor); + brush.SetColor(&brush_color); + match color_glyph.glyphImageFormat { + DWRITE_GLYPH_IMAGE_FORMATS_PNG + | DWRITE_GLYPH_IMAGE_FORMATS_JPEG + | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8 => render_target + .DrawColorBitmapGlyphRun( + color_glyph.glyphImageFormat, + baseline_origin, + &color_glyph.Base.glyphRun, + color_glyph.measuringMode, + D2D1_COLOR_BITMAP_GLYPH_SNAP_OPTION_DEFAULT, + ), + DWRITE_GLYPH_IMAGE_FORMATS_SVG => render_target.DrawSvgGlyphRun( + baseline_origin, + &color_glyph.Base.glyphRun, + &brush, + None, + color_glyph.Base.paletteIndex as u32, + color_glyph.measuringMode, + ), + _ => render_target.DrawGlyphRun( + baseline_origin, + &color_glyph.Base.glyphRun, + Some(color_glyph.Base.glyphRunDescription as *const _), + &brush, + color_glyph.measuringMode, + ), + } + } + } else { + render_target.DrawGlyphRun( + baseline_origin, + &glyph_run, + None, + &brush, + DWRITE_MEASURING_MODE_NATURAL, + ); + } + render_target.EndDraw(None, None)?; - let texture = [Some(layer.texture_view)]; - unsafe { device_context.PSSetShaderResources(0, Some(&texture)) }; - - let viewport = [D3D11_VIEWPORT { - TopLeftX: layer.bounds.origin.x as f32, - TopLeftY: layer.bounds.origin.y as f32, - Width: layer.bounds.size.width as f32, - Height: layer.bounds.size.height as f32, - MinDepth: 0.0, - MaxDepth: 1.0, - }]; - unsafe { device_context.RSSetViewports(Some(&viewport)) }; - - unsafe { device_context.Draw(4, 0) }; + let mut raw_data = vec![0u8; total_bytes]; + if params.is_emoji { + bitmap.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; + // Convert from BGRA with premultiplied alpha to BGRA with straight alpha. + for pixel in raw_data.chunks_exact_mut(4) { + let a = pixel[3] as f32 / 255.; + pixel[0] = (pixel[0] as f32 / a) as u8; + pixel[1] = (pixel[1] as f32 / a) as u8; + pixel[2] = (pixel[2] as f32 / a) as u8; + } + } else { + let scaler = bitmap_factory.CreateBitmapScaler()?; + scaler.Initialize( + &bitmap, + bitmap_size.width.0 as u32, + bitmap_size.height.0 as u32, + WICBitmapInterpolationModeHighQualityCubic, + )?; + scaler.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; + } + Ok((bitmap_size, raw_data)) } - - unsafe { device_context.CopyResource(&staging_texture, &render_target_texture) }; - - let mapped_data = { - let mut mapped_data = D3D11_MAPPED_SUBRESOURCE::default(); - unsafe { - device_context.Map( - &staging_texture, - 0, - D3D11_MAP_READ, - 0, - Some(&mut mapped_data), - ) - }?; - mapped_data - }; - let mut rasterized = - vec![0u8; (bitmap_size.width.0 as u32 * bitmap_size.height.0 as u32 * 4) as usize]; - - for y in 0..bitmap_size.height.0 as usize { - let width = bitmap_size.width.0 as usize; - unsafe { - std::ptr::copy_nonoverlapping::<u8>( - (mapped_data.pData as *const u8).byte_add(mapped_data.RowPitch as usize * y), - rasterized - .as_mut_ptr() - .byte_add(width * y * std::mem::size_of::<u32>()), - width * std::mem::size_of::<u32>(), - ) - }; - } - - Ok(rasterized) } fn get_typographic_bounds(&self, font_id: FontId, glyph_id: GlyphId) -> Result<Bounds<f32>> { @@ -1276,84 +976,6 @@ impl Drop for DirectWriteState { } } -struct GlyphLayerTexture { - run_color: Rgba, - bounds: Bounds<i32>, - texture_view: ID3D11ShaderResourceView, - // holding on to the texture to not RAII drop it - _texture: ID3D11Texture2D, -} - -impl GlyphLayerTexture { - pub fn new( - gpu_state: &GPUState, - run_color: Rgba, - bounds: Bounds<i32>, - alpha_data: &[u8], - ) -> Result<Self> { - let texture_size = bounds.size; - - let desc = D3D11_TEXTURE2D_DESC { - Width: texture_size.width as u32, - Height: texture_size.height as u32, - MipLevels: 1, - ArraySize: 1, - Format: DXGI_FORMAT_R8G8B8A8_UNORM, - SampleDesc: DXGI_SAMPLE_DESC { - Count: 1, - Quality: 0, - }, - Usage: D3D11_USAGE_DEFAULT, - BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, - CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, - MiscFlags: 0, - }; - - let texture = { - let mut texture: Option<ID3D11Texture2D> = None; - unsafe { - gpu_state - .device - .CreateTexture2D(&desc, None, Some(&mut texture))? - }; - texture.unwrap() - }; - let texture_view = { - let mut view: Option<ID3D11ShaderResourceView> = None; - unsafe { - gpu_state - .device - .CreateShaderResourceView(&texture, None, Some(&mut view))? - }; - view.unwrap() - }; - - unsafe { - gpu_state.device_context.UpdateSubresource( - &texture, - 0, - None, - alpha_data.as_ptr() as _, - (texture_size.width * 4) as u32, - 0, - ) - }; - - Ok(GlyphLayerTexture { - run_color, - bounds, - texture_view, - _texture: texture, - }) - } -} - -#[repr(C)] -struct GlyphLayerTextureParams { - bounds: Bounds<i32>, - run_color: Rgba, -} - struct TextRendererWrapper(pub IDWriteTextRenderer); impl TextRendererWrapper { @@ -1848,6 +1470,16 @@ fn get_name(string: IDWriteLocalizedStrings, locale: &str) -> Result<String> { Ok(String::from_utf16_lossy(&name_vec[..name_length])) } +#[inline] +fn translate_color(color: &DWRITE_COLOR_F) -> D2D1_COLOR_F { + D2D1_COLOR_F { + r: color.r, + g: color.g, + b: color.b, + a: color.a, + } +} + fn get_system_ui_font_name() -> SharedString { unsafe { let mut info: LOGFONTW = std::mem::zeroed(); @@ -1872,6 +1504,24 @@ fn get_system_ui_font_name() -> SharedString { } } +#[inline] +fn get_render_target_property( + pixel_format: DXGI_FORMAT, + alpha_mode: D2D1_ALPHA_MODE, +) -> D2D1_RENDER_TARGET_PROPERTIES { + D2D1_RENDER_TARGET_PROPERTIES { + r#type: D2D1_RENDER_TARGET_TYPE_DEFAULT, + pixelFormat: D2D1_PIXEL_FORMAT { + format: pixel_format, + alphaMode: alpha_mode, + }, + dpiX: 96.0, + dpiY: 96.0, + usage: D2D1_RENDER_TARGET_USAGE_NONE, + minLevel: D2D1_FEATURE_LEVEL_DEFAULT, + } +} + // One would think that with newer DirectWrite method: IDWriteFontFace4::GetGlyphImageFormats // but that doesn't seem to work for some glyphs, say ❤ fn is_color_glyph( @@ -1911,6 +1561,12 @@ fn is_color_glyph( } const DEFAULT_LOCALE_NAME: PCWSTR = windows::core::w!("en-US"); +const BRUSH_COLOR: D2D1_COLOR_F = D2D1_COLOR_F { + r: 1.0, + g: 1.0, + b: 1.0, + a: 1.0, +}; #[cfg(test)] mod tests { diff --git a/crates/gpui/src/platform/windows/directx_atlas.rs b/crates/gpui/src/platform/windows/directx_atlas.rs deleted file mode 100644 index 6bced4c11d..0000000000 --- a/crates/gpui/src/platform/windows/directx_atlas.rs +++ /dev/null @@ -1,309 +0,0 @@ -use collections::FxHashMap; -use etagere::BucketedAtlasAllocator; -use parking_lot::Mutex; -use windows::Win32::Graphics::{ - Direct3D11::{ - D3D11_BIND_SHADER_RESOURCE, D3D11_BOX, D3D11_CPU_ACCESS_WRITE, D3D11_TEXTURE2D_DESC, - D3D11_USAGE_DEFAULT, ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView, - ID3D11Texture2D, - }, - Dxgi::Common::*, -}; - -use crate::{ - AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTile, Bounds, DevicePixels, PlatformAtlas, - Point, Size, platform::AtlasTextureList, -}; - -pub(crate) struct DirectXAtlas(Mutex<DirectXAtlasState>); - -struct DirectXAtlasState { - device: ID3D11Device, - device_context: ID3D11DeviceContext, - monochrome_textures: AtlasTextureList<DirectXAtlasTexture>, - polychrome_textures: AtlasTextureList<DirectXAtlasTexture>, - tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, -} - -struct DirectXAtlasTexture { - id: AtlasTextureId, - bytes_per_pixel: u32, - allocator: BucketedAtlasAllocator, - texture: ID3D11Texture2D, - view: [Option<ID3D11ShaderResourceView>; 1], - live_atlas_keys: u32, -} - -impl DirectXAtlas { - pub(crate) fn new(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Self { - DirectXAtlas(Mutex::new(DirectXAtlasState { - device: device.clone(), - device_context: device_context.clone(), - monochrome_textures: Default::default(), - polychrome_textures: Default::default(), - tiles_by_key: Default::default(), - })) - } - - pub(crate) fn get_texture_view( - &self, - id: AtlasTextureId, - ) -> [Option<ID3D11ShaderResourceView>; 1] { - let lock = self.0.lock(); - let tex = lock.texture(id); - tex.view.clone() - } - - pub(crate) fn handle_device_lost( - &self, - device: &ID3D11Device, - device_context: &ID3D11DeviceContext, - ) { - let mut lock = self.0.lock(); - lock.device = device.clone(); - lock.device_context = device_context.clone(); - lock.monochrome_textures = AtlasTextureList::default(); - lock.polychrome_textures = AtlasTextureList::default(); - lock.tiles_by_key.clear(); - } -} - -impl PlatformAtlas for DirectXAtlas { - fn get_or_insert_with<'a>( - &self, - key: &AtlasKey, - build: &mut dyn FnMut() -> anyhow::Result< - Option<(Size<DevicePixels>, std::borrow::Cow<'a, [u8]>)>, - >, - ) -> anyhow::Result<Option<AtlasTile>> { - let mut lock = self.0.lock(); - if let Some(tile) = lock.tiles_by_key.get(key) { - Ok(Some(tile.clone())) - } else { - let Some((size, bytes)) = build()? else { - return Ok(None); - }; - let tile = lock - .allocate(size, key.texture_kind()) - .ok_or_else(|| anyhow::anyhow!("failed to allocate"))?; - let texture = lock.texture(tile.texture_id); - texture.upload(&lock.device_context, tile.bounds, &bytes); - lock.tiles_by_key.insert(key.clone(), tile.clone()); - Ok(Some(tile)) - } - } - - fn remove(&self, key: &AtlasKey) { - let mut lock = self.0.lock(); - - let Some(id) = lock.tiles_by_key.remove(key).map(|tile| tile.texture_id) else { - return; - }; - - let textures = match id.kind { - AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, - AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, - }; - - let Some(texture_slot) = textures.textures.get_mut(id.index as usize) else { - return; - }; - - if let Some(mut texture) = texture_slot.take() { - texture.decrement_ref_count(); - if texture.is_unreferenced() { - textures.free_list.push(texture.id.index as usize); - lock.tiles_by_key.remove(key); - } else { - *texture_slot = Some(texture); - } - } - } -} - -impl DirectXAtlasState { - fn allocate( - &mut self, - size: Size<DevicePixels>, - texture_kind: AtlasTextureKind, - ) -> Option<AtlasTile> { - { - let textures = match texture_kind { - AtlasTextureKind::Monochrome => &mut self.monochrome_textures, - AtlasTextureKind::Polychrome => &mut self.polychrome_textures, - }; - - if let Some(tile) = textures - .iter_mut() - .rev() - .find_map(|texture| texture.allocate(size)) - { - return Some(tile); - } - } - - let texture = self.push_texture(size, texture_kind)?; - texture.allocate(size) - } - - fn push_texture( - &mut self, - min_size: Size<DevicePixels>, - kind: AtlasTextureKind, - ) -> Option<&mut DirectXAtlasTexture> { - const DEFAULT_ATLAS_SIZE: Size<DevicePixels> = Size { - width: DevicePixels(1024), - height: DevicePixels(1024), - }; - // Max texture size for DirectX. See: - // https://learn.microsoft.com/en-us/windows/win32/direct3d11/overviews-direct3d-11-resources-limits - const MAX_ATLAS_SIZE: Size<DevicePixels> = Size { - width: DevicePixels(16384), - height: DevicePixels(16384), - }; - let size = min_size.min(&MAX_ATLAS_SIZE).max(&DEFAULT_ATLAS_SIZE); - let pixel_format; - let bind_flag; - let bytes_per_pixel; - match kind { - AtlasTextureKind::Monochrome => { - pixel_format = DXGI_FORMAT_R8_UNORM; - bind_flag = D3D11_BIND_SHADER_RESOURCE; - bytes_per_pixel = 1; - } - AtlasTextureKind::Polychrome => { - pixel_format = DXGI_FORMAT_B8G8R8A8_UNORM; - bind_flag = D3D11_BIND_SHADER_RESOURCE; - bytes_per_pixel = 4; - } - } - let texture_desc = D3D11_TEXTURE2D_DESC { - Width: size.width.0 as u32, - Height: size.height.0 as u32, - MipLevels: 1, - ArraySize: 1, - Format: pixel_format, - SampleDesc: DXGI_SAMPLE_DESC { - Count: 1, - Quality: 0, - }, - Usage: D3D11_USAGE_DEFAULT, - BindFlags: bind_flag.0 as u32, - CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, - MiscFlags: 0, - }; - let mut texture: Option<ID3D11Texture2D> = None; - unsafe { - // This only returns None if the device is lost, which we will recreate later. - // So it's ok to return None here. - self.device - .CreateTexture2D(&texture_desc, None, Some(&mut texture)) - .ok()?; - } - let texture = texture.unwrap(); - - let texture_list = match kind { - AtlasTextureKind::Monochrome => &mut self.monochrome_textures, - AtlasTextureKind::Polychrome => &mut self.polychrome_textures, - }; - let index = texture_list.free_list.pop(); - let view = unsafe { - let mut view = None; - self.device - .CreateShaderResourceView(&texture, None, Some(&mut view)) - .ok()?; - [view] - }; - let atlas_texture = DirectXAtlasTexture { - id: AtlasTextureId { - index: index.unwrap_or(texture_list.textures.len()) as u32, - kind, - }, - bytes_per_pixel, - allocator: etagere::BucketedAtlasAllocator::new(size.into()), - texture, - view, - live_atlas_keys: 0, - }; - if let Some(ix) = index { - texture_list.textures[ix] = Some(atlas_texture); - texture_list.textures.get_mut(ix).unwrap().as_mut() - } else { - texture_list.textures.push(Some(atlas_texture)); - texture_list.textures.last_mut().unwrap().as_mut() - } - } - - fn texture(&self, id: AtlasTextureId) -> &DirectXAtlasTexture { - let textures = match id.kind { - crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, - crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, - }; - textures[id.index as usize].as_ref().unwrap() - } -} - -impl DirectXAtlasTexture { - fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { - let allocation = self.allocator.allocate(size.into())?; - let tile = AtlasTile { - texture_id: self.id, - tile_id: allocation.id.into(), - bounds: Bounds { - origin: allocation.rectangle.min.into(), - size, - }, - padding: 0, - }; - self.live_atlas_keys += 1; - Some(tile) - } - - fn upload( - &self, - device_context: &ID3D11DeviceContext, - bounds: Bounds<DevicePixels>, - bytes: &[u8], - ) { - unsafe { - device_context.UpdateSubresource( - &self.texture, - 0, - Some(&D3D11_BOX { - left: bounds.left().0 as u32, - top: bounds.top().0 as u32, - front: 0, - right: bounds.right().0 as u32, - bottom: bounds.bottom().0 as u32, - back: 1, - }), - bytes.as_ptr() as _, - bounds.size.width.to_bytes(self.bytes_per_pixel as u8), - 0, - ); - } - } - - fn decrement_ref_count(&mut self) { - self.live_atlas_keys -= 1; - } - - fn is_unreferenced(&mut self) -> bool { - self.live_atlas_keys == 0 - } -} - -impl From<Size<DevicePixels>> for etagere::Size { - fn from(size: Size<DevicePixels>) -> Self { - etagere::Size::new(size.width.into(), size.height.into()) - } -} - -impl From<etagere::Point> for Point<DevicePixels> { - fn from(value: etagere::Point) -> Self { - Point { - x: DevicePixels::from(value.x), - y: DevicePixels::from(value.y), - } - } -} diff --git a/crates/gpui/src/platform/windows/directx_renderer.rs b/crates/gpui/src/platform/windows/directx_renderer.rs deleted file mode 100644 index 72cc12a5b4..0000000000 --- a/crates/gpui/src/platform/windows/directx_renderer.rs +++ /dev/null @@ -1,1807 +0,0 @@ -use std::{mem::ManuallyDrop, sync::Arc}; - -use ::util::ResultExt; -use anyhow::{Context, Result}; -use windows::{ - Win32::{ - Foundation::{HMODULE, HWND}, - Graphics::{ - Direct3D::*, - Direct3D11::*, - DirectComposition::*, - Dxgi::{Common::*, *}, - }, - }, - core::Interface, -}; - -use crate::{ - platform::windows::directx_renderer::shader_resources::{ - RawShaderBytes, ShaderModule, ShaderTarget, - }, - *, -}; - -pub(crate) const DISABLE_DIRECT_COMPOSITION: &str = "GPUI_DISABLE_DIRECT_COMPOSITION"; -const RENDER_TARGET_FORMAT: DXGI_FORMAT = DXGI_FORMAT_B8G8R8A8_UNORM; -// This configuration is used for MSAA rendering on paths only, and it's guaranteed to be supported by DirectX 11. -const PATH_MULTISAMPLE_COUNT: u32 = 4; - -pub(crate) struct DirectXRenderer { - hwnd: HWND, - atlas: Arc<DirectXAtlas>, - devices: ManuallyDrop<DirectXDevices>, - resources: ManuallyDrop<DirectXResources>, - globals: DirectXGlobalElements, - pipelines: DirectXRenderPipelines, - direct_composition: Option<DirectComposition>, -} - -/// Direct3D objects -#[derive(Clone)] -pub(crate) struct DirectXDevices { - adapter: IDXGIAdapter1, - dxgi_factory: IDXGIFactory6, - pub(crate) device: ID3D11Device, - pub(crate) device_context: ID3D11DeviceContext, - dxgi_device: Option<IDXGIDevice>, -} - -struct DirectXResources { - // Direct3D rendering objects - swap_chain: IDXGISwapChain1, - render_target: ManuallyDrop<ID3D11Texture2D>, - render_target_view: [Option<ID3D11RenderTargetView>; 1], - - // Path intermediate textures (with MSAA) - path_intermediate_texture: ID3D11Texture2D, - path_intermediate_srv: [Option<ID3D11ShaderResourceView>; 1], - path_intermediate_msaa_texture: ID3D11Texture2D, - path_intermediate_msaa_view: [Option<ID3D11RenderTargetView>; 1], - - // Cached window size and viewport - width: u32, - height: u32, - viewport: [D3D11_VIEWPORT; 1], -} - -struct DirectXRenderPipelines { - shadow_pipeline: PipelineState<Shadow>, - quad_pipeline: PipelineState<Quad>, - path_rasterization_pipeline: PipelineState<PathRasterizationSprite>, - path_sprite_pipeline: PipelineState<PathSprite>, - underline_pipeline: PipelineState<Underline>, - mono_sprites: PipelineState<MonochromeSprite>, - poly_sprites: PipelineState<PolychromeSprite>, -} - -struct DirectXGlobalElements { - global_params_buffer: [Option<ID3D11Buffer>; 1], - sampler: [Option<ID3D11SamplerState>; 1], -} - -struct DirectComposition { - comp_device: IDCompositionDevice, - comp_target: IDCompositionTarget, - comp_visual: IDCompositionVisual, -} - -impl DirectXDevices { - pub(crate) fn new(disable_direct_composition: bool) -> Result<ManuallyDrop<Self>> { - let debug_layer_available = check_debug_layer_available(); - let dxgi_factory = - get_dxgi_factory(debug_layer_available).context("Creating DXGI factory")?; - let adapter = - get_adapter(&dxgi_factory, debug_layer_available).context("Getting DXGI adapter")?; - let (device, device_context) = { - let mut device: Option<ID3D11Device> = None; - let mut context: Option<ID3D11DeviceContext> = None; - let mut feature_level = D3D_FEATURE_LEVEL::default(); - get_device( - &adapter, - Some(&mut device), - Some(&mut context), - Some(&mut feature_level), - debug_layer_available, - ) - .context("Creating Direct3D device")?; - match feature_level { - D3D_FEATURE_LEVEL_11_1 => { - log::info!("Created device with Direct3D 11.1 feature level.") - } - D3D_FEATURE_LEVEL_11_0 => { - log::info!("Created device with Direct3D 11.0 feature level.") - } - D3D_FEATURE_LEVEL_10_1 => { - log::info!("Created device with Direct3D 10.1 feature level.") - } - _ => unreachable!(), - } - (device.unwrap(), context.unwrap()) - }; - let dxgi_device = if disable_direct_composition { - None - } else { - Some(device.cast().context("Creating DXGI device")?) - }; - - Ok(ManuallyDrop::new(Self { - adapter, - dxgi_factory, - dxgi_device, - device, - device_context, - })) - } -} - -impl DirectXRenderer { - pub(crate) fn new(hwnd: HWND, disable_direct_composition: bool) -> Result<Self> { - if disable_direct_composition { - log::info!("Direct Composition is disabled."); - } - - let devices = - DirectXDevices::new(disable_direct_composition).context("Creating DirectX devices")?; - let atlas = Arc::new(DirectXAtlas::new(&devices.device, &devices.device_context)); - - let resources = DirectXResources::new(&devices, 1, 1, hwnd, disable_direct_composition) - .context("Creating DirectX resources")?; - let globals = DirectXGlobalElements::new(&devices.device) - .context("Creating DirectX global elements")?; - let pipelines = DirectXRenderPipelines::new(&devices.device) - .context("Creating DirectX render pipelines")?; - - let direct_composition = if disable_direct_composition { - None - } else { - let composition = DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), hwnd) - .context("Creating DirectComposition")?; - composition - .set_swap_chain(&resources.swap_chain) - .context("Setting swap chain for DirectComposition")?; - Some(composition) - }; - - Ok(DirectXRenderer { - hwnd, - atlas, - devices, - resources, - globals, - pipelines, - direct_composition, - }) - } - - pub(crate) fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> { - self.atlas.clone() - } - - fn pre_draw(&self) -> Result<()> { - update_buffer( - &self.devices.device_context, - self.globals.global_params_buffer[0].as_ref().unwrap(), - &[GlobalParams { - viewport_size: [ - self.resources.viewport[0].Width, - self.resources.viewport[0].Height, - ], - _pad: 0, - }], - )?; - unsafe { - self.devices.device_context.ClearRenderTargetView( - self.resources.render_target_view[0].as_ref().unwrap(), - &[0.0; 4], - ); - self.devices - .device_context - .OMSetRenderTargets(Some(&self.resources.render_target_view), None); - self.devices - .device_context - .RSSetViewports(Some(&self.resources.viewport)); - } - Ok(()) - } - - fn present(&mut self) -> Result<()> { - unsafe { - let result = self.resources.swap_chain.Present(1, DXGI_PRESENT(0)); - // Presenting the swap chain can fail if the DirectX device was removed or reset. - if result == DXGI_ERROR_DEVICE_REMOVED || result == DXGI_ERROR_DEVICE_RESET { - let reason = self.devices.device.GetDeviceRemovedReason(); - log::error!( - "DirectX device removed or reset when drawing. Reason: {:?}", - reason - ); - self.handle_device_lost()?; - } else { - result.ok()?; - } - } - Ok(()) - } - - fn handle_device_lost(&mut self) -> Result<()> { - // Here we wait a bit to ensure the the system has time to recover from the device lost state. - // If we don't wait, the final drawing result will be blank. - std::thread::sleep(std::time::Duration::from_millis(300)); - let disable_direct_composition = self.direct_composition.is_none(); - - unsafe { - #[cfg(debug_assertions)] - report_live_objects(&self.devices.device) - .context("Failed to report live objects after device lost") - .log_err(); - - ManuallyDrop::drop(&mut self.resources); - self.devices.device_context.OMSetRenderTargets(None, None); - self.devices.device_context.ClearState(); - self.devices.device_context.Flush(); - - #[cfg(debug_assertions)] - report_live_objects(&self.devices.device) - .context("Failed to report live objects after device lost") - .log_err(); - - drop(self.direct_composition.take()); - ManuallyDrop::drop(&mut self.devices); - } - - let devices = DirectXDevices::new(disable_direct_composition) - .context("Recreating DirectX devices")?; - let resources = DirectXResources::new( - &devices, - self.resources.width, - self.resources.height, - self.hwnd, - disable_direct_composition, - )?; - let globals = DirectXGlobalElements::new(&devices.device)?; - let pipelines = DirectXRenderPipelines::new(&devices.device)?; - - let direct_composition = if disable_direct_composition { - None - } else { - let composition = - DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), self.hwnd)?; - composition.set_swap_chain(&resources.swap_chain)?; - Some(composition) - }; - - self.atlas - .handle_device_lost(&devices.device, &devices.device_context); - self.devices = devices; - self.resources = resources; - self.globals = globals; - self.pipelines = pipelines; - self.direct_composition = direct_composition; - - unsafe { - self.devices - .device_context - .OMSetRenderTargets(Some(&self.resources.render_target_view), None); - } - Ok(()) - } - - pub(crate) fn draw(&mut self, scene: &Scene) -> Result<()> { - self.pre_draw()?; - for batch in scene.batches() { - match batch { - PrimitiveBatch::Shadows(shadows) => self.draw_shadows(shadows), - PrimitiveBatch::Quads(quads) => self.draw_quads(quads), - PrimitiveBatch::Paths(paths) => { - self.draw_paths_to_intermediate(paths)?; - self.draw_paths_from_intermediate(paths) - } - PrimitiveBatch::Underlines(underlines) => self.draw_underlines(underlines), - PrimitiveBatch::MonochromeSprites { - texture_id, - sprites, - } => self.draw_monochrome_sprites(texture_id, sprites), - PrimitiveBatch::PolychromeSprites { - texture_id, - sprites, - } => self.draw_polychrome_sprites(texture_id, sprites), - PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces(surfaces), - }.context(format!("scene too large: {} paths, {} shadows, {} quads, {} underlines, {} mono, {} poly, {} surfaces", - scene.paths.len(), - scene.shadows.len(), - scene.quads.len(), - scene.underlines.len(), - scene.monochrome_sprites.len(), - scene.polychrome_sprites.len(), - scene.surfaces.len(),))?; - } - self.present() - } - - pub(crate) fn resize(&mut self, new_size: Size<DevicePixels>) -> Result<()> { - let width = new_size.width.0.max(1) as u32; - let height = new_size.height.0.max(1) as u32; - if self.resources.width == width && self.resources.height == height { - return Ok(()); - } - unsafe { - // Clear the render target before resizing - self.devices.device_context.OMSetRenderTargets(None, None); - ManuallyDrop::drop(&mut self.resources.render_target); - drop(self.resources.render_target_view[0].take().unwrap()); - - let result = self.resources.swap_chain.ResizeBuffers( - BUFFER_COUNT as u32, - width, - height, - RENDER_TARGET_FORMAT, - DXGI_SWAP_CHAIN_FLAG(0), - ); - // Resizing the swap chain requires a call to the underlying DXGI adapter, which can return the device removed error. - // The app might have moved to a monitor that's attached to a different graphics device. - // When a graphics device is removed or reset, the desktop resolution often changes, resulting in a window size change. - match result { - Ok(_) => {} - Err(e) => { - if e.code() == DXGI_ERROR_DEVICE_REMOVED || e.code() == DXGI_ERROR_DEVICE_RESET - { - let reason = self.devices.device.GetDeviceRemovedReason(); - log::error!( - "DirectX device removed or reset when resizing. Reason: {:?}", - reason - ); - self.resources.width = width; - self.resources.height = height; - self.handle_device_lost()?; - return Ok(()); - } else { - log::error!("Failed to resize swap chain: {:?}", e); - return Err(e.into()); - } - } - } - - self.resources - .recreate_resources(&self.devices, width, height)?; - self.devices - .device_context - .OMSetRenderTargets(Some(&self.resources.render_target_view), None); - } - Ok(()) - } - - fn draw_shadows(&mut self, shadows: &[Shadow]) -> Result<()> { - if shadows.is_empty() { - return Ok(()); - } - self.pipelines.shadow_pipeline.update_buffer( - &self.devices.device, - &self.devices.device_context, - shadows, - )?; - self.pipelines.shadow_pipeline.draw( - &self.devices.device_context, - &self.resources.viewport, - &self.globals.global_params_buffer, - D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, - 4, - shadows.len() as u32, - ) - } - - fn draw_quads(&mut self, quads: &[Quad]) -> Result<()> { - if quads.is_empty() { - return Ok(()); - } - self.pipelines.quad_pipeline.update_buffer( - &self.devices.device, - &self.devices.device_context, - quads, - )?; - self.pipelines.quad_pipeline.draw( - &self.devices.device_context, - &self.resources.viewport, - &self.globals.global_params_buffer, - D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, - 4, - quads.len() as u32, - ) - } - - fn draw_paths_to_intermediate(&mut self, paths: &[Path<ScaledPixels>]) -> Result<()> { - if paths.is_empty() { - return Ok(()); - } - - // Clear intermediate MSAA texture - unsafe { - self.devices.device_context.ClearRenderTargetView( - self.resources.path_intermediate_msaa_view[0] - .as_ref() - .unwrap(), - &[0.0; 4], - ); - // Set intermediate MSAA texture as render target - self.devices - .device_context - .OMSetRenderTargets(Some(&self.resources.path_intermediate_msaa_view), None); - } - - // Collect all vertices and sprites for a single draw call - let mut vertices = Vec::new(); - - for path in paths { - vertices.extend(path.vertices.iter().map(|v| PathRasterizationSprite { - xy_position: v.xy_position, - st_position: v.st_position, - color: path.color, - bounds: path.bounds.intersect(&path.content_mask.bounds), - })); - } - - self.pipelines.path_rasterization_pipeline.update_buffer( - &self.devices.device, - &self.devices.device_context, - &vertices, - )?; - self.pipelines.path_rasterization_pipeline.draw( - &self.devices.device_context, - &self.resources.viewport, - &self.globals.global_params_buffer, - D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST, - vertices.len() as u32, - 1, - )?; - - // Resolve MSAA to non-MSAA intermediate texture - unsafe { - self.devices.device_context.ResolveSubresource( - &self.resources.path_intermediate_texture, - 0, - &self.resources.path_intermediate_msaa_texture, - 0, - RENDER_TARGET_FORMAT, - ); - // Restore main render target - self.devices - .device_context - .OMSetRenderTargets(Some(&self.resources.render_target_view), None); - } - - Ok(()) - } - - fn draw_paths_from_intermediate(&mut self, paths: &[Path<ScaledPixels>]) -> Result<()> { - let Some(first_path) = paths.first() else { - return Ok(()); - }; - - // When copying paths from the intermediate texture to the drawable, - // each pixel must only be copied once, in case of transparent paths. - // - // If all paths have the same draw order, then their bounds are all - // disjoint, so we can copy each path's bounds individually. If this - // batch combines different draw orders, we perform a single copy - // for a minimal spanning rect. - let sprites = if paths.last().unwrap().order == first_path.order { - paths - .iter() - .map(|path| PathSprite { - bounds: path.bounds, - }) - .collect::<Vec<_>>() - } else { - let mut bounds = first_path.bounds; - for path in paths.iter().skip(1) { - bounds = bounds.union(&path.bounds); - } - vec![PathSprite { bounds }] - }; - - self.pipelines.path_sprite_pipeline.update_buffer( - &self.devices.device, - &self.devices.device_context, - &sprites, - )?; - - // Draw the sprites with the path texture - self.pipelines.path_sprite_pipeline.draw_with_texture( - &self.devices.device_context, - &self.resources.path_intermediate_srv, - &self.resources.viewport, - &self.globals.global_params_buffer, - &self.globals.sampler, - sprites.len() as u32, - ) - } - - fn draw_underlines(&mut self, underlines: &[Underline]) -> Result<()> { - if underlines.is_empty() { - return Ok(()); - } - self.pipelines.underline_pipeline.update_buffer( - &self.devices.device, - &self.devices.device_context, - underlines, - )?; - self.pipelines.underline_pipeline.draw( - &self.devices.device_context, - &self.resources.viewport, - &self.globals.global_params_buffer, - D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, - 4, - underlines.len() as u32, - ) - } - - fn draw_monochrome_sprites( - &mut self, - texture_id: AtlasTextureId, - sprites: &[MonochromeSprite], - ) -> Result<()> { - if sprites.is_empty() { - return Ok(()); - } - self.pipelines.mono_sprites.update_buffer( - &self.devices.device, - &self.devices.device_context, - sprites, - )?; - let texture_view = self.atlas.get_texture_view(texture_id); - self.pipelines.mono_sprites.draw_with_texture( - &self.devices.device_context, - &texture_view, - &self.resources.viewport, - &self.globals.global_params_buffer, - &self.globals.sampler, - sprites.len() as u32, - ) - } - - fn draw_polychrome_sprites( - &mut self, - texture_id: AtlasTextureId, - sprites: &[PolychromeSprite], - ) -> Result<()> { - if sprites.is_empty() { - return Ok(()); - } - self.pipelines.poly_sprites.update_buffer( - &self.devices.device, - &self.devices.device_context, - sprites, - )?; - let texture_view = self.atlas.get_texture_view(texture_id); - self.pipelines.poly_sprites.draw_with_texture( - &self.devices.device_context, - &texture_view, - &self.resources.viewport, - &self.globals.global_params_buffer, - &self.globals.sampler, - sprites.len() as u32, - ) - } - - fn draw_surfaces(&mut self, surfaces: &[PaintSurface]) -> Result<()> { - if surfaces.is_empty() { - return Ok(()); - } - Ok(()) - } - - pub(crate) fn gpu_specs(&self) -> Result<GpuSpecs> { - let desc = unsafe { self.devices.adapter.GetDesc1() }?; - let is_software_emulated = (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE.0 as u32) != 0; - let device_name = String::from_utf16_lossy(&desc.Description) - .trim_matches(char::from(0)) - .to_string(); - let driver_name = match desc.VendorId { - 0x10DE => "NVIDIA Corporation".to_string(), - 0x1002 => "AMD Corporation".to_string(), - 0x8086 => "Intel Corporation".to_string(), - id => format!("Unknown Vendor (ID: {:#X})", id), - }; - let driver_version = match desc.VendorId { - 0x10DE => nvidia::get_driver_version(), - 0x1002 => amd::get_driver_version(), - // For Intel and other vendors, we use the DXGI API to get the driver version. - _ => dxgi::get_driver_version(&self.devices.adapter), - } - .context("Failed to get gpu driver info") - .log_err() - .unwrap_or("Unknown Driver".to_string()); - Ok(GpuSpecs { - is_software_emulated, - device_name, - driver_name, - driver_info: driver_version, - }) - } -} - -impl DirectXResources { - pub fn new( - devices: &DirectXDevices, - width: u32, - height: u32, - hwnd: HWND, - disable_direct_composition: bool, - ) -> Result<ManuallyDrop<Self>> { - let swap_chain = if disable_direct_composition { - create_swap_chain(&devices.dxgi_factory, &devices.device, hwnd, width, height)? - } else { - create_swap_chain_for_composition( - &devices.dxgi_factory, - &devices.device, - width, - height, - )? - }; - - let ( - render_target, - render_target_view, - path_intermediate_texture, - path_intermediate_srv, - path_intermediate_msaa_texture, - path_intermediate_msaa_view, - viewport, - ) = create_resources(devices, &swap_chain, width, height)?; - set_rasterizer_state(&devices.device, &devices.device_context)?; - - Ok(ManuallyDrop::new(Self { - swap_chain, - render_target, - render_target_view, - path_intermediate_texture, - path_intermediate_msaa_texture, - path_intermediate_msaa_view, - path_intermediate_srv, - viewport, - width, - height, - })) - } - - #[inline] - fn recreate_resources( - &mut self, - devices: &DirectXDevices, - width: u32, - height: u32, - ) -> Result<()> { - let ( - render_target, - render_target_view, - path_intermediate_texture, - path_intermediate_srv, - path_intermediate_msaa_texture, - path_intermediate_msaa_view, - viewport, - ) = create_resources(devices, &self.swap_chain, width, height)?; - self.render_target = render_target; - self.render_target_view = render_target_view; - self.path_intermediate_texture = path_intermediate_texture; - self.path_intermediate_msaa_texture = path_intermediate_msaa_texture; - self.path_intermediate_msaa_view = path_intermediate_msaa_view; - self.path_intermediate_srv = path_intermediate_srv; - self.viewport = viewport; - self.width = width; - self.height = height; - Ok(()) - } -} - -impl DirectXRenderPipelines { - pub fn new(device: &ID3D11Device) -> Result<Self> { - let shadow_pipeline = PipelineState::new( - device, - "shadow_pipeline", - ShaderModule::Shadow, - 4, - create_blend_state(device)?, - )?; - let quad_pipeline = PipelineState::new( - device, - "quad_pipeline", - ShaderModule::Quad, - 64, - create_blend_state(device)?, - )?; - let path_rasterization_pipeline = PipelineState::new( - device, - "path_rasterization_pipeline", - ShaderModule::PathRasterization, - 32, - create_blend_state_for_path_rasterization(device)?, - )?; - let path_sprite_pipeline = PipelineState::new( - device, - "path_sprite_pipeline", - ShaderModule::PathSprite, - 4, - create_blend_state_for_path_sprite(device)?, - )?; - let underline_pipeline = PipelineState::new( - device, - "underline_pipeline", - ShaderModule::Underline, - 4, - create_blend_state(device)?, - )?; - let mono_sprites = PipelineState::new( - device, - "monochrome_sprite_pipeline", - ShaderModule::MonochromeSprite, - 512, - create_blend_state(device)?, - )?; - let poly_sprites = PipelineState::new( - device, - "polychrome_sprite_pipeline", - ShaderModule::PolychromeSprite, - 16, - create_blend_state(device)?, - )?; - - Ok(Self { - shadow_pipeline, - quad_pipeline, - path_rasterization_pipeline, - path_sprite_pipeline, - underline_pipeline, - mono_sprites, - poly_sprites, - }) - } -} - -impl DirectComposition { - pub fn new(dxgi_device: &IDXGIDevice, hwnd: HWND) -> Result<Self> { - let comp_device = get_comp_device(&dxgi_device)?; - let comp_target = unsafe { comp_device.CreateTargetForHwnd(hwnd, true) }?; - let comp_visual = unsafe { comp_device.CreateVisual() }?; - - Ok(Self { - comp_device, - comp_target, - comp_visual, - }) - } - - pub fn set_swap_chain(&self, swap_chain: &IDXGISwapChain1) -> Result<()> { - unsafe { - self.comp_visual.SetContent(swap_chain)?; - self.comp_target.SetRoot(&self.comp_visual)?; - self.comp_device.Commit()?; - } - Ok(()) - } -} - -impl DirectXGlobalElements { - pub fn new(device: &ID3D11Device) -> Result<Self> { - let global_params_buffer = unsafe { - let desc = D3D11_BUFFER_DESC { - ByteWidth: std::mem::size_of::<GlobalParams>() as u32, - Usage: D3D11_USAGE_DYNAMIC, - BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, - CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, - ..Default::default() - }; - let mut buffer = None; - device.CreateBuffer(&desc, None, Some(&mut buffer))?; - [buffer] - }; - - let sampler = unsafe { - let desc = D3D11_SAMPLER_DESC { - Filter: D3D11_FILTER_MIN_MAG_MIP_LINEAR, - AddressU: D3D11_TEXTURE_ADDRESS_WRAP, - AddressV: D3D11_TEXTURE_ADDRESS_WRAP, - AddressW: D3D11_TEXTURE_ADDRESS_WRAP, - MipLODBias: 0.0, - MaxAnisotropy: 1, - ComparisonFunc: D3D11_COMPARISON_ALWAYS, - BorderColor: [0.0; 4], - MinLOD: 0.0, - MaxLOD: D3D11_FLOAT32_MAX, - }; - let mut output = None; - device.CreateSamplerState(&desc, Some(&mut output))?; - [output] - }; - - Ok(Self { - global_params_buffer, - sampler, - }) - } -} - -#[derive(Debug, Default)] -#[repr(C)] -struct GlobalParams { - viewport_size: [f32; 2], - _pad: u64, -} - -struct PipelineState<T> { - label: &'static str, - vertex: ID3D11VertexShader, - fragment: ID3D11PixelShader, - buffer: ID3D11Buffer, - buffer_size: usize, - view: [Option<ID3D11ShaderResourceView>; 1], - blend_state: ID3D11BlendState, - _marker: std::marker::PhantomData<T>, -} - -impl<T> PipelineState<T> { - fn new( - device: &ID3D11Device, - label: &'static str, - shader_module: ShaderModule, - buffer_size: usize, - blend_state: ID3D11BlendState, - ) -> Result<Self> { - let vertex = { - let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Vertex)?; - create_vertex_shader(device, raw_shader.as_bytes())? - }; - let fragment = { - let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Fragment)?; - create_fragment_shader(device, raw_shader.as_bytes())? - }; - let buffer = create_buffer(device, std::mem::size_of::<T>(), buffer_size)?; - let view = create_buffer_view(device, &buffer)?; - - Ok(PipelineState { - label, - vertex, - fragment, - buffer, - buffer_size, - view, - blend_state, - _marker: std::marker::PhantomData, - }) - } - - fn update_buffer( - &mut self, - device: &ID3D11Device, - device_context: &ID3D11DeviceContext, - data: &[T], - ) -> Result<()> { - if self.buffer_size < data.len() { - let new_buffer_size = data.len().next_power_of_two(); - log::info!( - "Updating {} buffer size from {} to {}", - self.label, - self.buffer_size, - new_buffer_size - ); - let buffer = create_buffer(device, std::mem::size_of::<T>(), new_buffer_size)?; - let view = create_buffer_view(device, &buffer)?; - self.buffer = buffer; - self.view = view; - self.buffer_size = new_buffer_size; - } - update_buffer(device_context, &self.buffer, data) - } - - fn draw( - &self, - device_context: &ID3D11DeviceContext, - viewport: &[D3D11_VIEWPORT], - global_params: &[Option<ID3D11Buffer>], - topology: D3D_PRIMITIVE_TOPOLOGY, - vertex_count: u32, - instance_count: u32, - ) -> Result<()> { - set_pipeline_state( - device_context, - &self.view, - topology, - viewport, - &self.vertex, - &self.fragment, - global_params, - &self.blend_state, - ); - unsafe { - device_context.DrawInstanced(vertex_count, instance_count, 0, 0); - } - Ok(()) - } - - fn draw_with_texture( - &self, - device_context: &ID3D11DeviceContext, - texture: &[Option<ID3D11ShaderResourceView>], - viewport: &[D3D11_VIEWPORT], - global_params: &[Option<ID3D11Buffer>], - sampler: &[Option<ID3D11SamplerState>], - instance_count: u32, - ) -> Result<()> { - set_pipeline_state( - device_context, - &self.view, - D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, - viewport, - &self.vertex, - &self.fragment, - global_params, - &self.blend_state, - ); - unsafe { - device_context.PSSetSamplers(0, Some(sampler)); - device_context.VSSetShaderResources(0, Some(texture)); - device_context.PSSetShaderResources(0, Some(texture)); - - device_context.DrawInstanced(4, instance_count, 0, 0); - } - Ok(()) - } -} - -#[derive(Clone, Copy)] -#[repr(C)] -struct PathRasterizationSprite { - xy_position: Point<ScaledPixels>, - st_position: Point<f32>, - color: Background, - bounds: Bounds<ScaledPixels>, -} - -#[derive(Clone, Copy)] -#[repr(C)] -struct PathSprite { - bounds: Bounds<ScaledPixels>, -} - -impl Drop for DirectXRenderer { - fn drop(&mut self) { - #[cfg(debug_assertions)] - report_live_objects(&self.devices.device).ok(); - unsafe { - ManuallyDrop::drop(&mut self.devices); - ManuallyDrop::drop(&mut self.resources); - } - } -} - -impl Drop for DirectXResources { - fn drop(&mut self) { - unsafe { - ManuallyDrop::drop(&mut self.render_target); - } - } -} - -#[inline] -fn check_debug_layer_available() -> bool { - #[cfg(debug_assertions)] - { - unsafe { DXGIGetDebugInterface1::<IDXGIInfoQueue>(0) } - .log_err() - .is_some() - } - #[cfg(not(debug_assertions))] - { - false - } -} - -#[inline] -fn get_dxgi_factory(debug_layer_available: bool) -> Result<IDXGIFactory6> { - let factory_flag = if debug_layer_available { - DXGI_CREATE_FACTORY_DEBUG - } else { - #[cfg(debug_assertions)] - log::warn!( - "Failed to get DXGI debug interface. DirectX debugging features will be disabled." - ); - DXGI_CREATE_FACTORY_FLAGS::default() - }; - unsafe { Ok(CreateDXGIFactory2(factory_flag)?) } -} - -fn get_adapter(dxgi_factory: &IDXGIFactory6, debug_layer_available: bool) -> Result<IDXGIAdapter1> { - for adapter_index in 0.. { - let adapter: IDXGIAdapter1 = unsafe { - dxgi_factory - .EnumAdapterByGpuPreference(adapter_index, DXGI_GPU_PREFERENCE_MINIMUM_POWER) - }?; - if let Ok(desc) = unsafe { adapter.GetDesc1() } { - let gpu_name = String::from_utf16_lossy(&desc.Description) - .trim_matches(char::from(0)) - .to_string(); - log::info!("Using GPU: {}", gpu_name); - } - // Check to see whether the adapter supports Direct3D 11, but don't - // create the actual device yet. - if get_device(&adapter, None, None, None, debug_layer_available) - .log_err() - .is_some() - { - return Ok(adapter); - } - } - - unreachable!() -} - -fn get_device( - adapter: &IDXGIAdapter1, - device: Option<*mut Option<ID3D11Device>>, - context: Option<*mut Option<ID3D11DeviceContext>>, - feature_level: Option<*mut D3D_FEATURE_LEVEL>, - debug_layer_available: bool, -) -> Result<()> { - let device_flags = if debug_layer_available { - D3D11_CREATE_DEVICE_BGRA_SUPPORT | D3D11_CREATE_DEVICE_DEBUG - } else { - D3D11_CREATE_DEVICE_BGRA_SUPPORT - }; - unsafe { - D3D11CreateDevice( - adapter, - D3D_DRIVER_TYPE_UNKNOWN, - HMODULE::default(), - device_flags, - // 4x MSAA is required for Direct3D Feature Level 10.1 or better - Some(&[ - D3D_FEATURE_LEVEL_11_1, - D3D_FEATURE_LEVEL_11_0, - D3D_FEATURE_LEVEL_10_1, - ]), - D3D11_SDK_VERSION, - device, - feature_level, - context, - )?; - } - Ok(()) -} - -#[inline] -fn get_comp_device(dxgi_device: &IDXGIDevice) -> Result<IDCompositionDevice> { - Ok(unsafe { DCompositionCreateDevice(dxgi_device)? }) -} - -fn create_swap_chain_for_composition( - dxgi_factory: &IDXGIFactory6, - device: &ID3D11Device, - width: u32, - height: u32, -) -> Result<IDXGISwapChain1> { - let desc = DXGI_SWAP_CHAIN_DESC1 { - Width: width, - Height: height, - Format: RENDER_TARGET_FORMAT, - Stereo: false.into(), - SampleDesc: DXGI_SAMPLE_DESC { - Count: 1, - Quality: 0, - }, - BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, - BufferCount: BUFFER_COUNT as u32, - // Composition SwapChains only support the DXGI_SCALING_STRETCH Scaling. - Scaling: DXGI_SCALING_STRETCH, - SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, - AlphaMode: DXGI_ALPHA_MODE_PREMULTIPLIED, - Flags: 0, - }; - Ok(unsafe { dxgi_factory.CreateSwapChainForComposition(device, &desc, None)? }) -} - -fn create_swap_chain( - dxgi_factory: &IDXGIFactory6, - device: &ID3D11Device, - hwnd: HWND, - width: u32, - height: u32, -) -> Result<IDXGISwapChain1> { - use windows::Win32::Graphics::Dxgi::DXGI_MWA_NO_ALT_ENTER; - - let desc = DXGI_SWAP_CHAIN_DESC1 { - Width: width, - Height: height, - Format: RENDER_TARGET_FORMAT, - Stereo: false.into(), - SampleDesc: DXGI_SAMPLE_DESC { - Count: 1, - Quality: 0, - }, - BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, - BufferCount: BUFFER_COUNT as u32, - Scaling: DXGI_SCALING_NONE, - SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, - AlphaMode: DXGI_ALPHA_MODE_IGNORE, - Flags: 0, - }; - let swap_chain = - unsafe { dxgi_factory.CreateSwapChainForHwnd(device, hwnd, &desc, None, None) }?; - unsafe { dxgi_factory.MakeWindowAssociation(hwnd, DXGI_MWA_NO_ALT_ENTER) }?; - Ok(swap_chain) -} - -#[inline] -fn create_resources( - devices: &DirectXDevices, - swap_chain: &IDXGISwapChain1, - width: u32, - height: u32, -) -> Result<( - ManuallyDrop<ID3D11Texture2D>, - [Option<ID3D11RenderTargetView>; 1], - ID3D11Texture2D, - [Option<ID3D11ShaderResourceView>; 1], - ID3D11Texture2D, - [Option<ID3D11RenderTargetView>; 1], - [D3D11_VIEWPORT; 1], -)> { - let (render_target, render_target_view) = - create_render_target_and_its_view(&swap_chain, &devices.device)?; - let (path_intermediate_texture, path_intermediate_srv) = - create_path_intermediate_texture(&devices.device, width, height)?; - let (path_intermediate_msaa_texture, path_intermediate_msaa_view) = - create_path_intermediate_msaa_texture_and_view(&devices.device, width, height)?; - let viewport = set_viewport(&devices.device_context, width as f32, height as f32); - Ok(( - render_target, - render_target_view, - path_intermediate_texture, - path_intermediate_srv, - path_intermediate_msaa_texture, - path_intermediate_msaa_view, - viewport, - )) -} - -#[inline] -fn create_render_target_and_its_view( - swap_chain: &IDXGISwapChain1, - device: &ID3D11Device, -) -> Result<( - ManuallyDrop<ID3D11Texture2D>, - [Option<ID3D11RenderTargetView>; 1], -)> { - let render_target: ID3D11Texture2D = unsafe { swap_chain.GetBuffer(0) }?; - let mut render_target_view = None; - unsafe { device.CreateRenderTargetView(&render_target, None, Some(&mut render_target_view))? }; - Ok(( - ManuallyDrop::new(render_target), - [Some(render_target_view.unwrap())], - )) -} - -#[inline] -fn create_path_intermediate_texture( - device: &ID3D11Device, - width: u32, - height: u32, -) -> Result<(ID3D11Texture2D, [Option<ID3D11ShaderResourceView>; 1])> { - let texture = unsafe { - let mut output = None; - let desc = D3D11_TEXTURE2D_DESC { - Width: width, - Height: height, - MipLevels: 1, - ArraySize: 1, - Format: RENDER_TARGET_FORMAT, - SampleDesc: DXGI_SAMPLE_DESC { - Count: 1, - Quality: 0, - }, - Usage: D3D11_USAGE_DEFAULT, - BindFlags: (D3D11_BIND_RENDER_TARGET.0 | D3D11_BIND_SHADER_RESOURCE.0) as u32, - CPUAccessFlags: 0, - MiscFlags: 0, - }; - device.CreateTexture2D(&desc, None, Some(&mut output))?; - output.unwrap() - }; - - let mut shader_resource_view = None; - unsafe { device.CreateShaderResourceView(&texture, None, Some(&mut shader_resource_view))? }; - - Ok((texture, [Some(shader_resource_view.unwrap())])) -} - -#[inline] -fn create_path_intermediate_msaa_texture_and_view( - device: &ID3D11Device, - width: u32, - height: u32, -) -> Result<(ID3D11Texture2D, [Option<ID3D11RenderTargetView>; 1])> { - let msaa_texture = unsafe { - let mut output = None; - let desc = D3D11_TEXTURE2D_DESC { - Width: width, - Height: height, - MipLevels: 1, - ArraySize: 1, - Format: RENDER_TARGET_FORMAT, - SampleDesc: DXGI_SAMPLE_DESC { - Count: PATH_MULTISAMPLE_COUNT, - Quality: D3D11_STANDARD_MULTISAMPLE_PATTERN.0 as u32, - }, - Usage: D3D11_USAGE_DEFAULT, - BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, - CPUAccessFlags: 0, - MiscFlags: 0, - }; - device.CreateTexture2D(&desc, None, Some(&mut output))?; - output.unwrap() - }; - let mut msaa_view = None; - unsafe { device.CreateRenderTargetView(&msaa_texture, None, Some(&mut msaa_view))? }; - Ok((msaa_texture, [Some(msaa_view.unwrap())])) -} - -#[inline] -fn set_viewport( - device_context: &ID3D11DeviceContext, - width: f32, - height: f32, -) -> [D3D11_VIEWPORT; 1] { - let viewport = [D3D11_VIEWPORT { - TopLeftX: 0.0, - TopLeftY: 0.0, - Width: width, - Height: height, - MinDepth: 0.0, - MaxDepth: 1.0, - }]; - unsafe { device_context.RSSetViewports(Some(&viewport)) }; - viewport -} - -#[inline] -fn set_rasterizer_state(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Result<()> { - let desc = D3D11_RASTERIZER_DESC { - FillMode: D3D11_FILL_SOLID, - CullMode: D3D11_CULL_NONE, - FrontCounterClockwise: false.into(), - DepthBias: 0, - DepthBiasClamp: 0.0, - SlopeScaledDepthBias: 0.0, - DepthClipEnable: true.into(), - ScissorEnable: false.into(), - MultisampleEnable: true.into(), - AntialiasedLineEnable: false.into(), - }; - let rasterizer_state = unsafe { - let mut state = None; - device.CreateRasterizerState(&desc, Some(&mut state))?; - state.unwrap() - }; - unsafe { device_context.RSSetState(&rasterizer_state) }; - Ok(()) -} - -// https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_blend_desc -#[inline] -fn create_blend_state(device: &ID3D11Device) -> Result<ID3D11BlendState> { - // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display - // device performs the blend in linear space, which is ideal. - let mut desc = D3D11_BLEND_DESC::default(); - desc.RenderTarget[0].BlendEnable = true.into(); - desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; - desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; - desc.RenderTarget[0].SrcBlend = D3D11_BLEND_SRC_ALPHA; - desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; - desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; - desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; - desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; - unsafe { - let mut state = None; - device.CreateBlendState(&desc, Some(&mut state))?; - Ok(state.unwrap()) - } -} - -#[inline] -fn create_blend_state_for_path_rasterization(device: &ID3D11Device) -> Result<ID3D11BlendState> { - // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display - // device performs the blend in linear space, which is ideal. - let mut desc = D3D11_BLEND_DESC::default(); - desc.RenderTarget[0].BlendEnable = true.into(); - desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; - desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; - desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; - desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; - desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; - desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_INV_SRC_ALPHA; - desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; - unsafe { - let mut state = None; - device.CreateBlendState(&desc, Some(&mut state))?; - Ok(state.unwrap()) - } -} - -#[inline] -fn create_blend_state_for_path_sprite(device: &ID3D11Device) -> Result<ID3D11BlendState> { - // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display - // device performs the blend in linear space, which is ideal. - let mut desc = D3D11_BLEND_DESC::default(); - desc.RenderTarget[0].BlendEnable = true.into(); - desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; - desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; - desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; - desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; - desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; - desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; - desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; - unsafe { - let mut state = None; - device.CreateBlendState(&desc, Some(&mut state))?; - Ok(state.unwrap()) - } -} - -#[inline] -fn create_vertex_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11VertexShader> { - unsafe { - let mut shader = None; - device.CreateVertexShader(bytes, None, Some(&mut shader))?; - Ok(shader.unwrap()) - } -} - -#[inline] -fn create_fragment_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11PixelShader> { - unsafe { - let mut shader = None; - device.CreatePixelShader(bytes, None, Some(&mut shader))?; - Ok(shader.unwrap()) - } -} - -#[inline] -fn create_buffer( - device: &ID3D11Device, - element_size: usize, - buffer_size: usize, -) -> Result<ID3D11Buffer> { - let desc = D3D11_BUFFER_DESC { - ByteWidth: (element_size * buffer_size) as u32, - Usage: D3D11_USAGE_DYNAMIC, - BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, - CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, - MiscFlags: D3D11_RESOURCE_MISC_BUFFER_STRUCTURED.0 as u32, - StructureByteStride: element_size as u32, - }; - let mut buffer = None; - unsafe { device.CreateBuffer(&desc, None, Some(&mut buffer)) }?; - Ok(buffer.unwrap()) -} - -#[inline] -fn create_buffer_view( - device: &ID3D11Device, - buffer: &ID3D11Buffer, -) -> Result<[Option<ID3D11ShaderResourceView>; 1]> { - let mut view = None; - unsafe { device.CreateShaderResourceView(buffer, None, Some(&mut view)) }?; - Ok([view]) -} - -#[inline] -fn update_buffer<T>( - device_context: &ID3D11DeviceContext, - buffer: &ID3D11Buffer, - data: &[T], -) -> Result<()> { - unsafe { - let mut dest = std::mem::zeroed(); - device_context.Map(buffer, 0, D3D11_MAP_WRITE_DISCARD, 0, Some(&mut dest))?; - std::ptr::copy_nonoverlapping(data.as_ptr(), dest.pData as _, data.len()); - device_context.Unmap(buffer, 0); - } - Ok(()) -} - -#[inline] -fn set_pipeline_state( - device_context: &ID3D11DeviceContext, - buffer_view: &[Option<ID3D11ShaderResourceView>], - topology: D3D_PRIMITIVE_TOPOLOGY, - viewport: &[D3D11_VIEWPORT], - vertex_shader: &ID3D11VertexShader, - fragment_shader: &ID3D11PixelShader, - global_params: &[Option<ID3D11Buffer>], - blend_state: &ID3D11BlendState, -) { - unsafe { - device_context.VSSetShaderResources(1, Some(buffer_view)); - device_context.PSSetShaderResources(1, Some(buffer_view)); - device_context.IASetPrimitiveTopology(topology); - device_context.RSSetViewports(Some(viewport)); - device_context.VSSetShader(vertex_shader, None); - device_context.PSSetShader(fragment_shader, None); - device_context.VSSetConstantBuffers(0, Some(global_params)); - device_context.PSSetConstantBuffers(0, Some(global_params)); - device_context.OMSetBlendState(blend_state, None, 0xFFFFFFFF); - } -} - -#[cfg(debug_assertions)] -fn report_live_objects(device: &ID3D11Device) -> Result<()> { - let debug_device: ID3D11Debug = device.cast()?; - unsafe { - debug_device.ReportLiveDeviceObjects(D3D11_RLDO_DETAIL)?; - } - Ok(()) -} - -const BUFFER_COUNT: usize = 3; - -pub(crate) mod shader_resources { - use anyhow::Result; - - #[cfg(debug_assertions)] - use windows::{ - Win32::Graphics::Direct3D::{ - Fxc::{D3DCOMPILE_DEBUG, D3DCOMPILE_SKIP_OPTIMIZATION, D3DCompileFromFile}, - ID3DBlob, - }, - core::{HSTRING, PCSTR}, - }; - - #[derive(Copy, Clone, Debug, Eq, PartialEq)] - pub(crate) enum ShaderModule { - Quad, - Shadow, - Underline, - PathRasterization, - PathSprite, - MonochromeSprite, - PolychromeSprite, - EmojiRasterization, - } - - #[derive(Copy, Clone, Debug, Eq, PartialEq)] - pub(crate) enum ShaderTarget { - Vertex, - Fragment, - } - - pub(crate) struct RawShaderBytes<'t> { - inner: &'t [u8], - - #[cfg(debug_assertions)] - _blob: ID3DBlob, - } - - impl<'t> RawShaderBytes<'t> { - pub(crate) fn new(module: ShaderModule, target: ShaderTarget) -> Result<Self> { - #[cfg(not(debug_assertions))] - { - Ok(Self::from_bytes(module, target)) - } - #[cfg(debug_assertions)] - { - let blob = build_shader_blob(module, target)?; - let inner = unsafe { - std::slice::from_raw_parts( - blob.GetBufferPointer() as *const u8, - blob.GetBufferSize(), - ) - }; - Ok(Self { inner, _blob: blob }) - } - } - - pub(crate) fn as_bytes(&'t self) -> &'t [u8] { - self.inner - } - - #[cfg(not(debug_assertions))] - fn from_bytes(module: ShaderModule, target: ShaderTarget) -> Self { - let bytes = match module { - ShaderModule::Quad => match target { - ShaderTarget::Vertex => QUAD_VERTEX_BYTES, - ShaderTarget::Fragment => QUAD_FRAGMENT_BYTES, - }, - ShaderModule::Shadow => match target { - ShaderTarget::Vertex => SHADOW_VERTEX_BYTES, - ShaderTarget::Fragment => SHADOW_FRAGMENT_BYTES, - }, - ShaderModule::Underline => match target { - ShaderTarget::Vertex => UNDERLINE_VERTEX_BYTES, - ShaderTarget::Fragment => UNDERLINE_FRAGMENT_BYTES, - }, - ShaderModule::PathRasterization => match target { - ShaderTarget::Vertex => PATH_RASTERIZATION_VERTEX_BYTES, - ShaderTarget::Fragment => PATH_RASTERIZATION_FRAGMENT_BYTES, - }, - ShaderModule::PathSprite => match target { - ShaderTarget::Vertex => PATH_SPRITE_VERTEX_BYTES, - ShaderTarget::Fragment => PATH_SPRITE_FRAGMENT_BYTES, - }, - ShaderModule::MonochromeSprite => match target { - ShaderTarget::Vertex => MONOCHROME_SPRITE_VERTEX_BYTES, - ShaderTarget::Fragment => MONOCHROME_SPRITE_FRAGMENT_BYTES, - }, - ShaderModule::PolychromeSprite => match target { - ShaderTarget::Vertex => POLYCHROME_SPRITE_VERTEX_BYTES, - ShaderTarget::Fragment => POLYCHROME_SPRITE_FRAGMENT_BYTES, - }, - ShaderModule::EmojiRasterization => match target { - ShaderTarget::Vertex => EMOJI_RASTERIZATION_VERTEX_BYTES, - ShaderTarget::Fragment => EMOJI_RASTERIZATION_FRAGMENT_BYTES, - }, - }; - Self { inner: bytes } - } - } - - #[cfg(debug_assertions)] - pub(super) fn build_shader_blob(entry: ShaderModule, target: ShaderTarget) -> Result<ID3DBlob> { - unsafe { - let shader_name = if matches!(entry, ShaderModule::EmojiRasterization) { - "color_text_raster.hlsl" - } else { - "shaders.hlsl" - }; - - let entry = format!( - "{}_{}\0", - entry.as_str(), - match target { - ShaderTarget::Vertex => "vertex", - ShaderTarget::Fragment => "fragment", - } - ); - let target = match target { - ShaderTarget::Vertex => "vs_4_1\0", - ShaderTarget::Fragment => "ps_4_1\0", - }; - - let mut compile_blob = None; - let mut error_blob = None; - let shader_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join(&format!("src/platform/windows/{}", shader_name)) - .canonicalize()?; - - let entry_point = PCSTR::from_raw(entry.as_ptr()); - let target_cstr = PCSTR::from_raw(target.as_ptr()); - - let ret = D3DCompileFromFile( - &HSTRING::from(shader_path.to_str().unwrap()), - None, - None, - entry_point, - target_cstr, - D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, - 0, - &mut compile_blob, - Some(&mut error_blob), - ); - if ret.is_err() { - let Some(error_blob) = error_blob else { - return Err(anyhow::anyhow!("{ret:?}")); - }; - - let error_string = - std::ffi::CStr::from_ptr(error_blob.GetBufferPointer() as *const i8) - .to_string_lossy(); - log::error!("Shader compile error: {}", error_string); - return Err(anyhow::anyhow!("Compile error: {}", error_string)); - } - Ok(compile_blob.unwrap()) - } - } - - #[cfg(not(debug_assertions))] - include!(concat!(env!("OUT_DIR"), "/shaders_bytes.rs")); - - #[cfg(debug_assertions)] - impl ShaderModule { - pub fn as_str(&self) -> &str { - match self { - ShaderModule::Quad => "quad", - ShaderModule::Shadow => "shadow", - ShaderModule::Underline => "underline", - ShaderModule::PathRasterization => "path_rasterization", - ShaderModule::PathSprite => "path_sprite", - ShaderModule::MonochromeSprite => "monochrome_sprite", - ShaderModule::PolychromeSprite => "polychrome_sprite", - ShaderModule::EmojiRasterization => "emoji_rasterization", - } - } - } -} - -mod nvidia { - use std::{ - ffi::CStr, - os::raw::{c_char, c_int, c_uint}, - }; - - use anyhow::{Context, Result}; - use windows::{ - Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, - core::s, - }; - - // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180 - const NVAPI_SHORT_STRING_MAX: usize = 64; - - // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L235 - #[allow(non_camel_case_types)] - type NvAPI_ShortString = [c_char; NVAPI_SHORT_STRING_MAX]; - - // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L447 - #[allow(non_camel_case_types)] - type NvAPI_SYS_GetDriverAndBranchVersion_t = unsafe extern "C" fn( - driver_version: *mut c_uint, - build_branch_string: *mut NvAPI_ShortString, - ) -> c_int; - - pub(super) fn get_driver_version() -> Result<String> { - unsafe { - // Try to load the NVIDIA driver DLL - #[cfg(target_pointer_width = "64")] - let nvidia_dll = LoadLibraryA(s!("nvapi64.dll")).context("Can't load nvapi64.dll")?; - #[cfg(target_pointer_width = "32")] - let nvidia_dll = LoadLibraryA(s!("nvapi.dll")).context("Can't load nvapi.dll")?; - - let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface")) - .ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?; - let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr); - - // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_interface.h#L41 - let nvapi_get_driver_version_ptr = nvapi_query(0x2926aaad); - if nvapi_get_driver_version_ptr.is_null() { - anyhow::bail!("Failed to get NVIDIA driver version function pointer"); - } - let nvapi_get_driver_version: NvAPI_SYS_GetDriverAndBranchVersion_t = - std::mem::transmute(nvapi_get_driver_version_ptr); - - let mut driver_version: c_uint = 0; - let mut build_branch_string: NvAPI_ShortString = [0; NVAPI_SHORT_STRING_MAX]; - let result = nvapi_get_driver_version( - &mut driver_version as *mut c_uint, - &mut build_branch_string as *mut NvAPI_ShortString, - ); - - if result != 0 { - anyhow::bail!( - "Failed to get NVIDIA driver version, error code: {}", - result - ); - } - let major = driver_version / 100; - let minor = driver_version % 100; - let branch_string = CStr::from_ptr(build_branch_string.as_ptr()); - Ok(format!( - "{}.{} {}", - major, - minor, - branch_string.to_string_lossy() - )) - } - } -} - -mod amd { - use std::os::raw::{c_char, c_int, c_void}; - - use anyhow::{Context, Result}; - use windows::{ - Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, - core::s, - }; - - // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145 - const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12); - - // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L204 - // This is an opaque type, using struct to represent it properly for FFI - #[repr(C)] - struct AGSContext { - _private: [u8; 0], - } - - #[repr(C)] - pub struct AGSGPUInfo { - pub driver_version: *const c_char, - pub radeon_software_version: *const c_char, - pub num_devices: c_int, - pub devices: *mut c_void, - } - - // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L429 - #[allow(non_camel_case_types)] - type agsInitialize_t = unsafe extern "C" fn( - version: c_int, - config: *const c_void, - context: *mut *mut AGSContext, - gpu_info: *mut AGSGPUInfo, - ) -> c_int; - - // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L436 - #[allow(non_camel_case_types)] - type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int; - - pub(super) fn get_driver_version() -> Result<String> { - unsafe { - #[cfg(target_pointer_width = "64")] - let amd_dll = - LoadLibraryA(s!("amd_ags_x64.dll")).context("Failed to load AMD AGS library")?; - #[cfg(target_pointer_width = "32")] - let amd_dll = - LoadLibraryA(s!("amd_ags_x86.dll")).context("Failed to load AMD AGS library")?; - - let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize")) - .ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?; - let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize")) - .ok_or_else(|| anyhow::anyhow!("Failed to get agsDeInitialize address"))?; - - let ags_initialize: agsInitialize_t = std::mem::transmute(ags_initialize_addr); - let ags_deinitialize: agsDeInitialize_t = std::mem::transmute(ags_deinitialize_addr); - - let mut context: *mut AGSContext = std::ptr::null_mut(); - let mut gpu_info: AGSGPUInfo = AGSGPUInfo { - driver_version: std::ptr::null(), - radeon_software_version: std::ptr::null(), - num_devices: 0, - devices: std::ptr::null_mut(), - }; - - let result = ags_initialize( - AGS_CURRENT_VERSION, - std::ptr::null(), - &mut context, - &mut gpu_info, - ); - if result != 0 { - anyhow::bail!("Failed to initialize AMD AGS, error code: {}", result); - } - - // Vulkan acctually returns this as the driver version - let software_version = if !gpu_info.radeon_software_version.is_null() { - std::ffi::CStr::from_ptr(gpu_info.radeon_software_version) - .to_string_lossy() - .into_owned() - } else { - "Unknown Radeon Software Version".to_string() - }; - - let driver_version = if !gpu_info.driver_version.is_null() { - std::ffi::CStr::from_ptr(gpu_info.driver_version) - .to_string_lossy() - .into_owned() - } else { - "Unknown Radeon Driver Version".to_string() - }; - - ags_deinitialize(context); - Ok(format!("{} ({})", software_version, driver_version)) - } - } -} - -mod dxgi { - use windows::{ - Win32::Graphics::Dxgi::{IDXGIAdapter1, IDXGIDevice}, - core::Interface, - }; - - pub(super) fn get_driver_version(adapter: &IDXGIAdapter1) -> anyhow::Result<String> { - let number = unsafe { adapter.CheckInterfaceSupport(&IDXGIDevice::IID as _) }?; - Ok(format!( - "{}.{}.{}.{}", - number >> 48, - (number >> 32) & 0xFFFF, - (number >> 16) & 0xFFFF, - number & 0xFFFF - )) - } -} diff --git a/crates/gpui/src/platform/windows/events.rs b/crates/gpui/src/platform/windows/events.rs index 61f410a8c6..839fd10375 100644 --- a/crates/gpui/src/platform/windows/events.rs +++ b/crates/gpui/src/platform/windows/events.rs @@ -23,7 +23,6 @@ pub(crate) const WM_GPUI_CURSOR_STYLE_CHANGED: u32 = WM_USER + 1; pub(crate) const WM_GPUI_CLOSE_ONE_WINDOW: u32 = WM_USER + 2; pub(crate) const WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD: u32 = WM_USER + 3; pub(crate) const WM_GPUI_DOCK_MENU_ACTION: u32 = WM_USER + 4; -pub(crate) const WM_GPUI_FORCE_UPDATE_WINDOW: u32 = WM_USER + 5; const SIZE_MOVE_LOOP_TIMER_ID: usize = 1; const AUTO_HIDE_TASKBAR_THICKNESS_PX: i32 = 1; @@ -38,7 +37,6 @@ pub(crate) fn handle_msg( let handled = match msg { WM_ACTIVATE => handle_activate_msg(wparam, state_ptr), WM_CREATE => handle_create_msg(handle, state_ptr), - WM_DEVICECHANGE => handle_device_change_msg(handle, wparam, state_ptr), WM_MOVE => handle_move_msg(handle, lparam, state_ptr), WM_SIZE => handle_size_msg(wparam, lparam, state_ptr), WM_GETMINMAXINFO => handle_get_min_max_info_msg(lparam, state_ptr), @@ -50,7 +48,7 @@ pub(crate) fn handle_msg( WM_DISPLAYCHANGE => handle_display_change_msg(handle, state_ptr), WM_NCHITTEST => handle_hit_test_msg(handle, msg, wparam, lparam, state_ptr), WM_PAINT => handle_paint_msg(handle, state_ptr), - WM_CLOSE => handle_close_msg(state_ptr), + WM_CLOSE => handle_close_msg(handle, state_ptr), WM_DESTROY => handle_destroy_msg(handle, state_ptr), WM_MOUSEMOVE => handle_mouse_move_msg(handle, lparam, wparam, state_ptr), WM_MOUSELEAVE | WM_NCMOUSELEAVE => handle_mouse_leave_msg(state_ptr), @@ -98,7 +96,6 @@ pub(crate) fn handle_msg( WM_SETTINGCHANGE => handle_system_settings_changed(handle, wparam, lparam, state_ptr), WM_INPUTLANGCHANGE => handle_input_language_changed(lparam, state_ptr), WM_GPUI_CURSOR_STYLE_CHANGED => handle_cursor_changed(lparam, state_ptr), - WM_GPUI_FORCE_UPDATE_WINDOW => draw_window(handle, true, state_ptr), _ => None, }; if let Some(n) = handled { @@ -184,9 +181,11 @@ fn handle_size_msg( let new_size = size(DevicePixels(width), DevicePixels(height)); let scale_factor = lock.scale_factor; if lock.restore_from_minimized.is_some() { + lock.renderer + .update_drawable_size_even_if_unchanged(new_size); lock.callbacks.request_frame = lock.restore_from_minimized.take(); } else { - lock.renderer.resize(new_size).log_err(); + lock.renderer.update_drawable_size(new_size); } let new_size = new_size.to_pixels(scale_factor); lock.logical_size = new_size; @@ -239,14 +238,40 @@ fn handle_timer_msg( } fn handle_paint_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - draw_window(handle, false, state_ptr) + let mut lock = state_ptr.state.borrow_mut(); + if let Some(mut request_frame) = lock.callbacks.request_frame.take() { + drop(lock); + request_frame(Default::default()); + state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame); + } + unsafe { ValidateRect(Some(handle), None).ok().log_err() }; + Some(0) } -fn handle_close_msg(state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut callback = state_ptr.state.borrow_mut().callbacks.should_close.take()?; - let should_close = callback(); - state_ptr.state.borrow_mut().callbacks.should_close = Some(callback); - if should_close { None } else { Some(0) } +fn handle_close_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { + let mut lock = state_ptr.state.borrow_mut(); + let output = if let Some(mut callback) = lock.callbacks.should_close.take() { + drop(lock); + let should_close = callback(); + state_ptr.state.borrow_mut().callbacks.should_close = Some(callback); + if should_close { None } else { Some(0) } + } else { + None + }; + + // Workaround as window close animation is not played with `WS_EX_LAYERED` enabled. + if output.is_none() { + unsafe { + let current_style = get_window_long(handle, GWL_EXSTYLE); + set_window_long( + handle, + GWL_EXSTYLE, + current_style & !WS_EX_LAYERED.0 as isize, + ); + } + } + + output } fn handle_destroy_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { @@ -1198,53 +1223,6 @@ fn handle_input_language_changed( Some(0) } -fn handle_device_change_msg( - handle: HWND, - wparam: WPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if wparam.0 == DBT_DEVNODES_CHANGED as usize { - // The reason for sending this message is to actually trigger a redraw of the window. - unsafe { - PostMessageW( - Some(handle), - WM_GPUI_FORCE_UPDATE_WINDOW, - WPARAM(0), - LPARAM(0), - ) - .log_err(); - } - // If the GPU device is lost, this redraw will take care of recreating the device context. - // The WM_GPUI_FORCE_UPDATE_WINDOW message will take care of redrawing the window, after - // the device context has been recreated. - draw_window(handle, true, state_ptr) - } else { - // Other device change messages are not handled. - None - } -} - -#[inline] -fn draw_window( - handle: HWND, - force_render: bool, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut request_frame = state_ptr - .state - .borrow_mut() - .callbacks - .request_frame - .take()?; - request_frame(RequestFrameOptions { - require_presentation: false, - force_render, - }); - state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame); - unsafe { ValidateRect(Some(handle), None).ok().log_err() }; - Some(0) -} - #[inline] fn parse_char_message(wparam: WPARAM, state_ptr: &Rc<WindowsWindowStatePtr>) -> Option<String> { let code_point = wparam.loword(); diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index bc09cc199d..401ecdeffe 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -28,12 +28,13 @@ use windows::{ core::*, }; -use crate::*; +use crate::{platform::blade::BladeContext, *}; pub(crate) struct WindowsPlatform { state: RefCell<WindowsPlatformState>, raw_window_handles: RwLock<SmallVec<[HWND; 4]>>, // The below members will never change throughout the entire lifecycle of the app. + gpu_context: BladeContext, icon: HICON, main_receiver: flume::Receiver<Runnable>, background_executor: BackgroundExecutor, @@ -44,7 +45,6 @@ pub(crate) struct WindowsPlatform { drop_target_helper: IDropTargetHelper, validation_number: usize, main_thread_id_win32: u32, - disable_direct_composition: bool, } pub(crate) struct WindowsPlatformState { @@ -94,18 +94,14 @@ impl WindowsPlatform { main_thread_id_win32, validation_number, )); - let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION) - .is_ok_and(|value| value == "true" || value == "1"); let background_executor = BackgroundExecutor::new(dispatcher.clone()); let foreground_executor = ForegroundExecutor::new(dispatcher); - let directx_devices = DirectXDevices::new(disable_direct_composition) - .context("Unable to init directx devices.")?; let bitmap_factory = ManuallyDrop::new(unsafe { CoCreateInstance(&CLSID_WICImagingFactory, None, CLSCTX_INPROC_SERVER) .context("Error creating bitmap factory.")? }); let text_system = Arc::new( - DirectWriteTextSystem::new(&directx_devices, &bitmap_factory) + DirectWriteTextSystem::new(&bitmap_factory) .context("Error creating DirectWriteTextSystem")?, ); let drop_target_helper: IDropTargetHelper = unsafe { @@ -115,17 +111,18 @@ impl WindowsPlatform { let icon = load_icon().unwrap_or_default(); let state = RefCell::new(WindowsPlatformState::new()); let raw_window_handles = RwLock::new(SmallVec::new()); + let gpu_context = BladeContext::new().context("Unable to init GPU context")?; let windows_version = WindowsVersion::new().context("Error retrieve windows version")?; Ok(Self { state, raw_window_handles, + gpu_context, icon, main_receiver, background_executor, foreground_executor, text_system, - disable_direct_composition, windows_version, bitmap_factory, drop_target_helper, @@ -190,7 +187,6 @@ impl WindowsPlatform { validation_number: self.validation_number, main_receiver: self.main_receiver.clone(), main_thread_id_win32: self.main_thread_id_win32, - disable_direct_composition: self.disable_direct_composition, } } @@ -347,11 +343,27 @@ impl Platform for WindowsPlatform { fn run(&self, on_finish_launching: Box<dyn 'static + FnOnce()>) { on_finish_launching(); - loop { - if self.handle_events() { - break; + let vsync_event = unsafe { Owned::new(CreateEventW(None, false, false, None).unwrap()) }; + begin_vsync(*vsync_event); + 'a: loop { + let wait_result = unsafe { + MsgWaitForMultipleObjects(Some(&[*vsync_event]), false, INFINITE, QS_ALLINPUT) + }; + + match wait_result { + // compositor clock ticked so we should draw a frame + WAIT_EVENT(0) => self.redraw_all(), + // Windows thread messages are posted + WAIT_EVENT(1) => { + if self.handle_events() { + break 'a; + } + } + _ => { + log::error!("Something went wrong while waiting {:?}", wait_result); + break; + } } - self.redraw_all(); } if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit { @@ -443,7 +455,12 @@ impl Platform for WindowsPlatform { handle: AnyWindowHandle, options: WindowParams, ) -> Result<Box<dyn PlatformWindow>> { - let window = WindowsWindow::new(handle, options, self.generate_creation_info())?; + let window = WindowsWindow::new( + handle, + options, + self.generate_creation_info(), + &self.gpu_context, + )?; let handle = window.get_raw_handle(); self.raw_window_handles.write().push(handle); @@ -722,7 +739,6 @@ pub(crate) struct WindowCreationInfo { pub(crate) validation_number: usize, pub(crate) main_receiver: flume::Receiver<Runnable>, pub(crate) main_thread_id_win32: u32, - pub(crate) disable_direct_composition: bool, } fn open_target(target: &str) { @@ -830,6 +846,16 @@ fn file_save_dialog(directory: PathBuf, window: Option<HWND>) -> Result<Option<P Ok(Some(PathBuf::from(file_path_string))) } +fn begin_vsync(vsync_event: HANDLE) { + let event: SafeHandle = vsync_event.into(); + std::thread::spawn(move || unsafe { + loop { + windows::Win32::Graphics::Dwm::DwmFlush().log_err(); + SetEvent(*event).log_err(); + } + }); +} + fn load_icon() -> Result<HICON> { let module = unsafe { GetModuleHandleW(None).context("unable to get module handle")? }; let handle = unsafe { diff --git a/crates/gpui/src/platform/windows/shaders.hlsl b/crates/gpui/src/platform/windows/shaders.hlsl deleted file mode 100644 index 25830e4b6c..0000000000 --- a/crates/gpui/src/platform/windows/shaders.hlsl +++ /dev/null @@ -1,1159 +0,0 @@ -cbuffer GlobalParams: register(b0) { - float2 global_viewport_size; - uint2 _pad; -}; - -Texture2D<float4> t_sprite: register(t0); -SamplerState s_sprite: register(s0); - -struct Bounds { - float2 origin; - float2 size; -}; - -struct Corners { - float top_left; - float top_right; - float bottom_right; - float bottom_left; -}; - -struct Edges { - float top; - float right; - float bottom; - float left; -}; - -struct Hsla { - float h; - float s; - float l; - float a; -}; - -struct LinearColorStop { - Hsla color; - float percentage; -}; - -struct Background { - // 0u is Solid - // 1u is LinearGradient - // 2u is PatternSlash - uint tag; - // 0u is sRGB linear color - // 1u is Oklab color - uint color_space; - Hsla solid; - float gradient_angle_or_pattern_height; - LinearColorStop colors[2]; - uint pad; -}; - -struct GradientColor { - float4 solid; - float4 color0; - float4 color1; -}; - -struct AtlasTextureId { - uint index; - uint kind; -}; - -struct AtlasBounds { - int2 origin; - int2 size; -}; - -struct AtlasTile { - AtlasTextureId texture_id; - uint tile_id; - uint padding; - AtlasBounds bounds; -}; - -struct TransformationMatrix { - float2x2 rotation_scale; - float2 translation; -}; - -static const float M_PI_F = 3.141592653f; -static const float3 GRAYSCALE_FACTORS = float3(0.2126f, 0.7152f, 0.0722f); - -float4 to_device_position_impl(float2 position) { - float2 device_position = position / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); - return float4(device_position, 0., 1.); -} - -float4 to_device_position(float2 unit_vertex, Bounds bounds) { - float2 position = unit_vertex * bounds.size + bounds.origin; - return to_device_position_impl(position); -} - -float4 distance_from_clip_rect_impl(float2 position, Bounds clip_bounds) { - float2 tl = position - clip_bounds.origin; - float2 br = clip_bounds.origin + clip_bounds.size - position; - return float4(tl.x, br.x, tl.y, br.y); -} - -float4 distance_from_clip_rect(float2 unit_vertex, Bounds bounds, Bounds clip_bounds) { - float2 position = unit_vertex * bounds.size + bounds.origin; - return distance_from_clip_rect_impl(position, clip_bounds); -} - -// Convert linear RGB to sRGB -float3 linear_to_srgb(float3 color) { - return pow(color, float3(2.2, 2.2, 2.2)); -} - -// Convert sRGB to linear RGB -float3 srgb_to_linear(float3 color) { - return pow(color, float3(1.0 / 2.2, 1.0 / 2.2, 1.0 / 2.2)); -} - -/// Hsla to linear RGBA conversion. -float4 hsla_to_rgba(Hsla hsla) { - float h = hsla.h * 6.0; // Now, it's an angle but scaled in [0, 6) range - float s = hsla.s; - float l = hsla.l; - float a = hsla.a; - - float c = (1.0 - abs(2.0 * l - 1.0)) * s; - float x = c * (1.0 - abs(fmod(h, 2.0) - 1.0)); - float m = l - c / 2.0; - - float r = 0.0; - float g = 0.0; - float b = 0.0; - - if (h >= 0.0 && h < 1.0) { - r = c; - g = x; - b = 0.0; - } else if (h >= 1.0 && h < 2.0) { - r = x; - g = c; - b = 0.0; - } else if (h >= 2.0 && h < 3.0) { - r = 0.0; - g = c; - b = x; - } else if (h >= 3.0 && h < 4.0) { - r = 0.0; - g = x; - b = c; - } else if (h >= 4.0 && h < 5.0) { - r = x; - g = 0.0; - b = c; - } else { - r = c; - g = 0.0; - b = x; - } - - float4 rgba; - rgba.x = (r + m); - rgba.y = (g + m); - rgba.z = (b + m); - rgba.w = a; - return rgba; -} - -// Converts a sRGB color to the Oklab color space. -// Reference: https://bottosson.github.io/posts/oklab/#converting-from-linear-srgb-to-oklab -float4 srgb_to_oklab(float4 color) { - // Convert non-linear sRGB to linear sRGB - color = float4(srgb_to_linear(color.rgb), color.a); - - float l = 0.4122214708 * color.r + 0.5363325363 * color.g + 0.0514459929 * color.b; - float m = 0.2119034982 * color.r + 0.6806995451 * color.g + 0.1073969566 * color.b; - float s = 0.0883024619 * color.r + 0.2817188376 * color.g + 0.6299787005 * color.b; - - float l_ = pow(l, 1.0/3.0); - float m_ = pow(m, 1.0/3.0); - float s_ = pow(s, 1.0/3.0); - - return float4( - 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_, - 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_, - 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_, - color.a - ); -} - -// Converts an Oklab color to the sRGB color space. -float4 oklab_to_srgb(float4 color) { - float l_ = color.r + 0.3963377774 * color.g + 0.2158037573 * color.b; - float m_ = color.r - 0.1055613458 * color.g - 0.0638541728 * color.b; - float s_ = color.r - 0.0894841775 * color.g - 1.2914855480 * color.b; - - float l = l_ * l_ * l_; - float m = m_ * m_ * m_; - float s = s_ * s_ * s_; - - float3 linear_rgb = float3( - 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s, - -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s, - -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s - ); - - // Convert linear sRGB to non-linear sRGB - return float4(linear_to_srgb(linear_rgb), color.a); -} - -// This approximates the error function, needed for the gaussian integral -float2 erf(float2 x) { - float2 s = sign(x); - float2 a = abs(x); - x = 1. + (0.278393 + (0.230389 + 0.078108 * (a * a)) * a) * a; - x *= x; - return s - s / (x * x); -} - -float blur_along_x(float x, float y, float sigma, float corner, float2 half_size) { - float delta = min(half_size.y - corner - abs(y), 0.); - float curved = half_size.x - corner + sqrt(max(0., corner * corner - delta * delta)); - float2 integral = 0.5 + 0.5 * erf((x + float2(-curved, curved)) * (sqrt(0.5) / sigma)); - return integral.y - integral.x; -} - -// A standard gaussian function, used for weighting samples -float gaussian(float x, float sigma) { - return exp(-(x * x) / (2. * sigma * sigma)) / (sqrt(2. * M_PI_F) * sigma); -} - -float4 over(float4 below, float4 above) { - float4 result; - float alpha = above.a + below.a * (1.0 - above.a); - result.rgb = (above.rgb * above.a + below.rgb * below.a * (1.0 - above.a)) / alpha; - result.a = alpha; - return result; -} - -float2 to_tile_position(float2 unit_vertex, AtlasTile tile) { - float2 atlas_size; - t_sprite.GetDimensions(atlas_size.x, atlas_size.y); - return (float2(tile.bounds.origin) + unit_vertex * float2(tile.bounds.size)) / atlas_size; -} - -// Selects corner radius based on quadrant. -float pick_corner_radius(float2 center_to_point, Corners corner_radii) { - if (center_to_point.x < 0.) { - if (center_to_point.y < 0.) { - return corner_radii.top_left; - } else { - return corner_radii.bottom_left; - } - } else { - if (center_to_point.y < 0.) { - return corner_radii.top_right; - } else { - return corner_radii.bottom_right; - } - } -} - -float4 to_device_position_transformed(float2 unit_vertex, Bounds bounds, - TransformationMatrix transformation) { - float2 position = unit_vertex * bounds.size + bounds.origin; - float2 transformed = mul(position, transformation.rotation_scale) + transformation.translation; - float2 device_position = transformed / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); - return float4(device_position, 0.0, 1.0); -} - -// Implementation of quad signed distance field -float quad_sdf_impl(float2 corner_center_to_point, float corner_radius) { - if (corner_radius == 0.0) { - // Fast path for unrounded corners - return max(corner_center_to_point.x, corner_center_to_point.y); - } else { - // Signed distance of the point from a quad that is inset by corner_radius - // It is negative inside this quad, and positive outside - float signed_distance_to_inset_quad = - // 0 inside the inset quad, and positive outside - length(max(float2(0.0, 0.0), corner_center_to_point)) + - // 0 outside the inset quad, and negative inside - min(0.0, max(corner_center_to_point.x, corner_center_to_point.y)); - - return signed_distance_to_inset_quad - corner_radius; - } -} - -float quad_sdf(float2 pt, Bounds bounds, Corners corner_radii) { - float2 half_size = bounds.size / 2.; - float2 center = bounds.origin + half_size; - float2 center_to_point = pt - center; - float corner_radius = pick_corner_radius(center_to_point, corner_radii); - float2 corner_to_point = abs(center_to_point) - half_size; - float2 corner_center_to_point = corner_to_point + corner_radius; - return quad_sdf_impl(corner_center_to_point, corner_radius); -} - -GradientColor prepare_gradient_color(uint tag, uint color_space, Hsla solid, LinearColorStop colors[2]) { - GradientColor output; - if (tag == 0 || tag == 2) { - output.solid = hsla_to_rgba(solid); - } else if (tag == 1) { - output.color0 = hsla_to_rgba(colors[0].color); - output.color1 = hsla_to_rgba(colors[1].color); - - // Prepare color space in vertex for avoid conversion - // in fragment shader for performance reasons - if (color_space == 1) { - // Oklab - output.color0 = srgb_to_oklab(output.color0); - output.color1 = srgb_to_oklab(output.color1); - } - } - - return output; -} - -float2x2 rotate2d(float angle) { - float s = sin(angle); - float c = cos(angle); - return float2x2(c, -s, s, c); -} - -float4 gradient_color(Background background, - float2 position, - Bounds bounds, - float4 solid_color, float4 color0, float4 color1) { - float4 color; - - switch (background.tag) { - case 0: - color = solid_color; - break; - case 1: { - // -90 degrees to match the CSS gradient angle. - float gradient_angle = background.gradient_angle_or_pattern_height; - float radians = (fmod(gradient_angle, 360.0) - 90.0) * (M_PI_F / 180.0); - float2 direction = float2(cos(radians), sin(radians)); - - // Expand the short side to be the same as the long side - if (bounds.size.x > bounds.size.y) { - direction.y *= bounds.size.y / bounds.size.x; - } else { - direction.x *= bounds.size.x / bounds.size.y; - } - - // Get the t value for the linear gradient with the color stop percentages. - float2 half_size = bounds.size * 0.5; - float2 center = bounds.origin + half_size; - float2 center_to_point = position - center; - float t = dot(center_to_point, direction) / length(direction); - // Check the direct to determine the use x or y - if (abs(direction.x) > abs(direction.y)) { - t = (t + half_size.x) / bounds.size.x; - } else { - t = (t + half_size.y) / bounds.size.y; - } - - // Adjust t based on the stop percentages - t = (t - background.colors[0].percentage) - / (background.colors[1].percentage - - background.colors[0].percentage); - t = clamp(t, 0.0, 1.0); - - switch (background.color_space) { - case 0: - color = lerp(color0, color1, t); - break; - case 1: { - float4 oklab_color = lerp(color0, color1, t); - color = oklab_to_srgb(oklab_color); - break; - } - } - break; - } - case 2: { - float gradient_angle_or_pattern_height = background.gradient_angle_or_pattern_height; - float pattern_width = (gradient_angle_or_pattern_height / 65535.0f) / 255.0f; - float pattern_interval = fmod(gradient_angle_or_pattern_height, 65535.0f) / 255.0f; - float pattern_height = pattern_width + pattern_interval; - float stripe_angle = M_PI_F / 4.0; - float pattern_period = pattern_height * sin(stripe_angle); - float2x2 rotation = rotate2d(stripe_angle); - float2 relative_position = position - bounds.origin; - float2 rotated_point = mul(rotation, relative_position); - float pattern = fmod(rotated_point.x, pattern_period); - float distance = min(pattern, pattern_period - pattern) - pattern_period * (pattern_width / pattern_height) / 2.0f; - color = solid_color; - color.a *= saturate(0.5 - distance); - break; - } - } - - return color; -} - -// Returns the dash velocity of a corner given the dash velocity of the two -// sides, by returning the slower velocity (larger dashes). -// -// Since 0 is used for dash velocity when the border width is 0 (instead of -// +inf), this returns the other dash velocity in that case. -// -// An alternative to this might be to appropriately interpolate the dash -// velocity around the corner, but that seems overcomplicated. -float corner_dash_velocity(float dv1, float dv2) { - if (dv1 == 0.0) { - return dv2; - } else if (dv2 == 0.0) { - return dv1; - } else { - return min(dv1, dv2); - } -} - -// Returns alpha used to render antialiased dashes. -// `t` is within the dash when `fmod(t, period) < length`. -float dash_alpha( - float t, float period, float length, float dash_velocity, - float antialias_threshold -) { - float half_period = period / 2.0; - float half_length = length / 2.0; - // Value in [-half_period, half_period] - // The dash is in [-half_length, half_length] - float centered = fmod(t + half_period - half_length, period) - half_period; - // Signed distance for the dash, negative values are inside the dash - float signed_distance = abs(centered) - half_length; - // Antialiased alpha based on the signed distance - return saturate(antialias_threshold - signed_distance / dash_velocity); -} - -// This approximates distance to the nearest point to a quarter ellipse in a way -// that is sufficient for anti-aliasing when the ellipse is not very eccentric. -// The components of `point` are expected to be positive. -// -// Negative on the outside and positive on the inside. -float quarter_ellipse_sdf(float2 pt, float2 radii) { - // Scale the space to treat the ellipse like a unit circle - float2 circle_vec = pt / radii; - float unit_circle_sdf = length(circle_vec) - 1.0; - // Approximate up-scaling of the length by using the average of the radii. - // - // TODO: A better solution would be to use the gradient of the implicit - // function for an ellipse to approximate a scaling factor. - return unit_circle_sdf * (radii.x + radii.y) * -0.5; -} - -/* -** -** Quads -** -*/ - -struct Quad { - uint order; - uint border_style; - Bounds bounds; - Bounds content_mask; - Background background; - Hsla border_color; - Corners corner_radii; - Edges border_widths; -}; - -struct QuadVertexOutput { - nointerpolation uint quad_id: TEXCOORD0; - float4 position: SV_Position; - nointerpolation float4 border_color: COLOR0; - nointerpolation float4 background_solid: COLOR1; - nointerpolation float4 background_color0: COLOR2; - nointerpolation float4 background_color1: COLOR3; - float4 clip_distance: SV_ClipDistance; -}; - -struct QuadFragmentInput { - nointerpolation uint quad_id: TEXCOORD0; - float4 position: SV_Position; - nointerpolation float4 border_color: COLOR0; - nointerpolation float4 background_solid: COLOR1; - nointerpolation float4 background_color0: COLOR2; - nointerpolation float4 background_color1: COLOR3; -}; - -StructuredBuffer<Quad> quads: register(t1); - -QuadVertexOutput quad_vertex(uint vertex_id: SV_VertexID, uint quad_id: SV_InstanceID) { - float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); - Quad quad = quads[quad_id]; - float4 device_position = to_device_position(unit_vertex, quad.bounds); - - GradientColor gradient = prepare_gradient_color( - quad.background.tag, - quad.background.color_space, - quad.background.solid, - quad.background.colors - ); - float4 clip_distance = distance_from_clip_rect(unit_vertex, quad.bounds, quad.content_mask); - float4 border_color = hsla_to_rgba(quad.border_color); - - QuadVertexOutput output; - output.position = device_position; - output.border_color = border_color; - output.quad_id = quad_id; - output.background_solid = gradient.solid; - output.background_color0 = gradient.color0; - output.background_color1 = gradient.color1; - output.clip_distance = clip_distance; - return output; -} - -float4 quad_fragment(QuadFragmentInput input): SV_Target { - Quad quad = quads[input.quad_id]; - float4 background_color = gradient_color(quad.background, input.position.xy, quad.bounds, - input.background_solid, input.background_color0, input.background_color1); - - bool unrounded = quad.corner_radii.top_left == 0.0 && - quad.corner_radii.top_right == 0.0 && - quad.corner_radii.bottom_left == 0.0 && - quad.corner_radii.bottom_right == 0.0; - - // Fast path when the quad is not rounded and doesn't have any border - if (quad.border_widths.top == 0.0 && - quad.border_widths.left == 0.0 && - quad.border_widths.right == 0.0 && - quad.border_widths.bottom == 0.0 && - unrounded) { - return background_color; - } - - float2 size = quad.bounds.size; - float2 half_size = size / 2.; - float2 the_point = input.position.xy - quad.bounds.origin; - float2 center_to_point = the_point - half_size; - - // Signed distance field threshold for inclusion of pixels. 0.5 is the - // minimum distance between the center of the pixel and the edge. - const float antialias_threshold = 0.5; - - // Radius of the nearest corner - float corner_radius = pick_corner_radius(center_to_point, quad.corner_radii); - - float2 border = float2( - center_to_point.x < 0.0 ? quad.border_widths.left : quad.border_widths.right, - center_to_point.y < 0.0 ? quad.border_widths.top : quad.border_widths.bottom - ); - - // 0-width borders are reduced so that `inner_sdf >= antialias_threshold`. - // The purpose of this is to not draw antialiasing pixels in this case. - float2 reduced_border = float2( - border.x == 0.0 ? -antialias_threshold : border.x, - border.y == 0.0 ? -antialias_threshold : border.y - ); - - // Vector from the corner of the quad bounds to the point, after mirroring - // the point into the bottom right quadrant. Both components are <= 0. - float2 corner_to_point = abs(center_to_point) - half_size; - - // Vector from the point to the center of the rounded corner's circle, also - // mirrored into bottom right quadrant. - float2 corner_center_to_point = corner_to_point + corner_radius; - - // Whether the nearest point on the border is rounded - bool is_near_rounded_corner = - corner_center_to_point.x >= 0.0 && - corner_center_to_point.y >= 0.0; - - // Vector from straight border inner corner to point. - // - // 0-width borders are turned into width -1 so that inner_sdf is > 1.0 near - // the border. Without this, antialiasing pixels would be drawn. - float2 straight_border_inner_corner_to_point = corner_to_point + reduced_border; - - // Whether the point is beyond the inner edge of the straight border - bool is_beyond_inner_straight_border = - straight_border_inner_corner_to_point.x > 0.0 || - straight_border_inner_corner_to_point.y > 0.0; - - // Whether the point is far enough inside the quad, such that the pixels are - // not affected by the straight border. - bool is_within_inner_straight_border = - straight_border_inner_corner_to_point.x < -antialias_threshold && - straight_border_inner_corner_to_point.y < -antialias_threshold; - - // Fast path for points that must be part of the background - if (is_within_inner_straight_border && !is_near_rounded_corner) { - return background_color; - } - - // Signed distance of the point to the outside edge of the quad's border - float outer_sdf = quad_sdf_impl(corner_center_to_point, corner_radius); - - // Approximate signed distance of the point to the inside edge of the quad's - // border. It is negative outside this edge (within the border), and - // positive inside. - // - // This is not always an accurate signed distance: - // * The rounded portions with varying border width use an approximation of - // nearest-point-on-ellipse. - // * When it is quickly known to be outside the edge, -1.0 is used. - float inner_sdf = 0.0; - if (corner_center_to_point.x <= 0.0 || corner_center_to_point.y <= 0.0) { - // Fast paths for straight borders - inner_sdf = -max(straight_border_inner_corner_to_point.x, - straight_border_inner_corner_to_point.y); - } else if (is_beyond_inner_straight_border) { - // Fast path for points that must be outside the inner edge - inner_sdf = -1.0; - } else if (reduced_border.x == reduced_border.y) { - // Fast path for circular inner edge. - inner_sdf = -(outer_sdf + reduced_border.x); - } else { - float2 ellipse_radii = max(float2(0.0, 0.0), float2(corner_radius, corner_radius) - reduced_border); - inner_sdf = quarter_ellipse_sdf(corner_center_to_point, ellipse_radii); - } - - // Negative when inside the border - float border_sdf = max(inner_sdf, outer_sdf); - - float4 color = background_color; - if (border_sdf < antialias_threshold) { - float4 border_color = input.border_color; - // Dashed border logic when border_style == 1 - if (quad.border_style == 1) { - // Position along the perimeter in "dash space", where each dash - // period has length 1 - float t = 0.0; - - // Total number of dash periods, so that the dash spacing can be - // adjusted to evenly divide it - float max_t = 0.0; - - // Border width is proportional to dash size. This is the behavior - // used by browsers, but also avoids dashes from different segments - // overlapping when dash size is smaller than the border width. - // - // Dash pattern: (2 * border width) dash, (1 * border width) gap - const float dash_length_per_width = 2.0; - const float dash_gap_per_width = 1.0; - const float dash_period_per_width = dash_length_per_width + dash_gap_per_width; - - // Since the dash size is determined by border width, the density of - // dashes varies. Multiplying a pixel distance by this returns a - // position in dash space - it has units (dash period / pixels). So - // a dash velocity of (1 / 10) is 1 dash every 10 pixels. - float dash_velocity = 0.0; - - // Dividing this by the border width gives the dash velocity - const float dv_numerator = 1.0 / dash_period_per_width; - - if (unrounded) { - // When corners aren't rounded, the dashes are separately laid - // out on each straight line, rather than around the whole - // perimeter. This way each line starts and ends with a dash. - bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; - float border_width = is_horizontal ? border.x : border.y; - dash_velocity = dv_numerator / border_width; - t = is_horizontal ? the_point.x : the_point.y; - t *= dash_velocity; - max_t = is_horizontal ? size.x : size.y; - max_t *= dash_velocity; - } else { - // When corners are rounded, the dashes are laid out clockwise - // around the whole perimeter. - - float r_tr = quad.corner_radii.top_right; - float r_br = quad.corner_radii.bottom_right; - float r_bl = quad.corner_radii.bottom_left; - float r_tl = quad.corner_radii.top_left; - - float w_t = quad.border_widths.top; - float w_r = quad.border_widths.right; - float w_b = quad.border_widths.bottom; - float w_l = quad.border_widths.left; - - // Straight side dash velocities - float dv_t = w_t <= 0.0 ? 0.0 : dv_numerator / w_t; - float dv_r = w_r <= 0.0 ? 0.0 : dv_numerator / w_r; - float dv_b = w_b <= 0.0 ? 0.0 : dv_numerator / w_b; - float dv_l = w_l <= 0.0 ? 0.0 : dv_numerator / w_l; - - // Straight side lengths in dash space - float s_t = (size.x - r_tl - r_tr) * dv_t; - float s_r = (size.y - r_tr - r_br) * dv_r; - float s_b = (size.x - r_br - r_bl) * dv_b; - float s_l = (size.y - r_bl - r_tl) * dv_l; - - float corner_dash_velocity_tr = corner_dash_velocity(dv_t, dv_r); - float corner_dash_velocity_br = corner_dash_velocity(dv_b, dv_r); - float corner_dash_velocity_bl = corner_dash_velocity(dv_b, dv_l); - float corner_dash_velocity_tl = corner_dash_velocity(dv_t, dv_l); - - // Corner lengths in dash space - float c_tr = r_tr * (M_PI_F / 2.0) * corner_dash_velocity_tr; - float c_br = r_br * (M_PI_F / 2.0) * corner_dash_velocity_br; - float c_bl = r_bl * (M_PI_F / 2.0) * corner_dash_velocity_bl; - float c_tl = r_tl * (M_PI_F / 2.0) * corner_dash_velocity_tl; - - // Cumulative dash space upto each segment - float upto_tr = s_t; - float upto_r = upto_tr + c_tr; - float upto_br = upto_r + s_r; - float upto_b = upto_br + c_br; - float upto_bl = upto_b + s_b; - float upto_l = upto_bl + c_bl; - float upto_tl = upto_l + s_l; - max_t = upto_tl + c_tl; - - if (is_near_rounded_corner) { - float radians = atan2(corner_center_to_point.y, corner_center_to_point.x); - float corner_t = radians * corner_radius; - - if (center_to_point.x >= 0.0) { - if (center_to_point.y < 0.0) { - dash_velocity = corner_dash_velocity_tr; - // Subtracted because radians is pi/2 to 0 when - // going clockwise around the top right corner, - // since the y axis has been flipped - t = upto_r - corner_t * dash_velocity; - } else { - dash_velocity = corner_dash_velocity_br; - // Added because radians is 0 to pi/2 when going - // clockwise around the bottom-right corner - t = upto_br + corner_t * dash_velocity; - } - } else { - if (center_to_point.y >= 0.0) { - dash_velocity = corner_dash_velocity_bl; - // Subtracted because radians is pi/1 to 0 when - // going clockwise around the bottom-left corner, - // since the x axis has been flipped - t = upto_l - corner_t * dash_velocity; - } else { - dash_velocity = corner_dash_velocity_tl; - // Added because radians is 0 to pi/2 when going - // clockwise around the top-left corner, since both - // axis were flipped - t = upto_tl + corner_t * dash_velocity; - } - } - } else { - // Straight borders - bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; - if (is_horizontal) { - if (center_to_point.y < 0.0) { - dash_velocity = dv_t; - t = (the_point.x - r_tl) * dash_velocity; - } else { - dash_velocity = dv_b; - t = upto_bl - (the_point.x - r_bl) * dash_velocity; - } - } else { - if (center_to_point.x < 0.0) { - dash_velocity = dv_l; - t = upto_tl - (the_point.y - r_tl) * dash_velocity; - } else { - dash_velocity = dv_r; - t = upto_r + (the_point.y - r_tr) * dash_velocity; - } - } - } - } - float dash_length = dash_length_per_width / dash_period_per_width; - float desired_dash_gap = dash_gap_per_width / dash_period_per_width; - - // Straight borders should start and end with a dash, so max_t is - // reduced to cause this. - max_t -= unrounded ? dash_length : 0.0; - if (max_t >= 1.0) { - // Adjust dash gap to evenly divide max_t - float dash_count = floor(max_t); - float dash_period = max_t / dash_count; - border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); - } else if (unrounded) { - // When there isn't enough space for the full gap between the - // two start / end dashes of a straight border, reduce gap to - // make them fit. - float dash_gap = max_t - dash_length; - if (dash_gap > 0.0) { - float dash_period = dash_length + dash_gap; - border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); - } - } - } - - // Blend the border on top of the background and then linearly interpolate - // between the two as we slide inside the background. - float4 blended_border = over(background_color, border_color); - color = lerp(background_color, blended_border, - saturate(antialias_threshold - inner_sdf)); - } - - return color * float4(1.0, 1.0, 1.0, saturate(antialias_threshold - outer_sdf)); -} - -/* -** -** Shadows -** -*/ - -struct Shadow { - uint order; - float blur_radius; - Bounds bounds; - Corners corner_radii; - Bounds content_mask; - Hsla color; -}; - -struct ShadowVertexOutput { - nointerpolation uint shadow_id: TEXCOORD0; - float4 position: SV_Position; - nointerpolation float4 color: COLOR; - float4 clip_distance: SV_ClipDistance; -}; - -struct ShadowFragmentInput { - nointerpolation uint shadow_id: TEXCOORD0; - float4 position: SV_Position; - nointerpolation float4 color: COLOR; -}; - -StructuredBuffer<Shadow> shadows: register(t1); - -ShadowVertexOutput shadow_vertex(uint vertex_id: SV_VertexID, uint shadow_id: SV_InstanceID) { - float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); - Shadow shadow = shadows[shadow_id]; - - float margin = 3.0 * shadow.blur_radius; - Bounds bounds = shadow.bounds; - bounds.origin -= margin; - bounds.size += 2.0 * margin; - - float4 device_position = to_device_position(unit_vertex, bounds); - float4 clip_distance = distance_from_clip_rect(unit_vertex, bounds, shadow.content_mask); - float4 color = hsla_to_rgba(shadow.color); - - ShadowVertexOutput output; - output.position = device_position; - output.color = color; - output.shadow_id = shadow_id; - output.clip_distance = clip_distance; - - return output; -} - -float4 shadow_fragment(ShadowFragmentInput input): SV_TARGET { - Shadow shadow = shadows[input.shadow_id]; - - float2 half_size = shadow.bounds.size / 2.; - float2 center = shadow.bounds.origin + half_size; - float2 point0 = input.position.xy - center; - float corner_radius = pick_corner_radius(point0, shadow.corner_radii); - - // The signal is only non-zero in a limited range, so don't waste samples - float low = point0.y - half_size.y; - float high = point0.y + half_size.y; - float start = clamp(-3. * shadow.blur_radius, low, high); - float end = clamp(3. * shadow.blur_radius, low, high); - - // Accumulate samples (we can get away with surprisingly few samples) - float step = (end - start) / 4.; - float y = start + step * 0.5; - float alpha = 0.; - for (int i = 0; i < 4; i++) { - alpha += blur_along_x(point0.x, point0.y - y, shadow.blur_radius, - corner_radius, half_size) * - gaussian(y, shadow.blur_radius) * step; - y += step; - } - - return input.color * float4(1., 1., 1., alpha); -} - -/* -** -** Path Rasterization -** -*/ - -struct PathRasterizationSprite { - float2 xy_position; - float2 st_position; - Background color; - Bounds bounds; -}; - -StructuredBuffer<PathRasterizationSprite> path_rasterization_sprites: register(t1); - -struct PathVertexOutput { - float4 position: SV_Position; - float2 st_position: TEXCOORD0; - nointerpolation uint vertex_id: TEXCOORD1; - float4 clip_distance: SV_ClipDistance; -}; - -struct PathFragmentInput { - float4 position: SV_Position; - float2 st_position: TEXCOORD0; - nointerpolation uint vertex_id: TEXCOORD1; -}; - -PathVertexOutput path_rasterization_vertex(uint vertex_id: SV_VertexID) { - PathRasterizationSprite sprite = path_rasterization_sprites[vertex_id]; - - PathVertexOutput output; - output.position = to_device_position_impl(sprite.xy_position); - output.st_position = sprite.st_position; - output.vertex_id = vertex_id; - output.clip_distance = distance_from_clip_rect_impl(sprite.xy_position, sprite.bounds); - - return output; -} - -float4 path_rasterization_fragment(PathFragmentInput input): SV_Target { - float2 dx = ddx(input.st_position); - float2 dy = ddy(input.st_position); - PathRasterizationSprite sprite = path_rasterization_sprites[input.vertex_id]; - - Background background = sprite.color; - Bounds bounds = sprite.bounds; - - float alpha; - if (length(float2(dx.x, dy.x))) { - alpha = 1.0; - } else { - float2 gradient = 2.0 * input.st_position.xx * float2(dx.x, dy.x) - float2(dx.y, dy.y); - float f = input.st_position.x * input.st_position.x - input.st_position.y; - float distance = f / length(gradient); - alpha = saturate(0.5 - distance); - } - - GradientColor gradient = prepare_gradient_color( - background.tag, background.color_space, background.solid, background.colors); - - float4 color = gradient_color(background, input.position.xy, bounds, - gradient.solid, gradient.color0, gradient.color1); - return float4(color.rgb * color.a * alpha, alpha * color.a); -} - -/* -** -** Path Sprites -** -*/ - -struct PathSprite { - Bounds bounds; -}; - -struct PathSpriteVertexOutput { - float4 position: SV_Position; - float2 texture_coords: TEXCOORD0; -}; - -StructuredBuffer<PathSprite> path_sprites: register(t1); - -PathSpriteVertexOutput path_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { - float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); - PathSprite sprite = path_sprites[sprite_id]; - - // Don't apply content mask because it was already accounted for when rasterizing the path - float4 device_position = to_device_position(unit_vertex, sprite.bounds); - - float2 screen_position = sprite.bounds.origin + unit_vertex * sprite.bounds.size; - float2 texture_coords = screen_position / global_viewport_size; - - PathSpriteVertexOutput output; - output.position = device_position; - output.texture_coords = texture_coords; - return output; -} - -float4 path_sprite_fragment(PathSpriteVertexOutput input): SV_Target { - return t_sprite.Sample(s_sprite, input.texture_coords); -} - -/* -** -** Underlines -** -*/ - -struct Underline { - uint order; - uint pad; - Bounds bounds; - Bounds content_mask; - Hsla color; - float thickness; - uint wavy; -}; - -struct UnderlineVertexOutput { - nointerpolation uint underline_id: TEXCOORD0; - float4 position: SV_Position; - nointerpolation float4 color: COLOR; - float4 clip_distance: SV_ClipDistance; -}; - -struct UnderlineFragmentInput { - nointerpolation uint underline_id: TEXCOORD0; - float4 position: SV_Position; - nointerpolation float4 color: COLOR; -}; - -StructuredBuffer<Underline> underlines: register(t1); - -UnderlineVertexOutput underline_vertex(uint vertex_id: SV_VertexID, uint underline_id: SV_InstanceID) { - float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); - Underline underline = underlines[underline_id]; - float4 device_position = to_device_position(unit_vertex, underline.bounds); - float4 clip_distance = distance_from_clip_rect(unit_vertex, underline.bounds, - underline.content_mask); - float4 color = hsla_to_rgba(underline.color); - - UnderlineVertexOutput output; - output.position = device_position; - output.color = color; - output.underline_id = underline_id; - output.clip_distance = clip_distance; - return output; -} - -float4 underline_fragment(UnderlineFragmentInput input): SV_Target { - Underline underline = underlines[input.underline_id]; - if (underline.wavy) { - float half_thickness = underline.thickness * 0.5; - float2 origin = underline.bounds.origin; - float2 st = ((input.position.xy - origin) / underline.bounds.size.y) - float2(0., 0.5); - float frequency = (M_PI_F * (3. * underline.thickness)) / 8.; - float amplitude = 1. / (2. * underline.thickness); - float sine = sin(st.x * frequency) * amplitude; - float dSine = cos(st.x * frequency) * amplitude * frequency; - float distance = (st.y - sine) / sqrt(1. + dSine * dSine); - float distance_in_pixels = distance * underline.bounds.size.y; - float distance_from_top_border = distance_in_pixels - half_thickness; - float distance_from_bottom_border = distance_in_pixels + half_thickness; - float alpha = saturate( - 0.5 - max(-distance_from_bottom_border, distance_from_top_border)); - return input.color * float4(1., 1., 1., alpha); - } else { - return input.color; - } -} - -/* -** -** Monochrome sprites -** -*/ - -struct MonochromeSprite { - uint order; - uint pad; - Bounds bounds; - Bounds content_mask; - Hsla color; - AtlasTile tile; - TransformationMatrix transformation; -}; - -struct MonochromeSpriteVertexOutput { - float4 position: SV_Position; - float2 tile_position: POSITION; - nointerpolation float4 color: COLOR; - float4 clip_distance: SV_ClipDistance; -}; - -struct MonochromeSpriteFragmentInput { - float4 position: SV_Position; - float2 tile_position: POSITION; - nointerpolation float4 color: COLOR; - float4 clip_distance: SV_ClipDistance; -}; - -StructuredBuffer<MonochromeSprite> mono_sprites: register(t1); - -MonochromeSpriteVertexOutput monochrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { - float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); - MonochromeSprite sprite = mono_sprites[sprite_id]; - float4 device_position = - to_device_position_transformed(unit_vertex, sprite.bounds, sprite.transformation); - float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, sprite.content_mask); - float2 tile_position = to_tile_position(unit_vertex, sprite.tile); - float4 color = hsla_to_rgba(sprite.color); - - MonochromeSpriteVertexOutput output; - output.position = device_position; - output.tile_position = tile_position; - output.color = color; - output.clip_distance = clip_distance; - return output; -} - -float4 monochrome_sprite_fragment(MonochromeSpriteFragmentInput input): SV_Target { - float sample = t_sprite.Sample(s_sprite, input.tile_position).r; - return float4(input.color.rgb, input.color.a * sample); -} - -/* -** -** Polychrome sprites -** -*/ - -struct PolychromeSprite { - uint order; - uint pad; - uint grayscale; - float opacity; - Bounds bounds; - Bounds content_mask; - Corners corner_radii; - AtlasTile tile; -}; - -struct PolychromeSpriteVertexOutput { - nointerpolation uint sprite_id: TEXCOORD0; - float4 position: SV_Position; - float2 tile_position: POSITION; - float4 clip_distance: SV_ClipDistance; -}; - -struct PolychromeSpriteFragmentInput { - nointerpolation uint sprite_id: TEXCOORD0; - float4 position: SV_Position; - float2 tile_position: POSITION; -}; - -StructuredBuffer<PolychromeSprite> poly_sprites: register(t1); - -PolychromeSpriteVertexOutput polychrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { - float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); - PolychromeSprite sprite = poly_sprites[sprite_id]; - float4 device_position = to_device_position(unit_vertex, sprite.bounds); - float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, - sprite.content_mask); - float2 tile_position = to_tile_position(unit_vertex, sprite.tile); - - PolychromeSpriteVertexOutput output; - output.position = device_position; - output.tile_position = tile_position; - output.sprite_id = sprite_id; - output.clip_distance = clip_distance; - return output; -} - -float4 polychrome_sprite_fragment(PolychromeSpriteFragmentInput input): SV_Target { - PolychromeSprite sprite = poly_sprites[input.sprite_id]; - float4 sample = t_sprite.Sample(s_sprite, input.tile_position); - float distance = quad_sdf(input.position.xy, sprite.bounds, sprite.corner_radii); - - float4 color = sample; - if ((sprite.grayscale & 0xFFu) != 0u) { - float3 grayscale = dot(color.rgb, GRAYSCALE_FACTORS); - color = float4(grayscale, sample.a); - } - color.a *= sprite.opacity * saturate(0.5 - distance); - return color; -} diff --git a/crates/gpui/src/platform/windows/window.rs b/crates/gpui/src/platform/windows/window.rs index 68b667569b..5703a82815 100644 --- a/crates/gpui/src/platform/windows/window.rs +++ b/crates/gpui/src/platform/windows/window.rs @@ -26,6 +26,7 @@ use windows::{ core::*, }; +use crate::platform::blade::{BladeContext, BladeRenderer}; use crate::*; pub(crate) struct WindowsWindow(pub Rc<WindowsWindowStatePtr>); @@ -48,7 +49,7 @@ pub struct WindowsWindowState { pub system_key_handled: bool, pub hovered: bool, - pub renderer: DirectXRenderer, + pub renderer: BladeRenderer, pub click_state: ClickState, pub system_settings: WindowsSystemSettings, @@ -79,12 +80,13 @@ pub(crate) struct WindowsWindowStatePtr { impl WindowsWindowState { fn new( hwnd: HWND, + transparent: bool, cs: &CREATESTRUCTW, current_cursor: Option<HCURSOR>, display: WindowsDisplay, + gpu_context: &BladeContext, min_size: Option<Size<Pixels>>, appearance: WindowAppearance, - disable_direct_composition: bool, ) -> Result<Self> { let scale_factor = { let monitor_dpi = unsafe { GetDpiForWindow(hwnd) } as f32; @@ -101,8 +103,7 @@ impl WindowsWindowState { }; let border_offset = WindowBorderOffset::default(); let restore_from_minimized = None; - let renderer = DirectXRenderer::new(hwnd, disable_direct_composition) - .context("Creating DirectX renderer")?; + let renderer = windows_renderer::init(gpu_context, hwnd, transparent)?; let callbacks = Callbacks::default(); let input_handler = None; let pending_surrogate = None; @@ -205,12 +206,13 @@ impl WindowsWindowStatePtr { fn new(context: &WindowCreateContext, hwnd: HWND, cs: &CREATESTRUCTW) -> Result<Rc<Self>> { let state = RefCell::new(WindowsWindowState::new( hwnd, + context.transparent, cs, context.current_cursor, context.display, + context.gpu_context, context.min_size, context.appearance, - context.disable_direct_composition, )?); Ok(Rc::new_cyclic(|this| Self { @@ -327,11 +329,12 @@ pub(crate) struct Callbacks { pub(crate) appearance_changed: Option<Box<dyn FnMut()>>, } -struct WindowCreateContext { +struct WindowCreateContext<'a> { inner: Option<Result<Rc<WindowsWindowStatePtr>>>, handle: AnyWindowHandle, hide_title_bar: bool, display: WindowsDisplay, + transparent: bool, is_movable: bool, min_size: Option<Size<Pixels>>, executor: ForegroundExecutor, @@ -340,9 +343,9 @@ struct WindowCreateContext { drop_target_helper: IDropTargetHelper, validation_number: usize, main_receiver: flume::Receiver<Runnable>, + gpu_context: &'a BladeContext, main_thread_id_win32: u32, appearance: WindowAppearance, - disable_direct_composition: bool, } impl WindowsWindow { @@ -350,6 +353,7 @@ impl WindowsWindow { handle: AnyWindowHandle, params: WindowParams, creation_info: WindowCreationInfo, + gpu_context: &BladeContext, ) -> Result<Self> { let WindowCreationInfo { icon, @@ -360,7 +364,6 @@ impl WindowsWindow { validation_number, main_receiver, main_thread_id_win32, - disable_direct_composition, } = creation_info; let classname = register_wnd_class(icon); let hide_title_bar = params @@ -376,18 +379,14 @@ impl WindowsWindow { .map(|title| title.as_ref()) .unwrap_or(""), ); - - let (mut dwexstyle, dwstyle) = if params.kind == WindowKind::PopUp { - (WS_EX_TOOLWINDOW, WINDOW_STYLE(0x0)) + let (dwexstyle, mut dwstyle) = if params.kind == WindowKind::PopUp { + (WS_EX_TOOLWINDOW | WS_EX_LAYERED, WINDOW_STYLE(0x0)) } else { ( - WS_EX_APPWINDOW, + WS_EX_APPWINDOW | WS_EX_LAYERED, WS_THICKFRAME | WS_SYSMENU | WS_MAXIMIZEBOX | WS_MINIMIZEBOX, ) }; - if !disable_direct_composition { - dwexstyle |= WS_EX_NOREDIRECTIONBITMAP; - } let hinstance = get_module_handle(); let display = if let Some(display_id) = params.display_id { @@ -402,6 +401,7 @@ impl WindowsWindow { handle, hide_title_bar, display, + transparent: true, is_movable: params.is_movable, min_size: params.window_min_size, executor, @@ -410,9 +410,9 @@ impl WindowsWindow { drop_target_helper, validation_number, main_receiver, + gpu_context, main_thread_id_win32, appearance, - disable_direct_composition, }; let lpparam = Some(&context as *const _ as *const _); let creation_result = unsafe { @@ -453,6 +453,14 @@ impl WindowsWindow { state: WindowOpenState::Windowed, }); } + // The render pipeline will perform compositing on the GPU when the + // swapchain is configured correctly (see downstream of + // update_transparency). + // The following configuration is a one-time setup to ensure that the + // window is going to be composited with per-pixel alpha, but the render + // pipeline is responsible for effectively calling UpdateLayeredWindow + // at the appropriate time. + unsafe { SetLayeredWindowAttributes(hwnd, COLORREF(0), 255, LWA_ALPHA)? }; Ok(Self(state_ptr)) } @@ -477,6 +485,7 @@ impl rwh::HasDisplayHandle for WindowsWindow { impl Drop for WindowsWindow { fn drop(&mut self) { + self.0.state.borrow_mut().renderer.destroy(); // clone this `Rc` to prevent early release of the pointer let this = self.0.clone(); self.0 @@ -696,21 +705,24 @@ impl PlatformWindow for WindowsWindow { } fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) { - let hwnd = self.0.hwnd; + let mut window_state = self.0.state.borrow_mut(); + window_state + .renderer + .update_transparency(background_appearance != WindowBackgroundAppearance::Opaque); match background_appearance { WindowBackgroundAppearance::Opaque => { // ACCENT_DISABLED - set_window_composition_attribute(hwnd, None, 0); + set_window_composition_attribute(window_state.hwnd, None, 0); } WindowBackgroundAppearance::Transparent => { // Use ACCENT_ENABLE_TRANSPARENTGRADIENT for transparent background - set_window_composition_attribute(hwnd, None, 2); + set_window_composition_attribute(window_state.hwnd, None, 2); } WindowBackgroundAppearance::Blurred => { // Enable acrylic blur // ACCENT_ENABLE_ACRYLICBLURBEHIND - set_window_composition_attribute(hwnd, Some((0, 0, 0, 0)), 4); + set_window_composition_attribute(window_state.hwnd, Some((0, 0, 0, 0)), 4); } } } @@ -782,11 +794,11 @@ impl PlatformWindow for WindowsWindow { } fn draw(&self, scene: &Scene) { - self.0.state.borrow_mut().renderer.draw(scene).log_err(); + self.0.state.borrow_mut().renderer.draw(scene) } fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> { - self.0.state.borrow().renderer.sprite_atlas() + self.0.state.borrow().renderer.sprite_atlas().clone() } fn get_raw_handle(&self) -> HWND { @@ -794,11 +806,11 @@ impl PlatformWindow for WindowsWindow { } fn gpu_specs(&self) -> Option<GpuSpecs> { - self.0.state.borrow().renderer.gpu_specs().log_err() + Some(self.0.state.borrow().renderer.gpu_specs()) } fn update_ime_position(&self, _bounds: Bounds<ScaledPixels>) { - // There is no such thing on Windows. + // todo(windows) } } @@ -1294,6 +1306,52 @@ fn set_window_composition_attribute(hwnd: HWND, color: Option<Color>, state: u32 } } +mod windows_renderer { + use crate::platform::blade::{BladeContext, BladeRenderer, BladeSurfaceConfig}; + use raw_window_handle as rwh; + use std::num::NonZeroIsize; + use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::GWLP_HINSTANCE}; + + use crate::{get_window_long, show_error}; + + pub(super) fn init( + context: &BladeContext, + hwnd: HWND, + transparent: bool, + ) -> anyhow::Result<BladeRenderer> { + let raw = RawWindow { hwnd }; + let config = BladeSurfaceConfig { + size: Default::default(), + transparent, + }; + BladeRenderer::new(context, &raw, config) + .inspect_err(|err| show_error("Failed to initialize BladeRenderer", err.to_string())) + } + + struct RawWindow { + hwnd: HWND, + } + + impl rwh::HasWindowHandle for RawWindow { + fn window_handle(&self) -> Result<rwh::WindowHandle<'_>, rwh::HandleError> { + Ok(unsafe { + let hwnd = NonZeroIsize::new_unchecked(self.hwnd.0 as isize); + let mut handle = rwh::Win32WindowHandle::new(hwnd); + let hinstance = get_window_long(self.hwnd, GWLP_HINSTANCE); + handle.hinstance = NonZeroIsize::new(hinstance); + rwh::WindowHandle::borrow_raw(handle.into()) + }) + } + } + + impl rwh::HasDisplayHandle for RawWindow { + fn display_handle(&self) -> Result<rwh::DisplayHandle<'_>, rwh::HandleError> { + let handle = rwh::WindowsDisplayHandle::new(); + Ok(unsafe { rwh::DisplayHandle::borrow_raw(handle.into()) }) + } + } +} + #[cfg(test)] mod tests { use super::ClickState; diff --git a/crates/gpui/src/scene.rs b/crates/gpui/src/scene.rs index ec8d720cdf..4eaef64afa 100644 --- a/crates/gpui/src/scene.rs +++ b/crates/gpui/src/scene.rs @@ -43,6 +43,17 @@ impl Scene { self.surfaces.clear(); } + #[cfg_attr( + all( + any(target_os = "linux", target_os = "freebsd"), + not(any(feature = "x11", feature = "wayland")) + ), + allow(dead_code) + )] + pub fn paths(&self) -> &[Path<ScaledPixels>] { + &self.paths + } + pub fn len(&self) -> usize { self.paint_operations.len() } @@ -670,7 +681,7 @@ pub(crate) struct PathId(pub(crate) usize); #[derive(Clone, Debug)] pub struct Path<P: Clone + Debug + Default + PartialEq> { pub(crate) id: PathId, - pub(crate) order: DrawOrder, + order: DrawOrder, pub(crate) bounds: Bounds<P>, pub(crate) content_mask: ContentMask<P>, pub(crate) vertices: Vec<PathVertex<P>>, diff --git a/crates/gpui/src/tab_stop.rs b/crates/gpui/src/tab_stop.rs index 7dde42efed..2ec3f560e8 100644 --- a/crates/gpui/src/tab_stop.rs +++ b/crates/gpui/src/tab_stop.rs @@ -5,7 +5,7 @@ use crate::{FocusHandle, FocusId}; /// Used to manage the `Tab` event to switch between focus handles. #[derive(Default)] pub(crate) struct TabHandles { - pub(crate) handles: Vec<FocusHandle>, + handles: Vec<FocusHandle>, } impl TabHandles { @@ -32,18 +32,20 @@ impl TabHandles { self.handles.clear(); } - fn current_index(&self, focused_id: Option<&FocusId>) -> Option<usize> { - self.handles.iter().position(|h| Some(&h.id) == focused_id) + fn current_index(&self, focused_id: Option<&FocusId>) -> usize { + self.handles + .iter() + .position(|h| Some(&h.id) == focused_id) + .unwrap_or_default() } pub(crate) fn next(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { - let next_ix = self - .current_index(focused_id) - .and_then(|ix| { - let next_ix = ix + 1; - (next_ix < self.handles.len()).then_some(next_ix) - }) - .unwrap_or_default(); + let ix = self.current_index(focused_id); + + let mut next_ix = ix + 1; + if next_ix + 1 > self.handles.len() { + next_ix = 0; + } if let Some(next_handle) = self.handles.get(next_ix) { Some(next_handle.clone()) @@ -53,7 +55,7 @@ impl TabHandles { } pub(crate) fn prev(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { - let ix = self.current_index(focused_id).unwrap_or_default(); + let ix = self.current_index(focused_id); let prev_ix; if ix == 0 { prev_ix = self.handles.len().saturating_sub(1); @@ -106,14 +108,8 @@ mod tests { ] ); - // Select first tab index if no handle is currently focused. - assert_eq!(tab.next(None), Some(tab.handles[0].clone())); - // Select last tab index if no handle is currently focused. - assert_eq!( - tab.prev(None), - Some(tab.handles[tab.handles.len() - 1].clone()) - ); - + // next + assert_eq!(tab.next(None), Some(tab.handles[1].clone())); assert_eq!( tab.next(Some(&tab.handles[0].id)), Some(tab.handles[1].clone()) diff --git a/crates/gpui/src/taffy.rs b/crates/gpui/src/taffy.rs index f7fa54256d..6228a60490 100644 --- a/crates/gpui/src/taffy.rs +++ b/crates/gpui/src/taffy.rs @@ -283,7 +283,7 @@ impl ToTaffy<taffy::style::LengthPercentageAuto> for Length { fn to_taffy(&self, rem_size: Pixels) -> taffy::prelude::LengthPercentageAuto { match self { Length::Definite(length) => length.to_taffy(rem_size), - Length::Auto => taffy::prelude::LengthPercentageAuto::auto(), + Length::Auto => taffy::prelude::LengthPercentageAuto::Auto, } } } @@ -292,7 +292,7 @@ impl ToTaffy<taffy::style::Dimension> for Length { fn to_taffy(&self, rem_size: Pixels) -> taffy::prelude::Dimension { match self { Length::Definite(length) => length.to_taffy(rem_size), - Length::Auto => taffy::prelude::Dimension::auto(), + Length::Auto => taffy::prelude::Dimension::Auto, } } } @@ -302,14 +302,14 @@ impl ToTaffy<taffy::style::LengthPercentage> for DefiniteLength { match self { DefiniteLength::Absolute(length) => match length { AbsoluteLength::Pixels(pixels) => { - taffy::style::LengthPercentage::length(pixels.into()) + taffy::style::LengthPercentage::Length(pixels.into()) } AbsoluteLength::Rems(rems) => { - taffy::style::LengthPercentage::length((*rems * rem_size).into()) + taffy::style::LengthPercentage::Length((*rems * rem_size).into()) } }, DefiniteLength::Fraction(fraction) => { - taffy::style::LengthPercentage::percent(*fraction) + taffy::style::LengthPercentage::Percent(*fraction) } } } @@ -320,14 +320,14 @@ impl ToTaffy<taffy::style::LengthPercentageAuto> for DefiniteLength { match self { DefiniteLength::Absolute(length) => match length { AbsoluteLength::Pixels(pixels) => { - taffy::style::LengthPercentageAuto::length(pixels.into()) + taffy::style::LengthPercentageAuto::Length(pixels.into()) } AbsoluteLength::Rems(rems) => { - taffy::style::LengthPercentageAuto::length((*rems * rem_size).into()) + taffy::style::LengthPercentageAuto::Length((*rems * rem_size).into()) } }, DefiniteLength::Fraction(fraction) => { - taffy::style::LengthPercentageAuto::percent(*fraction) + taffy::style::LengthPercentageAuto::Percent(*fraction) } } } @@ -337,12 +337,12 @@ impl ToTaffy<taffy::style::Dimension> for DefiniteLength { fn to_taffy(&self, rem_size: Pixels) -> taffy::style::Dimension { match self { DefiniteLength::Absolute(length) => match length { - AbsoluteLength::Pixels(pixels) => taffy::style::Dimension::length(pixels.into()), + AbsoluteLength::Pixels(pixels) => taffy::style::Dimension::Length(pixels.into()), AbsoluteLength::Rems(rems) => { - taffy::style::Dimension::length((*rems * rem_size).into()) + taffy::style::Dimension::Length((*rems * rem_size).into()) } }, - DefiniteLength::Fraction(fraction) => taffy::style::Dimension::percent(*fraction), + DefiniteLength::Fraction(fraction) => taffy::style::Dimension::Percent(*fraction), } } } @@ -350,9 +350,9 @@ impl ToTaffy<taffy::style::Dimension> for DefiniteLength { impl ToTaffy<taffy::style::LengthPercentage> for AbsoluteLength { fn to_taffy(&self, rem_size: Pixels) -> taffy::style::LengthPercentage { match self { - AbsoluteLength::Pixels(pixels) => taffy::style::LengthPercentage::length(pixels.into()), + AbsoluteLength::Pixels(pixels) => taffy::style::LengthPercentage::Length(pixels.into()), AbsoluteLength::Rems(rems) => { - taffy::style::LengthPercentage::length((*rems * rem_size).into()) + taffy::style::LengthPercentage::Length((*rems * rem_size).into()) } } } diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index 6ebb1cac40..963d2bb45c 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -702,7 +702,6 @@ pub(crate) struct PaintIndex { input_handlers_index: usize, cursor_styles_index: usize, accessed_element_states_index: usize, - tab_handle_index: usize, line_layout_index: LineLayoutIndex, } @@ -1020,7 +1019,7 @@ impl Window { || (active.get() && last_input_timestamp.get().elapsed() < Duration::from_secs(1)); - if invalidator.is_dirty() || request_frame_options.force_render { + if invalidator.is_dirty() { measure("frame duration", || { handle .update(&mut cx, |_, window, cx| { @@ -2209,7 +2208,6 @@ impl Window { input_handlers_index: self.next_frame.input_handlers.len(), cursor_styles_index: self.next_frame.cursor_styles.len(), accessed_element_states_index: self.next_frame.accessed_element_states.len(), - tab_handle_index: self.next_frame.tab_handles.handles.len(), line_layout_index: self.text_system.layout_index(), } } @@ -2239,12 +2237,6 @@ impl Window { .iter() .map(|(id, type_id)| (GlobalElementId(id.0.clone()), *type_id)), ); - self.next_frame.tab_handles.handles.extend( - self.rendered_frame.tab_handles.handles - [range.start.tab_handle_index..range.end.tab_handle_index] - .iter() - .cloned(), - ); self.text_system .reuse_layouts(range.start.line_layout_index..range.end.line_layout_index); diff --git a/crates/http_client/Cargo.toml b/crates/http_client/Cargo.toml index 3f51cc5a23..2045708ff2 100644 --- a/crates/http_client/Cargo.toml +++ b/crates/http_client/Cargo.toml @@ -23,7 +23,6 @@ futures.workspace = true http.workspace = true http-body.workspace = true log.workspace = true -parking_lot.workspace = true serde.workspace = true serde_json.workspace = true url.workspace = true diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index d33bbefc06..434bd74fc8 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -9,10 +9,12 @@ pub use http::{self, Method, Request, Response, StatusCode, Uri}; use futures::future::BoxFuture; use http::request::Builder; -use parking_lot::Mutex; #[cfg(feature = "test-support")] use std::fmt; -use std::{any::type_name, sync::Arc}; +use std::{ + any::type_name, + sync::{Arc, Mutex}, +}; pub use url::Url; #[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] @@ -84,11 +86,6 @@ pub trait HttpClient: 'static + Send + Sync { } fn proxy(&self) -> Option<&Url>; - - #[cfg(feature = "test-support")] - fn as_fake(&self) -> &FakeHttpClient { - panic!("called as_fake on {}", type_name::<Self>()) - } } /// An [`HttpClient`] that may have a proxy. @@ -135,11 +132,6 @@ impl HttpClient for HttpClientWithProxy { fn type_name(&self) -> &'static str { self.client.type_name() } - - #[cfg(feature = "test-support")] - fn as_fake(&self) -> &FakeHttpClient { - self.client.as_fake() - } } impl HttpClient for Arc<HttpClientWithProxy> { @@ -161,11 +153,6 @@ impl HttpClient for Arc<HttpClientWithProxy> { fn type_name(&self) -> &'static str { self.client.type_name() } - - #[cfg(feature = "test-support")] - fn as_fake(&self) -> &FakeHttpClient { - self.client.as_fake() - } } /// An [`HttpClient`] that has a base URL. @@ -212,13 +199,20 @@ impl HttpClientWithUrl { /// Returns the base URL. pub fn base_url(&self) -> String { - self.base_url.lock().clone() + self.base_url + .lock() + .map_or_else(|_| Default::default(), |url| url.clone()) } /// Sets the base URL. pub fn set_base_url(&self, base_url: impl Into<String>) { let base_url = base_url.into(); - *self.base_url.lock() = base_url; + self.base_url + .lock() + .map(|mut url| { + *url = base_url; + }) + .ok(); } /// Builds a URL using the given path. @@ -242,22 +236,6 @@ impl HttpClientWithUrl { )?) } - /// Builds a Zed Cloud URL using the given path. - pub fn build_zed_cloud_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> { - let base_url = self.base_url(); - let base_api_url = match base_url.as_ref() { - "https://zed.dev" => "https://cloud.zed.dev", - "https://staging.zed.dev" => "https://cloud.zed.dev", - "http://localhost:3000" => "http://localhost:8787", - other => other, - }; - - Ok(Url::parse_with_params( - &format!("{}{}", base_api_url, path), - query, - )?) - } - /// Builds a Zed LLM URL using the given path. pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> { let base_url = self.base_url(); @@ -294,11 +272,6 @@ impl HttpClient for Arc<HttpClientWithUrl> { fn type_name(&self) -> &'static str { self.client.type_name() } - - #[cfg(feature = "test-support")] - fn as_fake(&self) -> &FakeHttpClient { - self.client.as_fake() - } } impl HttpClient for HttpClientWithUrl { @@ -320,11 +293,6 @@ impl HttpClient for HttpClientWithUrl { fn type_name(&self) -> &'static str { self.client.type_name() } - - #[cfg(feature = "test-support")] - fn as_fake(&self) -> &FakeHttpClient { - self.client.as_fake() - } } pub fn read_proxy_from_env() -> Option<Url> { @@ -376,15 +344,10 @@ impl HttpClient for BlockedHttpClient { fn type_name(&self) -> &'static str { type_name::<Self>() } - - #[cfg(feature = "test-support")] - fn as_fake(&self) -> &FakeHttpClient { - panic!("called as_fake on {}", type_name::<Self>()) - } } #[cfg(feature = "test-support")] -type FakeHttpHandler = Arc< +type FakeHttpHandler = Box< dyn Fn(Request<AsyncBody>) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> + Send + Sync @@ -393,7 +356,7 @@ type FakeHttpHandler = Arc< #[cfg(feature = "test-support")] pub struct FakeHttpClient { - handler: Mutex<Option<FakeHttpHandler>>, + handler: FakeHttpHandler, user_agent: HeaderValue, } @@ -408,7 +371,7 @@ impl FakeHttpClient { base_url: Mutex::new("http://test.example".into()), client: HttpClientWithProxy { client: Arc::new(Self { - handler: Mutex::new(Some(Arc::new(move |req| Box::pin(handler(req))))), + handler: Box::new(move |req| Box::pin(handler(req))), user_agent: HeaderValue::from_static(type_name::<Self>()), }), proxy: None, @@ -433,18 +396,6 @@ impl FakeHttpClient { .unwrap()) }) } - - pub fn replace_handler<Fut, F>(&self, new_handler: F) - where - Fut: futures::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send + 'static, - F: Fn(FakeHttpHandler, Request<AsyncBody>) -> Fut + Send + Sync + 'static, - { - let mut handler = self.handler.lock(); - let old_handler = handler.take().unwrap(); - *handler = Some(Arc::new(move |req| { - Box::pin(new_handler(old_handler.clone(), req)) - })); - } } #[cfg(feature = "test-support")] @@ -460,7 +411,7 @@ impl HttpClient for FakeHttpClient { &self, req: Request<AsyncBody>, ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - let future = (self.handler.lock().as_ref().unwrap())(req); + let future = (self.handler)(req); future } @@ -475,8 +426,4 @@ impl HttpClient for FakeHttpClient { fn type_name(&self) -> &'static str { type_name::<Self>() } - - fn as_fake(&self) -> &FakeHttpClient { - self - } } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index a94d89bdc8..e7066ae151 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -38,6 +38,7 @@ pub enum IconName { ArrowUpFromLine, ArrowUpRight, ArrowUpRightAlt, + AtSign, AudioOff, AudioOn, Backspace, @@ -47,13 +48,15 @@ pub enum IconName { BellRing, Binary, Blocks, - BoltOutlined, + Bolt, BoltFilled, + BoltFilledAlt, Book, BookCopy, + BookPlus, + Brain, BugOff, CaseSensitive, - Chat, Check, CheckDouble, ChevronDown, @@ -68,7 +71,6 @@ pub enum IconName { CircleHelp, Close, Cloud, - CloudDownload, Code, Cog, Command, @@ -104,12 +106,6 @@ pub enum IconName { Disconnected, DocumentText, Download, - EditorAtom, - EditorCursor, - EditorEmacs, - EditorJetBrains, - EditorSublime, - EditorVsCode, Ellipsis, EllipsisVertical, Envelope, @@ -181,9 +177,14 @@ pub enum IconName { Maximize, Menu, MenuAlt, + MessageBubbles, Mic, MicMute, + Microscope, Minimize, + NewFromSummary, + NewTextThread, + NewThread, Option, PageDown, PageUp, @@ -194,7 +195,9 @@ pub enum IconName { PersonCircle, PhoneIncoming, Pin, - PlayOutlined, + Play, + PlayAlt, + PlayBug, PlayFilled, Plus, PocketKnife, @@ -211,6 +214,7 @@ pub enum IconName { ReplyArrowRight, Rerun, Return, + Reveal, RotateCcw, RotateCw, Route, @@ -224,7 +228,6 @@ pub enum IconName { Server, Settings, SettingsAlt, - ShieldCheck, Shift, Slash, SlashSquare, @@ -235,6 +238,7 @@ pub enum IconName { Sparkle, SparkleAlt, SparkleFilled, + Spinner, Split, SplitAlt, SquareDot, @@ -244,6 +248,7 @@ pub enum IconName { StarFilled, Stop, StopFilled, + Strikethrough, Supermaven, SupermavenDisabled, SupermavenError, @@ -253,9 +258,6 @@ pub enum IconName { Terminal, TerminalAlt, TextSnippet, - TextThread, - Thread, - ThreadFromSummary, ThumbsDown, ThumbsUp, TodoComplete, @@ -275,6 +277,7 @@ pub enum IconName { ToolTerminal, ToolWeb, Trash, + TrashAlt, Triangle, TriangleRight, Undo, diff --git a/crates/inline_completion_button/Cargo.toml b/crates/inline_completion_button/Cargo.toml index b34e59336b..c2a619d500 100644 --- a/crates/inline_completion_button/Cargo.toml +++ b/crates/inline_completion_button/Cargo.toml @@ -15,7 +15,6 @@ doctest = false [dependencies] anyhow.workspace = true client.workspace = true -cloud_llm_client.workspace = true copilot.workspace = true editor.workspace = true feature_flags.workspace = true @@ -33,6 +32,7 @@ ui.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true +zed_llm_client.workspace = true zeta.workspace = true [dev-dependencies] diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs index 2d7f211942..2615a8beef 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -1,6 +1,5 @@ use anyhow::Result; use client::{DisableAiSettings, UserStore, zed_urls}; -use cloud_llm_client::UsageLimit; use copilot::{Copilot, Status}; use editor::{ Editor, SelectionEffects, @@ -35,6 +34,7 @@ use workspace::{ notifications::NotificationId, }; use zed_actions::OpenBrowser; +use zed_llm_client::UsageLimit; use zeta::RateCompletions; actions!( @@ -246,15 +246,12 @@ impl Render for InlineCompletionButton { }; if zeta::should_show_upsell_modal(&self.user_store, cx) { - let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { - if self.user_store.read(cx).has_accepted_terms_of_service() { - "Choose a Plan" - } else { - "Accept the Terms of Service" - } - } else { - "Sign In" - }; + let tooltip_meta = + match self.user_store.read(cx).current_user_has_accepted_terms() { + Some(true) => "Choose a Plan", + Some(false) => "Accept the Terms of Service", + None => "Sign In", + }; return div().child( IconButton::new("zed-predict-pending-button", zeta_icon) @@ -390,9 +387,9 @@ impl InlineCompletionButton { language: None, file: None, edit_prediction_provider: None, - user_store, popover_menu_handle, fs, + user_store, } } diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 894625b982..1df33286ee 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -161,11 +161,12 @@ pub struct CachedLspAdapter { pub name: LanguageServerName, pub disk_based_diagnostic_sources: Vec<String>, pub disk_based_diagnostics_progress_token: Option<String>, - language_ids: HashMap<LanguageName, String>, + language_ids: HashMap<String, String>, pub adapter: Arc<dyn LspAdapter>, pub reinstall_attempt_count: AtomicU64, cached_binary: futures::lock::Mutex<Option<LanguageServerBinary>>, manifest_name: OnceLock<Option<ManifestName>>, + attach_kind: OnceLock<Attach>, } impl Debug for CachedLspAdapter { @@ -201,6 +202,7 @@ impl CachedLspAdapter { adapter, cached_binary: Default::default(), reinstall_attempt_count: AtomicU64::new(0), + attach_kind: Default::default(), manifest_name: Default::default(), }) } @@ -277,25 +279,38 @@ impl CachedLspAdapter { pub fn language_id(&self, language_name: &LanguageName) -> String { self.language_ids - .get(language_name) + .get(language_name.as_ref()) .cloned() .unwrap_or_else(|| language_name.lsp_id()) } - pub fn manifest_name(&self) -> Option<ManifestName> { self.manifest_name .get_or_init(|| self.adapter.manifest_name()) .clone() } + pub fn attach_kind(&self) -> Attach { + *self.attach_kind.get_or_init(|| self.adapter.attach_kind()) + } } -/// Determines what gets sent out as a workspace folders content #[derive(Clone, Copy, Debug, PartialEq)] -pub enum WorkspaceFoldersContent { - /// Send out a single entry with the root of the workspace. - WorktreeRoot, - /// Send out a list of subproject roots. - SubprojectRoots, +pub enum Attach { + /// Create a single language server instance per subproject root. + InstancePerRoot, + /// Use one shared language server instance for all subprojects within a project. + Shared, +} + +impl Attach { + pub fn root_path( + &self, + root_subproject_path: (WorktreeId, Arc<Path>), + ) -> (WorktreeId, Arc<Path>) { + match self { + Attach::InstancePerRoot => root_subproject_path, + Attach::Shared => (root_subproject_path.0, Arc::from(Path::new(""))), + } + } } /// [`LspAdapterDelegate`] allows [`LspAdapter]` implementations to interface with the application @@ -574,8 +589,8 @@ pub trait LspAdapter: 'static + Send + Sync { None } - fn language_ids(&self) -> HashMap<LanguageName, String> { - HashMap::default() + fn language_ids(&self) -> HashMap<String, String> { + Default::default() } /// Support custom initialize params. @@ -587,11 +602,8 @@ pub trait LspAdapter: 'static + Send + Sync { Ok(original) } - /// Determines whether a language server supports workspace folders. - /// - /// And does not trip over itself in the process. - fn workspace_folders_content(&self) -> WorkspaceFoldersContent { - WorkspaceFoldersContent::SubprojectRoots + fn attach_kind(&self) -> Attach { + Attach::Shared } fn manifest_name(&self) -> Option<ManifestName> { diff --git a/crates/language_extension/src/extension_lsp_adapter.rs b/crates/language_extension/src/extension_lsp_adapter.rs index 98b6fd4b5a..58fbe6cda2 100644 --- a/crates/language_extension/src/extension_lsp_adapter.rs +++ b/crates/language_extension/src/extension_lsp_adapter.rs @@ -242,7 +242,7 @@ impl LspAdapter for ExtensionLspAdapter { ])) } - fn language_ids(&self) -> HashMap<LanguageName, String> { + fn language_ids(&self) -> HashMap<String, String> { // TODO: The language IDs can be provided via the language server options // in `extension.toml now but we're leaving these existing usages in place temporarily // to avoid any compatibility issues between Zed and the extension versions. @@ -250,7 +250,7 @@ impl LspAdapter for ExtensionLspAdapter { // We can remove once the following extension versions no longer see any use: // - php@0.0.1 if self.extension.manifest().id.as_ref() == "php" { - return HashMap::from_iter([(LanguageName::new("PHP"), "php".into())]); + return HashMap::from_iter([("PHP".into(), "php".into())]); } self.extension diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 841be60b0e..b718c530f5 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,7 +20,6 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true base64.workspace = true client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true @@ -38,6 +37,7 @@ telemetry_events.workspace = true thiserror.workspace = true util.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 1637d2de8a..54640419b6 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -11,7 +11,6 @@ pub mod fake_provider; use anthropic::{AnthropicError, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::Client; -use cloud_llm_client::{CompletionMode, CompletionRequestStatus}; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; @@ -27,6 +26,7 @@ use std::time::Duration; use std::{fmt, io}; use thiserror::Error; use util::serde::is_default; +use zed_llm_client::{CompletionMode, CompletionRequestStatus}; pub use crate::model::*; pub use crate::rate_limiter::*; diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 8ae5893410..72b7132c60 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,11 +3,10 @@ use std::sync::Arc; use anyhow::Result; use client::Client; -use cloud_llm_client::Plan; use gpui::{ App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, }; -use proto::TypedEnvelope; +use proto::{Plan, TypedEnvelope}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -31,7 +30,7 @@ pub struct ModelRequestLimitReachedError { impl fmt::Display for ModelRequestLimitReachedError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let message = match self.plan { - Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.", + Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.", Plan::ZedPro => { "Model request limit reached. Upgrade to usage-based billing for more requests." } @@ -65,14 +64,9 @@ impl LlmApiToken { mut lock: RwLockWriteGuard<'_, Option<String>>, client: &Arc<Client>, ) -> Result<String> { - let system_id = client - .telemetry() - .system_id() - .map(|system_id| system_id.to_string()); - - let response = client.cloud_client().create_llm_token(system_id).await?; - *lock = Some(response.token.0.clone()); - Ok(response.token.0.clone()) + let response = client.request(proto::GetLlmToken {}).await?; + *lock = Some(response.token.clone()); + Ok(response.token.clone()) } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index dc485e9937..6f3d420ad5 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,9 +1,10 @@ use std::io::{Cursor, Write}; use std::sync::Arc; +use crate::role::Role; +use crate::{LanguageModelToolUse, LanguageModelToolUseId}; use anyhow::Result; use base64::write::EncoderWriter; -use cloud_llm_client::{CompletionIntent, CompletionMode}; use gpui::{ App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task, point, px, size, @@ -11,9 +12,7 @@ use gpui::{ use image::codecs::png::PngEncoder; use serde::{Deserialize, Serialize}; use util::ResultExt; - -use crate::role::Role; -use crate::{LanguageModelToolUse, LanguageModelToolUseId}; +use zed_llm_client::{CompletionIntent, CompletionMode}; #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct LanguageModelImage { diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b5bfb870f6..574579aaa7 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -16,17 +16,18 @@ ai_onboarding.workspace = true anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true aws-config = { workspace = true, features = ["behavior-version-latest"] } -aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] } +aws-credential-types = { workspace = true, features = [ + "hardcoded-credentials", +] } aws_http_client.workspace = true bedrock.workspace = true chrono.workspace = true client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true component.workspace = true +credentials_provider.workspace = true convert_case.workspace = true copilot.workspace = true -credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } editor.workspace = true futures.workspace = true @@ -34,7 +35,6 @@ google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true gpui_tokio.workspace = true http_client.workspace = true -language.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true @@ -43,7 +43,10 @@ mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } +vercel = { workspace = true, features = ["schemars"] } +x_ai = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true +proto.workspace = true release_channel.workspace = true schemars.workspace = true serde.workspace = true @@ -58,9 +61,9 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true ui_input.workspace = true util.workspace = true -vercel = { workspace = true, features = ["schemars"] } workspace-hack.workspace = true -x_ai = { workspace = true, features = ["schemars"] } +zed_llm_client.workspace = true +language.workspace = true [dev-dependencies] editor = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 2108547c4f..09a2ac6e0a 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -3,13 +3,6 @@ use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use client::{Client, ModelRequestUsage, UserStore, zed_urls}; -use cloud_llm_client::{ - CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, - CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, - EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, - SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, - TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, -}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; @@ -27,6 +20,7 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, }; +use proto::Plan; use release_channel::AppVersion; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -39,6 +33,13 @@ use std::time::Duration; use thiserror::Error; use ui::{TintColor, prelude::*}; use util::{ResultExt as _, maybe}; +use zed_llm_client::{ + CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, + CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, + ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, + SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, + TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, +}; use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; use crate::provider::google::{GoogleEventMapper, into_google}; @@ -119,10 +120,10 @@ pub struct State { user_store: Entity<UserStore>, status: client::Status, accept_terms_of_service_task: Option<Task<Result<()>>>, - models: Vec<Arc<cloud_llm_client::LanguageModel>>, - default_model: Option<Arc<cloud_llm_client::LanguageModel>>, - default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>, - recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>, + models: Vec<Arc<zed_llm_client::LanguageModel>>, + default_model: Option<Arc<zed_llm_client::LanguageModel>>, + default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>, + recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>, _fetch_models_task: Task<()>, _settings_subscription: Subscription, _llm_token_subscription: Subscription, @@ -136,10 +137,11 @@ impl State { cx: &mut Context<Self>, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); + Self { client: client.clone(), llm_api_token: LlmApiToken::default(), - user_store: user_store.clone(), + user_store, status, accept_terms_of_service_task: None, models: Vec::new(), @@ -152,9 +154,8 @@ impl State { .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; loop { - let is_authenticated = user_store - .read_with(cx, |user_store, _cx| user_store.current_user().is_some())?; - if is_authenticated { + let status = this.read_with(cx, |this, _cx| this.status)?; + if matches!(status, client::Status::Connected { .. }) { break; } @@ -193,20 +194,26 @@ impl State { } } - fn is_signed_out(&self, cx: &App) -> bool { - self.user_store.read(cx).current_user().is_none() + fn is_signed_out(&self) -> bool { + self.status.is_signed_out() } fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { let client = self.client.clone(); cx.spawn(async move |state, cx| { - client.sign_in_with_optional_connect(true, &cx).await?; + client + .authenticate_and_connect(true, &cx) + .await + .into_response()?; state.update(cx, |_, cx| cx.notify()) }) } fn has_accepted_terms_of_service(&self, cx: &App) -> bool { - self.user_store.read(cx).has_accepted_terms_of_service() + self.user_store + .read(cx) + .current_user_has_accepted_terms() + .unwrap_or(false) } fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) { @@ -231,8 +238,8 @@ impl State { // Right now we represent thinking variants of models as separate models on the client, // so we need to insert variants for any model that supports thinking. if model.supports_thinking { - models.push(Arc::new(cloud_llm_client::LanguageModel { - id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), + models.push(Arc::new(zed_llm_client::LanguageModel { + id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), display_name: format!("{} Thinking", model.display_name), ..model })); @@ -321,7 +328,7 @@ impl CloudLanguageModelProvider { fn create_language_model( &self, - model: Arc<cloud_llm_client::LanguageModel>, + model: Arc<zed_llm_client::LanguageModel>, llm_api_token: LlmApiToken, ) -> Arc<dyn LanguageModel> { Arc::new(CloudLanguageModel { @@ -391,7 +398,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn is_authenticated(&self, cx: &App) -> bool { let state = self.state.read(cx); - !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx) + !state.is_signed_out() && state.has_accepted_terms_of_service(cx) } fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> { @@ -511,7 +518,7 @@ fn render_accept_terms( pub struct CloudLanguageModel { id: LanguageModelId, - model: Arc<cloud_llm_client::LanguageModel>, + model: Arc<zed_llm_client::LanguageModel>, llm_api_token: LlmApiToken, client: Arc<Client>, request_limiter: RateLimiter, @@ -604,8 +611,13 @@ impl CloudLanguageModel { .headers() .get(CURRENT_PLAN_HEADER_NAME) .and_then(|plan| plan.to_str().ok()) - .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok()) + .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok()) { + let plan = match plan { + zed_llm_client::Plan::ZedFree => Plan::Free, + zed_llm_client::Plan::ZedPro => Plan::ZedPro, + zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial, + }; return Err(anyhow!(ModelRequestLimitReachedError { plan })); } } @@ -717,7 +729,7 @@ impl LanguageModel for CloudLanguageModel { } fn upstream_provider_id(&self) -> LanguageModelProviderId { - use cloud_llm_client::LanguageModelProvider::*; + use zed_llm_client::LanguageModelProvider::*; match self.model.provider { Anthropic => language_model::ANTHROPIC_PROVIDER_ID, OpenAi => language_model::OPEN_AI_PROVIDER_ID, @@ -726,7 +738,7 @@ impl LanguageModel for CloudLanguageModel { } fn upstream_provider_name(&self) -> LanguageModelProviderName { - use cloud_llm_client::LanguageModelProvider::*; + use zed_llm_client::LanguageModelProvider::*; match self.model.provider { Anthropic => language_model::ANTHROPIC_PROVIDER_NAME, OpenAi => language_model::OPEN_AI_PROVIDER_NAME, @@ -760,11 +772,11 @@ impl LanguageModel for CloudLanguageModel { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic - | cloud_llm_client::LanguageModelProvider::OpenAi => { + zed_llm_client::LanguageModelProvider::Anthropic + | zed_llm_client::LanguageModelProvider::OpenAi => { LanguageModelToolSchemaFormat::JsonSchema } - cloud_llm_client::LanguageModelProvider::Google => { + zed_llm_client::LanguageModelProvider::Google => { LanguageModelToolSchemaFormat::JsonSchemaSubset } } @@ -783,15 +795,15 @@ impl LanguageModel for CloudLanguageModel { fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> { match &self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { + zed_llm_client::LanguageModelProvider::Anthropic => { Some(LanguageModelCacheConfiguration { min_total_token: 2_048, should_speculate: true, max_cache_anchors: 4, }) } - cloud_llm_client::LanguageModelProvider::OpenAi - | cloud_llm_client::LanguageModelProvider::Google => None, + zed_llm_client::LanguageModelProvider::OpenAi + | zed_llm_client::LanguageModelProvider::Google => None, } } @@ -801,17 +813,15 @@ impl LanguageModel for CloudLanguageModel { cx: &App, ) -> BoxFuture<'static, Result<u64>> { match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { - count_anthropic_tokens(request, cx) - } - cloud_llm_client::LanguageModelProvider::OpenAi => { + zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx), + zed_llm_client::LanguageModelProvider::OpenAi => { let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model, Err(err) => return async move { Err(anyhow!(err)) }.boxed(), }; count_open_ai_tokens(request, model, cx) } - cloud_llm_client::LanguageModelProvider::Google => { + zed_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let model_id = self.model.id.to_string(); @@ -822,7 +832,7 @@ impl LanguageModel for CloudLanguageModel { let token = llm_api_token.acquire(&client).await?; let request_body = CountTokensBody { - provider: cloud_llm_client::LanguageModelProvider::Google, + provider: zed_llm_client::LanguageModelProvider::Google, model: model_id, provider_request: serde_json::to_value(&google_ai::CountTokensRequest { generate_content_request, @@ -883,7 +893,7 @@ impl LanguageModel for CloudLanguageModel { let app_version = cx.update(|cx| AppVersion::global(cx)).ok(); let thinking_allowed = request.thinking_allowed; match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { + zed_llm_client::LanguageModelProvider::Anthropic => { let request = into_anthropic( request, self.model.id.to_string(), @@ -914,7 +924,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: cloud_llm_client::LanguageModelProvider::Anthropic, + provider: zed_llm_client::LanguageModelProvider::Anthropic, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -938,7 +948,7 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } - cloud_llm_client::LanguageModelProvider::OpenAi => { + zed_llm_client::LanguageModelProvider::OpenAi => { let client = self.client.clone(); let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model, @@ -966,7 +976,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: cloud_llm_client::LanguageModelProvider::OpenAi, + provider: zed_llm_client::LanguageModelProvider::OpenAi, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -986,7 +996,7 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } - cloud_llm_client::LanguageModelProvider::Google => { + zed_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let request = into_google(request, self.model.id.to_string(), GoogleModelMode::Default); @@ -1006,7 +1016,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: cloud_llm_client::LanguageModelProvider::Google, + provider: zed_llm_client::LanguageModelProvider::Google, model: request.model.model_id.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -1030,8 +1040,15 @@ impl LanguageModel for CloudLanguageModel { } } +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CloudCompletionEvent<T> { + Status(CompletionRequestStatus), + Event(T), +} + fn map_cloud_completion_events<T, F>( - stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>, + stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>, mut map_callback: F, ) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> where @@ -1046,10 +1063,10 @@ where Err(error) => { vec![Err(LanguageModelCompletionError::from(error))] } - Ok(CompletionEvent::Status(event)) => { + Ok(CloudCompletionEvent::Status(event)) => { vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] } - Ok(CompletionEvent::Event(event)) => map_callback(event), + Ok(CloudCompletionEvent::Event(event)) => map_callback(event), }) }) .boxed() @@ -1057,9 +1074,9 @@ where fn usage_updated_event<T>( usage: Option<ModelRequestUsage>, -) -> impl Stream<Item = Result<CompletionEvent<T>>> { +) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { futures::stream::iter(usage.map(|usage| { - Ok(CompletionEvent::Status( + Ok(CloudCompletionEvent::Status( CompletionRequestStatus::UsageUpdated { amount: usage.amount as usize, limit: usage.limit, @@ -1070,9 +1087,9 @@ fn usage_updated_event<T>( fn tool_use_limit_reached_event<T>( tool_use_limit_reached: bool, -) -> impl Stream<Item = Result<CompletionEvent<T>>> { +) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { futures::stream::iter(tool_use_limit_reached.then(|| { - Ok(CompletionEvent::Status( + Ok(CloudCompletionEvent::Status( CompletionRequestStatus::ToolUseLimitReached, )) })) @@ -1081,7 +1098,7 @@ fn tool_use_limit_reached_event<T>( fn response_lines<T: DeserializeOwned>( response: Response<AsyncBody>, includes_status_messages: bool, -) -> impl Stream<Item = Result<CompletionEvent<T>>> { +) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { futures::stream::try_unfold( (String::new(), BufReader::new(response.into_body())), move |(mut line, mut body)| async move { @@ -1089,9 +1106,9 @@ fn response_lines<T: DeserializeOwned>( Ok(0) => Ok(None), Ok(_) => { let event = if includes_status_messages { - serde_json::from_str::<CompletionEvent<T>>(&line)? + serde_json::from_str::<CloudCompletionEvent<T>>(&line)? } else { - CompletionEvent::Event(serde_json::from_str::<T>(&line)?) + CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?) }; line.clear(); @@ -1106,7 +1123,7 @@ fn response_lines<T: DeserializeOwned>( #[derive(IntoElement, RegisterComponent)] struct ZedAiConfiguration { is_connected: bool, - plan: Option<Plan>, + plan: Option<proto::Plan>, subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>, eligible_for_trial: bool, has_accepted_terms_of_service: bool, @@ -1120,15 +1137,15 @@ impl RenderOnce for ZedAiConfiguration { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { let young_account_banner = YoungAccountBanner; - let is_pro = self.plan == Some(Plan::ZedPro); + let is_pro = self.plan == Some(proto::Plan::ZedPro); let subscription_text = match (self.plan, self.subscription_period) { - (Some(Plan::ZedPro), Some(_)) => { + (Some(proto::Plan::ZedPro), Some(_)) => { "You have access to Zed's hosted models through your Pro subscription." } - (Some(Plan::ZedProTrial), Some(_)) => { + (Some(proto::Plan::ZedProTrial), Some(_)) => { "You have access to Zed's hosted models through your Pro trial." } - (Some(Plan::ZedFree), Some(_)) => { + (Some(proto::Plan::Free), Some(_)) => { "You have basic access to Zed's hosted models through the Free plan." } _ => { @@ -1253,8 +1270,8 @@ impl Render for ConfigurationView { let user_store = state.user_store.read(cx); ZedAiConfiguration { - is_connected: !state.is_signed_out(cx), - plan: user_store.plan(), + is_connected: !state.is_signed_out(), + plan: user_store.current_plan(), subscription_period: user_store.subscription_period(), eligible_for_trial: user_store.trial_started_at().is_none(), has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), @@ -1274,7 +1291,7 @@ impl Component for ZedAiConfiguration { fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { fn configuration( is_connected: bool, - plan: Option<Plan>, + plan: Option<proto::Plan>, eligible_for_trial: bool, account_too_young: bool, has_accepted_terms_of_service: bool, @@ -1318,15 +1335,15 @@ impl Component for ZedAiConfiguration { ), single_example( "Free Plan", - configuration(true, Some(Plan::ZedFree), true, false, true), + configuration(true, Some(proto::Plan::Free), true, false, true), ), single_example( "Zed Pro Trial Plan", - configuration(true, Some(Plan::ZedProTrial), true, false, true), + configuration(true, Some(proto::Plan::ZedProTrial), true, false, true), ), single_example( "Zed Pro Plan", - configuration(true, Some(Plan::ZedPro), true, false, true), + configuration(true, Some(proto::Plan::ZedPro), true, false, true), ), ]) .into_any_element(), diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 3cdc2e5401..d9a84f1eb7 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -3,7 +3,6 @@ use std::str::FromStr as _; use std::sync::Arc; use anyhow::{Result, anyhow}; -use cloud_llm_client::CompletionIntent; use collections::HashMap; use copilot::copilot_chat::{ ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl, @@ -31,6 +30,7 @@ use settings::SettingsStore; use std::time::Duration; use ui::prelude::*; use util::debug_panic; +use zed_llm_client::CompletionIntent; use super::anthropic::count_anthropic_tokens; use super::google::count_google_tokens; diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 9792b4f27b..01600f3646 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -744,7 +744,7 @@ impl Render for ConfigurationView { Button::new("retry_lmstudio_models", "Connect") .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) - .icon(IconName::PlayOutlined) + .icon(IconName::Play) .on_click(cx.listener(move |this, _, _window, cx| { this.retry_connection(cx) })), diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index d4739bcab8..dc81e8be18 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -192,16 +192,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { IconName::AiOllama } - fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> { - // We shouldn't try to select default model, because it might lead to a load call for an unloaded model. - // In a constrained environment where user might not have enough resources it'll be a bad UX to select something - // to load by default. - None + fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { + self.provided_models(cx).into_iter().next() } - fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> { - // See explanation for default_model. - None + fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { + self.default_model(cx) } fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { @@ -658,7 +654,7 @@ impl Render for ConfigurationView { Button::new("retry_ollama_models", "Connect") .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) - .icon(IconName::PlayOutlined) + .icon(IconName::Play) .on_click(cx.listener(move |this, _, _, cx| { this.retry_connection(cx) })), diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index 2b0e13f4be..d1a90d7dbb 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -867,7 +867,7 @@ impl LspLogView { BINARY = server.binary(), WORKSPACE_FOLDERS = server .workspace_folders() - .into_iter() + .iter() .filter_map(|path| path .to_file_path() .ok() diff --git a/crates/language_tools/src/lsp_tool.rs b/crates/language_tools/src/lsp_tool.rs index 50547253a9..9e95ed4673 100644 --- a/crates/language_tools/src/lsp_tool.rs +++ b/crates/language_tools/src/lsp_tool.rs @@ -1015,7 +1015,7 @@ impl Render for LspTool { .anchor(Corner::BottomLeft) .with_handle(self.popover_menu_handle.clone()) .trigger_with_tooltip( - IconButton::new("zed-lsp-tool-button", IconName::BoltOutlined) + IconButton::new("zed-lsp-tool-button", IconName::BoltFilledAlt) .when_some(indicator, IconButton::indicator) .icon_size(IconSize::Small) .indicator_border_color(Some(cx.theme().colors().status_bar_background)), diff --git a/crates/languages/Cargo.toml b/crates/languages/Cargo.toml index 260126da63..2e8f007cff 100644 --- a/crates/languages/Cargo.toml +++ b/crates/languages/Cargo.toml @@ -41,7 +41,6 @@ async-trait.workspace = true chrono.workspace = true collections.workspace = true dap.workspace = true -feature_flags.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true diff --git a/crates/languages/src/bash/config.toml b/crates/languages/src/bash/config.toml index 8ff4802aee..db9a2749e7 100644 --- a/crates/languages/src/bash/config.toml +++ b/crates/languages/src/bash/config.toml @@ -18,20 +18,17 @@ brackets = [ { start = "in", end = "esac", close = false, newline = true, not_in = ["comment", "string"] }, ] -auto_indent_using_last_non_empty_line = false -increase_indent_pattern = "^\\s*(\\b(else|elif)\\b|([^#]+\\b(do|then|in)\\b)|([\\w\\*]+\\)))\\s*$" -decrease_indent_patterns = [ - { pattern = "^\\s*elif\\b.*", valid_after = ["if", "elif"] }, - { pattern = "^\\s*else\\b.*", valid_after = ["if", "elif", "for", "while"] }, - { pattern = "^\\s*fi\\b.*", valid_after = ["if", "elif", "else"] }, - { pattern = "^\\s*done\\b.*", valid_after = ["for", "while"] }, - { pattern = "^\\s*esac\\b.*", valid_after = ["case"] }, - { pattern = "^\\s*[\\w\\*]+\\)\\s*$", valid_after = ["case_item"] }, -] - -# We can't use decrease_indent_patterns simply for elif, because -# there is bug in tree sitter which throws ERROR on if match. -# -# This is workaround. That means, elif will outdents with despite -# of wrong context. Like using elif after else. -decrease_indent_pattern = "(^|\\s+|;)(elif)\\b.*$" +### WARN: the following is not working when you insert an `elif` just before an else +### example: (^ is cursor after hitting enter) +### ``` +### if true; then +### foo +### elif +### ^ +### else +### bar +### fi +### ``` +increase_indent_pattern = "(^|\\s+|;)(do|then|in|else|elif)\\b.*$" +decrease_indent_pattern = "(^|\\s+|;)(fi|done|esac|else|elif)\\b.*$" +# make sure to test each line mode & block mode diff --git a/crates/languages/src/bash/indents.scm b/crates/languages/src/bash/indents.scm index 468fc595e5..acdcddabfe 100644 --- a/crates/languages/src/bash/indents.scm +++ b/crates/languages/src/bash/indents.scm @@ -1,12 +1,12 @@ -(_ "[" "]" @end) @indent -(_ "{" "}" @end) @indent -(_ "(" ")" @end) @indent +(function_definition + "function"? + body: ( + _ + "{" @start + "}" @end + )) @indent -(function_definition) @start.function -(if_statement) @start.if -(elif_clause) @start.elif -(else_clause) @start.else -(for_statement) @start.for -(while_statement) @start.while -(case_statement) @start.case -(case_item) @start.case_item +(array + "(" @start + ")" @end + ) @indent diff --git a/crates/languages/src/go/runnables.scm b/crates/languages/src/go/runnables.scm index 6418cd04d8..49e112b860 100644 --- a/crates/languages/src/go/runnables.scm +++ b/crates/languages/src/go/runnables.scm @@ -69,7 +69,7 @@ ( ( (function_declaration name: (_) @run @_name - (#match? @_name "^Benchmark.*")) + (#match? @_name "^Benchmark.+")) ) @_ (#set! tag go-benchmark) ) diff --git a/crates/languages/src/json.rs b/crates/languages/src/json.rs index 601b4620c5..15818730b8 100644 --- a/crates/languages/src/json.rs +++ b/crates/languages/src/json.rs @@ -8,8 +8,8 @@ use futures::StreamExt; use gpui::{App, AsyncApp, Task}; use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; use language::{ - ContextProvider, LanguageName, LanguageRegistry, LanguageToolchainStore, LocalFile as _, - LspAdapter, LspAdapterDelegate, + ContextProvider, LanguageRegistry, LanguageToolchainStore, LocalFile as _, LspAdapter, + LspAdapterDelegate, }; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; @@ -408,10 +408,10 @@ impl LspAdapter for JsonLspAdapter { Ok(config) } - fn language_ids(&self) -> HashMap<LanguageName, String> { + fn language_ids(&self) -> HashMap<String, String> { [ - (LanguageName::new("JSON"), "json".into()), - (LanguageName::new("JSONC"), "jsonc".into()), + ("JSON".into(), "json".into()), + ("JSONC".into(), "jsonc".into()), ] .into_iter() .collect() diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index 001fd15200..a224111002 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -1,5 +1,4 @@ use anyhow::Context as _; -use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use gpui::{App, UpdateGlobal}; use node_runtime::NodeRuntime; use python::PyprojectTomlManifestProvider; @@ -12,7 +11,7 @@ use util::{ResultExt, asset_str}; pub use language::*; -use crate::{json::JsonTaskProvider, python::BasedPyrightLspAdapter}; +use crate::json::JsonTaskProvider; mod bash; mod c; @@ -53,12 +52,6 @@ pub static LANGUAGE_GIT_COMMIT: std::sync::LazyLock<Arc<Language>> = )) }); -struct BasedPyrightFeatureFlag; - -impl FeatureFlag for BasedPyrightFeatureFlag { - const NAME: &'static str = "basedpyright"; -} - pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { #[cfg(feature = "load-grammars")] languages.register_native_grammars([ @@ -95,7 +88,6 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { let py_lsp_adapter = Arc::new(python::PyLspAdapter::new()); let python_context_provider = Arc::new(python::PythonContextProvider); let python_lsp_adapter = Arc::new(python::PythonLspAdapter::new(node.clone())); - let basedpyright_lsp_adapter = Arc::new(BasedPyrightLspAdapter::new()); let python_toolchain_provider = Arc::new(python::PythonToolchainProvider::default()); let rust_context_provider = Arc::new(rust::RustContextProvider); let rust_lsp_adapter = Arc::new(rust::RustLspAdapter); @@ -236,20 +228,6 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { ); } - let mut basedpyright_lsp_adapter = Some(basedpyright_lsp_adapter); - cx.observe_flag::<BasedPyrightFeatureFlag, _>({ - let languages = languages.clone(); - move |enabled, _| { - if enabled { - if let Some(adapter) = basedpyright_lsp_adapter.take() { - languages - .register_available_lsp_adapter(adapter.name(), move || adapter.clone()); - } - } - } - }) - .detach(); - // Register globally available language servers. // // This will allow users to add support for a built-in language server (e.g., Tailwind) diff --git a/crates/languages/src/python.rs b/crates/languages/src/python.rs index 0524c02fd5..dc6996d399 100644 --- a/crates/languages/src/python.rs +++ b/crates/languages/src/python.rs @@ -4,13 +4,13 @@ use async_trait::async_trait; use collections::HashMap; use gpui::{App, Task}; use gpui::{AsyncApp, SharedString}; +use language::Toolchain; use language::ToolchainList; use language::ToolchainLister; use language::language_settings::language_settings; use language::{ContextLocation, LanguageToolchainStore}; use language::{ContextProvider, LspAdapter, LspAdapterDelegate}; use language::{LanguageName, ManifestName, ManifestProvider, ManifestQuery}; -use language::{Toolchain, WorkspaceFoldersContent}; use lsp::LanguageServerBinary; use lsp::LanguageServerName; use node_runtime::NodeRuntime; @@ -400,9 +400,6 @@ impl LspAdapter for PythonLspAdapter { fn manifest_name(&self) -> Option<ManifestName> { Some(SharedString::new_static("pyproject.toml").into()) } - fn workspace_folders_content(&self) -> WorkspaceFoldersContent { - WorkspaceFoldersContent::WorktreeRoot - } } async fn get_cached_server_binary( @@ -1285,350 +1282,6 @@ impl LspAdapter for PyLspAdapter { fn manifest_name(&self) -> Option<ManifestName> { Some(SharedString::new_static("pyproject.toml").into()) } - fn workspace_folders_content(&self) -> WorkspaceFoldersContent { - WorkspaceFoldersContent::WorktreeRoot - } -} - -pub(crate) struct BasedPyrightLspAdapter { - python_venv_base: OnceCell<Result<Arc<Path>, String>>, -} - -impl BasedPyrightLspAdapter { - const SERVER_NAME: LanguageServerName = LanguageServerName::new_static("basedpyright"); - const BINARY_NAME: &'static str = "basedpyright-langserver"; - - pub(crate) fn new() -> Self { - Self { - python_venv_base: OnceCell::new(), - } - } - - async fn ensure_venv(delegate: &dyn LspAdapterDelegate) -> Result<Arc<Path>> { - let python_path = Self::find_base_python(delegate) - .await - .context("Could not find Python installation for basedpyright")?; - let work_dir = delegate - .language_server_download_dir(&Self::SERVER_NAME) - .await - .context("Could not get working directory for basedpyright")?; - let mut path = PathBuf::from(work_dir.as_ref()); - path.push("basedpyright-venv"); - if !path.exists() { - util::command::new_smol_command(python_path) - .arg("-m") - .arg("venv") - .arg("basedpyright-venv") - .current_dir(work_dir) - .spawn()? - .output() - .await?; - } - - Ok(path.into()) - } - - // Find "baseline", user python version from which we'll create our own venv. - async fn find_base_python(delegate: &dyn LspAdapterDelegate) -> Option<PathBuf> { - for path in ["python3", "python"] { - if let Some(path) = delegate.which(path.as_ref()).await { - return Some(path); - } - } - None - } - - async fn base_venv(&self, delegate: &dyn LspAdapterDelegate) -> Result<Arc<Path>, String> { - self.python_venv_base - .get_or_init(move || async move { - Self::ensure_venv(delegate) - .await - .map_err(|e| format!("{e}")) - }) - .await - .clone() - } -} - -#[async_trait(?Send)] -impl LspAdapter for BasedPyrightLspAdapter { - fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() - } - - async fn initialization_options( - self: Arc<Self>, - _: &dyn Fs, - _: &Arc<dyn LspAdapterDelegate>, - ) -> Result<Option<Value>> { - // Provide minimal initialization options - // Virtual environment configuration will be handled through workspace configuration - Ok(Some(json!({ - "python": { - "analysis": { - "autoSearchPaths": true, - "useLibraryCodeForTypes": true, - "autoImportCompletions": true - } - } - }))) - } - - async fn check_if_user_installed( - &self, - delegate: &dyn LspAdapterDelegate, - toolchains: Arc<dyn LanguageToolchainStore>, - cx: &AsyncApp, - ) -> Option<LanguageServerBinary> { - if let Some(bin) = delegate.which(Self::BINARY_NAME.as_ref()).await { - let env = delegate.shell_env().await; - Some(LanguageServerBinary { - path: bin, - env: Some(env), - arguments: vec!["--stdio".into()], - }) - } else { - let venv = toolchains - .active_toolchain( - delegate.worktree_id(), - Arc::from("".as_ref()), - LanguageName::new("Python"), - &mut cx.clone(), - ) - .await?; - let path = Path::new(venv.path.as_ref()) - .parent()? - .join(Self::BINARY_NAME); - path.exists().then(|| LanguageServerBinary { - path, - arguments: vec!["--stdio".into()], - env: None, - }) - } - } - - async fn fetch_latest_server_version( - &self, - _: &dyn LspAdapterDelegate, - ) -> Result<Box<dyn 'static + Any + Send>> { - Ok(Box::new(()) as Box<_>) - } - - async fn fetch_server_binary( - &self, - _latest_version: Box<dyn 'static + Send + Any>, - _container_dir: PathBuf, - delegate: &dyn LspAdapterDelegate, - ) -> Result<LanguageServerBinary> { - let venv = self.base_venv(delegate).await.map_err(|e| anyhow!(e))?; - let pip_path = venv.join(BINARY_DIR).join("pip3"); - ensure!( - util::command::new_smol_command(pip_path.as_path()) - .arg("install") - .arg("basedpyright") - .arg("-U") - .output() - .await? - .status - .success(), - "basedpyright installation failed" - ); - let pylsp = venv.join(BINARY_DIR).join(Self::BINARY_NAME); - Ok(LanguageServerBinary { - path: pylsp, - env: None, - arguments: vec!["--stdio".into()], - }) - } - - async fn cached_server_binary( - &self, - _container_dir: PathBuf, - delegate: &dyn LspAdapterDelegate, - ) -> Option<LanguageServerBinary> { - let venv = self.base_venv(delegate).await.ok()?; - let pylsp = venv.join(BINARY_DIR).join(Self::BINARY_NAME); - Some(LanguageServerBinary { - path: pylsp, - env: None, - arguments: vec!["--stdio".into()], - }) - } - - async fn process_completions(&self, items: &mut [lsp::CompletionItem]) { - // Pyright assigns each completion item a `sortText` of the form `XX.YYYY.name`. - // Where `XX` is the sorting category, `YYYY` is based on most recent usage, - // and `name` is the symbol name itself. - // - // Because the symbol name is included, there generally are not ties when - // sorting by the `sortText`, so the symbol's fuzzy match score is not taken - // into account. Here, we remove the symbol name from the sortText in order - // to allow our own fuzzy score to be used to break ties. - // - // see https://github.com/microsoft/pyright/blob/95ef4e103b9b2f129c9320427e51b73ea7cf78bd/packages/pyright-internal/src/languageService/completionProvider.ts#LL2873 - for item in items { - let Some(sort_text) = &mut item.sort_text else { - continue; - }; - let mut parts = sort_text.split('.'); - let Some(first) = parts.next() else { continue }; - let Some(second) = parts.next() else { continue }; - let Some(_) = parts.next() else { continue }; - sort_text.replace_range(first.len() + second.len() + 1.., ""); - } - } - - async fn label_for_completion( - &self, - item: &lsp::CompletionItem, - language: &Arc<language::Language>, - ) -> Option<language::CodeLabel> { - let label = &item.label; - let grammar = language.grammar()?; - let highlight_id = match item.kind? { - lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method")?, - lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function")?, - lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type")?, - lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant")?, - _ => return None, - }; - let filter_range = item - .filter_text - .as_deref() - .and_then(|filter| label.find(filter).map(|ix| ix..ix + filter.len())) - .unwrap_or(0..label.len()); - Some(language::CodeLabel { - text: label.clone(), - runs: vec![(0..label.len(), highlight_id)], - filter_range, - }) - } - - async fn label_for_symbol( - &self, - name: &str, - kind: lsp::SymbolKind, - language: &Arc<language::Language>, - ) -> Option<language::CodeLabel> { - let (text, filter_range, display_range) = match kind { - lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => { - let text = format!("def {}():\n", name); - let filter_range = 4..4 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::CLASS => { - let text = format!("class {}:", name); - let filter_range = 6..6 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::CONSTANT => { - let text = format!("{} = 0", name); - let filter_range = 0..name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - _ => return None, - }; - - Some(language::CodeLabel { - runs: language.highlight_text(&text.as_str().into(), display_range.clone()), - text: text[display_range].to_string(), - filter_range, - }) - } - - async fn workspace_configuration( - self: Arc<Self>, - _: &dyn Fs, - adapter: &Arc<dyn LspAdapterDelegate>, - toolchains: Arc<dyn LanguageToolchainStore>, - cx: &mut AsyncApp, - ) -> Result<Value> { - let toolchain = toolchains - .active_toolchain( - adapter.worktree_id(), - Arc::from("".as_ref()), - LanguageName::new("Python"), - cx, - ) - .await; - cx.update(move |cx| { - let mut user_settings = - language_server_settings(adapter.as_ref(), &Self::SERVER_NAME, cx) - .and_then(|s| s.settings.clone()) - .unwrap_or_default(); - - // If we have a detected toolchain, configure Pyright to use it - if let Some(toolchain) = toolchain { - if user_settings.is_null() { - user_settings = Value::Object(serde_json::Map::default()); - } - let object = user_settings.as_object_mut().unwrap(); - - let interpreter_path = toolchain.path.to_string(); - - // Detect if this is a virtual environment - if let Some(interpreter_dir) = Path::new(&interpreter_path).parent() { - if let Some(venv_dir) = interpreter_dir.parent() { - // Check if this looks like a virtual environment - if venv_dir.join("pyvenv.cfg").exists() - || venv_dir.join("bin/activate").exists() - || venv_dir.join("Scripts/activate.bat").exists() - { - // Set venvPath and venv at the root level - // This matches the format of a pyrightconfig.json file - if let Some(parent) = venv_dir.parent() { - // Use relative path if the venv is inside the workspace - let venv_path = if parent == adapter.worktree_root_path() { - ".".to_string() - } else { - parent.to_string_lossy().into_owned() - }; - object.insert("venvPath".to_string(), Value::String(venv_path)); - } - - if let Some(venv_name) = venv_dir.file_name() { - object.insert( - "venv".to_owned(), - Value::String(venv_name.to_string_lossy().into_owned()), - ); - } - } - } - } - - // Always set the python interpreter path - // Get or create the python section - let python = object - .entry("python") - .or_insert(Value::Object(serde_json::Map::default())) - .as_object_mut() - .unwrap(); - - // Set both pythonPath and defaultInterpreterPath for compatibility - python.insert( - "pythonPath".to_owned(), - Value::String(interpreter_path.clone()), - ); - python.insert( - "defaultInterpreterPath".to_owned(), - Value::String(interpreter_path), - ); - } - - user_settings - }) - } - - fn manifest_name(&self) -> Option<ManifestName> { - Some(SharedString::new_static("pyproject.toml").into()) - } - - fn workspace_folders_content(&self) -> WorkspaceFoldersContent { - WorkspaceFoldersContent::WorktreeRoot - } } #[cfg(test)] diff --git a/crates/languages/src/tailwind.rs b/crates/languages/src/tailwind.rs index a7edbb148c..cb4e939083 100644 --- a/crates/languages/src/tailwind.rs +++ b/crates/languages/src/tailwind.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use collections::HashMap; use futures::StreamExt; use gpui::AsyncApp; -use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; @@ -168,20 +168,20 @@ impl LspAdapter for TailwindLspAdapter { })) } - fn language_ids(&self) -> HashMap<LanguageName, String> { + fn language_ids(&self) -> HashMap<String, String> { HashMap::from_iter([ - (LanguageName::new("Astro"), "astro".to_string()), - (LanguageName::new("HTML"), "html".to_string()), - (LanguageName::new("CSS"), "css".to_string()), - (LanguageName::new("JavaScript"), "javascript".to_string()), - (LanguageName::new("TSX"), "typescriptreact".to_string()), - (LanguageName::new("Svelte"), "svelte".to_string()), - (LanguageName::new("Elixir"), "phoenix-heex".to_string()), - (LanguageName::new("HEEX"), "phoenix-heex".to_string()), - (LanguageName::new("ERB"), "erb".to_string()), - (LanguageName::new("HTML/ERB"), "erb".to_string()), - (LanguageName::new("PHP"), "php".to_string()), - (LanguageName::new("Vue.js"), "vue".to_string()), + ("Astro".to_string(), "astro".to_string()), + ("HTML".to_string(), "html".to_string()), + ("CSS".to_string(), "css".to_string()), + ("JavaScript".to_string(), "javascript".to_string()), + ("TSX".to_string(), "typescriptreact".to_string()), + ("Svelte".to_string(), "svelte".to_string()), + ("Elixir".to_string(), "phoenix-heex".to_string()), + ("HEEX".to_string(), "phoenix-heex".to_string()), + ("ERB".to_string(), "erb".to_string()), + ("HTML/ERB".to_string(), "erb".to_string()), + ("PHP".to_string(), "php".to_string()), + ("Vue.js".to_string(), "vue".to_string()), ]) } } diff --git a/crates/languages/src/typescript.rs b/crates/languages/src/typescript.rs index 9dc3ee303d..34b9c3224e 100644 --- a/crates/languages/src/typescript.rs +++ b/crates/languages/src/typescript.rs @@ -8,8 +8,7 @@ use futures::future::join_all; use gpui::{App, AppContext, AsyncApp, Task}; use http_client::github::{AssetKind, GitHubLspBinaryVersion, build_asset_url}; use language::{ - ContextLocation, ContextProvider, File, LanguageName, LanguageToolchainStore, LspAdapter, - LspAdapterDelegate, + ContextLocation, ContextProvider, File, LanguageToolchainStore, LspAdapter, LspAdapterDelegate, }; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; @@ -513,7 +512,7 @@ fn eslint_server_binary_arguments(server_path: &Path) -> Vec<OsString> { fn replace_test_name_parameters(test_name: &str) -> String { let pattern = regex::Regex::new(r"(%|\$)[0-9a-zA-Z]+").unwrap(); - regex::escape(&pattern.replace_all(test_name, "(.+?)")) + pattern.replace_all(test_name, "(.+?)").to_string() } pub struct TypeScriptLspAdapter { @@ -742,11 +741,11 @@ impl LspAdapter for TypeScriptLspAdapter { })) } - fn language_ids(&self) -> HashMap<LanguageName, String> { + fn language_ids(&self) -> HashMap<String, String> { HashMap::from_iter([ - (LanguageName::new("TypeScript"), "typescript".into()), - (LanguageName::new("JavaScript"), "javascript".into()), - (LanguageName::new("TSX"), "typescriptreact".into()), + ("TypeScript".into(), "typescript".into()), + ("JavaScript".into(), "javascript".into()), + ("TSX".into(), "typescriptreact".into()), ]) } } diff --git a/crates/languages/src/typescript/runnables.scm b/crates/languages/src/typescript/runnables.scm index 6bfc536329..85702cf99d 100644 --- a/crates/languages/src/typescript/runnables.scm +++ b/crates/languages/src/typescript/runnables.scm @@ -1,4 +1,4 @@ -; Add support for (node:test, bun:test, Jest and Deno.test) runnable +; Add support for (node:test, bun:test and Jest) runnable ; Function expression that has `it`, `test` or `describe` as the function name ( (call_expression @@ -44,42 +44,3 @@ (#set! tag js-test) ) - -; Add support for Deno.test with string names -( - (call_expression - function: (member_expression - object: (identifier) @_namespace - property: (property_identifier) @_method - ) - (#eq? @_namespace "Deno") - (#eq? @_method "test") - arguments: ( - arguments . [ - (string (string_fragment) @run @DENO_TEST_NAME) - (identifier) @run @DENO_TEST_NAME - ] - ) - ) @_js-test - - (#set! tag js-test) -) - -; Add support for Deno.test with named function expressions -( - (call_expression - function: (member_expression - object: (identifier) @_namespace - property: (property_identifier) @_method - ) - (#eq? @_namespace "Deno") - (#eq? @_method "test") - arguments: ( - arguments . (function_expression - name: (identifier) @run @DENO_TEST_NAME - ) - ) - ) @_js-test - - (#set! tag js-test) -) diff --git a/crates/languages/src/vtsls.rs b/crates/languages/src/vtsls.rs index 33751f733e..ca07673d5f 100644 --- a/crates/languages/src/vtsls.rs +++ b/crates/languages/src/vtsls.rs @@ -2,7 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use collections::HashMap; use gpui::AsyncApp; -use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; @@ -273,11 +273,11 @@ impl LspAdapter for VtslsLspAdapter { Ok(default_workspace_configuration) } - fn language_ids(&self) -> HashMap<LanguageName, String> { + fn language_ids(&self) -> HashMap<String, String> { HashMap::from_iter([ - (LanguageName::new("TypeScript"), "typescript".into()), - (LanguageName::new("JavaScript"), "javascript".into()), - (LanguageName::new("TSX"), "typescriptreact".into()), + ("TypeScript".into(), "typescript".into()), + ("JavaScript".into(), "javascript".into()), + ("TSX".into(), "typescriptreact".into()), ]) } } diff --git a/crates/languages/src/yaml/outline.scm b/crates/languages/src/yaml/outline.scm index c5a7f8e5d4..7ab007835f 100644 --- a/crates/languages/src/yaml/outline.scm +++ b/crates/languages/src/yaml/outline.scm @@ -1,9 +1 @@ -(block_mapping_pair - key: - (flow_node - (plain_scalar - (string_scalar) @name)) - value: - (flow_node - (plain_scalar - (string_scalar) @context))?) @item +(block_mapping_pair key: (flow_node (plain_scalar (string_scalar) @name))) @item diff --git a/crates/livekit_client/Cargo.toml b/crates/livekit_client/Cargo.toml index 821fd5d390..a0c11d46e6 100644 --- a/crates/livekit_client/Cargo.toml +++ b/crates/livekit_client/Cargo.toml @@ -40,8 +40,8 @@ util.workspace = true workspace-hack.workspace = true [target.'cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))'.dependencies] -libwebrtc = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks" } -livekit = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ +libwebrtc = { rev = "d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4", git = "https://github.com/zed-industries/livekit-rust-sdks" } +livekit = { rev = "d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ "__rustls-tls" ] } diff --git a/crates/livekit_client/src/lib.rs b/crates/livekit_client/src/lib.rs index 149859fdc8..f94181b8f8 100644 --- a/crates/livekit_client/src/lib.rs +++ b/crates/livekit_client/src/lib.rs @@ -3,41 +3,16 @@ use collections::HashMap; mod remote_video_track_view; pub use remote_video_track_view::{RemoteVideoTrackView, RemoteVideoTrackViewEvent}; -#[cfg(not(any( - test, - feature = "test-support", - all(target_os = "windows", target_env = "gnu"), - target_os = "freebsd" -)))] +#[cfg(not(any(test, feature = "test-support", target_os = "freebsd")))] mod livekit_client; -#[cfg(not(any( - test, - feature = "test-support", - all(target_os = "windows", target_env = "gnu"), - target_os = "freebsd" -)))] +#[cfg(not(any(test, feature = "test-support", target_os = "freebsd")))] pub use livekit_client::*; -#[cfg(any( - test, - feature = "test-support", - all(target_os = "windows", target_env = "gnu"), - target_os = "freebsd" -))] +#[cfg(any(test, feature = "test-support", target_os = "freebsd"))] mod mock_client; -#[cfg(any( - test, - feature = "test-support", - all(target_os = "windows", target_env = "gnu"), - target_os = "freebsd" -))] +#[cfg(any(test, feature = "test-support", target_os = "freebsd"))] pub mod test; -#[cfg(any( - test, - feature = "test-support", - all(target_os = "windows", target_env = "gnu"), - target_os = "freebsd" -))] +#[cfg(any(test, feature = "test-support", target_os = "freebsd"))] pub use mock_client::*; #[derive(Debug, Clone)] diff --git a/crates/livekit_client/src/mock_client/participant.rs b/crates/livekit_client/src/mock_client/participant.rs index 033808cbb5..991d10bd50 100644 --- a/crates/livekit_client/src/mock_client/participant.rs +++ b/crates/livekit_client/src/mock_client/participant.rs @@ -5,9 +5,7 @@ use crate::{ }; use anyhow::Result; use collections::HashMap; -use gpui::{ - AsyncApp, DevicePixels, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata, size, -}; +use gpui::{AsyncApp, ScreenCaptureSource, ScreenCaptureStream, TestScreenCaptureStream}; #[derive(Clone, Debug)] pub struct LocalParticipant { @@ -121,16 +119,3 @@ impl RemoteParticipant { self.identity.clone() } } - -struct TestScreenCaptureStream; - -impl ScreenCaptureStream for TestScreenCaptureStream { - fn metadata(&self) -> Result<SourceMetadata> { - Ok(SourceMetadata { - id: 0, - is_main: None, - label: None, - resolution: size(DevicePixels(1), DevicePixels(1)), - }) - } -} diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index b9701a83d2..7dcfa61f47 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -4,7 +4,7 @@ pub use lsp_types::request::*; pub use lsp_types::*; use anyhow::{Context as _, Result, anyhow}; -use collections::{BTreeMap, HashMap}; +use collections::HashMap; use futures::{ AsyncRead, AsyncWrite, Future, FutureExt, channel::oneshot::{self, Canceled}, @@ -29,7 +29,7 @@ use std::{ ffi::{OsStr, OsString}, fmt, io::Write, - ops::DerefMut, + ops::{Deref, DerefMut}, path::PathBuf, pin::Pin, sync::{ @@ -40,7 +40,7 @@ use std::{ time::{Duration, Instant}, }; use std::{path::Path, process::Stdio}; -use util::{ConnectionResult, ResultExt, TryFutureExt, redact}; +use util::{ConnectionResult, ResultExt, TryFutureExt}; const JSON_RPC_VERSION: &str = "2.0"; const CONTENT_LEN_HEADER: &str = "Content-Length: "; @@ -62,7 +62,7 @@ pub enum IoKind { /// Represents a launchable language server. This can either be a standalone binary or the path /// to a runtime with arguments to instruct it to launch the actual language server file. -#[derive(Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct LanguageServerBinary { pub path: PathBuf, pub arguments: Vec<OsString>, @@ -100,7 +100,7 @@ pub struct LanguageServer { io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>, output_done_rx: Mutex<Option<barrier::Receiver>>, server: Arc<Mutex<Option<Child>>>, - workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, + workspace_folders: Arc<Mutex<BTreeSet<Url>>>, root_uri: Url, } @@ -307,7 +307,7 @@ impl LanguageServer { binary: LanguageServerBinary, root_path: &Path, code_action_kinds: Option<Vec<CodeActionKind>>, - workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, + workspace_folders: Arc<Mutex<BTreeSet<Url>>>, cx: &mut AsyncApp, ) -> Result<Self> { let working_dir = if root_path.is_dir() { @@ -381,7 +381,7 @@ impl LanguageServer { code_action_kinds: Option<Vec<CodeActionKind>>, binary: LanguageServerBinary, root_uri: Url, - workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, + workspace_folders: Arc<Mutex<BTreeSet<Url>>>, cx: &mut AsyncApp, on_unhandled_notification: F, ) -> Self @@ -421,14 +421,14 @@ impl LanguageServer { .map(|stderr| { let io_handlers = io_handlers.clone(); let stderr_captures = stderr_capture.clone(); - cx.background_spawn(async move { + cx.spawn(async move |_| { Self::handle_stderr(stderr, io_handlers, stderr_captures) .log_err() .await }) }) .unwrap_or_else(|| Task::ready(None)); - let input_task = cx.background_spawn(async move { + let input_task = cx.spawn(async move |_| { let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task); stdout.or(stderr) }); @@ -595,26 +595,16 @@ impl LanguageServer { } pub fn default_initialize_params(&self, pull_diagnostics: bool, cx: &App) -> InitializeParams { - let workspace_folders = self.workspace_folders.as_ref().map_or_else( - || { - vec![WorkspaceFolder { - name: Default::default(), - uri: self.root_uri.clone(), - }] - }, - |folders| { - folders - .lock() - .iter() - .cloned() - .map(|uri| WorkspaceFolder { - name: Default::default(), - uri, - }) - .collect() - }, - ); - + let workspace_folders = self + .workspace_folders + .lock() + .iter() + .cloned() + .map(|uri| WorkspaceFolder { + name: Default::default(), + uri, + }) + .collect::<Vec<_>>(); #[allow(deprecated)] InitializeParams { process_id: None, @@ -846,7 +836,7 @@ impl LanguageServer { configuration: Arc<DidChangeConfigurationParams>, cx: &App, ) -> Task<Result<Arc<Self>>> { - cx.background_spawn(async move { + cx.spawn(async move |_| { let response = self .request::<request::Initialize>(params) .await @@ -887,41 +877,39 @@ impl LanguageServer { let server = self.server.clone(); let name = self.name.clone(); - let server_id = self.server_id; let mut timer = self.executor.timer(SERVER_SHUTDOWN_TIMEOUT).fuse(); - Some(async move { - log::debug!("language server shutdown started"); + Some( + async move { + log::debug!("language server shutdown started"); - select! { - request_result = shutdown_request.fuse() => { - match request_result { - ConnectionResult::Timeout => { - log::warn!("timeout waiting for language server {name} (id {server_id}) to shutdown"); - }, - ConnectionResult::ConnectionReset => { - log::warn!("language server {name} (id {server_id}) closed the shutdown request connection"); - }, - ConnectionResult::Result(Err(e)) => { - log::error!("Shutdown request failure, server {name} (id {server_id}): {e:#}"); - }, - ConnectionResult::Result(Ok(())) => {} + select! { + request_result = shutdown_request.fuse() => { + match request_result { + ConnectionResult::Timeout => { + log::warn!("timeout waiting for language server {name} to shutdown"); + }, + ConnectionResult::ConnectionReset => {}, + ConnectionResult::Result(r) => r?, + } } + + _ = timer => { + log::info!("timeout waiting for language server {name} to shutdown"); + }, } - _ = timer => { - log::info!("timeout waiting for language server {name} (id {server_id}) to shutdown"); - }, - } + response_handlers.lock().take(); + Self::notify_internal::<notification::Exit>(&outbound_tx, &()).ok(); + outbound_tx.close(); + output_done.recv().await; + server.lock().take().map(|mut child| child.kill()); + log::debug!("language server shutdown finished"); - response_handlers.lock().take(); - Self::notify_internal::<notification::Exit>(&outbound_tx, &()).ok(); - outbound_tx.close(); - output_done.recv().await; - server.lock().take().map(|mut child| child.kill()); - drop(tasks); - log::debug!("language server shutdown finished"); - Some(()) - }) + drop(tasks); + anyhow::Ok(()) + } + .log_err(), + ) } else { None } @@ -1325,10 +1313,7 @@ impl LanguageServer { return; } - let Some(workspace_folders) = self.workspace_folders.as_ref() else { - return; - }; - let is_new_folder = workspace_folders.lock().insert(uri.clone()); + let is_new_folder = self.workspace_folders.lock().insert(uri.clone()); if is_new_folder { let params = DidChangeWorkspaceFoldersParams { event: WorkspaceFoldersChangeEvent { @@ -1358,10 +1343,7 @@ impl LanguageServer { { return; } - let Some(workspace_folders) = self.workspace_folders.as_ref() else { - return; - }; - let was_removed = workspace_folders.lock().remove(&uri); + let was_removed = self.workspace_folders.lock().remove(&uri); if was_removed { let params = DidChangeWorkspaceFoldersParams { event: WorkspaceFoldersChangeEvent { @@ -1376,10 +1358,7 @@ impl LanguageServer { } } pub fn set_workspace_folders(&self, folders: BTreeSet<Url>) { - let Some(workspace_folders) = self.workspace_folders.as_ref() else { - return; - }; - let mut workspace_folders = workspace_folders.lock(); + let mut workspace_folders = self.workspace_folders.lock(); let old_workspace_folders = std::mem::take(&mut *workspace_folders); let added: Vec<_> = folders @@ -1408,11 +1387,8 @@ impl LanguageServer { } } - pub fn workspace_folders(&self) -> BTreeSet<Url> { - self.workspace_folders.as_ref().map_or_else( - || BTreeSet::from_iter([self.root_uri.clone()]), - |folders| folders.lock().clone(), - ) + pub fn workspace_folders(&self) -> impl Deref<Target = BTreeSet<Url>> + '_ { + self.workspace_folders.lock() } pub fn register_buffer( @@ -1472,33 +1448,6 @@ impl fmt::Debug for LanguageServer { } } -impl fmt::Debug for LanguageServerBinary { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut debug = f.debug_struct("LanguageServerBinary"); - debug.field("path", &self.path); - debug.field("arguments", &self.arguments); - - if let Some(env) = &self.env { - let redacted_env: BTreeMap<String, String> = env - .iter() - .map(|(key, value)| { - let redacted_value = if redact::should_redact(key) { - "REDACTED".to_string() - } else { - value.clone() - }; - (key.clone(), redacted_value) - }) - .collect(); - debug.field("env", &Some(redacted_env)); - } else { - debug.field("env", &self.env); - } - - debug.finish() - } -} - impl Drop for Subscription { fn drop(&mut self) { match self { @@ -1557,7 +1506,7 @@ impl FakeLanguageServer { None, binary.clone(), root, - Some(workspace_folders.clone()), + workspace_folders.clone(), cx, |_| {}, ); @@ -1576,7 +1525,7 @@ impl FakeLanguageServer { None, binary, Self::root_path(), - Some(workspace_folders), + workspace_folders, cx, move |msg| { notifications_tx diff --git a/crates/onboarding/Cargo.toml b/crates/onboarding/Cargo.toml index 8f684dd1b8..693e39d4ca 100644 --- a/crates/onboarding/Cargo.toml +++ b/crates/onboarding/Cargo.toml @@ -16,29 +16,13 @@ default = [] [dependencies] anyhow.workspace = true -ai_onboarding.workspace = true -client.workspace = true command_palette_hooks.workspace = true -component.workspace = true -documented.workspace = true db.workspace = true -editor.workspace = true feature_flags.workspace = true fs.workspace = true gpui.workspace = true -itertools.workspace = true -language.workspace = true -language_model.workspace = true -menu.workspace = true -project.workspace = true -schemars.workspace = true -serde.workspace = true settings.workspace = true theme.workspace = true ui.workspace = true -util.workspace = true -vim_mode_setting.workspace = true -workspace-hack.workspace = true workspace.workspace = true -zed_actions.workspace = true -zlog.workspace = true +workspace-hack.workspace = true diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs deleted file mode 100644 index 2f031e7bb8..0000000000 --- a/crates/onboarding/src/ai_setup_page.rs +++ /dev/null @@ -1,359 +0,0 @@ -use std::sync::Arc; - -use ai_onboarding::{AiUpsellCard, SignInStatus}; -use client::DisableAiSettings; -use fs::Fs; -use gpui::{ - Action, AnyView, App, DismissEvent, EventEmitter, FocusHandle, Focusable, Window, prelude::*, -}; -use itertools; - -use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry}; -use settings::{Settings, update_settings_file}; -use ui::{ - Badge, ButtonLike, Divider, Modal, ModalFooter, ModalHeader, Section, SwitchField, ToggleState, - prelude::*, -}; -use workspace::ModalView; - -use util::ResultExt; -use zed_actions::agent::OpenSettings; - -use crate::Onboarding; - -const FEATURED_PROVIDERS: [&'static str; 4] = ["anthropic", "google", "openai", "ollama"]; - -fn render_llm_provider_section( - onboarding: &Onboarding, - disabled: bool, - window: &mut Window, - cx: &mut App, -) -> impl IntoElement { - v_flex() - .gap_4() - .child( - v_flex() - .child(Label::new("Or use other LLM providers").size(LabelSize::Large)) - .child( - Label::new("Bring your API keys to use the available providers with Zed's UI for free.") - .color(Color::Muted), - ), - ) - .child(render_llm_provider_card(onboarding, disabled, window, cx)) -} - -fn render_privacy_card(disabled: bool, cx: &mut App) -> impl IntoElement { - let privacy_badge = || Badge::new("Privacy").icon(IconName::ShieldCheck); - - v_flex() - .relative() - .pt_2() - .pb_2p5() - .pl_3() - .pr_2() - .border_1() - .border_dashed() - .border_color(cx.theme().colors().border.opacity(0.5)) - .bg(cx.theme().colors().surface_background.opacity(0.3)) - .rounded_lg() - .overflow_hidden() - .map(|this| { - if disabled { - this.child( - h_flex() - .gap_2() - .justify_between() - .child( - h_flex() - .gap_1() - .child(Label::new("AI is disabled across Zed")) - .child( - Icon::new(IconName::Check) - .color(Color::Success) - .size(IconSize::XSmall), - ), - ) - .child(privacy_badge()), - ) - .child( - Label::new("Re-enable it any time in Settings.") - .size(LabelSize::Small) - .color(Color::Muted), - ) - } else { - this.child( - h_flex() - .gap_2() - .justify_between() - .child(Label::new("We don't train models using your data")) - .child( - h_flex().gap_1().child(privacy_badge()).child( - Button::new("learn_more", "Learn More") - .style(ButtonStyle::Outlined) - .label_size(LabelSize::Small) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .on_click(|_, _, cx| { - cx.open_url("https://zed.dev/docs/ai/privacy-and-security"); - }), - ), - ), - ) - .child( - Label::new( - "Feel confident in the security and privacy of your projects using Zed.", - ) - .size(LabelSize::Small) - .color(Color::Muted), - ) - } - }) -} - -fn render_llm_provider_card( - onboarding: &Onboarding, - disabled: bool, - _: &mut Window, - cx: &mut App, -) -> impl IntoElement { - let registry = LanguageModelRegistry::read_global(cx); - - v_flex() - .border_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().surface_background.opacity(0.5)) - .rounded_lg() - .overflow_hidden() - .children(itertools::intersperse_with( - FEATURED_PROVIDERS - .into_iter() - .flat_map(|provider_name| { - registry.provider(&LanguageModelProviderId::new(provider_name)) - }) - .enumerate() - .map(|(index, provider)| { - let group_name = SharedString::new(format!("onboarding-hover-group-{}", index)); - let is_authenticated = provider.is_authenticated(cx); - - ButtonLike::new(("onboarding-ai-setup-buttons", index)) - .size(ButtonSize::Large) - .child( - h_flex() - .group(&group_name) - .px_0p5() - .w_full() - .gap_2() - .justify_between() - .child( - h_flex() - .gap_1() - .child( - Icon::new(provider.icon()) - .color(Color::Muted) - .size(IconSize::XSmall), - ) - .child(Label::new(provider.name().0)), - ) - .child( - h_flex() - .gap_1() - .when(!is_authenticated, |el| { - el.visible_on_hover(group_name.clone()) - .child( - Icon::new(IconName::Settings) - .color(Color::Muted) - .size(IconSize::XSmall), - ) - .child( - Label::new("Configure") - .color(Color::Muted) - .size(LabelSize::Small), - ) - }) - .when(is_authenticated && !disabled, |el| { - el.child( - Icon::new(IconName::Check) - .color(Color::Success) - .size(IconSize::XSmall), - ) - .child( - Label::new("Configured") - .color(Color::Muted) - .size(LabelSize::Small), - ) - }), - ), - ) - .on_click({ - let workspace = onboarding.workspace.clone(); - move |_, window, cx| { - workspace - .update(cx, |workspace, cx| { - workspace.toggle_modal(window, cx, |window, cx| { - let modal = AiConfigurationModal::new( - provider.clone(), - window, - cx, - ); - window.focus(&modal.focus_handle(cx)); - modal - }); - }) - .log_err(); - } - }) - .into_any_element() - }), - || Divider::horizontal().into_any_element(), - )) - .child(Divider::horizontal()) - .child( - Button::new("agent_settings", "Add Many Others") - .size(ButtonSize::Large) - .icon(IconName::Plus) - .icon_position(IconPosition::Start) - .icon_color(Color::Muted) - .icon_size(IconSize::XSmall) - .on_click(|_event, window, cx| { - window.dispatch_action(OpenSettings.boxed_clone(), cx) - }), - ) -} - -pub(crate) fn render_ai_setup_page( - onboarding: &Onboarding, - window: &mut Window, - cx: &mut App, -) -> impl IntoElement { - let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; - - let backdrop = div() - .id("backdrop") - .size_full() - .absolute() - .inset_0() - .bg(cx.theme().colors().editor_background) - .opacity(0.8) - .block_mouse_except_scroll(); - - v_flex() - .gap_2() - .child(SwitchField::new( - "enable_ai", - "Enable AI features", - None, - if is_ai_disabled { - ToggleState::Unselected - } else { - ToggleState::Selected - }, - |toggle_state, _, cx| { - let enabled = match toggle_state { - ToggleState::Indeterminate => { - return; - } - ToggleState::Unselected => false, - ToggleState::Selected => true, - }; - - let fs = <dyn Fs>::global(cx); - update_settings_file::<DisableAiSettings>( - fs, - cx, - move |ai_settings: &mut Option<bool>, _| { - *ai_settings = Some(!enabled); - }, - ); - }, - )) - .child(render_privacy_card(is_ai_disabled, cx)) - .child( - v_flex() - .mt_2() - .gap_6() - .child(AiUpsellCard { - sign_in_status: SignInStatus::SignedIn, - sign_in: Arc::new(|_, _| {}), - user_plan: onboarding.user_store.read(cx).plan(), - }) - .child(render_llm_provider_section( - onboarding, - is_ai_disabled, - window, - cx, - )) - .when(is_ai_disabled, |this| this.child(backdrop)), - ) -} - -struct AiConfigurationModal { - focus_handle: FocusHandle, - selected_provider: Arc<dyn LanguageModelProvider>, - configuration_view: AnyView, -} - -impl AiConfigurationModal { - fn new( - selected_provider: Arc<dyn LanguageModelProvider>, - window: &mut Window, - cx: &mut Context<Self>, - ) -> Self { - let focus_handle = cx.focus_handle(); - let configuration_view = selected_provider.configuration_view(window, cx); - - Self { - focus_handle, - configuration_view, - selected_provider, - } - } -} - -impl ModalView for AiConfigurationModal {} - -impl EventEmitter<DismissEvent> for AiConfigurationModal {} - -impl Focusable for AiConfigurationModal { - fn focus_handle(&self, _cx: &App) -> FocusHandle { - self.focus_handle.clone() - } -} - -impl Render for AiConfigurationModal { - fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - v_flex() - .w(rems(34.)) - .elevation_3(cx) - .track_focus(&self.focus_handle) - .child( - Modal::new("onboarding-ai-setup-modal", None) - .header( - ModalHeader::new() - .icon( - Icon::new(self.selected_provider.icon()) - .color(Color::Muted) - .size(IconSize::Small), - ) - .headline(self.selected_provider.name().0), - ) - .section(Section::new().child(self.configuration_view.clone())) - .footer( - ModalFooter::new().end_slot( - h_flex() - .gap_1() - .child( - Button::new("onboarding-closing-cancel", "Cancel") - .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), - ) - .child(Button::new("save-btn", "Done").on_click(cx.listener( - |_, _, window, cx| { - window.dispatch_action(menu::Confirm.boxed_clone(), cx); - cx.emit(DismissEvent); - }, - ))), - ), - ), - ) - } -} diff --git a/crates/onboarding/src/basics_page.rs b/crates/onboarding/src/basics_page.rs deleted file mode 100644 index 82688e6220..0000000000 --- a/crates/onboarding/src/basics_page.rs +++ /dev/null @@ -1,351 +0,0 @@ -use client::TelemetrySettings; -use fs::Fs; -use gpui::{App, Entity, IntoElement, Window}; -use settings::{BaseKeymap, Settings, update_settings_file}; -use theme::{Appearance, ThemeMode, ThemeName, ThemeRegistry, ThemeSelection, ThemeSettings}; -use ui::{ - ParentElement as _, StatefulInteractiveElement, SwitchField, ToggleButtonGroup, - ToggleButtonSimple, ToggleButtonWithIcon, prelude::*, rems_from_px, -}; -use vim_mode_setting::VimModeSetting; - -use crate::theme_preview::ThemePreviewTile; - -/// separates theme "mode" ("dark" | "light" | "system") into two separate states -/// - appearance = "dark" | "light" -/// - "system" true/false -/// when system selected: -/// - toggling between light and dark does not change theme.mode, just which variant will be changed -/// when system not selected: -/// - toggling between light and dark does change theme.mode -/// selecting a theme preview will always change theme.["light" | "dark"] to the selected theme, -/// -/// this allows for selecting a dark and light theme option regardless of whether the mode is set to system or not -/// it does not support setting theme to a static value -fn render_theme_section(window: &mut Window, cx: &mut App) -> impl IntoElement { - let theme_selection = ThemeSettings::get_global(cx).theme_selection.clone(); - let system_appearance = theme::SystemAppearance::global(cx); - let appearance_state = window.use_state(cx, |_, _cx| { - theme_selection - .as_ref() - .and_then(|selection| selection.mode()) - .and_then(|mode| match mode { - ThemeMode::System => None, - ThemeMode::Light => Some(Appearance::Light), - ThemeMode::Dark => Some(Appearance::Dark), - }) - .unwrap_or(*system_appearance) - }); - let appearance = *appearance_state.read(cx); - let theme_selection = theme_selection.unwrap_or_else(|| ThemeSelection::Dynamic { - mode: match *system_appearance { - Appearance::Light => ThemeMode::Light, - Appearance::Dark => ThemeMode::Dark, - }, - light: ThemeName("One Light".into()), - dark: ThemeName("One Dark".into()), - }); - let theme_registry = ThemeRegistry::global(cx); - - let current_theme_name = theme_selection.theme(appearance); - let theme_mode = theme_selection.mode().unwrap_or_default(); - - // let theme_mode = theme_selection.mode(); - // TODO: Clean this up once the "System" button inside the - // toggle button group is done - - let selected_index = match appearance { - Appearance::Light => 0, - Appearance::Dark => 1, - }; - - let theme_seed = 0xBEEF as f32; - - const LIGHT_THEMES: [&'static str; 3] = ["One Light", "Ayu Light", "Gruvbox Light"]; - const DARK_THEMES: [&'static str; 3] = ["One Dark", "Ayu Dark", "Gruvbox Dark"]; - - let theme_names = match appearance { - Appearance::Light => LIGHT_THEMES, - Appearance::Dark => DARK_THEMES, - }; - let themes = theme_names - .map(|theme_name| theme_registry.get(theme_name)) - .map(Result::unwrap); - - let theme_previews = themes.map(|theme| { - let is_selected = theme.name == current_theme_name; - let name = theme.name.clone(); - let colors = cx.theme().colors(); - - v_flex() - .id(name.clone()) - .w_full() - .items_center() - .gap_1() - .child( - div() - .w_full() - .border_2() - .border_color(colors.border_transparent) - .rounded(ThemePreviewTile::CORNER_RADIUS) - .map(|this| { - if is_selected { - this.border_color(colors.border_selected) - } else { - this.opacity(0.8).hover(|s| s.border_color(colors.border)) - } - }) - .child(ThemePreviewTile::new(theme.clone(), theme_seed)), - ) - .child(Label::new(name).color(Color::Muted).size(LabelSize::Small)) - .on_click({ - let theme_name = theme.name.clone(); - move |_, _, cx| { - let fs = <dyn Fs>::global(cx); - let theme_name = theme_name.clone(); - update_settings_file::<ThemeSettings>(fs, cx, move |settings, _| { - settings.set_theme(theme_name, appearance); - }); - } - }) - }); - - return v_flex() - .gap_2() - .child( - h_flex().justify_between().child(Label::new("Theme")).child( - ToggleButtonGroup::single_row( - "theme-selector-onboarding-dark-light", - [ - ToggleButtonSimple::new("Light", { - let appearance_state = appearance_state.clone(); - move |_, _, cx| { - write_appearance_change(&appearance_state, Appearance::Light, cx); - } - }), - ToggleButtonSimple::new("Dark", { - let appearance_state = appearance_state.clone(); - move |_, _, cx| { - write_appearance_change(&appearance_state, Appearance::Dark, cx); - } - }), - // TODO: Properly put the System back as a button within this group - // Currently, given "System" is not an option in the Appearance enum, - // this button doesn't get selected - ToggleButtonSimple::new("System", { - let theme = theme_selection.clone(); - move |_, _, cx| { - toggle_system_theme_mode(theme.clone(), appearance, cx); - } - }) - .selected(theme_mode == ThemeMode::System), - ], - ) - .selected_index(selected_index) - .style(ui::ToggleButtonGroupStyle::Outlined) - .button_width(rems_from_px(64.)), - ), - ) - .child(h_flex().gap_4().justify_between().children(theme_previews)); - - fn write_appearance_change( - appearance_state: &Entity<Appearance>, - new_appearance: Appearance, - cx: &mut App, - ) { - let fs = <dyn Fs>::global(cx); - appearance_state.write(cx, new_appearance); - - update_settings_file::<ThemeSettings>(fs, cx, move |settings, _| { - if settings.theme.as_ref().and_then(ThemeSelection::mode) == Some(ThemeMode::System) { - return; - } - let new_mode = match new_appearance { - Appearance::Light => ThemeMode::Light, - Appearance::Dark => ThemeMode::Dark, - }; - settings.set_mode(new_mode); - }); - } - - fn toggle_system_theme_mode( - theme_selection: ThemeSelection, - appearance: Appearance, - cx: &mut App, - ) { - let fs = <dyn Fs>::global(cx); - - update_settings_file::<ThemeSettings>(fs, cx, move |settings, _| { - settings.theme = Some(match theme_selection { - ThemeSelection::Static(theme_name) => ThemeSelection::Dynamic { - mode: ThemeMode::System, - light: theme_name.clone(), - dark: theme_name.clone(), - }, - ThemeSelection::Dynamic { - mode: ThemeMode::System, - light, - dark, - } => { - let mode = match appearance { - Appearance::Light => ThemeMode::Light, - Appearance::Dark => ThemeMode::Dark, - }; - ThemeSelection::Dynamic { mode, light, dark } - } - ThemeSelection::Dynamic { - mode: _, - light, - dark, - } => ThemeSelection::Dynamic { - mode: ThemeMode::System, - light, - dark, - }, - }); - }); - } -} - -fn write_keymap_base(keymap_base: BaseKeymap, cx: &App) { - let fs = <dyn Fs>::global(cx); - - update_settings_file::<BaseKeymap>(fs, cx, move |setting, _| { - *setting = Some(keymap_base); - }); -} - -fn render_telemetry_section(cx: &App) -> impl IntoElement { - let fs = <dyn Fs>::global(cx); - - v_flex() - .gap_4() - .child(Label::new("Telemetry").size(LabelSize::Large)) - .child(SwitchField::new( - "onboarding-telemetry-metrics", - "Help Improve Zed", - Some("Sending anonymous usage data helps us build the right features and create the best experience.".into()), - if TelemetrySettings::get_global(cx).metrics { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - { - let fs = fs.clone(); - move |selection, _, cx| { - let enabled = match selection { - ToggleState::Selected => true, - ToggleState::Unselected => false, - ToggleState::Indeterminate => { return; }, - }; - - update_settings_file::<TelemetrySettings>( - fs.clone(), - cx, - move |setting, _| setting.metrics = Some(enabled), - ); - }}, - )) - .child(SwitchField::new( - "onboarding-telemetry-crash-reports", - "Help Fix Zed", - Some("Send crash reports so we can fix critical issues fast.".into()), - if TelemetrySettings::get_global(cx).diagnostics { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - { - let fs = fs.clone(); - move |selection, _, cx| { - let enabled = match selection { - ToggleState::Selected => true, - ToggleState::Unselected => false, - ToggleState::Indeterminate => { return; }, - }; - - update_settings_file::<TelemetrySettings>( - fs.clone(), - cx, - move |setting, _| setting.diagnostics = Some(enabled), - ); - } - } - )) -} - -pub(crate) fn render_basics_page(window: &mut Window, cx: &mut App) -> impl IntoElement { - let base_keymap = match BaseKeymap::get_global(cx) { - BaseKeymap::VSCode => Some(0), - BaseKeymap::JetBrains => Some(1), - BaseKeymap::SublimeText => Some(2), - BaseKeymap::Atom => Some(3), - BaseKeymap::Emacs => Some(4), - BaseKeymap::Cursor => Some(5), - BaseKeymap::TextMate | BaseKeymap::None => None, - }; - - v_flex() - .gap_6() - .child(render_theme_section(window, cx)) - .child( - v_flex().gap_2().child(Label::new("Base Keymap")).child( - ToggleButtonGroup::two_rows( - "multiple_row_test", - [ - ToggleButtonWithIcon::new("VS Code", IconName::EditorVsCode, |_, _, cx| { - write_keymap_base(BaseKeymap::VSCode, cx); - }), - ToggleButtonWithIcon::new("Jetbrains", IconName::EditorJetBrains, |_, _, cx| { - write_keymap_base(BaseKeymap::JetBrains, cx); - }), - ToggleButtonWithIcon::new("Sublime Text", IconName::EditorSublime, |_, _, cx| { - write_keymap_base(BaseKeymap::SublimeText, cx); - }), - ], - [ - ToggleButtonWithIcon::new("Atom", IconName::EditorAtom, |_, _, cx| { - write_keymap_base(BaseKeymap::Atom, cx); - }), - ToggleButtonWithIcon::new("Emacs", IconName::EditorEmacs, |_, _, cx| { - write_keymap_base(BaseKeymap::Emacs, cx); - }), - ToggleButtonWithIcon::new("Cursor (Beta)", IconName::EditorCursor, |_, _, cx| { - write_keymap_base(BaseKeymap::Cursor, cx); - }), - ], - ) - .when_some(base_keymap, |this, base_keymap| this.selected_index(base_keymap)) - .button_width(rems_from_px(216.)) - .size(ui::ToggleButtonGroupSize::Medium) - .style(ui::ToggleButtonGroupStyle::Outlined) - ), - ) - .child(SwitchField::new( - "onboarding-vim-mode", - "Vim Mode", - Some("Coming from Neovim? Zed's first-class implementation of Vim Mode has got your back.".into()), - if VimModeSetting::get_global(cx).0 { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - { - let fs = <dyn Fs>::global(cx); - move |selection, _, cx| { - let enabled = match selection { - ToggleState::Selected => true, - ToggleState::Unselected => false, - ToggleState::Indeterminate => { return; }, - }; - - update_settings_file::<VimModeSetting>( - fs.clone(), - cx, - move |setting, _| *setting = Some(enabled), - ); - } - }, - )) - .child(render_telemetry_section(cx)) -} diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs deleted file mode 100644 index 2972f41348..0000000000 --- a/crates/onboarding/src/editing_page.rs +++ /dev/null @@ -1,457 +0,0 @@ -use std::sync::Arc; - -use editor::{EditorSettings, ShowMinimap}; -use fs::Fs; -use gpui::{Action, App, FontFeatures, IntoElement, Pixels, Window}; -use language::language_settings::{AllLanguageSettings, FormatOnSave}; -use project::project_settings::ProjectSettings; -use settings::{Settings as _, update_settings_file}; -use theme::{FontFamilyCache, FontFamilyName, ThemeSettings}; -use ui::{ - ButtonLike, ContextMenu, DropdownMenu, NumericStepper, SwitchField, ToggleButtonGroup, - ToggleButtonGroupStyle, ToggleButtonSimple, ToggleState, prelude::*, -}; - -use crate::{ImportCursorSettings, ImportVsCodeSettings}; - -fn read_show_mini_map(cx: &App) -> ShowMinimap { - editor::EditorSettings::get_global(cx).minimap.show -} - -fn write_show_mini_map(show: ShowMinimap, cx: &mut App) { - let fs = <dyn Fs>::global(cx); - - // This is used to speed up the UI - // the UI reads the current values to get what toggle state to show on buttons - // there's a slight delay if we just call update_settings_file so we manually set - // the value here then call update_settings file to get around the delay - let mut curr_settings = EditorSettings::get_global(cx).clone(); - curr_settings.minimap.show = show; - EditorSettings::override_global(curr_settings, cx); - - update_settings_file::<EditorSettings>(fs, cx, move |editor_settings, _| { - editor_settings.minimap.get_or_insert_default().show = Some(show); - }); -} - -fn read_inlay_hints(cx: &App) -> bool { - AllLanguageSettings::get_global(cx) - .defaults - .inlay_hints - .enabled -} - -fn write_inlay_hints(enabled: bool, cx: &mut App) { - let fs = <dyn Fs>::global(cx); - - let mut curr_settings = AllLanguageSettings::get_global(cx).clone(); - curr_settings.defaults.inlay_hints.enabled = enabled; - AllLanguageSettings::override_global(curr_settings, cx); - - update_settings_file::<AllLanguageSettings>(fs, cx, move |all_language_settings, cx| { - all_language_settings - .defaults - .inlay_hints - .get_or_insert_with(|| { - AllLanguageSettings::get_global(cx) - .clone() - .defaults - .inlay_hints - }) - .enabled = enabled; - }); -} - -fn read_git_blame(cx: &App) -> bool { - ProjectSettings::get_global(cx).git.inline_blame_enabled() -} - -fn set_git_blame(enabled: bool, cx: &mut App) { - let fs = <dyn Fs>::global(cx); - - let mut curr_settings = ProjectSettings::get_global(cx).clone(); - curr_settings - .git - .inline_blame - .get_or_insert_default() - .enabled = enabled; - ProjectSettings::override_global(curr_settings, cx); - - update_settings_file::<ProjectSettings>(fs, cx, move |project_settings, _| { - project_settings - .git - .inline_blame - .get_or_insert_default() - .enabled = enabled; - }); -} - -fn write_ui_font_family(font: SharedString, cx: &mut App) { - let fs = <dyn Fs>::global(cx); - - update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { - theme_settings.ui_font_family = Some(FontFamilyName(font.into())); - }); -} - -fn write_ui_font_size(size: Pixels, cx: &mut App) { - let fs = <dyn Fs>::global(cx); - - update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { - theme_settings.ui_font_size = Some(size.into()); - }); -} - -fn write_buffer_font_size(size: Pixels, cx: &mut App) { - let fs = <dyn Fs>::global(cx); - - update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { - theme_settings.buffer_font_size = Some(size.into()); - }); -} - -fn write_buffer_font_family(font_family: SharedString, cx: &mut App) { - let fs = <dyn Fs>::global(cx); - - update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { - theme_settings.buffer_font_family = Some(FontFamilyName(font_family.into())); - }); -} - -fn read_font_ligatures(cx: &App) -> bool { - ThemeSettings::get_global(cx) - .buffer_font - .features - .is_calt_enabled() - .unwrap_or(true) -} - -fn write_font_ligatures(enabled: bool, cx: &mut App) { - let fs = <dyn Fs>::global(cx); - let bit = if enabled { 1 } else { 0 }; - - update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { - let mut features = theme_settings - .buffer_font_features - .as_mut() - .map(|features| features.tag_value_list().to_vec()) - .unwrap_or_default(); - - if let Some(calt_index) = features.iter().position(|(tag, _)| tag == "calt") { - features[calt_index].1 = bit; - } else { - features.push(("calt".into(), bit)); - } - - theme_settings.buffer_font_features = Some(FontFeatures(Arc::new(features))); - }); -} - -fn read_format_on_save(cx: &App) -> bool { - match AllLanguageSettings::get_global(cx).defaults.format_on_save { - FormatOnSave::On | FormatOnSave::List(_) => true, - FormatOnSave::Off => false, - } -} - -fn write_format_on_save(format_on_save: bool, cx: &mut App) { - let fs = <dyn Fs>::global(cx); - - update_settings_file::<AllLanguageSettings>(fs, cx, move |language_settings, _| { - language_settings.defaults.format_on_save = Some(match format_on_save { - true => FormatOnSave::On, - false => FormatOnSave::Off, - }); - }); -} - -fn render_import_settings_section() -> impl IntoElement { - v_flex() - .gap_4() - .child( - v_flex() - .child(Label::new("Import Settings").size(LabelSize::Large)) - .child( - Label::new("Automatically pull your settings from other editors.") - .color(Color::Muted), - ), - ) - .child( - h_flex() - .w_full() - .gap_4() - .child( - h_flex().w_full().child( - ButtonLike::new("import_vs_code") - .full_width() - .style(ButtonStyle::Outlined) - .size(ButtonSize::Large) - .child( - h_flex() - .w_full() - .gap_1p5() - .px_1() - .child( - Icon::new(IconName::EditorVsCode) - .color(Color::Muted) - .size(IconSize::XSmall), - ) - .child(Label::new("VS Code")), - ) - .on_click(|_, window, cx| { - window.dispatch_action( - ImportVsCodeSettings::default().boxed_clone(), - cx, - ) - }), - ), - ) - .child( - h_flex().w_full().child( - ButtonLike::new("import_cursor") - .full_width() - .style(ButtonStyle::Outlined) - .size(ButtonSize::Large) - .child( - h_flex() - .w_full() - .gap_1p5() - .px_1() - .child( - Icon::new(IconName::EditorCursor) - .color(Color::Muted) - .size(IconSize::XSmall), - ) - .child(Label::new("Cursor")), - ) - .on_click(|_, window, cx| { - window.dispatch_action( - ImportCursorSettings::default().boxed_clone(), - cx, - ) - }), - ), - ), - ) -} - -fn render_font_customization_section(window: &mut Window, cx: &mut App) -> impl IntoElement { - let theme_settings = ThemeSettings::get_global(cx); - let ui_font_size = theme_settings.ui_font_size(cx); - let font_family = theme_settings.buffer_font.family.clone(); - let buffer_font_size = theme_settings.buffer_font_size(cx); - - h_flex() - .w_full() - .gap_4() - .child( - v_flex() - .w_full() - .gap_1() - .child(Label::new("UI Font")) - .child( - h_flex() - .w_full() - .justify_between() - .gap_2() - .child( - DropdownMenu::new( - "ui-font-family", - theme_settings.ui_font.family.clone(), - ContextMenu::build(window, cx, |mut menu, _, cx| { - let font_family_cache = FontFamilyCache::global(cx); - - for font_name in font_family_cache.list_font_families(cx) { - menu = menu.custom_entry( - { - let font_name = font_name.clone(); - move |_window, _cx| { - Label::new(font_name.clone()).into_any_element() - } - }, - { - let font_name = font_name.clone(); - move |_window, cx| { - write_ui_font_family(font_name.clone(), cx); - } - }, - ) - } - - menu - }), - ) - .style(ui::DropdownStyle::Outlined) - .full_width(true), - ) - .child( - NumericStepper::new( - "ui-font-size", - ui_font_size.to_string(), - move |_, _, cx| { - write_ui_font_size(ui_font_size - px(1.), cx); - }, - move |_, _, cx| { - write_ui_font_size(ui_font_size + px(1.), cx); - }, - ) - .style(ui::NumericStepperStyle::Outlined), - ), - ), - ) - .child( - v_flex() - .w_full() - .gap_1() - .child(Label::new("Editor Font")) - .child( - h_flex() - .w_full() - .justify_between() - .gap_2() - .child( - DropdownMenu::new( - "buffer-font-family", - font_family, - ContextMenu::build(window, cx, |mut menu, _, cx| { - let font_family_cache = FontFamilyCache::global(cx); - - for font_name in font_family_cache.list_font_families(cx) { - menu = menu.custom_entry( - { - let font_name = font_name.clone(); - move |_window, _cx| { - Label::new(font_name.clone()).into_any_element() - } - }, - { - let font_name = font_name.clone(); - move |_window, cx| { - write_buffer_font_family(font_name.clone(), cx); - } - }, - ) - } - - menu - }), - ) - .style(ui::DropdownStyle::Outlined) - .full_width(true), - ) - .child( - NumericStepper::new( - "buffer-font-size", - buffer_font_size.to_string(), - move |_, _, cx| { - write_buffer_font_size(buffer_font_size - px(1.), cx); - }, - move |_, _, cx| { - write_buffer_font_size(buffer_font_size + px(1.), cx); - }, - ) - .style(ui::NumericStepperStyle::Outlined), - ), - ), - ) -} - -fn render_popular_settings_section(window: &mut Window, cx: &mut App) -> impl IntoElement { - v_flex() - .gap_5() - .child(Label::new("Popular Settings").size(LabelSize::Large).mt_8()) - .child(render_font_customization_section(window, cx)) - .child(SwitchField::new( - "onboarding-font-ligatures", - "Font Ligatures", - Some("Combine text characters into their associated symbols.".into()), - if read_font_ligatures(cx) { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - |toggle_state, _, cx| { - write_font_ligatures(toggle_state == &ToggleState::Selected, cx); - }, - )) - .child(SwitchField::new( - "onboarding-format-on-save", - "Format on Save", - Some("Format code automatically when saving.".into()), - if read_format_on_save(cx) { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - |toggle_state, _, cx| { - write_format_on_save(toggle_state == &ToggleState::Selected, cx); - }, - )) - .child( - h_flex() - .items_start() - .justify_between() - .child( - v_flex().child(Label::new("Mini Map")).child( - Label::new("See a high-level overview of your source code.") - .color(Color::Muted), - ), - ) - .child( - ToggleButtonGroup::single_row( - "onboarding-show-mini-map", - [ - ToggleButtonSimple::new("Auto", |_, _, cx| { - write_show_mini_map(ShowMinimap::Auto, cx); - }), - ToggleButtonSimple::new("Always", |_, _, cx| { - write_show_mini_map(ShowMinimap::Always, cx); - }), - ToggleButtonSimple::new("Never", |_, _, cx| { - write_show_mini_map(ShowMinimap::Never, cx); - }), - ], - ) - .selected_index(match read_show_mini_map(cx) { - ShowMinimap::Auto => 0, - ShowMinimap::Always => 1, - ShowMinimap::Never => 2, - }) - .style(ToggleButtonGroupStyle::Outlined) - .button_width(ui::rems_from_px(64.)), - ), - ) - .child(SwitchField::new( - "onboarding-enable-inlay-hints", - "Inlay Hints", - Some("See parameter names for function and method calls inline.".into()), - if read_inlay_hints(cx) { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - |toggle_state, _, cx| { - write_inlay_hints(toggle_state == &ToggleState::Selected, cx); - }, - )) - .child(SwitchField::new( - "onboarding-git-blame-switch", - "Git Blame", - Some("See who committed each line on a given file.".into()), - if read_git_blame(cx) { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - |toggle_state, _, cx| { - set_git_blame(toggle_state == &ToggleState::Selected, cx); - }, - )) -} - -pub(crate) fn render_editing_page(window: &mut Window, cx: &mut App) -> impl IntoElement { - v_flex() - .gap_4() - .child(render_import_settings_section()) - .child(render_popular_settings_section(window, cx)) -} diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index f7e76f2f34..1ce236f941 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -1,60 +1,33 @@ -use crate::welcome::{ShowWelcome, WelcomePage}; -use client::{Client, UserStore}; use command_palette_hooks::CommandPaletteFilter; use db::kvp::KEY_VALUE_STORE; use feature_flags::{FeatureFlag, FeatureFlagViewExt as _}; use fs::Fs; use gpui::{ - Action, AnyElement, App, AppContext, AsyncWindowContext, Context, Entity, EventEmitter, - FocusHandle, Focusable, IntoElement, KeyContext, Render, SharedString, Subscription, Task, - WeakEntity, Window, actions, + AnyElement, App, AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, + IntoElement, Render, SharedString, Subscription, Task, WeakEntity, Window, actions, }; -use schemars::JsonSchema; -use serde::Deserialize; -use settings::{SettingsStore, VsCodeSettingsSource}; +use settings::{Settings, SettingsStore, update_settings_file}; use std::sync::Arc; +use theme::{ThemeMode, ThemeSettings}; use ui::{ - Avatar, ButtonLike, FluentBuilder, Headline, KeyBinding, ParentElement as _, - StatefulInteractiveElement, Vector, VectorName, prelude::*, rems_from_px, + ButtonCommon as _, ButtonSize, ButtonStyle, Clickable as _, Color, Divider, FluentBuilder, + Headline, InteractiveElement, KeyBinding, Label, LabelCommon, ParentElement as _, + StatefulInteractiveElement, Styled, ToggleButton, Toggleable as _, Vector, VectorName, div, + h_flex, rems, v_container, v_flex, }; use workspace::{ AppState, Workspace, WorkspaceId, dock::DockPosition, item::{Item, ItemEvent}, - notifications::NotifyResultExt as _, - open_new, register_serializable_item, with_active_or_new_workspace, + open_new, with_active_or_new_workspace, }; -mod ai_setup_page; -mod basics_page; -mod editing_page; -mod theme_preview; -mod welcome; - pub struct OnBoardingFeatureFlag {} impl FeatureFlag for OnBoardingFeatureFlag { const NAME: &'static str = "onboarding"; } -/// Imports settings from Visual Studio Code. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportVsCodeSettings { - #[serde(default)] - pub skip_prompt: bool, -} - -/// Imports settings from Cursor editor. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportCursorSettings { - #[serde(default)] - pub skip_prompt: bool, -} - pub const FIRST_OPEN: &str = "first_open"; actions!( @@ -65,18 +38,6 @@ actions!( ] ); -actions!( - onboarding, - [ - /// Activates the Basics page. - ActivateBasicsPage, - /// Activates the Editing page. - ActivateEditingPage, - /// Activates the AI Setup page. - ActivateAISetupPage, - ] -); - pub fn init(cx: &mut App) { cx.on_action(|_: &OpenOnboarding, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { @@ -91,7 +52,7 @@ pub fn init(cx: &mut App) { if let Some(existing) = existing { workspace.activate_item(&existing, true, true, window, cx); } else { - let settings_page = Onboarding::new(workspace, cx); + let settings_page = Onboarding::new(workspace.weak_handle(), cx); workspace.add_item_to_active_pane( Box::new(settings_page), None, @@ -104,80 +65,12 @@ pub fn init(cx: &mut App) { .detach(); }); }); - - cx.on_action(|_: &ShowWelcome, cx| { - with_active_or_new_workspace(cx, |workspace, window, cx| { - workspace - .with_local_workspace(window, cx, |workspace, window, cx| { - let existing = workspace - .active_pane() - .read(cx) - .items() - .find_map(|item| item.downcast::<WelcomePage>()); - - if let Some(existing) = existing { - workspace.activate_item(&existing, true, true, window, cx); - } else { - let settings_page = WelcomePage::new(window, cx); - workspace.add_item_to_active_pane( - Box::new(settings_page), - None, - true, - window, - cx, - ) - } - }) - .detach(); - }); - }); - - cx.observe_new(|workspace: &mut Workspace, _window, _cx| { - workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { - let fs = <dyn Fs>::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::VsCode, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - - workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { - let fs = <dyn Fs>::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::Cursor, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - }) - .detach(); - cx.observe_new::<Workspace>(|_, window, cx| { let Some(window) = window else { return; }; - let onboarding_actions = [ - std::any::TypeId::of::<OpenOnboarding>(), - std::any::TypeId::of::<ShowWelcome>(), - ]; + let onboarding_actions = [std::any::TypeId::of::<OpenOnboarding>()]; CommandPaletteFilter::update_global(cx, |filter, _cx| { filter.hide_action_types(&onboarding_actions); @@ -197,7 +90,6 @@ pub fn init(cx: &mut App) { .detach(); }) .detach(); - register_serializable_item::<Onboarding>(cx); } pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyhow::Result<()>> { @@ -208,7 +100,7 @@ pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyh |workspace, window, cx| { { workspace.toggle_dock(DockPosition::Left, window, cx); - let onboarding_page = Onboarding::new(workspace, cx); + let onboarding_page = Onboarding::new(workspace.weak_handle(), cx); workspace.add_item_to_center(Box::new(onboarding_page.clone()), window, cx); window.focus(&onboarding_page.focus_handle(cx)); @@ -222,6 +114,23 @@ pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyh ) } +fn read_theme_selection(cx: &App) -> ThemeMode { + let settings = ThemeSettings::get_global(cx); + settings + .theme_selection + .as_ref() + .and_then(|selection| selection.mode()) + .unwrap_or_default() +} + +fn write_theme_selection(theme_mode: ThemeMode, cx: &App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |settings, _| { + settings.set_mode(theme_mode); + }); +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SelectedPage { Basics, @@ -233,278 +142,175 @@ struct Onboarding { workspace: WeakEntity<Workspace>, focus_handle: FocusHandle, selected_page: SelectedPage, - user_store: Entity<UserStore>, _settings_subscription: Subscription, } impl Onboarding { - fn new(workspace: &Workspace, cx: &mut App) -> Entity<Self> { + fn new(workspace: WeakEntity<Workspace>, cx: &mut App) -> Entity<Self> { cx.new(|cx| Self { - workspace: workspace.weak_handle(), + workspace, focus_handle: cx.focus_handle(), selected_page: SelectedPage::Basics, - user_store: workspace.user_store().clone(), _settings_subscription: cx.observe_global::<SettingsStore>(move |_, cx| cx.notify()), }) } - fn set_page(&mut self, page: SelectedPage, cx: &mut Context<Self>) { - self.selected_page = page; - cx.notify(); - cx.emit(ItemEvent::UpdateTab); - } - - fn render_nav_buttons( + fn render_page_nav( &mut self, - window: &mut Window, + page: SelectedPage, + _: &mut Window, cx: &mut Context<Self>, - ) -> [impl IntoElement; 3] { - let pages = [ - SelectedPage::Basics, - SelectedPage::Editing, - SelectedPage::AiSetup, - ]; - - let text = ["Basics", "Editing", "AI Setup"]; - - let actions: [&dyn Action; 3] = [ - &ActivateBasicsPage, - &ActivateEditingPage, - &ActivateAISetupPage, - ]; - - let mut binding = actions.map(|action| { - KeyBinding::for_action_in(action, &self.focus_handle, window, cx) - .map(|kb| kb.size(rems_from_px(12.))) - }); - - pages.map(|page| { - let i = page as usize; - let selected = self.selected_page == page; - h_flex() - .id(text[i]) - .relative() - .w_full() - .gap_2() - .px_2() - .py_0p5() - .justify_between() - .rounded_sm() - .when(selected, |this| { - this.child( - div() - .h_4() - .w_px() - .bg(cx.theme().colors().text_accent) - .absolute() - .left_0(), - ) - }) - .hover(|style| style.bg(cx.theme().colors().element_hover)) - .child(Label::new(text[i]).map(|this| { - if selected { - this.color(Color::Default) - } else { - this.color(Color::Muted) - } - })) - .child(binding[i].take().map_or( - gpui::Empty.into_any_element(), - IntoElement::into_any_element, - )) - .on_click(cx.listener(move |this, _, _, cx| { - this.set_page(page, cx); - })) - }) - } - - fn render_nav(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - v_flex() - .h_full() - .w(rems_from_px(220.)) - .flex_shrink_0() - .gap_4() + ) -> impl IntoElement { + let text = match page { + SelectedPage::Basics => "Basics", + SelectedPage::Editing => "Editing", + SelectedPage::AiSetup => "AI Setup", + }; + let binding = match page { + SelectedPage::Basics => { + KeyBinding::new(vec![gpui::Keystroke::parse("cmd-1").unwrap()], cx) + } + SelectedPage::Editing => { + KeyBinding::new(vec![gpui::Keystroke::parse("cmd-2").unwrap()], cx) + } + SelectedPage::AiSetup => { + KeyBinding::new(vec![gpui::Keystroke::parse("cmd-3").unwrap()], cx) + } + }; + let selected = self.selected_page == page; + h_flex() + .id(text) + .rounded_sm() + .child(text) + .child(binding) + .h_8() + .gap_2() + .px_2() + .py_0p5() + .w_full() .justify_between() - .child( - v_flex() - .gap_6() - .child( - h_flex() - .px_2() - .gap_4() - .child(Vector::square(VectorName::ZedLogo, rems(2.5))) - .child( - v_flex() - .child( - Headline::new("Welcome to Zed").size(HeadlineSize::Small), - ) - .child( - Label::new("The editor for what's next") - .color(Color::Muted) - .size(LabelSize::Small) - .italic(), - ), - ), - ) - .child( - v_flex() - .gap_4() - .child( - v_flex() - .py_4() - .border_y_1() - .border_color(cx.theme().colors().border_variant.opacity(0.5)) - .gap_1() - .children(self.render_nav_buttons(window, cx)), - ) - .child( - ButtonLike::new("skip_all") - .child(Label::new("Skip All").ml_1()) - .on_click(|_, _, cx| { - with_active_or_new_workspace( - cx, - |workspace, window, cx| { - let Some((onboarding_id, onboarding_idx)) = - workspace - .active_pane() - .read(cx) - .items() - .enumerate() - .find_map(|(idx, item)| { - let _ = - item.downcast::<Onboarding>()?; - Some((item.item_id(), idx)) - }) - else { - return; - }; - - workspace.active_pane().update(cx, |pane, cx| { - // Get the index here to get around the borrow checker - let idx = pane.items().enumerate().find_map( - |(idx, item)| { - let _ = - item.downcast::<WelcomePage>()?; - Some(idx) - }, - ); - - if let Some(idx) = idx { - pane.activate_item( - idx, true, true, window, cx, - ); - } else { - let item = - Box::new(WelcomePage::new(window, cx)); - pane.add_item( - item, - true, - true, - Some(onboarding_idx), - window, - cx, - ); - } - - pane.remove_item( - onboarding_id, - false, - false, - window, - cx, - ); - }); - }, - ); - }), - ), - ), - ) - .child( - if let Some(user) = self.user_store.read(cx).current_user() { - h_flex() - .pl_1p5() - .gap_2() - .child(Avatar::new(user.avatar_uri.clone())) - .child(Label::new(user.github_login.clone())) - .into_any_element() + .map(|this| { + if selected { + this.bg(Color::Selected.color(cx)) + .border_l_1() + .border_color(Color::Accent.color(cx)) } else { - Button::new("sign_in", "Sign In") - .style(ButtonStyle::Outlined) - .full_width() - .on_click(|_, window, cx| { - let client = Client::global(cx); - window - .spawn(cx, async move |cx| { - client - .sign_in_with_optional_connect(true, &cx) - .await - .notify_async_err(cx); - }) - .detach(); - }) - .into_any_element() - }, - ) + this.text_color(Color::Muted.color(cx)) + } + }) + .hover(|style| { + if selected { + style.bg(Color::Selected.color(cx).opacity(0.6)) + } else { + style.bg(Color::Selected.color(cx).opacity(0.3)) + } + }) + .on_click(cx.listener(move |this, _, _, cx| { + this.selected_page = page; + cx.notify(); + })) } fn render_page(&mut self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement { match self.selected_page { - SelectedPage::Basics => { - crate::basics_page::render_basics_page(window, cx).into_any_element() - } - SelectedPage::Editing => { - crate::editing_page::render_editing_page(window, cx).into_any_element() - } - SelectedPage::AiSetup => { - crate::ai_setup_page::render_ai_setup_page(&self, window, cx).into_any_element() - } + SelectedPage::Basics => self.render_basics_page(window, cx).into_any_element(), + SelectedPage::Editing => self.render_editing_page(window, cx).into_any_element(), + SelectedPage::AiSetup => self.render_ai_setup_page(window, cx).into_any_element(), } } + + fn render_basics_page(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let theme_mode = read_theme_selection(cx); + + v_container().child( + h_flex() + .items_center() + .justify_between() + .child(Label::new("Theme")) + .child( + h_flex() + .rounded_md() + .child( + ToggleButton::new("light", "Light") + .style(ButtonStyle::Filled) + .size(ButtonSize::Large) + .toggle_state(theme_mode == ThemeMode::Light) + .on_click(|_, _, cx| write_theme_selection(ThemeMode::Light, cx)) + .first(), + ) + .child( + ToggleButton::new("dark", "Dark") + .style(ButtonStyle::Filled) + .size(ButtonSize::Large) + .toggle_state(theme_mode == ThemeMode::Dark) + .on_click(|_, _, cx| write_theme_selection(ThemeMode::Dark, cx)) + .last(), + ) + .child( + ToggleButton::new("system", "System") + .style(ButtonStyle::Filled) + .size(ButtonSize::Large) + .toggle_state(theme_mode == ThemeMode::System) + .on_click(|_, _, cx| write_theme_selection(ThemeMode::System, cx)) + .middle(), + ), + ), + ) + } + + fn render_editing_page(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { + // div().child("editing page") + "Right" + } + + fn render_ai_setup_page(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { + div().child("ai setup page") + } } impl Render for Onboarding { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { h_flex() .image_cache(gpui::retain_all("onboarding-page")) - .key_context({ - let mut ctx = KeyContext::new_with_defaults(); - ctx.add("Onboarding"); - ctx - }) - .track_focus(&self.focus_handle) - .size_full() - .bg(cx.theme().colors().editor_background) - .on_action(cx.listener(|this, _: &ActivateBasicsPage, _, cx| { - this.set_page(SelectedPage::Basics, cx); - })) - .on_action(cx.listener(|this, _: &ActivateEditingPage, _, cx| { - this.set_page(SelectedPage::Editing, cx); - })) - .on_action(cx.listener(|this, _: &ActivateAISetupPage, _, cx| { - this.set_page(SelectedPage::AiSetup, cx); - })) + .key_context("onboarding-page") + .px_24() + .py_12() + .items_start() .child( - h_flex() - .max_w(rems_from_px(1100.)) - .size_full() - .m_auto() - .py_20() - .px_12() - .items_start() - .gap_12() - .child(self.render_nav(window, cx)) + v_flex() + .w_1_3() + .h_full() .child( - v_flex() - .max_w_full() - .min_w_0() - .pl_12() - .border_l_1() - .border_color(cx.theme().colors().border_variant.opacity(0.5)) - .size_full() - .child(self.render_page(window, cx)), + h_flex() + .pt_0p5() + .child(Vector::square(VectorName::ZedLogo, rems(2.))) + .child( + v_flex() + .left_1() + .items_center() + .child(Headline::new("Welcome to Zed")) + .child( + Label::new("The editor for what's next") + .color(Color::Muted) + .italic(), + ), + ), + ) + .p_1() + .child(Divider::horizontal_dashed()) + .child( + v_flex().gap_1().children([ + self.render_page_nav(SelectedPage::Basics, window, cx) + .into_element(), + self.render_page_nav(SelectedPage::Editing, window, cx) + .into_element(), + self.render_page_nav(SelectedPage::AiSetup, window, cx) + .into_element(), + ]), ), ) + // .child(Divider::vertical_dashed()) + .child(div().w_2_3().h_full().child(self.render_page(window, cx))) } } @@ -537,185 +343,10 @@ impl Item for Onboarding { _: &mut Window, cx: &mut Context<Self>, ) -> Option<Entity<Self>> { - self.workspace - .update(cx, |workspace, cx| Onboarding::new(workspace, cx)) - .ok() + Some(Onboarding::new(self.workspace.clone(), cx)) } fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { f(*event) } } - -pub async fn handle_import_vscode_settings( - source: VsCodeSettingsSource, - skip_prompt: bool, - fs: Arc<dyn Fs>, - cx: &mut AsyncWindowContext, -) { - use util::truncate_and_remove_front; - - let vscode_settings = - match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { - Ok(vscode_settings) => vscode_settings, - Err(err) => { - zlog::error!("{err}"); - let _ = cx.prompt( - gpui::PromptLevel::Info, - &format!("Could not find or load a {source} settings file"), - None, - &["Ok"], - ); - return; - } - }; - - if !skip_prompt { - let prompt = cx.prompt( - gpui::PromptLevel::Warning, - &format!( - "Importing {} settings may overwrite your existing settings. \ - Will import settings from {}", - vscode_settings.source, - truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), - ), - None, - &["Ok", "Cancel"], - ); - let result = cx.spawn(async move |_| prompt.await.ok()).await; - if result != Some(0) { - return; - } - }; - - cx.update(|_, cx| { - let source = vscode_settings.source; - let path = vscode_settings.path.clone(); - cx.global::<SettingsStore>() - .import_vscode_settings(fs, vscode_settings); - zlog::info!("Imported {source} settings from {}", path.display()); - }) - .ok(); -} - -impl workspace::SerializableItem for Onboarding { - fn serialized_item_kind() -> &'static str { - "OnboardingPage" - } - - fn cleanup( - workspace_id: workspace::WorkspaceId, - alive_items: Vec<workspace::ItemId>, - _window: &mut Window, - cx: &mut App, - ) -> gpui::Task<gpui::Result<()>> { - workspace::delete_unloaded_items( - alive_items, - workspace_id, - "onboarding_pages", - &persistence::ONBOARDING_PAGES, - cx, - ) - } - - fn deserialize( - _project: Entity<project::Project>, - workspace: WeakEntity<Workspace>, - workspace_id: workspace::WorkspaceId, - item_id: workspace::ItemId, - window: &mut Window, - cx: &mut App, - ) -> gpui::Task<gpui::Result<Entity<Self>>> { - window.spawn(cx, async move |cx| { - if let Some(page_number) = - persistence::ONBOARDING_PAGES.get_onboarding_page(item_id, workspace_id)? - { - let page = match page_number { - 0 => Some(SelectedPage::Basics), - 1 => Some(SelectedPage::Editing), - 2 => Some(SelectedPage::AiSetup), - _ => None, - }; - workspace.update(cx, |workspace, cx| { - let onboarding_page = Onboarding::new(workspace, cx); - if let Some(page) = page { - zlog::info!("Onboarding page {page:?} loaded"); - onboarding_page.update(cx, |onboarding_page, cx| { - onboarding_page.set_page(page, cx); - }) - } - onboarding_page - }) - } else { - Err(anyhow::anyhow!("No onboarding page to deserialize")) - } - }) - } - - fn serialize( - &mut self, - workspace: &mut Workspace, - item_id: workspace::ItemId, - _closing: bool, - _window: &mut Window, - cx: &mut ui::Context<Self>, - ) -> Option<gpui::Task<gpui::Result<()>>> { - let workspace_id = workspace.database_id()?; - let page_number = self.selected_page as u16; - Some(cx.background_spawn(async move { - persistence::ONBOARDING_PAGES - .save_onboarding_page(item_id, workspace_id, page_number) - .await - })) - } - - fn should_serialize(&self, event: &Self::Event) -> bool { - event == &ItemEvent::UpdateTab - } -} - -mod persistence { - use db::{define_connection, query, sqlez_macros::sql}; - use workspace::WorkspaceDb; - - define_connection! { - pub static ref ONBOARDING_PAGES: OnboardingPagesDb<WorkspaceDb> = - &[ - sql!( - CREATE TABLE onboarding_pages ( - workspace_id INTEGER, - item_id INTEGER UNIQUE, - page_number INTEGER, - - PRIMARY KEY(workspace_id, item_id), - FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) - ON DELETE CASCADE - ) STRICT; - ), - ]; - } - - impl OnboardingPagesDb { - query! { - pub async fn save_onboarding_page( - item_id: workspace::ItemId, - workspace_id: workspace::WorkspaceId, - page_number: u16 - ) -> Result<()> { - INSERT OR REPLACE INTO onboarding_pages(item_id, workspace_id, page_number) - VALUES (?, ?, ?) - } - } - - query! { - pub fn get_onboarding_page( - item_id: workspace::ItemId, - workspace_id: workspace::WorkspaceId - ) -> Result<Option<u16>> { - SELECT page_number - FROM onboarding_pages - WHERE item_id = ? AND workspace_id = ? - } - } - } -} diff --git a/crates/onboarding/src/welcome.rs b/crates/onboarding/src/welcome.rs deleted file mode 100644 index 3d2c034367..0000000000 --- a/crates/onboarding/src/welcome.rs +++ /dev/null @@ -1,336 +0,0 @@ -use gpui::{ - Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, - NoAction, ParentElement, Render, Styled, Window, actions, -}; -use ui::{ButtonLike, Divider, DividerColor, KeyBinding, Vector, VectorName, prelude::*}; -use workspace::{ - NewFile, Open, WorkspaceId, - item::{Item, ItemEvent}, - with_active_or_new_workspace, -}; -use zed_actions::{Extensions, OpenSettings, agent, command_palette}; - -use crate::{Onboarding, OpenOnboarding}; - -actions!( - zed, - [ - /// Show the Zed welcome screen - ShowWelcome - ] -); - -const CONTENT: (Section<4>, Section<3>) = ( - Section { - title: "Get Started", - entries: [ - SectionEntry { - icon: IconName::Plus, - title: "New File", - action: &NewFile, - }, - SectionEntry { - icon: IconName::FolderOpen, - title: "Open Project", - action: &Open, - }, - SectionEntry { - icon: IconName::CloudDownload, - title: "Clone a Repo", - // TODO: use proper action - action: &NoAction, - }, - SectionEntry { - icon: IconName::ListCollapse, - title: "Open Command Palette", - action: &command_palette::Toggle, - }, - ], - }, - Section { - title: "Configure", - entries: [ - SectionEntry { - icon: IconName::Settings, - title: "Open Settings", - action: &OpenSettings, - }, - SectionEntry { - icon: IconName::ZedAssistant, - title: "View AI Settings", - action: &agent::OpenSettings, - }, - SectionEntry { - icon: IconName::Blocks, - title: "Explore Extensions", - action: &Extensions { - category_filter: None, - id: None, - }, - }, - ], - }, -); - -struct Section<const COLS: usize> { - title: &'static str, - entries: [SectionEntry; COLS], -} - -impl<const COLS: usize> Section<COLS> { - fn render( - self, - index_offset: usize, - focus: &FocusHandle, - window: &mut Window, - cx: &mut App, - ) -> impl IntoElement { - v_flex() - .min_w_full() - .gap_2() - .child( - h_flex() - .px_1() - .gap_4() - .child( - Label::new(self.title.to_ascii_uppercase()) - .buffer_font(cx) - .color(Color::Muted) - .size(LabelSize::XSmall), - ) - .child(Divider::horizontal().color(DividerColor::Border)), - ) - .children( - self.entries - .iter() - .enumerate() - .map(|(index, entry)| entry.render(index_offset + index, &focus, window, cx)), - ) - } -} - -struct SectionEntry { - icon: IconName, - title: &'static str, - action: &'static dyn Action, -} - -impl SectionEntry { - fn render( - &self, - button_index: usize, - focus: &FocusHandle, - window: &Window, - cx: &App, - ) -> impl IntoElement { - ButtonLike::new(("onboarding-button-id", button_index)) - .full_width() - .child( - h_flex() - .w_full() - .gap_1() - .justify_between() - .child( - h_flex() - .gap_2() - .child( - Icon::new(self.icon) - .color(Color::Muted) - .size(IconSize::XSmall), - ) - .child(Label::new(self.title)), - ) - .children(KeyBinding::for_action_in(self.action, focus, window, cx)), - ) - .on_click(|_, window, cx| window.dispatch_action(self.action.boxed_clone(), cx)) - } -} - -pub struct WelcomePage { - focus_handle: FocusHandle, -} - -impl Render for WelcomePage { - fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - let (first_section, second_entries) = CONTENT; - let first_section_entries = first_section.entries.len(); - - h_flex() - .size_full() - .justify_center() - .overflow_hidden() - .bg(cx.theme().colors().editor_background) - .key_context("Welcome") - .track_focus(&self.focus_handle(cx)) - .child( - h_flex() - .px_12() - .py_40() - .size_full() - .relative() - .max_w(px(1100.)) - .child( - div() - .size_full() - .max_w_128() - .mx_auto() - .child( - h_flex() - .w_full() - .justify_center() - .gap_4() - .child(Vector::square(VectorName::ZedLogo, rems(2.))) - .child( - div().child(Headline::new("Welcome to Zed")).child( - Label::new("The editor for what's next") - .size(LabelSize::Small) - .color(Color::Muted) - .italic(), - ), - ), - ) - .child( - v_flex() - .mt_12() - .gap_8() - .child(first_section.render( - Default::default(), - &self.focus_handle, - window, - cx, - )) - .child(second_entries.render( - first_section_entries, - &self.focus_handle, - window, - cx, - )) - .child( - h_flex() - .w_full() - .pt_4() - .justify_center() - // We call this a hack - .rounded_b_xs() - .border_t_1() - .border_color(DividerColor::Border.hsla(cx)) - .border_dashed() - .child( - div().child( - Button::new("welcome-exit", "Return to Setup") - .full_width() - .label_size(LabelSize::XSmall) - .on_click(|_, window, cx| { - window.dispatch_action( - OpenOnboarding.boxed_clone(), - cx, - ); - - with_active_or_new_workspace(cx, |workspace, window, cx| { - let Some((welcome_id, welcome_idx)) = workspace - .active_pane() - .read(cx) - .items() - .enumerate() - .find_map(|(idx, item)| { - let _ = item.downcast::<WelcomePage>()?; - Some((item.item_id(), idx)) - }) - else { - return; - }; - - workspace.active_pane().update(cx, |pane, cx| { - // Get the index here to get around the borrow checker - let idx = pane.items().enumerate().find_map( - |(idx, item)| { - let _ = - item.downcast::<Onboarding>()?; - Some(idx) - }, - ); - - if let Some(idx) = idx { - pane.activate_item( - idx, true, true, window, cx, - ); - } else { - let item = - Box::new(Onboarding::new(workspace, cx)); - pane.add_item( - item, - true, - true, - Some(welcome_idx), - window, - cx, - ); - } - - pane.remove_item( - welcome_id, - false, - false, - window, - cx, - ); - }); - }); - }), - ), - ), - ), - ), - ), - ) - } -} - -impl WelcomePage { - pub fn new(window: &mut Window, cx: &mut App) -> Entity<Self> { - cx.new(|cx| { - let focus_handle = cx.focus_handle(); - cx.on_focus(&focus_handle, window, |_, _, cx| cx.notify()) - .detach(); - - WelcomePage { focus_handle } - }) - } -} - -impl EventEmitter<ItemEvent> for WelcomePage {} - -impl Focusable for WelcomePage { - fn focus_handle(&self, _: &App) -> gpui::FocusHandle { - self.focus_handle.clone() - } -} - -impl Item for WelcomePage { - type Event = ItemEvent; - - fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { - "Welcome".into() - } - - fn telemetry_event_text(&self) -> Option<&'static str> { - Some("New Welcome Page Opened") - } - - fn show_toolbar(&self) -> bool { - false - } - - fn clone_on_split( - &self, - _workspace_id: Option<WorkspaceId>, - _: &mut Window, - _: &mut Context<Self>, - ) -> Option<Entity<Self>> { - None - } - - fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { - f(*event) - } -} diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index ad96670db9..12dcab9e87 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -1,5 +1,19 @@ mod outline_panel_settings; +use std::{ + cmp, + collections::BTreeMap, + hash::Hash, + ops::Range, + path::{MAIN_SEPARATOR_STR, Path, PathBuf}, + sync::{ + Arc, OnceLock, + atomic::{self, AtomicBool}, + }, + time::Duration, + u32, +}; + use anyhow::Context as _; use collections::{BTreeSet, HashMap, HashSet, hash_map}; use db::kvp::KEY_VALUE_STORE; @@ -22,21 +36,8 @@ use gpui::{ uniform_list, }; use itertools::Itertools; -use language::{Anchor, BufferId, BufferSnapshot, OffsetRangeExt, OutlineItem}; +use language::{BufferId, BufferSnapshot, OffsetRangeExt, OutlineItem}; use menu::{Cancel, SelectFirst, SelectLast, SelectNext, SelectPrevious}; -use std::{ - cmp, - collections::BTreeMap, - hash::Hash, - ops::Range, - path::{MAIN_SEPARATOR_STR, Path, PathBuf}, - sync::{ - Arc, OnceLock, - atomic::{self, AtomicBool}, - }, - time::Duration, - u32, -}; use outline_panel_settings::{OutlinePanelDockPosition, OutlinePanelSettings, ShowIndentGuides}; use project::{File, Fs, GitEntry, GitTraversal, Project, ProjectItem}; @@ -131,8 +132,6 @@ pub struct OutlinePanel { hide_scrollbar_task: Option<Task<()>>, max_width_item_index: Option<usize>, preserve_selection_on_buffer_fold_toggles: HashSet<BufferId>, - pending_default_expansion_depth: Option<usize>, - outline_children_cache: HashMap<BufferId, HashMap<(Range<Anchor>, usize), bool>>, } #[derive(Debug)] @@ -319,13 +318,12 @@ struct CachedEntry { entry: PanelEntry, } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] enum CollapsedEntry { Dir(WorktreeId, ProjectEntryId), File(WorktreeId, BufferId), ExternalFile(BufferId), Excerpt(BufferId, ExcerptId), - Outline(BufferId, ExcerptId, Range<Anchor>), } #[derive(Debug)] @@ -805,56 +803,8 @@ impl OutlinePanel { outline_panel.update_cached_entries(Some(UPDATE_DEBOUNCE), window, cx); } } else if &outline_panel_settings != new_settings { - let old_expansion_depth = outline_panel_settings.expand_outlines_with_depth; outline_panel_settings = *new_settings; - - if old_expansion_depth != new_settings.expand_outlines_with_depth { - let old_collapsed_entries = outline_panel.collapsed_entries.clone(); - outline_panel - .collapsed_entries - .retain(|entry| !matches!(entry, CollapsedEntry::Outline(..))); - - let new_depth = new_settings.expand_outlines_with_depth; - - for (buffer_id, excerpts) in &outline_panel.excerpts { - for (excerpt_id, excerpt) in excerpts { - if let ExcerptOutlines::Outlines(outlines) = &excerpt.outlines { - for outline in outlines { - if outline_panel - .outline_children_cache - .get(buffer_id) - .and_then(|children_map| { - let key = - (outline.range.clone(), outline.depth); - children_map.get(&key) - }) - .copied() - .unwrap_or(false) - && (new_depth == 0 || outline.depth >= new_depth) - { - outline_panel.collapsed_entries.insert( - CollapsedEntry::Outline( - *buffer_id, - *excerpt_id, - outline.range.clone(), - ), - ); - } - } - } - } - } - - if old_collapsed_entries != outline_panel.collapsed_entries { - outline_panel.update_cached_entries( - Some(UPDATE_DEBOUNCE), - window, - cx, - ); - } - } else { - cx.notify(); - } + cx.notify(); } }); @@ -891,7 +841,6 @@ impl OutlinePanel { updating_cached_entries: false, new_entries_for_fs_update: HashSet::default(), preserve_selection_on_buffer_fold_toggles: HashSet::default(), - pending_default_expansion_depth: None, fs_entries_update_task: Task::ready(()), cached_entries_update_task: Task::ready(()), reveal_selection_task: Task::ready(Ok(())), @@ -906,7 +855,6 @@ impl OutlinePanel { workspace_subscription, filter_update_subscription, ], - outline_children_cache: HashMap::default(), }; if let Some((item, editor)) = workspace_active_editor(workspace, cx) { outline_panel.replace_active_editor(item, editor, window, cx); @@ -1041,7 +989,7 @@ impl OutlinePanel { fn open_excerpts( &mut self, - action: &editor::actions::OpenExcerpts, + action: &editor::OpenExcerpts, window: &mut Window, cx: &mut Context<Self>, ) { @@ -1057,7 +1005,7 @@ impl OutlinePanel { fn open_excerpts_split( &mut self, - action: &editor::actions::OpenExcerptsSplit, + action: &editor::OpenExcerptsSplit, window: &mut Window, cx: &mut Context<Self>, ) { @@ -1514,12 +1462,7 @@ impl OutlinePanel { PanelEntry::Outline(OutlineEntry::Excerpt(excerpt)) => { Some(CollapsedEntry::Excerpt(excerpt.buffer_id, excerpt.id)) } - PanelEntry::Outline(OutlineEntry::Outline(outline)) => Some(CollapsedEntry::Outline( - outline.buffer_id, - outline.excerpt_id, - outline.outline.range.clone(), - )), - PanelEntry::Search(_) => return, + PanelEntry::Search(_) | PanelEntry::Outline(..) => return, }; let Some(collapsed_entry) = entry_to_expand else { return; @@ -1622,14 +1565,7 @@ impl OutlinePanel { PanelEntry::Outline(OutlineEntry::Excerpt(excerpt)) => self .collapsed_entries .insert(CollapsedEntry::Excerpt(excerpt.buffer_id, excerpt.id)), - PanelEntry::Outline(OutlineEntry::Outline(outline)) => { - self.collapsed_entries.insert(CollapsedEntry::Outline( - outline.buffer_id, - outline.excerpt_id, - outline.outline.range.clone(), - )) - } - PanelEntry::Search(_) => false, + PanelEntry::Search(_) | PanelEntry::Outline(..) => false, }; if collapsed { @@ -1844,17 +1780,7 @@ impl OutlinePanel { self.collapsed_entries.insert(collapsed_entry); } } - PanelEntry::Outline(OutlineEntry::Outline(outline)) => { - let collapsed_entry = CollapsedEntry::Outline( - outline.buffer_id, - outline.excerpt_id, - outline.outline.range.clone(), - ); - if !self.collapsed_entries.remove(&collapsed_entry) { - self.collapsed_entries.insert(collapsed_entry); - } - } - _ => {} + PanelEntry::Search(_) | PanelEntry::Outline(..) => return, } active_editor.update(cx, |editor, cx| { @@ -2182,7 +2108,7 @@ impl OutlinePanel { PanelEntry::Outline(OutlineEntry::Excerpt(excerpt.clone())), item_id, depth, - icon, + Some(icon), is_active, label_element, window, @@ -2234,31 +2160,10 @@ impl OutlinePanel { _ => false, }; - let has_children = self - .outline_children_cache - .get(&outline.buffer_id) - .and_then(|children_map| { - let key = (outline.outline.range.clone(), outline.outline.depth); - children_map.get(&key) - }) - .copied() - .unwrap_or(false); - let is_expanded = !self.collapsed_entries.contains(&CollapsedEntry::Outline( - outline.buffer_id, - outline.excerpt_id, - outline.outline.range.clone(), - )); - - let icon = if has_children { - FileIcons::get_chevron_icon(is_expanded, cx) - .map(|icon_path| { - Icon::from_path(icon_path) - .color(entry_label_color(is_active)) - .into_any_element() - }) - .unwrap_or_else(empty_icon) + let icon = if self.is_singleton_active(cx) { + None } else { - empty_icon() + Some(empty_icon()) }; self.entry_element( @@ -2382,7 +2287,7 @@ impl OutlinePanel { PanelEntry::Fs(rendered_entry.clone()), item_id, depth, - icon, + Some(icon), is_active, label_element, window, @@ -2453,7 +2358,7 @@ impl OutlinePanel { PanelEntry::FoldedDirs(folded_dir.clone()), item_id, depth, - icon, + Some(icon), is_active, label_element, window, @@ -2544,7 +2449,7 @@ impl OutlinePanel { }), ElementId::from(SharedString::from(format!("search-{match_range:?}"))), depth, - empty_icon(), + None, is_active, entire_label, window, @@ -2557,7 +2462,7 @@ impl OutlinePanel { rendered_entry: PanelEntry, item_id: ElementId, depth: usize, - icon_element: AnyElement, + icon_element: Option<AnyElement>, is_active: bool, label_element: gpui::AnyElement, window: &mut Window, @@ -2573,10 +2478,8 @@ impl OutlinePanel { if event.down.button == MouseButton::Right || event.down.first_mouse { return; } - let change_focus = event.down.click_count > 1; outline_panel.toggle_expanded(&clicked_entry, window, cx); - outline_panel.scroll_editor_to_entry( &clicked_entry, true, @@ -2592,11 +2495,10 @@ impl OutlinePanel { .indent_level(depth) .indent_step_size(px(settings.indent_size)) .toggle_state(is_active) - .child( - h_flex() - .child(h_flex().w(px(16.)).justify_center().child(icon_element)) - .child(h_flex().h_6().child(label_element).ml_1()), - ) + .when_some(icon_element, |list_item, icon_element| { + list_item.child(h_flex().child(icon_element)) + }) + .child(h_flex().h_6().child(label_element).ml_1()) .on_secondary_mouse_down(cx.listener( move |outline_panel, event: &MouseDownEvent, window, cx| { // Stop propagation to prevent the catch-all context menu for the project @@ -3038,12 +2940,7 @@ impl OutlinePanel { outline_panel.fs_entries_depth = new_depth_map; outline_panel.fs_children_count = new_children_count; outline_panel.update_non_fs_items(window, cx); - - // Only update cached entries if we don't have outlines to fetch - // If we do have outlines to fetch, let fetch_outdated_outlines handle the update - if outline_panel.excerpt_fetch_ranges(cx).is_empty() { - outline_panel.update_cached_entries(debounce, window, cx); - } + outline_panel.update_cached_entries(debounce, window, cx); cx.notify(); }) @@ -3059,12 +2956,6 @@ impl OutlinePanel { cx: &mut Context<Self>, ) { self.clear_previous(window, cx); - - let default_expansion_depth = - OutlinePanelSettings::get_global(cx).expand_outlines_with_depth; - // We'll apply the expansion depth after outlines are loaded - self.pending_default_expansion_depth = Some(default_expansion_depth); - let buffer_search_subscription = cx.subscribe_in( &new_active_editor, window, @@ -3113,7 +3004,6 @@ impl OutlinePanel { self.selected_entry = SelectedEntry::None; self.pinned = false; self.mode = ItemsDisplayMode::Outline; - self.pending_default_expansion_depth = None; } fn location_for_editor_selection( @@ -3369,74 +3259,25 @@ impl OutlinePanel { || buffer_language.as_ref() == buffer_snapshot.language_at(outline.range.start) }); - - let outlines_with_children = outlines - .windows(2) - .filter_map(|window| { - let current = &window[0]; - let next = &window[1]; - if next.depth > current.depth { - Some((current.range.clone(), current.depth)) - } else { - None - } - }) - .collect::<HashSet<_>>(); - - (outlines, outlines_with_children) + outlines }) .await; - - let (fetched_outlines, outlines_with_children) = fetched_outlines; - outline_panel .update_in(cx, |outline_panel, window, cx| { - let pending_default_depth = - outline_panel.pending_default_expansion_depth.take(); - - let debounce = - if first_update.fetch_and(false, atomic::Ordering::AcqRel) { - None - } else { - Some(UPDATE_DEBOUNCE) - }; - if let Some(excerpt) = outline_panel .excerpts .entry(buffer_id) .or_default() .get_mut(&excerpt_id) { + let debounce = if first_update + .fetch_and(false, atomic::Ordering::AcqRel) + { + None + } else { + Some(UPDATE_DEBOUNCE) + }; excerpt.outlines = ExcerptOutlines::Outlines(fetched_outlines); - - if let Some(default_depth) = pending_default_depth { - if let ExcerptOutlines::Outlines(outlines) = - &excerpt.outlines - { - outlines - .iter() - .filter(|outline| { - (default_depth == 0 - || outline.depth >= default_depth) - && outlines_with_children.contains(&( - outline.range.clone(), - outline.depth, - )) - }) - .for_each(|outline| { - outline_panel.collapsed_entries.insert( - CollapsedEntry::Outline( - buffer_id, - excerpt_id, - outline.range.clone(), - ), - ); - }); - } - } - - // Even if no outlines to check, we still need to update cached entries - // to show the outline entries that were just fetched outline_panel.update_cached_entries(debounce, window, cx); } }) @@ -4242,7 +4083,7 @@ impl OutlinePanel { } fn add_excerpt_entries( - &mut self, + &self, state: &mut GenerationState, buffer_id: BufferId, entries_to_add: &[ExcerptId], @@ -4253,8 +4094,6 @@ impl OutlinePanel { cx: &mut Context<Self>, ) { if let Some(excerpts) = self.excerpts.get(&buffer_id) { - let buffer_snapshot = self.buffer_snapshot_for_id(buffer_id, cx); - for &excerpt_id in entries_to_add { let Some(excerpt) = excerpts.get(&excerpt_id) else { continue; @@ -4284,84 +4123,15 @@ impl OutlinePanel { continue; } - let mut last_depth_at_level: Vec<Option<Range<Anchor>>> = vec![None; 10]; - - let all_outlines: Vec<_> = excerpt.iter_outlines().collect(); - - let mut outline_has_children = HashMap::default(); - let mut visible_outlines = Vec::new(); - let mut collapsed_state: Option<(usize, Range<Anchor>)> = None; - - for (i, &outline) in all_outlines.iter().enumerate() { - let has_children = all_outlines - .get(i + 1) - .map(|next| next.depth > outline.depth) - .unwrap_or(false); - - outline_has_children - .insert((outline.range.clone(), outline.depth), has_children); - - let mut should_include = true; - - if let Some((collapsed_depth, collapsed_range)) = &collapsed_state { - if outline.depth <= *collapsed_depth { - collapsed_state = None; - } else if let Some(buffer_snapshot) = buffer_snapshot.as_ref() { - let outline_start = outline.range.start; - if outline_start - .cmp(&collapsed_range.start, buffer_snapshot) - .is_ge() - && outline_start - .cmp(&collapsed_range.end, buffer_snapshot) - .is_lt() - { - should_include = false; // Skip - inside collapsed range - } else { - collapsed_state = None; - } - } - } - - // Check if this outline itself is collapsed - if should_include - && self.collapsed_entries.contains(&CollapsedEntry::Outline( - buffer_id, - excerpt_id, - outline.range.clone(), - )) - { - collapsed_state = Some((outline.depth, outline.range.clone())); - } - - if should_include { - visible_outlines.push(outline); - } - } - - self.outline_children_cache - .entry(buffer_id) - .or_default() - .extend(outline_has_children); - - for outline in visible_outlines { - let outline_entry = OutlineEntryOutline { - buffer_id, - excerpt_id, - outline: outline.clone(), - }; - - if outline.depth < last_depth_at_level.len() { - last_depth_at_level[outline.depth] = Some(outline.range.clone()); - // Clear deeper levels when we go back to a shallower depth - for d in (outline.depth + 1)..last_depth_at_level.len() { - last_depth_at_level[d] = None; - } - } - + for outline in excerpt.iter_outlines() { self.push_entry( state, track_matches, - PanelEntry::Outline(OutlineEntry::Outline(outline_entry)), + PanelEntry::Outline(OutlineEntry::Outline(OutlineEntryOutline { + buffer_id, + excerpt_id, + outline: outline.clone(), + })), outline_base_depth + outline.depth, cx, ); @@ -5958,7 +5728,7 @@ mod tests { }); outline_panel.update_in(cx, |outline_panel, window, cx| { - outline_panel.open_excerpts(&editor::actions::OpenExcerpts, window, cx); + outline_panel.open_excerpts(&editor::OpenExcerpts, window, cx); }); cx.executor() .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); @@ -7138,540 +6908,4 @@ outline: struct OutlineEntryExcerpt multi_buffer_snapshot.text_for_range(line_start..line_end).collect::<String>().trim().to_owned() }) } - - #[gpui::test] - async fn test_outline_keyboard_expand_collapse(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.background_executor.clone()); - fs.insert_tree( - "/test", - json!({ - "src": { - "lib.rs": indoc!(" - mod outer { - pub struct OuterStruct { - field: String, - } - impl OuterStruct { - pub fn new() -> Self { - Self { field: String::new() } - } - pub fn method(&self) { - println!(\"{}\", self.field); - } - } - mod inner { - pub fn inner_function() { - let x = 42; - println!(\"{}\", x); - } - pub struct InnerStruct { - value: i32, - } - } - } - fn main() { - let s = outer::OuterStruct::new(); - s.method(); - } - "), - } - }), - ) - .await; - - let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; - project.read_with(cx, |project, _| { - project.languages().add(Arc::new( - rust_lang() - .with_outline_query( - r#" - (struct_item - (visibility_modifier)? @context - "struct" @context - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @context - "for"? @context - type: (_) @context - body: (_)) @item - (function_item - (visibility_modifier)? @context - "fn" @context - name: (_) @name - parameters: (_) @context) @item - (mod_item - (visibility_modifier)? @context - "mod" @context - name: (_) @name) @item - (enum_item - (visibility_modifier)? @context - "enum" @context - name: (_) @name) @item - (field_declaration - (visibility_modifier)? @context - name: (_) @name - ":" @context - type: (_) @context) @item - "#, - ) - .unwrap(), - )) - }); - let workspace = add_outline_panel(&project, cx).await; - let cx = &mut VisualTestContext::from_window(*workspace, cx); - let outline_panel = outline_panel(&workspace, cx); - - outline_panel.update_in(cx, |outline_panel, window, cx| { - outline_panel.set_active(true, window, cx) - }); - - workspace - .update(cx, |workspace, window, cx| { - workspace.open_abs_path( - PathBuf::from("/test/src/lib.rs"), - OpenOptions { - visible: Some(OpenVisible::All), - ..Default::default() - }, - window, - cx, - ) - }) - .unwrap() - .await - .unwrap(); - - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(500)); - cx.run_until_parked(); - - // Force another update cycle to ensure outlines are fetched - outline_panel.update_in(cx, |panel, window, cx| { - panel.update_non_fs_items(window, cx); - panel.update_cached_entries(Some(UPDATE_DEBOUNCE), window, cx); - }); - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(500)); - cx.run_until_parked(); - - outline_panel.update(cx, |outline_panel, cx| { - assert_eq!( - display_entries( - &project, - &snapshot(&outline_panel, cx), - &outline_panel.cached_entries, - outline_panel.selected_entry(), - cx, - ), - indoc!( - " -outline: mod outer <==== selected - outline: pub struct OuterStruct - outline: field: String - outline: impl OuterStruct - outline: pub fn new() - outline: pub fn method(&self) - outline: mod inner - outline: pub fn inner_function() - outline: pub struct InnerStruct - outline: value: i32 -outline: fn main()" - ) - ); - }); - - let parent_outline = outline_panel - .read_with(cx, |panel, _cx| { - panel - .cached_entries - .iter() - .find_map(|entry| match &entry.entry { - PanelEntry::Outline(OutlineEntry::Outline(outline)) - if panel - .outline_children_cache - .get(&outline.buffer_id) - .and_then(|children_map| { - let key = - (outline.outline.range.clone(), outline.outline.depth); - children_map.get(&key) - }) - .copied() - .unwrap_or(false) => - { - Some(entry.entry.clone()) - } - _ => None, - }) - }) - .expect("Should find an outline with children"); - - outline_panel.update_in(cx, |panel, window, cx| { - panel.select_entry(parent_outline.clone(), true, window, cx); - panel.collapse_selected_entry(&CollapseSelectedEntry, window, cx); - }); - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); - cx.run_until_parked(); - - outline_panel.update(cx, |outline_panel, cx| { - assert_eq!( - display_entries( - &project, - &snapshot(&outline_panel, cx), - &outline_panel.cached_entries, - outline_panel.selected_entry(), - cx, - ), - indoc!( - " -outline: mod outer <==== selected -outline: fn main()" - ) - ); - }); - - outline_panel.update_in(cx, |panel, window, cx| { - panel.expand_selected_entry(&ExpandSelectedEntry, window, cx); - }); - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); - cx.run_until_parked(); - - outline_panel.update(cx, |outline_panel, cx| { - assert_eq!( - display_entries( - &project, - &snapshot(&outline_panel, cx), - &outline_panel.cached_entries, - outline_panel.selected_entry(), - cx, - ), - indoc!( - " -outline: mod outer <==== selected - outline: pub struct OuterStruct - outline: field: String - outline: impl OuterStruct - outline: pub fn new() - outline: pub fn method(&self) - outline: mod inner - outline: pub fn inner_function() - outline: pub struct InnerStruct - outline: value: i32 -outline: fn main()" - ) - ); - }); - - outline_panel.update_in(cx, |panel, window, cx| { - panel.collapsed_entries.clear(); - panel.update_cached_entries(None, window, cx); - }); - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); - cx.run_until_parked(); - - outline_panel.update_in(cx, |panel, window, cx| { - let outlines_with_children: Vec<_> = panel - .cached_entries - .iter() - .filter_map(|entry| match &entry.entry { - PanelEntry::Outline(OutlineEntry::Outline(outline)) - if panel - .outline_children_cache - .get(&outline.buffer_id) - .and_then(|children_map| { - let key = (outline.outline.range.clone(), outline.outline.depth); - children_map.get(&key) - }) - .copied() - .unwrap_or(false) => - { - Some(entry.entry.clone()) - } - _ => None, - }) - .collect(); - - for outline in outlines_with_children { - panel.select_entry(outline, false, window, cx); - panel.collapse_selected_entry(&CollapseSelectedEntry, window, cx); - } - }); - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); - cx.run_until_parked(); - - outline_panel.update(cx, |outline_panel, cx| { - assert_eq!( - display_entries( - &project, - &snapshot(&outline_panel, cx), - &outline_panel.cached_entries, - outline_panel.selected_entry(), - cx, - ), - indoc!( - " -outline: mod outer -outline: fn main()" - ) - ); - }); - - let collapsed_entries_count = - outline_panel.read_with(cx, |panel, _| panel.collapsed_entries.len()); - assert!( - collapsed_entries_count > 0, - "Should have collapsed entries tracked" - ); - } - - #[gpui::test] - async fn test_outline_click_toggle_behavior(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.background_executor.clone()); - fs.insert_tree( - "/test", - json!({ - "src": { - "main.rs": indoc!(" - struct Config { - name: String, - value: i32, - } - impl Config { - fn new(name: String) -> Self { - Self { name, value: 0 } - } - fn get_value(&self) -> i32 { - self.value - } - } - enum Status { - Active, - Inactive, - } - fn process_config(config: Config) -> Status { - if config.get_value() > 0 { - Status::Active - } else { - Status::Inactive - } - } - fn main() { - let config = Config::new(\"test\".to_string()); - let status = process_config(config); - } - "), - } - }), - ) - .await; - - let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; - project.read_with(cx, |project, _| { - project.languages().add(Arc::new( - rust_lang() - .with_outline_query( - r#" - (struct_item - (visibility_modifier)? @context - "struct" @context - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @context - "for"? @context - type: (_) @context - body: (_)) @item - (function_item - (visibility_modifier)? @context - "fn" @context - name: (_) @name - parameters: (_) @context) @item - (mod_item - (visibility_modifier)? @context - "mod" @context - name: (_) @name) @item - (enum_item - (visibility_modifier)? @context - "enum" @context - name: (_) @name) @item - (field_declaration - (visibility_modifier)? @context - name: (_) @name - ":" @context - type: (_) @context) @item - "#, - ) - .unwrap(), - )) - }); - - let workspace = add_outline_panel(&project, cx).await; - let cx = &mut VisualTestContext::from_window(*workspace, cx); - let outline_panel = outline_panel(&workspace, cx); - - outline_panel.update_in(cx, |outline_panel, window, cx| { - outline_panel.set_active(true, window, cx) - }); - - let _editor = workspace - .update(cx, |workspace, window, cx| { - workspace.open_abs_path( - PathBuf::from("/test/src/main.rs"), - OpenOptions { - visible: Some(OpenVisible::All), - ..Default::default() - }, - window, - cx, - ) - }) - .unwrap() - .await - .unwrap(); - - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); - cx.run_until_parked(); - - outline_panel.update(cx, |outline_panel, _cx| { - outline_panel.selected_entry = SelectedEntry::None; - }); - - // Check initial state - all entries should be expanded by default - outline_panel.update(cx, |outline_panel, cx| { - assert_eq!( - display_entries( - &project, - &snapshot(&outline_panel, cx), - &outline_panel.cached_entries, - outline_panel.selected_entry(), - cx, - ), - indoc!( - " -outline: struct Config - outline: name: String - outline: value: i32 -outline: impl Config - outline: fn new(name: String) - outline: fn get_value(&self) -outline: enum Status -outline: fn process_config(config: Config) -outline: fn main()" - ) - ); - }); - - outline_panel.update(cx, |outline_panel, _cx| { - outline_panel.selected_entry = SelectedEntry::None; - }); - - cx.update(|window, cx| { - outline_panel.update(cx, |outline_panel, cx| { - outline_panel.select_first(&SelectFirst, window, cx); - }); - }); - - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); - cx.run_until_parked(); - - outline_panel.update(cx, |outline_panel, cx| { - assert_eq!( - display_entries( - &project, - &snapshot(&outline_panel, cx), - &outline_panel.cached_entries, - outline_panel.selected_entry(), - cx, - ), - indoc!( - " -outline: struct Config <==== selected - outline: name: String - outline: value: i32 -outline: impl Config - outline: fn new(name: String) - outline: fn get_value(&self) -outline: enum Status -outline: fn process_config(config: Config) -outline: fn main()" - ) - ); - }); - - cx.update(|window, cx| { - outline_panel.update(cx, |outline_panel, cx| { - outline_panel.open_selected_entry(&OpenSelectedEntry, window, cx); - }); - }); - - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); - cx.run_until_parked(); - - outline_panel.update(cx, |outline_panel, cx| { - assert_eq!( - display_entries( - &project, - &snapshot(&outline_panel, cx), - &outline_panel.cached_entries, - outline_panel.selected_entry(), - cx, - ), - indoc!( - " -outline: struct Config <==== selected -outline: impl Config - outline: fn new(name: String) - outline: fn get_value(&self) -outline: enum Status -outline: fn process_config(config: Config) -outline: fn main()" - ) - ); - }); - - cx.update(|window, cx| { - outline_panel.update(cx, |outline_panel, cx| { - outline_panel.open_selected_entry(&OpenSelectedEntry, window, cx); - }); - }); - - cx.executor() - .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); - cx.run_until_parked(); - - outline_panel.update(cx, |outline_panel, cx| { - assert_eq!( - display_entries( - &project, - &snapshot(&outline_panel, cx), - &outline_panel.cached_entries, - outline_panel.selected_entry(), - cx, - ), - indoc!( - " -outline: struct Config <==== selected - outline: name: String - outline: value: i32 -outline: impl Config - outline: fn new(name: String) - outline: fn get_value(&self) -outline: enum Status -outline: fn process_config(config: Config) -outline: fn main()" - ) - ); - }); - } } diff --git a/crates/outline_panel/src/outline_panel_settings.rs b/crates/outline_panel/src/outline_panel_settings.rs index 133d28b748..6b70cb54fb 100644 --- a/crates/outline_panel/src/outline_panel_settings.rs +++ b/crates/outline_panel/src/outline_panel_settings.rs @@ -31,7 +31,6 @@ pub struct OutlinePanelSettings { pub auto_reveal_entries: bool, pub auto_fold_dirs: bool, pub scrollbar: ScrollbarSettings, - pub expand_outlines_with_depth: usize, } #[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -106,13 +105,6 @@ pub struct OutlinePanelSettingsContent { pub indent_guides: Option<IndentGuidesSettingsContent>, /// Scrollbar-related settings pub scrollbar: Option<ScrollbarSettingsContent>, - /// Default depth to expand outline items in the current file. - /// The default depth to which outline entries are expanded on reveal. - /// - Set to 0 to collapse all items that have children - /// - Set to 1 or higher to collapse items at that depth or deeper - /// - /// Default: 100 - pub expand_outlines_with_depth: Option<usize>, } impl Settings for OutlinePanelSettings { diff --git a/crates/paths/src/paths.rs b/crates/paths/src/paths.rs index 47a0f12c06..2f3b188980 100644 --- a/crates/paths/src/paths.rs +++ b/crates/paths/src/paths.rs @@ -35,7 +35,6 @@ pub fn remote_server_dir_relative() -> &'static Path { /// Sets a custom directory for all user data, overriding the default data directory. /// This function must be called before any other path operations that depend on the data directory. -/// The directory's path will be canonicalized to an absolute path by a blocking FS operation. /// The directory will be created if it doesn't exist. /// /// # Arguments @@ -51,20 +50,13 @@ pub fn remote_server_dir_relative() -> &'static Path { /// /// Panics if: /// * Called after the data directory has been initialized (e.g., via `data_dir` or `config_dir`) -/// * The directory's path cannot be canonicalized to an absolute path /// * The directory cannot be created pub fn set_custom_data_dir(dir: &str) -> &'static PathBuf { if CURRENT_DATA_DIR.get().is_some() || CONFIG_DIR.get().is_some() { panic!("set_custom_data_dir called after data_dir or config_dir was initialized"); } CUSTOM_DATA_DIR.get_or_init(|| { - let mut path = PathBuf::from(dir); - if path.is_relative() { - let abs_path = path - .canonicalize() - .expect("failed to canonicalize custom data directory's path to an absolute path"); - path = PathBuf::from(util::paths::SanitizedPath::from(abs_path)) - } + let path = PathBuf::from(dir); std::fs::create_dir_all(&path).expect("failed to create custom data directory"); path }) diff --git a/crates/prettier/src/prettier_server.js b/crates/prettier/src/prettier_server.js index b3d8a660a4..6799b4aceb 100644 --- a/crates/prettier/src/prettier_server.js +++ b/crates/prettier/src/prettier_server.js @@ -152,10 +152,6 @@ async function handleMessage(message, prettier) { throw new Error(`Message method is undefined: ${JSON.stringify(message)}`); } else if (method == "initialized") { return; - } else if (method === "shutdown") { - sendResponse({ result: {} }); - } else if (method == "exit") { - process.exit(0); } if (id === undefined) { diff --git a/crates/project/src/debugger/dap_command.rs b/crates/project/src/debugger/dap_command.rs index 3be3192369..1cb611680c 100644 --- a/crates/project/src/debugger/dap_command.rs +++ b/crates/project/src/debugger/dap_command.rs @@ -107,7 +107,7 @@ impl<T: DapCommand> DapCommand for Arc<T> { #[derive(Debug, Hash, PartialEq, Eq)] pub struct StepCommand { - pub thread_id: i64, + pub thread_id: u64, pub granularity: Option<SteppingGranularity>, pub single_thread: Option<bool>, } @@ -483,7 +483,7 @@ impl DapCommand for ContinueCommand { #[derive(Debug, Hash, PartialEq, Eq)] pub(crate) struct PauseCommand { - pub thread_id: i64, + pub thread_id: u64, } impl LocalDapCommand for PauseCommand { @@ -612,7 +612,7 @@ impl DapCommand for DisconnectCommand { #[derive(Debug, Hash, PartialEq, Eq)] pub(crate) struct TerminateThreadsCommand { - pub thread_ids: Option<Vec<i64>>, + pub thread_ids: Option<Vec<u64>>, } impl LocalDapCommand for TerminateThreadsCommand { @@ -1182,7 +1182,7 @@ impl DapCommand for LoadedSourcesCommand { #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub(crate) struct StackTraceCommand { - pub thread_id: i64, + pub thread_id: u64, pub start_frame: Option<u64>, pub levels: Option<u64>, } diff --git a/crates/project/src/debugger/dap_store.rs b/crates/project/src/debugger/dap_store.rs index 6f834b5dc0..d494088b13 100644 --- a/crates/project/src/debugger/dap_store.rs +++ b/crates/project/src/debugger/dap_store.rs @@ -920,22 +920,12 @@ impl dap::adapters::DapDelegate for DapAdapterDelegate { self.console.unbounded_send(msg).ok(); } - #[cfg(not(target_os = "windows"))] async fn which(&self, command: &OsStr) -> Option<PathBuf> { let worktree_abs_path = self.worktree.abs_path(); let shell_path = self.shell_env().await.get("PATH").cloned(); which::which_in(command, shell_path.as_ref(), worktree_abs_path).ok() } - #[cfg(target_os = "windows")] - async fn which(&self, command: &OsStr) -> Option<PathBuf> { - // On Windows, `PATH` is handled differently from Unix. Windows generally expects users to modify the `PATH` themselves, - // and every program loads it directly from the system at startup. - // There's also no concept of a default shell on Windows, and you can't really retrieve one, so trying to get shell environment variables - // from a specific directory doesn’t make sense on Windows. - which::which(command).ok() - } - async fn shell_env(&self) -> HashMap<String, String> { let task = self.load_shell_env_task.clone(); task.await.unwrap_or_default() diff --git a/crates/project/src/debugger/locators/cargo.rs b/crates/project/src/debugger/locators/cargo.rs index fa265dae58..7d70371380 100644 --- a/crates/project/src/debugger/locators/cargo.rs +++ b/crates/project/src/debugger/locators/cargo.rs @@ -128,7 +128,7 @@ impl DapLocator for CargoLocator { .chain(Some("--message-format=json".to_owned())) .collect(), ); - let mut child = util::command::new_smol_command(program) + let mut child = Command::new(program) .args(args) .envs(build_config.env.iter().map(|(k, v)| (k.clone(), v.clone()))) .current_dir(cwd) diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index f60a7becf7..1e296ac2ac 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -61,10 +61,15 @@ use worktree::Worktree; #[derive(Debug, Copy, Clone, Hash, PartialEq, PartialOrd, Ord, Eq)] #[repr(transparent)] -pub struct ThreadId(pub i64); +pub struct ThreadId(pub u64); -impl From<i64> for ThreadId { - fn from(id: i64) -> Self { +impl ThreadId { + pub const MIN: ThreadId = ThreadId(u64::MIN); + pub const MAX: ThreadId = ThreadId(u64::MAX); +} + +impl From<u64> for ThreadId { + fn from(id: u64) -> Self { Self(id) } } diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index 28dd0e91e3..eb16446daf 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -14,10 +14,9 @@ use collections::HashMap; pub use conflict_set::{ConflictRegion, ConflictSet, ConflictSetSnapshot, ConflictSetUpdate}; use fs::Fs; use futures::{ - FutureExt, StreamExt, + FutureExt, StreamExt as _, channel::{mpsc, oneshot}, future::{self, Shared}, - stream::FuturesOrdered, }; use git::{ BuildPermalinkParams, GitHostingProviderRegistry, WORK_DIRECTORY_REPO_PATH, @@ -64,8 +63,8 @@ use sum_tree::{Edit, SumTree, TreeSet}; use text::{Bias, BufferId}; use util::{ResultExt, debug_panic, post_inc}; use worktree::{ - File, PathChange, PathKey, PathProgress, PathSummary, PathTarget, ProjectEntryId, - UpdatedGitRepositoriesSet, UpdatedGitRepository, Worktree, + File, PathKey, PathProgress, PathSummary, PathTarget, UpdatedGitRepositoriesSet, + UpdatedGitRepository, Worktree, }; pub struct GitStore { @@ -420,8 +419,6 @@ impl GitStore { client.add_entity_request_handler(Self::handle_fetch); client.add_entity_request_handler(Self::handle_stage); client.add_entity_request_handler(Self::handle_unstage); - client.add_entity_request_handler(Self::handle_stash); - client.add_entity_request_handler(Self::handle_stash_pop); client.add_entity_request_handler(Self::handle_commit); client.add_entity_request_handler(Self::handle_reset); client.add_entity_request_handler(Self::handle_show); @@ -1086,26 +1083,27 @@ impl GitStore { match event { WorktreeStoreEvent::WorktreeUpdatedEntries(worktree_id, updated_entries) => { - if let Some(worktree) = self - .worktree_store - .read(cx) - .worktree_for_id(*worktree_id, cx) - { - let paths_by_git_repo = - self.process_updated_entries(&worktree, updated_entries, cx); - let downstream = downstream - .as_ref() - .map(|downstream| downstream.updates_tx.clone()); - cx.spawn(async move |_, cx| { - let paths_by_git_repo = paths_by_git_repo.await; - for (repo, paths) in paths_by_git_repo { - repo.update(cx, |repo, cx| { - repo.paths_changed(paths, downstream.clone(), cx); - }) - .ok(); - } - }) - .detach(); + let mut paths_by_git_repo = HashMap::<_, Vec<_>>::default(); + for (relative_path, _, _) in updated_entries.iter() { + let Some((repo, repo_path)) = self.repository_and_path_for_project_path( + &(*worktree_id, relative_path.clone()).into(), + cx, + ) else { + continue; + }; + paths_by_git_repo.entry(repo).or_default().push(repo_path) + } + + for (repo, paths) in paths_by_git_repo { + repo.update(cx, |repo, cx| { + repo.paths_changed( + paths, + downstream + .as_ref() + .map(|downstream| downstream.updates_tx.clone()), + cx, + ); + }); } } WorktreeStoreEvent::WorktreeUpdatedGitRepositories(worktree_id, changed_repos) => { @@ -1698,48 +1696,6 @@ impl GitStore { Ok(proto::Ack {}) } - async fn handle_stash( - this: Entity<Self>, - envelope: TypedEnvelope<proto::Stash>, - mut cx: AsyncApp, - ) -> Result<proto::Ack> { - let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); - let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; - - let entries = envelope - .payload - .paths - .into_iter() - .map(PathBuf::from) - .map(RepoPath::new) - .collect(); - - repository_handle - .update(&mut cx, |repository_handle, cx| { - repository_handle.stash_entries(entries, cx) - })? - .await?; - - Ok(proto::Ack {}) - } - - async fn handle_stash_pop( - this: Entity<Self>, - envelope: TypedEnvelope<proto::StashPop>, - mut cx: AsyncApp, - ) -> Result<proto::Ack> { - let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); - let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; - - repository_handle - .update(&mut cx, |repository_handle, cx| { - repository_handle.stash_pop(cx) - })? - .await?; - - Ok(proto::Ack {}) - } - async fn handle_set_index_text( this: Entity<Self>, envelope: TypedEnvelope<proto::SetIndexText>, @@ -2235,80 +2191,6 @@ impl GitStore { .map(|(id, repo)| (*id, repo.read(cx).snapshot.clone())) .collect() } - - fn process_updated_entries( - &self, - worktree: &Entity<Worktree>, - updated_entries: &[(Arc<Path>, ProjectEntryId, PathChange)], - cx: &mut App, - ) -> Task<HashMap<Entity<Repository>, Vec<RepoPath>>> { - let mut repo_paths = self - .repositories - .values() - .map(|repo| (repo.read(cx).work_directory_abs_path.clone(), repo.clone())) - .collect::<Vec<_>>(); - let mut entries: Vec<_> = updated_entries - .iter() - .map(|(path, _, _)| path.clone()) - .collect(); - entries.sort(); - let worktree = worktree.read(cx); - - let entries = entries - .into_iter() - .filter_map(|path| worktree.absolutize(&path).ok()) - .collect::<Arc<[_]>>(); - - let executor = cx.background_executor().clone(); - cx.background_executor().spawn(async move { - repo_paths.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); - let mut paths_by_git_repo = HashMap::<_, Vec<_>>::default(); - let mut tasks = FuturesOrdered::new(); - for (repo_path, repo) in repo_paths.into_iter().rev() { - let entries = entries.clone(); - let task = executor.spawn(async move { - // Find all repository paths that belong to this repo - let mut ix = entries.partition_point(|path| path < &*repo_path); - if ix == entries.len() { - return None; - }; - - let mut paths = vec![]; - // All paths prefixed by a given repo will constitute a continuous range. - while let Some(path) = entries.get(ix) - && let Some(repo_path) = - RepositorySnapshot::abs_path_to_repo_path_inner(&repo_path, &path) - { - paths.push((repo_path, ix)); - ix += 1; - } - Some((repo, paths)) - }); - tasks.push_back(task); - } - - // Now, let's filter out the "duplicate" entries that were processed by multiple distinct repos. - let mut path_was_used = vec![false; entries.len()]; - let tasks = tasks.collect::<Vec<_>>().await; - // Process tasks from the back: iterating backwards allows us to see more-specific paths first. - // We always want to assign a path to it's innermost repository. - for t in tasks { - let Some((repo, paths)) = t else { - continue; - }; - let entry = paths_by_git_repo.entry(repo).or_default(); - for (repo_path, ix) in paths { - if path_was_used[ix] { - continue; - } - path_was_used[ix] = true; - entry.push(repo_path); - } - } - - paths_by_git_repo - }) - } } impl BufferGitState { @@ -2778,16 +2660,8 @@ impl RepositorySnapshot { } pub fn abs_path_to_repo_path(&self, abs_path: &Path) -> Option<RepoPath> { - Self::abs_path_to_repo_path_inner(&self.work_directory_abs_path, abs_path) - } - - #[inline] - fn abs_path_to_repo_path_inner( - work_directory_abs_path: &Path, - abs_path: &Path, - ) -> Option<RepoPath> { abs_path - .strip_prefix(&work_directory_abs_path) + .strip_prefix(&self.work_directory_abs_path) .map(RepoPath::from) .ok() } @@ -3584,82 +3458,6 @@ impl Repository { self.unstage_entries(to_unstage, cx) } - pub fn stash_all(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> { - let to_stash = self - .cached_status() - .map(|entry| entry.repo_path.clone()) - .collect(); - - self.stash_entries(to_stash, cx) - } - - pub fn stash_entries( - &mut self, - entries: Vec<RepoPath>, - cx: &mut Context<Self>, - ) -> Task<anyhow::Result<()>> { - let id = self.id; - - cx.spawn(async move |this, cx| { - this.update(cx, |this, _| { - this.send_job(None, move |git_repo, _cx| async move { - match git_repo { - RepositoryState::Local { - backend, - environment, - .. - } => backend.stash_paths(entries, environment).await, - RepositoryState::Remote { project_id, client } => { - client - .request(proto::Stash { - project_id: project_id.0, - repository_id: id.to_proto(), - paths: entries - .into_iter() - .map(|repo_path| repo_path.as_ref().to_proto()) - .collect(), - }) - .await - .context("sending stash request")?; - Ok(()) - } - } - }) - })? - .await??; - Ok(()) - }) - } - - pub fn stash_pop(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> { - let id = self.id; - cx.spawn(async move |this, cx| { - this.update(cx, |this, _| { - this.send_job(None, move |git_repo, _cx| async move { - match git_repo { - RepositoryState::Local { - backend, - environment, - .. - } => backend.stash_pop(environment).await, - RepositoryState::Remote { project_id, client } => { - client - .request(proto::StashPop { - project_id: project_id.0, - repository_id: id.to_proto(), - }) - .await - .context("sending stash pop request")?; - Ok(()) - } - } - }) - })? - .await??; - Ok(()) - }) - } - pub fn commit( &mut self, message: SharedString, diff --git a/crates/project/src/git_store/git_traversal.rs b/crates/project/src/git_store/git_traversal.rs index 777042cb02..cd173d5714 100644 --- a/crates/project/src/git_store/git_traversal.rs +++ b/crates/project/src/git_store/git_traversal.rs @@ -1,6 +1,6 @@ use collections::HashMap; -use git::{repository::RepoPath, status::GitSummary}; -use std::{collections::BTreeMap, ops::Deref, path::Path}; +use git::status::GitSummary; +use std::{ops::Deref, path::Path}; use sum_tree::Cursor; use text::Bias; use worktree::{Entry, PathProgress, PathTarget, Traversal}; @@ -11,7 +11,7 @@ use super::{RepositoryId, RepositorySnapshot, StatusEntry}; pub struct GitTraversal<'a> { traversal: Traversal<'a>, current_entry_summary: Option<GitSummary>, - repo_root_to_snapshot: BTreeMap<&'a Path, &'a RepositorySnapshot>, + repo_snapshots: &'a HashMap<RepositoryId, RepositorySnapshot>, repo_location: Option<(RepositoryId, Cursor<'a, StatusEntry, PathProgress<'a>>)>, } @@ -20,46 +20,16 @@ impl<'a> GitTraversal<'a> { repo_snapshots: &'a HashMap<RepositoryId, RepositorySnapshot>, traversal: Traversal<'a>, ) -> GitTraversal<'a> { - let repo_root_to_snapshot = repo_snapshots - .values() - .map(|snapshot| (&*snapshot.work_directory_abs_path, snapshot)) - .collect(); let mut this = GitTraversal { traversal, + repo_snapshots, current_entry_summary: None, repo_location: None, - repo_root_to_snapshot, }; this.synchronize_statuses(true); this } - fn repo_root_for_path(&self, path: &Path) -> Option<(&'a RepositorySnapshot, RepoPath)> { - // We might need to perform a range search multiple times, as there may be a nested repository inbetween - // the target and our path. E.g: - // /our_root_repo/ - // .git/ - // other_repo/ - // .git/ - // our_query.txt - let mut query = path.ancestors(); - while let Some(query) = query.next() { - let (_, snapshot) = self - .repo_root_to_snapshot - .range(Path::new("")..=query) - .last()?; - - let stripped = snapshot - .abs_path_to_repo_path(path) - .map(|repo_path| (*snapshot, repo_path)); - if stripped.is_some() { - return stripped; - } - } - - None - } - fn synchronize_statuses(&mut self, reset: bool) { self.current_entry_summary = None; @@ -72,7 +42,15 @@ impl<'a> GitTraversal<'a> { return; }; - let Some((repo, repo_path)) = self.repo_root_for_path(&abs_path) else { + let Some((repo, repo_path)) = self + .repo_snapshots + .values() + .filter_map(|repo_snapshot| { + let repo_path = repo_snapshot.abs_path_to_repo_path(&abs_path)?; + Some((repo_snapshot, repo_path)) + }) + .max_by_key(|(repo, _)| repo.work_directory_abs_path.clone()) + else { self.repo_location = None; return; }; diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index 2fd61ea0b2..958921a0e6 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -3580,18 +3580,6 @@ impl LspCommand for GetCodeLens { } } -impl LinkedEditingRange { - pub fn check_server_capabilities(capabilities: ServerCapabilities) -> bool { - let Some(linked_editing_options) = capabilities.linked_editing_range_provider else { - return false; - }; - if let LinkedEditingRangeServerCapabilities::Simple(false) = linked_editing_options { - return false; - } - true - } -} - #[async_trait(?Send)] impl LspCommand for LinkedEditingRange { type Response = Vec<Range<Anchor>>; @@ -3603,7 +3591,16 @@ impl LspCommand for LinkedEditingRange { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - Self::check_server_capabilities(capabilities.server_capabilities) + let Some(linked_editing_options) = &capabilities + .server_capabilities + .linked_editing_range_provider + else { + return false; + }; + if let LinkedEditingRangeServerCapabilities::Simple(false) = linked_editing_options { + return false; + } + true } fn to_lsp( diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 98cecc2e9b..161b861dd0 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -46,7 +46,6 @@ use language::{ DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, Diff, File as _, Language, LanguageName, LanguageRegistry, LanguageToolchainStore, LocalFile, LspAdapter, LspAdapterDelegate, Patch, PointUtf16, TextBufferSnapshot, ToOffset, ToPointUtf16, Transaction, Unclipped, - WorkspaceFoldersContent, language_settings::{ FormatOnSave, Formatter, LanguageSettings, SelectedFormatter, language_settings, }, @@ -218,7 +217,6 @@ impl LocalLspStore { let binary = self.get_language_server_binary(adapter.clone(), delegate.clone(), true, cx); let pending_workspace_folders: Arc<Mutex<BTreeSet<Url>>> = Default::default(); - let pending_server = cx.spawn({ let adapter = adapter.clone(); let server_name = adapter.name.clone(); @@ -244,18 +242,14 @@ impl LocalLspStore { return Ok(server); } - let code_action_kinds = adapter.code_action_kinds(); lsp::LanguageServer::new( stderr_capture, server_id, server_name, binary, &root_path, - code_action_kinds, - Some(pending_workspace_folders).filter(|_| { - adapter.adapter.workspace_folders_content() - == WorkspaceFoldersContent::SubprojectRoots - }), + adapter.code_action_kinds(), + pending_workspace_folders, cx, ) } @@ -424,7 +418,7 @@ impl LocalLspStore { if settings.as_ref().is_some_and(|b| b.path.is_some()) { let settings = settings.unwrap(); - return cx.background_spawn(async move { + return cx.spawn(async move |_| { let mut env = delegate.shell_env().await; env.extend(settings.env.unwrap_or_default()); @@ -581,7 +575,8 @@ impl LocalLspStore { }; let root = server.workspace_folders(); Ok(Some( - root.into_iter() + root.iter() + .cloned() .map(|uri| WorkspaceFolder { uri, name: Default::default(), @@ -2425,12 +2420,36 @@ impl LocalLspStore { let server_id = server_node.server_id_or_init( |LaunchDisposition { server_name, - + attach, path, settings, }| { - let server_id = - { + let server_id = match attach { + language::Attach::InstancePerRoot => { + // todo: handle instance per root proper. + if let Some(server_ids) = self + .language_server_ids + .get(&(worktree_id, server_name.clone())) + { + server_ids.iter().cloned().next().unwrap() + } else { + let language_name = language.name(); + let adapter = self.languages + .lsp_adapters(&language_name) + .into_iter() + .find(|adapter| &adapter.name() == server_name) + .expect("To find LSP adapter"); + let server_id = self.start_language_server( + &worktree, + delegate.clone(), + adapter, + settings, + cx, + ); + server_id + } + } + language::Attach::Shared => { let uri = Url::from_file_path( worktree.read(cx).abs_path().join(&path.path), ); @@ -2465,7 +2484,7 @@ impl LocalLspStore { } else { unreachable!("Language server ID should be available, as it's registered on demand") } - + } }; let lsp_store = self.weak.clone(); let server_name = server_node.name(); @@ -4681,11 +4700,35 @@ impl LspStore { let server_id = node.server_id_or_init( |LaunchDisposition { server_name, - + attach, path, settings, - }| - { + }| match attach { + language::Attach::InstancePerRoot => { + // todo: handle instance per root proper. + if let Some(server_ids) = local + .language_server_ids + .get(&(worktree_id, server_name.clone())) + { + server_ids.iter().cloned().next().unwrap() + } else { + let adapter = local + .languages + .lsp_adapters(&language) + .into_iter() + .find(|adapter| &adapter.name() == server_name) + .expect("To find LSP adapter"); + let server_id = local.start_language_server( + &worktree, + delegate.clone(), + adapter, + settings, + cx, + ); + server_id + } + } + language::Attach::Shared => { let uri = Url::from_file_path( worktree.read(cx).abs_path().join(&path.path), ); @@ -4714,6 +4757,7 @@ impl LspStore { } server_id } + }, ); if let Some(language_server_id) = server_id { @@ -4911,7 +4955,7 @@ impl LspStore { language_server_id: server_id.0 as u64, hint: Some(InlayHints::project_to_proto_hint(hint.clone())), }; - cx.background_spawn(async move { + cx.spawn(async move |_, _| { let response = upstream_client .request(request) .await @@ -5069,7 +5113,10 @@ impl LspStore { local .language_servers_for_buffer(buffer, cx) .filter(|(_, server)| { - LinkedEditingRange::check_server_capabilities(server.capabilities()) + server + .capabilities() + .linked_editing_range_provider + .is_some() }) .filter(|(adapter, _)| { scope @@ -5122,7 +5169,7 @@ impl LspStore { trigger, version: serialize_version(&buffer.read(cx).version()), }; - cx.background_spawn(async move { + cx.spawn(async move |_, _| { client .request(request) .await? @@ -5281,7 +5328,7 @@ impl LspStore { GetDefinitions { position }, cx, ); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { Ok(definitions_task .await .into_iter() @@ -5354,7 +5401,7 @@ impl LspStore { GetDeclarations { position }, cx, ); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { Ok(declarations_task .await .into_iter() @@ -5427,7 +5474,7 @@ impl LspStore { GetTypeDefinitions { position }, cx, ); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { Ok(type_definitions_task .await .into_iter() @@ -5500,7 +5547,7 @@ impl LspStore { GetImplementations { position }, cx, ); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { Ok(implementations_task .await .into_iter() @@ -5573,7 +5620,7 @@ impl LspStore { GetReferences { position }, cx, ); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { Ok(references_task .await .into_iter() @@ -5657,7 +5704,7 @@ impl LspStore { }, cx, ); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { Ok(all_actions_task .await .into_iter() @@ -6041,6 +6088,7 @@ impl LspStore { let resolved = Self::resolve_completion_local( server, + &buffer_snapshot, completions.clone(), completion_index, ) @@ -6073,6 +6121,7 @@ impl LspStore { async fn resolve_completion_local( server: Arc<lsp::LanguageServer>, + snapshot: &BufferSnapshot, completions: Rc<RefCell<Box<[Completion]>>>, completion_index: usize, ) -> Result<()> { @@ -6117,8 +6166,26 @@ impl LspStore { .into_response() .context("resolve completion")?; - // We must not use any data such as sortText, filterText, insertText and textEdit to edit `Completion` since they are not suppose change during resolve. - // Refer: https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_completion + if let Some(text_edit) = resolved_completion.text_edit.as_ref() { + // Technically we don't have to parse the whole `text_edit`, since the only + // language server we currently use that does update `text_edit` in `completionItem/resolve` + // is `typescript-language-server` and they only update `text_edit.new_text`. + // But we should not rely on that. + let edit = parse_completion_text_edit(text_edit, snapshot); + + if let Some(mut parsed_edit) = edit { + LineEnding::normalize(&mut parsed_edit.new_text); + + let mut completions = completions.borrow_mut(); + let completion = &mut completions[completion_index]; + + completion.new_text = parsed_edit.new_text; + completion.replace_range = parsed_edit.replace_range; + if let CompletionSource::Lsp { insert_range, .. } = &mut completion.source { + *insert_range = parsed_edit.insert_range; + } + } + } let mut completions = completions.borrow_mut(); let completion = &mut completions[completion_index]; @@ -6368,10 +6435,12 @@ impl LspStore { }) else { return Task::ready(Ok(None)); }; + let snapshot = buffer_handle.read(&cx).snapshot(); cx.spawn(async move |this, cx| { Self::resolve_completion_local( server.clone(), + &snapshot, completions.clone(), completion_index, ) @@ -6829,7 +6898,7 @@ impl LspStore { } else { let document_colors_task = self.request_multiple_lsp_locally(buffer, None::<usize>, GetDocumentColor, cx); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { Ok(document_colors_task .await .into_iter() @@ -6908,7 +6977,7 @@ impl LspStore { GetSignatureHelp { position }, cx, ); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { all_actions_task .await .into_iter() @@ -6985,7 +7054,7 @@ impl LspStore { GetHover { position }, cx, ); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { all_actions_task .await .into_iter() @@ -7988,7 +8057,7 @@ impl LspStore { }) .collect::<FuturesUnordered<_>>(); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { let mut responses = Vec::with_capacity(response_results.len()); while let Some((server_id, response_result)) = response_results.next().await { if let Some(response) = response_result.log_err() { diff --git a/crates/project/src/manifest_tree/server_tree.rs b/crates/project/src/manifest_tree/server_tree.rs index 81cb1c450c..0283f06eec 100644 --- a/crates/project/src/manifest_tree/server_tree.rs +++ b/crates/project/src/manifest_tree/server_tree.rs @@ -13,10 +13,10 @@ use std::{ sync::{Arc, Weak}, }; -use collections::IndexMap; +use collections::{HashMap, IndexMap}; use gpui::{App, AppContext as _, Entity, Subscription}; use language::{ - CachedLspAdapter, LanguageName, LanguageRegistry, ManifestDelegate, + Attach, CachedLspAdapter, LanguageName, LanguageRegistry, ManifestDelegate, language_settings::AllLanguageSettings, }; use lsp::LanguageServerName; @@ -38,6 +38,7 @@ pub(crate) struct ServersForWorktree { pub struct LanguageServerTree { manifest_tree: Entity<ManifestTree>, pub(crate) instances: BTreeMap<WorktreeId, ServersForWorktree>, + attach_kind_cache: HashMap<LanguageServerName, Attach>, languages: Arc<LanguageRegistry>, _subscriptions: Subscription, } @@ -52,6 +53,7 @@ pub struct LanguageServerTreeNode(Weak<InnerTreeNode>); #[derive(Debug)] pub(crate) struct LaunchDisposition<'a> { pub(crate) server_name: &'a LanguageServerName, + pub(crate) attach: Attach, pub(crate) path: ProjectPath, pub(crate) settings: Arc<LspSettings>, } @@ -60,6 +62,7 @@ impl<'a> From<&'a InnerTreeNode> for LaunchDisposition<'a> { fn from(value: &'a InnerTreeNode) -> Self { LaunchDisposition { server_name: &value.name, + attach: value.attach, path: value.path.clone(), settings: value.settings.clone(), } @@ -102,6 +105,7 @@ impl From<Weak<InnerTreeNode>> for LanguageServerTreeNode { pub struct InnerTreeNode { id: OnceLock<LanguageServerId>, name: LanguageServerName, + attach: Attach, path: ProjectPath, settings: Arc<LspSettings>, } @@ -109,12 +113,14 @@ pub struct InnerTreeNode { impl InnerTreeNode { fn new( name: LanguageServerName, + attach: Attach, path: ProjectPath, settings: impl Into<Arc<LspSettings>>, ) -> Self { InnerTreeNode { id: Default::default(), name, + attach, path, settings: settings.into(), } @@ -124,11 +130,8 @@ impl InnerTreeNode { /// Determines how the list of adapters to query should be constructed. pub(crate) enum AdapterQuery<'a> { /// Search for roots of all adapters associated with a given language name. - /// Layman: Look for all project roots along the queried path that have any - /// language server associated with this language running. Language(&'a LanguageName), /// Search for roots of adapter with a given name. - /// Layman: Look for all project roots along the queried path that have this server running. Adapter(&'a LanguageServerName), } @@ -144,7 +147,7 @@ impl LanguageServerTree { }), manifest_tree, instances: Default::default(), - + attach_kind_cache: Default::default(), languages, }) } @@ -220,6 +223,7 @@ impl LanguageServerTree { .and_then(|name| roots.get(&name)) .cloned() .unwrap_or_else(|| root_path.clone()); + let attach = adapter.attach_kind(); let inner_node = self .instances @@ -233,6 +237,7 @@ impl LanguageServerTree { ( Arc::new(InnerTreeNode::new( adapter.name(), + attach, root_path.clone(), settings.clone(), )), @@ -374,6 +379,7 @@ pub(crate) struct ServerTreeRebase<'a> { impl<'tree> ServerTreeRebase<'tree> { fn new(new_tree: &'tree mut LanguageServerTree) -> Self { let old_contents = std::mem::take(&mut new_tree.instances); + new_tree.attach_kind_cache.clear(); let all_server_ids = old_contents .values() .flat_map(|nodes| { @@ -440,7 +446,10 @@ impl<'tree> ServerTreeRebase<'tree> { .get(&disposition.path.worktree_id) .and_then(|worktree_nodes| worktree_nodes.roots.get(&disposition.path.path)) .and_then(|roots| roots.get(&disposition.name)) - .filter(|(old_node, _)| disposition.settings == old_node.settings) + .filter(|(old_node, _)| { + disposition.attach == old_node.attach + && disposition.settings == old_node.settings + }) else { return Some(node); }; diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 623f48d3c9..6b943216b3 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1362,7 +1362,10 @@ impl Project { fs: Arc<dyn Fs>, cx: AsyncApp, ) -> Result<Entity<Self>> { - client.connect(true, &cx).await.into_response()?; + client + .authenticate_and_connect(true, &cx) + .await + .into_response()?; let subscriptions = [ EntitySubscription::Project(client.subscribe_to_entity::<Self>(remote_id)?), @@ -3369,7 +3372,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.definitions(buffer, position, cx) }); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { let result = task.await; drop(guard); result @@ -3387,7 +3390,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.declarations(buffer, position, cx) }); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { let result = task.await; drop(guard); result @@ -3405,7 +3408,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.type_definitions(buffer, position, cx) }); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { let result = task.await; drop(guard); result @@ -3423,7 +3426,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.implementations(buffer, position, cx) }); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { let result = task.await; drop(guard); result @@ -3441,7 +3444,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.references(buffer, position, cx) }); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { let result = task.await; drop(guard); result @@ -3993,7 +3996,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.request_lsp(buffer_handle, server, request, cx) }); - cx.background_spawn(async move { + cx.spawn(async move |_, _| { let result = task.await; drop(guard); result diff --git a/crates/project/src/terminals.rs b/crates/project/src/terminals.rs index 973d4e8811..8cfbdff311 100644 --- a/crates/project/src/terminals.rs +++ b/crates/project/src/terminals.rs @@ -213,24 +213,17 @@ impl Project { cx: &mut Context<Self>, ) -> Result<Entity<Terminal>> { let this = &mut *self; - let ssh_details = this.ssh_details(cx); let path: Option<Arc<Path>> = match &kind { TerminalKind::Shell(path) => path.as_ref().map(|path| Arc::from(path.as_ref())), TerminalKind::Task(spawn_task) => { if let Some(cwd) = &spawn_task.cwd { - if ssh_details.is_some() { - Some(Arc::from(cwd.as_ref())) - } else { - let cwd = cwd.to_string_lossy(); - let tilde_substituted = shellexpand::tilde(&cwd); - Some(Arc::from(Path::new(tilde_substituted.as_ref()))) - } + Some(Arc::from(cwd.as_ref())) } else { this.active_project_directory(cx) } } }; - + let ssh_details = this.ssh_details(cx); let is_ssh_terminal = ssh_details.is_some(); let mut settings_location = None; diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 05e6bfe4df..b8a7aa2220 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -2731,7 +2731,26 @@ impl ProjectPanel { } fn index_for_selection(&self, selection: SelectedEntry) -> Option<(usize, usize, usize)> { - self.index_for_entry(selection.entry_id, selection.worktree_id) + let mut entry_index = 0; + let mut visible_entries_index = 0; + for (worktree_index, (worktree_id, worktree_entries, _)) in + self.visible_entries.iter().enumerate() + { + if *worktree_id == selection.worktree_id { + for entry in worktree_entries { + if entry.id == selection.entry_id { + return Some((worktree_index, entry_index, visible_entries_index)); + } else { + visible_entries_index += 1; + entry_index += 1; + } + } + break; + } else { + visible_entries_index += worktree_entries.len(); + } + } + None } fn disjoint_entries(&self, cx: &App) -> BTreeSet<SelectedEntry> { @@ -3342,12 +3361,12 @@ impl ProjectPanel { entry_id: ProjectEntryId, worktree_id: WorktreeId, ) -> Option<(usize, usize, usize)> { + let mut worktree_ix = 0; let mut total_ix = 0; - for (worktree_ix, (current_worktree_id, visible_worktree_entries, _)) in - self.visible_entries.iter().enumerate() - { + for (current_worktree_id, visible_worktree_entries, _) in &self.visible_entries { if worktree_id != *current_worktree_id { total_ix += visible_worktree_entries.len(); + worktree_ix += 1; continue; } diff --git a/crates/proto/proto/debugger.proto b/crates/proto/proto/debugger.proto index c6f9c9f134..09abd4bf1c 100644 --- a/crates/proto/proto/debugger.proto +++ b/crates/proto/proto/debugger.proto @@ -188,7 +188,7 @@ message DapSetVariableValueResponse { message DapPauseRequest { uint64 project_id = 1; uint64 client_id = 2; - int64 thread_id = 3; + uint64 thread_id = 3; } message DapDisconnectRequest { @@ -202,7 +202,7 @@ message DapDisconnectRequest { message DapTerminateThreadsRequest { uint64 project_id = 1; uint64 client_id = 2; - repeated int64 thread_ids = 3; + repeated uint64 thread_ids = 3; } message DapThreadsRequest { @@ -246,7 +246,7 @@ message IgnoreBreakpointState { message DapNextRequest { uint64 project_id = 1; uint64 client_id = 2; - int64 thread_id = 3; + uint64 thread_id = 3; optional bool single_thread = 4; optional SteppingGranularity granularity = 5; } @@ -254,7 +254,7 @@ message DapNextRequest { message DapStepInRequest { uint64 project_id = 1; uint64 client_id = 2; - int64 thread_id = 3; + uint64 thread_id = 3; optional uint64 target_id = 4; optional bool single_thread = 5; optional SteppingGranularity granularity = 6; @@ -263,7 +263,7 @@ message DapStepInRequest { message DapStepOutRequest { uint64 project_id = 1; uint64 client_id = 2; - int64 thread_id = 3; + uint64 thread_id = 3; optional bool single_thread = 4; optional SteppingGranularity granularity = 5; } @@ -271,7 +271,7 @@ message DapStepOutRequest { message DapStepBackRequest { uint64 project_id = 1; uint64 client_id = 2; - int64 thread_id = 3; + uint64 thread_id = 3; optional bool single_thread = 4; optional SteppingGranularity granularity = 5; } @@ -279,7 +279,7 @@ message DapStepBackRequest { message DapContinueRequest { uint64 project_id = 1; uint64 client_id = 2; - int64 thread_id = 3; + uint64 thread_id = 3; optional bool single_thread = 4; } @@ -311,7 +311,7 @@ message DapLoadedSourcesResponse { message DapStackTraceRequest { uint64 project_id = 1; uint64 client_id = 2; - int64 thread_id = 3; + uint64 thread_id = 3; optional uint64 start_frame = 4; optional uint64 stack_trace_levels = 5; } @@ -358,7 +358,7 @@ message DapVariable { } message DapThread { - int64 id = 1; + uint64 id = 1; string name = 2; } diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index ea08d36371..1d544b15ff 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -286,17 +286,6 @@ message Unstage { repeated string paths = 4; } -message Stash { - uint64 project_id = 1; - uint64 repository_id = 2; - repeated string paths = 3; -} - -message StashPop { - uint64 project_id = 1; - uint64 repository_id = 2; -} - message Commit { uint64 project_id = 1; reserved 2; diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 29ab2b1e90..31f929ec90 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -396,10 +396,8 @@ message Envelope { GetDocumentColor get_document_color = 353; GetDocumentColorResponse get_document_color_response = 354; GetColorPresentation get_color_presentation = 355; - GetColorPresentationResponse get_color_presentation_response = 356; + GetColorPresentationResponse get_color_presentation_response = 356; // current max - Stash stash = 357; - StashPop stash_pop = 358; // current max } reserved 87 to 88; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 83e5a77c86..918ac9e935 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -261,8 +261,6 @@ messages!( (Unfollow, Foreground), (UnshareProject, Foreground), (Unstage, Background), - (Stash, Background), - (StashPop, Background), (UpdateBuffer, Foreground), (UpdateBufferFile, Foreground), (UpdateChannelBuffer, Foreground), @@ -421,8 +419,6 @@ request_messages!( (TaskContextForLocation, TaskContext), (Test, Test), (Unstage, Ack), - (Stash, Ack), - (StashPop, Ack), (UpdateBuffer, Ack), (UpdateParticipantLocation, Ack), (UpdateProject, Ack), @@ -553,8 +549,6 @@ entity_messages!( TaskContextForLocation, UnshareProject, Unstage, - Stash, - StashPop, UpdateBuffer, UpdateBufferFile, UpdateDiagnosticSummary, @@ -784,25 +778,6 @@ pub fn split_repository_update( }]) } -impl MultiLspQuery { - pub fn request_str(&self) -> &str { - match self.request { - Some(multi_lsp_query::Request::GetHover(_)) => "GetHover", - Some(multi_lsp_query::Request::GetCodeActions(_)) => "GetCodeActions", - Some(multi_lsp_query::Request::GetSignatureHelp(_)) => "GetSignatureHelp", - Some(multi_lsp_query::Request::GetCodeLens(_)) => "GetCodeLens", - Some(multi_lsp_query::Request::GetDocumentDiagnostics(_)) => "GetDocumentDiagnostics", - Some(multi_lsp_query::Request::GetDocumentColor(_)) => "GetDocumentColor", - Some(multi_lsp_query::Request::GetDefinition(_)) => "GetDefinition", - Some(multi_lsp_query::Request::GetDeclaration(_)) => "GetDeclaration", - Some(multi_lsp_query::Request::GetTypeDefinition(_)) => "GetTypeDefinition", - Some(multi_lsp_query::Request::GetImplementation(_)) => "GetImplementation", - Some(multi_lsp_query::Request::GetReferences(_)) => "GetReferences", - None => "<unknown>", - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index 655e24860a..aa5103e62b 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -963,7 +963,7 @@ impl RemoteServerProjects { .child({ let project = project.clone(); // Right-margin to offset it from the Scrollbar - IconButton::new("remove-remote-project", IconName::Trash) + IconButton::new("remove-remote-project", IconName::TrashAlt) .icon_size(IconSize::Small) .shape(IconButtonShape::Square) .size(ButtonSize::Large) diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 4306251e44..e31d3dcfd5 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -1742,7 +1742,7 @@ impl SshRemoteConnection { } }); - cx.background_spawn(async move { + cx.spawn(async move |_| { let result = futures::select! { result = stdin_task.fuse() => { result.context("stdin") diff --git a/crates/repl/src/notebook/cell.rs b/crates/repl/src/notebook/cell.rs index 18851417c0..2ed68c17d1 100644 --- a/crates/repl/src/notebook/cell.rs +++ b/crates/repl/src/notebook/cell.rs @@ -38,7 +38,7 @@ pub enum CellControlType { impl CellControlType { fn icon_name(&self) -> IconName { match self { - CellControlType::RunCell => IconName::PlayOutlined, + CellControlType::RunCell => IconName::Play, CellControlType::RerunCell => IconName::ArrowCircle, CellControlType::ClearCell => IconName::ListX, CellControlType::CellOptions => IconName::Ellipsis, diff --git a/crates/repl/src/notebook/notebook_ui.rs b/crates/repl/src/notebook/notebook_ui.rs index 3e96cc4d11..d14f458fa9 100644 --- a/crates/repl/src/notebook/notebook_ui.rs +++ b/crates/repl/src/notebook/notebook_ui.rs @@ -343,7 +343,7 @@ impl NotebookEditor { .child( Self::render_notebook_control( "run-all-cells", - IconName::PlayOutlined, + IconName::Play, window, cx, ) diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index 2f77b4f3cc..be6a69c23b 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -319,7 +319,7 @@ impl PickerDelegate for RulePickerDelegate { }) .into_any() } else { - IconButton::new("delete-rule", IconName::Trash) + IconButton::new("delete-rule", IconName::TrashAlt) .icon_color(Color::Muted) .icon_size(IconSize::Small) .shape(IconButtonShape::Square) @@ -1163,7 +1163,7 @@ impl RulesLibrary { }) .into_any() } else { - IconButton::new("delete-rule", IconName::Trash) + IconButton::new("delete-rule", IconName::TrashAlt) .icon_size(IconSize::Small) .tooltip(move |window, cx| { Tooltip::for_action( diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index 5d77a95027..c2590ec9b0 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -228,17 +228,16 @@ impl Render for BufferSearchBar { if in_replace { key_context.add("in_replace"); } - let query_border = if self.query_error.is_some() { + let editor_border = if self.query_error.is_some() { Color::Error.color(cx) } else { cx.theme().colors().border }; - let replacement_border = cx.theme().colors().border; let container_width = window.viewport_size().width; let input_width = SearchInputWidth::calc_width(container_width); - let input_base_styles = |border_color| { + let input_base_styles = || { h_flex() .min_w_32() .w(input_width) @@ -247,7 +246,7 @@ impl Render for BufferSearchBar { .pr_1() .py_1() .border_1() - .border_color(border_color) + .border_color(editor_border) .rounded_lg() }; @@ -257,7 +256,7 @@ impl Render for BufferSearchBar { el.child(Label::new("Find in results").color(Color::Hint)) }) .child( - input_base_styles(query_border) + input_base_styles() .id("editor-scroll") .track_scroll(&self.editor_scroll_handle) .child(self.render_text_input(&self.query_editor, color_override, cx)) @@ -431,13 +430,11 @@ impl Render for BufferSearchBar { let replace_line = should_show_replace_input.then(|| { h_flex() .gap_2() - .child( - input_base_styles(replacement_border).child(self.render_text_input( - &self.replacement_editor, - None, - cx, - )), - ) + .child(input_base_styles().child(self.render_text_input( + &self.replacement_editor, + None, + cx, + ))) .child( h_flex() .min_w_64() @@ -703,11 +700,7 @@ impl BufferSearchBar { window: &mut Window, cx: &mut Context<Self>, ) -> Self { - let query_editor = cx.new(|cx| { - let mut editor = Editor::single_line(window, cx); - editor.set_use_autoclose(false); - editor - }); + let query_editor = cx.new(|cx| Editor::single_line(window, cx)); cx.subscribe_in(&query_editor, window, Self::on_query_editor_event) .detach(); let replacement_editor = cx.new(|cx| Editor::single_line(window, cx)); @@ -778,7 +771,6 @@ impl BufferSearchBar { pub fn dismiss(&mut self, _: &Dismiss, window: &mut Window, cx: &mut Context<Self>) { self.dismissed = true; - self.query_error = None; for searchable_item in self.searchable_items_with_matches.keys() { if let Some(searchable_item) = WeakSearchableItemHandle::upgrade(searchable_item.as_ref(), cx) diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 3b9700c5f1..57ca5e56b9 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -195,7 +195,6 @@ pub struct ProjectSearch { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum InputPanel { Query, - Replacement, Exclude, Include, } @@ -1963,7 +1962,7 @@ impl Render for ProjectSearchBar { MultipleInputs, } - let input_base_styles = |base_style: BaseStyle, panel: InputPanel| { + let input_base_styles = |base_style: BaseStyle| { h_flex() .min_w_32() .map(|div| match base_style { @@ -1975,11 +1974,11 @@ impl Render for ProjectSearchBar { .pr_1() .py_1() .border_1() - .border_color(search.border_color_for(panel, cx)) + .border_color(search.border_color_for(InputPanel::Query, cx)) .rounded_lg() }; - let query_column = input_base_styles(BaseStyle::SingleInput, InputPanel::Query) + let query_column = input_base_styles(BaseStyle::SingleInput) .on_action(cx.listener(|this, action, window, cx| this.confirm(action, window, cx))) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) @@ -2168,7 +2167,7 @@ impl Render for ProjectSearchBar { .child(h_flex().min_w_64().child(mode_column).child(matches_column)); let replace_line = search.replace_enabled.then(|| { - let replace_column = input_base_styles(BaseStyle::SingleInput, InputPanel::Replacement) + let replace_column = input_base_styles(BaseStyle::SingleInput) .child(self.render_text_input(&search.replacement_editor, cx)); let focus_handle = search.replacement_editor.read(cx).focus_handle(cx); @@ -2242,7 +2241,7 @@ impl Render for ProjectSearchBar { .gap_2() .w(input_width) .child( - input_base_styles(BaseStyle::MultipleInputs, InputPanel::Include) + input_base_styles(BaseStyle::MultipleInputs) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) })) @@ -2252,7 +2251,7 @@ impl Render for ProjectSearchBar { .child(self.render_text_input(&search.included_files_editor, cx)), ) .child( - input_base_styles(BaseStyle::MultipleInputs, InputPanel::Exclude) + input_base_styles(BaseStyle::MultipleInputs) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) })) diff --git a/crates/settings/src/settings.rs b/crates/settings/src/settings.rs index afd4ea0890..4e6bd94d92 100644 --- a/crates/settings/src/settings.rs +++ b/crates/settings/src/settings.rs @@ -7,7 +7,7 @@ mod settings_json; mod settings_store; mod vscode_import; -use gpui::{App, Global}; +use gpui::App; use rust_embed::RustEmbed; use std::{borrow::Cow, fmt, str}; use util::asset_str; @@ -27,11 +27,6 @@ pub use settings_store::{ }; pub use vscode_import::{VsCodeSettings, VsCodeSettingsSource}; -#[derive(Clone, Debug, PartialEq)] -pub struct ActiveSettingsProfileName(pub String); - -impl Global for ActiveSettingsProfileName {} - #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)] pub struct WorktreeId(usize); @@ -79,7 +74,6 @@ pub fn init(cx: &mut App) { .unwrap(); cx.set_global(settings); BaseKeymap::register(cx); - SettingsStore::observe_active_settings_profile_name(cx).detach(); } pub fn default_settings() -> Cow<'static, str> { diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 7f6437dac8..0d23385a68 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -26,8 +26,8 @@ use util::{ pub type EditorconfigProperties = ec4rs::Properties; use crate::{ - ActiveSettingsProfileName, ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, - WorktreeId, parse_json_with_comments, update_value_in_json_text, + ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, WorktreeId, + parse_json_with_comments, update_value_in_json_text, }; /// A value that can be defined as a user setting. @@ -122,8 +122,6 @@ pub struct SettingsSources<'a, T> { pub user: Option<&'a T>, /// The user settings for the current release channel. pub release_channel: Option<&'a T>, - /// The settings associated with an enabled settings profile - pub profile: Option<&'a T>, /// The server's settings. pub server: Option<&'a T>, /// The project settings, ordered from least specific to most specific. @@ -143,7 +141,6 @@ impl<'a, T: Serialize> SettingsSources<'a, T> { .chain(self.extensions) .chain(self.user) .chain(self.release_channel) - .chain(self.profile) .chain(self.server) .chain(self.project.iter().copied()) } @@ -285,14 +282,6 @@ impl SettingsStore { } } - pub fn observe_active_settings_profile_name(cx: &mut App) -> gpui::Subscription { - cx.observe_global::<ActiveSettingsProfileName>(|cx| { - Self::update_global(cx, |store, cx| { - store.recompute_values(None, cx).log_err(); - }); - }) - } - pub fn update<C, R>(cx: &mut C, f: impl FnOnce(&mut Self, &mut C) -> R) -> R where C: BorrowAppContext, @@ -332,17 +321,6 @@ impl SettingsStore { .log_err(); } - let mut profile_value = None; - if let Some(active_profile) = cx.try_global::<ActiveSettingsProfileName>() { - if let Some(profiles) = self.raw_user_settings.get("profiles") { - if let Some(profile_settings) = profiles.get(&active_profile.0) { - profile_value = setting_value - .deserialize_setting(profile_settings) - .log_err(); - } - } - } - let server_value = self .raw_server_settings .as_ref() @@ -362,7 +340,6 @@ impl SettingsStore { extensions: extension_value.as_ref(), user: user_value.as_ref(), release_channel: release_channel_value.as_ref(), - profile: profile_value.as_ref(), server: server_value.as_ref(), project: &[], }, @@ -425,16 +402,6 @@ impl SettingsStore { &self.raw_user_settings } - /// Get the configured settings profile names. - pub fn configured_settings_profiles(&self) -> impl Iterator<Item = &str> { - self.raw_user_settings - .get("profiles") - .and_then(|v| v.as_object()) - .into_iter() - .flat_map(|obj| obj.keys()) - .map(|s| s.as_str()) - } - /// Access the raw JSON value of the global settings. pub fn raw_global_settings(&self) -> Option<&Value> { self.raw_global_settings.as_ref() @@ -565,9 +532,7 @@ impl SettingsStore { })) .ok(); } -} -impl SettingsStore { /// Updates the value of a setting in a JSON file, returning the new text /// for that JSON file. pub fn new_text_for_update<T: Settings>( @@ -1036,18 +1001,18 @@ impl SettingsStore { const ZED_SETTINGS: &str = "ZedSettings"; let zed_settings_ref = add_new_subschema(&mut generator, ZED_SETTINGS, combined_schema); - // add `ZedSettingsOverride` which is the same as `ZedSettings` except that unknown - // fields are rejected. This is used for release stage settings and profiles. - let mut zed_settings_override = zed_settings_ref.clone(); - zed_settings_override.insert("unevaluatedProperties".to_string(), false.into()); - let zed_settings_override_ref = add_new_subschema( + // add `ZedReleaseStageSettings` which is the same as `ZedSettings` except that unknown + // fields are rejected. + let mut zed_release_stage_settings = zed_settings_ref.clone(); + zed_release_stage_settings.insert("unevaluatedProperties".to_string(), false.into()); + let zed_release_stage_settings_ref = add_new_subschema( &mut generator, - "ZedSettingsOverride", - zed_settings_override.to_value(), + "ZedReleaseStageSettings", + zed_release_stage_settings.to_value(), ); // Remove `"additionalProperties": false` added by `DefaultDenyUnknownFields` so that - // unknown fields can be handled by the root schema and `ZedSettingsOverride`. + // unknown fields can be handled by the root schema and `ZedReleaseStageSettings`. let mut definitions = generator.take_definitions(true); definitions .get_mut(ZED_SETTINGS) @@ -1067,20 +1032,15 @@ impl SettingsStore { "$schema": meta_schema, "title": "Zed Settings", "unevaluatedProperties": false, - // ZedSettings + settings overrides for each release stage / profiles + // ZedSettings + settings overrides for each release stage "allOf": [ zed_settings_ref, { "properties": { - "dev": zed_settings_override_ref, - "nightly": zed_settings_override_ref, - "stable": zed_settings_override_ref, - "preview": zed_settings_override_ref, - "profiles": { - "type": "object", - "description": "Configures any number of settings profiles.", - "additionalProperties": zed_settings_override_ref - } + "dev": zed_release_stage_settings_ref, + "nightly": zed_release_stage_settings_ref, + "stable": zed_release_stage_settings_ref, + "preview": zed_release_stage_settings_ref, } } ], @@ -1139,16 +1099,6 @@ impl SettingsStore { } } - let mut profile_settings = None; - if let Some(active_profile) = cx.try_global::<ActiveSettingsProfileName>() { - if let Some(profiles) = self.raw_user_settings.get("profiles") { - if let Some(profile_json) = profiles.get(&active_profile.0) { - profile_settings = - setting_value.deserialize_setting(profile_json).log_err(); - } - } - } - // If the global settings file changed, reload the global value for the field. if changed_local_path.is_none() { if let Some(value) = setting_value @@ -1159,7 +1109,6 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), - profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &[], }, @@ -1212,7 +1161,6 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), - profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &project_settings_stack.iter().collect::<Vec<_>>(), }, @@ -1338,9 +1286,6 @@ impl<T: Settings> AnySettingValue for SettingValue<T> { release_channel: values .release_channel .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), - profile: values - .profile - .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), server: values .server .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), diff --git a/crates/settings_profile_selector/Cargo.toml b/crates/settings_profile_selector/Cargo.toml deleted file mode 100644 index 189272e54b..0000000000 --- a/crates/settings_profile_selector/Cargo.toml +++ /dev/null @@ -1,35 +0,0 @@ -[package] -name = "settings_profile_selector" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/settings_profile_selector.rs" -doctest = false - -[dependencies] -fuzzy.workspace = true -gpui.workspace = true -picker.workspace = true -settings.workspace = true -ui.workspace = true -workspace-hack.workspace = true -workspace.workspace = true -zed_actions.workspace = true - -[dev-dependencies] -client = { workspace = true, features = ["test-support"] } -editor = { workspace = true, features = ["test-support"] } -gpui = { workspace = true, features = ["test-support"] } -language = { workspace = true, features = ["test-support"] } -menu.workspace = true -project = { workspace = true, features = ["test-support"] } -serde_json.workspace = true -settings = { workspace = true, features = ["test-support"] } -theme = { workspace = true, features = ["test-support"] } -workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_profile_selector/LICENSE-GPL b/crates/settings_profile_selector/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/settings_profile_selector/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/settings_profile_selector/src/settings_profile_selector.rs b/crates/settings_profile_selector/src/settings_profile_selector.rs deleted file mode 100644 index 8a34c12051..0000000000 --- a/crates/settings_profile_selector/src/settings_profile_selector.rs +++ /dev/null @@ -1,581 +0,0 @@ -use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; -use gpui::{ - App, Context, DismissEvent, Entity, EventEmitter, Focusable, Render, Task, WeakEntity, Window, -}; -use picker::{Picker, PickerDelegate}; -use settings::{ActiveSettingsProfileName, SettingsStore}; -use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*}; -use workspace::{ModalView, Workspace}; - -pub fn init(cx: &mut App) { - cx.on_action(|_: &zed_actions::settings_profile_selector::Toggle, cx| { - workspace::with_active_or_new_workspace(cx, |workspace, window, cx| { - toggle_settings_profile_selector(workspace, window, cx); - }); - }); -} - -fn toggle_settings_profile_selector( - workspace: &mut Workspace, - window: &mut Window, - cx: &mut Context<Workspace>, -) { - workspace.toggle_modal(window, cx, |window, cx| { - let delegate = SettingsProfileSelectorDelegate::new(cx.entity().downgrade(), window, cx); - SettingsProfileSelector::new(delegate, window, cx) - }); -} - -pub struct SettingsProfileSelector { - picker: Entity<Picker<SettingsProfileSelectorDelegate>>, -} - -impl ModalView for SettingsProfileSelector {} - -impl EventEmitter<DismissEvent> for SettingsProfileSelector {} - -impl Focusable for SettingsProfileSelector { - fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { - self.picker.focus_handle(cx) - } -} - -impl Render for SettingsProfileSelector { - fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - v_flex().w(rems(22.)).child(self.picker.clone()) - } -} - -impl SettingsProfileSelector { - pub fn new( - delegate: SettingsProfileSelectorDelegate, - window: &mut Window, - cx: &mut Context<Self>, - ) -> Self { - let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); - Self { picker } - } -} - -pub struct SettingsProfileSelectorDelegate { - matches: Vec<StringMatch>, - profile_names: Vec<Option<String>>, - original_profile_name: Option<String>, - selected_profile_name: Option<String>, - selected_index: usize, - selection_completed: bool, - selector: WeakEntity<SettingsProfileSelector>, -} - -impl SettingsProfileSelectorDelegate { - fn new( - selector: WeakEntity<SettingsProfileSelector>, - _: &mut Window, - cx: &mut Context<SettingsProfileSelector>, - ) -> Self { - let settings_store = cx.global::<SettingsStore>(); - let mut profile_names: Vec<Option<String>> = settings_store - .configured_settings_profiles() - .map(|s| Some(s.to_string())) - .collect(); - profile_names.insert(0, None); - - let matches = profile_names - .iter() - .enumerate() - .map(|(ix, profile_name)| StringMatch { - candidate_id: ix, - score: 0.0, - positions: Default::default(), - string: display_name(profile_name), - }) - .collect(); - - let profile_name = cx - .try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()); - - let mut this = Self { - matches, - profile_names, - original_profile_name: profile_name.clone(), - selected_profile_name: None, - selected_index: 0, - selection_completed: false, - selector, - }; - - if let Some(profile_name) = profile_name { - this.select_if_matching(&profile_name); - } - - this - } - - fn select_if_matching(&mut self, profile_name: &str) { - self.selected_index = self - .matches - .iter() - .position(|mat| mat.string == profile_name) - .unwrap_or(self.selected_index); - } - - fn set_selected_profile( - &self, - cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, - ) -> Option<String> { - let mat = self.matches.get(self.selected_index)?; - let profile_name = self.profile_names.get(mat.candidate_id)?; - return Self::update_active_profile_name_global(profile_name.clone(), cx); - } - - fn update_active_profile_name_global( - profile_name: Option<String>, - cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, - ) -> Option<String> { - if let Some(profile_name) = profile_name { - cx.set_global(ActiveSettingsProfileName(profile_name.clone())); - return Some(profile_name.clone()); - } - - if cx.has_global::<ActiveSettingsProfileName>() { - cx.remove_global::<ActiveSettingsProfileName>(); - } - - None - } -} - -impl PickerDelegate for SettingsProfileSelectorDelegate { - type ListItem = ListItem; - - fn placeholder_text(&self, _: &mut Window, _: &mut App) -> std::sync::Arc<str> { - "Select a settings profile...".into() - } - - fn match_count(&self) -> usize { - self.matches.len() - } - - fn selected_index(&self) -> usize { - self.selected_index - } - - fn set_selected_index( - &mut self, - ix: usize, - _: &mut Window, - cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, - ) { - self.selected_index = ix; - self.selected_profile_name = self.set_selected_profile(cx); - } - - fn update_matches( - &mut self, - query: String, - window: &mut Window, - cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, - ) -> Task<()> { - let background = cx.background_executor().clone(); - let candidates = self - .profile_names - .iter() - .enumerate() - .map(|(id, profile_name)| StringMatchCandidate::new(id, &display_name(profile_name))) - .collect::<Vec<_>>(); - - cx.spawn_in(window, async move |this, cx| { - let matches = if query.is_empty() { - candidates - .into_iter() - .enumerate() - .map(|(index, candidate)| StringMatch { - candidate_id: index, - string: candidate.string, - positions: Vec::new(), - score: 0.0, - }) - .collect() - } else { - match_strings( - &candidates, - &query, - false, - true, - 100, - &Default::default(), - background, - ) - .await - }; - - this.update_in(cx, |this, _, cx| { - this.delegate.matches = matches; - this.delegate.selected_index = this - .delegate - .selected_index - .min(this.delegate.matches.len().saturating_sub(1)); - this.delegate.selected_profile_name = this.delegate.set_selected_profile(cx); - }) - .ok(); - }) - } - - fn confirm( - &mut self, - _: bool, - _: &mut Window, - cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, - ) { - self.selection_completed = true; - self.selector - .update(cx, |_, cx| { - cx.emit(DismissEvent); - }) - .ok(); - } - - fn dismissed( - &mut self, - _: &mut Window, - cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, - ) { - if !self.selection_completed { - SettingsProfileSelectorDelegate::update_active_profile_name_global( - self.original_profile_name.clone(), - cx, - ); - } - self.selector.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); - } - - fn render_match( - &self, - ix: usize, - selected: bool, - _: &mut Window, - _: &mut Context<Picker<Self>>, - ) -> Option<Self::ListItem> { - let mat = &self.matches[ix]; - let profile_name = &self.profile_names[mat.candidate_id]; - - Some( - ListItem::new(ix) - .inset(true) - .spacing(ListItemSpacing::Sparse) - .toggle_state(selected) - .child(HighlightedLabel::new( - display_name(profile_name), - mat.positions.clone(), - )), - ) - } -} - -fn display_name(profile_name: &Option<String>) -> String { - profile_name.clone().unwrap_or("Disabled".into()) -} - -#[cfg(test)] -mod tests { - use super::*; - use client; - use editor; - use gpui::{TestAppContext, UpdateGlobal, VisualTestContext}; - use language; - use menu::{Cancel, Confirm, SelectNext, SelectPrevious}; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::Settings; - use theme::{self, ThemeSettings}; - use workspace::{self, AppState}; - use zed_actions::settings_profile_selector; - - async fn init_test( - profiles_json: serde_json::Value, - cx: &mut TestAppContext, - ) -> (Entity<Workspace>, &mut VisualTestContext) { - cx.update(|cx| { - let state = AppState::test(cx); - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - settings::init(cx); - theme::init(theme::LoadThemes::JustBase, cx); - ThemeSettings::register(cx); - client::init_settings(cx); - language::init(cx); - super::init(cx); - editor::init(cx); - workspace::init_settings(cx); - Project::init_settings(cx); - state - }); - - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - let settings_json = json!({ - "buffer_font_size": 10.0, - "profiles": profiles_json, - }); - - store - .set_user_settings(&settings_json.to_string(), cx) - .unwrap(); - }); - }); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, ["/test".as_ref()], cx).await; - let (workspace, cx) = - cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - - cx.update(|_, cx| { - assert!(!cx.has_global::<ActiveSettingsProfileName>()); - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); - }); - - (workspace, cx) - } - - #[track_caller] - fn active_settings_profile_picker( - workspace: &Entity<Workspace>, - cx: &mut VisualTestContext, - ) -> Entity<Picker<SettingsProfileSelectorDelegate>> { - workspace.update(cx, |workspace, cx| { - workspace - .active_modal::<SettingsProfileSelector>(cx) - .expect("settings profile selector is not open") - .read(cx) - .picker - .clone() - }) - } - - #[gpui::test] - async fn test_settings_profile_selector_state(cx: &mut TestAppContext) { - let classroom_and_streaming_profile_name = "Classroom / Streaming".to_string(); - let demo_videos_profile_name = "Demo Videos".to_string(); - - let profiles_json = json!({ - classroom_and_streaming_profile_name.clone(): { - "buffer_font_size": 20.0, - }, - demo_videos_profile_name.clone(): { - "buffer_font_size": 15.0 - } - }); - let (workspace, cx) = init_test(profiles_json.clone(), cx).await; - - cx.dispatch_action(settings_profile_selector::Toggle); - let picker = active_settings_profile_picker(&workspace, cx); - - picker.read_with(cx, |picker, cx| { - assert_eq!(picker.delegate.matches.len(), 3); - assert_eq!(picker.delegate.matches[0].string, display_name(&None)); - assert_eq!( - picker.delegate.matches[1].string, - classroom_and_streaming_profile_name - ); - assert_eq!(picker.delegate.matches[2].string, demo_videos_profile_name); - assert_eq!(picker.delegate.matches.get(3), None); - - assert_eq!(picker.delegate.selected_index, 0); - assert_eq!(picker.delegate.selected_profile_name, None); - - assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); - }); - - cx.dispatch_action(Confirm); - - cx.update(|_, cx| { - assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); - }); - - cx.dispatch_action(settings_profile_selector::Toggle); - let picker = active_settings_profile_picker(&workspace, cx); - cx.dispatch_action(SelectNext); - - picker.read_with(cx, |picker, cx| { - assert_eq!(picker.delegate.selected_index, 1); - assert_eq!( - picker.delegate.selected_profile_name, - Some(classroom_and_streaming_profile_name.clone()) - ); - - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - Some(classroom_and_streaming_profile_name.clone()) - ); - - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); - }); - - cx.dispatch_action(Cancel); - - cx.update(|_, cx| { - assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); - }); - - cx.dispatch_action(settings_profile_selector::Toggle); - let picker = active_settings_profile_picker(&workspace, cx); - - cx.dispatch_action(SelectNext); - - picker.read_with(cx, |picker, cx| { - assert_eq!(picker.delegate.selected_index, 1); - assert_eq!( - picker.delegate.selected_profile_name, - Some(classroom_and_streaming_profile_name.clone()) - ); - - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - Some(classroom_and_streaming_profile_name.clone()) - ); - - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); - }); - - cx.dispatch_action(SelectNext); - - picker.read_with(cx, |picker, cx| { - assert_eq!(picker.delegate.selected_index, 2); - assert_eq!( - picker.delegate.selected_profile_name, - Some(demo_videos_profile_name.clone()) - ); - - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - Some(demo_videos_profile_name.clone()) - ); - - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); - }); - - cx.dispatch_action(Confirm); - - cx.update(|_, cx| { - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - Some(demo_videos_profile_name.clone()) - ); - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); - }); - - cx.dispatch_action(settings_profile_selector::Toggle); - let picker = active_settings_profile_picker(&workspace, cx); - - picker.read_with(cx, |picker, cx| { - assert_eq!(picker.delegate.selected_index, 2); - assert_eq!( - picker.delegate.selected_profile_name, - Some(demo_videos_profile_name.clone()) - ); - - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - Some(demo_videos_profile_name.clone()) - ); - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); - }); - - cx.dispatch_action(SelectPrevious); - - picker.read_with(cx, |picker, cx| { - assert_eq!(picker.delegate.selected_index, 1); - assert_eq!( - picker.delegate.selected_profile_name, - Some(classroom_and_streaming_profile_name.clone()) - ); - - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - Some(classroom_and_streaming_profile_name.clone()) - ); - - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); - }); - - cx.dispatch_action(Cancel); - - cx.update(|_, cx| { - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - Some(demo_videos_profile_name.clone()) - ); - - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); - }); - - cx.dispatch_action(settings_profile_selector::Toggle); - let picker = active_settings_profile_picker(&workspace, cx); - - picker.read_with(cx, |picker, cx| { - assert_eq!(picker.delegate.selected_index, 2); - assert_eq!( - picker.delegate.selected_profile_name, - Some(demo_videos_profile_name.clone()) - ); - - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - Some(demo_videos_profile_name) - ); - - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); - }); - - cx.dispatch_action(SelectPrevious); - - picker.read_with(cx, |picker, cx| { - assert_eq!(picker.delegate.selected_index, 1); - assert_eq!( - picker.delegate.selected_profile_name, - Some(classroom_and_streaming_profile_name.clone()) - ); - - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - Some(classroom_and_streaming_profile_name) - ); - - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); - }); - - cx.dispatch_action(SelectPrevious); - - picker.read_with(cx, |picker, cx| { - assert_eq!(picker.delegate.selected_index, 0); - assert_eq!(picker.delegate.selected_profile_name, None); - - assert_eq!( - cx.try_global::<ActiveSettingsProfileName>() - .map(|p| p.0.clone()), - None - ); - - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); - }); - - cx.dispatch_action(Confirm); - - cx.update(|_, cx| { - assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); - assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); - }); - } -} diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index a4c47081c6..02327045fd 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -30,6 +30,7 @@ menu.workspace = true notifications.workspace = true paths.workspace = true project.workspace = true +schemars.workspace = true search.workspace = true serde.workspace = true serde_json.workspace = true @@ -44,10 +45,3 @@ ui_input.workspace = true util.workspace = true workspace-hack.workspace = true workspace.workspace = true - -[dev-dependencies] -db = {"workspace"= true, "features" = ["test-support"]} -fs = { workspace = true, features = ["test-support"] } -gpui = { workspace = true, features = ["test-support"] } -project = { workspace = true, features = ["test-support"] } -workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_ui/src/keybindings.rs b/crates/settings_ui/src/keybindings.rs index 70afe1729c..9da7242e36 100644 --- a/crates/settings_ui/src/keybindings.rs +++ b/crates/settings_ui/src/keybindings.rs @@ -11,10 +11,11 @@ use editor::{CompletionProvider, Editor, EditorEvent}; use fs::Fs; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ - Action, AppContext as _, AsyncApp, Axis, ClickEvent, Context, DismissEvent, Entity, - EventEmitter, FocusHandle, Focusable, Global, IsZero, KeyContext, Keystroke, MouseButton, - Point, ScrollStrategy, ScrollWheelEvent, Stateful, StyledText, Subscription, Task, - TextStyleRefinement, WeakEntity, actions, anchored, deferred, div, + Action, Animation, AnimationExt, AppContext as _, AsyncApp, Axis, ClickEvent, Context, + DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, Global, IsZero, + KeyContext, Keystroke, Modifiers, ModifiersChangedEvent, MouseButton, Point, ScrollStrategy, + ScrollWheelEvent, Stateful, StyledText, Subscription, Task, TextStyleRefinement, WeakEntity, + actions, anchored, deferred, div, }; use language::{Language, LanguageConfig, ToOffset as _}; use notifications::status_toast::{StatusToast, ToastIcon}; @@ -34,10 +35,7 @@ use workspace::{ use crate::{ keybindings::persistence::KEYBINDING_EDITORS, - ui_components::{ - keystroke_input::{ClearKeystrokes, KeystrokeInput, StartRecording, StopRecording}, - table::{ColumnWidths, ResizeBehavior, Table, TableInteractionState}, - }, + ui_components::table::{ColumnWidths, ResizeBehavior, Table, TableInteractionState}, }; const NO_ACTION_ARGUMENTS_TEXT: SharedString = SharedString::new_static("<no arguments>"); @@ -74,6 +72,18 @@ actions!( ] ); +actions!( + keystroke_input, + [ + /// Starts recording keystrokes + StartRecording, + /// Stops recording keystrokes + StopRecording, + /// Clears the recorded keystrokes + ClearKeystrokes, + ] +); + pub fn init(cx: &mut App) { let keymap_event_channel = KeymapEventChannel::new(); cx.set_global(keymap_event_channel); @@ -383,7 +393,7 @@ impl KeymapEditor { let keystroke_editor = cx.new(|cx| { let mut keystroke_editor = KeystrokeInput::new(None, window, cx); - keystroke_editor.set_search(true); + keystroke_editor.search = true; keystroke_editor }); @@ -1680,7 +1690,7 @@ impl Render for KeymapEditor { move |window, cx| this.read(cx).render_no_matches_hint(window, cx) }) .column_widths([ - DefiniteLength::Absolute(AbsoluteLength::Pixels(px(36.))), + DefiniteLength::Absolute(AbsoluteLength::Pixels(px(40.))), DefiniteLength::Fraction(0.25), DefiniteLength::Fraction(0.20), DefiniteLength::Fraction(0.14), @@ -1755,7 +1765,6 @@ impl Render for KeymapEditor { }, ) .into_any_element(); - let keystrokes = binding.ui_key_binding().cloned().map_or( binding .keystroke_text() @@ -1764,7 +1773,6 @@ impl Render for KeymapEditor { .into_any_element(), IntoElement::into_any_element, ); - let action_arguments = match binding.action().arguments.clone() { Some(arguments) => arguments.into_any_element(), @@ -1777,7 +1785,6 @@ impl Render for KeymapEditor { } } }; - let context = binding.context().cloned().map_or( gpui::Empty.into_any_element(), |context| { @@ -1802,13 +1809,11 @@ impl Render for KeymapEditor { .into_any_element() }, ); - let source = binding .keybind_source() .map(|source| source.name()) .unwrap_or_default() .into_any_element(); - Some([ icon.into_any_element(), action, @@ -2969,6 +2974,524 @@ async fn remove_keybinding( Ok(()) } +#[derive(PartialEq, Eq, Debug, Copy, Clone)] +enum CloseKeystrokeResult { + Partial, + Close, + None, +} + +struct KeystrokeInput { + keystrokes: Vec<Keystroke>, + placeholder_keystrokes: Option<Vec<Keystroke>>, + outer_focus_handle: FocusHandle, + inner_focus_handle: FocusHandle, + intercept_subscription: Option<Subscription>, + _focus_subscriptions: [Subscription; 2], + search: bool, + /// Handles tripe escape to stop recording + close_keystrokes: Option<Vec<Keystroke>>, + close_keystrokes_start: Option<usize>, + previous_modifiers: Modifiers, +} + +impl KeystrokeInput { + const KEYSTROKE_COUNT_MAX: usize = 3; + + fn new( + placeholder_keystrokes: Option<Vec<Keystroke>>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let outer_focus_handle = cx.focus_handle(); + let inner_focus_handle = cx.focus_handle(); + let _focus_subscriptions = [ + cx.on_focus_in(&inner_focus_handle, window, Self::on_inner_focus_in), + cx.on_focus_out(&inner_focus_handle, window, Self::on_inner_focus_out), + ]; + Self { + keystrokes: Vec::new(), + placeholder_keystrokes, + inner_focus_handle, + outer_focus_handle, + intercept_subscription: None, + _focus_subscriptions, + search: false, + close_keystrokes: None, + close_keystrokes_start: None, + previous_modifiers: Modifiers::default(), + } + } + + fn set_keystrokes(&mut self, keystrokes: Vec<Keystroke>, cx: &mut Context<Self>) { + self.keystrokes = keystrokes; + self.keystrokes_changed(cx); + } + + fn dummy(modifiers: Modifiers) -> Keystroke { + return Keystroke { + modifiers, + key: "".to_string(), + key_char: None, + }; + } + + fn keystrokes_changed(&self, cx: &mut Context<Self>) { + cx.emit(()); + cx.notify(); + } + + fn key_context() -> KeyContext { + let mut key_context = KeyContext::default(); + key_context.add("KeystrokeInput"); + key_context + } + + fn handle_possible_close_keystroke( + &mut self, + keystroke: &Keystroke, + window: &mut Window, + cx: &mut Context<Self>, + ) -> CloseKeystrokeResult { + let Some(keybind_for_close_action) = window + .highest_precedence_binding_for_action_in_context(&StopRecording, Self::key_context()) + else { + log::trace!("No keybinding to stop recording keystrokes in keystroke input"); + self.close_keystrokes.take(); + self.close_keystrokes_start.take(); + return CloseKeystrokeResult::None; + }; + let action_keystrokes = keybind_for_close_action.keystrokes(); + + if let Some(mut close_keystrokes) = self.close_keystrokes.take() { + let mut index = 0; + + while index < action_keystrokes.len() && index < close_keystrokes.len() { + if !close_keystrokes[index].should_match(&action_keystrokes[index]) { + break; + } + index += 1; + } + if index == close_keystrokes.len() { + if index >= action_keystrokes.len() { + self.close_keystrokes_start.take(); + return CloseKeystrokeResult::None; + } + if keystroke.should_match(&action_keystrokes[index]) { + if action_keystrokes.len() >= 1 && index == action_keystrokes.len() - 1 { + self.stop_recording(&StopRecording, window, cx); + return CloseKeystrokeResult::Close; + } else { + close_keystrokes.push(keystroke.clone()); + self.close_keystrokes = Some(close_keystrokes); + return CloseKeystrokeResult::Partial; + } + } else { + self.close_keystrokes_start.take(); + return CloseKeystrokeResult::None; + } + } + } else if let Some(first_action_keystroke) = action_keystrokes.first() + && keystroke.should_match(first_action_keystroke) + { + self.close_keystrokes = Some(vec![keystroke.clone()]); + return CloseKeystrokeResult::Partial; + } + self.close_keystrokes_start.take(); + return CloseKeystrokeResult::None; + } + + fn on_modifiers_changed( + &mut self, + event: &ModifiersChangedEvent, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + let keystrokes_len = self.keystrokes.len(); + + if self.previous_modifiers.modified() + && event.modifiers.is_subset_of(&self.previous_modifiers) + { + self.previous_modifiers &= event.modifiers; + cx.stop_propagation(); + return; + } + + if let Some(last) = self.keystrokes.last_mut() + && last.key.is_empty() + && keystrokes_len <= Self::KEYSTROKE_COUNT_MAX + { + if self.search { + if self.previous_modifiers.modified() { + last.modifiers |= event.modifiers; + self.previous_modifiers |= event.modifiers; + } else { + self.keystrokes.push(Self::dummy(event.modifiers)); + self.previous_modifiers |= event.modifiers; + } + } else if !event.modifiers.modified() { + self.keystrokes.pop(); + } else { + last.modifiers = event.modifiers; + } + + self.keystrokes_changed(cx); + } else if keystrokes_len < Self::KEYSTROKE_COUNT_MAX { + self.keystrokes.push(Self::dummy(event.modifiers)); + if self.search { + self.previous_modifiers |= event.modifiers; + } + self.keystrokes_changed(cx); + } + cx.stop_propagation(); + } + + fn handle_keystroke( + &mut self, + keystroke: &Keystroke, + window: &mut Window, + cx: &mut Context<Self>, + ) { + let close_keystroke_result = self.handle_possible_close_keystroke(keystroke, window, cx); + if close_keystroke_result != CloseKeystrokeResult::Close { + let key_len = self.keystrokes.len(); + if let Some(last) = self.keystrokes.last_mut() + && last.key.is_empty() + && key_len <= Self::KEYSTROKE_COUNT_MAX + { + if self.search { + last.key = keystroke.key.clone(); + if close_keystroke_result == CloseKeystrokeResult::Partial + && self.close_keystrokes_start.is_none() + { + self.close_keystrokes_start = Some(self.keystrokes.len() - 1); + } + if self.search { + self.previous_modifiers = keystroke.modifiers; + } + self.keystrokes_changed(cx); + cx.stop_propagation(); + return; + } else { + self.keystrokes.pop(); + } + } + if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX { + if close_keystroke_result == CloseKeystrokeResult::Partial + && self.close_keystrokes_start.is_none() + { + self.close_keystrokes_start = Some(self.keystrokes.len()); + } + self.keystrokes.push(keystroke.clone()); + if self.search { + self.previous_modifiers = keystroke.modifiers; + } else if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX { + self.keystrokes.push(Self::dummy(keystroke.modifiers)); + } + } else if close_keystroke_result != CloseKeystrokeResult::Partial { + self.clear_keystrokes(&ClearKeystrokes, window, cx); + } + } + self.keystrokes_changed(cx); + cx.stop_propagation(); + } + + fn on_inner_focus_in(&mut self, _window: &mut Window, cx: &mut Context<Self>) { + if self.intercept_subscription.is_none() { + let listener = cx.listener(|this, event: &gpui::KeystrokeEvent, window, cx| { + this.handle_keystroke(&event.keystroke, window, cx); + }); + self.intercept_subscription = Some(cx.intercept_keystrokes(listener)) + } + } + + fn on_inner_focus_out( + &mut self, + _event: gpui::FocusOutEvent, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + self.intercept_subscription.take(); + cx.notify(); + } + + fn keystrokes(&self) -> &[Keystroke] { + if let Some(placeholders) = self.placeholder_keystrokes.as_ref() + && self.keystrokes.is_empty() + { + return placeholders; + } + if !self.search + && self + .keystrokes + .last() + .map_or(false, |last| last.key.is_empty()) + { + return &self.keystrokes[..self.keystrokes.len() - 1]; + } + return &self.keystrokes; + } + + fn render_keystrokes(&self, is_recording: bool) -> impl Iterator<Item = Div> { + let keystrokes = if let Some(placeholders) = self.placeholder_keystrokes.as_ref() + && self.keystrokes.is_empty() + { + if is_recording { + &[] + } else { + placeholders.as_slice() + } + } else { + &self.keystrokes + }; + keystrokes.iter().map(move |keystroke| { + h_flex().children(ui::render_keystroke( + keystroke, + Some(Color::Default), + Some(rems(0.875).into()), + ui::PlatformStyle::platform(), + false, + )) + }) + } + + fn start_recording(&mut self, _: &StartRecording, window: &mut Window, cx: &mut Context<Self>) { + window.focus(&self.inner_focus_handle); + self.clear_keystrokes(&ClearKeystrokes, window, cx); + self.previous_modifiers = window.modifiers(); + cx.stop_propagation(); + } + + fn stop_recording(&mut self, _: &StopRecording, window: &mut Window, cx: &mut Context<Self>) { + if !self.inner_focus_handle.is_focused(window) { + return; + } + window.focus(&self.outer_focus_handle); + if let Some(close_keystrokes_start) = self.close_keystrokes_start.take() + && close_keystrokes_start < self.keystrokes.len() + { + self.keystrokes.drain(close_keystrokes_start..); + } + self.close_keystrokes.take(); + cx.notify(); + } + + fn clear_keystrokes( + &mut self, + _: &ClearKeystrokes, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + self.keystrokes.clear(); + self.keystrokes_changed(cx); + } +} + +impl EventEmitter<()> for KeystrokeInput {} + +impl Focusable for KeystrokeInput { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.outer_focus_handle.clone() + } +} + +impl Render for KeystrokeInput { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let colors = cx.theme().colors(); + let is_focused = self.outer_focus_handle.contains_focused(window, cx); + let is_recording = self.inner_focus_handle.is_focused(window); + + let horizontal_padding = rems_from_px(64.); + + let recording_bg_color = colors + .editor_background + .blend(colors.text_accent.opacity(0.1)); + + let recording_pulse = |color: Color| { + Icon::new(IconName::Circle) + .size(IconSize::Small) + .color(Color::Error) + .with_animation( + "recording-pulse", + Animation::new(std::time::Duration::from_secs(2)) + .repeat() + .with_easing(gpui::pulsating_between(0.4, 0.8)), + { + let color = color.color(cx); + move |this, delta| this.color(Color::Custom(color.opacity(delta))) + }, + ) + }; + + let recording_indicator = h_flex() + .h_4() + .pr_1() + .gap_0p5() + .border_1() + .border_color(colors.border) + .bg(colors + .editor_background + .blend(colors.text_accent.opacity(0.1))) + .rounded_sm() + .child(recording_pulse(Color::Error)) + .child( + Label::new("REC") + .size(LabelSize::XSmall) + .weight(FontWeight::SEMIBOLD) + .color(Color::Error), + ); + + let search_indicator = h_flex() + .h_4() + .pr_1() + .gap_0p5() + .border_1() + .border_color(colors.border) + .bg(colors + .editor_background + .blend(colors.text_accent.opacity(0.1))) + .rounded_sm() + .child(recording_pulse(Color::Accent)) + .child( + Label::new("SEARCH") + .size(LabelSize::XSmall) + .weight(FontWeight::SEMIBOLD) + .color(Color::Accent), + ); + + let record_icon = if self.search { + IconName::MagnifyingGlass + } else { + IconName::PlayFilled + }; + + h_flex() + .id("keystroke-input") + .track_focus(&self.outer_focus_handle) + .py_2() + .px_3() + .gap_2() + .min_h_10() + .w_full() + .flex_1() + .justify_between() + .rounded_lg() + .overflow_hidden() + .map(|this| { + if is_recording { + this.bg(recording_bg_color) + } else { + this.bg(colors.editor_background) + } + }) + .border_1() + .border_color(colors.border_variant) + .when(is_focused, |parent| { + parent.border_color(colors.border_focused) + }) + .key_context(Self::key_context()) + .on_action(cx.listener(Self::start_recording)) + .on_action(cx.listener(Self::clear_keystrokes)) + .child( + h_flex() + .w(horizontal_padding) + .gap_0p5() + .justify_start() + .flex_none() + .when(is_recording, |this| { + this.map(|this| { + if self.search { + this.child(search_indicator) + } else { + this.child(recording_indicator) + } + }) + }), + ) + .child( + h_flex() + .id("keystroke-input-inner") + .track_focus(&self.inner_focus_handle) + .on_modifiers_changed(cx.listener(Self::on_modifiers_changed)) + .size_full() + .when(!self.search, |this| { + this.focus(|mut style| { + style.border_color = Some(colors.border_focused); + style + }) + }) + .w_full() + .min_w_0() + .justify_center() + .flex_wrap() + .gap(ui::DynamicSpacing::Base04.rems(cx)) + .children(self.render_keystrokes(is_recording)), + ) + .child( + h_flex() + .w(horizontal_padding) + .gap_0p5() + .justify_end() + .flex_none() + .map(|this| { + if is_recording { + this.child( + IconButton::new("stop-record-btn", IconName::StopFilled) + .shape(ui::IconButtonShape::Square) + .map(|this| { + this.tooltip(Tooltip::for_action_title( + if self.search { + "Stop Searching" + } else { + "Stop Recording" + }, + &StopRecording, + )) + }) + .icon_color(Color::Error) + .on_click(cx.listener(|this, _event, window, cx| { + this.stop_recording(&StopRecording, window, cx); + })), + ) + } else { + this.child( + IconButton::new("record-btn", record_icon) + .shape(ui::IconButtonShape::Square) + .map(|this| { + this.tooltip(Tooltip::for_action_title( + if self.search { + "Start Searching" + } else { + "Start Recording" + }, + &StartRecording, + )) + }) + .when(!is_focused, |this| this.icon_color(Color::Muted)) + .on_click(cx.listener(|this, _event, window, cx| { + this.start_recording(&StartRecording, window, cx); + })), + ) + } + }) + .child( + IconButton::new("clear-btn", IconName::Delete) + .shape(ui::IconButtonShape::Square) + .tooltip(Tooltip::for_action_title( + "Clear Keystrokes", + &ClearKeystrokes, + )) + .when(!is_recording || !is_focused, |this| { + this.icon_color(Color::Muted) + }) + .on_click(cx.listener(|this, _event, window, cx| { + this.clear_keystrokes(&ClearKeystrokes, window, cx); + })), + ), + ) + } +} + fn collect_contexts_from_assets() -> Vec<SharedString> { let mut keymap_assets = vec![ util::asset_str::<SettingsAssets>(settings::DEFAULT_KEYMAP_PATH), diff --git a/crates/settings_ui/src/settings_ui.rs b/crates/settings_ui/src/settings_ui.rs index 3022cc7142..2f0abb4789 100644 --- a/crates/settings_ui/src/settings_ui.rs +++ b/crates/settings_ui/src/settings_ui.rs @@ -1,12 +1,20 @@ mod appearance_settings_controls; use std::any::TypeId; +use std::sync::Arc; use command_palette_hooks::CommandPaletteFilter; use editor::EditorSettingsControls; use feature_flags::{FeatureFlag, FeatureFlagViewExt}; -use gpui::{App, Entity, EventEmitter, FocusHandle, Focusable, actions}; +use fs::Fs; +use gpui::{ + Action, App, AsyncWindowContext, Entity, EventEmitter, FocusHandle, Focusable, Task, actions, +}; +use schemars::JsonSchema; +use serde::Deserialize; +use settings::{SettingsStore, VsCodeSettingsSource}; use ui::prelude::*; +use util::truncate_and_remove_front; use workspace::item::{Item, ItemEvent}; use workspace::{Workspace, with_active_or_new_workspace}; @@ -21,6 +29,23 @@ impl FeatureFlag for SettingsUiFeatureFlag { const NAME: &'static str = "settings-ui"; } +/// Imports settings from Visual Studio Code. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportVsCodeSettings { + #[serde(default)] + pub skip_prompt: bool, +} + +/// Imports settings from Cursor editor. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportCursorSettings { + #[serde(default)] + pub skip_prompt: bool, +} actions!( zed, [ @@ -47,11 +72,45 @@ pub fn init(cx: &mut App) { }); }); - cx.observe_new(|_workspace: &mut Workspace, window, cx| { + cx.observe_new(|workspace: &mut Workspace, window, cx| { let Some(window) = window else { return; }; + workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { + let fs = <dyn Fs>::global(cx); + let action = *action; + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + VsCodeSettingsSource::VsCode, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + + workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { + let fs = <dyn Fs>::global(cx); + let action = *action; + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + VsCodeSettingsSource::Cursor, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + let settings_ui_actions = [TypeId::of::<OpenSettingsEditor>()]; CommandPaletteFilter::update_global(cx, |filter, _cx| { @@ -79,6 +138,57 @@ pub fn init(cx: &mut App) { keybindings::init(cx); } +async fn handle_import_vscode_settings( + source: VsCodeSettingsSource, + skip_prompt: bool, + fs: Arc<dyn Fs>, + cx: &mut AsyncWindowContext, +) { + let vscode_settings = + match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { + Ok(vscode_settings) => vscode_settings, + Err(err) => { + log::error!("{err}"); + let _ = cx.prompt( + gpui::PromptLevel::Info, + &format!("Could not find or load a {source} settings file"), + None, + &["Ok"], + ); + return; + } + }; + + let prompt = if skip_prompt { + Task::ready(Some(0)) + } else { + let prompt = cx.prompt( + gpui::PromptLevel::Warning, + &format!( + "Importing {} settings may overwrite your existing settings. \ + Will import settings from {}", + vscode_settings.source, + truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), + ), + None, + &["Ok", "Cancel"], + ); + cx.spawn(async move |_| prompt.await.ok()) + }; + if prompt.await != Some(0) { + return; + } + + cx.update(|_, cx| { + let source = vscode_settings.source; + let path = vscode_settings.path.clone(); + cx.global::<SettingsStore>() + .import_vscode_settings(fs, vscode_settings); + log::info!("Imported {source} settings from {}", path.display()); + }) + .ok(); +} + pub struct SettingsPage { focus_handle: FocusHandle, } diff --git a/crates/settings_ui/src/ui_components/keystroke_input.rs b/crates/settings_ui/src/ui_components/keystroke_input.rs deleted file mode 100644 index 03d27d0ab9..0000000000 --- a/crates/settings_ui/src/ui_components/keystroke_input.rs +++ /dev/null @@ -1,1388 +0,0 @@ -use gpui::{ - Animation, AnimationExt, Context, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, - Keystroke, Modifiers, ModifiersChangedEvent, Subscription, Task, actions, -}; -use ui::{ - ActiveTheme as _, Color, IconButton, IconButtonShape, IconName, IconSize, Label, LabelSize, - ParentElement as _, Render, Styled as _, Tooltip, Window, prelude::*, -}; - -actions!( - keystroke_input, - [ - /// Starts recording keystrokes - StartRecording, - /// Stops recording keystrokes - StopRecording, - /// Clears the recorded keystrokes - ClearKeystrokes, - ] -); - -const KEY_CONTEXT_VALUE: &'static str = "KeystrokeInput"; - -const CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT: std::time::Duration = - std::time::Duration::from_millis(300); - -enum CloseKeystrokeResult { - Partial, - Close, - None, -} - -impl PartialEq for CloseKeystrokeResult { - fn eq(&self, other: &Self) -> bool { - matches!( - (self, other), - (CloseKeystrokeResult::Partial, CloseKeystrokeResult::Partial) - | (CloseKeystrokeResult::Close, CloseKeystrokeResult::Close) - | (CloseKeystrokeResult::None, CloseKeystrokeResult::None) - ) - } -} - -pub struct KeystrokeInput { - keystrokes: Vec<Keystroke>, - placeholder_keystrokes: Option<Vec<Keystroke>>, - outer_focus_handle: FocusHandle, - inner_focus_handle: FocusHandle, - intercept_subscription: Option<Subscription>, - _focus_subscriptions: [Subscription; 2], - search: bool, - /// The sequence of close keystrokes being typed - close_keystrokes: Option<Vec<Keystroke>>, - close_keystrokes_start: Option<usize>, - previous_modifiers: Modifiers, - /// In order to support inputting keystrokes that end with a prefix of the - /// close keybind keystrokes, we clear the close keystroke capture info - /// on a timeout after a close keystroke is pressed - /// - /// e.g. if close binding is `esc esc esc` and user wants to search for - /// `ctrl-g esc`, after entering the `ctrl-g esc`, hitting `esc` twice would - /// stop recording because of the sequence of three escapes making it - /// impossible to search for anything ending in `esc` - clear_close_keystrokes_timer: Option<Task<()>>, - #[cfg(test)] - recording: bool, -} - -impl KeystrokeInput { - const KEYSTROKE_COUNT_MAX: usize = 3; - - pub fn new( - placeholder_keystrokes: Option<Vec<Keystroke>>, - window: &mut Window, - cx: &mut Context<Self>, - ) -> Self { - let outer_focus_handle = cx.focus_handle(); - let inner_focus_handle = cx.focus_handle(); - let _focus_subscriptions = [ - cx.on_focus_in(&inner_focus_handle, window, Self::on_inner_focus_in), - cx.on_focus_out(&inner_focus_handle, window, Self::on_inner_focus_out), - ]; - Self { - keystrokes: Vec::new(), - placeholder_keystrokes, - inner_focus_handle, - outer_focus_handle, - intercept_subscription: None, - _focus_subscriptions, - search: false, - close_keystrokes: None, - close_keystrokes_start: None, - previous_modifiers: Modifiers::default(), - clear_close_keystrokes_timer: None, - #[cfg(test)] - recording: false, - } - } - - pub fn set_keystrokes(&mut self, keystrokes: Vec<Keystroke>, cx: &mut Context<Self>) { - self.keystrokes = keystrokes; - self.keystrokes_changed(cx); - } - - pub fn set_search(&mut self, search: bool) { - self.search = search; - } - - pub fn keystrokes(&self) -> &[Keystroke] { - if let Some(placeholders) = self.placeholder_keystrokes.as_ref() - && self.keystrokes.is_empty() - { - return placeholders; - } - if !self.search - && self - .keystrokes - .last() - .map_or(false, |last| last.key.is_empty()) - { - return &self.keystrokes[..self.keystrokes.len() - 1]; - } - return &self.keystrokes; - } - - fn dummy(modifiers: Modifiers) -> Keystroke { - return Keystroke { - modifiers, - key: "".to_string(), - key_char: None, - }; - } - - fn keystrokes_changed(&self, cx: &mut Context<Self>) { - cx.emit(()); - cx.notify(); - } - - fn key_context() -> KeyContext { - let mut key_context = KeyContext::default(); - key_context.add(KEY_CONTEXT_VALUE); - key_context - } - - fn determine_stop_recording_binding(window: &mut Window) -> Option<gpui::KeyBinding> { - if cfg!(test) { - Some(gpui::KeyBinding::new( - "escape escape escape", - StopRecording, - Some(KEY_CONTEXT_VALUE), - )) - } else { - window.highest_precedence_binding_for_action_in_context( - &StopRecording, - Self::key_context(), - ) - } - } - - fn upsert_close_keystrokes_start(&mut self, start: usize, cx: &mut Context<Self>) { - if self.close_keystrokes_start.is_some() { - return; - } - self.close_keystrokes_start = Some(start); - self.update_clear_close_keystrokes_timer(cx); - } - - fn update_clear_close_keystrokes_timer(&mut self, cx: &mut Context<Self>) { - self.clear_close_keystrokes_timer = Some(cx.spawn(async |this, cx| { - cx.background_executor() - .timer(CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT) - .await; - this.update(cx, |this, _cx| { - this.end_close_keystrokes_capture(); - }) - .ok(); - })); - } - - /// Interrupt the capture of close keystrokes, but do not clear the close keystrokes - /// from the input - fn end_close_keystrokes_capture(&mut self) -> Option<usize> { - self.close_keystrokes.take(); - self.clear_close_keystrokes_timer.take(); - return self.close_keystrokes_start.take(); - } - - fn handle_possible_close_keystroke( - &mut self, - keystroke: &Keystroke, - window: &mut Window, - cx: &mut Context<Self>, - ) -> CloseKeystrokeResult { - let Some(keybind_for_close_action) = Self::determine_stop_recording_binding(window) else { - log::trace!("No keybinding to stop recording keystrokes in keystroke input"); - self.end_close_keystrokes_capture(); - return CloseKeystrokeResult::None; - }; - let action_keystrokes = keybind_for_close_action.keystrokes(); - - if let Some(mut close_keystrokes) = self.close_keystrokes.take() { - let mut index = 0; - - while index < action_keystrokes.len() && index < close_keystrokes.len() { - if !close_keystrokes[index].should_match(&action_keystrokes[index]) { - break; - } - index += 1; - } - if index == close_keystrokes.len() { - if index >= action_keystrokes.len() { - self.end_close_keystrokes_capture(); - return CloseKeystrokeResult::None; - } - if keystroke.should_match(&action_keystrokes[index]) { - close_keystrokes.push(keystroke.clone()); - if close_keystrokes.len() == action_keystrokes.len() { - return CloseKeystrokeResult::Close; - } else { - self.close_keystrokes = Some(close_keystrokes); - self.update_clear_close_keystrokes_timer(cx); - return CloseKeystrokeResult::Partial; - } - } else { - self.end_close_keystrokes_capture(); - return CloseKeystrokeResult::None; - } - } - } else if let Some(first_action_keystroke) = action_keystrokes.first() - && keystroke.should_match(first_action_keystroke) - { - self.close_keystrokes = Some(vec![keystroke.clone()]); - return CloseKeystrokeResult::Partial; - } - self.end_close_keystrokes_capture(); - return CloseKeystrokeResult::None; - } - - fn on_modifiers_changed( - &mut self, - event: &ModifiersChangedEvent, - window: &mut Window, - cx: &mut Context<Self>, - ) { - cx.stop_propagation(); - let keystrokes_len = self.keystrokes.len(); - - if self.previous_modifiers.modified() - && event.modifiers.is_subset_of(&self.previous_modifiers) - { - self.previous_modifiers &= event.modifiers; - return; - } - self.keystrokes_changed(cx); - - if let Some(last) = self.keystrokes.last_mut() - && last.key.is_empty() - && keystrokes_len <= Self::KEYSTROKE_COUNT_MAX - { - if !self.search && !event.modifiers.modified() { - self.keystrokes.pop(); - return; - } - if self.search { - if self.previous_modifiers.modified() { - last.modifiers |= event.modifiers; - } else { - self.keystrokes.push(Self::dummy(event.modifiers)); - } - self.previous_modifiers |= event.modifiers; - } else { - last.modifiers = event.modifiers; - return; - } - } else if keystrokes_len < Self::KEYSTROKE_COUNT_MAX { - self.keystrokes.push(Self::dummy(event.modifiers)); - if self.search { - self.previous_modifiers |= event.modifiers; - } - } - if keystrokes_len >= Self::KEYSTROKE_COUNT_MAX { - self.clear_keystrokes(&ClearKeystrokes, window, cx); - } - } - - fn handle_keystroke( - &mut self, - keystroke: &Keystroke, - window: &mut Window, - cx: &mut Context<Self>, - ) { - cx.stop_propagation(); - - let close_keystroke_result = self.handle_possible_close_keystroke(keystroke, window, cx); - if close_keystroke_result == CloseKeystrokeResult::Close { - self.stop_recording(&StopRecording, window, cx); - return; - } - - let mut keystroke = keystroke.clone(); - if let Some(last) = self.keystrokes.last() - && last.key.is_empty() - && (!self.search || self.previous_modifiers.modified()) - { - let key = keystroke.key.clone(); - keystroke = last.clone(); - keystroke.key = key; - self.keystrokes.pop(); - } - - if close_keystroke_result == CloseKeystrokeResult::Partial { - self.upsert_close_keystrokes_start(self.keystrokes.len(), cx); - if self.keystrokes.len() >= Self::KEYSTROKE_COUNT_MAX { - return; - } - } - - if self.keystrokes.len() >= Self::KEYSTROKE_COUNT_MAX { - self.clear_keystrokes(&ClearKeystrokes, window, cx); - return; - } - - self.keystrokes.push(keystroke.clone()); - self.keystrokes_changed(cx); - - if self.search { - self.previous_modifiers = keystroke.modifiers; - return; - } - if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX && keystroke.modifiers.modified() { - self.keystrokes.push(Self::dummy(keystroke.modifiers)); - } - } - - fn on_inner_focus_in(&mut self, _window: &mut Window, cx: &mut Context<Self>) { - if self.intercept_subscription.is_none() { - let listener = cx.listener(|this, event: &gpui::KeystrokeEvent, window, cx| { - this.handle_keystroke(&event.keystroke, window, cx); - }); - self.intercept_subscription = Some(cx.intercept_keystrokes(listener)) - } - } - - fn on_inner_focus_out( - &mut self, - _event: gpui::FocusOutEvent, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - self.intercept_subscription.take(); - cx.notify(); - } - - fn render_keystrokes(&self, is_recording: bool) -> impl Iterator<Item = Div> { - let keystrokes = if let Some(placeholders) = self.placeholder_keystrokes.as_ref() - && self.keystrokes.is_empty() - { - if is_recording { - &[] - } else { - placeholders.as_slice() - } - } else { - &self.keystrokes - }; - keystrokes.iter().map(move |keystroke| { - h_flex().children(ui::render_keystroke( - keystroke, - Some(Color::Default), - Some(rems(0.875).into()), - ui::PlatformStyle::platform(), - false, - )) - }) - } - - pub fn start_recording( - &mut self, - _: &StartRecording, - window: &mut Window, - cx: &mut Context<Self>, - ) { - window.focus(&self.inner_focus_handle); - self.clear_keystrokes(&ClearKeystrokes, window, cx); - self.previous_modifiers = window.modifiers(); - #[cfg(test)] - { - self.recording = true; - } - cx.stop_propagation(); - } - - pub fn stop_recording( - &mut self, - _: &StopRecording, - window: &mut Window, - cx: &mut Context<Self>, - ) { - if !self.is_recording(window) { - return; - } - window.focus(&self.outer_focus_handle); - if let Some(close_keystrokes_start) = self.close_keystrokes_start.take() - && close_keystrokes_start < self.keystrokes.len() - { - self.keystrokes.drain(close_keystrokes_start..); - self.keystrokes_changed(cx); - } - self.end_close_keystrokes_capture(); - #[cfg(test)] - { - self.recording = false; - } - cx.notify(); - } - - pub fn clear_keystrokes( - &mut self, - _: &ClearKeystrokes, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - self.keystrokes.clear(); - self.keystrokes_changed(cx); - self.end_close_keystrokes_capture(); - } - - fn is_recording(&self, window: &Window) -> bool { - #[cfg(test)] - { - if true { - // in tests, we just need a simple bool that is toggled on start and stop recording - return self.recording; - } - } - // however, in the real world, checking if the inner focus handle is focused - // is a much more reliable check, as the intercept keystroke handlers are installed - // on focus of the inner focus handle, thereby ensuring our recording state does - // not get de-synced - return self.inner_focus_handle.is_focused(window); - } -} - -impl EventEmitter<()> for KeystrokeInput {} - -impl Focusable for KeystrokeInput { - fn focus_handle(&self, _cx: &gpui::App) -> FocusHandle { - self.outer_focus_handle.clone() - } -} - -impl Render for KeystrokeInput { - fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - let colors = cx.theme().colors(); - let is_focused = self.outer_focus_handle.contains_focused(window, cx); - let is_recording = self.is_recording(window); - - let horizontal_padding = rems_from_px(64.); - - let recording_bg_color = colors - .editor_background - .blend(colors.text_accent.opacity(0.1)); - - let recording_pulse = |color: Color| { - Icon::new(IconName::Circle) - .size(IconSize::Small) - .color(Color::Error) - .with_animation( - "recording-pulse", - Animation::new(std::time::Duration::from_secs(2)) - .repeat() - .with_easing(gpui::pulsating_between(0.4, 0.8)), - { - let color = color.color(cx); - move |this, delta| this.color(Color::Custom(color.opacity(delta))) - }, - ) - }; - - let recording_indicator = h_flex() - .h_4() - .pr_1() - .gap_0p5() - .border_1() - .border_color(colors.border) - .bg(colors - .editor_background - .blend(colors.text_accent.opacity(0.1))) - .rounded_sm() - .child(recording_pulse(Color::Error)) - .child( - Label::new("REC") - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Error), - ); - - let search_indicator = h_flex() - .h_4() - .pr_1() - .gap_0p5() - .border_1() - .border_color(colors.border) - .bg(colors - .editor_background - .blend(colors.text_accent.opacity(0.1))) - .rounded_sm() - .child(recording_pulse(Color::Accent)) - .child( - Label::new("SEARCH") - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Accent), - ); - - let record_icon = if self.search { - IconName::MagnifyingGlass - } else { - IconName::PlayFilled - }; - - h_flex() - .id("keystroke-input") - .track_focus(&self.outer_focus_handle) - .py_2() - .px_3() - .gap_2() - .min_h_10() - .w_full() - .flex_1() - .justify_between() - .rounded_lg() - .overflow_hidden() - .map(|this| { - if is_recording { - this.bg(recording_bg_color) - } else { - this.bg(colors.editor_background) - } - }) - .border_1() - .border_color(colors.border_variant) - .when(is_focused, |parent| { - parent.border_color(colors.border_focused) - }) - .key_context(Self::key_context()) - .on_action(cx.listener(Self::start_recording)) - .on_action(cx.listener(Self::clear_keystrokes)) - .child( - h_flex() - .w(horizontal_padding) - .gap_0p5() - .justify_start() - .flex_none() - .when(is_recording, |this| { - this.map(|this| { - if self.search { - this.child(search_indicator) - } else { - this.child(recording_indicator) - } - }) - }), - ) - .child( - h_flex() - .id("keystroke-input-inner") - .track_focus(&self.inner_focus_handle) - .on_modifiers_changed(cx.listener(Self::on_modifiers_changed)) - .size_full() - .when(!self.search, |this| { - this.focus(|mut style| { - style.border_color = Some(colors.border_focused); - style - }) - }) - .w_full() - .min_w_0() - .justify_center() - .flex_wrap() - .gap(ui::DynamicSpacing::Base04.rems(cx)) - .children(self.render_keystrokes(is_recording)), - ) - .child( - h_flex() - .w(horizontal_padding) - .gap_0p5() - .justify_end() - .flex_none() - .map(|this| { - if is_recording { - this.child( - IconButton::new("stop-record-btn", IconName::StopFilled) - .shape(IconButtonShape::Square) - .map(|this| { - this.tooltip(Tooltip::for_action_title( - if self.search { - "Stop Searching" - } else { - "Stop Recording" - }, - &StopRecording, - )) - }) - .icon_color(Color::Error) - .on_click(cx.listener(|this, _event, window, cx| { - this.stop_recording(&StopRecording, window, cx); - })), - ) - } else { - this.child( - IconButton::new("record-btn", record_icon) - .shape(IconButtonShape::Square) - .map(|this| { - this.tooltip(Tooltip::for_action_title( - if self.search { - "Start Searching" - } else { - "Start Recording" - }, - &StartRecording, - )) - }) - .when(!is_focused, |this| this.icon_color(Color::Muted)) - .on_click(cx.listener(|this, _event, window, cx| { - this.start_recording(&StartRecording, window, cx); - })), - ) - } - }) - .child( - IconButton::new("clear-btn", IconName::Delete) - .shape(IconButtonShape::Square) - .tooltip(Tooltip::for_action_title( - "Clear Keystrokes", - &ClearKeystrokes, - )) - .when(!is_recording || !is_focused, |this| { - this.icon_color(Color::Muted) - }) - .on_click(cx.listener(|this, _event, window, cx| { - this.clear_keystrokes(&ClearKeystrokes, window, cx); - })), - ), - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use fs::FakeFs; - use gpui::{Entity, TestAppContext, VisualTestContext}; - use itertools::Itertools as _; - use project::Project; - use settings::SettingsStore; - use workspace::Workspace; - - pub struct KeystrokeInputTestHelper { - input: Entity<KeystrokeInput>, - current_modifiers: Modifiers, - cx: VisualTestContext, - } - - impl KeystrokeInputTestHelper { - /// Creates a new test helper with default settings - pub fn new(mut cx: VisualTestContext) -> Self { - let input = cx.new_window_entity(|window, cx| KeystrokeInput::new(None, window, cx)); - - let mut helper = Self { - input, - current_modifiers: Modifiers::default(), - cx, - }; - - helper.start_recording(); - helper - } - - /// Sets search mode on the input - pub fn with_search_mode(&mut self, search: bool) -> &mut Self { - self.input.update(&mut self.cx, |input, _| { - input.set_search(search); - }); - self - } - - /// Sends a keystroke event based on string description - /// Examples: "a", "ctrl-a", "cmd-shift-z", "escape" - #[track_caller] - pub fn send_keystroke(&mut self, keystroke_input: &str) -> &mut Self { - self.expect_is_recording(true); - let keystroke_str = if keystroke_input.ends_with('-') { - format!("{}_", keystroke_input) - } else { - keystroke_input.to_string() - }; - - let mut keystroke = Keystroke::parse(&keystroke_str) - .unwrap_or_else(|_| panic!("Invalid keystroke: {}", keystroke_input)); - - // Remove the dummy key if we added it for modifier-only keystrokes - if keystroke_input.ends_with('-') && keystroke_str.ends_with("_") { - keystroke.key = "".to_string(); - } - - // Combine current modifiers with keystroke modifiers - keystroke.modifiers |= self.current_modifiers; - - self.update_input(|input, window, cx| { - input.handle_keystroke(&keystroke, window, cx); - }); - - // Don't update current_modifiers for keystrokes with actual keys - if keystroke.key.is_empty() { - self.current_modifiers = keystroke.modifiers; - } - self - } - - /// Sends a modifier change event based on string description - /// Examples: "+ctrl", "-ctrl", "+cmd+shift", "-all" - #[track_caller] - pub fn send_modifiers(&mut self, modifiers: &str) -> &mut Self { - self.expect_is_recording(true); - let new_modifiers = if modifiers == "-all" { - Modifiers::default() - } else { - self.parse_modifier_change(modifiers) - }; - - let event = ModifiersChangedEvent { - modifiers: new_modifiers, - capslock: gpui::Capslock::default(), - }; - - self.update_input(|input, window, cx| { - input.on_modifiers_changed(&event, window, cx); - }); - - self.current_modifiers = new_modifiers; - self - } - - /// Sends multiple events in sequence - /// Each event string is either a keystroke or modifier change - #[track_caller] - pub fn send_events(&mut self, events: &[&str]) -> &mut Self { - self.expect_is_recording(true); - for event in events { - if event.starts_with('+') || event.starts_with('-') { - self.send_modifiers(event); - } else { - self.send_keystroke(event); - } - } - self - } - - #[track_caller] - fn expect_keystrokes_equal(actual: &[Keystroke], expected: &[&str]) { - let expected_keystrokes: Result<Vec<Keystroke>, _> = expected - .iter() - .map(|s| { - let keystroke_str = if s.ends_with('-') { - format!("{}_", s) - } else { - s.to_string() - }; - - let mut keystroke = Keystroke::parse(&keystroke_str)?; - - // Remove the dummy key if we added it for modifier-only keystrokes - if s.ends_with('-') && keystroke_str.ends_with("_") { - keystroke.key = "".to_string(); - } - - Ok(keystroke) - }) - .collect(); - - let expected_keystrokes = expected_keystrokes - .unwrap_or_else(|e: anyhow::Error| panic!("Invalid expected keystroke: {}", e)); - - assert_eq!( - actual.len(), - expected_keystrokes.len(), - "Keystroke count mismatch. Expected: {:?}, Actual: {:?}", - expected_keystrokes - .iter() - .map(|k| k.unparse()) - .collect::<Vec<_>>(), - actual.iter().map(|k| k.unparse()).collect::<Vec<_>>() - ); - - for (i, (actual, expected)) in actual.iter().zip(expected_keystrokes.iter()).enumerate() - { - assert_eq!( - actual.unparse(), - expected.unparse(), - "Keystroke {} mismatch. Expected: '{}', Actual: '{}'", - i, - expected.unparse(), - actual.unparse() - ); - } - } - - /// Verifies that the keystrokes match the expected strings - #[track_caller] - pub fn expect_keystrokes(&mut self, expected: &[&str]) -> &mut Self { - let actual = self - .input - .read_with(&mut self.cx, |input, _| input.keystrokes.clone()); - Self::expect_keystrokes_equal(&actual, expected); - self - } - - #[track_caller] - pub fn expect_close_keystrokes(&mut self, expected: &[&str]) -> &mut Self { - let actual = self - .input - .read_with(&mut self.cx, |input, _| input.close_keystrokes.clone()) - .unwrap_or_default(); - Self::expect_keystrokes_equal(&actual, expected); - self - } - - /// Verifies that there are no keystrokes - #[track_caller] - pub fn expect_empty(&mut self) -> &mut Self { - self.expect_keystrokes(&[]) - } - - /// Starts recording keystrokes - #[track_caller] - pub fn start_recording(&mut self) -> &mut Self { - self.expect_is_recording(false); - self.input.update_in(&mut self.cx, |input, window, cx| { - input.start_recording(&StartRecording, window, cx); - }); - self - } - - /// Stops recording keystrokes - pub fn stop_recording(&mut self) -> &mut Self { - self.expect_is_recording(true); - self.input.update_in(&mut self.cx, |input, window, cx| { - input.stop_recording(&StopRecording, window, cx); - }); - self - } - - /// Clears all keystrokes - #[track_caller] - pub fn clear_keystrokes(&mut self) -> &mut Self { - let change_tracker = KeystrokeUpdateTracker::new(self.input.clone(), &mut self.cx); - self.input.update_in(&mut self.cx, |input, window, cx| { - input.clear_keystrokes(&ClearKeystrokes, window, cx); - }); - KeystrokeUpdateTracker::finish(change_tracker, &self.cx); - self.current_modifiers = Default::default(); - self - } - - /// Verifies the recording state - #[track_caller] - pub fn expect_is_recording(&mut self, expected: bool) -> &mut Self { - let actual = self - .input - .update_in(&mut self.cx, |input, window, _| input.is_recording(window)); - assert_eq!( - actual, expected, - "Recording state mismatch. Expected: {}, Actual: {}", - expected, actual - ); - self - } - - pub async fn wait_for_close_keystroke_capture_end(&mut self) -> &mut Self { - let task = self.input.update_in(&mut self.cx, |input, _, _| { - input.clear_close_keystrokes_timer.take() - }); - let task = task.expect("No close keystroke capture end timer task"); - self.cx - .executor() - .advance_clock(CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT); - task.await; - self - } - - /// Parses modifier change strings like "+ctrl", "-shift", "+cmd+alt" - #[track_caller] - fn parse_modifier_change(&self, modifiers_str: &str) -> Modifiers { - let mut modifiers = self.current_modifiers; - - assert!(!modifiers_str.is_empty(), "Empty modifier string"); - - let value; - let split_char; - let remaining; - if let Some(to_add) = modifiers_str.strip_prefix('+') { - value = true; - split_char = '+'; - remaining = to_add; - } else { - let to_remove = modifiers_str - .strip_prefix('-') - .expect("Modifier string must start with '+' or '-'"); - value = false; - split_char = '-'; - remaining = to_remove; - } - - for modifier in remaining.split(split_char) { - match modifier { - "ctrl" | "control" => modifiers.control = value, - "alt" | "option" => modifiers.alt = value, - "shift" => modifiers.shift = value, - "cmd" | "command" | "platform" => modifiers.platform = value, - "fn" | "function" => modifiers.function = value, - _ => panic!("Unknown modifier: {}", modifier), - } - } - - modifiers - } - - #[track_caller] - fn update_input<R>( - &mut self, - cb: impl FnOnce(&mut KeystrokeInput, &mut Window, &mut Context<KeystrokeInput>) -> R, - ) -> R { - let change_tracker = KeystrokeUpdateTracker::new(self.input.clone(), &mut self.cx); - let result = self.input.update_in(&mut self.cx, cb); - KeystrokeUpdateTracker::finish(change_tracker, &self.cx); - return result; - } - } - - struct KeystrokeUpdateTracker { - initial_keystrokes: Vec<Keystroke>, - _subscription: Subscription, - input: Entity<KeystrokeInput>, - received_keystrokes_updated: bool, - } - - impl KeystrokeUpdateTracker { - fn new(input: Entity<KeystrokeInput>, cx: &mut VisualTestContext) -> Entity<Self> { - cx.new(|cx| Self { - initial_keystrokes: input.read_with(cx, |input, _| input.keystrokes.clone()), - _subscription: cx.subscribe(&input, |this: &mut Self, _, _, _| { - this.received_keystrokes_updated = true; - }), - input, - received_keystrokes_updated: false, - }) - } - #[track_caller] - fn finish(this: Entity<Self>, cx: &VisualTestContext) { - let (received_keystrokes_updated, initial_keystrokes_str, updated_keystrokes_str) = - this.read_with(cx, |this, cx| { - let updated_keystrokes = this - .input - .read_with(cx, |input, _| input.keystrokes.clone()); - let initial_keystrokes_str = keystrokes_str(&this.initial_keystrokes); - let updated_keystrokes_str = keystrokes_str(&updated_keystrokes); - ( - this.received_keystrokes_updated, - initial_keystrokes_str, - updated_keystrokes_str, - ) - }); - if received_keystrokes_updated { - assert_ne!( - initial_keystrokes_str, updated_keystrokes_str, - "Received keystrokes_updated event, expected different keystrokes" - ); - } else { - assert_eq!( - initial_keystrokes_str, updated_keystrokes_str, - "Received no keystrokes_updated event, expected same keystrokes" - ); - } - - fn keystrokes_str(ks: &[Keystroke]) -> String { - ks.iter().map(|ks| ks.unparse()).join(" ") - } - } - } - - async fn init_test(cx: &mut TestAppContext) -> KeystrokeInputTestHelper { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - theme::init(theme::LoadThemes::JustBase, cx); - language::init(cx); - project::Project::init_settings(cx); - workspace::init_settings(cx); - }); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [], cx).await; - let workspace = - cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let cx = VisualTestContext::from_window(*workspace, cx); - KeystrokeInputTestHelper::new(cx) - } - - #[gpui::test] - async fn test_basic_keystroke_input(cx: &mut TestAppContext) { - init_test(cx) - .await - .send_keystroke("a") - .clear_keystrokes() - .expect_empty(); - } - - #[gpui::test] - async fn test_modifier_handling(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl", "a", "-ctrl"]) - .expect_keystrokes(&["ctrl-a"]); - } - - #[gpui::test] - async fn test_multiple_modifiers(cx: &mut TestAppContext) { - init_test(cx) - .await - .send_keystroke("cmd-shift-z") - .expect_keystrokes(&["cmd-shift-z", "cmd-shift-"]); - } - - #[gpui::test] - async fn test_search_mode_behavior(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+cmd", "shift-f", "-cmd"]) - // In search mode, when completing a modifier-only keystroke with a key, - // only the original modifiers are preserved, not the keystroke's modifiers - .expect_keystrokes(&["cmd-f"]); - } - - #[gpui::test] - async fn test_keystroke_limit(cx: &mut TestAppContext) { - init_test(cx) - .await - .send_keystroke("a") - .send_keystroke("b") - .send_keystroke("c") - .expect_keystrokes(&["a", "b", "c"]) // At max limit - .send_keystroke("d") - .expect_empty(); // Should clear when exceeding limit - } - - #[gpui::test] - async fn test_modifier_release_all(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl+shift", "a", "-all"]) - .expect_keystrokes(&["ctrl-shift-a"]); - } - - #[gpui::test] - async fn test_search_new_modifiers_not_added_until_all_released(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl+shift", "a", "-ctrl"]) - .expect_keystrokes(&["ctrl-shift-a"]) - .send_events(&["+ctrl"]) - .expect_keystrokes(&["ctrl-shift-a", "ctrl-shift-"]); - } - - #[gpui::test] - async fn test_previous_modifiers_no_effect_when_not_search(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(false) - .send_events(&["+ctrl+shift", "a", "-all"]) - .expect_keystrokes(&["ctrl-shift-a"]); - } - - #[gpui::test] - async fn test_keystroke_limit_overflow_non_search_mode(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(false) - .send_events(&["a", "b", "c", "d"]) // 4 keystrokes, exceeds limit of 3 - .expect_empty(); // Should clear when exceeding limit - } - - #[gpui::test] - async fn test_complex_modifier_sequences(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl", "+shift", "+alt", "a", "-ctrl", "-shift", "-alt"]) - .expect_keystrokes(&["ctrl-shift-alt-a"]); - } - - #[gpui::test] - async fn test_modifier_only_keystrokes_search_mode(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl", "+shift", "-ctrl", "-shift"]) - .expect_keystrokes(&["ctrl-shift-"]); // Modifier-only sequences create modifier-only keystrokes - } - - #[gpui::test] - async fn test_modifier_only_keystrokes_non_search_mode(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(false) - .send_events(&["+ctrl", "+shift", "-ctrl", "-shift"]) - .expect_empty(); // Modifier-only sequences get filtered in non-search mode - } - - #[gpui::test] - async fn test_rapid_modifier_changes(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl", "-ctrl", "+shift", "-shift", "+alt", "a", "-alt"]) - .expect_keystrokes(&["ctrl-", "shift-", "alt-a"]); - } - - #[gpui::test] - async fn test_clear_keystrokes_search_mode(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl", "a", "-ctrl", "b"]) - .expect_keystrokes(&["ctrl-a", "b"]) - .clear_keystrokes() - .expect_empty(); - } - - #[gpui::test] - async fn test_non_search_mode_modifier_key_sequence(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(false) - .send_events(&["+ctrl", "a"]) - .expect_keystrokes(&["ctrl-a", "ctrl-"]) - .send_events(&["-ctrl"]) - .expect_keystrokes(&["ctrl-a"]); // Non-search mode filters trailing empty keystrokes - } - - #[gpui::test] - async fn test_all_modifiers_at_once(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl+shift+alt+cmd", "a", "-all"]) - .expect_keystrokes(&["ctrl-shift-alt-cmd-a"]); - } - - #[gpui::test] - async fn test_keystrokes_at_exact_limit(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["a", "b", "c"]) // exactly 3 keystrokes (at limit) - .expect_keystrokes(&["a", "b", "c"]) - .send_events(&["d"]) // should clear when exceeding - .expect_empty(); - } - - #[gpui::test] - async fn test_function_modifier_key(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+fn", "f1", "-fn"]) - .expect_keystrokes(&["fn-f1"]); - } - - #[gpui::test] - async fn test_start_stop_recording(cx: &mut TestAppContext) { - init_test(cx) - .await - .send_events(&["a", "b"]) - .expect_keystrokes(&["a", "b"]) // start_recording clears existing keystrokes - .stop_recording() - .expect_is_recording(false) - .start_recording() - .send_events(&["c"]) - .expect_keystrokes(&["c"]); - } - - #[gpui::test] - async fn test_modifier_sequence_with_interruption(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl", "+shift", "a", "-shift", "b", "-ctrl"]) - .expect_keystrokes(&["ctrl-shift-a", "ctrl-b"]); - } - - #[gpui::test] - async fn test_empty_key_sequence_search_mode(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&[]) // No events at all - .expect_empty(); - } - - #[gpui::test] - async fn test_modifier_sequence_completion_search_mode(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl", "+shift", "-shift", "a", "-ctrl"]) - .expect_keystrokes(&["ctrl-shift-a"]); - } - - #[gpui::test] - async fn test_triple_escape_stops_recording_search_mode(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["a", "escape", "escape", "escape"]) - .expect_keystrokes(&["a"]) // Triple escape removes final escape, stops recording - .expect_is_recording(false); - } - - #[gpui::test] - async fn test_triple_escape_stops_recording_non_search_mode(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(false) - .send_events(&["a", "escape", "escape", "escape"]) - .expect_keystrokes(&["a"]); // Triple escape stops recording but only removes final escape - } - - #[gpui::test] - async fn test_triple_escape_at_keystroke_limit(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["a", "b", "c", "escape", "escape", "escape"]) // 6 keystrokes total, exceeds limit - .expect_keystrokes(&["a", "b", "c"]); // Triple escape stops recording and removes escapes, leaves original keystrokes - } - - #[gpui::test] - async fn test_interrupted_escape_sequence(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["escape", "escape", "a", "escape"]) // Partial escape sequence interrupted by 'a' - .expect_keystrokes(&["escape", "escape", "a"]); // Escape sequence interrupted by 'a', no close triggered - } - - #[gpui::test] - async fn test_interrupted_escape_sequence_within_limit(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["escape", "escape", "a"]) // Partial escape sequence interrupted by 'a' (3 keystrokes, at limit) - .expect_keystrokes(&["escape", "escape", "a"]); // Should not trigger close, interruption resets escape detection - } - - #[gpui::test] - async fn test_partial_escape_sequence_no_close(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["escape", "escape"]) // Only 2 escapes, not enough to close - .expect_keystrokes(&["escape", "escape"]) - .expect_is_recording(true); // Should remain in keystrokes, no close triggered - } - - #[gpui::test] - async fn test_recording_state_after_triple_escape(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["a", "escape", "escape", "escape"]) - .expect_keystrokes(&["a"]) // Triple escape stops recording, removes final escape - .expect_is_recording(false); - } - - #[gpui::test] - async fn test_triple_escape_mixed_with_other_keystrokes(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["a", "escape", "b", "escape", "escape"]) // Mixed sequence, should not trigger close - .expect_keystrokes(&["a", "escape", "b"]); // No complete triple escape sequence, stays at limit - } - - #[gpui::test] - async fn test_triple_escape_only(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["escape", "escape", "escape"]) // Pure triple escape sequence - .expect_empty(); - } - - #[gpui::test] - async fn test_end_close_keystroke_capture(cx: &mut TestAppContext) { - init_test(cx) - .await - .send_events(&["+ctrl", "g", "-ctrl", "escape"]) - .expect_keystrokes(&["ctrl-g", "escape"]) - .wait_for_close_keystroke_capture_end() - .await - .send_events(&["escape", "escape"]) - .expect_keystrokes(&["ctrl-g", "escape", "escape"]) - .expect_close_keystrokes(&["escape", "escape"]) - .send_keystroke("escape") - .expect_keystrokes(&["ctrl-g", "escape"]); - } - - #[gpui::test] - async fn test_search_previous_modifiers_are_sticky(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl+alt", "-ctrl", "j"]) - .expect_keystrokes(&["ctrl-alt-j"]); - } - - #[gpui::test] - async fn test_previous_modifiers_can_be_entered_separately(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl", "-ctrl"]) - .expect_keystrokes(&["ctrl-"]) - .send_events(&["+alt", "-alt"]) - .expect_keystrokes(&["ctrl-", "alt-"]); - } - - #[gpui::test] - async fn test_previous_modifiers_reset_on_key(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl+alt", "-ctrl", "+shift"]) - .expect_keystrokes(&["ctrl-shift-alt-"]) - .send_keystroke("j") - .expect_keystrokes(&["ctrl-shift-alt-j"]) - .send_keystroke("i") - .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i"]) - .send_events(&["-shift-alt", "+cmd"]) - .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i", "cmd-"]); - } - - #[gpui::test] - async fn test_previous_modifiers_reset_on_release_all(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl+alt", "-ctrl", "+shift"]) - .expect_keystrokes(&["ctrl-shift-alt-"]) - .send_events(&["-all", "j"]) - .expect_keystrokes(&["ctrl-shift-alt-", "j"]); - } - - #[gpui::test] - async fn test_search_repeat_modifiers(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(true) - .send_events(&["+ctrl", "-ctrl", "+alt", "-alt", "+shift", "-shift"]) - .expect_keystrokes(&["ctrl-", "alt-", "shift-"]) - .send_events(&["+cmd"]) - .expect_empty(); - } - - #[gpui::test] - async fn test_not_search_repeat_modifiers(cx: &mut TestAppContext) { - init_test(cx) - .await - .with_search_mode(false) - .send_events(&["+ctrl", "-ctrl", "+alt", "-alt", "+shift", "-shift"]) - .expect_empty(); - } -} diff --git a/crates/settings_ui/src/ui_components/mod.rs b/crates/settings_ui/src/ui_components/mod.rs index 5d6463a61a..13971b0a5d 100644 --- a/crates/settings_ui/src/ui_components/mod.rs +++ b/crates/settings_ui/src/ui_components/mod.rs @@ -1,2 +1 @@ -pub mod keystroke_input; pub mod table; diff --git a/crates/settings_ui/src/ui_components/table.rs b/crates/settings_ui/src/ui_components/table.rs index 3c9992bd68..69207f559b 100644 --- a/crates/settings_ui/src/ui_components/table.rs +++ b/crates/settings_ui/src/ui_components/table.rs @@ -2,9 +2,9 @@ use std::{ops::Range, rc::Rc, time::Duration}; use editor::{EditorSettings, ShowScrollbar, scroll::ScrollbarAutoHide}; use gpui::{ - AbsoluteLength, AppContext, Axis, Context, DefiniteLength, DragMoveEvent, Entity, EntityId, - FocusHandle, Length, ListHorizontalSizingBehavior, ListSizingBehavior, MouseButton, Point, - Stateful, Task, UniformListScrollHandle, WeakEntity, transparent_black, uniform_list, + AbsoluteLength, AppContext, Axis, Context, DefiniteLength, DragMoveEvent, Entity, FocusHandle, + Length, ListHorizontalSizingBehavior, ListSizingBehavior, MouseButton, Point, Stateful, Task, + UniformListScrollHandle, WeakEntity, transparent_black, uniform_list, }; use itertools::intersperse_with; @@ -13,12 +13,10 @@ use ui::{ ActiveTheme as _, AnyElement, App, Button, ButtonCommon as _, ButtonStyle, Color, Component, ComponentScope, Div, ElementId, FixedWidth as _, FluentBuilder as _, Indicator, InteractiveElement, IntoElement, ParentElement, Pixels, RegisterComponent, RenderOnce, - Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, StyledExt as _, + Scrollbar, ScrollbarState, StatefulInteractiveElement, Styled, StyledExt as _, StyledTypography, Window, div, example_group_with_title, h_flex, px, single_example, v_flex, }; -const RESIZE_COLUMN_WIDTH: f32 = 8.0; - #[derive(Debug)] struct DraggedColumn(usize); @@ -214,7 +212,6 @@ impl TableInteractionState { let mut column_ix = 0; let resizable_columns_slice = *resizable_columns; let mut resizable_columns = resizable_columns.into_iter(); - let dividers = intersperse_with(spacers, || { window.with_id(column_ix, |window| { let mut resize_divider = div() @@ -222,15 +219,15 @@ impl TableInteractionState { .id(column_ix) .relative() .top_0() - .w_px() + .w_0p5() .h_full() - .bg(cx.theme().colors().border.opacity(0.8)); + .bg(cx.theme().colors().border.opacity(0.5)); let mut resize_handle = div() .id("column-resize-handle") .absolute() .left_neg_0p5() - .w(px(RESIZE_COLUMN_WIDTH)) + .w(px(5.0)) .h_full(); if resizable_columns @@ -238,11 +235,9 @@ impl TableInteractionState { .is_some_and(ResizeBehavior::is_resizable) { let hovered = window.use_state(cx, |_window, _cx| false); - resize_divider = resize_divider.when(*hovered.read(cx), |div| { div.bg(cx.theme().colors().border_focused) }); - resize_handle = resize_handle .on_hover(move |&was_hovered, _, cx| hovered.write(cx, was_hovered)) .cursor_col_resize() @@ -272,11 +267,12 @@ impl TableInteractionState { }) }); - h_flex() + div() .id("resize-handles") + .h_flex() .absolute() - .inset_0() .w_full() + .inset_0() .children(dividers) .into_any_element() } @@ -482,7 +478,6 @@ impl ResizeBehavior { pub struct ColumnWidths<const COLS: usize> { widths: [DefiniteLength; COLS], - visible_widths: [DefiniteLength; COLS], cached_bounds_width: Pixels, initialized: bool, } @@ -491,7 +486,6 @@ impl<const COLS: usize> ColumnWidths<COLS> { pub fn new(_: &mut App) -> Self { Self { widths: [DefiniteLength::default(); COLS], - visible_widths: [DefiniteLength::default(); COLS], cached_bounds_width: Default::default(), initialized: false, } @@ -518,105 +512,46 @@ impl<const COLS: usize> ColumnWidths<COLS> { let rem_size = window.rem_size(); let initial_sizes = initial_sizes.map(|length| Self::get_fraction(&length, bounds_width, rem_size)); - let widths = self + let mut widths = self .widths .map(|length| Self::get_fraction(&length, bounds_width, rem_size)); - let updated_widths = Self::reset_to_initial_size( - double_click_position, - widths, - initial_sizes, - resize_behavior, - ); - self.widths = updated_widths.map(DefiniteLength::Fraction); - self.visible_widths = self.widths; - } + let diff = initial_sizes[double_click_position] - widths[double_click_position]; - fn reset_to_initial_size( - col_idx: usize, - mut widths: [f32; COLS], - initial_sizes: [f32; COLS], - resize_behavior: &[ResizeBehavior; COLS], - ) -> [f32; COLS] { - // RESET: - // Part 1: - // Figure out if we should shrink/grow the selected column - // Get diff which represents the change in column we want to make initial size delta curr_size = diff - // - // Part 2: We need to decide which side column we should move and where - // - // If we want to grow our column we should check the left/right columns diff to see what side - // has a greater delta than their initial size. Likewise, if we shrink our column we should check - // the left/right column diffs to see what side has the smallest delta. - // - // Part 3: resize - // - // col_idx represents the column handle to the right of an active column - // - // If growing and right has the greater delta { - // shift col_idx to the right - // } else if growing and left has the greater delta { - // shift col_idx - 1 to the left - // } else if shrinking and the right has the greater delta { - // shift - // } { - // - // } - // } - // - // if we need to shrink, then if the right - // + if diff > 0.0 { + let diff_remaining = self.propagate_resize_diff_right( + diff, + double_click_position, + &mut widths, + resize_behavior, + ); - // DRAGGING - // we get diff which represents the change in the _drag handle_ position - // -diff => dragging left -> - // grow the column to the right of the handle as much as we can shrink columns to the left of the handle - // +diff => dragging right -> growing handles column - // grow the column to the left of the handle as much as we can shrink columns to the right of the handle - // - - let diff = initial_sizes[col_idx] - widths[col_idx]; - - let left_diff = - initial_sizes[..col_idx].iter().sum::<f32>() - widths[..col_idx].iter().sum::<f32>(); - let right_diff = initial_sizes[col_idx + 1..].iter().sum::<f32>() - - widths[col_idx + 1..].iter().sum::<f32>(); - - let go_left_first = if diff < 0.0 { - left_diff > right_diff - } else { - left_diff < right_diff - }; - - if !go_left_first { - let diff_remaining = - Self::propagate_resize_diff(diff, col_idx, &mut widths, resize_behavior, 1); - - if diff_remaining != 0.0 && col_idx > 0 { - Self::propagate_resize_diff( - diff_remaining, - col_idx, + if diff_remaining > 0.0 && double_click_position > 0 { + self.propagate_resize_diff_left( + -diff_remaining, + double_click_position - 1, &mut widths, resize_behavior, - -1, ); } - } else { - let diff_remaining = - Self::propagate_resize_diff(diff, col_idx, &mut widths, resize_behavior, -1); + } else if double_click_position > 0 { + let diff_remaining = self.propagate_resize_diff_left( + diff, + double_click_position, + &mut widths, + resize_behavior, + ); - if diff_remaining != 0.0 { - Self::propagate_resize_diff( - diff_remaining, - col_idx, + if diff_remaining < 0.0 { + self.propagate_resize_diff_right( + -diff_remaining, + double_click_position, &mut widths, resize_behavior, - 1, ); } } - - widths + self.widths = widths.map(DefiniteLength::Fraction); } fn on_drag_move( @@ -634,102 +569,98 @@ impl<const COLS: usize> ColumnWidths<COLS> { let bounds_width = bounds.right() - bounds.left(); let col_idx = drag_event.drag(cx).0; - let column_handle_width = Self::get_fraction( - &DefiniteLength::Absolute(AbsoluteLength::Pixels(px(RESIZE_COLUMN_WIDTH))), - bounds_width, - rem_size, - ); - let mut widths = self .widths .map(|length| Self::get_fraction(&length, bounds_width, rem_size)); for length in widths[0..=col_idx].iter() { - col_position += length + column_handle_width; + col_position += length; } let mut total_length_ratio = col_position; for length in widths[col_idx + 1..].iter() { total_length_ratio += length; } - total_length_ratio += (COLS - 1 - col_idx) as f32 * column_handle_width; let drag_fraction = (drag_position.x - bounds.left()) / bounds_width; let drag_fraction = drag_fraction * total_length_ratio; - let diff = drag_fraction - col_position - column_handle_width / 2.0; + let diff = drag_fraction - col_position; - Self::drag_column_handle(diff, col_idx, &mut widths, resize_behavior); + let is_dragging_right = diff > 0.0; - self.visible_widths = widths.map(DefiniteLength::Fraction); - } - - fn drag_column_handle( - diff: f32, - col_idx: usize, - widths: &mut [f32; COLS], - resize_behavior: &[ResizeBehavior; COLS], - ) { - // if diff > 0.0 then go right - if diff > 0.0 { - Self::propagate_resize_diff(diff, col_idx, widths, resize_behavior, 1); + if is_dragging_right { + self.propagate_resize_diff_right(diff, col_idx, &mut widths, resize_behavior); } else { - Self::propagate_resize_diff(-diff, col_idx + 1, widths, resize_behavior, -1); + // Resize behavior should be improved in the future by also seeking to the right column when there's not enough space + self.propagate_resize_diff_left(diff, col_idx, &mut widths, resize_behavior); } + self.widths = widths.map(DefiniteLength::Fraction); } - fn propagate_resize_diff( + fn propagate_resize_diff_right( + &self, diff: f32, col_idx: usize, widths: &mut [f32; COLS], resize_behavior: &[ResizeBehavior; COLS], - direction: i8, ) -> f32 { let mut diff_remaining = diff; - if resize_behavior[col_idx].min_size().is_none() { - return diff; + let mut curr_column = col_idx + 1; + + while diff_remaining > 0.0 && curr_column < COLS { + let Some(min_size) = resize_behavior[curr_column - 1].min_size() else { + curr_column += 1; + continue; + }; + + let mut curr_width = widths[curr_column] - diff_remaining; + + diff_remaining = 0.0; + if min_size > curr_width { + diff_remaining += min_size - curr_width; + curr_width = min_size; + } + widths[curr_column] = curr_width; + curr_column += 1; } - let step_right; - let step_left; - if direction < 0 { - step_right = 0; - step_left = 1; - } else { - step_right = 1; - step_left = 0; - } - if col_idx == 0 && direction < 0 { - return diff; - } - let mut curr_column = col_idx + step_right - step_left; + widths[col_idx] = widths[col_idx] + (diff - diff_remaining); + return diff_remaining; + } - while diff_remaining != 0.0 && curr_column < COLS { + fn propagate_resize_diff_left( + &mut self, + diff: f32, + mut curr_column: usize, + widths: &mut [f32; COLS], + resize_behavior: &[ResizeBehavior; COLS], + ) -> f32 { + let mut diff_remaining = diff; + let col_idx = curr_column; + while diff_remaining < 0.0 { let Some(min_size) = resize_behavior[curr_column].min_size() else { if curr_column == 0 { break; } - curr_column -= step_left; - curr_column += step_right; + curr_column -= 1; continue; }; - let curr_width = widths[curr_column] - diff_remaining; - widths[curr_column] = curr_width; + let mut curr_width = widths[curr_column] + diff_remaining; - if min_size > curr_width { - diff_remaining = min_size - curr_width; - widths[curr_column] = min_size; - } else { - diff_remaining = 0.0; - break; + diff_remaining = 0.0; + if curr_width < min_size { + diff_remaining = curr_width - min_size; + curr_width = min_size } + + widths[curr_column] = curr_width; if curr_column == 0 { break; } - curr_column -= step_left; - curr_column += step_right; + curr_column -= 1; } - widths[col_idx] = widths[col_idx] + (diff - diff_remaining); + widths[col_idx + 1] = widths[col_idx + 1] - (diff - diff_remaining); return diff_remaining; } @@ -755,7 +686,7 @@ impl<const COLS: usize> TableWidths<COLS> { fn lengths(&self, cx: &App) -> [Length; COLS] { self.current .as_ref() - .map(|entity| entity.read(cx).visible_widths.map(Length::Definite)) + .map(|entity| entity.read(cx).widths.map(Length::Definite)) .unwrap_or(self.initial.map(Length::Definite)) } } @@ -868,7 +799,6 @@ impl<const COLS: usize> Table<COLS> { if !widths.initialized { widths.initialized = true; widths.widths = table_widths.initial; - widths.visible_widths = widths.widths; } }) } @@ -898,6 +828,7 @@ fn base_cell_style(width: Option<Length>) -> Div { .px_1p5() .when_some(width, |this, width| this.w(width)) .when(width.is_none(), |this| this.flex_1()) + .justify_start() .whitespace_nowrap() .text_ellipsis() .overflow_hidden() @@ -942,7 +873,7 @@ pub fn render_row<const COLS: usize>( .map(IntoElement::into_any_element) .into_iter() .zip(column_widths) - .map(|(cell, width)| base_cell_style_text(width, cx).px_1().py_0p5().child(cell)), + .map(|(cell, width)| base_cell_style_text(width, cx).px_1p5().py_1().child(cell)), ); let row = if let Some(map_row) = table_context.map_row { @@ -951,30 +882,17 @@ pub fn render_row<const COLS: usize>( row.into_any_element() }; - div().size_full().child(row).into_any_element() + div().h_full().w_full().child(row).into_any_element() } pub fn render_header<const COLS: usize>( headers: [impl IntoElement; COLS], table_context: TableRenderContext<COLS>, - columns_widths: Option<( - WeakEntity<ColumnWidths<COLS>>, - [ResizeBehavior; COLS], - [DefiniteLength; COLS], - )>, - entity_id: Option<EntityId>, cx: &mut App, ) -> impl IntoElement { let column_widths = table_context .column_widths .map_or([None; COLS], |widths| widths.map(Some)); - - let element_id = entity_id - .map(|entity| entity.to_string()) - .unwrap_or_default(); - - let shared_element_id: SharedString = format!("table-{}", element_id).into(); - div() .flex() .flex_row() @@ -984,39 +902,12 @@ pub fn render_header<const COLS: usize>( .p_2() .border_b_1() .border_color(cx.theme().colors().border) - .children(headers.into_iter().enumerate().zip(column_widths).map( - |((header_idx, h), width)| { - base_cell_style_text(width, cx) - .child(h) - .id(ElementId::NamedInteger( - shared_element_id.clone(), - header_idx as u64, - )) - .when_some( - columns_widths.as_ref().cloned(), - |this, (column_widths, resizables, initial_sizes)| { - if resizables[header_idx].is_resizable() { - this.on_click(move |event, window, cx| { - if event.down.click_count > 1 { - column_widths - .update(cx, |column, _| { - column.on_double_click( - header_idx, - &initial_sizes, - &resizables, - window, - ); - }) - .ok(); - } - }) - } else { - this - } - }, - ) - }, - )) + .children( + headers + .into_iter() + .zip(column_widths) + .map(|(h, width)| base_cell_style_text(width, cx).child(h)), + ) } #[derive(Clone)] @@ -1048,12 +939,6 @@ impl<const COLS: usize> RenderOnce for Table<COLS> { .and_then(|widths| Some((widths.current.as_ref()?, widths.resizable))) .map(|(curr, resize_behavior)| (curr.downgrade(), resize_behavior)); - let current_widths_with_initial_sizes = self - .col_widths - .as_ref() - .and_then(|widths| Some((widths.current.as_ref()?, widths.resizable, widths.initial))) - .map(|(curr, resize_behavior, initial)| (curr.downgrade(), resize_behavior, initial)); - let scroll_track_size = px(16.); let h_scroll_offset = if interaction_state .as_ref() @@ -1073,13 +958,7 @@ impl<const COLS: usize> RenderOnce for Table<COLS> { .h_full() .v_flex() .when_some(self.headers.take(), |this, headers| { - this.child(render_header( - headers, - table_context.clone(), - current_widths_with_initial_sizes, - interaction_state.as_ref().map(Entity::entity_id), - cx, - )) + this.child(render_header(headers, table_context.clone(), cx)) }) .when_some(current_widths, { |this, (widths, resize_behavior)| { @@ -1093,28 +972,19 @@ impl<const COLS: usize> RenderOnce for Table<COLS> { .ok(); } }) - .on_children_prepainted({ - let widths = widths.clone(); - move |bounds, _, cx| { - widths - .update(cx, |widths, _| { - // This works because all children x axis bounds are the same - widths.cached_bounds_width = - bounds[0].right() - bounds[0].left(); - }) - .ok(); - } - }) - .on_drop::<DraggedColumn>(move |_, _, cx| { + .on_children_prepainted(move |bounds, _, cx| { widths .update(cx, |widths, _| { - widths.widths = widths.visible_widths; + // This works because all children x axis bounds are the same + widths.cached_bounds_width = bounds[0].right() - bounds[0].left(); }) .ok(); - // Finish the resize operation }) } }) + .on_drop::<DraggedColumn>(|_, _, _| { + // Finish the resize operation + }) .child( div() .flex_grow() @@ -1443,323 +1313,3 @@ impl Component for Table<3> { ) } } - -#[cfg(test)] -mod test { - use super::*; - - fn is_almost_eq(a: &[f32], b: &[f32]) -> bool { - a.len() == b.len() && a.iter().zip(b).all(|(x, y)| (x - y).abs() < 1e-6) - } - - fn cols_to_str<const COLS: usize>(cols: &[f32; COLS], total_size: f32) -> String { - cols.map(|f| "*".repeat(f32::round(f * total_size) as usize)) - .join("|") - } - - fn parse_resize_behavior<const COLS: usize>( - input: &str, - total_size: f32, - ) -> [ResizeBehavior; COLS] { - let mut resize_behavior = [ResizeBehavior::None; COLS]; - let mut max_index = 0; - for (index, col) in input.split('|').enumerate() { - if col.starts_with('X') || col.is_empty() { - resize_behavior[index] = ResizeBehavior::None; - } else if col.starts_with('*') { - resize_behavior[index] = ResizeBehavior::MinSize(col.len() as f32 / total_size); - } else { - panic!("invalid test input: unrecognized resize behavior: {}", col); - } - max_index = index; - } - - if max_index + 1 != COLS { - panic!("invalid test input: too many columns"); - } - resize_behavior - } - - mod reset_column_size { - use super::*; - - fn parse<const COLS: usize>(input: &str) -> ([f32; COLS], f32, Option<usize>) { - let mut widths = [f32::NAN; COLS]; - let mut column_index = None; - for (index, col) in input.split('|').enumerate() { - widths[index] = col.len() as f32; - if col.starts_with('X') { - column_index = Some(index); - } - } - - for w in widths { - assert!(w.is_finite(), "incorrect number of columns"); - } - let total = widths.iter().sum::<f32>(); - for width in &mut widths { - *width /= total; - } - (widths, total, column_index) - } - - #[track_caller] - fn check_reset_size<const COLS: usize>( - initial_sizes: &str, - widths: &str, - expected: &str, - resize_behavior: &str, - ) { - let (initial_sizes, total_1, None) = parse::<COLS>(initial_sizes) else { - panic!("invalid test input: initial sizes should not be marked"); - }; - let (widths, total_2, Some(column_index)) = parse::<COLS>(widths) else { - panic!("invalid test input: widths should be marked"); - }; - assert_eq!( - total_1, total_2, - "invalid test input: total width not the same {total_1}, {total_2}" - ); - let (expected, total_3, None) = parse::<COLS>(expected) else { - panic!("invalid test input: expected should not be marked: {expected:?}"); - }; - assert_eq!( - total_2, total_3, - "invalid test input: total width not the same" - ); - let resize_behavior = parse_resize_behavior::<COLS>(resize_behavior, total_1); - let result = ColumnWidths::reset_to_initial_size( - column_index, - widths, - initial_sizes, - &resize_behavior, - ); - let is_eq = is_almost_eq(&result, &expected); - if !is_eq { - let result_str = cols_to_str(&result, total_1); - let expected_str = cols_to_str(&expected, total_1); - panic!( - "resize failed\ncomputed: {result_str}\nexpected: {expected_str}\n\ncomputed values: {result:?}\nexpected values: {expected:?}\n:minimum widths: {resize_behavior:?}" - ); - } - } - - macro_rules! check_reset_size { - (columns: $cols:expr, starting: $initial:expr, snapshot: $current:expr, expected: $expected:expr, resizing: $resizing:expr $(,)?) => { - check_reset_size::<$cols>($initial, $current, $expected, $resizing); - }; - ($name:ident, columns: $cols:expr, starting: $initial:expr, snapshot: $current:expr, expected: $expected:expr, minimums: $resizing:expr $(,)?) => { - #[test] - fn $name() { - check_reset_size::<$cols>($initial, $current, $expected, $resizing); - } - }; - } - - check_reset_size!( - basic_right, - columns: 5, - starting: "**|**|**|**|**", - snapshot: "**|**|X|***|**", - expected: "**|**|**|**|**", - minimums: "X|*|*|*|*", - ); - - check_reset_size!( - basic_left, - columns: 5, - starting: "**|**|**|**|**", - snapshot: "**|**|***|X|**", - expected: "**|**|**|**|**", - minimums: "X|*|*|*|**", - ); - - check_reset_size!( - squashed_left_reset_col2, - columns: 6, - starting: "*|***|**|**|****|*", - snapshot: "*|*|X|*|*|********", - expected: "*|*|**|*|*|*******", - minimums: "X|*|*|*|*|*", - ); - - check_reset_size!( - grow_cascading_right, - columns: 6, - starting: "*|***|****|**|***|*", - snapshot: "*|***|X|**|**|*****", - expected: "*|***|****|*|*|****", - minimums: "X|*|*|*|*|*", - ); - - check_reset_size!( - squashed_right_reset_col4, - columns: 6, - starting: "*|***|**|**|****|*", - snapshot: "*|********|*|*|X|*", - expected: "*|*****|*|*|****|*", - minimums: "X|*|*|*|*|*", - ); - - check_reset_size!( - reset_col6_right, - columns: 6, - starting: "*|***|**|***|***|**", - snapshot: "*|***|**|***|**|XXX", - expected: "*|***|**|***|***|**", - minimums: "X|*|*|*|*|*", - ); - - check_reset_size!( - reset_col6_left, - columns: 6, - starting: "*|***|**|***|***|**", - snapshot: "*|***|**|***|****|X", - expected: "*|***|**|***|***|**", - minimums: "X|*|*|*|*|*", - ); - - check_reset_size!( - last_column_grow_cascading, - columns: 6, - starting: "*|***|**|**|**|***", - snapshot: "*|*******|*|**|*|X", - expected: "*|******|*|*|*|***", - minimums: "X|*|*|*|*|*", - ); - - check_reset_size!( - goes_left_when_left_has_extreme_diff, - columns: 6, - starting: "*|***|****|**|**|***", - snapshot: "*|********|X|*|**|**", - expected: "*|*****|****|*|**|**", - minimums: "X|*|*|*|*|*", - ); - - check_reset_size!( - basic_shrink_right, - columns: 6, - starting: "**|**|**|**|**|**", - snapshot: "**|**|XXX|*|**|**", - expected: "**|**|**|**|**|**", - minimums: "X|*|*|*|*|*", - ); - - check_reset_size!( - shrink_should_go_left, - columns: 6, - starting: "*|***|**|*|*|*", - snapshot: "*|*|XXX|**|*|*", - expected: "*|**|**|**|*|*", - minimums: "X|*|*|*|*|*", - ); - - check_reset_size!( - shrink_should_go_right, - columns: 6, - starting: "*|***|**|**|**|*", - snapshot: "*|****|XXX|*|*|*", - expected: "*|****|**|**|*|*", - minimums: "X|*|*|*|*|*", - ); - } - - mod drag_handle { - use super::*; - - fn parse<const COLS: usize>(input: &str) -> ([f32; COLS], f32, Option<usize>) { - let mut widths = [f32::NAN; COLS]; - let column_index = input.replace("*", "").find("I"); - for (index, col) in input.replace("I", "|").split('|').enumerate() { - widths[index] = col.len() as f32; - } - - for w in widths { - assert!(w.is_finite(), "incorrect number of columns"); - } - let total = widths.iter().sum::<f32>(); - for width in &mut widths { - *width /= total; - } - (widths, total, column_index) - } - - #[track_caller] - fn check<const COLS: usize>( - distance: i32, - widths: &str, - expected: &str, - resize_behavior: &str, - ) { - let (mut widths, total_1, Some(column_index)) = parse::<COLS>(widths) else { - panic!("invalid test input: widths should be marked"); - }; - let (expected, total_2, None) = parse::<COLS>(expected) else { - panic!("invalid test input: expected should not be marked: {expected:?}"); - }; - assert_eq!( - total_1, total_2, - "invalid test input: total width not the same" - ); - let resize_behavior = parse_resize_behavior::<COLS>(resize_behavior, total_1); - - let distance = distance as f32 / total_1; - - let result = ColumnWidths::drag_column_handle( - distance, - column_index, - &mut widths, - &resize_behavior, - ); - - let is_eq = is_almost_eq(&widths, &expected); - if !is_eq { - let result_str = cols_to_str(&widths, total_1); - let expected_str = cols_to_str(&expected, total_1); - panic!( - "resize failed\ncomputed: {result_str}\nexpected: {expected_str}\n\ncomputed values: {result:?}\nexpected values: {expected:?}\n:minimum widths: {resize_behavior:?}" - ); - } - } - - macro_rules! check { - (columns: $cols:expr, distance: $dist:expr, snapshot: $current:expr, expected: $expected:expr, resizing: $resizing:expr $(,)?) => { - check!($cols, $dist, $snapshot, $expected, $resizing); - }; - ($name:ident, columns: $cols:expr, distance: $dist:expr, snapshot: $current:expr, expected: $expected:expr, minimums: $resizing:expr $(,)?) => { - #[test] - fn $name() { - check::<$cols>($dist, $current, $expected, $resizing); - } - }; - } - - check!( - basic_right_drag, - columns: 3, - distance: 1, - snapshot: "**|**I**", - expected: "**|***|*", - minimums: "X|*|*", - ); - - check!( - drag_left_against_mins, - columns: 5, - distance: -1, - snapshot: "*|*|*|*I*******", - expected: "*|*|*|*|*******", - minimums: "X|*|*|*|*", - ); - - check!( - drag_left, - columns: 5, - distance: -2, - snapshot: "*|*|*|*****I***", - expected: "*|*|*|***|*****", - minimums: "X|*|*|*|*", - ); - } -} diff --git a/crates/sum_tree/src/sum_tree.rs b/crates/sum_tree/src/sum_tree.rs index 4c5ce39590..4f9e01ce20 100644 --- a/crates/sum_tree/src/sum_tree.rs +++ b/crates/sum_tree/src/sum_tree.rs @@ -41,14 +41,16 @@ pub trait Summary: Clone { fn add_summary(&mut self, summary: &Self, cx: &Self::Context); } -/// Catch-all implementation for when you need something that implements [`Summary`] without a specific type. -/// We implement it on a &'static, as that avoids blanket impl collisions with `impl<T: Summary> Dimension for T` -/// (as we also need unit type to be a fill-in dimension) -impl Summary for &'static () { +/// This type exists because we can't implement Summary for () without causing +/// type resolution errors +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct Unit; + +impl Summary for Unit { type Context = (); fn zero(_: &()) -> Self { - &() + Unit } fn add_summary(&mut self, _: &Self, _: &()) {} diff --git a/crates/tasks_ui/src/modal.rs b/crates/tasks_ui/src/modal.rs index c4b0931c35..1510f613e3 100644 --- a/crates/tasks_ui/src/modal.rs +++ b/crates/tasks_ui/src/modal.rs @@ -500,7 +500,7 @@ impl PickerDelegate for TasksModalDelegate { .map(|icon| icon.color(Color::Muted).size(IconSize::Small)); let indicator = if matches!(source_kind, TaskSourceKind::Lsp { .. }) { Some(Indicator::icon( - Icon::new(IconName::BoltOutlined).size(IconSize::Small), + Icon::new(IconName::Bolt).size(IconSize::Small), )) } else { None diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 2e6be5aaf4..1cc1fbcf6f 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -430,7 +430,6 @@ impl TerminalView { fn settings_changed(&mut self, cx: &mut Context<Self>) { let settings = TerminalSettings::get_global(cx); - let breadcrumb_visibility_changed = self.show_breadcrumbs != settings.toolbar.breadcrumbs; self.show_breadcrumbs = settings.toolbar.breadcrumbs; let new_cursor_shape = settings.cursor_shape.unwrap_or_default(); @@ -442,9 +441,6 @@ impl TerminalView { }); } - if breadcrumb_visibility_changed { - cx.emit(ItemEvent::UpdateBreadcrumbs); - } cx.notify(); } @@ -1591,7 +1587,7 @@ impl Item for TerminalView { let (icon, icon_color, rerun_button) = match terminal.task() { Some(terminal_task) => match &terminal_task.status { TaskStatus::Running => ( - IconName::PlayOutlined, + IconName::Play, Color::Disabled, TerminalView::rerun_button(&terminal_task), ), diff --git a/crates/theme/src/icon_theme.rs b/crates/theme/src/icon_theme.rs index 10fd1e002d..09f5df06b0 100644 --- a/crates/theme/src/icon_theme.rs +++ b/crates/theme/src/icon_theme.rs @@ -152,7 +152,6 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ("javascript", &["cjs", "js", "mjs"]), ("json", &["json"]), ("julia", &["jl"]), - ("kdl", &["kdl"]), ("kotlin", &["kt"]), ("lock", &["lock"]), ("log", &["log"]), @@ -217,7 +216,6 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ "stylelintrc.yml", ], ), - ("surrealql", &["surql"]), ("svelte", &["svelte"]), ("swift", &["swift"]), ("tcl", &["tcl"]), @@ -316,7 +314,6 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("javascript", "icons/file_icons/javascript.svg"), ("json", "icons/file_icons/code.svg"), ("julia", "icons/file_icons/julia.svg"), - ("kdl", "icons/file_icons/kdl.svg"), ("kotlin", "icons/file_icons/kotlin.svg"), ("lock", "icons/file_icons/lock.svg"), ("log", "icons/file_icons/info.svg"), @@ -343,7 +340,6 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("solidity", "icons/file_icons/file.svg"), ("storage", "icons/file_icons/database.svg"), ("stylelint", "icons/file_icons/javascript.svg"), - ("surrealql", "icons/file_icons/surrealql.svg"), ("svelte", "icons/file_icons/html.svg"), ("swift", "icons/file_icons/swift.svg"), ("tcl", "icons/file_icons/tcl.svg"), diff --git a/crates/theme/src/settings.rs b/crates/theme/src/settings.rs index 20c837f287..1c4c90a475 100644 --- a/crates/theme/src/settings.rs +++ b/crates/theme/src/settings.rs @@ -438,7 +438,7 @@ fn default_font_fallbacks() -> Option<FontFallbacks> { impl ThemeSettingsContent { /// Sets the theme for the given appearance to the theme with the specified name. - pub fn set_theme(&mut self, theme_name: impl Into<Arc<str>>, appearance: Appearance) { + pub fn set_theme(&mut self, theme_name: String, appearance: Appearance) { if let Some(selection) = self.theme.as_mut() { let theme_to_update = match selection { ThemeSelection::Static(theme) => theme, @@ -867,7 +867,6 @@ impl settings::Settings for ThemeSettings { .user .into_iter() .chain(sources.release_channel) - .chain(sources.profile) .chain(sources.server) { if let Some(value) = value.ui_density { diff --git a/crates/title_bar/Cargo.toml b/crates/title_bar/Cargo.toml index cf178e2850..8e95c6f79f 100644 --- a/crates/title_bar/Cargo.toml +++ b/crates/title_bar/Cargo.toml @@ -32,7 +32,6 @@ auto_update.workspace = true call.workspace = true chrono.workspace = true client.workspace = true -cloud_llm_client.workspace = true db.workspace = true gpui = { workspace = true, features = ["screen-capture"] } notifications.workspace = true diff --git a/crates/title_bar/src/collab.rs b/crates/title_bar/src/collab.rs index d026b4de14..056c981ccf 100644 --- a/crates/title_bar/src/collab.rs +++ b/crates/title_bar/src/collab.rs @@ -11,8 +11,8 @@ use gpui::{App, Task, Window, actions}; use rpc::proto::{self}; use theme::ActiveTheme; use ui::{ - Avatar, AvatarAudioStatusIndicator, ContextMenu, ContextMenuItem, Divider, DividerColor, - Facepile, PopoverMenu, SplitButton, SplitButtonStyle, TintColor, Tooltip, prelude::*, + Avatar, AvatarAudioStatusIndicator, ContextMenu, ContextMenuItem, Divider, Facepile, + PopoverMenu, SplitButton, SplitButtonStyle, TintColor, Tooltip, prelude::*, }; use util::maybe; use workspace::notifications::DetachAndPromptErr; @@ -343,24 +343,6 @@ impl TitleBar { let mut children = Vec::new(); - children.push( - h_flex() - .gap_1() - .child( - IconButton::new("leave-call", IconName::Exit) - .style(ButtonStyle::Subtle) - .tooltip(Tooltip::text("Leave Call")) - .icon_size(IconSize::Small) - .on_click(move |_, _window, cx| { - ActiveCall::global(cx) - .update(cx, |call, cx| call.hang_up(cx)) - .detach_and_log_err(cx); - }), - ) - .child(Divider::vertical().color(DividerColor::Border)) - .into_any_element(), - ); - if is_local && can_share_projects && !is_connecting_to_project { children.push( Button::new( @@ -387,14 +369,32 @@ impl TitleBar { ); } + children.push( + div() + .pr_2() + .child( + IconButton::new("leave-call", ui::IconName::Exit) + .style(ButtonStyle::Subtle) + .tooltip(Tooltip::text("Leave call")) + .icon_size(IconSize::Small) + .on_click(move |_, _window, cx| { + ActiveCall::global(cx) + .update(cx, |call, cx| call.hang_up(cx)) + .detach_and_log_err(cx); + }), + ) + .child(Divider::vertical()) + .into_any_element(), + ); + if can_use_microphone { children.push( IconButton::new( "mute-microphone", if is_muted { - IconName::MicMute + ui::IconName::MicMute } else { - IconName::Mic + ui::IconName::Mic }, ) .tooltip(move |window, cx| { @@ -429,9 +429,9 @@ impl TitleBar { IconButton::new( "mute-sound", if is_deafened { - IconName::AudioOff + ui::IconName::AudioOff } else { - IconName::AudioOn + ui::IconName::AudioOn }, ) .style(ButtonStyle::Subtle) @@ -462,7 +462,7 @@ impl TitleBar { ); if can_use_microphone && screen_sharing_supported { - let trigger = IconButton::new("screen-share", IconName::Screen) + let trigger = IconButton::new("screen-share", ui::IconName::Screen) .style(ButtonStyle::Subtle) .icon_size(IconSize::Small) .toggle_state(is_screen_sharing) @@ -498,7 +498,7 @@ impl TitleBar { trigger.render(window, cx), self.render_screen_list().into_any_element(), ) - .style(SplitButtonStyle::Transparent) + .style(SplitButtonStyle::Outlined) .into_any_element(), ); } @@ -513,11 +513,11 @@ impl TitleBar { .with_handle(self.screen_share_popover_handle.clone()) .trigger( ui::ButtonLike::new_rounded_right("screen-share-screen-list-trigger") + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::None) .child( - h_flex() - .mx_neg_0p5() - .h_full() - .justify_center() + div() + .px_1() .child(Icon::new(IconName::ChevronDownSmall).size(IconSize::XSmall)), ) .toggle_state(self.screen_share_popover_handle.is_deployed()), diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index a8b16d881f..17c4c85b6d 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -21,7 +21,6 @@ use crate::application_menu::{ use auto_update::AutoUpdateStatus; use call::ActiveCall; use client::{Client, UserStore, zed_urls}; -use cloud_llm_client::Plan; use gpui::{ Action, AnyElement, App, Context, Corner, Element, Entity, Focusable, InteractiveElement, IntoElement, MouseButton, ParentElement, Render, StatefulInteractiveElement, Styled, @@ -29,6 +28,7 @@ use gpui::{ }; use onboarding_banner::OnboardingBanner; use project::Project; +use rpc::proto; use settings::Settings as _; use settings_ui::keybindings; use std::sync::Arc; @@ -179,23 +179,24 @@ impl Render for TitleBar { children.push(self.banner.clone().into_any_element()) } - let status = self.client.status(); - let status = &*status.borrow(); - let user = self.user_store.read(cx).current_user(); - children.push( h_flex() .gap_1() .pr_1() .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()) .children(self.render_call_controls(window, cx)) - .children(self.render_connection_status(status, cx)) - .when( - user.is_none() && TitleBarSettings::get_global(cx).show_sign_in, - |el| el.child(self.render_sign_in_button(cx)), - ) - .when(user.is_some(), |parent| { - parent.child(self.render_user_menu_button(cx)) + .map(|el| { + let status = self.client.status(); + let status = &*status.borrow(); + if matches!(status, client::Status::Connected { .. }) { + el.child(self.render_user_menu_button(cx)) + } else { + el.children(self.render_connection_status(status, cx)) + .when(TitleBarSettings::get_global(cx).show_sign_in, |el| { + el.child(self.render_sign_in_button(cx)) + }) + .child(self.render_user_menu_button(cx)) + } }) .into_any_element(), ); @@ -617,8 +618,9 @@ impl TitleBar { window .spawn(cx, async move |cx| { client - .sign_in_with_optional_connect(true, &cx) + .authenticate_and_connect(true, &cx) .await + .into_response() .notify_async_err(cx); }) .detach(); @@ -628,8 +630,8 @@ impl TitleBar { pub fn render_user_menu_button(&mut self, cx: &mut Context<Self>) -> impl Element { let user_store = self.user_store.read(cx); if let Some(user) = user_store.current_user() { - let has_subscription_period = user_store.subscription_period().is_some(); - let plan = user_store.plan().filter(|_| { + let has_subscription_period = self.user_store.read(cx).subscription_period().is_some(); + let plan = self.user_store.read(cx).current_plan().filter(|_| { // Since the user might be on the legacy free plan we filter based on whether we have a subscription period. has_subscription_period }); @@ -656,9 +658,13 @@ impl TitleBar { let user_login = user.github_login.clone(); let (plan_name, label_color, bg_color) = match plan { - None | Some(Plan::ZedFree) => ("Free", Color::Default, free_chip_bg), - Some(Plan::ZedProTrial) => ("Pro Trial", Color::Accent, pro_chip_bg), - Some(Plan::ZedPro) => ("Pro", Color::Accent, pro_chip_bg), + None | Some(proto::Plan::Free) => { + ("Free", Color::Default, free_chip_bg) + } + Some(proto::Plan::ZedProTrial) => { + ("Pro Trial", Color::Accent, pro_chip_bg) + } + Some(proto::Plan::ZedPro) => ("Pro", Color::Accent, pro_chip_bg), }; menu.custom_entry( @@ -682,10 +688,6 @@ impl TitleBar { ) .separator() .action("Settings", zed_actions::OpenSettings.boxed_clone()) - .action( - "Settings Profiles", - zed_actions::settings_profile_selector::Toggle.boxed_clone(), - ) .action("Key Bindings", Box::new(keybindings::OpenKeymapEditor)) .action( "Themes…", @@ -730,10 +732,6 @@ impl TitleBar { .menu(|window, cx| { ContextMenu::build(window, cx, |menu, _, _| { menu.action("Settings", zed_actions::OpenSettings.boxed_clone()) - .action( - "Settings Profiles", - zed_actions::settings_profile_selector::Toggle.boxed_clone(), - ) .action("Key Bindings", Box::new(keybindings::OpenKeymapEditor)) .action( "Themes…", diff --git a/crates/ui/src/components.rs b/crates/ui/src/components.rs index 486673e733..9c2961c55f 100644 --- a/crates/ui/src/components.rs +++ b/crates/ui/src/components.rs @@ -1,5 +1,4 @@ mod avatar; -mod badge; mod banner; mod button; mod callout; @@ -42,7 +41,6 @@ mod tooltip; mod stories; pub use avatar::*; -pub use badge::*; pub use banner::*; pub use button::*; pub use callout::*; diff --git a/crates/ui/src/components/badge.rs b/crates/ui/src/components/badge.rs deleted file mode 100644 index 2eee084bbb..0000000000 --- a/crates/ui/src/components/badge.rs +++ /dev/null @@ -1,66 +0,0 @@ -use crate::Divider; -use crate::DividerColor; -use crate::component_prelude::*; -use crate::prelude::*; -use gpui::{AnyElement, IntoElement, SharedString, Window}; - -#[derive(IntoElement, RegisterComponent)] -pub struct Badge { - label: SharedString, - icon: IconName, -} - -impl Badge { - pub fn new(label: impl Into<SharedString>) -> Self { - Self { - label: label.into(), - icon: IconName::Check, - } - } - - pub fn icon(mut self, icon: IconName) -> Self { - self.icon = icon; - self - } -} - -impl RenderOnce for Badge { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - h_flex() - .h_full() - .gap_1() - .pl_1() - .pr_2() - .border_1() - .border_color(cx.theme().colors().border.opacity(0.6)) - .bg(cx.theme().colors().element_background) - .rounded_sm() - .overflow_hidden() - .child( - Icon::new(self.icon) - .size(IconSize::XSmall) - .color(Color::Muted), - ) - .child(Divider::vertical().color(DividerColor::Border)) - .child(Label::new(self.label.clone()).size(LabelSize::Small).ml_1()) - } -} - -impl Component for Badge { - fn scope() -> ComponentScope { - ComponentScope::DataDisplay - } - - fn description() -> Option<&'static str> { - Some( - "A compact, labeled component with optional icon for displaying status, categories, or metadata.", - ) - } - - fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { - Some( - single_example("Basic Badge", Badge::new("Default").into_any_element()) - .into_any_element(), - ) - } -} diff --git a/crates/ui/src/components/button/button_like.rs b/crates/ui/src/components/button/button_like.rs index 03f7964f35..135ecdfe62 100644 --- a/crates/ui/src/components/button/button_like.rs +++ b/crates/ui/src/components/button/button_like.rs @@ -358,7 +358,6 @@ impl ButtonStyle { #[derive(Default, PartialEq, Clone, Copy)] pub enum ButtonSize { Large, - Medium, #[default] Default, Compact, @@ -369,7 +368,6 @@ impl ButtonSize { pub fn rems(self) -> Rems { match self { ButtonSize::Large => rems_from_px(32.), - ButtonSize::Medium => rems_from_px(28.), ButtonSize::Default => rems_from_px(22.), ButtonSize::Compact => rems_from_px(18.), ButtonSize::None => rems_from_px(16.), @@ -575,7 +573,7 @@ impl RenderOnce for ButtonLike { }) .gap(DynamicSpacing::Base04.rems(cx)) .map(|this| match self.size { - ButtonSize::Large | ButtonSize::Medium => this.px(DynamicSpacing::Base06.rems(cx)), + ButtonSize::Large => this.px(DynamicSpacing::Base06.rems(cx)), ButtonSize::Default | ButtonSize::Compact => { this.px(DynamicSpacing::Base04.rems(cx)) } diff --git a/crates/ui/src/components/button/split_button.rs b/crates/ui/src/components/button/split_button.rs index 14b9fd153c..a7fa2106d1 100644 --- a/crates/ui/src/components/button/split_button.rs +++ b/crates/ui/src/components/button/split_button.rs @@ -12,7 +12,6 @@ use super::ButtonLike; pub enum SplitButtonStyle { Filled, Outlined, - Transparent, } /// /// A button with two parts: a primary action on the left and a secondary action on the right. @@ -45,17 +44,10 @@ impl SplitButton { impl RenderOnce for SplitButton { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - let is_filled_or_outlined = matches!( - self.style, - SplitButtonStyle::Filled | SplitButtonStyle::Outlined - ); - h_flex() .rounded_sm() - .when(is_filled_or_outlined, |this| { - this.border_1() - .border_color(cx.theme().colors().border.opacity(0.8)) - }) + .border_1() + .border_color(cx.theme().colors().border.opacity(0.5)) .child(div().flex_grow().child(self.left)) .child( div() diff --git a/crates/ui/src/components/button/toggle_button.rs b/crates/ui/src/components/button/toggle_button.rs index a1e4d65a24..eca23fe6f7 100644 --- a/crates/ui/src/components/button/toggle_button.rs +++ b/crates/ui/src/components/button/toggle_button.rs @@ -1,6 +1,6 @@ use gpui::{AnyView, ClickEvent}; -use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, TintColor, prelude::*}; +use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, prelude::*}; /// The position of a [`ToggleButton`] within a group of buttons. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -290,617 +290,3 @@ impl Component for ToggleButton { ) } } - -pub struct ButtonConfiguration { - label: SharedString, - icon: Option<IconName>, - on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, - selected: bool, -} - -mod private { - pub trait ToggleButtonStyle {} -} - -pub trait ButtonBuilder: 'static + private::ToggleButtonStyle { - fn into_configuration(self) -> ButtonConfiguration; -} - -pub struct ToggleButtonSimple { - label: SharedString, - on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, - selected: bool, -} - -impl ToggleButtonSimple { - pub fn new( - label: impl Into<SharedString>, - on_click: impl Fn(&ClickEvent, &mut Window, &mut App) + 'static, - ) -> Self { - Self { - label: label.into(), - on_click: Box::new(on_click), - selected: false, - } - } - - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self - } -} - -impl private::ToggleButtonStyle for ToggleButtonSimple {} - -impl ButtonBuilder for ToggleButtonSimple { - fn into_configuration(self) -> ButtonConfiguration { - ButtonConfiguration { - label: self.label, - icon: None, - on_click: self.on_click, - selected: self.selected, - } - } -} - -pub struct ToggleButtonWithIcon { - label: SharedString, - icon: IconName, - on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, - selected: bool, -} - -impl ToggleButtonWithIcon { - pub fn new( - label: impl Into<SharedString>, - icon: IconName, - on_click: impl Fn(&ClickEvent, &mut Window, &mut App) + 'static, - ) -> Self { - Self { - label: label.into(), - icon, - on_click: Box::new(on_click), - selected: false, - } - } - - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self - } -} - -impl private::ToggleButtonStyle for ToggleButtonWithIcon {} - -impl ButtonBuilder for ToggleButtonWithIcon { - fn into_configuration(self) -> ButtonConfiguration { - ButtonConfiguration { - label: self.label, - icon: Some(self.icon), - on_click: self.on_click, - selected: self.selected, - } - } -} - -#[derive(Clone, Copy, PartialEq)] -pub enum ToggleButtonGroupStyle { - Transparent, - Filled, - Outlined, -} - -#[derive(Clone, Copy, PartialEq)] -pub enum ToggleButtonGroupSize { - Default, - Medium, -} - -#[derive(IntoElement)] -pub struct ToggleButtonGroup<T, const COLS: usize = 3, const ROWS: usize = 1> -where - T: ButtonBuilder, -{ - group_name: &'static str, - rows: [[T; COLS]; ROWS], - style: ToggleButtonGroupStyle, - size: ToggleButtonGroupSize, - button_width: Rems, - selected_index: usize, -} - -impl<T: ButtonBuilder, const COLS: usize> ToggleButtonGroup<T, COLS> { - pub fn single_row(group_name: &'static str, buttons: [T; COLS]) -> Self { - Self { - group_name, - rows: [buttons], - style: ToggleButtonGroupStyle::Transparent, - size: ToggleButtonGroupSize::Default, - button_width: rems_from_px(100.), - selected_index: 0, - } - } -} - -impl<T: ButtonBuilder, const COLS: usize> ToggleButtonGroup<T, COLS, 2> { - pub fn two_rows(group_name: &'static str, first_row: [T; COLS], second_row: [T; COLS]) -> Self { - Self { - group_name, - rows: [first_row, second_row], - style: ToggleButtonGroupStyle::Transparent, - size: ToggleButtonGroupSize::Default, - button_width: rems_from_px(100.), - selected_index: 0, - } - } -} - -impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> ToggleButtonGroup<T, COLS, ROWS> { - pub fn style(mut self, style: ToggleButtonGroupStyle) -> Self { - self.style = style; - self - } - - pub fn size(mut self, size: ToggleButtonGroupSize) -> Self { - self.size = size; - self - } - - pub fn button_width(mut self, button_width: Rems) -> Self { - self.button_width = button_width; - self - } - - pub fn selected_index(mut self, index: usize) -> Self { - self.selected_index = index; - self - } -} - -impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> RenderOnce - for ToggleButtonGroup<T, COLS, ROWS> -{ - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - let entries = - self.rows.into_iter().enumerate().map(|(row_index, row)| { - row.into_iter().enumerate().map(move |(col_index, button)| { - let ButtonConfiguration { - label, - icon, - on_click, - selected, - } = button.into_configuration(); - - let entry_index = row_index * COLS + col_index; - - ButtonLike::new((self.group_name, entry_index)) - .when(entry_index == self.selected_index || selected, |this| { - this.toggle_state(true) - .selected_style(ButtonStyle::Tinted(TintColor::Accent)) - }) - .rounding(None) - .when(self.style == ToggleButtonGroupStyle::Filled, |button| { - button.style(ButtonStyle::Filled) - }) - .when(self.size == ToggleButtonGroupSize::Medium, |button| { - button.size(ButtonSize::Medium) - }) - .child( - h_flex() - .min_w(self.button_width) - .gap_1p5() - .px_3() - .py_1() - .justify_center() - .when_some(icon, |this, icon| { - this.py_2() - .child(Icon::new(icon).size(IconSize::XSmall).map(|this| { - if entry_index == self.selected_index || selected { - this.color(Color::Accent) - } else { - this.color(Color::Muted) - } - })) - }) - .child(Label::new(label).size(LabelSize::Small).when( - entry_index == self.selected_index || selected, - |this| this.color(Color::Accent), - )), - ) - .on_click(on_click) - .into_any_element() - }) - }); - - let border_color = cx.theme().colors().border.opacity(0.6); - let is_outlined_or_filled = self.style == ToggleButtonGroupStyle::Outlined - || self.style == ToggleButtonGroupStyle::Filled; - let is_transparent = self.style == ToggleButtonGroupStyle::Transparent; - - v_flex() - .rounded_md() - .overflow_hidden() - .map(|this| { - if is_transparent { - this.gap_px() - } else { - this.border_1().border_color(border_color) - } - }) - .children(entries.enumerate().map(|(row_index, row)| { - let last_row = row_index == ROWS - 1; - h_flex() - .when(!is_outlined_or_filled, |this| this.gap_px()) - .when(is_outlined_or_filled && !last_row, |this| { - this.border_b_1().border_color(border_color) - }) - .children(row.enumerate().map(|(item_index, item)| { - let last_item = item_index == COLS - 1; - div() - .when(is_outlined_or_filled && !last_item, |this| { - this.border_r_1().border_color(border_color) - }) - .child(item) - })) - })) - } -} - -fn register_toggle_button_group() { - component::register_component::<ToggleButtonGroup<ToggleButtonSimple>>(); -} - -component::__private::inventory::submit! { - component::ComponentFn::new(register_toggle_button_group) -} - -impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> Component - for ToggleButtonGroup<T, COLS, ROWS> -{ - fn name() -> &'static str { - "ToggleButtonGroup" - } - - fn scope() -> ComponentScope { - ComponentScope::Input - } - - fn sort_name() -> &'static str { - "ButtonG" - } - - fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { - Some( - v_flex() - .gap_6() - .children(vec![example_group_with_title( - "Transparent Variant", - vec![ - single_example( - "Single Row Group", - ToggleButtonGroup::single_row( - "single_row_test", - [ - ToggleButtonSimple::new("First", |_, _, _| {}), - ToggleButtonSimple::new("Second", |_, _, _| {}), - ToggleButtonSimple::new("Third", |_, _, _| {}), - ], - ) - .selected_index(1) - .button_width(rems_from_px(100.)) - .into_any_element(), - ), - single_example( - "Single Row Group with icons", - ToggleButtonGroup::single_row( - "single_row_test_icon", - [ - ToggleButtonWithIcon::new( - "First", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Second", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Third", - IconName::AiZed, - |_, _, _| {}, - ), - ], - ) - .selected_index(1) - .button_width(rems_from_px(100.)) - .into_any_element(), - ), - single_example( - "Multiple Row Group", - ToggleButtonGroup::two_rows( - "multiple_row_test", - [ - ToggleButtonSimple::new("First", |_, _, _| {}), - ToggleButtonSimple::new("Second", |_, _, _| {}), - ToggleButtonSimple::new("Third", |_, _, _| {}), - ], - [ - ToggleButtonSimple::new("Fourth", |_, _, _| {}), - ToggleButtonSimple::new("Fifth", |_, _, _| {}), - ToggleButtonSimple::new("Sixth", |_, _, _| {}), - ], - ) - .selected_index(3) - .button_width(rems_from_px(100.)) - .into_any_element(), - ), - single_example( - "Multiple Row Group with Icons", - ToggleButtonGroup::two_rows( - "multiple_row_test_icons", - [ - ToggleButtonWithIcon::new( - "First", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Second", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Third", - IconName::AiZed, - |_, _, _| {}, - ), - ], - [ - ToggleButtonWithIcon::new( - "Fourth", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Fifth", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Sixth", - IconName::AiZed, - |_, _, _| {}, - ), - ], - ) - .selected_index(3) - .button_width(rems_from_px(100.)) - .into_any_element(), - ), - ], - )]) - .children(vec![example_group_with_title( - "Outlined Variant", - vec![ - single_example( - "Single Row Group", - ToggleButtonGroup::single_row( - "single_row_test_outline", - [ - ToggleButtonSimple::new("First", |_, _, _| {}), - ToggleButtonSimple::new("Second", |_, _, _| {}), - ToggleButtonSimple::new("Third", |_, _, _| {}), - ], - ) - .selected_index(1) - .style(ToggleButtonGroupStyle::Outlined) - .into_any_element(), - ), - single_example( - "Single Row Group with icons", - ToggleButtonGroup::single_row( - "single_row_test_icon_outlined", - [ - ToggleButtonWithIcon::new( - "First", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Second", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Third", - IconName::AiZed, - |_, _, _| {}, - ), - ], - ) - .selected_index(1) - .button_width(rems_from_px(100.)) - .style(ToggleButtonGroupStyle::Outlined) - .into_any_element(), - ), - single_example( - "Multiple Row Group", - ToggleButtonGroup::two_rows( - "multiple_row_test", - [ - ToggleButtonSimple::new("First", |_, _, _| {}), - ToggleButtonSimple::new("Second", |_, _, _| {}), - ToggleButtonSimple::new("Third", |_, _, _| {}), - ], - [ - ToggleButtonSimple::new("Fourth", |_, _, _| {}), - ToggleButtonSimple::new("Fifth", |_, _, _| {}), - ToggleButtonSimple::new("Sixth", |_, _, _| {}), - ], - ) - .selected_index(3) - .button_width(rems_from_px(100.)) - .style(ToggleButtonGroupStyle::Outlined) - .into_any_element(), - ), - single_example( - "Multiple Row Group with Icons", - ToggleButtonGroup::two_rows( - "multiple_row_test", - [ - ToggleButtonWithIcon::new( - "First", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Second", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Third", - IconName::AiZed, - |_, _, _| {}, - ), - ], - [ - ToggleButtonWithIcon::new( - "Fourth", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Fifth", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Sixth", - IconName::AiZed, - |_, _, _| {}, - ), - ], - ) - .selected_index(3) - .button_width(rems_from_px(100.)) - .style(ToggleButtonGroupStyle::Outlined) - .into_any_element(), - ), - ], - )]) - .children(vec![example_group_with_title( - "Filled Variant", - vec![ - single_example( - "Single Row Group", - ToggleButtonGroup::single_row( - "single_row_test_outline", - [ - ToggleButtonSimple::new("First", |_, _, _| {}), - ToggleButtonSimple::new("Second", |_, _, _| {}), - ToggleButtonSimple::new("Third", |_, _, _| {}), - ], - ) - .selected_index(2) - .style(ToggleButtonGroupStyle::Filled) - .into_any_element(), - ), - single_example( - "Single Row Group with icons", - ToggleButtonGroup::single_row( - "single_row_test_icon_outlined", - [ - ToggleButtonWithIcon::new( - "First", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Second", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Third", - IconName::AiZed, - |_, _, _| {}, - ), - ], - ) - .selected_index(1) - .button_width(rems_from_px(100.)) - .style(ToggleButtonGroupStyle::Filled) - .into_any_element(), - ), - single_example( - "Multiple Row Group", - ToggleButtonGroup::two_rows( - "multiple_row_test", - [ - ToggleButtonSimple::new("First", |_, _, _| {}), - ToggleButtonSimple::new("Second", |_, _, _| {}), - ToggleButtonSimple::new("Third", |_, _, _| {}), - ], - [ - ToggleButtonSimple::new("Fourth", |_, _, _| {}), - ToggleButtonSimple::new("Fifth", |_, _, _| {}), - ToggleButtonSimple::new("Sixth", |_, _, _| {}), - ], - ) - .selected_index(3) - .button_width(rems_from_px(100.)) - .style(ToggleButtonGroupStyle::Filled) - .into_any_element(), - ), - single_example( - "Multiple Row Group with Icons", - ToggleButtonGroup::two_rows( - "multiple_row_test", - [ - ToggleButtonWithIcon::new( - "First", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Second", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Third", - IconName::AiZed, - |_, _, _| {}, - ), - ], - [ - ToggleButtonWithIcon::new( - "Fourth", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Fifth", - IconName::AiZed, - |_, _, _| {}, - ), - ToggleButtonWithIcon::new( - "Sixth", - IconName::AiZed, - |_, _, _| {}, - ), - ], - ) - .selected_index(3) - .button_width(rems_from_px(100.)) - .style(ToggleButtonGroupStyle::Filled) - .into_any_element(), - ), - ], - )]) - .into_any_element(), - ) - } -} diff --git a/crates/ui/src/components/dropdown_menu.rs b/crates/ui/src/components/dropdown_menu.rs index cdb98086ca..189fac930f 100644 --- a/crates/ui/src/components/dropdown_menu.rs +++ b/crates/ui/src/components/dropdown_menu.rs @@ -8,7 +8,6 @@ use super::PopoverMenuHandle; pub enum DropdownStyle { #[default] Solid, - Outlined, Ghost, } @@ -148,23 +147,6 @@ impl Component for DropdownMenu { ), ], ), - example_group_with_title( - "Styles", - vec![ - single_example( - "Outlined", - DropdownMenu::new("outlined", "Outlined Dropdown", menu.clone()) - .style(DropdownStyle::Outlined) - .into_any_element(), - ), - single_example( - "Ghost", - DropdownMenu::new("ghost", "Ghost Dropdown", menu.clone()) - .style(DropdownStyle::Ghost) - .into_any_element(), - ), - ], - ), example_group_with_title( "States", vec![single_example( @@ -188,13 +170,10 @@ pub struct DropdownTriggerStyle { impl DropdownTriggerStyle { pub fn for_style(style: DropdownStyle, cx: &App) -> Self { let colors = cx.theme().colors(); - let bg = match style { DropdownStyle::Solid => colors.editor_background, - DropdownStyle::Outlined => colors.surface_background, DropdownStyle::Ghost => colors.ghost_element_background, }; - Self { bg } } } @@ -265,24 +244,17 @@ impl RenderOnce for DropdownMenuTrigger { let disabled = self.disabled; let style = DropdownTriggerStyle::for_style(self.style, cx); - let is_outlined = matches!(self.style, DropdownStyle::Outlined); h_flex() .id("dropdown-menu-trigger") - .min_w_20() + .justify_between() + .rounded_sm() + .bg(style.bg) .pl_2() .pr_1p5() .py_0p5() .gap_2() - .justify_between() - .rounded_sm() - .bg(style.bg) - .hover(|s| s.bg(cx.theme().colors().element_hover)) - .when(is_outlined, |this| { - this.border_1() - .border_color(cx.theme().colors().border) - .overflow_hidden() - }) + .min_w_20() .map(|el| { if self.full_width { el.w_full() diff --git a/crates/ui/src/components/keybinding.rs b/crates/ui/src/components/keybinding.rs index 5779093ccc..1d91492f26 100644 --- a/crates/ui/src/components/keybinding.rs +++ b/crates/ui/src/components/keybinding.rs @@ -44,7 +44,7 @@ impl KeyBinding { pub fn for_action_in( action: &dyn Action, focus: &FocusHandle, - window: &Window, + window: &mut Window, cx: &App, ) -> Option<Self> { let key_binding = window.highest_precedence_binding_for_action_in(action, focus)?; diff --git a/crates/ui/src/components/modal.rs b/crates/ui/src/components/modal.rs index a70f5e1ea5..2145b34ef2 100644 --- a/crates/ui/src/components/modal.rs +++ b/crates/ui/src/components/modal.rs @@ -1,5 +1,5 @@ use crate::{ - Clickable, Color, DynamicSpacing, Headline, HeadlineSize, Icon, IconButton, IconButtonShape, + Clickable, Color, DynamicSpacing, Headline, HeadlineSize, IconButton, IconButtonShape, IconName, Label, LabelCommon, LabelSize, h_flex, v_flex, }; use gpui::{prelude::FluentBuilder, *}; @@ -92,7 +92,6 @@ impl RenderOnce for Modal { #[derive(IntoElement)] pub struct ModalHeader { - icon: Option<Icon>, headline: Option<SharedString>, description: Option<SharedString>, children: SmallVec<[AnyElement; 2]>, @@ -109,7 +108,6 @@ impl Default for ModalHeader { impl ModalHeader { pub fn new() -> Self { Self { - icon: None, headline: None, description: None, children: SmallVec::new(), @@ -118,11 +116,6 @@ impl ModalHeader { } } - pub fn icon(mut self, icon: Icon) -> Self { - self.icon = Some(icon); - self - } - /// Set the headline of the modal. /// /// This will insert the headline as the first item @@ -186,17 +179,12 @@ impl RenderOnce for ModalHeader { ) }) .child( - v_flex() - .flex_1() - .child( - h_flex() - .gap_1() - .when_some(self.icon, |this, icon| this.child(icon)) - .children(children), - ) - .when_some(self.description, |this, description| { + v_flex().flex_1().children(children).when_some( + self.description, + |this, description| { this.child(Label::new(description).color(Color::Muted).mb_2()) - }), + }, + ), ) .when(self.show_dismiss_button, |this| { this.child( diff --git a/crates/ui/src/components/numeric_stepper.rs b/crates/ui/src/components/numeric_stepper.rs index 5a84633d1b..f9e6e88f01 100644 --- a/crates/ui/src/components/numeric_stepper.rs +++ b/crates/ui/src/components/numeric_stepper.rs @@ -2,18 +2,10 @@ use gpui::ClickEvent; use crate::{IconButtonShape, prelude::*}; -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] -pub enum NumericStepperStyle { - Outlined, - #[default] - Ghost, -} - -#[derive(IntoElement, RegisterComponent)] +#[derive(IntoElement)] pub struct NumericStepper { id: ElementId, value: SharedString, - style: NumericStepperStyle, on_decrement: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, on_increment: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, /// Whether to reserve space for the reset button. @@ -31,7 +23,6 @@ impl NumericStepper { Self { id: id.into(), value: value.into(), - style: NumericStepperStyle::default(), on_decrement: Box::new(on_decrement), on_increment: Box::new(on_increment), reserve_space_for_reset: false, @@ -39,11 +30,6 @@ impl NumericStepper { } } - pub fn style(mut self, style: NumericStepperStyle) -> Self { - self.style = style; - self - } - pub fn reserve_space_for_reset(mut self, reserve_space_for_reset: bool) -> Self { self.reserve_space_for_reset = reserve_space_for_reset; self @@ -63,8 +49,6 @@ impl RenderOnce for NumericStepper { let shape = IconButtonShape::Square; let icon_size = IconSize::Small; - let is_outlined = matches!(self.style, NumericStepperStyle::Outlined); - h_flex() .id(self.id) .gap_1() @@ -90,117 +74,22 @@ impl RenderOnce for NumericStepper { .child( h_flex() .gap_1() - .rounded_sm() - .map(|this| { - if is_outlined { - this.overflow_hidden() - .bg(cx.theme().colors().surface_background) - .border_1() - .border_color(cx.theme().colors().border) - } else { - this.px_1().bg(cx.theme().colors().editor_background) - } - }) - .map(|decrement| { - if is_outlined { - decrement.child( - h_flex() - .id("decrement_button") - .p_1p5() - .size_full() - .justify_center() - .hover(|s| s.bg(cx.theme().colors().element_hover)) - .border_r_1() - .border_color(cx.theme().colors().border) - .child(Icon::new(IconName::Dash).size(IconSize::Small)) - .on_click(self.on_decrement), - ) - } else { - decrement.child( - IconButton::new("decrement", IconName::Dash) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_decrement), - ) - } - }) - .when(is_outlined, |this| this) - .child(Label::new(self.value).mx_3()) - .map(|increment| { - if is_outlined { - increment.child( - h_flex() - .id("increment_button") - .p_1p5() - .size_full() - .justify_center() - .hover(|s| s.bg(cx.theme().colors().element_hover)) - .border_l_1() - .border_color(cx.theme().colors().border) - .child(Icon::new(IconName::Plus).size(IconSize::Small)) - .on_click(self.on_increment), - ) - } else { - increment.child( - IconButton::new("increment", IconName::Dash) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_increment), - ) - } - }), + .px_1() + .rounded_xs() + .bg(cx.theme().colors().editor_background) + .child( + IconButton::new("decrement", IconName::Dash) + .shape(shape) + .icon_size(icon_size) + .on_click(self.on_decrement), + ) + .child(Label::new(self.value)) + .child( + IconButton::new("increment", IconName::Plus) + .shape(shape) + .icon_size(icon_size) + .on_click(self.on_increment), + ), ) } } - -impl Component for NumericStepper { - fn scope() -> ComponentScope { - ComponentScope::Input - } - - fn name() -> &'static str { - "Numeric Stepper" - } - - fn sort_name() -> &'static str { - Self::name() - } - - fn description() -> Option<&'static str> { - Some("A button used to increment or decrement a numeric value.") - } - - fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { - Some( - v_flex() - .gap_6() - .children(vec![example_group_with_title( - "Styles", - vec![ - single_example( - "Default", - NumericStepper::new( - "numeric-stepper-component-preview", - "10", - move |_, _, _| {}, - move |_, _, _| {}, - ) - .into_any_element(), - ), - single_example( - "Outlined", - NumericStepper::new( - "numeric-stepper-with-border-component-preview", - "10", - move |_, _, _| {}, - move |_, _, _| {}, - ) - .style(NumericStepperStyle::Outlined) - .into_any_element(), - ), - ], - )]) - .into_any_element(), - ) - } -} diff --git a/crates/ui/src/components/popover.rs b/crates/ui/src/components/popover.rs index 7143514c52..24460f6d9c 100644 --- a/crates/ui/src/components/popover.rs +++ b/crates/ui/src/components/popover.rs @@ -50,7 +50,7 @@ impl RenderOnce for Popover { v_flex() .elevation_2(cx) .py(POPOVER_Y_PADDING / 2.) - .child(div().children(self.children)), + .children(self.children), ) .when_some(self.aside, |this, aside| { this.child( diff --git a/crates/ui/src/components/scrollbar.rs b/crates/ui/src/components/scrollbar.rs index 7af55b76b7..17ab2e788f 100644 --- a/crates/ui/src/components/scrollbar.rs +++ b/crates/ui/src/components/scrollbar.rs @@ -4,8 +4,8 @@ use crate::{IntoElement, prelude::*, px, relative}; use gpui::{ Along, App, Axis as ScrollbarAxis, BorderStyle, Bounds, ContentMask, Corners, CursorStyle, Edges, Element, ElementId, Entity, EntityId, GlobalElementId, Hitbox, HitboxBehavior, Hsla, - IsZero, LayoutId, ListState, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, - Point, ScrollHandle, ScrollWheelEvent, Size, Style, UniformListScrollHandle, Window, quad, + IsZero, LayoutId, ListState, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, Point, + ScrollHandle, ScrollWheelEvent, Size, Style, UniformListScrollHandle, Window, quad, }; pub struct Scrollbar { @@ -301,6 +301,8 @@ impl Element for Scrollbar { window.set_cursor_style(CursorStyle::Arrow, hitbox); } + let scroll = self.state.scroll_handle.clone(); + enum ScrollbarMouseEvent { GutterClick, ThumbDrag(Pixels), @@ -335,12 +337,10 @@ impl Element for Scrollbar { }; window.on_mouse_event({ + let scroll = scroll.clone(); let state = self.state.clone(); move |event: &MouseDownEvent, phase, _, _| { - if !phase.bubble() - || event.button != MouseButton::Left - || !bounds.contains(&event.position) - { + if !(phase.bubble() && bounds.contains(&event.position)) { return; } @@ -348,71 +348,57 @@ impl Element for Scrollbar { let offset = event.position.along(axis) - thumb_bounds.origin.along(axis); state.set_dragging(offset); } else { - let scroll_handle = state.scroll_handle(); let click_offset = compute_click_offset( event.position, - scroll_handle.max_offset(), + scroll.max_offset(), ScrollbarMouseEvent::GutterClick, ); - scroll_handle - .set_offset(scroll_handle.offset().apply_along(axis, |_| click_offset)); + scroll.set_offset(scroll.offset().apply_along(axis, |_| click_offset)); } } }); window.on_mouse_event({ - let scroll_handle = self.state.scroll_handle().clone(); + let scroll = scroll.clone(); move |event: &ScrollWheelEvent, phase, window, _| { if phase.bubble() && bounds.contains(&event.position) { - let current_offset = scroll_handle.offset(); - scroll_handle.set_offset( + let current_offset = scroll.offset(); + scroll.set_offset( current_offset + event.delta.pixel_delta(window.line_height()), ); } } }); - window.on_mouse_event({ - let state = self.state.clone(); - move |event: &MouseMoveEvent, phase, window, cx| { - if phase.bubble() { - match state.thumb_state.get() { - ThumbState::Dragging(drag_state) if event.dragging() => { - let scroll_handle = state.scroll_handle(); - let drag_offset = compute_click_offset( - event.position, - scroll_handle.max_offset(), - ScrollbarMouseEvent::ThumbDrag(drag_state), - ); - scroll_handle.set_offset( - scroll_handle.offset().apply_along(axis, |_| drag_offset), - ); - window.refresh(); - if let Some(id) = state.parent_id { - cx.notify(id); - } - } - _ if event.pressed_button.is_none() => { - state.set_thumb_hovered(thumb_bounds.contains(&event.position)) - } - _ => {} + let state = self.state.clone(); + window.on_mouse_event(move |event: &MouseMoveEvent, _, window, cx| { + match state.thumb_state.get() { + ThumbState::Dragging(drag_state) if event.dragging() => { + let drag_offset = compute_click_offset( + event.position, + scroll.max_offset(), + ScrollbarMouseEvent::ThumbDrag(drag_state), + ); + scroll.set_offset(scroll.offset().apply_along(axis, |_| drag_offset)); + window.refresh(); + if let Some(id) = state.parent_id { + cx.notify(id); } } + _ => state.set_thumb_hovered(thumb_bounds.contains(&event.position)), } }); - - window.on_mouse_event({ - let state = self.state.clone(); - move |event: &MouseUpEvent, phase, _, cx| { - if phase.bubble() { - if state.is_dragging() { - state.scroll_handle().drag_ended(); - if let Some(id) = state.parent_id { - cx.notify(id); - } - } + let state = self.state.clone(); + let scroll = self.state.scroll_handle.clone(); + window.on_mouse_event(move |event: &MouseUpEvent, phase, _, cx| { + if phase.bubble() { + if state.is_dragging() { state.set_thumb_hovered(thumb_bounds.contains(&event.position)); } + scroll.drag_ended(); + if let Some(id) = state.parent_id { + cx.notify(id); + } } }); }) diff --git a/crates/ui/src/components/stories/icon_button.rs b/crates/ui/src/components/stories/icon_button.rs index ad6886252d..e787e81b55 100644 --- a/crates/ui/src/components/stories/icon_button.rs +++ b/crates/ui/src/components/stories/icon_button.rs @@ -77,7 +77,7 @@ impl Render for IconButtonStory { let with_tooltip_button = StoryItem::new( "With `tooltip`", - IconButton::new("with_tooltip_button", IconName::Chat) + IconButton::new("with_tooltip_button", IconName::MessageBubbles) .tooltip(Tooltip::text("Open messages")), ) .description("Displays an icon button that has a tooltip when hovered.") diff --git a/crates/ui/src/components/toggle.rs b/crates/ui/src/components/toggle.rs index 0d8f5c4107..cf2a56b1c9 100644 --- a/crates/ui/src/components/toggle.rs +++ b/crates/ui/src/components/toggle.rs @@ -566,7 +566,7 @@ impl RenderOnce for Switch { pub struct SwitchField { id: ElementId, label: SharedString, - description: Option<SharedString>, + description: SharedString, toggle_state: ToggleState, on_click: Arc<dyn Fn(&ToggleState, &mut Window, &mut App) + 'static>, disabled: bool, @@ -577,14 +577,14 @@ impl SwitchField { pub fn new( id: impl Into<ElementId>, label: impl Into<SharedString>, - description: Option<SharedString>, + description: impl Into<SharedString>, toggle_state: impl Into<ToggleState>, on_click: impl Fn(&ToggleState, &mut Window, &mut App) + 'static, ) -> Self { Self { id: id.into(), label: label.into(), - description: description, + description: description.into(), toggle_state: toggle_state.into(), on_click: Arc::new(on_click), disabled: false, @@ -592,11 +592,6 @@ impl SwitchField { } } - pub fn description(mut self, description: impl Into<SharedString>) -> Self { - self.description = Some(description.into()); - self - } - pub fn disabled(mut self, disabled: bool) -> Self { self.disabled = disabled; self @@ -614,22 +609,17 @@ impl RenderOnce for SwitchField { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { h_flex() .id(SharedString::from(format!("{}-container", self.id))) - .when(!self.disabled, |this| { - this.hover(|this| this.cursor_pointer()) - }) .w_full() .gap_4() .justify_between() .flex_wrap() - .child(match &self.description { - Some(description) => v_flex() + .child( + v_flex() .gap_0p5() .max_w_5_6() - .child(Label::new(self.label.clone())) - .child(Label::new(description.clone()).color(Color::Muted)) - .into_any_element(), - None => Label::new(self.label.clone()).into_any_element(), - }) + .child(Label::new(self.label)) + .child(Label::new(self.description).color(Color::Muted)), + ) .child( Switch::new( SharedString::from(format!("{}-switch", self.id)), @@ -678,7 +668,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_unselected", "Enable notifications", - Some("Receive notifications when new messages arrive.".into()), + "Receive notifications when new messages arrive.", ToggleState::Unselected, |_, _, _| {}, ) @@ -689,7 +679,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_selected", "Enable notifications", - Some("Receive notifications when new messages arrive.".into()), + "Receive notifications when new messages arrive.", ToggleState::Selected, |_, _, _| {}, ) @@ -705,7 +695,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_default", "Default color", - Some("This uses the default switch color.".into()), + "This uses the default switch color.", ToggleState::Selected, |_, _, _| {}, ) @@ -716,7 +706,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_accent", "Accent color", - Some("This uses the accent color scheme.".into()), + "This uses the accent color scheme.", ToggleState::Selected, |_, _, _| {}, ) @@ -732,7 +722,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_disabled", "Disabled field", - Some("This field is disabled and cannot be toggled.".into()), + "This field is disabled and cannot be toggled.", ToggleState::Selected, |_, _, _| {}, ) @@ -740,20 +730,6 @@ impl Component for SwitchField { .into_any_element(), )], ), - example_group_with_title( - "No Description", - vec![single_example( - "No Description", - SwitchField::new( - "switch_field_disabled", - "Disabled field", - None, - ToggleState::Selected, - |_, _, _| {}, - ) - .into_any_element(), - )], - ), ]) .into_any_element(), ) diff --git a/crates/ui/src/styles/animation.rs b/crates/ui/src/styles/animation.rs index 0649bee1f8..50c4e0eb0d 100644 --- a/crates/ui/src/styles/animation.rs +++ b/crates/ui/src/styles/animation.rs @@ -109,7 +109,7 @@ impl Component for Animation { fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { let container_size = 128.0; let element_size = 32.0; - let offset = container_size / 2.0 - element_size / 2.0; + let left_offset = element_size - container_size / 2.0; Some( v_flex() .gap_6() @@ -129,7 +129,7 @@ impl Component for Animation { .id("animate-in-from-bottom") .absolute() .size(px(element_size)) - .left(px(offset)) + .left(px(left_offset)) .rounded_md() .bg(gpui::red()) .animate_in(AnimationDirection::FromBottom, false), @@ -148,7 +148,7 @@ impl Component for Animation { .id("animate-in-from-top") .absolute() .size(px(element_size)) - .left(px(offset)) + .left(px(left_offset)) .rounded_md() .bg(gpui::blue()) .animate_in(AnimationDirection::FromTop, false), @@ -167,7 +167,7 @@ impl Component for Animation { .id("animate-in-from-left") .absolute() .size(px(element_size)) - .top(px(offset)) + .left(px(left_offset)) .rounded_md() .bg(gpui::green()) .animate_in(AnimationDirection::FromLeft, false), @@ -186,7 +186,7 @@ impl Component for Animation { .id("animate-in-from-right") .absolute() .size(px(element_size)) - .top(px(offset)) + .left(px(left_offset)) .rounded_md() .bg(gpui::yellow()) .animate_in(AnimationDirection::FromRight, false), @@ -211,7 +211,7 @@ impl Component for Animation { .id("fade-animate-in-from-bottom") .absolute() .size(px(element_size)) - .left(px(offset)) + .left(px(left_offset)) .rounded_md() .bg(gpui::red()) .animate_in(AnimationDirection::FromBottom, true), @@ -230,7 +230,7 @@ impl Component for Animation { .id("fade-animate-in-from-top") .absolute() .size(px(element_size)) - .left(px(offset)) + .left(px(left_offset)) .rounded_md() .bg(gpui::blue()) .animate_in(AnimationDirection::FromTop, true), @@ -249,7 +249,7 @@ impl Component for Animation { .id("fade-animate-in-from-left") .absolute() .size(px(element_size)) - .top(px(offset)) + .left(px(left_offset)) .rounded_md() .bg(gpui::green()) .animate_in(AnimationDirection::FromLeft, true), @@ -268,7 +268,7 @@ impl Component for Animation { .id("fade-animate-in-from-right") .absolute() .size(px(element_size)) - .top(px(offset)) + .left(px(left_offset)) .rounded_md() .bg(gpui::yellow()) .animate_in(AnimationDirection::FromRight, true), diff --git a/crates/ui_prompt/src/ui_prompt.rs b/crates/ui_prompt/src/ui_prompt.rs index fe6dc5b3f4..2b6a030f26 100644 --- a/crates/ui_prompt/src/ui_prompt.rs +++ b/crates/ui_prompt/src/ui_prompt.rs @@ -43,7 +43,7 @@ fn zed_prompt_renderer( let renderer = cx.new({ |cx| ZedPromptRenderer { _level: level, - message: cx.new(|cx| Markdown::new(SharedString::new(message), None, None, cx)), + message: message.to_string(), actions: actions.iter().map(|a| a.label().to_string()).collect(), focus: cx.focus_handle(), active_action_id: 0, @@ -58,7 +58,7 @@ fn zed_prompt_renderer( pub struct ZedPromptRenderer { _level: PromptLevel, - message: Entity<Markdown>, + message: String, actions: Vec<String>, focus: FocusHandle, active_action_id: usize, @@ -114,7 +114,7 @@ impl ZedPromptRenderer { impl Render for ZedPromptRenderer { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); - let font_size = settings.ui_font_size(cx).into(); + let font_family = settings.ui_font.family.clone(); let prompt = v_flex() .key_context("Prompt") .cursor_default() @@ -130,38 +130,24 @@ impl Render for ZedPromptRenderer { .overflow_hidden() .p_4() .gap_4() - .font_family(settings.ui_font.family.clone()) + .font_family(font_family) .child( div() .w_full() - .child(MarkdownElement::new(self.message.clone(), { - let mut base_text_style = window.text_style(); - base_text_style.refine(&TextStyleRefinement { - font_family: Some(settings.ui_font.family.clone()), - font_size: Some(font_size), - font_weight: Some(FontWeight::BOLD), - color: Some(ui::Color::Default.color(cx)), - ..Default::default() - }); - MarkdownStyle { - base_text_style, - selection_background_color: cx - .theme() - .colors() - .element_selection_background, - ..Default::default() - } - })), + .font_weight(FontWeight::BOLD) + .child(self.message.clone()) + .text_color(ui::Color::Default.color(cx)), ) .children(self.detail.clone().map(|detail| { div() .w_full() .text_xs() .child(MarkdownElement::new(detail, { + let settings = ThemeSettings::get_global(cx); let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { font_family: Some(settings.ui_font.family.clone()), - font_size: Some(font_size), + font_size: Some(settings.ui_font_size(cx).into()), color: Some(ui::Color::Muted.color(cx)), ..Default::default() }); @@ -190,28 +176,24 @@ impl Render for ZedPromptRenderer { }), )); - div() - .size_full() - .occlude() - .bg(gpui::black().opacity(0.2)) - .child( - div() - .size_full() - .absolute() - .top_0() - .left_0() - .flex() - .flex_col() - .justify_around() - .child( - div() - .w_full() - .flex() - .flex_row() - .justify_around() - .child(prompt), - ), - ) + div().size_full().occlude().child( + div() + .size_full() + .absolute() + .top_0() + .left_0() + .flex() + .flex_col() + .justify_around() + .child( + div() + .w_full() + .flex() + .flex_row() + .justify_around() + .child(prompt), + ), + ) } } diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index 7963db3571..23e04cae2c 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -6,7 +6,7 @@ use editor::{ actions::{SortLinesCaseInsensitive, SortLinesCaseSensitive}, display_map::ToDisplayPoint, }; -use gpui::{Action, App, AppContext as _, Context, Global, Keystroke, Window, actions}; +use gpui::{Action, App, AppContext as _, Context, Global, Window, actions}; use itertools::Itertools; use language::Point; use multi_buffer::MultiBufferRow; @@ -202,7 +202,6 @@ actions!( ArgumentRequired ] ); - /// Opens the specified file for editing. #[derive(Clone, PartialEq, Action)] #[action(namespace = vim, no_json, no_register)] @@ -210,13 +209,6 @@ struct VimEdit { pub filename: String, } -#[derive(Clone, PartialEq, Action)] -#[action(namespace = vim, no_json, no_register)] -struct VimNorm { - pub range: Option<CommandRange>, - pub command: String, -} - #[derive(Debug)] struct WrappedAction(Box<dyn Action>); @@ -455,81 +447,6 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) { }); }); - Vim::action(editor, cx, |vim, action: &VimNorm, window, cx| { - let keystrokes = action - .command - .chars() - .map(|c| Keystroke::parse(&c.to_string()).unwrap()) - .collect(); - vim.switch_mode(Mode::Normal, true, window, cx); - let initial_selections = vim.update_editor(window, cx, |_, editor, _, _| { - editor.selections.disjoint_anchors() - }); - if let Some(range) = &action.range { - let result = vim.update_editor(window, cx, |vim, editor, window, cx| { - let range = range.buffer_range(vim, editor, window, cx)?; - editor.change_selections( - SelectionEffects::no_scroll().nav_history(false), - window, - cx, - |s| { - s.select_ranges( - (range.start.0..=range.end.0) - .map(|line| Point::new(line, 0)..Point::new(line, 0)), - ); - }, - ); - anyhow::Ok(()) - }); - if let Some(Err(err)) = result { - log::error!("Error selecting range: {}", err); - return; - } - }; - - let Some(workspace) = vim.workspace(window) else { - return; - }; - let task = workspace.update(cx, |workspace, cx| { - workspace.send_keystrokes_impl(keystrokes, window, cx) - }); - let had_range = action.range.is_some(); - - cx.spawn_in(window, async move |vim, cx| { - task.await; - vim.update_in(cx, |vim, window, cx| { - vim.update_editor(window, cx, |_, editor, window, cx| { - if had_range { - editor.change_selections(SelectionEffects::default(), window, cx, |s| { - s.select_anchor_ranges([s.newest_anchor().range()]); - }) - } - }); - if matches!(vim.mode, Mode::Insert | Mode::Replace) { - vim.normal_before(&Default::default(), window, cx); - } else { - vim.switch_mode(Mode::Normal, true, window, cx); - } - vim.update_editor(window, cx, |_, editor, _, cx| { - if let Some(first_sel) = initial_selections { - if let Some(tx_id) = editor - .buffer() - .update(cx, |multi, cx| multi.last_transaction_id(cx)) - { - let last_sel = editor.selections.disjoint_anchors(); - editor.modify_transaction_selection_history(tx_id, |old| { - old.0 = first_sel; - old.1 = Some(last_sel); - }); - } - } - }); - }) - .ok(); - }) - .detach(); - }); - Vim::action(editor, cx, |vim, _: &CountCommand, window, cx| { let Some(workspace) = vim.workspace(window) else { return; @@ -758,15 +675,14 @@ impl VimCommand { } else { return None; }; - - let action = if args.is_empty() { - action - } else { + if !args.is_empty() { // if command does not accept args and we have args then we should do no action - self.args.as_ref()?(action, args)? - }; - - if let Some(range) = range { + if let Some(args_fn) = &self.args { + args_fn.deref()(action, args) + } else { + None + } + } else if let Some(range) = range { self.range.as_ref().and_then(|f| f(action, range)) } else { Some(action) @@ -1145,27 +1061,6 @@ fn generate_commands(_: &App) -> Vec<VimCommand> { save_intent: Some(SaveIntent::Skip), close_pinned: true, }), - VimCommand::new( - ("norm", "al"), - VimNorm { - command: "".into(), - range: None, - }, - ) - .args(|_, args| { - Some( - VimNorm { - command: args, - range: None, - } - .boxed_clone(), - ) - }) - .range(|action, range| { - let mut action: VimNorm = action.as_any().downcast_ref::<VimNorm>().unwrap().clone(); - action.range.replace(range.clone()); - Some(Box::new(action)) - }), VimCommand::new(("bn", "ext"), workspace::ActivateNextItem).count(), VimCommand::new(("bN", "ext"), workspace::ActivatePreviousItem).count(), VimCommand::new(("bp", "revious"), workspace::ActivatePreviousItem).count(), @@ -2403,78 +2298,4 @@ mod test { }); assert!(mark.is_none()) } - - #[gpui::test] - async fn test_normal_command(cx: &mut TestAppContext) { - let mut cx = NeovimBackedTestContext::new(cx).await; - - cx.set_shared_state(indoc! {" - The quick - brown« fox - jumpsˇ» over - the lazy dog - "}) - .await; - - cx.simulate_shared_keystrokes(": n o r m space w C w o r d") - .await; - cx.simulate_shared_keystrokes("enter").await; - - cx.shared_state().await.assert_eq(indoc! {" - The quick - brown word - jumps worˇd - the lazy dog - "}); - - cx.simulate_shared_keystrokes(": n o r m space _ w c i w t e s t") - .await; - cx.simulate_shared_keystrokes("enter").await; - - cx.shared_state().await.assert_eq(indoc! {" - The quick - brown word - jumps tesˇt - the lazy dog - "}); - - cx.simulate_shared_keystrokes("_ l v l : n o r m space s l a") - .await; - cx.simulate_shared_keystrokes("enter").await; - - cx.shared_state().await.assert_eq(indoc! {" - The quick - brown word - lˇaumps test - the lazy dog - "}); - - cx.set_shared_state(indoc! {" - ˇThe quick - brown fox - jumps over - the lazy dog - "}) - .await; - - cx.simulate_shared_keystrokes("c i w M y escape").await; - - cx.shared_state().await.assert_eq(indoc! {" - Mˇy quick - brown fox - jumps over - the lazy dog - "}); - - cx.simulate_shared_keystrokes(": n o r m space u").await; - cx.simulate_shared_keystrokes("enter").await; - - cx.shared_state().await.assert_eq(indoc! {" - ˇThe quick - brown fox - jumps over - the lazy dog - "}); - // Once ctrl-v to input character literals is added there should be a test for redo - } } diff --git a/crates/vim/src/helix.rs b/crates/vim/src/helix.rs index ca93c9c1de..ec9b959b12 100644 --- a/crates/vim/src/helix.rs +++ b/crates/vim/src/helix.rs @@ -1,31 +1,21 @@ -use editor::{DisplayPoint, Editor, SelectionEffects, ToOffset, ToPoint, movement}; +use editor::{DisplayPoint, Editor, movement}; use gpui::{Action, actions}; use gpui::{Context, Window}; use language::{CharClassifier, CharKind}; -use text::{Bias, SelectionGoal}; +use text::SelectionGoal; -use crate::{ - Vim, - motion::{Motion, right}, - state::Mode, -}; +use crate::{Vim, motion::Motion, state::Mode}; actions!( vim, [ /// Switches to normal mode after the cursor (Helix-style). - HelixNormalAfter, - /// Inserts at the beginning of the selection. - HelixInsert, - /// Appends at the end of the selection. - HelixAppend, + HelixNormalAfter ] ); pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) { Vim::action(editor, cx, Vim::helix_normal_after); - Vim::action(editor, cx, Vim::helix_insert); - Vim::action(editor, cx, Vim::helix_append); } impl Vim { @@ -309,112 +299,6 @@ impl Vim { _ => self.helix_move_and_collapse(motion, times, window, cx), } } - - fn helix_insert(&mut self, _: &HelixInsert, window: &mut Window, cx: &mut Context<Self>) { - self.start_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { - editor.change_selections(Default::default(), window, cx, |s| { - s.move_with(|_map, selection| { - // In helix normal mode, move cursor to start of selection and collapse - if !selection.is_empty() { - selection.collapse_to(selection.start, SelectionGoal::None); - } - }); - }); - }); - self.switch_mode(Mode::Insert, false, window, cx); - } - - fn helix_append(&mut self, _: &HelixAppend, window: &mut Window, cx: &mut Context<Self>) { - self.start_recording(cx); - self.switch_mode(Mode::Insert, false, window, cx); - self.update_editor(window, cx, |_, editor, window, cx| { - editor.change_selections(Default::default(), window, cx, |s| { - s.move_with(|map, selection| { - let point = if selection.is_empty() { - right(map, selection.head(), 1) - } else { - selection.end - }; - selection.collapse_to(point, SelectionGoal::None); - }); - }); - }); - } - - pub fn helix_replace(&mut self, text: &str, window: &mut Window, cx: &mut Context<Self>) { - self.update_editor(window, cx, |_, editor, window, cx| { - editor.transact(window, cx, |editor, window, cx| { - let (map, selections) = editor.selections.all_display(cx); - - // Store selection info for positioning after edit - let selection_info: Vec<_> = selections - .iter() - .map(|selection| { - let range = selection.range(); - let start_offset = range.start.to_offset(&map, Bias::Left); - let end_offset = range.end.to_offset(&map, Bias::Left); - let was_empty = range.is_empty(); - let was_reversed = selection.reversed; - ( - map.buffer_snapshot.anchor_at(start_offset, Bias::Left), - end_offset - start_offset, - was_empty, - was_reversed, - ) - }) - .collect(); - - let mut edits = Vec::new(); - for selection in &selections { - let mut range = selection.range(); - - // For empty selections, extend to replace one character - if range.is_empty() { - range.end = movement::saturating_right(&map, range.start); - } - - let byte_range = range.start.to_offset(&map, Bias::Left) - ..range.end.to_offset(&map, Bias::Left); - - if !byte_range.is_empty() { - let replacement_text = text.repeat(byte_range.len()); - edits.push((byte_range, replacement_text)); - } - } - - editor.edit(edits, cx); - - // Restore selections based on original info - let snapshot = editor.buffer().read(cx).snapshot(cx); - let ranges: Vec<_> = selection_info - .into_iter() - .map(|(start_anchor, original_len, was_empty, was_reversed)| { - let start_point = start_anchor.to_point(&snapshot); - if was_empty { - // For cursor-only, collapse to start - start_point..start_point - } else { - // For selections, span the replaced text - let replacement_len = text.len() * original_len; - let end_offset = start_anchor.to_offset(&snapshot) + replacement_len; - let end_point = snapshot.offset_to_point(end_offset); - if was_reversed { - end_point..start_point - } else { - start_point..end_point - } - } - }) - .collect(); - - editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_ranges(ranges); - }); - }); - }); - self.switch_mode(Mode::HelixNormal, true, window, cx); - } } #[cfg(test)] @@ -613,94 +497,4 @@ mod test { cx.assert_state("«ˇaa»\n", Mode::HelixNormal); } - - #[gpui::test] - async fn test_insert_selected(cx: &mut gpui::TestAppContext) { - let mut cx = VimTestContext::new(cx, true).await; - cx.set_state( - indoc! {" - «The ˇ»quick brown - fox jumps over - the lazy dog."}, - Mode::HelixNormal, - ); - - cx.simulate_keystrokes("i"); - - cx.assert_state( - indoc! {" - ˇThe quick brown - fox jumps over - the lazy dog."}, - Mode::Insert, - ); - } - - #[gpui::test] - async fn test_append(cx: &mut gpui::TestAppContext) { - let mut cx = VimTestContext::new(cx, true).await; - // test from the end of the selection - cx.set_state( - indoc! {" - «Theˇ» quick brown - fox jumps over - the lazy dog."}, - Mode::HelixNormal, - ); - - cx.simulate_keystrokes("a"); - - cx.assert_state( - indoc! {" - Theˇ quick brown - fox jumps over - the lazy dog."}, - Mode::Insert, - ); - - // test from the beginning of the selection - cx.set_state( - indoc! {" - «ˇThe» quick brown - fox jumps over - the lazy dog."}, - Mode::HelixNormal, - ); - - cx.simulate_keystrokes("a"); - - cx.assert_state( - indoc! {" - Theˇ quick brown - fox jumps over - the lazy dog."}, - Mode::Insert, - ); - } - - #[gpui::test] - async fn test_replace(cx: &mut gpui::TestAppContext) { - let mut cx = VimTestContext::new(cx, true).await; - - // No selection (single character) - cx.set_state("ˇaa", Mode::HelixNormal); - - cx.simulate_keystrokes("r x"); - - cx.assert_state("ˇxa", Mode::HelixNormal); - - // Cursor at the beginning - cx.set_state("«ˇaa»", Mode::HelixNormal); - - cx.simulate_keystrokes("r x"); - - cx.assert_state("«ˇxx»", Mode::HelixNormal); - - // Cursor at the end - cx.set_state("«aaˇ»", Mode::HelixNormal); - - cx.simulate_keystrokes("r x"); - - cx.assert_state("«xxˇ»", Mode::HelixNormal); - } } diff --git a/crates/vim/src/insert.rs b/crates/vim/src/insert.rs index 0a370e16ba..89c60adee7 100644 --- a/crates/vim/src/insert.rs +++ b/crates/vim/src/insert.rs @@ -21,7 +21,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) { } impl Vim { - pub(crate) fn normal_before( + fn normal_before( &mut self, action: &NormalBefore, window: &mut Window, diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index c22cf0ef00..a50b238cc5 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -987,7 +987,7 @@ impl Motion { SelectionGoal::None, ), NextWordEnd { ignore_punctuation } => ( - next_word_end(map, point, *ignore_punctuation, times, true, true), + next_word_end(map, point, *ignore_punctuation, times, true), SelectionGoal::None, ), PreviousWordStart { ignore_punctuation } => ( @@ -1723,19 +1723,14 @@ pub(crate) fn next_word_end( ignore_punctuation: bool, times: usize, allow_cross_newline: bool, - always_advance: bool, ) -> DisplayPoint { let classifier = map .buffer_snapshot .char_classifier_at(point.to_point(map)) .ignore_punctuation(ignore_punctuation); for _ in 0..times { + let new_point = next_char(map, point, allow_cross_newline); let mut need_next_char = false; - let new_point = if always_advance { - next_char(map, point, allow_cross_newline) - } else { - point - }; let new_point = movement::find_boundary_exclusive( map, new_point, diff --git a/crates/vim/src/normal/change.rs b/crates/vim/src/normal/change.rs index 135cdd687f..9485f17477 100644 --- a/crates/vim/src/normal/change.rs +++ b/crates/vim/src/normal/change.rs @@ -51,7 +51,6 @@ impl Vim { ignore_punctuation, &text_layout_details, motion == Motion::NextSubwordStart { ignore_punctuation }, - !matches!(motion, Motion::NextWordStart { .. }), ) } _ => { @@ -149,7 +148,6 @@ fn expand_changed_word_selection( ignore_punctuation: bool, text_layout_details: &TextLayoutDetails, use_subword: bool, - always_advance: bool, ) -> Option<MotionKind> { let is_in_word = || { let classifier = map @@ -175,14 +173,8 @@ fn expand_changed_word_selection( selection.end = motion::next_subword_end(map, selection.end, ignore_punctuation, 1, false); } else { - selection.end = motion::next_word_end( - map, - selection.end, - ignore_punctuation, - 1, - false, - always_advance, - ); + selection.end = + motion::next_word_end(map, selection.end, ignore_punctuation, 1, false); } selection.end = motion::next_char(map, selection.end, false); } @@ -279,10 +271,6 @@ mod test { cx.simulate("c shift-w", "Test teˇst-test test") .await .assert_matches(); - - // on last character of word, `cw` doesn't eat subsequent punctuation - // see https://github.com/zed-industries/zed/issues/35269 - cx.simulate("c w", "tesˇt-test").await.assert_matches(); } #[gpui::test] diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 2f759ec8af..95a08d7c66 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -747,7 +747,7 @@ impl Vim { Vim::action( editor, cx, - |vim, action: &editor::actions::AcceptEditPrediction, window, cx| { + |vim, action: &editor::AcceptEditPrediction, window, cx| { vim.update_editor(window, cx, |_, editor, window, cx| { editor.accept_edit_prediction(action, window, cx); }); @@ -1639,7 +1639,6 @@ impl Vim { Mode::Visual | Mode::VisualLine | Mode::VisualBlock => { self.visual_replace(text, window, cx) } - Mode::HelixNormal => self.helix_replace(&text, window, cx), _ => self.clear_operator(window, cx), }, Some(Operator::Digraph { first_char }) => { diff --git a/crates/vim/test_data/test_change_w.json b/crates/vim/test_data/test_change_w.json index 149dac8420..27be543532 100644 --- a/crates/vim/test_data/test_change_w.json +++ b/crates/vim/test_data/test_change_w.json @@ -30,7 +30,3 @@ {"Key":"c"} {"Key":"shift-w"} {"Get":{"state":"Test teˇ test","mode":"Insert"}} -{"Put":{"state":"tesˇt-test"}} -{"Key":"c"} -{"Key":"w"} -{"Get":{"state":"tesˇ-test","mode":"Insert"}} diff --git a/crates/vim/test_data/test_normal_command.json b/crates/vim/test_data/test_normal_command.json deleted file mode 100644 index efd1d532c4..0000000000 --- a/crates/vim/test_data/test_normal_command.json +++ /dev/null @@ -1,64 +0,0 @@ -{"Put":{"state":"The quick\nbrown« fox\njumpsˇ» over\nthe lazy dog\n"}} -{"Key":":"} -{"Key":"n"} -{"Key":"o"} -{"Key":"r"} -{"Key":"m"} -{"Key":"space"} -{"Key":"w"} -{"Key":"C"} -{"Key":"w"} -{"Key":"o"} -{"Key":"r"} -{"Key":"d"} -{"Key":"enter"} -{"Get":{"state":"The quick\nbrown word\njumps worˇd\nthe lazy dog\n","mode":"Normal"}} -{"Key":":"} -{"Key":"n"} -{"Key":"o"} -{"Key":"r"} -{"Key":"m"} -{"Key":"space"} -{"Key":"_"} -{"Key":"w"} -{"Key":"c"} -{"Key":"i"} -{"Key":"w"} -{"Key":"t"} -{"Key":"e"} -{"Key":"s"} -{"Key":"t"} -{"Key":"enter"} -{"Get":{"state":"The quick\nbrown word\njumps tesˇt\nthe lazy dog\n","mode":"Normal"}} -{"Key":"_"} -{"Key":"l"} -{"Key":"v"} -{"Key":"l"} -{"Key":":"} -{"Key":"n"} -{"Key":"o"} -{"Key":"r"} -{"Key":"m"} -{"Key":"space"} -{"Key":"s"} -{"Key":"l"} -{"Key":"a"} -{"Key":"enter"} -{"Get":{"state":"The quick\nbrown word\nlˇaumps test\nthe lazy dog\n","mode":"Normal"}} -{"Put":{"state":"ˇThe quick\nbrown fox\njumps over\nthe lazy dog\n"}} -{"Key":"c"} -{"Key":"i"} -{"Key":"w"} -{"Key":"M"} -{"Key":"y"} -{"Key":"escape"} -{"Get":{"state":"Mˇy quick\nbrown fox\njumps over\nthe lazy dog\n","mode":"Normal"}} -{"Key":":"} -{"Key":"n"} -{"Key":"o"} -{"Key":"r"} -{"Key":"m"} -{"Key":"space"} -{"Key":"u"} -{"Key":"enter"} -{"Get":{"state":"ˇThe quick\nbrown fox\njumps over\nthe lazy dog\n","mode":"Normal"}} diff --git a/crates/web_search/Cargo.toml b/crates/web_search/Cargo.toml index 4ba46faec4..e5b8ca63b2 100644 --- a/crates/web_search/Cargo.toml +++ b/crates/web_search/Cargo.toml @@ -13,8 +13,8 @@ path = "src/web_search.rs" [dependencies] anyhow.workspace = true -cloud_llm_client.workspace = true collections.workspace = true gpui.workspace = true serde.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true diff --git a/crates/web_search/src/web_search.rs b/crates/web_search/src/web_search.rs index 8578cfe4aa..a131b0de71 100644 --- a/crates/web_search/src/web_search.rs +++ b/crates/web_search/src/web_search.rs @@ -1,9 +1,8 @@ -use std::sync::Arc; - use anyhow::Result; -use cloud_llm_client::WebSearchResponse; use collections::HashMap; use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task}; +use std::sync::Arc; +use zed_llm_client::WebSearchResponse; pub fn init(cx: &mut App) { let registry = cx.new(|_cx| WebSearchRegistry::default()); diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index f7a248d106..2e052796c4 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,7 +14,6 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true -cloud_llm_client.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true @@ -23,3 +22,4 @@ serde.workspace = true serde_json.workspace = true web_search.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 52ee0da0d4..adf79b0ff6 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::Client; -use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Subscription, Task}; use http_client::{HttpClient, Method}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use web_search::{WebSearchProvider, WebSearchProviderId}; +use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; pub struct CloudWebSearchProvider { state: Entity<State>, diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml index acb3fe0f84..769dd8d6aa 100644 --- a/crates/welcome/Cargo.toml +++ b/crates/welcome/Cargo.toml @@ -29,6 +29,7 @@ project.workspace = true serde.workspace = true settings.workspace = true telemetry.workspace = true +theme.workspace = true ui.workspace = true util.workspace = true vim_mode_setting.workspace = true diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs index 352118eee8..49bf2031ab 100644 --- a/crates/welcome/src/welcome.rs +++ b/crates/welcome/src/welcome.rs @@ -21,6 +21,7 @@ pub use multibuffer_hint::*; mod base_keymap_picker; mod multibuffer_hint; +mod welcome_ui; actions!( welcome, diff --git a/crates/welcome/src/welcome_ui.rs b/crates/welcome/src/welcome_ui.rs new file mode 100644 index 0000000000..622b6f448d --- /dev/null +++ b/crates/welcome/src/welcome_ui.rs @@ -0,0 +1 @@ +mod theme_preview; diff --git a/crates/onboarding/src/theme_preview.rs b/crates/welcome/src/welcome_ui/theme_preview.rs similarity index 72% rename from crates/onboarding/src/theme_preview.rs rename to crates/welcome/src/welcome_ui/theme_preview.rs index d51511b7f4..b3a80c74c3 100644 --- a/crates/onboarding/src/theme_preview.rs +++ b/crates/welcome/src/welcome_ui/theme_preview.rs @@ -11,14 +11,22 @@ use ui::{ #[derive(IntoElement, RegisterComponent, Documented)] pub struct ThemePreviewTile { theme: Arc<Theme>, + selected: bool, seed: f32, } impl ThemePreviewTile { - pub const CORNER_RADIUS: Pixels = px(8.0); + pub fn new(theme: Arc<Theme>, selected: bool, seed: f32) -> Self { + Self { + theme, + selected, + seed, + } + } - pub fn new(theme: Arc<Theme>, seed: f32) -> Self { - Self { theme, seed } + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self } } @@ -26,7 +34,7 @@ impl RenderOnce for ThemePreviewTile { fn render(self, _window: &mut ui::Window, _cx: &mut ui::App) -> impl IntoElement { let color = self.theme.colors(); - let root_radius = Self::CORNER_RADIUS; + let root_radius = px(8.0); let root_border = px(2.0); let root_padding = px(2.0); let child_border = px(1.0); @@ -35,7 +43,7 @@ impl RenderOnce for ThemePreviewTile { let item_skeleton = |w: Length, h: Pixels, bg: Hsla| div().w(w).h(h).rounded_full().bg(bg); - let skeleton_height = px(2.); + let skeleton_height = px(4.); let sidebar_seeded_width = |seed: f32, index: usize| { let value = (seed * 1000.0 + index as f32 * 10.0).sin() * 0.5 + 0.5; @@ -62,10 +70,12 @@ impl RenderOnce for ThemePreviewTile { .border_color(color.border_transparent) .bg(color.panel_background) .child( - v_flex() + div() .p_2() + .flex() + .flex_col() .size_full() - .gap_1() + .gap(px(4.)) .children(sidebar_skeleton), ); @@ -141,19 +151,32 @@ impl RenderOnce for ThemePreviewTile { v_flex() .size_full() .p_1() - .gap_1p5() + .gap(px(6.)) .children(lines) .into_any_element() }; - let pane = v_flex().h_full().flex_grow().child( - div() - .size_full() - .overflow_hidden() - .bg(color.editor_background) - .p_2() - .child(pseudo_code_skeleton(self.theme.clone(), self.seed)), - ); + let pane = div() + .h_full() + .flex_grow() + .flex() + .flex_col() + // .child( + // div() + // .w_full() + // .border_color(color.border) + // .border_b(px(1.)) + // .h(relative(0.1)) + // .bg(color.tab_bar_background), + // ) + .child( + div() + .size_full() + .overflow_hidden() + .bg(color.editor_background) + .p_2() + .child(pseudo_code_skeleton(self.theme.clone(), self.seed)), + ); let content = div().size_full().flex().child(sidebar).child(pane); @@ -161,6 +184,11 @@ impl RenderOnce for ThemePreviewTile { .size_full() .rounded(root_radius) .p(root_padding) + .border(root_border) + .border_color(color.border_transparent) + .when(self.selected, |this| { + this.border_color(color.border_selected) + }) .child( div() .size_full() @@ -202,14 +230,24 @@ impl Component for ThemePreviewTile { .p_4() .children({ if let Some(one_dark) = one_dark.ok() { - vec![example_group(vec![single_example( - "Default", - div() - .w(px(240.)) - .h(px(180.)) - .child(ThemePreviewTile::new(one_dark.clone(), 0.42)) - .into_any_element(), - )])] + vec![example_group(vec![ + single_example( + "Default", + div() + .w(px(240.)) + .h(px(180.)) + .child(ThemePreviewTile::new(one_dark.clone(), false, 0.42)) + .into_any_element(), + ), + single_example( + "Selected", + div() + .w(px(240.)) + .h(px(180.)) + .child(ThemePreviewTile::new(one_dark, true, 0.42)) + .into_any_element(), + ), + ])] } else { vec![] } @@ -223,11 +261,12 @@ impl Component for ThemePreviewTile { themes_to_preview .iter() .enumerate() - .map(|(_, theme)| { - div() - .w(px(200.)) - .h(px(140.)) - .child(ThemePreviewTile::new(theme.clone(), 0.42)) + .map(|(i, theme)| { + div().w(px(200.)).h(px(140.)).child(ThemePreviewTile::new( + theme.clone(), + false, + 0.42, + )) }) .collect::<Vec<_>>(), ) diff --git a/crates/workspace/src/dock.rs b/crates/workspace/src/dock.rs index ca63d3e553..7165de23ec 100644 --- a/crates/workspace/src/dock.rs +++ b/crates/workspace/src/dock.rs @@ -934,10 +934,6 @@ impl Render for PanelButtons { h_flex() .gap_1() - .when( - has_buttons && dock.position == DockPosition::Bottom, - |this| this.child(Divider::vertical().color(DividerColor::Border)), - ) .children(buttons) .when(has_buttons && dock.position == DockPosition::Left, |this| { this.child(Divider::vertical().color(DividerColor::Border)) diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index ad1c74a040..c7a2562a1b 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -2832,7 +2832,7 @@ impl Pane { }) .collect::<Vec<_>>(); let tab_count = tab_items.len(); - if self.is_tab_pinned(tab_count) { + if self.pinned_tab_count > tab_count { log::warn!( "Pinned tab count ({}) exceeds actual tab count ({}). \ This should not happen. If possible, add reproduction steps, \ @@ -3030,7 +3030,7 @@ impl Pane { || cfg!(not(target_os = "macos")) && window.modifiers().control; let from_pane = dragged_tab.pane.clone(); - + let from_ix = dragged_tab.ix; self.workspace .update(cx, |_, cx| { cx.defer_in(window, move |workspace, window, cx| { @@ -3062,13 +3062,9 @@ impl Pane { } to_pane.update(cx, |this, _| { if to_pane == from_pane { - let actual_ix = this - .items - .iter() - .position(|item| item.item_id() == item_id) - .unwrap_or(0); - - let is_pinned_in_to_pane = this.is_tab_pinned(actual_ix); + let moved_right = ix > from_ix; + let ix = if moved_right { ix - 1 } else { ix }; + let is_pinned_in_to_pane = this.is_tab_pinned(ix); if !was_pinned_in_from_pane && is_pinned_in_to_pane { this.pinned_tab_count += 1; @@ -4954,43 +4950,6 @@ mod tests { assert_item_labels(&pane_a, ["B!", "A*!"], cx); } - #[gpui::test] - async fn test_dragging_pinned_tab_onto_unpinned_tab_reduces_unpinned_tab_count( - cx: &mut TestAppContext, - ) { - init_test(cx); - let fs = FakeFs::new(cx.executor()); - - let project = Project::test(fs, None, cx).await; - let (workspace, cx) = - cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let pane_a = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone()); - - // Add A, B to pane A and pin A - let item_a = add_labeled_item(&pane_a, "A", false, cx); - add_labeled_item(&pane_a, "B", false, cx); - pane_a.update_in(cx, |pane, window, cx| { - let ix = pane.index_for_item_id(item_a.item_id()).unwrap(); - pane.pin_tab_at(ix, window, cx); - }); - assert_item_labels(&pane_a, ["A!", "B*"], cx); - - // Drag pinned A on top of B in the same pane, which changes tab order to B, A - pane_a.update_in(cx, |pane, window, cx| { - let dragged_tab = DraggedTab { - pane: pane_a.clone(), - item: item_a.boxed_clone(), - ix: 0, - detail: 0, - is_active: true, - }; - pane.handle_tab_drop(&dragged_tab, 1, window, cx); - }); - - // Neither are pinned - assert_item_labels(&pane_a, ["B", "A*"], cx); - } - #[gpui::test] async fn test_drag_pinned_tab_beyond_unpinned_tab_in_same_pane_becomes_unpinned( cx: &mut TestAppContext, diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 6fa5c969e7..3f8b098203 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -939,26 +939,6 @@ impl WorkspaceDb { } } - query! { - pub async fn update_ssh_project_paths_query(ssh_project_id: u64, paths: String) -> Result<Option<SerializedSshProject>> { - UPDATE ssh_projects - SET paths = ?2 - WHERE id = ?1 - RETURNING id, host, port, paths, user - } - } - - pub(crate) async fn update_ssh_project_paths( - &self, - ssh_project_id: SshProjectId, - new_paths: Vec<String>, - ) -> Result<SerializedSshProject> { - let paths = serde_json::to_string(&new_paths)?; - self.update_ssh_project_paths_query(ssh_project_id.0, paths) - .await? - .context("failed to update ssh project paths") - } - query! { pub async fn next_id() -> Result<WorkspaceId> { INSERT INTO workspaces DEFAULT VALUES RETURNING workspace_id @@ -2644,56 +2624,4 @@ mod tests { assert_eq!(workspace.center_group, new_workspace.center_group); } - - #[gpui::test] - async fn test_update_ssh_project_paths() { - zlog::init_test(); - - let db = WorkspaceDb::open_test_db("test_update_ssh_project_paths").await; - - let (host, port, initial_paths, user) = ( - "example.com".to_string(), - Some(22_u16), - vec!["/home/user".to_string(), "/etc/nginx".to_string()], - Some("user".to_string()), - ); - - let project = db - .get_or_create_ssh_project(host.clone(), port, initial_paths.clone(), user.clone()) - .await - .unwrap(); - - assert_eq!(project.host, host); - assert_eq!(project.paths, initial_paths); - assert_eq!(project.user, user); - - let new_paths = vec![ - "/home/user".to_string(), - "/etc/nginx".to_string(), - "/var/log".to_string(), - "/opt/app".to_string(), - ]; - - let updated_project = db - .update_ssh_project_paths(project.id, new_paths.clone()) - .await - .unwrap(); - - assert_eq!(updated_project.id, project.id); - assert_eq!(updated_project.paths, new_paths); - - let retrieved_project = db - .get_ssh_project( - host.clone(), - port, - serde_json::to_string(&new_paths).unwrap(), - user.clone(), - ) - .await - .unwrap() - .unwrap(); - - assert_eq!(retrieved_project.id, project.id); - assert_eq!(retrieved_project.paths, new_paths); - } } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 6f7db668dd..52502c1aa8 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -32,7 +32,7 @@ use futures::{ mpsc::{self, UnboundedReceiver, UnboundedSender}, oneshot, }, - future::{Shared, try_join_all}, + future::try_join_all, }; use gpui::{ Action, AnyEntity, AnyView, AnyWeakView, App, AsyncApp, AsyncWindowContext, Bounds, Context, @@ -87,7 +87,7 @@ use std::{ borrow::Cow, cell::RefCell, cmp, - collections::{VecDeque, hash_map::DefaultHasher}, + collections::hash_map::DefaultHasher, env, hash::{Hash, Hasher}, path::{Path, PathBuf}, @@ -1043,13 +1043,6 @@ type PromptForOpenPath = Box< ) -> oneshot::Receiver<Option<Vec<PathBuf>>>, >; -#[derive(Default)] -struct DispatchingKeystrokes { - dispatched: HashSet<Vec<Keystroke>>, - queue: VecDeque<Keystroke>, - task: Option<Shared<Task<()>>>, -} - /// Collects everything project-related for a certain window opened. /// In some way, is a counterpart of a window, as the [`WindowHandle`] could be downcast into `Workspace`. /// @@ -1065,6 +1058,7 @@ pub struct Workspace { center: PaneGroup, left_dock: Entity<Dock>, bottom_dock: Entity<Dock>, + bottom_dock_layout: BottomDockLayout, right_dock: Entity<Dock>, panes: Vec<Entity<Pane>>, panes_by_item: HashMap<EntityId, WeakEntity<Pane>>, @@ -1086,12 +1080,11 @@ pub struct Workspace { leader_updates_tx: mpsc::UnboundedSender<(PeerId, proto::UpdateFollowers)>, database_id: Option<WorkspaceId>, app_state: Arc<AppState>, - dispatching_keystrokes: Rc<RefCell<DispatchingKeystrokes>>, + dispatching_keystrokes: Rc<RefCell<(HashSet<String>, Vec<Keystroke>)>>, _subscriptions: Vec<Subscription>, _apply_leader_updates: Task<Result<()>>, _observe_current_user: Task<Result<()>>, - _schedule_serialize_workspace: Option<Task<()>>, - _schedule_serialize_ssh_paths: Option<Task<()>>, + _schedule_serialize: Option<Task<()>>, pane_history_timestamp: Arc<AtomicUsize>, bounds: Bounds<Pixels>, pub centered_layout: bool, @@ -1150,8 +1143,6 @@ impl Workspace { project::Event::WorktreeRemoved(_) | project::Event::WorktreeAdded(_) => { this.update_window_title(window, cx); - this.update_ssh_paths(cx); - this.serialize_ssh_paths(window, cx); this.serialize_workspace(window, cx); // This event could be triggered by `AddFolderToProject` or `RemoveFromProject`. this.update_history(cx); @@ -1309,6 +1300,7 @@ impl Workspace { ) .detach(); + let bottom_dock_layout = WorkspaceSettings::get_global(cx).bottom_dock_layout; let left_dock = Dock::new(DockPosition::Left, modal_layer.clone(), window, cx); let bottom_dock = Dock::new(DockPosition::Bottom, modal_layer.clone(), window, cx); let right_dock = Dock::new(DockPosition::Right, modal_layer.clone(), window, cx); @@ -1407,6 +1399,7 @@ impl Workspace { suppressed_notifications: HashSet::default(), left_dock, bottom_dock, + bottom_dock_layout, right_dock, project: project.clone(), follower_states: Default::default(), @@ -1419,8 +1412,7 @@ impl Workspace { app_state, _observe_current_user, _apply_leader_updates, - _schedule_serialize_workspace: None, - _schedule_serialize_ssh_paths: None, + _schedule_serialize: None, leader_updates_tx, _subscriptions: subscriptions, pane_history_timestamp, @@ -1634,6 +1626,10 @@ impl Workspace { &self.bottom_dock } + pub fn bottom_dock_layout(&self) -> BottomDockLayout { + self.bottom_dock_layout + } + pub fn set_bottom_dock_layout( &mut self, layout: BottomDockLayout, @@ -1645,6 +1641,7 @@ impl Workspace { content.bottom_dock_layout = Some(layout); }); + self.bottom_dock_layout = layout; cx.notify(); self.serialize_workspace(window, cx); } @@ -2316,65 +2313,49 @@ impl Workspace { window: &mut Window, cx: &mut Context<Self>, ) { - let keystrokes: Vec<Keystroke> = action + let mut state = self.dispatching_keystrokes.borrow_mut(); + if !state.0.insert(action.0.clone()) { + cx.propagate(); + return; + } + let mut keystrokes: Vec<Keystroke> = action .0 .split(' ') .flat_map(|k| Keystroke::parse(k).log_err()) .collect(); - let _ = self.send_keystrokes_impl(keystrokes, window, cx); - } + keystrokes.reverse(); - pub fn send_keystrokes_impl( - &mut self, - keystrokes: Vec<Keystroke>, - window: &mut Window, - cx: &mut Context<Self>, - ) -> Shared<Task<()>> { - let mut state = self.dispatching_keystrokes.borrow_mut(); - if !state.dispatched.insert(keystrokes.clone()) { - cx.propagate(); - return state.task.clone().unwrap(); - } - - state.queue.extend(keystrokes); + state.1.append(&mut keystrokes); + drop(state); let keystrokes = self.dispatching_keystrokes.clone(); - if state.task.is_none() { - state.task = Some( - window - .spawn(cx, async move |cx| { - // limit to 100 keystrokes to avoid infinite recursion. - for _ in 0..100 { - let mut state = keystrokes.borrow_mut(); - let Some(keystroke) = state.queue.pop_front() else { - state.dispatched.clear(); - state.task.take(); - return; - }; - drop(state); - cx.update(|window, cx| { - let focused = window.focused(cx); - window.dispatch_keystroke(keystroke.clone(), cx); - if window.focused(cx) != focused { - // dispatch_keystroke may cause the focus to change. - // draw's side effect is to schedule the FocusChanged events in the current flush effect cycle - // And we need that to happen before the next keystroke to keep vim mode happy... - // (Note that the tests always do this implicitly, so you must manually test with something like: - // "bindings": { "g z": ["workspace::SendKeystrokes", ": j <enter> u"]} - // ) - window.draw(cx).clear(); - } - }) - .ok(); + window + .spawn(cx, async move |cx| { + // limit to 100 keystrokes to avoid infinite recursion. + for _ in 0..100 { + let Some(keystroke) = keystrokes.borrow_mut().1.pop() else { + keystrokes.borrow_mut().0.clear(); + return Ok(()); + }; + cx.update(|window, cx| { + let focused = window.focused(cx); + window.dispatch_keystroke(keystroke.clone(), cx); + if window.focused(cx) != focused { + // dispatch_keystroke may cause the focus to change. + // draw's side effect is to schedule the FocusChanged events in the current flush effect cycle + // And we need that to happen before the next keystroke to keep vim mode happy... + // (Note that the tests always do this implicitly, so you must manually test with something like: + // "bindings": { "g z": ["workspace::SendKeystrokes", ": j <enter> u"]} + // ) + window.draw(cx).clear(); } + })?; + } - *keystrokes.borrow_mut() = Default::default(); - log::error!("over 100 keystrokes passed to send_keystrokes"); - }) - .shared(), - ); - } - state.task.clone().unwrap() + *keystrokes.borrow_mut() = Default::default(); + anyhow::bail!("over 100 keystrokes passed to send_keystrokes"); + }) + .detach_and_log_err(cx); } fn save_all_internal( @@ -5077,46 +5058,6 @@ impl Workspace { } } - fn update_ssh_paths(&mut self, cx: &App) { - let project = self.project().read(cx); - if !project.is_local() { - let paths: Vec<String> = project - .visible_worktrees(cx) - .map(|worktree| worktree.read(cx).abs_path().to_string_lossy().to_string()) - .collect(); - if let Some(ssh_project) = &mut self.serialized_ssh_project { - ssh_project.paths = paths; - } - } - } - - fn serialize_ssh_paths(&mut self, window: &mut Window, cx: &mut Context<Workspace>) { - if self._schedule_serialize_ssh_paths.is_none() { - self._schedule_serialize_ssh_paths = - Some(cx.spawn_in(window, async move |this, cx| { - cx.background_executor() - .timer(SERIALIZATION_THROTTLE_TIME) - .await; - this.update_in(cx, |this, window, cx| { - let task = if let Some(ssh_project) = &this.serialized_ssh_project { - let ssh_project_id = ssh_project.id; - let ssh_project_paths = ssh_project.paths.clone(); - window.spawn(cx, async move |_| { - persistence::DB - .update_ssh_project_paths(ssh_project_id, ssh_project_paths) - .await - }) - } else { - Task::ready(Err(anyhow::anyhow!("No SSH project to serialize"))) - }; - task.detach(); - this._schedule_serialize_ssh_paths.take(); - }) - .log_err(); - })); - } - } - fn remove_panes(&mut self, member: Member, window: &mut Window, cx: &mut Context<Workspace>) { match member { Member::Axis(PaneAxis { members, .. }) => { @@ -5160,18 +5101,17 @@ impl Workspace { } fn serialize_workspace(&mut self, window: &mut Window, cx: &mut Context<Self>) { - if self._schedule_serialize_workspace.is_none() { - self._schedule_serialize_workspace = - Some(cx.spawn_in(window, async move |this, cx| { - cx.background_executor() - .timer(SERIALIZATION_THROTTLE_TIME) - .await; - this.update_in(cx, |this, window, cx| { - this.serialize_workspace_internal(window, cx).detach(); - this._schedule_serialize_workspace.take(); - }) - .log_err(); - })); + if self._schedule_serialize.is_none() { + self._schedule_serialize = Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(Duration::from_millis(100)) + .await; + this.update_in(cx, |this, window, cx| { + this.serialize_workspace_internal(window, cx).detach(); + this._schedule_serialize.take(); + }) + .log_err(); + })); } } @@ -5734,6 +5674,7 @@ impl Workspace { let client = project.read(cx).client(); let user_store = project.read(cx).user_store(); + let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let session = cx.new(|cx| AppSession::new(Session::test(), cx)); window.activate_window(); @@ -6282,7 +6223,6 @@ impl Render for Workspace { .iter() .map(|(_, notification)| notification.entity_id()) .collect::<Vec<_>>(); - let bottom_dock_layout = WorkspaceSettings::get_global(cx).bottom_dock_layout; client_side_decorations( self.actions(div(), window, cx) @@ -6406,7 +6346,7 @@ impl Render for Workspace { )) }) .child({ - match bottom_dock_layout { + match self.bottom_dock_layout { BottomDockLayout::Full => div() .flex() .flex_col() @@ -6938,13 +6878,10 @@ async fn join_channel_internal( match status { Status::Connecting | Status::Authenticating - | Status::Authenticated | Status::Reconnecting | Status::Reauthenticating => continue, Status::Connected { .. } => break 'outer, - Status::SignedOut | Status::AuthenticationError => { - return Err(ErrorCode::SignedOut.into()); - } + Status::SignedOut => return Err(ErrorCode::SignedOut.into()), Status::UpgradeRequired => return Err(ErrorCode::UpgradeRequired.into()), Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { return Err(ErrorCode::Disconnected.into()); diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index e6949f62df..4fc6b91abb 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -62,7 +62,7 @@ use std::{ }, time::{Duration, Instant}, }; -use sum_tree::{Bias, Edit, KeyedItem, SeekTarget, SumTree, Summary, TreeMap, TreeSet}; +use sum_tree::{Bias, Edit, KeyedItem, SeekTarget, SumTree, Summary, TreeMap, TreeSet, Unit}; use text::{LineEnding, Rope}; use util::{ ResultExt, @@ -407,12 +407,12 @@ struct LocalRepositoryEntry { } impl sum_tree::Item for LocalRepositoryEntry { - type Summary = PathSummary<&'static ()>; + type Summary = PathSummary<Unit>; fn summary(&self, _: &<Self::Summary as Summary>::Context) -> Self::Summary { PathSummary { max_path: self.work_directory.path_key().0, - item_summary: &(), + item_summary: Unit, } } } @@ -425,6 +425,12 @@ impl KeyedItem for LocalRepositoryEntry { } } +//impl LocalRepositoryEntry { +// pub fn repo(&self) -> &Arc<dyn GitRepository> { +// &self.repo_ptr +// } +//} + impl Deref for LocalRepositoryEntry { type Target = WorkDirectory; @@ -5411,7 +5417,7 @@ impl<'a> SeekTarget<'a, EntrySummary, TraversalProgress<'a>> for TraversalTarget } } -impl<'a> SeekTarget<'a, PathSummary<&'static ()>, TraversalProgress<'a>> for TraversalTarget<'_> { +impl<'a> SeekTarget<'a, PathSummary<Unit>, TraversalProgress<'a>> for TraversalTarget<'_> { fn cmp(&self, cursor_location: &TraversalProgress<'a>, _: &()) -> Ordering { self.cmp_progress(cursor_location) } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 536af7b7b9..3e8c169a83 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.199.0" +version = "0.197.5" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team <hi@zed.dev>"] @@ -56,7 +56,6 @@ env_logger.workspace = true extension.workspace = true extension_host.workspace = true extensions_ui.workspace = true -feature_flags.workspace = true feedback.workspace = true file_finder.workspace = true fs.workspace = true @@ -106,7 +105,6 @@ outline_panel.workspace = true parking_lot.workspace = true paths.workspace = true picker.workspace = true -settings_profile_selector.workspace = true profiling.workspace = true project.workspace = true project_panel.workspace = true diff --git a/crates/zed/RELEASE_CHANNEL b/crates/zed/RELEASE_CHANNEL index 38f8e886e1..870bbe4e50 100644 --- a/crates/zed/RELEASE_CHANNEL +++ b/crates/zed/RELEASE_CHANNEL @@ -1 +1 @@ -dev +stable \ No newline at end of file diff --git a/crates/zed/resources/app-icon-nightly.png b/crates/zed/resources/app-icon-nightly.png index 776cd06b1b..5f1304a6af 100644 Binary files a/crates/zed/resources/app-icon-nightly.png and b/crates/zed/resources/app-icon-nightly.png differ diff --git a/crates/zed/resources/app-icon-nightly@2x.png b/crates/zed/resources/app-icon-nightly@2x.png index 6d781594ac..edb416ede4 100644 Binary files a/crates/zed/resources/app-icon-nightly@2x.png and b/crates/zed/resources/app-icon-nightly@2x.png differ diff --git a/crates/zed/resources/windows/zed.iss b/crates/zed/resources/windows/zed.iss index 2e76f35a0b..9d104d1f15 100644 --- a/crates/zed/resources/windows/zed.iss +++ b/crates/zed/resources/windows/zed.iss @@ -62,7 +62,6 @@ Source: "{#ResourcesDir}\Zed.exe"; DestDir: "{code:GetInstallDir}"; Flags: ignor Source: "{#ResourcesDir}\bin\*"; DestDir: "{code:GetInstallDir}\bin"; Flags: ignoreversion Source: "{#ResourcesDir}\tools\*"; DestDir: "{app}\tools"; Flags: ignoreversion Source: "{#ResourcesDir}\appx\*"; DestDir: "{app}\appx"; BeforeInstall: RemoveAppxPackage; AfterInstall: AddAppxPackage; Flags: ignoreversion; Check: IsWindows11OrLater -Source: "{#ResourcesDir}\amd_ags_x64.dll"; DestDir: "{app}"; Flags: ignoreversion [Icons] Name: "{group}\{#AppName}"; Filename: "{app}\{#AppExeName}.exe"; AppUserModelID: "{#AppUserId}" @@ -1246,6 +1245,16 @@ Root: HKCU; Subkey: "Software\Classes\zed\DefaultIcon"; ValueType: "string"; Val Root: HKCU; Subkey: "Software\Classes\zed\shell\open\command"; ValueType: "string"; ValueData: """{app}\Zed.exe"" ""%1""" [Code] +function InitializeSetup(): Boolean; +begin + Result := True; + + if not WizardSilent() and IsAdmin() then begin + MsgBox('This User Installer is not meant to be run as an Administrator.', mbError, MB_OK); + Result := False; + end; +end; + function WizardNotSilent(): Boolean; begin Result := not WizardSilent(); diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index c264135e5c..d0b9c53397 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -42,7 +42,7 @@ use theme::{ ActiveTheme, IconThemeNotFoundError, SystemAppearance, ThemeNotFoundError, ThemeRegistry, ThemeSettings, }; -use util::{ResultExt, TryFutureExt, maybe}; +use util::{ConnectionResult, ResultExt, TryFutureExt, maybe}; use uuid::Uuid; use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::{ @@ -613,7 +613,6 @@ pub fn main() { language_selector::init(cx); toolchain_selector::init(cx); theme_selector::init(cx); - settings_profile_selector::init(cx); language_tools::init(cx); call::init(app_state.client.clone(), app_state.user_store.clone(), cx); notifications::init(app_state.client.clone(), app_state.user_store.clone(), cx); @@ -682,9 +681,17 @@ pub fn main() { cx.spawn({ let client = app_state.client.clone(); - async move |cx| authenticate(client, &cx).await + async move |cx| match authenticate(client, &cx).await { + ConnectionResult::Timeout => log::error!("Timeout during initial auth"), + ConnectionResult::ConnectionReset => { + log::error!("Connection reset during initial auth") + } + ConnectionResult::Result(r) => { + r.log_err(); + } + } }) - .detach_and_log_err(cx); + .detach(); let urls: Vec<_> = args .paths_or_urls @@ -834,7 +841,15 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut let client = app_state.client.clone(); // we continue even if authentication fails as join_channel/ open channel notes will // show a visible error message. - authenticate(client, &cx).await.log_err(); + match authenticate(client, &cx).await { + ConnectionResult::Timeout => { + log::error!("Timeout during open request handling") + } + ConnectionResult::ConnectionReset => { + log::error!("Connection reset during open request handling") + } + ConnectionResult::Result(r) => r?, + }; if let Some(channel_id) = request.join_channel { cx.update(|cx| { @@ -884,18 +899,18 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut } } -async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> Result<()> { +async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> ConnectionResult<()> { if stdout_is_a_pty() { if client::IMPERSONATE_LOGIN.is_some() { - client.sign_in_with_optional_connect(false, cx).await?; + return client.authenticate_and_connect(false, cx).await; } else if client.has_credentials(cx).await { - client.sign_in_with_optional_connect(true, cx).await?; + return client.authenticate_and_connect(true, cx).await; } } else if client.has_credentials(cx).await { - client.sign_in_with_optional_connect(true, cx).await?; + return client.authenticate_and_connect(true, cx).await; } - Ok(()) + ConnectionResult::Result(Ok(())) } async fn system_id() -> Result<IdType> { diff --git a/crates/zed/src/reliability.rs b/crates/zed/src/reliability.rs index d7f1473288..ccbe57e7b3 100644 --- a/crates/zed/src/reliability.rs +++ b/crates/zed/src/reliability.rs @@ -63,7 +63,7 @@ pub fn init_panic_hook( location.column(), match app_commit_sha.as_ref() { Some(commit_sha) => format!( - "https://github.com/zed-industries/zed/blob/{}/{}#L{} \ + "https://github.com/zed-industries/zed/blob/{}/src/{}#L{} \ (may not be uploaded, line may be incorrect if files modified)\n", commit_sha.full(), location.file(), diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index af317edeee..24c7ab5ba2 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -19,7 +19,6 @@ use collections::VecDeque; use debugger_ui::debugger_panel::DebugPanel; use editor::ProposedChangesEditorToolbar; use editor::{Editor, MultiBuffer}; -use feature_flags::{FeatureFlagAppExt, PanicFeatureFlag}; use futures::future::Either; use futures::{StreamExt, channel::mpsc, select_biased}; use git_ui::git_panel::GitPanel; @@ -54,12 +53,9 @@ use settings::{ initial_local_debug_tasks_content, initial_project_settings_content, initial_tasks_content, update_settings_file, }; -use std::{ - borrow::Cow, - path::{Path, PathBuf}, - sync::Arc, - sync::atomic::{self, AtomicBool}, -}; +use std::path::PathBuf; +use std::sync::atomic::{self, AtomicBool}; +use std::{borrow::Cow, path::Path, sync::Arc}; use terminal_view::terminal_panel::{self, TerminalPanel}; use theme::{ActiveTheme, ThemeSettings}; use ui::{PopoverMenuHandle, prelude::*}; @@ -111,8 +107,6 @@ actions!( Zoom, /// Triggers a test panic for debugging. TestPanic, - /// Triggers a hard crash for debugging. - TestCrash, ] ); @@ -126,28 +120,11 @@ pub fn init(cx: &mut App) { cx.on_action(quit); cx.on_action(|_: &RestoreBanner, cx| title_bar::restore_banner(cx)); - let flag = cx.wait_for_flag::<PanicFeatureFlag>(); - cx.spawn(async |cx| { - if cx - .update(|cx| ReleaseChannel::global(cx) == ReleaseChannel::Dev) - .unwrap_or_default() - || flag.await - { - cx.update(|cx| { - cx.on_action(|_: &TestPanic, _| panic!("Ran the TestPanic action")); - cx.on_action(|_: &TestCrash, _| { - unsafe extern "C" { - fn puts(s: *const i8); - } - unsafe { - puts(0xabad1d3a as *const i8); - } - }); - }) - .ok(); - }; - }) - .detach(); + + if ReleaseChannel::global(cx) == ReleaseChannel::Dev { + cx.on_action(test_panic); + } + cx.on_action(|_: &OpenLog, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { open_log_file(workspace, window, cx); @@ -1010,6 +987,10 @@ fn about( .detach(); } +fn test_panic(_: &TestPanic, _: &mut App) { + panic!("Ran the TestPanic action") +} + fn install_cli( _: &mut Workspace, _: &install_cli::Install, @@ -4354,7 +4335,6 @@ mod tests { "menu", "notebook", "notification_panel", - "onboarding", "outline", "outline_panel", "pane", @@ -4367,7 +4347,6 @@ mod tests { "repl", "rules_library", "search", - "settings_profile_selector", "snippets", "supermaven", "svg", diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index 15d5659f03..78532b10b4 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -24,10 +24,6 @@ pub fn app_menus() -> Vec<Menu> { zed_actions::OpenDefaultKeymap, ), MenuItem::action("Open Project Settings", super::OpenProjectSettings), - MenuItem::action( - "Select Settings Profile...", - zed_actions::settings_profile_selector::Toggle, - ), MenuItem::action( "Select Theme...", zed_actions::theme_selector::Toggle::default(), diff --git a/crates/zed/src/zed/component_preview.rs b/crates/zed/src/zed/component_preview.rs index 480505338b..670793cff3 100644 --- a/crates/zed/src/zed/component_preview.rs +++ b/crates/zed/src/zed/component_preview.rs @@ -105,7 +105,6 @@ enum PreviewPage { struct ComponentPreview { active_page: PreviewPage, active_thread: Option<Entity<ActiveThread>>, - reset_key: usize, component_list: ListState, component_map: HashMap<ComponentId, ComponentMetadata>, components: Vec<ComponentMetadata>, @@ -139,7 +138,8 @@ impl ComponentPreview { let project_clone = project.clone(); cx.spawn_in(window, async move |entity, cx| { - let thread_store_future = load_preview_thread_store(project_clone.clone(), cx); + let thread_store_future = + load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx); let text_thread_store_future = load_preview_text_thread_store(workspace_clone.clone(), project_clone.clone(), cx); @@ -188,7 +188,6 @@ impl ComponentPreview { let mut component_preview = Self { active_page, active_thread: None, - reset_key: 0, component_list, component_map: component_registry.component_map(), components: sorted_components, @@ -266,13 +265,8 @@ impl ComponentPreview { } fn set_active_page(&mut self, page: PreviewPage, cx: &mut Context<Self>) { - if self.active_page == page { - // Force the current preview page to render again - self.reset_key = self.reset_key.wrapping_add(1); - } else { - self.active_page = page; - cx.emit(ItemEvent::UpdateTab); - } + self.active_page = page; + cx.emit(ItemEvent::UpdateTab); cx.notify(); } @@ -696,7 +690,6 @@ impl ComponentPreview { component.clone(), self.workspace.clone(), self.active_thread.clone(), - self.reset_key, )) .into_any_element() } else { @@ -1048,7 +1041,6 @@ pub struct ComponentPreviewPage { component: ComponentMetadata, workspace: WeakEntity<Workspace>, active_thread: Option<Entity<ActiveThread>>, - reset_key: usize, } impl ComponentPreviewPage { @@ -1056,7 +1048,6 @@ impl ComponentPreviewPage { component: ComponentMetadata, workspace: WeakEntity<Workspace>, active_thread: Option<Entity<ActiveThread>>, - reset_key: usize, // languages: Arc<LanguageRegistry> ) -> Self { Self { @@ -1064,7 +1055,6 @@ impl ComponentPreviewPage { component, workspace, active_thread, - reset_key, } } @@ -1165,7 +1155,6 @@ impl ComponentPreviewPage { }; v_flex() - .id(("component-preview", self.reset_key)) .size_full() .flex_1() .px_12() diff --git a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs index de98106fae..825744572d 100644 --- a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs +++ b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs @@ -12,19 +12,21 @@ use ui::{App, Window}; use workspace::Workspace; pub fn load_preview_thread_store( + workspace: WeakEntity<Workspace>, project: Entity<Project>, cx: &mut AsyncApp, ) -> Task<Result<Entity<ThreadStore>>> { - cx.update(|cx| { - ThreadStore::load( - project.clone(), - cx.new(|_| ToolWorkingSet::default()), - None, - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - }) - .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) + workspace + .update(cx, |_, cx| { + ThreadStore::load( + project.clone(), + cx.new(|_| ToolWorkingSet::default()), + None, + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) } pub fn load_preview_text_thread_store( diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index bbecd26417..52b7166a11 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -1,10 +1,10 @@ -use client::{Client, UserStore}; +use client::{Client, DisableAiSettings, UserStore}; use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::Editor; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use language::language_settings::{EditPredictionProvider, all_language_settings}; -use settings::SettingsStore; +use settings::{Settings as _, SettingsStore}; use smol::stream::StreamExt; use std::{cell::RefCell, rc::Rc, sync::Arc}; use supermaven::{Supermaven, SupermavenCompletionProvider}; @@ -90,7 +90,10 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) { let new_provider = all_language_settings(None, cx).edit_predictions.provider; if new_provider != provider { - let tos_accepted = user_store.read(cx).has_accepted_terms_of_service(); + let tos_accepted = user_store + .read(cx) + .current_user_has_accepted_terms() + .unwrap_or(false); telemetry::event!( "Edit Prediction Provider Changed", @@ -192,6 +195,18 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed }, )) .detach(); + if !DisableAiSettings::get_global(cx).disable_ai { + editor + .register_action(cx.listener( + |editor, + _: &editor::actions::AcceptPartialCopilotSuggestion, + window: &mut Window, + cx: &mut Context<Editor>| { + editor.accept_partial_inline_completion(&Default::default(), window, cx); + }, + )) + .detach(); + } } fn assign_edit_prediction_provider( @@ -229,7 +244,7 @@ fn assign_edit_prediction_provider( } } EditPredictionProvider::Zed => { - if user_store.read(cx).current_user().is_some() { + if client.status().borrow().is_connected() { let mut worktree = None; if let Some(buffer) = &singleton_buffer { diff --git a/crates/zed/src/zed/quick_action_bar.rs b/crates/zed/src/zed/quick_action_bar.rs index 1164704ce6..aff124a0bc 100644 --- a/crates/zed/src/zed/quick_action_bar.rs +++ b/crates/zed/src/zed/quick_action_bar.rs @@ -192,7 +192,7 @@ impl Render for QuickActionBar { }; v_flex() .child( - IconButton::new("toggle_code_actions_icon", IconName::BoltOutlined) + IconButton::new("toggle_code_actions_icon", IconName::Bolt) .icon_size(IconSize::Small) .style(ButtonStyle::Subtle) .disabled(!has_available_code_actions) diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index 64891b6973..4b4bf016c4 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -260,25 +260,14 @@ pub mod icon_theme_selector { } } -pub mod settings_profile_selector { - use gpui::Action; - use schemars::JsonSchema; - use serde::Deserialize; - - #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] - #[action(namespace = settings_profile_selector)] - pub struct Toggle; -} - pub mod agent { use gpui::actions; actions!( agent, [ - /// Opens the agent settings panel. - #[action(deprecated_aliases = ["agent::OpenConfiguration"])] - OpenSettings, + /// Opens the agent configuration panel. + OpenConfiguration, /// Opens the agent onboarding modal. OpenOnboardingModal, /// Resets the agent onboarding state. diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 26eeda3f22..c2b1de08ae 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -21,7 +21,6 @@ ai_onboarding.workspace = true anyhow.workspace = true arrayvec.workspace = true client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true copilot.workspace = true @@ -40,6 +39,7 @@ log.workspace = true menu.workspace = true postage.workspace = true project.workspace = true +proto.workspace = true regex.workspace = true release_channel.workspace = true serde.workspace = true @@ -52,17 +52,16 @@ thiserror.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true -workspace-hack.workspace = true workspace.workspace = true worktree.workspace = true zed_actions.workspace = true +zed_llm_client.workspace = true +workspace-hack.workspace = true [dev-dependencies] -call = { workspace = true, features = ["test-support"] } +collections = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } -cloud_api_types.workspace = true -collections = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } @@ -78,4 +77,5 @@ tree-sitter-rust.workspace = true unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } +call = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index f130c3a965..d6f033899d 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -17,10 +17,6 @@ pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; -use cloud_llm_client::{ - AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, -}; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; use gpui::{ @@ -57,6 +53,10 @@ use uuid::Uuid; use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; use worktree::Worktree; +use zed_llm_client::{ + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, + PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, +}; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>"; @@ -121,10 +121,9 @@ impl Dismissable for ZedPredictUpsell { } pub fn should_show_upsell_modal(user_store: &Entity<UserStore>, cx: &App) -> bool { - if user_store.read(cx).has_accepted_terms_of_service() { - !ZedPredictUpsell::dismissed() - } else { - true + match user_store.read(cx).current_user_has_accepted_terms() { + Some(true) => !ZedPredictUpsell::dismissed(), + Some(false) | None => true, } } @@ -146,14 +145,14 @@ pub struct InlineCompletion { input_events: Arc<str>, input_excerpt: Arc<str>, output_excerpt: Arc<str>, - buffer_snapshotted_at: Instant, + request_sent_at: Instant, response_received_at: Instant, } impl InlineCompletion { fn latency(&self) -> Duration { self.response_received_at - .duration_since(self.buffer_snapshotted_at) + .duration_since(self.request_sent_at) } fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> { @@ -227,9 +226,12 @@ pub struct Zeta { data_collection_choice: Entity<DataCollectionChoice>, llm_token: LlmApiToken, _llm_token_subscription: Subscription, + /// Whether the terms of service have been accepted. + tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. update_required: bool, user_store: Entity<UserStore>, + _user_store_subscription: Subscription, license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>, } @@ -304,7 +306,22 @@ impl Zeta { .detach_and_log_err(cx); }, ), + tos_accepted: user_store + .read(cx) + .current_user_has_accepted_terms() + .unwrap_or(false), update_required: false, + _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { + match event { + client::user::Event::PrivateUserInfoUpdated => { + this.tos_accepted = user_store + .read(cx) + .current_user_has_accepted_terms() + .unwrap_or(false); + } + _ => {} + } + }), license_detection_watchers: HashMap::default(), user_store, } @@ -391,48 +408,104 @@ impl Zeta { + Send + 'static, { - let buffer = buffer.clone(); - let buffer_snapshotted_at = Instant::now(); let snapshot = self.report_changes_for_buffer(&buffer, cx); - let zeta = cx.entity(); + let diagnostic_groups = snapshot.diagnostic_groups(None); + let cursor_point = cursor.to_point(&snapshot); + let cursor_offset = cursor_point.to_offset(&snapshot); let events = self.events.clone(); + let path: Arc<Path> = snapshot + .file() + .map(|f| Arc::from(f.full_path(cx).as_path())) + .unwrap_or_else(|| Arc::from(Path::new("untitled"))); + + let zeta = cx.entity(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let full_path: Arc<Path> = snapshot - .file() - .map(|f| Arc::from(f.full_path(cx).as_path())) - .unwrap_or_else(|| Arc::from(Path::new("untitled"))); - let full_path_str = full_path.to_string_lossy().to_string(); - let cursor_point = cursor.to_point(&snapshot); - let cursor_offset = cursor_point.to_offset(&snapshot); - let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS); - let gather_task = gather_context( - project, - full_path_str, - &snapshot, - cursor_point, - make_events_prompt, - can_collect_data, - cx, - ); + let buffer = buffer.clone(); + + let local_lsp_store = + project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); + let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store { + Some( + diagnostic_groups + .into_iter() + .filter_map(|(language_server_id, diagnostic_group)| { + let language_server = + local_lsp_store.running_language_server_for_id(language_server_id)?; + + Some(( + language_server.name(), + diagnostic_group.resolve::<usize>(&snapshot), + )) + }) + .collect::<Vec<_>>(), + ) + } else { + None + }; cx.spawn(async move |this, cx| { - let GatherContextOutput { - body, - editable_range, - } = gather_task.await?; + let request_sent_at = Instant::now(); + + struct BackgroundValues { + input_events: String, + input_excerpt: String, + speculated_output: String, + editable_range: Range<usize>, + input_outline: String, + } + + let values = cx + .background_spawn({ + let snapshot = snapshot.clone(); + let path = path.clone(); + async move { + let path = path.to_string_lossy(); + let input_excerpt = excerpt_for_cursor_position( + cursor_point, + &path, + &snapshot, + MAX_REWRITE_TOKENS, + MAX_CONTEXT_TOKENS, + ); + let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS); + let input_outline = prompt_for_outline(&snapshot); + + anyhow::Ok(BackgroundValues { + input_events, + input_excerpt: input_excerpt.prompt, + speculated_output: input_excerpt.speculated_output, + editable_range: input_excerpt.editable_range.to_offset(&snapshot), + input_outline, + }) + } + }) + .await?; log::debug!( "Events:\n{}\nExcerpt:\n{:?}", - body.input_events, - body.input_excerpt + values.input_events, + values.input_excerpt ); - let input_outline = body.outline.clone().unwrap_or_default(); - let input_events = body.input_events.clone(); - let input_excerpt = body.input_excerpt.clone(); + let body = PredictEditsBody { + input_events: values.input_events.clone(), + input_excerpt: values.input_excerpt.clone(), + speculated_output: Some(values.speculated_output), + outline: Some(values.input_outline.clone()), + can_collect_data, + diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| { + diagnostic_groups + .into_iter() + .map(|(name, diagnostic_group)| { + Ok((name.to_string(), serde_json::to_value(diagnostic_group)?)) + }) + .collect::<Result<Vec<_>>>() + .log_err() + }), + }; let response = perform_predict_edits(PerformPredictEditsParams { client, @@ -490,13 +563,13 @@ impl Zeta { response, buffer, &snapshot, - editable_range, + values.editable_range, cursor_offset, - full_path, - input_outline, - input_events, - input_excerpt, - buffer_snapshotted_at, + path, + values.input_outline, + values.input_events, + values.input_excerpt, + request_sent_at, &cx, ) .await @@ -695,7 +768,7 @@ and then another ) } - pub fn perform_predict_edits( + fn perform_predict_edits( params: PerformPredictEditsParams, ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> { async move { @@ -850,7 +923,7 @@ and then another input_outline: String, input_events: String, input_excerpt: String, - buffer_snapshotted_at: Instant, + request_sent_at: Instant, cx: &AsyncApp, ) -> Task<Result<Option<InlineCompletion>>> { let snapshot = snapshot.clone(); @@ -896,7 +969,7 @@ and then another input_events: input_events.into(), input_excerpt: input_excerpt.into(), output_excerpt, - buffer_snapshotted_at, + request_sent_at, response_received_at: Instant::now(), })) }) @@ -1080,7 +1153,7 @@ and then another } } -pub struct PerformPredictEditsParams { +struct PerformPredictEditsParams { pub client: Arc<Client>, pub llm_token: LlmApiToken, pub app_version: SemanticVersion, @@ -1155,77 +1228,6 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: .sum() } -pub struct GatherContextOutput { - pub body: PredictEditsBody, - pub editable_range: Range<usize>, -} - -pub fn gather_context( - project: Option<&Entity<Project>>, - full_path_str: String, - snapshot: &BufferSnapshot, - cursor_point: language::Point, - make_events_prompt: impl FnOnce() -> String + Send + 'static, - can_collect_data: bool, - cx: &App, -) -> Task<Result<GatherContextOutput>> { - let local_lsp_store = - project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); - let diagnostic_groups: Vec<(String, serde_json::Value)> = - if let Some(local_lsp_store) = local_lsp_store { - snapshot - .diagnostic_groups(None) - .into_iter() - .filter_map(|(language_server_id, diagnostic_group)| { - let language_server = - local_lsp_store.running_language_server_for_id(language_server_id)?; - let diagnostic_group = diagnostic_group.resolve::<usize>(&snapshot); - let language_server_name = language_server.name().to_string(); - let serialized = serde_json::to_value(diagnostic_group).unwrap(); - Some((language_server_name, serialized)) - }) - .collect::<Vec<_>>() - } else { - Vec::new() - }; - - cx.background_spawn({ - let snapshot = snapshot.clone(); - async move { - let diagnostic_groups = if diagnostic_groups.is_empty() { - None - } else { - Some(diagnostic_groups) - }; - - let input_excerpt = excerpt_for_cursor_position( - cursor_point, - &full_path_str, - &snapshot, - MAX_REWRITE_TOKENS, - MAX_CONTEXT_TOKENS, - ); - let input_events = make_events_prompt(); - let input_outline = prompt_for_outline(&snapshot); - let editable_range = input_excerpt.editable_range.to_offset(&snapshot); - - let body = PredictEditsBody { - input_events, - input_excerpt: input_excerpt.prompt, - speculated_output: Some(input_excerpt.speculated_output), - outline: Some(input_outline), - can_collect_data, - diagnostic_groups, - }; - - Ok(GatherContextOutput { - body, - editable_range, - }) - } - }) -} - fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { let mut input_outline = String::new(); @@ -1276,7 +1278,7 @@ struct RegisteredBuffer { } #[derive(Clone)] -pub enum Event { +enum Event { BufferChange { old_snapshot: BufferSnapshot, new_snapshot: BufferSnapshot, @@ -1571,12 +1573,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } fn needs_terms_acceptance(&self, cx: &App) -> bool { - !self - .zeta - .read(cx) - .user_store - .read(cx) - .has_accepted_terms_of_service() + !self.zeta.read(cx).tos_accepted } fn is_refreshing(&self) -> bool { @@ -1591,7 +1588,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider _debounce: bool, cx: &mut Context<Self>, ) { - if self.needs_terms_acceptance(cx) { + if !self.zeta.read(cx).tos_accepted { return; } @@ -1603,7 +1600,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider .zeta .read(cx) .user_store - .read_with(cx, |user_store, _cx| { + .read_with(cx, |user_store, _| { user_store.account_too_young() || user_store.has_overdue_invoices() }) { @@ -1820,14 +1817,13 @@ fn tokens_for_bytes(bytes: usize) -> usize { #[cfg(test)] mod tests { - use client::UserStore; use client::test::FakeServer; use clock::FakeSystemClock; - use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; use language::Point; + use rpc::proto; use settings::SettingsStore; use super::*; @@ -1860,7 +1856,7 @@ mod tests { input_events: "".into(), input_excerpt: "".into(), output_excerpt: "".into(), - buffer_snapshotted_at: Instant::now(), + request_sent_at: Instant::now(), response_received_at: Instant::now(), }; @@ -2031,45 +2027,28 @@ mod tests { <|editable_region_end|> ```"}; - let http_client = FakeHttpClient::create(move |req| async move { - match (req.method(), req.uri().path()) { - (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&CreateLlmTokenResponse { - token: LlmToken("the-llm-token".to_string()), - }) - .unwrap() - .into(), - ) - .unwrap()), - (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") - .unwrap(), - output_excerpt: completion_response.to_string(), - }) - .unwrap() - .into(), - ) - .unwrap()), - _ => Ok(http_client::Response::builder() - .status(404) - .body("Not Found".into()) - .unwrap()), - } + let http_client = FakeHttpClient::create(move |_| async move { + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") + .unwrap(), + output_excerpt: completion_response.to_string(), + }) + .unwrap() + .into(), + ) + .unwrap()) }); let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - // Construct the fake server to authenticate. - let _server = FakeServer::for_client(42, &client, cx).await; + let server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -2077,6 +2056,13 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); + server.receive::<proto::GetUsers>().await.unwrap(); + let token_request = server.receive::<proto::GetLlmToken>().await.unwrap(); + server.respond( + token_request.receipt(), + proto::GetLlmTokenResponse { token: "".into() }, + ); + let completion = completion_task.await.unwrap().unwrap(); buffer.update(cx, |buffer, cx| { buffer.edit(completion.edits.iter().cloned(), None, cx) @@ -2093,36 +2079,20 @@ mod tests { cx: &mut TestAppContext, ) -> Vec<(Range<Point>, String)> { let completion_response = completion_response.to_string(); - let http_client = FakeHttpClient::create(move |req| { + let http_client = FakeHttpClient::create(move |_| { let completion = completion_response.clone(); async move { - match (req.method(), req.uri().path()) { - (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&CreateLlmTokenResponse { - token: LlmToken("the-llm-token".to_string()), - }) - .unwrap() - .into(), - ) - .unwrap()), - (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), - output_excerpt: completion, - }) - .unwrap() - .into(), - ) - .unwrap()), - _ => Ok(http_client::Response::builder() - .status(404) - .body("Not Found".into()) - .unwrap()), - } + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::new_v4(), + output_excerpt: completion, + }) + .unwrap() + .into(), + ) + .unwrap()) } }); @@ -2130,10 +2100,9 @@ mod tests { cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - // Construct the fake server to authenticate. - let _server = FakeServer::for_client(42, &client, cx).await; + let server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); @@ -2142,6 +2111,13 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); + server.receive::<proto::GetUsers>().await.unwrap(); + let token_request = server.receive::<proto::GetLlmToken>().await.unwrap(); + server.respond( + token_request.receipt(), + proto::GetLlmTokenResponse { token: "".into() }, + ); + let completion = completion_task.await.unwrap().unwrap(); completion .edits diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml deleted file mode 100644 index e77351c219..0000000000 --- a/crates/zeta_cli/Cargo.toml +++ /dev/null @@ -1,45 +0,0 @@ -[package] -name = "zeta_cli" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[[bin]] -name = "zeta" -path = "src/main.rs" - -[dependencies] -anyhow.workspace = true -clap.workspace = true -client.workspace = true -debug_adapter_extension.workspace = true -extension.workspace = true -fs.workspace = true -futures.workspace = true -gpui.workspace = true -gpui_tokio.workspace = true -language.workspace = true -language_extension.workspace = true -language_model.workspace = true -language_models.workspace = true -languages = { workspace = true, features = ["load-grammars"] } -node_runtime.workspace = true -paths.workspace = true -project.workspace = true -prompt_store.workspace = true -release_channel.workspace = true -reqwest_client.workspace = true -serde.workspace = true -serde_json.workspace = true -settings.workspace = true -shellexpand.workspace = true -terminal_view.workspace = true -util.workspace = true -watch.workspace = true -workspace-hack.workspace = true -zeta.workspace = true -smol.workspace = true diff --git a/crates/zeta_cli/LICENSE-GPL b/crates/zeta_cli/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/zeta_cli/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zeta_cli/build.rs b/crates/zeta_cli/build.rs deleted file mode 100644 index ccbb54c5b4..0000000000 --- a/crates/zeta_cli/build.rs +++ /dev/null @@ -1,14 +0,0 @@ -fn main() { - let cargo_toml = - std::fs::read_to_string("../zed/Cargo.toml").expect("Failed to read Cargo.toml"); - let version = cargo_toml - .lines() - .find(|line| line.starts_with("version = ")) - .expect("Version not found in crates/zed/Cargo.toml") - .split('=') - .nth(1) - .expect("Invalid version format") - .trim() - .trim_matches('"'); - println!("cargo:rustc-env=ZED_PKG_VERSION={}", version); -} diff --git a/crates/zeta_cli/src/headless.rs b/crates/zeta_cli/src/headless.rs deleted file mode 100644 index 959bb91a8f..0000000000 --- a/crates/zeta_cli/src/headless.rs +++ /dev/null @@ -1,128 +0,0 @@ -use client::{Client, ProxySettings, UserStore}; -use extension::ExtensionHostProxy; -use fs::RealFs; -use gpui::http_client::read_proxy_from_env; -use gpui::{App, AppContext, Entity}; -use gpui_tokio::Tokio; -use language::LanguageRegistry; -use language_extension::LspAccess; -use node_runtime::{NodeBinaryOptions, NodeRuntime}; -use project::Project; -use project::project_settings::ProjectSettings; -use release_channel::AppVersion; -use reqwest_client::ReqwestClient; -use settings::{Settings, SettingsStore}; -use std::path::PathBuf; -use std::sync::Arc; -use util::ResultExt as _; - -/// Headless subset of `workspace::AppState`. -pub struct ZetaCliAppState { - pub languages: Arc<LanguageRegistry>, - pub client: Arc<Client>, - pub user_store: Entity<UserStore>, - pub fs: Arc<dyn fs::Fs>, - pub node_runtime: NodeRuntime, -} - -// TODO: dedupe with crates/eval/src/eval.rs -pub fn init(cx: &mut App) -> ZetaCliAppState { - let app_version = AppVersion::load(env!("ZED_PKG_VERSION")); - release_channel::init(app_version, cx); - gpui_tokio::init(cx); - - let mut settings_store = SettingsStore::new(cx); - settings_store - .set_default_settings(settings::default_settings().as_ref(), cx) - .unwrap(); - cx.set_global(settings_store); - client::init_settings(cx); - - // Set User-Agent so we can download language servers from GitHub - let user_agent = format!( - "Zed/{} ({}; {})", - app_version, - std::env::consts::OS, - std::env::consts::ARCH - ); - let proxy_str = ProxySettings::get_global(cx).proxy.to_owned(); - let proxy_url = proxy_str - .as_ref() - .and_then(|input| input.parse().ok()) - .or_else(read_proxy_from_env); - let http = { - let _guard = Tokio::handle(cx).enter(); - - ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent) - .expect("could not start HTTP client") - }; - cx.set_http_client(Arc::new(http)); - - Project::init_settings(cx); - - let client = Client::production(cx); - cx.set_http_client(client.http_client()); - - let git_binary_path = None; - let fs = Arc::new(RealFs::new( - git_binary_path, - cx.background_executor().clone(), - )); - - let mut languages = LanguageRegistry::new(cx.background_executor().clone()); - languages.set_language_server_download_dir(paths::languages_dir().clone()); - let languages = Arc::new(languages); - - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - - extension::init(cx); - - let (mut tx, rx) = watch::channel(None); - cx.observe_global::<SettingsStore>(move |cx| { - let settings = &ProjectSettings::get_global(cx).node; - let options = NodeBinaryOptions { - allow_path_lookup: !settings.ignore_system_version, - allow_binary_download: true, - use_paths: settings.path.as_ref().map(|node_path| { - let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref()); - let npm_path = settings - .npm_path - .as_ref() - .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref())); - ( - node_path.clone(), - npm_path.unwrap_or_else(|| { - let base_path = PathBuf::new(); - node_path.parent().unwrap_or(&base_path).join("npm") - }), - ) - }), - }; - tx.send(Some(options)).log_err(); - }) - .detach(); - let node_runtime = NodeRuntime::new(client.http_client(), None, rx); - - let extension_host_proxy = ExtensionHostProxy::global(cx); - - language::init(cx); - debug_adapter_extension::init(extension_host_proxy.clone(), cx); - language_extension::init( - LspAccess::Noop, - extension_host_proxy.clone(), - languages.clone(), - ); - language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), cx); - languages::init(languages.clone(), node_runtime.clone(), cx); - prompt_store::init(cx); - terminal_view::init(cx); - - ZetaCliAppState { - languages, - client, - user_store, - fs, - node_runtime, - } -} diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs deleted file mode 100644 index c5374b56c9..0000000000 --- a/crates/zeta_cli/src/main.rs +++ /dev/null @@ -1,376 +0,0 @@ -mod headless; - -use anyhow::{Result, anyhow}; -use clap::{Args, Parser, Subcommand}; -use futures::channel::mpsc; -use futures::{FutureExt as _, StreamExt as _}; -use gpui::{AppContext, Application, AsyncApp}; -use gpui::{Entity, Task}; -use language::Bias; -use language::Buffer; -use language::Point; -use language_model::LlmApiToken; -use project::{Project, ProjectPath}; -use release_channel::AppVersion; -use reqwest_client::ReqwestClient; -use std::path::{Path, PathBuf}; -use std::process::exit; -use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; -use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context}; - -use crate::headless::ZetaCliAppState; - -#[derive(Parser, Debug)] -#[command(name = "zeta")] -struct ZetaCliArgs { - #[command(subcommand)] - command: Commands, -} - -#[derive(Subcommand, Debug)] -enum Commands { - Context(ContextArgs), - Predict { - #[arg(long)] - predict_edits_body: Option<FileOrStdin>, - #[clap(flatten)] - context_args: Option<ContextArgs>, - }, -} - -#[derive(Debug, Args)] -#[group(requires = "worktree")] -struct ContextArgs { - #[arg(long)] - worktree: PathBuf, - #[arg(long)] - cursor: CursorPosition, - #[arg(long)] - use_language_server: bool, - #[arg(long)] - events: Option<FileOrStdin>, -} - -#[derive(Debug, Clone)] -enum FileOrStdin { - File(PathBuf), - Stdin, -} - -impl FileOrStdin { - async fn read_to_string(&self) -> Result<String, std::io::Error> { - match self { - FileOrStdin::File(path) => smol::fs::read_to_string(path).await, - FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await, - } - } -} - -impl FromStr for FileOrStdin { - type Err = <PathBuf as FromStr>::Err; - - fn from_str(s: &str) -> Result<Self, Self::Err> { - match s { - "-" => Ok(Self::Stdin), - _ => Ok(Self::File(PathBuf::from_str(s)?)), - } - } -} - -#[derive(Debug, Clone)] -struct CursorPosition { - path: PathBuf, - point: Point, -} - -impl FromStr for CursorPosition { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result<Self> { - let parts: Vec<&str> = s.split(':').collect(); - if parts.len() != 3 { - return Err(anyhow!( - "Invalid cursor format. Expected 'file.rs:line:column', got '{}'", - s - )); - } - - let path = PathBuf::from(parts[0]); - let line: u32 = parts[1] - .parse() - .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?; - let column: u32 = parts[2] - .parse() - .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?; - - // Convert from 1-based to 0-based indexing - let point = Point::new(line.saturating_sub(1), column.saturating_sub(1)); - - Ok(CursorPosition { path, point }) - } -} - -async fn get_context( - args: ContextArgs, - app_state: &Arc<ZetaCliAppState>, - cx: &mut AsyncApp, -) -> Result<GatherContextOutput> { - let ContextArgs { - worktree: worktree_path, - cursor, - use_language_server, - events, - } = args; - - let worktree_path = worktree_path.canonicalize()?; - if cursor.path.is_absolute() { - return Err(anyhow!("Absolute paths are not supported in --cursor")); - } - - let (project, _lsp_open_handle, buffer) = if use_language_server { - let (project, lsp_open_handle, buffer) = - open_buffer_with_language_server(&worktree_path, &cursor.path, &app_state, cx).await?; - (Some(project), Some(lsp_open_handle), buffer) - } else { - let abs_path = worktree_path.join(&cursor.path); - let content = smol::fs::read_to_string(&abs_path).await?; - let buffer = cx.new(|cx| Buffer::local(content, cx))?; - (None, None, buffer) - }; - - let worktree_name = worktree_path - .file_name() - .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?; - let full_path_str = PathBuf::from(worktree_name) - .join(&cursor.path) - .to_string_lossy() - .to_string(); - - let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?; - let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left); - if clipped_cursor != cursor.point { - let max_row = snapshot.max_point().row; - if cursor.point.row < max_row { - return Err(anyhow!( - "Cursor position {:?} is out of bounds (line length is {})", - cursor.point, - snapshot.line_len(cursor.point.row) - )); - } else { - return Err(anyhow!( - "Cursor position {:?} is out of bounds (max row is {})", - cursor.point, - max_row - )); - } - } - - let events = match events { - Some(events) => events.read_to_string().await?, - None => String::new(), - }; - let can_collect_data = false; - cx.update(|cx| { - gather_context( - project.as_ref(), - full_path_str, - &snapshot, - clipped_cursor, - move || events, - can_collect_data, - cx, - ) - })? - .await -} - -pub async fn open_buffer_with_language_server( - worktree_path: &Path, - path: &Path, - app_state: &Arc<ZetaCliAppState>, - cx: &mut AsyncApp, -) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> { - let project = cx.update(|cx| { - Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - app_state.fs.clone(), - None, - cx, - ) - })?; - - let worktree = project - .update(cx, |project, cx| { - project.create_worktree(worktree_path, true, cx) - })? - .await?; - - let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { - worktree_id: worktree.id(), - path: path.to_path_buf().into(), - })?; - - let buffer = project - .update(cx, |project, cx| project.open_buffer(project_path, cx))? - .await?; - - let lsp_open_handle = project.update(cx, |project, cx| { - project.register_buffer_with_language_servers(&buffer, cx) - })?; - - let log_prefix = path.to_string_lossy().to_string(); - wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; - - Ok((project, lsp_open_handle, buffer)) -} - -// TODO: Dedupe with similar function in crates/eval/src/instance.rs -pub fn wait_for_lang_server( - project: &Entity<Project>, - buffer: &Entity<Buffer>, - log_prefix: String, - cx: &mut AsyncApp, -) -> Task<Result<()>> { - println!("{}⏵ Waiting for language server", log_prefix); - - let (mut tx, mut rx) = mpsc::channel(1); - - let lsp_store = project - .read_with(cx, |project, _| project.lsp_store()) - .unwrap(); - - let has_lang_server = buffer - .update(cx, |buffer, cx| { - lsp_store.update(cx, |lsp_store, cx| { - lsp_store - .language_servers_for_local_buffer(&buffer, cx) - .next() - .is_some() - }) - }) - .unwrap_or(false); - - if has_lang_server { - project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) - .unwrap() - .detach(); - } - - let subscriptions = [ - cx.subscribe(&lsp_store, { - let log_prefix = log_prefix.clone(); - move |_, event, _| match event { - project::LspStoreEvent::LanguageServerUpdate { - message: - client::proto::update_language_server::Variant::WorkProgress( - client::proto::LspWorkProgress { - message: Some(message), - .. - }, - ), - .. - } => println!("{}⟲ {message}", log_prefix), - _ => {} - } - }), - cx.subscribe(&project, { - let buffer = buffer.clone(); - move |project, event, cx| match event { - project::Event::LanguageServerAdded(_, _, _) => { - let buffer = buffer.clone(); - project - .update(cx, |project, cx| project.save_buffer(buffer, cx)) - .detach(); - } - project::Event::DiskBasedDiagnosticsFinished { .. } => { - tx.try_send(()).ok(); - } - _ => {} - } - }), - ]; - - cx.spawn(async move |cx| { - let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0)); - let result = futures::select! { - _ = rx.next() => { - println!("{}⚑ Language server idle", log_prefix); - anyhow::Ok(()) - }, - _ = timeout.fuse() => { - anyhow::bail!("LSP wait timed out after 5 minutes"); - } - }; - drop(subscriptions); - result - }) -} - -fn main() { - let args = ZetaCliArgs::parse(); - let http_client = Arc::new(ReqwestClient::new()); - let app = Application::headless().with_http_client(http_client); - - app.run(move |cx| { - let app_state = Arc::new(headless::init(cx)); - cx.spawn(async move |cx| { - let result = match args.command { - Commands::Context(context_args) => get_context(context_args, &app_state, cx) - .await - .map(|output| serde_json::to_string_pretty(&output.body).unwrap()), - Commands::Predict { - predict_edits_body, - context_args, - } => { - cx.spawn(async move |cx| { - let app_version = cx.update(|cx| AppVersion::global(cx))?; - app_state.client.sign_in(true, cx).await?; - let llm_token = LlmApiToken::default(); - llm_token.refresh(&app_state.client).await?; - - let predict_edits_body = - if let Some(predict_edits_body) = predict_edits_body { - serde_json::from_str(&predict_edits_body.read_to_string().await?)? - } else if let Some(context_args) = context_args { - get_context(context_args, &app_state, cx).await?.body - } else { - return Err(anyhow!( - "Expected either --predict-edits-body-file \ - or the required args of the `context` command." - )); - }; - - let (response, _usage) = - Zeta::perform_predict_edits(PerformPredictEditsParams { - client: app_state.client.clone(), - llm_token, - app_version, - body: predict_edits_body, - }) - .await?; - - Ok(response.output_excerpt) - }) - .await - } - }; - match result { - Ok(output) => { - println!("{}", output); - let _ = cx.update(|cx| cx.quit()); - } - Err(e) => { - eprintln!("Failed: {:?}", e); - exit(1); - } - } - }) - .detach(); - }); -} diff --git a/crates/zlog/src/sink.rs b/crates/zlog/src/sink.rs index 17aa08026e..acf0469c77 100644 --- a/crates/zlog/src/sink.rs +++ b/crates/zlog/src/sink.rs @@ -21,8 +21,6 @@ const ANSI_MAGENTA: &str = "\x1b[35m"; /// Whether stdout output is enabled. static mut ENABLED_SINKS_STDOUT: bool = false; -/// Whether stderr output is enabled. -static mut ENABLED_SINKS_STDERR: bool = false; /// Is Some(file) if file output is enabled. static ENABLED_SINKS_FILE: Mutex<Option<std::fs::File>> = Mutex::new(None); @@ -47,12 +45,6 @@ pub fn init_output_stdout() { } } -pub fn init_output_stderr() { - unsafe { - ENABLED_SINKS_STDERR = true; - } -} - pub fn init_output_file( path: &'static PathBuf, path_rotate: Option<&'static PathBuf>, @@ -123,21 +115,6 @@ pub fn submit(record: Record) { }, record.message ); - } else if unsafe { ENABLED_SINKS_STDERR } { - let mut stdout = std::io::stderr().lock(); - _ = writeln!( - &mut stdout, - "{} {ANSI_BOLD}{}{}{ANSI_RESET} {} {}", - chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%:z"), - LEVEL_ANSI_COLORS[record.level as usize], - LEVEL_OUTPUT_STRINGS[record.level as usize], - SourceFmt { - scope: record.scope, - module_path: record.module_path, - ansi: true, - }, - record.message - ); } let mut file = ENABLED_SINKS_FILE.lock().unwrap_or_else(|handle| { ENABLED_SINKS_FILE.clear_poison(); diff --git a/crates/zlog/src/zlog.rs b/crates/zlog/src/zlog.rs index 5b40278f3f..570c82314c 100644 --- a/crates/zlog/src/zlog.rs +++ b/crates/zlog/src/zlog.rs @@ -5,7 +5,7 @@ mod env_config; pub mod filter; pub mod sink; -pub use sink::{flush, init_output_file, init_output_stderr, init_output_stdout}; +pub use sink::{flush, init_output_file, init_output_stdout}; pub const SCOPE_DEPTH_MAX: usize = 4; diff --git a/docs/README.md b/docs/README.md index a225903674..55993c9e36 100644 --- a/docs/README.md +++ b/docs/README.md @@ -69,64 +69,3 @@ Templates are just functions that modify the source of the docs pages (usually w - Template Trait: crates/docs_preprocessor/src/templates.rs - Example template: crates/docs_preprocessor/src/templates/keybinding.rs - Client-side plugins: docs/theme/plugins.js - -## Postprocessor - -A postprocessor is implemented as a sub-command of `docs_preprocessor` that wraps the builtin `html` renderer and applies post-processing to the `html` files, to add support for page-specific title and meta description values. - -An example of the syntax can be found in `git.md`, as well as below - -```md ---- -title: Some more detailed title for this page -description: A page-specific description ---- - -# Editor -``` - -The above will be transformed into (with non-relevant tags removed) - -```html -<head> - <title>Editor | Some more detailed title for this page - - - -

Editor

- -``` - -If no front-matter is provided, or If one or both keys aren't provided, the title and description will be set based on the `default-title` and `default-description` keys in `book.toml` respectively. - -### Implementation details - -Unfortunately, `mdbook` does not support post-processing like it does pre-processing, and only supports defining one description to put in the meta tag per book rather than per file. So in order to apply post-processing (necessary to modify the html head tags) the global book description is set to a marker value `#description#` and the html renderer is replaced with a sub-command of `docs_preprocessor` that wraps the builtin `html` renderer and applies post-processing to the `html` files, replacing the marker value and the `(.*)` with the contents of the front-matter if there is one. - -### Known limitations - -The front-matter parsing is extremely simple, which avoids needing to take on an additional dependency, or implement full yaml parsing. - -- Double quotes and multi-line values are not supported, i.e. Keys and values must be entirely on the same line, with no double quotes around the value. - -The following will not work: - -```md ---- -title: Some - Multi-line - Title ---- -``` - -And neither will: - -```md ---- -title: "Some title" ---- -``` - -- The front-matter must be at the top of the file, with only white-space preceding it - -- The contents of the title and description will not be html-escaped. They should be simple ascii text with no unicode or emoji characters diff --git a/docs/book.toml b/docs/book.toml index 60ddc5ac51..f5d186f377 100644 --- a/docs/book.toml +++ b/docs/book.toml @@ -6,27 +6,13 @@ src = "src" title = "Zed" site-url = "/docs/" -[build] -extra-watch-dirs = ["../crates/docs_preprocessor"] - -# zed-html is a "custom" renderer that just wraps the -# builtin mdbook html renderer, and applies post-processing -# as post-processing is not possible with mdbook in the same way -# pre-processing is -# The config is passed directly to the html renderer, so all config -# options that apply to html apply to zed-html -[output.zed-html] -command = "cargo run -p docs_preprocessor -- postprocess" -# Set here instead of above as we only use it replace the `#description#` we set in the template -# when no front-matter is provided value -default-description = "Learn how to use and customize Zed, the fast, collaborative code editor. Official docs on features, configuration, AI tools, and workflows." -default-title = "Zed Code Editor Documentation" +[output.html] no-section-label = true preferred-dark-theme = "dark" additional-css = ["theme/page-toc.css", "theme/plugins.css", "theme/highlight.css"] additional-js = ["theme/page-toc.js", "theme/plugins.js"] -[output.zed-html.print] +[output.html.print] enable = false # Redirects for `/docs` pages. @@ -38,7 +24,7 @@ enable = false # The destination URLs are interpreted relative to `https://zed.dev`. # - Redirects to other docs pages should end in `.html` # - You can link to pages on the Zed site by omitting the `/docs` in front of it. -[output.zed-html.redirect] +[output.html.redirect] # AI "/ai.html" = "/docs/ai/overview.html" "/assistant-panel.html" = "/docs/ai/agent-panel.html" @@ -54,7 +40,6 @@ enable = false "/assistant/prompting.html" = "/docs/ai/rules.html" "/language-model-integration.html" = "/docs/assistant/assistant.html" "/model-improvement.html" = "/docs/ai/ai-improvement.html" -"/ai/temperature.html" = "/docs/ai/agent-settings.html#model-temperature" # Community "/community/feedback.html" = "/community-links" diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index fc936d6bd0..1d43872547 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -45,14 +45,13 @@ - [Overview](./ai/overview.md) - [Agent Panel](./ai/agent-panel.md) - [Tools](./ai/tools.md) + - [Model Temperature](./ai/temperature.md) - [Inline Assistant](./ai/inline-assistant.md) - [Edit Prediction](./ai/edit-prediction.md) - [Text Threads](./ai/text-threads.md) - [Rules](./ai/rules.md) - [Model Context Protocol](./ai/mcp.md) - [Configuration](./ai/configuration.md) - - [LLM Providers](./ai/llm-providers.md) - - [Agent Settings](./ai/agent-settings.md) - [Subscription](./ai/subscription.md) - [Plans and Usage](./ai/plans-and-usage.md) - [Billing](./ai/billing.md) diff --git a/docs/src/accounts.md b/docs/src/accounts.md index 1ce23cf902..c13c98ad9a 100644 --- a/docs/src/accounts.md +++ b/docs/src/accounts.md @@ -5,7 +5,7 @@ Signing in to Zed is not a requirement. You can use most features you'd expect i ## What Features Require Signing In? 1. All real-time [collaboration features](./collaboration.md). -2. [LLM-powered features](./ai/overview.md), if you are using Zed as the provider of your LLM models. Alternatively, you can [bring and configure your own API keys](./ai/llm-providers.md#use-your-own-keys) if you'd prefer, and avoid having to sign in. +2. [LLM-powered features](./ai/overview.md), if you are using Zed as the provider of your LLM models. Alternatively, you can [bring and configure your own API keys](./ai/configuration.md#use-your-own-keys) if you'd prefer, and avoid having to sign in. ## Signing In diff --git a/docs/src/ai/agent-panel.md b/docs/src/ai/agent-panel.md index f944eb88b0..97568d6643 100644 --- a/docs/src/ai/agent-panel.md +++ b/docs/src/ai/agent-panel.md @@ -8,7 +8,7 @@ If you're using the Agent Panel for the first time, you need to have at least on You can do that by: 1. [subscribing to our Pro plan](https://zed.dev/pricing), so you have access to our hosted models -2. or by [bringing your own API keys](./llm-providers.md#use-your-own-keys) for your desired provider +2. or by [bringing your own API keys](./configuration.md#use-your-own-keys) for your desired provider ## Overview {#overview} @@ -87,7 +87,7 @@ You can also do this at any time with an ongoing thread via the "Agent Options" ## Changing Models {#changing-models} -After you've configured your LLM providers—either via [a custom API key](./llm-providers.md#use-your-own-keys) or through [Zed's hosted models](./models.md)—you can switch between them by clicking on the model selector on the message editor or by using the {#kb agent::ToggleModelSelector} keybinding. +After you've configured your LLM providers—either via [a custom API key](./configuration.md#use-your-own-keys) or through [Zed's hosted models](./models.md)—you can switch between them by clicking on the model selector on the message editor or by using the {#kb agent::ToggleModelSelector} keybinding. ## Using Tools {#using-tools} diff --git a/docs/src/ai/agent-settings.md b/docs/src/ai/agent-settings.md deleted file mode 100644 index ff97bcb8ee..0000000000 --- a/docs/src/ai/agent-settings.md +++ /dev/null @@ -1,226 +0,0 @@ -# Agent Settings - -Learn about all the settings you can customize in Zed's Agent Panel. - -## Model Settings {#model-settings} - -### Default Model {#default-model} - -If you're using [Zed's hosted LLM service](./plans-and-usage.md), it sets `claude-sonnet-4` as the default model. -But if you're not subscribed to it or simply just want to change it, you can do it so either via the model dropdown in the Agent Panel's bottom-right corner or by manually editing the `default_model` object in your settings: - -```json -{ - "agent": { - "default_model": { - "provider": "zed.dev", - "model": "gpt-4o" - } - } -} -``` - -### Feature-specific Models {#feature-specific-models} - -Assign distinct and specific models for the following AI-powered features in Zed: - -- Thread summary model: Used for generating thread summaries -- Inline assistant model: Used for the inline assistant feature -- Commit message model: Used for generating Git commit messages - -```json -{ - "agent": { - "default_model": { - "provider": "zed.dev", - "model": "claude-sonnet-4" - }, - "inline_assistant_model": { - "provider": "anthropic", - "model": "claude-3-5-sonnet" - }, - "commit_message_model": { - "provider": "openai", - "model": "gpt-4o-mini" - }, - "thread_summary_model": { - "provider": "google", - "model": "gemini-2.0-flash" - } - } -} -``` - -> If a custom model isn't set for one of these features, they automatically fall back to using the default model. - -### Alternative Models for Inline Assists {#alternative-assists} - -The Inline Assist feature in particular has the capacity to perform multiple generations in parallel using different models. -That is possible by assigning more than one model to it, taking the configuration shown above one step further. - -When configured, the inline assist UI will surface controls to cycle between the outputs generated by each model. - -The models you specify here are always used in _addition_ to your [default model](#default-model). - -For example, the following configuration will generate two outputs for every assist. -One with Claude Sonnet 4 (the default model), and one with GPT-4o. - -```json -{ - "agent": { - "default_model": { - "provider": "zed.dev", - "model": "claude-sonnet-4" - }, - "inline_alternatives": [ - { - "provider": "zed.dev", - "model": "gpt-4o" - } - ] - } -} -``` - -### Model Temperature - -Specify a custom temperature for a provider and/or model: - -```json -"model_parameters": [ - // To set parameters for all requests to OpenAI models: - { - "provider": "openai", - "temperature": 0.5 - }, - // To set parameters for all requests in general: - { - "temperature": 0 - }, - // To set parameters for a specific provider and model: - { - "provider": "zed.dev", - "model": "claude-sonnet-4", - "temperature": 1.0 - } -], -``` - -## Agent Panel Settings {#agent-panel-settings} - -Note that some of these settings are also surfaced in the Agent Panel's settings UI, which you can access either via the `agent: open settings` action or by the dropdown menu on the top-right corner of the panel. - -### Default View - -Use the `default_view` setting to change the default view of the Agent Panel. -You can choose between `thread` (the default) and `text_thread`: - -```json -{ - "agent": { - "default_view": "text_thread" - } -} -``` - -### Auto-run Commands - -Control whether you want to allow the agent to run commands without asking you for permission. -The default value is `false`. - -```json -{ - "agent": { - "always_allow_tool_actions": "true" - } -} -``` - -> This setting is available via the Agent Panel's settings UI. - -### Single-file Review - -Control whether you want to see review actions (accept & reject) in single buffers after the agent is done performing edits. -The default value is `false`. - -```json -{ - "agent": { - "single_file_review": "true" - } -} -``` - -When set to false, these controls are only available in the multibuffer review tab. - -> This setting is available via the Agent Panel's settings UI. - -### Sound Notification - -Control whether you want to hear a notification sound when the agent is done generating changes or needs your input. -The default value is `false`. - -```json -{ - "agent": { - "play_sound_when_agent_done": "true" - } -} -``` - -> This setting is available via the Agent Panel's settings UI. - -### Modifier to Send - -Make a modifier (`cmd` on macOS, `ctrl` on Linux) required to send messages. -This is encouraged for more thoughtful prompt crafting. -The default value is `false`. - -```json -{ - "agent": { - "use_modifier_to_send": "true" - } -} -``` - -> This setting is available via the Agent Panel's settings UI. - -### Edit Card - -Use the `expand_edit_card` setting to control whether edit cards show the full diff in the Agent Panel. -It is set to `true` by default, but if set to false, the card's height is capped to a certain number of lines, requiring a click to be expanded. - -```json -{ - "agent": { - "expand_edit_card": "false" - } -} -``` - -### Terminal Card - -Use the `expand_terminal_card` setting to control whether terminal cards show the command output in the Agent Panel. -It is set to `true` by default, but if set to false, the card will be fully collapsed even while the command is running, requiring a click to be expanded. - -```json -{ - "agent": { - "expand_terminal_card": "false" - } -} -``` - -### Feedback Controls - -Control whether you want to see the thumbs up/down buttons to give Zed feedback about the agent's performance. -The default value is `true`. - -```json -{ - "agent": { - "enable_feedback": "false" - } -} -``` diff --git a/docs/src/ai/billing.md b/docs/src/ai/billing.md index d519b136ae..c49bacd883 100644 --- a/docs/src/ai/billing.md +++ b/docs/src/ai/billing.md @@ -1,7 +1,7 @@ # Billing We use Stripe as our billing and payments provider. All Pro plans require payment via credit card. -For invoice-based billing, a Business plan is required. Contact [sales@zed.dev](mailto:sales@zed.dev) for more information. +For invoice-based billing, a Business plan is required. Contact sales@zed.dev for more information. ## Settings {#settings} @@ -12,8 +12,7 @@ Clicking the button under Account Settings will navigate you to Stripe’s secur Zed is billed on a monthly basis based on the date you initially subscribe. -We’ll also bill in-month for additional prompts used beyond your plan’s prompt limit, if usage exceeds $20 before month end. -See [usage-based pricing](./plans-and-usage.md#ubp) for more. +We’ll also bill in-month for additional prompts used beyond your plan’s prompt limit, if usage exceeds $20 before month end. See [usage-based pricing](./plans-and-usage.md#ubp) for more. ## Invoice History {#invoice-history} @@ -26,12 +25,3 @@ From Stripe’s secure portal, you can download all current and historical invoi You can update your payment method, company name, address, and tax information through the billing portal. Please note that changes to billing information will **only** affect future invoices — **we cannot modify historical invoices**. - -## Sales Tax {#sales-tax} - -Zed partners with [Sphere](https://www.getsphere.com/) to calculate indirect tax rate for invoices, based on customer location and the product being sold. Tax is listed as a separate line item on invoices, based preferentially on your billing address, followed by the card issue country known to Stripe. - -If you have a VAT/GST ID, you can add it at [zed.dev/account](https://zed.dev/account) by clicking "Manage" on your subscription. Check the box that denotes you as a business. - -Please note that changes to VAT/GST IDs and address will **only** affect future invoices — **we cannot modify historical invoices**. -Questions or issues can be directed to [billing-support@zed.dev](mailto:billing-support@zed.dev). diff --git a/docs/src/ai/configuration.md b/docs/src/ai/configuration.md index d28a7e8ed0..414da2206f 100644 --- a/docs/src/ai/configuration.md +++ b/docs/src/ai/configuration.md @@ -1,20 +1,735 @@ # Configuration -When using AI in Zed, you can customize several aspects: +There are various aspects about the Agent Panel that you can customize. +All of them can be seen by either visiting [the Configuring Zed page](../configuring-zed.md#agent) or by running the `zed: open default settings` action and searching for `"agent"`. -1. Which [LLM providers](./llm-providers.md) you can use -2. [Model parameters and usage](./agent-settings.md#model-settings) -3. [Interactions with the Agent Panel](./agent-settings.md#agent-panel-settings) +Alternatively, you can also visit the panel's Settings view by running the `agent: open configuration` action or going to the top-right menu and hitting "Settings". -## Turning AI Off Entirely +## LLM Providers -We want to respect users who want to use Zed without interacting with AI whatsoever. -To do that, add the following key to your `settings.json`: +Zed supports multiple large language model providers. +Here's an overview of the supported providers and tool call support: + +| Provider | Tool Use Supported | +| ----------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [Amazon Bedrock](#amazon-bedrock) | Depends on the model | +| [Anthropic](#anthropic) | ✅ | +| [DeepSeek](#deepseek) | ✅ | +| [GitHub Copilot Chat](#github-copilot-chat) | For some models ([link](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198)) | +| [Google AI](#google-ai) | ✅ | +| [LM Studio](#lmstudio) | ✅ | +| [Mistral](#mistral) | ✅ | +| [Ollama](#ollama) | ✅ | +| [OpenAI](#openai) | ✅ | +| [OpenAI API Compatible](#openai-api-compatible) | 🚫 | +| [OpenRouter](#openrouter) | ✅ | +| [Vercel](#vercel-v0) | ✅ | +| [xAI](#xai) | ✅ | + +## Use Your Own Keys {#use-your-own-keys} + +While Zed offers hosted versions of models through [our various plans](./plans-and-usage.md), we're always happy to support users wanting to supply their own API keys. +Below, you can learn how to do that for each provider. + +> Using your own API keys is _free_—you do not need to subscribe to a Zed plan to use our AI features with your own keys. + +### Amazon Bedrock {#amazon-bedrock} + +> ✅ Supports tool use with models that support streaming tool use. +> More details can be found in the [Amazon Bedrock's Tool Use documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html). + +To use Amazon Bedrock's models, an AWS authentication is required. +Ensure your credentials have the following permissions set up: + +- `bedrock:InvokeModelWithResponseStream` +- `bedrock:InvokeModel` +- `bedrock:ConverseStream` + +Your IAM policy should look similar to: ```json { - "disable_ai": true + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModel", + "bedrock:InvokeModelWithResponseStream", + "bedrock:ConverseStream" + ], + "Resource": "*" + } + ] } ``` -Read [the following blog post](https://zed.dev/blog/disable-ai-features) to learn more about our motivation to promote this, as much as we also encourage users to explore AI-assisted programming. +With that done, choose one of the two authentication methods: + +#### Authentication via Named Profile (Recommended) + +1. Ensure you have the AWS CLI installed and configured with a named profile +2. Open your `settings.json` (`zed: open settings`) and include the `bedrock` key under `language_models` with the following settings: + ```json + { + "language_models": { + "bedrock": { + "authentication_method": "named_profile", + "region": "your-aws-region", + "profile": "your-profile-name" + } + } + } + ``` + +#### Authentication via Static Credentials + +While it's possible to configure through the Agent Panel settings UI by entering your AWS access key and secret directly, we recommend using named profiles instead for better security practices. +To do this: + +1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). +2. Create security credentials for that User, save them and keep them secure. +3. Open the Agent Configuration with (`agent: open configuration`) and go to the Amazon Bedrock section +4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields. + +#### Cross-Region Inference + +The Zed implementation of Amazon Bedrock uses [Cross-Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html) for all the models and region combinations that support it. +With Cross-Region inference, you can distribute traffic across multiple AWS Regions, enabling higher throughput. + +For example, if you use `Claude Sonnet 3.7 Thinking` from `us-east-1`, it may be processed across the US regions, namely: `us-east-1`, `us-east-2`, or `us-west-2`. +Cross-Region inference requests are kept within the AWS Regions that are part of the geography where the data originally resides. +For example, a request made within the US is kept within the AWS Regions in the US. + +Although the data remains stored only in the source Region, your input prompts and output results might move outside of your source Region during cross-Region inference. +All data will be transmitted encrypted across Amazon's secure network. + +We will support Cross-Region inference for each of the models on a best-effort basis, please refer to the [Cross-Region Inference method Code](https://github.com/zed-industries/zed/blob/main/crates/bedrock/src/models.rs#L297). + +For the most up-to-date supported regions and models, refer to the [Supported Models and Regions for Cross Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). + +### Anthropic {#anthropic} + +> ✅ Supports tool use + +You can use Anthropic models by choosing it via the model dropdown in the Agent Panel. + +1. Sign up for Anthropic and [create an API key](https://console.anthropic.com/settings/keys) +2. Make sure that your Anthropic account has credits +3. Open the settings view (`agent: open configuration`) and go to the Anthropic section +4. Enter your Anthropic API key + +Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API. + +Zed will also use the `ANTHROPIC_API_KEY` environment variable if it's defined. + +#### Custom Models {#anthropic-custom-models} + +You can add custom models to the Anthropic provider by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "anthropic": { + "available_models": [ + { + "name": "claude-3-5-sonnet-20240620", + "display_name": "Sonnet 2024-June", + "max_tokens": 128000, + "max_output_tokens": 2560, + "cache_configuration": { + "max_cache_anchors": 10, + "min_total_token": 10000, + "should_speculate": false + }, + "tool_override": "some-model-that-supports-toolcalling" + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. + +You can configure a model to use [extended thinking](https://docs.anthropic.com/en/docs/about-claude/models/extended-thinking-models) (if it supports it) by changing the mode in your model's configuration to `thinking`, for example: + +```json +{ + "name": "claude-sonnet-4-latest", + "display_name": "claude-sonnet-4-thinking", + "max_tokens": 200000, + "mode": { + "type": "thinking", + "budget_tokens": 4_096 + } +} +``` + +### DeepSeek {#deepseek} + +> ✅ Supports tool use + +1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) +2. Open the settings view (`agent: open configuration`) and go to the DeepSeek section +3. Enter your DeepSeek API key + +The DeepSeek API key will be saved in your keychain. + +Zed will also use the `DEEPSEEK_API_KEY` environment variable if it's defined. + +#### Custom Models {#deepseek-custom-models} + +The Zed agent comes pre-configured to use the latest version for common models (DeepSeek Chat, DeepSeek Reasoner). +If you wish to use alternate models or customize the API endpoint, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "deepseek": { + "api_url": "https://api.deepseek.com", + "available_models": [ + { + "name": "deepseek-chat", + "display_name": "DeepSeek Chat", + "max_tokens": 64000 + }, + { + "name": "deepseek-reasoner", + "display_name": "DeepSeek Reasoner", + "max_tokens": 64000, + "max_output_tokens": 4096 + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. +You can also modify the `api_url` to use a custom endpoint if needed. + +### GitHub Copilot Chat {#github-copilot-chat} + +> ✅ Supports tool use in some cases. +> Visit [the Copilot Chat code](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198) for the supported subset. + +You can use GitHub Copilot Chat with the Zed agent by choosing it via the model dropdown in the Agent Panel. + +1. Open the settings view (`agent: open configuration`) and go to the GitHub Copilot Chat section +2. Click on `Sign in to use GitHub Copilot`, follow the steps shown in the modal. + +Alternatively, you can provide an OAuth token via the `GH_COPILOT_TOKEN` environment variable. + +> **Note**: If you don't see specific models in the dropdown, you may need to enable them in your [GitHub Copilot settings](https://github.com/settings/copilot/features). + +To use Copilot Enterprise with Zed (for both agent and inline completions), you must configure your enterprise endpoint as described in [Configuring GitHub Copilot Enterprise](./edit-prediction.md#github-copilot-enterprise). + +### Google AI {#google-ai} + +> ✅ Supports tool use + +You can use Gemini models with the Zed agent by choosing it via the model dropdown in the Agent Panel. + +1. Go to the Google AI Studio site and [create an API key](https://aistudio.google.com/app/apikey). +2. Open the settings view (`agent: open configuration`) and go to the Google AI section +3. Enter your Google AI API key and press enter. + +The Google AI API key will be saved in your keychain. + +Zed will also use the `GEMINI_API_KEY` environment variable if it's defined. See [Using Gemini API keys](Using Gemini API keys) in the Gemini docs for more. + +#### Custom Models {#google-ai-custom-models} + +By default, Zed will use `stable` versions of models, but you can use specific versions of models, including [experimental models](https://ai.google.dev/gemini-api/docs/models/experimental-models). You can configure a model to use [thinking mode](https://ai.google.dev/gemini-api/docs/thinking) (if it supports it) by adding a `mode` configuration to your model. This is useful for controlling reasoning token usage and response speed. If not specified, Gemini will automatically choose the thinking budget. + +Here is an example of a custom Google AI model you could add to your Zed `settings.json`: + +```json +{ + "language_models": { + "google": { + "available_models": [ + { + "name": "gemini-2.5-flash-preview-05-20", + "display_name": "Gemini 2.5 Flash (Thinking)", + "max_tokens": 1000000, + "mode": { + "type": "thinking", + "budget_tokens": 24000 + } + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. + +### LM Studio {#lmstudio} + +> ✅ Supports tool use + +1. Download and install [the latest version of LM Studio](https://lmstudio.ai/download) +2. In the app press `cmd/ctrl-shift-m` and download at least one model (e.g., qwen2.5-coder-7b). Alternatively, you can get models via the LM Studio CLI: + + ```sh + lms get qwen2.5-coder-7b + ``` + +3. Make sure the LM Studio API server is running by executing: + + ```sh + lms server start + ``` + +Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless#run-the-llm-service-on-machine-login) to automate running the LM Studio server. + +### Mistral {#mistral} + +> ✅ Supports tool use + +1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) +2. Open the configuration view (`agent: open configuration`) and navigate to the Mistral section +3. Enter your Mistral API key + +The Mistral API key will be saved in your keychain. + +Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined. + +#### Custom Models {#mistral-custom-models} + +The Zed agent comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). +All the default models support tool use. +If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "mistral": { + "api_url": "https://api.mistral.ai/v1", + "available_models": [ + { + "name": "mistral-tiny-latest", + "display_name": "Mistral Tiny", + "max_tokens": 32000, + "max_output_tokens": 4096, + "max_completion_tokens": 1024, + "supports_tools": true, + "supports_images": false + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. + +### Ollama {#ollama} + +> ✅ Supports tool use + +Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`. + +1. Download one of the [available models](https://ollama.com/models), for example, for `mistral`: + + ```sh + ollama pull mistral + ``` + +2. Make sure that the Ollama server is running. You can start it either via running Ollama.app (macOS) or launching: + + ```sh + ollama serve + ``` + +3. In the Agent Panel, select one of the Ollama models using the model dropdown. + +#### Ollama Context Length {#ollama-context} + +Zed has pre-configured maximum context lengths (`max_tokens`) to match the capabilities of common models. +Zed API requests to Ollama include this as the `num_ctx` parameter, but the default values do not exceed `16384` so users with ~16GB of RAM are able to use most models out of the box. + +See [get_max_tokens in ollama.rs](https://github.com/zed-industries/zed/blob/main/crates/ollama/src/ollama.rs) for a complete set of defaults. + +> **Note**: Token counts displayed in the Agent Panel are only estimates and will differ from the model's native tokenizer. + +Depending on your hardware or use-case you may wish to limit or increase the context length for a specific model via settings.json: + +```json +{ + "language_models": { + "ollama": { + "api_url": "http://localhost:11434", + "available_models": [ + { + "name": "qwen2.5-coder", + "display_name": "qwen 2.5 coder 32K", + "max_tokens": 32768, + "supports_tools": true, + "supports_thinking": true, + "supports_images": true + } + ] + } + } +} +``` + +If you specify a context length that is too large for your hardware, Ollama will log an error. +You can watch these logs by running: `tail -f ~/.ollama/logs/ollama.log` (macOS) or `journalctl -u ollama -f` (Linux). +Depending on the memory available on your machine, you may need to adjust the context length to a smaller value. + +You may also optionally specify a value for `keep_alive` for each available model. +This can be an integer (seconds) or alternatively a string duration like "5m", "10m", "1h", "1d", etc. +For example, `"keep_alive": "120s"` will allow the remote server to unload the model (freeing up GPU VRAM) after 120 seconds. + +The `supports_tools` option controls whether the model will use additional tools. +If the model is tagged with `tools` in the Ollama catalog, this option should be supplied, and the built-in profiles `Ask` and `Write` can be used. +If the model is not tagged with `tools` in the Ollama catalog, this option can still be supplied with the value `true`; however, be aware that only the `Minimal` built-in profile will work. + +The `supports_thinking` option controls whether the model will perform an explicit "thinking" (reasoning) pass before producing its final answer. +If the model is tagged with `thinking` in the Ollama catalog, set this option and you can use it in Zed. + +The `supports_images` option enables the model's vision capabilities, allowing it to process images included in the conversation context. +If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in Zed. + +### OpenAI {#openai} + +> ✅ Supports tool use + +1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) +2. Make sure that your OpenAI account has credits +3. Open the settings view (`agent: open configuration`) and go to the OpenAI section +4. Enter your OpenAI API key + +The OpenAI API key will be saved in your keychain. + +Zed will also use the `OPENAI_API_KEY` environment variable if it's defined. + +#### Custom Models {#openai-custom-models} + +The Zed agent comes pre-configured to use the latest version for common models (GPT-3.5 Turbo, GPT-4, GPT-4 Turbo, GPT-4o, GPT-4o mini). +To use alternate models, perhaps a preview release or a dated model release, or if you wish to control the request parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "openai": { + "available_models": [ + { + "name": "gpt-4o-2024-08-06", + "display_name": "GPT 4o Summer 2024", + "max_tokens": 128000 + }, + { + "name": "o1-mini", + "display_name": "o1-mini", + "max_tokens": 128000, + "max_completion_tokens": 20000 + } + ], + "version": "1" + } + } +} +``` + +You must provide the model's context window in the `max_tokens` parameter; this can be found in the [OpenAI model documentation](https://platform.openai.com/docs/models). + +OpenAI `o1` models should set `max_completion_tokens` as well to avoid incurring high reasoning token costs. +Custom models will be listed in the model dropdown in the Agent Panel. + +### OpenAI API Compatible {#openai-api-compatible} + +Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models. + +To configure a compatible API, you can add a custom API URL for OpenAI either via the UI (currently available only in Preview) or by editing your `settings.json`. + +For example, to connect to [Together AI](https://www.together.ai/) via the UI: + +1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys). +2. Go to the Agent Panel's settings view, click on the "Add Provider" button, and then on the "OpenAI" menu item +3. Add the requested fields, such as `api_url`, `api_key`, available models, and others + +Alternatively, you can also add it via the `settings.json`: + +```json +{ + "language_models": { + "openai": { + "api_url": "https://api.together.xyz/v1", + "api_key": "YOUR_TOGETHER_AI_API_KEY", + "available_models": [ + { + "name": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "display_name": "Together Mixtral 8x7B", + "max_tokens": 32768, + "supports_tools": true + } + ] + } + } +} +``` + +### OpenRouter {#openrouter} + +> ✅ Supports tool use + +OpenRouter provides access to multiple AI models through a single API. It supports tool use for compatible models. + +1. Visit [OpenRouter](https://openrouter.ai) and create an account +2. Generate an API key from your [OpenRouter keys page](https://openrouter.ai/keys) +3. Open the settings view (`agent: open configuration`) and go to the OpenRouter section +4. Enter your OpenRouter API key + +The OpenRouter API key will be saved in your keychain. + +Zed will also use the `OPENROUTER_API_KEY` environment variable if it's defined. + +#### Custom Models {#openrouter-custom-models} + +You can add custom models to the OpenRouter provider by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "open_router": { + "api_url": "https://openrouter.ai/api/v1", + "available_models": [ + { + "name": "google/gemini-2.0-flash-thinking-exp", + "display_name": "Gemini 2.0 Flash (Thinking)", + "max_tokens": 200000, + "max_output_tokens": 8192, + "supports_tools": true, + "supports_images": true, + "mode": { + "type": "thinking", + "budget_tokens": 8000 + } + } + ] + } + } +} +``` + +The available configuration options for each model are: + +- `name` (required): The model identifier used by OpenRouter +- `display_name` (optional): A human-readable name shown in the UI +- `max_tokens` (required): The model's context window size +- `max_output_tokens` (optional): Maximum tokens the model can generate +- `max_completion_tokens` (optional): Maximum completion tokens +- `supports_tools` (optional): Whether the model supports tool/function calling +- `supports_images` (optional): Whether the model supports image inputs +- `mode` (optional): Special mode configuration for thinking models + +You can find available models and their specifications on the [OpenRouter models page](https://openrouter.ai/models). + +Custom models will be listed in the model dropdown in the Agent Panel. + +### Vercel v0 {#vercel-v0} + +> ✅ Supports tool use + +[Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel. +It supports text and image inputs and provides fast streaming responses. + +The v0 models are [OpenAI-compatible models](/#openai-api-compatible), but Vercel is listed as first-class provider in the panel's settings view. + +To start using it with Zed, ensure you have first created a [v0 API key](https://v0.dev/chat/settings/keys). +Once you have it, paste it directly into the Vercel provider section in the panel's settings view. + +You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. + +### xAI {#xai} + +> ✅ Supports tool use + +Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. + +1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) +2. Open the settings view (`agent: open configuration`) and go to the **xAI** section +3. Enter your xAI API key + +The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined. + +> **Note:** While the xAI API is OpenAI-compatible, Zed has first-class support for it as a dedicated provider. For the best experience, we recommend using the dedicated `x_ai` provider configuration instead of the [OpenAI API Compatible](#openai-api-compatible) method. + +#### Custom Models {#xai-custom-models} + +The Zed agent comes pre-configured with common Grok models. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "x_ai": { + "api_url": "https://api.x.ai/v1", + "available_models": [ + { + "name": "grok-1.5", + "display_name": "Grok 1.5", + "max_tokens": 131072, + "max_output_tokens": 8192 + }, + { + "name": "grok-1.5v", + "display_name": "Grok 1.5V (Vision)", + "max_tokens": 131072, + "max_output_tokens": 8192, + "supports_images": true + } + ] + } + } +} +``` + +## Advanced Configuration {#advanced-configuration} + +### Custom Provider Endpoints {#custom-provider-endpoint} + +You can use a custom API endpoint for different providers, as long as it's compatible with the provider's API structure. +To do so, add the following to your `settings.json`: + +```json +{ + "language_models": { + "some-provider": { + "api_url": "http://localhost:11434" + } + } +} +``` + +Where `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`. + +### Default Model {#default-model} + +Zed's hosted LLM service sets `claude-sonnet-4` as the default model. +However, you can change it either via the model dropdown in the Agent Panel's bottom-right corner or by manually editing the `default_model` object in your settings: + +```json +{ + "agent": { + "version": "2", + "default_model": { + "provider": "zed.dev", + "model": "gpt-4o" + } + } +} +``` + +### Feature-specific Models {#feature-specific-models} + +If a feature-specific model is not set, it will fall back to using the default model, which is the one you set on the Agent Panel. + +You can configure the following feature-specific models: + +- Thread summary model: Used for generating thread summaries +- Inline assistant model: Used for the inline assistant feature +- Commit message model: Used for generating Git commit messages + +Example configuration: + +```json +{ + "agent": { + "version": "2", + "default_model": { + "provider": "zed.dev", + "model": "claude-sonnet-4" + }, + "inline_assistant_model": { + "provider": "anthropic", + "model": "claude-3-5-sonnet" + }, + "commit_message_model": { + "provider": "openai", + "model": "gpt-4o-mini" + }, + "thread_summary_model": { + "provider": "google", + "model": "gemini-2.0-flash" + } + } +} +``` + +### Alternative Models for Inline Assists {#alternative-assists} + +You can configure additional models that will be used to perform inline assists in parallel. +When you do this, the inline assist UI will surface controls to cycle between the alternatives generated by each model. + +The models you specify here are always used in _addition_ to your [default model](#default-model). +For example, the following configuration will generate two outputs for every assist. +One with Claude 3.7 Sonnet, and one with GPT-4o. + +```json +{ + "agent": { + "default_model": { + "provider": "zed.dev", + "model": "claude-sonnet-4" + }, + "inline_alternatives": [ + { + "provider": "zed.dev", + "model": "gpt-4o" + } + ], + "version": "2" + } +} +``` + +### Default View + +Use the `default_view` setting to set change the default view of the Agent Panel. +You can choose between `thread` (the default) and `text_thread`: + +```json +{ + "agent": { + "default_view": "text_thread" + } +} +``` + +### Edit Card + +Use the `expand_edit_card` setting to control whether edit cards show the full diff in the Agent Panel. +It is set to `true` by default, but if set to false, the card's height is capped to a certain number of lines, requiring a click to be expanded. + +```json +{ + "agent": { + "expand_edit_card": "false" + } +} +``` + +This setting is currently only available in Preview. +It should be up in Stable by the next release. + +### Terminal Card + +Use the `expand_terminal_card` setting to control whether terminal cards show the command output in the Agent Panel. +It is set to `true` by default, but if set to false, the card will be fully collapsed even while the command is running, requiring a click to be expanded. + +```json +{ + "agent": { + "expand_terminal_card": "false" + } +} +``` + +This setting is currently only available in Preview. +It should be up in Stable by the next release. diff --git a/docs/src/ai/inline-assistant.md b/docs/src/ai/inline-assistant.md index da894e2cd8..cd0ace3ce6 100644 --- a/docs/src/ai/inline-assistant.md +++ b/docs/src/ai/inline-assistant.md @@ -12,7 +12,7 @@ You can also perform multiple generation requests in parallel by pressing `ctrl- Give the Inline Assistant context the same way you can in [the Agent Panel](./agent-panel.md), allowing you to provide additional instructions or rules for code transformations with @-mentions. -A useful pattern here is to create a thread in the Agent Panel, and then mention that thread with `@thread` in the Inline Assistant to include it as context. +A useful pattern here is to create a thread in the Agent Panel, and then use the mention that thread with `@thread` in the Inline Assistant to include it as context. > The Inline Assistant is limited to normal mode context windows ([see Models](./models.md) for more). diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md deleted file mode 100644 index a6e6f7c774..0000000000 --- a/docs/src/ai/llm-providers.md +++ /dev/null @@ -1,606 +0,0 @@ -# LLM Providers - -To use AI in Zed, you need to have at least one large language model provider set up. - -You can do that by either subscribing to [one of Zed's plans](./plans-and-usage.md), or by using API keys you already have for the supported providers. - -## Use Your Own Keys {#use-your-own-keys} - -If you already have an API key for an existing LLM provider—say Anthropic or OpenAI, for example—you can insert them in Zed and use the Agent Panel **_for free_**. - -You can add your API key to a given provider either via the Agent Panel's settings UI or directly via the `settings.json` through the `language_models` key. - -## Supported Providers - -Here's all the supported LLM providers for which you can use your own API keys: - -| Provider | Tool Use Supported | -| ----------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [Amazon Bedrock](#amazon-bedrock) | Depends on the model | -| [Anthropic](#anthropic) | ✅ | -| [DeepSeek](#deepseek) | ✅ | -| [GitHub Copilot Chat](#github-copilot-chat) | For some models ([link](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198)) | -| [Google AI](#google-ai) | ✅ | -| [LM Studio](#lmstudio) | ✅ | -| [Mistral](#mistral) | ✅ | -| [Ollama](#ollama) | ✅ | -| [OpenAI](#openai) | ✅ | -| [OpenAI API Compatible](#openai-api-compatible) | ✅ | -| [OpenRouter](#openrouter) | ✅ | -| [Vercel](#vercel-v0) | ✅ | -| [xAI](#xai) | ✅ | - -### Amazon Bedrock {#amazon-bedrock} - -> ✅ Supports tool use with models that support streaming tool use. -> More details can be found in the [Amazon Bedrock's Tool Use documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html). - -To use Amazon Bedrock's models, an AWS authentication is required. -Ensure your credentials have the following permissions set up: - -- `bedrock:InvokeModelWithResponseStream` -- `bedrock:InvokeModel` -- `bedrock:ConverseStream` - -Your IAM policy should look similar to: - -```json -{ - "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": [ - "bedrock:InvokeModel", - "bedrock:InvokeModelWithResponseStream", - "bedrock:ConverseStream" - ], - "Resource": "*" - } - ] -} -``` - -With that done, choose one of the two authentication methods: - -#### Authentication via Named Profile (Recommended) - -1. Ensure you have the AWS CLI installed and configured with a named profile -2. Open your `settings.json` (`zed: open settings`) and include the `bedrock` key under `language_models` with the following settings: - ```json - { - "language_models": { - "bedrock": { - "authentication_method": "named_profile", - "region": "your-aws-region", - "profile": "your-profile-name" - } - } - } - ``` - -#### Authentication via Static Credentials - -While it's possible to configure through the Agent Panel settings UI by entering your AWS access key and secret directly, we recommend using named profiles instead for better security practices. -To do this: - -1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). -2. Create security credentials for that User, save them and keep them secure. -3. Open the Agent Configuration with (`agent: open settings`) and go to the Amazon Bedrock section -4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields. - -#### Cross-Region Inference - -The Zed implementation of Amazon Bedrock uses [Cross-Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html) for all the models and region combinations that support it. -With Cross-Region inference, you can distribute traffic across multiple AWS Regions, enabling higher throughput. - -For example, if you use `Claude Sonnet 3.7 Thinking` from `us-east-1`, it may be processed across the US regions, namely: `us-east-1`, `us-east-2`, or `us-west-2`. -Cross-Region inference requests are kept within the AWS Regions that are part of the geography where the data originally resides. -For example, a request made within the US is kept within the AWS Regions in the US. - -Although the data remains stored only in the source Region, your input prompts and output results might move outside of your source Region during cross-Region inference. -All data will be transmitted encrypted across Amazon's secure network. - -We will support Cross-Region inference for each of the models on a best-effort basis, please refer to the [Cross-Region Inference method Code](https://github.com/zed-industries/zed/blob/main/crates/bedrock/src/models.rs#L297). - -For the most up-to-date supported regions and models, refer to the [Supported Models and Regions for Cross Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). - -### Anthropic {#anthropic} - -> ✅ Supports tool use - -You can use Anthropic models by choosing them via the model dropdown in the Agent Panel. - -1. Sign up for Anthropic and [create an API key](https://console.anthropic.com/settings/keys) -2. Make sure that your Anthropic account has credits -3. Open the settings view (`agent: open settings`) and go to the Anthropic section -4. Enter your Anthropic API key - -Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API. - -Zed will also use the `ANTHROPIC_API_KEY` environment variable if it's defined. - -#### Custom Models {#anthropic-custom-models} - -You can add custom models to the Anthropic provider by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "anthropic": { - "available_models": [ - { - "name": "claude-3-5-sonnet-20240620", - "display_name": "Sonnet 2024-June", - "max_tokens": 128000, - "max_output_tokens": 2560, - "cache_configuration": { - "max_cache_anchors": 10, - "min_total_token": 10000, - "should_speculate": false - }, - "tool_override": "some-model-that-supports-toolcalling" - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. - -You can configure a model to use [extended thinking](https://docs.anthropic.com/en/docs/about-claude/models/extended-thinking-models) (if it supports it) by changing the mode in your model's configuration to `thinking`, for example: - -```json -{ - "name": "claude-sonnet-4-latest", - "display_name": "claude-sonnet-4-thinking", - "max_tokens": 200000, - "mode": { - "type": "thinking", - "budget_tokens": 4_096 - } -} -``` - -### DeepSeek {#deepseek} - -> ✅ Supports tool use - -1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) -2. Open the settings view (`agent: open settings`) and go to the DeepSeek section -3. Enter your DeepSeek API key - -The DeepSeek API key will be saved in your keychain. - -Zed will also use the `DEEPSEEK_API_KEY` environment variable if it's defined. - -#### Custom Models {#deepseek-custom-models} - -The Zed agent comes pre-configured to use the latest version for common models (DeepSeek Chat, DeepSeek Reasoner). -If you wish to use alternate models or customize the API endpoint, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "deepseek": { - "api_url": "https://api.deepseek.com", - "available_models": [ - { - "name": "deepseek-chat", - "display_name": "DeepSeek Chat", - "max_tokens": 64000 - }, - { - "name": "deepseek-reasoner", - "display_name": "DeepSeek Reasoner", - "max_tokens": 64000, - "max_output_tokens": 4096 - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. -You can also modify the `api_url` to use a custom endpoint if needed. - -### GitHub Copilot Chat {#github-copilot-chat} - -> ✅ Supports tool use in some cases. -> Visit [the Copilot Chat code](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198) for the supported subset. - -You can use GitHub Copilot Chat with the Zed agent by choosing it via the model dropdown in the Agent Panel. - -1. Open the settings view (`agent: open settings`) and go to the GitHub Copilot Chat section -2. Click on `Sign in to use GitHub Copilot`, follow the steps shown in the modal. - -Alternatively, you can provide an OAuth token via the `GH_COPILOT_TOKEN` environment variable. - -> **Note**: If you don't see specific models in the dropdown, you may need to enable them in your [GitHub Copilot settings](https://github.com/settings/copilot/features). - -To use Copilot Enterprise with Zed (for both agent and inline completions), you must configure your enterprise endpoint as described in [Configuring GitHub Copilot Enterprise](./edit-prediction.md#github-copilot-enterprise). - -### Google AI {#google-ai} - -> ✅ Supports tool use - -You can use Gemini models with the Zed agent by choosing it via the model dropdown in the Agent Panel. - -1. Go to the Google AI Studio site and [create an API key](https://aistudio.google.com/app/apikey). -2. Open the settings view (`agent: open settings`) and go to the Google AI section -3. Enter your Google AI API key and press enter. - -The Google AI API key will be saved in your keychain. - -Zed will also use the `GEMINI_API_KEY` environment variable if it's defined. See [Using Gemini API keys](https://ai.google.dev/gemini-api/docs/api-key) in the Gemini docs for more. - -#### Custom Models {#google-ai-custom-models} - -By default, Zed will use `stable` versions of models, but you can use specific versions of models, including [experimental models](https://ai.google.dev/gemini-api/docs/models/experimental-models). You can configure a model to use [thinking mode](https://ai.google.dev/gemini-api/docs/thinking) (if it supports it) by adding a `mode` configuration to your model. This is useful for controlling reasoning token usage and response speed. If not specified, Gemini will automatically choose the thinking budget. - -Here is an example of a custom Google AI model you could add to your Zed `settings.json`: - -```json -{ - "language_models": { - "google": { - "available_models": [ - { - "name": "gemini-2.5-flash-preview-05-20", - "display_name": "Gemini 2.5 Flash (Thinking)", - "max_tokens": 1000000, - "mode": { - "type": "thinking", - "budget_tokens": 24000 - } - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. - -### LM Studio {#lmstudio} - -> ✅ Supports tool use - -1. Download and install [the latest version of LM Studio](https://lmstudio.ai/download) -2. In the app press `cmd/ctrl-shift-m` and download at least one model (e.g., qwen2.5-coder-7b). Alternatively, you can get models via the LM Studio CLI: - - ```sh - lms get qwen2.5-coder-7b - ``` - -3. Make sure the LM Studio API server is running by executing: - - ```sh - lms server start - ``` - -Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless#run-the-llm-service-on-machine-login) to automate running the LM Studio server. - -### Mistral {#mistral} - -> ✅ Supports tool use - -1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) -2. Open the configuration view (`agent: open settings`) and navigate to the Mistral section -3. Enter your Mistral API key - -The Mistral API key will be saved in your keychain. - -Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined. - -#### Custom Models {#mistral-custom-models} - -The Zed agent comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). -All the default models support tool use. -If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "mistral": { - "api_url": "https://api.mistral.ai/v1", - "available_models": [ - { - "name": "mistral-tiny-latest", - "display_name": "Mistral Tiny", - "max_tokens": 32000, - "max_output_tokens": 4096, - "max_completion_tokens": 1024, - "supports_tools": true, - "supports_images": false - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. - -### Ollama {#ollama} - -> ✅ Supports tool use - -Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`. - -1. Download one of the [available models](https://ollama.com/models), for example, for `mistral`: - - ```sh - ollama pull mistral - ``` - -2. Make sure that the Ollama server is running. You can start it either via running Ollama.app (macOS) or launching: - - ```sh - ollama serve - ``` - -3. In the Agent Panel, select one of the Ollama models using the model dropdown. - -#### Ollama Context Length {#ollama-context} - -Zed has pre-configured maximum context lengths (`max_tokens`) to match the capabilities of common models. -Zed API requests to Ollama include this as the `num_ctx` parameter, but the default values do not exceed `16384` so users with ~16GB of RAM are able to use most models out of the box. - -See [get_max_tokens in ollama.rs](https://github.com/zed-industries/zed/blob/main/crates/ollama/src/ollama.rs) for a complete set of defaults. - -> **Note**: Token counts displayed in the Agent Panel are only estimates and will differ from the model's native tokenizer. - -Depending on your hardware or use-case you may wish to limit or increase the context length for a specific model via settings.json: - -```json -{ - "language_models": { - "ollama": { - "api_url": "http://localhost:11434", - "available_models": [ - { - "name": "qwen2.5-coder", - "display_name": "qwen 2.5 coder 32K", - "max_tokens": 32768, - "supports_tools": true, - "supports_thinking": true, - "supports_images": true - } - ] - } - } -} -``` - -If you specify a context length that is too large for your hardware, Ollama will log an error. -You can watch these logs by running: `tail -f ~/.ollama/logs/ollama.log` (macOS) or `journalctl -u ollama -f` (Linux). -Depending on the memory available on your machine, you may need to adjust the context length to a smaller value. - -You may also optionally specify a value for `keep_alive` for each available model. -This can be an integer (seconds) or alternatively a string duration like "5m", "10m", "1h", "1d", etc. -For example, `"keep_alive": "120s"` will allow the remote server to unload the model (freeing up GPU VRAM) after 120 seconds. - -The `supports_tools` option controls whether the model will use additional tools. -If the model is tagged with `tools` in the Ollama catalog, this option should be supplied, and the built-in profiles `Ask` and `Write` can be used. -If the model is not tagged with `tools` in the Ollama catalog, this option can still be supplied with the value `true`; however, be aware that only the `Minimal` built-in profile will work. - -The `supports_thinking` option controls whether the model will perform an explicit "thinking" (reasoning) pass before producing its final answer. -If the model is tagged with `thinking` in the Ollama catalog, set this option and you can use it in Zed. - -The `supports_images` option enables the model's vision capabilities, allowing it to process images included in the conversation context. -If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in Zed. - -### OpenAI {#openai} - -> ✅ Supports tool use - -1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) -2. Make sure that your OpenAI account has credits -3. Open the settings view (`agent: open settings`) and go to the OpenAI section -4. Enter your OpenAI API key - -The OpenAI API key will be saved in your keychain. - -Zed will also use the `OPENAI_API_KEY` environment variable if it's defined. - -#### Custom Models {#openai-custom-models} - -The Zed agent comes pre-configured to use the latest version for common models (GPT-3.5 Turbo, GPT-4, GPT-4 Turbo, GPT-4o, GPT-4o mini). -To use alternate models, perhaps a preview release or a dated model release, or if you wish to control the request parameters, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "openai": { - "available_models": [ - { - "name": "gpt-4o-2024-08-06", - "display_name": "GPT 4o Summer 2024", - "max_tokens": 128000 - }, - { - "name": "o1-mini", - "display_name": "o1-mini", - "max_tokens": 128000, - "max_completion_tokens": 20000 - } - ], - "version": "1" - } - } -} -``` - -You must provide the model's context window in the `max_tokens` parameter; this can be found in the [OpenAI model documentation](https://platform.openai.com/docs/models). - -OpenAI `o1` models should set `max_completion_tokens` as well to avoid incurring high reasoning token costs. -Custom models will be listed in the model dropdown in the Agent Panel. - -### OpenAI API Compatible {#openai-api-compatible} - -Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider. -This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models. - -You can add a custom, OpenAI-compatible model via either via the UI or by editing your `settings.json`. - -To do it via the UI, go to the Agent Panel settings (`agent: open settings`) and look for the "Add Provider" button to the right of the "LLM Providers" section title. -Then, fill up the input fields available in the modal. - -To do it via your `settings.json`, add the following snippet under `language_models`: - -```json -{ - "language_models": { - "openai": { - "api_url": "https://api.together.xyz/v1", // Using Together AI as an example - "available_models": [ - { - "name": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "display_name": "Together Mixtral 8x7B", - "max_tokens": 32768 - } - ] - } - } -} -``` - -Note that LLM API keys aren't stored in your settings file. -So, ensure you have it set in your environment variables (`OPENAI_API_KEY=`) so your settings can pick it up. - -### OpenRouter {#openrouter} - -> ✅ Supports tool use - -OpenRouter provides access to multiple AI models through a single API. It supports tool use for compatible models. - -1. Visit [OpenRouter](https://openrouter.ai) and create an account -2. Generate an API key from your [OpenRouter keys page](https://openrouter.ai/keys) -3. Open the settings view (`agent: open settings`) and go to the OpenRouter section -4. Enter your OpenRouter API key - -The OpenRouter API key will be saved in your keychain. - -Zed will also use the `OPENROUTER_API_KEY` environment variable if it's defined. - -#### Custom Models {#openrouter-custom-models} - -You can add custom models to the OpenRouter provider by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "open_router": { - "api_url": "https://openrouter.ai/api/v1", - "available_models": [ - { - "name": "google/gemini-2.0-flash-thinking-exp", - "display_name": "Gemini 2.0 Flash (Thinking)", - "max_tokens": 200000, - "max_output_tokens": 8192, - "supports_tools": true, - "supports_images": true, - "mode": { - "type": "thinking", - "budget_tokens": 8000 - } - } - ] - } - } -} -``` - -The available configuration options for each model are: - -- `name` (required): The model identifier used by OpenRouter -- `display_name` (optional): A human-readable name shown in the UI -- `max_tokens` (required): The model's context window size -- `max_output_tokens` (optional): Maximum tokens the model can generate -- `max_completion_tokens` (optional): Maximum completion tokens -- `supports_tools` (optional): Whether the model supports tool/function calling -- `supports_images` (optional): Whether the model supports image inputs -- `mode` (optional): Special mode configuration for thinking models - -You can find available models and their specifications on the [OpenRouter models page](https://openrouter.ai/models). - -Custom models will be listed in the model dropdown in the Agent Panel. - -### Vercel v0 {#vercel-v0} - -> ✅ Supports tool use - -[Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel. -It supports text and image inputs and provides fast streaming responses. - -The v0 models are [OpenAI-compatible models](/#openai-api-compatible), but Vercel is listed as first-class provider in the panel's settings view. - -To start using it with Zed, ensure you have first created a [v0 API key](https://v0.dev/chat/settings/keys). -Once you have it, paste it directly into the Vercel provider section in the panel's settings view. - -You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. - -### xAI {#xai} - -> ✅ Supports tool use - -Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. - -1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) -2. Open the settings view (`agent: open settings`) and go to the **xAI** section -3. Enter your xAI API key - -The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined. - -> **Note:** While the xAI API is OpenAI-compatible, Zed has first-class support for it as a dedicated provider. For the best experience, we recommend using the dedicated `x_ai` provider configuration instead of the [OpenAI API Compatible](#openai-api-compatible) method. - -#### Custom Models {#xai-custom-models} - -The Zed agent comes pre-configured with common Grok models. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "x_ai": { - "api_url": "https://api.x.ai/v1", - "available_models": [ - { - "name": "grok-1.5", - "display_name": "Grok 1.5", - "max_tokens": 131072, - "max_output_tokens": 8192 - }, - { - "name": "grok-1.5v", - "display_name": "Grok 1.5V (Vision)", - "max_tokens": 131072, - "max_output_tokens": 8192, - "supports_images": true - } - ] - } - } -} -``` - -## Custom Provider Endpoints {#custom-provider-endpoint} - -You can use a custom API endpoint for different providers, as long as it's compatible with the provider's API structure. -To do so, add the following to your `settings.json`: - -```json -{ - "language_models": { - "some-provider": { - "api_url": "http://localhost:11434" - } - } -} -``` - -Currently, `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`. - -This is the same infrastructure that powers models that are, for example, [OpenAI-compatible](#openai-api-compatible). diff --git a/docs/src/ai/mcp.md b/docs/src/ai/mcp.md index dfe3e4bdb9..95929b2d7e 100644 --- a/docs/src/ai/mcp.md +++ b/docs/src/ai/mcp.md @@ -50,7 +50,7 @@ You can connect them by adding their commands directly to your `settings.json`, } ``` -Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open settings` action). +Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open configuration` action). From there, you can add it through the modal that appears when you click the "Add Custom Server" button. ## Using MCP Servers @@ -75,7 +75,7 @@ Mentioning your MCP server by name helps the agent pick it up. If you want to ensure a given server will be used, you can create [a custom profile](./agent-panel.md#custom-profiles) by turning off the built-in tools (either all of them or the ones that would cause conflicts) and turning on only the tools coming from the MCP server. -As an example, [the Dagger team suggests](https://container-use.com/agent-integrations#add-container-use-agent-profile-optional) doing that with their [Container Use MCP server](https://zed.dev/extensions/mcp-server-container-use): +As an example, [the Dagger team suggests](https://container-use.com/agent-integrations#add-container-use-agent-profile-optional) doing that with their [Container Use MCP server](https://zed.dev/extensions/container-use-mcp-server): ```json "agent": { diff --git a/docs/src/ai/overview.md b/docs/src/ai/overview.md index 6f081cb243..f437b24ba6 100644 --- a/docs/src/ai/overview.md +++ b/docs/src/ai/overview.md @@ -1,12 +1,15 @@ # AI -Learn how to get started using AI with Zed and all its capabilities. +Zed smoothly integrates LLMs in multiple ways across the editor. +Learn how to get started with AI on Zed and all its capabilities. ## Setting up AI in Zed - [Configuration](./configuration.md): Learn how to set up different language model providers like Anthropic, OpenAI, Ollama, Google AI, and more. -- [Subscription](./subscription.md): Learn about Zed's hosted model service and other billing-related information. +- [Models](./models.md): Learn about the various language models available in Zed. + +- [Subscription](./subscription.md): Learn about Zed's subscriptions and other billing-related information. - [Privacy and Security](./privacy-and-security.md): Understand how Zed handles privacy and security with AI features. diff --git a/docs/src/ai/plans-and-usage.md b/docs/src/ai/plans-and-usage.md index 1e6616c79b..a1da17f50d 100644 --- a/docs/src/ai/plans-and-usage.md +++ b/docs/src/ai/plans-and-usage.md @@ -11,7 +11,7 @@ Please note that if you’re interested in just using Zed as the world’s faste ## Usage {#usage} -- A `prompt` in Zed is an input from the user, initiated by pressing enter, composed of one or many `requests`. A `prompt` can be initiated from the Agent Panel, or via Inline Assist. +- A `prompt` in Zed is an input from the user, initiated on pressing enter, composed of one or many `requests`. A `prompt` can be initiated from the Agent Panel, or via Inline Assist. - A `request` in Zed is a response to a `prompt`, plus any tool calls that are initiated as part of that response. There may be one `request` per `prompt`, or many. Most models offered by Zed are metered per-prompt. diff --git a/docs/src/ai/rules.md b/docs/src/ai/rules.md index 653b907a7d..ed916874ca 100644 --- a/docs/src/ai/rules.md +++ b/docs/src/ai/rules.md @@ -5,7 +5,7 @@ Currently, Zed supports `.rules` files at the directory's root and the Rules Lib ## `.rules` files -Zed supports including `.rules` files at the top level of worktrees, and they act as project-level instructions that are included in all of your interactions with the Agent Panel. +Zed supports including `.rules` files at the top level of worktrees, and act as project-level instructions that are included in all of your interactions with the Agent Panel. Other names for this file are also supported for compatibility with other agents, but note that the first file which matches in this list will be used: - `.rules` diff --git a/docs/src/ai/temperature.md b/docs/src/ai/temperature.md new file mode 100644 index 0000000000..bb0cef6b51 --- /dev/null +++ b/docs/src/ai/temperature.md @@ -0,0 +1,23 @@ +# Model Temperature + +Zed's settings allow you to specify a custom temperature for a provider and/or model: + +```json +"model_parameters": [ + // To set parameters for all requests to OpenAI models: + { + "provider": "openai", + "temperature": 0.5 + }, + // To set parameters for all requests in general: + { + "temperature": 0 + }, + // To set parameters for a specific provider and model: + { + "provider": "zed.dev", + "model": "claude-sonnet-4", + "temperature": 1.0 + } + ], +``` diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index 5fd27abad6..cc4800fd6d 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -2588,7 +2588,6 @@ List of `integer` column numbers "font_features": null, "font_size": null, "line_height": "comfortable", - "minimum_contrast": 45, "option_as_meta": false, "button": true, "shell": "system", @@ -2884,30 +2883,6 @@ See Buffer Font Features } ``` -### Terminal: Minimum Contrast - -- Description: Controls the minimum contrast between foreground and background colors in the terminal. Uses the APCA (Accessible Perceptual Contrast Algorithm) for color adjustments. Set this to 0 to disable this feature. -- Setting: `minimum_contrast` -- Default: `45` - -**Options** - -`integer` values from 0 to 106. Common recommended values: - -- `0`: No contrast adjustment -- `45`: Minimum for large fluent text (default) -- `60`: Minimum for other content text -- `75`: Minimum for body text -- `90`: Preferred for body text - -```json -{ - "terminal": { - "minimum_contrast": 45 - } -} -``` - ### Terminal: Option As Meta - Description: Re-interprets the option keys to act like a 'meta' key, like in Emacs. @@ -3415,7 +3390,26 @@ Run the `theme selector: toggle` action in the command palette to see a current ## Agent -Visit [the Configuration page](./ai/configuration.md) under the AI section to learn more about all the agent-related settings. +- Description: Customize agent behavior +- Setting: `agent` +- Default: + +```json +"agent": { + "version": "2", + "enabled": true, + "button": true, + "dock": "right", + "default_width": 640, + "default_height": 320, + "default_view": "thread", + "default_model": { + "provider": "zed.dev", + "model": "claude-sonnet-4" + }, + "single_file_review": true, +} +``` ## Outline Panel diff --git a/docs/src/extensions/installing-extensions.md b/docs/src/extensions/installing-extensions.md index 801fe5c55c..aed8bef428 100644 --- a/docs/src/extensions/installing-extensions.md +++ b/docs/src/extensions/installing-extensions.md @@ -1,6 +1,6 @@ # Installing Extensions -You can search for extensions by launching the Zed Extension Gallery by pressing {#kb zed::Extensions} , opening the command palette and selecting {#action zed::Extensions} or by selecting "Zed > Extensions" from the menu bar. +You can search for extensions by launching the Zed Extension Gallery by pressing `cmd-shift-x` (macOS) or `ctrl-shift-x` (Linux), opening the command palette and selecting `zed: extensions` or by selecting "Zed > Extensions" from the menu bar. Here you can view the extensions that you currently have installed or search and install new ones. diff --git a/docs/src/getting-started.md b/docs/src/getting-started.md index 22af3b36d7..5940c74b21 100644 --- a/docs/src/getting-started.md +++ b/docs/src/getting-started.md @@ -83,6 +83,6 @@ Visit [the AI overview page](./ai/overview.md) to learn how to quickly get start ## Set up your key bindings -To edit your custom keymap and add or remap bindings, you can either use {#kb zed::OpenKeymapEditor} to spawn the Zed Keymap Editor ({#action zed::OpenKeymapEditor}) or you can directly open your Zed Keymap json (`~/.config/zed/keymap.json`) with {#action zed::OpenKeymap}. +To open your custom keymap to add your key bindings, use the {#kb zed::OpenKeymap} keybinding. To access the default key binding set, open the Command Palette with {#kb command_palette::Toggle} and search for "zed: open default keymap". See [Key Bindings](./key-bindings.md) for more info. diff --git a/docs/src/git.md b/docs/src/git.md index cccbad9b2e..76db15a767 100644 --- a/docs/src/git.md +++ b/docs/src/git.md @@ -1,8 +1,3 @@ ---- -description: Zed is a text editor that supports lots of Git features -title: Zed Editor Git integration documentation ---- - # Git Zed currently offers a set of fundamental Git features, with support coming in the future for more advanced ones, like conflict resolution tools, line by line staging, and more. @@ -81,7 +76,7 @@ You can ask AI to generate a commit message by focusing on the message editor wi > Note that you need to have an LLM provider configured. Visit [the AI configuration page](./ai/configuration.md) to learn how to do so. -You can specify your preferred model to use by providing a `commit_message_model` agent setting. See [Feature-specific models](./ai/agent-settings.md#feature-specific-models) for more information. +You can specify your preferred model to use by providing a `commit_message_model` agent setting. See [Feature-specific models](./ai/configuration.md#feature-specific-models) for more information. ```json { diff --git a/docs/src/key-bindings.md b/docs/src/key-bindings.md index 9984f234ad..90aa400bb4 100644 --- a/docs/src/key-bindings.md +++ b/docs/src/key-bindings.md @@ -18,7 +18,7 @@ You can also enable `vim_mode`, which adds vim bindings too. ## User keymaps -Zed reads your keymap from `~/.config/zed/keymap.json`. You can open the file within Zed with {#action zed::OpenKeymap} from the command palette or to spawn the Zed Keymap Editor ({#action zed::OpenKeymapEditor}) use {#kb zed::OpenKeymapEditor}. +Zed reads your keymap from `~/.config/zed/keymap.json`. You can open the file within Zed with {#kb zed::OpenKeymap}, or via `zed: Open Keymap` in the command palette. The file contains a JSON array of objects with `"bindings"`. If no `"context"` is set the bindings are always active. If it is set the binding is only active when the [context matches](#contexts). diff --git a/docs/src/languages/c.md b/docs/src/languages/c.md index 8db1bb6712..14a11c0d66 100644 --- a/docs/src/languages/c.md +++ b/docs/src/languages/c.md @@ -77,7 +77,7 @@ You can use CodeLLDB or GDB to debug native binaries. (Make sure that your build "command": "make", "args": ["-j8"], "cwd": "$ZED_WORKTREE_ROOT" - }, + } "program": "$ZED_WORKTREE_ROOT/build/prog", "request": "launch", "adapter": "CodeLLDB" diff --git a/docs/src/languages/cpp.md b/docs/src/languages/cpp.md index e84bb6ea50..1273bce2ac 100644 --- a/docs/src/languages/cpp.md +++ b/docs/src/languages/cpp.md @@ -127,7 +127,7 @@ You can use CodeLLDB or GDB to debug native binaries. (Make sure that your build "command": "make", "args": ["-j8"], "cwd": "$ZED_WORKTREE_ROOT" - }, + } "program": "$ZED_WORKTREE_ROOT/build/prog", "request": "launch", "adapter": "CodeLLDB" diff --git a/docs/src/languages/deno.md b/docs/src/languages/deno.md index c40b6531e6..c18b112326 100644 --- a/docs/src/languages/deno.md +++ b/docs/src/languages/deno.md @@ -57,40 +57,6 @@ See [Configuring supported languages](../configuring-languages.md) in the Zed do TBD: Deno Typescript REPL instructions [docs/repl#typescript-deno](../repl.md#typescript-deno) --> -## DAP support - -To debug deno programs, add this to `.zed/debug.json` - -```json -[ - { - "adapter": "JavaScript", - "label": "Deno", - "request": "launch", - "type": "pwa-node", - "cwd": "$ZED_WORKTREE_ROOT", - "program": "$ZED_FILE", - "runtimeExecutable": "deno", - "runtimeArgs": ["run", "--allow-all", "--inspect-wait"], - "attachSimplePort": 9229 - } -] -``` - -## Runnable support - -To run deno tasks like tests from the ui, add this to `.zed/tasks.json` - -```json -[ - { - "label": "deno test", - "command": "deno test -A --filter '/^$ZED_CUSTOM_DENO_TEST_NAME$/' $ZED_FILE", - "tags": ["js-test"] - } -] -``` - ## See also: - [TypeScript](./typescript.md) diff --git a/docs/src/linux.md b/docs/src/linux.md index 309354de6d..ca65da2969 100644 --- a/docs/src/linux.md +++ b/docs/src/linux.md @@ -294,78 +294,3 @@ If your system uses PipeWire: ``` 3. **Restart your system** - -### Forcing X11 scale factor - -On X11 systems, Zed automatically detects the appropriate scale factor for high-DPI displays. The scale factor is determined using the following priority order: - -1. `GPUI_X11_SCALE_FACTOR` environment variable (if set) -2. `Xft.dpi` from X resources database (xrdb) -3. Automatic detection via RandR based on monitor resolution and physical size - -If you want to customize the scale factor beyond what Zed detects automatically, you have several options: - -#### Check your current scale factor - -You can verify if you have `Xft.dpi` set: - -```sh -xrdb -query | grep Xft.dpi -``` - -If this command returns no output, Zed is using RandR (X11's monitor management extension) to automatically calculate the scale factor based on your monitor's reported resolution and physical dimensions. - -#### Option 1: Set Xft.dpi (X Resources Database) - -`Xft.dpi` is a standard X11 setting that many applications use for consistent font and UI scaling. Setting this ensures Zed scales the same way as other X11 applications that respect this setting. - -Edit or create the `~/.Xresources` file: - -```sh -vim ~/.Xresources -``` - -Add this line with your desired DPI: - -```sh -Xft.dpi: 96 -``` - -Common DPI values: - -- `96` for standard 1x scaling -- `144` for 1.5x scaling -- `192` for 2x scaling -- `288` for 3x scaling - -Load the configuration: - -```sh -xrdb -merge ~/.Xresources -``` - -Restart Zed for the changes to take effect. - -#### Option 2: Use the GPUI_X11_SCALE_FACTOR environment variable - -This Zed-specific environment variable directly sets the scale factor, bypassing all automatic detection. - -```sh -GPUI_X11_SCALE_FACTOR=1.5 zed -``` - -You can use decimal values (e.g., `1.25`, `1.5`, `2.0`) or set `GPUI_X11_SCALE_FACTOR=randr` to force RandR-based detection even when `Xft.dpi` is set. - -To make this permanent, add it to your shell profile or desktop entry. - -#### Option 3: Adjust system-wide RandR DPI - -This changes the reported DPI for your entire X11 session, affecting how RandR calculates scaling for all applications that use it. - -Add this to your `.xprofile` or `.xinitrc`: - -```sh -xrandr --dpi 192 -``` - -Replace `192` with your desired DPI value. This affects the system globally and will be used by Zed's automatic RandR detection when `Xft.dpi` is not set. diff --git a/docs/src/telemetry.md b/docs/src/telemetry.md index 7f5994be0c..20018b920a 100644 --- a/docs/src/telemetry.md +++ b/docs/src/telemetry.md @@ -22,9 +22,8 @@ The telemetry settings can also be configured via the welcome screen, which can Telemetry is sent from the application to our servers. Data is proxied through our servers to enable us to easily switch analytics services. We currently use: - [Axiom](https://axiom.co): Cloud-monitoring service - stores diagnostic events -- [Snowflake](https://snowflake.com): Data warehouse - stores both diagnostic and metric events -- [Hex](https://www.hex.tech): Dashboards and data exploration - accesses data stored in Snowflake -- [Amplitude](https://www.amplitude.com): Dashboards and data exploration - accesses data stored in Snowflake +- [Snowflake](https://snowflake.com): Business Intelligence platform - stores both diagnostic and metric events +- [Metabase](https://www.metabase.com): Dashboards - dashboards built around data pulled from Snowflake ## Types of Telemetry @@ -34,7 +33,7 @@ Diagnostic events include debug information (stack traces) from crash reports. R You can see what data is sent when a panic occurs by inspecting the `Panic` struct in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repo. You can find additional information in the [Debugging Crashes](./development/debugging-crashes.md) documentation. -### Client-Side Usage Data {#client-metrics} +### Usage Data (Metrics) {#metrics} To improve Zed and understand how it is being used in the wild, Zed optionally collects usage data like the following: @@ -51,12 +50,6 @@ You can audit the metrics data that Zed has reported by running the command {#ac You can see the full list of the event types and exactly the data sent for each by inspecting the `Event` enum and the associated structs in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repository. -### Server-Side Usage Data {#metrics} - -When using Zed's hosted services, we may collect, generate, and Process data to allow us to support users and improve our hosted offering. Examples include metadata around rate limiting and billing metrics/token usage. Zed does not persistently store user content or use user content to evaluate and/or improve our AI features, unless it is explicitly shared with Zed, and we have a zero-data retention agreement with Anthropic. - -You can see more about our stance on data collection (and that any prompt data shared with Zed is explicitly opt-in) at [AI Improvement](./ai/ai-improvement.md). - ## Concerns and Questions If you have concerns about telemetry, please feel free to [open an issue](https://github.com/zed-industries/zed/issues/new/choose). diff --git a/docs/src/visual-customization.md b/docs/src/visual-customization.md index 8b307d97d5..197c9b80f8 100644 --- a/docs/src/visual-customization.md +++ b/docs/src/visual-customization.md @@ -267,7 +267,7 @@ TBD: Centered layout related settings "display_in": "active_editor", // Where to show (active_editor, all_editor) "thumb": "always", // When to show thumb (always, hover) "thumb_border": "left_open", // Thumb border (left_open, right_open, full, none) - "max_width_columns": 80, // Maximum width of minimap + "max_width_columns": 80 // Maximum width of minimap "current_line_highlight": null // Highlight current line (null, line, gutter) }, diff --git a/docs/theme/index.hbs b/docs/theme/index.hbs index 4339a02d17..8ab4f21cf1 100644 --- a/docs/theme/index.hbs +++ b/docs/theme/index.hbs @@ -15,7 +15,7 @@ {{> head}} - + diff --git a/extensions/emmet/Cargo.toml b/extensions/emmet/Cargo.toml index 9d72a6c5c4..db8aaaae41 100644 --- a/extensions/emmet/Cargo.toml +++ b/extensions/emmet/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zed_emmet" -version = "0.0.4" +version = "0.0.3" edition.workspace = true publish.workspace = true license = "Apache-2.0" diff --git a/script/bundle-linux b/script/bundle-linux index 64de62ce9b..c52312015b 100755 --- a/script/bundle-linux +++ b/script/bundle-linux @@ -83,23 +83,6 @@ if [[ "$remote_server_triple" == "$musl_triple" ]]; then fi cargo build --release --target "${remote_server_triple}" --package remote_server -# Upload debug info to sentry.io -if ! command -v sentry-cli >/dev/null 2>&1; then - echo "sentry-cli not found. skipping sentry upload." - echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" -else - if [[ -n "${SENTRY_AUTH_TOKEN:-}" ]]; then - echo "Uploading zed debug symbols to sentry..." - # note: this uploads the unstripped binary which is needed because it contains - # .eh_frame data for stack unwinindg. see https://github.com/getsentry/symbolic/issues/783 - sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev \ - "${target_dir}/${target_triple}"/release/zed \ - "${target_dir}/${remote_server_triple}"/release/remote_server - else - echo "missing SENTRY_AUTH_TOKEN. skipping sentry upload." - fi -fi - # Strip debug symbols and save them for upload to DigitalOcean objcopy --only-keep-debug "${target_dir}/${target_triple}/release/zed" "${target_dir}/${target_triple}/release/zed.dbg" objcopy --only-keep-debug "${target_dir}/${remote_server_triple}/release/remote_server" "${target_dir}/${remote_server_triple}/release/remote_server.dbg" diff --git a/script/bundle-mac b/script/bundle-mac index b2be573235..18dfe90815 100755 --- a/script/bundle-mac +++ b/script/bundle-mac @@ -366,20 +366,3 @@ else gzip -f --stdout --best target/x86_64-apple-darwin/release/remote_server > target/zed-remote-server-macos-x86_64.gz gzip -f --stdout --best target/aarch64-apple-darwin/release/remote_server > target/zed-remote-server-macos-aarch64.gz fi - -# Upload debug info to sentry.io -if ! command -v sentry-cli >/dev/null 2>&1; then - echo "sentry-cli not found. skipping sentry upload." - echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" -else - if [[ -n "${SENTRY_AUTH_TOKEN:-}" ]]; then - echo "Uploading zed debug symbols to sentry..." - # note: this uploads the unstripped binary which is needed because it contains - # .eh_frame data for stack unwinindg. see https://github.com/getsentry/symbolic/issues/783 - sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev \ - "target/x86_64-apple-darwin/${target_dir}/" \ - "target/aarch64-apple-darwin/${target_dir}/" - else - echo "missing SENTRY_AUTH_TOKEN. skipping sentry upload." - fi -fi diff --git a/script/bundle-windows.ps1 b/script/bundle-windows.ps1 index 8ae0212491..01a1114c26 100644 --- a/script/bundle-windows.ps1 +++ b/script/bundle-windows.ps1 @@ -26,7 +26,6 @@ if ($Help) { Push-Location -Path crates/zed $channel = Get-Content "RELEASE_CHANNEL" $env:ZED_RELEASE_CHANNEL = $channel -$env:RELEASE_CHANNEL = $channel Pop-Location function CheckEnvironmentVariables { @@ -97,21 +96,6 @@ function ZipZedAndItsFriendsDebug { Compress-Archive -Path $items -DestinationPath ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" -Force } - -function UploadToSentry { - if (-not (Get-Command "sentry-cli" -ErrorAction SilentlyContinue)) { - Write-Output "sentry-cli not found. skipping sentry upload." - Write-Output "install with: 'winget install -e --id=Sentry.sentry-cli'" - return - } - if (-not (Test-Path "env:SENTRY_AUTH_TOKEN")) { - Write-Output "missing SENTRY_AUTH_TOKEN. skipping sentry upload." - return - } - Write-Output "Uploading zed debug symbols to sentry..." - sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev .\target\release\ -} - function MakeAppx { switch ($channel) { "stable" { @@ -136,22 +120,11 @@ function SignZedAndItsFriends { & "$innoDir\sign.ps1" $files } -function DownloadAMDGpuServices { - # If you update the AGS SDK version, please also update the version in `crates/gpui/src/platform/windows/directx_renderer.rs` - $url = "https://codeload.github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/zip/refs/tags/v6.3.0" - $zipPath = ".\AGS_SDK_v6.3.0.zip" - # Download the AGS SDK zip file - Invoke-WebRequest -Uri $url -OutFile $zipPath - # Extract the AGS SDK zip file - Expand-Archive -Path $zipPath -DestinationPath "." -Force -} - function CollectFiles { Move-Item -Path "$innoDir\zed_explorer_command_injector.appx" -Destination "$innoDir\appx\zed_explorer_command_injector.appx" -Force Move-Item -Path "$innoDir\zed_explorer_command_injector.dll" -Destination "$innoDir\appx\zed_explorer_command_injector.dll" -Force Move-Item -Path "$innoDir\cli.exe" -Destination "$innoDir\bin\zed.exe" -Force Move-Item -Path "$innoDir\auto_update_helper.exe" -Destination "$innoDir\tools\auto_update_helper.exe" -Force - Move-Item -Path ".\AGS_SDK-6.3.0\ags_lib\lib\amd_ags_x64.dll" -Destination "$innoDir\amd_ags_x64.dll" -Force } function BuildInstaller { @@ -222,6 +195,7 @@ function BuildInstaller { # Windows runner 2022 default has iscc in PATH, https://github.com/actions/runner-images/blob/main/images/windows/Windows2022-Readme.md # Currently, we are using Windows 2022 runner. # Windows runner 2025 doesn't have iscc in PATH for now, https://github.com/actions/runner-images/issues/11228 + # $innoSetupPath = "iscc.exe" $innoSetupPath = "C:\Program Files (x86)\Inno Setup 6\ISCC.exe" $definitions = @{ @@ -268,8 +242,6 @@ function BuildInstaller { ParseZedWorkspace $innoDir = "$env:ZED_WORKSPACE\inno" -$debugArchive = ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" -$debugStoreKey = "$env:ZED_RELEASE_CHANNEL/zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" CheckEnvironmentVariables PrepareForBundle @@ -278,12 +250,12 @@ BuildZedAndItsFriends MakeAppx SignZedAndItsFriends ZipZedAndItsFriendsDebug -DownloadAMDGpuServices CollectFiles BuildInstaller +$debugArchive = ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" +$debugStoreKey = "$env:ZED_RELEASE_CHANNEL/zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" UploadToBlobStorePublic -BucketName "zed-debug-symbols" -FileToUpload $debugArchive -BlobStoreKey $debugStoreKey -UploadToSentry if ($buildSuccess) { Write-Output "Build successful" diff --git a/script/linux b/script/linux index 029278bea3..98ae026896 100755 --- a/script/linux +++ b/script/linux @@ -143,7 +143,6 @@ if [[ -n $zyp ]]; then gzip jq libvulkan1 - libx11-devel libxcb-devel libxkbcommon-devel libxkbcommon-x11-devel diff --git a/script/zed-local b/script/zed-local index 99d9308232..2568931246 100755 --- a/script/zed-local +++ b/script/zed-local @@ -213,7 +213,7 @@ setTimeout(() => { platform === "win32" ? "http://127.0.0.1:8080/rpc" : "http://localhost:8080/rpc", - ZED_ADMIN_API_TOKEN: "internal-api-key-secret", + ZED_ADMIN_API_TOKEN: "secret", ZED_WINDOW_SIZE: size, ZED_CLIENT_CHECKSUM_SEED: "development-checksum-seed", RUST_LOG: process.env.RUST_LOG || "info", diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index 4196696f47..f84682ad89 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -284,6 +284,7 @@ winnow = { version = "0.7", features = ["simd"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } +coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } @@ -309,9 +310,11 @@ tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } [target.x86_64-apple-darwin.build-dependencies] +clang-sys = { version = "1", default-features = false, features = ["clang_11_0", "runtime"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } +coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } @@ -341,6 +344,7 @@ tower = { version = "0.5", default-features = false, features = ["timeout", "uti codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } +coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } @@ -366,9 +370,11 @@ tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } [target.aarch64-apple-darwin.build-dependencies] +clang-sys = { version = "1", default-features = false, features = ["clang_11_0", "runtime"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } +coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } @@ -558,6 +564,7 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } +naga = { version = "25", features = ["spv-out", "wgsl-in"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } scopeguard = { version = "1" } @@ -581,6 +588,7 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } +naga = { version = "25", features = ["spv-out", "wgsl-in"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } diff --git a/typos.toml b/typos.toml index 336a829a44..7f1c6e04f1 100644 --- a/typos.toml +++ b/typos.toml @@ -71,10 +71,6 @@ extend-ignore-re = [ # Not an actual typo but an intentionally invalid color, in `color_extractor` "#fof", # Stripped version of reserved keyword `type` - "typ", - # AMD GPU Services - "ags", - # AMD GPU Services - "AGS" + "typ" ] check-filename = true