diff --git a/.github/actionlint.yml b/.github/actionlint.yml index d93ec5b15e..6bfbc27705 100644 --- a/.github/actionlint.yml +++ b/.github/actionlint.yml @@ -24,7 +24,6 @@ self-hosted-runner: - buildjet-8vcpu-ubuntu-2204-arm - buildjet-16vcpu-ubuntu-2204-arm - buildjet-32vcpu-ubuntu-2204-arm - - buildjet-64vcpu-ubuntu-2204-arm # Self Hosted Runners - self-mini-macos - self-32vcpu-windows-2022 diff --git a/.github/actions/build_docs/action.yml b/.github/actions/build_docs/action.yml index 9a2d7e1ec7..a7effad247 100644 --- a/.github/actions/build_docs/action.yml +++ b/.github/actions/build_docs/action.yml @@ -19,7 +19,7 @@ runs: shell: bash -euxo pipefail {0} run: ./script/linux - - name: Check for broken links + - name: Check for broken links (in MD) uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 with: args: --no-progress --exclude '^http' './docs/src/**/*' @@ -30,3 +30,9 @@ runs: run: | mkdir -p target/deploy mdbook build ./docs --dest-dir=../target/deploy/docs/ + + - name: Check for broken links (in HTML) + uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 + with: + args: --no-progress --exclude '^http' 'target/deploy/docs/' + fail: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a9ef1531e7..43d305faae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,6 +24,7 @@ env: DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + ZED_MINIDUMP_ENDPOINT: ${{ secrets.ZED_SENTRY_MINIDUMP_ENDPOINT }} jobs: job_spec: @@ -649,7 +650,7 @@ jobs: timeout-minutes: 60 name: Linux arm64 release bundle runs-on: - - buildjet-16vcpu-ubuntu-2204-arm + - buildjet-32vcpu-ubuntu-2204-arm if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') @@ -771,7 +772,8 @@ jobs: timeout-minutes: 120 name: Create a Windows installer runs-on: [self-hosted, Windows, X64] - if: false && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) + if: contains(github.event.pull_request.labels.*.name, 'run-bundling') + # if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) needs: [windows_tests] env: AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }} diff --git a/.github/workflows/nix.yml b/.github/workflows/nix.yml index beacd27774..6c3a97c163 100644 --- a/.github/workflows/nix.yml +++ b/.github/workflows/nix.yml @@ -29,6 +29,7 @@ jobs: runs-on: ${{ matrix.system.runner }} env: ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + ZED_MINIDUMP_ENDPOINT: ${{ secrets.ZED_SENTRY_MINIDUMP_ENDPOINT }} ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }} GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on steps: diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index f799133ea7..c847149984 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -13,6 +13,7 @@ env: CARGO_INCREMENTAL: 0 RUST_BACKTRACE: 1 ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + ZED_MINIDUMP_ENDPOINT: ${{ secrets.ZED_SENTRY_MINIDUMP_ENDPOINT }} DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} @@ -111,6 +112,11 @@ jobs: echo "Publishing version: ${version} on release channel nightly" echo "nightly" > crates/zed/RELEASE_CHANNEL + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Create macOS app bundle run: script/bundle-mac @@ -136,6 +142,11 @@ jobs: - name: Install Linux dependencies run: ./script/linux && ./script/install-mold 2.34.0 + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Limit target directory size run: script/clear-target-dir-if-larger-than 100 @@ -157,7 +168,7 @@ jobs: name: Create a Linux *.tar.gz bundle for ARM if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-16vcpu-ubuntu-2204-arm + - buildjet-32vcpu-ubuntu-2204-arm needs: tests steps: - name: Checkout repo @@ -168,6 +179,11 @@ jobs: - name: Install Linux dependencies run: ./script/linux + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Limit target directory size run: script/clear-target-dir-if-larger-than 100 @@ -262,6 +278,11 @@ jobs: Write-Host "Publishing version: $version on release channel nightly" "nightly" | Set-Content -Path "crates/zed/RELEASE_CHANNEL" + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Build Zed installer working-directory: ${{ env.ZED_WORKSPACE }} run: script/bundle-windows.ps1 diff --git a/Cargo.lock b/Cargo.lock index 3477d1270d..cb493b2a05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,10 +7,8 @@ name = "acp_thread" version = "0.1.0" dependencies = [ "agent-client-protocol", - "agentic-coding-protocol", "anyhow", "assistant_tool", - "async-pipe", "buffer_diff", "editor", "env_logger 0.11.8", @@ -20,7 +18,9 @@ dependencies = [ "itertools 0.14.0", "language", "markdown", + "parking_lot", "project", + "rand 0.8.5", "serde", "serde_json", "settings", @@ -90,6 +90,7 @@ dependencies = [ "assistant_tools", "chrono", "client", + "cloud_llm_client", "collections", "component", "context_server", @@ -113,7 +114,6 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", - "proto", "rand 0.8.5", "ref-cast", "rope", @@ -132,16 +132,19 @@ dependencies = [ "uuid", "workspace", "workspace-hack", - "zed_llm_client", "zstd", ] [[package]] name = "agent-client-protocol" -version = "0.0.11" +version = "0.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b" +checksum = "f8e4c1dccb35e69d32566f0d11948d902f9942fc3f038821816c1150cf5925f4" dependencies = [ + "anyhow", + "futures 0.3.31", + "log", + "parking_lot", "schemars", "serde", "serde_json", @@ -176,6 +179,7 @@ dependencies = [ "smol", "strum 0.27.1", "tempfile", + "thiserror 2.0.12", "ui", "util", "uuid", @@ -189,6 +193,7 @@ name = "agent_settings" version = "0.1.0" dependencies = [ "anyhow", + "cloud_llm_client", "collections", "fs", "gpui", @@ -200,7 +205,6 @@ dependencies = [ "serde_json_lenient", "settings", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -223,6 +227,7 @@ dependencies = [ "buffer_diff", "chrono", "client", + "cloud_llm_client", "collections", "command_palette_hooks", "component", @@ -294,7 +299,6 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", - "zed_llm_client", ] [[package]] @@ -355,10 +359,10 @@ name = "ai_onboarding" version = "0.1.0" dependencies = [ "client", + "cloud_llm_client", "component", "gpui", "language_model", - "proto", "serde", "smallvec", "telemetry", @@ -687,6 +691,7 @@ dependencies = [ "chrono", "client", "clock", + "cloud_llm_client", "collections", "context_server", "fs", @@ -720,7 +725,6 @@ dependencies = [ "uuid", "workspace", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -828,6 +832,7 @@ dependencies = [ "chrono", "client", "clock", + "cloud_llm_client", "collections", "component", "derive_more 0.99.19", @@ -881,7 +886,6 @@ dependencies = [ "which 6.0.3", "workspace", "workspace-hack", - "zed_llm_client", "zlog", ] @@ -1075,17 +1079,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "async-recursion" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "async-recursion" version = "1.1.1" @@ -1179,7 +1172,7 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_qs 0.10.1", - "smart-default", + "smart-default 0.6.0", "smol_str 0.1.24", "thiserror 1.0.69", "tokio", @@ -2971,11 +2964,12 @@ name = "client" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 0.3.2", "async-tungstenite", "base64 0.22.1", "chrono", "clock", + "cloud_api_client", + "cloud_llm_client", "cocoa 0.26.0", "collections", "credentials_provider", @@ -3018,7 +3012,6 @@ dependencies = [ "windows 0.61.1", "workspace-hack", "worktree", - "zed_llm_client", ] [[package]] @@ -3031,6 +3024,44 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "cloud_api_client" +version = "0.1.0" +dependencies = [ + "anyhow", + "cloud_api_types", + "futures 0.3.31", + "http_client", + "parking_lot", + "serde_json", + "workspace-hack", +] + +[[package]] +name = "cloud_api_types" +version = "0.1.0" +dependencies = [ + "chrono", + "cloud_llm_client", + "pretty_assertions", + "serde", + "serde_json", + "workspace-hack", +] + +[[package]] +name = "cloud_llm_client" +version = "0.1.0" +dependencies = [ + "anyhow", + "pretty_assertions", + "serde", + "serde_json", + "strum 0.27.1", + "uuid", + "workspace-hack", +] + [[package]] name = "clru" version = "0.6.2" @@ -3157,6 +3188,7 @@ dependencies = [ "chrono", "client", "clock", + "cloud_llm_client", "collab_ui", "collections", "command_palette_hooks", @@ -3243,7 +3275,6 @@ dependencies = [ "workspace", "workspace-hack", "worktree", - "zed_llm_client", "zlog", ] @@ -3511,13 +3542,13 @@ dependencies = [ "command_palette_hooks", "ctor", "dirs 4.0.0", + "edit_prediction", "editor", "fs", "futures 0.3.31", "gpui", "http_client", "indoc", - "inline_completion", "itertools 0.14.0", "language", "log", @@ -3531,6 +3562,7 @@ dependencies = [ "serde", "serde_json", "settings", + "sum_tree", "task", "theme", "ui", @@ -3684,17 +3716,6 @@ dependencies = [ "libm", ] -[[package]] -name = "coreaudio-rs" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "321077172d79c662f64f5071a03120748d5bb652f5231570141be24cfcd2bace" -dependencies = [ - "bitflags 1.3.2", - "core-foundation-sys", - "coreaudio-sys", -] - [[package]] name = "coreaudio-rs" version = "0.12.1" @@ -3752,29 +3773,6 @@ dependencies = [ "unicode-segmentation", ] -[[package]] -name = "cpal" -version = "0.15.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "873dab07c8f743075e57f524c583985fbaf745602acbe916a01539364369a779" -dependencies = [ - "alsa", - "core-foundation-sys", - "coreaudio-rs 0.11.3", - "dasp_sample", - "jni", - "js-sys", - "libc", - "mach2", - "ndk 0.8.0", - "ndk-context", - "oboe", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", - "windows 0.54.0", -] - [[package]] name = "cpal" version = "0.16.0" @@ -3788,7 +3786,7 @@ dependencies = [ "js-sys", "libc", "mach2", - "ndk 0.9.0", + "ndk", "ndk-context", "num-derive", "num-traits", @@ -3929,6 +3927,42 @@ dependencies = [ "target-lexicon 0.13.2", ] +[[package]] +name = "crash-context" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031ed29858d90cfdf27fe49fae28028a1f20466db97962fa2f4ea34809aeebf3" +dependencies = [ + "cfg-if", + "libc", + "mach2", +] + +[[package]] +name = "crash-handler" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2066907075af649bcb8bcb1b9b986329b243677e6918b2d920aa64b0aac5ace3" +dependencies = [ + "cfg-if", + "crash-context", + "libc", + "mach2", + "parking_lot", +] + +[[package]] +name = "crashes" +version = "0.1.0" +dependencies = [ + "crash-handler", + "log", + "minidumper", + "paths", + "smol", + "workspace-hack", +] + [[package]] name = "crc" version = "3.2.1" @@ -4290,41 +4324,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "darling" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn 2.0.101", -] - -[[package]] -name = "darling_macro" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" -dependencies = [ - "darling_core", - "quote", - "syn 2.0.101", -] - [[package]] name = "dashmap" version = "5.5.3" @@ -4490,6 +4489,15 @@ dependencies = [ "zlog", ] +[[package]] +name = "debugid" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" +dependencies = [ + "uuid", +] + [[package]] name = "deepseek" version = "0.1.0" @@ -4540,37 +4548,6 @@ dependencies = [ "serde", ] -[[package]] -name = "derive_builder" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 2.0.101", -] - -[[package]] -name = "derive_builder_macro" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" -dependencies = [ - "derive_builder_core", - "syn 2.0.101", -] - [[package]] name = "derive_more" version = "0.99.19" @@ -4792,7 +4769,6 @@ name = "docs_preprocessor" version = "0.1.0" dependencies = [ "anyhow", - "clap", "command_palette", "gpui", "mdbook", @@ -4803,6 +4779,7 @@ dependencies = [ "util", "workspace-hack", "zed", + "zlog", ] [[package]] @@ -4924,6 +4901,49 @@ dependencies = [ "signature 1.6.4", ] +[[package]] +name = "edit_prediction" +version = "0.1.0" +dependencies = [ + "client", + "gpui", + "language", + "project", + "workspace-hack", +] + +[[package]] +name = "edit_prediction_button" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "cloud_llm_client", + "copilot", + "edit_prediction", + "editor", + "feature_flags", + "fs", + "futures 0.3.31", + "gpui", + "indoc", + "language", + "lsp", + "paths", + "project", + "regex", + "serde_json", + "settings", + "supermaven", + "telemetry", + "theme", + "ui", + "workspace", + "workspace-hack", + "zed_actions", + "zeta", +] + [[package]] name = "editor" version = "0.1.0" @@ -4939,6 +4959,7 @@ dependencies = [ "ctor", "dap", "db", + "edit_prediction", "emojis", "file_icons", "fs", @@ -4948,7 +4969,6 @@ dependencies = [ "gpui", "http_client", "indoc", - "inline_completion", "itertools 0.14.0", "language", "languages", @@ -4981,6 +5001,7 @@ dependencies = [ "theme", "time", "tree-sitter-bash", + "tree-sitter-c", "tree-sitter-html", "tree-sitter-python", "tree-sitter-rust", @@ -5263,6 +5284,7 @@ dependencies = [ "chrono", "clap", "client", + "cloud_llm_client", "collections", "debug_adapter_extension", "dirs 4.0.0", @@ -5302,7 +5324,6 @@ dependencies = [ "uuid", "watch", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -5367,6 +5388,12 @@ dependencies = [ "zune-inflate", ] +[[package]] +name = "extended" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af9673d8203fcb076b19dfd17e38b3d4ae9f44959416ea532ce72415a6020365" + [[package]] name = "extension" version = "0.1.0" @@ -5943,7 +5970,7 @@ dependencies = [ "ignore", "libc", "log", - "notify", + "notify 8.0.0", "objc", "parking_lot", "paths", @@ -6377,7 +6404,7 @@ dependencies = [ "buffer_diff", "call", "chrono", - "client", + "cloud_llm_client", "collections", "command_palette_hooks", "component", @@ -6388,6 +6415,7 @@ dependencies = [ "fuzzy", "git", "gpui", + "indoc", "itertools 0.14.0", "language", "language_model", @@ -6420,7 +6448,6 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", - "zed_llm_client", "zlog", ] @@ -7253,6 +7280,17 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "goblin" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b363a30c165f666402fe6a3024d3bec7ebc898f96a4a23bd1c99f8dbf3f4f47" +dependencies = [ + "log", + "plain", + "scroll", +] + [[package]] name = "google_ai" version = "0.1.0" @@ -7500,18 +7538,16 @@ dependencies = [ [[package]] name = "handlebars" -version = "6.3.2" +version = "5.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "759e2d5aea3287cb1190c8ec394f42866cb5bf74fcbf213f354e3c856ea26098" +checksum = "d08485b96a0e6393e9e4d1b8d48cf74ad6c063cd905eb33f42c1ce3f0377539b" dependencies = [ - "derive_builder", "log", - "num-order", "pest", "pest_derive", "serde", "serde_json", - "thiserror 2.0.12", + "thiserror 1.0.69", ] [[package]] @@ -7692,12 +7728,6 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" -[[package]] -name = "hex-literal" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcaaec4551594c969335c98c903c1397853d4198408ea609190f420500f6be71" - [[package]] name = "hexf-parse" version = "0.2.1" @@ -7742,12 +7772,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "hound" -version = "3.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" - [[package]] name = "html5ever" version = "0.27.0" @@ -7881,6 +7905,8 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "log", + "parking_lot", + "reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)", "serde", "serde_json", "url", @@ -8188,12 +8214,6 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "1.0.3" @@ -8370,46 +8390,14 @@ dependencies = [ ] [[package]] -name = "inline_completion" -version = "0.1.0" +name = "inotify" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" dependencies = [ - "client", - "gpui", - "language", - "project", - "workspace-hack", -] - -[[package]] -name = "inline_completion_button" -version = "0.1.0" -dependencies = [ - "anyhow", - "client", - "copilot", - "editor", - "feature_flags", - "fs", - "futures 0.3.31", - "gpui", - "indoc", - "inline_completion", - "language", - "lsp", - "paths", - "project", - "regex", - "serde_json", - "settings", - "supermaven", - "telemetry", - "theme", - "ui", - "workspace", - "workspace-hack", - "zed_actions", - "zed_llm_client", - "zeta", + "bitflags 1.3.2", + "inotify-sys", + "libc", ] [[package]] @@ -8565,7 +8553,7 @@ dependencies = [ "fnv", "lazy_static", "libc", - "mio", + "mio 1.0.3", "rand 0.8.5", "serde", "tempfile", @@ -9090,6 +9078,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "client", + "cloud_llm_client", "collections", "futures 0.3.31", "gpui", @@ -9107,7 +9096,6 @@ dependencies = [ "thiserror 2.0.12", "util", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -9123,6 +9111,7 @@ dependencies = [ "bedrock", "chrono", "client", + "cloud_llm_client", "collections", "component", "convert_case 0.8.0", @@ -9147,7 +9136,6 @@ dependencies = [ "open_router", "partial-json-fixer", "project", - "proto", "release_channel", "schemars", "serde", @@ -9165,7 +9153,6 @@ dependencies = [ "vercel", "workspace-hack", "x_ai", - "zed_llm_client", ] [[package]] @@ -9227,6 +9214,7 @@ dependencies = [ "chrono", "collections", "dap", + "feature_flags", "futures 0.3.31", "gpui", "http_client", @@ -9419,7 +9407,7 @@ dependencies = [ [[package]] name = "libwebrtc" version = "0.3.10" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "cxx", "jni", @@ -9499,7 +9487,7 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "livekit" version = "0.7.8" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "chrono", "futures-util", @@ -9522,7 +9510,7 @@ dependencies = [ [[package]] name = "livekit-api" version = "0.4.2" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "futures-util", "http 0.2.12", @@ -9546,7 +9534,7 @@ dependencies = [ [[package]] name = "livekit-protocol" version = "0.3.9" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "futures-util", "livekit-runtime", @@ -9563,7 +9551,7 @@ dependencies = [ [[package]] name = "livekit-runtime" version = "0.4.0" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "tokio", "tokio-stream", @@ -9595,7 +9583,7 @@ dependencies = [ "core-foundation 0.10.0", "core-video", "coreaudio-rs 0.12.1", - "cpal 0.16.0", + "cpal", "futures 0.3.31", "gpui", "gpui_tokio", @@ -9646,9 +9634,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -9885,7 +9873,7 @@ name = "markdown_preview" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 1.1.1", + "async-recursion", "collections", "editor", "fs", @@ -10005,9 +9993,9 @@ dependencies = [ [[package]] name = "mdbook" -version = "0.4.48" +version = "0.4.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6fbb4ac2d9fd7aa987c3510309ea3c80004a968d063c42f0d34fea070817c1" +checksum = "b45a38e19bd200220ef07c892b0157ad3d2365e5b5a267ca01ad12182491eea5" dependencies = [ "ammonia", "anyhow", @@ -10017,12 +10005,11 @@ dependencies = [ "elasticlunr-rs", "env_logger 0.11.8", "futures-util", - "handlebars 6.3.2", - "hex", + "handlebars 5.1.2", "ignore", "log", "memchr", - "notify", + "notify 6.1.1", "notify-debouncer-mini", "once_cell", "opener", @@ -10031,7 +10018,6 @@ dependencies = [ "regex", "serde", "serde_json", - "sha2", "shlex", "tempfile", "tokio", @@ -10152,6 +10138,63 @@ dependencies = [ "unicase", ] +[[package]] +name = "minidump-common" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c4d14bcca0fd3ed165a03000480aaa364c6860c34e900cb2dafdf3b95340e77" +dependencies = [ + "bitflags 2.9.0", + "debugid", + "num-derive", + "num-traits", + "range-map", + "scroll", + "smart-default 0.7.1", +] + +[[package]] +name = "minidump-writer" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abcd9c8a1e6e1e9d56ce3627851f39a17ea83e17c96bc510f29d7e43d78a7d" +dependencies = [ + "bitflags 2.9.0", + "byteorder", + "cfg-if", + "crash-context", + "goblin", + "libc", + "log", + "mach2", + "memmap2", + "memoffset", + "minidump-common", + "nix 0.28.0", + "procfs-core", + "scroll", + "tempfile", + "thiserror 1.0.69", +] + +[[package]] +name = "minidumper" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4ebc9d1f8847ec1d078f78b35ed598e0ebefa1f242d5f83cd8d7f03960a7d1" +dependencies = [ + "cfg-if", + "crash-context", + "libc", + "log", + "minidump-writer", + "parking_lot", + "polling", + "scroll", + "thiserror 1.0.69", + "uds", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -10174,6 +10217,18 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e53debba6bda7a793e5f99b8dacf19e626084f525f7829104ba9898f367d85ff" +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys 0.48.0", +] + [[package]] name = "mio" version = "1.0.3" @@ -10366,20 +10421,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "ndk" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7" -dependencies = [ - "bitflags 2.9.0", - "jni-sys", - "log", - "ndk-sys 0.5.0+25.2.9519653", - "num_enum", - "thiserror 1.0.69", -] - [[package]] name = "ndk" version = "0.9.0" @@ -10389,7 +10430,7 @@ dependencies = [ "bitflags 2.9.0", "jni-sys", "log", - "ndk-sys 0.6.0+11769913", + "ndk-sys", "num_enum", "thiserror 1.0.69", ] @@ -10400,15 +10441,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" -[[package]] -name = "ndk-sys" -version = "0.5.0+25.2.9519653" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691" -dependencies = [ - "jni-sys", -] - [[package]] name = "ndk-sys" version = "0.6.0+11769913" @@ -10543,6 +10575,25 @@ dependencies = [ "zed_actions", ] +[[package]] +name = "notify" +version = "6.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" +dependencies = [ + "bitflags 2.9.0", + "crossbeam-channel", + "filetime", + "fsevent-sys 4.1.0", + "inotify 0.9.6", + "kqueue", + "libc", + "log", + "mio 0.8.11", + "walkdir", + "windows-sys 0.48.0", +] + [[package]] name = "notify" version = "8.0.0" @@ -10551,11 +10602,11 @@ dependencies = [ "bitflags 2.9.0", "filetime", "fsevent-sys 4.1.0", - "inotify", + "inotify 0.11.0", "kqueue", "libc", "log", - "mio", + "mio 1.0.3", "notify-types", "walkdir", "windows-sys 0.59.0", @@ -10563,14 +10614,13 @@ dependencies = [ [[package]] name = "notify-debouncer-mini" -version = "0.6.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a689eb4262184d9a1727f9087cd03883ea716682ab03ed24efec57d7716dccb8" +checksum = "5d40b221972a1fc5ef4d858a2f671fb34c75983eb385463dff3780eeff6a9d43" dependencies = [ + "crossbeam-channel", "log", - "notify", - "notify-types", - "tempfile", + "notify 6.1.1", ] [[package]] @@ -10710,21 +10760,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-modular" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17bb261bf36fa7d83f4c294f834e91256769097b3cb505d44831e0a179ac647f" - -[[package]] -name = "num-order" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537b596b97c40fcf8056d153049eb22f481c17ebce72a513ec9286e4986d1bb6" -dependencies = [ - "num-modular", -] - [[package]] name = "num-rational" version = "0.4.2" @@ -10978,29 +11013,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "oboe" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8b61bebd49e5d43f5f8cc7ee2891c16e0f41ec7954d36bcb6c14c5e0de867fb" -dependencies = [ - "jni", - "ndk 0.8.0", - "ndk-context", - "num-derive", - "num-traits", - "oboe-sys", -] - -[[package]] -name = "oboe-sys" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8bb09a4a2b1d668170cfe0a7d5bc103f8999fb316c98099b6a9939c9f2e79d" -dependencies = [ - "cc", -] - [[package]] name = "ollama" version = "0.1.0" @@ -11018,17 +11030,36 @@ dependencies = [ name = "onboarding" version = "0.1.0" dependencies = [ + "ai_onboarding", "anyhow", + "client", "command_palette_hooks", + "component", "db", + "documented", + "editor", "feature_flags", "fs", + "fuzzy", "gpui", + "itertools 0.14.0", + "language", + "language_model", + "menu", + "notifications", + "picker", + "project", + "schemars", + "serde", "settings", "theme", "ui", + "util", + "vim_mode_setting", "workspace", "workspace-hack", + "zed_actions", + "zlog", ] [[package]] @@ -11379,9 +11410,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -11389,9 +11420,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", @@ -12155,6 +12186,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "plist" version = "1.7.1" @@ -12415,6 +12452,16 @@ dependencies = [ "yansi", ] +[[package]] +name = "procfs-core" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d3554923a69f4ce04c4a754260c338f505ce22642d3830e049a399fc2059a29" +dependencies = [ + "bitflags 2.9.0", + "hex", +] + [[package]] name = "prodash" version = "29.0.2" @@ -13065,6 +13112,15 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "range-map" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12a5a2d6c7039059af621472a4389be1215a816df61aa4d531cfe85264aee95f" +dependencies = [ + "num-traits", +] + [[package]] name = "rangemap" version = "1.5.1" @@ -13407,6 +13463,8 @@ dependencies = [ "clap", "client", "clock", + "crash-handler", + "crashes", "dap", "dap_adapters", "debug_adapter_extension", @@ -13430,6 +13488,7 @@ dependencies = [ "libc", "log", "lsp", + "minidumper", "node_runtime", "paths", "project", @@ -13618,6 +13677,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "once_cell", "percent-encoding", "pin-project-lite", @@ -13779,12 +13839,15 @@ dependencies = [ [[package]] name = "rodio" -version = "0.20.1" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ceb6607dd738c99bc8cb28eff249b7cd5c8ec88b9db96c0608c1480d140fb1" +checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183" dependencies = [ - "cpal 0.15.3", - "hound", + "cpal", + "dasp_sample", + "num-rational", + "symphonia", + "tracing", ] [[package]] @@ -14343,6 +14406,26 @@ dependencies = [ "once_cell", ] +[[package]] +name = "scroll" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab8598aa408498679922eff7fa985c25d58a90771bd6be794434c5277eab1a6" +dependencies = [ + "scroll_derive", +] + +[[package]] +name = "scroll_derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1783eabc414609e28a5ba76aee5ddd52199f7107a0b24c2e9746a1ecc34a683d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "scrypt" version = "0.11.0" @@ -14789,6 +14872,27 @@ dependencies = [ "zlog", ] +[[package]] +name = "settings_profile_selector" +version = "0.1.0" +dependencies = [ + "client", + "editor", + "fuzzy", + "gpui", + "language", + "menu", + "picker", + "project", + "serde_json", + "settings", + "theme", + "ui", + "workspace", + "workspace-hack", + "zed_actions", +] + [[package]] name = "settings_ui" version = "0.1.0" @@ -14811,7 +14915,6 @@ dependencies = [ "notifications", "paths", "project", - "schemars", "search", "serde", "serde_json", @@ -15068,6 +15171,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "smart-default" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eb01866308440fc64d6c44d9e86c5cc17adfe33c4d6eed55da9145044d0ffc1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "smol" version = "2.0.2" @@ -15650,12 +15764,12 @@ dependencies = [ "anyhow", "client", "collections", + "edit_prediction", "editor", "env_logger 0.11.8", "futures 0.3.31", "gpui", "http_client", - "inline_completion", "language", "log", "postage", @@ -15805,6 +15919,66 @@ dependencies = [ "zeno", ] +[[package]] +name = "symphonia" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9" +dependencies = [ + "lazy_static", + "symphonia-codec-pcm", + "symphonia-core", + "symphonia-format-riff", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-codec-pcm" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f395a67057c2ebc5e84d7bb1be71cce1a7ba99f64e0f0f0e303a03f79116f89b" +dependencies = [ + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-core" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "798306779e3dc7d5231bd5691f5a813496dc79d3f56bf82e25789f2094e022c3" +dependencies = [ + "arrayvec", + "bitflags 1.3.2", + "bytemuck", + "lazy_static", + "log", +] + +[[package]] +name = "symphonia-format-riff" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f7be232f962f937f4b7115cbe62c330929345434c834359425e043bfd15f50" +dependencies = [ + "extended", + "log", + "symphonia-core", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-metadata" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc622b9841a10089c5b18e99eb904f4341615d5aa55bbf4eedde1be721a4023c" +dependencies = [ + "encoding_rs", + "lazy_static", + "log", + "symphonia-core", +] + [[package]] name = "syn" version = "1.0.109" @@ -16188,7 +16362,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_slash_command", - "async-recursion 1.1.1", + "async-recursion", "breadcrumbs", "client", "collections", @@ -16537,6 +16711,7 @@ dependencies = [ "call", "chrono", "client", + "cloud_llm_client", "collections", "db", "gpui", @@ -16572,7 +16747,7 @@ dependencies = [ "backtrace", "bytes 1.10.1", "libc", - "mio", + "mio 1.0.3", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -17290,6 +17465,15 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" +[[package]] +name = "uds" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "885c31f06fce836457fe3ef09a59f83fe8db95d270b11cd78f40a4666c4d1661" +dependencies = [ + "libc", +] + [[package]] name = "uds_windows" version = "1.1.0" @@ -18505,11 +18689,11 @@ name = "web_search" version = "0.1.0" dependencies = [ "anyhow", + "cloud_llm_client", "collections", "gpui", "serde", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -18518,6 +18702,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", + "cloud_llm_client", "futures 0.3.31", "gpui", "http_client", @@ -18526,7 +18711,6 @@ dependencies = [ "serde_json", "web_search", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -18550,7 +18734,7 @@ dependencies = [ [[package]] name = "webrtc-sys" version = "0.3.7" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "cc", "cxx", @@ -18563,15 +18747,13 @@ dependencies = [ [[package]] name = "webrtc-sys-build" version = "0.3.6" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "fs2", - "hex-literal", "regex", "reqwest 0.11.27", "scratch", "semver", - "sha2", "zip", ] @@ -18600,7 +18782,6 @@ dependencies = [ "serde", "settings", "telemetry", - "theme", "ui", "util", "vim_mode_setting", @@ -19615,7 +19796,7 @@ version = "0.1.0" dependencies = [ "any_vec", "anyhow", - "async-recursion 1.1.1", + "async-recursion", "bincode", "call", "client", @@ -19692,14 +19873,12 @@ dependencies = [ "cc", "chrono", "cipher", - "clang-sys", "clap", "clap_builder", "codespan-reporting 0.12.0", "concurrent-queue", "core-foundation 0.9.4", "core-foundation-sys", - "coreaudio-sys", "cranelift-codegen", "crc32fast", "crossbeam-epoch", @@ -19750,9 +19929,11 @@ dependencies = [ "lyon_path", "md-5", "memchr", + "mime_guess", "miniz_oxide", - "mio", + "mio 1.0.3", "naga", + "nix 0.28.0", "nix 0.29.0", "nom", "num-bigint", @@ -20142,7 +20323,7 @@ dependencies = [ "async-io", "async-lock", "async-process", - "async-recursion 1.1.1", + "async-recursion", "async-task", "async-trait", "blocking", @@ -20195,7 +20376,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.198.0" +version = "0.199.0" dependencies = [ "activity_indicator", "agent", @@ -20224,6 +20405,7 @@ dependencies = [ "command_palette", "component", "copilot", + "crashes", "dap", "dap_adapters", "db", @@ -20231,6 +20413,7 @@ dependencies = [ "debugger_tools", "debugger_ui", "diagnostics", + "edit_prediction_button", "editor", "env_logger 0.11.8", "extension", @@ -20250,7 +20433,6 @@ dependencies = [ "http_client", "image_viewer", "indoc", - "inline_completion_button", "inspector_ui", "install_cli", "itertools 0.14.0", @@ -20291,6 +20473,7 @@ dependencies = [ "release_channel", "remote", "repl", + "reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)", "reqwest_client", "rope", "search", @@ -20298,6 +20481,7 @@ dependencies = [ "serde_json", "session", "settings", + "settings_profile_selector", "settings_ui", "shellexpand 2.1.2", "smol", @@ -20356,7 +20540,7 @@ dependencies = [ [[package]] name = "zed_emmet" -version = "0.0.3" +version = "0.0.4" dependencies = [ "zed_extension_api 0.1.0", ] @@ -20395,19 +20579,6 @@ dependencies = [ "zed_extension_api 0.1.0", ] -[[package]] -name = "zed_llm_client" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc" -dependencies = [ - "anyhow", - "serde", - "serde_json", - "strum 0.27.1", - "uuid", -] - [[package]] name = "zed_proto" version = "0.2.2" @@ -20417,7 +20588,7 @@ dependencies = [ [[package]] name = "zed_ruff" -version = "0.1.0" +version = "0.1.1" dependencies = [ "zed_extension_api 0.1.0", ] @@ -20587,11 +20758,14 @@ dependencies = [ "call", "client", "clock", + "cloud_api_types", + "cloud_llm_client", "collections", "command_palette_hooks", "copilot", "ctor", "db", + "edit_prediction", "editor", "feature_flags", "fs", @@ -20599,14 +20773,12 @@ dependencies = [ "gpui", "http_client", "indoc", - "inline_completion", "language", "language_model", "log", "menu", "postage", "project", - "proto", "regex", "release_channel", "reqwest_client", @@ -20628,10 +20800,45 @@ dependencies = [ "workspace-hack", "worktree", "zed_actions", - "zed_llm_client", "zlog", ] +[[package]] +name = "zeta_cli" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "client", + "debug_adapter_extension", + "extension", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "language", + "language_extension", + "language_model", + "language_models", + "languages", + "node_runtime", + "paths", + "project", + "prompt_store", + "release_channel", + "reqwest_client", + "serde", + "serde_json", + "settings", + "shellexpand 2.1.2", + "smol", + "terminal_view", + "util", + "watch", + "workspace-hack", + "zeta", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index 16ace7dee0..733db92ce9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [workspace] resolver = "2" members = [ - "crates/activity_indicator", "crates/acp_thread", - "crates/agent_ui", + "crates/activity_indicator", "crates/agent", - "crates/agent_settings", - "crates/ai_onboarding", "crates/agent_servers", + "crates/agent_settings", + "crates/agent_ui", + "crates/ai_onboarding", "crates/anthropic", "crates/askpass", "crates/assets", @@ -29,6 +29,9 @@ members = [ "crates/cli", "crates/client", "crates/clock", + "crates/cloud_api_client", + "crates/cloud_api_types", + "crates/cloud_llm_client", "crates/collab", "crates/collab_ui", "crates/collections", @@ -37,6 +40,7 @@ members = [ "crates/component", "crates/context_server", "crates/copilot", + "crates/crashes", "crates/credentials_provider", "crates/dap", "crates/dap_adapters", @@ -48,8 +52,8 @@ members = [ "crates/diagnostics", "crates/docs_preprocessor", "crates/editor", - "crates/explorer_command_injector", "crates/eval", + "crates/explorer_command_injector", "crates/extension", "crates/extension_api", "crates/extension_cli", @@ -70,15 +74,14 @@ members = [ "crates/gpui", "crates/gpui_macros", "crates/gpui_tokio", - "crates/html_to_markdown", "crates/http_client", "crates/http_client_tls", "crates/icons", "crates/image_viewer", "crates/indexed_docs", - "crates/inline_completion", - "crates/inline_completion_button", + "crates/edit_prediction", + "crates/edit_prediction_button", "crates/inspector_ui", "crates/install_cli", "crates/jj", @@ -99,7 +102,6 @@ members = [ "crates/markdown_preview", "crates/media", "crates/menu", - "crates/svg_preview", "crates/migrator", "crates/mistral", "crates/multi_buffer", @@ -140,6 +142,7 @@ members = [ "crates/semantic_version", "crates/session", "crates/settings", + "crates/settings_profile_selector", "crates/settings_ui", "crates/snippet", "crates/snippet_provider", @@ -152,6 +155,7 @@ members = [ "crates/sum_tree", "crates/supermaven", "crates/supermaven_api", + "crates/svg_preview", "crates/tab_switcher", "crates/task", "crates/tasks_ui", @@ -186,6 +190,7 @@ members = [ "crates/zed", "crates/zed_actions", "crates/zeta", + "crates/zeta_cli", "crates/zlog", "crates/zlog_settings", @@ -251,6 +256,9 @@ channel = { path = "crates/channel" } cli = { path = "crates/cli" } client = { path = "crates/client" } clock = { path = "crates/clock" } +cloud_api_client = { path = "crates/cloud_api_client" } +cloud_api_types = { path = "crates/cloud_api_types" } +cloud_llm_client = { path = "crates/cloud_llm_client" } collab = { path = "crates/collab" } collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } @@ -259,6 +267,7 @@ command_palette_hooks = { path = "crates/command_palette_hooks" } component = { path = "crates/component" } context_server = { path = "crates/context_server" } copilot = { path = "crates/copilot" } +crashes = { path = "crates/crashes" } credentials_provider = { path = "crates/credentials_provider" } dap = { path = "crates/dap" } dap_adapters = { path = "crates/dap_adapters" } @@ -295,8 +304,8 @@ http_client_tls = { path = "crates/http_client_tls" } icons = { path = "crates/icons" } image_viewer = { path = "crates/image_viewer" } indexed_docs = { path = "crates/indexed_docs" } -inline_completion = { path = "crates/inline_completion" } -inline_completion_button = { path = "crates/inline_completion_button" } +edit_prediction = { path = "crates/edit_prediction" } +edit_prediction_button = { path = "crates/edit_prediction_button" } inspector_ui = { path = "crates/inspector_ui" } install_cli = { path = "crates/install_cli" } jj = { path = "crates/jj" } @@ -337,6 +346,7 @@ picker = { path = "crates/picker" } plugin = { path = "crates/plugin" } plugin_macros = { path = "crates/plugin_macros" } prettier = { path = "crates/prettier" } +settings_profile_selector = { path = "crates/settings_profile_selector" } project = { path = "crates/project" } project_panel = { path = "crates/project_panel" } project_symbols = { path = "crates/project_symbols" } @@ -413,7 +423,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.11" +agent-client-protocol = "0.0.18" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" @@ -458,6 +468,7 @@ core-foundation = "0.10.0" core-foundation-sys = "0.8.6" core-video = { version = "0.4.3", features = ["metal"] } cpal = "0.16" +crash-handler = "0.6" criterion = { version = "0.5", features = ["html_reports"] } ctor = "0.4.0" dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "1b461b310481d01e02b2603c16d7144b926339f8" } @@ -505,6 +516,7 @@ log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" } markup5ever_rcdom = "0.3.0" metal = "0.29" +minidumper = "0.8" moka = { version = "0.12.10", features = ["sync"] } naga = { version = "25.0", features = ["wgsl-in"] } nanoid = "0.4" @@ -544,6 +556,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77 "charset", "http2", "macos-system-configuration", + "multipart", "rustls-tls-native-roots", "socks", "stream", @@ -645,7 +658,6 @@ which = "6.0.0" windows-core = "0.61" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "= 0.8.6" zstd = "0.11" [workspace.dependencies.async-stripe] @@ -672,14 +684,16 @@ features = [ "UI_ViewManagement", "Wdk_System_SystemServices", "Win32_Globalization", - "Win32_Graphics_Direct2D", - "Win32_Graphics_Direct2D_Common", + "Win32_Graphics_Direct3D", + "Win32_Graphics_Direct3D11", + "Win32_Graphics_Direct3D_Fxc", + "Win32_Graphics_DirectComposition", "Win32_Graphics_DirectWrite", "Win32_Graphics_Dwm", + "Win32_Graphics_Dxgi", "Win32_Graphics_Dxgi_Common", "Win32_Graphics_Gdi", "Win32_Graphics_Imaging", - "Win32_Graphics_Imaging_D2D", "Win32_Networking_WinSock", "Win32_Security", "Win32_Security_Credentials", @@ -747,7 +761,7 @@ feature_flags = { codegen-units = 1 } file_icons = { codegen-units = 1 } fsevent = { codegen-units = 1 } image_viewer = { codegen-units = 1 } -inline_completion_button = { codegen-units = 1 } +edit_prediction_button = { codegen-units = 1 } install_cli = { codegen-units = 1 } journal = { codegen-units = 1 } lmstudio = { codegen-units = 1 } diff --git a/README.md b/README.md index 4c794efc3d..38547c1ca4 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Zed +[![Zed](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/zed-industries/zed/main/assets/badge/v0.json)](https://zed.dev) [![CI](https://github.com/zed-industries/zed/actions/workflows/ci.yml/badge.svg)](https://github.com/zed-industries/zed/actions/workflows/ci.yml) Welcome to Zed, a high-performance, multiplayer code editor from the creators of [Atom](https://github.com/atom/atom) and [Tree-sitter](https://github.com/tree-sitter/tree-sitter). diff --git a/assets/badge/v0.json b/assets/badge/v0.json new file mode 100644 index 0000000000..c7d18bb42b --- /dev/null +++ b/assets/badge/v0.json @@ -0,0 +1,8 @@ +{ + "label": "", + "message": "Zed", + "logoSvg": "", + "logoWidth": 16, + "labelColor": "black", + "color": "white" +} diff --git a/assets/icons/ai_bedrock.svg b/assets/icons/ai_bedrock.svg index 2b672c364e..c9bbcc82e1 100644 --- a/assets/icons/ai_bedrock.svg +++ b/assets/icons/ai_bedrock.svg @@ -1,4 +1,8 @@ - - - + + + + + + + diff --git a/assets/icons/ai_deep_seek.svg b/assets/icons/ai_deep_seek.svg index cf480c834c..c8e5483fb3 100644 --- a/assets/icons/ai_deep_seek.svg +++ b/assets/icons/ai_deep_seek.svg @@ -1 +1,3 @@ -DeepSeek + + + diff --git a/assets/icons/ai_lm_studio.svg b/assets/icons/ai_lm_studio.svg index 0b455f48a7..5cfdeb5578 100644 --- a/assets/icons/ai_lm_studio.svg +++ b/assets/icons/ai_lm_studio.svg @@ -1,33 +1,15 @@ - - - Artboard - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + diff --git a/assets/icons/ai_mistral.svg b/assets/icons/ai_mistral.svg index 23b8f2ef6c..f11c177e2f 100644 --- a/assets/icons/ai_mistral.svg +++ b/assets/icons/ai_mistral.svg @@ -1 +1,8 @@ -Mistral \ No newline at end of file + + + + + + + + diff --git a/assets/icons/ai_ollama.svg b/assets/icons/ai_ollama.svg index d433df3981..36a88c1ad6 100644 --- a/assets/icons/ai_ollama.svg +++ b/assets/icons/ai_ollama.svg @@ -1,14 +1,7 @@ - - - - - - - - - - - - + + + + + diff --git a/assets/icons/ai_open_ai.svg b/assets/icons/ai_open_ai.svg index e659a472d8..e45ac315a0 100644 --- a/assets/icons/ai_open_ai.svg +++ b/assets/icons/ai_open_ai.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/ai_open_router.svg b/assets/icons/ai_open_router.svg index 94f2849146..b6f5164e0b 100644 --- a/assets/icons/ai_open_router.svg +++ b/assets/icons/ai_open_router.svg @@ -1,8 +1,8 @@ - - - - - - - + + + + + + + diff --git a/assets/icons/ai_x_ai.svg b/assets/icons/ai_x_ai.svg index 289525c8ef..d3400fbe9c 100644 --- a/assets/icons/ai_x_ai.svg +++ b/assets/icons/ai_x_ai.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/ai_zed.svg b/assets/icons/ai_zed.svg index 1c6bb8ad63..6d78efacd5 100644 --- a/assets/icons/ai_zed.svg +++ b/assets/icons/ai_zed.svg @@ -1,10 +1,3 @@ - - - - - - - - + diff --git a/assets/icons/bolt.svg b/assets/icons/bolt.svg deleted file mode 100644 index 2688ede2a5..0000000000 --- a/assets/icons/bolt.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/bolt_filled.svg b/assets/icons/bolt_filled.svg index 543e72adf8..14d8f53e02 100644 --- a/assets/icons/bolt_filled.svg +++ b/assets/icons/bolt_filled.svg @@ -1,3 +1,3 @@ - - + + diff --git a/assets/icons/bolt_filled_alt.svg b/assets/icons/bolt_filled_alt.svg deleted file mode 100644 index 141e1c5f57..0000000000 --- a/assets/icons/bolt_filled_alt.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/bolt_outlined.svg b/assets/icons/bolt_outlined.svg new file mode 100644 index 0000000000..58fccf7788 --- /dev/null +++ b/assets/icons/bolt_outlined.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/book_plus.svg b/assets/icons/book_plus.svg deleted file mode 100644 index 2868f07cd0..0000000000 --- a/assets/icons/book_plus.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/brain.svg b/assets/icons/brain.svg deleted file mode 100644 index 80c93814f7..0000000000 --- a/assets/icons/brain.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/chat.svg b/assets/icons/chat.svg new file mode 100644 index 0000000000..a0548c3d3e --- /dev/null +++ b/assets/icons/chat.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/at_sign.svg b/assets/icons/cloud_download.svg similarity index 51% rename from assets/icons/at_sign.svg rename to assets/icons/cloud_download.svg index 4cf8cd468f..bc7a8376d1 100644 --- a/assets/icons/at_sign.svg +++ b/assets/icons/cloud_download.svg @@ -1 +1 @@ - + \ No newline at end of file diff --git a/assets/icons/editor_atom.svg b/assets/icons/editor_atom.svg new file mode 100644 index 0000000000..cc5fa83843 --- /dev/null +++ b/assets/icons/editor_atom.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/editor_cursor.svg b/assets/icons/editor_cursor.svg new file mode 100644 index 0000000000..338697be8a --- /dev/null +++ b/assets/icons/editor_cursor.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/assets/icons/editor_emacs.svg b/assets/icons/editor_emacs.svg new file mode 100644 index 0000000000..951d7b2be1 --- /dev/null +++ b/assets/icons/editor_emacs.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/assets/icons/editor_jet_brains.svg b/assets/icons/editor_jet_brains.svg new file mode 100644 index 0000000000..7d9cf0c65c --- /dev/null +++ b/assets/icons/editor_jet_brains.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/editor_sublime.svg b/assets/icons/editor_sublime.svg new file mode 100644 index 0000000000..95a04f6b54 --- /dev/null +++ b/assets/icons/editor_sublime.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/editor_vs_code.svg b/assets/icons/editor_vs_code.svg new file mode 100644 index 0000000000..2a71ad52af --- /dev/null +++ b/assets/icons/editor_vs_code.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/file_icons/kdl.svg b/assets/icons/file_icons/kdl.svg new file mode 100644 index 0000000000..92d9f28428 --- /dev/null +++ b/assets/icons/file_icons/kdl.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/assets/icons/file_text.svg b/assets/icons/file_text.svg index 7c602f2ac7..a9b8f971e0 100644 --- a/assets/icons/file_text.svg +++ b/assets/icons/file_text.svg @@ -1 +1,6 @@ - + + + + + + diff --git a/assets/icons/git_onboarding_bg.svg b/assets/icons/git_onboarding_bg.svg deleted file mode 100644 index 18da0230a2..0000000000 --- a/assets/icons/git_onboarding_bg.svg +++ /dev/null @@ -1,40 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/assets/icons/message_bubbles.svg b/assets/icons/message_bubbles.svg deleted file mode 100644 index 03a6c7760c..0000000000 --- a/assets/icons/message_bubbles.svg +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - diff --git a/assets/icons/microscope.svg b/assets/icons/microscope.svg deleted file mode 100644 index 2b3009a28b..0000000000 --- a/assets/icons/microscope.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/new_from_summary.svg b/assets/icons/new_from_summary.svg deleted file mode 100644 index 3b61ca51a0..0000000000 --- a/assets/icons/new_from_summary.svg +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - - diff --git a/assets/icons/play.svg b/assets/icons/play.svg deleted file mode 100644 index 2481bda7d6..0000000000 --- a/assets/icons/play.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/play_bug.svg b/assets/icons/play_bug.svg deleted file mode 100644 index 7d265dd42a..0000000000 --- a/assets/icons/play_bug.svg +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - diff --git a/assets/icons/play_filled.svg b/assets/icons/play_filled.svg index 387304ef04..c632434305 100644 --- a/assets/icons/play_filled.svg +++ b/assets/icons/play_filled.svg @@ -1,3 +1,3 @@ - - + + diff --git a/assets/icons/play_alt.svg b/assets/icons/play_outlined.svg similarity index 70% rename from assets/icons/play_alt.svg rename to assets/icons/play_outlined.svg index b327ab07b5..7e1cacd5af 100644 --- a/assets/icons/play_alt.svg +++ b/assets/icons/play_outlined.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/reveal.svg b/assets/icons/reveal.svg deleted file mode 100644 index ff5444d8f8..0000000000 --- a/assets/icons/reveal.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/shield_check.svg b/assets/icons/shield_check.svg new file mode 100644 index 0000000000..6e58c31468 --- /dev/null +++ b/assets/icons/shield_check.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/spinner.svg b/assets/icons/spinner.svg deleted file mode 100644 index 4f4034ae89..0000000000 --- a/assets/icons/spinner.svg +++ /dev/null @@ -1,13 +0,0 @@ - - - - - - - - - - - - - diff --git a/assets/icons/strikethrough.svg b/assets/icons/strikethrough.svg deleted file mode 100644 index d7d0905912..0000000000 --- a/assets/icons/strikethrough.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/new_text_thread.svg b/assets/icons/text_thread.svg similarity index 100% rename from assets/icons/new_text_thread.svg rename to assets/icons/text_thread.svg diff --git a/assets/icons/new_thread.svg b/assets/icons/thread.svg similarity index 100% rename from assets/icons/new_thread.svg rename to assets/icons/thread.svg diff --git a/assets/icons/thread_from_summary.svg b/assets/icons/thread_from_summary.svg new file mode 100644 index 0000000000..7519935aff --- /dev/null +++ b/assets/icons/thread_from_summary.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/assets/icons/trash.svg b/assets/icons/trash.svg index b71035b99c..1322e90f9f 100644 --- a/assets/icons/trash.svg +++ b/assets/icons/trash.svg @@ -1 +1,5 @@ - + + + + + diff --git a/assets/icons/trash_alt.svg b/assets/icons/trash_alt.svg deleted file mode 100644 index 6867b42147..0000000000 --- a/assets/icons/trash_alt.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/zed_predict_bg.svg b/assets/icons/zed_predict_bg.svg deleted file mode 100644 index 1dccbb51af..0000000000 --- a/assets/icons/zed_predict_bg.svg +++ /dev/null @@ -1,19 +0,0 @@ - - - - - - - - - - - - - - - - - - - diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 31adef8cd5..81f5c695a2 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -232,7 +232,7 @@ "ctrl-n": "agent::NewThread", "ctrl-alt-n": "agent::NewTextThread", "ctrl-shift-h": "agent::OpenHistory", - "ctrl-alt-c": "agent::OpenConfiguration", + "ctrl-alt-c": "agent::OpenSettings", "ctrl-alt-p": "agent::OpenRulesLibrary", "ctrl-i": "agent::ToggleProfileSelector", "ctrl-alt-/": "agent::ToggleModelSelector", @@ -495,7 +495,7 @@ "shift-f12": "editor::GoToImplementation", "alt-ctrl-f12": "editor::GoToTypeDefinitionSplit", "alt-shift-f12": "editor::FindAllReferences", - "ctrl-m": "editor::MoveToEnclosingBracket", + "ctrl-m": "editor::MoveToEnclosingBracket", // from jetbrains "ctrl-|": "editor::MoveToEnclosingBracket", "ctrl-{": "editor::Fold", "ctrl-}": "editor::UnfoldLines", @@ -598,6 +598,7 @@ "ctrl-shift-t": "pane::ReopenClosedItem", "ctrl-k ctrl-s": "zed::OpenKeymapEditor", "ctrl-k ctrl-t": "theme_selector::Toggle", + "ctrl-alt-super-p": "settings_profile_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -872,8 +873,6 @@ "tab": "git_panel::FocusEditor", "shift-tab": "git_panel::FocusEditor", "escape": "git_panel::ToggleFocus", - "ctrl-enter": "git::Commit", - "ctrl-shift-enter": "git::Amend", "alt-enter": "menu::SecondaryConfirm", "delete": ["git::RestoreFile", { "skip_prompt": false }], "backspace": ["git::RestoreFile", { "skip_prompt": false }], @@ -910,7 +909,9 @@ "ctrl-g backspace": "git::RestoreTrackedFiles", "ctrl-g shift-backspace": "git::TrashUntrackedFiles", "ctrl-space": "git::StageAll", - "ctrl-shift-space": "git::UnstageAll" + "ctrl-shift-space": "git::UnstageAll", + "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend" } }, { @@ -1167,5 +1168,15 @@ "up": "menu::SelectPrevious", "down": "menu::SelectNext" } + }, + { + "context": "Onboarding", + "use_key_equivalents": true, + "bindings": { + "ctrl-1": "onboarding::ActivateBasicsPage", + "ctrl-2": "onboarding::ActivateEditingPage", + "ctrl-3": "onboarding::ActivateAISetupPage", + "ctrl-escape": "onboarding::Finish" + } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index f942c6f8ae..69958fd1f8 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -272,7 +272,7 @@ "cmd-n": "agent::NewThread", "cmd-alt-n": "agent::NewTextThread", "cmd-shift-h": "agent::OpenHistory", - "cmd-alt-c": "agent::OpenConfiguration", + "cmd-alt-c": "agent::OpenSettings", "cmd-alt-p": "agent::OpenRulesLibrary", "cmd-i": "agent::ToggleProfileSelector", "cmd-alt-/": "agent::ToggleModelSelector", @@ -549,7 +549,7 @@ "alt-cmd-f12": "editor::GoToTypeDefinitionSplit", "alt-shift-f12": "editor::FindAllReferences", "cmd-|": "editor::MoveToEnclosingBracket", - "ctrl-m": "editor::MoveToEnclosingBracket", + "ctrl-m": "editor::MoveToEnclosingBracket", // From Jetbrains "alt-cmd-[": "editor::Fold", "alt-cmd-]": "editor::UnfoldLines", "cmd-k cmd-l": "editor::ToggleFold", @@ -665,6 +665,7 @@ "cmd-shift-t": "pane::ReopenClosedItem", "cmd-k cmd-s": "zed::OpenKeymapEditor", "cmd-k cmd-t": "theme_selector::Toggle", + "ctrl-alt-cmd-p": "settings_profile_selector::Toggle", "cmd-t": "project_symbols::Toggle", "cmd-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -950,8 +951,6 @@ "tab": "git_panel::FocusEditor", "shift-tab": "git_panel::FocusEditor", "escape": "git_panel::ToggleFocus", - "cmd-enter": "git::Commit", - "cmd-shift-enter": "git::Amend", "backspace": ["git::RestoreFile", { "skip_prompt": false }], "delete": ["git::RestoreFile", { "skip_prompt": false }], "cmd-backspace": ["git::RestoreFile", { "skip_prompt": true }], @@ -1001,7 +1000,9 @@ "ctrl-g backspace": "git::RestoreTrackedFiles", "ctrl-g shift-backspace": "git::TrashUntrackedFiles", "cmd-ctrl-y": "git::StageAll", - "cmd-ctrl-shift-y": "git::UnstageAll" + "cmd-ctrl-shift-y": "git::UnstageAll", + "cmd-enter": "git::Commit", + "cmd-shift-enter": "git::Amend" } }, { @@ -1269,5 +1270,15 @@ "up": "menu::SelectPrevious", "down": "menu::SelectNext" } + }, + { + "context": "Onboarding", + "use_key_equivalents": true, + "bindings": { + "cmd-1": "onboarding::ActivateBasicsPage", + "cmd-2": "onboarding::ActivateEditingPage", + "cmd-3": "onboarding::ActivateAISetupPage", + "cmd-escape": "onboarding::Finish" + } } ] diff --git a/assets/keymaps/linux/cursor.json b/assets/keymaps/linux/cursor.json index 347b7885fc..1c381b0cf0 100644 --- a/assets/keymaps/linux/cursor.json +++ b/assets/keymaps/linux/cursor.json @@ -8,7 +8,7 @@ "ctrl-shift-i": "agent::ToggleFocus", "ctrl-l": "agent::ToggleFocus", "ctrl-shift-l": "agent::ToggleFocus", - "ctrl-shift-j": "agent::OpenConfiguration" + "ctrl-shift-j": "agent::OpenSettings" } }, { diff --git a/assets/keymaps/linux/jetbrains.json b/assets/keymaps/linux/jetbrains.json index 629333663d..3df1243fed 100644 --- a/assets/keymaps/linux/jetbrains.json +++ b/assets/keymaps/linux/jetbrains.json @@ -4,6 +4,7 @@ "ctrl-alt-s": "zed::OpenSettings", "ctrl-{": "pane::ActivatePreviousItem", "ctrl-}": "pane::ActivateNextItem", + "shift-escape": null, // Unmap workspace::zoom "ctrl-f2": "debugger::Stop", "f6": "debugger::Pause", "f7": "debugger::StepInto", @@ -44,8 +45,8 @@ "ctrl-alt-right": "pane::GoForward", "alt-f7": "editor::FindAllReferences", "ctrl-alt-f7": "editor::FindAllReferences", - // "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock - // "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleLeftDock + "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock + "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleRightDock "ctrl-shift-b": "editor::GoToTypeDefinition", "ctrl-alt-shift-b": "editor::GoToTypeDefinitionSplit", "f2": "editor::GoToDiagnostic", @@ -94,18 +95,33 @@ "ctrl-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], "alt-shift-f10": "task::Spawn", "ctrl-e": "file_finder::Toggle", - "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor + // "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "ctrl-shift-n": "file_finder::Toggle", "ctrl-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", "ctrl-alt-shift-n": "project_symbols::Toggle", "alt-0": "git_panel::ToggleFocus", - "alt-1": "workspace::ToggleLeftDock", + "alt-1": "project_panel::ToggleFocus", "alt-5": "debug_panel::ToggleFocus", "alt-6": "diagnostics::Deploy", "alt-7": "outline_panel::ToggleFocus" } }, + { + "context": "Pane", // this is to override the default Pane mappings to switch tabs + "bindings": { + "alt-1": "project_panel::ToggleFocus", + "alt-2": null, // Bookmarks (left dock) + "alt-3": null, // Find Panel (bottom dock) + "alt-4": null, // Run Panel (bottom dock) + "alt-5": "debug_panel::ToggleFocus", + "alt-6": "diagnostics::Deploy", + "alt-7": "outline_panel::ToggleFocus", + "alt-8": null, // Services (bottom dock) + "alt-9": null, // Git History (bottom dock) + "alt-0": "git_panel::ToggleFocus" + } + }, { "context": "Workspace || Editor", "bindings": { @@ -150,7 +166,10 @@ { "context": "Diagnostics > Editor", "bindings": { "alt-6": "pane::CloseActiveItem" } }, { "context": "OutlinePanel", "bindings": { "alt-7": "workspace::CloseActiveDock" } }, { - "context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", - "bindings": { "escape": "editor::ToggleFocus" } + "context": "Dock || Workspace || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", + "bindings": { + "escape": "editor::ToggleFocus", + "shift-escape": "workspace::CloseActiveDock" + } } ] diff --git a/assets/keymaps/macos/cursor.json b/assets/keymaps/macos/cursor.json index b1d39bef9e..fdf9c437cf 100644 --- a/assets/keymaps/macos/cursor.json +++ b/assets/keymaps/macos/cursor.json @@ -8,7 +8,7 @@ "cmd-shift-i": "agent::ToggleFocus", "cmd-l": "agent::ToggleFocus", "cmd-shift-l": "agent::ToggleFocus", - "cmd-shift-j": "agent::OpenConfiguration" + "cmd-shift-j": "agent::OpenSettings" } }, { diff --git a/assets/keymaps/macos/jetbrains.json b/assets/keymaps/macos/jetbrains.json index e8b796f534..66962811f4 100644 --- a/assets/keymaps/macos/jetbrains.json +++ b/assets/keymaps/macos/jetbrains.json @@ -4,6 +4,7 @@ "cmd-{": "pane::ActivatePreviousItem", "cmd-}": "pane::ActivateNextItem", "cmd-0": "git_panel::ToggleFocus", // overrides `cmd-0` zoom reset + "shift-escape": null, // Unmap workspace::zoom "ctrl-f2": "debugger::Stop", "f6": "debugger::Pause", "f7": "debugger::StepInto", @@ -96,7 +97,7 @@ "cmd-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], "ctrl-alt-r": "task::Spawn", "cmd-e": "file_finder::Toggle", - "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor + // "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "cmd-shift-o": "file_finder::Toggle", "cmd-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", @@ -108,6 +109,21 @@ "cmd-7": "outline_panel::ToggleFocus" } }, + { + "context": "Pane", // this is to override the default Pane mappings to switch tabs + "bindings": { + "cmd-1": "project_panel::ToggleFocus", + "cmd-2": null, // Bookmarks (left dock) + "cmd-3": null, // Find Panel (bottom dock) + "cmd-4": null, // Run Panel (bottom dock) + "cmd-5": "debug_panel::ToggleFocus", + "cmd-6": "diagnostics::Deploy", + "cmd-7": "outline_panel::ToggleFocus", + "cmd-8": null, // Services (bottom dock) + "cmd-9": null, // Git History (bottom dock) + "cmd-0": "git_panel::ToggleFocus" + } + }, { "context": "Workspace || Editor", "bindings": { @@ -146,11 +162,15 @@ } }, { "context": "GitPanel", "bindings": { "cmd-0": "workspace::CloseActiveDock" } }, + { "context": "ProjectPanel", "bindings": { "cmd-1": "workspace::CloseActiveDock" } }, { "context": "DebugPanel", "bindings": { "cmd-5": "workspace::CloseActiveDock" } }, { "context": "Diagnostics > Editor", "bindings": { "cmd-6": "pane::CloseActiveItem" } }, { "context": "OutlinePanel", "bindings": { "cmd-7": "workspace::CloseActiveDock" } }, { - "context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", - "bindings": { "escape": "editor::ToggleFocus" } + "context": "Dock || Workspace || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", + "bindings": { + "escape": "editor::ToggleFocus", + "shift-escape": "workspace::CloseActiveDock" + } } ] diff --git a/assets/settings/default.json b/assets/settings/default.json index 3a7a48efc2..4734b5d118 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1877,5 +1877,25 @@ "save_breakpoints": true, "dock": "bottom", "button": true - } + }, + // Configures any number of settings profiles that are temporarily applied on + // top of your existing user settings when selected from + // `settings profile selector: toggle`. + // Examples: + // "profiles": { + // "Presenting": { + // "agent_font_size": 20.0, + // "buffer_font_size": 20.0, + // "theme": "One Light", + // "ui_font_size": 20.0 + // }, + // "Python (ty)": { + // "languages": { + // "Python": { + // "language_servers": ["ty"] + // } + // } + // } + // } + "profiles": [] } diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 011f26f364..225597415c 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -17,7 +17,6 @@ test-support = ["gpui/test-support", "project/test-support"] [dependencies] agent-client-protocol.workspace = true -agentic-coding-protocol.workspace = true anyhow.workspace = true assistant_tool.workspace = true buffer_diff.workspace = true @@ -37,11 +36,12 @@ util.workspace = true workspace-hack.workspace = true [dev-dependencies] -async-pipe.workspace = true env_logger.workspace = true gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true +parking_lot.workspace = true project = { workspace = true, "features" = ["test-support"] } +rand.workspace = true tempfile.workspace = true util.workspace = true settings.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d572992c54..44190a4860 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,7 +1,5 @@ mod connection; -mod old_acp_support; pub use connection::*; -pub use old_acp_support::*; use agent_client_protocol as acp; use anyhow::{Context as _, Result}; @@ -180,7 +178,7 @@ impl ToolCall { id: tool_call.id, label: cx.new(|cx| { Markdown::new( - tool_call.label.into(), + tool_call.title.into(), Some(language_registry.clone()), None, cx, @@ -207,7 +205,7 @@ impl ToolCall { let acp::ToolCallUpdateFields { kind, status, - label, + title, content, locations, raw_input, @@ -221,8 +219,8 @@ impl ToolCall { self.status = ToolCallStatus::Allowed { status }; } - if let Some(label) = label { - self.label = cx.new(|cx| Markdown::new_text(label.into(), cx)); + if let Some(title) = title { + self.label = cx.new(|cx| Markdown::new_text(title.into(), cx)); } if let Some(content) = content { @@ -391,7 +389,7 @@ impl ToolCallContent { cx: &mut App, ) -> Self { match content { - acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock { + acp::ToolCallContent::Content { content } => Self::ContentBlock { content: ContentBlock::new(content, &language_registry, cx), }, acp::ToolCallContent::Diff { diff } => Self::Diff { @@ -580,6 +578,9 @@ pub struct AcpThread { pub enum AcpThreadEvent { NewEntry, EntryUpdated(usize), + ToolAuthorizationRequired, + Stopped, + Error, } impl EventEmitter for AcpThread {} @@ -616,6 +617,7 @@ impl Error for LoadError {} impl AcpThread { pub fn new( + title: impl Into, connection: Rc, project: Entity, session_id: acp::SessionId, @@ -628,7 +630,7 @@ impl AcpThread { shared_buffers: Default::default(), entries: Default::default(), plan: Default::default(), - title: connection.name().into(), + title: title.into(), project, send_task: None, connection, @@ -668,7 +670,18 @@ impl AcpThread { for entry in self.entries.iter().rev() { match entry { AgentThreadEntry::UserMessage(_) => return false, - AgentThreadEntry::ToolCall(call) if call.diffs().next().is_some() => return true, + AgentThreadEntry::ToolCall( + call @ ToolCall { + status: + ToolCallStatus::Allowed { + status: + acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending, + }, + .. + }, + ) if call.diffs().next().is_some() => { + return true; + } AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {} } } @@ -676,20 +689,32 @@ impl AcpThread { false } + pub fn used_tools_since_last_user_message(&self) -> bool { + for entry in self.entries.iter().rev() { + match entry { + AgentThreadEntry::UserMessage(..) => return false, + AgentThreadEntry::AssistantMessage(..) => continue, + AgentThreadEntry::ToolCall(..) => return true, + } + } + + false + } + pub fn handle_session_update( &mut self, update: acp::SessionUpdate, cx: &mut Context, ) -> Result<()> { match update { - acp::SessionUpdate::UserMessage(content_block) => { - self.push_user_content_block(content_block, cx); + acp::SessionUpdate::UserMessageChunk { content } => { + self.push_user_content_block(content, cx); } - acp::SessionUpdate::AgentMessageChunk(content_block) => { - self.push_assistant_content_block(content_block, false, cx); + acp::SessionUpdate::AgentMessageChunk { content } => { + self.push_assistant_content_block(content, false, cx); } - acp::SessionUpdate::AgentThoughtChunk(content_block) => { - self.push_assistant_content_block(content_block, true, cx); + acp::SessionUpdate::AgentThoughtChunk { content } => { + self.push_assistant_content_block(content, true, cx); } acp::SessionUpdate::ToolCall(tool_call) => { self.upsert_tool_call(tool_call, cx); @@ -879,6 +904,7 @@ impl AcpThread { }; self.upsert_tool_call_inner(tool_call, status, cx); + cx.emit(AcpThreadEvent::ToolAuthorizationRequired); rx } @@ -957,10 +983,6 @@ impl AcpThread { cx.notify(); } - pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future> { - self.connection.authenticate(cx) - } - #[cfg(any(test, feature = "test-support"))] pub fn send_raw( &mut self, @@ -1002,7 +1024,7 @@ impl AcpThread { let result = this .update(cx, |this, cx| { this.connection.prompt( - acp::PromptArguments { + acp::PromptRequest { prompt: message, session_id: this.session_id.clone(), }, @@ -1018,12 +1040,18 @@ impl AcpThread { .log_err(); })); - async move { - match rx.await { - Ok(Err(e)) => Err(e)?, - _ => Ok(()), + cx.spawn(async move |this, cx| match rx.await { + Ok(Err(e)) => { + this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error)) + .log_err(); + Err(e)? } - } + _ => { + this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped)) + .log_err(); + Ok(()) + } + }) .boxed() } @@ -1206,16 +1234,15 @@ impl AcpThread { #[cfg(test)] mod tests { use super::*; - use agentic_coding_protocol as acp_old; use anyhow::anyhow; - use async_pipe::{PipeReader, PipeWriter}; use futures::{channel::mpsc, future::LocalBoxFuture, select}; - use gpui::{AsyncApp, TestAppContext}; + use gpui::{AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; use project::FakeFs; + use rand::Rng as _; use serde_json::json; use settings::SettingsStore; - use smol::{future::BoxedLocal, stream::StreamExt as _}; + use smol::stream::StreamExt as _; use std::{cell::RefCell, rc::Rc, time::Duration}; use util::path; @@ -1236,7 +1263,15 @@ mod tests { let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; - let (thread, _fake_server) = fake_acp_thread(project, cx); + let connection = Rc::new(FakeAgentConnection::new()); + let thread = cx + .spawn(async move |mut cx| { + connection + .new_thread(project, Path::new(path!("/test")), &mut cx) + .await + }) + .await + .unwrap(); // Test creating a new user message thread.update(cx, |thread, cx| { @@ -1316,34 +1351,40 @@ mod tests { let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; - let (thread, fake_server) = fake_acp_thread(project, cx); + let connection = Rc::new(FakeAgentConnection::new().on_user_message( + |_, thread, mut cx| { + async move { + thread.update(&mut cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: "Thinking ".into(), + }, + cx, + ) + .unwrap(); + thread + .handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: "hard!".into(), + }, + cx, + ) + .unwrap(); + }) + } + .boxed_local() + }, + )); - fake_server.update(cx, |fake_server, _| { - fake_server.on_user_message(move |_, server, mut cx| async move { - server - .update(&mut cx, |server, _| { - server.send_to_zed(acp_old::StreamAssistantMessageChunkParams { - chunk: acp_old::AssistantMessageChunk::Thought { - thought: "Thinking ".into(), - }, - }) - })? + let thread = cx + .spawn(async move |mut cx| { + connection + .new_thread(project, Path::new(path!("/test")), &mut cx) .await - .unwrap(); - server - .update(&mut cx, |server, _| { - server.send_to_zed(acp_old::StreamAssistantMessageChunkParams { - chunk: acp_old::AssistantMessageChunk::Thought { - thought: "hard!".into(), - }, - }) - })? - .await - .unwrap(); - - Ok(()) }) - }); + .await + .unwrap(); thread .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) @@ -1376,7 +1417,38 @@ mod tests { fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"})) .await; let project = Project::test(fs.clone(), [], cx).await; - let (thread, fake_server) = fake_acp_thread(project.clone(), cx); + let (read_file_tx, read_file_rx) = oneshot::channel::<()>(); + let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx))); + let connection = Rc::new(FakeAgentConnection::new().on_user_message( + move |_, thread, mut cx| { + let read_file_tx = read_file_tx.clone(); + async move { + let content = thread + .update(&mut cx, |thread, cx| { + thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx) + }) + .unwrap() + .await + .unwrap(); + assert_eq!(content, "one\ntwo\nthree\n"); + read_file_tx.take().unwrap().send(()).unwrap(); + thread + .update(&mut cx, |thread, cx| { + thread.write_text_file( + path!("/tmp/foo").into(), + "one\ntwo\nthree\nfour\nfive\n".to_string(), + cx, + ) + }) + .unwrap() + .await + .unwrap(); + Ok(()) + } + .boxed_local() + }, + )); + let (worktree, pathbuf) = project .update(cx, |project, cx| { project.find_or_create_worktree(path!("/tmp/foo"), true, cx) @@ -1390,38 +1462,10 @@ mod tests { .await .unwrap(); - let (read_file_tx, read_file_rx) = oneshot::channel::<()>(); - let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx))); - - fake_server.update(cx, |fake_server, _| { - fake_server.on_user_message(move |_, server, mut cx| { - let read_file_tx = read_file_tx.clone(); - async move { - let content = server - .update(&mut cx, |server, _| { - server.send_to_zed(acp_old::ReadTextFileParams { - path: path!("/tmp/foo").into(), - line: None, - limit: None, - }) - })? - .await - .unwrap(); - assert_eq!(content.content, "one\ntwo\nthree\n"); - read_file_tx.take().unwrap().send(()).unwrap(); - server - .update(&mut cx, |server, _| { - server.send_to_zed(acp_old::WriteTextFileParams { - path: path!("/tmp/foo").into(), - content: "one\ntwo\nthree\nfour\nfive\n".to_string(), - }) - })? - .await - .unwrap(); - Ok(()) - } - }) - }); + let thread = cx + .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx)) + .await + .unwrap(); let request = thread.update(cx, |thread, cx| { thread.send_raw("Extend the count in /tmp/foo", cx) @@ -1448,36 +1492,44 @@ mod tests { let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; - let (thread, fake_server) = fake_acp_thread(project, cx); + let id = acp::ToolCallId("test".into()); - let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>(); - - let tool_call_id = Rc::new(RefCell::new(None)); - let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx))); - fake_server.update(cx, |fake_server, _| { - let tool_call_id = tool_call_id.clone(); - fake_server.on_user_message(move |_, server, mut cx| { - let end_turn_rx = end_turn_rx.clone(); - let tool_call_id = tool_call_id.clone(); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let id = id.clone(); + move |_, thread, mut cx| { + let id = id.clone(); async move { - let tool_call_result = server - .update(&mut cx, |server, _| { - server.send_to_zed(acp_old::PushToolCallParams { - label: "Fetch".to_string(), - icon: acp_old::Icon::Globe, - content: None, - locations: vec![], - }) - })? - .await + thread + .update(&mut cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: id.clone(), + title: "Label".into(), + kind: acp::ToolKind::Fetch, + status: acp::ToolCallStatus::InProgress, + content: vec![], + locations: vec![], + raw_input: None, + }), + cx, + ) + }) + .unwrap() .unwrap(); - *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id); - end_turn_rx.take().unwrap().await.ok(); - Ok(()) } + .boxed_local() + } + })); + + let thread = cx + .spawn(async move |mut cx| { + connection + .new_thread(project, Path::new(path!("/test")), &mut cx) + .await }) - }); + .await + .unwrap(); let request = thread.update(cx, |thread, cx| { thread.send_raw("Fetch https://example.com", cx) @@ -1498,8 +1550,6 @@ mod tests { )); }); - cx.run_until_parked(); - thread.update(cx, |thread, cx| thread.cancel(cx)).await; thread.read_with(cx, |thread, _| { @@ -1512,19 +1562,22 @@ mod tests { )); }); - fake_server - .update(cx, |fake_server, _| { - fake_server.send_to_zed(acp_old::UpdateToolCallParams { - tool_call_id: tool_call_id.borrow().unwrap(), - status: acp_old::ToolCallStatus::Finished, - content: None, - }) + thread + .update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate { + id, + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + ..Default::default() + }, + }), + cx, + ) }) - .await .unwrap(); - drop(end_turn_tx); - assert!(request.await.unwrap_err().to_string().contains("canceled")); + request.await.unwrap(); thread.read_with(cx, |thread, _| { assert!(matches!( @@ -1540,6 +1593,56 @@ mod tests { }); } + #[gpui::test] + async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + move |_, thread, mut cx| { + async move { + thread + .update(&mut cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("test".into()), + title: "Label".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/test/test.txt".into(), + old_text: None, + new_text: "foo".into(), + }, + }], + locations: vec![], + raw_input: None, + }), + cx, + ) + }) + .unwrap() + .unwrap(); + Ok(()) + } + .boxed_local() + } + })); + + let thread = connection + .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + .await + .unwrap(); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) + .await + .unwrap(); + + assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); + } + async fn run_until_first_tool_call( thread: &Entity, cx: &mut TestAppContext, @@ -1567,168 +1670,108 @@ mod tests { } } - pub fn fake_acp_thread( - project: Entity, - cx: &mut TestAppContext, - ) -> (Entity, Entity) { - let (stdin_tx, stdin_rx) = async_pipe::pipe(); - let (stdout_tx, stdout_rx) = async_pipe::pipe(); - - let thread = cx.new(|cx| { - let foreground_executor = cx.foreground_executor().clone(); - let thread_rc = Rc::new(RefCell::new(cx.entity().downgrade())); - - let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(thread_rc.clone(), cx.to_async()), - stdin_tx, - stdout_rx, - move |fut| { - foreground_executor.spawn(fut).detach(); - }, - ); - - let io_task = cx.background_spawn({ - async move { - io_fut.await.log_err(); - Ok(()) - } - }); - let connection = OldAcpAgentConnection { - name: "test", - connection, - child_status: io_task, - }; - - AcpThread::new( - Rc::new(connection), - project, - acp::SessionId("test".into()), - cx, - ) - }); - let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx))); - (thread, agent) - } - - pub struct FakeAcpServer { - connection: acp_old::ClientConnection, - - _io_task: Task<()>, + #[derive(Clone, Default)] + struct FakeAgentConnection { + auth_methods: Vec, + sessions: Arc>>>, on_user_message: Option< Rc< dyn Fn( - acp_old::SendUserMessageParams, - Entity, - AsyncApp, - ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>, + acp::PromptRequest, + WeakEntity, + AsyncApp, + ) -> LocalBoxFuture<'static, Result<()>> + + 'static, >, >, } - #[derive(Clone)] - struct FakeAgent { - server: Entity, - cx: AsyncApp, - cancel_tx: Rc>>>, - } - - impl acp_old::Agent for FakeAgent { - async fn initialize( - &self, - params: acp_old::InitializeParams, - ) -> Result { - Ok(acp_old::InitializeResponse { - protocol_version: params.protocol_version, - is_authenticated: true, - }) - } - - async fn authenticate(&self) -> Result<(), acp_old::Error> { - Ok(()) - } - - async fn cancel_send_message(&self) -> Result<(), acp_old::Error> { - if let Some(cancel_tx) = self.cancel_tx.take() { - cancel_tx.send(()).log_err(); - } - Ok(()) - } - - async fn send_user_message( - &self, - request: acp_old::SendUserMessageParams, - ) -> Result<(), acp_old::Error> { - let (cancel_tx, cancel_rx) = oneshot::channel(); - self.cancel_tx.replace(Some(cancel_tx)); - - let mut cx = self.cx.clone(); - let handler = self - .server - .update(&mut cx, |server, _| server.on_user_message.clone()) - .ok() - .flatten(); - if let Some(handler) = handler { - select! { - _ = cancel_rx.fuse() => Err(anyhow::anyhow!("Message sending canceled").into()), - _ = handler(request, self.server.clone(), self.cx.clone()).fuse() => Ok(()), - } - } else { - Err(anyhow::anyhow!("No handler for on_user_message").into()) - } - } - } - - impl FakeAcpServer { - fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context) -> Self { - let agent = FakeAgent { - server: cx.entity(), - cx: cx.to_async(), - cancel_tx: Default::default(), - }; - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp_old::ClientConnection::connect_to_client( - agent.clone(), - stdout, - stdin, - move |fut| { - foreground_executor.spawn(fut).detach(); - }, - ); - FakeAcpServer { - connection: connection, + impl FakeAgentConnection { + fn new() -> Self { + Self { + auth_methods: Vec::new(), on_user_message: None, - _io_task: cx.background_spawn(async move { - io_fut.await.log_err(); - }), + sessions: Arc::default(), } } - fn on_user_message( - &mut self, - handler: impl for<'a> Fn( - acp_old::SendUserMessageParams, - Entity, - AsyncApp, - ) -> F - + 'static, - ) where - F: Future> + 'static, - { - self.on_user_message - .replace(Rc::new(move |request, server, cx| { - handler(request, server, cx).boxed_local() - })); + #[expect(unused)] + fn with_auth_methods(mut self, auth_methods: Vec) -> Self { + self.auth_methods = auth_methods; + self } - fn send_to_zed( - &self, - message: T, - ) -> BoxedLocal> { - self.connection - .request(message) - .map(|f| f.map_err(|err| anyhow!(err))) - .boxed_local() + fn on_user_message( + mut self, + handler: impl Fn( + acp::PromptRequest, + WeakEntity, + AsyncApp, + ) -> LocalBoxFuture<'static, Result<()>> + + 'static, + ) -> Self { + self.on_user_message.replace(Rc::new(handler)); + self + } + } + + impl AgentConnection for FakeAgentConnection { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + let session_id = acp::SessionId( + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .into(), + ); + let thread = cx + .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) + .unwrap(); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task> { + if self.auth_methods().iter().any(|m| m.id == method) { + Task::ready(Ok(())) + } else { + Task::ready(Err(anyhow!("Invalid Auth Method"))) + } + } + + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + if let Some(handler) = &self.on_user_message { + let handler = handler.clone(); + let thread = thread.clone(); + cx.spawn(async move |cx| handler(params, thread, cx.clone()).await) + } else { + Task::ready(Ok(())) + } + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let sessions = self.sessions.lock(); + let thread = sessions.get(&session_id).unwrap().clone(); + + cx.spawn(async move |cx| { + thread + .update(cx, |thread, cx| thread.cancel(cx)) + .unwrap() + .await + }) + .detach(); } } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 5b25b71863..929500a67b 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,6 +1,6 @@ -use std::{path::Path, rc::Rc}; +use std::{error::Error, fmt, path::Path, rc::Rc}; -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp}; use anyhow::Result; use gpui::{AsyncApp, Entity, Task}; use project::Project; @@ -9,8 +9,6 @@ use ui::App; use crate::AcpThread; pub trait AgentConnection { - fn name(&self) -> &'static str; - fn new_thread( self: Rc, project: Entity, @@ -18,9 +16,21 @@ pub trait AgentConnection { cx: &mut AsyncApp, ) -> Task>>; - fn authenticate(&self, cx: &mut App) -> Task>; + fn auth_methods(&self) -> &[acp::AuthMethod]; - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task>; + fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; + + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task>; fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); } + +#[derive(Debug)] +pub struct AuthRequired; + +impl Error for AuthRequired {} +impl fmt::Display for AuthRequired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthRequired") + } +} diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 135363ab65..7bc0e82cad 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -25,6 +25,7 @@ assistant_context.workspace = true assistant_tool.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true component.workspace = true context_server.workspace = true @@ -35,9 +36,9 @@ futures.workspace = true git.workspace = true gpui.workspace = true heed.workspace = true +http_client.workspace = true icons.workspace = true indoc.workspace = true -http_client.workspace = true itertools.workspace = true language.workspace = true language_model.workspace = true @@ -46,7 +47,6 @@ paths.workspace = true postage.workspace = true project.workspace = true prompt_store.workspace = true -proto.workspace = true ref-cast.workspace = true rope.workspace = true schemars.workspace = true @@ -63,7 +63,6 @@ time.workspace = true util.workspace = true uuid.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true zstd.workspace = true [dev-dependencies] diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs index a89857e71a..34ea1c8df7 100644 --- a/crates/agent/src/agent_profile.rs +++ b/crates/agent/src/agent_profile.rs @@ -308,7 +308,12 @@ mod tests { unimplemented!() } - fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + fn needs_confirmation( + &self, + _input: &serde_json::Value, + _project: &Entity, + _cx: &App, + ) -> bool { unimplemented!() } diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index ddd13de491..cd366b8308 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -42,8 +42,8 @@ impl ContextKind { ContextKind::Symbol => IconName::Code, ContextKind::Selection => IconName::Context, ContextKind::FetchedUrl => IconName::Globe, - ContextKind::Thread => IconName::MessageBubbles, - ContextKind::TextThread => IconName::MessageBubbles, + ContextKind::Thread => IconName::Thread, + ContextKind::TextThread => IconName::TextThread, ContextKind::Rules => RULES_ICON, ContextKind::Image => IconName::Image, } diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index 4c6d2b2b0b..85e8ac7451 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -47,7 +47,7 @@ impl Tool for ContextServerTool { } } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { true } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 1af27ca8a7..8558dd528d 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -13,6 +13,7 @@ use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; use collections::HashMap; use feature_flags::{self, FeatureFlagAppExt}; use futures::{FutureExt, StreamExt as _, future::Shared}; @@ -36,7 +37,6 @@ use project::{ git_store::{GitStore, GitStoreCheckpoint, RepositoryState}, }; use prompt_store::{ModelContext, PromptBuilder}; -use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -49,7 +49,6 @@ use std::{ use thiserror::Error; use util::{ResultExt as _, post_inc}; use uuid::Uuid; -use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; const MAX_RETRY_ATTEMPTS: u8 = 4; const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); @@ -942,7 +941,7 @@ impl Thread { } pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { - self.tool_use.tool_uses_for_message(id, cx) + self.tool_use.tool_uses_for_message(id, &self.project, cx) } pub fn tool_results_for_message( @@ -1681,7 +1680,7 @@ impl Thread { let completion_mode = request .mode - .unwrap_or(zed_llm_client::CompletionMode::Normal); + .unwrap_or(cloud_llm_client::CompletionMode::Normal); self.last_received_chunk_at = Some(Instant::now()); @@ -2557,7 +2556,7 @@ impl Thread { return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx); } - if tool.needs_confirmation(&tool_use.input, cx) + if tool.needs_confirmation(&tool_use.input, &self.project, cx) && !AgentSettings::get_global(cx).always_allow_tool_actions { self.tool_use.confirm_tool_use( @@ -3255,8 +3254,10 @@ impl Thread { } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { - self.project.update(cx, |project, cx| { - project.user_store().update(cx, |user_store, cx| { + self.project + .read(cx) + .user_store() + .update(cx, |user_store, cx| { user_store.update_model_request_usage( ModelRequestUsage(RequestUsage { amount: amount as i32, @@ -3264,8 +3265,7 @@ impl Thread { }), cx, ) - }) - }); + }); } pub fn deny_tool_use( diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 74c719b4e6..7392c0878d 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -165,7 +165,12 @@ impl ToolUseState { self.pending_tool_uses_by_id.values().collect() } - pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { + pub fn tool_uses_for_message( + &self, + id: MessageId, + project: &Entity, + cx: &App, + ) -> Vec { let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { return Vec::new(); }; @@ -211,7 +216,10 @@ impl ToolUseState { let (icon, needs_confirmation) = if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) { - (tool.icon(), tool.needs_confirmation(&tool_use.input, cx)) + ( + tool.icon(), + tool.needs_confirmation(&tool_use.input, project, cx), + ) } else { (IconName::Cog, false) }; diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index dcffb05bc0..81c97c8aa6 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -25,6 +25,7 @@ collections.workspace = true context_server.workspace = true futures.workspace = true gpui.workspace = true +indoc.workspace = true itertools.workspace = true log.workspace = true paths.workspace = true @@ -37,11 +38,11 @@ settings.workspace = true smol.workspace = true strum.workspace = true tempfile.workspace = true +thiserror.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true watch.workspace = true -indoc.workspace = true which.workspace = true workspace-hack.workspace = true diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs new file mode 100644 index 0000000000..00e3e3df50 --- /dev/null +++ b/crates/agent_servers/src/acp.rs @@ -0,0 +1,34 @@ +use std::{path::Path, rc::Rc}; + +use crate::AgentServerCommand; +use acp_thread::AgentConnection; +use anyhow::Result; +use gpui::AsyncApp; +use thiserror::Error; + +mod v0; +mod v1; + +#[derive(Debug, Error)] +#[error("Unsupported version")] +pub struct UnsupportedVersion; + +pub async fn connect( + server_name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, +) -> Result> { + let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await; + + match conn { + Ok(conn) => Ok(Rc::new(conn) as _), + Err(err) if err.is::() => { + // Consider re-using initialize response and subprocess when adding another version here + let conn: Rc = + Rc::new(v0::AcpConnection::stdio(server_name, command, &root_dir, cx).await?); + Ok(conn) + } + Err(err) => Err(err), + } +} diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/agent_servers/src/acp/v0.rs similarity index 82% rename from crates/acp_thread/src/old_acp_support.rs rename to crates/agent_servers/src/acp/v0.rs index 44cd00348f..3dcda4ce8d 100644 --- a/crates/acp_thread/src/old_acp_support.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -1,17 +1,19 @@ // Translates old acp agents into the new schema use agent_client_protocol as acp; use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; -use anyhow::{Context as _, Result}; +use anyhow::{Context as _, Result, anyhow}; use futures::channel::oneshot; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; -use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc}; +use std::{cell::RefCell, path::Path, rc::Rc}; use ui::App; +use util::ResultExt as _; -use crate::{AcpThread, AgentConnection}; +use crate::AgentServerCommand; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; #[derive(Clone)] -pub struct OldAcpClientDelegate { +struct OldAcpClientDelegate { thread: Rc>>, cx: AsyncApp, next_tool_call_id: Rc>, @@ -19,7 +21,7 @@ pub struct OldAcpClientDelegate { } impl OldAcpClientDelegate { - pub fn new(thread: Rc>>, cx: AsyncApp) -> Self { + fn new(thread: Rc>>, cx: AsyncApp) -> Self { Self { thread, cx, @@ -46,7 +48,7 @@ impl acp_old::Client for OldAcpClientDelegate { thread.push_assistant_content_block(thought.into(), true, cx) } }) - .ok(); + .log_err(); })?; Ok(()) @@ -125,7 +127,7 @@ impl acp_old::Client for OldAcpClientDelegate { outcomes.push(outcome); acp_options.push(acp::PermissionOption { id: acp::PermissionOptionId(index.to_string().into()), - label, + name: label, kind, }) } @@ -264,7 +266,7 @@ impl acp_old::Client for OldAcpClientDelegate { fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall { acp::ToolCall { id: id, - label: request.label, + title: request.label, kind: acp_kind_from_old_icon(request.icon), status: acp::ToolCallStatus::InProgress, content: request @@ -350,27 +352,71 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu } } -#[derive(Debug)] -pub struct Unauthenticated; - -impl Error for Unauthenticated {} -impl fmt::Display for Unauthenticated { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Unauthenticated") - } -} - -pub struct OldAcpAgentConnection { +pub struct AcpConnection { pub name: &'static str, pub connection: acp_old::AgentConnection, - pub child_status: Task>, + pub _child_status: Task>, + pub current_thread: Rc>>, } -impl AgentConnection for OldAcpAgentConnection { - fn name(&self) -> &'static str { - self.name - } +impl AcpConnection { + pub fn stdio( + name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, + ) -> Task> { + let root_dir = root_dir.to_path_buf(); + cx.spawn(async move |cx| { + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + let foreground_executor = cx.foreground_executor().clone(); + + let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => Err(anyhow!(result)), + }; + drop(io_task); + result + }); + + Ok(Self { + name, + connection, + _child_status: child_status, + current_thread: thread_rc, + }) + }) + } +} + +impl AgentConnection for AcpConnection { fn new_thread( self: Rc, project: Entity, @@ -383,25 +429,31 @@ impl AgentConnection for OldAcpAgentConnection { } .into_any(), ); + let current_thread = self.current_thread.clone(); cx.spawn(async move |cx| { let result = task.await?; let result = acp_old::InitializeParams::response_from_any(result)?; if !result.is_authenticated { - anyhow::bail!(Unauthenticated) + anyhow::bail!(AuthRequired) } cx.update(|cx| { let thread = cx.new(|cx| { let session_id = acp::SessionId("acp-old-no-id".into()); - AcpThread::new(self.clone(), project, session_id, cx) + AcpThread::new(self.name, self.clone(), project, session_id, cx) }); + current_thread.replace(thread.downgrade()); thread }) }) } - fn authenticate(&self, cx: &mut App) -> Task> { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { let task = self .connection .request_any(acp_old::AuthenticateParams.into_any()); @@ -411,7 +463,7 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let chunks = params .prompt .into_iter() diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs new file mode 100644 index 0000000000..a4f0e996b5 --- /dev/null +++ b/crates/agent_servers/src/acp/v1.rs @@ -0,0 +1,260 @@ +use agent_client_protocol::{self as acp, Agent as _}; +use anyhow::anyhow; +use collections::HashMap; +use futures::channel::oneshot; +use project::Project; +use std::cell::RefCell; +use std::path::Path; +use std::rc::Rc; + +use anyhow::{Context as _, Result}; +use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; + +use crate::{AgentServerCommand, acp::UnsupportedVersion}; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; + +pub struct AcpConnection { + server_name: &'static str, + connection: Rc, + sessions: Rc>>, + auth_methods: Vec, + _io_task: Task>, + _child: smol::process::Child, +} + +pub struct AcpSession { + thread: WeakEntity, +} + +const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1; + +impl AcpConnection { + pub async fn stdio( + server_name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, + ) -> Result { + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter().map(|arg| arg.as_str())) + .envs(command.env.iter().flatten()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdout = child.stdout.take().expect("Failed to take stdout"); + let stdin = child.stdin.take().expect("Failed to take stdin"); + + let sessions = Rc::new(RefCell::new(HashMap::default())); + + let client = ClientDelegate { + sessions: sessions.clone(), + cx: cx.clone(), + }; + let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, { + let foreground_executor = cx.foreground_executor().clone(); + move |fut| { + foreground_executor.spawn(fut).detach(); + } + }); + + let io_task = cx.background_spawn(io_task); + + let response = connection + .initialize(acp::InitializeRequest { + protocol_version: acp::VERSION, + client_capabilities: acp::ClientCapabilities { + fs: acp::FileSystemCapability { + read_text_file: true, + write_text_file: true, + }, + }, + }) + .await?; + + if response.protocol_version < MINIMUM_SUPPORTED_VERSION { + return Err(UnsupportedVersion.into()); + } + + Ok(Self { + auth_methods: response.auth_methods, + connection: connection.into(), + server_name, + sessions, + _child: child, + _io_task: io_task, + }) + } +} + +impl AgentConnection for AcpConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let conn = self.connection.clone(); + let sessions = self.sessions.clone(); + let cwd = cwd.to_path_buf(); + cx.spawn(async move |cx| { + let response = conn + .new_session(acp::NewSessionRequest { + mcp_servers: vec![], + cwd, + }) + .await + .map_err(|err| { + if err.code == acp::ErrorCode::AUTH_REQUIRED.code { + anyhow!(AuthRequired) + } else { + anyhow!(err) + } + })?; + + let session_id = response.session_id; + + let thread = cx.new(|cx| { + AcpThread::new( + self.server_name, + self.clone(), + project, + session_id.clone(), + cx, + ) + })?; + + let session = AcpSession { + thread: thread.downgrade(), + }; + sessions.borrow_mut().insert(session_id, session); + + Ok(thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor().spawn(async move { + let result = conn + .authenticate(acp::AuthenticateRequest { + method_id: method_id.clone(), + }) + .await?; + + Ok(result) + }) + } + + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor() + .spawn(async move { Ok(conn.prompt(params).await?) }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let conn = self.connection.clone(); + let params = acp::CancelNotification { + session_id: session_id.clone(), + }; + cx.foreground_executor() + .spawn(async move { conn.cancel(params).await }) + .detach(); + } +} + +struct ClientDelegate { + sessions: Rc>>, + cx: AsyncApp, +} + +impl acp::Client for ClientDelegate { + async fn request_permission( + &self, + arguments: acp::RequestPermissionRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let rx = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx) + })?; + + let result = rx.await; + + let outcome = match result { + Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, + Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled, + }; + + Ok(acp::RequestPermissionResponse { outcome }) + } + + async fn write_text_file( + &self, + arguments: acp::WriteTextFileRequest, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let task = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.write_text_file(arguments.path, arguments.content, cx) + })?; + + task.await?; + + Ok(()) + } + + async fn read_text_file( + &self, + arguments: acp::ReadTextFileRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let task = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx) + })?; + + let content = task.await?; + + Ok(acp::ReadTextFileResponse { content }) + } + + async fn session_notification( + &self, + notification: acp::SessionNotification, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let sessions = self.sessions.borrow(); + let session = sessions + .get(¬ification.session_id) + .context("Failed to get session")?; + + session.thread.update(cx, |thread, cx| { + thread.handle_session_update(notification.update, cx) + })??; + + Ok(()) + } +} diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 212bb74d8a..ec69290206 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,14 +1,12 @@ +mod acp; mod claude; -mod codex; mod gemini; -mod mcp_server; mod settings; #[cfg(test)] mod e2e_tests; pub use claude::*; -pub use codex::*; pub use gemini::*; pub use settings::*; @@ -38,7 +36,6 @@ pub trait AgentServer: Send { fn connect( &self, - // these will go away when old_acp is fully removed root_dir: &Path, project: &Entity, cx: &mut App, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 4b48dbf3c1..9040b83085 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -70,10 +70,6 @@ struct ClaudeAgentConnection { } impl AgentConnection for ClaudeAgentConnection { - fn name(&self) -> &'static str { - ClaudeCode.name() - } - fn new_thread( self: Rc, project: Entity, @@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection { } }); - let thread = - cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?; + let thread = cx.new(|cx| { + AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) + })?; thread_tx.send(thread.downgrade())?; @@ -186,11 +183,15 @@ impl AgentConnection for ClaudeAgentConnection { }) } - fn authenticate(&self, _cx: &mut App) -> Task> { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { Task::ready(Err(anyhow!("Authentication not supported"))) } - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let sessions = self.sessions.borrow(); let Some(session) = sessions.get(¶ms.session_id) else { return Task::ready(Err(anyhow!( @@ -438,7 +439,7 @@ impl ClaudeAgentSession { } } } - SdkMessage::System { .. } => {} + SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {} } } @@ -642,6 +643,8 @@ enum SdkMessage { request_id: String, request: ControlRequest, }, + /// Response to a control request + ControlResponse { response: ControlResponse }, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -651,6 +654,12 @@ enum ControlRequest { Interrupt, } +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ControlResponse { + request_id: String, + subtype: ResultErrorType, +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] enum ResultErrorType { @@ -707,7 +716,7 @@ pub(crate) mod tests { use super::*; use serde_json::json; - crate::common_e2e_tests!(ClaudeCode); + crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow"); pub fn local_command() -> AgentServerCommand { AgentServerCommand { diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index a320a6d37f..c6f8bb5b69 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -42,9 +42,13 @@ impl ClaudeZedMcpServer { } pub fn server_config(&self) -> Result { + #[cfg(not(test))] let zed_path = std::env::current_exe() .context("finding current executable path for use in mcp_server")?; + #[cfg(test)] + let zed_path = crate::e2e_tests::get_zed_path(); + Ok(McpServerConfig { command: zed_path, args: vec![ @@ -154,12 +158,12 @@ impl McpServerTool for PermissionTool { vec![ acp::PermissionOption { id: allow_option_id.clone(), - label: "Allow".into(), + name: "Allow".into(), kind: acp::PermissionOptionKind::AllowOnce, }, acp::PermissionOption { id: reject_option_id.clone(), - label: "Reject".into(), + name: "Reject".into(), kind: acp::PermissionOptionKind::RejectOnce, }, ], @@ -174,6 +178,7 @@ impl McpServerTool for PermissionTool { updated_input: input.input, } } else { + debug_assert_eq!(chosen_option, reject_option_id); PermissionToolResponse { behavior: PermissionToolBehavior::Deny, updated_input: input.input, diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index 6acb6355aa..e7d33e5298 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -308,7 +308,7 @@ impl ClaudeTool { id, kind: self.kind(), status: acp::ToolCallStatus::InProgress, - label: self.label(), + title: self.label(), content: self.content(), locations: self.locations(), raw_input: None, diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs deleted file mode 100644 index 3eb95a6841..0000000000 --- a/crates/agent_servers/src/codex.rs +++ /dev/null @@ -1,317 +0,0 @@ -use agent_client_protocol as acp; -use anyhow::anyhow; -use collections::HashMap; -use context_server::listener::McpServerTool; -use context_server::types::requests; -use context_server::{ContextServer, ContextServerCommand, ContextServerId}; -use futures::channel::{mpsc, oneshot}; -use project::Project; -use settings::SettingsStore; -use smol::stream::StreamExt as _; -use std::cell::RefCell; -use std::rc::Rc; -use std::{path::Path, sync::Arc}; -use util::ResultExt; - -use anyhow::{Context, Result}; -use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; - -use crate::mcp_server::ZedMcpServer; -use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server}; -use acp_thread::{AcpThread, AgentConnection}; - -#[derive(Clone)] -pub struct Codex; - -impl AgentServer for Codex { - fn name(&self) -> &'static str { - "Codex" - } - - fn empty_state_headline(&self) -> &'static str { - "Welcome to Codex" - } - - fn empty_state_message(&self) -> &'static str { - "What can I help with?" - } - - fn logo(&self) -> ui::IconName { - ui::IconName::AiOpenAi - } - - fn connect( - &self, - _root_dir: &Path, - project: &Entity, - cx: &mut App, - ) -> Task>> { - let project = project.clone(); - cx.spawn(async move |cx| { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).codex.clone() - })?; - - let Some(command) = - AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await - else { - anyhow::bail!("Failed to find codex binary"); - }; - - let client: Arc = ContextServer::stdio( - ContextServerId("codex-mcp-server".into()), - ContextServerCommand { - path: command.path, - args: command.args, - env: command.env, - }, - ) - .into(); - ContextServer::start(client.clone(), cx).await?; - - let (notification_tx, mut notification_rx) = mpsc::unbounded(); - client - .client() - .context("Failed to subscribe")? - .on_notification(acp::SESSION_UPDATE_METHOD_NAME, { - move |notification, _cx| { - let notification_tx = notification_tx.clone(); - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(notification) = - serde_json::from_value::(notification) - .log_err() - { - notification_tx.unbounded_send(notification).ok(); - } - } - }); - - let sessions = Rc::new(RefCell::new(HashMap::default())); - - let notification_handler_task = cx.spawn({ - let sessions = sessions.clone(); - async move |cx| { - while let Some(notification) = notification_rx.next().await { - CodexConnection::handle_session_notification( - notification, - sessions.clone(), - cx, - ) - } - } - }); - - let connection = CodexConnection { - client, - sessions, - _notification_handler_task: notification_handler_task, - }; - Ok(Rc::new(connection) as _) - }) - } -} - -struct CodexConnection { - client: Arc, - sessions: Rc>>, - _notification_handler_task: Task<()>, -} - -struct CodexSession { - thread: WeakEntity, - cancel_tx: Option>, - _mcp_server: ZedMcpServer, -} - -impl AgentConnection for CodexConnection { - fn name(&self) -> &'static str { - "Codex" - } - - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let client = self.client.client(); - let sessions = self.sessions.clone(); - let cwd = cwd.to_path_buf(); - cx.spawn(async move |cx| { - let client = client.context("MCP server is not initialized yet")?; - let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); - - let mcp_server = ZedMcpServer::new(thread_rx, cx).await?; - - let response = client - .request::(context_server::types::CallToolParams { - name: acp::NEW_SESSION_TOOL_NAME.into(), - arguments: Some(serde_json::to_value(acp::NewSessionArguments { - mcp_servers: [( - mcp_server::SERVER_NAME.to_string(), - mcp_server.server_config()?, - )] - .into(), - client_tools: acp::ClientTools { - request_permission: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::RequestPermissionTool::NAME.into(), - }), - read_text_file: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::ReadTextFileTool::NAME.into(), - }), - write_text_file: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::WriteTextFileTool::NAME.into(), - }), - }, - cwd, - })?), - meta: None, - }) - .await?; - - if response.is_error.unwrap_or_default() { - return Err(anyhow!(response.text_contents())); - } - - let result = serde_json::from_value::( - response.structured_content.context("Empty response")?, - )?; - - let thread = - cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?; - - thread_tx.send(thread.downgrade())?; - - let session = CodexSession { - thread: thread.downgrade(), - cancel_tx: None, - _mcp_server: mcp_server, - }; - sessions.borrow_mut().insert(result.session_id, session); - - Ok(thread) - }) - } - - fn authenticate(&self, _cx: &mut App) -> Task> { - Task::ready(Err(anyhow!("Authentication not supported"))) - } - - fn prompt( - &self, - params: agent_client_protocol::PromptArguments, - cx: &mut App, - ) -> Task> { - let client = self.client.client(); - let sessions = self.sessions.clone(); - - cx.foreground_executor().spawn(async move { - let client = client.context("MCP server is not initialized yet")?; - - let (new_cancel_tx, cancel_rx) = oneshot::channel(); - { - let mut sessions = sessions.borrow_mut(); - let session = sessions - .get_mut(¶ms.session_id) - .context("Session not found")?; - session.cancel_tx.replace(new_cancel_tx); - } - - let result = client - .request_with::( - context_server::types::CallToolParams { - name: acp::PROMPT_TOOL_NAME.into(), - arguments: Some(serde_json::to_value(params)?), - meta: None, - }, - Some(cancel_rx), - None, - ) - .await; - - if let Err(err) = &result - && err.is::() - { - return Ok(()); - } - - let response = result?; - - if response.is_error.unwrap_or_default() { - return Err(anyhow!(response.text_contents())); - } - - Ok(()) - }) - } - - fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) { - let mut sessions = self.sessions.borrow_mut(); - - if let Some(cancel_tx) = sessions - .get_mut(session_id) - .and_then(|session| session.cancel_tx.take()) - { - cancel_tx.send(()).ok(); - } - } -} - -impl CodexConnection { - pub fn handle_session_notification( - notification: acp::SessionNotification, - threads: Rc>>, - cx: &mut AsyncApp, - ) { - let threads = threads.borrow(); - let Some(thread) = threads - .get(¬ification.session_id) - .and_then(|session| session.thread.upgrade()) - else { - log::error!( - "Thread not found for session ID: {}", - notification.session_id - ); - return; - }; - - thread - .update(cx, |thread, cx| { - thread.handle_session_update(notification.update, cx) - }) - .log_err(); - } -} - -impl Drop for CodexConnection { - fn drop(&mut self) { - self.client.stop().log_err(); - } -} - -#[cfg(test)] -pub(crate) mod tests { - use super::*; - use crate::AgentServerCommand; - use std::path::Path; - - crate::common_e2e_tests!(Codex); - - pub fn local_command() -> AgentServerCommand { - let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) - .join("../../../codex/codex-rs/target/debug/codex"); - - AgentServerCommand { - path: cli_path, - args: vec!["mcp".into()], - env: None, - } - } -} diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 905f06a148..a60aefb7b9 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -1,4 +1,8 @@ -use std::{path::Path, sync::Arc, time::Duration}; +use std::{ + path::{Path, PathBuf}, + sync::Arc, + time::Duration, +}; use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; @@ -8,7 +12,6 @@ use futures::{FutureExt, StreamExt, channel::mpsc, select}; use gpui::{Entity, TestAppContext}; use indoc::indoc; use project::{FakeFs, Project}; -use serde_json::json; use settings::{Settings, SettingsStore}; use util::path; @@ -23,7 +26,11 @@ pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppCont .unwrap(); thread.read_with(cx, |thread, _| { - assert_eq!(thread.entries().len(), 2); + assert!( + thread.entries().len() >= 2, + "Expected at least 2 entries. Got: {:?}", + thread.entries() + ); assert!(matches!( thread.entries()[0], AgentThreadEntry::UserMessage(_) @@ -79,37 +86,44 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes .unwrap(); thread.read_with(cx, |thread, cx| { - assert_eq!(thread.entries().len(), 3); assert!(matches!( thread.entries()[0], AgentThreadEntry::UserMessage(_) )); - assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_))); - let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else { - panic!("Expected AssistantMessage") - }; + let assistant_message = &thread + .entries() + .iter() + .rev() + .find_map(|entry| match entry { + AgentThreadEntry::AssistantMessage(msg) => Some(msg), + _ => None, + }) + .unwrap(); + assert!( assistant_message.to_markdown(cx).contains("Hello, world!"), "unexpected assistant message: {:?}", assistant_message.to_markdown(cx) ); }); + + drop(tempdir); } pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) { - let fs = init_test(cx).await; - fs.insert_tree( - path!("/private/tmp"), - json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), - ) - .await; - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let _fs = init_test(cx).await; + + let tempdir = tempfile::tempdir().unwrap(); + let foo_path = tempdir.path().join("foo"); + std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file"); + + let project = Project::example([tempdir.path()], &mut cx.to_async()).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; thread .update(cx, |thread, cx| { thread.send_raw( - "Read the '/private/tmp/foo' file and tell me what you see.", + &format!("Read {} and tell me what you see.", foo_path.display()), cx, ) }) @@ -132,10 +146,13 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) }) ); }); + + drop(tempdir); } -pub async fn test_tool_call_with_confirmation( +pub async fn test_tool_call_with_permission( server: impl AgentServer + 'static, + allow_option_id: acp::PermissionOptionId, cx: &mut TestAppContext, ) { let fs = init_test(cx).await; @@ -143,7 +160,7 @@ pub async fn test_tool_call_with_confirmation( let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let full_turn = thread.update(cx, |thread, cx| { thread.send_raw( - r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#, + r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, cx, ) }); @@ -163,10 +180,10 @@ pub async fn test_tool_call_with_confirmation( ) .await; - let tool_call_id = thread.read_with(cx, |thread, _cx| { + let tool_call_id = thread.read_with(cx, |thread, cx| { let AgentThreadEntry::ToolCall(ToolCall { id, - content, + label, status: ToolCallStatus::WaitingForConfirmation { .. }, .. }) = &thread @@ -178,7 +195,8 @@ pub async fn test_tool_call_with_confirmation( panic!(); }; - assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch"))); + let label = label.read(cx).source(); + assert!(label.contains("touch"), "Got: {}", label); id.clone() }); @@ -186,7 +204,7 @@ pub async fn test_tool_call_with_confirmation( thread.update(cx, |thread, cx| { thread.authorize_tool_call( tool_call_id, - acp::PermissionOptionId("0".into()), + allow_option_id, acp::PermissionOptionKind::AllowOnce, cx, ); @@ -230,7 +248,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let full_turn = thread.update(cx, |thread, cx| { thread.send_raw( - r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#, + r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, cx, ) }); @@ -250,10 +268,10 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon ) .await; - thread.read_with(cx, |thread, _cx| { + thread.read_with(cx, |thread, cx| { let AgentThreadEntry::ToolCall(ToolCall { id, - content, + label, status: ToolCallStatus::WaitingForConfirmation { .. }, .. }) = &thread.entries()[first_tool_call_ix] @@ -261,7 +279,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon panic!("{:?}", thread.entries()[1]); }; - assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch"))); + let label = label.read(cx).source(); + assert!(label.contains("touch"), "Got: {}", label); id.clone() }); @@ -294,7 +313,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon #[macro_export] macro_rules! common_e2e_tests { - ($server:expr) => { + ($server:expr, allow_option_id = $allow_option_id:expr) => { mod common_e2e { use super::*; @@ -318,8 +337,13 @@ macro_rules! common_e2e_tests { #[::gpui::test] #[cfg_attr(not(feature = "e2e"), ignore)] - async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) { - $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await; + async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_tool_call_with_permission( + $server, + ::agent_client_protocol::PermissionOptionId($allow_option_id.into()), + cx, + ) + .await; } #[::gpui::test] @@ -351,9 +375,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc { gemini: Some(AgentServerSettings { command: crate::gemini::tests::local_command(), }), - codex: Some(AgentServerSettings { - command: crate::codex::tests::local_command(), - }), }, cx, ); @@ -412,3 +433,24 @@ pub async fn run_until_first_tool_call( } } } + +pub fn get_zed_path() -> PathBuf { + let mut zed_path = std::env::current_exe().unwrap(); + + while zed_path + .file_name() + .map_or(true, |name| name.to_string_lossy() != "debug") + { + if !zed_path.pop() { + panic!("Could not find target directory"); + } + } + + zed_path.push("zed"); + + if !zed_path.exists() { + panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n"); + } + + zed_path +} diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 47b965cdad..2366783d22 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,14 +1,10 @@ -use anyhow::anyhow; -use std::cell::RefCell; use std::path::Path; use std::rc::Rc; -use util::ResultExt as _; -use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; -use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; -use agentic_coding_protocol as acp_old; -use anyhow::{Context as _, Result}; -use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use crate::{AgentServer, AgentServerCommand}; +use acp_thread::AgentConnection; +use anyhow::Result; +use gpui::{Entity, Task}; use project::Project; use settings::SettingsStore; use ui::App; @@ -43,152 +39,32 @@ impl AgentServer for Gemini { project: &Entity, cx: &mut App, ) -> Task>> { - let root_dir = root_dir.to_path_buf(); let project = project.clone(); - let this = self.clone(); - let name = self.name(); - + let root_dir = root_dir.to_path_buf(); + let server_name = self.name(); cx.spawn(async move |cx| { - let command = this.command(&project, cx).await?; + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() + })?; - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; + let Some(command) = + AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await + else { + anyhow::bail!("Failed to find gemini binary"); + }; - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - let foreground_executor = cx.foreground_executor().clone(); - - let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); - - let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(AgentServerVersion::Unsupported { - error_message, - upgrade_message, - upgrade_command, - }) = this.version(&command).await.log_err() - { - Err(anyhow!(LoadError::Unsupported { - error_message, - upgrade_message, - upgrade_command - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } - } - }; - drop(io_task); - result - }); - - let connection: Rc = Rc::new(OldAcpAgentConnection { - name, - connection, - child_status, - }); - - Ok(connection) + crate::acp::connect(server_name, command, &root_dir, cx).await }) } } -impl Gemini { - async fn command( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).gemini.clone() - })?; - - if let Some(command) = - AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await - { - return Ok(command); - }; - - let (fs, node_runtime) = project.update(cx, |project, _| { - (project.fs().clone(), project.node_runtime().cloned()) - })?; - let node_runtime = node_runtime.context("gemini not found on path")?; - - let directory = ::paths::agent_servers_dir().join("gemini"); - fs.create_dir(&directory).await?; - node_runtime - .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) - .await?; - let path = directory.join("node_modules/.bin/gemini"); - - Ok(AgentServerCommand { - path, - args: vec![ACP_ARG.into()], - env: None, - }) - } - - async fn version(&self, command: &AgentServerCommand) -> Result { - let version_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--version") - .kill_on_drop(true) - .output(); - - let help_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--help") - .kill_on_drop(true) - .output(); - - let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; - - let current_version = String::from_utf8(version_output?.stdout)?; - let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); - - if supported { - Ok(AgentServerVersion::Supported) - } else { - Ok(AgentServerVersion::Unsupported { - error_message: format!( - "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", - current_version - ).into(), - upgrade_message: "Upgrade Gemini to Latest".into(), - upgrade_command: "npm install -g @google/gemini-cli@latest".into(), - }) - } - } -} - #[cfg(test)] pub(crate) mod tests { use super::*; use crate::AgentServerCommand; use std::path::Path; - crate::common_e2e_tests!(Gemini); + crate::common_e2e_tests!(Gemini, allow_option_id = "proceed_once"); pub fn local_command() -> AgentServerCommand { let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) @@ -198,7 +74,7 @@ pub(crate) mod tests { AgentServerCommand { path: "node".into(), - args: vec![cli_path, ACP_ARG.into()], + args: vec![cli_path], env: None, } } diff --git a/crates/agent_servers/src/mcp_server.rs b/crates/agent_servers/src/mcp_server.rs deleted file mode 100644 index 47575fa3ea..0000000000 --- a/crates/agent_servers/src/mcp_server.rs +++ /dev/null @@ -1,201 +0,0 @@ -use acp_thread::AcpThread; -use agent_client_protocol as acp; -use anyhow::{Context, Result}; -use context_server::listener::{McpServerTool, ToolResponse}; -use context_server::types::{ - Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, - ToolsCapabilities, requests, -}; -use futures::channel::oneshot; -use gpui::{App, AsyncApp, Task, WeakEntity}; -use indoc::indoc; - -pub struct ZedMcpServer { - server: context_server::listener::McpServer, -} - -pub const SERVER_NAME: &str = "zed"; - -impl ZedMcpServer { - pub async fn new( - thread_rx: watch::Receiver>, - cx: &AsyncApp, - ) -> Result { - let mut mcp_server = context_server::listener::McpServer::new(cx).await?; - mcp_server.handle_request::(Self::handle_initialize); - - mcp_server.add_tool(RequestPermissionTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(ReadTextFileTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(WriteTextFileTool { - thread_rx: thread_rx.clone(), - }); - - Ok(Self { server: mcp_server }) - } - - pub fn server_config(&self) -> Result { - let zed_path = std::env::current_exe() - .context("finding current executable path for use in mcp_server")?; - - Ok(acp::McpServerConfig { - command: zed_path, - args: vec![ - "--nc".into(), - self.server.socket_path().display().to_string(), - ], - env: None, - }) - } - - fn handle_initialize(_: InitializeParams, cx: &App) -> Task> { - cx.foreground_executor().spawn(async move { - Ok(InitializeResponse { - protocol_version: ProtocolVersion("2025-06-18".into()), - capabilities: ServerCapabilities { - experimental: None, - logging: None, - completions: None, - prompts: None, - resources: None, - tools: Some(ToolsCapabilities { - list_changed: Some(false), - }), - }, - server_info: Implementation { - name: SERVER_NAME.into(), - version: "0.1.0".into(), - }, - meta: None, - }) - }) - } -} - -// Tools - -#[derive(Clone)] -pub struct RequestPermissionTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for RequestPermissionTool { - type Input = acp::RequestPermissionArguments; - type Output = acp::RequestPermissionOutput; - - const NAME: &'static str = "Confirmation"; - - fn description(&self) -> &'static str { - indoc! {" - Request permission for tool calls. - - This tool is meant to be called programmatically by the agent loop, not the LLM. - "} - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let result = thread - .update(cx, |thread, cx| { - thread.request_tool_call_permission(input.tool_call, input.options, cx) - })? - .await; - - let outcome = match result { - Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id }, - Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled, - }; - - Ok(ToolResponse { - content: vec![], - structured_content: acp::RequestPermissionOutput { outcome }, - }) - } -} - -#[derive(Clone)] -pub struct ReadTextFileTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for ReadTextFileTool { - type Input = acp::ReadTextFileArguments; - type Output = acp::ReadTextFileOutput; - - const NAME: &'static str = "Read"; - - fn description(&self) -> &'static str { - "Reads the content of the given file in the project including unsaved changes." - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(input.path, input.line, input.limit, false, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![], - structured_content: acp::ReadTextFileOutput { content }, - }) - } -} - -#[derive(Clone)] -pub struct WriteTextFileTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for WriteTextFileTool { - type Input = acp::WriteTextFileArguments; - type Output = (); - - const NAME: &'static str = "Write"; - - fn description(&self) -> &'static str { - "Write to a file replacing its contents" - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - thread - .update(cx, |thread, cx| { - thread.write_text_file(input.path, input.content, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![], - structured_content: (), - }) - } -} diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs index aeb34a5e61..645674b5f1 100644 --- a/crates/agent_servers/src/settings.rs +++ b/crates/agent_servers/src/settings.rs @@ -13,7 +13,6 @@ pub fn init(cx: &mut App) { pub struct AllAgentServersSettings { pub gemini: Option, pub claude: Option, - pub codex: Option, } #[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] @@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings { fn load(sources: SettingsSources, _: &mut App) -> Result { let mut settings = AllAgentServersSettings::default(); - for AllAgentServersSettings { - gemini, - claude, - codex, - } in sources.defaults_and_customizations() - { + for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() { if gemini.is_some() { settings.gemini = gemini.clone(); } if claude.is_some() { settings.claude = claude.clone(); } - if codex.is_some() { - settings.codex = codex.clone(); - } } Ok(settings) diff --git a/crates/agent_settings/Cargo.toml b/crates/agent_settings/Cargo.toml index 3afe5ae547..d34396a5d3 100644 --- a/crates/agent_settings/Cargo.toml +++ b/crates/agent_settings/Cargo.toml @@ -13,6 +13,7 @@ path = "src/agent_settings.rs" [dependencies] anyhow.workspace = true +cloud_llm_client.workspace = true collections.workspace = true gpui.workspace = true language_model.workspace = true @@ -20,7 +21,6 @@ schemars.workspace = true serde.workspace = true settings.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true [dev-dependencies] fs.workspace = true diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 13b966608c..4e872c78d7 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -321,11 +321,11 @@ pub enum CompletionMode { Burn, } -impl From for zed_llm_client::CompletionMode { +impl From for cloud_llm_client::CompletionMode { fn from(value: CompletionMode) -> Self { match value { - CompletionMode::Normal => zed_llm_client::CompletionMode::Normal, - CompletionMode::Burn => zed_llm_client::CompletionMode::Max, + CompletionMode::Normal => cloud_llm_client::CompletionMode::Normal, + CompletionMode::Burn => cloud_llm_client::CompletionMode::Max, } } } diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index fbd53e8d09..95fd2b1757 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -31,6 +31,7 @@ audio.workspace = true buffer_diff.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true component.workspace = true @@ -46,9 +47,9 @@ futures.workspace = true fuzzy.workspace = true gpui.workspace = true html_to_markdown.workspace = true -indoc.workspace = true http_client.workspace = true indexed_docs.workspace = true +indoc.workspace = true inventory.workspace = true itertools.workspace = true jsonschema.workspace = true @@ -97,7 +98,6 @@ watch.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true [dev-dependencies] assistant_tools.workspace = true diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index e46e1ae3ab..a8e2d59b62 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,5 +1,7 @@ use acp_thread::{AgentConnection, Plan}; use agent_servers::AgentServer; +use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; +use audio::{Audio, Sound}; use std::cell::RefCell; use std::collections::BTreeMap; use std::path::Path; @@ -18,10 +20,10 @@ use editor::{ use file_icons::FileIcons; use gpui::{ Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, - FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, - Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, - Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*, - pulsating_between, + FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString, + StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation, + UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient, + list, percentage, point, prelude::*, pulsating_between, }; use language::language_settings::SoftWrap; use language::{Buffer, Language}; @@ -29,7 +31,7 @@ use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::Project; use settings::Settings as _; -use text::Anchor; +use text::{Anchor, BufferSnapshot}; use theme::ThemeSettings; use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*}; use util::ResultExt; @@ -45,7 +47,10 @@ use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSe use crate::acp::message_history::MessageHistory; use crate::agent_diff::AgentDiff; use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; -use crate::{AgentDiffPane, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll}; +use crate::ui::{AgentNotification, AgentNotificationEvent}; +use crate::{ + AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll, +}; const RESPONSE_PADDING_X: Pixels = px(19.); @@ -56,9 +61,11 @@ pub struct AcpThreadView { thread_state: ThreadState, diff_editors: HashMap>, message_editor: Entity, - message_set_from_history: bool, + message_set_from_history: Option, _message_editor_subscription: Subscription, mention_set: Arc>, + notifications: Vec>, + notification_subscriptions: HashMap, Vec>, last_error: Option>, list_state: ListState, auth_task: Option>, @@ -137,14 +144,28 @@ impl AcpThreadView { editor }); - let message_editor_subscription = cx.subscribe(&message_editor, |this, _, event, _| { - if let editor::EditorEvent::BufferEdited = &event { - if !this.message_set_from_history { - this.message_history.borrow_mut().reset_position(); + let message_editor_subscription = + cx.subscribe(&message_editor, |this, editor, event, cx| { + if let editor::EditorEvent::BufferEdited = &event { + let buffer = editor + .read(cx) + .buffer() + .read(cx) + .as_singleton() + .unwrap() + .read(cx) + .snapshot(); + if let Some(message) = this.message_set_from_history.clone() + && message.version() != buffer.version() + { + this.message_set_from_history = None; + } + + if this.message_set_from_history.is_none() { + this.message_history.borrow_mut().reset_position(); + } } - this.message_set_from_history = false; - } - }); + }); let mention_set = mention_set.clone(); @@ -171,9 +192,11 @@ impl AcpThreadView { project: project.clone(), thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, - message_set_from_history: false, + message_set_from_history: None, _message_editor_subscription: message_editor_subscription, mention_set, + notifications: Vec::new(), + notification_subscriptions: HashMap::default(), diff_editors: Default::default(), list_state: list_state, last_error: None, @@ -223,7 +246,7 @@ impl AcpThreadView { { Err(e) => { let mut cx = cx.clone(); - if e.downcast_ref::().is_some() { + if e.is::() { this.update(&mut cx, |this, cx| { this.thread_state = ThreadState::Unauthenticated { connection }; cx.notify(); @@ -381,7 +404,9 @@ impl AcpThreadView { return; } - let Some(thread) = self.thread() else { return }; + let Some(thread) = self.thread() else { + return; + }; let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); cx.spawn(async move |this, cx| { @@ -399,11 +424,14 @@ impl AcpThreadView { let mention_set = self.mention_set.clone(); self.set_editor_is_expanded(false, cx); + self.message_editor.update(cx, |editor, cx| { editor.clear(window, cx); editor.remove_creases(mention_set.lock().drain(), cx) }); + self.scroll_to_bottom(cx); + self.message_history.borrow_mut().push(chunks); } @@ -413,11 +441,21 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { + if self.message_set_from_history.is_none() && !self.message_editor.read(cx).is_empty(cx) { + self.message_editor.update(cx, |editor, cx| { + editor.move_up(&Default::default(), window, cx); + }); + return; + } + self.message_set_from_history = Self::set_draft_message( self.message_editor.clone(), self.mention_set.clone(), self.project.clone(), - self.message_history.borrow_mut().prev(), + self.message_history + .borrow_mut() + .prev() + .map(|blocks| blocks.as_slice()), window, cx, ); @@ -429,14 +467,35 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - self.message_set_from_history = Self::set_draft_message( + if self.message_set_from_history.is_none() { + self.message_editor.update(cx, |editor, cx| { + editor.move_down(&Default::default(), window, cx); + }); + return; + } + + let mut message_history = self.message_history.borrow_mut(); + let next_history = message_history.next(); + + let set_draft_message = Self::set_draft_message( self.message_editor.clone(), self.mention_set.clone(), self.project.clone(), - self.message_history.borrow_mut().next(), + Some( + next_history + .map(|blocks| blocks.as_slice()) + .unwrap_or_else(|| &[]), + ), window, cx, ); + // If we reset the text to an empty string because we ran out of history, + // we don't want to mark it as coming from the history + self.message_set_from_history = if next_history.is_some() { + set_draft_message + } else { + None + }; } fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { @@ -470,15 +529,13 @@ impl AcpThreadView { message_editor: Entity, mention_set: Arc>, project: Entity, - message: Option<&Vec>, + message: Option<&[acp::ContentBlock]>, window: &mut Window, cx: &mut Context, - ) -> bool { + ) -> Option { cx.notify(); - let Some(message) = message else { - return false; - }; + let message = message?; let mut text = String::new(); let mut mentions = Vec::new(); @@ -542,7 +599,8 @@ impl AcpThreadView { } } - true + let snapshot = snapshot.as_singleton().unwrap().2.clone(); + Some(snapshot.text) } fn handle_thread_event( @@ -564,6 +622,30 @@ impl AcpThreadView { self.sync_thread_entry_view(index, window, cx); self.list_state.splice(index..index + 1, 1); } + AcpThreadEvent::ToolAuthorizationRequired => { + self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); + } + AcpThreadEvent::Stopped => { + let used_tools = thread.read(cx).used_tools_since_last_user_message(); + self.notify_with_sound( + if used_tools { + "Finished running tools" + } else { + "New message" + }, + IconName::ZedAssistant, + window, + cx, + ); + } + AcpThreadEvent::Error => { + self.notify_with_sound( + "Agent stopped due to an error", + IconName::Warning, + window, + cx, + ); + } } cx.notify(); } @@ -640,13 +722,18 @@ impl AcpThreadView { Some(entry.diffs().map(|diff| diff.multibuffer.clone())) } - fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { + fn authenticate( + &mut self, + method: acp::AuthMethodId, + window: &mut Window, + cx: &mut Context, + ) { let ThreadState::Unauthenticated { ref connection } = self.thread_state else { return; }; self.last_error.take(); - let authenticate = connection.authenticate(cx); + let authenticate = connection.authenticate(method, cx); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); let agent = self.agent.clone(); @@ -1146,7 +1233,7 @@ impl AcpThreadView { }) .children(options.iter().map(|option| { let option_id = SharedString::from(option.id.0.clone()); - Button::new((option_id, entry_ix), option.label.clone()) + Button::new((option_id, entry_ix), option.name.clone()) .map(|this| match option.kind { acp::PermissionOptionKind::AllowOnce => { this.icon(IconName::Check).icon_color(Color::Success) @@ -1938,15 +2025,15 @@ impl AcpThreadView { .icon_color(Color::Accent) .style(ButtonStyle::Filled) .disabled(self.thread().is_none() || is_editor_empty) - .on_click(cx.listener(|this, _, window, cx| { - this.chat(&Chat, window, cx); - })) .when(!is_editor_empty, |button| { button.tooltip(move |window, cx| Tooltip::for_action("Send", &Chat, window, cx)) }) .when(is_editor_empty, |button| { button.tooltip(Tooltip::text("Type a message to submit")) }) + .on_click(cx.listener(|this, _, window, cx| { + this.chat(&Chat, window, cx); + })) .into_any_element() } else { IconButton::new("stop-generation", IconName::StopFilled) @@ -2160,17 +2247,165 @@ impl AcpThreadView { self.list_state.scroll_to(ListOffset::default()); cx.notify(); } -} -impl Focusable for AcpThreadView { - fn focus_handle(&self, cx: &App) -> FocusHandle { - self.message_editor.focus_handle(cx) + pub fn scroll_to_bottom(&mut self, cx: &mut Context) { + if let Some(thread) = self.thread() { + let entry_count = thread.read(cx).entries().len(); + self.list_state.reset(entry_count); + cx.notify(); + } } -} -impl Render for AcpThreadView { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let open_as_markdown = IconButton::new("open-as-markdown", IconName::DocumentText) + fn notify_with_sound( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + self.play_notification_sound(window, cx); + self.show_notification(caption, icon, window, cx); + } + + fn play_notification_sound(&self, window: &Window, cx: &mut App) { + let settings = AgentSettings::get_global(cx); + if settings.play_sound_when_agent_done && !window.is_window_active() { + Audio::play_sound(Sound::AgentDone, cx); + } + } + + fn show_notification( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + if window.is_window_active() || !self.notifications.is_empty() { + return; + } + + let title = self.title(cx); + + match AgentSettings::get_global(cx).notify_when_agent_waiting { + NotifyWhenAgentWaiting::PrimaryScreen => { + if let Some(primary) = cx.primary_display() { + self.pop_up(icon, caption.into(), title, window, primary, cx); + } + } + NotifyWhenAgentWaiting::AllScreens => { + let caption = caption.into(); + for screen in cx.displays() { + self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx); + } + } + NotifyWhenAgentWaiting::Never => { + // Don't show anything + } + } + } + + fn pop_up( + &mut self, + icon: IconName, + caption: SharedString, + title: SharedString, + window: &mut Window, + screen: Rc, + cx: &mut Context, + ) { + let options = AgentNotification::window_options(screen, cx); + + let project_name = self.workspace.upgrade().and_then(|workspace| { + workspace + .read(cx) + .project() + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).root_name().to_string()) + }); + + if let Some(screen_window) = cx + .open_window(options, |_, cx| { + cx.new(|_| { + AgentNotification::new(title.clone(), caption.clone(), icon, project_name) + }) + }) + .log_err() + { + if let Some(pop_up) = screen_window.entity(cx).log_err() { + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push(cx.subscribe_in(&pop_up, window, { + |this, _, event, window, cx| match event { + AgentNotificationEvent::Accepted => { + let handle = window.window_handle(); + cx.activate(true); + + let workspace_handle = this.workspace.clone(); + + // If there are multiple Zed windows, activate the correct one. + cx.defer(move |cx| { + handle + .update(cx, |_view, window, _cx| { + window.activate_window(); + + if let Some(workspace) = workspace_handle.upgrade() { + workspace.update(_cx, |workspace, cx| { + workspace.focus_panel::(window, cx); + }); + } + }) + .log_err(); + }); + + this.dismiss_notifications(cx); + } + AgentNotificationEvent::Dismissed => { + this.dismiss_notifications(cx); + } + } + })); + + self.notifications.push(screen_window); + + // If the user manually refocuses the original window, dismiss the popup. + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push({ + let pop_up_weak = pop_up.downgrade(); + + cx.observe_window_activation(window, move |_, window, cx| { + if window.is_window_active() { + if let Some(pop_up) = pop_up_weak.upgrade() { + pop_up.update(cx, |_, cx| { + cx.emit(AgentNotificationEvent::Dismissed); + }); + } + } + }) + }); + } + } + } + + fn dismiss_notifications(&mut self, cx: &mut Context) { + for window in self.notifications.drain(..) { + window + .update(cx, |_, window, _| { + window.remove_window(); + }) + .ok(); + + self.notification_subscriptions.remove(&window); + } + } + + fn render_thread_controls(&mut self, cx: &mut Context) -> impl IntoElement { + let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileText) .icon_size(IconSize::XSmall) .icon_color(Color::Ignored) .tooltip(Tooltip::text("Open Thread as Markdown")) @@ -2189,6 +2424,28 @@ impl Render for AcpThreadView { this.scroll_to_top(cx); })); + h_flex() + .mt_1() + .mr_1() + .py_2() + .px(RESPONSE_PADDING_X) + .opacity(0.4) + .hover(|style| style.opacity(1.)) + .flex_wrap() + .justify_end() + .child(open_as_markdown) + .child(scroll_to_top) + } +} + +impl Focusable for AcpThreadView { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.message_editor.focus_handle(cx) + } +} + +impl Render for AcpThreadView { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() .size_full() .key_context("AcpThread") @@ -2196,23 +2453,28 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::previous_history_message)) .on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::open_agent_diff)) + .bg(cx.theme().colors().panel_background) .child(match &self.thread_state { - ThreadState::Unauthenticated { .. } => { - v_flex() - .p_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .child( - h_flex().mt_1p5().justify_center().child( - Button::new("sign-in", format!("Sign in to {}", self.agent.name())) - .on_click(cx.listener(|this, _, window, cx| { - this.authenticate(window, cx) - })), - ), - ) - } + ThreadState::Unauthenticated { connection } => v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_pending_auth_state()) + .child(h_flex().mt_1p5().justify_center().children( + connection.auth_methods().into_iter().map(|method| { + Button::new( + SharedString::from(method.id.0.clone()), + method.name.clone(), + ) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) + }) + }), + )), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2() @@ -2220,42 +2482,39 @@ impl Render for AcpThreadView { .items_center() .justify_center() .child(self.render_error_state(e, cx)), - ThreadState::Ready { thread, .. } => v_flex().flex_1().map(|this| { - if self.list_state.item_count() > 0 { - this.child( - list(self.list_state.clone()) - .with_sizing_behavior(gpui::ListSizingBehavior::Auto) - .flex_grow() - .into_any(), - ) - .child( - h_flex() - .group("controls") - .mt_1() - .mr_1() - .py_2() - .px(RESPONSE_PADDING_X) - .opacity(0.4) - .hover(|style| style.opacity(1.)) - .flex_wrap() - .justify_end() - .child(open_as_markdown) - .child(scroll_to_top) - .into_any_element(), - ) - .children(match thread.read(cx).status() { - ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => None, - ThreadStatus::Generating => div() - .px_5() - .py_2() - .child(LoadingLabel::new("").size(LabelSize::Small)) - .into(), - }) - .children(self.render_activity_bar(&thread, window, cx)) - } else { - this.child(self.render_empty_state(cx)) - } - }), + ThreadState::Ready { thread, .. } => { + let thread_clone = thread.clone(); + + v_flex().flex_1().map(|this| { + if self.list_state.item_count() > 0 { + let is_generating = + matches!(thread_clone.read(cx).status(), ThreadStatus::Generating); + + this.child( + list(self.list_state.clone()) + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .flex_grow() + .into_any(), + ) + .when(!is_generating, |this| { + this.child(self.render_thread_controls(cx)) + }) + .children(match thread_clone.read(cx).status() { + ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => { + None + } + ThreadStatus::Generating => div() + .px_5() + .py_2() + .child(LoadingLabel::new("").size(LabelSize::Small)) + .into(), + }) + .children(self.render_activity_bar(&thread_clone, window, cx)) + } else { + this.child(self.render_empty_state(cx)) + } + }) + } }) .when_some(self.last_error.clone(), |el, error| { el.child( @@ -2441,3 +2700,347 @@ fn plan_label_markdown_style( ..default_md_style } } + +#[cfg(test)] +mod tests { + use agent_client_protocol::SessionId; + use editor::EditorSettings; + use fs::FakeFs; + use futures::future::try_join_all; + use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; + use rand::Rng; + use settings::SettingsStore; + + use super::*; + + #[gpui::test] + async fn test_notification_for_stop_event(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_notification_for_error(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { + init_test(cx); + + let tool_call_id = acp::ToolCallId("1".into()); + let tool_call = acp::ToolCall { + id: tool_call_id.clone(), + title: "Label".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Pending, + content: vec!["hi".into()], + locations: vec![], + raw_input: None, + }; + let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) + .with_permission_requests(HashMap::from_iter([( + tool_call_id, + vec![acp::PermissionOption { + id: acp::PermissionOptionId("1".into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }], + )])); + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + async fn setup_thread_view( + agent: impl AgentServer + 'static, + cx: &mut TestAppContext, + ) -> (Entity, &mut VisualTestContext) { + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(agent), + workspace.downgrade(), + project, + Rc::new(RefCell::new(MessageHistory::default())), + 1, + None, + window, + cx, + ) + }) + }); + cx.run_until_parked(); + (thread_view, cx) + } + + struct StubAgentServer { + connection: C, + } + + impl StubAgentServer { + fn new(connection: C) -> Self { + Self { connection } + } + } + + impl StubAgentServer { + fn default() -> Self { + Self::new(StubAgentConnection::default()) + } + } + + impl AgentServer for StubAgentServer + where + C: 'static + AgentConnection + Send + Clone, + { + fn logo(&self) -> ui::IconName { + unimplemented!() + } + + fn name(&self) -> &'static str { + unimplemented!() + } + + fn empty_state_headline(&self) -> &'static str { + unimplemented!() + } + + fn empty_state_message(&self) -> &'static str { + unimplemented!() + } + + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + _cx: &mut App, + ) -> Task>> { + Task::ready(Ok(Rc::new(self.connection.clone()))) + } + } + + #[derive(Clone, Default)] + struct StubAgentConnection { + sessions: Arc>>>, + permission_requests: HashMap>, + updates: Vec, + } + + impl StubAgentConnection { + fn new(updates: Vec) -> Self { + Self { + updates, + permission_requests: HashMap::default(), + sessions: Arc::default(), + } + } + + fn with_permission_requests( + mut self, + permission_requests: HashMap>, + ) -> Self { + self.permission_requests = permission_requests; + self + } + } + + impl AgentConnection for StubAgentConnection { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + let session_id = SessionId( + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .into(), + ); + let thread = cx + .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) + .unwrap(); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + let mut tasks = vec![]; + for update in &self.updates { + let thread = thread.clone(); + let update = update.clone(); + let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update + && let Some(options) = self.permission_requests.get(&tool_call.id) + { + Some((tool_call.clone(), options.clone())) + } else { + None + }; + let task = cx.spawn(async move |cx| { + if let Some((tool_call, options)) = permission_request { + let permission = thread.update(cx, |thread, cx| { + thread.request_tool_call_permission( + tool_call.clone(), + options.clone(), + cx, + ) + })?; + permission.await?; + } + thread.update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + })?; + anyhow::Ok(()) + }); + tasks.push(task); + } + cx.spawn(async move |_| { + try_join_all(tasks).await?; + Ok(()) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + } + + #[derive(Clone)] + struct SaboteurAgentConnection; + + impl AgentConnection for SaboteurAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + Task::ready(Ok(cx + .new(|cx| { + AcpThread::new( + "SaboteurAgentConnection", + self, + project, + SessionId("test".into()), + cx, + ) + }) + .unwrap())) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow::anyhow!("Error prompting"))) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AgentSettings::register(cx); + workspace::init_settings(cx); + ThemeSettings::register(cx); + release_channel::init(SemanticVersion::default(), cx); + EditorSettings::register(cx); + }); + } +} diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index e27c318221..04a093c7d0 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -14,6 +14,7 @@ use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; use anyhow::Context as _; use assistant_tool::ToolUseStatus; use audio::{Audio, Sound}; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; use editor::scroll::Autoscroll; @@ -52,7 +53,6 @@ use util::ResultExt as _; use util::markdown::MarkdownCodeBlock; use workspace::{CollaboratorId, Workspace}; use zed_actions::assistant::OpenRulesLibrary; -use zed_llm_client::CompletionIntent; const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container"; const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1; diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index fae04188eb..02c15b7e41 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -7,6 +7,7 @@ use std::{sync::Arc, time::Duration}; use agent_settings::AgentSettings; use assistant_tool::{ToolSource, ToolWorkingSet}; +use cloud_llm_client::Plan; use collections::HashMap; use context_server::ContextServerId; use extension::ExtensionManifest; @@ -25,7 +26,6 @@ use project::{ context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore}, project_settings::{ContextServerSettings, ProjectSettings}, }; -use proto::Plan; use settings::{Settings, update_settings_file}; use ui::{ Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, @@ -180,7 +180,7 @@ impl AgentConfiguration { let current_plan = if is_zed_provider { self.workspace .upgrade() - .and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan()) + .and_then(|workspace| workspace.read(cx).user_store().read(cx).plan()) } else { None }; @@ -406,7 +406,9 @@ impl AgentConfiguration { SwitchField::new( "always-allow-tool-actions-switch", "Allow running commands without asking for confirmation", - "The agent can perform potentially destructive actions without asking for your confirmation.", + Some( + "The agent can perform potentially destructive actions without asking for your confirmation.".into(), + ), always_allow_tool_actions, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -424,7 +426,7 @@ impl AgentConfiguration { SwitchField::new( "single-file-review", "Enable single-file agent reviews", - "Agent edits are also displayed in single-file editors for review.", + Some("Agent edits are also displayed in single-file editors for review.".into()), single_file_review, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -442,7 +444,9 @@ impl AgentConfiguration { SwitchField::new( "sound-notification", "Play sound when finished generating", - "Hear a notification sound when the agent is done generating changes or needs your input.", + Some( + "Hear a notification sound when the agent is done generating changes or needs your input.".into(), + ), play_sound_when_agent_done, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -460,7 +464,9 @@ impl AgentConfiguration { SwitchField::new( "modifier-send", "Use modifier to submit a message", - "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.", + Some( + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(), + ), use_modifier_to_send, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -502,7 +508,7 @@ impl AgentConfiguration { .blend(cx.theme().colors().text_accent.opacity(0.2)); let (plan_name, label_color, bg_color) = match plan { - Plan::Free => ("Free", Color::Default, free_chip_bg), + Plan::ZedFree => ("Free", Color::Default, free_chip_bg), Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg), Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg), }; @@ -533,7 +539,7 @@ impl AgentConfiguration { v_flex() .gap_0p5() .child(Headline::new("Model Context Protocol (MCP) Servers")) - .child(Label::new("Connect to context servers via the Model Context Protocol either via Zed extensions or directly.").color(Color::Muted)), + .child(Label::new("Connect to context servers through the Model Context Protocol, either using Zed extensions or directly.").color(Color::Muted)), ) .children( context_server_ids.into_iter().map(|context_server_id| { diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs index 94b32d156b..401a633488 100644 --- a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs +++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs @@ -272,42 +272,34 @@ impl AddLlmProviderModal { cx.emit(DismissEvent); } - fn render_section(&self) -> Section { - Section::new() - .child(self.input.provider_name.clone()) - .child(self.input.api_url.clone()) - .child(self.input.api_key.clone()) - } - - fn render_model_section(&self, cx: &mut Context) -> Section { - Section::new().child( - v_flex() - .gap_2() - .child( - h_flex() - .justify_between() - .child(Label::new("Models").size(LabelSize::Small)) - .child( - Button::new("add-model", "Add Model") - .icon(IconName::Plus) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .label_size(LabelSize::Small) - .on_click(cx.listener(|this, _, window, cx| { - this.input.add_model(window, cx); - cx.notify(); - })), - ), - ) - .children( - self.input - .models - .iter() - .enumerate() - .map(|(ix, _)| self.render_model(ix, cx)), - ), - ) + fn render_model_section(&self, cx: &mut Context) -> impl IntoElement { + v_flex() + .mt_1() + .gap_2() + .child( + h_flex() + .justify_between() + .child(Label::new("Models").size(LabelSize::Small)) + .child( + Button::new("add-model", "Add Model") + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + this.input.add_model(window, cx); + cx.notify(); + })), + ), + ) + .children( + self.input + .models + .iter() + .enumerate() + .map(|(ix, _)| self.render_model(ix, cx)), + ) } fn render_model(&self, ix: usize, cx: &mut Context) -> impl IntoElement + use<> { @@ -393,10 +385,14 @@ impl Render for AddLlmProviderModal { .child( v_flex() .id("modal_content") + .size_full() .max_h_128() .overflow_y_scroll() - .gap_2() - .child(self.render_section()) + .px(DynamicSpacing::Base12.rems(cx)) + .gap(DynamicSpacing::Base04.rems(cx)) + .child(self.input.provider_name.clone()) + .child(self.input.api_url.clone()) + .child(self.input.api_key.clone()) .child(self.render_model_section(cx)), ) .footer( diff --git a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs index 45536ff13b..5d44bb2d92 100644 --- a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs @@ -483,7 +483,7 @@ impl ManageProfilesModal { let icon = match mode.profile_id.as_str() { "write" => IconName::Pencil, - "ask" => IconName::MessageBubbles, + "ask" => IconName::Chat, _ => IconName::UserRoundPen, }; diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index ec0a11f86b..c4dc359093 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1521,6 +1521,9 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } } + AcpThreadEvent::Stopped + | AcpThreadEvent::ToolAuthorizationRequired + | AcpThreadEvent::Error => {} } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 61a65de50b..5f3315f69a 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -43,7 +43,8 @@ use anyhow::{Result, anyhow}; use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; -use client::{DisableAiSettings, UserStore, zed_urls}; +use client::{UserStore, zed_urls}; +use cloud_llm_client::{CompletionIntent, Plan, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; use feature_flags::{self, FeatureFlagAppExt}; use fs::Fs; @@ -57,9 +58,8 @@ use language::LanguageRegistry; use language_model::{ ConfigurationError, ConfiguredModel, LanguageModelProviderTosView, LanguageModelRegistry, }; -use project::{Project, ProjectPath, Worktree}; +use project::{DisableAiSettings, Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; -use proto::Plan; use rules_library::{RulesLibrary, open_rules_library}; use search::{BufferSearchBar, buffer_search}; use settings::{Settings, update_settings_file}; @@ -77,10 +77,9 @@ use workspace::{ }; use zed_actions::{ DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize, - agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding, ToggleModelSelector}, + agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector}, assistant::{OpenRulesLibrary, ToggleFocus}, }; -use zed_llm_client::{CompletionIntent, UsageLimit}; const AGENT_PANEL_KEY: &str = "agent_panel"; @@ -105,7 +104,7 @@ pub fn init(cx: &mut App) { panel.update(cx, |panel, cx| panel.open_history(window, cx)); } }) - .register_action(|workspace, _: &OpenConfiguration, window, cx| { + .register_action(|workspace, _: &OpenSettings, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| panel.open_configuration(window, cx)); @@ -579,7 +578,6 @@ impl AgentPanel { MessageEditor::new( fs.clone(), workspace.clone(), - user_store.clone(), message_editor_context_store.clone(), prompt_store.clone(), thread_store.downgrade(), @@ -848,7 +846,6 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), context_store.clone(), self.prompt_store.clone(), self.thread_store.downgrade(), @@ -1122,7 +1119,6 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), context_store, self.prompt_store.clone(), self.thread_store.downgrade(), @@ -1884,10 +1880,10 @@ impl AgentPanel { }), ); - let zoom_in_label = if self.is_zoomed(window, cx) { - "Zoom Out" + let full_screen_label = if self.is_zoomed(window, cx) { + "Disable Full Screen" } else { - "Zoom In" + "Enable Full Screen" }; let active_thread = match &self.active_view { @@ -1915,27 +1911,6 @@ impl AgentPanel { .when(cx.has_flag::(), |this| { this.header("Zed Agent") }) - .item( - ContextMenuEntry::new("New Thread") - .icon(IconName::NewThread) - .icon_color(Color::Muted) - .action(NewThread::default().boxed_clone()) - .handler(move |window, cx| { - window.dispatch_action( - NewThread::default().boxed_clone(), - cx, - ); - }), - ) - .item( - ContextMenuEntry::new("New Text Thread") - .icon(IconName::NewTextThread) - .icon_color(Color::Muted) - .action(NewTextThread.boxed_clone()) - .handler(move |window, cx| { - window.dispatch_action(NewTextThread.boxed_clone(), cx); - }), - ) .when_some(active_thread, |this, active_thread| { let thread = active_thread.read(cx); @@ -1943,7 +1918,7 @@ impl AgentPanel { let thread_id = thread.id().clone(); this.item( ContextMenuEntry::new("New From Summary") - .icon(IconName::NewFromSummary) + .icon(IconName::ThreadFromSummary) .icon_color(Color::Muted) .handler(move |window, cx| { window.dispatch_action( @@ -1958,6 +1933,27 @@ impl AgentPanel { this } }) + .item( + ContextMenuEntry::new("New Thread") + .icon(IconName::Thread) + .icon_color(Color::Muted) + .action(NewThread::default().boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Text Thread") + .icon(IconName::TextThread) + .icon_color(Color::Muted) + .action(NewTextThread.boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action(NewTextThread.boxed_clone(), cx); + }), + ) .when(cx.has_flag::(), |this| { this.separator() .header("External Agents") @@ -1991,20 +1987,6 @@ impl AgentPanel { ); }), ) - .item( - ContextMenuEntry::new("New Codex Thread") - .icon(IconName::AiOpenAi) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some(crate::ExternalAgent::Codex), - } - .boxed_clone(), - cx, - ); - }), - ) }); menu })) @@ -2088,8 +2070,9 @@ impl AgentPanel { menu = menu .action("Rules…", Box::new(OpenRulesLibrary::default())) - .action("Settings", Box::new(OpenConfiguration)) - .action(zoom_in_label, Box::new(ToggleZoom)); + .action("Settings", Box::new(OpenSettings)) + .separator() + .action(full_screen_label, Box::new(ToggleZoom)); menu })) } @@ -2293,10 +2276,10 @@ impl AgentPanel { | ActiveView::Configuration => return false, } - let plan = self.user_store.read(cx).current_plan(); + let plan = self.user_store.read(cx).plan(); let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); - matches!(plan, Some(Plan::Free)) && has_previous_trial + matches!(plan, Some(Plan::ZedFree)) && has_previous_trial } fn should_render_onboarding(&self, cx: &mut Context) -> bool { @@ -2482,14 +2465,14 @@ impl AgentPanel { .icon_color(Color::Muted) .full_width() .key_binding(KeyBinding::for_action_in( - &OpenConfiguration, + &OpenSettings, &focus_handle, window, cx, )) .on_click(|_event, window, cx| { window.dispatch_action( - OpenConfiguration.boxed_clone(), + OpenSettings.boxed_clone(), cx, ) }), @@ -2576,7 +2559,7 @@ impl AgentPanel { NewThreadButton::new( "new-thread-btn", "New Thread", - IconName::NewThread, + IconName::Thread, ) .keybinding(KeyBinding::for_action_in( &NewThread::default(), @@ -2597,7 +2580,7 @@ impl AgentPanel { NewThreadButton::new( "new-text-thread-btn", "New Text Thread", - IconName::NewTextThread, + IconName::TextThread, ) .keybinding(KeyBinding::for_action_in( &NewTextThread, @@ -2666,25 +2649,6 @@ impl AgentPanel { ) }, ), - ) - .child( - NewThreadButton::new( - "new-codex-thread-btn", - "New Codex Thread", - IconName::AiOpenAi, - ) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::Codex, - ), - }), - cx, - ) - }, - ), ), ) }), @@ -2713,16 +2677,11 @@ impl AgentPanel { .style(ButtonStyle::Tinted(ui::TintColor::Warning)) .label_size(LabelSize::Small) .key_binding( - KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - ) - .map(|kb| kb.size(rems_from_px(12.))), + KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))), ) .on_click(|_event, window, cx| { - window.dispatch_action(OpenConfiguration.boxed_clone(), cx) + window.dispatch_action(OpenSettings.boxed_clone(), cx) }), ), ConfigurationError::ProviderPendingTermsAcceptance(provider) => { @@ -2916,7 +2875,7 @@ impl AgentPanel { ) -> AnyElement { let error_message = match plan { Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", - Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.", + Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.", }; let icon = Icon::new(IconName::XCircle) @@ -3226,7 +3185,7 @@ impl Render for AgentPanel { .on_action(cx.listener(|this, _: &OpenHistory, window, cx| { this.open_history(window, cx); })) - .on_action(cx.listener(|this, _: &OpenConfiguration, window, cx| { + .on_action(cx.listener(|this, _: &OpenSettings, window, cx| { this.open_configuration(window, cx); })) .on_action(cx.listener(Self::open_active_thread_as_markdown)) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 4b75cc9e77..30faf5ef2e 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -31,7 +31,7 @@ use std::sync::Arc; use agent::{Thread, ThreadId}; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection}; use assistant_slash_command::SlashCommandRegistry; -use client::{Client, DisableAiSettings}; +use client::Client; use command_palette_hooks::CommandPaletteFilter; use feature_flags::FeatureFlagAppExt as _; use fs::Fs; @@ -40,6 +40,7 @@ use language::LanguageRegistry; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, }; +use project::DisableAiSettings; use prompt_store::PromptBuilder; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -150,7 +151,6 @@ enum ExternalAgent { #[default] Gemini, ClaudeCode, - Codex, } impl ExternalAgent { @@ -158,7 +158,6 @@ impl ExternalAgent { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), - ExternalAgent::Codex => Rc::new(agent_servers::Codex), } } } @@ -265,8 +264,8 @@ fn update_command_palette_filter(cx: &mut App) { filter.hide_namespace("agent"); filter.hide_namespace("assistant"); filter.hide_namespace("copilot"); + filter.hide_namespace("supermaven"); filter.hide_namespace("zed_predict_onboarding"); - filter.hide_namespace("edit_prediction"); use editor::actions::{ diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 64498e9281..615142b73d 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -6,6 +6,7 @@ use agent::{ use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; +use cloud_llm_client::CompletionIntent; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; use futures::{ @@ -35,7 +36,6 @@ use std::{ }; use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; -use zed_llm_client::CompletionIntent; pub struct BufferCodegen { alternatives: Vec>, diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 5cc56b014e..32f9a096d9 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -148,7 +148,7 @@ impl ContextPickerMode { Self::File => IconName::File, Self::Symbol => IconName::Code, Self::Fetch => IconName::Globe, - Self::Thread => IconName::MessageBubbles, + Self::Thread => IconName::Thread, Self::Rules => RULES_ICON, } } diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index b377e40b19..5ca0913be7 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -423,7 +423,7 @@ impl ContextPickerCompletionProvider { let icon_for_completion = if recent { IconName::HistoryRerun } else { - IconName::MessageBubbles + IconName::Thread }; let new_text = format!("{} ", MentionLink::for_thread(&thread_entry)); let new_text_len = new_text.len(); @@ -436,7 +436,7 @@ impl ContextPickerCompletionProvider { source: project::CompletionSource::Custom, icon_path: Some(icon_for_completion.path().into()), confirm: Some(confirm_completion_callback( - IconName::MessageBubbles.path().into(), + IconName::Thread.path().into(), thread_entry.title().clone(), excerpt_id, source_range.start, diff --git a/crates/agent_ui/src/context_picker/thread_context_picker.rs b/crates/agent_ui/src/context_picker/thread_context_picker.rs index cb2e97a493..15cc731f8f 100644 --- a/crates/agent_ui/src/context_picker/thread_context_picker.rs +++ b/crates/agent_ui/src/context_picker/thread_context_picker.rs @@ -253,7 +253,7 @@ pub fn render_thread_context_entry( .gap_1p5() .max_w_72() .child( - Icon::new(IconName::MessageBubbles) + Icon::new(IconName::Thread) .size(IconSize::XSmall) .color(Color::Muted), ) diff --git a/crates/agent_ui/src/debug.rs b/crates/agent_ui/src/debug.rs index ff6538dc85..bd34659210 100644 --- a/crates/agent_ui/src/debug.rs +++ b/crates/agent_ui/src/debug.rs @@ -1,10 +1,10 @@ #![allow(unused, dead_code)] use client::{ModelRequestUsage, RequestUsage}; +use cloud_llm_client::{Plan, UsageLimit}; use gpui::Global; use std::ops::{Deref, DerefMut}; use ui::prelude::*; -use zed_llm_client::{Plan, UsageLimit}; /// Debug only: Used for testing various account states /// diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 44ec050ae2..4a4a747899 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -16,7 +16,7 @@ use agent::{ }; use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; -use client::{DisableAiSettings, telemetry::Telemetry}; +use client::telemetry::Telemetry; use collections::{HashMap, HashSet, VecDeque, hash_map}; use editor::SelectionEffects; use editor::{ @@ -39,7 +39,7 @@ use language_model::{ }; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; -use project::{CodeAction, LspAction, Project, ProjectTransaction}; +use project::{CodeAction, DisableAiSettings, LspAction, Project, ProjectTransaction}; use prompt_store::{PromptBuilder, PromptStore}; use settings::{Settings, SettingsStore}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; @@ -48,7 +48,7 @@ use text::{OffsetRangeExt, ToPoint as _}; use ui::prelude::*; use util::{RangeExt, ResultExt, maybe}; use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId}; -use zed_actions::agent::OpenConfiguration; +use zed_actions::agent::OpenSettings; pub fn init( fs: Arc, @@ -162,7 +162,7 @@ impl InlineAssistant { let window = windows[0]; let _ = window.update(cx, |_, window, cx| { editor.update(cx, |editor, cx| { - if editor.has_active_inline_completion() { + if editor.has_active_edit_prediction() { editor.cancel(&Default::default(), window, cx); } }); @@ -231,8 +231,8 @@ impl InlineAssistant { ); if DisableAiSettings::get_global(cx).disable_ai { - // Cancel any active completions - if editor.has_active_inline_completion() { + // Cancel any active edit predictions + if editor.has_active_edit_prediction() { editor.cancel(&Default::default(), window, cx); } } @@ -345,7 +345,7 @@ impl InlineAssistant { if let Some(answer) = answer { if answer == 0 { cx.update(|window, cx| { - window.dispatch_action(Box::new(OpenConfiguration), cx) + window.dispatch_action(Box::new(OpenSettings), cx) }) .ok(); } diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index ade7a5e13d..a5f90edb57 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -541,7 +541,7 @@ impl PromptEditor { match &self.mode { PromptEditorMode::Terminal { .. } => vec![ accept, - IconButton::new("confirm", IconName::Play) + IconButton::new("confirm", IconName::PlayOutlined) .icon_color(Color::Info) .shape(IconButtonShape::Square) .tooltip(|window, cx| { diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 655e87d7cd..7121624c87 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -576,7 +576,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { .icon_position(IconPosition::Start) .on_click(|_, window, cx| { window.dispatch_action( - zed_actions::agent::OpenConfiguration.boxed_clone(), + zed_actions::agent::OpenSettings.boxed_clone(), cx, ); }), diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index c160f1de04..2185885347 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -17,7 +17,7 @@ use agent::{ use agent_settings::{AgentSettings, CompletionMode}; use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; -use client::UserStore; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; use editor::display_map::CreaseId; @@ -42,7 +42,6 @@ use language_model::{ use multi_buffer; use project::Project; use prompt_store::PromptStore; -use proto::Plan; use settings::Settings; use std::time::Duration; use theme::ThemeSettings; @@ -53,7 +52,6 @@ use util::ResultExt as _; use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::Chat; use zed_actions::agent::ToggleModelSelector; -use zed_llm_client::CompletionIntent; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; @@ -79,7 +77,6 @@ pub struct MessageEditor { editor: Entity, workspace: WeakEntity, project: Entity, - user_store: Entity, context_store: Entity, prompt_store: Option>, history_store: Option>, @@ -159,7 +156,6 @@ impl MessageEditor { pub fn new( fs: Arc, workspace: WeakEntity, - user_store: Entity, context_store: Entity, prompt_store: Option>, thread_store: WeakEntity, @@ -231,7 +227,6 @@ impl MessageEditor { Self { editor: editor.clone(), project: thread.read(cx).project().clone(), - user_store, thread, incompatible_tools_state: incompatible_tools.clone(), workspace, @@ -1287,24 +1282,12 @@ impl MessageEditor { return None; } - let user_store = self.user_store.read(cx); - - let ubb_enable = user_store - .usage_based_billing_enabled() - .map_or(false, |enabled| enabled); - - if ubb_enable { + let user_store = self.project.read(cx).user_store().read(cx); + if user_store.is_usage_based_billing_enabled() { return None; } - let plan = user_store - .current_plan() - .map(|plan| match plan { - Plan::Free => zed_llm_client::Plan::ZedFree, - Plan::ZedPro => zed_llm_client::Plan::ZedPro, - Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, - }) - .unwrap_or(zed_llm_client::Plan::ZedFree); + let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree); let usage = user_store.model_request_usage()?; @@ -1769,7 +1752,6 @@ impl AgentPreview for MessageEditor { ) -> Option { if let Some(workspace) = workspace.upgrade() { let fs = workspace.read(cx).app_state().fs.clone(); - let user_store = workspace.read(cx).app_state().user_store.clone(); let project = workspace.read(cx).project().clone(); let weak_project = project.downgrade(); let context_store = cx.new(|_cx| ContextStore::new(weak_project, None)); @@ -1782,7 +1764,6 @@ impl AgentPreview for MessageEditor { MessageEditor::new( fs, workspace.downgrade(), - user_store, context_store, None, thread_store.downgrade(), diff --git a/crates/agent_ui/src/terminal_inline_assistant.rs b/crates/agent_ui/src/terminal_inline_assistant.rs index 91867957cd..bcbc308c99 100644 --- a/crates/agent_ui/src/terminal_inline_assistant.rs +++ b/crates/agent_ui/src/terminal_inline_assistant.rs @@ -10,6 +10,7 @@ use agent::{ use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, VecDeque}; use editor::{MultiBuffer, actions::SelectAll}; use fs::Fs; @@ -27,7 +28,6 @@ use terminal_view::TerminalView; use ui::prelude::*; use util::ResultExt; use workspace::{Toast, Workspace, notifications::NotificationId}; -use zed_llm_client::CompletionIntent; pub fn init( fs: Arc, diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 3df0a48aa4..4836a95c8e 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -12,7 +12,7 @@ use assistant_slash_commands::{ use client::{proto, zed_urls}; use collections::{BTreeSet, HashMap, HashSet, hash_map}; use editor::{ - Anchor, Editor, EditorEvent, MenuInlineCompletionsPolicy, MultiBuffer, MultiBufferSnapshot, + Anchor, Editor, EditorEvent, MenuEditPredictionsPolicy, MultiBuffer, MultiBufferSnapshot, RowExt, ToOffset as _, ToPoint, actions::{MoveToEndOfLine, Newline, ShowCompletions}, display_map::{ @@ -254,7 +254,7 @@ impl TextThreadEditor { editor.set_show_wrap_guides(false, cx); editor.set_show_indent_guides(false, cx); editor.set_completion_provider(Some(Rc::new(completion_provider))); - editor.set_menu_inline_completions_policy(MenuInlineCompletionsPolicy::Never); + editor.set_menu_edit_predictions_policy(MenuEditPredictionsPolicy::Never); editor.set_collaboration_hub(Box::new(project.clone())); let show_edit_predictions = all_language_settings(None, cx) diff --git a/crates/agent_ui/src/thread_history.rs b/crates/agent_ui/src/thread_history.rs index a2ee816f73..b8d1db88d6 100644 --- a/crates/agent_ui/src/thread_history.rs +++ b/crates/agent_ui/src/thread_history.rs @@ -701,7 +701,7 @@ impl RenderOnce for HistoryEntryElement { .on_hover(self.on_hover) .end_slot::(if self.hovered || self.selected { Some( - IconButton::new("delete", IconName::TrashAlt) + IconButton::new("delete", IconName::Trash) .shape(IconButtonShape::Square) .icon_size(IconSize::XSmall) .icon_color(Color::Muted) diff --git a/crates/agent_ui/src/ui/preview/usage_callouts.rs b/crates/agent_ui/src/ui/preview/usage_callouts.rs index 45af41395b..64869a6ec7 100644 --- a/crates/agent_ui/src/ui/preview/usage_callouts.rs +++ b/crates/agent_ui/src/ui/preview/usage_callouts.rs @@ -1,8 +1,8 @@ use client::{ModelRequestUsage, RequestUsage, zed_urls}; +use cloud_llm_client::{Plan, UsageLimit}; use component::{empty_example, example_group_with_title, single_example}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; use ui::{Callout, prelude::*}; -use zed_llm_client::{Plan, UsageLimit}; #[derive(IntoElement, RegisterComponent)] pub struct UsageCallout { diff --git a/crates/ai_onboarding/Cargo.toml b/crates/ai_onboarding/Cargo.toml index 9031e14e29..95a45b1a6f 100644 --- a/crates/ai_onboarding/Cargo.toml +++ b/crates/ai_onboarding/Cargo.toml @@ -16,10 +16,10 @@ default = [] [dependencies] client.workspace = true +cloud_llm_client.workspace = true component.workspace = true gpui.workspace = true language_model.workspace = true -proto.workspace = true serde.workspace = true smallvec.workspace = true telemetry.workspace = true diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index 5f56e4d26e..e86568fe7a 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -136,10 +136,7 @@ impl RenderOnce for ApiKeysWithoutProviders { .full_width() .style(ButtonStyle::Outlined) .on_click(move |_, window, cx| { - window.dispatch_action( - zed_actions::agent::OpenConfiguration.boxed_clone(), - cx, - ); + window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx); }), ) } diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index e8a62f7ff2..f1629eeff8 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use client::{Client, UserStore}; +use cloud_llm_client::Plan; use gpui::{Entity, IntoElement, ParentElement}; use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; use ui::prelude::*; @@ -56,15 +57,8 @@ impl AgentPanelOnboarding { impl Render for AgentPanelOnboarding { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let enrolled_in_trial = matches!( - self.user_store.read(cx).current_plan(), - Some(proto::Plan::ZedProTrial) - ); - - let is_pro_user = matches!( - self.user_store.read(cx).current_plan(), - Some(proto::Plan::ZedPro) - ); + let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial); + let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro); AgentPanelOnboardingCard::new() .child( diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs index 3aec9c62cd..c252b65f20 100644 --- a/crates/ai_onboarding/src/ai_onboarding.rs +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -9,6 +9,7 @@ pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProvider pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; pub use agent_panel_onboarding_content::AgentPanelOnboarding; pub use ai_upsell_card::AiUpsellCard; +use cloud_llm_client::Plan; pub use edit_prediction_onboarding_content::EditPredictionOnboarding; pub use young_account_banner::YoungAccountBanner; @@ -79,7 +80,7 @@ impl From for SignInStatus { pub struct ZedAiOnboarding { pub sign_in_status: SignInStatus, pub has_accepted_terms_of_service: bool, - pub plan: Option, + pub plan: Option, pub account_too_young: bool, pub continue_with_zed_ai: Arc, pub sign_in: Arc, @@ -99,8 +100,8 @@ impl ZedAiOnboarding { Self { sign_in_status: status.into(), - has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false), - plan: store.current_plan(), + has_accepted_terms_of_service: store.has_accepted_terms_of_service(), + plan: store.plan(), account_too_young: store.account_too_young(), continue_with_zed_ai, accept_terms_of_service: Arc::new({ @@ -113,11 +114,9 @@ impl ZedAiOnboarding { sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); - async move |cx| { - client.authenticate_and_connect(true, cx).await; - } + async move |cx| client.sign_in_with_optional_connect(true, cx).await }) - .detach(); + .detach_and_log_err(cx); }), dismiss_onboarding: None, } @@ -411,9 +410,9 @@ impl RenderOnce for ZedAiOnboarding { if matches!(self.sign_in_status, SignInStatus::SignedIn) { if self.has_accepted_terms_of_service { match self.plan { - None | Some(proto::Plan::Free) => self.render_free_plan_state(cx), - Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx), - Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx), + None | Some(Plan::ZedFree) => self.render_free_plan_state(cx), + Some(Plan::ZedProTrial) => self.render_trial_state(cx), + Some(Plan::ZedPro) => self.render_pro_plan_state(cx), } } else { self.render_accept_terms_of_service() @@ -433,7 +432,7 @@ impl Component for ZedAiOnboarding { fn onboarding( sign_in_status: SignInStatus, has_accepted_terms_of_service: bool, - plan: Option, + plan: Option, account_too_young: bool, ) -> AnyElement { ZedAiOnboarding { @@ -468,25 +467,15 @@ impl Component for ZedAiOnboarding { ), single_example( "Free Plan", - onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false), ), single_example( "Pro Trial", - onboarding( - SignInStatus::SignedIn, - true, - Some(proto::Plan::ZedProTrial), - false, - ), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false), ), single_example( "Pro Plan", - onboarding( - SignInStatus::SignedIn, - true, - Some(proto::Plan::ZedPro), - false, - ), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false), ), ]) .into_any_element(), diff --git a/crates/ai_onboarding/src/ai_upsell_card.rs b/crates/ai_onboarding/src/ai_upsell_card.rs index 041e0d87ec..89a782a7c2 100644 --- a/crates/ai_onboarding/src/ai_upsell_card.rs +++ b/crates/ai_onboarding/src/ai_upsell_card.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use client::{Client, zed_urls}; +use cloud_llm_client::Plan; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; use ui::{Divider, List, Vector, VectorName, prelude::*}; @@ -10,23 +11,25 @@ use crate::{BulletItem, SignInStatus}; pub struct AiUpsellCard { pub sign_in_status: SignInStatus, pub sign_in: Arc, + pub user_plan: Option, + pub tab_index: Option, } impl AiUpsellCard { - pub fn new(client: Arc) -> Self { + pub fn new(client: Arc, user_plan: Option) -> Self { let status = *client.status().borrow(); Self { + user_plan, sign_in_status: status.into(), sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); - async move |cx| { - client.authenticate_and_connect(true, cx).await; - } + async move |cx| client.sign_in_with_optional_connect(true, cx).await }) - .detach(); + .detach_and_log_err(cx); }), + tab_index: None, } } } @@ -34,6 +37,7 @@ impl AiUpsellCard { impl RenderOnce for AiUpsellCard { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { let pro_section = v_flex() + .flex_grow() .w_full() .gap_1() .child( @@ -56,6 +60,7 @@ impl RenderOnce for AiUpsellCard { ); let free_section = v_flex() + .flex_grow() .w_full() .gap_1() .child( @@ -71,7 +76,7 @@ impl RenderOnce for AiUpsellCard { ) .child( List::new() - .child(BulletItem::new("50 prompts with the Claude models")) + .child(BulletItem::new("50 prompts with Claude models")) .child(BulletItem::new("2,000 accepted edit predictions")), ); @@ -109,7 +114,8 @@ impl RenderOnce for AiUpsellCard { .on_click(move |_, _window, cx| { telemetry::event!("Start Trial Clicked", state = "post-sign-in"); cx.open_url(&zed_urls::start_trial_url(cx)) - }), + }) + .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)), ) .child( Label::new("No credit card required") @@ -120,6 +126,7 @@ impl RenderOnce for AiUpsellCard { _ => Button::new("sign_in", "Sign In") .full_width() .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)) .on_click({ let callback = self.sign_in.clone(); move |_, window, cx| { @@ -132,22 +139,28 @@ impl RenderOnce for AiUpsellCard { v_flex() .relative() - .p_6() - .pt_4() + .p_4() + .pt_3() .border_1() .border_color(cx.theme().colors().border) .rounded_lg() .overflow_hidden() .child(grid_bg) .child(gradient_bg) - .child(Headline::new("Try Zed AI")) - .child(Label::new(DESCRIPTION).color(Color::Muted).mb_2()) + .child(Label::new("Try Zed AI").size(LabelSize::Large)) + .child( + div() + .max_w_3_4() + .mb_2() + .child(Label::new(DESCRIPTION).color(Color::Muted)), + ) .child( h_flex() + .w_full() .mt_1p5() .mb_2p5() .items_start() - .gap_12() + .gap_6() .child(free_section) .child(pro_section), ) @@ -183,6 +196,8 @@ impl Component for AiUpsellCard { AiUpsellCard { sign_in_status: SignInStatus::SignedOut, sign_in: Arc::new(|_, _| {}), + user_plan: None, + tab_index: Some(0), } .into_any_element(), ), @@ -191,6 +206,8 @@ impl Component for AiUpsellCard { AiUpsellCard { sign_in_status: SignInStatus::SignedIn, sign_in: Arc::new(|_, _| {}), + user_plan: None, + tab_index: Some(1), } .into_any_element(), ), diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index c73f606045..3ff1666755 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -36,11 +36,18 @@ pub enum AnthropicModelMode { pub enum Model { #[serde(rename = "claude-opus-4", alias = "claude-opus-4-latest")] ClaudeOpus4, + #[serde(rename = "claude-opus-4-1", alias = "claude-opus-4-1-latest")] + ClaudeOpus4_1, #[serde( rename = "claude-opus-4-thinking", alias = "claude-opus-4-thinking-latest" )] ClaudeOpus4Thinking, + #[serde( + rename = "claude-opus-4-1-thinking", + alias = "claude-opus-4-1-thinking-latest" + )] + ClaudeOpus4_1Thinking, #[default] #[serde(rename = "claude-sonnet-4", alias = "claude-sonnet-4-latest")] ClaudeSonnet4, @@ -91,10 +98,18 @@ impl Model { } pub fn from_id(id: &str) -> Result { + if id.starts_with("claude-opus-4-1-thinking") { + return Ok(Self::ClaudeOpus4_1Thinking); + } + if id.starts_with("claude-opus-4-thinking") { return Ok(Self::ClaudeOpus4Thinking); } + if id.starts_with("claude-opus-4-1") { + return Ok(Self::ClaudeOpus4_1); + } + if id.starts_with("claude-opus-4") { return Ok(Self::ClaudeOpus4); } @@ -141,7 +156,9 @@ impl Model { pub fn id(&self) -> &str { match self { Self::ClaudeOpus4 => "claude-opus-4-latest", + Self::ClaudeOpus4_1 => "claude-opus-4-1-latest", Self::ClaudeOpus4Thinking => "claude-opus-4-thinking-latest", + Self::ClaudeOpus4_1Thinking => "claude-opus-4-1-thinking-latest", Self::ClaudeSonnet4 => "claude-sonnet-4-latest", Self::ClaudeSonnet4Thinking => "claude-sonnet-4-thinking-latest", Self::Claude3_5Sonnet => "claude-3-5-sonnet-latest", @@ -159,6 +176,7 @@ impl Model { pub fn request_id(&self) -> &str { match self { Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking => "claude-opus-4-20250514", + Self::ClaudeOpus4_1 | Self::ClaudeOpus4_1Thinking => "claude-opus-4-1-20250805", Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => "claude-sonnet-4-20250514", Self::Claude3_5Sonnet => "claude-3-5-sonnet-latest", Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => "claude-3-7-sonnet-latest", @@ -173,7 +191,9 @@ impl Model { pub fn display_name(&self) -> &str { match self { Self::ClaudeOpus4 => "Claude Opus 4", + Self::ClaudeOpus4_1 => "Claude Opus 4.1", Self::ClaudeOpus4Thinking => "Claude Opus 4 Thinking", + Self::ClaudeOpus4_1Thinking => "Claude Opus 4.1 Thinking", Self::ClaudeSonnet4 => "Claude Sonnet 4", Self::ClaudeSonnet4Thinking => "Claude Sonnet 4 Thinking", Self::Claude3_7Sonnet => "Claude 3.7 Sonnet", @@ -192,7 +212,9 @@ impl Model { pub fn cache_configuration(&self) -> Option { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -215,7 +237,9 @@ impl Model { pub fn max_token_count(&self) -> u64 { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -232,7 +256,9 @@ impl Model { pub fn max_output_tokens(&self) -> u64 { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -249,7 +275,9 @@ impl Model { pub fn default_temperature(&self) -> f32 { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -269,6 +297,7 @@ impl Model { pub fn mode(&self) -> AnthropicModelMode { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeSonnet4 | Self::Claude3_5Sonnet | Self::Claude3_7Sonnet @@ -277,6 +306,7 @@ impl Model { | Self::Claude3Sonnet | Self::Claude3Haiku => AnthropicModelMode::Default, Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4Thinking | Self::Claude3_7SonnetThinking => AnthropicModelMode::Thinking { budget_tokens: Some(4_096), diff --git a/crates/assistant_context/Cargo.toml b/crates/assistant_context/Cargo.toml index f35dc43340..8f5ff98790 100644 --- a/crates/assistant_context/Cargo.toml +++ b/crates/assistant_context/Cargo.toml @@ -19,6 +19,7 @@ assistant_slash_commands.workspace = true chrono.workspace = true client.workspace = true clock.workspace = true +cloud_llm_client.workspace = true collections.workspace = true context_server.workspace = true fs.workspace = true @@ -48,7 +49,6 @@ util.workspace = true uuid.workspace = true workspace-hack.workspace = true workspace.workspace = true -zed_llm_client.workspace = true [dev-dependencies] indoc.workspace = true diff --git a/crates/assistant_context/src/assistant_context.rs b/crates/assistant_context/src/assistant_context.rs index 136468e084..4518bbff79 100644 --- a/crates/assistant_context/src/assistant_context.rs +++ b/crates/assistant_context/src/assistant_context.rs @@ -11,6 +11,7 @@ use assistant_slash_command::{ use assistant_slash_commands::FileCommandMetadata; use client::{self, Client, proto, telemetry::Telemetry}; use clock::ReplicaId; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use fs::{Fs, RenameOptions}; use futures::{FutureExt, StreamExt, future::Shared}; @@ -46,7 +47,6 @@ use text::{BufferSnapshot, ToPoint}; use ui::IconName; use util::{ResultExt, TryFutureExt, post_inc}; use uuid::Uuid; -use zed_llm_client::CompletionIntent; pub use crate::context_store::*; diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index 672c048872..025aba060d 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -630,6 +630,11 @@ impl ActionLog { false } }); + if tracked_buffer.unreviewed_edits.is_empty() { + if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status { + tracked_buffer.status = TrackedBufferStatus::Modified; + } + } tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx); } } @@ -775,6 +780,9 @@ impl ActionLog { .retain(|_buffer, tracked_buffer| match tracked_buffer.status { TrackedBufferStatus::Deleted => false, _ => { + if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status { + tracked_buffer.status = TrackedBufferStatus::Modified; + } tracked_buffer.unreviewed_edits.clear(); tracked_buffer.diff_base = tracked_buffer.snapshot.as_rope().clone(); tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx); @@ -2075,6 +2083,134 @@ mod tests { assert_eq!(content, "ai content\nuser added this line"); } + #[gpui::test] + async fn test_reject_after_accepting_hunk_on_created_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/new_file", cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path.clone(), cx)) + .await + .unwrap(); + + // AI creates file with initial content + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v1", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + assert_ne!(unreviewed_hunks(&action_log, cx), vec![]); + + // User accepts the single hunk + action_log.update(cx, |log, cx| { + log.keep_edits_in_range(buffer.clone(), Anchor::MIN..Anchor::MAX, cx) + }); + cx.run_until_parked(); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + + // AI modifies the file + cx.update(|cx| { + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v2", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + assert_ne!(unreviewed_hunks(&action_log, cx), vec![]); + + // User rejects the hunk + action_log + .update(cx, |log, cx| { + log.reject_edits_in_ranges(buffer.clone(), vec![Anchor::MIN..Anchor::MAX], cx) + }) + .await + .unwrap(); + cx.run_until_parked(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await,); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "ai content v1" + ); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + + #[gpui::test] + async fn test_reject_edits_on_previously_accepted_created_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/new_file", cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path.clone(), cx)) + .await + .unwrap(); + + // AI creates file with initial content + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v1", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + + // User clicks "Accept All" + action_log.update(cx, |log, cx| log.keep_all_edits(cx)); + cx.run_until_parked(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); // Hunks are cleared + + // AI modifies file again + cx.update(|cx| { + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v2", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + assert_ne!(unreviewed_hunks(&action_log, cx), vec![]); + + // User clicks "Reject All" + action_log + .update(cx, |log, cx| log.reject_all_edits(cx)) + .await; + cx.run_until_parked(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "ai content v1" + ); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + #[gpui::test(iterations = 100)] async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 554b3f3f3c..22cbaac3f8 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -216,7 +216,12 @@ pub trait Tool: 'static + Send + Sync { /// Returns true if the tool needs the users's confirmation /// before having permission to run. - fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool; + fn needs_confirmation( + &self, + input: &serde_json::Value, + project: &Entity, + cx: &App, + ) -> bool; /// Returns true if the tool may perform edits. fn may_perform_edits(&self) -> bool; diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 9a6ec49914..c0a358917b 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -375,7 +375,12 @@ mod tests { false } - fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + fn needs_confirmation( + &self, + _input: &serde_json::Value, + _project: &Entity, + _cx: &App, + ) -> bool { true } diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 146800e094..d4b8fa3afc 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -21,9 +21,11 @@ assistant_tool.workspace = true buffer_diff.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true component.workspace = true derive_more.workspace = true +diffy = "0.4.2" editor.workspace = true feature_flags.workspace = true futures.workspace = true @@ -63,8 +65,6 @@ web_search.workspace = true which.workspace = true workspace-hack.workspace = true workspace.workspace = true -zed_llm_client.workspace = true -diffy = "0.4.2" [dev-dependencies] lsp = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/copy_path_tool.rs b/crates/assistant_tools/src/copy_path_tool.rs index 1922b5677a..e34ae9ff93 100644 --- a/crates/assistant_tools/src/copy_path_tool.rs +++ b/crates/assistant_tools/src/copy_path_tool.rs @@ -44,7 +44,7 @@ impl Tool for CopyPathTool { "copy_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/create_directory_tool.rs b/crates/assistant_tools/src/create_directory_tool.rs index 224e8357e5..11d969d234 100644 --- a/crates/assistant_tools/src/create_directory_tool.rs +++ b/crates/assistant_tools/src/create_directory_tool.rs @@ -37,7 +37,7 @@ impl Tool for CreateDirectoryTool { include_str!("./create_directory_tool/description.md").into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index b13f9863c9..9e69c18b65 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -33,7 +33,7 @@ impl Tool for DeletePathTool { "delete_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index 84595a37b7..12ab97f820 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -46,7 +46,7 @@ impl Tool for DiagnosticsTool { "diagnostics".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index 0184dff36c..fed79434bb 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -7,6 +7,7 @@ mod streaming_fuzzy_matcher; use crate::{Template, Templates}; use anyhow::Result; use assistant_tool::ActionLog; +use cloud_llm_client::CompletionIntent; use create_file_parser::{CreateFileParser, CreateFileParserEvent}; pub use edit_parser::EditFormat; use edit_parser::{EditParser, EditParserEvent, EditParserMetrics}; @@ -29,7 +30,6 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task:: use streaming_diff::{CharOperation, StreamingDiff}; use streaming_fuzzy_matcher::StreamingFuzzyMatcher; use util::debug_panic; -use zed_llm_client::CompletionIntent; #[derive(Serialize)] struct CreateFilePromptTemplate { diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index eda7eee0e3..9a8e762455 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -1658,23 +1658,24 @@ impl EditAgentTest { } async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> Result { + const MAX_RETRIES: usize = 20; let mut attempt = 0; + loop { attempt += 1; - match request().await { - Ok(result) => return Ok(result), - Err(err) => match err.downcast::() { - Ok(err) => match &err { + let response = request().await; + + if attempt >= MAX_RETRIES { + return response; + } + + let retry_delay = match &response { + Ok(_) => None, + Err(err) => match err.downcast_ref::() { + Some(err) => match &err { LanguageModelCompletionError::RateLimitExceeded { retry_after, .. } | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => { - let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); - // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time. - let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); - eprintln!( - "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" - ); - Timer::after(retry_after + jitter).await; - continue; + Some(retry_after.unwrap_or(Duration::from_secs(5))) } LanguageModelCompletionError::UpstreamProviderError { status, @@ -1687,23 +1688,31 @@ async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE ) || status.as_u16() == 529; - if !should_retry { - return Err(err.into()); + if should_retry { + // Use server-provided retry_after if available, otherwise use default + Some(retry_after.unwrap_or(Duration::from_secs(5))) + } else { + None } - - // Use server-provided retry_after if available, otherwise use default - let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); - let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); - eprintln!( - "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" - ); - Timer::after(retry_after + jitter).await; - continue; } - _ => return Err(err.into()), + LanguageModelCompletionError::ApiReadResponseError { .. } + | LanguageModelCompletionError::ApiInternalServerError { .. } + | LanguageModelCompletionError::HttpSend { .. } => { + // Exponential backoff for transient I/O and internal server errors + Some(Duration::from_secs(2_u64.pow((attempt - 1) as u32).min(30))) + } + _ => None, }, - Err(err) => return Err(err), + _ => None, }, + }; + + if let Some(retry_after) = retry_delay { + let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); + eprintln!("Attempt #{attempt}: Retry after {retry_after:?} + jitter of {jitter:?}"); + Timer::after(retry_after + jitter).await; + } else { + return response; } } } diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 6413677bd9..1c41b26092 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -25,6 +25,7 @@ use language::{ }; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use markdown::{Markdown, MarkdownElement, MarkdownStyle}; +use paths; use project::{ Project, ProjectPath, lsp_store::{FormatTrigger, LspFormatTarget}, @@ -126,8 +127,47 @@ impl Tool for EditFileTool { "edit_file".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { - false + fn needs_confirmation( + &self, + input: &serde_json::Value, + project: &Entity, + cx: &App, + ) -> bool { + if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { + return false; + } + + let Ok(input) = serde_json::from_value::(input.clone()) else { + // If it's not valid JSON, it's going to error and confirming won't do anything. + return false; + }; + + // If any path component matches the local settings folder, then this could affect + // the editor in ways beyond the project source, so prompt. + let local_settings_folder = paths::local_settings_folder_relative_path(); + let path = Path::new(&input.path); + if path + .components() + .any(|component| component.as_os_str() == local_settings_folder.as_os_str()) + { + return true; + } + + // It's also possible that the global config dir is configured to be inside the project, + // so check for that edge case too. + if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { + if canonical_path.starts_with(paths::config_dir()) { + return true; + } + } + + // Check if path is inside the global config directory + // First check if it's already inside project - if not, try to canonicalize + let project_path = project.read(cx).find_project_path(&input.path, cx); + + // If the path is inside the project, and it's not one of the above edge cases, + // then no confirmation is necessary. Otherwise, confirmation is necessary. + project_path.is_none() } fn may_perform_edits(&self) -> bool { @@ -148,7 +188,25 @@ impl Tool for EditFileTool { fn ui_text(&self, input: &serde_json::Value) -> String { match serde_json::from_value::(input.clone()) { - Ok(input) => input.display_description, + Ok(input) => { + let path = Path::new(&input.path); + let mut description = input.display_description.clone(); + + // Add context about why confirmation may be needed + let local_settings_folder = paths::local_settings_folder_relative_path(); + if path + .components() + .any(|c| c.as_os_str() == local_settings_folder.as_os_str()) + { + description.push_str(" (local settings)"); + } else if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { + if canonical_path.starts_with(paths::config_dir()) { + description.push_str(" (global settings)"); + } + } + + description + } Err(_) => "Editing file".to_string(), } } @@ -1175,19 +1233,20 @@ async fn build_buffer_diff( #[cfg(test)] mod tests { use super::*; + use ::fs::Fs; use client::TelemetrySettings; - use fs::{FakeFs, Fs}; use gpui::{TestAppContext, UpdateGlobal}; use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; + use std::fs; use util::path; #[gpui::test] async fn test_edit_nonexistent_file(cx: &mut TestAppContext) { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -1277,7 +1336,7 @@ mod tests { ) -> anyhow::Result { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree( "/root", json!({ @@ -1384,6 +1443,21 @@ mod tests { cx.set_global(settings_store); language::init(cx); TelemetrySettings::register(cx); + agent_settings::AgentSettings::register(cx); + Project::init_settings(cx); + }); + } + + fn init_test_with_config(cx: &mut TestAppContext, data_dir: &Path) { + cx.update(|cx| { + // Set custom data directory (config will be under data_dir/config) + paths::set_custom_data_dir(data_dir.to_str().unwrap()); + + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + TelemetrySettings::register(cx); + agent_settings::AgentSettings::register(cx); Project::init_settings(cx); }); } @@ -1392,7 +1466,7 @@ mod tests { async fn test_format_on_save(cx: &mut TestAppContext) { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; @@ -1591,7 +1665,7 @@ mod tests { async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; // Create a simple file with trailing whitespace @@ -1723,4 +1797,641 @@ mod tests { "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" ); } + + #[gpui::test] + async fn test_needs_confirmation(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({})).await; + + // Test 1: Path with .zed component should require confirmation + let input_with_zed = json!({ + "display_description": "Edit settings", + "path": ".zed/settings.json", + "mode": "edit" + }); + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_with_zed, &project, cx), + "Path with .zed component should require confirmation" + ); + }); + + // Test 2: Absolute path should require confirmation + let input_absolute = json!({ + "display_description": "Edit file", + "path": "/etc/hosts", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_absolute, &project, cx), + "Absolute path should require confirmation" + ); + }); + + // Test 3: Relative path without .zed should not require confirmation + let input_relative = json!({ + "display_description": "Edit file", + "path": "root/src/main.rs", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input_relative, &project, cx), + "Relative path without .zed should not require confirmation" + ); + }); + + // Test 4: Path with .zed in the middle should require confirmation + let input_zed_middle = json!({ + "display_description": "Edit settings", + "path": "root/.zed/tasks.json", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_zed_middle, &project, cx), + "Path with .zed in any component should require confirmation" + ); + }); + + // Test 5: When always_allow_tool_actions is enabled, no confirmation needed + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = true; + agent_settings::AgentSettings::override_global(settings, cx); + + assert!( + !tool.needs_confirmation(&input_with_zed, &project, cx), + "When always_allow_tool_actions is true, no confirmation should be needed" + ); + assert!( + !tool.needs_confirmation(&input_absolute, &project, cx), + "When always_allow_tool_actions is true, no confirmation should be needed for absolute paths" + ); + }); + } + + #[gpui::test] + async fn test_ui_text_shows_correct_context(cx: &mut TestAppContext) { + // Set up a custom config directory for testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + + // Test ui_text shows context for various paths + let test_cases = vec![ + ( + json!({ + "display_description": "Update config", + "path": ".zed/settings.json", + "mode": "edit" + }), + "Update config (local settings)", + ".zed path should show local settings context", + ), + ( + json!({ + "display_description": "Fix bug", + "path": "src/.zed/local.json", + "mode": "edit" + }), + "Fix bug (local settings)", + "Nested .zed path should show local settings context", + ), + ( + json!({ + "display_description": "Update readme", + "path": "README.md", + "mode": "edit" + }), + "Update readme", + "Normal path should not show additional context", + ), + ( + json!({ + "display_description": "Edit config", + "path": "config.zed", + "mode": "edit" + }), + "Edit config", + ".zed as extension should not show context", + ), + ]; + + for (input, expected_text, description) in test_cases { + cx.update(|_cx| { + let ui_text = tool.ui_text(&input); + assert_eq!(ui_text, expected_text, "Failed for case: {}", description); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_outside_project(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + + // Create a project in /project directory + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test file outside project requires confirmation + let input_outside = json!({ + "display_description": "Edit file", + "path": "/outside/file.txt", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_outside, &project, cx), + "File outside project should require confirmation" + ); + }); + + // Test file inside project doesn't require confirmation + let input_inside = json!({ + "display_description": "Edit file", + "path": "project/file.txt", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input_inside, &project, cx), + "File inside project should not require confirmation" + ); + }); + } + + #[gpui::test] + async fn test_needs_confirmation_config_paths(cx: &mut TestAppContext) { + // Set up a custom data directory for testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/home/user/myproject", json!({})).await; + let project = Project::test(fs.clone(), [path!("/home/user/myproject").as_ref()], cx).await; + + // Get the actual local settings folder name + let local_settings_folder = paths::local_settings_folder_relative_path(); + + // Test various config path patterns + let test_cases = vec![ + ( + format!("{}/settings.json", local_settings_folder.display()), + true, + "Top-level local settings file".to_string(), + ), + ( + format!( + "myproject/{}/settings.json", + local_settings_folder.display() + ), + true, + "Local settings in project path".to_string(), + ), + ( + format!("src/{}/config.toml", local_settings_folder.display()), + true, + "Local settings in subdirectory".to_string(), + ), + ( + ".zed.backup/file.txt".to_string(), + true, + ".zed.backup is outside project".to_string(), + ), + ( + "my.zed/file.txt".to_string(), + true, + "my.zed is outside project".to_string(), + ), + ( + "myproject/src/file.zed".to_string(), + false, + ".zed as file extension".to_string(), + ), + ( + "myproject/normal/path/file.rs".to_string(), + false, + "Normal file without config paths".to_string(), + ), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {} - path: {}", + description, + path + ); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_global_config(cx: &mut TestAppContext) { + // Set up a custom data directory for testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + + // Create test files in the global config directory + let global_config_dir = paths::config_dir(); + fs::create_dir_all(&global_config_dir).unwrap(); + let global_settings_path = global_config_dir.join("settings.json"); + fs::write(&global_settings_path, "{}").unwrap(); + + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test global config paths + let test_cases = vec![ + ( + global_settings_path.to_str().unwrap().to_string(), + true, + "Global settings file should require confirmation", + ), + ( + global_config_dir + .join("keymap.json") + .to_str() + .unwrap() + .to_string(), + true, + "Global keymap file should require confirmation", + ), + ( + "project/normal_file.rs".to_string(), + false, + "Normal project file should not require confirmation", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {}", + description + ); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + + // Create multiple worktree directories + fs.insert_tree( + "/workspace/frontend", + json!({ + "src": { + "main.js": "console.log('frontend');" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/backend", + json!({ + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/shared", + json!({ + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + + // Create project with multiple worktrees + let project = Project::test( + fs.clone(), + [ + path!("/workspace/frontend").as_ref(), + path!("/workspace/backend").as_ref(), + path!("/workspace/shared").as_ref(), + ], + cx, + ) + .await; + + // Test files in different worktrees + let test_cases = vec![ + ("frontend/src/main.js", false, "File in first worktree"), + ("backend/src/main.rs", false, "File in second worktree"), + ( + "shared/.zed/settings.json", + true, + ".zed file in third worktree", + ), + ("/etc/hosts", true, "Absolute path outside all worktrees"), + ( + "../outside/file.txt", + true, + "Relative path outside worktrees", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {} - path: {}", + description, + path + ); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + ".zed": { + "settings.json": "{}" + }, + "src": { + ".zed": { + "local.json": "{}" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test edge cases + let test_cases = vec![ + // Empty path - find_project_path returns Some for empty paths + ("", false, "Empty path is treated as project root"), + // Root directory + ("/", true, "Root directory should be outside project"), + // Parent directory references - find_project_path resolves these + ( + "project/../other", + false, + "Path with .. is resolved by find_project_path", + ), + ( + "project/./src/file.rs", + false, + "Path with . should work normally", + ), + // Windows-style paths (if on Windows) + #[cfg(target_os = "windows")] + ("C:\\Windows\\System32\\hosts", true, "Windows system path"), + #[cfg(target_os = "windows")] + ("project\\src\\main.rs", false, "Windows-style project path"), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {} - path: {}", + description, + path + ); + }); + } + } + + #[gpui::test] + async fn test_ui_text_with_all_path_types(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + + // Test UI text for various scenarios + let test_cases = vec![ + ( + json!({ + "display_description": "Update config", + "path": ".zed/settings.json", + "mode": "edit" + }), + "Update config (local settings)", + ".zed path should show local settings context", + ), + ( + json!({ + "display_description": "Fix bug", + "path": "src/.zed/local.json", + "mode": "edit" + }), + "Fix bug (local settings)", + "Nested .zed path should show local settings context", + ), + ( + json!({ + "display_description": "Update readme", + "path": "README.md", + "mode": "edit" + }), + "Update readme", + "Normal path should not show additional context", + ), + ( + json!({ + "display_description": "Edit config", + "path": "config.zed", + "mode": "edit" + }), + "Edit config", + ".zed as extension should not show context", + ), + ]; + + for (input, expected_text, description) in test_cases { + cx.update(|_cx| { + let ui_text = tool.ui_text(&input); + assert_eq!(ui_text, expected_text, "Failed for case: {}", description); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "existing.txt": "content", + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test different EditFileMode values + let modes = vec![ + EditFileMode::Edit, + EditFileMode::Create, + EditFileMode::Overwrite, + ]; + + for mode in modes { + // Test .zed path with different modes + let input_zed = json!({ + "display_description": "Edit settings", + "path": "project/.zed/settings.json", + "mode": mode + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_zed, &project, cx), + ".zed path should require confirmation regardless of mode: {:?}", + mode + ); + }); + + // Test outside path with different modes + let input_outside = json!({ + "display_description": "Edit file", + "path": "/outside/file.txt", + "mode": mode + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_outside, &project, cx), + "Outside path should require confirmation regardless of mode: {:?}", + mode + ); + }); + + // Test normal path with different modes + let input_normal = json!({ + "display_description": "Edit file", + "path": "project/normal.txt", + "mode": mode + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input_normal, &project, cx), + "Normal path should not require confirmation regardless of mode: {:?}", + mode + ); + }); + } + } + + #[gpui::test] + async fn test_always_allow_tool_actions_bypasses_all_checks(cx: &mut TestAppContext) { + // Set up with custom directories for deterministic testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Enable always_allow_tool_actions + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = true; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + // Test that all paths that normally require confirmation are bypassed + let global_settings_path = paths::config_dir().join("settings.json"); + fs::create_dir_all(paths::config_dir()).unwrap(); + fs::write(&global_settings_path, "{}").unwrap(); + + let test_cases = vec![ + ".zed/settings.json", + "project/.zed/config.toml", + global_settings_path.to_str().unwrap(), + "/etc/hosts", + "/absolute/path/file.txt", + "../outside/project.txt", + ]; + + for path in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input, &project, cx), + "Path {} should not require confirmation when always_allow_tool_actions is true", + path + ); + }); + } + + // Disable always_allow_tool_actions and verify confirmation is required again + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = false; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + // Verify .zed path requires confirmation again + let input = json!({ + "display_description": "Edit file", + "path": ".zed/settings.json", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input, &project, cx), + ".zed path should require confirmation when always_allow_tool_actions is false" + ); + }); + } } diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index 54d49359ba..a31ec39268 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -116,7 +116,7 @@ impl Tool for FetchTool { "fetch".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs index fd0e44e42c..affc019417 100644 --- a/crates/assistant_tools/src/find_path_tool.rs +++ b/crates/assistant_tools/src/find_path_tool.rs @@ -55,7 +55,7 @@ impl Tool for FindPathTool { "find_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index 053273d71b..43c3d1d990 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -57,7 +57,7 @@ impl Tool for GrepTool { "grep".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index 723416e2ce..b1980615d6 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -45,7 +45,7 @@ impl Tool for ListDirectoryTool { "list_directory".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/move_path_tool.rs b/crates/assistant_tools/src/move_path_tool.rs index 27ae10151d..c1cbbf848d 100644 --- a/crates/assistant_tools/src/move_path_tool.rs +++ b/crates/assistant_tools/src/move_path_tool.rs @@ -42,7 +42,7 @@ impl Tool for MovePathTool { "move_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index b6b1cf90a4..b51b91d3d5 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -33,7 +33,7 @@ impl Tool for NowTool { "now".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/open_tool.rs b/crates/assistant_tools/src/open_tool.rs index 97a4769e19..8fddbb0431 100644 --- a/crates/assistant_tools/src/open_tool.rs +++ b/crates/assistant_tools/src/open_tool.rs @@ -23,7 +23,7 @@ impl Tool for OpenTool { "open".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { true } fn may_perform_edits(&self) -> bool { diff --git a/crates/assistant_tools/src/project_notifications_tool.rs b/crates/assistant_tools/src/project_notifications_tool.rs index 7567926dca..03487e5419 100644 --- a/crates/assistant_tools/src/project_notifications_tool.rs +++ b/crates/assistant_tools/src/project_notifications_tool.rs @@ -19,7 +19,7 @@ impl Tool for ProjectNotificationsTool { "project_notifications".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } fn may_perform_edits(&self) -> bool { diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index dc504e2dc4..ee38273cc0 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -54,7 +54,7 @@ impl Tool for ReadFileTool { "read_file".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 03e76f6a5b..58833c5208 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -77,7 +77,7 @@ impl Tool for TerminalTool { Self::NAME.to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { true } diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index 422204f97d..443c2930be 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -24,7 +24,7 @@ impl Tool for ThinkingTool { "thinking".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index 24bc8e9cba..d4a12f22c5 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -6,6 +6,7 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, }; +use cloud_llm_client::{WebSearchResponse, WebSearchResult}; use futures::{Future, FutureExt, TryFutureExt}; use gpui::{ AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, @@ -17,7 +18,6 @@ use serde::{Deserialize, Serialize}; use ui::{IconName, Tooltip, prelude::*}; use web_search::WebSearchRegistry; use workspace::Workspace; -use zed_llm_client::{WebSearchResponse, WebSearchResult}; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct WebSearchToolInput { @@ -32,7 +32,7 @@ impl Tool for WebSearchTool { "web_search".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/audio/Cargo.toml b/crates/audio/Cargo.toml index 960aaf8e08..d857a3eb2f 100644 --- a/crates/audio/Cargo.toml +++ b/crates/audio/Cargo.toml @@ -18,6 +18,6 @@ collections.workspace = true derive_more.workspace = true gpui.workspace = true parking_lot.workspace = true -rodio = { version = "0.20.0", default-features = false, features = ["wav"] } +rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] } util.workspace = true workspace-hack.workspace = true diff --git a/crates/audio/src/assets.rs b/crates/audio/src/assets.rs index 02da79dc24..fd5c935d87 100644 --- a/crates/audio/src/assets.rs +++ b/crates/audio/src/assets.rs @@ -3,12 +3,9 @@ use std::{io::Cursor, sync::Arc}; use anyhow::{Context as _, Result}; use collections::HashMap; use gpui::{App, AssetSource, Global}; -use rodio::{ - Decoder, Source, - source::{Buffered, SamplesConverter}, -}; +use rodio::{Decoder, Source, source::Buffered}; -type Sound = Buffered>>, f32>>; +type Sound = Buffered>>>; pub struct SoundRegistry { cache: Arc>>, @@ -48,7 +45,7 @@ impl SoundRegistry { .with_context(|| format!("No asset available for path {path}"))?? .into_owned(); let cursor = Cursor::new(bytes); - let source = Decoder::new(cursor)?.convert_samples::().buffered(); + let source = Decoder::new(cursor)?.buffered(); self.cache.lock().insert(name.to_string(), source.clone()); diff --git a/crates/audio/src/audio.rs b/crates/audio/src/audio.rs index e7b9a59e8f..44baa16aa2 100644 --- a/crates/audio/src/audio.rs +++ b/crates/audio/src/audio.rs @@ -1,7 +1,7 @@ use assets::SoundRegistry; use derive_more::{Deref, DerefMut}; use gpui::{App, AssetSource, BorrowAppContext, Global}; -use rodio::{OutputStream, OutputStreamHandle}; +use rodio::{OutputStream, OutputStreamBuilder}; use util::ResultExt; mod assets; @@ -37,8 +37,7 @@ impl Sound { #[derive(Default)] pub struct Audio { - _output_stream: Option, - output_handle: Option, + output_handle: Option, } #[derive(Deref, DerefMut)] @@ -51,11 +50,9 @@ impl Audio { Self::default() } - fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> { + fn ensure_output_exists(&mut self) -> Option<&OutputStream> { if self.output_handle.is_none() { - let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip(); - self.output_handle = output_handle; - self._output_stream = _output_stream; + self.output_handle = OutputStreamBuilder::open_default_stream().log_err(); } self.output_handle.as_ref() @@ -69,7 +66,7 @@ impl Audio { cx.update_global::(|this, cx| { let output_handle = this.ensure_output_exists()?; let source = SoundRegistry::global(cx).get(sound.file()).log_err()?; - output_handle.play_raw(source).log_err()?; + output_handle.mixer().add(source); Some(()) }); } @@ -80,7 +77,6 @@ impl Audio { } cx.update_global::(|this, _| { - this._output_stream.take(); this.output_handle.take(); }); } diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index b6eeafa2d6..69d2ffb845 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -32,11 +32,18 @@ pub enum Model { ClaudeSonnet4Thinking, #[serde(rename = "claude-opus-4", alias = "claude-opus-4-latest")] ClaudeOpus4, + #[serde(rename = "claude-opus-4-1", alias = "claude-opus-4-1-latest")] + ClaudeOpus4_1, #[serde( rename = "claude-opus-4-thinking", alias = "claude-opus-4-thinking-latest" )] ClaudeOpus4Thinking, + #[serde( + rename = "claude-opus-4-1-thinking", + alias = "claude-opus-4-1-thinking-latest" + )] + ClaudeOpus4_1Thinking, #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")] Claude3_5SonnetV2, #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")] @@ -147,7 +154,9 @@ impl Model { Model::ClaudeSonnet4 => "claude-4-sonnet", Model::ClaudeSonnet4Thinking => "claude-4-sonnet-thinking", Model::ClaudeOpus4 => "claude-4-opus", + Model::ClaudeOpus4_1 => "claude-4-opus-1", Model::ClaudeOpus4Thinking => "claude-4-opus-thinking", + Model::ClaudeOpus4_1Thinking => "claude-4-opus-1-thinking", Model::Claude3_5SonnetV2 => "claude-3-5-sonnet-v2", Model::Claude3_5Sonnet => "claude-3-5-sonnet", Model::Claude3Opus => "claude-3-opus", @@ -208,6 +217,9 @@ impl Model { Model::ClaudeOpus4 | Model::ClaudeOpus4Thinking => { "anthropic.claude-opus-4-20250514-v1:0" } + Model::ClaudeOpus4_1 | Model::ClaudeOpus4_1Thinking => { + "anthropic.claude-opus-4-1-20250805-v1:0" + } Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0", Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0", Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0", @@ -266,7 +278,9 @@ impl Model { Self::ClaudeSonnet4 => "Claude Sonnet 4", Self::ClaudeSonnet4Thinking => "Claude Sonnet 4 Thinking", Self::ClaudeOpus4 => "Claude Opus 4", + Self::ClaudeOpus4_1 => "Claude Opus 4.1", Self::ClaudeOpus4Thinking => "Claude Opus 4 Thinking", + Self::ClaudeOpus4_1Thinking => "Claude Opus 4.1 Thinking", Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2", Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", Self::Claude3Opus => "Claude 3 Opus", @@ -330,8 +344,10 @@ impl Model { | Self::Claude3_7Sonnet | Self::ClaudeSonnet4 | Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeSonnet4Thinking - | Self::ClaudeOpus4Thinking => 200_000, + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking => 200_000, Self::AmazonNovaPremier => 1_000_000, Self::PalmyraWriterX5 => 1_000_000, Self::PalmyraWriterX4 => 128_000, @@ -348,7 +364,9 @@ impl Model { | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::ClaudeOpus4 - | Model::ClaudeOpus4Thinking => 128_000, + | Model::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Model::ClaudeOpus4_1Thinking => 128_000, Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192, Self::Custom { max_output_tokens, .. @@ -366,6 +384,8 @@ impl Model { | Self::Claude3_7Sonnet | Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => 1.0, Self::Custom { @@ -387,6 +407,8 @@ impl Model { | Self::Claude3_7SonnetThinking | Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Haiku => true, @@ -420,7 +442,9 @@ impl Model { | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::ClaudeOpus4 - | Self::ClaudeOpus4Thinking => true, + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking => true, // Custom models - check if they have cache configuration Self::Custom { @@ -440,7 +464,9 @@ impl Model { | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::ClaudeOpus4 - | Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration { + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking => Some(BedrockModelCacheConfiguration { max_cache_anchors: 4, min_total_token: 1024, }), @@ -467,9 +493,11 @@ impl Model { Model::ClaudeSonnet4Thinking => BedrockModelMode::Thinking { budget_tokens: Some(4096), }, - Model::ClaudeOpus4Thinking => BedrockModelMode::Thinking { - budget_tokens: Some(4096), - }, + Model::ClaudeOpus4Thinking | Model::ClaudeOpus4_1Thinking => { + BedrockModelMode::Thinking { + budget_tokens: Some(4096), + } + } _ => BedrockModelMode::Default, } } @@ -518,6 +546,8 @@ impl Model { | Model::ClaudeSonnet4Thinking | Model::ClaudeOpus4 | Model::ClaudeOpus4Thinking + | Model::ClaudeOpus4_1 + | Model::ClaudeOpus4_1Thinking | Model::Claude3Haiku | Model::Claude3Opus | Model::Claude3Sonnet diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 866e3ccd90..4ac37ffd14 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -13,7 +13,7 @@ use std::{ ops::{ControlFlow, Range}, sync::Arc, }; -use sum_tree::{Bias, SumTree}; +use sum_tree::{Bias, Dimensions, SumTree}; use time::OffsetDateTime; use util::{ResultExt as _, TryFutureExt, post_inc}; @@ -331,7 +331,9 @@ impl ChannelChat { .update(&mut cx, |chat, cx| { if let Some(first_id) = chat.first_loaded_message_id() { if first_id <= message_id { - let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>(&()); + let mut cursor = chat + .messages + .cursor::>(&()); let message_id = ChannelMessageId::Saved(message_id); cursor.seek(&message_id, Bias::Left); return ControlFlow::Break( @@ -587,7 +589,9 @@ impl ChannelChat { .map(|m| m.nonce) .collect::>(); - let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>(&()); + let mut old_cursor = self + .messages + .cursor::>(&()); let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left); let start_ix = old_cursor.start().1.0; let removed_messages = old_cursor.slice(&last_message.id, Bias::Right); diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index b7ba811421..4ad156b9fb 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -126,7 +126,7 @@ impl ChannelMembership { proto::channel_member::Kind::Member => 0, proto::channel_member::Kind::Invitee => 1, }, - username_order: self.user.github_login.as_str(), + username_order: &self.user.github_login, } } } diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index f8f5de3c39..c92226eeeb 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -259,20 +259,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) { assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx); }); - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![5]); - server.respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 5, - github_login: "nathansobo".into(), - avatar_url: "http://avatar.com/nathansobo".into(), - name: None, - }], - }, - ); - // Join a channel and populate its existing messages. let channel = channel_store.update(cx, |store, cx| { let channel_id = store.ordered_channels().next().unwrap().1.id; @@ -334,7 +320,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "a".into()), + ("user-5".into(), "a".into()), ("maxbrunsfeld".into(), "b".into()) ] ); @@ -437,7 +423,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "y".into()), + ("user-5".into(), "y".into()), ("maxbrunsfeld".into(), "z".into()) ] ); diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index b741f515fd..365625b445 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -17,11 +17,12 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup [dependencies] anyhow.workspace = true -async-recursion = "0.3" async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] } base64.workspace = true chrono = { workspace = true, features = ["serde"] } clock.workspace = true +cloud_api_client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true credentials_provider.workspace = true derive_more.workspace = true @@ -33,8 +34,8 @@ http_client.workspace = true http_client_tls.workspace = true httparse = "1.10" log.workspace = true -paths.workspace = true parking_lot.workspace = true +paths.workspace = true postage.workspace = true rand.workspace = true regex.workspace = true @@ -46,19 +47,18 @@ serde_json.workspace = true settings.workspace = true sha2.workspace = true smol.workspace = true +telemetry.workspace = true telemetry_events.workspace = true text.workspace = true thiserror.workspace = true time.workspace = true tiny_http.workspace = true tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] } +tokio.workspace = true url.workspace = true util.workspace = true -worktree.workspace = true -telemetry.workspace = true -tokio.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true +worktree.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 81bb95b514..b4894cddcf 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -6,22 +6,21 @@ pub mod telemetry; pub mod user; pub mod zed_urls; -use anyhow::{Context as _, Result, anyhow, bail}; -use async_recursion::async_recursion; +use anyhow::{Context as _, Result, anyhow}; use async_tungstenite::tungstenite::{ client::IntoClientRequest, error::Error as WebsocketError, http::{HeaderValue, Request, StatusCode}, }; -use chrono::{DateTime, Utc}; use clock::SystemClock; +use cloud_api_client::CloudApiClient; use credentials_provider::CredentialsProvider; use futures::{ AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, channel::oneshot, future::BoxFuture, }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; -use http_client::{AsyncBody, HttpClient, HttpClientWithUrl}; +use http_client::{HttpClient, HttpClientWithUrl, http}; use parking_lot::RwLock; use postage::watch; use proxy::connect_proxy_stream; @@ -31,7 +30,6 @@ use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use std::pin::Pin; use std::{ any::TypeId, convert::TryFrom, @@ -45,6 +43,7 @@ use std::{ }, time::{Duration, Instant}, }; +use std::{cmp, pin::Pin}; use telemetry::Telemetry; use thiserror::Error; use tokio::net::TcpStream; @@ -78,7 +77,7 @@ pub static ZED_ALWAYS_ACTIVE: LazyLock = LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty())); pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500); -pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10); +pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30); pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20); actions!( @@ -151,7 +150,6 @@ impl Settings for ProxySettings { pub fn init_settings(cx: &mut App) { TelemetrySettings::register(cx); - DisableAiSettings::register(cx); ClientSettings::register(cx); ProxySettings::register(cx); } @@ -162,20 +160,8 @@ pub fn init(client: &Arc, cx: &mut App) { let client = client.clone(); move |_: &SignIn, cx| { if let Some(client) = client.upgrade() { - cx.spawn( - async move |cx| match client.authenticate_and_connect(true, &cx).await { - ConnectionResult::Timeout => { - log::error!("Initial authentication timed out"); - } - ConnectionResult::ConnectionReset => { - log::error!("Initial authentication connection reset"); - } - ConnectionResult::Result(r) => { - r.log_err(); - } - }, - ) - .detach(); + cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await) + .detach_and_log_err(cx); } } }); @@ -213,6 +199,7 @@ pub struct Client { id: AtomicU64, peer: Arc, http: Arc, + cloud_client: Arc, telemetry: Arc, credentials_provider: ClientCredentialsProvider, state: RwLock, @@ -283,6 +270,8 @@ pub enum Status { SignedOut, UpgradeRequired, Authenticating, + Authenticated, + AuthenticationError, Connecting, ConnectionError, Connected { @@ -549,33 +538,6 @@ impl settings::Settings for TelemetrySettings { } } -/// Whether to disable all AI features in Zed. -/// -/// Default: false -#[derive(Copy, Clone, Debug)] -pub struct DisableAiSettings { - pub disable_ai: bool, -} - -impl settings::Settings for DisableAiSettings { - const KEY: Option<&'static str> = Some("disable_ai"); - - type FileContent = Option; - - fn load(sources: SettingsSources, _: &mut App) -> Result { - Ok(Self { - disable_ai: sources - .user - .or(sources.server) - .copied() - .flatten() - .unwrap_or(sources.default.ok_or_else(Self::missing_default)?), - }) - } - - fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} -} - impl Client { pub fn new( clock: Arc, @@ -586,6 +548,7 @@ impl Client { id: AtomicU64::new(0), peer: Peer::new(0), telemetry: Telemetry::new(clock, http.clone(), cx), + cloud_client: Arc::new(CloudApiClient::new(http.clone())), http, credentials_provider: ClientCredentialsProvider::new(cx), state: Default::default(), @@ -618,6 +581,10 @@ impl Client { self.http.clone() } + pub fn cloud_client(&self) -> Arc { + self.cloud_client.clone() + } + pub fn set_id(&self, id: u64) -> &Self { self.id.store(id, Ordering::SeqCst); self @@ -704,7 +671,7 @@ impl Client { let mut delay = INITIAL_RECONNECTION_DELAY; loop { - match client.authenticate_and_connect(true, &cx).await { + match client.connect(true, &cx).await { ConnectionResult::Timeout => { log::error!("client connect attempt timed out") } @@ -720,18 +687,20 @@ impl Client { } } - if matches!(*client.status().borrow(), Status::ConnectionError) { + if matches!( + *client.status().borrow(), + Status::AuthenticationError | Status::ConnectionError + ) { client.set_status( Status::ReconnectionError { next_reconnection: Instant::now() + delay, }, &cx, ); - cx.background_executor().timer(delay).await; - delay = delay - .mul_f32(rng.gen_range(0.5..=2.5)) - .max(INITIAL_RECONNECTION_DELAY) - .min(MAX_RECONNECTION_DELAY); + let jitter = + Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64)); + cx.background_executor().timer(delay + jitter).await; + delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY); } else { break; } @@ -875,17 +844,127 @@ impl Client { .is_some() } - #[async_recursion(?Send)] - pub async fn authenticate_and_connect( + pub async fn sign_in( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> Result { + if self.status().borrow().is_signed_out() { + self.set_status(Status::Authenticating, cx); + } else { + self.set_status(Status::Reauthenticating, cx); + } + + let mut credentials = None; + + let old_credentials = self.state.read().credentials.clone(); + if let Some(old_credentials) = old_credentials { + if self.validate_credentials(&old_credentials, cx).await? { + credentials = Some(old_credentials); + } + } + + if credentials.is_none() && try_provider { + if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await { + if self.validate_credentials(&stored_credentials, cx).await? { + credentials = Some(stored_credentials); + } else { + self.credentials_provider + .delete_credentials(cx) + .await + .log_err(); + } + } + } + + if credentials.is_none() { + let mut status_rx = self.status(); + let _ = status_rx.next().await; + futures::select_biased! { + authenticate = self.authenticate(cx).fuse() => { + match authenticate { + Ok(creds) => { + if IMPERSONATE_LOGIN.is_none() { + self.credentials_provider + .write_credentials(creds.user_id, creds.access_token.clone(), cx) + .await + .log_err(); + } + + credentials = Some(creds); + }, + Err(err) => { + self.set_status(Status::AuthenticationError, cx); + return Err(err); + } + } + } + _ = status_rx.next().fuse() => { + return Err(anyhow!("authentication canceled")); + } + } + } + + let credentials = credentials.unwrap(); + self.set_id(credentials.user_id); + self.cloud_client + .set_credentials(credentials.user_id as u32, credentials.access_token.clone()); + self.state.write().credentials = Some(credentials.clone()); + self.set_status(Status::Authenticated, cx); + + Ok(credentials) + } + + async fn validate_credentials( + self: &Arc, + credentials: &Credentials, + cx: &AsyncApp, + ) -> Result { + match self + .cloud_client + .validate_credentials(credentials.user_id as u32, &credentials.access_token) + .await + { + Ok(valid) => Ok(valid), + Err(err) => { + self.set_status(Status::AuthenticationError, cx); + Err(anyhow!("failed to validate credentials: {}", err)) + } + } + } + + /// Performs a sign-in and also connects to Collab. + /// + /// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls + /// to `sign_in` when we're ready to remove auto-connection to Collab. + pub async fn sign_in_with_optional_connect( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> Result<()> { + let credentials = self.sign_in(try_provider, cx).await?; + + let connect_result = match self.connect_with_credentials(credentials, cx).await { + ConnectionResult::Timeout => Err(anyhow!("connection timed out")), + ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")), + ConnectionResult::Result(result) => result.context("client auth and connect"), + }; + connect_result.log_err(); + + Ok(()) + } + + pub async fn connect( self: &Arc, try_provider: bool, cx: &AsyncApp, ) -> ConnectionResult<()> { let was_disconnected = match *self.status().borrow() { - Status::SignedOut => true, + Status::SignedOut | Status::Authenticated => true, Status::ConnectionError | Status::ConnectionLost | Status::Authenticating { .. } + | Status::AuthenticationError | Status::Reauthenticating { .. } | Status::ReconnectionError { .. } => false, Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { @@ -898,39 +977,10 @@ impl Client { ); } }; - if was_disconnected { - self.set_status(Status::Authenticating, cx); - } else { - self.set_status(Status::Reauthenticating, cx) - } - - let mut read_from_provider = false; - let mut credentials = self.state.read().credentials.clone(); - if credentials.is_none() && try_provider { - credentials = self.credentials_provider.read_credentials(cx).await; - read_from_provider = credentials.is_some(); - } - - if credentials.is_none() { - let mut status_rx = self.status(); - let _ = status_rx.next().await; - futures::select_biased! { - authenticate = self.authenticate(cx).fuse() => { - match authenticate { - Ok(creds) => credentials = Some(creds), - Err(err) => { - self.set_status(Status::ConnectionError, cx); - return ConnectionResult::Result(Err(err)); - } - } - } - _ = status_rx.next().fuse() => { - return ConnectionResult::Result(Err(anyhow!("authentication canceled"))); - } - } - } - let credentials = credentials.unwrap(); - self.set_id(credentials.user_id); + let credentials = match self.sign_in(try_provider, cx).await { + Ok(credentials) => credentials, + Err(err) => return ConnectionResult::Result(Err(err)), + }; if was_disconnected { self.set_status(Status::Connecting, cx); @@ -938,17 +988,20 @@ impl Client { self.set_status(Status::Reconnecting, cx); } + self.connect_with_credentials(credentials, cx).await + } + + async fn connect_with_credentials( + self: &Arc, + credentials: Credentials, + cx: &AsyncApp, + ) -> ConnectionResult<()> { let mut timeout = futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT)); futures::select_biased! { connection = self.establish_connection(&credentials, cx).fuse() => { match connection { Ok(conn) => { - self.state.write().credentials = Some(credentials.clone()); - if !read_from_provider && IMPERSONATE_LOGIN.is_none() { - self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err(); - } - futures::select_biased! { result = self.set_connection(conn, cx).fuse() => { match result.context("client auth and connect") { @@ -966,15 +1019,8 @@ impl Client { } } Err(EstablishConnectionError::Unauthorized) => { - self.state.write().credentials.take(); - if read_from_provider { - self.credentials_provider.delete_credentials(cx).await.log_err(); - self.set_status(Status::SignedOut, cx); - self.authenticate_and_connect(false, cx).await - } else { - self.set_status(Status::ConnectionError, cx); - ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) - } + self.set_status(Status::ConnectionError, cx); + ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) } Err(EstablishConnectionError::UpgradeRequired) => { self.set_status(Status::UpgradeRequired, cx); @@ -1138,7 +1184,7 @@ impl Client { .to_str() .map_err(EstablishConnectionError::other)? .to_string(); - Url::parse(&collab_url).with_context(|| format!("parsing colab rpc url {collab_url}")) + Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}")) } } @@ -1158,6 +1204,7 @@ impl Client { let http = self.http.clone(); let proxy = http.proxy().cloned(); + let user_agent = http.user_agent().cloned(); let credentials = credentials.clone(); let rpc_url = self.rpc_url(http, release_channel); let system_id = self.telemetry.system_id(); @@ -1209,7 +1256,7 @@ impl Client { // We then modify the request to add our desired headers. let request_headers = request.headers_mut(); request_headers.insert( - "Authorization", + http::header::AUTHORIZATION, HeaderValue::from_str(&credentials.authorization_header())?, ); request_headers.insert( @@ -1221,6 +1268,9 @@ impl Client { "x-zed-release-channel", HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?, ); + if let Some(user_agent) = user_agent { + request_headers.insert(http::header::USER_AGENT, user_agent); + } if let Some(system_id) = system_id { request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?); } @@ -1365,96 +1415,31 @@ impl Client { self: &Arc, http: Arc, login: String, - mut api_token: String, + api_token: String, ) -> Result { - #[derive(Deserialize)] - struct AuthenticatedUserResponse { - user: User, + #[derive(Serialize)] + struct ImpersonateUserBody { + github_login: String, } #[derive(Deserialize)] - struct User { - id: u64, + struct ImpersonateUserResponse { + user_id: u64, + access_token: String, } - let github_user = { - #[derive(Deserialize)] - struct GithubUser { - id: i32, - login: String, - created_at: DateTime, - } - - let request = { - let mut request_builder = - Request::get(&format!("https://api.github.com/users/{login}")); - if let Ok(github_token) = std::env::var("GITHUB_TOKEN") { - request_builder = - request_builder.header("Authorization", format!("Bearer {}", github_token)); - } - - request_builder.body(AsyncBody::empty())? - }; - - let mut response = http - .send(request) - .await - .context("error fetching GitHub user")?; - - let mut body = Vec::new(); - response - .body_mut() - .read_to_end(&mut body) - .await - .context("error reading GitHub user")?; - - if !response.status().is_success() { - let text = String::from_utf8_lossy(body.as_slice()); - bail!( - "status error {}, response: {text:?}", - response.status().as_u16() - ); - } - - serde_json::from_slice::(body.as_slice()).map_err(|err| { - log::error!("Error deserializing: {:?}", err); - log::error!( - "GitHub API response text: {:?}", - String::from_utf8_lossy(body.as_slice()) - ); - anyhow!("error deserializing GitHub user") - })? - }; - - let query_params = [ - ("github_login", &github_user.login), - ("github_user_id", &github_user.id.to_string()), - ( - "github_user_created_at", - &github_user.created_at.to_rfc3339(), - ), - ]; - - // Use the collab server's admin API to retrieve the ID - // of the impersonated user. - let mut url = self.rpc_url(http.clone(), None).await?; - url.set_path("/user"); - url.set_query(Some( - &query_params - .iter() - .map(|(key, value)| { - format!( - "{}={}", - key, - url::form_urlencoded::byte_serialize(value.as_bytes()).collect::() - ) - }) - .collect::>() - .join("&"), - )); - let request: http_client::Request = Request::get(url.as_str()) - .header("Authorization", format!("token {api_token}")) - .body("".into())?; + let url = self + .http + .build_zed_cloud_url("/internal/users/impersonate", &[])?; + let request = Request::post(url.as_str()) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {api_token}")) + .body( + serde_json::to_string(&ImpersonateUserBody { + github_login: login, + })? + .into(), + )?; let mut response = http.send(request).await?; let mut body = String::new(); @@ -1465,18 +1450,17 @@ impl Client { response.status().as_u16(), body, ); - let response: AuthenticatedUserResponse = serde_json::from_str(&body)?; + let response: ImpersonateUserResponse = serde_json::from_str(&body)?; - // Use the admin API token to authenticate as the impersonated user. - api_token.insert_str(0, "ADMIN_TOKEN:"); Ok(Credentials { - user_id: response.user.id, - access_token: api_token, + user_id: response.user_id, + access_token: response.access_token, }) } pub async fn sign_out(self: &Arc, cx: &AsyncApp) { self.state.write().credentials = None; + self.cloud_client.clear_credentials(); self.disconnect(cx); if self.has_credentials(cx).await { @@ -1705,7 +1689,7 @@ pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> { #[cfg(test)] mod tests { use super::*; - use crate::test::FakeServer; + use crate::test::{FakeServer, parse_authorization_header}; use clock::FakeSystemClock; use gpui::{AppContext as _, BackgroundExecutor, TestAppContext}; @@ -1756,6 +1740,46 @@ mod tests { assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token } + #[gpui::test(iterations = 10)] + async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) { + init_test(cx); + let http_client = FakeHttpClient::with_200_response(); + let client = + cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx)); + let server = FakeServer::for_client(42, &client, cx).await; + let mut status = client.status(); + assert!(matches!( + status.next().await, + Some(Status::Connected { .. }) + )); + assert_eq!(server.auth_count(), 1); + + // Simulate an auth failure during reconnection. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(503) + .body("".into()) + .unwrap()) + }); + server.disconnect(); + while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} + + // Restore the ability to authenticate. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + }); + cx.executor().advance_clock(Duration::from_secs(10)); + while !matches!(status.next().await, Some(Status::Connected { .. })) {} + assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting + } + #[gpui::test(iterations = 10)] async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) { init_test(cx); @@ -1786,7 +1810,7 @@ mod tests { }); let auth_and_connect = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert!(matches!(status.next().await, Some(Status::Connecting))); @@ -1831,6 +1855,75 @@ mod tests { )); } + #[gpui::test(iterations = 10)] + async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) { + init_test(cx); + let auth_count = Arc::new(Mutex::new(0)); + let http_client = FakeHttpClient::create(|_request| async move { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + }); + let client = + cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx)); + client.override_authenticate({ + let auth_count = auth_count.clone(); + move |cx| { + let auth_count = auth_count.clone(); + cx.background_spawn(async move { + *auth_count.lock() += 1; + Ok(Credentials { + user_id: 1, + access_token: auth_count.lock().to_string(), + }) + }) + } + }); + + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(credentials.access_token, "1"); + + // If credentials are still valid, signing in doesn't trigger authentication. + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(credentials.access_token, "1"); + + // If the server is unavailable, signing in doesn't trigger authentication. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(503) + .body("".into()) + .unwrap()) + }); + client.sign_in(false, &cx.to_async()).await.unwrap_err(); + assert_eq!(*auth_count.lock(), 1); + + // If credentials became invalid, signing in triggers authentication. + http_client + .as_fake() + .replace_handler(|_, request| async move { + let credentials = parse_authorization_header(&request).unwrap(); + if credentials.access_token == "2" { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + } else { + Ok(http_client::Response::builder() + .status(401) + .body("".into()) + .unwrap()) + } + }); + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 2); + assert_eq!(credentials.access_token, "2"); + } + #[gpui::test(iterations = 10)] async fn test_authenticating_more_than_once( cx: &mut TestAppContext, @@ -1863,7 +1956,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - move |cx| async move { client.authenticate_and_connect(false, &cx).await } + move |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 1); @@ -1871,7 +1964,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 2); diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 7d39464e4a..43a1a0b7a4 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -74,6 +74,12 @@ static ZED_CLIENT_CHECKSUM_SEED: LazyLock>> = LazyLock::new(|| { }) }); +pub static MINIDUMP_ENDPOINT: LazyLock> = LazyLock::new(|| { + option_env!("ZED_MINIDUMP_ENDPOINT") + .map(|s| s.to_owned()) + .or_else(|| env::var("ZED_MINIDUMP_ENDPOINT").ok()) +}); + static DOTNET_PROJECT_FILES_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"^(global\.json|Directory\.Build\.props|.*\.(csproj|fsproj|vbproj|sln))$").unwrap() }); diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 6ce79fa9c5..439fb100d2 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,8 +1,11 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; use chrono::Duration; +use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; +use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; +use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; use rpc::{ ConnectionId, Peer, Receipt, TypedEnvelope, @@ -39,6 +42,44 @@ impl FakeServer { executor: cx.executor(), }; + client.http_client().as_fake().replace_handler({ + let state = server.state.clone(); + move |old_handler, req| { + let state = state.clone(); + let old_handler = old_handler.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: client_user_id, + access_token: state.lock().access_token.to_string(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + client_user_id as i32, + format!("user-{client_user_id}"), + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => old_handler(req).await, + } + } + } + }); client .override_authenticate({ let state = Arc::downgrade(&server.state); @@ -105,7 +146,7 @@ impl FakeServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -223,3 +264,54 @@ impl Drop for FakeServer { self.disconnect(); } } + +pub fn parse_authorization_header(req: &Request) -> Option { + let mut auth_header = req + .headers() + .get(http::header::AUTHORIZATION)? + .to_str() + .ok()? + .split_whitespace(); + let user_id = auth_header.next()?.parse().ok()?; + let access_token = auth_header.next()?; + Some(Credentials { + user_id, + access_token: access_token.to_string(), + }) +} + +pub fn make_get_authenticated_user_response( + user_id: i32, + github_login: String, +) -> GetAuthenticatedUserResponse { + GetAuthenticatedUserResponse { + user: AuthenticatedUser { + id: user_id, + metrics_id: format!("metrics-id-{user_id}"), + avatar_url: "".to_string(), + github_login, + name: None, + is_staff: false, + accepted_tos_at: None, + }, + feature_flags: vec![], + plan: PlanInfo { + plan: Plan::ZedPro, + subscription_period: None, + usage: CurrentUsage { + model_requests: UsageData { + used: 0, + limit: UsageLimit::Limited(500), + }, + edit_predictions: UsageData { + used: 250, + limit: UsageLimit::Unlimited, + }, + }, + trial_started_at: None, + is_usage_based_billing_enabled: false, + is_account_too_young: false, + has_overdue_invoices: false, + }, + } +} diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 5ed258aa8e..3c125a0882 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,6 +1,11 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; +use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo}; +use cloud_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; use collections::{HashMap, HashSet, hash_map::Entry}; use derive_more::Deref; use feature_flags::FeatureFlagAppExt; @@ -16,11 +21,7 @@ use std::{ sync::{Arc, Weak}, }; use text::ReplicaId; -use util::{TryFutureExt as _, maybe}; -use zed_llm_client::{ - EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, - MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; +use util::{ResultExt, TryFutureExt as _}; pub type UserId = u64; @@ -55,7 +56,7 @@ pub struct ParticipantIndex(pub u32); #[derive(Default, Debug)] pub struct User { pub id: UserId, - pub github_login: String, + pub github_login: SharedString, pub avatar_uri: SharedUri, pub name: Option, } @@ -107,19 +108,14 @@ pub enum ContactRequestStatus { pub struct UserStore { users: HashMap>, - by_github_login: HashMap, + by_github_login: HashMap, participant_indices: HashMap, update_contacts_tx: mpsc::UnboundedSender, - current_plan: Option, - subscription_period: Option<(DateTime, DateTime)>, - trial_started_at: Option>, model_request_usage: Option, edit_prediction_usage: Option, - is_usage_based_billing_enabled: Option, - account_too_young: Option, - has_overdue_invoices: Option, + plan_info: Option, current_user: watch::Receiver>>, - accepted_tos_at: Option>>, + accepted_tos_at: Option>, contacts: Vec>, incoming_contact_requests: Vec>, outgoing_contact_requests: Vec>, @@ -145,6 +141,7 @@ pub enum Event { ShowContacts, ParticipantIndicesChanged, PrivateUserInfoUpdated, + PlanUpdated, } #[derive(Clone, Copy)] @@ -188,14 +185,9 @@ impl UserStore { users: Default::default(), by_github_login: Default::default(), current_user: current_user_rx, - current_plan: None, - subscription_period: None, - trial_started_at: None, + plan_info: None, model_request_usage: None, edit_prediction_usage: None, - is_usage_based_billing_enabled: None, - account_too_young: None, - has_overdue_invoices: None, accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), @@ -225,53 +217,30 @@ impl UserStore { return Ok(()); }; match status { - Status::Connected { .. } => { + Status::Authenticated | Status::Connected { .. } => { if let Some(user_id) = client.user_id() { - let fetch_user = if let Ok(fetch_user) = - this.update(cx, |this, cx| this.get_user(user_id, cx).log_err()) - { - fetch_user - } else { - break; - }; - let fetch_private_user_info = - client.request(proto::GetPrivateUserInfo {}).log_err(); - let (user, info) = - futures::join!(fetch_user, fetch_private_user_info); - + let response = client.cloud_client().get_authenticated_user().await; + let mut current_user = None; cx.update(|cx| { - if let Some(info) = info { - let staff = - info.staff && !*feature_flags::ZED_DISABLE_STAFF; - cx.update_flags(staff, info.flags); - client.telemetry.set_authenticated_user_info( - Some(info.metrics_id.clone()), - staff, - ); - + if let Some(response) = response.log_err() { + let user = Arc::new(User { + id: user_id, + github_login: response.user.github_login.clone().into(), + avatar_uri: response.user.avatar_url.clone().into(), + name: response.user.name.clone(), + }); + current_user = Some(user.clone()); this.update(cx, |this, cx| { - let accepted_tos_at = { - #[cfg(debug_assertions)] - if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() - { - None - } else { - info.accepted_tos_at - } - - #[cfg(not(debug_assertions))] - info.accepted_tos_at - }; - - this.set_current_user_accepted_tos_at(accepted_tos_at); - cx.emit(Event::PrivateUserInfoUpdated); + this.by_github_login + .insert(user.github_login.clone(), user_id); + this.users.insert(user_id, user); + this.update_authenticated_user(response, cx) }) } else { anyhow::Ok(()) } })??; - - current_user_tx.send(user).await.ok(); + current_user_tx.send(current_user).await.ok(); this.update(cx, |_, cx| cx.notify())?; } @@ -352,59 +321,22 @@ impl UserStore { async fn handle_update_plan( this: Entity, - message: TypedEnvelope, + _message: TypedEnvelope, mut cx: AsyncApp, ) -> Result<()> { + let client = this + .read_with(&cx, |this, _| this.client.upgrade())? + .context("client was dropped")?; + + let response = client + .cloud_client() + .get_authenticated_user() + .await + .context("failed to fetch authenticated user")?; + this.update(&mut cx, |this, cx| { - this.current_plan = Some(message.payload.plan()); - this.subscription_period = maybe!({ - let period = message.payload.subscription_period?; - let started_at = DateTime::from_timestamp(period.started_at as i64, 0)?; - let ended_at = DateTime::from_timestamp(period.ended_at as i64, 0)?; - - Some((started_at, ended_at)) - }); - this.trial_started_at = message - .payload - .trial_started_at - .and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0)); - this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled; - this.account_too_young = message.payload.account_too_young; - this.has_overdue_invoices = message.payload.has_overdue_invoices; - - if let Some(usage) = message.payload.usage { - // limits are always present even though they are wrapped in Option - this.model_request_usage = usage - .model_requests_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(ModelRequestUsage); - this.edit_prediction_usage = usage - .edit_predictions_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(EditPredictionUsage); - } - - cx.notify(); - })?; - Ok(()) - } - - pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { - self.model_request_usage = Some(usage); - cx.notify(); - } - - pub fn update_edit_prediction_usage( - &mut self, - usage: EditPredictionUsage, - cx: &mut Context, - ) { - self.edit_prediction_usage = Some(usage); - cx.notify(); + this.update_authenticated_user(response, cx); + }) } fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { @@ -763,59 +695,131 @@ impl UserStore { self.current_user.borrow().clone() } - pub fn current_plan(&self) -> Option { + pub fn plan(&self) -> Option { #[cfg(debug_assertions)] if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() { return match plan.as_str() { - "free" => Some(proto::Plan::Free), - "trial" => Some(proto::Plan::ZedProTrial), - "pro" => Some(proto::Plan::ZedPro), + "free" => Some(cloud_llm_client::Plan::ZedFree), + "trial" => Some(cloud_llm_client::Plan::ZedProTrial), + "pro" => Some(cloud_llm_client::Plan::ZedPro), _ => { panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'"); } }; } - self.current_plan + self.plan_info.as_ref().map(|info| info.plan) } pub fn subscription_period(&self) -> Option<(DateTime, DateTime)> { - self.subscription_period + self.plan_info + .as_ref() + .and_then(|plan| plan.subscription_period) + .map(|subscription_period| { + ( + subscription_period.started_at.0, + subscription_period.ended_at.0, + ) + }) } pub fn trial_started_at(&self) -> Option> { - self.trial_started_at + self.plan_info + .as_ref() + .and_then(|plan| plan.trial_started_at) + .map(|trial_started_at| trial_started_at.0) } - pub fn usage_based_billing_enabled(&self) -> Option { - self.is_usage_based_billing_enabled + /// Returns whether the user's account is too new to use the service. + pub fn account_too_young(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_account_too_young) + .unwrap_or_default() + } + + /// Returns whether the current user has overdue invoices and usage should be blocked. + pub fn has_overdue_invoices(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.has_overdue_invoices) + .unwrap_or_default() + } + + pub fn is_usage_based_billing_enabled(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_usage_based_billing_enabled) + .unwrap_or_default() } pub fn model_request_usage(&self) -> Option { self.model_request_usage } + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + pub fn edit_prediction_usage(&self) -> Option { self.edit_prediction_usage } + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); + } + + fn update_authenticated_user( + &mut self, + response: GetAuthenticatedUserResponse, + cx: &mut Context, + ) { + let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF; + cx.update_flags(staff, response.feature_flags); + if let Some(client) = self.client.upgrade() { + client + .telemetry + .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff); + } + + let accepted_tos_at = { + #[cfg(debug_assertions)] + if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() { + None + } else { + response.user.accepted_tos_at + } + + #[cfg(not(debug_assertions))] + response.user.accepted_tos_at + }; + + self.accepted_tos_at = Some(accepted_tos_at); + self.model_request_usage = Some(ModelRequestUsage(RequestUsage { + limit: response.plan.usage.model_requests.limit, + amount: response.plan.usage.model_requests.used as i32, + })); + self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { + limit: response.plan.usage.edit_predictions.limit, + amount: response.plan.usage.edit_predictions.used as i32, + })); + self.plan_info = Some(response.plan); + cx.emit(Event::PrivateUserInfoUpdated); + } + pub fn watch_current_user(&self) -> watch::Receiver>> { self.current_user.clone() } - /// Returns whether the user's account is too new to use the service. - pub fn account_too_young(&self) -> bool { - self.account_too_young.unwrap_or(false) - } - - /// Returns whether the current user has overdue invoices and usage should be blocked. - pub fn has_overdue_invoices(&self) -> bool { - self.has_overdue_invoices.unwrap_or(false) - } - - pub fn current_user_has_accepted_terms(&self) -> Option { + pub fn has_accepted_terms_of_service(&self) -> bool { self.accepted_tos_at - .map(|accepted_tos_at| accepted_tos_at.is_some()) + .map_or(false, |accepted_tos_at| accepted_tos_at.is_some()) } pub fn accept_terms_of_service(&self, cx: &Context) -> Task> { @@ -827,23 +831,18 @@ impl UserStore { cx.spawn(async move |this, cx| -> anyhow::Result<()> { let client = client.upgrade().context("client not found")?; let response = client - .request(proto::AcceptTermsOfService {}) + .cloud_client() + .accept_terms_of_service() .await .context("error accepting tos")?; this.update(cx, |this, cx| { - this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at)); + this.accepted_tos_at = Some(response.user.accepted_tos_at); cx.emit(Event::PrivateUserInfoUpdated); })?; Ok(()) }) } - fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option) { - self.accepted_tos_at = Some( - accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)), - ); - } - fn load_users( &self, request: impl RequestMessage, @@ -902,7 +901,7 @@ impl UserStore { let mut missing_user_ids = Vec::new(); for id in user_ids { if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) { - ret.insert(id, github_login.into()); + ret.insert(id, github_login); } else { missing_user_ids.push(id) } @@ -923,7 +922,7 @@ impl User { fn new(message: proto::User) -> Arc { Arc::new(User { id: message.id, - github_login: message.github_login, + github_login: message.github_login.into(), avatar_uri: message.avatar_url.into(), name: message.name, }) diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml new file mode 100644 index 0000000000..d56aa94c6e --- /dev/null +++ b/crates/cloud_api_client/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "cloud_api_client" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_api_client.rs" + +[dependencies] +anyhow.workspace = true +cloud_api_types.workspace = true +futures.workspace = true +http_client.workspace = true +parking_lot.workspace = true +serde_json.workspace = true +workspace-hack.workspace = true diff --git a/crates/cloud_api_client/LICENSE-APACHE b/crates/cloud_api_client/LICENSE-APACHE new file mode 120000 index 0000000000..1cd601d0a3 --- /dev/null +++ b/crates/cloud_api_client/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs new file mode 100644 index 0000000000..edac051a0e --- /dev/null +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -0,0 +1,188 @@ +use std::sync::Arc; + +use anyhow::{Context, Result, anyhow}; +pub use cloud_api_types::*; +use futures::AsyncReadExt as _; +use http_client::http::request; +use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode}; +use parking_lot::RwLock; + +struct Credentials { + user_id: u32, + access_token: String, +} + +pub struct CloudApiClient { + credentials: RwLock>, + http_client: Arc, +} + +impl CloudApiClient { + pub fn new(http_client: Arc) -> Self { + Self { + credentials: RwLock::new(None), + http_client, + } + } + + pub fn has_credentials(&self) -> bool { + self.credentials.read().is_some() + } + + pub fn set_credentials(&self, user_id: u32, access_token: String) { + *self.credentials.write() = Some(Credentials { + user_id, + access_token, + }); + } + + pub fn clear_credentials(&self) { + *self.credentials.write() = None; + } + + fn build_request( + &self, + req: request::Builder, + body: impl Into, + ) -> Result> { + let credentials = self.credentials.read(); + let credentials = credentials.as_ref().context("no credentials provided")?; + build_request(req, body, credentials) + } + + pub async fn get_authenticated_user(&self) -> Result { + let request = self.build_request( + Request::builder().method(Method::GET).uri( + self.http_client + .build_zed_cloud_url("/client/users/me", &[])? + .as_ref(), + ), + AsyncBody::default(), + )?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn accept_terms_of_service(&self) -> Result { + let request = self.build_request( + Request::builder().method(Method::POST).uri( + self.http_client + .build_zed_cloud_url("/client/terms_of_service/accept", &[])? + .as_ref(), + ), + AsyncBody::default(), + )?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to accept terms of service.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn create_llm_token( + &self, + system_id: Option, + ) -> Result { + let mut request_builder = Request::builder().method(Method::POST).uri( + self.http_client + .build_zed_cloud_url("/client/llm_tokens", &[])? + .as_ref(), + ); + + if let Some(system_id) = system_id { + request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id); + } + + let request = self.build_request(request_builder, AsyncBody::default())?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to create LLM token.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn validate_credentials(&self, user_id: u32, access_token: &str) -> Result { + let request = build_request( + Request::builder().method(Method::GET).uri( + self.http_client + .build_zed_cloud_url("/client/users/me", &[])? + .as_ref(), + ), + AsyncBody::default(), + &Credentials { + user_id, + access_token: access_token.into(), + }, + )?; + + let mut response = self.http_client.send(request).await?; + + if response.status().is_success() { + Ok(true) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + if response.status() == StatusCode::UNAUTHORIZED { + return Ok(false); + } else { + return Err(anyhow!( + "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", + response.status() + )); + } + } + } +} + +fn build_request( + req: request::Builder, + body: impl Into, + credentials: &Credentials, +) -> Result> { + Ok(req + .header("Content-Type", "application/json") + .header( + "Authorization", + format!("{} {}", credentials.user_id, credentials.access_token), + ) + .body(body.into())?) +} diff --git a/crates/cloud_api_types/Cargo.toml b/crates/cloud_api_types/Cargo.toml new file mode 100644 index 0000000000..868797df3b --- /dev/null +++ b/crates/cloud_api_types/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "cloud_api_types" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_api_types.rs" + +[dependencies] +chrono.workspace = true +cloud_llm_client.workspace = true +serde.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true +serde_json.workspace = true diff --git a/crates/cloud_api_types/LICENSE-APACHE b/crates/cloud_api_types/LICENSE-APACHE new file mode 120000 index 0000000000..1cd601d0a3 --- /dev/null +++ b/crates/cloud_api_types/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_types/src/cloud_api_types.rs b/crates/cloud_api_types/src/cloud_api_types.rs new file mode 100644 index 0000000000..b38b38cde1 --- /dev/null +++ b/crates/cloud_api_types/src/cloud_api_types.rs @@ -0,0 +1,55 @@ +mod timestamp; + +use serde::{Deserialize, Serialize}; + +pub use crate::timestamp::Timestamp; + +pub const ZED_SYSTEM_ID_HEADER_NAME: &str = "x-zed-system-id"; + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct GetAuthenticatedUserResponse { + pub user: AuthenticatedUser, + pub feature_flags: Vec, + pub plan: PlanInfo, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct AuthenticatedUser { + pub id: i32, + pub metrics_id: String, + pub avatar_url: String, + pub github_login: String, + pub name: Option, + pub is_staff: bool, + pub accepted_tos_at: Option, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct PlanInfo { + pub plan: cloud_llm_client::Plan, + pub subscription_period: Option, + pub usage: cloud_llm_client::CurrentUsage, + pub trial_started_at: Option, + pub is_usage_based_billing_enabled: bool, + pub is_account_too_young: bool, + pub has_overdue_invoices: bool, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +pub struct SubscriptionPeriod { + pub started_at: Timestamp, + pub ended_at: Timestamp, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct AcceptTermsOfServiceResponse { + pub user: AuthenticatedUser, +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct LlmToken(pub String); + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct CreateLlmTokenResponse { + pub token: LlmToken, +} diff --git a/crates/cloud_api_types/src/timestamp.rs b/crates/cloud_api_types/src/timestamp.rs new file mode 100644 index 0000000000..1f055d58ef --- /dev/null +++ b/crates/cloud_api_types/src/timestamp.rs @@ -0,0 +1,166 @@ +use chrono::{DateTime, NaiveDateTime, SecondsFormat, Utc}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A timestamp with a serialized representation in RFC 3339 format. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub struct Timestamp(pub DateTime); + +impl Timestamp { + pub fn new(datetime: DateTime) -> Self { + Self(datetime) + } +} + +impl From> for Timestamp { + fn from(value: DateTime) -> Self { + Self(value) + } +} + +impl From for Timestamp { + fn from(value: NaiveDateTime) -> Self { + Self(value.and_utc()) + } +} + +impl Serialize for Timestamp { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let rfc3339_string = self.0.to_rfc3339_opts(SecondsFormat::Millis, true); + serializer.serialize_str(&rfc3339_string) + } +} + +impl<'de> Deserialize<'de> for Timestamp { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = String::deserialize(deserializer)?; + let datetime = DateTime::parse_from_rfc3339(&value) + .map_err(serde::de::Error::custom)? + .to_utc(); + Ok(Self(datetime)) + } +} + +#[cfg(test)] +mod tests { + use chrono::NaiveDate; + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_timestamp_serialization() { + let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + let timestamp = Timestamp::new(datetime); + + let json = serde_json::to_string(×tamp).unwrap(); + assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); + } + + #[test] + fn test_timestamp_deserialization() { + let json = "\"2023-12-25T14:30:45.123Z\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + let expected = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_roundtrip() { + let original = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + let timestamp = Timestamp::new(original); + let json = serde_json::to_string(×tamp).unwrap(); + let deserialized: Timestamp = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.0, original); + } + + #[test] + fn test_timestamp_from_datetime_utc() { + let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + let timestamp = Timestamp::from(datetime); + assert_eq!(timestamp.0, datetime); + } + + #[test] + fn test_timestamp_from_naive_datetime() { + let naive_dt = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_milli_opt(14, 30, 45, 123) + .unwrap(); + + let timestamp = Timestamp::from(naive_dt); + let expected = naive_dt.and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_serialization_with_microseconds() { + // Test that microseconds are truncated to milliseconds + let datetime = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_micro_opt(14, 30, 45, 123456) + .unwrap() + .and_utc(); + + let timestamp = Timestamp::new(datetime); + let json = serde_json::to_string(×tamp).unwrap(); + + // Should be truncated to milliseconds + assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); + } + + #[test] + fn test_timestamp_deserialization_without_milliseconds() { + let json = "\"2023-12-25T14:30:45Z\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + let expected = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_opt(14, 30, 45) + .unwrap() + .and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_deserialization_with_timezone() { + let json = "\"2023-12-25T14:30:45.123+05:30\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + // Should be converted to UTC + let expected = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_milli_opt(9, 0, 45, 123) // 14:30:45 + 5:30 = 20:00:45, but we want UTC so subtract 5:30 + .unwrap() + .and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_deserialization_with_invalid_format() { + let json = "\"invalid-date\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } +} diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml new file mode 100644 index 0000000000..6f090d3c6e --- /dev/null +++ b/crates/cloud_llm_client/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "cloud_llm_client" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_llm_client.rs" + +[dependencies] +anyhow.workspace = true +serde = { workspace = true, features = ["derive", "rc"] } +serde_json.workspace = true +strum = { workspace = true, features = ["derive"] } +uuid = { workspace = true, features = ["serde"] } +workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/cloud_llm_client/LICENSE-APACHE b/crates/cloud_llm_client/LICENSE-APACHE new file mode 120000 index 0000000000..1cd601d0a3 --- /dev/null +++ b/crates/cloud_llm_client/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs new file mode 100644 index 0000000000..e78957ec49 --- /dev/null +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -0,0 +1,386 @@ +use std::str::FromStr; +use std::sync::Arc; + +use anyhow::Context as _; +use serde::{Deserialize, Serialize}; +use strum::{Display, EnumIter, EnumString}; +use uuid::Uuid; + +/// The name of the header used to indicate which version of Zed the client is running. +pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version"; + +/// The name of the header used to indicate when a request failed due to an +/// expired LLM token. +/// +/// The client may use this as a signal to refresh the token. +pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; + +/// The name of the header used to indicate what plan the user is currently on. +pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan"; + +/// The name of the header used to indicate the usage limit for model requests. +pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit"; + +/// The name of the header used to indicate the usage amount for model requests. +pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount"; + +/// The name of the header used to indicate the usage limit for edit predictions. +pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit"; + +/// The name of the header used to indicate the usage amount for edit predictions. +pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount"; + +/// The name of the header used to indicate the resource for which the subscription limit has been reached. +pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource"; + +pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests"; +pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions"; + +/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached. +pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached"; + +/// The name of the header used to indicate the the minimum required Zed version. +/// +/// This can be used to force a Zed upgrade in order to continue communicating +/// with the LLM service. +pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version"; + +/// The name of the header used by the client to indicate to the server that it supports receiving status messages. +pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = + "x-zed-client-supports-status-messages"; + +/// The name of the header used by the server to indicate to the client that it supports sending status messages. +pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = + "x-zed-server-supports-status-messages"; + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum UsageLimit { + Limited(i32), + Unlimited, +} + +impl FromStr for UsageLimit { + type Err = anyhow::Error; + + fn from_str(value: &str) -> Result { + match value { + "unlimited" => Ok(Self::Unlimited), + limit => limit + .parse::() + .map(Self::Limited) + .context("failed to parse limit"), + } + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Plan { + #[default] + #[serde(alias = "Free")] + ZedFree, + #[serde(alias = "ZedPro")] + ZedPro, + #[serde(alias = "ZedProTrial")] + ZedProTrial, +} + +impl Plan { + pub fn as_str(&self) -> &'static str { + match self { + Plan::ZedFree => "zed_free", + Plan::ZedPro => "zed_pro", + Plan::ZedProTrial => "zed_pro_trial", + } + } + + pub fn model_requests_limit(&self) -> UsageLimit { + match self { + Plan::ZedPro => UsageLimit::Limited(500), + Plan::ZedProTrial => UsageLimit::Limited(150), + Plan::ZedFree => UsageLimit::Limited(50), + } + } + + pub fn edit_predictions_limit(&self) -> UsageLimit { + match self { + Plan::ZedPro => UsageLimit::Unlimited, + Plan::ZedProTrial => UsageLimit::Unlimited, + Plan::ZedFree => UsageLimit::Limited(2_000), + } + } +} + +impl FromStr for Plan { + type Err = anyhow::Error; + + fn from_str(value: &str) -> Result { + match value { + "zed_free" => Ok(Plan::ZedFree), + "zed_pro" => Ok(Plan::ZedPro), + "zed_pro_trial" => Ok(Plan::ZedProTrial), + plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")), + } + } +} + +#[derive( + Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum LanguageModelProvider { + Anthropic, + OpenAi, + Google, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsBody { + #[serde(skip_serializing_if = "Option::is_none", default)] + pub outline: Option, + pub input_events: String, + pub input_excerpt: String, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub speculated_output: Option, + /// Whether the user provided consent for sampling this interaction. + #[serde(default, alias = "data_collection_permission")] + pub can_collect_data: bool, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub diagnostic_groups: Option>, + /// Info about the git repository state, only present when can_collect_data is true. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub git_info: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsGitInfo { + /// SHA of git HEAD commit at time of prediction. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub head_sha: Option, + /// URL of the remote called `origin`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub remote_origin_url: Option, + /// URL of the remote called `upstream`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub remote_upstream_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsResponse { + pub request_id: Uuid, + pub output_excerpt: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcceptEditPredictionBody { + pub request_id: Uuid, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionMode { + Normal, + Max, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionIntent { + UserPrompt, + ToolResults, + ThreadSummarization, + ThreadContextSummarization, + CreateFile, + EditFile, + InlineAssist, + TerminalInlineAssist, + GenerateGitCommitMessage, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionBody { + #[serde(skip_serializing_if = "Option::is_none", default)] + pub thread_id: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub prompt_id: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub intent: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub mode: Option, + pub provider: LanguageModelProvider, + pub model: String, + pub provider_request: serde_json::Value, +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionRequestStatus { + Queued { + position: usize, + }, + Started, + Failed { + code: String, + message: String, + request_id: Uuid, + /// Retry duration in seconds. + retry_after: Option, + }, + UsageUpdated { + amount: usize, + limit: UsageLimit, + }, + ToolUseLimitReached, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionEvent { + Status(CompletionRequestStatus), + Event(T), +} + +impl CompletionEvent { + pub fn into_status(self) -> Option { + match self { + Self::Status(status) => Some(status), + Self::Event(_) => None, + } + } + + pub fn into_event(self) -> Option { + match self { + Self::Event(event) => Some(event), + Self::Status(_) => None, + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct WebSearchBody { + pub query: String, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct WebSearchResponse { + pub results: Vec, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct WebSearchResult { + pub title: String, + pub url: String, + pub text: String, +} + +#[derive(Serialize, Deserialize)] +pub struct CountTokensBody { + pub provider: LanguageModelProvider, + pub model: String, + pub provider_request: serde_json::Value, +} + +#[derive(Serialize, Deserialize)] +pub struct CountTokensResponse { + pub tokens: usize, +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelId(pub Arc); + +impl std::fmt::Display for LanguageModelId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LanguageModel { + pub provider: LanguageModelProvider, + pub id: LanguageModelId, + pub display_name: String, + pub max_token_count: usize, + pub max_token_count_in_max_mode: Option, + pub max_output_tokens: usize, + pub supports_tools: bool, + pub supports_images: bool, + pub supports_thinking: bool, + pub supports_max_mode: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListModelsResponse { + pub models: Vec, + pub default_model: LanguageModelId, + pub default_fast_model: LanguageModelId, + pub recommended_models: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GetSubscriptionResponse { + pub plan: Plan, + pub usage: Option, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct CurrentUsage { + pub model_requests: UsageData, + pub edit_predictions: UsageData, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct UsageData { + pub used: u32, + pub limit: UsageLimit, +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; + + #[test] + fn test_plan_deserialize_snake_case() { + let plan = serde_json::from_value::(json!("zed_free")).unwrap(); + assert_eq!(plan, Plan::ZedFree); + + let plan = serde_json::from_value::(json!("zed_pro")).unwrap(); + assert_eq!(plan, Plan::ZedPro); + + let plan = serde_json::from_value::(json!("zed_pro_trial")).unwrap(); + assert_eq!(plan, Plan::ZedProTrial); + } + + #[test] + fn test_plan_deserialize_aliases() { + let plan = serde_json::from_value::(json!("Free")).unwrap(); + assert_eq!(plan, Plan::ZedFree); + + let plan = serde_json::from_value::(json!("ZedPro")).unwrap(); + assert_eq!(plan, Plan::ZedPro); + + let plan = serde_json::from_value::(json!("ZedProTrial")).unwrap(); + assert_eq!(plan, Plan::ZedProTrial); + } + + #[test] + fn test_usage_limit_from_str() { + let limit = UsageLimit::from_str("unlimited").unwrap(); + assert!(matches!(limit, UsageLimit::Unlimited)); + + let limit = UsageLimit::from_str(&0.to_string()).unwrap(); + assert!(matches!(limit, UsageLimit::Limited(0))); + + let limit = UsageLimit::from_str(&50.to_string()).unwrap(); + assert!(matches!(limit, UsageLimit::Limited(50))); + + for value in ["not_a_number", "50xyz"] { + let limit = UsageLimit::from_str(value); + assert!(limit.is_err()); + } + } +} diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index d3b5048283..9af95317e6 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -23,13 +23,14 @@ async-stripe.workspace = true async-trait.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } -aws-sdk-s3 = { version = "1.15.0" } aws-sdk-kinesis = "1.51.0" +aws-sdk-s3 = { version = "1.15.0" } axum = { version = "0.6", features = ["json", "headers", "ws"] } axum-extra = { version = "0.4", features = ["erased-json"] } base64.workspace = true chrono.workspace = true clock.workspace = true +cloud_llm_client.workspace = true collections.workspace = true dashmap.workspace = true derive_more.workspace = true @@ -75,7 +76,6 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re util.workspace = true uuid.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true [dev-dependencies] agent_settings.workspace = true diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index ca840493ad..73d473ab76 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -173,6 +173,7 @@ CREATE TABLE "language_servers" ( "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "name" VARCHAR NOT NULL, + "capabilities" TEXT NOT NULL, PRIMARY KEY (project_id, id) ); diff --git a/crates/collab/migrations/20250804080620_language_server_capabilities.sql b/crates/collab/migrations/20250804080620_language_server_capabilities.sql new file mode 100644 index 0000000000..f74f094ed2 --- /dev/null +++ b/crates/collab/migrations/20250804080620_language_server_capabilities.sql @@ -0,0 +1,5 @@ +ALTER TABLE language_servers + ADD COLUMN capabilities TEXT NOT NULL DEFAULT '{}'; + +ALTER TABLE language_servers + ALTER COLUMN capabilities DROP DEFAULT; diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 3b0f5396a7..6cf3f68f54 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -100,13 +100,11 @@ impl std::fmt::Display for SystemIdHeader { pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() - .route("/user", get(update_or_create_authenticated_user)) .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) .route("/users/:id/update_plan", post(update_plan)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) - .merge(billing::router()) .merge(contributors::router()) .layer( ServiceBuilder::new() @@ -146,48 +144,6 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR Ok::<_, Error>(next.run(req).await) } -#[derive(Debug, Deserialize)] -struct AuthenticatedUserParams { - github_user_id: i32, - github_login: String, - github_email: Option, - github_name: Option, - github_user_created_at: chrono::DateTime, -} - -#[derive(Debug, Serialize)] -struct AuthenticatedUserResponse { - user: User, - metrics_id: String, - feature_flags: Vec, -} - -async fn update_or_create_authenticated_user( - Query(params): Query, - Extension(app): Extension>, -) -> Result> { - let initial_channel_id = app.config.auto_join_channel_id; - - let user = app - .db - .update_or_create_user_by_github_account( - ¶ms.github_login, - params.github_user_id, - params.github_email.as_deref(), - params.github_name.as_deref(), - params.github_user_created_at, - initial_channel_id, - ) - .await?; - let metrics_id = app.db.get_user_metrics_id(user.id).await?; - let feature_flags = app.db.get_user_flags(user.id).await?; - Ok(Json(AuthenticatedUserResponse { - user, - metrics_id, - feature_flags, - })) -} - #[derive(Debug, Deserialize)] struct LookUpUserParams { identifier: String, @@ -354,9 +310,9 @@ async fn refresh_llm_tokens( #[derive(Debug, Serialize, Deserialize)] struct UpdatePlanBody { - pub plan: zed_llm_client::Plan, + pub plan: cloud_llm_client::Plan, pub subscription_period: SubscriptionPeriod, - pub usage: zed_llm_client::CurrentUsage, + pub usage: cloud_llm_client::CurrentUsage, pub trial_started_at: Option>, pub is_usage_based_billing_enabled: bool, pub is_account_too_young: bool, @@ -378,9 +334,9 @@ async fn update_plan( extract::Json(body): extract::Json, ) -> Result> { let plan = match body.plan { - zed_llm_client::Plan::ZedFree => proto::Plan::Free, - zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro, - zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, + cloud_llm_client::Plan::ZedFree => proto::Plan::Free, + cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro, + cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, }; let update_user_plan = proto::UpdateUserPlan { @@ -412,15 +368,15 @@ async fn update_plan( Ok(Json(UpdatePlanResponse {})) } -fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit { +fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit { proto::UsageLimit { variant: Some(match limit { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 9a27e22f87..0e15308ffe 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,15 +1,13 @@ use anyhow::{Context as _, bail}; -use axum::{Extension, Json, Router, extract, routing::post}; use chrono::{DateTime, Utc}; +use cloud_llm_client::LanguageModelProvider; use collections::{HashMap, HashSet}; -use reqwest::StatusCode; use sea_orm::ActiveValue; -use serde::{Deserialize, Serialize}; use std::{sync::Arc, time::Duration}; use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus}; use util::{ResultExt, maybe}; -use zed_llm_client::LanguageModelProvider; +use crate::AppState; use crate::db::billing_subscription::{ StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, }; @@ -19,7 +17,6 @@ use crate::stripe_client::{ StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription, StripeSubscriptionId, }; -use crate::{AppState, Error, Result}; use crate::{db::UserId, llm::db::LlmDatabase}; use crate::{ db::{ @@ -30,70 +27,6 @@ use crate::{ stripe_billing::StripeBilling, }; -pub fn router() -> Router { - Router::new().route( - "/billing/subscriptions/sync", - post(sync_billing_subscription), - ) -} - -#[derive(Debug, Deserialize)] -struct SyncBillingSubscriptionBody { - github_user_id: i32, -} - -#[derive(Debug, Serialize)] -struct SyncBillingSubscriptionResponse { - stripe_customer_id: String, -} - -async fn sync_billing_subscription( - Extension(app): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let Some(stripe_client) = app.stripe_client.clone() else { - log::error!("failed to retrieve Stripe client"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; - - let user = app - .db - .get_user_by_github_user_id(body.github_user_id) - .await? - .context("user not found")?; - - let billing_customer = app - .db - .get_billing_customer_by_user_id(user.id) - .await? - .context("billing customer not found")?; - let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - let subscriptions = stripe_client - .list_subscriptions_for_customer(&stripe_customer_id) - .await?; - - for subscription in subscriptions { - let subscription_id = subscription.id.clone(); - - sync_subscription(&app, &stripe_client, subscription) - .await - .with_context(|| { - format!( - "failed to sync subscription {subscription_id} for user {}", - user.id, - ) - })?; - } - - Ok(Json(SyncBillingSubscriptionResponse { - stripe_customer_id: billing_customer.stripe_customer_id.clone(), - })) -} - /// The amount of time we wait in between each poll of Stripe events. /// /// This value should strike a balance between: @@ -154,6 +87,14 @@ async fn poll_stripe_events( stripe_client: &Arc, real_stripe_client: &stripe::Client, ) -> anyhow::Result<()> { + let feature_flags = app.db.list_feature_flags().await?; + let sync_events_using_cloud = feature_flags + .iter() + .any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all); + if sync_events_using_cloud { + return Ok(()); + } + fn event_type_to_string(event_type: EventType) -> String { // Calling `to_string` on `stripe::EventType` members gives us a quoted string, // so we need to unquote it. @@ -636,6 +577,14 @@ async fn sync_model_request_usage_with_stripe( llm_db: &Arc, stripe_billing: &Arc, ) -> anyhow::Result<()> { + let feature_flags = app.db.list_feature_flags().await?; + let sync_model_request_usage_using_cloud = feature_flags + .iter() + .any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all); + if sync_model_request_usage_using_cloud { + return Ok(()); + } + log::info!("Stripe usage sync: Starting"); let started_at = Utc::now(); diff --git a/crates/collab/src/api/contributors.rs b/crates/collab/src/api/contributors.rs index 9296c1d428..8cfef0ad7e 100644 --- a/crates/collab/src/api/contributors.rs +++ b/crates/collab/src/api/contributors.rs @@ -8,7 +8,6 @@ use axum::{ use chrono::{NaiveDateTime, SecondsFormat}; use serde::{Deserialize, Serialize}; -use crate::api::AuthenticatedUserParams; use crate::db::ContributorSelector; use crate::{AppState, Result}; @@ -104,9 +103,18 @@ impl RenovateBot { } } +#[derive(Debug, Deserialize)] +struct AddContributorBody { + github_user_id: i32, + github_login: String, + github_email: Option, + github_name: Option, + github_user_created_at: chrono::DateTime, +} + async fn add_contributor( Extension(app): Extension>, - extract::Json(params): extract::Json, + extract::Json(params): extract::Json, ) -> Result<()> { let initial_channel_id = app.config.auto_join_channel_id; app.db diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index bc7dd152b0..2f34a843a8 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -580,7 +580,7 @@ fn for_snowflake( }, serde_json::to_value(e).unwrap(), ), - Event::InlineCompletion(e) => ( + Event::EditPrediction(e) => ( format!( "Edit Prediction {}", if e.suggestion_accepted { @@ -591,7 +591,7 @@ fn for_snowflake( ), serde_json::to_value(e).unwrap(), ), - Event::InlineCompletionRating(e) => ( + Event::EditPredictionRating(e) => ( "Edit Prediction Rated".to_string(), serde_json::to_value(e).unwrap(), ), diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 8cd1e3ea83..2c22ca2069 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -529,11 +529,17 @@ pub struct RejoinedProject { pub worktrees: Vec, pub updated_repositories: Vec, pub removed_repositories: Vec, - pub language_servers: Vec, + pub language_servers: Vec, } impl RejoinedProject { pub fn to_proto(&self) -> proto::RejoinedProject { + let (language_servers, language_server_capabilities) = self + .language_servers + .clone() + .into_iter() + .map(|server| (server.server, server.capabilities)) + .unzip(); proto::RejoinedProject { id: self.id.to_proto(), worktrees: self @@ -551,7 +557,8 @@ impl RejoinedProject { .iter() .map(|collaborator| collaborator.to_proto()) .collect(), - language_servers: self.language_servers.clone(), + language_servers, + language_server_capabilities, } } } @@ -598,7 +605,7 @@ pub struct Project { pub collaborators: Vec, pub worktrees: BTreeMap, pub repositories: Vec, - pub language_servers: Vec, + pub language_servers: Vec, } pub struct ProjectCollaborator { @@ -623,6 +630,12 @@ impl ProjectCollaborator { } } +#[derive(Debug, Clone)] +pub struct LanguageServer { + pub server: proto::LanguageServer, + pub capabilities: String, +} + #[derive(Debug)] pub struct LeftProject { pub id: ProjectId, diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index a288a4e7eb..2e6b4719d1 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -786,6 +786,32 @@ impl Database { }) .collect()) } + + /// Update language server capabilities for a given id. + pub async fn update_server_capabilities( + &self, + project_id: ProjectId, + server_id: u64, + new_capabilities: String, + ) -> Result<()> { + self.transaction(|tx| { + let new_capabilities = new_capabilities.clone(); + async move { + Ok( + language_server::Entity::update(language_server::ActiveModel { + project_id: ActiveValue::unchanged(project_id), + id: ActiveValue::unchanged(server_id as i64), + capabilities: ActiveValue::set(new_capabilities), + ..Default::default() + }) + .exec(&*tx) + .await?, + ) + } + }) + .await?; + Ok(()) + } } fn operation_to_storage( diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index ba22a7b4e3..31635575a8 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -692,6 +692,7 @@ impl Database { project_id: ActiveValue::set(project_id), id: ActiveValue::set(server.id as i64), name: ActiveValue::set(server.name.clone()), + capabilities: ActiveValue::set(update.capabilities.clone()), }) .on_conflict( OnConflict::columns([ @@ -1054,10 +1055,13 @@ impl Database { repositories, language_servers: language_servers .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id as u64, - name: language_server.name, - worktree_id: None, + .map(|language_server| LanguageServer { + server: proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + worktree_id: None, + }, + capabilities: language_server.capabilities, }) .collect(), }; diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index cb805786dd..c63d7133be 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -804,10 +804,13 @@ impl Database { .all(tx) .await? .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id as u64, - name: language_server.name, - worktree_id: None, + .map(|language_server| LanguageServer { + server: proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + worktree_id: None, + }, + capabilities: language_server.capabilities, }) .collect::>(); diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 43198f9859..522973dbc9 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -95,7 +95,7 @@ pub enum SubscriptionKind { ZedFree, } -impl From for zed_llm_client::Plan { +impl From for cloud_llm_client::Plan { fn from(value: SubscriptionKind) -> Self { match value { SubscriptionKind::ZedPro => Self::ZedPro, diff --git a/crates/collab/src/db/tables/language_server.rs b/crates/collab/src/db/tables/language_server.rs index 9ff8c75fc6..34c7514d91 100644 --- a/crates/collab/src/db/tables/language_server.rs +++ b/crates/collab/src/db/tables/language_server.rs @@ -9,6 +9,7 @@ pub struct Model { #[sea_orm(primary_key)] pub id: i64, pub name: String, + pub capabilities: String, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index 6a6efca0de..18ad624dab 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -6,11 +6,11 @@ mod tables; #[cfg(test)] mod tests; +use cloud_llm_client::LanguageModelProvider; use collections::HashMap; pub use ids::*; pub use seed::*; pub use tables::*; -use zed_llm_client::LanguageModelProvider; #[cfg(test)] pub use tests::TestLlmDb; diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs index 7d52964b93..f4e1de40ec 100644 --- a/crates/collab/src/llm/db/tests/provider_tests.rs +++ b/crates/collab/src/llm/db/tests/provider_tests.rs @@ -1,5 +1,5 @@ +use cloud_llm_client::LanguageModelProvider; use pretty_assertions::assert_eq; -use zed_llm_client::LanguageModelProvider; use crate::llm::db::LlmDatabase; use crate::test_llm_db; diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index d4566ffcb4..da01c7f3be 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -4,12 +4,12 @@ use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEA use crate::{Config, db::billing_preference}; use anyhow::{Context as _, Result}; use chrono::{NaiveDateTime, Utc}; +use cloud_llm_client::Plan; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use std::time::Duration; use thiserror::Error; use uuid::Uuid; -use zed_llm_client::Plan; #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 515647f97d..22b21f2c7a 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -23,6 +23,7 @@ use anyhow::{Context as _, anyhow, bail}; use async_tungstenite::tungstenite::{ Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame, }; +use axum::headers::UserAgent; use axum::{ Extension, Router, TypedHeader, body::Body, @@ -41,7 +42,7 @@ use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; use reqwest_client::ReqwestClient; -use rpc::proto::split_repository_update; +use rpc::proto::{MultiLspQuery, split_repository_update}; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; use futures::{ @@ -314,7 +315,7 @@ impl Server { .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_find_search_candidates_request) + .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) @@ -373,7 +374,7 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) + .add_request_handler(multi_lsp_query) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) @@ -665,7 +666,6 @@ impl Server { let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0; let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0; let queue_duration_ms = total_duration_ms - processing_duration_ms; - let payload_type = M::NAME; match result { Err(error) => { @@ -674,7 +674,6 @@ impl Server { total_duration_ms, processing_duration_ms, queue_duration_ms, - payload_type, "error handling message" ) } @@ -750,6 +749,7 @@ impl Server { address: String, principal: Principal, zed_version: ZedVersion, + user_agent: Option, geoip_country_code: Option, system_id: Option, send_connection_id: Option>, @@ -762,9 +762,14 @@ impl Server { user_id=field::Empty, login=field::Empty, impersonator=field::Empty, + user_agent=field::Empty, geoip_country_code=field::Empty ); principal.update_span(&span); + if let Some(user_agent) = user_agent { + span.record("user_agent", user_agent); + } + if let Some(country_code) = geoip_country_code.as_ref() { span.record("geoip_country_code", country_code); } @@ -773,12 +778,11 @@ impl Server { async move { if *teardown.borrow() { tracing::error!("server is tearing down"); - return + return; } - let (connection_id, handle_io, mut incoming_rx) = this - .peer - .add_connection(connection, { + let (connection_id, handle_io, mut incoming_rx) = + this.peer.add_connection(connection, { let executor = executor.clone(); move |duration| executor.sleep(duration) }); @@ -795,10 +799,14 @@ impl Server { } }; - let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new( - supermaven_admin_api_key.to_string(), - http_client.clone(), - ))); + let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map( + |supermaven_admin_api_key| { + Arc::new(SupermavenAdminApi::new( + supermaven_admin_api_key.to_string(), + http_client.clone(), + )) + }, + ); let session = Session { principal: principal.clone(), @@ -813,7 +821,15 @@ impl Server { supermaven_client, }; - if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await { + if let Err(error) = this + .send_initial_client_update( + connection_id, + zed_version, + send_connection_id, + &session, + ) + .await + { tracing::error!(?error, "failed to send initial client update"); return; } @@ -830,14 +846,22 @@ impl Server { // // This arrangement ensures we will attempt to process earlier messages first, but fall // back to processing messages arrived later in the spirit of making progress. + const MAX_CONCURRENT_HANDLERS: usize = 256; let mut foreground_message_handlers = FuturesUnordered::new(); - let concurrent_handlers = Arc::new(Semaphore::new(512)); + let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS)); + let get_concurrent_handlers = { + let concurrent_handlers = concurrent_handlers.clone(); + move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits() + }; loop { let next_message = async { let permit = concurrent_handlers.clone().acquire_owned().await.unwrap(); let message = incoming_rx.next().await; - (permit, message) - }.fuse(); + // Cache the concurrent_handlers here, so that we know what the + // queue looks like as each handler starts + (permit, message, get_concurrent_handlers()) + } + .fuse(); futures::pin_mut!(next_message); futures::select_biased! { _ = teardown.changed().fuse() => return, @@ -849,15 +873,20 @@ impl Server { } _ = foreground_message_handlers.next() => {} next_message = next_message => { - let (permit, message) = next_message; + let (permit, message, concurrent_handlers) = next_message; if let Some(message) = message { let type_name = message.payload_type_name(); // note: we copy all the fields from the parent span so we can query them in the logs. // (https://github.com/tokio-rs/tracing/issues/2670). - let span = tracing::info_span!("receive message", %connection_id, %address, type_name, + let span = tracing::info_span!("receive message", + %connection_id, + %address, + type_name, + concurrent_handlers, user_id=field::Empty, login=field::Empty, impersonator=field::Empty, + multi_lsp_query_request=field::Empty, ); principal.update_span(&span); let span_enter = span.enter(); @@ -887,12 +916,13 @@ impl Server { } drop(foreground_message_handlers); - tracing::info!("signing out"); + let concurrent_handlers = get_concurrent_handlers(); + tracing::info!(concurrent_handlers, "signing out"); if let Err(error) = connection_lost(session, teardown, executor).await { tracing::error!(?error, "error signing out"); } - - }.instrument(span) + } + .instrument(span) } async fn send_initial_client_update( @@ -1172,6 +1202,7 @@ pub async fn handle_websocket_request( ConnectInfo(socket_address): ConnectInfo, Extension(server): Extension>, Extension(principal): Extension, + user_agent: Option>, country_code_header: Option>, system_id_header: Option>, ws: WebSocketUpgrade, @@ -1227,6 +1258,7 @@ pub async fn handle_websocket_request( socket_address, principal, version, + user_agent.map(|header| header.to_string()), country_code_header.map(|header| header.to_string()), system_id_header.map(|header| header.to_string()), None, @@ -1958,12 +1990,19 @@ async fn join_project( } // First, we send the metadata associated with each worktree. + let (language_servers, language_server_capabilities) = project + .language_servers + .clone() + .into_iter() + .map(|server| (server.server, server.capabilities)) + .unzip(); response.send(proto::JoinProjectResponse { project_id: project.id.0 as u64, worktrees: worktrees.clone(), replica_id: replica_id.0 as u32, collaborators: collaborators.clone(), - language_servers: project.language_servers.clone(), + language_servers, + language_server_capabilities, role: project.role.into(), })?; @@ -2022,8 +2061,8 @@ async fn join_project( session.connection_id, proto::UpdateLanguageServer { project_id: project_id.to_proto(), - server_name: Some(language_server.name.clone()), - language_server_id: language_server.id, + server_name: Some(language_server.server.name.clone()), + language_server_id: language_server.server.id, variant: Some( proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( proto::LspDiskBasedDiagnosticsUpdated {}, @@ -2235,9 +2274,17 @@ async fn update_language_server( session: Session, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = session - .db() - .await + let db = session.db().await; + + if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant + { + if let Some(capabilities) = update.capabilities.clone() { + db.update_server_capabilities(project_id, request.language_server_id, capabilities) + .await?; + } + } + + let project_connection_ids = db .project_connection_ids(project_id, session.connection_id, true) .await?; broadcast( @@ -2276,25 +2323,6 @@ where Ok(()) } -async fn forward_find_search_candidates_request( - request: proto::FindSearchCandidates, - response: Response, - session: Session, -) -> Result<()> { - let project_id = ProjectId::from_proto(request.remote_entity_id()); - let host_connection_id = session - .db() - .await - .host_for_read_only_project_request(project_id, session.connection_id) - .await?; - let payload = session - .peer - .forward_request(session.connection_id, host_connection_id, request) - .await?; - response.send(payload)?; - Ok(()) -} - /// forward a project request to the host. These requests are disallowed /// for guests. async fn forward_mutating_project_request( @@ -2320,6 +2348,16 @@ where Ok(()) } +async fn multi_lsp_query( + request: MultiLspQuery, + response: Response, + session: Session, +) -> Result<()> { + tracing::Span::current().record("multi_lsp_query_request", request.request_str()); + tracing::info!("multi_lsp_query message received"); + forward_mutating_project_request(request, response, session).await +} + /// Notify other participants that a new buffer has been created async fn create_buffer_for_peer( request: proto::CreateBufferForPeer, @@ -2859,12 +2897,12 @@ async fn make_update_user_plan_message( } fn model_requests_limit( - plan: zed_llm_client::Plan, + plan: cloud_llm_client::Plan, feature_flags: &Vec, -) -> zed_llm_client::UsageLimit { +) -> cloud_llm_client::UsageLimit { match plan.model_requests_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == zed_llm_client::Plan::ZedProTrial + cloud_llm_client::UsageLimit::Limited(limit) => { + let limit = if plan == cloud_llm_client::Plan::ZedProTrial && feature_flags .iter() .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) @@ -2874,9 +2912,9 @@ fn model_requests_limit( limit }; - zed_llm_client::UsageLimit::Limited(limit) + cloud_llm_client::UsageLimit::Limited(limit) } - zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited, + cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited, } } @@ -2886,21 +2924,21 @@ fn subscription_usage_to_proto( feature_flags: &Vec, ) -> proto::SubscriptionUsage { let plan = match plan { - proto::Plan::Free => zed_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + proto::Plan::Free => cloud_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, }; proto::SubscriptionUsage { model_requests_usage_amount: usage.model_requests as u32, model_requests_usage_limit: Some(proto::UsageLimit { variant: Some(match model_requests_limit(plan, feature_flags) { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2908,12 +2946,12 @@ fn subscription_usage_to_proto( edit_predictions_usage_amount: usage.edit_predictions as u32, edit_predictions_usage_limit: Some(proto::UsageLimit { variant: Some(match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2926,21 +2964,21 @@ fn make_default_subscription_usage( feature_flags: &Vec, ) -> proto::SubscriptionUsage { let plan = match plan { - proto::Plan::Free => zed_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + proto::Plan::Free => cloud_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, }; proto::SubscriptionUsage { model_requests_usage_amount: 0, model_requests_usage_limit: Some(proto::UsageLimit { variant: Some(match model_requests_limit(plan, feature_flags) { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2948,12 +2986,12 @@ fn make_default_subscription_usage( edit_predictions_usage_amount: 0, edit_predictions_usage_limit: Some(proto::UsageLimit { variant: Some(match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 19e410de5b..8d5d076780 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -38,12 +38,12 @@ fn room_participants(room: &Entity, cx: &mut TestAppContext) -> RoomPartic let mut remote = room .remote_participants() .values() - .map(|participant| participant.user.github_login.clone()) + .map(|participant| participant.user.github_login.clone().to_string()) .collect::>(); let mut pending = room .pending_participants() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect::>(); remote.sort(); pending.sort(); diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 73ab2b8167..1d28c7f6ef 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -296,19 +296,28 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + trigger_characters: Some(vec![".".to_string()]), + resolve_provider: Some(true), + ..lsp::CompletionOptions::default() + }), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - completion_provider: Some(lsp::CompletionOptions { - trigger_characters: Some(vec![".".to_string()]), - resolve_provider: Some(true), - ..Default::default() - }), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -566,11 +575,14 @@ async fn test_collaborating_with_code_actions( cx_b.update(editor::init); - // Set up a fake language server. client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a .language_registry() .register_fake_lsp("Rust", FakeLspAdapter::default()); + client_b.language_registry().add(rust_lang()); + client_b + .language_registry() + .register_fake_lsp("Rust", FakeLspAdapter::default()); client_a .fs() @@ -775,19 +787,27 @@ async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut T cx_b.update(editor::init); - // Set up a fake language server. + let capabilities = lsp::ServerCapabilities { + rename_provider: Some(lsp::OneOf::Right(lsp::RenameOptions { + prepare_provider: Some(true), + work_done_progress_options: Default::default(), + })), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - rename_provider: Some(lsp::OneOf::Right(lsp::RenameOptions { - prepare_provider: Some(true), - work_done_progress_options: Default::default(), - })), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -818,6 +838,8 @@ async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut T .downcast::() .unwrap(); let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); // Move cursor to a location that can be renamed. let prepare_rename = editor_b.update_in(cx_b, |editor, window, cx| { @@ -1055,7 +1077,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_a.read_with(cx_a, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); assert_eq!(status.pending_work.len(), 1); assert_eq!( status.pending_work["the-token"].message.as_ref().unwrap(), @@ -1072,7 +1094,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); }); executor.advance_clock(SERVER_PROGRESS_THROTTLE_TIMEOUT); @@ -1089,7 +1111,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_a.read_with(cx_a, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); assert_eq!(status.pending_work.len(), 1); assert_eq!( status.pending_work["the-token"].message.as_ref().unwrap(), @@ -1099,7 +1121,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); assert_eq!(status.pending_work.len(), 1); assert_eq!( status.pending_work["the-token"].message.as_ref().unwrap(), @@ -1422,18 +1444,27 @@ async fn test_on_input_format_from_guest_to_host( .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + document_on_type_formatting_provider: Some(lsp::DocumentOnTypeFormattingOptions { + first_trigger_character: ":".to_string(), + more_trigger_character: Some(vec![">".to_string()]), + }), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - document_on_type_formatting_provider: Some(lsp::DocumentOnTypeFormattingOptions { - first_trigger_character: ":".to_string(), - more_trigger_character: Some(vec![">".to_string()]), - }), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -1588,16 +1619,24 @@ async fn test_mutual_editor_inlay_hint_cache_update( }); }); + let capabilities = lsp::ServerCapabilities { + inlay_hint_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - inlay_hint_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -1830,16 +1869,24 @@ async fn test_inlay_hint_refresh_is_forwarded( }); }); + let capabilities = lsp::ServerCapabilities { + inlay_hint_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - inlay_hint_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -2004,15 +2051,23 @@ async fn test_lsp_document_color(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo }); }); + let capabilities = lsp::ServerCapabilities { + color_provider: Some(lsp::ColorProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - color_provider: Some(lsp::ColorProviderCapability::Simple(true)), - ..lsp::ServerCapabilities::default() - }, + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, ..FakeLspAdapter::default() }, ); @@ -2063,6 +2118,8 @@ async fn test_lsp_document_color(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo .unwrap(); let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let requests_made = Arc::new(AtomicUsize::new(0)); let closure_requests_made = Arc::clone(&requests_made); @@ -2264,24 +2321,32 @@ async fn test_lsp_pull_diagnostics( cx_a.update(editor::init); cx_b.update(editor::init); + let capabilities = lsp::ServerCapabilities { + diagnostic_provider: Some(lsp::DiagnosticServerCapabilities::Options( + lsp::DiagnosticOptions { + identifier: Some("test-pulls".to_string()), + inter_file_dependencies: true, + workspace_diagnostics: true, + work_done_progress_options: lsp::WorkDoneProgressOptions { + work_done_progress: None, + }, + }, + )), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - diagnostic_provider: Some(lsp::DiagnosticServerCapabilities::Options( - lsp::DiagnosticOptions { - identifier: Some("test-pulls".to_string()), - inter_file_dependencies: true, - workspace_diagnostics: true, - work_done_progress_options: lsp::WorkDoneProgressOptions { - work_done_progress: None, - }, - }, - )), - ..lsp::ServerCapabilities::default() - }, + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, ..FakeLspAdapter::default() }, ); @@ -2334,6 +2399,8 @@ async fn test_lsp_pull_diagnostics( .unwrap(); let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let expected_push_diagnostic_main_message = "pushed main diagnostic"; let expected_push_diagnostic_lib_message = "pushed lib diagnostic"; let expected_pull_diagnostic_main_message = "pulled main diagnostic"; @@ -2689,6 +2756,7 @@ async fn test_lsp_pull_diagnostics( .unwrap() .downcast::() .unwrap(); + cx_b.run_until_parked(); pull_diagnostics_handle.next().await.unwrap(); assert_eq!( diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index 9795c27574..5a2c40b890 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -842,7 +842,7 @@ async fn test_client_disconnecting_from_room( // Allow user A to reconnect to the server. server.allow_connections(); - executor.advance_clock(RECEIVE_TIMEOUT); + executor.advance_clock(RECONNECT_TIMEOUT); // Call user B again from client A. active_call_a @@ -1286,7 +1286,7 @@ async fn test_calls_on_multiple_connections( client_b1.disconnect(&cx_b1.to_async()); executor.advance_clock(RECEIVE_TIMEOUT); client_b1 - .authenticate_and_connect(false, &cx_b1.to_async()) + .connect(false, &cx_b1.to_async()) .await .into_response() .unwrap(); @@ -1358,7 +1358,7 @@ async fn test_calls_on_multiple_connections( // User A reconnects automatically, then calls user B again. server.allow_connections(); - executor.advance_clock(RECEIVE_TIMEOUT); + executor.advance_clock(RECONNECT_TIMEOUT); active_call_a .update(cx_a, |call, cx| { call.invite(client_b1.user_id().unwrap(), None, cx) @@ -1667,7 +1667,7 @@ async fn test_project_reconnect( // Client A reconnects. Their project is re-shared, and client B re-joins it. server.allow_connections(); client_a - .authenticate_and_connect(false, &cx_a.to_async()) + .connect(false, &cx_a.to_async()) .await .into_response() .unwrap(); @@ -1796,7 +1796,7 @@ async fn test_project_reconnect( // Client B reconnects. They re-join the room and the remaining shared project. server.allow_connections(); client_b - .authenticate_and_connect(false, &cx_b.to_async()) + .connect(false, &cx_b.to_async()) .await .into_response() .unwrap(); @@ -1881,7 +1881,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_a.user_id().unwrap(), - github_login: "user_a".to_string(), + github_login: "user_a".into(), avatar_uri: "avatar_a".into(), name: None, }), @@ -1900,7 +1900,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_b.user_id().unwrap(), - github_login: "user_b".to_string(), + github_login: "user_b".into(), avatar_uri: "avatar_b".into(), name: None, }), @@ -4778,10 +4778,27 @@ async fn test_definition( .await; let active_call_a = cx_a.read(ActiveCall::global); - let mut fake_language_servers = client_a - .language_registry() - .register_fake_lsp("Rust", Default::default()); + let capabilities = lsp::ServerCapabilities { + definition_provider: Some(OneOf::Left(true)), + type_definition_provider: Some(lsp::TypeDefinitionProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); client_a .fs() @@ -4827,13 +4844,19 @@ async fn test_definition( ))) }, ); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let definitions_1 = project_b .update(cx_b, |p, cx| p.definitions(&buffer_b, 23, cx)) .await .unwrap(); cx_b.read(|cx| { - assert_eq!(definitions_1.len(), 1); + assert_eq!( + definitions_1.len(), + 1, + "Unexpected definitions: {definitions_1:?}" + ); assert_eq!(project_b.read(cx).worktrees(cx).count(), 2); let target_buffer = definitions_1[0].target.buffer.read(cx); assert_eq!( @@ -4901,7 +4924,11 @@ async fn test_definition( .await .unwrap(); cx_b.read(|cx| { - assert_eq!(type_definitions.len(), 1); + assert_eq!( + type_definitions.len(), + 1, + "Unexpected type definitions: {type_definitions:?}" + ); let target_buffer = type_definitions[0].target.buffer.read(cx); assert_eq!(target_buffer.text(), "type T2 = usize;"); assert_eq!( @@ -4925,16 +4952,26 @@ async fn test_references( .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + references_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { name: "my-fake-lsp-adapter", - capabilities: lsp::ServerCapabilities { - references_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: "my-fake-lsp-adapter", + capabilities: capabilities, + ..FakeLspAdapter::default() }, ); @@ -4989,6 +5026,8 @@ async fn test_references( } } }); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let references = project_b.update(cx_b, |p, cx| p.references(&buffer_b, 7, cx)); @@ -4996,7 +5035,7 @@ async fn test_references( executor.run_until_parked(); project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "my-fake-lsp-adapter"); + assert_eq!(status.name.0, "my-fake-lsp-adapter"); assert_eq!( status.pending_work.values().next().unwrap().message, Some("Finding references...".into()) @@ -5054,7 +5093,7 @@ async fn test_references( executor.run_until_parked(); project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "my-fake-lsp-adapter"); + assert_eq!(status.name.0, "my-fake-lsp-adapter"); assert_eq!( status.pending_work.values().next().unwrap().message, Some("Finding references...".into()) @@ -5204,10 +5243,26 @@ async fn test_document_highlights( ) .await; - let mut fake_language_servers = client_a - .language_registry() - .register_fake_lsp("Rust", Default::default()); client_a.language_registry().add(rust_lang()); + let capabilities = lsp::ServerCapabilities { + document_highlight_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); let (project_a, worktree_id) = client_a.build_local_project(path!("/root-1"), cx_a).await; let project_id = active_call_a @@ -5256,6 +5311,8 @@ async fn test_document_highlights( ])) }, ); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let highlights = project_b .update(cx_b, |p, cx| p.document_highlights(&buffer_b, 34, cx)) @@ -5306,30 +5363,49 @@ async fn test_lsp_hover( client_a.language_registry().add(rust_lang()); let language_server_names = ["rust-analyzer", "CrabLang-ls"]; + let capabilities_1 = lsp::ServerCapabilities { + hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; + let capabilities_2 = lsp::ServerCapabilities { + hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; let mut language_servers = [ client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - name: "rust-analyzer", - capabilities: lsp::ServerCapabilities { - hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), - ..lsp::ServerCapabilities::default() - }, + name: language_server_names[0], + capabilities: capabilities_1.clone(), ..FakeLspAdapter::default() }, ), client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - name: "CrabLang-ls", - capabilities: lsp::ServerCapabilities { - hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), - ..lsp::ServerCapabilities::default() - }, + name: language_server_names[1], + capabilities: capabilities_2.clone(), ..FakeLspAdapter::default() }, ), ]; + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: language_server_names[0], + capabilities: capabilities_1, + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: language_server_names[1], + capabilities: capabilities_2, + ..FakeLspAdapter::default() + }, + ); let (project_a, worktree_id) = client_a.build_local_project(path!("/root-1"), cx_a).await; let project_id = active_call_a @@ -5423,6 +5499,8 @@ async fn test_lsp_hover( unexpected => panic!("Unexpected server name: {unexpected}"), } } + cx_a.run_until_parked(); + cx_b.run_until_parked(); // Request hover information as the guest. let mut hovers = project_b @@ -5605,10 +5683,26 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + definition_provider: Some(OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - let mut fake_language_servers = client_a - .language_registry() - .register_fake_lsp("Rust", Default::default()); + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); client_a .fs() @@ -5649,6 +5743,8 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( let definitions; let buffer_b2; if rng.r#gen() { + cx_a.run_until_parked(); + cx_b.run_until_parked(); definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx)); (buffer_b2, _) = project_b .update(cx_b, |p, cx| { @@ -5663,11 +5759,17 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( }) .await .unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx)); } let definitions = definitions.await.unwrap(); - assert_eq!(definitions.len(), 1); + assert_eq!( + definitions.len(), + 1, + "Unexpected definitions: {definitions:?}" + ); assert_eq!(definitions[0].target.buffer, buffer_b2); } @@ -5738,7 +5840,7 @@ async fn test_contacts( server.allow_connections(); client_c - .authenticate_and_connect(false, &cx_c.to_async()) + .connect(false, &cx_c.to_async()) .await .into_response() .unwrap(); @@ -6079,7 +6181,7 @@ async fn test_contacts( .iter() .map(|contact| { ( - contact.user.github_login.clone(), + contact.user.github_login.clone().to_string(), if contact.online { "online" } else { "offline" }, if contact.busy { "busy" } else { "free" }, ) @@ -6269,7 +6371,7 @@ async fn test_contact_requests( client.disconnect(&cx.to_async()); client.clear_contacts(cx).await; client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs index 4e64b5526b..9bf906694e 100644 --- a/crates/collab/src/tests/notification_tests.rs +++ b/crates/collab/src/tests/notification_tests.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use gpui::{BackgroundExecutor, TestAppContext}; use notifications::NotificationEvent; use parking_lot::Mutex; +use pretty_assertions::assert_eq; use rpc::{Notification, proto}; use crate::tests::TestServer; @@ -17,6 +18,9 @@ async fn test_notifications( let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; + // Wait for authentication/connection to Collab to be established. + executor.run_until_parked(); + let notification_events_a = Arc::new(Mutex::new(Vec::new())); let notification_events_b = Arc::new(Mutex::new(Vec::new())); client_a.notification_store().update(cx_a, |_, cx| { diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index ab84e02b19..5fcc622fc1 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -8,6 +8,7 @@ use crate::{ use anyhow::anyhow; use call::ActiveCall; use channel::{ChannelBuffer, ChannelStore}; +use client::test::{make_get_authenticated_user_response, parse_authorization_header}; use client::{ self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore, proto::PeerId, @@ -20,7 +21,7 @@ use fs::FakeFs; use futures::{StreamExt as _, channel::oneshot}; use git::GitHostingProviderRegistry; use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext}; -use http_client::FakeHttpClient; +use http_client::{FakeHttpClient, Method}; use language::LanguageRegistry; use node_runtime::NodeRuntime; use notifications::NotificationStore; @@ -161,6 +162,8 @@ impl TestServer { } pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { + const ACCESS_TOKEN: &str = "the-token"; + let fs = FakeFs::new(cx.executor()); cx.update(|cx| { @@ -175,7 +178,7 @@ impl TestServer { }); let clock = Arc::new(FakeSystemClock::new()); - let http = FakeHttpClient::with_404_response(); + let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await { user.id @@ -197,6 +200,47 @@ impl TestServer { .expect("creating user failed") .user_id }; + + let http = FakeHttpClient::create({ + let name = name.to_string(); + move |req| { + let name = name.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: user_id.to_proto(), + access_token: ACCESS_TOKEN.into(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + user_id.0, name, + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } + } + } + }); + let client_name = name.to_string(); let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx)); let server = self.server.clone(); @@ -208,11 +252,10 @@ impl TestServer { .unwrap() .set_id(user_id.to_proto()) .override_authenticate(move |cx| { - let access_token = "the-token".to_string(); cx.spawn(async move |_| { Ok(Credentials { user_id: user_id.to_proto(), - access_token, + access_token: ACCESS_TOKEN.into(), }) }) }) @@ -221,7 +264,7 @@ impl TestServer { credentials, &Credentials { user_id: user_id.0 as u64, - access_token: "the-token".into() + access_token: ACCESS_TOKEN.into(), } ); @@ -256,6 +299,7 @@ impl TestServer { ZedVersion(SemanticVersion::new(1, 0, 0)), None, None, + None, Some(connection_id_tx), Executor::Deterministic(cx.background_executor().clone()), None, @@ -318,7 +362,7 @@ impl TestServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -691,17 +735,17 @@ impl TestClient { current: store .contacts() .iter() - .map(|contact| contact.user.github_login.clone()) + .map(|contact| contact.user.github_login.clone().to_string()) .collect(), outgoing_requests: store .outgoing_contact_requests() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect(), incoming_requests: store .incoming_contact_requests() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect(), }) } diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 3e2d813f1b..3a9b568264 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -1162,7 +1162,7 @@ impl Panel for ChatPanel { } fn icon(&self, _window: &Window, cx: &App) -> Option { - self.enabled(cx).then(|| ui::IconName::MessageBubbles) + self.enabled(cx).then(|| ui::IconName::Chat) } fn icon_tooltip(&self, _: &Window, _: &App) -> Option<&'static str> { diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 4d5973481e..689591df12 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -940,7 +940,7 @@ impl CollabPanel { room.read(cx).local_participant().role == proto::ChannelRole::Admin }); - ListItem::new(SharedString::from(user.github_login.clone())) + ListItem::new(user.github_login.clone()) .start_slot(Avatar::new(user.avatar_uri.clone())) .child(Label::new(user.github_login.clone())) .toggle_state(is_selected) @@ -1124,7 +1124,7 @@ impl CollabPanel { .relative() .gap_1() .child(render_tree_branch(false, false, window, cx)) - .child(IconButton::new(0, IconName::MessageBubbles)) + .child(IconButton::new(0, IconName::Chat)) .children(has_messages_notification.then(|| { div() .w_1p5() @@ -2331,7 +2331,7 @@ impl CollabPanel { let client = this.client.clone(); cx.spawn_in(window, async move |_, cx| { client - .authenticate_and_connect(true, &cx) + .connect(true, &cx) .await .into_response() .notify_async_err(cx); @@ -2583,7 +2583,7 @@ impl CollabPanel { ) -> impl IntoElement { let online = contact.online; let busy = contact.busy || calling; - let github_login = SharedString::from(contact.user.github_login.clone()); + let github_login = contact.user.github_login.clone(); let item = ListItem::new(github_login.clone()) .indent_level(1) .indent_step_size(px(20.)) @@ -2662,7 +2662,7 @@ impl CollabPanel { is_selected: bool, cx: &mut Context, ) -> impl IntoElement { - let github_login = SharedString::from(user.github_login.clone()); + let github_login = user.github_login.clone(); let user_id = user.id; let is_response_pending = self.user_store.read(cx).is_contact_request_pending(user); let color = if is_response_pending { @@ -2923,7 +2923,7 @@ impl CollabPanel { .gap_1() .px_1() .child( - IconButton::new("channel_chat", IconName::MessageBubbles) + IconButton::new("channel_chat", IconName::Chat) .style(ButtonStyle::Filled) .shape(ui::IconButtonShape::Square) .icon_size(IconSize::Small) @@ -2939,7 +2939,7 @@ impl CollabPanel { .visible_on_hover(""), ) .child( - IconButton::new("channel_notes", IconName::File) + IconButton::new("channel_notes", IconName::FileText) .style(ButtonStyle::Filled) .shape(ui::IconButtonShape::Square) .icon_size(IconSize::Small) diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index fba8f66c2d..c3e834b645 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -634,13 +634,13 @@ impl Render for NotificationPanel { .child(Icon::new(IconName::Envelope)), ) .map(|this| { - if self.client.user_id().is_none() { + if !self.client.status().borrow().is_connected() { this.child( v_flex() .gap_2() .p_4() .child( - Button::new("sign_in_prompt_button", "Sign in") + Button::new("connect_prompt_button", "Connect") .icon_color(Color::Muted) .icon(IconName::Github) .icon_position(IconPosition::Start) @@ -652,10 +652,7 @@ impl Render for NotificationPanel { let client = client.clone(); window .spawn(cx, async move |cx| { - match client - .authenticate_and_connect(true, &cx) - .await - { + match client.connect(true, &cx).await { util::ConnectionResult::Timeout => { log::error!("Connection timeout"); } @@ -673,7 +670,7 @@ impl Render for NotificationPanel { ) .child( div().flex().w_full().items_center().child( - Label::new("Sign in to view notifications.") + Label::new("Connect to view notifications.") .color(Color::Muted) .size(LabelSize::Small), ), diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index ff4d79c07d..65283afa87 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -158,6 +158,7 @@ impl Client { pub fn stdio( server_id: ContextServerId, binary: ModelContextServerBinary, + working_directory: &Option, cx: AsyncApp, ) -> Result { log::info!( @@ -172,7 +173,7 @@ impl Client { .map(|name| name.to_string_lossy().to_string()) .unwrap_or_else(String::new); - let transport = Arc::new(StdioTransport::new(binary, &cx)?); + let transport = Arc::new(StdioTransport::new(binary, working_directory, &cx)?); Self::new(server_id, server_name.into(), transport, cx) } @@ -440,14 +441,12 @@ impl Client { Ok(()) } - #[allow(unused)] - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { - self.notification_handlers - .lock() - .insert(method, Box::new(f)); + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { + self.notification_handlers.lock().insert(method, f); } } diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index f2517feb27..34fa29678d 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -53,7 +53,7 @@ impl std::fmt::Debug for ContextServerCommand { } enum ContextServerTransport { - Stdio(ContextServerCommand), + Stdio(ContextServerCommand, Option), Custom(Arc), } @@ -64,11 +64,18 @@ pub struct ContextServer { } impl ContextServer { - pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self { + pub fn stdio( + id: ContextServerId, + command: ContextServerCommand, + working_directory: Option>, + ) -> Self { Self { id, client: RwLock::new(None), - configuration: ContextServerTransport::Stdio(command), + configuration: ContextServerTransport::Stdio( + command, + working_directory.map(|directory| directory.to_path_buf()), + ), } } @@ -88,15 +95,36 @@ impl ContextServer { self.client.read().clone() } - pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { - let client = match &self.configuration { - ContextServerTransport::Stdio(command) => Client::stdio( + pub async fn start(&self, cx: &AsyncApp) -> Result<()> { + self.initialize(self.new_client(cx)?).await + } + + /// Starts the context server, making sure handlers are registered before initialization happens + pub async fn start_with_handlers( + &self, + notification_handlers: Vec<( + &'static str, + Box, + )>, + cx: &AsyncApp, + ) -> Result<()> { + let client = self.new_client(cx)?; + for (method, handler) in notification_handlers { + client.on_notification(method, handler); + } + self.initialize(client).await + } + + fn new_client(&self, cx: &AsyncApp) -> Result { + Ok(match &self.configuration { + ContextServerTransport::Stdio(command, working_directory) => Client::stdio( client::ContextServerId(self.id.0.clone()), client::ModelContextServerBinary { executable: Path::new(&command.path).to_path_buf(), args: command.args.clone(), env: command.env.clone(), }, + working_directory, cx.clone(), )?, ContextServerTransport::Custom(transport) => Client::new( @@ -105,8 +133,7 @@ impl ContextServer { transport.clone(), cx.clone(), )?, - }; - self.initialize(client).await + }) } async fn initialize(&self, client: Client) -> Result<()> { diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs index 34e3a9a78c..0e85fb2129 100644 --- a/crates/context_server/src/listener.rs +++ b/crates/context_server/src/listener.rs @@ -83,14 +83,18 @@ impl McpServer { } pub fn add_tool(&mut self, tool: T) { - let output_schema = schemars::schema_for!(T::Output); - let unit_schema = schemars::schema_for!(()); + let mut settings = schemars::generate::SchemaSettings::draft07(); + settings.inline_subschemas = true; + let mut generator = settings.into_generator(); + + let output_schema = generator.root_schema_for::(); + let unit_schema = generator.root_schema_for::(); let registered_tool = RegisteredTool { tool: Tool { name: T::NAME.into(), description: Some(tool.description().into()), - input_schema: schemars::schema_for!(T::Input).into(), + input_schema: generator.root_schema_for::().into(), output_schema: if output_schema == unit_schema { None } else { diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 9ccbc8a553..5355f20f62 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -115,10 +115,11 @@ impl InitializedContextServerProtocol { self.inner.notify(T::METHOD, params) } - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { self.inner.on_notification(method, f); } } diff --git a/crates/context_server/src/transport/stdio_transport.rs b/crates/context_server/src/transport/stdio_transport.rs index 56d0240fa5..443b8c16f1 100644 --- a/crates/context_server/src/transport/stdio_transport.rs +++ b/crates/context_server/src/transport/stdio_transport.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use std::pin::Pin; use anyhow::{Context as _, Result}; @@ -22,7 +23,11 @@ pub struct StdioTransport { } impl StdioTransport { - pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result { + pub fn new( + binary: ModelContextServerBinary, + working_directory: &Option, + cx: &AsyncApp, + ) -> Result { let mut command = util::command::new_smol_command(&binary.executable); command .args(&binary.args) @@ -32,6 +37,10 @@ impl StdioTransport { .stderr(std::process::Stdio::piped()) .kill_on_drop(true); + if let Some(working_directory) = working_directory { + command.current_dir(working_directory); + } + let mut server = command.spawn().with_context(|| { format!( "failed to spawn command. (path={:?}, args={:?})", diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 234875d420..0fc119f311 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -34,7 +34,7 @@ fs.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true language.workspace = true log.workspace = true lsp.workspace = true @@ -46,6 +46,7 @@ project.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +sum_tree.workspace = true task.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index e11242cb15..49ae2b9d9c 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -6,7 +6,6 @@ mod sign_in; use crate::sign_in::initiate_sign_in_within_workspace; use ::fs::Fs; use anyhow::{Context as _, Result, anyhow}; -use client::DisableAiSettings; use collections::{HashMap, HashSet}; use command_palette_hooks::CommandPaletteFilter; use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared}; @@ -24,6 +23,7 @@ use language::{ use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName}; use node_runtime::NodeRuntime; use parking_lot::Mutex; +use project::DisableAiSettings; use request::StatusNotification; use serde_json::json; use settings::Settings; @@ -39,6 +39,7 @@ use std::{ path::{Path, PathBuf}, sync::Arc, }; +use sum_tree::Dimensions; use util::{ResultExt, fs::remove_matching}; use workspace::Workspace; @@ -85,45 +86,13 @@ pub fn init( move |cx| Copilot::start(new_server_id, fs, node_runtime, cx) }); Copilot::set_global(copilot.clone(), cx); - cx.observe(&copilot, |handle, cx| { - let copilot_action_types = [ - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - ]; - let copilot_auth_action_types = [TypeId::of::()]; - let copilot_no_auth_action_types = [TypeId::of::()]; - let status = handle.read(cx).status(); - - let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; - let filter = CommandPaletteFilter::global_mut(cx); - - if is_ai_disabled { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.hide_action_types(&copilot_no_auth_action_types); - } else { - match status { - Status::Disabled => { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.hide_action_types(&copilot_no_auth_action_types); - } - Status::Authorized => { - filter.hide_action_types(&copilot_no_auth_action_types); - filter.show_action_types( - copilot_action_types - .iter() - .chain(&copilot_auth_action_types), - ); - } - _ => { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.show_action_types(copilot_no_auth_action_types.iter()); - } - } + cx.observe(&copilot, |copilot, cx| { + copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); + }) + .detach(); + cx.observe_global::(|cx| { + if let Some(copilot) = Copilot::global(cx) { + copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); } }) .detach(); @@ -271,7 +240,7 @@ impl RegisteredBuffer { let new_snapshot = new_snapshot.clone(); async move { new_snapshot - .edits_since::<(PointUtf16, usize)>(&old_version) + .edits_since::>(&old_version) .map(|edit| { let edit_start = edit.new.start.0; let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); @@ -1131,6 +1100,44 @@ impl Copilot { cx.notify(); } } + + fn update_action_visibilities(&self, cx: &mut App) { + let signed_in_actions = [ + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + let auth_actions = [TypeId::of::()]; + let no_auth_actions = [TypeId::of::()]; + let status = self.status(); + + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let filter = CommandPaletteFilter::global_mut(cx); + + if is_ai_disabled { + filter.hide_action_types(&signed_in_actions); + filter.hide_action_types(&auth_actions); + filter.hide_action_types(&no_auth_actions); + } else { + match status { + Status::Disabled => { + filter.hide_action_types(&signed_in_actions); + filter.hide_action_types(&auth_actions); + filter.hide_action_types(&no_auth_actions); + } + Status::Authorized => { + filter.hide_action_types(&no_auth_actions); + filter.show_action_types(signed_in_actions.iter().chain(&auth_actions)); + } + _ => { + filter.hide_action_types(&signed_in_actions); + filter.hide_action_types(&auth_actions); + filter.show_action_types(no_auth_actions.iter()); + } + } + } + } } fn id_for_language(language: Option<&Arc>) -> String { diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index 8dc04622f9..2a7225c4e3 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -1,7 +1,7 @@ use crate::{Completion, Copilot}; use anyhow::Result; +use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use gpui::{App, Context, Entity, EntityId, Task}; -use inline_completion::{Direction, EditPredictionProvider, InlineCompletion}; use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings}; use project::Project; use settings::Settings; @@ -210,7 +210,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { buffer: &Entity, cursor_position: language::Anchor, cx: &mut Context, - ) -> Option { + ) -> Option { let buffer_id = buffer.entity_id(); let buffer = buffer.read(cx); let completion = self.active_completion()?; @@ -241,7 +241,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { None } else { let position = cursor_position.bias_right(buffer); - Some(InlineCompletion { + Some(EditPrediction { id: None, edits: vec![(position..position, completion_text.into())], edit_preview: None, @@ -343,7 +343,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { assert!(editor.context_menu_visible()); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); // Since we have both, the copilot suggestion is not shown inline assert_eq!(editor.text(cx), "one.\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.\ntwo\nthree\n"); @@ -355,7 +355,7 @@ mod tests { .unwrap() .detach(); assert!(!editor.context_menu_visible()); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.completion_a\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.completion_a\ntwo\nthree\n"); }); @@ -389,7 +389,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); // Since only the copilot is available, it's shown inline assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.\ntwo\nthree\n"); @@ -400,7 +400,7 @@ mod tests { executor.run_until_parked(); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); }); @@ -418,25 +418,25 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); // Canceling should remove the active Copilot suggestion. editor.cancel(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.c\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); // After canceling, tabbing shouldn't insert the previously shown suggestion. editor.tab(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.c \ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c \ntwo\nthree\n"); // When undoing the previously active suggestion is shown again. editor.undo(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); }); @@ -444,25 +444,25 @@ mod tests { // If an edit occurs outside of this editor, the suggestion is still correctly interpolated. cx.update_buffer(|buffer, cx| buffer.edit([(5..5, "o")], None, cx)); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.co\ntwo\nthree\n"); // AcceptEditPrediction when there is an active suggestion inserts it. editor.accept_edit_prediction(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.copilot2\ntwo\nthree\n"); // When undoing the previously active suggestion is shown again. editor.undo(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.co\ntwo\nthree\n"); // Hide suggestion. editor.cancel(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.co\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.co\ntwo\nthree\n"); }); @@ -471,7 +471,7 @@ mod tests { // we won't make it visible. cx.update_buffer(|buffer, cx| buffer.edit([(6..6, "p")], None, cx)); cx.update_editor(|editor, _, cx| { - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.cop\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.cop\ntwo\nthree\n"); }); @@ -498,19 +498,19 @@ mod tests { }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "fn foo() {\n let x = 4;\n}"); assert_eq!(editor.text(cx), "fn foo() {\n \n}"); // Tabbing inside of leading whitespace inserts indentation without accepting the suggestion. editor.tab(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "fn foo() {\n \n}"); assert_eq!(editor.display_text(cx), "fn foo() {\n let x = 4;\n}"); // Using AcceptEditPrediction again accepts the suggestion. editor.accept_edit_prediction(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "fn foo() {\n let x = 4;\n}"); assert_eq!(editor.display_text(cx), "fn foo() {\n let x = 4;\n}"); }); @@ -575,17 +575,17 @@ mod tests { ); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); // Accepting the first word of the suggestion should only accept the first word and still show the rest. - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.copilot\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); // Accepting next word should accept the non-word and copilot suggestion should be gone - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.copilot1\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); }); @@ -617,11 +617,11 @@ mod tests { ); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); // Accepting the first word (non-word) of the suggestion should only accept the first word and still show the rest. - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.123. \ntwo\nthree\n"); assert_eq!( editor.display_text(cx), @@ -629,8 +629,8 @@ mod tests { ); // Accepting next word should accept the next word and copilot suggestion should still exist - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.123. copilot\ntwo\nthree\n"); assert_eq!( editor.display_text(cx), @@ -638,8 +638,8 @@ mod tests { ); // Accepting the whitespace should accept the non-word/whitespaces with newline and copilot suggestion should be gone - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.123. copilot\n 456\ntwo\nthree\n"); assert_eq!( editor.display_text(cx), @@ -692,29 +692,29 @@ mod tests { }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\ntw\nthree\n"); editor.backspace(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\nt\nthree\n"); editor.backspace(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\n\nthree\n"); // Deleting across the original suggestion range invalidates it. editor.backspace(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\nthree\n"); assert_eq!(editor.text(cx), "one\nthree\n"); // Undoing the deletion restores the suggestion. editor.undo(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\n\nthree\n"); }); @@ -775,7 +775,7 @@ mod tests { }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); _ = editor.update(cx, |editor, _, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2 + a\n\n\n\nc = 3\nd = 4\n" @@ -797,7 +797,7 @@ mod tests { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.select_ranges([Point::new(4, 5)..Point::new(4, 5)]) }); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2\n\n\n\nc = 3\nd = 4\n" @@ -806,7 +806,7 @@ mod tests { // Type a character, ensuring we don't even try to interpolate the previous suggestion. editor.handle_input(" ", window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2\n\n\n\nc = 3\nd = 4 \n" @@ -817,7 +817,7 @@ mod tests { // Ensure the new suggestion is displayed when the debounce timeout expires. executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); _ = editor.update(cx, |editor, _, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2\n\n\n\nc = 3\nd = 4 + c\n" @@ -880,7 +880,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\ntw\nthree\n"); }); @@ -907,7 +907,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\ntwo\nthree\n"); }); @@ -934,7 +934,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(editor.context_menu_visible()); - assert!(!editor.has_active_inline_completion(),); + assert!(!editor.has_active_edit_prediction(),); assert_eq!(editor.text(cx), "one\ntwo.\nthree\n"); }); } @@ -1023,7 +1023,7 @@ mod tests { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |selections| { selections.select_ranges([Point::new(0, 0)..Point::new(0, 0)]) }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); @@ -1033,7 +1033,7 @@ mod tests { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.select_ranges([Point::new(5, 0)..Point::new(5, 0)]) }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); diff --git a/crates/crashes/Cargo.toml b/crates/crashes/Cargo.toml new file mode 100644 index 0000000000..641a97765a --- /dev/null +++ b/crates/crashes/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "crashes" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "GPL-3.0-or-later" + +[dependencies] +crash-handler.workspace = true +log.workspace = true +minidumper.workspace = true +paths.workspace = true +smol.workspace = true +workspace-hack.workspace = true + +[lints] +workspace = true + +[lib] +path = "src/crashes.rs" diff --git a/crates/inline_completion/LICENSE-GPL b/crates/crashes/LICENSE-GPL similarity index 100% rename from crates/inline_completion/LICENSE-GPL rename to crates/crashes/LICENSE-GPL diff --git a/crates/crashes/src/crashes.rs b/crates/crashes/src/crashes.rs new file mode 100644 index 0000000000..cfb4b57d5d --- /dev/null +++ b/crates/crashes/src/crashes.rs @@ -0,0 +1,172 @@ +use crash_handler::CrashHandler; +use log::info; +use minidumper::{Client, LoopAction, MinidumpBinary}; + +use std::{ + env, + fs::File, + io, + path::{Path, PathBuf}, + process::{self, Command}, + sync::{ + OnceLock, + atomic::{AtomicBool, Ordering}, + }, + thread, + time::Duration, +}; + +// set once the crash handler has initialized and the client has connected to it +pub static CRASH_HANDLER: AtomicBool = AtomicBool::new(false); +// set when the first minidump request is made to avoid generating duplicate crash reports +pub static REQUESTED_MINIDUMP: AtomicBool = AtomicBool::new(false); +const CRASH_HANDLER_TIMEOUT: Duration = Duration::from_secs(60); + +pub async fn init(id: String) { + let exe = env::current_exe().expect("unable to find ourselves"); + let zed_pid = process::id(); + // TODO: we should be able to get away with using 1 crash-handler process per machine, + // but for now we append the PID of the current process which makes it unique per remote + // server or interactive zed instance. This solves an issue where occasionally the socket + // used by the crash handler isn't destroyed correctly which causes it to stay on the file + // system and block further attempts to initialize crash handlers with that socket path. + let socket_name = paths::temp_dir().join(format!("zed-crash-handler-{zed_pid}")); + #[allow(unused)] + let server_pid = Command::new(exe) + .arg("--crash-handler") + .arg(&socket_name) + .spawn() + .expect("unable to spawn server process") + .id(); + info!("spawning crash handler process"); + + let mut elapsed = Duration::ZERO; + let retry_frequency = Duration::from_millis(100); + let mut maybe_client = None; + while maybe_client.is_none() { + if let Ok(client) = Client::with_name(socket_name.as_path()) { + maybe_client = Some(client); + info!("connected to crash handler process after {elapsed:?}"); + break; + } + elapsed += retry_frequency; + smol::Timer::after(retry_frequency).await; + } + let client = maybe_client.unwrap(); + client.send_message(1, id).unwrap(); // set session id on the server + + let client = std::sync::Arc::new(client); + let handler = crash_handler::CrashHandler::attach(unsafe { + let client = client.clone(); + crash_handler::make_crash_event(move |crash_context: &crash_handler::CrashContext| { + // only request a minidump once + let res = if REQUESTED_MINIDUMP + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + client.send_message(2, "mistakes were made").unwrap(); + client.ping().unwrap(); + client.request_dump(crash_context).is_ok() + } else { + true + }; + crash_handler::CrashEventResult::Handled(res) + }) + }) + .expect("failed to attach signal handler"); + + #[cfg(target_os = "linux")] + { + handler.set_ptracer(Some(server_pid)); + } + CRASH_HANDLER.store(true, Ordering::Release); + std::mem::forget(handler); + info!("crash handler registered"); + + loop { + client.ping().ok(); + smol::Timer::after(Duration::from_secs(10)).await; + } +} + +pub struct CrashServer { + session_id: OnceLock, +} + +impl minidumper::ServerHandler for CrashServer { + fn create_minidump_file(&self) -> Result<(File, PathBuf), io::Error> { + let err_message = "Need to send a message with the ID upon starting the crash handler"; + let dump_path = paths::logs_dir() + .join(self.session_id.get().expect(err_message)) + .with_extension("dmp"); + let file = File::create(&dump_path)?; + Ok((file, dump_path)) + } + + fn on_minidump_created(&self, result: Result) -> LoopAction { + match result { + Ok(mut md_bin) => { + use io::Write; + let _ = md_bin.file.flush(); + info!("wrote minidump to disk {:?}", md_bin.path); + } + Err(e) => { + info!("failed to write minidump: {:#}", e); + } + } + LoopAction::Exit + } + + fn on_message(&self, kind: u32, buffer: Vec) { + let message = String::from_utf8(buffer).expect("invalid utf-8"); + info!("kind: {kind}, message: {message}",); + if kind == 1 { + self.session_id + .set(message) + .expect("session id already initialized"); + } + } + + fn on_client_disconnected(&self, clients: usize) -> LoopAction { + info!("client disconnected, {clients} remaining"); + if clients == 0 { + LoopAction::Exit + } else { + LoopAction::Continue + } + } +} + +pub fn handle_panic() { + // wait 500ms for the crash handler process to start up + // if it's still not there just write panic info and no minidump + let retry_frequency = Duration::from_millis(100); + for _ in 0..5 { + if CRASH_HANDLER.load(Ordering::Acquire) { + log::error!("triggering a crash to generate a minidump..."); + #[cfg(target_os = "linux")] + CrashHandler.simulate_signal(crash_handler::Signal::Trap as u32); + #[cfg(not(target_os = "linux"))] + CrashHandler.simulate_exception(None); + break; + } + thread::sleep(retry_frequency); + } +} + +pub fn crash_server(socket: &Path) { + let Ok(mut server) = minidumper::Server::with_name(socket) else { + log::info!("Couldn't create socket, there may already be a running crash server"); + return; + }; + let ab = AtomicBool::new(false); + server + .run( + Box::new(CrashServer { + session_id: OnceLock::new(), + }), + &ab, + Some(CRASH_HANDLER_TIMEOUT), + ) + .expect("failed to run server"); +} diff --git a/crates/dap/src/client.rs b/crates/dap/src/client.rs index 86a15b2d8a..7b791450ec 100644 --- a/crates/dap/src/client.rs +++ b/crates/dap/src/client.rs @@ -295,7 +295,7 @@ mod tests { request: dap_types::StartDebuggingRequestArgumentsRequest::Launch, }, }, - Box::new(|_| panic!("Did not expect to hit this code path")), + Box::new(|_| {}), &mut cx.to_async(), ) .await diff --git a/crates/dap/src/transport.rs b/crates/dap/src/transport.rs index 6dadf1cf35..f9fbbfc842 100644 --- a/crates/dap/src/transport.rs +++ b/crates/dap/src/transport.rs @@ -883,6 +883,7 @@ impl FakeTransport { break Err(anyhow!("exit in response to request")); } }; + let success = response.success; let message = serde_json::to_string(&Message::Response(response)).unwrap(); @@ -893,6 +894,25 @@ impl FakeTransport { ) .await .unwrap(); + + if request.command == dap_types::requests::Initialize::COMMAND + && success + { + let message = serde_json::to_string(&Message::Event(Box::new( + dap_types::messages::Events::Initialized(Some( + Default::default(), + )), + ))) + .unwrap(); + writer + .write_all( + TransportDelegate::build_rpc_message(message) + .as_bytes(), + ) + .await + .unwrap(); + } + writer.flush().await.unwrap(); } } diff --git a/crates/dap_adapters/src/python.rs b/crates/dap_adapters/src/python.rs index aa64fea6ed..461ce6fbb3 100644 --- a/crates/dap_adapters/src/python.rs +++ b/crates/dap_adapters/src/python.rs @@ -1,38 +1,36 @@ use crate::*; use anyhow::Context as _; use dap::{DebugRequest, StartDebuggingRequestArguments, adapters::DebugTaskDefinition}; +use fs::RemoveOptions; +use futures::{StreamExt, TryStreamExt}; +use gpui::http_client::AsyncBody; use gpui::{AsyncApp, SharedString}; use json_dotpath::DotPaths; use language::LanguageName; use paths::debug_adapters_dir; use serde_json::Value; +use smol::fs::File; +use smol::io::AsyncReadExt; use smol::lock::OnceCell; +use std::ffi::OsString; use std::net::Ipv4Addr; +use std::str::FromStr; use std::{ collections::HashMap, ffi::OsStr, path::{Path, PathBuf}, }; +use util::{ResultExt, maybe}; #[derive(Default)] pub(crate) struct PythonDebugAdapter { - python_venv_base: OnceCell, String>>, + debugpy_whl_base_path: OnceCell, String>>, } impl PythonDebugAdapter { const ADAPTER_NAME: &'static str = "Debugpy"; const DEBUG_ADAPTER_NAME: DebugAdapterName = DebugAdapterName(SharedString::new_static(Self::ADAPTER_NAME)); - const PYTHON_ADAPTER_IN_VENV: &'static str = if cfg!(target_os = "windows") { - "Scripts/python3" - } else { - "bin/python3" - }; - const ADAPTER_PATH: &'static str = if cfg!(target_os = "windows") { - "debugpy-venv/Scripts/debugpy-adapter" - } else { - "debugpy-venv/bin/debugpy-adapter" - }; const LANGUAGE_NAME: &'static str = "Python"; @@ -41,7 +39,6 @@ impl PythonDebugAdapter { port: u16, user_installed_path: Option<&Path>, user_args: Option>, - installed_in_venv: bool, ) -> Result> { let mut args = if let Some(user_installed_path) = user_installed_path { log::debug!( @@ -49,13 +46,11 @@ impl PythonDebugAdapter { user_installed_path.display() ); vec![user_installed_path.to_string_lossy().to_string()] - } else if installed_in_venv { - log::debug!("Using venv-installed debugpy"); - vec!["-m".to_string(), "debugpy.adapter".to_string()] } else { let adapter_path = paths::debug_adapters_dir().join(Self::DEBUG_ADAPTER_NAME.as_ref()); let path = adapter_path - .join(Self::ADAPTER_PATH) + .join("debugpy") + .join("adapter") .to_string_lossy() .into_owned(); log::debug!("Using pip debugpy adapter from: {path}"); @@ -96,68 +91,145 @@ impl PythonDebugAdapter { }) } - async fn ensure_venv(delegate: &dyn DapDelegate) -> Result> { - let python_path = Self::find_base_python(delegate) + async fn fetch_wheel(delegate: &Arc) -> Result, String> { + let system_python = Self::system_python_name(delegate) .await - .context("Could not find Python installation for DebugPy")?; - let work_dir = debug_adapters_dir().join(Self::ADAPTER_NAME); - let mut path = work_dir.clone(); - path.push("debugpy-venv"); - if !path.exists() { - util::command::new_smol_command(python_path) - .arg("-m") - .arg("venv") - .arg("debugpy-venv") - .current_dir(work_dir) - .spawn()? - .output() - .await?; + .ok_or_else(|| String::from("Could not find a Python installation"))?; + let command: &OsStr = system_python.as_ref(); + let download_dir = debug_adapters_dir().join(Self::ADAPTER_NAME).join("wheels"); + std::fs::create_dir_all(&download_dir).map_err(|e| e.to_string())?; + let installation_succeeded = util::command::new_smol_command(command) + .args([ + "-m", + "pip", + "download", + "debugpy", + "--only-binary=:all:", + "-d", + download_dir.to_string_lossy().as_ref(), + ]) + .output() + .await + .map_err(|e| format!("{e}"))? + .status + .success(); + if !installation_succeeded { + return Err("debugpy installation failed".into()); } - Ok(path.into()) + let wheel_path = std::fs::read_dir(&download_dir) + .map_err(|e| e.to_string())? + .find_map(|entry| { + entry.ok().filter(|e| { + e.file_type().is_ok_and(|typ| typ.is_file()) + && Path::new(&e.file_name()).extension() == Some("whl".as_ref()) + }) + }) + .ok_or_else(|| String::from("Did not find a .whl in {download_dir}"))?; + + util::archive::extract_zip( + &debug_adapters_dir().join(Self::ADAPTER_NAME), + File::open(&wheel_path.path()) + .await + .map_err(|e| e.to_string())?, + ) + .await + .map_err(|e| e.to_string())?; + + Ok(Arc::from(wheel_path.path())) } - // Find "baseline", user python version from which we'll create our own venv. - async fn find_base_python(delegate: &dyn DapDelegate) -> Option { - for path in ["python3", "python"] { - if let Some(path) = delegate.which(path.as_ref()).await { - return Some(path); + async fn maybe_fetch_new_wheel(delegate: &Arc) { + let latest_release = delegate + .http_client() + .get( + "https://pypi.org/pypi/debugpy/json", + AsyncBody::empty(), + false, + ) + .await + .log_err(); + maybe!(async move { + let response = latest_release.filter(|response| response.status().is_success())?; + + let mut output = String::new(); + response + .into_body() + .read_to_string(&mut output) + .await + .ok()?; + let as_json = serde_json::Value::from_str(&output).ok()?; + let latest_version = as_json.get("info").and_then(|info| { + info.get("version") + .and_then(|version| version.as_str()) + .map(ToOwned::to_owned) + })?; + let dist_info_dirname: OsString = format!("debugpy-{latest_version}.dist-info").into(); + let is_up_to_date = delegate + .fs() + .read_dir(&debug_adapters_dir().join(Self::ADAPTER_NAME)) + .await + .ok()? + .into_stream() + .any(async |entry| { + entry.is_ok_and(|e| e.file_name().is_some_and(|name| name == dist_info_dirname)) + }) + .await; + + if !is_up_to_date { + delegate + .fs() + .remove_dir( + &debug_adapters_dir().join(Self::ADAPTER_NAME), + RemoveOptions { + recursive: true, + ignore_if_not_exists: true, + }, + ) + .await + .ok()?; + Self::fetch_wheel(delegate).await.ok()?; } - } - None + Some(()) + }) + .await; } - async fn base_venv(&self, delegate: &dyn DapDelegate) -> Result, String> { - const BINARY_DIR: &str = if cfg!(target_os = "windows") { - "Scripts" - } else { - "bin" - }; - self.python_venv_base - .get_or_init(move || async move { - let venv_base = Self::ensure_venv(delegate) - .await - .map_err(|e| format!("{e}"))?; - let pip_path = venv_base.join(BINARY_DIR).join("pip3"); - let installation_succeeded = util::command::new_smol_command(pip_path.as_path()) - .arg("install") - .arg("debugpy") - .arg("-U") - .output() - .await - .map_err(|e| format!("{e}"))? - .status - .success(); - if !installation_succeeded { - return Err("debugpy installation failed".into()); - } - - Ok(venv_base) + async fn fetch_debugpy_whl( + &self, + delegate: &Arc, + ) -> Result, String> { + self.debugpy_whl_base_path + .get_or_init(|| async move { + Self::maybe_fetch_new_wheel(delegate).await; + Ok(Arc::from( + debug_adapters_dir() + .join(Self::ADAPTER_NAME) + .join("debugpy") + .join("adapter") + .as_ref(), + )) }) .await .clone() } + async fn system_python_name(delegate: &Arc) -> Option { + const BINARY_NAMES: [&str; 3] = ["python3", "python", "py"]; + let mut name = None; + + for cmd in BINARY_NAMES { + name = delegate + .which(OsStr::new(cmd)) + .await + .map(|path| path.to_string_lossy().to_string()); + if name.is_some() { + break; + } + } + name + } + async fn get_installed_binary( &self, delegate: &Arc, @@ -165,27 +237,14 @@ impl PythonDebugAdapter { user_installed_path: Option, user_args: Option>, python_from_toolchain: Option, - installed_in_venv: bool, ) -> Result { - const BINARY_NAMES: [&str; 3] = ["python3", "python", "py"]; let tcp_connection = config.tcp_connection.clone().unwrap_or_default(); let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?; let python_path = if let Some(toolchain) = python_from_toolchain { Some(toolchain) } else { - let mut name = None; - - for cmd in BINARY_NAMES { - name = delegate - .which(OsStr::new(cmd)) - .await - .map(|path| path.to_string_lossy().to_string()); - if name.is_some() { - break; - } - } - name + Self::system_python_name(delegate).await }; let python_command = python_path.context("failed to find binary path for Python")?; @@ -196,7 +255,6 @@ impl PythonDebugAdapter { port, user_installed_path.as_deref(), user_args, - installed_in_venv, ) .await?; @@ -618,63 +676,53 @@ impl DebugAdapter for PythonDebugAdapter { local_path.display() ); return self - .get_installed_binary( - delegate, - &config, - Some(local_path.clone()), - user_args, - None, - false, - ) + .get_installed_binary(delegate, &config, Some(local_path.clone()), user_args, None) .await; } + let base_path = config + .config + .get("cwd") + .and_then(|cwd| { + cwd.as_str() + .map(Path::new)? + .strip_prefix(delegate.worktree_root_path()) + .ok() + }) + .unwrap_or_else(|| "".as_ref()) + .into(); let toolchain = delegate .toolchain_store() .active_toolchain( delegate.worktree_id(), - Arc::from("".as_ref()), + base_path, language::LanguageName::new(Self::LANGUAGE_NAME), cx, ) .await; - if let Some(toolchain) = &toolchain { - if let Some(path) = Path::new(&toolchain.path.to_string()).parent() { - let debugpy_path = path.join("debugpy"); - if delegate.fs().is_file(&debugpy_path).await { - log::debug!( - "Found debugpy in toolchain environment: {}", - debugpy_path.display() - ); - return self - .get_installed_binary( - delegate, - &config, - None, - user_args, - Some(toolchain.path.to_string()), - true, - ) - .await; - } - } - } - let toolchain = self - .base_venv(&**delegate) + let debugpy_path = self + .fetch_debugpy_whl(delegate) .await - .map_err(|e| anyhow::anyhow!(e))? - .join(Self::PYTHON_ADAPTER_IN_VENV); + .map_err(|e| anyhow::anyhow!("{e}"))?; + if let Some(toolchain) = &toolchain { + log::debug!( + "Found debugpy in toolchain environment: {}", + debugpy_path.display() + ); + return self + .get_installed_binary( + delegate, + &config, + None, + user_args, + Some(toolchain.path.to_string()), + ) + .await; + } - self.get_installed_binary( - delegate, - &config, - None, - user_args, - Some(toolchain.to_string_lossy().into_owned()), - false, - ) - .await + self.get_installed_binary(delegate, &config, None, user_args, None) + .await } fn label_for_child_session(&self, args: &StartDebuggingRequestArguments) -> Option { @@ -689,6 +737,8 @@ impl DebugAdapter for PythonDebugAdapter { #[cfg(test)] mod tests { + use util::path; + use super::*; use std::{net::Ipv4Addr, path::PathBuf}; @@ -699,30 +749,24 @@ mod tests { // Case 1: User-defined debugpy path (highest precedence) let user_path = PathBuf::from("/custom/path/to/debugpy/src/debugpy/adapter"); - let user_args = PythonDebugAdapter::generate_debugpy_arguments( - &host, - port, - Some(&user_path), - None, - false, - ) - .await - .unwrap(); - - // Case 2: Venv-installed debugpy (uses -m debugpy.adapter) - let venv_args = - PythonDebugAdapter::generate_debugpy_arguments(&host, port, None, None, true) + let user_args = + PythonDebugAdapter::generate_debugpy_arguments(&host, port, Some(&user_path), None) .await .unwrap(); + // Case 2: Venv-installed debugpy (uses -m debugpy.adapter) + let venv_args = PythonDebugAdapter::generate_debugpy_arguments(&host, port, None, None) + .await + .unwrap(); + assert_eq!(user_args[0], "/custom/path/to/debugpy/src/debugpy/adapter"); assert_eq!(user_args[1], "--host=127.0.0.1"); assert_eq!(user_args[2], "--port=5678"); - assert_eq!(venv_args[0], "-m"); - assert_eq!(venv_args[1], "debugpy.adapter"); - assert_eq!(venv_args[2], "--host=127.0.0.1"); - assert_eq!(venv_args[3], "--port=5678"); + let expected_suffix = path!("debug_adapters/Debugpy/debugpy/adapter"); + assert!(venv_args[0].ends_with(expected_suffix)); + assert_eq!(venv_args[1], "--host=127.0.0.1"); + assert_eq!(venv_args[2], "--port=5678"); // The same cases, with arguments overridden by the user let user_args = PythonDebugAdapter::generate_debugpy_arguments( @@ -730,7 +774,6 @@ mod tests { port, Some(&user_path), Some(vec!["foo".into()]), - false, ) .await .unwrap(); @@ -739,7 +782,6 @@ mod tests { port, None, Some(vec!["foo".into()]), - true, ) .await .unwrap(); @@ -747,9 +789,8 @@ mod tests { assert!(user_args[0].ends_with("src/debugpy/adapter")); assert_eq!(user_args[1], "foo"); - assert_eq!(venv_args[0], "-m"); - assert_eq!(venv_args[1], "debugpy.adapter"); - assert_eq!(venv_args[2], "foo"); + assert!(venv_args[0].ends_with(expected_suffix)); + assert_eq!(venv_args[1], "foo"); // Note: Case 3 (GitHub-downloaded debugpy) is not tested since this requires mocking the Github API. } diff --git a/crates/diagnostics/src/diagnostics_tests.rs b/crates/diagnostics/src/diagnostics_tests.rs index 1364aaf853..1bb84488e8 100644 --- a/crates/diagnostics/src/diagnostics_tests.rs +++ b/crates/diagnostics/src/diagnostics_tests.rs @@ -873,7 +873,7 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S editor.splice_inlays( &[], - vec![Inlay::inline_completion( + vec![Inlay::edit_prediction( post_inc(&mut next_inlay_id), snapshot.buffer_snapshot.anchor_before(position), format!("Test inlay {next_inlay_id}"), diff --git a/crates/docs_preprocessor/Cargo.toml b/crates/docs_preprocessor/Cargo.toml index a0df669abe..e46ceb18db 100644 --- a/crates/docs_preprocessor/Cargo.toml +++ b/crates/docs_preprocessor/Cargo.toml @@ -7,17 +7,19 @@ license = "GPL-3.0-or-later" [dependencies] anyhow.workspace = true -clap.workspace = true -mdbook = "0.4.40" +command_palette.workspace = true +gpui.workspace = true +# We are specifically pinning this version of mdbook, as later versions introduce issues with double-nested subdirectories. +# Ask @maxdeviant about this before bumping. +mdbook = "= 0.4.40" +regex.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true -regex.workspace = true util.workspace = true workspace-hack.workspace = true zed.workspace = true -gpui.workspace = true -command_palette.workspace = true +zlog.workspace = true [lints] workspace = true diff --git a/crates/docs_preprocessor/src/main.rs b/crates/docs_preprocessor/src/main.rs index 8eeeb6f0c5..1448f4cb52 100644 --- a/crates/docs_preprocessor/src/main.rs +++ b/crates/docs_preprocessor/src/main.rs @@ -1,14 +1,15 @@ -use anyhow::Result; -use clap::{Arg, ArgMatches, Command}; +use anyhow::{Context, Result}; use mdbook::BookItem; use mdbook::book::{Book, Chapter}; use mdbook::preprocess::CmdPreprocessor; use regex::Regex; use settings::KeymapFile; -use std::collections::HashSet; +use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; use std::io::{self, Read}; use std::process; use std::sync::LazyLock; +use util::paths::PathExt; static KEYMAP_MACOS: LazyLock = LazyLock::new(|| { load_keymap("keymaps/default-macos.json").expect("Failed to load MacOS keymap") @@ -20,60 +21,68 @@ static KEYMAP_LINUX: LazyLock = LazyLock::new(|| { static ALL_ACTIONS: LazyLock> = LazyLock::new(dump_all_gpui_actions); -pub fn make_app() -> Command { - Command::new("zed-docs-preprocessor") - .about("Preprocesses Zed Docs content to provide rich action & keybinding support and more") - .subcommand( - Command::new("supports") - .arg(Arg::new("renderer").required(true)) - .about("Check whether a renderer is supported by this preprocessor"), - ) -} +const FRONT_MATTER_COMMENT: &'static str = ""; fn main() -> Result<()> { - let matches = make_app().get_matches(); + zlog::init(); + zlog::init_output_stderr(); // call a zed:: function so everything in `zed` crate is linked and // all actions in the actual app are registered zed::stdout_is_a_pty(); + let args = std::env::args().skip(1).collect::>(); - if let Some(sub_args) = matches.subcommand_matches("supports") { - handle_supports(sub_args); - } else { - handle_preprocessing()?; + match args.get(0).map(String::as_str) { + Some("supports") => { + let renderer = args.get(1).expect("Required argument"); + let supported = renderer != "not-supported"; + if supported { + process::exit(0); + } else { + process::exit(1); + } + } + Some("postprocess") => handle_postprocessing()?, + _ => handle_preprocessing()?, } Ok(()) } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -enum Error { +enum PreprocessorError { ActionNotFound { action_name: String }, DeprecatedActionUsed { used: String, should_be: String }, + InvalidFrontmatterLine(String), } -impl Error { +impl PreprocessorError { fn new_for_not_found_action(action_name: String) -> Self { for action in &*ALL_ACTIONS { for alias in action.deprecated_aliases { if alias == &action_name { - return Error::DeprecatedActionUsed { + return PreprocessorError::DeprecatedActionUsed { used: action_name.clone(), should_be: action.name.to_string(), }; } } } - Error::ActionNotFound { + PreprocessorError::ActionNotFound { action_name: action_name.to_string(), } } } -impl std::fmt::Display for Error { +impl std::fmt::Display for PreprocessorError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Error::ActionNotFound { action_name } => write!(f, "Action not found: {}", action_name), - Error::DeprecatedActionUsed { used, should_be } => write!( + PreprocessorError::InvalidFrontmatterLine(line) => { + write!(f, "Invalid frontmatter line: {}", line) + } + PreprocessorError::ActionNotFound { action_name } => { + write!(f, "Action not found: {}", action_name) + } + PreprocessorError::DeprecatedActionUsed { used, should_be } => write!( f, "Deprecated action used: {} should be {}", used, should_be @@ -89,8 +98,9 @@ fn handle_preprocessing() -> Result<()> { let (_ctx, mut book) = CmdPreprocessor::parse_input(input.as_bytes())?; - let mut errors = HashSet::::new(); + let mut errors = HashSet::::new(); + handle_frontmatter(&mut book, &mut errors); template_and_validate_keybindings(&mut book, &mut errors); template_and_validate_actions(&mut book, &mut errors); @@ -108,19 +118,41 @@ fn handle_preprocessing() -> Result<()> { Ok(()) } -fn handle_supports(sub_args: &ArgMatches) -> ! { - let renderer = sub_args - .get_one::("renderer") - .expect("Required argument"); - let supported = renderer != "not-supported"; - if supported { - process::exit(0); - } else { - process::exit(1); - } +fn handle_frontmatter(book: &mut Book, errors: &mut HashSet) { + let frontmatter_regex = Regex::new(r"(?s)^\s*---(.*?)---").unwrap(); + for_each_chapter_mut(book, |chapter| { + let new_content = frontmatter_regex.replace(&chapter.content, |caps: ®ex::Captures| { + let frontmatter = caps[1].trim(); + let frontmatter = frontmatter.trim_matches(&[' ', '-', '\n']); + let mut metadata = HashMap::::default(); + for line in frontmatter.lines() { + let Some((name, value)) = line.split_once(':') else { + errors.insert(PreprocessorError::InvalidFrontmatterLine(format!( + "{}: {}", + chapter_breadcrumbs(&chapter), + line + ))); + continue; + }; + let name = name.trim(); + let value = value.trim(); + metadata.insert(name.to_string(), value.to_string()); + } + FRONT_MATTER_COMMENT.replace( + "{}", + &serde_json::to_string(&metadata).expect("Failed to serialize metadata"), + ) + }); + match new_content { + Cow::Owned(content) => { + chapter.content = content; + } + Cow::Borrowed(_) => {} + } + }); } -fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { +fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { let regex = Regex::new(r"\{#kb (.*?)\}").unwrap(); for_each_chapter_mut(book, |chapter| { @@ -128,7 +160,9 @@ fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { +fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet) { let regex = Regex::new(r"\{#action (.*?)\}").unwrap(); for_each_chapter_mut(book, |chapter| { @@ -152,7 +186,9 @@ fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet) { .replace_all(&chapter.content, |caps: ®ex::Captures| { let name = caps[1].trim(); let Some(action) = find_action_by_name(name) else { - errors.insert(Error::new_for_not_found_action(name.to_string())); + errors.insert(PreprocessorError::new_for_not_found_action( + name.to_string(), + )); return String::new(); }; format!("{}", &action.human_name) @@ -217,6 +253,13 @@ fn name_for_action(action_as_str: String) -> String { .unwrap_or(action_as_str) } +fn chapter_breadcrumbs(chapter: &Chapter) -> String { + let mut breadcrumbs = Vec::with_capacity(chapter.parent_names.len() + 1); + breadcrumbs.extend(chapter.parent_names.iter().map(String::as_str)); + breadcrumbs.push(chapter.name.as_str()); + format!("[{:?}] {}", chapter.source_path, breadcrumbs.join(" > ")) +} + fn load_keymap(asset_path: &str) -> Result { let content = util::asset_str::(asset_path); KeymapFile::parse(content.as_ref()) @@ -254,3 +297,126 @@ fn dump_all_gpui_actions() -> Vec { return actions; } + +fn handle_postprocessing() -> Result<()> { + let logger = zlog::scoped!("render"); + let mut ctx = mdbook::renderer::RenderContext::from_json(io::stdin())?; + let output = ctx + .config + .get_mut("output") + .expect("has output") + .as_table_mut() + .expect("output is table"); + let zed_html = output.remove("zed-html").expect("zed-html output defined"); + let default_description = zed_html + .get("default-description") + .expect("Default description not found") + .as_str() + .expect("Default description not a string") + .to_string(); + let default_title = zed_html + .get("default-title") + .expect("Default title not found") + .as_str() + .expect("Default title not a string") + .to_string(); + + output.insert("html".to_string(), zed_html); + mdbook::Renderer::render(&mdbook::renderer::HtmlHandlebars::new(), &ctx)?; + let ignore_list = ["toc.html"]; + + let root_dir = ctx.destination.clone(); + let mut files = Vec::with_capacity(128); + let mut queue = Vec::with_capacity(64); + queue.push(root_dir.clone()); + while let Some(dir) = queue.pop() { + for entry in std::fs::read_dir(&dir).context(dir.to_sanitized_string())? { + let Ok(entry) = entry else { + continue; + }; + let file_type = entry.file_type().context("Failed to determine file type")?; + if file_type.is_dir() { + queue.push(entry.path()); + } + if file_type.is_file() + && matches!( + entry.path().extension().and_then(std::ffi::OsStr::to_str), + Some("html") + ) + { + if ignore_list.contains(&&*entry.file_name().to_string_lossy()) { + zlog::info!(logger => "Ignoring {}", entry.path().to_string_lossy()); + } else { + files.push(entry.path()); + } + } + } + } + + zlog::info!(logger => "Processing {} `.html` files", files.len()); + let meta_regex = Regex::new(&FRONT_MATTER_COMMENT.replace("{}", "(.*)")).unwrap(); + for file in files { + let contents = std::fs::read_to_string(&file)?; + let mut meta_description = None; + let mut meta_title = None; + let contents = meta_regex.replace(&contents, |caps: ®ex::Captures| { + let metadata: HashMap = serde_json::from_str(&caps[1]).with_context(|| format!("JSON Metadata: {:?}", &caps[1])).expect("Failed to deserialize metadata"); + for (kind, content) in metadata { + match kind.as_str() { + "description" => { + meta_description = Some(content); + } + "title" => { + meta_title = Some(content); + } + _ => { + zlog::warn!(logger => "Unrecognized frontmatter key: {} in {:?}", kind, pretty_path(&file, &root_dir)); + } + } + } + String::new() + }); + let meta_description = meta_description.as_ref().unwrap_or_else(|| { + zlog::warn!(logger => "No meta description found for {:?}", pretty_path(&file, &root_dir)); + &default_description + }); + let page_title = extract_title_from_page(&contents, pretty_path(&file, &root_dir)); + let meta_title = meta_title.as_ref().unwrap_or_else(|| { + zlog::debug!(logger => "No meta title found for {:?}", pretty_path(&file, &root_dir)); + &default_title + }); + let meta_title = format!("{} | {}", page_title, meta_title); + zlog::trace!(logger => "Updating {:?}", pretty_path(&file, &root_dir)); + let contents = contents.replace("#description#", meta_description); + let contents = TITLE_REGEX + .replace(&contents, |_: ®ex::Captures| { + format!("{}", meta_title) + }) + .to_string(); + // let contents = contents.replace("#title#", &meta_title); + std::fs::write(file, contents)?; + } + return Ok(()); + + fn pretty_path<'a>( + path: &'a std::path::PathBuf, + root: &'a std::path::PathBuf, + ) -> &'a std::path::Path { + &path.strip_prefix(&root).unwrap_or(&path) + } + const TITLE_REGEX: std::cell::LazyCell = + std::cell::LazyCell::new(|| Regex::new(r"\s*(.*?)\s*").unwrap()); + fn extract_title_from_page(contents: &str, pretty_path: &std::path::Path) -> String { + let title_tag_contents = &TITLE_REGEX + .captures(&contents) + .with_context(|| format!("Failed to find title in {:?}", pretty_path)) + .expect("Page has element")[1]; + let title = title_tag_contents + .trim() + .strip_suffix("- Zed") + .unwrap_or(title_tag_contents) + .trim() + .to_string(); + title + } +} diff --git a/crates/inline_completion/Cargo.toml b/crates/edit_prediction/Cargo.toml similarity index 82% rename from crates/inline_completion/Cargo.toml rename to crates/edit_prediction/Cargo.toml index 3a90875def..81c1e5dec2 100644 --- a/crates/inline_completion/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "inline_completion" +name = "edit_prediction" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,7 +9,7 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/inline_completion.rs" +path = "src/edit_prediction.rs" [dependencies] client.workspace = true diff --git a/crates/inline_completion_button/LICENSE-GPL b/crates/edit_prediction/LICENSE-GPL similarity index 100% rename from crates/inline_completion_button/LICENSE-GPL rename to crates/edit_prediction/LICENSE-GPL diff --git a/crates/inline_completion/src/inline_completion.rs b/crates/edit_prediction/src/edit_prediction.rs similarity index 95% rename from crates/inline_completion/src/inline_completion.rs rename to crates/edit_prediction/src/edit_prediction.rs index c8f35bf16a..fd4e9bb21d 100644 --- a/crates/inline_completion/src/inline_completion.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -7,7 +7,7 @@ use project::Project; // TODO: Find a better home for `Direction`. // -// This should live in an ancestor crate of `editor` and `inline_completion`, +// This should live in an ancestor crate of `editor` and `edit_prediction`, // but at time of writing there isn't an obvious spot. #[derive(Copy, Clone, PartialEq, Eq)] pub enum Direction { @@ -16,7 +16,7 @@ pub enum Direction { } #[derive(Clone)] -pub struct InlineCompletion { +pub struct EditPrediction { /// The ID of the completion, if it has one. pub id: Option<SharedString>, pub edits: Vec<(Range<language::Anchor>, String)>, @@ -102,10 +102,10 @@ pub trait EditPredictionProvider: 'static + Sized { buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut Context<Self>, - ) -> Option<InlineCompletion>; + ) -> Option<EditPrediction>; } -pub trait InlineCompletionProviderHandle { +pub trait EditPredictionProviderHandle { fn name(&self) -> &'static str; fn display_name(&self) -> &'static str; fn is_enabled( @@ -143,10 +143,10 @@ pub trait InlineCompletionProviderHandle { buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut App, - ) -> Option<InlineCompletion>; + ) -> Option<EditPrediction>; } -impl<T> InlineCompletionProviderHandle for Entity<T> +impl<T> EditPredictionProviderHandle for Entity<T> where T: EditPredictionProvider, { @@ -233,7 +233,7 @@ where buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut App, - ) -> Option<InlineCompletion> { + ) -> Option<EditPrediction> { self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx)) } } diff --git a/crates/inline_completion_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml similarity index 86% rename from crates/inline_completion_button/Cargo.toml rename to crates/edit_prediction_button/Cargo.toml index c2a619d500..07447280fa 100644 --- a/crates/inline_completion_button/Cargo.toml +++ b/crates/edit_prediction_button/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "inline_completion_button" +name = "edit_prediction_button" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,21 +9,23 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/inline_completion_button.rs" +path = "src/edit_prediction_button.rs" doctest = false [dependencies] anyhow.workspace = true client.workspace = true +cloud_llm_client.workspace = true copilot.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true gpui.workspace = true indoc.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true language.workspace = true paths.workspace = true +project.workspace = true regex.workspace = true settings.workspace = true supermaven.workspace = true @@ -32,7 +34,6 @@ ui.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true zeta.workspace = true [dev-dependencies] diff --git a/crates/edit_prediction_button/LICENSE-GPL b/crates/edit_prediction_button/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/edit_prediction_button/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs similarity index 95% rename from crates/inline_completion_button/src/inline_completion_button.rs rename to crates/edit_prediction_button/src/edit_prediction_button.rs index 2615a8beef..9ab94a4095 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -1,11 +1,8 @@ use anyhow::Result; -use client::{DisableAiSettings, UserStore, zed_urls}; +use client::{UserStore, zed_urls}; +use cloud_llm_client::UsageLimit; use copilot::{Copilot, Status}; -use editor::{ - Editor, SelectionEffects, - actions::{ShowEditPrediction, ToggleEditPrediction}, - scroll::Autoscroll, -}; +use editor::{Editor, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll}; use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag}; use fs::Fs; use gpui::{ @@ -18,6 +15,7 @@ use language::{ EditPredictionsMode, File, Language, language_settings::{self, AllLanguageSettings, EditPredictionProvider, all_language_settings}, }; +use project::DisableAiSettings; use regex::Regex; use settings::{Settings, SettingsStore, update_settings_file}; use std::{ @@ -34,13 +32,12 @@ use workspace::{ notifications::NotificationId, }; use zed_actions::OpenBrowser; -use zed_llm_client::UsageLimit; use zeta::RateCompletions; actions!( edit_prediction, [ - /// Toggles the inline completion menu. + /// Toggles the edit prediction menu. ToggleMenu ] ); @@ -50,14 +47,14 @@ const PRIVACY_DOCS: &str = "https://zed.dev/docs/ai/privacy-and-security"; struct CopilotErrorToast; -pub struct InlineCompletionButton { +pub struct EditPredictionButton { editor_subscription: Option<(Subscription, usize)>, editor_enabled: Option<bool>, editor_show_predictions: bool, editor_focus_handle: Option<FocusHandle>, language: Option<Arc<Language>>, file: Option<Arc<dyn File>>, - edit_prediction_provider: Option<Arc<dyn inline_completion::InlineCompletionProviderHandle>>, + edit_prediction_provider: Option<Arc<dyn edit_prediction::EditPredictionProviderHandle>>, fs: Arc<dyn Fs>, user_store: Entity<UserStore>, popover_menu_handle: PopoverMenuHandle<ContextMenu>, @@ -70,7 +67,7 @@ enum SupermavenButtonStatus { Initializing, } -impl Render for InlineCompletionButton { +impl Render for EditPredictionButton { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { // Return empty div if AI is disabled if DisableAiSettings::get_global(cx).disable_ai { @@ -246,12 +243,15 @@ impl Render for InlineCompletionButton { }; if zeta::should_show_upsell_modal(&self.user_store, cx) { - let tooltip_meta = - match self.user_store.read(cx).current_user_has_accepted_terms() { - Some(true) => "Choose a Plan", - Some(false) => "Accept the Terms of Service", - None => "Sign In", - }; + let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { + if self.user_store.read(cx).has_accepted_terms_of_service() { + "Choose a Plan" + } else { + "Accept the Terms of Service" + } + } else { + "Sign In" + }; return div().child( IconButton::new("zed-predict-pending-button", zeta_icon) @@ -365,7 +365,7 @@ impl Render for InlineCompletionButton { } } -impl InlineCompletionButton { +impl EditPredictionButton { pub fn new( fs: Arc<dyn Fs>, user_store: Entity<UserStore>, @@ -387,9 +387,9 @@ impl InlineCompletionButton { language: None, file: None, edit_prediction_provider: None, + user_store, popover_menu_handle, fs, - user_store, } } @@ -437,9 +437,13 @@ impl InlineCompletionButton { if let Some(editor_focus_handle) = self.editor_focus_handle.clone() { let entry = ContextMenuEntry::new("This Buffer") .toggleable(IconPosition::Start, self.editor_show_predictions) - .action(Box::new(ToggleEditPrediction)) + .action(Box::new(editor::actions::ToggleEditPrediction)) .handler(move |window, cx| { - editor_focus_handle.dispatch_action(&ToggleEditPrediction, window, cx); + editor_focus_handle.dispatch_action( + &editor::actions::ToggleEditPrediction, + window, + cx, + ); }); match language_state.clone() { @@ -466,7 +470,7 @@ impl InlineCompletionButton { IconPosition::Start, None, move |_, cx| { - toggle_show_inline_completions_for_language(language.clone(), fs.clone(), cx) + toggle_show_edit_predictions_for_language(language.clone(), fs.clone(), cx) }, ); } @@ -474,10 +478,13 @@ impl InlineCompletionButton { let settings = AllLanguageSettings::get_global(cx); let globally_enabled = settings.show_edit_predictions(None, cx); - menu = menu.toggleable_entry("All Files", globally_enabled, IconPosition::Start, None, { - let fs = fs.clone(); - move |_, cx| toggle_inline_completions_globally(fs.clone(), cx) - }); + let entry = ContextMenuEntry::new("All Files") + .toggleable(IconPosition::Start, globally_enabled) + .action(workspace::ToggleEditPrediction.boxed_clone()) + .handler(|window, cx| { + window.dispatch_action(workspace::ToggleEditPrediction.boxed_clone(), cx) + }); + menu = menu.item(entry); let provider = settings.edit_predictions.provider; let current_mode = settings.edit_predictions_mode(); @@ -831,7 +838,7 @@ impl InlineCompletionButton { } } -impl StatusItemView for InlineCompletionButton { +impl StatusItemView for EditPredictionButton { fn set_active_pane_item( &mut self, item: Option<&dyn ItemHandle>, @@ -901,7 +908,7 @@ async fn open_disabled_globs_setting_in_editor( let settings = cx.global::<SettingsStore>(); - // Ensure that we always have "inline_completions { "disabled_globs": [] }" + // Ensure that we always have "edit_predictions { "disabled_globs": [] }" let edits = settings.edits_for_update::<AllLanguageSettings>(&text, |file| { file.edit_predictions .get_or_insert_with(Default::default) @@ -939,13 +946,6 @@ async fn open_disabled_globs_setting_in_editor( anyhow::Ok(()) } -fn toggle_inline_completions_globally(fs: Arc<dyn Fs>, cx: &mut App) { - let show_edit_predictions = all_language_settings(None, cx).show_edit_predictions(None, cx); - update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| { - file.defaults.show_edit_predictions = Some(!show_edit_predictions) - }); -} - fn set_completion_provider(fs: Arc<dyn Fs>, cx: &mut App, provider: EditPredictionProvider) { update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| { file.features @@ -954,7 +954,7 @@ fn set_completion_provider(fs: Arc<dyn Fs>, cx: &mut App, provider: EditPredicti }); } -fn toggle_show_inline_completions_for_language( +fn toggle_show_edit_predictions_for_language( language: Arc<Language>, fs: Arc<dyn Fs>, cx: &mut App, diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 0692c7fbe6..339f98ae8b 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -22,6 +22,7 @@ test-support = [ "theme/test-support", "util/test-support", "workspace/test-support", + "tree-sitter-c", "tree-sitter-rust", "tree-sitter-typescript", "tree-sitter-html", @@ -47,7 +48,7 @@ fs.workspace = true git.workspace = true gpui.workspace = true indoc.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true itertools.workspace = true language.workspace = true linkify.workspace = true @@ -76,6 +77,7 @@ telemetry.workspace = true text.workspace = true time.workspace = true theme.workspace = true +tree-sitter-c = { workspace = true, optional = true } tree-sitter-html = { workspace = true, optional = true } tree-sitter-rust = { workspace = true, optional = true } tree-sitter-typescript = { workspace = true, optional = true } @@ -106,6 +108,7 @@ settings = { workspace = true, features = ["test-support"] } tempfile.workspace = true text = { workspace = true, features = ["test-support"] } theme = { workspace = true, features = ["test-support"] } +tree-sitter-c.workspace = true tree-sitter-html.workspace = true tree-sitter-rust.workspace = true tree-sitter-typescript.workspace = true diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index 1212651cb3..3a3a57ca64 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -315,9 +315,8 @@ actions!( [ /// Accepts the full edit prediction. AcceptEditPrediction, - /// Accepts a partial Copilot suggestion. - AcceptPartialCopilotSuggestion, /// Accepts a partial edit prediction. + #[action(deprecated_aliases = ["editor::AcceptPartialCopilotSuggestion"])] AcceptPartialEditPrediction, /// Adds a cursor above the current selection. AddSelectionAbove, diff --git a/crates/editor/src/code_completion_tests.rs b/crates/editor/src/code_completion_tests.rs index 4f9822b597..fd8db29584 100644 --- a/crates/editor/src/code_completion_tests.rs +++ b/crates/editor/src/code_completion_tests.rs @@ -94,7 +94,7 @@ async fn test_fuzzy_score(cx: &mut TestAppContext) { filter_and_sort_matches("set_text", &completions, SnippetSortOrder::Top, cx).await; assert_eq!(matches[0].string, "set_text"); assert_eq!(matches[1].string, "set_text_style_refinement"); - assert_eq!(matches[2].string, "set_context_menu_options"); + assert_eq!(matches[2].string, "set_placeholder_text"); } // fuzzy filter text over label, sort_text and sort_kind @@ -216,6 +216,28 @@ async fn test_sort_positions(cx: &mut TestAppContext) { assert_eq!(matches[0].string, "rounded-full"); } +#[gpui::test] +async fn test_fuzzy_over_sort_positions(cx: &mut TestAppContext) { + let completions = vec![ + CompletionBuilder::variable("lsp_document_colors", None, "7fffffff"), // 0.29 fuzzy score + CompletionBuilder::function( + "language_servers_running_disk_based_diagnostics", + None, + "7fffffff", + ), // 0.168 fuzzy score + CompletionBuilder::function("code_lens", None, "7fffffff"), // 3.2 fuzzy score + CompletionBuilder::variable("lsp_code_lens", None, "7fffffff"), // 3.2 fuzzy score + CompletionBuilder::function("fetch_code_lens", None, "7fffffff"), // 3.2 fuzzy score + ]; + + let matches = + filter_and_sort_matches("lens", &completions, SnippetSortOrder::default(), cx).await; + + assert_eq!(matches[0].string, "code_lens"); + assert_eq!(matches[1].string, "lsp_code_lens"); + assert_eq!(matches[2].string, "fetch_code_lens"); +} + async fn test_for_each_prefix<F>( target: &str, completions: &Vec<Completion>, diff --git a/crates/editor/src/code_context_menus.rs b/crates/editor/src/code_context_menus.rs index 52446ceafc..4ae2a14ca7 100644 --- a/crates/editor/src/code_context_menus.rs +++ b/crates/editor/src/code_context_menus.rs @@ -844,7 +844,7 @@ impl CompletionsMenu { .with_sizing_behavior(ListSizingBehavior::Infer) .w(rems(34.)); - Popover::new().child(div().child(list)).into_any_element() + Popover::new().child(list).into_any_element() } fn render_aside( @@ -1057,9 +1057,9 @@ impl CompletionsMenu { enum MatchTier<'a> { WordStartMatch { sort_exact: Reverse<i32>, - sort_positions: Vec<usize>, sort_snippet: Reverse<i32>, sort_score: Reverse<OrderedFloat<f64>>, + sort_positions: Vec<usize>, sort_text: Option<&'a str>, sort_kind: usize, sort_label: &'a str, @@ -1137,9 +1137,9 @@ impl CompletionsMenu { MatchTier::WordStartMatch { sort_exact, - sort_positions, sort_snippet, sort_score, + sort_positions, sort_text, sort_kind, sort_label, diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index 5425d5a8b9..a16e516a70 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -635,7 +635,7 @@ pub(crate) struct Highlights<'a> { } #[derive(Clone, Copy, Debug)] -pub struct InlineCompletionStyles { +pub struct EditPredictionStyles { pub insertion: HighlightStyle, pub whitespace: HighlightStyle, } @@ -643,7 +643,7 @@ pub struct InlineCompletionStyles { #[derive(Default, Debug, Clone, Copy)] pub struct HighlightStyles { pub inlay_hint: Option<HighlightStyle>, - pub inline_completion: Option<InlineCompletionStyles>, + pub edit_prediction: Option<EditPredictionStyles>, } #[derive(Clone)] @@ -958,7 +958,7 @@ impl DisplaySnapshot { language_aware, HighlightStyles { inlay_hint: Some(editor_style.inlay_hints_style), - inline_completion: Some(editor_style.inline_completion_styles), + edit_prediction: Some(editor_style.edit_prediction_styles), }, ) .flat_map(|chunk| { @@ -2036,7 +2036,7 @@ pub mod tests { map.update(cx, |map, cx| { map.splice_inlays( &[], - vec![Inlay::inline_completion( + vec![Inlay::edit_prediction( 0, buffer_snapshot.anchor_after(0), "\n", diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index 85495a2611..e25c02432d 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -22,7 +22,7 @@ use std::{ atomic::{AtomicUsize, Ordering::SeqCst}, }, }; -use sum_tree::{Bias, SumTree, Summary, TreeMap}; +use sum_tree::{Bias, Dimensions, SumTree, Summary, TreeMap}; use text::{BufferId, Edit}; use ui::ElementId; @@ -416,7 +416,7 @@ struct TransformSummary { } pub struct BlockChunks<'a> { - transforms: sum_tree::Cursor<'a, Transform, (BlockRow, WrapRow)>, + transforms: sum_tree::Cursor<'a, Transform, Dimensions<BlockRow, WrapRow>>, input_chunks: wrap_map::WrapChunks<'a>, input_chunk: Chunk<'a>, output_row: u32, @@ -426,7 +426,7 @@ pub struct BlockChunks<'a> { #[derive(Clone)] pub struct BlockRows<'a> { - transforms: sum_tree::Cursor<'a, Transform, (BlockRow, WrapRow)>, + transforms: sum_tree::Cursor<'a, Transform, Dimensions<BlockRow, WrapRow>>, input_rows: wrap_map::WrapRows<'a>, output_row: BlockRow, started: bool, @@ -970,7 +970,7 @@ impl BlockMapReader<'_> { .unwrap_or(self.wrap_snapshot.max_point().row() + 1), ); - let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<WrapRow, BlockRow>>(&()); cursor.seek(&start_wrap_row, Bias::Left); while let Some(transform) = cursor.item() { if cursor.start().0 > end_wrap_row { @@ -1292,7 +1292,7 @@ impl BlockSnapshot { ) -> BlockChunks<'a> { let max_output_row = cmp::min(rows.end, self.transforms.summary().output_rows); - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&BlockRow(rows.start), Bias::Right); let transform_output_start = cursor.start().0.0; let transform_input_start = cursor.start().1.0; @@ -1324,9 +1324,9 @@ impl BlockSnapshot { } pub(super) fn row_infos(&self, start_row: BlockRow) -> BlockRows<'_> { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&start_row, Bias::Right); - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = if cursor .item() .map_or(false, |transform| transform.block.is_none()) @@ -1441,14 +1441,14 @@ impl BlockSnapshot { } pub fn longest_row_in_range(&self, range: Range<BlockRow>) -> BlockRow { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&range.start, Bias::Right); let mut longest_row = range.start; let mut longest_row_chars = 0; if let Some(transform) = cursor.item() { if transform.block.is_none() { - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = range.start.0 - output_start.0; let wrap_start_row = input_start.0 + overshoot; let wrap_end_row = cmp::min( @@ -1474,7 +1474,7 @@ impl BlockSnapshot { if let Some(transform) = cursor.item() { if transform.block.is_none() { - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = range.end.0 - output_start.0; let wrap_start_row = input_start.0; let wrap_end_row = input_start.0 + overshoot; @@ -1492,10 +1492,10 @@ impl BlockSnapshot { } pub(super) fn line_len(&self, row: BlockRow) -> u32 { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&BlockRow(row.0), Bias::Right); if let Some(transform) = cursor.item() { - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = row.0 - output_start.0; if transform.block.is_some() { 0 @@ -1510,13 +1510,13 @@ impl BlockSnapshot { } pub(super) fn is_block_line(&self, row: BlockRow) -> bool { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&row, Bias::Right); cursor.item().map_or(false, |t| t.block.is_some()) } pub(super) fn is_folded_buffer_header(&self, row: BlockRow) -> bool { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&row, Bias::Right); let Some(transform) = cursor.item() else { return false; @@ -1528,7 +1528,7 @@ impl BlockSnapshot { let wrap_point = self .wrap_snapshot .make_wrap_point(Point::new(row.0, 0), Bias::Left); - let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<WrapRow, BlockRow>>(&()); cursor.seek(&WrapRow(wrap_point.row()), Bias::Right); cursor.item().map_or(false, |transform| { transform @@ -1539,7 +1539,7 @@ impl BlockSnapshot { } pub fn clip_point(&self, point: BlockPoint, bias: Bias) -> BlockPoint { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&BlockRow(point.row), Bias::Right); let max_input_row = WrapRow(self.transforms.summary().input_rows); @@ -1549,8 +1549,8 @@ impl BlockSnapshot { loop { if let Some(transform) = cursor.item() { - let (output_start_row, input_start_row) = cursor.start(); - let (output_end_row, input_end_row) = cursor.end(); + let Dimensions(output_start_row, input_start_row, _) = cursor.start(); + let Dimensions(output_end_row, input_end_row, _) = cursor.end(); let output_start = Point::new(output_start_row.0, 0); let input_start = Point::new(input_start_row.0, 0); let input_end = Point::new(input_end_row.0, 0); @@ -1599,13 +1599,13 @@ impl BlockSnapshot { } pub fn to_block_point(&self, wrap_point: WrapPoint) -> BlockPoint { - let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<WrapRow, BlockRow>>(&()); cursor.seek(&WrapRow(wrap_point.row()), Bias::Right); if let Some(transform) = cursor.item() { if transform.block.is_some() { BlockPoint::new(cursor.start().1.0, 0) } else { - let (input_start_row, output_start_row) = cursor.start(); + let Dimensions(input_start_row, output_start_row, _) = cursor.start(); let input_start = Point::new(input_start_row.0, 0); let output_start = Point::new(output_start_row.0, 0); let input_overshoot = wrap_point.0 - input_start; @@ -1617,7 +1617,7 @@ impl BlockSnapshot { } pub fn to_wrap_point(&self, block_point: BlockPoint, bias: Bias) -> WrapPoint { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&BlockRow(block_point.row), Bias::Right); if let Some(transform) = cursor.item() { match transform.block.as_ref() { diff --git a/crates/editor/src/display_map/fold_map.rs b/crates/editor/src/display_map/fold_map.rs index 829d34ff58..c4e53a0f43 100644 --- a/crates/editor/src/display_map/fold_map.rs +++ b/crates/editor/src/display_map/fold_map.rs @@ -17,7 +17,7 @@ use std::{ sync::Arc, usize, }; -use sum_tree::{Bias, Cursor, FilterCursor, SumTree, Summary, TreeMap}; +use sum_tree::{Bias, Cursor, Dimensions, FilterCursor, SumTree, Summary, TreeMap}; use ui::IntoElement as _; use util::post_inc; @@ -98,7 +98,9 @@ impl FoldPoint { } pub fn to_inlay_point(self, snapshot: &FoldSnapshot) -> InlayPoint { - let mut cursor = snapshot.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); + let mut cursor = snapshot + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().0.0; InlayPoint(cursor.start().1.0 + overshoot) @@ -107,7 +109,7 @@ impl FoldPoint { pub fn to_offset(self, snapshot: &FoldSnapshot) -> FoldOffset { let mut cursor = snapshot .transforms - .cursor::<(FoldPoint, TransformSummary)>(&()); + .cursor::<Dimensions<FoldPoint, TransformSummary>>(&()); cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().1.output.lines; let mut offset = cursor.start().1.output.len; @@ -567,8 +569,9 @@ impl FoldMap { let mut old_transforms = self .snapshot .transforms - .cursor::<(InlayOffset, FoldOffset)>(&()); - let mut new_transforms = new_transforms.cursor::<(InlayOffset, FoldOffset)>(&()); + .cursor::<Dimensions<InlayOffset, FoldOffset>>(&()); + let mut new_transforms = + new_transforms.cursor::<Dimensions<InlayOffset, FoldOffset>>(&()); for mut edit in inlay_edits { old_transforms.seek(&edit.old.start, Bias::Left); @@ -651,7 +654,9 @@ impl FoldSnapshot { pub fn text_summary_for_range(&self, range: Range<FoldPoint>) -> TextSummary { let mut summary = TextSummary::default(); - let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); cursor.seek(&range.start, Bias::Right); if let Some(transform) = cursor.item() { let start_in_transform = range.start.0 - cursor.start().0.0; @@ -700,7 +705,9 @@ impl FoldSnapshot { } pub fn to_fold_point(&self, point: InlayPoint, bias: Bias) -> FoldPoint { - let mut cursor = self.transforms.cursor::<(InlayPoint, FoldPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayPoint, FoldPoint>>(&()); cursor.seek(&point, Bias::Right); if cursor.item().map_or(false, |t| t.is_fold()) { if bias == Bias::Left || point == cursor.start().0 { @@ -734,7 +741,9 @@ impl FoldSnapshot { } let fold_point = FoldPoint::new(start_row, 0); - let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); cursor.seek(&fold_point, Bias::Left); let overshoot = fold_point.0 - cursor.start().0.0; @@ -816,7 +825,9 @@ impl FoldSnapshot { language_aware: bool, highlights: Highlights<'a>, ) -> FoldChunks<'a> { - let mut transform_cursor = self.transforms.cursor::<(FoldOffset, InlayOffset)>(&()); + let mut transform_cursor = self + .transforms + .cursor::<Dimensions<FoldOffset, InlayOffset>>(&()); transform_cursor.seek(&range.start, Bias::Right); let inlay_start = { @@ -871,7 +882,9 @@ impl FoldSnapshot { } pub fn clip_point(&self, point: FoldPoint, bias: Bias) -> FoldPoint { - let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); cursor.seek(&point, Bias::Right); if let Some(transform) = cursor.item() { let transform_start = cursor.start().0.0; @@ -1196,7 +1209,7 @@ impl<'a> sum_tree::Dimension<'a, FoldSummary> for usize { #[derive(Clone)] pub struct FoldRows<'a> { - cursor: Cursor<'a, Transform, (FoldPoint, InlayPoint)>, + cursor: Cursor<'a, Transform, Dimensions<FoldPoint, InlayPoint>>, input_rows: InlayBufferRows<'a>, fold_point: FoldPoint, } @@ -1313,7 +1326,7 @@ impl DerefMut for ChunkRendererContext<'_, '_> { } pub struct FoldChunks<'a> { - transform_cursor: Cursor<'a, Transform, (FoldOffset, InlayOffset)>, + transform_cursor: Cursor<'a, Transform, Dimensions<FoldOffset, InlayOffset>>, inlay_chunks: InlayChunks<'a>, inlay_chunk: Option<(InlayOffset, InlayChunk<'a>)>, inlay_offset: InlayOffset, @@ -1448,7 +1461,7 @@ impl FoldOffset { pub fn to_point(self, snapshot: &FoldSnapshot) -> FoldPoint { let mut cursor = snapshot .transforms - .cursor::<(FoldOffset, TransformSummary)>(&()); + .cursor::<Dimensions<FoldOffset, TransformSummary>>(&()); cursor.seek(&self, Bias::Right); let overshoot = if cursor.item().map_or(true, |t| t.is_fold()) { Point::new(0, (self.0 - cursor.start().0.0) as u32) @@ -1462,7 +1475,9 @@ impl FoldOffset { #[cfg(test)] pub fn to_inlay_offset(self, snapshot: &FoldSnapshot) -> InlayOffset { - let mut cursor = snapshot.transforms.cursor::<(FoldOffset, InlayOffset)>(&()); + let mut cursor = snapshot + .transforms + .cursor::<Dimensions<FoldOffset, InlayOffset>>(&()); cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().0.0; InlayOffset(cursor.start().1.0 + overshoot) diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index a36d18ff6d..fd49c262c6 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -10,7 +10,7 @@ use std::{ ops::{Add, AddAssign, Range, Sub, SubAssign}, sync::Arc, }; -use sum_tree::{Bias, Cursor, SumTree}; +use sum_tree::{Bias, Cursor, Dimensions, SumTree}; use text::{Patch, Rope}; use ui::{ActiveTheme, IntoElement as _, ParentElement as _, Styled as _, div}; @@ -81,9 +81,9 @@ impl Inlay { } } - pub fn inline_completion<T: Into<Rope>>(id: usize, position: Anchor, text: T) -> Self { + pub fn edit_prediction<T: Into<Rope>>(id: usize, position: Anchor, text: T) -> Self { Self { - id: InlayId::InlineCompletion(id), + id: InlayId::EditPrediction(id), position, text: text.into(), color: None, @@ -235,14 +235,14 @@ impl<'a> sum_tree::Dimension<'a, TransformSummary> for Point { #[derive(Clone)] pub struct InlayBufferRows<'a> { - transforms: Cursor<'a, Transform, (InlayPoint, Point)>, + transforms: Cursor<'a, Transform, Dimensions<InlayPoint, Point>>, buffer_rows: MultiBufferRows<'a>, inlay_row: u32, max_buffer_row: MultiBufferRow, } pub struct InlayChunks<'a> { - transforms: Cursor<'a, Transform, (InlayOffset, usize)>, + transforms: Cursor<'a, Transform, Dimensions<InlayOffset, usize>>, buffer_chunks: CustomHighlightsChunks<'a>, buffer_chunk: Option<Chunk<'a>>, inlay_chunks: Option<text::Chunks<'a>>, @@ -340,15 +340,13 @@ impl<'a> Iterator for InlayChunks<'a> { let mut renderer = None; let mut highlight_style = match inlay.id { - InlayId::InlineCompletion(_) => { - self.highlight_styles.inline_completion.map(|s| { - if inlay.text.chars().all(|c| c.is_whitespace()) { - s.whitespace - } else { - s.insertion - } - }) - } + InlayId::EditPrediction(_) => self.highlight_styles.edit_prediction.map(|s| { + if inlay.text.chars().all(|c| c.is_whitespace()) { + s.whitespace + } else { + s.insertion + } + }), InlayId::Hint(_) => self.highlight_styles.inlay_hint, InlayId::DebuggerValue(_) => self.highlight_styles.inlay_hint, InlayId::Color(_) => { @@ -553,7 +551,9 @@ impl InlayMap { } else { let mut inlay_edits = Patch::default(); let mut new_transforms = SumTree::default(); - let mut cursor = snapshot.transforms.cursor::<(usize, InlayOffset)>(&()); + let mut cursor = snapshot + .transforms + .cursor::<Dimensions<usize, InlayOffset>>(&()); let mut buffer_edits_iter = buffer_edits.iter().peekable(); while let Some(buffer_edit) = buffer_edits_iter.next() { new_transforms.append(cursor.slice(&buffer_edit.old.start, Bias::Left), &()); @@ -740,7 +740,7 @@ impl InlayMap { text.clone(), ) } else { - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(next_inlay_id), snapshot.buffer.anchor_at(position, bias), text.clone(), @@ -772,20 +772,20 @@ impl InlaySnapshot { pub fn to_point(&self, offset: InlayOffset) -> InlayPoint { let mut cursor = self .transforms - .cursor::<(InlayOffset, (InlayPoint, usize))>(&()); + .cursor::<Dimensions<InlayOffset, InlayPoint, usize>>(&()); cursor.seek(&offset, Bias::Right); let overshoot = offset.0 - cursor.start().0.0; match cursor.item() { Some(Transform::Isomorphic(_)) => { - let buffer_offset_start = cursor.start().1.1; + let buffer_offset_start = cursor.start().2; let buffer_offset_end = buffer_offset_start + overshoot; let buffer_start = self.buffer.offset_to_point(buffer_offset_start); let buffer_end = self.buffer.offset_to_point(buffer_offset_end); - InlayPoint(cursor.start().1.0.0 + (buffer_end - buffer_start)) + InlayPoint(cursor.start().1.0 + (buffer_end - buffer_start)) } Some(Transform::Inlay(inlay)) => { let overshoot = inlay.text.offset_to_point(overshoot); - InlayPoint(cursor.start().1.0.0 + overshoot) + InlayPoint(cursor.start().1.0 + overshoot) } None => self.max_point(), } @@ -802,26 +802,26 @@ impl InlaySnapshot { pub fn to_offset(&self, point: InlayPoint) -> InlayOffset { let mut cursor = self .transforms - .cursor::<(InlayPoint, (InlayOffset, Point))>(&()); + .cursor::<Dimensions<InlayPoint, InlayOffset, Point>>(&()); cursor.seek(&point, Bias::Right); let overshoot = point.0 - cursor.start().0.0; match cursor.item() { Some(Transform::Isomorphic(_)) => { - let buffer_point_start = cursor.start().1.1; + let buffer_point_start = cursor.start().2; let buffer_point_end = buffer_point_start + overshoot; let buffer_offset_start = self.buffer.point_to_offset(buffer_point_start); let buffer_offset_end = self.buffer.point_to_offset(buffer_point_end); - InlayOffset(cursor.start().1.0.0 + (buffer_offset_end - buffer_offset_start)) + InlayOffset(cursor.start().1.0 + (buffer_offset_end - buffer_offset_start)) } Some(Transform::Inlay(inlay)) => { let overshoot = inlay.text.point_to_offset(overshoot); - InlayOffset(cursor.start().1.0.0 + overshoot) + InlayOffset(cursor.start().1.0 + overshoot) } None => self.len(), } } pub fn to_buffer_point(&self, point: InlayPoint) -> Point { - let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(&()); cursor.seek(&point, Bias::Right); match cursor.item() { Some(Transform::Isomorphic(_)) => { @@ -833,7 +833,9 @@ impl InlaySnapshot { } } pub fn to_buffer_offset(&self, offset: InlayOffset) -> usize { - let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayOffset, usize>>(&()); cursor.seek(&offset, Bias::Right); match cursor.item() { Some(Transform::Isomorphic(_)) => { @@ -846,7 +848,9 @@ impl InlaySnapshot { } pub fn to_inlay_offset(&self, offset: usize) -> InlayOffset { - let mut cursor = self.transforms.cursor::<(usize, InlayOffset)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<usize, InlayOffset>>(&()); cursor.seek(&offset, Bias::Left); loop { match cursor.item() { @@ -879,7 +883,7 @@ impl InlaySnapshot { } } pub fn to_inlay_point(&self, point: Point) -> InlayPoint { - let mut cursor = self.transforms.cursor::<(Point, InlayPoint)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<Point, InlayPoint>>(&()); cursor.seek(&point, Bias::Left); loop { match cursor.item() { @@ -913,7 +917,7 @@ impl InlaySnapshot { } pub fn clip_point(&self, mut point: InlayPoint, mut bias: Bias) -> InlayPoint { - let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(&()); cursor.seek(&point, Bias::Left); loop { match cursor.item() { @@ -1010,7 +1014,9 @@ impl InlaySnapshot { pub fn text_summary_for_range(&self, range: Range<InlayOffset>) -> TextSummary { let mut summary = TextSummary::default(); - let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayOffset, usize>>(&()); cursor.seek(&range.start, Bias::Right); let overshoot = range.start.0 - cursor.start().0.0; @@ -1058,7 +1064,7 @@ impl InlaySnapshot { } pub fn row_infos(&self, row: u32) -> InlayBufferRows<'_> { - let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(&()); let inlay_point = InlayPoint::new(row, 0); cursor.seek(&inlay_point, Bias::Left); @@ -1100,7 +1106,9 @@ impl InlaySnapshot { language_aware: bool, highlights: Highlights<'a>, ) -> InlayChunks<'a> { - let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayOffset, usize>>(&()); cursor.seek(&range.start, Bias::Right); let buffer_range = self.to_buffer_offset(range.start)..self.to_buffer_offset(range.end); @@ -1389,7 +1397,7 @@ mod tests { buffer.read(cx).snapshot(cx).anchor_before(3), "|123|", ), - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut next_inlay_id), buffer.read(cx).snapshot(cx).anchor_after(3), "|456|", @@ -1609,7 +1617,7 @@ mod tests { buffer.read(cx).snapshot(cx).anchor_before(4), "|456|", ), - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut next_inlay_id), buffer.read(cx).snapshot(cx).anchor_before(7), "\n|567|\n", diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index d55577826e..269f8f0c40 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -9,7 +9,7 @@ use multi_buffer::{MultiBufferSnapshot, RowInfo}; use smol::future::yield_now; use std::sync::LazyLock; use std::{cmp, collections::VecDeque, mem, ops::Range, time::Duration}; -use sum_tree::{Bias, Cursor, SumTree}; +use sum_tree::{Bias, Cursor, Dimensions, SumTree}; use text::Patch; pub use super::tab_map::TextSummary; @@ -55,7 +55,7 @@ pub struct WrapChunks<'a> { input_chunk: Chunk<'a>, output_position: WrapPoint, max_output_row: u32, - transforms: Cursor<'a, Transform, (WrapPoint, TabPoint)>, + transforms: Cursor<'a, Transform, Dimensions<WrapPoint, TabPoint>>, snapshot: &'a WrapSnapshot, } @@ -66,7 +66,7 @@ pub struct WrapRows<'a> { output_row: u32, soft_wrapped: bool, max_output_row: u32, - transforms: Cursor<'a, Transform, (WrapPoint, TabPoint)>, + transforms: Cursor<'a, Transform, Dimensions<WrapPoint, TabPoint>>, } impl WrapRows<'_> { @@ -598,7 +598,9 @@ impl WrapSnapshot { ) -> WrapChunks<'a> { let output_start = WrapPoint::new(rows.start, 0); let output_end = WrapPoint::new(rows.end, 0); - let mut transforms = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut transforms = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); transforms.seek(&output_start, Bias::Right); let mut input_start = TabPoint(transforms.start().1.0); if transforms.item().map_or(false, |t| t.is_isomorphic()) { @@ -626,7 +628,9 @@ impl WrapSnapshot { } pub fn line_len(&self, row: u32) -> u32 { - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Left); if cursor .item() @@ -651,7 +655,9 @@ impl WrapSnapshot { let start = WrapPoint::new(rows.start, 0); let end = WrapPoint::new(rows.end, 0); - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&start, Bias::Right); if let Some(transform) = cursor.item() { let start_in_transform = start.0 - cursor.start().0.0; @@ -721,7 +727,9 @@ impl WrapSnapshot { } pub fn row_infos(&self, start_row: u32) -> WrapRows<'_> { - let mut transforms = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut transforms = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); transforms.seek(&WrapPoint::new(start_row, 0), Bias::Left); let mut input_row = transforms.start().1.row(); if transforms.item().map_or(false, |t| t.is_isomorphic()) { @@ -741,7 +749,9 @@ impl WrapSnapshot { } pub fn to_tab_point(&self, point: WrapPoint) -> TabPoint { - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&point, Bias::Right); let mut tab_point = cursor.start().1.0; if cursor.item().map_or(false, |t| t.is_isomorphic()) { @@ -759,7 +769,9 @@ impl WrapSnapshot { } pub fn tab_point_to_wrap_point(&self, point: TabPoint) -> WrapPoint { - let mut cursor = self.transforms.cursor::<(TabPoint, WrapPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<TabPoint, WrapPoint>>(&()); cursor.seek(&point, Bias::Right); WrapPoint(cursor.start().1.0 + (point.0 - cursor.start().0.0)) } @@ -784,7 +796,9 @@ impl WrapSnapshot { *point.column_mut() = 0; - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&point, Bias::Right); if cursor.item().is_none() { cursor.prev(); @@ -804,7 +818,9 @@ impl WrapSnapshot { pub fn next_row_boundary(&self, mut point: WrapPoint) -> Option<u32> { point.0 += Point::new(1, 0); - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&point, Bias::Right); while let Some(transform) = cursor.item() { if transform.is_isomorphic() && cursor.start().1.column() == 0 { diff --git a/crates/editor/src/inline_completion_tests.rs b/crates/editor/src/edit_prediction_tests.rs similarity index 81% rename from crates/editor/src/inline_completion_tests.rs rename to crates/editor/src/edit_prediction_tests.rs index 5ac34c94f5..527dfb8832 100644 --- a/crates/editor/src/inline_completion_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -1,26 +1,26 @@ +use edit_prediction::EditPredictionProvider; use gpui::{Entity, prelude::*}; use indoc::indoc; -use inline_completion::EditPredictionProvider; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; use project::Project; use std::ops::Range; use text::{Point, ToOffset}; use crate::{ - InlineCompletion, editor_tests::init_test, test::editor_test_context::EditorTestContext, + EditPrediction, editor_tests::init_test, test::editor_test_context::EditorTestContext, }; #[gpui::test] -async fn test_inline_completion_insert(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let absolute_zero_celsius = ˇ;"); propose_edits(&provider, vec![(28..28, "-273.15")], &mut cx); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_edit_completion(&mut cx, |_, edits| { assert_eq!(edits.len(), 1); @@ -33,16 +33,16 @@ async fn test_inline_completion_insert(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_inline_completion_modification(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let pi = ˇ\"foo\";"); propose_edits(&provider, vec![(9..14, "3.14159")], &mut cx); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_edit_completion(&mut cx, |_, edits| { assert_eq!(edits.len(), 1); @@ -55,11 +55,11 @@ async fn test_inline_completion_modification(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_jump_button(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 2+ lines above the proposed edit @@ -77,7 +77,7 @@ async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), Point::new(4, 3)); }); @@ -107,7 +107,7 @@ async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), Point::new(1, 3)); }); @@ -124,11 +124,11 @@ async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 3+ lines above the proposed edit @@ -148,7 +148,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), edit_location); }); @@ -176,7 +176,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext line "}); cx.editor(|editor, _, _| { - assert!(editor.active_inline_completion.is_none()); + assert!(editor.active_edit_prediction.is_none()); }); // Cursor is 3+ lines below the proposed edit @@ -196,7 +196,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), edit_location); }); @@ -224,7 +224,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext line ˇ5 "}); cx.editor(|editor, _, _| { - assert!(editor.active_inline_completion.is_none()); + assert!(editor.active_edit_prediction.is_none()); }); } @@ -234,11 +234,11 @@ fn assert_editor_active_edit_completion( ) { cx.editor(|editor, _, cx| { let completion_state = editor - .active_inline_completion + .active_edit_prediction .as_ref() .expect("editor has no active completion"); - if let InlineCompletion::Edit { edits, .. } = &completion_state.completion { + if let EditPrediction::Edit { edits, .. } = &completion_state.completion { assert(editor.buffer().read(cx).snapshot(cx), edits); } else { panic!("expected edit completion"); @@ -252,11 +252,11 @@ fn assert_editor_active_move_completion( ) { cx.editor(|editor, _, cx| { let completion_state = editor - .active_inline_completion + .active_edit_prediction .as_ref() .expect("editor has no active completion"); - if let InlineCompletion::Move { target, .. } = &completion_state.completion { + if let EditPrediction::Move { target, .. } = &completion_state.completion { assert(editor.buffer().read(cx).snapshot(cx), *target); } else { panic!("expected move completion"); @@ -271,7 +271,7 @@ fn accept_completion(cx: &mut EditorTestContext) { } fn propose_edits<T: ToOffset>( - provider: &Entity<FakeInlineCompletionProvider>, + provider: &Entity<FakeEditPredictionProvider>, edits: Vec<(Range<T>, &str)>, cx: &mut EditorTestContext, ) { @@ -283,7 +283,7 @@ fn propose_edits<T: ToOffset>( cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_inline_completion(Some(inline_completion::InlineCompletion { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction { id: None, edits: edits.collect(), edit_preview: None, @@ -293,7 +293,7 @@ fn propose_edits<T: ToOffset>( } fn assign_editor_completion_provider( - provider: Entity<FakeInlineCompletionProvider>, + provider: Entity<FakeEditPredictionProvider>, cx: &mut EditorTestContext, ) { cx.update_editor(|editor, window, cx| { @@ -302,20 +302,17 @@ fn assign_editor_completion_provider( } #[derive(Default, Clone)] -pub struct FakeInlineCompletionProvider { - pub completion: Option<inline_completion::InlineCompletion>, +pub struct FakeEditPredictionProvider { + pub completion: Option<edit_prediction::EditPrediction>, } -impl FakeInlineCompletionProvider { - pub fn set_inline_completion( - &mut self, - completion: Option<inline_completion::InlineCompletion>, - ) { +impl FakeEditPredictionProvider { + pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) { self.completion = completion; } } -impl EditPredictionProvider for FakeInlineCompletionProvider { +impl EditPredictionProvider for FakeEditPredictionProvider { fn name() -> &'static str { "fake-completion-provider" } @@ -355,7 +352,7 @@ impl EditPredictionProvider for FakeInlineCompletionProvider { &mut self, _buffer: gpui::Entity<language::Buffer>, _cursor_position: language::Anchor, - _direction: inline_completion::Direction, + _direction: edit_prediction::Direction, _cx: &mut gpui::Context<Self>, ) { } @@ -369,7 +366,7 @@ impl EditPredictionProvider for FakeInlineCompletionProvider { _buffer: &gpui::Entity<language::Buffer>, _cursor_position: language::Anchor, _cx: &mut gpui::Context<Self>, - ) -> Option<inline_completion::InlineCompletion> { + ) -> Option<edit_prediction::EditPrediction> { self.completion.clone() } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 8f57fb1a20..ff9b703d66 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -43,50 +43,65 @@ pub mod tasks; #[cfg(test)] mod code_completion_tests; #[cfg(test)] -mod editor_tests; +mod edit_prediction_tests; #[cfg(test)] -mod inline_completion_tests; +mod editor_tests; mod signature_help; #[cfg(any(test, feature = "test-support"))] pub mod test; pub(crate) use actions::*; -pub use actions::{AcceptEditPrediction, OpenExcerpts, OpenExcerptsSplit}; +pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; +pub use edit_prediction::Direction; +pub use editor_settings::{ + CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, + ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap, ShowScrollbar, +}; +pub use editor_settings_controls::*; +pub use element::{ + CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition, +}; +pub use git::blame::BlameRenderer; +pub use hover_popover::hover_markdown_style; +pub use items::MAX_TAB_TITLE_LEN; +pub use lsp::CompletionContext; +pub use lsp_ext::lsp_tasks; +pub use multi_buffer::{ + Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, PathKey, + RowInfo, ToOffset, ToPoint, +}; +pub use proposed_changes_editor::{ + ProposedChangeLocation, ProposedChangesEditor, ProposedChangesEditorToolbar, +}; +pub use text::Bias; + +use ::git::{ + Restore, + blame::{BlameEntry, ParsedCommitMessage}, +}; use aho_corasick::AhoCorasick; use anyhow::{Context as _, Result, anyhow}; use blink_manager::BlinkManager; use buffer_diff::DiffHunkStatus; use client::{Collaborator, ParticipantIndex}; use clock::{AGENT_REPLICA_ID, ReplicaId}; +use code_context_menus::{ + AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu, + CompletionsMenu, ContextMenuOrigin, +}; use collections::{BTreeMap, HashMap, HashSet, VecDeque}; use convert_case::{Case, Casing}; use dap::TelemetrySpawnLocation; use display_map::*; -pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; -pub use editor_settings::{ - CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, - ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowScrollbar, -}; +use edit_prediction::{EditPredictionProvider, EditPredictionProviderHandle}; use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings}; -pub use editor_settings_controls::*; use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layout_line}; -pub use element::{ - CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition, -}; use futures::{ FutureExt, StreamExt as _, future::{self, Shared, join}, stream::FuturesUnordered, }; use fuzzy::{StringMatch, StringMatchCandidate}; -use lsp_colors::LspColorData; - -use ::git::blame::BlameEntry; -use ::git::{Restore, blame::ParsedCommitMessage}; -use code_context_menus::{ - AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu, - CompletionsMenu, ContextMenuOrigin, -}; use git::blame::{GitBlame, GlobalBlameRenderer}; use gpui::{ Action, Animation, AnimationExt, AnyElement, App, AppContext, AsyncWindowContext, @@ -100,32 +115,42 @@ use gpui::{ }; use highlight_matching_bracket::refresh_matching_bracket_highlights; use hover_links::{HoverLink, HoveredLinkState, InlayHighlight, find_file}; -pub use hover_popover::hover_markdown_style; use hover_popover::{HoverState, hide_hover}; use indent_guides::ActiveIndentGuidesState; use inlay_hint_cache::{InlayHintCache, InlaySplice, InvalidationStrategy}; -pub use inline_completion::Direction; -use inline_completion::{EditPredictionProvider, InlineCompletionProviderHandle}; -pub use items::MAX_TAB_TITLE_LEN; use itertools::Itertools; use language::{ - AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, Capability, CharKind, - CodeLabel, CursorShape, DiagnosticEntry, DiffOptions, EditPredictionsMode, EditPreview, - HighlightedText, IndentKind, IndentSize, Language, OffsetRangeExt, Point, Selection, - SelectionGoal, TextObject, TransactionId, TreeSitterOptions, WordsQuery, + AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, BufferRow, + BufferSnapshot, Capability, CharClassifier, CharKind, CodeLabel, CursorShape, DiagnosticEntry, + DiffOptions, EditPredictionsMode, EditPreview, HighlightedText, IndentKind, IndentSize, + Language, OffsetRangeExt, Point, Runnable, RunnableRange, Selection, SelectionGoal, TextObject, + TransactionId, TreeSitterOptions, WordsQuery, language_settings::{ self, InlayHintSettings, LspInsertMode, RewrapBehavior, WordsCompletionMode, all_language_settings, language_settings, }, - point_from_lsp, text_diff_with_options, + point_from_lsp, point_to_lsp, text_diff_with_options, }; -use language::{BufferRow, CharClassifier, Runnable, RunnableRange, point_to_lsp}; use linked_editing_ranges::refresh_linked_ranges; +use lsp::{ + CodeActionKind, CompletionItemKind, CompletionTriggerKind, InsertTextFormat, InsertTextMode, + LanguageServerId, +}; +use lsp_colors::LspColorData; use markdown::Markdown; use mouse_context_menu::MouseContextMenu; +use movement::TextLayoutDetails; +use multi_buffer::{ + ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow, + MultiOrSingleBufferOffsetRange, ToOffsetUtf16, +}; +use parking_lot::Mutex; use persistence::DB; use project::{ - BreakpointWithPosition, CompletionResponse, ProjectPath, + BreakpointWithPosition, CodeAction, Completion, CompletionIntent, CompletionResponse, + CompletionSource, DisableAiSettings, DocumentHighlight, InlayHint, Location, LocationLink, + PrepareRenameResponse, Project, ProjectItem, ProjectPath, ProjectTransaction, TaskSourceKind, + debugger::breakpoint_store::Breakpoint, debugger::{ breakpoint_store::{ BreakpointEditAction, BreakpointSessionState, BreakpointState, BreakpointStore, @@ -134,44 +159,12 @@ use project::{ session::{Session, SessionEvent}, }, git_store::{GitStoreEvent, RepositoryEvent}, - project_settings::{DiagnosticSeverity, GoToDiagnosticSeverityFilter}, -}; - -pub use git::blame::BlameRenderer; -pub use proposed_changes_editor::{ - ProposedChangeLocation, ProposedChangesEditor, ProposedChangesEditorToolbar, -}; -use std::{cell::OnceCell, iter::Peekable, ops::Not}; -use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables}; - -pub use lsp::CompletionContext; -use lsp::{ - CodeActionKind, CompletionItemKind, CompletionTriggerKind, InsertTextFormat, InsertTextMode, - LanguageServerId, LanguageServerName, -}; - -use language::BufferSnapshot; -pub use lsp_ext::lsp_tasks; -use movement::TextLayoutDetails; -pub use multi_buffer::{ - Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, PathKey, - RowInfo, ToOffset, ToPoint, -}; -use multi_buffer::{ - ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow, - MultiOrSingleBufferOffsetRange, ToOffsetUtf16, -}; -use parking_lot::Mutex; -use project::{ - CodeAction, Completion, CompletionIntent, CompletionSource, DocumentHighlight, InlayHint, - Location, LocationLink, PrepareRenameResponse, Project, ProjectItem, ProjectTransaction, - TaskSourceKind, - debugger::breakpoint_store::Breakpoint, lsp_store::{CompletionDocumentation, FormatTrigger, LspFormatTarget, OpenLspBufferHandle}, + project_settings::{DiagnosticSeverity, GoToDiagnosticSeverityFilter}, project_settings::{GitGutterSetting, ProjectSettings}, }; -use rand::prelude::*; -use rpc::{ErrorExt, proto::*}; +use rand::{seq::SliceRandom, thread_rng}; +use rpc::{ErrorCode, ErrorExt, proto::PeerId}; use scroll::{Autoscroll, OngoingScroll, ScrollAnchor, ScrollManager, ScrollbarAutoHide}; use selections_collection::{ MutableSelectionsCollection, SelectionsCollection, resolve_selections, @@ -180,21 +173,24 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsLocation, SettingsStore, update_settings_file}; use smallvec::{SmallVec, smallvec}; use snippet::Snippet; -use std::sync::Arc; use std::{ any::TypeId, borrow::Cow, + cell::OnceCell, cell::RefCell, cmp::{self, Ordering, Reverse}, + iter::Peekable, mem, num::NonZeroU32, + ops::Not, ops::{ControlFlow, Deref, DerefMut, Range, RangeInclusive}, path::{Path, PathBuf}, rc::Rc, + sync::Arc, time::{Duration, Instant}, }; -pub use sum_tree::Bias; use sum_tree::TreeMap; +use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables}; use text::{BufferId, FromAnchor, OffsetUtf16, Rope}; use theme::{ ActiveTheme, PlayerColor, StatusColors, SyntaxTheme, Theme, ThemeSettings, @@ -213,14 +209,11 @@ use workspace::{ notifications::{DetachAndPromptErr, NotificationId, NotifyTaskExt}, searchable::SearchEvent, }; -use zed_actions; use crate::{ code_context_menus::CompletionsMenuSource, - hover_links::{find_url, find_url_from_range}, -}; -use crate::{ editor_settings::MultiCursorModifier, + hover_links::{find_url, find_url_from_range}, signature_help::{SignatureHelpHiddenBy, SignatureHelpState}, }; @@ -275,7 +268,7 @@ impl InlineValueCache { #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum InlayId { - InlineCompletion(usize), + EditPrediction(usize), DebuggerValue(usize), // LSP Hint(usize), @@ -285,7 +278,7 @@ pub enum InlayId { impl InlayId { fn id(&self) -> usize { match self { - Self::InlineCompletion(id) => *id, + Self::EditPrediction(id) => *id, Self::DebuggerValue(id) => *id, Self::Hint(id) => *id, Self::Color(id) => *id, @@ -554,7 +547,7 @@ pub struct EditorStyle { pub syntax: Arc<SyntaxTheme>, pub status: StatusColors, pub inlay_hints_style: HighlightStyle, - pub inline_completion_styles: InlineCompletionStyles, + pub edit_prediction_styles: EditPredictionStyles, pub unnecessary_code_fade: f32, pub show_underlines: bool, } @@ -573,7 +566,7 @@ impl Default for EditorStyle { // style and retrieve them directly from the theme. status: StatusColors::dark(), inlay_hints_style: HighlightStyle::default(), - inline_completion_styles: InlineCompletionStyles { + edit_prediction_styles: EditPredictionStyles { insertion: HighlightStyle::default(), whitespace: HighlightStyle::default(), }, @@ -595,8 +588,8 @@ pub fn make_inlay_hints_style(cx: &mut App) -> HighlightStyle { } } -pub fn make_suggestion_styles(cx: &mut App) -> InlineCompletionStyles { - InlineCompletionStyles { +pub fn make_suggestion_styles(cx: &mut App) -> EditPredictionStyles { + EditPredictionStyles { insertion: HighlightStyle { color: Some(cx.theme().status().predictive), ..HighlightStyle::default() @@ -616,7 +609,7 @@ pub(crate) enum EditDisplayMode { Inline, } -enum InlineCompletion { +enum EditPrediction { Edit { edits: Vec<(Range<Anchor>, String)>, edit_preview: Option<EditPreview>, @@ -629,9 +622,9 @@ enum InlineCompletion { }, } -struct InlineCompletionState { +struct EditPredictionState { inlay_ids: Vec<InlayId>, - completion: InlineCompletion, + completion: EditPrediction, completion_id: Option<SharedString>, invalidation_range: Range<Anchor>, } @@ -644,7 +637,7 @@ enum EditPredictionSettings { }, } -enum InlineCompletionHighlight {} +enum EditPredictionHighlight {} #[derive(Debug, Clone)] struct InlineDiagnostic { @@ -655,7 +648,7 @@ struct InlineDiagnostic { severity: lsp::DiagnosticSeverity, } -pub enum MenuInlineCompletionsPolicy { +pub enum MenuEditPredictionsPolicy { Never, ByProvider, } @@ -1094,15 +1087,15 @@ pub struct Editor { pending_mouse_down: Option<Rc<RefCell<Option<MouseDownEvent>>>>, gutter_hovered: bool, hovered_link_state: Option<HoveredLinkState>, - edit_prediction_provider: Option<RegisteredInlineCompletionProvider>, + edit_prediction_provider: Option<RegisteredEditPredictionProvider>, code_action_providers: Vec<Rc<dyn CodeActionProvider>>, - active_inline_completion: Option<InlineCompletionState>, + active_edit_prediction: Option<EditPredictionState>, /// Used to prevent flickering as the user types while the menu is open - stale_inline_completion_in_menu: Option<InlineCompletionState>, + stale_edit_prediction_in_menu: Option<EditPredictionState>, edit_prediction_settings: EditPredictionSettings, - inline_completions_hidden_for_vim_mode: bool, - show_inline_completions_override: Option<bool>, - menu_inline_completions_policy: MenuInlineCompletionsPolicy, + edit_predictions_hidden_for_vim_mode: bool, + show_edit_predictions_override: Option<bool>, + menu_edit_predictions_policy: MenuEditPredictionsPolicy, edit_prediction_preview: EditPredictionPreview, edit_prediction_indent_conflict: bool, edit_prediction_requires_modifier_in_indent_conflict: bool, @@ -1305,6 +1298,7 @@ impl Default for SelectionHistoryMode { /// /// Similarly, you might want to disable scrolling if you don't want the viewport to /// move. +#[derive(Clone)] pub struct SelectionEffects { nav_history: Option<bool>, completions: bool, @@ -1516,8 +1510,8 @@ pub struct RenameState { struct InvalidationStack<T>(Vec<T>); -struct RegisteredInlineCompletionProvider { - provider: Arc<dyn InlineCompletionProviderHandle>, +struct RegisteredEditPredictionProvider { + provider: Arc<dyn EditPredictionProviderHandle>, _subscription: Subscription, } @@ -1774,7 +1768,7 @@ impl Editor { ) -> Self { debug_assert!( display_map.is_none() || mode.is_minimap(), - "Providing a display map for a new editor is only intended for the minimap and might have unindended side effects otherwise!" + "Providing a display map for a new editor is only intended for the minimap and might have unintended side effects otherwise!" ); let full_mode = mode.is_full(); @@ -1870,7 +1864,6 @@ impl Editor { editor.tasks_update_task = Some(editor.refresh_runnables(window, cx)); } - editor.update_lsp_data(true, None, window, cx); } project::Event::SnippetEdit(id, snippet_edits) => { if let Some(buffer) = editor.buffer.read(cx).buffer(*id) { @@ -1892,6 +1885,11 @@ impl Editor { } } } + project::Event::LanguageServerBufferRegistered { buffer_id, .. } => { + if editor.buffer().read(cx).buffer(*buffer_id).is_some() { + editor.update_lsp_data(false, Some(*buffer_id), window, cx); + } + } _ => {} }, )); @@ -2102,8 +2100,8 @@ impl Editor { pending_mouse_down: None, hovered_link_state: None, edit_prediction_provider: None, - active_inline_completion: None, - stale_inline_completion_in_menu: None, + active_edit_prediction: None, + stale_edit_prediction_in_menu: None, edit_prediction_preview: EditPredictionPreview::Inactive { released_too_fast: false, }, @@ -2122,9 +2120,9 @@ impl Editor { hovered_cursors: HashMap::default(), next_editor_action_id: EditorActionId::default(), editor_actions: Rc::default(), - inline_completions_hidden_for_vim_mode: false, - show_inline_completions_override: None, - menu_inline_completions_policy: MenuInlineCompletionsPolicy::ByProvider, + edit_predictions_hidden_for_vim_mode: false, + show_edit_predictions_override: None, + menu_edit_predictions_policy: MenuEditPredictionsPolicy::ByProvider, edit_prediction_settings: EditPredictionSettings::Disabled, edit_prediction_indent_conflict: false, edit_prediction_requires_modifier_in_indent_conflict: true, @@ -2356,7 +2354,7 @@ impl Editor { } pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext { - self.key_context_internal(self.has_active_inline_completion(), window, cx) + self.key_context_internal(self.has_active_edit_prediction(), window, cx) } fn key_context_internal( @@ -2723,17 +2721,16 @@ impl Editor { ) where T: EditPredictionProvider, { - self.edit_prediction_provider = - provider.map(|provider| RegisteredInlineCompletionProvider { - _subscription: cx.observe_in(&provider, window, |this, _, window, cx| { - if this.focus_handle.is_focused(window) { - this.update_visible_inline_completion(window, cx); - } - }), - provider: Arc::new(provider), - }); + self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionProvider { + _subscription: cx.observe_in(&provider, window, |this, _, window, cx| { + if this.focus_handle.is_focused(window) { + this.update_visible_edit_prediction(window, cx); + } + }), + provider: Arc::new(provider), + }); self.update_edit_prediction_settings(cx); - self.refresh_inline_completion(false, false, window, cx); + self.refresh_edit_prediction(false, false, window, cx); } pub fn placeholder_text(&self) -> Option<&str> { @@ -2804,24 +2801,24 @@ impl Editor { self.input_enabled = input_enabled; } - pub fn set_inline_completions_hidden_for_vim_mode( + pub fn set_edit_predictions_hidden_for_vim_mode( &mut self, hidden: bool, window: &mut Window, cx: &mut Context<Self>, ) { - if hidden != self.inline_completions_hidden_for_vim_mode { - self.inline_completions_hidden_for_vim_mode = hidden; + if hidden != self.edit_predictions_hidden_for_vim_mode { + self.edit_predictions_hidden_for_vim_mode = hidden; if hidden { - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); } else { - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); } } } - pub fn set_menu_inline_completions_policy(&mut self, value: MenuInlineCompletionsPolicy) { - self.menu_inline_completions_policy = value; + pub fn set_menu_edit_predictions_policy(&mut self, value: MenuEditPredictionsPolicy) { + self.menu_edit_predictions_policy = value; } pub fn set_autoindent(&mut self, autoindent: bool) { @@ -2858,7 +2855,7 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - if self.show_inline_completions_override.is_some() { + if self.show_edit_predictions_override.is_some() { self.set_show_edit_predictions(None, window, cx); } else { let show_edit_predictions = !self.edit_predictions_enabled(); @@ -2872,17 +2869,17 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - self.show_inline_completions_override = show_edit_predictions; + self.show_edit_predictions_override = show_edit_predictions; self.update_edit_prediction_settings(cx); if let Some(false) = show_edit_predictions { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); } else { - self.refresh_inline_completion(false, true, window, cx); + self.refresh_edit_prediction(false, true, window, cx); } } - fn inline_completions_disabled_in_scope( + fn edit_predictions_disabled_in_scope( &self, buffer: &Entity<Buffer>, buffer_position: language::Anchor, @@ -2944,10 +2941,12 @@ impl Editor { } } + let selection_anchors = self.selections.disjoint_anchors(); + if self.focus_handle.is_focused(window) && self.leader_id.is_none() { self.buffer.update(cx, |buffer, cx| { buffer.set_active_selections( - &self.selections.disjoint_anchors(), + &selection_anchors, self.selections.line_mode, self.cursor_shape, cx, @@ -2964,9 +2963,8 @@ impl Editor { self.select_next_state = None; self.select_prev_state = None; self.select_syntax_node_history.try_clear(); - self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), buffer); - self.snippet_stack - .invalidate(&self.selections.disjoint_anchors(), buffer); + self.invalidate_autoclose_regions(&selection_anchors, buffer); + self.snippet_stack.invalidate(&selection_anchors, buffer); self.take_rename(false, window, cx); let newest_selection = self.selections.newest_anchor(); @@ -3048,7 +3046,7 @@ impl Editor { self.refresh_document_highlights(cx); self.refresh_selected_text_highlights(false, window, cx); refresh_matching_bracket_highlights(self, window, cx); - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); self.edit_prediction_requires_modifier_in_indent_conflict = true; linked_editing_ranges::refresh_linked_ranges(self, window, cx); self.inline_blame_popover.take(); @@ -3838,7 +3836,7 @@ impl Editor { return true; } - if is_user_requested && self.discard_inline_completion(true, cx) { + if is_user_requested && self.discard_edit_prediction(true, cx) { return true; } @@ -4047,7 +4045,8 @@ impl Editor { // then don't insert that closing bracket again; just move the selection // past the closing bracket. let should_skip = selection.end == region.range.end.to_point(&snapshot) - && text.as_ref() == region.pair.end.as_str(); + && text.as_ref() == region.pair.end.as_str() + && snapshot.contains_str_at(region.range.end, text.as_ref()); if should_skip { let anchor = snapshot.anchor_after(selection.end); new_selections @@ -4243,7 +4242,7 @@ impl Editor { ); } - let had_active_inline_completion = this.has_active_inline_completion(); + let had_active_edit_prediction = this.has_active_edit_prediction(); this.change_selections( SelectionEffects::scroll(Autoscroll::fit()).completions(false), window, @@ -4268,7 +4267,7 @@ impl Editor { } let trigger_in_words = - this.show_edit_predictions_in_menu() || !had_active_inline_completion; + this.show_edit_predictions_in_menu() || !had_active_edit_prediction; if this.hard_wrap.is_some() { let latest: Range<Point> = this.selections.newest(cx).range(); if latest.is_empty() @@ -4290,7 +4289,7 @@ impl Editor { } this.trigger_completion_on_input(&text, trigger_in_words, window, cx); linked_editing_ranges::refresh_linked_ranges(this, window, cx); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); jsx_tag_auto_close::handle_from(this, initial_buffer_versions, window, cx); }); } @@ -4625,7 +4624,7 @@ impl Editor { .collect(); this.change_selections(Default::default(), window, cx, |s| s.select(new_selections)); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -4973,13 +4972,17 @@ impl Editor { }) } - /// Remove any autoclose regions that no longer contain their selection. + /// Remove any autoclose regions that no longer contain their selection or have invalid anchors in ranges. fn invalidate_autoclose_regions( &mut self, mut selections: &[Selection<Anchor>], buffer: &MultiBufferSnapshot, ) { self.autoclose_regions.retain(|state| { + if !state.range.start.is_valid(buffer) || !state.range.end.is_valid(buffer) { + return false; + } + let mut i = 0; while let Some(selection) = selections.get(i) { if selection.end.cmp(&state.range.start, buffer).is_lt() { @@ -5669,9 +5672,9 @@ impl Editor { crate::hover_popover::hide_hover(editor, cx); if editor.show_edit_predictions_in_menu() { - editor.update_visible_inline_completion(window, cx); + editor.update_visible_edit_prediction(window, cx); } else { - editor.discard_inline_completion(false, cx); + editor.discard_edit_prediction(false, cx); } cx.notify(); @@ -5682,10 +5685,10 @@ impl Editor { if editor.completion_tasks.len() <= 1 { // If there are no more completion tasks and the last menu was empty, we should hide it. let was_hidden = editor.hide_context_menu(window, cx).is_none(); - // If it was already hidden and we don't show inline completions in the menu, we should - // also show the inline-completion when available. + // If it was already hidden and we don't show edit predictions in the menu, + // we should also show the edit prediction when available. if was_hidden && editor.show_edit_predictions_in_menu() { - editor.update_visible_inline_completion(window, cx); + editor.update_visible_edit_prediction(window, cx); } } }) @@ -5779,7 +5782,7 @@ impl Editor { let entries = completions_menu.entries.borrow(); let mat = entries.get(item_ix.unwrap_or(completions_menu.selected_item))?; if self.show_edit_predictions_in_menu() { - self.discard_inline_completion(true, cx); + self.discard_edit_prediction(true, cx); } mat.candidate_id }; @@ -5891,18 +5894,20 @@ impl Editor { text: new_text[common_prefix_len..].into(), }); - self.transact(window, cx, |this, window, cx| { + self.transact(window, cx, |editor, window, cx| { if let Some(mut snippet) = snippet { snippet.text = new_text.to_string(); - this.insert_snippet(&ranges, snippet, window, cx).log_err(); + editor + .insert_snippet(&ranges, snippet, window, cx) + .log_err(); } else { - this.buffer.update(cx, |buffer, cx| { + editor.buffer.update(cx, |multi_buffer, cx| { let auto_indent = match completion.insert_text_mode { Some(InsertTextMode::AS_IS) => None, - _ => this.autoindent_mode.clone(), + _ => editor.autoindent_mode.clone(), }; let edits = ranges.into_iter().map(|range| (range, new_text.as_str())); - buffer.edit(edits, auto_indent, cx); + multi_buffer.edit(edits, auto_indent, cx); }); } for (buffer, edits) in linked_edits { @@ -5921,8 +5926,9 @@ impl Editor { }) } - this.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); + self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), &snapshot); let show_new_completions_on_confirm = completion .confirm @@ -5980,7 +5986,7 @@ impl Editor { let deployed_from = action.deployed_from.clone(); let action = action.clone(); self.completion_tasks.clear(); - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); let multibuffer_point = match &action.deployed_from { Some(CodeActionSource::Indicator(row)) | Some(CodeActionSource::RunMenu(row)) => { @@ -6400,7 +6406,6 @@ impl Editor { IconButton::new("inline_code_actions", ui::IconName::BoltFilled) .icon_size(icon_size) .shape(ui::IconButtonShape::Square) - .style(ButtonStyle::Transparent) .icon_color(ui::Color::Hidden) .toggle_state(is_active) .when(show_tooltip, |this| { @@ -6986,20 +6991,24 @@ impl Editor { } } - pub fn refresh_inline_completion( + pub fn refresh_edit_prediction( &mut self, debounce: bool, user_requested: bool, window: &mut Window, cx: &mut Context<Self>, ) -> Option<()> { + if DisableAiSettings::get_global(cx).disable_ai { + return None; + } + let provider = self.edit_prediction_provider()?; let cursor = self.selections.newest_anchor().head(); let (buffer, cursor_buffer_position) = self.buffer.read(cx).text_anchor_for_position(cursor, cx)?; if !self.edit_predictions_enabled_in_buffer(&buffer, cursor_buffer_position, cx) { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); return None; } @@ -7008,11 +7017,11 @@ impl Editor { || !self.is_focused(window) || buffer.read(cx).is_empty()) { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); return None; } - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); provider.refresh( self.project.clone(), buffer, @@ -7048,8 +7057,9 @@ impl Editor { } pub fn update_edit_prediction_settings(&mut self, cx: &mut Context<Self>) { - if self.edit_prediction_provider.is_none() { + if self.edit_prediction_provider.is_none() || DisableAiSettings::get_global(cx).disable_ai { self.edit_prediction_settings = EditPredictionSettings::Disabled; + self.discard_edit_prediction(false, cx); } else { let selection = self.selections.newest_anchor(); let cursor = selection.head(); @@ -7070,8 +7080,8 @@ impl Editor { cx: &App, ) -> EditPredictionSettings { if !self.mode.is_full() - || !self.show_inline_completions_override.unwrap_or(true) - || self.inline_completions_disabled_in_scope(buffer, buffer_position, cx) + || !self.show_edit_predictions_override.unwrap_or(true) + || self.edit_predictions_disabled_in_scope(buffer, buffer_position, cx) { return EditPredictionSettings::Disabled; } @@ -7085,8 +7095,8 @@ impl Editor { }; let by_provider = matches!( - self.menu_inline_completions_policy, - MenuInlineCompletionsPolicy::ByProvider + self.menu_edit_predictions_policy, + MenuEditPredictionsPolicy::ByProvider ); let show_in_menu = by_provider @@ -7156,7 +7166,7 @@ impl Editor { .unwrap_or(false) } - fn cycle_inline_completion( + fn cycle_edit_prediction( &mut self, direction: Direction, window: &mut Window, @@ -7166,28 +7176,28 @@ impl Editor { let cursor = self.selections.newest_anchor().head(); let (buffer, cursor_buffer_position) = self.buffer.read(cx).text_anchor_for_position(cursor, cx)?; - if self.inline_completions_hidden_for_vim_mode || !self.should_show_edit_predictions() { + if self.edit_predictions_hidden_for_vim_mode || !self.should_show_edit_predictions() { return None; } provider.cycle(buffer, cursor_buffer_position, direction, cx); - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); Some(()) } - pub fn show_inline_completion( + pub fn show_edit_prediction( &mut self, _: &ShowEditPrediction, window: &mut Window, cx: &mut Context<Self>, ) { - if !self.has_active_inline_completion() { - self.refresh_inline_completion(false, true, window, cx); + if !self.has_active_edit_prediction() { + self.refresh_edit_prediction(false, true, window, cx); return; } - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); } pub fn display_cursor_names( @@ -7219,11 +7229,11 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - if self.has_active_inline_completion() { - self.cycle_inline_completion(Direction::Next, window, cx); + if self.has_active_edit_prediction() { + self.cycle_edit_prediction(Direction::Next, window, cx); } else { let is_copilot_disabled = self - .refresh_inline_completion(false, true, window, cx) + .refresh_edit_prediction(false, true, window, cx) .is_none(); if is_copilot_disabled { cx.propagate(); @@ -7237,11 +7247,11 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - if self.has_active_inline_completion() { - self.cycle_inline_completion(Direction::Prev, window, cx); + if self.has_active_edit_prediction() { + self.cycle_edit_prediction(Direction::Prev, window, cx); } else { let is_copilot_disabled = self - .refresh_inline_completion(false, true, window, cx) + .refresh_edit_prediction(false, true, window, cx) .is_none(); if is_copilot_disabled { cx.propagate(); @@ -7259,18 +7269,14 @@ impl Editor { self.hide_context_menu(window, cx); } - let Some(active_inline_completion) = self.active_inline_completion.as_ref() else { + let Some(active_edit_prediction) = self.active_edit_prediction.as_ref() else { return; }; - self.report_inline_completion_event( - active_inline_completion.completion_id.clone(), - true, - cx, - ); + self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx); - match &active_inline_completion.completion { - InlineCompletion::Move { target, .. } => { + match &active_edit_prediction.completion { + EditPrediction::Move { target, .. } => { let target = *target; if let Some(position_map) = &self.last_position_map { @@ -7312,7 +7318,7 @@ impl Editor { } } } - InlineCompletion::Edit { edits, .. } => { + EditPrediction::Edit { edits, .. } => { if let Some(provider) = self.edit_prediction_provider() { provider.accept(cx); } @@ -7340,9 +7346,9 @@ impl Editor { } } - self.update_visible_inline_completion(window, cx); - if self.active_inline_completion.is_none() { - self.refresh_inline_completion(true, true, window, cx); + self.update_visible_edit_prediction(window, cx); + if self.active_edit_prediction.is_none() { + self.refresh_edit_prediction(true, true, window, cx); } cx.notify(); @@ -7352,27 +7358,23 @@ impl Editor { self.edit_prediction_requires_modifier_in_indent_conflict = false; } - pub fn accept_partial_inline_completion( + pub fn accept_partial_edit_prediction( &mut self, _: &AcceptPartialEditPrediction, window: &mut Window, cx: &mut Context<Self>, ) { - let Some(active_inline_completion) = self.active_inline_completion.as_ref() else { + let Some(active_edit_prediction) = self.active_edit_prediction.as_ref() else { return; }; if self.selections.count() != 1 { return; } - self.report_inline_completion_event( - active_inline_completion.completion_id.clone(), - true, - cx, - ); + self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx); - match &active_inline_completion.completion { - InlineCompletion::Move { target, .. } => { + match &active_edit_prediction.completion { + EditPrediction::Move { target, .. } => { let target = *target; self.change_selections( SelectionEffects::scroll(Autoscroll::newest()), @@ -7383,7 +7385,7 @@ impl Editor { }, ); } - InlineCompletion::Edit { edits, .. } => { + EditPrediction::Edit { edits, .. } => { // Find an insertion that starts at the cursor position. let snapshot = self.buffer.read(cx).snapshot(cx); let cursor_offset = self.selections.newest::<usize>(cx).head(); @@ -7417,7 +7419,7 @@ impl Editor { self.insert_with_autoindent_mode(&partial_completion, None, window, cx); - self.refresh_inline_completion(true, true, window, cx); + self.refresh_edit_prediction(true, true, window, cx); cx.notify(); } else { self.accept_edit_prediction(&Default::default(), window, cx); @@ -7426,28 +7428,28 @@ impl Editor { } } - fn discard_inline_completion( + fn discard_edit_prediction( &mut self, - should_report_inline_completion_event: bool, + should_report_edit_prediction_event: bool, cx: &mut Context<Self>, ) -> bool { - if should_report_inline_completion_event { + if should_report_edit_prediction_event { let completion_id = self - .active_inline_completion + .active_edit_prediction .as_ref() .and_then(|active_completion| active_completion.completion_id.clone()); - self.report_inline_completion_event(completion_id, false, cx); + self.report_edit_prediction_event(completion_id, false, cx); } if let Some(provider) = self.edit_prediction_provider() { provider.discard(cx); } - self.take_active_inline_completion(cx) + self.take_active_edit_prediction(cx) } - fn report_inline_completion_event(&self, id: Option<SharedString>, accepted: bool, cx: &App) { + fn report_edit_prediction_event(&self, id: Option<SharedString>, accepted: bool, cx: &App) { let Some(provider) = self.edit_prediction_provider() else { return; }; @@ -7478,18 +7480,18 @@ impl Editor { ); } - pub fn has_active_inline_completion(&self) -> bool { - self.active_inline_completion.is_some() + pub fn has_active_edit_prediction(&self) -> bool { + self.active_edit_prediction.is_some() } - fn take_active_inline_completion(&mut self, cx: &mut Context<Self>) -> bool { - let Some(active_inline_completion) = self.active_inline_completion.take() else { + fn take_active_edit_prediction(&mut self, cx: &mut Context<Self>) -> bool { + let Some(active_edit_prediction) = self.active_edit_prediction.take() else { return false; }; - self.splice_inlays(&active_inline_completion.inlay_ids, Default::default(), cx); - self.clear_highlights::<InlineCompletionHighlight>(cx); - self.stale_inline_completion_in_menu = Some(active_inline_completion); + self.splice_inlays(&active_edit_prediction.inlay_ids, Default::default(), cx); + self.clear_highlights::<EditPredictionHighlight>(cx); + self.stale_edit_prediction_in_menu = Some(active_edit_prediction); true } @@ -7634,7 +7636,7 @@ impl Editor { since: Instant::now(), }; - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); cx.notify(); } } else if let EditPredictionPreview::Active { @@ -7657,16 +7659,20 @@ impl Editor { released_too_fast: since.elapsed() < Duration::from_millis(200), }; self.clear_row_highlights::<EditPredictionPreview>(); - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); cx.notify(); } } - fn update_visible_inline_completion( + fn update_visible_edit_prediction( &mut self, _window: &mut Window, cx: &mut Context<Self>, ) -> Option<()> { + if DisableAiSettings::get_global(cx).disable_ai { + return None; + } + let selection = self.selections.newest_anchor(); let cursor = selection.head(); let multibuffer = self.buffer.read(cx).snapshot(cx); @@ -7676,12 +7682,12 @@ impl Editor { let show_in_menu = self.show_edit_predictions_in_menu(); let completions_menu_has_precedence = !show_in_menu && (self.context_menu.borrow().is_some() - || (!self.completion_tasks.is_empty() && !self.has_active_inline_completion())); + || (!self.completion_tasks.is_empty() && !self.has_active_edit_prediction())); if completions_menu_has_precedence || !offset_selection.is_empty() || self - .active_inline_completion + .active_edit_prediction .as_ref() .map_or(false, |completion| { let invalidation_range = completion.invalidation_range.to_offset(&multibuffer); @@ -7689,11 +7695,11 @@ impl Editor { !invalidation_range.contains(&offset_selection.head()) }) { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); return None; } - self.take_active_inline_completion(cx); + self.take_active_edit_prediction(cx); let Some(provider) = self.edit_prediction_provider() else { self.edit_prediction_settings = EditPredictionSettings::Disabled; return None; @@ -7719,8 +7725,8 @@ impl Editor { } } - let inline_completion = provider.suggest(&buffer, cursor_buffer_position, cx)?; - let edits = inline_completion + let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?; + let edits = edit_prediction .edits .into_iter() .flat_map(|(range, new_text)| { @@ -7755,15 +7761,15 @@ impl Editor { None }; let is_move = - move_invalidation_row_range.is_some() || self.inline_completions_hidden_for_vim_mode; + move_invalidation_row_range.is_some() || self.edit_predictions_hidden_for_vim_mode; let completion = if is_move { invalidation_row_range = move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row); let target = first_edit_start; - InlineCompletion::Move { target, snapshot } + EditPrediction::Move { target, snapshot } } else { let show_completions_in_buffer = !self.edit_prediction_visible_in_cursor_popover(true) - && !self.inline_completions_hidden_for_vim_mode; + && !self.edit_predictions_hidden_for_vim_mode; if show_completions_in_buffer { if edits @@ -7772,7 +7778,7 @@ impl Editor { { let mut inlays = Vec::new(); for (range, new_text) in &edits { - let inlay = Inlay::inline_completion( + let inlay = Inlay::edit_prediction( post_inc(&mut self.next_inlay_id), range.start, new_text.as_str(), @@ -7784,7 +7790,7 @@ impl Editor { self.splice_inlays(&[], inlays, cx); } else { let background_color = cx.theme().status().deleted_background; - self.highlight_text::<InlineCompletionHighlight>( + self.highlight_text::<EditPredictionHighlight>( edits.iter().map(|(range, _)| range.clone()).collect(), HighlightStyle { background_color: Some(background_color), @@ -7807,9 +7813,9 @@ impl Editor { EditDisplayMode::DiffPopover }; - InlineCompletion::Edit { + EditPrediction::Edit { edits, - edit_preview: inline_completion.edit_preview, + edit_preview: edit_prediction.edit_preview, display_mode, snapshot, } @@ -7822,11 +7828,11 @@ impl Editor { multibuffer.line_len(MultiBufferRow(invalidation_row_range.end)), )); - self.stale_inline_completion_in_menu = None; - self.active_inline_completion = Some(InlineCompletionState { + self.stale_edit_prediction_in_menu = None; + self.active_edit_prediction = Some(EditPredictionState { inlay_ids, completion, - completion_id: inline_completion.id, + completion_id: edit_prediction.id, invalidation_range, }); @@ -7835,7 +7841,7 @@ impl Editor { Some(()) } - pub fn edit_prediction_provider(&self) -> Option<Arc<dyn InlineCompletionProviderHandle>> { + pub fn edit_prediction_provider(&self) -> Option<Arc<dyn EditPredictionProviderHandle>> { Some(self.edit_prediction_provider.as_ref()?.provider.clone()) } @@ -8235,8 +8241,7 @@ impl Editor { return; }; - // Try to find a closest, enclosing node using tree-sitter that has a - // task + // Try to find a closest, enclosing node using tree-sitter that has a task let Some((buffer, buffer_row, tasks)) = self .find_enclosing_node_task(cx) // Or find the task that's closest in row-distance. @@ -8336,26 +8341,29 @@ impl Editor { let color = Color::Muted; let position = breakpoint.as_ref().map(|(anchor, _, _)| *anchor); - IconButton::new(("run_indicator", row.0 as usize), ui::IconName::Play) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(color) - .toggle_state(is_active) - .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| { - let quick_launch = e.down.button == MouseButton::Left; - window.focus(&editor.focus_handle(cx)); - editor.toggle_code_actions( - &ToggleCodeActions { - deployed_from: Some(CodeActionSource::RunMenu(row)), - quick_launch, - }, - window, - cx, - ); - })) - .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| { - editor.set_breakpoint_context_menu(row, position, event.down.position, window, cx); - })) + IconButton::new( + ("run_indicator", row.0 as usize), + ui::IconName::PlayOutlined, + ) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(color) + .toggle_state(is_active) + .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| { + let quick_launch = e.down.button == MouseButton::Left; + window.focus(&editor.focus_handle(cx)); + editor.toggle_code_actions( + &ToggleCodeActions { + deployed_from: Some(CodeActionSource::RunMenu(row)), + quick_launch, + }, + window, + cx, + ); + })) + .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| { + editor.set_breakpoint_context_menu(row, position, event.down.position, window, cx); + })) } pub fn context_menu_visible(&self) -> bool { @@ -8402,14 +8410,14 @@ impl Editor { if self.mode().is_minimap() { return None; } - let active_inline_completion = self.active_inline_completion.as_ref()?; + let active_edit_prediction = self.active_edit_prediction.as_ref()?; if self.edit_prediction_visible_in_cursor_popover(true) { return None; } - match &active_inline_completion.completion { - InlineCompletion::Move { target, .. } => { + match &active_edit_prediction.completion { + EditPrediction::Move { target, .. } => { let target_display_point = target.to_display_point(editor_snapshot); if self.edit_prediction_requires_modifier() { @@ -8446,11 +8454,11 @@ impl Editor { ) } } - InlineCompletion::Edit { + EditPrediction::Edit { display_mode: EditDisplayMode::Inline, .. } => None, - InlineCompletion::Edit { + EditPrediction::Edit { display_mode: EditDisplayMode::TabAccept, edits, .. @@ -8471,7 +8479,7 @@ impl Editor { cx, ) } - InlineCompletion::Edit { + EditPrediction::Edit { edits, edit_preview, display_mode: EditDisplayMode::DiffPopover, @@ -8788,7 +8796,7 @@ impl Editor { } let highlighted_edits = - crate::inline_completion_edit_text(&snapshot, edits, edit_preview.as_ref()?, false, cx); + crate::edit_prediction_edit_text(&snapshot, edits, edit_preview.as_ref()?, false, cx); let styled_text = highlighted_edits.to_styled_text(&style.text); let line_count = highlighted_edits.text.lines().count(); @@ -9118,7 +9126,7 @@ impl Editor { .child(Icon::new(IconName::ZedPredict)) } - let completion = match &self.active_inline_completion { + let completion = match &self.active_edit_prediction { Some(prediction) => { if !self.has_visible_completions_menu() { const RADIUS: Pixels = px(6.); @@ -9136,7 +9144,7 @@ impl Editor { .rounded_tl(px(0.)) .overflow_hidden() .child(div().px_1p5().child(match &prediction.completion { - InlineCompletion::Move { target, snapshot } => { + EditPrediction::Move { target, snapshot } => { use text::ToPoint as _; if target.text_anchor.to_point(&snapshot).row > cursor_point.row { @@ -9145,7 +9153,7 @@ impl Editor { Icon::new(IconName::ZedPredictUp) } } - InlineCompletion::Edit { .. } => Icon::new(IconName::ZedPredict), + EditPrediction::Edit { .. } => Icon::new(IconName::ZedPredict), })) .child( h_flex() @@ -9204,7 +9212,7 @@ impl Editor { )? } - None if is_refreshing => match &self.stale_inline_completion_in_menu { + None if is_refreshing => match &self.stale_edit_prediction_in_menu { Some(stale_completion) => self.render_edit_prediction_cursor_popover_preview( stale_completion, cursor_point, @@ -9234,7 +9242,7 @@ impl Editor { completion.into_any_element() }; - let has_completion = self.active_inline_completion.is_some(); + let has_completion = self.active_edit_prediction.is_some(); let is_platform_style_mac = PlatformStyle::platform() == PlatformStyle::Mac; Some( @@ -9293,7 +9301,7 @@ impl Editor { fn render_edit_prediction_cursor_popover_preview( &self, - completion: &InlineCompletionState, + completion: &EditPredictionState, cursor_point: Point, style: &EditorStyle, cx: &mut Context<Editor>, @@ -9321,7 +9329,7 @@ impl Editor { } match &completion.completion { - InlineCompletion::Move { + EditPrediction::Move { target, snapshot, .. } => Some( h_flex() @@ -9338,7 +9346,7 @@ impl Editor { .child(Label::new("Jump to Edit")), ), - InlineCompletion::Edit { + EditPrediction::Edit { edits, edit_preview, snapshot, @@ -9346,7 +9354,7 @@ impl Editor { } => { let first_edit_row = edits.first()?.0.start.text_anchor.to_point(&snapshot).row; - let (highlighted_edits, has_more_lines) = crate::inline_completion_edit_text( + let (highlighted_edits, has_more_lines) = crate::edit_prediction_edit_text( &snapshot, &edits, edit_preview.as_ref()?, @@ -9424,8 +9432,8 @@ impl Editor { cx.notify(); self.completion_tasks.clear(); let context_menu = self.context_menu.borrow_mut().take(); - self.stale_inline_completion_in_menu.take(); - self.update_visible_inline_completion(window, cx); + self.stale_edit_prediction_in_menu.take(); + self.update_visible_edit_prediction(window, cx); if let Some(CodeContextMenu::Completions(_)) = &context_menu { if let Some(completion_provider) = &self.completion_provider { completion_provider.selection_changed(None, window, cx); @@ -9563,27 +9571,46 @@ impl Editor { // Check whether the just-entered snippet ends with an auto-closable bracket. if self.autoclose_regions.is_empty() { let snapshot = self.buffer.read(cx).snapshot(cx); - for selection in &mut self.selections.all::<Point>(cx) { + let mut all_selections = self.selections.all::<Point>(cx); + for selection in &mut all_selections { let selection_head = selection.head(); let Some(scope) = snapshot.language_scope_at(selection_head) else { continue; }; let mut bracket_pair = None; - let next_chars = snapshot.chars_at(selection_head).collect::<String>(); - let prev_chars = snapshot - .reversed_chars_at(selection_head) - .collect::<String>(); - for (pair, enabled) in scope.brackets() { - if enabled - && pair.close - && prev_chars.starts_with(pair.start.as_str()) - && next_chars.starts_with(pair.end.as_str()) - { - bracket_pair = Some(pair.clone()); - break; + let max_lookup_length = scope + .brackets() + .map(|(pair, _)| { + pair.start + .as_str() + .chars() + .count() + .max(pair.end.as_str().chars().count()) + }) + .max(); + if let Some(max_lookup_length) = max_lookup_length { + let next_text = snapshot + .chars_at(selection_head) + .take(max_lookup_length) + .collect::<String>(); + let prev_text = snapshot + .reversed_chars_at(selection_head) + .take(max_lookup_length) + .collect::<String>(); + + for (pair, enabled) in scope.brackets() { + if enabled + && pair.close + && prev_text.starts_with(pair.start.as_str()) + && next_text.starts_with(pair.end.as_str()) + { + bracket_pair = Some(pair.clone()); + break; + } } } + if let Some(pair) = bracket_pair { let snapshot_settings = snapshot.language_settings_at(selection_head, cx); let autoclose_enabled = @@ -9764,7 +9791,7 @@ impl Editor { this.edit(edits, None, cx); }) } - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); linked_editing_ranges::refresh_linked_ranges(this, window, cx); }); } @@ -9783,7 +9810,7 @@ impl Editor { }) }); this.insert("", window, cx); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -9916,7 +9943,7 @@ impl Editor { self.transact(window, cx, |this, window, cx| { this.buffer.update(cx, |b, cx| b.edit(edits, None, cx)); this.change_selections(Default::default(), window, cx, |s| s.select(selections)); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -12245,7 +12272,7 @@ impl Editor { } self.request_autoscroll(Autoscroll::fit(), cx); self.unmark_text(window, cx); - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); cx.emit(EditorEvent::Edited { transaction_id }); cx.emit(EditorEvent::TransactionUndone { transaction_id }); } @@ -12275,7 +12302,7 @@ impl Editor { } self.request_autoscroll(Autoscroll::fit(), cx); self.unmark_text(window, cx); - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); cx.emit(EditorEvent::Edited { transaction_id }); } } @@ -15262,7 +15289,7 @@ impl Editor { ]) }); self.activate_diagnostics(buffer_id, next_diagnostic, window, cx); - self.refresh_inline_completion(false, true, window, cx); + self.refresh_edit_prediction(false, true, window, cx); } pub fn go_to_next_hunk(&mut self, _: &GoToHunk, window: &mut Window, cx: &mut Context<Self>) { @@ -15823,7 +15850,7 @@ impl Editor { let language_server_name = project .language_server_statuses(cx) .find(|(id, _)| server_id == *id) - .map(|(_, status)| LanguageServerName::from(status.name.as_str())); + .map(|(_, status)| status.name.clone()); language_server_name.map(|language_server_name| { project.open_local_buffer_via_lsp( lsp_location.uri.clone(), @@ -16226,7 +16253,7 @@ impl Editor { font_weight: Some(FontWeight::BOLD), ..make_inlay_hints_style(cx.app) }, - inline_completion_styles: make_suggestion_styles( + edit_prediction_styles: make_suggestion_styles( cx.app, ), ..EditorStyle::default() @@ -19000,7 +19027,7 @@ impl Editor { (selection.range(), uuid.to_string()) }); this.edit(edits, cx); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -19853,8 +19880,8 @@ impl Editor { self.refresh_selected_text_highlights(true, window, cx); self.refresh_single_line_folds(window, cx); refresh_matching_bracket_highlights(self, window, cx); - if self.has_active_inline_completion() { - self.update_visible_inline_completion(window, cx); + if self.has_active_edit_prediction() { + self.update_visible_edit_prediction(window, cx); } if let Some(project) = self.project.as_ref() { if let Some(edited_buffer) = edited_buffer { @@ -20056,7 +20083,7 @@ impl Editor { } self.tasks_update_task = Some(self.refresh_runnables(window, cx)); self.update_edit_prediction_settings(cx); - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); self.refresh_inline_values(cx); self.refresh_inlay_hints( InlayHintRefreshReason::SettingsChange(inlay_hint_settings( @@ -20688,7 +20715,7 @@ impl Editor { { self.hide_context_menu(window, cx); } - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); cx.emit(EditorEvent::Blurred); cx.notify(); } @@ -21101,13 +21128,6 @@ fn process_completion_for_edit( .is_le(), "replace_range should start before or at cursor position" ); - debug_assert!( - insert_range - .end - .cmp(&cursor_position, &buffer_snapshot) - .is_le(), - "insert_range should end before or at cursor position" - ); let should_replace = match intent { CompletionIntent::CompleteWithInsert => false, @@ -21835,11 +21855,11 @@ impl CodeActionProvider for Entity<Project> { cx: &mut App, ) -> Task<Result<Vec<CodeAction>>> { self.update(cx, |project, cx| { - let code_lens = project.code_lens(buffer, range.clone(), cx); + let code_lens_actions = project.code_lens_actions(buffer, range.clone(), cx); let code_actions = project.code_actions(buffer, range, None, cx); cx.background_spawn(async move { - let (code_lens, code_actions) = join(code_lens, code_actions).await; - Ok(code_lens + let (code_lens_actions, code_actions) = join(code_lens_actions, code_actions).await; + Ok(code_lens_actions .context("code lens fetch")? .into_iter() .chain(code_actions.context("code action fetch")?) @@ -22757,7 +22777,7 @@ impl Render for Editor { syntax: cx.theme().syntax().clone(), status: cx.theme().status().clone(), inlay_hints_style: make_inlay_hints_style(cx), - inline_completion_styles: make_suggestion_styles(cx), + edit_prediction_styles: make_suggestion_styles(cx), unnecessary_code_fade: ThemeSettings::get_global(cx).unnecessary_code_fade, show_underlines: self.diagnostics_enabled(), }, @@ -23152,7 +23172,7 @@ impl InvalidationRegion for SnippetState { } } -fn inline_completion_edit_text( +fn edit_prediction_edit_text( current_snapshot: &BufferSnapshot, edits: &[(Range<Anchor>, String)], edit_preview: &EditPreview, diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index a0333bb494..1cb3565733 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -2,7 +2,7 @@ use super::*; use crate::{ JoinLines, code_context_menus::CodeContextMenu, - inline_completion_tests::FakeInlineCompletionProvider, + edit_prediction_tests::FakeEditPredictionProvider, linked_editing_ranges::LinkedEditingRanges, scroll::scroll_amount::ScrollAmount, test::{ @@ -7251,12 +7251,12 @@ async fn test_undo_format_scrolls_to_last_edit_pos(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_undo_inline_completion_scrolls_to_edit_pos(cx: &mut TestAppContext) { +async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(provider.clone()), window, cx); }); @@ -7279,7 +7279,7 @@ async fn test_undo_inline_completion_scrolls_to_edit_pos(cx: &mut TestAppContext cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_inline_completion(Some(inline_completion::InlineCompletion { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction { id: None, edits: vec![(edit_position..edit_position, "X".into())], edit_preview: None, @@ -7287,7 +7287,7 @@ async fn test_undo_inline_completion_scrolls_to_edit_pos(cx: &mut TestAppContext }) }); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); cx.update_editor(|editor, window, cx| { editor.accept_edit_prediction(&crate::AcceptEditPrediction, window, cx) }); @@ -8612,6 +8612,7 @@ async fn test_autoclose_with_embedded_language(cx: &mut TestAppContext) { cx.language_registry().add(html_language.clone()); cx.language_registry().add(javascript_language.clone()); + cx.executor().run_until_parked(); cx.update_buffer(|buffer, cx| { buffer.set_language(Some(html_language), cx); @@ -10072,8 +10073,14 @@ async fn test_autosave_with_dirty_buffers(cx: &mut TestAppContext) { ); } -#[gpui::test] -async fn test_range_format_during_save(cx: &mut TestAppContext) { +async fn setup_range_format_test( + cx: &mut TestAppContext, +) -> ( + Entity<Project>, + Entity<Editor>, + &mut gpui::VisualTestContext, + lsp::FakeLanguageServer, +) { init_test(cx, |_| {}); let fs = FakeFs::new(cx.executor()); @@ -10088,9 +10095,9 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { FakeLspAdapter { capabilities: lsp::ServerCapabilities { document_range_formatting_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() + ..lsp::ServerCapabilities::default() }, - ..Default::default() + ..FakeLspAdapter::default() }, ); @@ -10105,14 +10112,22 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { let (editor, cx) = cx.add_window_view(|window, cx| { build_editor_with_project(project.clone(), buffer, window, cx) }); + + cx.executor().start_waiting(); + let fake_server = fake_servers.next().await.unwrap(); + + (project, editor, cx, fake_server) +} + +#[gpui::test] +async fn test_range_format_on_save_success(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; + editor.update_in(cx, |editor, window, cx| { editor.set_text("one\ntwo\nthree\n", window, cx) }); assert!(cx.read(|cx| editor.is_dirty(cx))); - cx.executor().start_waiting(); - let fake_server = fake_servers.next().await.unwrap(); - let save = editor .update_in(cx, |editor, window, cx| { editor.save( @@ -10147,13 +10162,18 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { "one, two\nthree\n" ); assert!(!cx.read(|cx| editor.is_dirty(cx))); +} + +#[gpui::test] +async fn test_range_format_on_save_timeout(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; editor.update_in(cx, |editor, window, cx| { editor.set_text("one\ntwo\nthree\n", window, cx) }); assert!(cx.read(|cx| editor.is_dirty(cx))); - // Ensure we can still save even if formatting hangs. + // Test that save still works when formatting hangs fake_server.set_request_handler::<lsp::request::RangeFormatting, _, _>( move |params, _| async move { assert_eq!( @@ -10185,8 +10205,13 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { "one\ntwo\nthree\n" ); assert!(!cx.read(|cx| editor.is_dirty(cx))); +} - // For non-dirty buffer, no formatting request should be sent +#[gpui::test] +async fn test_range_format_not_called_for_clean_buffer(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; + + // Buffer starts clean, no formatting should be requested let save = editor .update_in(cx, |editor, window, cx| { editor.save( @@ -10207,6 +10232,12 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { .next(); cx.executor().start_waiting(); save.await; + cx.run_until_parked(); +} + +#[gpui::test] +async fn test_range_format_respects_language_tab_size_override(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; // Set Rust language override and assert overridden tabsize is sent to language server update_test_language_settings(cx, |settings| { @@ -10220,7 +10251,7 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { }); editor.update_in(cx, |editor, window, cx| { - editor.set_text("somehting_new\n", window, cx) + editor.set_text("something_new\n", window, cx) }); assert!(cx.read(|cx| editor.is_dirty(cx))); let save = editor @@ -13370,6 +13401,178 @@ async fn test_as_is_completions(cx: &mut TestAppContext) { cx.assert_editor_state("fn a() {}\n unsafeˇ"); } +#[gpui::test] +async fn test_panic_during_c_completions(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + let language = + Arc::try_unwrap(languages::language("c", tree_sitter_c::LANGUAGE.into())).unwrap(); + let mut cx = EditorLspTestContext::new( + language, + lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + ..lsp::CompletionOptions::default() + }), + ..lsp::ServerCapabilities::default() + }, + cx, + ) + .await; + + cx.set_state( + "#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +ˇ", + ); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("#", window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("i", window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("n", window, cx); + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + "#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#inˇ", + ); + + cx.lsp + .set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move { + Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList { + is_incomplete: false, + item_defaults: None, + items: vec![lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::SNIPPET), + label_details: Some(lsp::CompletionItemLabelDetails { + detail: Some("header".to_string()), + description: None, + }), + label: " include".to_string(), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + range: lsp::Range { + start: lsp::Position { + line: 8, + character: 1, + }, + end: lsp::Position { + line: 8, + character: 1, + }, + }, + new_text: "include \"$0\"".to_string(), + })), + sort_text: Some("40b67681include".to_string()), + insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), + filter_text: Some("include".to_string()), + insert_text: Some("include \"$0\"".to_string()), + ..lsp::CompletionItem::default() + }], + }))) + }); + cx.update_editor(|editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.confirm_completion(&ConfirmCompletion::default(), window, cx) + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + "#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include \"ˇ\"", + ); + + cx.lsp + .set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move { + Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList { + is_incomplete: true, + item_defaults: None, + items: vec![lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::FILE), + label: "AGL/".to_string(), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + range: lsp::Range { + start: lsp::Position { + line: 8, + character: 10, + }, + end: lsp::Position { + line: 8, + character: 11, + }, + }, + new_text: "AGL/".to_string(), + })), + sort_text: Some("40b67681AGL/".to_string()), + insert_text_format: Some(lsp::InsertTextFormat::PLAIN_TEXT), + filter_text: Some("AGL/".to_string()), + insert_text: Some("AGL/".to_string()), + ..lsp::CompletionItem::default() + }], + }))) + }); + cx.update_editor(|editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.confirm_completion(&ConfirmCompletion::default(), window, cx) + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + r##"#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include "AGL/ˇ"##, + ); + + cx.update_editor(|editor, window, cx| { + editor.handle_input("\"", window, cx); + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + r##"#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include "AGL/"ˇ"##, + ); +} + #[gpui::test] async fn test_no_duplicated_completion_requests(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -20349,7 +20552,7 @@ async fn test_multi_buffer_navigation_with_folded_buffers(cx: &mut TestAppContex } #[gpui::test] -async fn test_inline_completion_text(cx: &mut TestAppContext) { +async fn test_edit_prediction_text(cx: &mut TestAppContext) { init_test(cx, |_| {}); // Simple insertion @@ -20448,7 +20651,7 @@ async fn test_inline_completion_text(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_inline_completion_text_with_deletions(cx: &mut TestAppContext) { +async fn test_edit_prediction_text_with_deletions(cx: &mut TestAppContext) { init_test(cx, |_| {}); // Deletion @@ -20538,7 +20741,7 @@ async fn assert_highlighted_edits( .await; cx.update(|_window, cx| { - let highlighted_edits = inline_completion_edit_text( + let highlighted_edits = edit_prediction_edit_text( &snapshot.as_singleton().unwrap().2, &edits, &edit_preview, @@ -21310,16 +21513,32 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex }, ); - let (buffer, _handle) = project - .update(cx, |p, cx| { - p.open_local_buffer_with_lsp(path!("/dir/a.ts"), cx) + let editor = workspace + .update(cx, |workspace, window, cx| { + workspace.open_abs_path( + PathBuf::from(path!("/dir/a.ts")), + OpenOptions::default(), + window, + cx, + ) }) + .unwrap() .await + .unwrap() + .downcast::<Editor>() .unwrap(); cx.executor().run_until_parked(); let fake_server = fake_language_servers.next().await.unwrap(); + let buffer = editor.update(cx, |editor, cx| { + editor + .buffer() + .read(cx) + .as_singleton() + .expect("have opened a single file by path") + }); + let buffer_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); let anchor = buffer_snapshot.anchor_at(0, text::Bias::Left); drop(buffer_snapshot); @@ -21377,7 +21596,7 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex assert_eq!( actions.len(), 1, - "Should have only one valid action for the 0..0 range" + "Should have only one valid action for the 0..0 range, got: {actions:#?}" ); let action = actions[0].clone(); let apply = project.update(cx, |project, cx| { @@ -21423,7 +21642,7 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex .into_iter() .collect(), ), - ..Default::default() + ..lsp::WorkspaceEdit::default() }, }, ) @@ -21446,6 +21665,38 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex buffer.undo(cx); assert_eq!(buffer.text(), "a"); }); + + let actions_after_edits = cx + .update_window(*workspace, |_, window, cx| { + project.code_actions(&buffer, anchor..anchor, window, cx) + }) + .unwrap() + .await + .unwrap(); + assert_eq!( + actions, actions_after_edits, + "For the same selection, same code lens actions should be returned" + ); + + let _responses = + fake_server.set_request_handler::<lsp::request::CodeLensRequest, _, _>(|_, _| async move { + panic!("No more code lens requests are expected"); + }); + editor.update_in(cx, |editor, window, cx| { + editor.select_all(&SelectAll, window, cx); + }); + cx.executor().run_until_parked(); + let new_actions = cx + .update_window(*workspace, |_, window, cx| { + project.code_actions(&buffer, anchor..anchor, window, cx) + }) + .unwrap() + .await + .unwrap(); + assert_eq!( + actions, new_actions, + "Code lens are queried for the same range and should get the same set back, but without additional LSP queries now" + ); } #[gpui::test] diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 7e77f113ac..268855ab61 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -3,11 +3,11 @@ use crate::{ CodeActionSource, ColumnarMode, ConflictsOurs, ConflictsOursMarker, ConflictsOuter, ConflictsTheirs, ConflictsTheirsMarker, ContextMenuPlacement, CursorShape, CustomBlockId, DisplayDiffHunk, DisplayPoint, DisplayRow, DocumentHighlightRead, DocumentHighlightWrite, - EditDisplayMode, Editor, EditorMode, EditorSettings, EditorSnapshot, EditorStyle, - FILE_HEADER_HEIGHT, FocusedBlock, GutterDimensions, HalfPageDown, HalfPageUp, HandleInput, - HoveredCursor, InlayHintRefreshReason, InlineCompletion, JumpData, LineDown, LineHighlight, - LineUp, MAX_LINE_LEN, MINIMAP_FONT_SIZE, MULTI_BUFFER_EXCERPT_HEADER_HEIGHT, OpenExcerpts, - PageDown, PageUp, PhantomBreakpointIndicator, Point, RowExt, RowRangeExt, SelectPhase, + EditDisplayMode, EditPrediction, Editor, EditorMode, EditorSettings, EditorSnapshot, + EditorStyle, FILE_HEADER_HEIGHT, FocusedBlock, GutterDimensions, HalfPageDown, HalfPageUp, + HandleInput, HoveredCursor, InlayHintRefreshReason, JumpData, LineDown, LineHighlight, LineUp, + MAX_LINE_LEN, MINIMAP_FONT_SIZE, MULTI_BUFFER_EXCERPT_HEADER_HEIGHT, OpenExcerpts, PageDown, + PageUp, PhantomBreakpointIndicator, Point, RowExt, RowRangeExt, SelectPhase, SelectedTextHighlight, Selection, SelectionDragState, SoftWrap, StickyHeaderExcerpt, ToPoint, ToggleFold, ToggleFoldAll, code_context_menus::{CodeActionsMenu, MENU_ASIDE_MAX_WIDTH, MENU_ASIDE_MIN_WIDTH, MENU_GAP}, @@ -554,7 +554,7 @@ impl EditorElement { register_action(editor, window, Editor::signature_help_next); register_action(editor, window, Editor::next_edit_prediction); register_action(editor, window, Editor::previous_edit_prediction); - register_action(editor, window, Editor::show_inline_completion); + register_action(editor, window, Editor::show_edit_prediction); register_action(editor, window, Editor::context_menu_first); register_action(editor, window, Editor::context_menu_prev); register_action(editor, window, Editor::context_menu_next); @@ -562,7 +562,7 @@ impl EditorElement { register_action(editor, window, Editor::display_cursor_names); register_action(editor, window, Editor::unique_lines_case_insensitive); register_action(editor, window, Editor::unique_lines_case_sensitive); - register_action(editor, window, Editor::accept_partial_inline_completion); + register_action(editor, window, Editor::accept_partial_edit_prediction); register_action(editor, window, Editor::accept_edit_prediction); register_action(editor, window, Editor::restore_file); register_action(editor, window, Editor::git_restore); @@ -2093,7 +2093,7 @@ impl EditorElement { row_block_types: &HashMap<DisplayRow, bool>, content_origin: gpui::Point<Pixels>, scroll_pixel_position: gpui::Point<Pixels>, - inline_completion_popover_origin: Option<gpui::Point<Pixels>>, + edit_prediction_popover_origin: Option<gpui::Point<Pixels>>, start_row: DisplayRow, end_row: DisplayRow, line_height: Pixels, @@ -2210,12 +2210,13 @@ impl EditorElement { cmp::max(padded_line, min_start) }; - let behind_inline_completion_popover = inline_completion_popover_origin - .as_ref() - .map_or(false, |inline_completion_popover_origin| { - (pos_y..pos_y + line_height).contains(&inline_completion_popover_origin.y) - }); - let opacity = if behind_inline_completion_popover { + let behind_edit_prediction_popover = edit_prediction_popover_origin.as_ref().map_or( + false, + |edit_prediction_popover_origin| { + (pos_y..pos_y + line_height).contains(&edit_prediction_popover_origin.y) + }, + ); + let opacity = if behind_edit_prediction_popover { 0.5 } else { 1.0 @@ -2427,9 +2428,9 @@ impl EditorElement { let mut padding = INLINE_BLAME_PADDING_EM_WIDTHS; - if let Some(inline_completion) = editor.active_inline_completion.as_ref() { - match &inline_completion.completion { - InlineCompletion::Edit { + if let Some(edit_prediction) = editor.active_edit_prediction.as_ref() { + match &edit_prediction.completion { + EditPrediction::Edit { display_mode: EditDisplayMode::TabAccept, .. } => padding += INLINE_ACCEPT_SUGGESTION_EM_WIDTHS, @@ -4086,8 +4087,7 @@ impl EditorElement { { let editor = self.editor.read(cx); - if editor - .edit_prediction_visible_in_cursor_popover(editor.has_active_inline_completion()) + if editor.edit_prediction_visible_in_cursor_popover(editor.has_active_edit_prediction()) { height_above_menu += editor.edit_prediction_cursor_popover_height() + POPOVER_Y_PADDING; @@ -6676,14 +6676,14 @@ impl EditorElement { } } - fn paint_inline_completion_popover( + fn paint_edit_prediction_popover( &mut self, layout: &mut EditorLayout, window: &mut Window, cx: &mut App, ) { - if let Some(inline_completion_popover) = layout.inline_completion_popover.as_mut() { - inline_completion_popover.paint(window, cx); + if let Some(edit_prediction_popover) = layout.edit_prediction_popover.as_mut() { + edit_prediction_popover.paint(window, cx); } } @@ -8501,7 +8501,7 @@ impl Element for EditorElement { ) }); - let (inline_completion_popover, inline_completion_popover_origin) = self + let (edit_prediction_popover, edit_prediction_popover_origin) = self .editor .update(cx, |editor, cx| { editor.render_edit_prediction_popover( @@ -8530,7 +8530,7 @@ impl Element for EditorElement { &row_block_types, content_origin, scroll_pixel_position, - inline_completion_popover_origin, + edit_prediction_popover_origin, start_row, end_row, line_height, @@ -8919,7 +8919,7 @@ impl Element for EditorElement { cursors, visible_cursors, selections, - inline_completion_popover, + edit_prediction_popover, diff_hunk_controls, mouse_context_menu, test_indicators, @@ -9001,7 +9001,7 @@ impl Element for EditorElement { self.paint_minimap(layout, window, cx); self.paint_scrollbars(layout, window, cx); - self.paint_inline_completion_popover(layout, window, cx); + self.paint_edit_prediction_popover(layout, window, cx); self.paint_mouse_context_menu(layout, window, cx); }); }) @@ -9102,7 +9102,7 @@ pub struct EditorLayout { expand_toggles: Vec<Option<(AnyElement, gpui::Point<Pixels>)>>, diff_hunk_controls: Vec<AnyElement>, crease_trailers: Vec<Option<CreaseTrailerLayout>>, - inline_completion_popover: Option<AnyElement>, + edit_prediction_popover: Option<AnyElement>, mouse_context_menu: Option<AnyElement>, tab_invisible: ShapedLine, space_invisible: ShapedLine, diff --git a/crates/editor/src/linked_editing_ranges.rs b/crates/editor/src/linked_editing_ranges.rs index 7c2672fc0d..a185de33ca 100644 --- a/crates/editor/src/linked_editing_ranges.rs +++ b/crates/editor/src/linked_editing_ranges.rs @@ -95,7 +95,7 @@ pub(super) fn refresh_linked_ranges( let snapshot = buffer.read(cx).snapshot(); let buffer_id = buffer.read(cx).remote_id(); - let linked_edits_task = project.linked_edit(buffer, *start, cx); + let linked_edits_task = project.linked_edits(buffer, *start, cx); let highlights = move || async move { let edits = linked_edits_task.await.log_err()?; // Find the range containing our current selection. diff --git a/crates/editor/src/lsp_colors.rs b/crates/editor/src/lsp_colors.rs index ce07dd43fe..08cf9078f2 100644 --- a/crates/editor/src/lsp_colors.rs +++ b/crates/editor/src/lsp_colors.rs @@ -6,7 +6,7 @@ use gpui::{Hsla, Rgba}; use itertools::Itertools; use language::point_from_lsp; use multi_buffer::Anchor; -use project::{DocumentColor, lsp_store::ColorFetchStrategy}; +use project::{DocumentColor, lsp_store::LspFetchStrategy}; use settings::Settings as _; use text::{Bias, BufferId, OffsetRangeExt as _}; use ui::{App, Context, Window}; @@ -180,9 +180,9 @@ impl Editor { .filter_map(|buffer| { let buffer_id = buffer.read(cx).remote_id(); let fetch_strategy = if ignore_cache { - ColorFetchStrategy::IgnoreCache + LspFetchStrategy::IgnoreCache } else { - ColorFetchStrategy::UseCache { + LspFetchStrategy::UseCache { known_cache_version: self.colors.as_ref().and_then(|colors| { Some(colors.buffer_colors.get(&buffer_id)?.cache_version_used) }), diff --git a/crates/editor/src/movement.rs b/crates/editor/src/movement.rs index b9b7cb2e58..a8850984a1 100644 --- a/crates/editor/src/movement.rs +++ b/crates/editor/src/movement.rs @@ -907,12 +907,12 @@ mod tests { let inlays = (0..buffer_snapshot.len()) .flat_map(|offset| { [ - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut id), buffer_snapshot.anchor_at(offset, Bias::Left), "test", ), - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut id), buffer_snapshot.anchor_at(offset, Bias::Right), "test", diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index d5db7f71a4..a0214c76a1 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -19,8 +19,8 @@ path = "src/explorer.rs" [dependencies] agent.workspace = true -agent_ui.workspace = true agent_settings.workspace = true +agent_ui.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true @@ -29,6 +29,7 @@ buffer_diff.workspace = true chrono.workspace = true clap.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true debug_adapter_extension.workspace = true dirs.workspace = true @@ -68,4 +69,3 @@ util.workspace = true uuid.workspace = true watch.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index a02b4a7f0b..d638ac171f 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -18,7 +18,7 @@ use collections::{HashMap, HashSet}; use extension::ExtensionHostProxy; use futures::future; use gpui::http_client::read_proxy_from_env; -use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal}; +use gpui::{App, AppContext, Application, AsyncApp, Entity, UpdateGlobal}; use gpui_tokio::Tokio; use language::LanguageRegistry; use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel}; @@ -337,7 +337,8 @@ pub struct AgentAppState { } pub fn init(cx: &mut App) -> Arc<AgentAppState> { - release_channel::init(SemanticVersion::default(), cx); + let app_version = AppVersion::global(cx); + release_channel::init(app_version, cx); gpui_tokio::init(cx); let mut settings_store = SettingsStore::new(cx); @@ -350,7 +351,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> { // Set User-Agent so we can download language servers from GitHub let user_agent = format!( "Zed/{} ({}; {})", - AppVersion::global(cx), + app_version, std::env::consts::OS, std::env::consts::ARCH ); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 7ce3b1fdf1..23c8814916 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -15,11 +15,11 @@ use agent_settings::AgentProfileId; use anyhow::{Result, anyhow}; use async_trait::async_trait; use buffer_diff::DiffHunkStatus; +use cloud_llm_client::CompletionIntent; use collections::HashMap; use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased}; use gpui::{App, AppContext, AsyncApp, Entity}; use language_model::{LanguageModel, Role, StopReason}; -use zed_llm_client::CompletionIntent; pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index e3235cf561..5852b3e3fc 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -163,7 +163,7 @@ pub struct LanguageServerManifestEntry { #[serde(default)] languages: Vec<LanguageName>, #[serde(default)] - pub language_ids: HashMap<String, String>, + pub language_ids: HashMap<LanguageName, String>, #[serde(default)] pub code_action_kinds: Option<Vec<lsp::CodeActionKind>>, } diff --git a/crates/extension_host/src/wasm_host.rs b/crates/extension_host/src/wasm_host.rs index 1f6f5035e3..d990b670f4 100644 --- a/crates/extension_host/src/wasm_host.rs +++ b/crates/extension_host/src/wasm_host.rs @@ -106,7 +106,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_initialization_options( @@ -131,7 +131,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_workspace_configuration( @@ -154,7 +154,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_additional_initialization_options( @@ -179,7 +179,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_additional_workspace_configuration( @@ -204,7 +204,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn labels_for_completions( @@ -230,7 +230,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn labels_for_symbols( @@ -256,7 +256,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn complete_slash_command_argument( @@ -275,7 +275,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn run_slash_command( @@ -301,7 +301,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn context_server_command( @@ -320,7 +320,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn context_server_configuration( @@ -347,7 +347,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> { @@ -362,7 +362,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn index_docs( @@ -388,7 +388,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn get_dap_binary( @@ -410,7 +410,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn dap_request_kind( &self, @@ -427,7 +427,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn dap_config_to_scenario(&self, config: ZedDebugConfig) -> Result<DebugScenario> { @@ -441,7 +441,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn dap_locator_create_scenario( @@ -465,7 +465,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn run_dap_locator( &self, @@ -481,7 +481,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } } @@ -761,7 +761,7 @@ impl WasmExtension { .with_context(|| format!("failed to load wasm extension {}", manifest.id)) } - pub async fn call<T, Fn>(&self, f: Fn) -> T + pub async fn call<T, Fn>(&self, f: Fn) -> Result<T> where T: 'static + Send, Fn: 'static @@ -777,8 +777,19 @@ impl WasmExtension { } .boxed() })) - .expect("wasm extension channel should not be closed yet"); - return_rx.await.expect("wasm extension channel") + .map_err(|_| { + anyhow!( + "wasm extension channel should not be closed yet, extension {} (id {})", + self.manifest.name, + self.manifest.id, + ) + })?; + return_rx.await.with_context(|| { + format!( + "wasm extension channel, extension {} (id {})", + self.manifest.name, self.manifest.id, + ) + }) } } @@ -799,8 +810,19 @@ impl WasmState { } .boxed_local() })) - .expect("main thread message channel should not be closed yet"); - async move { return_rx.await.expect("main thread message channel") } + .unwrap_or_else(|_| { + panic!( + "main thread message channel should not be closed yet, extension {} (id {})", + self.manifest.name, self.manifest.id, + ) + }); + let name = self.manifest.name.clone(); + let id = self.manifest.id.clone(); + async move { + return_rx.await.unwrap_or_else(|_| { + panic!("main thread message channel, extension {name} (id {id})") + }) + } } fn work_dir(&self) -> PathBuf { diff --git a/crates/file_finder/src/file_finder.rs b/crates/file_finder/src/file_finder.rs index a4d61dd56f..e5ac70bb58 100644 --- a/crates/file_finder/src/file_finder.rs +++ b/crates/file_finder/src/file_finder.rs @@ -1404,14 +1404,21 @@ impl PickerDelegate for FileFinderDelegate { } else { let path_position = PathWithPosition::parse_str(&raw_query); + #[cfg(windows)] + let raw_query = raw_query.trim().to_owned().replace("/", "\\"); + #[cfg(not(windows))] + let raw_query = raw_query.trim().to_owned(); + + let file_query_end = if path_position.path.to_str().unwrap_or(&raw_query) == raw_query { + None + } else { + // Safe to unwrap as we won't get here when the unwrap in if fails + Some(path_position.path.to_str().unwrap().len()) + }; + let query = FileSearchQuery { - raw_query: raw_query.trim().to_owned(), - file_query_end: if path_position.path.to_str().unwrap_or(raw_query) == raw_query { - None - } else { - // Safe to unwrap as we won't get here when the unwrap in if fails - Some(path_position.path.to_str().unwrap().len()) - }, + raw_query, + file_query_end, path_position, }; diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 378a8fb7df..04ba656232 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -10,7 +10,7 @@ use git::{ }, status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus}, }; -use gpui::{AsyncApp, BackgroundExecutor}; +use gpui::{AsyncApp, BackgroundExecutor, SharedString}; use ignore::gitignore::GitignoreBuilder; use rope::Rope; use smol::future::FutureExt as _; @@ -491,4 +491,8 @@ impl GitRepository for FakeGitRepository { ) -> BoxFuture<'_, Result<String>> { unimplemented!() } + + fn default_branch(&self) -> BoxFuture<'_, Result<Option<SharedString>>> { + unimplemented!() + } } diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index a63315e69e..b536bed710 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -463,6 +463,8 @@ pub trait GitRepository: Send + Sync { base_checkpoint: GitRepositoryCheckpoint, target_checkpoint: GitRepositoryCheckpoint, ) -> BoxFuture<'_, Result<String>>; + + fn default_branch(&self) -> BoxFuture<'_, Result<Option<SharedString>>>; } pub enum DiffType { @@ -1607,6 +1609,37 @@ impl GitRepository for RealGitRepository { }) .boxed() } + + fn default_branch(&self) -> BoxFuture<'_, Result<Option<SharedString>>> { + let working_directory = self.working_directory(); + let git_binary_path = self.git_binary_path.clone(); + + let executor = self.executor.clone(); + self.executor + .spawn(async move { + let working_directory = working_directory?; + let git = GitBinary::new(git_binary_path, working_directory, executor); + + if let Ok(output) = git + .run(&["symbolic-ref", "refs/remotes/upstream/HEAD"]) + .await + { + let output = output + .strip_prefix("refs/remotes/upstream/") + .map(|s| SharedString::from(s.to_owned())); + return Ok(output); + } + + let output = git + .run(&["symbolic-ref", "refs/remotes/origin/HEAD"]) + .await?; + + Ok(output + .strip_prefix("refs/remotes/origin/") + .map(|s| SharedString::from(s.to_owned()))) + }) + .boxed() + } } fn git_status_args(path_prefixes: &[RepoPath]) -> Vec<OsString> { diff --git a/crates/git_ui/Cargo.toml b/crates/git_ui/Cargo.toml index 2fb80b7e73..35f7a60354 100644 --- a/crates/git_ui/Cargo.toml +++ b/crates/git_ui/Cargo.toml @@ -23,7 +23,7 @@ askpass.workspace = true buffer_diff.workspace = true call.workspace = true chrono.workspace = true -client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true component.workspace = true @@ -62,7 +62,6 @@ watch.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true [target.'cfg(windows)'.dependencies] windows.workspace = true @@ -71,6 +70,7 @@ windows.workspace = true ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } +indoc.workspace = true pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 9eac3ce5af..1092ba33d1 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -13,7 +13,7 @@ use project::git_store::Repository; use std::sync::Arc; use time::OffsetDateTime; use time_format::format_local_timestamp; -use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*}; +use ui::{HighlightedLabel, ListItem, ListItemSpacing, Tooltip, prelude::*}; use util::ResultExt; use workspace::notifications::DetachAndPromptErr; use workspace::{ModalView, Workspace}; @@ -90,11 +90,21 @@ impl BranchList { let all_branches_request = repository .clone() .map(|repository| repository.update(cx, |repository, _| repository.branches())); + let default_branch_request = repository + .clone() + .map(|repository| repository.update(cx, |repository, _| repository.default_branch())); cx.spawn_in(window, async move |this, cx| { let mut all_branches = all_branches_request .context("No active repository")? .await??; + let default_branch = default_branch_request + .context("No active repository")? + .await + .map(Result::ok) + .ok() + .flatten() + .flatten(); let all_branches = cx .background_spawn(async move { @@ -124,6 +134,7 @@ impl BranchList { this.update_in(cx, |this, window, cx| { this.picker.update(cx, |picker, cx| { + picker.delegate.default_branch = default_branch; picker.delegate.all_branches = Some(all_branches); picker.refresh(window, cx); }) @@ -192,6 +203,7 @@ struct BranchEntry { pub struct BranchListDelegate { matches: Vec<BranchEntry>, all_branches: Option<Vec<Branch>>, + default_branch: Option<SharedString>, repo: Option<Entity<Repository>>, style: BranchListStyle, selected_index: usize, @@ -206,6 +218,7 @@ impl BranchListDelegate { repo, style, all_branches: None, + default_branch: None, selected_index: 0, last_query: Default::default(), modifiers: Default::default(), @@ -214,6 +227,7 @@ impl BranchListDelegate { fn create_branch( &self, + from_branch: Option<SharedString>, new_branch_name: SharedString, window: &mut Window, cx: &mut Context<Picker<Self>>, @@ -223,6 +237,11 @@ impl BranchListDelegate { }; let new_branch_name = new_branch_name.to_string().replace(' ', "-"); cx.spawn(async move |_, cx| { + if let Some(based_branch) = from_branch { + repo.update(cx, |repo, _| repo.change_branch(based_branch.to_string()))? + .await??; + } + repo.update(cx, |repo, _| { repo.create_branch(new_branch_name.to_string()) })? @@ -353,12 +372,22 @@ impl PickerDelegate for BranchListDelegate { }) } - fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) { + fn confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) { let Some(entry) = self.matches.get(self.selected_index()) else { return; }; if entry.is_new { - self.create_branch(entry.branch.name().to_owned().into(), window, cx); + let from_branch = if secondary { + self.default_branch.clone() + } else { + None + }; + self.create_branch( + from_branch, + entry.branch.name().to_owned().into(), + window, + cx, + ); return; } @@ -439,6 +468,28 @@ impl PickerDelegate for BranchListDelegate { }) .unwrap_or_else(|| (None, None)); + let icon = if let Some(default_branch) = self.default_branch.clone() + && entry.is_new + { + Some( + IconButton::new("branch-from-default", IconName::GitBranchSmall) + .on_click(cx.listener(move |this, _, window, cx| { + this.delegate.set_selected_index(ix, window, cx); + this.delegate.confirm(true, window, cx); + })) + .tooltip(move |window, cx| { + Tooltip::for_action( + format!("Create branch based off default: {default_branch}"), + &menu::SecondaryConfirm, + window, + cx, + ) + }), + ) + } else { + None + }; + let branch_name = if entry.is_new { h_flex() .gap_1() @@ -504,7 +555,8 @@ impl PickerDelegate for BranchListDelegate { .color(Color::Muted) })) }), - ), + ) + .end_slot::<IconButton>(icon), ) } diff --git a/crates/git_ui/src/commit_modal.rs b/crates/git_ui/src/commit_modal.rs index b99f628806..5dfa800ae5 100644 --- a/crates/git_ui/src/commit_modal.rs +++ b/crates/git_ui/src/commit_modal.rs @@ -1,9 +1,9 @@ use crate::branch_picker::{self, BranchList}; use crate::git_panel::{GitPanel, commit_message_editor}; -use client::DisableAiSettings; use git::repository::CommitOptions; use git::{Amend, Commit, GenerateCommitMessage, Signoff}; use panel::{panel_button, panel_editor_style}; +use project::DisableAiSettings; use settings::Settings; use ui::{ ContextMenu, KeybindingHint, PopoverMenu, PopoverMenuHandle, SplitButton, Tooltip, prelude::*, @@ -295,11 +295,13 @@ impl CommitModal { IconPosition::Start, Some(Box::new(Amend)), { - let git_panel = git_panel_entity.clone(); - move |window, cx| { - git_panel.update(cx, |git_panel, cx| { - git_panel.toggle_amend_pending(&Amend, window, cx); - }) + let git_panel = git_panel_entity.downgrade(); + move |_, cx| { + git_panel + .update(cx, |git_panel, cx| { + git_panel.toggle_amend_pending(cx); + }) + .ok(); } }, ) diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 725a1b6db5..44222b8299 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -12,7 +12,6 @@ use crate::{ use agent_settings::AgentSettings; use anyhow::Context as _; use askpass::AskPassDelegate; -use client::DisableAiSettings; use db::kvp::KEY_VALUE_STORE; use editor::{ Editor, EditorElement, EditorMode, EditorSettings, MultiBuffer, ShowScrollbar, @@ -51,10 +50,9 @@ use panel::{ PanelHeader, panel_button, panel_editor_container, panel_editor_style, panel_filled_button, panel_icon_button, }; -use project::git_store::{RepositoryEvent, RepositoryId}; use project::{ - Fs, Project, ProjectPath, - git_store::{GitStoreEvent, Repository}, + DisableAiSettings, Fs, Project, ProjectPath, + git_store::{GitStoreEvent, Repository, RepositoryEvent, RepositoryId}, }; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -71,12 +69,12 @@ use ui::{ use util::{ResultExt, TryFutureExt, maybe}; use workspace::SERIALIZATION_THROTTLE_TIME; +use cloud_llm_client::CompletionIntent; use workspace::{ Workspace, dock::{DockPosition, Panel, PanelEvent}, notifications::{DetachAndPromptErr, ErrorMessagePrompt, NotificationId}, }; -use zed_llm_client::CompletionIntent; actions!( git_panel, @@ -2416,7 +2414,7 @@ impl GitPanel { .committer_name .clone() .or_else(|| participant.user.name.clone()) - .unwrap_or_else(|| participant.user.github_login.clone()); + .unwrap_or_else(|| participant.user.github_login.clone().to_string()); new_co_authors.push((name.clone(), email.clone())) } } @@ -2436,7 +2434,7 @@ impl GitPanel { .name .clone() .or_else(|| user.name.clone()) - .unwrap_or_else(|| user.github_login.clone()); + .unwrap_or_else(|| user.github_login.clone().to_string()); Some((name, email)) } @@ -2901,7 +2899,9 @@ impl GitPanel { let status_toast = StatusToast::new(message, cx, move |this, _cx| { use remote_output::SuccessStyle::*; match style { - Toast { .. } => this, + Toast { .. } => { + this.icon(ToastIcon::new(IconName::GitBranchSmall).color(Color::Muted)) + } ToastWithLog { output } => this .icon(ToastIcon::new(IconName::GitBranchSmall).color(Color::Muted)) .action("View Log", move |window, cx| { @@ -2914,9 +2914,9 @@ impl GitPanel { }) .ok(); }), - PushPrLink { link } => this + PushPrLink { text, link } => this .icon(ToastIcon::new(IconName::GitBranchSmall).color(Color::Muted)) - .action("Open Pull Request", move |_, cx| cx.open_url(&link)), + .action(text, move |_, cx| cx.open_url(&link)), } }); workspace.toggle_status_toast(status_toast, cx) @@ -3113,6 +3113,7 @@ impl GitPanel { ), ) .menu({ + let git_panel = cx.entity(); let has_previous_commit = self.head_commit(cx).is_some(); let amend = self.amend_pending(); let signoff = self.signoff_enabled; @@ -3129,7 +3130,16 @@ impl GitPanel { amend, IconPosition::Start, Some(Box::new(Amend)), - move |window, cx| window.dispatch_action(Box::new(Amend), cx), + { + let git_panel = git_panel.downgrade(); + move |_, cx| { + git_panel + .update(cx, |git_panel, cx| { + git_panel.toggle_amend_pending(cx); + }) + .ok(); + } + }, ) }) .toggleable_entry( @@ -3500,9 +3510,11 @@ impl GitPanel { .truncate(), ), ) - .child(panel_button("Cancel").size(ButtonSize::Default).on_click( - cx.listener(|this, _, window, cx| this.toggle_amend_pending(&Amend, window, cx)), - )) + .child( + panel_button("Cancel") + .size(ButtonSize::Default) + .on_click(cx.listener(|this, _, _, cx| this.set_amend_pending(false, cx))), + ) } fn render_previous_commit(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> { @@ -4263,17 +4275,8 @@ impl GitPanel { pub fn set_amend_pending(&mut self, value: bool, cx: &mut Context<Self>) { self.amend_pending = value; - cx.notify(); - } - - pub fn toggle_amend_pending( - &mut self, - _: &Amend, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - self.set_amend_pending(!self.amend_pending, cx); self.serialize(cx); + cx.notify(); } pub fn signoff_enabled(&self) -> bool { @@ -4367,6 +4370,13 @@ impl GitPanel { anchor: path, }); } + + pub(crate) fn toggle_amend_pending(&mut self, cx: &mut Context<Self>) { + self.set_amend_pending(!self.amend_pending, cx); + if self.amend_pending { + self.load_last_commit_message_if_empty(cx); + } + } } fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn LanguageModel>> { @@ -4411,7 +4421,6 @@ impl Render for GitPanel { .on_action(cx.listener(Self::stage_range)) .on_action(cx.listener(GitPanel::commit)) .on_action(cx.listener(GitPanel::amend)) - .on_action(cx.listener(GitPanel::toggle_amend_pending)) .on_action(cx.listener(GitPanel::toggle_signoff_enabled)) .on_action(cx.listener(Self::stage_all)) .on_action(cx.listener(Self::unstage_all)) @@ -5106,7 +5115,6 @@ mod tests { language::init(cx); editor::init(cx); Project::init_settings(cx); - client::DisableAiSettings::register(cx); crate::init(cx); }); } diff --git a/crates/git_ui/src/remote_output.rs b/crates/git_ui/src/remote_output.rs index 03fbf4f917..8437bf0d0d 100644 --- a/crates/git_ui/src/remote_output.rs +++ b/crates/git_ui/src/remote_output.rs @@ -24,7 +24,7 @@ impl RemoteAction { pub enum SuccessStyle { Toast, ToastWithLog { output: RemoteCommandOutput }, - PushPrLink { link: String }, + PushPrLink { text: String, link: String }, } pub struct SuccessMessage { @@ -37,7 +37,7 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ RemoteAction::Fetch(remote) => { if output.stderr.is_empty() { SuccessMessage { - message: "Already up to date".into(), + message: "Fetch: Already up to date".into(), style: SuccessStyle::Toast, } } else { @@ -68,10 +68,9 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ Ok(files_changed) }; - - if output.stderr.starts_with("Everything up to date") { + if output.stdout.ends_with("Already up to date.\n") { SuccessMessage { - message: output.stderr.trim().to_owned(), + message: "Pull: Already up to date".into(), style: SuccessStyle::Toast, } } else if output.stdout.starts_with("Updating") { @@ -119,48 +118,42 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ } } RemoteAction::Push(branch_name, remote_ref) => { - if output.stderr.contains("* [new branch]") { - let pr_hints = [ - // GitHub - "Create a pull request", - // Bitbucket - "Create pull request", - // GitLab - "create a merge request", - ]; - let style = if pr_hints - .iter() - .any(|indicator| output.stderr.contains(indicator)) - { - let finder = LinkFinder::new(); - let first_link = finder - .links(&output.stderr) - .filter(|link| *link.kind() == LinkKind::Url) - .map(|link| link.start()..link.end()) - .next(); - if let Some(link) = first_link { - let link = output.stderr[link].to_string(); - SuccessStyle::PushPrLink { link } - } else { - SuccessStyle::ToastWithLog { output } - } - } else { - SuccessStyle::ToastWithLog { output } - }; - SuccessMessage { - message: format!("Published {} to {}", branch_name, remote_ref.name), - style, - } - } else if output.stderr.starts_with("Everything up to date") { - SuccessMessage { - message: output.stderr.trim().to_owned(), - style: SuccessStyle::Toast, - } + let message = if output.stderr.ends_with("Everything up-to-date\n") { + "Push: Everything is up-to-date".to_string() } else { - SuccessMessage { - message: format!("Pushed {} to {}", branch_name, remote_ref.name), - style: SuccessStyle::ToastWithLog { output }, - } + format!("Pushed {} to {}", branch_name, remote_ref.name) + }; + + let style = if output.stderr.ends_with("Everything up-to-date\n") { + Some(SuccessStyle::Toast) + } else if output.stderr.contains("\nremote: ") { + let pr_hints = [ + ("Create a pull request", "Create Pull Request"), // GitHub + ("Create pull request", "Create Pull Request"), // Bitbucket + ("create a merge request", "Create Merge Request"), // GitLab + ("View merge request", "View Merge Request"), // GitLab + ]; + pr_hints + .iter() + .find(|(indicator, _)| output.stderr.contains(indicator)) + .and_then(|(_, mapped)| { + let finder = LinkFinder::new(); + finder + .links(&output.stderr) + .filter(|link| *link.kind() == LinkKind::Url) + .map(|link| link.start()..link.end()) + .next() + .map(|link| SuccessStyle::PushPrLink { + text: mapped.to_string(), + link: output.stderr[link].to_string(), + }) + }) + } else { + None + }; + SuccessMessage { + message, + style: style.unwrap_or(SuccessStyle::ToastWithLog { output }), } } } @@ -169,6 +162,7 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ #[cfg(test)] mod tests { use super::*; + use indoc::indoc; #[test] fn test_push_new_branch_pull_request() { @@ -181,8 +175,7 @@ mod tests { let output = RemoteCommandOutput { stdout: String::new(), - stderr: String::from( - " + stderr: indoc! { " Total 0 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0) remote: remote: Create a pull request for 'test' on GitHub by visiting: @@ -190,13 +183,14 @@ mod tests { remote: To example.com:test/test.git * [new branch] test -> test - ", - ), + "} + .to_string(), }; let msg = format_output(&action, output); - if let SuccessStyle::PushPrLink { link } = &msg.style { + if let SuccessStyle::PushPrLink { text: hint, link } = &msg.style { + assert_eq!(hint, "Create Pull Request"); assert_eq!(link, "https://example.com/test/test/pull/new/test"); } else { panic!("Expected PushPrLink variant"); @@ -214,7 +208,7 @@ mod tests { let output = RemoteCommandOutput { stdout: String::new(), - stderr: String::from(" + stderr: indoc! {" Total 0 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0) remote: remote: To create a merge request for test, visit: @@ -222,12 +216,14 @@ mod tests { remote: To example.com:test/test.git * [new branch] test -> test - "), - }; + "} + .to_string() + }; let msg = format_output(&action, output); - if let SuccessStyle::PushPrLink { link } = &msg.style { + if let SuccessStyle::PushPrLink { text, link } = &msg.style { + assert_eq!(text, "Create Merge Request"); assert_eq!( link, "https://example.com/test/test/-/merge_requests/new?merge_request%5Bsource_branch%5D=test" @@ -237,6 +233,39 @@ mod tests { } } + #[test] + fn test_push_branch_existing_merge_request() { + let action = RemoteAction::Push( + SharedString::new("test_branch"), + Remote { + name: SharedString::new("test_remote"), + }, + ); + + let output = RemoteCommandOutput { + stdout: String::new(), + stderr: indoc! {" + Total 0 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0) + remote: + remote: View merge request for test: + remote: https://example.com/test/test/-/merge_requests/99999 + remote: + To example.com:test/test.git + + 80bd3c83be...e03d499d2e test -> test + "} + .to_string(), + }; + + let msg = format_output(&action, output); + + if let SuccessStyle::PushPrLink { text, link } = &msg.style { + assert_eq!(text, "View Merge Request"); + assert_eq!(link, "https://example.com/test/test/-/merge_requests/99999"); + } else { + panic!("Expected PushPrLink variant"); + } + } + #[test] fn test_push_new_branch_no_link() { let action = RemoteAction::Push( @@ -248,12 +277,12 @@ mod tests { let output = RemoteCommandOutput { stdout: String::new(), - stderr: String::from( - " + stderr: indoc! { " To http://example.com/test/test.git * [new branch] test -> test ", - ), + } + .to_string(), }; let msg = format_output(&action, output); @@ -261,10 +290,7 @@ mod tests { if let SuccessStyle::ToastWithLog { output } = &msg.style { assert_eq!( output.stderr, - " - To http://example.com/test/test.git - * [new branch] test -> test - " + "To http://example.com/test/test.git\n * [new branch] test -> test\n" ); } else { panic!("Expected ToastWithLog variant"); diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 29e81269e3..2bf49fa7d8 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -216,10 +216,6 @@ xim = { git = "https://github.com/XDeme1/xim-rs", rev = "d50d461764c2213655cd9cf x11-clipboard = { version = "0.9.3", optional = true } [target.'cfg(target_os = "windows")'.dependencies] -blade-util.workspace = true -bytemuck = "1" -blade-graphics.workspace = true -blade-macros.workspace = true flume = "0.11" rand.workspace = true windows.workspace = true @@ -240,7 +236,6 @@ util = { workspace = true, features = ["test-support"] } [target.'cfg(target_os = "windows")'.build-dependencies] embed-resource = "3.0" -naga.workspace = true [target.'cfg(target_os = "macos")'.build-dependencies] bindgen = "0.71" @@ -287,6 +282,10 @@ path = "examples/shadow.rs" name = "svg" path = "examples/svg/svg.rs" +[[example]] +name = "tab_stop" +path = "examples/tab_stop.rs" + [[example]] name = "text" path = "examples/text.rs" diff --git a/crates/gpui/build.rs b/crates/gpui/build.rs index 7ab44a73f5..93a1c15c41 100644 --- a/crates/gpui/build.rs +++ b/crates/gpui/build.rs @@ -9,7 +9,10 @@ fn main() { let target = env::var("CARGO_CFG_TARGET_OS"); println!("cargo::rustc-check-cfg=cfg(gles)"); - #[cfg(any(not(target_os = "macos"), feature = "macos-blade"))] + #[cfg(any( + not(any(target_os = "macos", target_os = "windows")), + all(target_os = "macos", feature = "macos-blade") + ))] check_wgsl_shaders(); match target.as_deref() { @@ -17,21 +20,18 @@ fn main() { #[cfg(target_os = "macos")] macos::build(); } - #[cfg(all(target_os = "windows", feature = "windows-manifest"))] Ok("windows") => { - let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); - let rc_file = std::path::Path::new("resources/windows/gpui.rc"); - println!("cargo:rerun-if-changed={}", manifest.display()); - println!("cargo:rerun-if-changed={}", rc_file.display()); - embed_resource::compile(rc_file, embed_resource::NONE) - .manifest_required() - .unwrap(); + #[cfg(target_os = "windows")] + windows::build(); } _ => (), }; } -#[allow(dead_code)] +#[cfg(any( + not(any(target_os = "macos", target_os = "windows")), + all(target_os = "macos", feature = "macos-blade") +))] fn check_wgsl_shaders() { use std::path::PathBuf; use std::process; @@ -243,3 +243,215 @@ mod macos { } } } + +#[cfg(target_os = "windows")] +mod windows { + use std::{ + fs, + io::Write, + path::{Path, PathBuf}, + process::{self, Command}, + }; + + pub(super) fn build() { + // Compile HLSL shaders + #[cfg(not(debug_assertions))] + compile_shaders(); + + // Embed the Windows manifest and resource file + #[cfg(feature = "windows-manifest")] + embed_resource(); + } + + #[cfg(feature = "windows-manifest")] + fn embed_resource() { + let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); + let rc_file = std::path::Path::new("resources/windows/gpui.rc"); + println!("cargo:rerun-if-changed={}", manifest.display()); + println!("cargo:rerun-if-changed={}", rc_file.display()); + embed_resource::compile(rc_file, embed_resource::NONE) + .manifest_required() + .unwrap(); + } + + /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. + fn compile_shaders() { + let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("src/platform/windows/shaders.hlsl"); + let out_dir = std::env::var("OUT_DIR").unwrap(); + + println!("cargo:rerun-if-changed={}", shader_path.display()); + + // Check if fxc.exe is available + let fxc_path = find_fxc_compiler(); + + // Define all modules + let modules = [ + "quad", + "shadow", + "path_rasterization", + "path_sprite", + "underline", + "monochrome_sprite", + "polychrome_sprite", + ]; + + let rust_binding_path = format!("{}/shaders_bytes.rs", out_dir); + if Path::new(&rust_binding_path).exists() { + fs::remove_file(&rust_binding_path) + .expect("Failed to remove existing Rust binding file"); + } + for module in modules { + compile_shader_for_module( + module, + &out_dir, + &fxc_path, + shader_path.to_str().unwrap(), + &rust_binding_path, + ); + } + + { + let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("src/platform/windows/color_text_raster.hlsl"); + compile_shader_for_module( + "emoji_rasterization", + &out_dir, + &fxc_path, + shader_path.to_str().unwrap(), + &rust_binding_path, + ); + } + } + + /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. + fn find_fxc_compiler() -> String { + // Check environment variable + if let Ok(path) = std::env::var("GPUI_FXC_PATH") { + if Path::new(&path).exists() { + return path; + } + } + + // Try to find in PATH + // NOTE: This has to be `where.exe` on Windows, not `where`, it must be ended with `.exe` + if let Ok(output) = std::process::Command::new("where.exe") + .arg("fxc.exe") + .output() + { + if output.status.success() { + let path = String::from_utf8_lossy(&output.stdout); + return path.trim().to_string(); + } + } + + // Check the default path + if Path::new(r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe") + .exists() + { + return r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe" + .to_string(); + } + + panic!("Failed to find fxc.exe"); + } + + fn compile_shader_for_module( + module: &str, + out_dir: &str, + fxc_path: &str, + shader_path: &str, + rust_binding_path: &str, + ) { + // Compile vertex shader + let output_file = format!("{}/{}_vs.h", out_dir, module); + let const_name = format!("{}_VERTEX_BYTES", module.to_uppercase()); + compile_shader_impl( + fxc_path, + &format!("{module}_vertex"), + &output_file, + &const_name, + shader_path, + "vs_4_1", + ); + generate_rust_binding(&const_name, &output_file, &rust_binding_path); + + // Compile fragment shader + let output_file = format!("{}/{}_ps.h", out_dir, module); + let const_name = format!("{}_FRAGMENT_BYTES", module.to_uppercase()); + compile_shader_impl( + fxc_path, + &format!("{module}_fragment"), + &output_file, + &const_name, + shader_path, + "ps_4_1", + ); + generate_rust_binding(&const_name, &output_file, &rust_binding_path); + } + + fn compile_shader_impl( + fxc_path: &str, + entry_point: &str, + output_path: &str, + var_name: &str, + shader_path: &str, + target: &str, + ) { + let output = Command::new(fxc_path) + .args([ + "/T", + target, + "/E", + entry_point, + "/Fh", + output_path, + "/Vn", + var_name, + "/O3", + shader_path, + ]) + .output(); + + match output { + Ok(result) => { + if result.status.success() { + return; + } + eprintln!( + "Shader compilation failed for {}:\n{}", + entry_point, + String::from_utf8_lossy(&result.stderr) + ); + process::exit(1); + } + Err(e) => { + eprintln!("Failed to run fxc for {}: {}", entry_point, e); + process::exit(1); + } + } + } + + fn generate_rust_binding(const_name: &str, head_file: &str, output_path: &str) { + let header_content = fs::read_to_string(head_file).expect("Failed to read header file"); + let const_definition = { + let global_var_start = header_content.find("const BYTE").unwrap(); + let global_var = &header_content[global_var_start..]; + let equal = global_var.find('=').unwrap(); + global_var[equal + 1..].trim() + }; + let rust_binding = format!( + "const {}: &[u8] = &{}\n", + const_name, + const_definition.replace('{', "[").replace('}', "]") + ); + let mut options = fs::OpenOptions::new() + .create(true) + .append(true) + .open(output_path) + .expect("Failed to open Rust binding file"); + options + .write_all(rust_binding.as_bytes()) + .expect("Failed to write Rust binding file"); + } +} diff --git a/crates/gpui/examples/tab_stop.rs b/crates/gpui/examples/tab_stop.rs index 9c58b52a5e..1f6500f3e6 100644 --- a/crates/gpui/examples/tab_stop.rs +++ b/crates/gpui/examples/tab_stop.rs @@ -6,6 +6,7 @@ use gpui::{ actions!(example, [Tab, TabPrev]); struct Example { + focus_handle: FocusHandle, items: Vec<FocusHandle>, message: SharedString, } @@ -20,8 +21,11 @@ impl Example { cx.focus_handle().tab_index(2).tab_stop(true), ]; - window.focus(items.first().unwrap()); + let focus_handle = cx.focus_handle(); + window.focus(&focus_handle); + Self { + focus_handle, items, message: SharedString::from("Press `Tab`, `Shift-Tab` to switch focus."), } @@ -40,6 +44,10 @@ impl Example { impl Render for Example { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + fn tab_stop_style<T: Styled>(this: T) -> T { + this.border_3().border_color(gpui::blue()) + } + fn button(id: impl Into<ElementId>) -> Stateful<Div> { div() .id(id) @@ -52,12 +60,13 @@ impl Render for Example { .border_color(gpui::black()) .bg(gpui::black()) .text_color(gpui::white()) - .focus(|this| this.border_color(gpui::blue())) + .focus(tab_stop_style) .shadow_sm() } div() .id("app") + .track_focus(&self.focus_handle) .on_action(cx.listener(Self::on_tab)) .on_action(cx.listener(Self::on_tab_prev)) .size_full() @@ -86,7 +95,7 @@ impl Render for Example { .border_color(gpui::black()) .when( item_handle.tab_stop && item_handle.is_focused(window), - |this| this.border_color(gpui::blue()), + tab_stop_style, ) .map(|this| match item_handle.tab_stop { true => this diff --git a/crates/gpui/examples/text.rs b/crates/gpui/examples/text.rs index 19214aebde..1166bb2795 100644 --- a/crates/gpui/examples/text.rs +++ b/crates/gpui/examples/text.rs @@ -198,7 +198,7 @@ impl RenderOnce for CharacterGrid { "χ", "ψ", "∂", "а", "в", "Ж", "ж", "З", "з", "К", "к", "л", "м", "Н", "н", "Р", "р", "У", "у", "ф", "ч", "ь", "ы", "Э", "э", "Я", "я", "ij", "öẋ", ".,", "⣝⣑", "~", "*", "_", "^", "`", "'", "(", "{", "«", "#", "&", "@", "$", "¢", "%", "|", "?", "¶", "µ", - "❮", "<=", "!=", "==", "--", "++", "=>", "->", + "❮", "<=", "!=", "==", "--", "++", "=>", "->", "🏀", "🎊", "😍", "❤️", "👍", "👎", ]; let columns = 11; diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index 759d33563e..ded7bae316 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -2023,6 +2023,10 @@ impl HttpClient for NullHttpClient { .boxed() } + fn user_agent(&self) -> Option<&http_client::http::HeaderValue> { + None + } + fn proxy(&self) -> Option<&Url> { None } diff --git a/crates/gpui/src/color.rs b/crates/gpui/src/color.rs index a16c8f46be..639c84c101 100644 --- a/crates/gpui/src/color.rs +++ b/crates/gpui/src/color.rs @@ -35,6 +35,7 @@ pub(crate) fn swap_rgba_pa_to_bgra(color: &mut [u8]) { /// An RGBA color #[derive(PartialEq, Clone, Copy, Default)] +#[repr(C)] pub struct Rgba { /// The red component of the color, in the range 0.0 to 1.0 pub r: f32, diff --git a/crates/gpui/src/elements/list.rs b/crates/gpui/src/elements/list.rs index 328a6a4cc1..709323ef58 100644 --- a/crates/gpui/src/elements/list.rs +++ b/crates/gpui/src/elements/list.rs @@ -16,7 +16,7 @@ use crate::{ use collections::VecDeque; use refineable::Refineable as _; use std::{cell::RefCell, ops::Range, rc::Rc}; -use sum_tree::{Bias, SumTree}; +use sum_tree::{Bias, Dimensions, SumTree}; /// Construct a new list element pub fn list(state: ListState) -> List { @@ -371,14 +371,14 @@ impl ListState { return None; } - let mut cursor = state.items.cursor::<(Count, Height)>(&()); + let mut cursor = state.items.cursor::<Dimensions<Count, Height>>(&()); cursor.seek(&Count(scroll_top.item_ix), Bias::Right); let scroll_top = cursor.start().1.0 + scroll_top.offset_in_item; cursor.seek_forward(&Count(ix), Bias::Right); if let Some(&ListItem::Measured { size, .. }) = cursor.item() { - let &(Count(count), Height(top)) = cursor.start(); + let &Dimensions(Count(count), Height(top), _) = cursor.start(); if count == ix { let top = bounds.top() + top - scroll_top; return Some(Bounds::from_corners( diff --git a/crates/gpui/src/geometry.rs b/crates/gpui/src/geometry.rs index 74be6344f9..3d2d9cd9db 100644 --- a/crates/gpui/src/geometry.rs +++ b/crates/gpui/src/geometry.rs @@ -3522,7 +3522,7 @@ impl Serialize for Length { /// # Returns /// /// A `DefiniteLength` representing the relative length as a fraction of the parent's size. -pub fn relative(fraction: f32) -> DefiniteLength { +pub const fn relative(fraction: f32) -> DefiniteLength { DefiniteLength::Fraction(fraction) } diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index 1e72d23868..b495d70dfd 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -13,8 +13,7 @@ mod mac; any(target_os = "linux", target_os = "freebsd"), any(feature = "x11", feature = "wayland") ), - target_os = "windows", - feature = "macos-blade" + all(target_os = "macos", feature = "macos-blade") ))] mod blade; @@ -448,6 +447,8 @@ impl Tiling { #[derive(Debug, Copy, Clone, Eq, PartialEq, Default)] pub(crate) struct RequestFrameOptions { pub(crate) require_presentation: bool, + /// Force refresh of all rendering states when true + pub(crate) force_render: bool, } pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle { diff --git a/crates/gpui/src/platform/keystroke.rs b/crates/gpui/src/platform/keystroke.rs index 8b6e72d150..24601eefd6 100644 --- a/crates/gpui/src/platform/keystroke.rs +++ b/crates/gpui/src/platform/keystroke.rs @@ -417,17 +417,6 @@ impl Modifiers { self.control || self.alt || self.shift || self.platform || self.function } - /// Returns the XOR of two modifier sets - pub fn xor(&self, other: &Modifiers) -> Modifiers { - Modifiers { - control: self.control ^ other.control, - alt: self.alt ^ other.alt, - shift: self.shift ^ other.shift, - platform: self.platform ^ other.platform, - function: self.function ^ other.function, - } - } - /// Whether the semantically 'secondary' modifier key is pressed. /// /// On macOS, this is the command key. @@ -545,11 +534,62 @@ impl Modifiers { /// Checks if this [`Modifiers`] is a subset of another [`Modifiers`]. pub fn is_subset_of(&self, other: &Modifiers) -> bool { - (other.control || !self.control) - && (other.alt || !self.alt) - && (other.shift || !self.shift) - && (other.platform || !self.platform) - && (other.function || !self.function) + (*other & *self) == *self + } +} + +impl std::ops::BitOr for Modifiers { + type Output = Self; + + fn bitor(mut self, other: Self) -> Self::Output { + self |= other; + self + } +} + +impl std::ops::BitOrAssign for Modifiers { + fn bitor_assign(&mut self, other: Self) { + self.control |= other.control; + self.alt |= other.alt; + self.shift |= other.shift; + self.platform |= other.platform; + self.function |= other.function; + } +} + +impl std::ops::BitXor for Modifiers { + type Output = Self; + fn bitxor(mut self, rhs: Self) -> Self::Output { + self ^= rhs; + self + } +} + +impl std::ops::BitXorAssign for Modifiers { + fn bitxor_assign(&mut self, other: Self) { + self.control ^= other.control; + self.alt ^= other.alt; + self.shift ^= other.shift; + self.platform ^= other.platform; + self.function ^= other.function; + } +} + +impl std::ops::BitAnd for Modifiers { + type Output = Self; + fn bitand(mut self, rhs: Self) -> Self::Output { + self &= rhs; + self + } +} + +impl std::ops::BitAndAssign for Modifiers { + fn bitand_assign(&mut self, other: Self) { + self.control &= other.control; + self.alt &= other.alt; + self.shift &= other.shift; + self.platform &= other.platform; + self.function &= other.function; } } diff --git a/crates/gpui/src/platform/linux/x11/client.rs b/crates/gpui/src/platform/linux/x11/client.rs index d1cb7d00cc..573e4addf7 100644 --- a/crates/gpui/src/platform/linux/x11/client.rs +++ b/crates/gpui/src/platform/linux/x11/client.rs @@ -1004,12 +1004,13 @@ impl X11Client { let mut keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code); let keysym = state.xkb.key_get_one_sym(code); - // should be called after key_get_one_sym - state.xkb.update_key(code, xkbc::KeyDirection::Down); - if keysym.is_modifier_key() { return Some(()); } + + // should be called after key_get_one_sym + state.xkb.update_key(code, xkbc::KeyDirection::Down); + if let Some(mut compose_state) = state.compose_state.take() { compose_state.feed(keysym); match compose_state.status() { @@ -1067,12 +1068,13 @@ impl X11Client { let keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code); let keysym = state.xkb.key_get_one_sym(code); - // should be called after key_get_one_sym - state.xkb.update_key(code, xkbc::KeyDirection::Up); - if keysym.is_modifier_key() { return Some(()); } + + // should be called after key_get_one_sym + state.xkb.update_key(code, xkbc::KeyDirection::Up); + keystroke }; drop(state); @@ -1793,6 +1795,7 @@ impl X11ClientState { drop(state); window.refresh(RequestFrameOptions { require_presentation: expose_event_received, + force_render: false, }); } xcb_connection diff --git a/crates/gpui/src/platform/windows.rs b/crates/gpui/src/platform/windows.rs index 4bdf42080d..5268d3ccba 100644 --- a/crates/gpui/src/platform/windows.rs +++ b/crates/gpui/src/platform/windows.rs @@ -1,6 +1,8 @@ mod clipboard; mod destination_list; mod direct_write; +mod directx_atlas; +mod directx_renderer; mod dispatcher; mod display; mod events; @@ -14,6 +16,8 @@ mod wrapper; pub(crate) use clipboard::*; pub(crate) use destination_list::*; pub(crate) use direct_write::*; +pub(crate) use directx_atlas::*; +pub(crate) use directx_renderer::*; pub(crate) use dispatcher::*; pub(crate) use display::*; pub(crate) use events::*; diff --git a/crates/gpui/src/platform/windows/color_text_raster.hlsl b/crates/gpui/src/platform/windows/color_text_raster.hlsl new file mode 100644 index 0000000000..ccc5fa26f0 --- /dev/null +++ b/crates/gpui/src/platform/windows/color_text_raster.hlsl @@ -0,0 +1,39 @@ +struct RasterVertexOutput { + float4 position : SV_Position; + float2 texcoord : TEXCOORD0; +}; + +RasterVertexOutput emoji_rasterization_vertex(uint vertexID : SV_VERTEXID) +{ + RasterVertexOutput output; + output.texcoord = float2((vertexID << 1) & 2, vertexID & 2); + output.position = float4(output.texcoord * 2.0f - 1.0f, 0.0f, 1.0f); + output.position.y = -output.position.y; + + return output; +} + +struct PixelInput { + float4 position: SV_Position; + float2 texcoord : TEXCOORD0; +}; + +struct Bounds { + int2 origin; + int2 size; +}; + +Texture2D<float4> t_layer : register(t0); +SamplerState s_layer : register(s0); + +cbuffer GlyphLayerTextureParams : register(b0) { + Bounds bounds; + float4 run_color; +}; + +float4 emoji_rasterization_fragment(PixelInput input): SV_Target { + float3 sampled = t_layer.Sample(s_layer, input.texcoord.xy).rgb; + float alpha = (sampled.r + sampled.g + sampled.b) / 3; + + return float4(run_color.rgb, alpha); +} diff --git a/crates/gpui/src/platform/windows/direct_write.rs b/crates/gpui/src/platform/windows/direct_write.rs index ada306c15c..587cb7b4a6 100644 --- a/crates/gpui/src/platform/windows/direct_write.rs +++ b/crates/gpui/src/platform/windows/direct_write.rs @@ -10,10 +10,11 @@ use windows::{ Foundation::*, Globalization::GetUserDefaultLocaleName, Graphics::{ - Direct2D::{Common::*, *}, + Direct3D::D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + Direct3D11::*, DirectWrite::*, Dxgi::Common::*, - Gdi::LOGFONTW, + Gdi::{IsRectEmpty, LOGFONTW}, Imaging::*, }, System::SystemServices::LOCALE_NAME_MAX_LENGTH, @@ -40,16 +41,21 @@ struct DirectWriteComponent { locale: String, factory: IDWriteFactory5, bitmap_factory: AgileReference<IWICImagingFactory>, - d2d1_factory: ID2D1Factory, in_memory_loader: IDWriteInMemoryFontFileLoader, builder: IDWriteFontSetBuilder1, text_renderer: Arc<TextRendererWrapper>, - render_context: GlyphRenderContext, + + render_params: IDWriteRenderingParams3, + gpu_state: GPUState, } -struct GlyphRenderContext { - params: IDWriteRenderingParams3, - dc_target: ID2D1DeviceContext4, +struct GPUState { + device: ID3D11Device, + device_context: ID3D11DeviceContext, + sampler: [Option<ID3D11SamplerState>; 1], + blend_state: ID3D11BlendState, + vertex_shader: ID3D11VertexShader, + pixel_shader: ID3D11PixelShader, } struct DirectWriteState { @@ -70,12 +76,11 @@ struct FontIdentifier { } impl DirectWriteComponent { - pub fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> { + pub fn new(bitmap_factory: &IWICImagingFactory, gpu_context: &DirectXDevices) -> Result<Self> { + // todo: ideally this would not be a large unsafe block but smaller isolated ones for easier auditing unsafe { let factory: IDWriteFactory5 = DWriteCreateFactory(DWRITE_FACTORY_TYPE_SHARED)?; let bitmap_factory = AgileReference::new(bitmap_factory)?; - let d2d1_factory: ID2D1Factory = - D2D1CreateFactory(D2D1_FACTORY_TYPE_MULTI_THREADED, None)?; // The `IDWriteInMemoryFontFileLoader` here is supported starting from // Windows 10 Creators Update, which consequently requires the entire // `DirectWriteTextSystem` to run on `win10 1703`+. @@ -86,60 +91,132 @@ impl DirectWriteComponent { GetUserDefaultLocaleName(&mut locale_vec); let locale = String::from_utf16_lossy(&locale_vec); let text_renderer = Arc::new(TextRendererWrapper::new(&locale)); - let render_context = GlyphRenderContext::new(&factory, &d2d1_factory)?; + + let render_params = { + let default_params: IDWriteRenderingParams3 = + factory.CreateRenderingParams()?.cast()?; + let gamma = default_params.GetGamma(); + let enhanced_contrast = default_params.GetEnhancedContrast(); + let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); + let cleartype_level = default_params.GetClearTypeLevel(); + let grid_fit_mode = default_params.GetGridFitMode(); + + factory.CreateCustomRenderingParams( + gamma, + enhanced_contrast, + gray_contrast, + cleartype_level, + DWRITE_PIXEL_GEOMETRY_RGB, + DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + grid_fit_mode, + )? + }; + + let gpu_state = GPUState::new(gpu_context)?; Ok(DirectWriteComponent { locale, factory, bitmap_factory, - d2d1_factory, in_memory_loader, builder, text_renderer, - render_context, + render_params, + gpu_state, }) } } } -impl GlyphRenderContext { - pub fn new(factory: &IDWriteFactory5, d2d1_factory: &ID2D1Factory) -> Result<Self> { - unsafe { - let default_params: IDWriteRenderingParams3 = - factory.CreateRenderingParams()?.cast()?; - let gamma = default_params.GetGamma(); - let enhanced_contrast = default_params.GetEnhancedContrast(); - let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); - let cleartype_level = default_params.GetClearTypeLevel(); - let grid_fit_mode = default_params.GetGridFitMode(); +impl GPUState { + fn new(gpu_context: &DirectXDevices) -> Result<Self> { + let device = gpu_context.device.clone(); + let device_context = gpu_context.device_context.clone(); - let params = factory.CreateCustomRenderingParams( - gamma, - enhanced_contrast, - gray_contrast, - cleartype_level, - DWRITE_PIXEL_GEOMETRY_RGB, - DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, - grid_fit_mode, - )?; - let dc_target = { - let target = d2d1_factory.CreateDCRenderTarget(&get_render_target_property( - DXGI_FORMAT_B8G8R8A8_UNORM, - D2D1_ALPHA_MODE_PREMULTIPLIED, - ))?; - let target = target.cast::<ID2D1DeviceContext4>()?; - target.SetTextRenderingParams(¶ms); - target + let blend_state = { + let mut blend_state = None; + let desc = D3D11_BLEND_DESC { + AlphaToCoverageEnable: false.into(), + IndependentBlendEnable: false.into(), + RenderTarget: [ + D3D11_RENDER_TARGET_BLEND_DESC { + BlendEnable: true.into(), + SrcBlend: D3D11_BLEND_SRC_ALPHA, + DestBlend: D3D11_BLEND_INV_SRC_ALPHA, + BlendOp: D3D11_BLEND_OP_ADD, + SrcBlendAlpha: D3D11_BLEND_SRC_ALPHA, + DestBlendAlpha: D3D11_BLEND_INV_SRC_ALPHA, + BlendOpAlpha: D3D11_BLEND_OP_ADD, + RenderTargetWriteMask: D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8, + }, + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + ], }; + unsafe { device.CreateBlendState(&desc, Some(&mut blend_state)) }?; + blend_state.unwrap() + }; - Ok(Self { params, dc_target }) - } + let sampler = { + let mut sampler = None; + let desc = D3D11_SAMPLER_DESC { + Filter: D3D11_FILTER_MIN_MAG_MIP_POINT, + AddressU: D3D11_TEXTURE_ADDRESS_BORDER, + AddressV: D3D11_TEXTURE_ADDRESS_BORDER, + AddressW: D3D11_TEXTURE_ADDRESS_BORDER, + MipLODBias: 0.0, + MaxAnisotropy: 1, + ComparisonFunc: D3D11_COMPARISON_ALWAYS, + BorderColor: [0.0, 0.0, 0.0, 0.0], + MinLOD: 0.0, + MaxLOD: 0.0, + }; + unsafe { device.CreateSamplerState(&desc, Some(&mut sampler)) }?; + [sampler] + }; + + let vertex_shader = { + let source = shader_resources::RawShaderBytes::new( + shader_resources::ShaderModule::EmojiRasterization, + shader_resources::ShaderTarget::Vertex, + )?; + let mut shader = None; + unsafe { device.CreateVertexShader(source.as_bytes(), None, Some(&mut shader)) }?; + shader.unwrap() + }; + + let pixel_shader = { + let source = shader_resources::RawShaderBytes::new( + shader_resources::ShaderModule::EmojiRasterization, + shader_resources::ShaderTarget::Fragment, + )?; + let mut shader = None; + unsafe { device.CreatePixelShader(source.as_bytes(), None, Some(&mut shader)) }?; + shader.unwrap() + }; + + Ok(Self { + device, + device_context, + sampler, + blend_state, + vertex_shader, + pixel_shader, + }) } } impl DirectWriteTextSystem { - pub(crate) fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> { - let components = DirectWriteComponent::new(bitmap_factory)?; + pub(crate) fn new( + gpu_context: &DirectXDevices, + bitmap_factory: &IWICImagingFactory, + ) -> Result<Self> { + let components = DirectWriteComponent::new(bitmap_factory, gpu_context)?; let system_font_collection = unsafe { let mut result = std::mem::zeroed(); components @@ -648,15 +725,13 @@ impl DirectWriteState { } } - fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> { - let render_target = &self.components.render_context.dc_target; - unsafe { - render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); - render_target.SetDpi(96.0 * params.scale_factor, 96.0 * params.scale_factor); - } + fn create_glyph_run_analysis( + &self, + params: &RenderGlyphParams, + ) -> Result<IDWriteGlyphRunAnalysis> { let font = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; - let advance = [0.0f32]; + let advance = [0.0]; let offset = [DWRITE_GLYPH_OFFSET::default()]; let glyph_run = DWRITE_GLYPH_RUN { fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, @@ -668,44 +743,87 @@ impl DirectWriteState { isSideways: BOOL(0), bidiLevel: 0, }; - let bounds = unsafe { - render_target.GetGlyphRunWorldBounds( - Vector2 { X: 0.0, Y: 0.0 }, - &glyph_run, - DWRITE_MEASURING_MODE_NATURAL, - )? + let transform = DWRITE_MATRIX { + m11: params.scale_factor, + m12: 0.0, + m21: 0.0, + m22: params.scale_factor, + dx: 0.0, + dy: 0.0, }; - // todo(windows) - // This is a walkaround, deleted when figured out. - let y_offset; - let extra_height; - if params.is_emoji { - y_offset = 0; - extra_height = 0; - } else { - // make some room for scaler. - y_offset = -1; - extra_height = 2; + let subpixel_shift = params + .subpixel_variant + .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + let baseline_origin_x = subpixel_shift.x / params.scale_factor; + let baseline_origin_y = subpixel_shift.y / params.scale_factor; + + let mut rendering_mode = DWRITE_RENDERING_MODE1::default(); + let mut grid_fit_mode = DWRITE_GRID_FIT_MODE::default(); + unsafe { + font.font_face.GetRecommendedRenderingMode( + params.font_size.0, + // The dpi here seems that it has the same effect with `Some(&transform)` + 1.0, + 1.0, + Some(&transform), + false, + DWRITE_OUTLINE_THRESHOLD_ANTIALIASED, + DWRITE_MEASURING_MODE_NATURAL, + &self.components.render_params, + &mut rendering_mode, + &mut grid_fit_mode, + )?; } - if bounds.right < bounds.left { + let glyph_analysis = unsafe { + self.components.factory.CreateGlyphRunAnalysis( + &glyph_run, + Some(&transform), + rendering_mode, + DWRITE_MEASURING_MODE_NATURAL, + grid_fit_mode, + // We're using cleartype not grayscale for monochrome is because it provides better quality + DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + baseline_origin_x, + baseline_origin_y, + ) + }?; + Ok(glyph_analysis) + } + + fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> { + let glyph_analysis = self.create_glyph_run_analysis(params)?; + + let bounds = unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1)? }; + // Some glyphs cannot be drawn with ClearType, such as bitmap fonts. In that case + // GetAlphaTextureBounds() supposedly returns an empty RECT, but I haven't tested that yet. + if !unsafe { IsRectEmpty(&bounds) }.as_bool() { Ok(Bounds { - origin: point(0.into(), 0.into()), - size: size(0.into(), 0.into()), + origin: point(bounds.left.into(), bounds.top.into()), + size: size( + (bounds.right - bounds.left).into(), + (bounds.bottom - bounds.top).into(), + ), }) } else { - Ok(Bounds { - origin: point( - ((bounds.left * params.scale_factor).ceil() as i32).into(), - ((bounds.top * params.scale_factor).ceil() as i32 + y_offset).into(), - ), - size: size( - (((bounds.right - bounds.left) * params.scale_factor).ceil() as i32).into(), - (((bounds.bottom - bounds.top) * params.scale_factor).ceil() as i32 - + extra_height) - .into(), - ), - }) + // If it's empty, retry with grayscale AA. + let bounds = + unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_ALIASED_1x1)? }; + + if bounds.right < bounds.left { + Ok(Bounds { + origin: point(0.into(), 0.into()), + size: size(0.into(), 0.into()), + }) + } else { + Ok(Bounds { + origin: point(bounds.left.into(), bounds.top.into()), + size: size( + (bounds.right - bounds.left).into(), + (bounds.bottom - bounds.top).into(), + ), + }) + } } } @@ -731,7 +849,95 @@ impl DirectWriteState { anyhow::bail!("glyph bounds are empty"); } - let font_info = &self.fonts[params.font_id.0]; + let bitmap_data = if params.is_emoji { + if let Ok(color) = self.rasterize_color(¶ms, glyph_bounds) { + color + } else { + let monochrome = self.rasterize_monochrome(params, glyph_bounds)?; + monochrome + .into_iter() + .flat_map(|pixel| [0, 0, 0, pixel]) + .collect::<Vec<_>>() + } + } else { + self.rasterize_monochrome(params, glyph_bounds)? + }; + + Ok((glyph_bounds.size, bitmap_data)) + } + + fn rasterize_monochrome( + &self, + params: &RenderGlyphParams, + glyph_bounds: Bounds<DevicePixels>, + ) -> Result<Vec<u8>> { + let mut bitmap_data = + vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize * 3]; + + let glyph_analysis = self.create_glyph_run_analysis(params)?; + unsafe { + glyph_analysis.CreateAlphaTexture( + // We're using cleartype not grayscale for monochrome is because it provides better quality + DWRITE_TEXTURE_CLEARTYPE_3x1, + &RECT { + left: glyph_bounds.origin.x.0, + top: glyph_bounds.origin.y.0, + right: glyph_bounds.size.width.0 + glyph_bounds.origin.x.0, + bottom: glyph_bounds.size.height.0 + glyph_bounds.origin.y.0, + }, + &mut bitmap_data, + )?; + } + + let bitmap_factory = self.components.bitmap_factory.resolve()?; + let bitmap = unsafe { + bitmap_factory.CreateBitmapFromMemory( + glyph_bounds.size.width.0 as u32, + glyph_bounds.size.height.0 as u32, + &GUID_WICPixelFormat24bppRGB, + glyph_bounds.size.width.0 as u32 * 3, + &bitmap_data, + ) + }?; + + let grayscale_bitmap = + unsafe { WICConvertBitmapSource(&GUID_WICPixelFormat8bppGray, &bitmap) }?; + + let mut bitmap_data = + vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize]; + unsafe { + grayscale_bitmap.CopyPixels( + std::ptr::null() as _, + glyph_bounds.size.width.0 as u32, + &mut bitmap_data, + ) + }?; + + Ok(bitmap_data) + } + + fn rasterize_color( + &self, + params: &RenderGlyphParams, + glyph_bounds: Bounds<DevicePixels>, + ) -> Result<Vec<u8>> { + let bitmap_size = glyph_bounds.size; + let subpixel_shift = params + .subpixel_variant + .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + let baseline_origin_x = subpixel_shift.x / params.scale_factor; + let baseline_origin_y = subpixel_shift.y / params.scale_factor; + + let transform = DWRITE_MATRIX { + m11: params.scale_factor, + m12: 0.0, + m21: 0.0, + m22: params.scale_factor, + dx: 0.0, + dy: 0.0, + }; + + let font = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; let advance = [glyph_bounds.size.width.0 as f32]; let offset = [DWRITE_GLYPH_OFFSET { @@ -739,7 +945,7 @@ impl DirectWriteState { ascenderOffset: glyph_bounds.origin.y.0 as f32 / params.scale_factor, }]; let glyph_run = DWRITE_GLYPH_RUN { - fontFace: unsafe { std::mem::transmute_copy(&font_info.font_face) }, + fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, fontEmSize: params.font_size.0, glyphCount: 1, glyphIndices: glyph_id.as_ptr(), @@ -749,160 +955,254 @@ impl DirectWriteState { bidiLevel: 0, }; - // Add an extra pixel when the subpixel variant isn't zero to make room for anti-aliasing. - let mut bitmap_size = glyph_bounds.size; - if params.subpixel_variant.x > 0 { - bitmap_size.width += DevicePixels(1); - } - if params.subpixel_variant.y > 0 { - bitmap_size.height += DevicePixels(1); - } - let bitmap_size = bitmap_size; + // todo: support formats other than COLR + let color_enumerator = unsafe { + self.components.factory.TranslateColorGlyphRun( + Vector2::new(baseline_origin_x, baseline_origin_y), + &glyph_run, + None, + DWRITE_GLYPH_IMAGE_FORMATS_COLR, + DWRITE_MEASURING_MODE_NATURAL, + Some(&transform), + 0, + ) + }?; - let total_bytes; - let bitmap_format; - let render_target_property; - let bitmap_width; - let bitmap_height; - let bitmap_stride; - let bitmap_dpi; - if params.is_emoji { - total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize * 4; - bitmap_format = &GUID_WICPixelFormat32bppPBGRA; - render_target_property = get_render_target_property( - DXGI_FORMAT_B8G8R8A8_UNORM, - D2D1_ALPHA_MODE_PREMULTIPLIED, - ); - bitmap_width = bitmap_size.width.0 as u32; - bitmap_height = bitmap_size.height.0 as u32; - bitmap_stride = bitmap_size.width.0 as u32 * 4; - bitmap_dpi = 96.0; - } else { - total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize; - bitmap_format = &GUID_WICPixelFormat8bppAlpha; - render_target_property = - get_render_target_property(DXGI_FORMAT_A8_UNORM, D2D1_ALPHA_MODE_STRAIGHT); - bitmap_width = bitmap_size.width.0 as u32 * 2; - bitmap_height = bitmap_size.height.0 as u32 * 2; - bitmap_stride = bitmap_size.width.0 as u32; - bitmap_dpi = 192.0; + let mut glyph_layers = Vec::new(); + loop { + let color_run = unsafe { color_enumerator.GetCurrentRun() }?; + let color_run = unsafe { &*color_run }; + let image_format = color_run.glyphImageFormat & !DWRITE_GLYPH_IMAGE_FORMATS_TRUETYPE; + if image_format == DWRITE_GLYPH_IMAGE_FORMATS_COLR { + let color_analysis = unsafe { + self.components.factory.CreateGlyphRunAnalysis( + &color_run.Base.glyphRun as *const _, + Some(&transform), + DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + DWRITE_MEASURING_MODE_NATURAL, + DWRITE_GRID_FIT_MODE_DEFAULT, + DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + baseline_origin_x, + baseline_origin_y, + ) + }?; + + let color_bounds = + unsafe { color_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1) }?; + + let color_size = size( + color_bounds.right - color_bounds.left, + color_bounds.bottom - color_bounds.top, + ); + if color_size.width > 0 && color_size.height > 0 { + let mut alpha_data = + vec![0u8; (color_size.width * color_size.height * 3) as usize]; + unsafe { + color_analysis.CreateAlphaTexture( + DWRITE_TEXTURE_CLEARTYPE_3x1, + &color_bounds, + &mut alpha_data, + ) + }?; + + let run_color = { + let run_color = color_run.Base.runColor; + Rgba { + r: run_color.r, + g: run_color.g, + b: run_color.b, + a: run_color.a, + } + }; + let bounds = bounds(point(color_bounds.left, color_bounds.top), color_size); + let alpha_data = alpha_data + .chunks_exact(3) + .flat_map(|chunk| [chunk[0], chunk[1], chunk[2], 255]) + .collect::<Vec<_>>(); + glyph_layers.push(GlyphLayerTexture::new( + &self.components.gpu_state, + run_color, + bounds, + &alpha_data, + )?); + } + } + + let has_next = unsafe { color_enumerator.MoveNext() } + .map(|e| e.as_bool()) + .unwrap_or(false); + if !has_next { + break; + } } - let bitmap_factory = self.components.bitmap_factory.resolve()?; - unsafe { - let bitmap = bitmap_factory.CreateBitmap( - bitmap_width, - bitmap_height, - bitmap_format, - WICBitmapCacheOnLoad, - )?; - let render_target = self - .components - .d2d1_factory - .CreateWicBitmapRenderTarget(&bitmap, &render_target_property)?; - let brush = render_target.CreateSolidColorBrush(&BRUSH_COLOR, None)?; - let subpixel_shift = params - .subpixel_variant - .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); - let baseline_origin = Vector2 { - X: subpixel_shift.x / params.scale_factor, - Y: subpixel_shift.y / params.scale_factor, + let gpu_state = &self.components.gpu_state; + let params_buffer = { + let desc = D3D11_BUFFER_DESC { + ByteWidth: std::mem::size_of::<GlyphLayerTextureParams>() as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + StructureByteStride: 0, }; - // This `cast()` action here should never fail since we are running on Win10+, and - // ID2D1DeviceContext4 requires Win8+ - let render_target = render_target.cast::<ID2D1DeviceContext4>().unwrap(); - render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); - render_target.SetDpi( - bitmap_dpi * params.scale_factor, - bitmap_dpi * params.scale_factor, - ); - render_target.SetTextRenderingParams(&self.components.render_context.params); - render_target.BeginDraw(); + let mut buffer = None; + unsafe { + gpu_state + .device + .CreateBuffer(&desc, None, Some(&mut buffer)) + }?; + [buffer] + }; - if params.is_emoji { - // WARN: only DWRITE_GLYPH_IMAGE_FORMATS_COLR has been tested - let enumerator = self.components.factory.TranslateColorGlyphRun( - baseline_origin, - &glyph_run as _, - None, - DWRITE_GLYPH_IMAGE_FORMATS_COLR - | DWRITE_GLYPH_IMAGE_FORMATS_SVG - | DWRITE_GLYPH_IMAGE_FORMATS_PNG - | DWRITE_GLYPH_IMAGE_FORMATS_JPEG - | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8, - DWRITE_MEASURING_MODE_NATURAL, - None, + let render_target_texture = { + let mut texture = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: bitmap_size.width.0 as u32, + Height: bitmap_size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture)) + }?; + texture.unwrap() + }; + + let render_target_view = { + let desc = D3D11_RENDER_TARGET_VIEW_DESC { + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + ViewDimension: D3D11_RTV_DIMENSION_TEXTURE2D, + Anonymous: D3D11_RENDER_TARGET_VIEW_DESC_0 { + Texture2D: D3D11_TEX2D_RTV { MipSlice: 0 }, + }, + }; + let mut rtv = None; + unsafe { + gpu_state.device.CreateRenderTargetView( + &render_target_texture, + Some(&desc), + Some(&mut rtv), + ) + }?; + [rtv] + }; + + let staging_texture = { + let mut texture = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: bitmap_size.width.0 as u32, + Height: bitmap_size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_STAGING, + BindFlags: 0, + CPUAccessFlags: D3D11_CPU_ACCESS_READ.0 as u32, + MiscFlags: 0, + }; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture)) + }?; + texture.unwrap() + }; + + let device_context = &gpu_state.device_context; + unsafe { device_context.IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP) }; + unsafe { device_context.VSSetShader(&gpu_state.vertex_shader, None) }; + unsafe { device_context.PSSetShader(&gpu_state.pixel_shader, None) }; + unsafe { device_context.VSSetConstantBuffers(0, Some(¶ms_buffer)) }; + unsafe { device_context.PSSetConstantBuffers(0, Some(¶ms_buffer)) }; + unsafe { device_context.OMSetRenderTargets(Some(&render_target_view), None) }; + unsafe { device_context.PSSetSamplers(0, Some(&gpu_state.sampler)) }; + unsafe { device_context.OMSetBlendState(&gpu_state.blend_state, None, 0xffffffff) }; + + for layer in glyph_layers { + let params = GlyphLayerTextureParams { + run_color: layer.run_color, + bounds: layer.bounds, + }; + unsafe { + let mut dest = std::mem::zeroed(); + gpu_state.device_context.Map( + params_buffer[0].as_ref().unwrap(), 0, + D3D11_MAP_WRITE_DISCARD, + 0, + Some(&mut dest), )?; - while enumerator.MoveNext().is_ok() { - let Ok(color_glyph) = enumerator.GetCurrentRun() else { - break; - }; - let color_glyph = &*color_glyph; - let brush_color = translate_color(&color_glyph.Base.runColor); - brush.SetColor(&brush_color); - match color_glyph.glyphImageFormat { - DWRITE_GLYPH_IMAGE_FORMATS_PNG - | DWRITE_GLYPH_IMAGE_FORMATS_JPEG - | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8 => render_target - .DrawColorBitmapGlyphRun( - color_glyph.glyphImageFormat, - baseline_origin, - &color_glyph.Base.glyphRun, - color_glyph.measuringMode, - D2D1_COLOR_BITMAP_GLYPH_SNAP_OPTION_DEFAULT, - ), - DWRITE_GLYPH_IMAGE_FORMATS_SVG => render_target.DrawSvgGlyphRun( - baseline_origin, - &color_glyph.Base.glyphRun, - &brush, - None, - color_glyph.Base.paletteIndex as u32, - color_glyph.measuringMode, - ), - _ => render_target.DrawGlyphRun( - baseline_origin, - &color_glyph.Base.glyphRun, - Some(color_glyph.Base.glyphRunDescription as *const _), - &brush, - color_glyph.measuringMode, - ), - } - } - } else { - render_target.DrawGlyphRun( - baseline_origin, - &glyph_run, - None, - &brush, - DWRITE_MEASURING_MODE_NATURAL, - ); - } - render_target.EndDraw(None, None)?; + std::ptr::copy_nonoverlapping(¶ms as *const _, dest.pData as *mut _, 1); + gpu_state + .device_context + .Unmap(params_buffer[0].as_ref().unwrap(), 0); + }; - let mut raw_data = vec![0u8; total_bytes]; - if params.is_emoji { - bitmap.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; - // Convert from BGRA with premultiplied alpha to BGRA with straight alpha. - for pixel in raw_data.chunks_exact_mut(4) { - let a = pixel[3] as f32 / 255.; - pixel[0] = (pixel[0] as f32 / a) as u8; - pixel[1] = (pixel[1] as f32 / a) as u8; - pixel[2] = (pixel[2] as f32 / a) as u8; - } - } else { - let scaler = bitmap_factory.CreateBitmapScaler()?; - scaler.Initialize( - &bitmap, - bitmap_size.width.0 as u32, - bitmap_size.height.0 as u32, - WICBitmapInterpolationModeHighQualityCubic, - )?; - scaler.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; - } - Ok((bitmap_size, raw_data)) + let texture = [Some(layer.texture_view)]; + unsafe { device_context.PSSetShaderResources(0, Some(&texture)) }; + + let viewport = [D3D11_VIEWPORT { + TopLeftX: layer.bounds.origin.x as f32, + TopLeftY: layer.bounds.origin.y as f32, + Width: layer.bounds.size.width as f32, + Height: layer.bounds.size.height as f32, + MinDepth: 0.0, + MaxDepth: 1.0, + }]; + unsafe { device_context.RSSetViewports(Some(&viewport)) }; + + unsafe { device_context.Draw(4, 0) }; } + + unsafe { device_context.CopyResource(&staging_texture, &render_target_texture) }; + + let mapped_data = { + let mut mapped_data = D3D11_MAPPED_SUBRESOURCE::default(); + unsafe { + device_context.Map( + &staging_texture, + 0, + D3D11_MAP_READ, + 0, + Some(&mut mapped_data), + ) + }?; + mapped_data + }; + let mut rasterized = + vec![0u8; (bitmap_size.width.0 as u32 * bitmap_size.height.0 as u32 * 4) as usize]; + + for y in 0..bitmap_size.height.0 as usize { + let width = bitmap_size.width.0 as usize; + unsafe { + std::ptr::copy_nonoverlapping::<u8>( + (mapped_data.pData as *const u8).byte_add(mapped_data.RowPitch as usize * y), + rasterized + .as_mut_ptr() + .byte_add(width * y * std::mem::size_of::<u32>()), + width * std::mem::size_of::<u32>(), + ) + }; + } + + Ok(rasterized) } fn get_typographic_bounds(&self, font_id: FontId, glyph_id: GlyphId) -> Result<Bounds<f32>> { @@ -976,6 +1276,84 @@ impl Drop for DirectWriteState { } } +struct GlyphLayerTexture { + run_color: Rgba, + bounds: Bounds<i32>, + texture_view: ID3D11ShaderResourceView, + // holding on to the texture to not RAII drop it + _texture: ID3D11Texture2D, +} + +impl GlyphLayerTexture { + pub fn new( + gpu_state: &GPUState, + run_color: Rgba, + bounds: Bounds<i32>, + alpha_data: &[u8], + ) -> Result<Self> { + let texture_size = bounds.size; + + let desc = D3D11_TEXTURE2D_DESC { + Width: texture_size.width as u32, + Height: texture_size.height as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_R8G8B8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + }; + + let texture = { + let mut texture: Option<ID3D11Texture2D> = None; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture))? + }; + texture.unwrap() + }; + let texture_view = { + let mut view: Option<ID3D11ShaderResourceView> = None; + unsafe { + gpu_state + .device + .CreateShaderResourceView(&texture, None, Some(&mut view))? + }; + view.unwrap() + }; + + unsafe { + gpu_state.device_context.UpdateSubresource( + &texture, + 0, + None, + alpha_data.as_ptr() as _, + (texture_size.width * 4) as u32, + 0, + ) + }; + + Ok(GlyphLayerTexture { + run_color, + bounds, + texture_view, + _texture: texture, + }) + } +} + +#[repr(C)] +struct GlyphLayerTextureParams { + bounds: Bounds<i32>, + run_color: Rgba, +} + struct TextRendererWrapper(pub IDWriteTextRenderer); impl TextRendererWrapper { @@ -1470,16 +1848,6 @@ fn get_name(string: IDWriteLocalizedStrings, locale: &str) -> Result<String> { Ok(String::from_utf16_lossy(&name_vec[..name_length])) } -#[inline] -fn translate_color(color: &DWRITE_COLOR_F) -> D2D1_COLOR_F { - D2D1_COLOR_F { - r: color.r, - g: color.g, - b: color.b, - a: color.a, - } -} - fn get_system_ui_font_name() -> SharedString { unsafe { let mut info: LOGFONTW = std::mem::zeroed(); @@ -1504,24 +1872,6 @@ fn get_system_ui_font_name() -> SharedString { } } -#[inline] -fn get_render_target_property( - pixel_format: DXGI_FORMAT, - alpha_mode: D2D1_ALPHA_MODE, -) -> D2D1_RENDER_TARGET_PROPERTIES { - D2D1_RENDER_TARGET_PROPERTIES { - r#type: D2D1_RENDER_TARGET_TYPE_DEFAULT, - pixelFormat: D2D1_PIXEL_FORMAT { - format: pixel_format, - alphaMode: alpha_mode, - }, - dpiX: 96.0, - dpiY: 96.0, - usage: D2D1_RENDER_TARGET_USAGE_NONE, - minLevel: D2D1_FEATURE_LEVEL_DEFAULT, - } -} - // One would think that with newer DirectWrite method: IDWriteFontFace4::GetGlyphImageFormats // but that doesn't seem to work for some glyphs, say ❤ fn is_color_glyph( @@ -1561,12 +1911,6 @@ fn is_color_glyph( } const DEFAULT_LOCALE_NAME: PCWSTR = windows::core::w!("en-US"); -const BRUSH_COLOR: D2D1_COLOR_F = D2D1_COLOR_F { - r: 1.0, - g: 1.0, - b: 1.0, - a: 1.0, -}; #[cfg(test)] mod tests { diff --git a/crates/gpui/src/platform/windows/directx_atlas.rs b/crates/gpui/src/platform/windows/directx_atlas.rs new file mode 100644 index 0000000000..6bced4c11d --- /dev/null +++ b/crates/gpui/src/platform/windows/directx_atlas.rs @@ -0,0 +1,309 @@ +use collections::FxHashMap; +use etagere::BucketedAtlasAllocator; +use parking_lot::Mutex; +use windows::Win32::Graphics::{ + Direct3D11::{ + D3D11_BIND_SHADER_RESOURCE, D3D11_BOX, D3D11_CPU_ACCESS_WRITE, D3D11_TEXTURE2D_DESC, + D3D11_USAGE_DEFAULT, ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView, + ID3D11Texture2D, + }, + Dxgi::Common::*, +}; + +use crate::{ + AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTile, Bounds, DevicePixels, PlatformAtlas, + Point, Size, platform::AtlasTextureList, +}; + +pub(crate) struct DirectXAtlas(Mutex<DirectXAtlasState>); + +struct DirectXAtlasState { + device: ID3D11Device, + device_context: ID3D11DeviceContext, + monochrome_textures: AtlasTextureList<DirectXAtlasTexture>, + polychrome_textures: AtlasTextureList<DirectXAtlasTexture>, + tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, +} + +struct DirectXAtlasTexture { + id: AtlasTextureId, + bytes_per_pixel: u32, + allocator: BucketedAtlasAllocator, + texture: ID3D11Texture2D, + view: [Option<ID3D11ShaderResourceView>; 1], + live_atlas_keys: u32, +} + +impl DirectXAtlas { + pub(crate) fn new(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Self { + DirectXAtlas(Mutex::new(DirectXAtlasState { + device: device.clone(), + device_context: device_context.clone(), + monochrome_textures: Default::default(), + polychrome_textures: Default::default(), + tiles_by_key: Default::default(), + })) + } + + pub(crate) fn get_texture_view( + &self, + id: AtlasTextureId, + ) -> [Option<ID3D11ShaderResourceView>; 1] { + let lock = self.0.lock(); + let tex = lock.texture(id); + tex.view.clone() + } + + pub(crate) fn handle_device_lost( + &self, + device: &ID3D11Device, + device_context: &ID3D11DeviceContext, + ) { + let mut lock = self.0.lock(); + lock.device = device.clone(); + lock.device_context = device_context.clone(); + lock.monochrome_textures = AtlasTextureList::default(); + lock.polychrome_textures = AtlasTextureList::default(); + lock.tiles_by_key.clear(); + } +} + +impl PlatformAtlas for DirectXAtlas { + fn get_or_insert_with<'a>( + &self, + key: &AtlasKey, + build: &mut dyn FnMut() -> anyhow::Result< + Option<(Size<DevicePixels>, std::borrow::Cow<'a, [u8]>)>, + >, + ) -> anyhow::Result<Option<AtlasTile>> { + let mut lock = self.0.lock(); + if let Some(tile) = lock.tiles_by_key.get(key) { + Ok(Some(tile.clone())) + } else { + let Some((size, bytes)) = build()? else { + return Ok(None); + }; + let tile = lock + .allocate(size, key.texture_kind()) + .ok_or_else(|| anyhow::anyhow!("failed to allocate"))?; + let texture = lock.texture(tile.texture_id); + texture.upload(&lock.device_context, tile.bounds, &bytes); + lock.tiles_by_key.insert(key.clone(), tile.clone()); + Ok(Some(tile)) + } + } + + fn remove(&self, key: &AtlasKey) { + let mut lock = self.0.lock(); + + let Some(id) = lock.tiles_by_key.remove(key).map(|tile| tile.texture_id) else { + return; + }; + + let textures = match id.kind { + AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, + AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, + }; + + let Some(texture_slot) = textures.textures.get_mut(id.index as usize) else { + return; + }; + + if let Some(mut texture) = texture_slot.take() { + texture.decrement_ref_count(); + if texture.is_unreferenced() { + textures.free_list.push(texture.id.index as usize); + lock.tiles_by_key.remove(key); + } else { + *texture_slot = Some(texture); + } + } + } +} + +impl DirectXAtlasState { + fn allocate( + &mut self, + size: Size<DevicePixels>, + texture_kind: AtlasTextureKind, + ) -> Option<AtlasTile> { + { + let textures = match texture_kind { + AtlasTextureKind::Monochrome => &mut self.monochrome_textures, + AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + }; + + if let Some(tile) = textures + .iter_mut() + .rev() + .find_map(|texture| texture.allocate(size)) + { + return Some(tile); + } + } + + let texture = self.push_texture(size, texture_kind)?; + texture.allocate(size) + } + + fn push_texture( + &mut self, + min_size: Size<DevicePixels>, + kind: AtlasTextureKind, + ) -> Option<&mut DirectXAtlasTexture> { + const DEFAULT_ATLAS_SIZE: Size<DevicePixels> = Size { + width: DevicePixels(1024), + height: DevicePixels(1024), + }; + // Max texture size for DirectX. See: + // https://learn.microsoft.com/en-us/windows/win32/direct3d11/overviews-direct3d-11-resources-limits + const MAX_ATLAS_SIZE: Size<DevicePixels> = Size { + width: DevicePixels(16384), + height: DevicePixels(16384), + }; + let size = min_size.min(&MAX_ATLAS_SIZE).max(&DEFAULT_ATLAS_SIZE); + let pixel_format; + let bind_flag; + let bytes_per_pixel; + match kind { + AtlasTextureKind::Monochrome => { + pixel_format = DXGI_FORMAT_R8_UNORM; + bind_flag = D3D11_BIND_SHADER_RESOURCE; + bytes_per_pixel = 1; + } + AtlasTextureKind::Polychrome => { + pixel_format = DXGI_FORMAT_B8G8R8A8_UNORM; + bind_flag = D3D11_BIND_SHADER_RESOURCE; + bytes_per_pixel = 4; + } + } + let texture_desc = D3D11_TEXTURE2D_DESC { + Width: size.width.0 as u32, + Height: size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: pixel_format, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: bind_flag.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + }; + let mut texture: Option<ID3D11Texture2D> = None; + unsafe { + // This only returns None if the device is lost, which we will recreate later. + // So it's ok to return None here. + self.device + .CreateTexture2D(&texture_desc, None, Some(&mut texture)) + .ok()?; + } + let texture = texture.unwrap(); + + let texture_list = match kind { + AtlasTextureKind::Monochrome => &mut self.monochrome_textures, + AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + }; + let index = texture_list.free_list.pop(); + let view = unsafe { + let mut view = None; + self.device + .CreateShaderResourceView(&texture, None, Some(&mut view)) + .ok()?; + [view] + }; + let atlas_texture = DirectXAtlasTexture { + id: AtlasTextureId { + index: index.unwrap_or(texture_list.textures.len()) as u32, + kind, + }, + bytes_per_pixel, + allocator: etagere::BucketedAtlasAllocator::new(size.into()), + texture, + view, + live_atlas_keys: 0, + }; + if let Some(ix) = index { + texture_list.textures[ix] = Some(atlas_texture); + texture_list.textures.get_mut(ix).unwrap().as_mut() + } else { + texture_list.textures.push(Some(atlas_texture)); + texture_list.textures.last_mut().unwrap().as_mut() + } + } + + fn texture(&self, id: AtlasTextureId) -> &DirectXAtlasTexture { + let textures = match id.kind { + crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, + crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + }; + textures[id.index as usize].as_ref().unwrap() + } +} + +impl DirectXAtlasTexture { + fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { + let allocation = self.allocator.allocate(size.into())?; + let tile = AtlasTile { + texture_id: self.id, + tile_id: allocation.id.into(), + bounds: Bounds { + origin: allocation.rectangle.min.into(), + size, + }, + padding: 0, + }; + self.live_atlas_keys += 1; + Some(tile) + } + + fn upload( + &self, + device_context: &ID3D11DeviceContext, + bounds: Bounds<DevicePixels>, + bytes: &[u8], + ) { + unsafe { + device_context.UpdateSubresource( + &self.texture, + 0, + Some(&D3D11_BOX { + left: bounds.left().0 as u32, + top: bounds.top().0 as u32, + front: 0, + right: bounds.right().0 as u32, + bottom: bounds.bottom().0 as u32, + back: 1, + }), + bytes.as_ptr() as _, + bounds.size.width.to_bytes(self.bytes_per_pixel as u8), + 0, + ); + } + } + + fn decrement_ref_count(&mut self) { + self.live_atlas_keys -= 1; + } + + fn is_unreferenced(&mut self) -> bool { + self.live_atlas_keys == 0 + } +} + +impl From<Size<DevicePixels>> for etagere::Size { + fn from(size: Size<DevicePixels>) -> Self { + etagere::Size::new(size.width.into(), size.height.into()) + } +} + +impl From<etagere::Point> for Point<DevicePixels> { + fn from(value: etagere::Point) -> Self { + Point { + x: DevicePixels::from(value.x), + y: DevicePixels::from(value.y), + } + } +} diff --git a/crates/gpui/src/platform/windows/directx_renderer.rs b/crates/gpui/src/platform/windows/directx_renderer.rs new file mode 100644 index 0000000000..72cc12a5b4 --- /dev/null +++ b/crates/gpui/src/platform/windows/directx_renderer.rs @@ -0,0 +1,1807 @@ +use std::{mem::ManuallyDrop, sync::Arc}; + +use ::util::ResultExt; +use anyhow::{Context, Result}; +use windows::{ + Win32::{ + Foundation::{HMODULE, HWND}, + Graphics::{ + Direct3D::*, + Direct3D11::*, + DirectComposition::*, + Dxgi::{Common::*, *}, + }, + }, + core::Interface, +}; + +use crate::{ + platform::windows::directx_renderer::shader_resources::{ + RawShaderBytes, ShaderModule, ShaderTarget, + }, + *, +}; + +pub(crate) const DISABLE_DIRECT_COMPOSITION: &str = "GPUI_DISABLE_DIRECT_COMPOSITION"; +const RENDER_TARGET_FORMAT: DXGI_FORMAT = DXGI_FORMAT_B8G8R8A8_UNORM; +// This configuration is used for MSAA rendering on paths only, and it's guaranteed to be supported by DirectX 11. +const PATH_MULTISAMPLE_COUNT: u32 = 4; + +pub(crate) struct DirectXRenderer { + hwnd: HWND, + atlas: Arc<DirectXAtlas>, + devices: ManuallyDrop<DirectXDevices>, + resources: ManuallyDrop<DirectXResources>, + globals: DirectXGlobalElements, + pipelines: DirectXRenderPipelines, + direct_composition: Option<DirectComposition>, +} + +/// Direct3D objects +#[derive(Clone)] +pub(crate) struct DirectXDevices { + adapter: IDXGIAdapter1, + dxgi_factory: IDXGIFactory6, + pub(crate) device: ID3D11Device, + pub(crate) device_context: ID3D11DeviceContext, + dxgi_device: Option<IDXGIDevice>, +} + +struct DirectXResources { + // Direct3D rendering objects + swap_chain: IDXGISwapChain1, + render_target: ManuallyDrop<ID3D11Texture2D>, + render_target_view: [Option<ID3D11RenderTargetView>; 1], + + // Path intermediate textures (with MSAA) + path_intermediate_texture: ID3D11Texture2D, + path_intermediate_srv: [Option<ID3D11ShaderResourceView>; 1], + path_intermediate_msaa_texture: ID3D11Texture2D, + path_intermediate_msaa_view: [Option<ID3D11RenderTargetView>; 1], + + // Cached window size and viewport + width: u32, + height: u32, + viewport: [D3D11_VIEWPORT; 1], +} + +struct DirectXRenderPipelines { + shadow_pipeline: PipelineState<Shadow>, + quad_pipeline: PipelineState<Quad>, + path_rasterization_pipeline: PipelineState<PathRasterizationSprite>, + path_sprite_pipeline: PipelineState<PathSprite>, + underline_pipeline: PipelineState<Underline>, + mono_sprites: PipelineState<MonochromeSprite>, + poly_sprites: PipelineState<PolychromeSprite>, +} + +struct DirectXGlobalElements { + global_params_buffer: [Option<ID3D11Buffer>; 1], + sampler: [Option<ID3D11SamplerState>; 1], +} + +struct DirectComposition { + comp_device: IDCompositionDevice, + comp_target: IDCompositionTarget, + comp_visual: IDCompositionVisual, +} + +impl DirectXDevices { + pub(crate) fn new(disable_direct_composition: bool) -> Result<ManuallyDrop<Self>> { + let debug_layer_available = check_debug_layer_available(); + let dxgi_factory = + get_dxgi_factory(debug_layer_available).context("Creating DXGI factory")?; + let adapter = + get_adapter(&dxgi_factory, debug_layer_available).context("Getting DXGI adapter")?; + let (device, device_context) = { + let mut device: Option<ID3D11Device> = None; + let mut context: Option<ID3D11DeviceContext> = None; + let mut feature_level = D3D_FEATURE_LEVEL::default(); + get_device( + &adapter, + Some(&mut device), + Some(&mut context), + Some(&mut feature_level), + debug_layer_available, + ) + .context("Creating Direct3D device")?; + match feature_level { + D3D_FEATURE_LEVEL_11_1 => { + log::info!("Created device with Direct3D 11.1 feature level.") + } + D3D_FEATURE_LEVEL_11_0 => { + log::info!("Created device with Direct3D 11.0 feature level.") + } + D3D_FEATURE_LEVEL_10_1 => { + log::info!("Created device with Direct3D 10.1 feature level.") + } + _ => unreachable!(), + } + (device.unwrap(), context.unwrap()) + }; + let dxgi_device = if disable_direct_composition { + None + } else { + Some(device.cast().context("Creating DXGI device")?) + }; + + Ok(ManuallyDrop::new(Self { + adapter, + dxgi_factory, + dxgi_device, + device, + device_context, + })) + } +} + +impl DirectXRenderer { + pub(crate) fn new(hwnd: HWND, disable_direct_composition: bool) -> Result<Self> { + if disable_direct_composition { + log::info!("Direct Composition is disabled."); + } + + let devices = + DirectXDevices::new(disable_direct_composition).context("Creating DirectX devices")?; + let atlas = Arc::new(DirectXAtlas::new(&devices.device, &devices.device_context)); + + let resources = DirectXResources::new(&devices, 1, 1, hwnd, disable_direct_composition) + .context("Creating DirectX resources")?; + let globals = DirectXGlobalElements::new(&devices.device) + .context("Creating DirectX global elements")?; + let pipelines = DirectXRenderPipelines::new(&devices.device) + .context("Creating DirectX render pipelines")?; + + let direct_composition = if disable_direct_composition { + None + } else { + let composition = DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), hwnd) + .context("Creating DirectComposition")?; + composition + .set_swap_chain(&resources.swap_chain) + .context("Setting swap chain for DirectComposition")?; + Some(composition) + }; + + Ok(DirectXRenderer { + hwnd, + atlas, + devices, + resources, + globals, + pipelines, + direct_composition, + }) + } + + pub(crate) fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> { + self.atlas.clone() + } + + fn pre_draw(&self) -> Result<()> { + update_buffer( + &self.devices.device_context, + self.globals.global_params_buffer[0].as_ref().unwrap(), + &[GlobalParams { + viewport_size: [ + self.resources.viewport[0].Width, + self.resources.viewport[0].Height, + ], + _pad: 0, + }], + )?; + unsafe { + self.devices.device_context.ClearRenderTargetView( + self.resources.render_target_view[0].as_ref().unwrap(), + &[0.0; 4], + ); + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + self.devices + .device_context + .RSSetViewports(Some(&self.resources.viewport)); + } + Ok(()) + } + + fn present(&mut self) -> Result<()> { + unsafe { + let result = self.resources.swap_chain.Present(1, DXGI_PRESENT(0)); + // Presenting the swap chain can fail if the DirectX device was removed or reset. + if result == DXGI_ERROR_DEVICE_REMOVED || result == DXGI_ERROR_DEVICE_RESET { + let reason = self.devices.device.GetDeviceRemovedReason(); + log::error!( + "DirectX device removed or reset when drawing. Reason: {:?}", + reason + ); + self.handle_device_lost()?; + } else { + result.ok()?; + } + } + Ok(()) + } + + fn handle_device_lost(&mut self) -> Result<()> { + // Here we wait a bit to ensure the the system has time to recover from the device lost state. + // If we don't wait, the final drawing result will be blank. + std::thread::sleep(std::time::Duration::from_millis(300)); + let disable_direct_composition = self.direct_composition.is_none(); + + unsafe { + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device) + .context("Failed to report live objects after device lost") + .log_err(); + + ManuallyDrop::drop(&mut self.resources); + self.devices.device_context.OMSetRenderTargets(None, None); + self.devices.device_context.ClearState(); + self.devices.device_context.Flush(); + + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device) + .context("Failed to report live objects after device lost") + .log_err(); + + drop(self.direct_composition.take()); + ManuallyDrop::drop(&mut self.devices); + } + + let devices = DirectXDevices::new(disable_direct_composition) + .context("Recreating DirectX devices")?; + let resources = DirectXResources::new( + &devices, + self.resources.width, + self.resources.height, + self.hwnd, + disable_direct_composition, + )?; + let globals = DirectXGlobalElements::new(&devices.device)?; + let pipelines = DirectXRenderPipelines::new(&devices.device)?; + + let direct_composition = if disable_direct_composition { + None + } else { + let composition = + DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), self.hwnd)?; + composition.set_swap_chain(&resources.swap_chain)?; + Some(composition) + }; + + self.atlas + .handle_device_lost(&devices.device, &devices.device_context); + self.devices = devices; + self.resources = resources; + self.globals = globals; + self.pipelines = pipelines; + self.direct_composition = direct_composition; + + unsafe { + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + Ok(()) + } + + pub(crate) fn draw(&mut self, scene: &Scene) -> Result<()> { + self.pre_draw()?; + for batch in scene.batches() { + match batch { + PrimitiveBatch::Shadows(shadows) => self.draw_shadows(shadows), + PrimitiveBatch::Quads(quads) => self.draw_quads(quads), + PrimitiveBatch::Paths(paths) => { + self.draw_paths_to_intermediate(paths)?; + self.draw_paths_from_intermediate(paths) + } + PrimitiveBatch::Underlines(underlines) => self.draw_underlines(underlines), + PrimitiveBatch::MonochromeSprites { + texture_id, + sprites, + } => self.draw_monochrome_sprites(texture_id, sprites), + PrimitiveBatch::PolychromeSprites { + texture_id, + sprites, + } => self.draw_polychrome_sprites(texture_id, sprites), + PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces(surfaces), + }.context(format!("scene too large: {} paths, {} shadows, {} quads, {} underlines, {} mono, {} poly, {} surfaces", + scene.paths.len(), + scene.shadows.len(), + scene.quads.len(), + scene.underlines.len(), + scene.monochrome_sprites.len(), + scene.polychrome_sprites.len(), + scene.surfaces.len(),))?; + } + self.present() + } + + pub(crate) fn resize(&mut self, new_size: Size<DevicePixels>) -> Result<()> { + let width = new_size.width.0.max(1) as u32; + let height = new_size.height.0.max(1) as u32; + if self.resources.width == width && self.resources.height == height { + return Ok(()); + } + unsafe { + // Clear the render target before resizing + self.devices.device_context.OMSetRenderTargets(None, None); + ManuallyDrop::drop(&mut self.resources.render_target); + drop(self.resources.render_target_view[0].take().unwrap()); + + let result = self.resources.swap_chain.ResizeBuffers( + BUFFER_COUNT as u32, + width, + height, + RENDER_TARGET_FORMAT, + DXGI_SWAP_CHAIN_FLAG(0), + ); + // Resizing the swap chain requires a call to the underlying DXGI adapter, which can return the device removed error. + // The app might have moved to a monitor that's attached to a different graphics device. + // When a graphics device is removed or reset, the desktop resolution often changes, resulting in a window size change. + match result { + Ok(_) => {} + Err(e) => { + if e.code() == DXGI_ERROR_DEVICE_REMOVED || e.code() == DXGI_ERROR_DEVICE_RESET + { + let reason = self.devices.device.GetDeviceRemovedReason(); + log::error!( + "DirectX device removed or reset when resizing. Reason: {:?}", + reason + ); + self.resources.width = width; + self.resources.height = height; + self.handle_device_lost()?; + return Ok(()); + } else { + log::error!("Failed to resize swap chain: {:?}", e); + return Err(e.into()); + } + } + } + + self.resources + .recreate_resources(&self.devices, width, height)?; + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + Ok(()) + } + + fn draw_shadows(&mut self, shadows: &[Shadow]) -> Result<()> { + if shadows.is_empty() { + return Ok(()); + } + self.pipelines.shadow_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + shadows, + )?; + self.pipelines.shadow_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + shadows.len() as u32, + ) + } + + fn draw_quads(&mut self, quads: &[Quad]) -> Result<()> { + if quads.is_empty() { + return Ok(()); + } + self.pipelines.quad_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + quads, + )?; + self.pipelines.quad_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + quads.len() as u32, + ) + } + + fn draw_paths_to_intermediate(&mut self, paths: &[Path<ScaledPixels>]) -> Result<()> { + if paths.is_empty() { + return Ok(()); + } + + // Clear intermediate MSAA texture + unsafe { + self.devices.device_context.ClearRenderTargetView( + self.resources.path_intermediate_msaa_view[0] + .as_ref() + .unwrap(), + &[0.0; 4], + ); + // Set intermediate MSAA texture as render target + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.path_intermediate_msaa_view), None); + } + + // Collect all vertices and sprites for a single draw call + let mut vertices = Vec::new(); + + for path in paths { + vertices.extend(path.vertices.iter().map(|v| PathRasterizationSprite { + xy_position: v.xy_position, + st_position: v.st_position, + color: path.color, + bounds: path.bounds.intersect(&path.content_mask.bounds), + })); + } + + self.pipelines.path_rasterization_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + &vertices, + )?; + self.pipelines.path_rasterization_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST, + vertices.len() as u32, + 1, + )?; + + // Resolve MSAA to non-MSAA intermediate texture + unsafe { + self.devices.device_context.ResolveSubresource( + &self.resources.path_intermediate_texture, + 0, + &self.resources.path_intermediate_msaa_texture, + 0, + RENDER_TARGET_FORMAT, + ); + // Restore main render target + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + + Ok(()) + } + + fn draw_paths_from_intermediate(&mut self, paths: &[Path<ScaledPixels>]) -> Result<()> { + let Some(first_path) = paths.first() else { + return Ok(()); + }; + + // When copying paths from the intermediate texture to the drawable, + // each pixel must only be copied once, in case of transparent paths. + // + // If all paths have the same draw order, then their bounds are all + // disjoint, so we can copy each path's bounds individually. If this + // batch combines different draw orders, we perform a single copy + // for a minimal spanning rect. + let sprites = if paths.last().unwrap().order == first_path.order { + paths + .iter() + .map(|path| PathSprite { + bounds: path.bounds, + }) + .collect::<Vec<_>>() + } else { + let mut bounds = first_path.bounds; + for path in paths.iter().skip(1) { + bounds = bounds.union(&path.bounds); + } + vec![PathSprite { bounds }] + }; + + self.pipelines.path_sprite_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + &sprites, + )?; + + // Draw the sprites with the path texture + self.pipelines.path_sprite_pipeline.draw_with_texture( + &self.devices.device_context, + &self.resources.path_intermediate_srv, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_underlines(&mut self, underlines: &[Underline]) -> Result<()> { + if underlines.is_empty() { + return Ok(()); + } + self.pipelines.underline_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + underlines, + )?; + self.pipelines.underline_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + underlines.len() as u32, + ) + } + + fn draw_monochrome_sprites( + &mut self, + texture_id: AtlasTextureId, + sprites: &[MonochromeSprite], + ) -> Result<()> { + if sprites.is_empty() { + return Ok(()); + } + self.pipelines.mono_sprites.update_buffer( + &self.devices.device, + &self.devices.device_context, + sprites, + )?; + let texture_view = self.atlas.get_texture_view(texture_id); + self.pipelines.mono_sprites.draw_with_texture( + &self.devices.device_context, + &texture_view, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_polychrome_sprites( + &mut self, + texture_id: AtlasTextureId, + sprites: &[PolychromeSprite], + ) -> Result<()> { + if sprites.is_empty() { + return Ok(()); + } + self.pipelines.poly_sprites.update_buffer( + &self.devices.device, + &self.devices.device_context, + sprites, + )?; + let texture_view = self.atlas.get_texture_view(texture_id); + self.pipelines.poly_sprites.draw_with_texture( + &self.devices.device_context, + &texture_view, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_surfaces(&mut self, surfaces: &[PaintSurface]) -> Result<()> { + if surfaces.is_empty() { + return Ok(()); + } + Ok(()) + } + + pub(crate) fn gpu_specs(&self) -> Result<GpuSpecs> { + let desc = unsafe { self.devices.adapter.GetDesc1() }?; + let is_software_emulated = (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE.0 as u32) != 0; + let device_name = String::from_utf16_lossy(&desc.Description) + .trim_matches(char::from(0)) + .to_string(); + let driver_name = match desc.VendorId { + 0x10DE => "NVIDIA Corporation".to_string(), + 0x1002 => "AMD Corporation".to_string(), + 0x8086 => "Intel Corporation".to_string(), + id => format!("Unknown Vendor (ID: {:#X})", id), + }; + let driver_version = match desc.VendorId { + 0x10DE => nvidia::get_driver_version(), + 0x1002 => amd::get_driver_version(), + // For Intel and other vendors, we use the DXGI API to get the driver version. + _ => dxgi::get_driver_version(&self.devices.adapter), + } + .context("Failed to get gpu driver info") + .log_err() + .unwrap_or("Unknown Driver".to_string()); + Ok(GpuSpecs { + is_software_emulated, + device_name, + driver_name, + driver_info: driver_version, + }) + } +} + +impl DirectXResources { + pub fn new( + devices: &DirectXDevices, + width: u32, + height: u32, + hwnd: HWND, + disable_direct_composition: bool, + ) -> Result<ManuallyDrop<Self>> { + let swap_chain = if disable_direct_composition { + create_swap_chain(&devices.dxgi_factory, &devices.device, hwnd, width, height)? + } else { + create_swap_chain_for_composition( + &devices.dxgi_factory, + &devices.device, + width, + height, + )? + }; + + let ( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + ) = create_resources(devices, &swap_chain, width, height)?; + set_rasterizer_state(&devices.device, &devices.device_context)?; + + Ok(ManuallyDrop::new(Self { + swap_chain, + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + path_intermediate_srv, + viewport, + width, + height, + })) + } + + #[inline] + fn recreate_resources( + &mut self, + devices: &DirectXDevices, + width: u32, + height: u32, + ) -> Result<()> { + let ( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + ) = create_resources(devices, &self.swap_chain, width, height)?; + self.render_target = render_target; + self.render_target_view = render_target_view; + self.path_intermediate_texture = path_intermediate_texture; + self.path_intermediate_msaa_texture = path_intermediate_msaa_texture; + self.path_intermediate_msaa_view = path_intermediate_msaa_view; + self.path_intermediate_srv = path_intermediate_srv; + self.viewport = viewport; + self.width = width; + self.height = height; + Ok(()) + } +} + +impl DirectXRenderPipelines { + pub fn new(device: &ID3D11Device) -> Result<Self> { + let shadow_pipeline = PipelineState::new( + device, + "shadow_pipeline", + ShaderModule::Shadow, + 4, + create_blend_state(device)?, + )?; + let quad_pipeline = PipelineState::new( + device, + "quad_pipeline", + ShaderModule::Quad, + 64, + create_blend_state(device)?, + )?; + let path_rasterization_pipeline = PipelineState::new( + device, + "path_rasterization_pipeline", + ShaderModule::PathRasterization, + 32, + create_blend_state_for_path_rasterization(device)?, + )?; + let path_sprite_pipeline = PipelineState::new( + device, + "path_sprite_pipeline", + ShaderModule::PathSprite, + 4, + create_blend_state_for_path_sprite(device)?, + )?; + let underline_pipeline = PipelineState::new( + device, + "underline_pipeline", + ShaderModule::Underline, + 4, + create_blend_state(device)?, + )?; + let mono_sprites = PipelineState::new( + device, + "monochrome_sprite_pipeline", + ShaderModule::MonochromeSprite, + 512, + create_blend_state(device)?, + )?; + let poly_sprites = PipelineState::new( + device, + "polychrome_sprite_pipeline", + ShaderModule::PolychromeSprite, + 16, + create_blend_state(device)?, + )?; + + Ok(Self { + shadow_pipeline, + quad_pipeline, + path_rasterization_pipeline, + path_sprite_pipeline, + underline_pipeline, + mono_sprites, + poly_sprites, + }) + } +} + +impl DirectComposition { + pub fn new(dxgi_device: &IDXGIDevice, hwnd: HWND) -> Result<Self> { + let comp_device = get_comp_device(&dxgi_device)?; + let comp_target = unsafe { comp_device.CreateTargetForHwnd(hwnd, true) }?; + let comp_visual = unsafe { comp_device.CreateVisual() }?; + + Ok(Self { + comp_device, + comp_target, + comp_visual, + }) + } + + pub fn set_swap_chain(&self, swap_chain: &IDXGISwapChain1) -> Result<()> { + unsafe { + self.comp_visual.SetContent(swap_chain)?; + self.comp_target.SetRoot(&self.comp_visual)?; + self.comp_device.Commit()?; + } + Ok(()) + } +} + +impl DirectXGlobalElements { + pub fn new(device: &ID3D11Device) -> Result<Self> { + let global_params_buffer = unsafe { + let desc = D3D11_BUFFER_DESC { + ByteWidth: std::mem::size_of::<GlobalParams>() as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + ..Default::default() + }; + let mut buffer = None; + device.CreateBuffer(&desc, None, Some(&mut buffer))?; + [buffer] + }; + + let sampler = unsafe { + let desc = D3D11_SAMPLER_DESC { + Filter: D3D11_FILTER_MIN_MAG_MIP_LINEAR, + AddressU: D3D11_TEXTURE_ADDRESS_WRAP, + AddressV: D3D11_TEXTURE_ADDRESS_WRAP, + AddressW: D3D11_TEXTURE_ADDRESS_WRAP, + MipLODBias: 0.0, + MaxAnisotropy: 1, + ComparisonFunc: D3D11_COMPARISON_ALWAYS, + BorderColor: [0.0; 4], + MinLOD: 0.0, + MaxLOD: D3D11_FLOAT32_MAX, + }; + let mut output = None; + device.CreateSamplerState(&desc, Some(&mut output))?; + [output] + }; + + Ok(Self { + global_params_buffer, + sampler, + }) + } +} + +#[derive(Debug, Default)] +#[repr(C)] +struct GlobalParams { + viewport_size: [f32; 2], + _pad: u64, +} + +struct PipelineState<T> { + label: &'static str, + vertex: ID3D11VertexShader, + fragment: ID3D11PixelShader, + buffer: ID3D11Buffer, + buffer_size: usize, + view: [Option<ID3D11ShaderResourceView>; 1], + blend_state: ID3D11BlendState, + _marker: std::marker::PhantomData<T>, +} + +impl<T> PipelineState<T> { + fn new( + device: &ID3D11Device, + label: &'static str, + shader_module: ShaderModule, + buffer_size: usize, + blend_state: ID3D11BlendState, + ) -> Result<Self> { + let vertex = { + let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Vertex)?; + create_vertex_shader(device, raw_shader.as_bytes())? + }; + let fragment = { + let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Fragment)?; + create_fragment_shader(device, raw_shader.as_bytes())? + }; + let buffer = create_buffer(device, std::mem::size_of::<T>(), buffer_size)?; + let view = create_buffer_view(device, &buffer)?; + + Ok(PipelineState { + label, + vertex, + fragment, + buffer, + buffer_size, + view, + blend_state, + _marker: std::marker::PhantomData, + }) + } + + fn update_buffer( + &mut self, + device: &ID3D11Device, + device_context: &ID3D11DeviceContext, + data: &[T], + ) -> Result<()> { + if self.buffer_size < data.len() { + let new_buffer_size = data.len().next_power_of_two(); + log::info!( + "Updating {} buffer size from {} to {}", + self.label, + self.buffer_size, + new_buffer_size + ); + let buffer = create_buffer(device, std::mem::size_of::<T>(), new_buffer_size)?; + let view = create_buffer_view(device, &buffer)?; + self.buffer = buffer; + self.view = view; + self.buffer_size = new_buffer_size; + } + update_buffer(device_context, &self.buffer, data) + } + + fn draw( + &self, + device_context: &ID3D11DeviceContext, + viewport: &[D3D11_VIEWPORT], + global_params: &[Option<ID3D11Buffer>], + topology: D3D_PRIMITIVE_TOPOLOGY, + vertex_count: u32, + instance_count: u32, + ) -> Result<()> { + set_pipeline_state( + device_context, + &self.view, + topology, + viewport, + &self.vertex, + &self.fragment, + global_params, + &self.blend_state, + ); + unsafe { + device_context.DrawInstanced(vertex_count, instance_count, 0, 0); + } + Ok(()) + } + + fn draw_with_texture( + &self, + device_context: &ID3D11DeviceContext, + texture: &[Option<ID3D11ShaderResourceView>], + viewport: &[D3D11_VIEWPORT], + global_params: &[Option<ID3D11Buffer>], + sampler: &[Option<ID3D11SamplerState>], + instance_count: u32, + ) -> Result<()> { + set_pipeline_state( + device_context, + &self.view, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + viewport, + &self.vertex, + &self.fragment, + global_params, + &self.blend_state, + ); + unsafe { + device_context.PSSetSamplers(0, Some(sampler)); + device_context.VSSetShaderResources(0, Some(texture)); + device_context.PSSetShaderResources(0, Some(texture)); + + device_context.DrawInstanced(4, instance_count, 0, 0); + } + Ok(()) + } +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct PathRasterizationSprite { + xy_position: Point<ScaledPixels>, + st_position: Point<f32>, + color: Background, + bounds: Bounds<ScaledPixels>, +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct PathSprite { + bounds: Bounds<ScaledPixels>, +} + +impl Drop for DirectXRenderer { + fn drop(&mut self) { + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device).ok(); + unsafe { + ManuallyDrop::drop(&mut self.devices); + ManuallyDrop::drop(&mut self.resources); + } + } +} + +impl Drop for DirectXResources { + fn drop(&mut self) { + unsafe { + ManuallyDrop::drop(&mut self.render_target); + } + } +} + +#[inline] +fn check_debug_layer_available() -> bool { + #[cfg(debug_assertions)] + { + unsafe { DXGIGetDebugInterface1::<IDXGIInfoQueue>(0) } + .log_err() + .is_some() + } + #[cfg(not(debug_assertions))] + { + false + } +} + +#[inline] +fn get_dxgi_factory(debug_layer_available: bool) -> Result<IDXGIFactory6> { + let factory_flag = if debug_layer_available { + DXGI_CREATE_FACTORY_DEBUG + } else { + #[cfg(debug_assertions)] + log::warn!( + "Failed to get DXGI debug interface. DirectX debugging features will be disabled." + ); + DXGI_CREATE_FACTORY_FLAGS::default() + }; + unsafe { Ok(CreateDXGIFactory2(factory_flag)?) } +} + +fn get_adapter(dxgi_factory: &IDXGIFactory6, debug_layer_available: bool) -> Result<IDXGIAdapter1> { + for adapter_index in 0.. { + let adapter: IDXGIAdapter1 = unsafe { + dxgi_factory + .EnumAdapterByGpuPreference(adapter_index, DXGI_GPU_PREFERENCE_MINIMUM_POWER) + }?; + if let Ok(desc) = unsafe { adapter.GetDesc1() } { + let gpu_name = String::from_utf16_lossy(&desc.Description) + .trim_matches(char::from(0)) + .to_string(); + log::info!("Using GPU: {}", gpu_name); + } + // Check to see whether the adapter supports Direct3D 11, but don't + // create the actual device yet. + if get_device(&adapter, None, None, None, debug_layer_available) + .log_err() + .is_some() + { + return Ok(adapter); + } + } + + unreachable!() +} + +fn get_device( + adapter: &IDXGIAdapter1, + device: Option<*mut Option<ID3D11Device>>, + context: Option<*mut Option<ID3D11DeviceContext>>, + feature_level: Option<*mut D3D_FEATURE_LEVEL>, + debug_layer_available: bool, +) -> Result<()> { + let device_flags = if debug_layer_available { + D3D11_CREATE_DEVICE_BGRA_SUPPORT | D3D11_CREATE_DEVICE_DEBUG + } else { + D3D11_CREATE_DEVICE_BGRA_SUPPORT + }; + unsafe { + D3D11CreateDevice( + adapter, + D3D_DRIVER_TYPE_UNKNOWN, + HMODULE::default(), + device_flags, + // 4x MSAA is required for Direct3D Feature Level 10.1 or better + Some(&[ + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_10_1, + ]), + D3D11_SDK_VERSION, + device, + feature_level, + context, + )?; + } + Ok(()) +} + +#[inline] +fn get_comp_device(dxgi_device: &IDXGIDevice) -> Result<IDCompositionDevice> { + Ok(unsafe { DCompositionCreateDevice(dxgi_device)? }) +} + +fn create_swap_chain_for_composition( + dxgi_factory: &IDXGIFactory6, + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<IDXGISwapChain1> { + let desc = DXGI_SWAP_CHAIN_DESC1 { + Width: width, + Height: height, + Format: RENDER_TARGET_FORMAT, + Stereo: false.into(), + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, + BufferCount: BUFFER_COUNT as u32, + // Composition SwapChains only support the DXGI_SCALING_STRETCH Scaling. + Scaling: DXGI_SCALING_STRETCH, + SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, + AlphaMode: DXGI_ALPHA_MODE_PREMULTIPLIED, + Flags: 0, + }; + Ok(unsafe { dxgi_factory.CreateSwapChainForComposition(device, &desc, None)? }) +} + +fn create_swap_chain( + dxgi_factory: &IDXGIFactory6, + device: &ID3D11Device, + hwnd: HWND, + width: u32, + height: u32, +) -> Result<IDXGISwapChain1> { + use windows::Win32::Graphics::Dxgi::DXGI_MWA_NO_ALT_ENTER; + + let desc = DXGI_SWAP_CHAIN_DESC1 { + Width: width, + Height: height, + Format: RENDER_TARGET_FORMAT, + Stereo: false.into(), + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, + BufferCount: BUFFER_COUNT as u32, + Scaling: DXGI_SCALING_NONE, + SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, + AlphaMode: DXGI_ALPHA_MODE_IGNORE, + Flags: 0, + }; + let swap_chain = + unsafe { dxgi_factory.CreateSwapChainForHwnd(device, hwnd, &desc, None, None) }?; + unsafe { dxgi_factory.MakeWindowAssociation(hwnd, DXGI_MWA_NO_ALT_ENTER) }?; + Ok(swap_chain) +} + +#[inline] +fn create_resources( + devices: &DirectXDevices, + swap_chain: &IDXGISwapChain1, + width: u32, + height: u32, +) -> Result<( + ManuallyDrop<ID3D11Texture2D>, + [Option<ID3D11RenderTargetView>; 1], + ID3D11Texture2D, + [Option<ID3D11ShaderResourceView>; 1], + ID3D11Texture2D, + [Option<ID3D11RenderTargetView>; 1], + [D3D11_VIEWPORT; 1], +)> { + let (render_target, render_target_view) = + create_render_target_and_its_view(&swap_chain, &devices.device)?; + let (path_intermediate_texture, path_intermediate_srv) = + create_path_intermediate_texture(&devices.device, width, height)?; + let (path_intermediate_msaa_texture, path_intermediate_msaa_view) = + create_path_intermediate_msaa_texture_and_view(&devices.device, width, height)?; + let viewport = set_viewport(&devices.device_context, width as f32, height as f32); + Ok(( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + )) +} + +#[inline] +fn create_render_target_and_its_view( + swap_chain: &IDXGISwapChain1, + device: &ID3D11Device, +) -> Result<( + ManuallyDrop<ID3D11Texture2D>, + [Option<ID3D11RenderTargetView>; 1], +)> { + let render_target: ID3D11Texture2D = unsafe { swap_chain.GetBuffer(0) }?; + let mut render_target_view = None; + unsafe { device.CreateRenderTargetView(&render_target, None, Some(&mut render_target_view))? }; + Ok(( + ManuallyDrop::new(render_target), + [Some(render_target_view.unwrap())], + )) +} + +#[inline] +fn create_path_intermediate_texture( + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<(ID3D11Texture2D, [Option<ID3D11ShaderResourceView>; 1])> { + let texture = unsafe { + let mut output = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: width, + Height: height, + MipLevels: 1, + ArraySize: 1, + Format: RENDER_TARGET_FORMAT, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: (D3D11_BIND_RENDER_TARGET.0 | D3D11_BIND_SHADER_RESOURCE.0) as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + device.CreateTexture2D(&desc, None, Some(&mut output))?; + output.unwrap() + }; + + let mut shader_resource_view = None; + unsafe { device.CreateShaderResourceView(&texture, None, Some(&mut shader_resource_view))? }; + + Ok((texture, [Some(shader_resource_view.unwrap())])) +} + +#[inline] +fn create_path_intermediate_msaa_texture_and_view( + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<(ID3D11Texture2D, [Option<ID3D11RenderTargetView>; 1])> { + let msaa_texture = unsafe { + let mut output = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: width, + Height: height, + MipLevels: 1, + ArraySize: 1, + Format: RENDER_TARGET_FORMAT, + SampleDesc: DXGI_SAMPLE_DESC { + Count: PATH_MULTISAMPLE_COUNT, + Quality: D3D11_STANDARD_MULTISAMPLE_PATTERN.0 as u32, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + device.CreateTexture2D(&desc, None, Some(&mut output))?; + output.unwrap() + }; + let mut msaa_view = None; + unsafe { device.CreateRenderTargetView(&msaa_texture, None, Some(&mut msaa_view))? }; + Ok((msaa_texture, [Some(msaa_view.unwrap())])) +} + +#[inline] +fn set_viewport( + device_context: &ID3D11DeviceContext, + width: f32, + height: f32, +) -> [D3D11_VIEWPORT; 1] { + let viewport = [D3D11_VIEWPORT { + TopLeftX: 0.0, + TopLeftY: 0.0, + Width: width, + Height: height, + MinDepth: 0.0, + MaxDepth: 1.0, + }]; + unsafe { device_context.RSSetViewports(Some(&viewport)) }; + viewport +} + +#[inline] +fn set_rasterizer_state(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Result<()> { + let desc = D3D11_RASTERIZER_DESC { + FillMode: D3D11_FILL_SOLID, + CullMode: D3D11_CULL_NONE, + FrontCounterClockwise: false.into(), + DepthBias: 0, + DepthBiasClamp: 0.0, + SlopeScaledDepthBias: 0.0, + DepthClipEnable: true.into(), + ScissorEnable: false.into(), + MultisampleEnable: true.into(), + AntialiasedLineEnable: false.into(), + }; + let rasterizer_state = unsafe { + let mut state = None; + device.CreateRasterizerState(&desc, Some(&mut state))?; + state.unwrap() + }; + unsafe { device_context.RSSetState(&rasterizer_state) }; + Ok(()) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_blend_desc +#[inline] +fn create_blend_state(device: &ID3D11Device) -> Result<ID3D11BlendState> { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_SRC_ALPHA; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_blend_state_for_path_rasterization(device: &ID3D11Device) -> Result<ID3D11BlendState> { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_blend_state_for_path_sprite(device: &ID3D11Device) -> Result<ID3D11BlendState> { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_vertex_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11VertexShader> { + unsafe { + let mut shader = None; + device.CreateVertexShader(bytes, None, Some(&mut shader))?; + Ok(shader.unwrap()) + } +} + +#[inline] +fn create_fragment_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11PixelShader> { + unsafe { + let mut shader = None; + device.CreatePixelShader(bytes, None, Some(&mut shader))?; + Ok(shader.unwrap()) + } +} + +#[inline] +fn create_buffer( + device: &ID3D11Device, + element_size: usize, + buffer_size: usize, +) -> Result<ID3D11Buffer> { + let desc = D3D11_BUFFER_DESC { + ByteWidth: (element_size * buffer_size) as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: D3D11_RESOURCE_MISC_BUFFER_STRUCTURED.0 as u32, + StructureByteStride: element_size as u32, + }; + let mut buffer = None; + unsafe { device.CreateBuffer(&desc, None, Some(&mut buffer)) }?; + Ok(buffer.unwrap()) +} + +#[inline] +fn create_buffer_view( + device: &ID3D11Device, + buffer: &ID3D11Buffer, +) -> Result<[Option<ID3D11ShaderResourceView>; 1]> { + let mut view = None; + unsafe { device.CreateShaderResourceView(buffer, None, Some(&mut view)) }?; + Ok([view]) +} + +#[inline] +fn update_buffer<T>( + device_context: &ID3D11DeviceContext, + buffer: &ID3D11Buffer, + data: &[T], +) -> Result<()> { + unsafe { + let mut dest = std::mem::zeroed(); + device_context.Map(buffer, 0, D3D11_MAP_WRITE_DISCARD, 0, Some(&mut dest))?; + std::ptr::copy_nonoverlapping(data.as_ptr(), dest.pData as _, data.len()); + device_context.Unmap(buffer, 0); + } + Ok(()) +} + +#[inline] +fn set_pipeline_state( + device_context: &ID3D11DeviceContext, + buffer_view: &[Option<ID3D11ShaderResourceView>], + topology: D3D_PRIMITIVE_TOPOLOGY, + viewport: &[D3D11_VIEWPORT], + vertex_shader: &ID3D11VertexShader, + fragment_shader: &ID3D11PixelShader, + global_params: &[Option<ID3D11Buffer>], + blend_state: &ID3D11BlendState, +) { + unsafe { + device_context.VSSetShaderResources(1, Some(buffer_view)); + device_context.PSSetShaderResources(1, Some(buffer_view)); + device_context.IASetPrimitiveTopology(topology); + device_context.RSSetViewports(Some(viewport)); + device_context.VSSetShader(vertex_shader, None); + device_context.PSSetShader(fragment_shader, None); + device_context.VSSetConstantBuffers(0, Some(global_params)); + device_context.PSSetConstantBuffers(0, Some(global_params)); + device_context.OMSetBlendState(blend_state, None, 0xFFFFFFFF); + } +} + +#[cfg(debug_assertions)] +fn report_live_objects(device: &ID3D11Device) -> Result<()> { + let debug_device: ID3D11Debug = device.cast()?; + unsafe { + debug_device.ReportLiveDeviceObjects(D3D11_RLDO_DETAIL)?; + } + Ok(()) +} + +const BUFFER_COUNT: usize = 3; + +pub(crate) mod shader_resources { + use anyhow::Result; + + #[cfg(debug_assertions)] + use windows::{ + Win32::Graphics::Direct3D::{ + Fxc::{D3DCOMPILE_DEBUG, D3DCOMPILE_SKIP_OPTIMIZATION, D3DCompileFromFile}, + ID3DBlob, + }, + core::{HSTRING, PCSTR}, + }; + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(crate) enum ShaderModule { + Quad, + Shadow, + Underline, + PathRasterization, + PathSprite, + MonochromeSprite, + PolychromeSprite, + EmojiRasterization, + } + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(crate) enum ShaderTarget { + Vertex, + Fragment, + } + + pub(crate) struct RawShaderBytes<'t> { + inner: &'t [u8], + + #[cfg(debug_assertions)] + _blob: ID3DBlob, + } + + impl<'t> RawShaderBytes<'t> { + pub(crate) fn new(module: ShaderModule, target: ShaderTarget) -> Result<Self> { + #[cfg(not(debug_assertions))] + { + Ok(Self::from_bytes(module, target)) + } + #[cfg(debug_assertions)] + { + let blob = build_shader_blob(module, target)?; + let inner = unsafe { + std::slice::from_raw_parts( + blob.GetBufferPointer() as *const u8, + blob.GetBufferSize(), + ) + }; + Ok(Self { inner, _blob: blob }) + } + } + + pub(crate) fn as_bytes(&'t self) -> &'t [u8] { + self.inner + } + + #[cfg(not(debug_assertions))] + fn from_bytes(module: ShaderModule, target: ShaderTarget) -> Self { + let bytes = match module { + ShaderModule::Quad => match target { + ShaderTarget::Vertex => QUAD_VERTEX_BYTES, + ShaderTarget::Fragment => QUAD_FRAGMENT_BYTES, + }, + ShaderModule::Shadow => match target { + ShaderTarget::Vertex => SHADOW_VERTEX_BYTES, + ShaderTarget::Fragment => SHADOW_FRAGMENT_BYTES, + }, + ShaderModule::Underline => match target { + ShaderTarget::Vertex => UNDERLINE_VERTEX_BYTES, + ShaderTarget::Fragment => UNDERLINE_FRAGMENT_BYTES, + }, + ShaderModule::PathRasterization => match target { + ShaderTarget::Vertex => PATH_RASTERIZATION_VERTEX_BYTES, + ShaderTarget::Fragment => PATH_RASTERIZATION_FRAGMENT_BYTES, + }, + ShaderModule::PathSprite => match target { + ShaderTarget::Vertex => PATH_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => PATH_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::MonochromeSprite => match target { + ShaderTarget::Vertex => MONOCHROME_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => MONOCHROME_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::PolychromeSprite => match target { + ShaderTarget::Vertex => POLYCHROME_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => POLYCHROME_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::EmojiRasterization => match target { + ShaderTarget::Vertex => EMOJI_RASTERIZATION_VERTEX_BYTES, + ShaderTarget::Fragment => EMOJI_RASTERIZATION_FRAGMENT_BYTES, + }, + }; + Self { inner: bytes } + } + } + + #[cfg(debug_assertions)] + pub(super) fn build_shader_blob(entry: ShaderModule, target: ShaderTarget) -> Result<ID3DBlob> { + unsafe { + let shader_name = if matches!(entry, ShaderModule::EmojiRasterization) { + "color_text_raster.hlsl" + } else { + "shaders.hlsl" + }; + + let entry = format!( + "{}_{}\0", + entry.as_str(), + match target { + ShaderTarget::Vertex => "vertex", + ShaderTarget::Fragment => "fragment", + } + ); + let target = match target { + ShaderTarget::Vertex => "vs_4_1\0", + ShaderTarget::Fragment => "ps_4_1\0", + }; + + let mut compile_blob = None; + let mut error_blob = None; + let shader_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(&format!("src/platform/windows/{}", shader_name)) + .canonicalize()?; + + let entry_point = PCSTR::from_raw(entry.as_ptr()); + let target_cstr = PCSTR::from_raw(target.as_ptr()); + + let ret = D3DCompileFromFile( + &HSTRING::from(shader_path.to_str().unwrap()), + None, + None, + entry_point, + target_cstr, + D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, + 0, + &mut compile_blob, + Some(&mut error_blob), + ); + if ret.is_err() { + let Some(error_blob) = error_blob else { + return Err(anyhow::anyhow!("{ret:?}")); + }; + + let error_string = + std::ffi::CStr::from_ptr(error_blob.GetBufferPointer() as *const i8) + .to_string_lossy(); + log::error!("Shader compile error: {}", error_string); + return Err(anyhow::anyhow!("Compile error: {}", error_string)); + } + Ok(compile_blob.unwrap()) + } + } + + #[cfg(not(debug_assertions))] + include!(concat!(env!("OUT_DIR"), "/shaders_bytes.rs")); + + #[cfg(debug_assertions)] + impl ShaderModule { + pub fn as_str(&self) -> &str { + match self { + ShaderModule::Quad => "quad", + ShaderModule::Shadow => "shadow", + ShaderModule::Underline => "underline", + ShaderModule::PathRasterization => "path_rasterization", + ShaderModule::PathSprite => "path_sprite", + ShaderModule::MonochromeSprite => "monochrome_sprite", + ShaderModule::PolychromeSprite => "polychrome_sprite", + ShaderModule::EmojiRasterization => "emoji_rasterization", + } + } + } +} + +mod nvidia { + use std::{ + ffi::CStr, + os::raw::{c_char, c_int, c_uint}, + }; + + use anyhow::{Context, Result}; + use windows::{ + Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, + core::s, + }; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180 + const NVAPI_SHORT_STRING_MAX: usize = 64; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L235 + #[allow(non_camel_case_types)] + type NvAPI_ShortString = [c_char; NVAPI_SHORT_STRING_MAX]; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L447 + #[allow(non_camel_case_types)] + type NvAPI_SYS_GetDriverAndBranchVersion_t = unsafe extern "C" fn( + driver_version: *mut c_uint, + build_branch_string: *mut NvAPI_ShortString, + ) -> c_int; + + pub(super) fn get_driver_version() -> Result<String> { + unsafe { + // Try to load the NVIDIA driver DLL + #[cfg(target_pointer_width = "64")] + let nvidia_dll = LoadLibraryA(s!("nvapi64.dll")).context("Can't load nvapi64.dll")?; + #[cfg(target_pointer_width = "32")] + let nvidia_dll = LoadLibraryA(s!("nvapi.dll")).context("Can't load nvapi.dll")?; + + let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface")) + .ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?; + let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr); + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_interface.h#L41 + let nvapi_get_driver_version_ptr = nvapi_query(0x2926aaad); + if nvapi_get_driver_version_ptr.is_null() { + anyhow::bail!("Failed to get NVIDIA driver version function pointer"); + } + let nvapi_get_driver_version: NvAPI_SYS_GetDriverAndBranchVersion_t = + std::mem::transmute(nvapi_get_driver_version_ptr); + + let mut driver_version: c_uint = 0; + let mut build_branch_string: NvAPI_ShortString = [0; NVAPI_SHORT_STRING_MAX]; + let result = nvapi_get_driver_version( + &mut driver_version as *mut c_uint, + &mut build_branch_string as *mut NvAPI_ShortString, + ); + + if result != 0 { + anyhow::bail!( + "Failed to get NVIDIA driver version, error code: {}", + result + ); + } + let major = driver_version / 100; + let minor = driver_version % 100; + let branch_string = CStr::from_ptr(build_branch_string.as_ptr()); + Ok(format!( + "{}.{} {}", + major, + minor, + branch_string.to_string_lossy() + )) + } + } +} + +mod amd { + use std::os::raw::{c_char, c_int, c_void}; + + use anyhow::{Context, Result}; + use windows::{ + Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, + core::s, + }; + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145 + const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12); + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L204 + // This is an opaque type, using struct to represent it properly for FFI + #[repr(C)] + struct AGSContext { + _private: [u8; 0], + } + + #[repr(C)] + pub struct AGSGPUInfo { + pub driver_version: *const c_char, + pub radeon_software_version: *const c_char, + pub num_devices: c_int, + pub devices: *mut c_void, + } + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L429 + #[allow(non_camel_case_types)] + type agsInitialize_t = unsafe extern "C" fn( + version: c_int, + config: *const c_void, + context: *mut *mut AGSContext, + gpu_info: *mut AGSGPUInfo, + ) -> c_int; + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L436 + #[allow(non_camel_case_types)] + type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int; + + pub(super) fn get_driver_version() -> Result<String> { + unsafe { + #[cfg(target_pointer_width = "64")] + let amd_dll = + LoadLibraryA(s!("amd_ags_x64.dll")).context("Failed to load AMD AGS library")?; + #[cfg(target_pointer_width = "32")] + let amd_dll = + LoadLibraryA(s!("amd_ags_x86.dll")).context("Failed to load AMD AGS library")?; + + let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize")) + .ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?; + let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize")) + .ok_or_else(|| anyhow::anyhow!("Failed to get agsDeInitialize address"))?; + + let ags_initialize: agsInitialize_t = std::mem::transmute(ags_initialize_addr); + let ags_deinitialize: agsDeInitialize_t = std::mem::transmute(ags_deinitialize_addr); + + let mut context: *mut AGSContext = std::ptr::null_mut(); + let mut gpu_info: AGSGPUInfo = AGSGPUInfo { + driver_version: std::ptr::null(), + radeon_software_version: std::ptr::null(), + num_devices: 0, + devices: std::ptr::null_mut(), + }; + + let result = ags_initialize( + AGS_CURRENT_VERSION, + std::ptr::null(), + &mut context, + &mut gpu_info, + ); + if result != 0 { + anyhow::bail!("Failed to initialize AMD AGS, error code: {}", result); + } + + // Vulkan acctually returns this as the driver version + let software_version = if !gpu_info.radeon_software_version.is_null() { + std::ffi::CStr::from_ptr(gpu_info.radeon_software_version) + .to_string_lossy() + .into_owned() + } else { + "Unknown Radeon Software Version".to_string() + }; + + let driver_version = if !gpu_info.driver_version.is_null() { + std::ffi::CStr::from_ptr(gpu_info.driver_version) + .to_string_lossy() + .into_owned() + } else { + "Unknown Radeon Driver Version".to_string() + }; + + ags_deinitialize(context); + Ok(format!("{} ({})", software_version, driver_version)) + } + } +} + +mod dxgi { + use windows::{ + Win32::Graphics::Dxgi::{IDXGIAdapter1, IDXGIDevice}, + core::Interface, + }; + + pub(super) fn get_driver_version(adapter: &IDXGIAdapter1) -> anyhow::Result<String> { + let number = unsafe { adapter.CheckInterfaceSupport(&IDXGIDevice::IID as _) }?; + Ok(format!( + "{}.{}.{}.{}", + number >> 48, + (number >> 32) & 0xFFFF, + (number >> 16) & 0xFFFF, + number & 0xFFFF + )) + } +} diff --git a/crates/gpui/src/platform/windows/events.rs b/crates/gpui/src/platform/windows/events.rs index 839fd10375..00b22fa807 100644 --- a/crates/gpui/src/platform/windows/events.rs +++ b/crates/gpui/src/platform/windows/events.rs @@ -23,1027 +23,868 @@ pub(crate) const WM_GPUI_CURSOR_STYLE_CHANGED: u32 = WM_USER + 1; pub(crate) const WM_GPUI_CLOSE_ONE_WINDOW: u32 = WM_USER + 2; pub(crate) const WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD: u32 = WM_USER + 3; pub(crate) const WM_GPUI_DOCK_MENU_ACTION: u32 = WM_USER + 4; +pub(crate) const WM_GPUI_FORCE_UPDATE_WINDOW: u32 = WM_USER + 5; const SIZE_MOVE_LOOP_TIMER_ID: usize = 1; const AUTO_HIDE_TASKBAR_THICKNESS_PX: i32 = 1; -pub(crate) fn handle_msg( - handle: HWND, - msg: u32, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> LRESULT { - let handled = match msg { - WM_ACTIVATE => handle_activate_msg(wparam, state_ptr), - WM_CREATE => handle_create_msg(handle, state_ptr), - WM_MOVE => handle_move_msg(handle, lparam, state_ptr), - WM_SIZE => handle_size_msg(wparam, lparam, state_ptr), - WM_GETMINMAXINFO => handle_get_min_max_info_msg(lparam, state_ptr), - WM_ENTERSIZEMOVE | WM_ENTERMENULOOP => handle_size_move_loop(handle), - WM_EXITSIZEMOVE | WM_EXITMENULOOP => handle_size_move_loop_exit(handle), - WM_TIMER => handle_timer_msg(handle, wparam, state_ptr), - WM_NCCALCSIZE => handle_calc_client_size(handle, wparam, lparam, state_ptr), - WM_DPICHANGED => handle_dpi_changed_msg(handle, wparam, lparam, state_ptr), - WM_DISPLAYCHANGE => handle_display_change_msg(handle, state_ptr), - WM_NCHITTEST => handle_hit_test_msg(handle, msg, wparam, lparam, state_ptr), - WM_PAINT => handle_paint_msg(handle, state_ptr), - WM_CLOSE => handle_close_msg(handle, state_ptr), - WM_DESTROY => handle_destroy_msg(handle, state_ptr), - WM_MOUSEMOVE => handle_mouse_move_msg(handle, lparam, wparam, state_ptr), - WM_MOUSELEAVE | WM_NCMOUSELEAVE => handle_mouse_leave_msg(state_ptr), - WM_NCMOUSEMOVE => handle_nc_mouse_move_msg(handle, lparam, state_ptr), - WM_NCLBUTTONDOWN => { - handle_nc_mouse_down_msg(handle, MouseButton::Left, wparam, lparam, state_ptr) +impl WindowsWindowInner { + pub(crate) fn handle_msg( + self: &Rc<Self>, + handle: HWND, + msg: u32, + wparam: WPARAM, + lparam: LPARAM, + ) -> LRESULT { + let handled = match msg { + WM_ACTIVATE => self.handle_activate_msg(wparam), + WM_CREATE => self.handle_create_msg(handle), + WM_DEVICECHANGE => self.handle_device_change_msg(handle, wparam), + WM_MOVE => self.handle_move_msg(handle, lparam), + WM_SIZE => self.handle_size_msg(wparam, lparam), + WM_GETMINMAXINFO => self.handle_get_min_max_info_msg(lparam), + WM_ENTERSIZEMOVE | WM_ENTERMENULOOP => self.handle_size_move_loop(handle), + WM_EXITSIZEMOVE | WM_EXITMENULOOP => self.handle_size_move_loop_exit(handle), + WM_TIMER => self.handle_timer_msg(handle, wparam), + WM_NCCALCSIZE => self.handle_calc_client_size(handle, wparam, lparam), + WM_DPICHANGED => self.handle_dpi_changed_msg(handle, wparam, lparam), + WM_DISPLAYCHANGE => self.handle_display_change_msg(handle), + WM_NCHITTEST => self.handle_hit_test_msg(handle, msg, wparam, lparam), + WM_PAINT => self.handle_paint_msg(handle), + WM_CLOSE => self.handle_close_msg(), + WM_DESTROY => self.handle_destroy_msg(handle), + WM_MOUSEMOVE => self.handle_mouse_move_msg(handle, lparam, wparam), + WM_MOUSELEAVE | WM_NCMOUSELEAVE => self.handle_mouse_leave_msg(), + WM_NCMOUSEMOVE => self.handle_nc_mouse_move_msg(handle, lparam), + WM_NCLBUTTONDOWN => { + self.handle_nc_mouse_down_msg(handle, MouseButton::Left, wparam, lparam) + } + WM_NCRBUTTONDOWN => { + self.handle_nc_mouse_down_msg(handle, MouseButton::Right, wparam, lparam) + } + WM_NCMBUTTONDOWN => { + self.handle_nc_mouse_down_msg(handle, MouseButton::Middle, wparam, lparam) + } + WM_NCLBUTTONUP => { + self.handle_nc_mouse_up_msg(handle, MouseButton::Left, wparam, lparam) + } + WM_NCRBUTTONUP => { + self.handle_nc_mouse_up_msg(handle, MouseButton::Right, wparam, lparam) + } + WM_NCMBUTTONUP => { + self.handle_nc_mouse_up_msg(handle, MouseButton::Middle, wparam, lparam) + } + WM_LBUTTONDOWN => self.handle_mouse_down_msg(handle, MouseButton::Left, lparam), + WM_RBUTTONDOWN => self.handle_mouse_down_msg(handle, MouseButton::Right, lparam), + WM_MBUTTONDOWN => self.handle_mouse_down_msg(handle, MouseButton::Middle, lparam), + WM_XBUTTONDOWN => { + self.handle_xbutton_msg(handle, wparam, lparam, Self::handle_mouse_down_msg) + } + WM_LBUTTONUP => self.handle_mouse_up_msg(handle, MouseButton::Left, lparam), + WM_RBUTTONUP => self.handle_mouse_up_msg(handle, MouseButton::Right, lparam), + WM_MBUTTONUP => self.handle_mouse_up_msg(handle, MouseButton::Middle, lparam), + WM_XBUTTONUP => { + self.handle_xbutton_msg(handle, wparam, lparam, Self::handle_mouse_up_msg) + } + WM_MOUSEWHEEL => self.handle_mouse_wheel_msg(handle, wparam, lparam), + WM_MOUSEHWHEEL => self.handle_mouse_horizontal_wheel_msg(handle, wparam, lparam), + WM_SYSKEYDOWN => self.handle_syskeydown_msg(handle, wparam, lparam), + WM_SYSKEYUP => self.handle_syskeyup_msg(handle, wparam, lparam), + WM_SYSCOMMAND => self.handle_system_command(wparam), + WM_KEYDOWN => self.handle_keydown_msg(handle, wparam, lparam), + WM_KEYUP => self.handle_keyup_msg(handle, wparam, lparam), + WM_CHAR => self.handle_char_msg(wparam), + WM_DEADCHAR => self.handle_dead_char_msg(wparam), + WM_IME_STARTCOMPOSITION => self.handle_ime_position(handle), + WM_IME_COMPOSITION => self.handle_ime_composition(handle, lparam), + WM_SETCURSOR => self.handle_set_cursor(handle, lparam), + WM_SETTINGCHANGE => self.handle_system_settings_changed(handle, wparam, lparam), + WM_INPUTLANGCHANGE => self.handle_input_language_changed(lparam), + WM_GPUI_CURSOR_STYLE_CHANGED => self.handle_cursor_changed(lparam), + WM_GPUI_FORCE_UPDATE_WINDOW => self.draw_window(handle, true), + _ => None, + }; + if let Some(n) = handled { + LRESULT(n) + } else { + unsafe { DefWindowProcW(handle, msg, wparam, lparam) } } - WM_NCRBUTTONDOWN => { - handle_nc_mouse_down_msg(handle, MouseButton::Right, wparam, lparam, state_ptr) - } - WM_NCMBUTTONDOWN => { - handle_nc_mouse_down_msg(handle, MouseButton::Middle, wparam, lparam, state_ptr) - } - WM_NCLBUTTONUP => { - handle_nc_mouse_up_msg(handle, MouseButton::Left, wparam, lparam, state_ptr) - } - WM_NCRBUTTONUP => { - handle_nc_mouse_up_msg(handle, MouseButton::Right, wparam, lparam, state_ptr) - } - WM_NCMBUTTONUP => { - handle_nc_mouse_up_msg(handle, MouseButton::Middle, wparam, lparam, state_ptr) - } - WM_LBUTTONDOWN => handle_mouse_down_msg(handle, MouseButton::Left, lparam, state_ptr), - WM_RBUTTONDOWN => handle_mouse_down_msg(handle, MouseButton::Right, lparam, state_ptr), - WM_MBUTTONDOWN => handle_mouse_down_msg(handle, MouseButton::Middle, lparam, state_ptr), - WM_XBUTTONDOWN => { - handle_xbutton_msg(handle, wparam, lparam, handle_mouse_down_msg, state_ptr) - } - WM_LBUTTONUP => handle_mouse_up_msg(handle, MouseButton::Left, lparam, state_ptr), - WM_RBUTTONUP => handle_mouse_up_msg(handle, MouseButton::Right, lparam, state_ptr), - WM_MBUTTONUP => handle_mouse_up_msg(handle, MouseButton::Middle, lparam, state_ptr), - WM_XBUTTONUP => handle_xbutton_msg(handle, wparam, lparam, handle_mouse_up_msg, state_ptr), - WM_MOUSEWHEEL => handle_mouse_wheel_msg(handle, wparam, lparam, state_ptr), - WM_MOUSEHWHEEL => handle_mouse_horizontal_wheel_msg(handle, wparam, lparam, state_ptr), - WM_SYSKEYDOWN => handle_syskeydown_msg(handle, wparam, lparam, state_ptr), - WM_SYSKEYUP => handle_syskeyup_msg(handle, wparam, lparam, state_ptr), - WM_SYSCOMMAND => handle_system_command(wparam, state_ptr), - WM_KEYDOWN => handle_keydown_msg(handle, wparam, lparam, state_ptr), - WM_KEYUP => handle_keyup_msg(handle, wparam, lparam, state_ptr), - WM_CHAR => handle_char_msg(wparam, state_ptr), - WM_DEADCHAR => handle_dead_char_msg(wparam, state_ptr), - WM_IME_STARTCOMPOSITION => handle_ime_position(handle, state_ptr), - WM_IME_COMPOSITION => handle_ime_composition(handle, lparam, state_ptr), - WM_SETCURSOR => handle_set_cursor(handle, lparam, state_ptr), - WM_SETTINGCHANGE => handle_system_settings_changed(handle, wparam, lparam, state_ptr), - WM_INPUTLANGCHANGE => handle_input_language_changed(lparam, state_ptr), - WM_GPUI_CURSOR_STYLE_CHANGED => handle_cursor_changed(lparam, state_ptr), - _ => None, - }; - if let Some(n) = handled { - LRESULT(n) - } else { - unsafe { DefWindowProcW(handle, msg, wparam, lparam) } - } -} - -fn handle_move_msg( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let origin = logical_point( - lparam.signed_loword() as f32, - lparam.signed_hiword() as f32, - lock.scale_factor, - ); - lock.origin = origin; - let size = lock.logical_size; - let center_x = origin.x.0 + size.width.0 / 2.; - let center_y = origin.y.0 + size.height.0 / 2.; - let monitor_bounds = lock.display.bounds(); - if center_x < monitor_bounds.left().0 - || center_x > monitor_bounds.right().0 - || center_y < monitor_bounds.top().0 - || center_y > monitor_bounds.bottom().0 - { - // center of the window may have moved to another monitor - let monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; - // minimize the window can trigger this event too, in this case, - // monitor is invalid, we do nothing. - if !monitor.is_invalid() && lock.display.handle != monitor { - // we will get the same monitor if we only have one - lock.display = WindowsDisplay::new_with_handle(monitor); - } - } - if let Some(mut callback) = lock.callbacks.moved.take() { - drop(lock); - callback(); - state_ptr.state.borrow_mut().callbacks.moved = Some(callback); - } - Some(0) -} - -fn handle_get_min_max_info_msg( - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let lock = state_ptr.state.borrow(); - let min_size = lock.min_size?; - let scale_factor = lock.scale_factor; - let boarder_offset = lock.border_offset; - drop(lock); - unsafe { - let minmax_info = &mut *(lparam.0 as *mut MINMAXINFO); - minmax_info.ptMinTrackSize.x = - min_size.width.scale(scale_factor).0 as i32 + boarder_offset.width_offset; - minmax_info.ptMinTrackSize.y = - min_size.height.scale(scale_factor).0 as i32 + boarder_offset.height_offset; - } - Some(0) -} - -fn handle_size_msg( - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - - // Don't resize the renderer when the window is minimized, but record that it was minimized so - // that on restore the swap chain can be recreated via `update_drawable_size_even_if_unchanged`. - if wparam.0 == SIZE_MINIMIZED as usize { - lock.restore_from_minimized = lock.callbacks.request_frame.take(); - return Some(0); } - let width = lparam.loword().max(1) as i32; - let height = lparam.hiword().max(1) as i32; - let new_size = size(DevicePixels(width), DevicePixels(height)); - let scale_factor = lock.scale_factor; - if lock.restore_from_minimized.is_some() { - lock.renderer - .update_drawable_size_even_if_unchanged(new_size); - lock.callbacks.request_frame = lock.restore_from_minimized.take(); - } else { - lock.renderer.update_drawable_size(new_size); - } - let new_size = new_size.to_pixels(scale_factor); - lock.logical_size = new_size; - if let Some(mut callback) = lock.callbacks.resize.take() { - drop(lock); - callback(new_size, scale_factor); - state_ptr.state.borrow_mut().callbacks.resize = Some(callback); - } - Some(0) -} - -fn handle_size_move_loop(handle: HWND) -> Option<isize> { - unsafe { - let ret = SetTimer( - Some(handle), - SIZE_MOVE_LOOP_TIMER_ID, - USER_TIMER_MINIMUM, - None, + fn handle_move_msg(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let origin = logical_point( + lparam.signed_loword() as f32, + lparam.signed_hiword() as f32, + lock.scale_factor, ); - if ret == 0 { - log::error!( - "unable to create timer: {}", - std::io::Error::last_os_error() - ); - } - } - None -} - -fn handle_size_move_loop_exit(handle: HWND) -> Option<isize> { - unsafe { - KillTimer(Some(handle), SIZE_MOVE_LOOP_TIMER_ID).log_err(); - } - None -} - -fn handle_timer_msg( - handle: HWND, - wparam: WPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID { - for runnable in state_ptr.main_receiver.drain() { - runnable.run(); - } - handle_paint_msg(handle, state_ptr) - } else { - None - } -} - -fn handle_paint_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut request_frame) = lock.callbacks.request_frame.take() { - drop(lock); - request_frame(Default::default()); - state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame); - } - unsafe { ValidateRect(Some(handle), None).ok().log_err() }; - Some(0) -} - -fn handle_close_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let output = if let Some(mut callback) = lock.callbacks.should_close.take() { - drop(lock); - let should_close = callback(); - state_ptr.state.borrow_mut().callbacks.should_close = Some(callback); - if should_close { None } else { Some(0) } - } else { - None - }; - - // Workaround as window close animation is not played with `WS_EX_LAYERED` enabled. - if output.is_none() { - unsafe { - let current_style = get_window_long(handle, GWL_EXSTYLE); - set_window_long( - handle, - GWL_EXSTYLE, - current_style & !WS_EX_LAYERED.0 as isize, - ); - } - } - - output -} - -fn handle_destroy_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let callback = { - let mut lock = state_ptr.state.borrow_mut(); - lock.callbacks.close.take() - }; - if let Some(callback) = callback { - callback(); - } - unsafe { - PostThreadMessageW( - state_ptr.main_thread_id_win32, - WM_GPUI_CLOSE_ONE_WINDOW, - WPARAM(state_ptr.validation_number), - LPARAM(handle.0 as isize), - ) - .log_err(); - } - Some(0) -} - -fn handle_mouse_move_msg( - handle: HWND, - lparam: LPARAM, - wparam: WPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - start_tracking_mouse(handle, &state_ptr, TME_LEAVE); - - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let scale_factor = lock.scale_factor; - drop(lock); - - let pressed_button = match MODIFIERKEYS_FLAGS(wparam.loword() as u32) { - flags if flags.contains(MK_LBUTTON) => Some(MouseButton::Left), - flags if flags.contains(MK_RBUTTON) => Some(MouseButton::Right), - flags if flags.contains(MK_MBUTTON) => Some(MouseButton::Middle), - flags if flags.contains(MK_XBUTTON1) => { - Some(MouseButton::Navigate(NavigationDirection::Back)) - } - flags if flags.contains(MK_XBUTTON2) => { - Some(MouseButton::Navigate(NavigationDirection::Forward)) - } - _ => None, - }; - let x = lparam.signed_loword() as f32; - let y = lparam.signed_hiword() as f32; - let input = PlatformInput::MouseMove(MouseMoveEvent { - position: logical_point(x, y, scale_factor), - pressed_button, - modifiers: current_modifiers(), - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_mouse_leave_msg(state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - lock.hovered = false; - if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { - drop(lock); - callback(false); - state_ptr.state.borrow_mut().callbacks.hovered_status_change = Some(callback); - } - - Some(0) -} - -fn handle_syskeydown_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyDown(KeyDownEvent { - keystroke, - is_held: lparam.0 & (0x1 << 30) > 0, - }) - })?; - let mut func = lock.callbacks.input.take()?; - drop(lock); - - let handled = !func(input).propagate; - - let mut lock = state_ptr.state.borrow_mut(); - lock.callbacks.input = Some(func); - - if handled { - lock.system_key_handled = true; - Some(0) - } else { - // we need to call `DefWindowProcW`, or we will lose the system-wide `Alt+F4`, `Alt+{other keys}` - // shortcuts. - None - } -} - -fn handle_syskeyup_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyUp(KeyUpEvent { keystroke }) - })?; - let mut func = lock.callbacks.input.take()?; - drop(lock); - func(input); - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - // Always return 0 to indicate that the message was handled, so we could properly handle `ModifiersChanged` event. - Some(0) -} - -// It's a known bug that you can't trigger `ctrl-shift-0`. See: -// https://superuser.com/questions/1455762/ctrl-shift-number-key-combination-has-stopped-working-for-a-few-numbers -fn handle_keydown_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyDown(KeyDownEvent { - keystroke, - is_held: lparam.0 & (0x1 << 30) > 0, - }) - }) else { - return Some(1); - }; - drop(lock); - - let is_composing = with_input_handler(&state_ptr, |input_handler| { - input_handler.marked_text_range() - }) - .flatten() - .is_some(); - if is_composing { - translate_message(handle, wparam, lparam); - return Some(0); - } - - let Some(mut func) = state_ptr.state.borrow_mut().callbacks.input.take() else { - return Some(1); - }; - - let handled = !func(input).propagate; - - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { - Some(0) - } else { - translate_message(handle, wparam, lparam); - Some(1) - } -} - -fn handle_keyup_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyUp(KeyUpEvent { keystroke }) - }) else { - return Some(1); - }; - - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - drop(lock); - - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_char_msg(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let input = parse_char_message(wparam, &state_ptr)?; - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_text_in_range(None, &input); - }); - - Some(0) -} - -fn handle_dead_char_msg(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let ch = char::from_u32(wparam.0 as u32)?.to_string(); - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_and_mark_text_in_range(None, &ch, None); - }); - None -} - -fn handle_mouse_down_msg( - handle: HWND, - button: MouseButton, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - unsafe { SetCapture(handle) }; - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let x = lparam.signed_loword(); - let y = lparam.signed_hiword(); - let physical_point = point(DevicePixels(x as i32), DevicePixels(y as i32)); - let click_count = lock.click_state.update(button, physical_point); - let scale_factor = lock.scale_factor; - drop(lock); - - let input = PlatformInput::MouseDown(MouseDownEvent { - button, - position: logical_point(x as f32, y as f32, scale_factor), - modifiers: current_modifiers(), - click_count, - first_mouse: false, - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_mouse_up_msg( - _handle: HWND, - button: MouseButton, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - unsafe { ReleaseCapture().log_err() }; - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let x = lparam.signed_loword() as f32; - let y = lparam.signed_hiword() as f32; - let click_count = lock.click_state.current_count; - let scale_factor = lock.scale_factor; - drop(lock); - - let input = PlatformInput::MouseUp(MouseUpEvent { - button, - position: logical_point(x, y, scale_factor), - modifiers: current_modifiers(), - click_count, - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_xbutton_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - handler: impl Fn(HWND, MouseButton, LPARAM, Rc<WindowsWindowStatePtr>) -> Option<isize>, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let nav_dir = match wparam.hiword() { - XBUTTON1 => NavigationDirection::Back, - XBUTTON2 => NavigationDirection::Forward, - _ => return Some(1), - }; - handler(handle, MouseButton::Navigate(nav_dir), lparam, state_ptr) -} - -fn handle_mouse_wheel_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let modifiers = current_modifiers(); - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let scale_factor = lock.scale_factor; - let wheel_scroll_amount = match modifiers.shift { - true => lock.system_settings.mouse_wheel_settings.wheel_scroll_chars, - false => lock.system_settings.mouse_wheel_settings.wheel_scroll_lines, - }; - drop(lock); - - let wheel_distance = - (wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_amount as f32; - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let input = PlatformInput::ScrollWheel(ScrollWheelEvent { - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - delta: ScrollDelta::Lines(match modifiers.shift { - true => Point { - x: wheel_distance, - y: 0.0, - }, - false => Point { - y: wheel_distance, - x: 0.0, - }, - }), - modifiers, - touch_phase: TouchPhase::Moved, - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_mouse_horizontal_wheel_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let scale_factor = lock.scale_factor; - let wheel_scroll_chars = lock.system_settings.mouse_wheel_settings.wheel_scroll_chars; - drop(lock); - - let wheel_distance = - (-wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_chars as f32; - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let event = PlatformInput::ScrollWheel(ScrollWheelEvent { - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - delta: ScrollDelta::Lines(Point { - x: wheel_distance, - y: 0.0, - }), - modifiers: current_modifiers(), - touch_phase: TouchPhase::Moved, - }); - let handled = !func(event).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn retrieve_caret_position(state_ptr: &Rc<WindowsWindowStatePtr>) -> Option<POINT> { - with_input_handler_and_scale_factor(state_ptr, |input_handler, scale_factor| { - let caret_range = input_handler.selected_text_range(false)?; - let caret_position = input_handler.bounds_for_range(caret_range.range)?; - Some(POINT { - // logical to physical - x: (caret_position.origin.x.0 * scale_factor) as i32, - y: (caret_position.origin.y.0 * scale_factor) as i32 - + ((caret_position.size.height.0 * scale_factor) as i32 / 2), - }) - }) -} - -fn handle_ime_position(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - unsafe { - let ctx = ImmGetContext(handle); - - let Some(caret_position) = retrieve_caret_position(&state_ptr) else { - return Some(0); - }; + lock.origin = origin; + let size = lock.logical_size; + let center_x = origin.x.0 + size.width.0 / 2.; + let center_y = origin.y.0 + size.height.0 / 2.; + let monitor_bounds = lock.display.bounds(); + if center_x < monitor_bounds.left().0 + || center_x > monitor_bounds.right().0 + || center_y < monitor_bounds.top().0 + || center_y > monitor_bounds.bottom().0 { - let config = COMPOSITIONFORM { - dwStyle: CFS_POINT, - ptCurrentPos: caret_position, - ..Default::default() - }; - ImmSetCompositionWindow(ctx, &config as _).ok().log_err(); - } - { - let config = CANDIDATEFORM { - dwStyle: CFS_CANDIDATEPOS, - ptCurrentPos: caret_position, - ..Default::default() - }; - ImmSetCandidateWindow(ctx, &config as _).ok().log_err(); - } - ImmReleaseContext(handle, ctx).ok().log_err(); - Some(0) - } -} - -fn handle_ime_composition( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let ctx = unsafe { ImmGetContext(handle) }; - let result = handle_ime_composition_inner(ctx, lparam, state_ptr); - unsafe { ImmReleaseContext(handle, ctx).ok().log_err() }; - result -} - -fn handle_ime_composition_inner( - ctx: HIMC, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let lparam = lparam.0 as u32; - if lparam == 0 { - // Japanese IME may send this message with lparam = 0, which indicates that - // there is no composition string. - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_text_in_range(None, ""); - })?; - Some(0) - } else { - if lparam & GCS_COMPSTR.0 > 0 { - let comp_string = parse_ime_composition_string(ctx, GCS_COMPSTR)?; - let caret_pos = (!comp_string.is_empty() && lparam & GCS_CURSORPOS.0 > 0).then(|| { - let pos = retrieve_composition_cursor_position(ctx); - pos..pos - }); - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_and_mark_text_in_range(None, &comp_string, caret_pos); - })?; - } - if lparam & GCS_RESULTSTR.0 > 0 { - let comp_result = parse_ime_composition_string(ctx, GCS_RESULTSTR)?; - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_text_in_range(None, &comp_result); - })?; - return Some(0); - } - - // currently, we don't care other stuff - None - } -} - -/// SEE: https://learn.microsoft.com/en-us/windows/win32/winmsg/wm-nccalcsize -fn handle_calc_client_size( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if !state_ptr.hide_title_bar || state_ptr.state.borrow().is_fullscreen() || wparam.0 == 0 { - return None; - } - - let is_maximized = state_ptr.state.borrow().is_maximized(); - let insets = get_client_area_insets(handle, is_maximized, state_ptr.windows_version); - // wparam is TRUE so lparam points to an NCCALCSIZE_PARAMS structure - let mut params = lparam.0 as *mut NCCALCSIZE_PARAMS; - let mut requested_client_rect = unsafe { &mut ((*params).rgrc) }; - - requested_client_rect[0].left += insets.left; - requested_client_rect[0].top += insets.top; - requested_client_rect[0].right -= insets.right; - requested_client_rect[0].bottom -= insets.bottom; - - // Fix auto hide taskbar not showing. This solution is based on the approach - // used by Chrome. However, it may result in one row of pixels being obscured - // in our client area. But as Chrome says, "there seems to be no better solution." - if is_maximized { - if let Some(ref taskbar_position) = state_ptr - .state - .borrow() - .system_settings - .auto_hide_taskbar_position - { - // Fot the auto-hide taskbar, adjust in by 1 pixel on taskbar edge, - // so the window isn't treated as a "fullscreen app", which would cause - // the taskbar to disappear. - match taskbar_position { - AutoHideTaskbarPosition::Left => { - requested_client_rect[0].left += AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Top => { - requested_client_rect[0].top += AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Right => { - requested_client_rect[0].right -= AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Bottom => { - requested_client_rect[0].bottom -= AUTO_HIDE_TASKBAR_THICKNESS_PX - } + // center of the window may have moved to another monitor + let monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; + // minimize the window can trigger this event too, in this case, + // monitor is invalid, we do nothing. + if !monitor.is_invalid() && lock.display.handle != monitor { + // we will get the same monitor if we only have one + lock.display = WindowsDisplay::new_with_handle(monitor); } } - } - - Some(0) -} - -fn handle_activate_msg(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let activated = wparam.loword() > 0; - let this = state_ptr.clone(); - state_ptr - .executor - .spawn(async move { - let mut lock = this.state.borrow_mut(); - if let Some(mut func) = lock.callbacks.active_status_change.take() { - drop(lock); - func(activated); - this.state.borrow_mut().callbacks.active_status_change = Some(func); - } - }) - .detach(); - - None -} - -fn handle_create_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - if state_ptr.hide_title_bar { - notify_frame_changed(handle); - Some(0) - } else { - None - } -} - -fn handle_dpi_changed_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let new_dpi = wparam.loword() as f32; - let mut lock = state_ptr.state.borrow_mut(); - lock.scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32; - lock.border_offset.update(handle).log_err(); - drop(lock); - - let rect = unsafe { &*(lparam.0 as *const RECT) }; - let width = rect.right - rect.left; - let height = rect.bottom - rect.top; - // this will emit `WM_SIZE` and `WM_MOVE` right here - // even before this function returns - // the new size is handled in `WM_SIZE` - unsafe { - SetWindowPos( - handle, - None, - rect.left, - rect.top, - width, - height, - SWP_NOZORDER | SWP_NOACTIVATE, - ) - .context("unable to set window position after dpi has changed") - .log_err(); - } - - Some(0) -} - -/// The following conditions will trigger this event: -/// 1. The monitor on which the window is located goes offline or changes resolution. -/// 2. Another monitor goes offline, is plugged in, or changes resolution. -/// -/// In either case, the window will only receive information from the monitor on which -/// it is located. -/// -/// For example, in the case of condition 2, where the monitor on which the window is -/// located has actually changed nothing, it will still receive this event. -fn handle_display_change_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - // NOTE: - // Even the `lParam` holds the resolution of the screen, we just ignore it. - // Because WM_DPICHANGED, WM_MOVE, WM_SIZE will come first, window reposition and resize - // are handled there. - // So we only care about if monitor is disconnected. - let previous_monitor = state_ptr.state.borrow().display; - if WindowsDisplay::is_connected(previous_monitor.handle) { - // we are fine, other display changed - return None; - } - // display disconnected - // in this case, the OS will move our window to another monitor, and minimize it. - // we deminimize the window and query the monitor after moving - unsafe { - let _ = ShowWindow(handle, SW_SHOWNORMAL); - }; - let new_monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; - // all monitors disconnected - if new_monitor.is_invalid() { - log::error!("No monitor detected!"); - return None; - } - let new_display = WindowsDisplay::new_with_handle(new_monitor); - state_ptr.state.borrow_mut().display = new_display; - Some(0) -} - -fn handle_hit_test_msg( - handle: HWND, - msg: u32, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if !state_ptr.is_movable || state_ptr.state.borrow().is_fullscreen() { - return None; - } - - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut callback) = lock.callbacks.hit_test_window_control.take() { - drop(lock); - let area = callback(); - state_ptr - .state - .borrow_mut() - .callbacks - .hit_test_window_control = Some(callback); - if let Some(area) = area { - return match area { - WindowControlArea::Drag => Some(HTCAPTION as _), - WindowControlArea::Close => Some(HTCLOSE as _), - WindowControlArea::Max => Some(HTMAXBUTTON as _), - WindowControlArea::Min => Some(HTMINBUTTON as _), - }; + if let Some(mut callback) = lock.callbacks.moved.take() { + drop(lock); + callback(); + self.state.borrow_mut().callbacks.moved = Some(callback); } - } else { - drop(lock); + Some(0) } - if !state_ptr.hide_title_bar { - // If the OS draws the title bar, we don't need to handle hit test messages. - return None; - } - - // default handler for resize areas - let hit = unsafe { DefWindowProcW(handle, msg, wparam, lparam) }; - if matches!( - hit.0 as u32, - HTNOWHERE - | HTRIGHT - | HTLEFT - | HTTOPLEFT - | HTTOP - | HTTOPRIGHT - | HTBOTTOMRIGHT - | HTBOTTOM - | HTBOTTOMLEFT - ) { - return Some(hit.0); - } - - if state_ptr.state.borrow().is_fullscreen() { - return Some(HTCLIENT as _); - } - - let dpi = unsafe { GetDpiForWindow(handle) }; - let frame_y = unsafe { GetSystemMetricsForDpi(SM_CYFRAME, dpi) }; - - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - if !state_ptr.state.borrow().is_maximized() && cursor_point.y >= 0 && cursor_point.y <= frame_y - { - return Some(HTTOP as _); - } - - Some(HTCLIENT as _) -} - -fn handle_nc_mouse_move_msg( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - start_tracking_mouse(handle, &state_ptr, TME_LEAVE | TME_NONCLIENT); - - let mut lock = state_ptr.state.borrow_mut(); - let mut func = lock.callbacks.input.take()?; - let scale_factor = lock.scale_factor; - drop(lock); - - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let input = PlatformInput::MouseMove(MouseMoveEvent { - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - pressed_button: None, - modifiers: current_modifiers(), - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { None } -} - -fn handle_nc_mouse_down_msg( - handle: HWND, - button: MouseButton, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut func) = lock.callbacks.input.take() { + fn handle_get_min_max_info_msg(&self, lparam: LPARAM) -> Option<isize> { + let lock = self.state.borrow(); + let min_size = lock.min_size?; let scale_factor = lock.scale_factor; - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), + let boarder_offset = lock.border_offset; + drop(lock); + unsafe { + let minmax_info = &mut *(lparam.0 as *mut MINMAXINFO); + minmax_info.ptMinTrackSize.x = + min_size.width.scale(scale_factor).0 as i32 + boarder_offset.width_offset; + minmax_info.ptMinTrackSize.y = + min_size.height.scale(scale_factor).0 as i32 + boarder_offset.height_offset; + } + Some(0) + } + + fn handle_size_msg(&self, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + + // Don't resize the renderer when the window is minimized, but record that it was minimized so + // that on restore the swap chain can be recreated via `update_drawable_size_even_if_unchanged`. + if wparam.0 == SIZE_MINIMIZED as usize { + lock.restore_from_minimized = lock.callbacks.request_frame.take(); + return Some(0); + } + + let width = lparam.loword().max(1) as i32; + let height = lparam.hiword().max(1) as i32; + let new_size = size(DevicePixels(width), DevicePixels(height)); + let scale_factor = lock.scale_factor; + if lock.restore_from_minimized.is_some() { + lock.callbacks.request_frame = lock.restore_from_minimized.take(); + } else { + lock.renderer.resize(new_size).log_err(); + } + let new_size = new_size.to_pixels(scale_factor); + lock.logical_size = new_size; + if let Some(mut callback) = lock.callbacks.resize.take() { + drop(lock); + callback(new_size, scale_factor); + self.state.borrow_mut().callbacks.resize = Some(callback); + } + Some(0) + } + + fn handle_size_move_loop(&self, handle: HWND) -> Option<isize> { + unsafe { + let ret = SetTimer( + Some(handle), + SIZE_MOVE_LOOP_TIMER_ID, + USER_TIMER_MINIMUM, + None, + ); + if ret == 0 { + log::error!( + "unable to create timer: {}", + std::io::Error::last_os_error() + ); + } + } + None + } + + fn handle_size_move_loop_exit(&self, handle: HWND) -> Option<isize> { + unsafe { + KillTimer(Some(handle), SIZE_MOVE_LOOP_TIMER_ID).log_err(); + } + None + } + + fn handle_timer_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> { + if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID { + for runnable in self.main_receiver.drain() { + runnable.run(); + } + self.handle_paint_msg(handle) + } else { + None + } + } + + fn handle_paint_msg(&self, handle: HWND) -> Option<isize> { + self.draw_window(handle, false) + } + + fn handle_close_msg(&self) -> Option<isize> { + let mut callback = self.state.borrow_mut().callbacks.should_close.take()?; + let should_close = callback(); + self.state.borrow_mut().callbacks.should_close = Some(callback); + if should_close { None } else { Some(0) } + } + + fn handle_destroy_msg(&self, handle: HWND) -> Option<isize> { + let callback = { + let mut lock = self.state.borrow_mut(); + lock.callbacks.close.take() }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let physical_point = point(DevicePixels(cursor_point.x), DevicePixels(cursor_point.y)); + if let Some(callback) = callback { + callback(); + } + unsafe { + PostThreadMessageW( + self.main_thread_id_win32, + WM_GPUI_CLOSE_ONE_WINDOW, + WPARAM(self.validation_number), + LPARAM(handle.0 as isize), + ) + .log_err(); + } + Some(0) + } + + fn handle_mouse_move_msg(&self, handle: HWND, lparam: LPARAM, wparam: WPARAM) -> Option<isize> { + self.start_tracking_mouse(handle, TME_LEAVE); + + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let scale_factor = lock.scale_factor; + drop(lock); + + let pressed_button = match MODIFIERKEYS_FLAGS(wparam.loword() as u32) { + flags if flags.contains(MK_LBUTTON) => Some(MouseButton::Left), + flags if flags.contains(MK_RBUTTON) => Some(MouseButton::Right), + flags if flags.contains(MK_MBUTTON) => Some(MouseButton::Middle), + flags if flags.contains(MK_XBUTTON1) => { + Some(MouseButton::Navigate(NavigationDirection::Back)) + } + flags if flags.contains(MK_XBUTTON2) => { + Some(MouseButton::Navigate(NavigationDirection::Forward)) + } + _ => None, + }; + let x = lparam.signed_loword() as f32; + let y = lparam.signed_hiword() as f32; + let input = PlatformInput::MouseMove(MouseMoveEvent { + position: logical_point(x, y, scale_factor), + pressed_button, + modifiers: current_modifiers(), + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn handle_mouse_leave_msg(&self) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + lock.hovered = false; + if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { + drop(lock); + callback(false); + self.state.borrow_mut().callbacks.hovered_status_change = Some(callback); + } + + Some(0) + } + + fn handle_syskeydown_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyDown(KeyDownEvent { + keystroke, + is_held: lparam.0 & (0x1 << 30) > 0, + }) + })?; + let mut func = lock.callbacks.input.take()?; + drop(lock); + + let handled = !func(input).propagate; + + let mut lock = self.state.borrow_mut(); + lock.callbacks.input = Some(func); + + if handled { + lock.system_key_handled = true; + Some(0) + } else { + // we need to call `DefWindowProcW`, or we will lose the system-wide `Alt+F4`, `Alt+{other keys}` + // shortcuts. + None + } + } + + fn handle_syskeyup_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyUp(KeyUpEvent { keystroke }) + })?; + let mut func = lock.callbacks.input.take()?; + drop(lock); + func(input); + self.state.borrow_mut().callbacks.input = Some(func); + + // Always return 0 to indicate that the message was handled, so we could properly handle `ModifiersChanged` event. + Some(0) + } + + // It's a known bug that you can't trigger `ctrl-shift-0`. See: + // https://superuser.com/questions/1455762/ctrl-shift-number-key-combination-has-stopped-working-for-a-few-numbers + fn handle_keydown_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyDown(KeyDownEvent { + keystroke, + is_held: lparam.0 & (0x1 << 30) > 0, + }) + }) else { + return Some(1); + }; + drop(lock); + + let is_composing = self + .with_input_handler(|input_handler| input_handler.marked_text_range()) + .flatten() + .is_some(); + if is_composing { + translate_message(handle, wparam, lparam); + return Some(0); + } + + let Some(mut func) = self.state.borrow_mut().callbacks.input.take() else { + return Some(1); + }; + + let handled = !func(input).propagate; + + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { + Some(0) + } else { + translate_message(handle, wparam, lparam); + Some(1) + } + } + + fn handle_keyup_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyUp(KeyUpEvent { keystroke }) + }) else { + return Some(1); + }; + + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + drop(lock); + + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn handle_char_msg(&self, wparam: WPARAM) -> Option<isize> { + let input = self.parse_char_message(wparam)?; + self.with_input_handler(|input_handler| { + input_handler.replace_text_in_range(None, &input); + }); + + Some(0) + } + + fn handle_dead_char_msg(&self, wparam: WPARAM) -> Option<isize> { + let ch = char::from_u32(wparam.0 as u32)?.to_string(); + self.with_input_handler(|input_handler| { + input_handler.replace_and_mark_text_in_range(None, &ch, None); + }); + None + } + + fn handle_mouse_down_msg( + &self, + handle: HWND, + button: MouseButton, + lparam: LPARAM, + ) -> Option<isize> { + unsafe { SetCapture(handle) }; + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let x = lparam.signed_loword(); + let y = lparam.signed_hiword(); + let physical_point = point(DevicePixels(x as i32), DevicePixels(y as i32)); let click_count = lock.click_state.update(button, physical_point); + let scale_factor = lock.scale_factor; drop(lock); let input = PlatformInput::MouseDown(MouseDownEvent { button, - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + position: logical_point(x as f32, y as f32, scale_factor), modifiers: current_modifiers(), click_count, first_mouse: false, }); - let result = func(input.clone()); - let handled = !result.propagate || result.default_prevented; - state_ptr.state.borrow_mut().callbacks.input = Some(func); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); - if handled { - return Some(0); - } - } else { - drop(lock); - }; + if handled { Some(0) } else { Some(1) } + } - // Since these are handled in handle_nc_mouse_up_msg we must prevent the default window proc - if button == MouseButton::Left { - match wparam.0 as u32 { - HTMINBUTTON => state_ptr.state.borrow_mut().nc_button_pressed = Some(HTMINBUTTON), - HTMAXBUTTON => state_ptr.state.borrow_mut().nc_button_pressed = Some(HTMAXBUTTON), - HTCLOSE => state_ptr.state.borrow_mut().nc_button_pressed = Some(HTCLOSE), - _ => return None, + fn handle_mouse_up_msg( + &self, + _handle: HWND, + button: MouseButton, + lparam: LPARAM, + ) -> Option<isize> { + unsafe { ReleaseCapture().log_err() }; + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); }; + let x = lparam.signed_loword() as f32; + let y = lparam.signed_hiword() as f32; + let click_count = lock.click_state.current_count; + let scale_factor = lock.scale_factor; + drop(lock); + + let input = PlatformInput::MouseUp(MouseUpEvent { + button, + position: logical_point(x, y, scale_factor), + modifiers: current_modifiers(), + click_count, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn handle_xbutton_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + handler: impl Fn(&Self, HWND, MouseButton, LPARAM) -> Option<isize>, + ) -> Option<isize> { + let nav_dir = match wparam.hiword() { + XBUTTON1 => NavigationDirection::Back, + XBUTTON2 => NavigationDirection::Forward, + _ => return Some(1), + }; + handler(self, handle, MouseButton::Navigate(nav_dir), lparam) + } + + fn handle_mouse_wheel_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let modifiers = current_modifiers(); + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let scale_factor = lock.scale_factor; + let wheel_scroll_amount = match modifiers.shift { + true => lock.system_settings.mouse_wheel_settings.wheel_scroll_chars, + false => lock.system_settings.mouse_wheel_settings.wheel_scroll_lines, + }; + drop(lock); + + let wheel_distance = + (wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_amount as f32; + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let input = PlatformInput::ScrollWheel(ScrollWheelEvent { + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + delta: ScrollDelta::Lines(match modifiers.shift { + true => Point { + x: wheel_distance, + y: 0.0, + }, + false => Point { + y: wheel_distance, + x: 0.0, + }, + }), + modifiers, + touch_phase: TouchPhase::Moved, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn handle_mouse_horizontal_wheel_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let scale_factor = lock.scale_factor; + let wheel_scroll_chars = lock.system_settings.mouse_wheel_settings.wheel_scroll_chars; + drop(lock); + + let wheel_distance = + (-wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_chars as f32; + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let event = PlatformInput::ScrollWheel(ScrollWheelEvent { + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + delta: ScrollDelta::Lines(Point { + x: wheel_distance, + y: 0.0, + }), + modifiers: current_modifiers(), + touch_phase: TouchPhase::Moved, + }); + let handled = !func(event).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn retrieve_caret_position(&self) -> Option<POINT> { + self.with_input_handler_and_scale_factor(|input_handler, scale_factor| { + let caret_range = input_handler.selected_text_range(false)?; + let caret_position = input_handler.bounds_for_range(caret_range.range)?; + Some(POINT { + // logical to physical + x: (caret_position.origin.x.0 * scale_factor) as i32, + y: (caret_position.origin.y.0 * scale_factor) as i32 + + ((caret_position.size.height.0 * scale_factor) as i32 / 2), + }) + }) + } + + fn handle_ime_position(&self, handle: HWND) -> Option<isize> { + unsafe { + let ctx = ImmGetContext(handle); + + let Some(caret_position) = self.retrieve_caret_position() else { + return Some(0); + }; + { + let config = COMPOSITIONFORM { + dwStyle: CFS_POINT, + ptCurrentPos: caret_position, + ..Default::default() + }; + ImmSetCompositionWindow(ctx, &config as _).ok().log_err(); + } + { + let config = CANDIDATEFORM { + dwStyle: CFS_CANDIDATEPOS, + ptCurrentPos: caret_position, + ..Default::default() + }; + ImmSetCandidateWindow(ctx, &config as _).ok().log_err(); + } + ImmReleaseContext(handle, ctx).ok().log_err(); + Some(0) + } + } + + fn handle_ime_composition(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + let ctx = unsafe { ImmGetContext(handle) }; + let result = self.handle_ime_composition_inner(ctx, lparam); + unsafe { ImmReleaseContext(handle, ctx).ok().log_err() }; + result + } + + fn handle_ime_composition_inner(&self, ctx: HIMC, lparam: LPARAM) -> Option<isize> { + let lparam = lparam.0 as u32; + if lparam == 0 { + // Japanese IME may send this message with lparam = 0, which indicates that + // there is no composition string. + self.with_input_handler(|input_handler| { + input_handler.replace_text_in_range(None, ""); + })?; + Some(0) + } else { + if lparam & GCS_COMPSTR.0 > 0 { + let comp_string = parse_ime_composition_string(ctx, GCS_COMPSTR)?; + let caret_pos = + (!comp_string.is_empty() && lparam & GCS_CURSORPOS.0 > 0).then(|| { + let pos = retrieve_composition_cursor_position(ctx); + pos..pos + }); + self.with_input_handler(|input_handler| { + input_handler.replace_and_mark_text_in_range(None, &comp_string, caret_pos); + })?; + } + if lparam & GCS_RESULTSTR.0 > 0 { + let comp_result = parse_ime_composition_string(ctx, GCS_RESULTSTR)?; + self.with_input_handler(|input_handler| { + input_handler.replace_text_in_range(None, &comp_result); + })?; + return Some(0); + } + + // currently, we don't care other stuff + None + } + } + + /// SEE: https://learn.microsoft.com/en-us/windows/win32/winmsg/wm-nccalcsize + fn handle_calc_client_size( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + if !self.hide_title_bar || self.state.borrow().is_fullscreen() || wparam.0 == 0 { + return None; + } + + let is_maximized = self.state.borrow().is_maximized(); + let insets = get_client_area_insets(handle, is_maximized, self.windows_version); + // wparam is TRUE so lparam points to an NCCALCSIZE_PARAMS structure + let mut params = lparam.0 as *mut NCCALCSIZE_PARAMS; + let mut requested_client_rect = unsafe { &mut ((*params).rgrc) }; + + requested_client_rect[0].left += insets.left; + requested_client_rect[0].top += insets.top; + requested_client_rect[0].right -= insets.right; + requested_client_rect[0].bottom -= insets.bottom; + + // Fix auto hide taskbar not showing. This solution is based on the approach + // used by Chrome. However, it may result in one row of pixels being obscured + // in our client area. But as Chrome says, "there seems to be no better solution." + if is_maximized { + if let Some(ref taskbar_position) = self + .state + .borrow() + .system_settings + .auto_hide_taskbar_position + { + // Fot the auto-hide taskbar, adjust in by 1 pixel on taskbar edge, + // so the window isn't treated as a "fullscreen app", which would cause + // the taskbar to disappear. + match taskbar_position { + AutoHideTaskbarPosition::Left => { + requested_client_rect[0].left += AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Top => { + requested_client_rect[0].top += AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Right => { + requested_client_rect[0].right -= AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Bottom => { + requested_client_rect[0].bottom -= AUTO_HIDE_TASKBAR_THICKNESS_PX + } + } + } + } + Some(0) - } else { + } + + fn handle_activate_msg(self: &Rc<Self>, wparam: WPARAM) -> Option<isize> { + let activated = wparam.loword() > 0; + let this = self.clone(); + self.executor + .spawn(async move { + let mut lock = this.state.borrow_mut(); + if let Some(mut func) = lock.callbacks.active_status_change.take() { + drop(lock); + func(activated); + this.state.borrow_mut().callbacks.active_status_change = Some(func); + } + }) + .detach(); + None } -} -fn handle_nc_mouse_up_msg( - handle: HWND, - button: MouseButton, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut func) = lock.callbacks.input.take() { + fn handle_create_msg(&self, handle: HWND) -> Option<isize> { + if self.hide_title_bar { + notify_frame_changed(handle); + Some(0) + } else { + None + } + } + + fn handle_dpi_changed_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let new_dpi = wparam.loword() as f32; + let mut lock = self.state.borrow_mut(); + lock.scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32; + lock.border_offset.update(handle).log_err(); + drop(lock); + + let rect = unsafe { &*(lparam.0 as *const RECT) }; + let width = rect.right - rect.left; + let height = rect.bottom - rect.top; + // this will emit `WM_SIZE` and `WM_MOVE` right here + // even before this function returns + // the new size is handled in `WM_SIZE` + unsafe { + SetWindowPos( + handle, + None, + rect.left, + rect.top, + width, + height, + SWP_NOZORDER | SWP_NOACTIVATE, + ) + .context("unable to set window position after dpi has changed") + .log_err(); + } + + Some(0) + } + + /// The following conditions will trigger this event: + /// 1. The monitor on which the window is located goes offline or changes resolution. + /// 2. Another monitor goes offline, is plugged in, or changes resolution. + /// + /// In either case, the window will only receive information from the monitor on which + /// it is located. + /// + /// For example, in the case of condition 2, where the monitor on which the window is + /// located has actually changed nothing, it will still receive this event. + fn handle_display_change_msg(&self, handle: HWND) -> Option<isize> { + // NOTE: + // Even the `lParam` holds the resolution of the screen, we just ignore it. + // Because WM_DPICHANGED, WM_MOVE, WM_SIZE will come first, window reposition and resize + // are handled there. + // So we only care about if monitor is disconnected. + let previous_monitor = self.state.borrow().display; + if WindowsDisplay::is_connected(previous_monitor.handle) { + // we are fine, other display changed + return None; + } + // display disconnected + // in this case, the OS will move our window to another monitor, and minimize it. + // we deminimize the window and query the monitor after moving + unsafe { + let _ = ShowWindow(handle, SW_SHOWNORMAL); + }; + let new_monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; + // all monitors disconnected + if new_monitor.is_invalid() { + log::error!("No monitor detected!"); + return None; + } + let new_display = WindowsDisplay::new_with_handle(new_monitor); + self.state.borrow_mut().display = new_display; + Some(0) + } + + fn handle_hit_test_msg( + &self, + handle: HWND, + msg: u32, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + if !self.is_movable || self.state.borrow().is_fullscreen() { + return None; + } + + let mut lock = self.state.borrow_mut(); + if let Some(mut callback) = lock.callbacks.hit_test_window_control.take() { + drop(lock); + let area = callback(); + self.state.borrow_mut().callbacks.hit_test_window_control = Some(callback); + if let Some(area) = area { + return match area { + WindowControlArea::Drag => Some(HTCAPTION as _), + WindowControlArea::Close => Some(HTCLOSE as _), + WindowControlArea::Max => Some(HTMAXBUTTON as _), + WindowControlArea::Min => Some(HTMINBUTTON as _), + }; + } + } else { + drop(lock); + } + + if !self.hide_title_bar { + // If the OS draws the title bar, we don't need to handle hit test messages. + return None; + } + + // default handler for resize areas + let hit = unsafe { DefWindowProcW(handle, msg, wparam, lparam) }; + if matches!( + hit.0 as u32, + HTNOWHERE + | HTRIGHT + | HTLEFT + | HTTOPLEFT + | HTTOP + | HTTOPRIGHT + | HTBOTTOMRIGHT + | HTBOTTOM + | HTBOTTOMLEFT + ) { + return Some(hit.0); + } + + if self.state.borrow().is_fullscreen() { + return Some(HTCLIENT as _); + } + + let dpi = unsafe { GetDpiForWindow(handle) }; + let frame_y = unsafe { GetSystemMetricsForDpi(SM_CYFRAME, dpi) }; + + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + if !self.state.borrow().is_maximized() && cursor_point.y >= 0 && cursor_point.y <= frame_y { + return Some(HTTOP as _); + } + + Some(HTCLIENT as _) + } + + fn handle_nc_mouse_move_msg(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + self.start_tracking_mouse(handle, TME_LEAVE | TME_NONCLIENT); + + let mut lock = self.state.borrow_mut(); + let mut func = lock.callbacks.input.take()?; let scale_factor = lock.scale_factor; drop(lock); @@ -1052,206 +893,355 @@ fn handle_nc_mouse_up_msg( y: lparam.signed_hiword().into(), }; unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let input = PlatformInput::MouseUp(MouseUpEvent { - button, + let input = PlatformInput::MouseMove(MouseMoveEvent { position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + pressed_button: None, modifiers: current_modifiers(), - click_count: 1, }); let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); + self.state.borrow_mut().callbacks.input = Some(func); - if handled { - return Some(0); - } - } else { - drop(lock); + if handled { Some(0) } else { None } } - let last_pressed = state_ptr.state.borrow_mut().nc_button_pressed.take(); - if button == MouseButton::Left - && let Some(last_pressed) = last_pressed - { - let handled = match (wparam.0 as u32, last_pressed) { - (HTMINBUTTON, HTMINBUTTON) => { - unsafe { ShowWindowAsync(handle, SW_MINIMIZE).ok().log_err() }; - true + fn handle_nc_mouse_down_msg( + &self, + handle: HWND, + button: MouseButton, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + if let Some(mut func) = lock.callbacks.input.take() { + let scale_factor = lock.scale_factor; + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let physical_point = point(DevicePixels(cursor_point.x), DevicePixels(cursor_point.y)); + let click_count = lock.click_state.update(button, physical_point); + drop(lock); + + let input = PlatformInput::MouseDown(MouseDownEvent { + button, + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + modifiers: current_modifiers(), + click_count, + first_mouse: false, + }); + let result = func(input.clone()); + let handled = !result.propagate || result.default_prevented; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { + return Some(0); } - (HTMAXBUTTON, HTMAXBUTTON) => { - if state_ptr.state.borrow().is_maximized() { - unsafe { ShowWindowAsync(handle, SW_NORMAL).ok().log_err() }; - } else { - unsafe { ShowWindowAsync(handle, SW_MAXIMIZE).ok().log_err() }; - } - true - } - (HTCLOSE, HTCLOSE) => { - unsafe { - PostMessageW(Some(handle), WM_CLOSE, WPARAM::default(), LPARAM::default()) - .log_err() - }; - true - } - _ => false, + } else { + drop(lock); }; - if handled { - return Some(0); - } - } - None -} - -fn handle_cursor_changed(lparam: LPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut state = state_ptr.state.borrow_mut(); - let had_cursor = state.current_cursor.is_some(); - - state.current_cursor = if lparam.0 == 0 { - None - } else { - Some(HCURSOR(lparam.0 as _)) - }; - - if had_cursor != state.current_cursor.is_some() { - unsafe { SetCursor(state.current_cursor) }; - } - - Some(0) -} - -fn handle_set_cursor( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if unsafe { !IsWindowEnabled(handle).as_bool() } - || matches!( - lparam.loword() as u32, - HTLEFT - | HTRIGHT - | HTTOP - | HTTOPLEFT - | HTTOPRIGHT - | HTBOTTOM - | HTBOTTOMLEFT - | HTBOTTOMRIGHT - ) - { - return None; - } - unsafe { - SetCursor(state_ptr.state.borrow().current_cursor); - }; - Some(1) -} - -fn handle_system_settings_changed( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if wparam.0 != 0 { - let mut lock = state_ptr.state.borrow_mut(); - let display = lock.display; - lock.system_settings.update(display, wparam.0); - lock.click_state.system_update(wparam.0); - lock.border_offset.update(handle).log_err(); - } else { - handle_system_theme_changed(handle, lparam, state_ptr)?; - }; - // Force to trigger WM_NCCALCSIZE event to ensure that we handle auto hide - // taskbar correctly. - notify_frame_changed(handle); - - Some(0) -} - -fn handle_system_command(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - if wparam.0 == SC_KEYMENU as usize { - let mut lock = state_ptr.state.borrow_mut(); - if lock.system_key_handled { - lock.system_key_handled = false; - return Some(0); - } - } - None -} - -fn handle_system_theme_changed( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - // lParam is a pointer to a string that indicates the area containing the system parameter - // that was changed. - let parameter = PCWSTR::from_raw(lparam.0 as _); - if unsafe { !parameter.is_null() && !parameter.is_empty() } { - if let Some(parameter_string) = unsafe { parameter.to_string() }.log_err() { - log::info!("System settings changed: {}", parameter_string); - match parameter_string.as_str() { - "ImmersiveColorSet" => { - let new_appearance = system_appearance() - .context("unable to get system appearance when handling ImmersiveColorSet") - .log_err()?; - let mut lock = state_ptr.state.borrow_mut(); - if new_appearance != lock.appearance { - lock.appearance = new_appearance; - let mut callback = lock.callbacks.appearance_changed.take()?; - drop(lock); - callback(); - state_ptr.state.borrow_mut().callbacks.appearance_changed = Some(callback); - configure_dwm_dark_mode(handle, new_appearance); - } - } - _ => {} - } - } - } - Some(0) -} - -fn handle_input_language_changed( - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let thread = state_ptr.main_thread_id_win32; - let validation = state_ptr.validation_number; - unsafe { - PostThreadMessageW(thread, WM_INPUTLANGCHANGE, WPARAM(validation), lparam).log_err(); - } - Some(0) -} - -#[inline] -fn parse_char_message(wparam: WPARAM, state_ptr: &Rc<WindowsWindowStatePtr>) -> Option<String> { - let code_point = wparam.loword(); - let mut lock = state_ptr.state.borrow_mut(); - // https://www.unicode.org/versions/Unicode16.0.0/core-spec/chapter-3/#G2630 - match code_point { - 0xD800..=0xDBFF => { - // High surrogate, wait for low surrogate - lock.pending_surrogate = Some(code_point); + // Since these are handled in handle_nc_mouse_up_msg we must prevent the default window proc + if button == MouseButton::Left { + match wparam.0 as u32 { + HTMINBUTTON => self.state.borrow_mut().nc_button_pressed = Some(HTMINBUTTON), + HTMAXBUTTON => self.state.borrow_mut().nc_button_pressed = Some(HTMAXBUTTON), + HTCLOSE => self.state.borrow_mut().nc_button_pressed = Some(HTCLOSE), + _ => return None, + }; + Some(0) + } else { None } - 0xDC00..=0xDFFF => { - if let Some(high_surrogate) = lock.pending_surrogate.take() { - // Low surrogate, combine with pending high surrogate - String::from_utf16(&[high_surrogate, code_point]).ok() - } else { - // Invalid low surrogate without a preceding high surrogate - log::warn!( - "Received low surrogate without a preceding high surrogate: {code_point:x}" - ); - None + } + + fn handle_nc_mouse_up_msg( + &self, + handle: HWND, + button: MouseButton, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + if let Some(mut func) = lock.callbacks.input.take() { + let scale_factor = lock.scale_factor; + drop(lock); + + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let input = PlatformInput::MouseUp(MouseUpEvent { + button, + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + modifiers: current_modifiers(), + click_count: 1, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { + return Some(0); + } + } else { + drop(lock); + } + + let last_pressed = self.state.borrow_mut().nc_button_pressed.take(); + if button == MouseButton::Left + && let Some(last_pressed) = last_pressed + { + let handled = match (wparam.0 as u32, last_pressed) { + (HTMINBUTTON, HTMINBUTTON) => { + unsafe { ShowWindowAsync(handle, SW_MINIMIZE).ok().log_err() }; + true + } + (HTMAXBUTTON, HTMAXBUTTON) => { + if self.state.borrow().is_maximized() { + unsafe { ShowWindowAsync(handle, SW_NORMAL).ok().log_err() }; + } else { + unsafe { ShowWindowAsync(handle, SW_MAXIMIZE).ok().log_err() }; + } + true + } + (HTCLOSE, HTCLOSE) => { + unsafe { + PostMessageW(Some(handle), WM_CLOSE, WPARAM::default(), LPARAM::default()) + .log_err() + }; + true + } + _ => false, + }; + if handled { + return Some(0); } } - _ => { - lock.pending_surrogate = None; - char::from_u32(code_point as u32) - .filter(|c| !c.is_control()) - .map(|c| c.to_string()) + + None + } + + fn handle_cursor_changed(&self, lparam: LPARAM) -> Option<isize> { + let mut state = self.state.borrow_mut(); + let had_cursor = state.current_cursor.is_some(); + + state.current_cursor = if lparam.0 == 0 { + None + } else { + Some(HCURSOR(lparam.0 as _)) + }; + + if had_cursor != state.current_cursor.is_some() { + unsafe { SetCursor(state.current_cursor) }; } + + Some(0) + } + + fn handle_set_cursor(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + if unsafe { !IsWindowEnabled(handle).as_bool() } + || matches!( + lparam.loword() as u32, + HTLEFT + | HTRIGHT + | HTTOP + | HTTOPLEFT + | HTTOPRIGHT + | HTBOTTOM + | HTBOTTOMLEFT + | HTBOTTOMRIGHT + ) + { + return None; + } + unsafe { + SetCursor(self.state.borrow().current_cursor); + }; + Some(1) + } + + fn handle_system_settings_changed( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + if wparam.0 != 0 { + let mut lock = self.state.borrow_mut(); + let display = lock.display; + lock.system_settings.update(display, wparam.0); + lock.click_state.system_update(wparam.0); + lock.border_offset.update(handle).log_err(); + } else { + self.handle_system_theme_changed(handle, lparam)?; + }; + // Force to trigger WM_NCCALCSIZE event to ensure that we handle auto hide + // taskbar correctly. + notify_frame_changed(handle); + + Some(0) + } + + fn handle_system_command(&self, wparam: WPARAM) -> Option<isize> { + if wparam.0 == SC_KEYMENU as usize { + let mut lock = self.state.borrow_mut(); + if lock.system_key_handled { + lock.system_key_handled = false; + return Some(0); + } + } + None + } + + fn handle_system_theme_changed(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + // lParam is a pointer to a string that indicates the area containing the system parameter + // that was changed. + let parameter = PCWSTR::from_raw(lparam.0 as _); + if unsafe { !parameter.is_null() && !parameter.is_empty() } { + if let Some(parameter_string) = unsafe { parameter.to_string() }.log_err() { + log::info!("System settings changed: {}", parameter_string); + match parameter_string.as_str() { + "ImmersiveColorSet" => { + let new_appearance = system_appearance() + .context( + "unable to get system appearance when handling ImmersiveColorSet", + ) + .log_err()?; + let mut lock = self.state.borrow_mut(); + if new_appearance != lock.appearance { + lock.appearance = new_appearance; + let mut callback = lock.callbacks.appearance_changed.take()?; + drop(lock); + callback(); + self.state.borrow_mut().callbacks.appearance_changed = Some(callback); + configure_dwm_dark_mode(handle, new_appearance); + } + } + _ => {} + } + } + } + Some(0) + } + + fn handle_input_language_changed(&self, lparam: LPARAM) -> Option<isize> { + let thread = self.main_thread_id_win32; + let validation = self.validation_number; + unsafe { + PostThreadMessageW(thread, WM_INPUTLANGCHANGE, WPARAM(validation), lparam).log_err(); + } + Some(0) + } + + fn handle_device_change_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> { + if wparam.0 == DBT_DEVNODES_CHANGED as usize { + // The reason for sending this message is to actually trigger a redraw of the window. + unsafe { + PostMessageW( + Some(handle), + WM_GPUI_FORCE_UPDATE_WINDOW, + WPARAM(0), + LPARAM(0), + ) + .log_err(); + } + // If the GPU device is lost, this redraw will take care of recreating the device context. + // The WM_GPUI_FORCE_UPDATE_WINDOW message will take care of redrawing the window, after + // the device context has been recreated. + self.draw_window(handle, true) + } else { + // Other device change messages are not handled. + None + } + } + + #[inline] + fn draw_window(&self, handle: HWND, force_render: bool) -> Option<isize> { + let mut request_frame = self.state.borrow_mut().callbacks.request_frame.take()?; + request_frame(RequestFrameOptions { + require_presentation: false, + force_render, + }); + self.state.borrow_mut().callbacks.request_frame = Some(request_frame); + unsafe { ValidateRect(Some(handle), None).ok().log_err() }; + Some(0) + } + + #[inline] + fn parse_char_message(&self, wparam: WPARAM) -> Option<String> { + let code_point = wparam.loword(); + let mut lock = self.state.borrow_mut(); + // https://www.unicode.org/versions/Unicode16.0.0/core-spec/chapter-3/#G2630 + match code_point { + 0xD800..=0xDBFF => { + // High surrogate, wait for low surrogate + lock.pending_surrogate = Some(code_point); + None + } + 0xDC00..=0xDFFF => { + if let Some(high_surrogate) = lock.pending_surrogate.take() { + // Low surrogate, combine with pending high surrogate + String::from_utf16(&[high_surrogate, code_point]).ok() + } else { + // Invalid low surrogate without a preceding high surrogate + log::warn!( + "Received low surrogate without a preceding high surrogate: {code_point:x}" + ); + None + } + } + _ => { + lock.pending_surrogate = None; + char::from_u32(code_point as u32) + .filter(|c| !c.is_control()) + .map(|c| c.to_string()) + } + } + } + + fn start_tracking_mouse(&self, handle: HWND, flags: TRACKMOUSEEVENT_FLAGS) { + let mut lock = self.state.borrow_mut(); + if !lock.hovered { + lock.hovered = true; + unsafe { + TrackMouseEvent(&mut TRACKMOUSEEVENT { + cbSize: std::mem::size_of::<TRACKMOUSEEVENT>() as u32, + dwFlags: flags, + hwndTrack: handle, + dwHoverTime: HOVER_DEFAULT, + }) + .log_err() + }; + if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { + drop(lock); + callback(true); + self.state.borrow_mut().callbacks.hovered_status_change = Some(callback); + } + } + } + + fn with_input_handler<F, R>(&self, f: F) -> Option<R> + where + F: FnOnce(&mut PlatformInputHandler) -> R, + { + let mut input_handler = self.state.borrow_mut().input_handler.take()?; + let result = f(&mut input_handler); + self.state.borrow_mut().input_handler = Some(input_handler); + Some(result) + } + + fn with_input_handler_and_scale_factor<F, R>(&self, f: F) -> Option<R> + where + F: FnOnce(&mut PlatformInputHandler, f32) -> Option<R>, + { + let mut lock = self.state.borrow_mut(); + let mut input_handler = lock.input_handler.take()?; + let scale_factor = lock.scale_factor; + drop(lock); + let result = f(&mut input_handler, scale_factor); + self.state.borrow_mut().input_handler = Some(input_handler); + result } } @@ -1521,54 +1511,3 @@ fn notify_frame_changed(handle: HWND) { .log_err(); } } - -fn start_tracking_mouse( - handle: HWND, - state_ptr: &Rc<WindowsWindowStatePtr>, - flags: TRACKMOUSEEVENT_FLAGS, -) { - let mut lock = state_ptr.state.borrow_mut(); - if !lock.hovered { - lock.hovered = true; - unsafe { - TrackMouseEvent(&mut TRACKMOUSEEVENT { - cbSize: std::mem::size_of::<TRACKMOUSEEVENT>() as u32, - dwFlags: flags, - hwndTrack: handle, - dwHoverTime: HOVER_DEFAULT, - }) - .log_err() - }; - if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { - drop(lock); - callback(true); - state_ptr.state.borrow_mut().callbacks.hovered_status_change = Some(callback); - } - } -} - -fn with_input_handler<F, R>(state_ptr: &Rc<WindowsWindowStatePtr>, f: F) -> Option<R> -where - F: FnOnce(&mut PlatformInputHandler) -> R, -{ - let mut input_handler = state_ptr.state.borrow_mut().input_handler.take()?; - let result = f(&mut input_handler); - state_ptr.state.borrow_mut().input_handler = Some(input_handler); - Some(result) -} - -fn with_input_handler_and_scale_factor<F, R>( - state_ptr: &Rc<WindowsWindowStatePtr>, - f: F, -) -> Option<R> -where - F: FnOnce(&mut PlatformInputHandler, f32) -> Option<R>, -{ - let mut lock = state_ptr.state.borrow_mut(); - let mut input_handler = lock.input_handler.take()?; - let scale_factor = lock.scale_factor; - drop(lock); - let result = f(&mut input_handler, scale_factor); - state_ptr.state.borrow_mut().input_handler = Some(input_handler); - result -} diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index 401ecdeffe..01b043a755 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -28,13 +28,12 @@ use windows::{ core::*, }; -use crate::{platform::blade::BladeContext, *}; +use crate::*; pub(crate) struct WindowsPlatform { state: RefCell<WindowsPlatformState>, raw_window_handles: RwLock<SmallVec<[HWND; 4]>>, // The below members will never change throughout the entire lifecycle of the app. - gpu_context: BladeContext, icon: HICON, main_receiver: flume::Receiver<Runnable>, background_executor: BackgroundExecutor, @@ -45,6 +44,7 @@ pub(crate) struct WindowsPlatform { drop_target_helper: IDropTargetHelper, validation_number: usize, main_thread_id_win32: u32, + disable_direct_composition: bool, } pub(crate) struct WindowsPlatformState { @@ -94,14 +94,18 @@ impl WindowsPlatform { main_thread_id_win32, validation_number, )); + let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION) + .is_ok_and(|value| value == "true" || value == "1"); let background_executor = BackgroundExecutor::new(dispatcher.clone()); let foreground_executor = ForegroundExecutor::new(dispatcher); + let directx_devices = DirectXDevices::new(disable_direct_composition) + .context("Unable to init directx devices.")?; let bitmap_factory = ManuallyDrop::new(unsafe { CoCreateInstance(&CLSID_WICImagingFactory, None, CLSCTX_INPROC_SERVER) .context("Error creating bitmap factory.")? }); let text_system = Arc::new( - DirectWriteTextSystem::new(&bitmap_factory) + DirectWriteTextSystem::new(&directx_devices, &bitmap_factory) .context("Error creating DirectWriteTextSystem")?, ); let drop_target_helper: IDropTargetHelper = unsafe { @@ -111,18 +115,17 @@ impl WindowsPlatform { let icon = load_icon().unwrap_or_default(); let state = RefCell::new(WindowsPlatformState::new()); let raw_window_handles = RwLock::new(SmallVec::new()); - let gpu_context = BladeContext::new().context("Unable to init GPU context")?; let windows_version = WindowsVersion::new().context("Error retrieve windows version")?; Ok(Self { state, raw_window_handles, - gpu_context, icon, main_receiver, background_executor, foreground_executor, text_system, + disable_direct_composition, windows_version, bitmap_factory, drop_target_helper, @@ -141,12 +144,12 @@ impl WindowsPlatform { } } - pub fn try_get_windows_inner_from_hwnd(&self, hwnd: HWND) -> Option<Rc<WindowsWindowStatePtr>> { + pub fn window_from_hwnd(&self, hwnd: HWND) -> Option<Rc<WindowsWindowInner>> { self.raw_window_handles .read() .iter() .find(|entry| *entry == &hwnd) - .and_then(|hwnd| try_get_window_inner(*hwnd)) + .and_then(|hwnd| window_from_hwnd(*hwnd)) } #[inline] @@ -187,6 +190,7 @@ impl WindowsPlatform { validation_number: self.validation_number, main_receiver: self.main_receiver.clone(), main_thread_id_win32: self.main_thread_id_win32, + disable_direct_composition: self.disable_direct_composition, } } @@ -343,27 +347,11 @@ impl Platform for WindowsPlatform { fn run(&self, on_finish_launching: Box<dyn 'static + FnOnce()>) { on_finish_launching(); - let vsync_event = unsafe { Owned::new(CreateEventW(None, false, false, None).unwrap()) }; - begin_vsync(*vsync_event); - 'a: loop { - let wait_result = unsafe { - MsgWaitForMultipleObjects(Some(&[*vsync_event]), false, INFINITE, QS_ALLINPUT) - }; - - match wait_result { - // compositor clock ticked so we should draw a frame - WAIT_EVENT(0) => self.redraw_all(), - // Windows thread messages are posted - WAIT_EVENT(1) => { - if self.handle_events() { - break 'a; - } - } - _ => { - log::error!("Something went wrong while waiting {:?}", wait_result); - break; - } + loop { + if self.handle_events() { + break; } + self.redraw_all(); } if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit { @@ -446,7 +434,7 @@ impl Platform for WindowsPlatform { fn active_window(&self) -> Option<AnyWindowHandle> { let active_window_hwnd = unsafe { GetActiveWindow() }; - self.try_get_windows_inner_from_hwnd(active_window_hwnd) + self.window_from_hwnd(active_window_hwnd) .map(|inner| inner.handle) } @@ -455,12 +443,7 @@ impl Platform for WindowsPlatform { handle: AnyWindowHandle, options: WindowParams, ) -> Result<Box<dyn PlatformWindow>> { - let window = WindowsWindow::new( - handle, - options, - self.generate_creation_info(), - &self.gpu_context, - )?; + let window = WindowsWindow::new(handle, options, self.generate_creation_info())?; let handle = window.get_raw_handle(); self.raw_window_handles.write().push(handle); @@ -739,6 +722,7 @@ pub(crate) struct WindowCreationInfo { pub(crate) validation_number: usize, pub(crate) main_receiver: flume::Receiver<Runnable>, pub(crate) main_thread_id_win32: u32, + pub(crate) disable_direct_composition: bool, } fn open_target(target: &str) { @@ -846,16 +830,6 @@ fn file_save_dialog(directory: PathBuf, window: Option<HWND>) -> Result<Option<P Ok(Some(PathBuf::from(file_path_string))) } -fn begin_vsync(vsync_event: HANDLE) { - let event: SafeHandle = vsync_event.into(); - std::thread::spawn(move || unsafe { - loop { - windows::Win32::Graphics::Dwm::DwmFlush().log_err(); - SetEvent(*event).log_err(); - } - }); -} - fn load_icon() -> Result<HICON> { let module = unsafe { GetModuleHandleW(None).context("unable to get module handle")? }; let handle = unsafe { diff --git a/crates/gpui/src/platform/windows/shaders.hlsl b/crates/gpui/src/platform/windows/shaders.hlsl new file mode 100644 index 0000000000..25830e4b6c --- /dev/null +++ b/crates/gpui/src/platform/windows/shaders.hlsl @@ -0,0 +1,1159 @@ +cbuffer GlobalParams: register(b0) { + float2 global_viewport_size; + uint2 _pad; +}; + +Texture2D<float4> t_sprite: register(t0); +SamplerState s_sprite: register(s0); + +struct Bounds { + float2 origin; + float2 size; +}; + +struct Corners { + float top_left; + float top_right; + float bottom_right; + float bottom_left; +}; + +struct Edges { + float top; + float right; + float bottom; + float left; +}; + +struct Hsla { + float h; + float s; + float l; + float a; +}; + +struct LinearColorStop { + Hsla color; + float percentage; +}; + +struct Background { + // 0u is Solid + // 1u is LinearGradient + // 2u is PatternSlash + uint tag; + // 0u is sRGB linear color + // 1u is Oklab color + uint color_space; + Hsla solid; + float gradient_angle_or_pattern_height; + LinearColorStop colors[2]; + uint pad; +}; + +struct GradientColor { + float4 solid; + float4 color0; + float4 color1; +}; + +struct AtlasTextureId { + uint index; + uint kind; +}; + +struct AtlasBounds { + int2 origin; + int2 size; +}; + +struct AtlasTile { + AtlasTextureId texture_id; + uint tile_id; + uint padding; + AtlasBounds bounds; +}; + +struct TransformationMatrix { + float2x2 rotation_scale; + float2 translation; +}; + +static const float M_PI_F = 3.141592653f; +static const float3 GRAYSCALE_FACTORS = float3(0.2126f, 0.7152f, 0.0722f); + +float4 to_device_position_impl(float2 position) { + float2 device_position = position / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); + return float4(device_position, 0., 1.); +} + +float4 to_device_position(float2 unit_vertex, Bounds bounds) { + float2 position = unit_vertex * bounds.size + bounds.origin; + return to_device_position_impl(position); +} + +float4 distance_from_clip_rect_impl(float2 position, Bounds clip_bounds) { + float2 tl = position - clip_bounds.origin; + float2 br = clip_bounds.origin + clip_bounds.size - position; + return float4(tl.x, br.x, tl.y, br.y); +} + +float4 distance_from_clip_rect(float2 unit_vertex, Bounds bounds, Bounds clip_bounds) { + float2 position = unit_vertex * bounds.size + bounds.origin; + return distance_from_clip_rect_impl(position, clip_bounds); +} + +// Convert linear RGB to sRGB +float3 linear_to_srgb(float3 color) { + return pow(color, float3(2.2, 2.2, 2.2)); +} + +// Convert sRGB to linear RGB +float3 srgb_to_linear(float3 color) { + return pow(color, float3(1.0 / 2.2, 1.0 / 2.2, 1.0 / 2.2)); +} + +/// Hsla to linear RGBA conversion. +float4 hsla_to_rgba(Hsla hsla) { + float h = hsla.h * 6.0; // Now, it's an angle but scaled in [0, 6) range + float s = hsla.s; + float l = hsla.l; + float a = hsla.a; + + float c = (1.0 - abs(2.0 * l - 1.0)) * s; + float x = c * (1.0 - abs(fmod(h, 2.0) - 1.0)); + float m = l - c / 2.0; + + float r = 0.0; + float g = 0.0; + float b = 0.0; + + if (h >= 0.0 && h < 1.0) { + r = c; + g = x; + b = 0.0; + } else if (h >= 1.0 && h < 2.0) { + r = x; + g = c; + b = 0.0; + } else if (h >= 2.0 && h < 3.0) { + r = 0.0; + g = c; + b = x; + } else if (h >= 3.0 && h < 4.0) { + r = 0.0; + g = x; + b = c; + } else if (h >= 4.0 && h < 5.0) { + r = x; + g = 0.0; + b = c; + } else { + r = c; + g = 0.0; + b = x; + } + + float4 rgba; + rgba.x = (r + m); + rgba.y = (g + m); + rgba.z = (b + m); + rgba.w = a; + return rgba; +} + +// Converts a sRGB color to the Oklab color space. +// Reference: https://bottosson.github.io/posts/oklab/#converting-from-linear-srgb-to-oklab +float4 srgb_to_oklab(float4 color) { + // Convert non-linear sRGB to linear sRGB + color = float4(srgb_to_linear(color.rgb), color.a); + + float l = 0.4122214708 * color.r + 0.5363325363 * color.g + 0.0514459929 * color.b; + float m = 0.2119034982 * color.r + 0.6806995451 * color.g + 0.1073969566 * color.b; + float s = 0.0883024619 * color.r + 0.2817188376 * color.g + 0.6299787005 * color.b; + + float l_ = pow(l, 1.0/3.0); + float m_ = pow(m, 1.0/3.0); + float s_ = pow(s, 1.0/3.0); + + return float4( + 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_, + 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_, + 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_, + color.a + ); +} + +// Converts an Oklab color to the sRGB color space. +float4 oklab_to_srgb(float4 color) { + float l_ = color.r + 0.3963377774 * color.g + 0.2158037573 * color.b; + float m_ = color.r - 0.1055613458 * color.g - 0.0638541728 * color.b; + float s_ = color.r - 0.0894841775 * color.g - 1.2914855480 * color.b; + + float l = l_ * l_ * l_; + float m = m_ * m_ * m_; + float s = s_ * s_ * s_; + + float3 linear_rgb = float3( + 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s, + -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s, + -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s + ); + + // Convert linear sRGB to non-linear sRGB + return float4(linear_to_srgb(linear_rgb), color.a); +} + +// This approximates the error function, needed for the gaussian integral +float2 erf(float2 x) { + float2 s = sign(x); + float2 a = abs(x); + x = 1. + (0.278393 + (0.230389 + 0.078108 * (a * a)) * a) * a; + x *= x; + return s - s / (x * x); +} + +float blur_along_x(float x, float y, float sigma, float corner, float2 half_size) { + float delta = min(half_size.y - corner - abs(y), 0.); + float curved = half_size.x - corner + sqrt(max(0., corner * corner - delta * delta)); + float2 integral = 0.5 + 0.5 * erf((x + float2(-curved, curved)) * (sqrt(0.5) / sigma)); + return integral.y - integral.x; +} + +// A standard gaussian function, used for weighting samples +float gaussian(float x, float sigma) { + return exp(-(x * x) / (2. * sigma * sigma)) / (sqrt(2. * M_PI_F) * sigma); +} + +float4 over(float4 below, float4 above) { + float4 result; + float alpha = above.a + below.a * (1.0 - above.a); + result.rgb = (above.rgb * above.a + below.rgb * below.a * (1.0 - above.a)) / alpha; + result.a = alpha; + return result; +} + +float2 to_tile_position(float2 unit_vertex, AtlasTile tile) { + float2 atlas_size; + t_sprite.GetDimensions(atlas_size.x, atlas_size.y); + return (float2(tile.bounds.origin) + unit_vertex * float2(tile.bounds.size)) / atlas_size; +} + +// Selects corner radius based on quadrant. +float pick_corner_radius(float2 center_to_point, Corners corner_radii) { + if (center_to_point.x < 0.) { + if (center_to_point.y < 0.) { + return corner_radii.top_left; + } else { + return corner_radii.bottom_left; + } + } else { + if (center_to_point.y < 0.) { + return corner_radii.top_right; + } else { + return corner_radii.bottom_right; + } + } +} + +float4 to_device_position_transformed(float2 unit_vertex, Bounds bounds, + TransformationMatrix transformation) { + float2 position = unit_vertex * bounds.size + bounds.origin; + float2 transformed = mul(position, transformation.rotation_scale) + transformation.translation; + float2 device_position = transformed / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); + return float4(device_position, 0.0, 1.0); +} + +// Implementation of quad signed distance field +float quad_sdf_impl(float2 corner_center_to_point, float corner_radius) { + if (corner_radius == 0.0) { + // Fast path for unrounded corners + return max(corner_center_to_point.x, corner_center_to_point.y); + } else { + // Signed distance of the point from a quad that is inset by corner_radius + // It is negative inside this quad, and positive outside + float signed_distance_to_inset_quad = + // 0 inside the inset quad, and positive outside + length(max(float2(0.0, 0.0), corner_center_to_point)) + + // 0 outside the inset quad, and negative inside + min(0.0, max(corner_center_to_point.x, corner_center_to_point.y)); + + return signed_distance_to_inset_quad - corner_radius; + } +} + +float quad_sdf(float2 pt, Bounds bounds, Corners corner_radii) { + float2 half_size = bounds.size / 2.; + float2 center = bounds.origin + half_size; + float2 center_to_point = pt - center; + float corner_radius = pick_corner_radius(center_to_point, corner_radii); + float2 corner_to_point = abs(center_to_point) - half_size; + float2 corner_center_to_point = corner_to_point + corner_radius; + return quad_sdf_impl(corner_center_to_point, corner_radius); +} + +GradientColor prepare_gradient_color(uint tag, uint color_space, Hsla solid, LinearColorStop colors[2]) { + GradientColor output; + if (tag == 0 || tag == 2) { + output.solid = hsla_to_rgba(solid); + } else if (tag == 1) { + output.color0 = hsla_to_rgba(colors[0].color); + output.color1 = hsla_to_rgba(colors[1].color); + + // Prepare color space in vertex for avoid conversion + // in fragment shader for performance reasons + if (color_space == 1) { + // Oklab + output.color0 = srgb_to_oklab(output.color0); + output.color1 = srgb_to_oklab(output.color1); + } + } + + return output; +} + +float2x2 rotate2d(float angle) { + float s = sin(angle); + float c = cos(angle); + return float2x2(c, -s, s, c); +} + +float4 gradient_color(Background background, + float2 position, + Bounds bounds, + float4 solid_color, float4 color0, float4 color1) { + float4 color; + + switch (background.tag) { + case 0: + color = solid_color; + break; + case 1: { + // -90 degrees to match the CSS gradient angle. + float gradient_angle = background.gradient_angle_or_pattern_height; + float radians = (fmod(gradient_angle, 360.0) - 90.0) * (M_PI_F / 180.0); + float2 direction = float2(cos(radians), sin(radians)); + + // Expand the short side to be the same as the long side + if (bounds.size.x > bounds.size.y) { + direction.y *= bounds.size.y / bounds.size.x; + } else { + direction.x *= bounds.size.x / bounds.size.y; + } + + // Get the t value for the linear gradient with the color stop percentages. + float2 half_size = bounds.size * 0.5; + float2 center = bounds.origin + half_size; + float2 center_to_point = position - center; + float t = dot(center_to_point, direction) / length(direction); + // Check the direct to determine the use x or y + if (abs(direction.x) > abs(direction.y)) { + t = (t + half_size.x) / bounds.size.x; + } else { + t = (t + half_size.y) / bounds.size.y; + } + + // Adjust t based on the stop percentages + t = (t - background.colors[0].percentage) + / (background.colors[1].percentage + - background.colors[0].percentage); + t = clamp(t, 0.0, 1.0); + + switch (background.color_space) { + case 0: + color = lerp(color0, color1, t); + break; + case 1: { + float4 oklab_color = lerp(color0, color1, t); + color = oklab_to_srgb(oklab_color); + break; + } + } + break; + } + case 2: { + float gradient_angle_or_pattern_height = background.gradient_angle_or_pattern_height; + float pattern_width = (gradient_angle_or_pattern_height / 65535.0f) / 255.0f; + float pattern_interval = fmod(gradient_angle_or_pattern_height, 65535.0f) / 255.0f; + float pattern_height = pattern_width + pattern_interval; + float stripe_angle = M_PI_F / 4.0; + float pattern_period = pattern_height * sin(stripe_angle); + float2x2 rotation = rotate2d(stripe_angle); + float2 relative_position = position - bounds.origin; + float2 rotated_point = mul(rotation, relative_position); + float pattern = fmod(rotated_point.x, pattern_period); + float distance = min(pattern, pattern_period - pattern) - pattern_period * (pattern_width / pattern_height) / 2.0f; + color = solid_color; + color.a *= saturate(0.5 - distance); + break; + } + } + + return color; +} + +// Returns the dash velocity of a corner given the dash velocity of the two +// sides, by returning the slower velocity (larger dashes). +// +// Since 0 is used for dash velocity when the border width is 0 (instead of +// +inf), this returns the other dash velocity in that case. +// +// An alternative to this might be to appropriately interpolate the dash +// velocity around the corner, but that seems overcomplicated. +float corner_dash_velocity(float dv1, float dv2) { + if (dv1 == 0.0) { + return dv2; + } else if (dv2 == 0.0) { + return dv1; + } else { + return min(dv1, dv2); + } +} + +// Returns alpha used to render antialiased dashes. +// `t` is within the dash when `fmod(t, period) < length`. +float dash_alpha( + float t, float period, float length, float dash_velocity, + float antialias_threshold +) { + float half_period = period / 2.0; + float half_length = length / 2.0; + // Value in [-half_period, half_period] + // The dash is in [-half_length, half_length] + float centered = fmod(t + half_period - half_length, period) - half_period; + // Signed distance for the dash, negative values are inside the dash + float signed_distance = abs(centered) - half_length; + // Antialiased alpha based on the signed distance + return saturate(antialias_threshold - signed_distance / dash_velocity); +} + +// This approximates distance to the nearest point to a quarter ellipse in a way +// that is sufficient for anti-aliasing when the ellipse is not very eccentric. +// The components of `point` are expected to be positive. +// +// Negative on the outside and positive on the inside. +float quarter_ellipse_sdf(float2 pt, float2 radii) { + // Scale the space to treat the ellipse like a unit circle + float2 circle_vec = pt / radii; + float unit_circle_sdf = length(circle_vec) - 1.0; + // Approximate up-scaling of the length by using the average of the radii. + // + // TODO: A better solution would be to use the gradient of the implicit + // function for an ellipse to approximate a scaling factor. + return unit_circle_sdf * (radii.x + radii.y) * -0.5; +} + +/* +** +** Quads +** +*/ + +struct Quad { + uint order; + uint border_style; + Bounds bounds; + Bounds content_mask; + Background background; + Hsla border_color; + Corners corner_radii; + Edges border_widths; +}; + +struct QuadVertexOutput { + nointerpolation uint quad_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 border_color: COLOR0; + nointerpolation float4 background_solid: COLOR1; + nointerpolation float4 background_color0: COLOR2; + nointerpolation float4 background_color1: COLOR3; + float4 clip_distance: SV_ClipDistance; +}; + +struct QuadFragmentInput { + nointerpolation uint quad_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 border_color: COLOR0; + nointerpolation float4 background_solid: COLOR1; + nointerpolation float4 background_color0: COLOR2; + nointerpolation float4 background_color1: COLOR3; +}; + +StructuredBuffer<Quad> quads: register(t1); + +QuadVertexOutput quad_vertex(uint vertex_id: SV_VertexID, uint quad_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Quad quad = quads[quad_id]; + float4 device_position = to_device_position(unit_vertex, quad.bounds); + + GradientColor gradient = prepare_gradient_color( + quad.background.tag, + quad.background.color_space, + quad.background.solid, + quad.background.colors + ); + float4 clip_distance = distance_from_clip_rect(unit_vertex, quad.bounds, quad.content_mask); + float4 border_color = hsla_to_rgba(quad.border_color); + + QuadVertexOutput output; + output.position = device_position; + output.border_color = border_color; + output.quad_id = quad_id; + output.background_solid = gradient.solid; + output.background_color0 = gradient.color0; + output.background_color1 = gradient.color1; + output.clip_distance = clip_distance; + return output; +} + +float4 quad_fragment(QuadFragmentInput input): SV_Target { + Quad quad = quads[input.quad_id]; + float4 background_color = gradient_color(quad.background, input.position.xy, quad.bounds, + input.background_solid, input.background_color0, input.background_color1); + + bool unrounded = quad.corner_radii.top_left == 0.0 && + quad.corner_radii.top_right == 0.0 && + quad.corner_radii.bottom_left == 0.0 && + quad.corner_radii.bottom_right == 0.0; + + // Fast path when the quad is not rounded and doesn't have any border + if (quad.border_widths.top == 0.0 && + quad.border_widths.left == 0.0 && + quad.border_widths.right == 0.0 && + quad.border_widths.bottom == 0.0 && + unrounded) { + return background_color; + } + + float2 size = quad.bounds.size; + float2 half_size = size / 2.; + float2 the_point = input.position.xy - quad.bounds.origin; + float2 center_to_point = the_point - half_size; + + // Signed distance field threshold for inclusion of pixels. 0.5 is the + // minimum distance between the center of the pixel and the edge. + const float antialias_threshold = 0.5; + + // Radius of the nearest corner + float corner_radius = pick_corner_radius(center_to_point, quad.corner_radii); + + float2 border = float2( + center_to_point.x < 0.0 ? quad.border_widths.left : quad.border_widths.right, + center_to_point.y < 0.0 ? quad.border_widths.top : quad.border_widths.bottom + ); + + // 0-width borders are reduced so that `inner_sdf >= antialias_threshold`. + // The purpose of this is to not draw antialiasing pixels in this case. + float2 reduced_border = float2( + border.x == 0.0 ? -antialias_threshold : border.x, + border.y == 0.0 ? -antialias_threshold : border.y + ); + + // Vector from the corner of the quad bounds to the point, after mirroring + // the point into the bottom right quadrant. Both components are <= 0. + float2 corner_to_point = abs(center_to_point) - half_size; + + // Vector from the point to the center of the rounded corner's circle, also + // mirrored into bottom right quadrant. + float2 corner_center_to_point = corner_to_point + corner_radius; + + // Whether the nearest point on the border is rounded + bool is_near_rounded_corner = + corner_center_to_point.x >= 0.0 && + corner_center_to_point.y >= 0.0; + + // Vector from straight border inner corner to point. + // + // 0-width borders are turned into width -1 so that inner_sdf is > 1.0 near + // the border. Without this, antialiasing pixels would be drawn. + float2 straight_border_inner_corner_to_point = corner_to_point + reduced_border; + + // Whether the point is beyond the inner edge of the straight border + bool is_beyond_inner_straight_border = + straight_border_inner_corner_to_point.x > 0.0 || + straight_border_inner_corner_to_point.y > 0.0; + + // Whether the point is far enough inside the quad, such that the pixels are + // not affected by the straight border. + bool is_within_inner_straight_border = + straight_border_inner_corner_to_point.x < -antialias_threshold && + straight_border_inner_corner_to_point.y < -antialias_threshold; + + // Fast path for points that must be part of the background + if (is_within_inner_straight_border && !is_near_rounded_corner) { + return background_color; + } + + // Signed distance of the point to the outside edge of the quad's border + float outer_sdf = quad_sdf_impl(corner_center_to_point, corner_radius); + + // Approximate signed distance of the point to the inside edge of the quad's + // border. It is negative outside this edge (within the border), and + // positive inside. + // + // This is not always an accurate signed distance: + // * The rounded portions with varying border width use an approximation of + // nearest-point-on-ellipse. + // * When it is quickly known to be outside the edge, -1.0 is used. + float inner_sdf = 0.0; + if (corner_center_to_point.x <= 0.0 || corner_center_to_point.y <= 0.0) { + // Fast paths for straight borders + inner_sdf = -max(straight_border_inner_corner_to_point.x, + straight_border_inner_corner_to_point.y); + } else if (is_beyond_inner_straight_border) { + // Fast path for points that must be outside the inner edge + inner_sdf = -1.0; + } else if (reduced_border.x == reduced_border.y) { + // Fast path for circular inner edge. + inner_sdf = -(outer_sdf + reduced_border.x); + } else { + float2 ellipse_radii = max(float2(0.0, 0.0), float2(corner_radius, corner_radius) - reduced_border); + inner_sdf = quarter_ellipse_sdf(corner_center_to_point, ellipse_radii); + } + + // Negative when inside the border + float border_sdf = max(inner_sdf, outer_sdf); + + float4 color = background_color; + if (border_sdf < antialias_threshold) { + float4 border_color = input.border_color; + // Dashed border logic when border_style == 1 + if (quad.border_style == 1) { + // Position along the perimeter in "dash space", where each dash + // period has length 1 + float t = 0.0; + + // Total number of dash periods, so that the dash spacing can be + // adjusted to evenly divide it + float max_t = 0.0; + + // Border width is proportional to dash size. This is the behavior + // used by browsers, but also avoids dashes from different segments + // overlapping when dash size is smaller than the border width. + // + // Dash pattern: (2 * border width) dash, (1 * border width) gap + const float dash_length_per_width = 2.0; + const float dash_gap_per_width = 1.0; + const float dash_period_per_width = dash_length_per_width + dash_gap_per_width; + + // Since the dash size is determined by border width, the density of + // dashes varies. Multiplying a pixel distance by this returns a + // position in dash space - it has units (dash period / pixels). So + // a dash velocity of (1 / 10) is 1 dash every 10 pixels. + float dash_velocity = 0.0; + + // Dividing this by the border width gives the dash velocity + const float dv_numerator = 1.0 / dash_period_per_width; + + if (unrounded) { + // When corners aren't rounded, the dashes are separately laid + // out on each straight line, rather than around the whole + // perimeter. This way each line starts and ends with a dash. + bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; + float border_width = is_horizontal ? border.x : border.y; + dash_velocity = dv_numerator / border_width; + t = is_horizontal ? the_point.x : the_point.y; + t *= dash_velocity; + max_t = is_horizontal ? size.x : size.y; + max_t *= dash_velocity; + } else { + // When corners are rounded, the dashes are laid out clockwise + // around the whole perimeter. + + float r_tr = quad.corner_radii.top_right; + float r_br = quad.corner_radii.bottom_right; + float r_bl = quad.corner_radii.bottom_left; + float r_tl = quad.corner_radii.top_left; + + float w_t = quad.border_widths.top; + float w_r = quad.border_widths.right; + float w_b = quad.border_widths.bottom; + float w_l = quad.border_widths.left; + + // Straight side dash velocities + float dv_t = w_t <= 0.0 ? 0.0 : dv_numerator / w_t; + float dv_r = w_r <= 0.0 ? 0.0 : dv_numerator / w_r; + float dv_b = w_b <= 0.0 ? 0.0 : dv_numerator / w_b; + float dv_l = w_l <= 0.0 ? 0.0 : dv_numerator / w_l; + + // Straight side lengths in dash space + float s_t = (size.x - r_tl - r_tr) * dv_t; + float s_r = (size.y - r_tr - r_br) * dv_r; + float s_b = (size.x - r_br - r_bl) * dv_b; + float s_l = (size.y - r_bl - r_tl) * dv_l; + + float corner_dash_velocity_tr = corner_dash_velocity(dv_t, dv_r); + float corner_dash_velocity_br = corner_dash_velocity(dv_b, dv_r); + float corner_dash_velocity_bl = corner_dash_velocity(dv_b, dv_l); + float corner_dash_velocity_tl = corner_dash_velocity(dv_t, dv_l); + + // Corner lengths in dash space + float c_tr = r_tr * (M_PI_F / 2.0) * corner_dash_velocity_tr; + float c_br = r_br * (M_PI_F / 2.0) * corner_dash_velocity_br; + float c_bl = r_bl * (M_PI_F / 2.0) * corner_dash_velocity_bl; + float c_tl = r_tl * (M_PI_F / 2.0) * corner_dash_velocity_tl; + + // Cumulative dash space upto each segment + float upto_tr = s_t; + float upto_r = upto_tr + c_tr; + float upto_br = upto_r + s_r; + float upto_b = upto_br + c_br; + float upto_bl = upto_b + s_b; + float upto_l = upto_bl + c_bl; + float upto_tl = upto_l + s_l; + max_t = upto_tl + c_tl; + + if (is_near_rounded_corner) { + float radians = atan2(corner_center_to_point.y, corner_center_to_point.x); + float corner_t = radians * corner_radius; + + if (center_to_point.x >= 0.0) { + if (center_to_point.y < 0.0) { + dash_velocity = corner_dash_velocity_tr; + // Subtracted because radians is pi/2 to 0 when + // going clockwise around the top right corner, + // since the y axis has been flipped + t = upto_r - corner_t * dash_velocity; + } else { + dash_velocity = corner_dash_velocity_br; + // Added because radians is 0 to pi/2 when going + // clockwise around the bottom-right corner + t = upto_br + corner_t * dash_velocity; + } + } else { + if (center_to_point.y >= 0.0) { + dash_velocity = corner_dash_velocity_bl; + // Subtracted because radians is pi/1 to 0 when + // going clockwise around the bottom-left corner, + // since the x axis has been flipped + t = upto_l - corner_t * dash_velocity; + } else { + dash_velocity = corner_dash_velocity_tl; + // Added because radians is 0 to pi/2 when going + // clockwise around the top-left corner, since both + // axis were flipped + t = upto_tl + corner_t * dash_velocity; + } + } + } else { + // Straight borders + bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; + if (is_horizontal) { + if (center_to_point.y < 0.0) { + dash_velocity = dv_t; + t = (the_point.x - r_tl) * dash_velocity; + } else { + dash_velocity = dv_b; + t = upto_bl - (the_point.x - r_bl) * dash_velocity; + } + } else { + if (center_to_point.x < 0.0) { + dash_velocity = dv_l; + t = upto_tl - (the_point.y - r_tl) * dash_velocity; + } else { + dash_velocity = dv_r; + t = upto_r + (the_point.y - r_tr) * dash_velocity; + } + } + } + } + float dash_length = dash_length_per_width / dash_period_per_width; + float desired_dash_gap = dash_gap_per_width / dash_period_per_width; + + // Straight borders should start and end with a dash, so max_t is + // reduced to cause this. + max_t -= unrounded ? dash_length : 0.0; + if (max_t >= 1.0) { + // Adjust dash gap to evenly divide max_t + float dash_count = floor(max_t); + float dash_period = max_t / dash_count; + border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); + } else if (unrounded) { + // When there isn't enough space for the full gap between the + // two start / end dashes of a straight border, reduce gap to + // make them fit. + float dash_gap = max_t - dash_length; + if (dash_gap > 0.0) { + float dash_period = dash_length + dash_gap; + border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); + } + } + } + + // Blend the border on top of the background and then linearly interpolate + // between the two as we slide inside the background. + float4 blended_border = over(background_color, border_color); + color = lerp(background_color, blended_border, + saturate(antialias_threshold - inner_sdf)); + } + + return color * float4(1.0, 1.0, 1.0, saturate(antialias_threshold - outer_sdf)); +} + +/* +** +** Shadows +** +*/ + +struct Shadow { + uint order; + float blur_radius; + Bounds bounds; + Corners corner_radii; + Bounds content_mask; + Hsla color; +}; + +struct ShadowVertexOutput { + nointerpolation uint shadow_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct ShadowFragmentInput { + nointerpolation uint shadow_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; +}; + +StructuredBuffer<Shadow> shadows: register(t1); + +ShadowVertexOutput shadow_vertex(uint vertex_id: SV_VertexID, uint shadow_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Shadow shadow = shadows[shadow_id]; + + float margin = 3.0 * shadow.blur_radius; + Bounds bounds = shadow.bounds; + bounds.origin -= margin; + bounds.size += 2.0 * margin; + + float4 device_position = to_device_position(unit_vertex, bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, bounds, shadow.content_mask); + float4 color = hsla_to_rgba(shadow.color); + + ShadowVertexOutput output; + output.position = device_position; + output.color = color; + output.shadow_id = shadow_id; + output.clip_distance = clip_distance; + + return output; +} + +float4 shadow_fragment(ShadowFragmentInput input): SV_TARGET { + Shadow shadow = shadows[input.shadow_id]; + + float2 half_size = shadow.bounds.size / 2.; + float2 center = shadow.bounds.origin + half_size; + float2 point0 = input.position.xy - center; + float corner_radius = pick_corner_radius(point0, shadow.corner_radii); + + // The signal is only non-zero in a limited range, so don't waste samples + float low = point0.y - half_size.y; + float high = point0.y + half_size.y; + float start = clamp(-3. * shadow.blur_radius, low, high); + float end = clamp(3. * shadow.blur_radius, low, high); + + // Accumulate samples (we can get away with surprisingly few samples) + float step = (end - start) / 4.; + float y = start + step * 0.5; + float alpha = 0.; + for (int i = 0; i < 4; i++) { + alpha += blur_along_x(point0.x, point0.y - y, shadow.blur_radius, + corner_radius, half_size) * + gaussian(y, shadow.blur_radius) * step; + y += step; + } + + return input.color * float4(1., 1., 1., alpha); +} + +/* +** +** Path Rasterization +** +*/ + +struct PathRasterizationSprite { + float2 xy_position; + float2 st_position; + Background color; + Bounds bounds; +}; + +StructuredBuffer<PathRasterizationSprite> path_rasterization_sprites: register(t1); + +struct PathVertexOutput { + float4 position: SV_Position; + float2 st_position: TEXCOORD0; + nointerpolation uint vertex_id: TEXCOORD1; + float4 clip_distance: SV_ClipDistance; +}; + +struct PathFragmentInput { + float4 position: SV_Position; + float2 st_position: TEXCOORD0; + nointerpolation uint vertex_id: TEXCOORD1; +}; + +PathVertexOutput path_rasterization_vertex(uint vertex_id: SV_VertexID) { + PathRasterizationSprite sprite = path_rasterization_sprites[vertex_id]; + + PathVertexOutput output; + output.position = to_device_position_impl(sprite.xy_position); + output.st_position = sprite.st_position; + output.vertex_id = vertex_id; + output.clip_distance = distance_from_clip_rect_impl(sprite.xy_position, sprite.bounds); + + return output; +} + +float4 path_rasterization_fragment(PathFragmentInput input): SV_Target { + float2 dx = ddx(input.st_position); + float2 dy = ddy(input.st_position); + PathRasterizationSprite sprite = path_rasterization_sprites[input.vertex_id]; + + Background background = sprite.color; + Bounds bounds = sprite.bounds; + + float alpha; + if (length(float2(dx.x, dy.x))) { + alpha = 1.0; + } else { + float2 gradient = 2.0 * input.st_position.xx * float2(dx.x, dy.x) - float2(dx.y, dy.y); + float f = input.st_position.x * input.st_position.x - input.st_position.y; + float distance = f / length(gradient); + alpha = saturate(0.5 - distance); + } + + GradientColor gradient = prepare_gradient_color( + background.tag, background.color_space, background.solid, background.colors); + + float4 color = gradient_color(background, input.position.xy, bounds, + gradient.solid, gradient.color0, gradient.color1); + return float4(color.rgb * color.a * alpha, alpha * color.a); +} + +/* +** +** Path Sprites +** +*/ + +struct PathSprite { + Bounds bounds; +}; + +struct PathSpriteVertexOutput { + float4 position: SV_Position; + float2 texture_coords: TEXCOORD0; +}; + +StructuredBuffer<PathSprite> path_sprites: register(t1); + +PathSpriteVertexOutput path_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + PathSprite sprite = path_sprites[sprite_id]; + + // Don't apply content mask because it was already accounted for when rasterizing the path + float4 device_position = to_device_position(unit_vertex, sprite.bounds); + + float2 screen_position = sprite.bounds.origin + unit_vertex * sprite.bounds.size; + float2 texture_coords = screen_position / global_viewport_size; + + PathSpriteVertexOutput output; + output.position = device_position; + output.texture_coords = texture_coords; + return output; +} + +float4 path_sprite_fragment(PathSpriteVertexOutput input): SV_Target { + return t_sprite.Sample(s_sprite, input.texture_coords); +} + +/* +** +** Underlines +** +*/ + +struct Underline { + uint order; + uint pad; + Bounds bounds; + Bounds content_mask; + Hsla color; + float thickness; + uint wavy; +}; + +struct UnderlineVertexOutput { + nointerpolation uint underline_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct UnderlineFragmentInput { + nointerpolation uint underline_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; +}; + +StructuredBuffer<Underline> underlines: register(t1); + +UnderlineVertexOutput underline_vertex(uint vertex_id: SV_VertexID, uint underline_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Underline underline = underlines[underline_id]; + float4 device_position = to_device_position(unit_vertex, underline.bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, underline.bounds, + underline.content_mask); + float4 color = hsla_to_rgba(underline.color); + + UnderlineVertexOutput output; + output.position = device_position; + output.color = color; + output.underline_id = underline_id; + output.clip_distance = clip_distance; + return output; +} + +float4 underline_fragment(UnderlineFragmentInput input): SV_Target { + Underline underline = underlines[input.underline_id]; + if (underline.wavy) { + float half_thickness = underline.thickness * 0.5; + float2 origin = underline.bounds.origin; + float2 st = ((input.position.xy - origin) / underline.bounds.size.y) - float2(0., 0.5); + float frequency = (M_PI_F * (3. * underline.thickness)) / 8.; + float amplitude = 1. / (2. * underline.thickness); + float sine = sin(st.x * frequency) * amplitude; + float dSine = cos(st.x * frequency) * amplitude * frequency; + float distance = (st.y - sine) / sqrt(1. + dSine * dSine); + float distance_in_pixels = distance * underline.bounds.size.y; + float distance_from_top_border = distance_in_pixels - half_thickness; + float distance_from_bottom_border = distance_in_pixels + half_thickness; + float alpha = saturate( + 0.5 - max(-distance_from_bottom_border, distance_from_top_border)); + return input.color * float4(1., 1., 1., alpha); + } else { + return input.color; + } +} + +/* +** +** Monochrome sprites +** +*/ + +struct MonochromeSprite { + uint order; + uint pad; + Bounds bounds; + Bounds content_mask; + Hsla color; + AtlasTile tile; + TransformationMatrix transformation; +}; + +struct MonochromeSpriteVertexOutput { + float4 position: SV_Position; + float2 tile_position: POSITION; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct MonochromeSpriteFragmentInput { + float4 position: SV_Position; + float2 tile_position: POSITION; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +StructuredBuffer<MonochromeSprite> mono_sprites: register(t1); + +MonochromeSpriteVertexOutput monochrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + MonochromeSprite sprite = mono_sprites[sprite_id]; + float4 device_position = + to_device_position_transformed(unit_vertex, sprite.bounds, sprite.transformation); + float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, sprite.content_mask); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile); + float4 color = hsla_to_rgba(sprite.color); + + MonochromeSpriteVertexOutput output; + output.position = device_position; + output.tile_position = tile_position; + output.color = color; + output.clip_distance = clip_distance; + return output; +} + +float4 monochrome_sprite_fragment(MonochromeSpriteFragmentInput input): SV_Target { + float sample = t_sprite.Sample(s_sprite, input.tile_position).r; + return float4(input.color.rgb, input.color.a * sample); +} + +/* +** +** Polychrome sprites +** +*/ + +struct PolychromeSprite { + uint order; + uint pad; + uint grayscale; + float opacity; + Bounds bounds; + Bounds content_mask; + Corners corner_radii; + AtlasTile tile; +}; + +struct PolychromeSpriteVertexOutput { + nointerpolation uint sprite_id: TEXCOORD0; + float4 position: SV_Position; + float2 tile_position: POSITION; + float4 clip_distance: SV_ClipDistance; +}; + +struct PolychromeSpriteFragmentInput { + nointerpolation uint sprite_id: TEXCOORD0; + float4 position: SV_Position; + float2 tile_position: POSITION; +}; + +StructuredBuffer<PolychromeSprite> poly_sprites: register(t1); + +PolychromeSpriteVertexOutput polychrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + PolychromeSprite sprite = poly_sprites[sprite_id]; + float4 device_position = to_device_position(unit_vertex, sprite.bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, + sprite.content_mask); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile); + + PolychromeSpriteVertexOutput output; + output.position = device_position; + output.tile_position = tile_position; + output.sprite_id = sprite_id; + output.clip_distance = clip_distance; + return output; +} + +float4 polychrome_sprite_fragment(PolychromeSpriteFragmentInput input): SV_Target { + PolychromeSprite sprite = poly_sprites[input.sprite_id]; + float4 sample = t_sprite.Sample(s_sprite, input.tile_position); + float distance = quad_sdf(input.position.xy, sprite.bounds, sprite.corner_radii); + + float4 color = sample; + if ((sprite.grayscale & 0xFFu) != 0u) { + float3 grayscale = dot(color.rgb, GRAYSCALE_FACTORS); + color = float4(grayscale, sample.a); + } + color.a *= sprite.opacity * saturate(0.5 - distance); + return color; +} diff --git a/crates/gpui/src/platform/windows/window.rs b/crates/gpui/src/platform/windows/window.rs index 5703a82815..32a6da2391 100644 --- a/crates/gpui/src/platform/windows/window.rs +++ b/crates/gpui/src/platform/windows/window.rs @@ -26,10 +26,9 @@ use windows::{ core::*, }; -use crate::platform::blade::{BladeContext, BladeRenderer}; use crate::*; -pub(crate) struct WindowsWindow(pub Rc<WindowsWindowStatePtr>); +pub(crate) struct WindowsWindow(pub Rc<WindowsWindowInner>); pub struct WindowsWindowState { pub origin: Point<Pixels>, @@ -49,7 +48,7 @@ pub struct WindowsWindowState { pub system_key_handled: bool, pub hovered: bool, - pub renderer: BladeRenderer, + pub renderer: DirectXRenderer, pub click_state: ClickState, pub system_settings: WindowsSystemSettings, @@ -62,9 +61,9 @@ pub struct WindowsWindowState { hwnd: HWND, } -pub(crate) struct WindowsWindowStatePtr { +pub(crate) struct WindowsWindowInner { hwnd: HWND, - this: Weak<Self>, + pub(super) this: Weak<Self>, drop_target_helper: IDropTargetHelper, pub(crate) state: RefCell<WindowsWindowState>, pub(crate) handle: AnyWindowHandle, @@ -80,21 +79,23 @@ pub(crate) struct WindowsWindowStatePtr { impl WindowsWindowState { fn new( hwnd: HWND, - transparent: bool, - cs: &CREATESTRUCTW, + window_params: &CREATESTRUCTW, current_cursor: Option<HCURSOR>, display: WindowsDisplay, - gpu_context: &BladeContext, min_size: Option<Size<Pixels>>, appearance: WindowAppearance, + disable_direct_composition: bool, ) -> Result<Self> { let scale_factor = { let monitor_dpi = unsafe { GetDpiForWindow(hwnd) } as f32; monitor_dpi / USER_DEFAULT_SCREEN_DPI as f32 }; - let origin = logical_point(cs.x as f32, cs.y as f32, scale_factor); + let origin = logical_point(window_params.x as f32, window_params.y as f32, scale_factor); let logical_size = { - let physical_size = size(DevicePixels(cs.cx), DevicePixels(cs.cy)); + let physical_size = size( + DevicePixels(window_params.cx), + DevicePixels(window_params.cy), + ); physical_size.to_pixels(scale_factor) }; let fullscreen_restore_bounds = Bounds { @@ -103,7 +104,8 @@ impl WindowsWindowState { }; let border_offset = WindowBorderOffset::default(); let restore_from_minimized = None; - let renderer = windows_renderer::init(gpu_context, hwnd, transparent)?; + let renderer = DirectXRenderer::new(hwnd, disable_direct_composition) + .context("Creating DirectX renderer")?; let callbacks = Callbacks::default(); let input_handler = None; let pending_surrogate = None; @@ -202,17 +204,16 @@ impl WindowsWindowState { } } -impl WindowsWindowStatePtr { +impl WindowsWindowInner { fn new(context: &WindowCreateContext, hwnd: HWND, cs: &CREATESTRUCTW) -> Result<Rc<Self>> { let state = RefCell::new(WindowsWindowState::new( hwnd, - context.transparent, cs, context.current_cursor, context.display, - context.gpu_context, context.min_size, context.appearance, + context.disable_direct_composition, )?); Ok(Rc::new_cyclic(|this| Self { @@ -232,13 +233,13 @@ impl WindowsWindowStatePtr { } fn toggle_fullscreen(&self) { - let Some(state_ptr) = self.this.upgrade() else { + let Some(this) = self.this.upgrade() else { log::error!("Unable to toggle fullscreen: window has been dropped"); return; }; self.executor .spawn(async move { - let mut lock = state_ptr.state.borrow_mut(); + let mut lock = this.state.borrow_mut(); let StyleAndBounds { style, x, @@ -250,10 +251,9 @@ impl WindowsWindowStatePtr { } else { let (window_bounds, _) = lock.calculate_window_bounds(); lock.fullscreen_restore_bounds = window_bounds; - let style = - WINDOW_STYLE(unsafe { get_window_long(state_ptr.hwnd, GWL_STYLE) } as _); + let style = WINDOW_STYLE(unsafe { get_window_long(this.hwnd, GWL_STYLE) } as _); let mut rc = RECT::default(); - unsafe { GetWindowRect(state_ptr.hwnd, &mut rc) }.log_err(); + unsafe { GetWindowRect(this.hwnd, &mut rc) }.log_err(); let _ = lock.fullscreen.insert(StyleAndBounds { style, x: rc.left, @@ -277,10 +277,10 @@ impl WindowsWindowStatePtr { } }; drop(lock); - unsafe { set_window_long(state_ptr.hwnd, GWL_STYLE, style.0 as isize) }; + unsafe { set_window_long(this.hwnd, GWL_STYLE, style.0 as isize) }; unsafe { SetWindowPos( - state_ptr.hwnd, + this.hwnd, None, x, y, @@ -329,12 +329,11 @@ pub(crate) struct Callbacks { pub(crate) appearance_changed: Option<Box<dyn FnMut()>>, } -struct WindowCreateContext<'a> { - inner: Option<Result<Rc<WindowsWindowStatePtr>>>, +struct WindowCreateContext { + inner: Option<Result<Rc<WindowsWindowInner>>>, handle: AnyWindowHandle, hide_title_bar: bool, display: WindowsDisplay, - transparent: bool, is_movable: bool, min_size: Option<Size<Pixels>>, executor: ForegroundExecutor, @@ -343,9 +342,9 @@ struct WindowCreateContext<'a> { drop_target_helper: IDropTargetHelper, validation_number: usize, main_receiver: flume::Receiver<Runnable>, - gpu_context: &'a BladeContext, main_thread_id_win32: u32, appearance: WindowAppearance, + disable_direct_composition: bool, } impl WindowsWindow { @@ -353,7 +352,6 @@ impl WindowsWindow { handle: AnyWindowHandle, params: WindowParams, creation_info: WindowCreationInfo, - gpu_context: &BladeContext, ) -> Result<Self> { let WindowCreationInfo { icon, @@ -364,14 +362,15 @@ impl WindowsWindow { validation_number, main_receiver, main_thread_id_win32, + disable_direct_composition, } = creation_info; - let classname = register_wnd_class(icon); + register_window_class(icon); let hide_title_bar = params .titlebar .as_ref() .map(|titlebar| titlebar.appears_transparent) .unwrap_or(true); - let windowname = HSTRING::from( + let window_name = HSTRING::from( params .titlebar .as_ref() @@ -379,14 +378,18 @@ impl WindowsWindow { .map(|title| title.as_ref()) .unwrap_or(""), ); - let (dwexstyle, mut dwstyle) = if params.kind == WindowKind::PopUp { - (WS_EX_TOOLWINDOW | WS_EX_LAYERED, WINDOW_STYLE(0x0)) + + let (mut dwexstyle, dwstyle) = if params.kind == WindowKind::PopUp { + (WS_EX_TOOLWINDOW, WINDOW_STYLE(0x0)) } else { ( - WS_EX_APPWINDOW | WS_EX_LAYERED, + WS_EX_APPWINDOW, WS_THICKFRAME | WS_SYSMENU | WS_MAXIMIZEBOX | WS_MINIMIZEBOX, ) }; + if !disable_direct_composition { + dwexstyle |= WS_EX_NOREDIRECTIONBITMAP; + } let hinstance = get_module_handle(); let display = if let Some(display_id) = params.display_id { @@ -401,7 +404,6 @@ impl WindowsWindow { handle, hide_title_bar, display, - transparent: true, is_movable: params.is_movable, min_size: params.window_min_size, executor, @@ -410,16 +412,15 @@ impl WindowsWindow { drop_target_helper, validation_number, main_receiver, - gpu_context, main_thread_id_win32, appearance, + disable_direct_composition, }; - let lpparam = Some(&context as *const _ as *const _); let creation_result = unsafe { CreateWindowExW( dwexstyle, - classname, - &windowname, + WINDOW_CLASS_NAME, + &window_name, dwstyle, CW_USEDEFAULT, CW_USEDEFAULT, @@ -428,41 +429,35 @@ impl WindowsWindow { None, None, Some(hinstance.into()), - lpparam, + Some(&context as *const _ as *const _), ) }; - // We should call `?` on state_ptr first, then call `?` on hwnd. - // Or, we will lose the error info reported by `WindowsWindowState::new` - let state_ptr = context.inner.take().unwrap()?; + + // Failure to create a `WindowsWindowState` can cause window creation to fail, + // so check the inner result first. + let this = context.inner.take().unwrap()?; let hwnd = creation_result?; - register_drag_drop(state_ptr.clone())?; + + register_drag_drop(&this)?; configure_dwm_dark_mode(hwnd, appearance); - state_ptr.state.borrow_mut().border_offset.update(hwnd)?; + this.state.borrow_mut().border_offset.update(hwnd)?; let placement = retrieve_window_placement( hwnd, display, params.bounds, - state_ptr.state.borrow().scale_factor, - state_ptr.state.borrow().border_offset, + this.state.borrow().scale_factor, + this.state.borrow().border_offset, )?; if params.show { unsafe { SetWindowPlacement(hwnd, &placement)? }; } else { - state_ptr.state.borrow_mut().initial_placement = Some(WindowOpenStatus { + this.state.borrow_mut().initial_placement = Some(WindowOpenStatus { placement, state: WindowOpenState::Windowed, }); } - // The render pipeline will perform compositing on the GPU when the - // swapchain is configured correctly (see downstream of - // update_transparency). - // The following configuration is a one-time setup to ensure that the - // window is going to be composited with per-pixel alpha, but the render - // pipeline is responsible for effectively calling UpdateLayeredWindow - // at the appropriate time. - unsafe { SetLayeredWindowAttributes(hwnd, COLORREF(0), 255, LWA_ALPHA)? }; - Ok(Self(state_ptr)) + Ok(Self(this)) } } @@ -485,7 +480,6 @@ impl rwh::HasDisplayHandle for WindowsWindow { impl Drop for WindowsWindow { fn drop(&mut self) { - self.0.state.borrow_mut().renderer.destroy(); // clone this `Rc` to prevent early release of the pointer let this = self.0.clone(); self.0 @@ -683,6 +677,36 @@ impl PlatformWindow for WindowsWindow { this.set_window_placement().log_err(); unsafe { SetActiveWindow(hwnd).log_err() }; unsafe { SetFocus(Some(hwnd)).log_err() }; + + // premium ragebait by windows, this is needed because the window + // must have received an input event to be able to set itself to foreground + // so let's just simulate user input as that seems to be the most reliable way + // some more info: https://gist.github.com/Aetopia/1581b40f00cc0cadc93a0e8ccb65dc8c + // bonus: this bug also doesn't manifest if you have vs attached to the process + let inputs = [ + INPUT { + r#type: INPUT_KEYBOARD, + Anonymous: INPUT_0 { + ki: KEYBDINPUT { + wVk: VK_MENU, + dwFlags: KEYBD_EVENT_FLAGS(0), + ..Default::default() + }, + }, + }, + INPUT { + r#type: INPUT_KEYBOARD, + Anonymous: INPUT_0 { + ki: KEYBDINPUT { + wVk: VK_MENU, + dwFlags: KEYEVENTF_KEYUP, + ..Default::default() + }, + }, + }, + ]; + unsafe { SendInput(&inputs, std::mem::size_of::<INPUT>() as i32) }; + // todo(windows) // crate `windows 0.56` reports true as Err unsafe { SetForegroundWindow(hwnd).as_bool() }; @@ -705,24 +729,21 @@ impl PlatformWindow for WindowsWindow { } fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) { - let mut window_state = self.0.state.borrow_mut(); - window_state - .renderer - .update_transparency(background_appearance != WindowBackgroundAppearance::Opaque); + let hwnd = self.0.hwnd; match background_appearance { WindowBackgroundAppearance::Opaque => { // ACCENT_DISABLED - set_window_composition_attribute(window_state.hwnd, None, 0); + set_window_composition_attribute(hwnd, None, 0); } WindowBackgroundAppearance::Transparent => { // Use ACCENT_ENABLE_TRANSPARENTGRADIENT for transparent background - set_window_composition_attribute(window_state.hwnd, None, 2); + set_window_composition_attribute(hwnd, None, 2); } WindowBackgroundAppearance::Blurred => { // Enable acrylic blur // ACCENT_ENABLE_ACRYLICBLURBEHIND - set_window_composition_attribute(window_state.hwnd, Some((0, 0, 0, 0)), 4); + set_window_composition_attribute(hwnd, Some((0, 0, 0, 0)), 4); } } } @@ -794,11 +815,11 @@ impl PlatformWindow for WindowsWindow { } fn draw(&self, scene: &Scene) { - self.0.state.borrow_mut().renderer.draw(scene) + self.0.state.borrow_mut().renderer.draw(scene).log_err(); } fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> { - self.0.state.borrow().renderer.sprite_atlas().clone() + self.0.state.borrow().renderer.sprite_atlas() } fn get_raw_handle(&self) -> HWND { @@ -806,16 +827,16 @@ impl PlatformWindow for WindowsWindow { } fn gpu_specs(&self) -> Option<GpuSpecs> { - Some(self.0.state.borrow().renderer.gpu_specs()) + self.0.state.borrow().renderer.gpu_specs().log_err() } fn update_ime_position(&self, _bounds: Bounds<ScaledPixels>) { - // todo(windows) + // There is no such thing on Windows. } } #[implement(IDropTarget)] -struct WindowsDragDropHandler(pub Rc<WindowsWindowStatePtr>); +struct WindowsDragDropHandler(pub Rc<WindowsWindowInner>); impl WindowsDragDropHandler { fn handle_drag_drop(&self, input: PlatformInput) { @@ -1096,15 +1117,15 @@ enum WindowOpenState { Windowed, } -fn register_wnd_class(icon_handle: HICON) -> PCWSTR { - const CLASS_NAME: PCWSTR = w!("Zed::Window"); +const WINDOW_CLASS_NAME: PCWSTR = w!("Zed::Window"); +fn register_window_class(icon_handle: HICON) { static ONCE: Once = Once::new(); ONCE.call_once(|| { let wc = WNDCLASSW { - lpfnWndProc: Some(wnd_proc), + lpfnWndProc: Some(window_procedure), hIcon: icon_handle, - lpszClassName: PCWSTR(CLASS_NAME.as_ptr()), + lpszClassName: PCWSTR(WINDOW_CLASS_NAME.as_ptr()), style: CS_HREDRAW | CS_VREDRAW, hInstance: get_module_handle().into(), hbrBackground: unsafe { CreateSolidBrush(COLORREF(0x00000000)) }, @@ -1112,54 +1133,58 @@ fn register_wnd_class(icon_handle: HICON) -> PCWSTR { }; unsafe { RegisterClassW(&wc) }; }); - - CLASS_NAME } -unsafe extern "system" fn wnd_proc( +unsafe extern "system" fn window_procedure( hwnd: HWND, msg: u32, wparam: WPARAM, lparam: LPARAM, ) -> LRESULT { if msg == WM_NCCREATE { - let cs = lparam.0 as *const CREATESTRUCTW; - let cs = unsafe { &*cs }; - let ctx = cs.lpCreateParams as *mut WindowCreateContext; - let ctx = unsafe { &mut *ctx }; - let creation_result = WindowsWindowStatePtr::new(ctx, hwnd, cs); - if creation_result.is_err() { - ctx.inner = Some(creation_result); - return LRESULT(0); - } - let weak = Box::new(Rc::downgrade(creation_result.as_ref().unwrap())); - unsafe { set_window_long(hwnd, GWLP_USERDATA, Box::into_raw(weak) as isize) }; - ctx.inner = Some(creation_result); - return unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) }; + let window_params = lparam.0 as *const CREATESTRUCTW; + let window_params = unsafe { &*window_params }; + let window_creation_context = window_params.lpCreateParams as *mut WindowCreateContext; + let window_creation_context = unsafe { &mut *window_creation_context }; + return match WindowsWindowInner::new(window_creation_context, hwnd, window_params) { + Ok(window_state) => { + let weak = Box::new(Rc::downgrade(&window_state)); + unsafe { set_window_long(hwnd, GWLP_USERDATA, Box::into_raw(weak) as isize) }; + window_creation_context.inner = Some(Ok(window_state)); + unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) } + } + Err(error) => { + window_creation_context.inner = Some(Err(error)); + LRESULT(0) + } + }; } - let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowStatePtr>; + + let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowInner>; if ptr.is_null() { return unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) }; } let inner = unsafe { &*ptr }; - let r = if let Some(state) = inner.upgrade() { - handle_msg(hwnd, msg, wparam, lparam, state) + let result = if let Some(inner) = inner.upgrade() { + inner.handle_msg(hwnd, msg, wparam, lparam) } else { unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) } }; + if msg == WM_NCDESTROY { unsafe { set_window_long(hwnd, GWLP_USERDATA, 0) }; unsafe { drop(Box::from_raw(ptr)) }; } - r + + result } -pub(crate) fn try_get_window_inner(hwnd: HWND) -> Option<Rc<WindowsWindowStatePtr>> { +pub(crate) fn window_from_hwnd(hwnd: HWND) -> Option<Rc<WindowsWindowInner>> { if hwnd.is_invalid() { return None; } - let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowStatePtr>; + let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowInner>; if !ptr.is_null() { let inner = unsafe { &*ptr }; inner.upgrade() @@ -1182,9 +1207,9 @@ fn get_module_handle() -> HMODULE { } } -fn register_drag_drop(state_ptr: Rc<WindowsWindowStatePtr>) -> Result<()> { - let window_handle = state_ptr.hwnd; - let handler = WindowsDragDropHandler(state_ptr); +fn register_drag_drop(window: &Rc<WindowsWindowInner>) -> Result<()> { + let window_handle = window.hwnd; + let handler = WindowsDragDropHandler(window.clone()); // The lifetime of `IDropTarget` is handled by Windows, it won't release until // we call `RevokeDragDrop`. // So, it's safe to drop it here. @@ -1306,52 +1331,6 @@ fn set_window_composition_attribute(hwnd: HWND, color: Option<Color>, state: u32 } } -mod windows_renderer { - use crate::platform::blade::{BladeContext, BladeRenderer, BladeSurfaceConfig}; - use raw_window_handle as rwh; - use std::num::NonZeroIsize; - use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::GWLP_HINSTANCE}; - - use crate::{get_window_long, show_error}; - - pub(super) fn init( - context: &BladeContext, - hwnd: HWND, - transparent: bool, - ) -> anyhow::Result<BladeRenderer> { - let raw = RawWindow { hwnd }; - let config = BladeSurfaceConfig { - size: Default::default(), - transparent, - }; - BladeRenderer::new(context, &raw, config) - .inspect_err(|err| show_error("Failed to initialize BladeRenderer", err.to_string())) - } - - struct RawWindow { - hwnd: HWND, - } - - impl rwh::HasWindowHandle for RawWindow { - fn window_handle(&self) -> Result<rwh::WindowHandle<'_>, rwh::HandleError> { - Ok(unsafe { - let hwnd = NonZeroIsize::new_unchecked(self.hwnd.0 as isize); - let mut handle = rwh::Win32WindowHandle::new(hwnd); - let hinstance = get_window_long(self.hwnd, GWLP_HINSTANCE); - handle.hinstance = NonZeroIsize::new(hinstance); - rwh::WindowHandle::borrow_raw(handle.into()) - }) - } - } - - impl rwh::HasDisplayHandle for RawWindow { - fn display_handle(&self) -> Result<rwh::DisplayHandle<'_>, rwh::HandleError> { - let handle = rwh::WindowsDisplayHandle::new(); - Ok(unsafe { rwh::DisplayHandle::borrow_raw(handle.into()) }) - } - } -} - #[cfg(test)] mod tests { use super::ClickState; diff --git a/crates/gpui/src/tab_stop.rs b/crates/gpui/src/tab_stop.rs index 2ec3f560e8..7dde42efed 100644 --- a/crates/gpui/src/tab_stop.rs +++ b/crates/gpui/src/tab_stop.rs @@ -5,7 +5,7 @@ use crate::{FocusHandle, FocusId}; /// Used to manage the `Tab` event to switch between focus handles. #[derive(Default)] pub(crate) struct TabHandles { - handles: Vec<FocusHandle>, + pub(crate) handles: Vec<FocusHandle>, } impl TabHandles { @@ -32,20 +32,18 @@ impl TabHandles { self.handles.clear(); } - fn current_index(&self, focused_id: Option<&FocusId>) -> usize { - self.handles - .iter() - .position(|h| Some(&h.id) == focused_id) - .unwrap_or_default() + fn current_index(&self, focused_id: Option<&FocusId>) -> Option<usize> { + self.handles.iter().position(|h| Some(&h.id) == focused_id) } pub(crate) fn next(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { - let ix = self.current_index(focused_id); - - let mut next_ix = ix + 1; - if next_ix + 1 > self.handles.len() { - next_ix = 0; - } + let next_ix = self + .current_index(focused_id) + .and_then(|ix| { + let next_ix = ix + 1; + (next_ix < self.handles.len()).then_some(next_ix) + }) + .unwrap_or_default(); if let Some(next_handle) = self.handles.get(next_ix) { Some(next_handle.clone()) @@ -55,7 +53,7 @@ impl TabHandles { } pub(crate) fn prev(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { - let ix = self.current_index(focused_id); + let ix = self.current_index(focused_id).unwrap_or_default(); let prev_ix; if ix == 0 { prev_ix = self.handles.len().saturating_sub(1); @@ -108,8 +106,14 @@ mod tests { ] ); - // next - assert_eq!(tab.next(None), Some(tab.handles[1].clone())); + // Select first tab index if no handle is currently focused. + assert_eq!(tab.next(None), Some(tab.handles[0].clone())); + // Select last tab index if no handle is currently focused. + assert_eq!( + tab.prev(None), + Some(tab.handles[tab.handles.len() - 1].clone()) + ); + assert_eq!( tab.next(Some(&tab.handles[0].id)), Some(tab.handles[1].clone()) diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index 963d2bb45c..9e4c1c26c5 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -702,6 +702,7 @@ pub(crate) struct PaintIndex { input_handlers_index: usize, cursor_styles_index: usize, accessed_element_states_index: usize, + tab_handle_index: usize, line_layout_index: LineLayoutIndex, } @@ -1019,7 +1020,7 @@ impl Window { || (active.get() && last_input_timestamp.get().elapsed() < Duration::from_secs(1)); - if invalidator.is_dirty() { + if invalidator.is_dirty() || request_frame_options.force_render { measure("frame duration", || { handle .update(&mut cx, |_, window, cx| { @@ -2208,6 +2209,7 @@ impl Window { input_handlers_index: self.next_frame.input_handlers.len(), cursor_styles_index: self.next_frame.cursor_styles.len(), accessed_element_states_index: self.next_frame.accessed_element_states.len(), + tab_handle_index: self.next_frame.tab_handles.handles.len(), line_layout_index: self.text_system.layout_index(), } } @@ -2237,6 +2239,12 @@ impl Window { .iter() .map(|(id, type_id)| (GlobalElementId(id.0.clone()), *type_id)), ); + self.next_frame.tab_handles.handles.extend( + self.rendered_frame.tab_handles.handles + [range.start.tab_handle_index..range.end.tab_handle_index] + .iter() + .cloned(), + ); self.text_system .reuse_layouts(range.start.line_layout_index..range.end.line_layout_index); @@ -4691,6 +4699,8 @@ pub enum ElementId { Path(Arc<std::path::Path>), /// A code location. CodeLocation(core::panic::Location<'static>), + /// A labeled child of an element. + NamedChild(Box<ElementId>, SharedString), } impl ElementId { @@ -4711,6 +4721,7 @@ impl Display for ElementId { ElementId::Uuid(uuid) => write!(f, "{}", uuid)?, ElementId::Path(path) => write!(f, "{}", path.display())?, ElementId::CodeLocation(location) => write!(f, "{}", location)?, + ElementId::NamedChild(id, name) => write!(f, "{}-{}", id, name)?, } Ok(()) @@ -4801,6 +4812,12 @@ impl From<(&'static str, u32)> for ElementId { } } +impl<T: Into<SharedString>> From<(ElementId, T)> for ElementId { + fn from((id, name): (ElementId, T)) -> Self { + ElementId::NamedChild(Box::new(id), name.into()) + } +} + /// A rectangle to be rendered in the window at the given position and size. /// Passed as an argument [`Window::paint_quad`]. #[derive(Clone)] diff --git a/crates/http_client/Cargo.toml b/crates/http_client/Cargo.toml index 2045708ff2..f63bff295e 100644 --- a/crates/http_client/Cargo.toml +++ b/crates/http_client/Cargo.toml @@ -23,6 +23,8 @@ futures.workspace = true http.workspace = true http-body.workspace = true log.workspace = true +parking_lot.workspace = true +reqwest.workspace = true serde.workspace = true serde_json.workspace = true url.workspace = true diff --git a/crates/http_client/src/async_body.rs b/crates/http_client/src/async_body.rs index 88972d279c..473849f3cd 100644 --- a/crates/http_client/src/async_body.rs +++ b/crates/http_client/src/async_body.rs @@ -88,6 +88,17 @@ impl From<&'static str> for AsyncBody { } } +impl TryFrom<reqwest::Body> for AsyncBody { + type Error = anyhow::Error; + + fn try_from(value: reqwest::Body) -> Result<Self, Self::Error> { + value + .as_bytes() + .ok_or_else(|| anyhow::anyhow!("Underlying data is a stream")) + .map(|bytes| Self::from_bytes(Bytes::copy_from_slice(bytes))) + } +} + impl<T: Into<Self>> From<Option<T>> for AsyncBody { fn from(body: Option<T>) -> Self { match body { diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index eebab86e21..a7f75b0962 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -4,16 +4,18 @@ pub mod github; pub use anyhow::{Result, anyhow}; pub use async_body::{AsyncBody, Inner}; use derive_more::Deref; +use http::HeaderValue; pub use http::{self, Method, Request, Response, StatusCode, Uri}; -use futures::future::BoxFuture; +use futures::{ + FutureExt as _, + future::{self, BoxFuture}, +}; use http::request::Builder; +use parking_lot::Mutex; #[cfg(feature = "test-support")] use std::fmt; -use std::{ - any::type_name, - sync::{Arc, Mutex}, -}; +use std::{any::type_name, sync::Arc}; pub use url::Url; #[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] @@ -39,6 +41,8 @@ impl HttpRequestExt for http::request::Builder { pub trait HttpClient: 'static + Send + Sync { fn type_name(&self) -> &'static str; + fn user_agent(&self) -> Option<&HeaderValue>; + fn send( &self, req: http::Request<AsyncBody>, @@ -83,6 +87,19 @@ pub trait HttpClient: 'static + Send + Sync { } fn proxy(&self) -> Option<&Url>; + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::<Self>()) + } + + fn send_multipart_form<'a>( + &'a self, + _url: &str, + _request: reqwest::multipart::Form, + ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> { + future::ready(Err(anyhow!("not implemented"))).boxed() + } } /// An [`HttpClient`] that may have a proxy. @@ -118,21 +135,8 @@ impl HttpClient for HttpClientWithProxy { self.client.send(req) } - fn proxy(&self) -> Option<&Url> { - self.proxy.as_ref() - } - - fn type_name(&self) -> &'static str { - self.client.type_name() - } -} - -impl HttpClient for Arc<HttpClientWithProxy> { - fn send( - &self, - req: Request<AsyncBody>, - ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - self.client.send(req) + fn user_agent(&self) -> Option<&HeaderValue> { + self.client.user_agent() } fn proxy(&self) -> Option<&Url> { @@ -142,6 +146,19 @@ impl HttpClient for Arc<HttpClientWithProxy> { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } + + fn send_multipart_form<'a>( + &'a self, + url: &str, + form: reqwest::multipart::Form, + ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> { + self.client.send_multipart_form(url, form) + } } /// An [`HttpClient`] that has a base URL. @@ -188,20 +205,13 @@ impl HttpClientWithUrl { /// Returns the base URL. pub fn base_url(&self) -> String { - self.base_url - .lock() - .map_or_else(|_| Default::default(), |url| url.clone()) + self.base_url.lock().clone() } /// Sets the base URL. pub fn set_base_url(&self, base_url: impl Into<String>) { let base_url = base_url.into(); - self.base_url - .lock() - .map(|mut url| { - *url = base_url; - }) - .ok(); + *self.base_url.lock() = base_url; } /// Builds a URL using the given path. @@ -225,6 +235,22 @@ impl HttpClientWithUrl { )?) } + /// Builds a Zed Cloud URL using the given path. + pub fn build_zed_cloud_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> { + let base_url = self.base_url(); + let base_api_url = match base_url.as_ref() { + "https://zed.dev" => "https://cloud.zed.dev", + "https://staging.zed.dev" => "https://cloud.zed.dev", + "http://localhost:3000" => "http://localhost:8787", + other => other, + }; + + Ok(Url::parse_with_params( + &format!("{}{}", base_api_url, path), + query, + )?) + } + /// Builds a Zed LLM URL using the given path. pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> { let base_url = self.base_url(); @@ -242,23 +268,6 @@ impl HttpClientWithUrl { } } -impl HttpClient for Arc<HttpClientWithUrl> { - fn send( - &self, - req: Request<AsyncBody>, - ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - self.client.send(req) - } - - fn proxy(&self) -> Option<&Url> { - self.client.proxy.as_ref() - } - - fn type_name(&self) -> &'static str { - self.client.type_name() - } -} - impl HttpClient for HttpClientWithUrl { fn send( &self, @@ -267,6 +276,10 @@ impl HttpClient for HttpClientWithUrl { self.client.send(req) } + fn user_agent(&self) -> Option<&HeaderValue> { + self.client.user_agent() + } + fn proxy(&self) -> Option<&Url> { self.client.proxy.as_ref() } @@ -274,6 +287,19 @@ impl HttpClient for HttpClientWithUrl { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } + + fn send_multipart_form<'a>( + &'a self, + url: &str, + request: reqwest::multipart::Form, + ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> { + self.client.send_multipart_form(url, request) + } } pub fn read_proxy_from_env() -> Option<Url> { @@ -314,6 +340,10 @@ impl HttpClient for BlockedHttpClient { }) } + fn user_agent(&self) -> Option<&HeaderValue> { + None + } + fn proxy(&self) -> Option<&Url> { None } @@ -321,10 +351,15 @@ impl HttpClient for BlockedHttpClient { fn type_name(&self) -> &'static str { type_name::<Self>() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::<Self>()) + } } #[cfg(feature = "test-support")] -type FakeHttpHandler = Box< +type FakeHttpHandler = Arc< dyn Fn(Request<AsyncBody>) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> + Send + Sync @@ -333,7 +368,8 @@ type FakeHttpHandler = Box< #[cfg(feature = "test-support")] pub struct FakeHttpClient { - handler: FakeHttpHandler, + handler: Mutex<Option<FakeHttpHandler>>, + user_agent: HeaderValue, } #[cfg(feature = "test-support")] @@ -347,7 +383,8 @@ impl FakeHttpClient { base_url: Mutex::new("http://test.example".into()), client: HttpClientWithProxy { client: Arc::new(Self { - handler: Box::new(move |req| Box::pin(handler(req))), + handler: Mutex::new(Some(Arc::new(move |req| Box::pin(handler(req))))), + user_agent: HeaderValue::from_static(type_name::<Self>()), }), proxy: None, }, @@ -371,6 +408,18 @@ impl FakeHttpClient { .unwrap()) }) } + + pub fn replace_handler<Fut, F>(&self, new_handler: F) + where + Fut: futures::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send + 'static, + F: Fn(FakeHttpHandler, Request<AsyncBody>) -> Fut + Send + Sync + 'static, + { + let mut handler = self.handler.lock(); + let old_handler = handler.take().unwrap(); + *handler = Some(Arc::new(move |req| { + Box::pin(new_handler(old_handler.clone(), req)) + })); + } } #[cfg(feature = "test-support")] @@ -386,10 +435,14 @@ impl HttpClient for FakeHttpClient { &self, req: Request<AsyncBody>, ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - let future = (self.handler)(req); + let future = (self.handler.lock().as_ref().unwrap())(req); future } + fn user_agent(&self) -> Option<&HeaderValue> { + Some(&self.user_agent) + } + fn proxy(&self) -> Option<&Url> { None } @@ -397,4 +450,8 @@ impl HttpClient for FakeHttpClient { fn type_name(&self) -> &'static str { type_name::<Self>() } + + fn as_fake(&self) -> &FakeHttpClient { + self + } } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index e7066ae151..a94d89bdc8 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -38,7 +38,6 @@ pub enum IconName { ArrowUpFromLine, ArrowUpRight, ArrowUpRightAlt, - AtSign, AudioOff, AudioOn, Backspace, @@ -48,15 +47,13 @@ pub enum IconName { BellRing, Binary, Blocks, - Bolt, + BoltOutlined, BoltFilled, - BoltFilledAlt, Book, BookCopy, - BookPlus, - Brain, BugOff, CaseSensitive, + Chat, Check, CheckDouble, ChevronDown, @@ -71,6 +68,7 @@ pub enum IconName { CircleHelp, Close, Cloud, + CloudDownload, Code, Cog, Command, @@ -106,6 +104,12 @@ pub enum IconName { Disconnected, DocumentText, Download, + EditorAtom, + EditorCursor, + EditorEmacs, + EditorJetBrains, + EditorSublime, + EditorVsCode, Ellipsis, EllipsisVertical, Envelope, @@ -177,14 +181,9 @@ pub enum IconName { Maximize, Menu, MenuAlt, - MessageBubbles, Mic, MicMute, - Microscope, Minimize, - NewFromSummary, - NewTextThread, - NewThread, Option, PageDown, PageUp, @@ -195,9 +194,7 @@ pub enum IconName { PersonCircle, PhoneIncoming, Pin, - Play, - PlayAlt, - PlayBug, + PlayOutlined, PlayFilled, Plus, PocketKnife, @@ -214,7 +211,6 @@ pub enum IconName { ReplyArrowRight, Rerun, Return, - Reveal, RotateCcw, RotateCw, Route, @@ -228,6 +224,7 @@ pub enum IconName { Server, Settings, SettingsAlt, + ShieldCheck, Shift, Slash, SlashSquare, @@ -238,7 +235,6 @@ pub enum IconName { Sparkle, SparkleAlt, SparkleFilled, - Spinner, Split, SplitAlt, SquareDot, @@ -248,7 +244,6 @@ pub enum IconName { StarFilled, Stop, StopFilled, - Strikethrough, Supermaven, SupermavenDisabled, SupermavenError, @@ -258,6 +253,9 @@ pub enum IconName { Terminal, TerminalAlt, TextSnippet, + TextThread, + Thread, + ThreadFromSummary, ThumbsDown, ThumbsUp, TodoComplete, @@ -277,7 +275,6 @@ pub enum IconName { ToolTerminal, ToolWeb, Trash, - TrashAlt, Triangle, TriangleRight, Undo, diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 1df33286ee..894625b982 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -161,12 +161,11 @@ pub struct CachedLspAdapter { pub name: LanguageServerName, pub disk_based_diagnostic_sources: Vec<String>, pub disk_based_diagnostics_progress_token: Option<String>, - language_ids: HashMap<String, String>, + language_ids: HashMap<LanguageName, String>, pub adapter: Arc<dyn LspAdapter>, pub reinstall_attempt_count: AtomicU64, cached_binary: futures::lock::Mutex<Option<LanguageServerBinary>>, manifest_name: OnceLock<Option<ManifestName>>, - attach_kind: OnceLock<Attach>, } impl Debug for CachedLspAdapter { @@ -202,7 +201,6 @@ impl CachedLspAdapter { adapter, cached_binary: Default::default(), reinstall_attempt_count: AtomicU64::new(0), - attach_kind: Default::default(), manifest_name: Default::default(), }) } @@ -279,38 +277,25 @@ impl CachedLspAdapter { pub fn language_id(&self, language_name: &LanguageName) -> String { self.language_ids - .get(language_name.as_ref()) + .get(language_name) .cloned() .unwrap_or_else(|| language_name.lsp_id()) } + pub fn manifest_name(&self) -> Option<ManifestName> { self.manifest_name .get_or_init(|| self.adapter.manifest_name()) .clone() } - pub fn attach_kind(&self) -> Attach { - *self.attach_kind.get_or_init(|| self.adapter.attach_kind()) - } } +/// Determines what gets sent out as a workspace folders content #[derive(Clone, Copy, Debug, PartialEq)] -pub enum Attach { - /// Create a single language server instance per subproject root. - InstancePerRoot, - /// Use one shared language server instance for all subprojects within a project. - Shared, -} - -impl Attach { - pub fn root_path( - &self, - root_subproject_path: (WorktreeId, Arc<Path>), - ) -> (WorktreeId, Arc<Path>) { - match self { - Attach::InstancePerRoot => root_subproject_path, - Attach::Shared => (root_subproject_path.0, Arc::from(Path::new(""))), - } - } +pub enum WorkspaceFoldersContent { + /// Send out a single entry with the root of the workspace. + WorktreeRoot, + /// Send out a list of subproject roots. + SubprojectRoots, } /// [`LspAdapterDelegate`] allows [`LspAdapter]` implementations to interface with the application @@ -589,8 +574,8 @@ pub trait LspAdapter: 'static + Send + Sync { None } - fn language_ids(&self) -> HashMap<String, String> { - Default::default() + fn language_ids(&self) -> HashMap<LanguageName, String> { + HashMap::default() } /// Support custom initialize params. @@ -602,8 +587,11 @@ pub trait LspAdapter: 'static + Send + Sync { Ok(original) } - fn attach_kind(&self) -> Attach { - Attach::Shared + /// Determines whether a language server supports workspace folders. + /// + /// And does not trip over itself in the process. + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::SubprojectRoots } fn manifest_name(&self) -> Option<ManifestName> { diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index ab3c0f9b37..85123d2373 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -411,30 +411,6 @@ impl LanguageRegistry { cached } - pub fn get_or_register_lsp_adapter( - &self, - language_name: LanguageName, - server_name: LanguageServerName, - build_adapter: impl FnOnce() -> Arc<dyn LspAdapter> + 'static, - ) -> Arc<CachedLspAdapter> { - let registered = self - .state - .write() - .lsp_adapters - .entry(language_name.clone()) - .or_default() - .iter() - .find(|cached_adapter| cached_adapter.name == server_name) - .cloned(); - - if let Some(found) = registered { - found - } else { - let adapter = build_adapter(); - self.register_lsp_adapter(language_name, adapter) - } - } - /// Register a fake language server and adapter /// The returned channel receives a new instance of the language server every time it is started #[cfg(any(feature = "test-support", test))] diff --git a/crates/language/src/syntax_map.rs b/crates/language/src/syntax_map.rs index f441114a90..c56ffed066 100644 --- a/crates/language/src/syntax_map.rs +++ b/crates/language/src/syntax_map.rs @@ -17,7 +17,7 @@ use std::{ sync::Arc, }; use streaming_iterator::StreamingIterator; -use sum_tree::{Bias, SeekTarget, SumTree}; +use sum_tree::{Bias, Dimensions, SeekTarget, SumTree}; use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, Rope, ToOffset, ToPoint}; use tree_sitter::{Node, Query, QueryCapture, QueryCaptures, QueryCursor, QueryMatches, Tree}; @@ -285,7 +285,7 @@ impl SyntaxSnapshot { pub fn interpolate(&mut self, text: &BufferSnapshot) { let edits = text - .anchored_edits_since::<(usize, Point)>(&self.interpolated_version) + .anchored_edits_since::<Dimensions<usize, Point>>(&self.interpolated_version) .collect::<Vec<_>>(); self.interpolated_version = text.version().clone(); @@ -333,7 +333,8 @@ impl SyntaxSnapshot { }; let Some(layer) = cursor.item() else { break }; - let (start_byte, start_point) = layer.range.start.summary::<(usize, Point)>(text); + let Dimensions(start_byte, start_point, _) = + layer.range.start.summary::<Dimensions<usize, Point>>(text); // Ignore edits that end before the start of this layer, and don't consider them // for any subsequent layers at this same depth. @@ -562,8 +563,8 @@ impl SyntaxSnapshot { } let Some(step) = step else { break }; - let (step_start_byte, step_start_point) = - step.range.start.summary::<(usize, Point)>(text); + let Dimensions(step_start_byte, step_start_point, _) = + step.range.start.summary::<Dimensions<usize, Point>>(text); let step_end_byte = step.range.end.to_offset(text); let mut old_layer = cursor.item(); diff --git a/crates/language_extension/src/extension_lsp_adapter.rs b/crates/language_extension/src/extension_lsp_adapter.rs index 58fbe6cda2..98b6fd4b5a 100644 --- a/crates/language_extension/src/extension_lsp_adapter.rs +++ b/crates/language_extension/src/extension_lsp_adapter.rs @@ -242,7 +242,7 @@ impl LspAdapter for ExtensionLspAdapter { ])) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { // TODO: The language IDs can be provided via the language server options // in `extension.toml now but we're leaving these existing usages in place temporarily // to avoid any compatibility issues between Zed and the extension versions. @@ -250,7 +250,7 @@ impl LspAdapter for ExtensionLspAdapter { // We can remove once the following extension versions no longer see any use: // - php@0.0.1 if self.extension.manifest().id.as_ref() == "php" { - return HashMap::from_iter([("PHP".into(), "php".into())]); + return HashMap::from_iter([(LanguageName::new("PHP"), "php".into())]); } self.extension diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index b718c530f5..841be60b0e 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,6 +20,7 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true base64.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true @@ -37,7 +38,6 @@ telemetry_events.workspace = true thiserror.workspace = true util.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 54640419b6..1637d2de8a 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -11,6 +11,7 @@ pub mod fake_provider; use anthropic::{AnthropicError, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::Client; +use cloud_llm_client::{CompletionMode, CompletionRequestStatus}; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; @@ -26,7 +27,6 @@ use std::time::Duration; use std::{fmt, io}; use thiserror::Error; use util::serde::is_default; -use zed_llm_client::{CompletionMode, CompletionRequestStatus}; pub use crate::model::*; pub use crate::rate_limiter::*; diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 72b7132c60..8ae5893410 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,10 +3,11 @@ use std::sync::Arc; use anyhow::Result; use client::Client; +use cloud_llm_client::Plan; use gpui::{ App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, }; -use proto::{Plan, TypedEnvelope}; +use proto::TypedEnvelope; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -30,7 +31,7 @@ pub struct ModelRequestLimitReachedError { impl fmt::Display for ModelRequestLimitReachedError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let message = match self.plan { - Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.", + Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.", Plan::ZedPro => { "Model request limit reached. Upgrade to usage-based billing for more requests." } @@ -64,9 +65,14 @@ impl LlmApiToken { mut lock: RwLockWriteGuard<'_, Option<String>>, client: &Arc<Client>, ) -> Result<String> { - let response = client.request(proto::GetLlmToken {}).await?; - *lock = Some(response.token.clone()); - Ok(response.token.clone()) + let system_id = client + .telemetry() + .system_id() + .map(|system_id| system_id.to_string()); + + let response = client.cloud_client().create_llm_token(system_id).await?; + *lock = Some(response.token.0.clone()); + Ok(response.token.0.clone()) } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 6f3d420ad5..dc485e9937 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,10 +1,9 @@ use std::io::{Cursor, Write}; use std::sync::Arc; -use crate::role::Role; -use crate::{LanguageModelToolUse, LanguageModelToolUseId}; use anyhow::Result; use base64::write::EncoderWriter; +use cloud_llm_client::{CompletionIntent, CompletionMode}; use gpui::{ App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task, point, px, size, @@ -12,7 +11,9 @@ use gpui::{ use image::codecs::png::PngEncoder; use serde::{Deserialize, Serialize}; use util::ResultExt; -use zed_llm_client::{CompletionIntent, CompletionMode}; + +use crate::role::Role; +use crate::{LanguageModelToolUse, LanguageModelToolUseId}; #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct LanguageModelImage { diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 6c2bf6739a..ad4e593d4f 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -16,18 +16,17 @@ ai_onboarding.workspace = true anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true aws-config = { workspace = true, features = ["behavior-version-latest"] } -aws-credential-types = { workspace = true, features = [ - "hardcoded-credentials", -] } +aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] } aws_http_client.workspace = true bedrock.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true component.workspace = true -credentials_provider.workspace = true convert_case.workspace = true copilot.workspace = true +credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } editor.workspace = true fs.workspace = true @@ -36,6 +35,7 @@ google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true gpui_tokio.workspace = true http_client.workspace = true +language.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true @@ -44,10 +44,7 @@ mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } -vercel = { workspace = true, features = ["schemars"] } -x_ai = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true -proto.workspace = true release_channel.workspace = true schemars.workspace = true serde.workspace = true @@ -62,9 +59,9 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true ui_input.workspace = true util.workspace = true +vercel = { workspace = true, features = ["schemars"] } workspace-hack.workspace = true -zed_llm_client.workspace = true -language.workspace = true +x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] editor = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 959cbccf39..ef21e85f71 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1012,7 +1012,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with Anthropic, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:")) .child( List::new() .child( diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index a86b3e78f5..6df96c5c56 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -1251,7 +1251,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(ConfigurationView::save_credentials)) - .child(Label::new("To use Zed's assistant with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials.")) + .child(Label::new("To use Zed's agent with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials.")) .child(Label::new("But, to access models on AWS, you need to:").mt_1()) .child( List::new() diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 09a2ac6e0a..2108547c4f 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -3,6 +3,13 @@ use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use client::{Client, ModelRequestUsage, UserStore, zed_urls}; +use cloud_llm_client::{ + CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, + CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, + EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, + SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, + TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, +}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; @@ -20,7 +27,6 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, }; -use proto::Plan; use release_channel::AppVersion; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -33,13 +39,6 @@ use std::time::Duration; use thiserror::Error; use ui::{TintColor, prelude::*}; use util::{ResultExt as _, maybe}; -use zed_llm_client::{ - CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, - CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, - ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, - SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, - TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, -}; use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; use crate::provider::google::{GoogleEventMapper, into_google}; @@ -120,10 +119,10 @@ pub struct State { user_store: Entity<UserStore>, status: client::Status, accept_terms_of_service_task: Option<Task<Result<()>>>, - models: Vec<Arc<zed_llm_client::LanguageModel>>, - default_model: Option<Arc<zed_llm_client::LanguageModel>>, - default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>, - recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>, + models: Vec<Arc<cloud_llm_client::LanguageModel>>, + default_model: Option<Arc<cloud_llm_client::LanguageModel>>, + default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>, + recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>, _fetch_models_task: Task<()>, _settings_subscription: Subscription, _llm_token_subscription: Subscription, @@ -137,11 +136,10 @@ impl State { cx: &mut Context<Self>, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - Self { client: client.clone(), llm_api_token: LlmApiToken::default(), - user_store, + user_store: user_store.clone(), status, accept_terms_of_service_task: None, models: Vec::new(), @@ -154,8 +152,9 @@ impl State { .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; loop { - let status = this.read_with(cx, |this, _cx| this.status)?; - if matches!(status, client::Status::Connected { .. }) { + let is_authenticated = user_store + .read_with(cx, |user_store, _cx| user_store.current_user().is_some())?; + if is_authenticated { break; } @@ -194,26 +193,20 @@ impl State { } } - fn is_signed_out(&self) -> bool { - self.status.is_signed_out() + fn is_signed_out(&self, cx: &App) -> bool { + self.user_store.read(cx).current_user().is_none() } fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { let client = self.client.clone(); cx.spawn(async move |state, cx| { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.sign_in_with_optional_connect(true, &cx).await?; state.update(cx, |_, cx| cx.notify()) }) } fn has_accepted_terms_of_service(&self, cx: &App) -> bool { - self.user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false) + self.user_store.read(cx).has_accepted_terms_of_service() } fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) { @@ -238,8 +231,8 @@ impl State { // Right now we represent thinking variants of models as separate models on the client, // so we need to insert variants for any model that supports thinking. if model.supports_thinking { - models.push(Arc::new(zed_llm_client::LanguageModel { - id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), + models.push(Arc::new(cloud_llm_client::LanguageModel { + id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), display_name: format!("{} Thinking", model.display_name), ..model })); @@ -328,7 +321,7 @@ impl CloudLanguageModelProvider { fn create_language_model( &self, - model: Arc<zed_llm_client::LanguageModel>, + model: Arc<cloud_llm_client::LanguageModel>, llm_api_token: LlmApiToken, ) -> Arc<dyn LanguageModel> { Arc::new(CloudLanguageModel { @@ -398,7 +391,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn is_authenticated(&self, cx: &App) -> bool { let state = self.state.read(cx); - !state.is_signed_out() && state.has_accepted_terms_of_service(cx) + !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx) } fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> { @@ -518,7 +511,7 @@ fn render_accept_terms( pub struct CloudLanguageModel { id: LanguageModelId, - model: Arc<zed_llm_client::LanguageModel>, + model: Arc<cloud_llm_client::LanguageModel>, llm_api_token: LlmApiToken, client: Arc<Client>, request_limiter: RateLimiter, @@ -611,13 +604,8 @@ impl CloudLanguageModel { .headers() .get(CURRENT_PLAN_HEADER_NAME) .and_then(|plan| plan.to_str().ok()) - .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok()) + .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok()) { - let plan = match plan { - zed_llm_client::Plan::ZedFree => Plan::Free, - zed_llm_client::Plan::ZedPro => Plan::ZedPro, - zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial, - }; return Err(anyhow!(ModelRequestLimitReachedError { plan })); } } @@ -729,7 +717,7 @@ impl LanguageModel for CloudLanguageModel { } fn upstream_provider_id(&self) -> LanguageModelProviderId { - use zed_llm_client::LanguageModelProvider::*; + use cloud_llm_client::LanguageModelProvider::*; match self.model.provider { Anthropic => language_model::ANTHROPIC_PROVIDER_ID, OpenAi => language_model::OPEN_AI_PROVIDER_ID, @@ -738,7 +726,7 @@ impl LanguageModel for CloudLanguageModel { } fn upstream_provider_name(&self) -> LanguageModelProviderName { - use zed_llm_client::LanguageModelProvider::*; + use cloud_llm_client::LanguageModelProvider::*; match self.model.provider { Anthropic => language_model::ANTHROPIC_PROVIDER_NAME, OpenAi => language_model::OPEN_AI_PROVIDER_NAME, @@ -772,11 +760,11 @@ impl LanguageModel for CloudLanguageModel { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { match self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic - | zed_llm_client::LanguageModelProvider::OpenAi => { + cloud_llm_client::LanguageModelProvider::Anthropic + | cloud_llm_client::LanguageModelProvider::OpenAi => { LanguageModelToolSchemaFormat::JsonSchema } - zed_llm_client::LanguageModelProvider::Google => { + cloud_llm_client::LanguageModelProvider::Google => { LanguageModelToolSchemaFormat::JsonSchemaSubset } } @@ -795,15 +783,15 @@ impl LanguageModel for CloudLanguageModel { fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> { match &self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic => { + cloud_llm_client::LanguageModelProvider::Anthropic => { Some(LanguageModelCacheConfiguration { min_total_token: 2_048, should_speculate: true, max_cache_anchors: 4, }) } - zed_llm_client::LanguageModelProvider::OpenAi - | zed_llm_client::LanguageModelProvider::Google => None, + cloud_llm_client::LanguageModelProvider::OpenAi + | cloud_llm_client::LanguageModelProvider::Google => None, } } @@ -813,15 +801,17 @@ impl LanguageModel for CloudLanguageModel { cx: &App, ) -> BoxFuture<'static, Result<u64>> { match self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx), - zed_llm_client::LanguageModelProvider::OpenAi => { + cloud_llm_client::LanguageModelProvider::Anthropic => { + count_anthropic_tokens(request, cx) + } + cloud_llm_client::LanguageModelProvider::OpenAi => { let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model, Err(err) => return async move { Err(anyhow!(err)) }.boxed(), }; count_open_ai_tokens(request, model, cx) } - zed_llm_client::LanguageModelProvider::Google => { + cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let model_id = self.model.id.to_string(); @@ -832,7 +822,7 @@ impl LanguageModel for CloudLanguageModel { let token = llm_api_token.acquire(&client).await?; let request_body = CountTokensBody { - provider: zed_llm_client::LanguageModelProvider::Google, + provider: cloud_llm_client::LanguageModelProvider::Google, model: model_id, provider_request: serde_json::to_value(&google_ai::CountTokensRequest { generate_content_request, @@ -893,7 +883,7 @@ impl LanguageModel for CloudLanguageModel { let app_version = cx.update(|cx| AppVersion::global(cx)).ok(); let thinking_allowed = request.thinking_allowed; match self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic => { + cloud_llm_client::LanguageModelProvider::Anthropic => { let request = into_anthropic( request, self.model.id.to_string(), @@ -924,7 +914,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: zed_llm_client::LanguageModelProvider::Anthropic, + provider: cloud_llm_client::LanguageModelProvider::Anthropic, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -948,7 +938,7 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } - zed_llm_client::LanguageModelProvider::OpenAi => { + cloud_llm_client::LanguageModelProvider::OpenAi => { let client = self.client.clone(); let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model, @@ -976,7 +966,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: zed_llm_client::LanguageModelProvider::OpenAi, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -996,7 +986,7 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } - zed_llm_client::LanguageModelProvider::Google => { + cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let request = into_google(request, self.model.id.to_string(), GoogleModelMode::Default); @@ -1016,7 +1006,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: zed_llm_client::LanguageModelProvider::Google, + provider: cloud_llm_client::LanguageModelProvider::Google, model: request.model.model_id.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -1040,15 +1030,8 @@ impl LanguageModel for CloudLanguageModel { } } -#[derive(Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CloudCompletionEvent<T> { - Status(CompletionRequestStatus), - Event(T), -} - fn map_cloud_completion_events<T, F>( - stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>, + stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>, mut map_callback: F, ) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> where @@ -1063,10 +1046,10 @@ where Err(error) => { vec![Err(LanguageModelCompletionError::from(error))] } - Ok(CloudCompletionEvent::Status(event)) => { + Ok(CompletionEvent::Status(event)) => { vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] } - Ok(CloudCompletionEvent::Event(event)) => map_callback(event), + Ok(CompletionEvent::Event(event)) => map_callback(event), }) }) .boxed() @@ -1074,9 +1057,9 @@ where fn usage_updated_event<T>( usage: Option<ModelRequestUsage>, -) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { +) -> impl Stream<Item = Result<CompletionEvent<T>>> { futures::stream::iter(usage.map(|usage| { - Ok(CloudCompletionEvent::Status( + Ok(CompletionEvent::Status( CompletionRequestStatus::UsageUpdated { amount: usage.amount as usize, limit: usage.limit, @@ -1087,9 +1070,9 @@ fn usage_updated_event<T>( fn tool_use_limit_reached_event<T>( tool_use_limit_reached: bool, -) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { +) -> impl Stream<Item = Result<CompletionEvent<T>>> { futures::stream::iter(tool_use_limit_reached.then(|| { - Ok(CloudCompletionEvent::Status( + Ok(CompletionEvent::Status( CompletionRequestStatus::ToolUseLimitReached, )) })) @@ -1098,7 +1081,7 @@ fn tool_use_limit_reached_event<T>( fn response_lines<T: DeserializeOwned>( response: Response<AsyncBody>, includes_status_messages: bool, -) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { +) -> impl Stream<Item = Result<CompletionEvent<T>>> { futures::stream::try_unfold( (String::new(), BufReader::new(response.into_body())), move |(mut line, mut body)| async move { @@ -1106,9 +1089,9 @@ fn response_lines<T: DeserializeOwned>( Ok(0) => Ok(None), Ok(_) => { let event = if includes_status_messages { - serde_json::from_str::<CloudCompletionEvent<T>>(&line)? + serde_json::from_str::<CompletionEvent<T>>(&line)? } else { - CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?) + CompletionEvent::Event(serde_json::from_str::<T>(&line)?) }; line.clear(); @@ -1123,7 +1106,7 @@ fn response_lines<T: DeserializeOwned>( #[derive(IntoElement, RegisterComponent)] struct ZedAiConfiguration { is_connected: bool, - plan: Option<proto::Plan>, + plan: Option<Plan>, subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>, eligible_for_trial: bool, has_accepted_terms_of_service: bool, @@ -1137,15 +1120,15 @@ impl RenderOnce for ZedAiConfiguration { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { let young_account_banner = YoungAccountBanner; - let is_pro = self.plan == Some(proto::Plan::ZedPro); + let is_pro = self.plan == Some(Plan::ZedPro); let subscription_text = match (self.plan, self.subscription_period) { - (Some(proto::Plan::ZedPro), Some(_)) => { + (Some(Plan::ZedPro), Some(_)) => { "You have access to Zed's hosted models through your Pro subscription." } - (Some(proto::Plan::ZedProTrial), Some(_)) => { + (Some(Plan::ZedProTrial), Some(_)) => { "You have access to Zed's hosted models through your Pro trial." } - (Some(proto::Plan::Free), Some(_)) => { + (Some(Plan::ZedFree), Some(_)) => { "You have basic access to Zed's hosted models through the Free plan." } _ => { @@ -1270,8 +1253,8 @@ impl Render for ConfigurationView { let user_store = state.user_store.read(cx); ZedAiConfiguration { - is_connected: !state.is_signed_out(), - plan: user_store.current_plan(), + is_connected: !state.is_signed_out(cx), + plan: user_store.plan(), subscription_period: user_store.subscription_period(), eligible_for_trial: user_store.trial_started_at().is_none(), has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), @@ -1291,7 +1274,7 @@ impl Component for ZedAiConfiguration { fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { fn configuration( is_connected: bool, - plan: Option<proto::Plan>, + plan: Option<Plan>, eligible_for_trial: bool, account_too_young: bool, has_accepted_terms_of_service: bool, @@ -1335,15 +1318,15 @@ impl Component for ZedAiConfiguration { ), single_example( "Free Plan", - configuration(true, Some(proto::Plan::Free), true, false, true), + configuration(true, Some(Plan::ZedFree), true, false, true), ), single_example( "Zed Pro Trial Plan", - configuration(true, Some(proto::Plan::ZedProTrial), true, false, true), + configuration(true, Some(Plan::ZedProTrial), true, false, true), ), single_example( "Zed Pro Plan", - configuration(true, Some(proto::Plan::ZedPro), true, false, true), + configuration(true, Some(Plan::ZedPro), true, false, true), ), ]) .into_any_element(), diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index d9a84f1eb7..73f73a9a31 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -3,6 +3,7 @@ use std::str::FromStr as _; use std::sync::Arc; use anyhow::{Result, anyhow}; +use cloud_llm_client::CompletionIntent; use collections::HashMap; use copilot::copilot_chat::{ ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl, @@ -30,7 +31,6 @@ use settings::SettingsStore; use std::time::Duration; use ui::prelude::*; use util::debug_panic; -use zed_llm_client::CompletionIntent; use super::anthropic::count_anthropic_tokens; use super::google::count_google_tokens; @@ -706,7 +706,8 @@ impl Render for ConfigurationView { .child(svg().size_8().path(IconName::CopilotError.path())) } _ => { - const LABEL: &str = "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription."; + const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription."; + v_flex().gap_2().child(Label::new(LABEL)).child( Button::new("sign_in", "Sign in to use GitHub Copilot") .icon_color(Color::Muted) diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index bd8a09970a..b287e8181a 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -880,7 +880,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with Google AI, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with Google AI, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 01600f3646..9792b4f27b 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -744,7 +744,7 @@ impl Render for ConfigurationView { Button::new("retry_lmstudio_models", "Connect") .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) - .icon(IconName::Play) + .icon(IconName::PlayOutlined) .on_click(cx.listener(move |this, _, _window, cx| { this.retry_connection(cx) })), diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index fb385308fa..02e53cb99a 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -807,7 +807,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with Mistral, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index cb341def2f..f4914ff91e 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -23,7 +23,7 @@ use settings::{Settings, SettingsStore, update_settings_file}; use std::pin::Pin; use std::sync::atomic::{AtomicU64, Ordering}; use std::{collections::HashMap, sync::Arc}; -use ui::{Indicator, List, prelude::*}; +use ui::{ButtonLike, Indicator, List, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; @@ -1004,63 +1004,73 @@ impl Render for ConfigurationView { .w_full() .justify_between() .gap_2() - .child({ - let mut buttons = h_flex() + .child( + h_flex() .w_full() - .gap_2(); - if is_authenticated { - buttons = buttons.child( - Button::new("ollama-site", "Ollama Homepage") + .gap_2() + .map(|this| { + if is_authenticated { + this.child( + Button::new("ollama-site", "Ollama") + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE)) + .into_any_element(), + ) + } else { + this.child( + Button::new( + "download_ollama_button", + "Download Ollama", + ) + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _, cx| { + cx.open_url(OLLAMA_DOWNLOAD_URL) + }) + .into_any_element(), + ) + } + }) + .child( + Button::new("view-models", "View All Models") .style(ButtonStyle::Subtle) .icon(IconName::ArrowUpRight) .icon_size(IconSize::XSmall) .icon_color(Color::Muted) - .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE)) - .into_any_element(), - ); - } else { - buttons = buttons.child( - Button::new( - "download_ollama_button", - "Download Ollama", - ) - .style(ButtonStyle::Filled) - .icon(IconName::Download) - .icon_size(IconSize::XSmall) - .on_click(move |_, _, cx| { - cx.open_url(OLLAMA_DOWNLOAD_URL) - }) - .into_any_element(), - ); - } - buttons.child( - Button::new("view-models", "Browse Models") - .style(ButtonStyle::Subtle) - .icon(IconName::Library) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)), - ) - }) - .child( - if is_authenticated { - h_flex() - .gap_2() - .child(Indicator::dot().color(Color::Success)) - .child(Label::new("Connected").size(LabelSize::Small)) - .into_any_element() - } else { - Button::new("retry_ollama_models", "Connect") - .style(ButtonStyle::Filled) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon(IconName::Play) - .on_click(cx.listener(move |this, _, _, cx| { - this.retry_connection(cx) - })) - .into_any_element() - } + .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)), + ), ) + .map(|this| { + if is_authenticated { + this.child( + ButtonLike::new("connected") + .disabled(true) + .cursor_style(gpui::CursorStyle::Arrow) + .child( + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child(Label::new("Connected")) + .into_any_element(), + ), + ) + } else { + this.child( + Button::new("retry_ollama_models", "Connect") + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon(IconName::PlayOutlined) + .on_click(cx.listener(move |this, _, _, cx| { + this.retry_connection(cx) + })), + ) + } + }) ) ) .into_any() diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 6c4d4c9b3e..5185e979b7 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -780,7 +780,7 @@ impl Render for ConfigurationView { let api_key_section = if self.should_render_editor(cx) { v_flex() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( @@ -868,7 +868,7 @@ impl Render for ConfigurationView { .icon_size(IconSize::XSmall) .icon_color(Color::Muted) .on_click(move |_, _window, cx| { - cx.open_url("https://zed.dev/docs/ai/configuration#openai-api-compatible") + cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible") }), ); diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 64add5483d..38bd7cee06 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -466,7 +466,7 @@ impl Render for ConfigurationView { let api_key_section = if self.should_render_editor(cx) { v_flex() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with an OpenAI compatible provider, you need to add an API key.")) + .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key.")) .child( div() .pt(DynamicSpacing::Base04.rems(cx)) diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 5a6acc4329..3a492086f1 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -855,7 +855,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with OpenRouter, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index d1a90d7dbb..606f3a3f0e 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -253,8 +253,8 @@ impl LogStore { let copilot_subscription = Copilot::global(cx).map(|copilot| { let copilot = &copilot; - cx.subscribe(copilot, |this, copilot, inline_completion_event, cx| { - if let copilot::Event::CopilotLanguageServerStarted = inline_completion_event { + cx.subscribe(copilot, |this, copilot, edit_prediction_event, cx| { + if let copilot::Event::CopilotLanguageServerStarted = edit_prediction_event { if let Some(server) = copilot.read(cx).language_server() { let server_id = server.server_id(); let weak_this = cx.weak_entity(); @@ -867,7 +867,7 @@ impl LspLogView { BINARY = server.binary(), WORKSPACE_FOLDERS = server .workspace_folders() - .iter() + .into_iter() .filter_map(|path| path .to_file_path() .ok() diff --git a/crates/language_tools/src/lsp_tool.rs b/crates/language_tools/src/lsp_tool.rs index 9e95ed4673..50547253a9 100644 --- a/crates/language_tools/src/lsp_tool.rs +++ b/crates/language_tools/src/lsp_tool.rs @@ -1015,7 +1015,7 @@ impl Render for LspTool { .anchor(Corner::BottomLeft) .with_handle(self.popover_menu_handle.clone()) .trigger_with_tooltip( - IconButton::new("zed-lsp-tool-button", IconName::BoltFilledAlt) + IconButton::new("zed-lsp-tool-button", IconName::BoltOutlined) .when_some(indicator, IconButton::indicator) .icon_size(IconSize::Small) .indicator_border_color(Some(cx.theme().colors().status_bar_background)), diff --git a/crates/languages/Cargo.toml b/crates/languages/Cargo.toml index 2e8f007cff..260126da63 100644 --- a/crates/languages/Cargo.toml +++ b/crates/languages/Cargo.toml @@ -41,6 +41,7 @@ async-trait.workspace = true chrono.workspace = true collections.workspace = true dap.workspace = true +feature_flags.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true diff --git a/crates/languages/src/json.rs b/crates/languages/src/json.rs index 15818730b8..601b4620c5 100644 --- a/crates/languages/src/json.rs +++ b/crates/languages/src/json.rs @@ -8,8 +8,8 @@ use futures::StreamExt; use gpui::{App, AsyncApp, Task}; use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; use language::{ - ContextProvider, LanguageRegistry, LanguageToolchainStore, LocalFile as _, LspAdapter, - LspAdapterDelegate, + ContextProvider, LanguageName, LanguageRegistry, LanguageToolchainStore, LocalFile as _, + LspAdapter, LspAdapterDelegate, }; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; @@ -408,10 +408,10 @@ impl LspAdapter for JsonLspAdapter { Ok(config) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { [ - ("JSON".into(), "json".into()), - ("JSONC".into(), "jsonc".into()), + (LanguageName::new("JSON"), "json".into()), + (LanguageName::new("JSONC"), "jsonc".into()), ] .into_iter() .collect() diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index a224111002..001fd15200 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -1,4 +1,5 @@ use anyhow::Context as _; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use gpui::{App, UpdateGlobal}; use node_runtime::NodeRuntime; use python::PyprojectTomlManifestProvider; @@ -11,7 +12,7 @@ use util::{ResultExt, asset_str}; pub use language::*; -use crate::json::JsonTaskProvider; +use crate::{json::JsonTaskProvider, python::BasedPyrightLspAdapter}; mod bash; mod c; @@ -52,6 +53,12 @@ pub static LANGUAGE_GIT_COMMIT: std::sync::LazyLock<Arc<Language>> = )) }); +struct BasedPyrightFeatureFlag; + +impl FeatureFlag for BasedPyrightFeatureFlag { + const NAME: &'static str = "basedpyright"; +} + pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { #[cfg(feature = "load-grammars")] languages.register_native_grammars([ @@ -88,6 +95,7 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { let py_lsp_adapter = Arc::new(python::PyLspAdapter::new()); let python_context_provider = Arc::new(python::PythonContextProvider); let python_lsp_adapter = Arc::new(python::PythonLspAdapter::new(node.clone())); + let basedpyright_lsp_adapter = Arc::new(BasedPyrightLspAdapter::new()); let python_toolchain_provider = Arc::new(python::PythonToolchainProvider::default()); let rust_context_provider = Arc::new(rust::RustContextProvider); let rust_lsp_adapter = Arc::new(rust::RustLspAdapter); @@ -228,6 +236,20 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { ); } + let mut basedpyright_lsp_adapter = Some(basedpyright_lsp_adapter); + cx.observe_flag::<BasedPyrightFeatureFlag, _>({ + let languages = languages.clone(); + move |enabled, _| { + if enabled { + if let Some(adapter) = basedpyright_lsp_adapter.take() { + languages + .register_available_lsp_adapter(adapter.name(), move || adapter.clone()); + } + } + } + }) + .detach(); + // Register globally available language servers. // // This will allow users to add support for a built-in language server (e.g., Tailwind) diff --git a/crates/languages/src/python.rs b/crates/languages/src/python.rs index dc6996d399..0524c02fd5 100644 --- a/crates/languages/src/python.rs +++ b/crates/languages/src/python.rs @@ -4,13 +4,13 @@ use async_trait::async_trait; use collections::HashMap; use gpui::{App, Task}; use gpui::{AsyncApp, SharedString}; -use language::Toolchain; use language::ToolchainList; use language::ToolchainLister; use language::language_settings::language_settings; use language::{ContextLocation, LanguageToolchainStore}; use language::{ContextProvider, LspAdapter, LspAdapterDelegate}; use language::{LanguageName, ManifestName, ManifestProvider, ManifestQuery}; +use language::{Toolchain, WorkspaceFoldersContent}; use lsp::LanguageServerBinary; use lsp::LanguageServerName; use node_runtime::NodeRuntime; @@ -400,6 +400,9 @@ impl LspAdapter for PythonLspAdapter { fn manifest_name(&self) -> Option<ManifestName> { Some(SharedString::new_static("pyproject.toml").into()) } + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::WorktreeRoot + } } async fn get_cached_server_binary( @@ -1282,6 +1285,350 @@ impl LspAdapter for PyLspAdapter { fn manifest_name(&self) -> Option<ManifestName> { Some(SharedString::new_static("pyproject.toml").into()) } + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::WorktreeRoot + } +} + +pub(crate) struct BasedPyrightLspAdapter { + python_venv_base: OnceCell<Result<Arc<Path>, String>>, +} + +impl BasedPyrightLspAdapter { + const SERVER_NAME: LanguageServerName = LanguageServerName::new_static("basedpyright"); + const BINARY_NAME: &'static str = "basedpyright-langserver"; + + pub(crate) fn new() -> Self { + Self { + python_venv_base: OnceCell::new(), + } + } + + async fn ensure_venv(delegate: &dyn LspAdapterDelegate) -> Result<Arc<Path>> { + let python_path = Self::find_base_python(delegate) + .await + .context("Could not find Python installation for basedpyright")?; + let work_dir = delegate + .language_server_download_dir(&Self::SERVER_NAME) + .await + .context("Could not get working directory for basedpyright")?; + let mut path = PathBuf::from(work_dir.as_ref()); + path.push("basedpyright-venv"); + if !path.exists() { + util::command::new_smol_command(python_path) + .arg("-m") + .arg("venv") + .arg("basedpyright-venv") + .current_dir(work_dir) + .spawn()? + .output() + .await?; + } + + Ok(path.into()) + } + + // Find "baseline", user python version from which we'll create our own venv. + async fn find_base_python(delegate: &dyn LspAdapterDelegate) -> Option<PathBuf> { + for path in ["python3", "python"] { + if let Some(path) = delegate.which(path.as_ref()).await { + return Some(path); + } + } + None + } + + async fn base_venv(&self, delegate: &dyn LspAdapterDelegate) -> Result<Arc<Path>, String> { + self.python_venv_base + .get_or_init(move || async move { + Self::ensure_venv(delegate) + .await + .map_err(|e| format!("{e}")) + }) + .await + .clone() + } +} + +#[async_trait(?Send)] +impl LspAdapter for BasedPyrightLspAdapter { + fn name(&self) -> LanguageServerName { + Self::SERVER_NAME.clone() + } + + async fn initialization_options( + self: Arc<Self>, + _: &dyn Fs, + _: &Arc<dyn LspAdapterDelegate>, + ) -> Result<Option<Value>> { + // Provide minimal initialization options + // Virtual environment configuration will be handled through workspace configuration + Ok(Some(json!({ + "python": { + "analysis": { + "autoSearchPaths": true, + "useLibraryCodeForTypes": true, + "autoImportCompletions": true + } + } + }))) + } + + async fn check_if_user_installed( + &self, + delegate: &dyn LspAdapterDelegate, + toolchains: Arc<dyn LanguageToolchainStore>, + cx: &AsyncApp, + ) -> Option<LanguageServerBinary> { + if let Some(bin) = delegate.which(Self::BINARY_NAME.as_ref()).await { + let env = delegate.shell_env().await; + Some(LanguageServerBinary { + path: bin, + env: Some(env), + arguments: vec!["--stdio".into()], + }) + } else { + let venv = toolchains + .active_toolchain( + delegate.worktree_id(), + Arc::from("".as_ref()), + LanguageName::new("Python"), + &mut cx.clone(), + ) + .await?; + let path = Path::new(venv.path.as_ref()) + .parent()? + .join(Self::BINARY_NAME); + path.exists().then(|| LanguageServerBinary { + path, + arguments: vec!["--stdio".into()], + env: None, + }) + } + } + + async fn fetch_latest_server_version( + &self, + _: &dyn LspAdapterDelegate, + ) -> Result<Box<dyn 'static + Any + Send>> { + Ok(Box::new(()) as Box<_>) + } + + async fn fetch_server_binary( + &self, + _latest_version: Box<dyn 'static + Send + Any>, + _container_dir: PathBuf, + delegate: &dyn LspAdapterDelegate, + ) -> Result<LanguageServerBinary> { + let venv = self.base_venv(delegate).await.map_err(|e| anyhow!(e))?; + let pip_path = venv.join(BINARY_DIR).join("pip3"); + ensure!( + util::command::new_smol_command(pip_path.as_path()) + .arg("install") + .arg("basedpyright") + .arg("-U") + .output() + .await? + .status + .success(), + "basedpyright installation failed" + ); + let pylsp = venv.join(BINARY_DIR).join(Self::BINARY_NAME); + Ok(LanguageServerBinary { + path: pylsp, + env: None, + arguments: vec!["--stdio".into()], + }) + } + + async fn cached_server_binary( + &self, + _container_dir: PathBuf, + delegate: &dyn LspAdapterDelegate, + ) -> Option<LanguageServerBinary> { + let venv = self.base_venv(delegate).await.ok()?; + let pylsp = venv.join(BINARY_DIR).join(Self::BINARY_NAME); + Some(LanguageServerBinary { + path: pylsp, + env: None, + arguments: vec!["--stdio".into()], + }) + } + + async fn process_completions(&self, items: &mut [lsp::CompletionItem]) { + // Pyright assigns each completion item a `sortText` of the form `XX.YYYY.name`. + // Where `XX` is the sorting category, `YYYY` is based on most recent usage, + // and `name` is the symbol name itself. + // + // Because the symbol name is included, there generally are not ties when + // sorting by the `sortText`, so the symbol's fuzzy match score is not taken + // into account. Here, we remove the symbol name from the sortText in order + // to allow our own fuzzy score to be used to break ties. + // + // see https://github.com/microsoft/pyright/blob/95ef4e103b9b2f129c9320427e51b73ea7cf78bd/packages/pyright-internal/src/languageService/completionProvider.ts#LL2873 + for item in items { + let Some(sort_text) = &mut item.sort_text else { + continue; + }; + let mut parts = sort_text.split('.'); + let Some(first) = parts.next() else { continue }; + let Some(second) = parts.next() else { continue }; + let Some(_) = parts.next() else { continue }; + sort_text.replace_range(first.len() + second.len() + 1.., ""); + } + } + + async fn label_for_completion( + &self, + item: &lsp::CompletionItem, + language: &Arc<language::Language>, + ) -> Option<language::CodeLabel> { + let label = &item.label; + let grammar = language.grammar()?; + let highlight_id = match item.kind? { + lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method")?, + lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function")?, + lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type")?, + lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant")?, + _ => return None, + }; + let filter_range = item + .filter_text + .as_deref() + .and_then(|filter| label.find(filter).map(|ix| ix..ix + filter.len())) + .unwrap_or(0..label.len()); + Some(language::CodeLabel { + text: label.clone(), + runs: vec![(0..label.len(), highlight_id)], + filter_range, + }) + } + + async fn label_for_symbol( + &self, + name: &str, + kind: lsp::SymbolKind, + language: &Arc<language::Language>, + ) -> Option<language::CodeLabel> { + let (text, filter_range, display_range) = match kind { + lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => { + let text = format!("def {}():\n", name); + let filter_range = 4..4 + name.len(); + let display_range = 0..filter_range.end; + (text, filter_range, display_range) + } + lsp::SymbolKind::CLASS => { + let text = format!("class {}:", name); + let filter_range = 6..6 + name.len(); + let display_range = 0..filter_range.end; + (text, filter_range, display_range) + } + lsp::SymbolKind::CONSTANT => { + let text = format!("{} = 0", name); + let filter_range = 0..name.len(); + let display_range = 0..filter_range.end; + (text, filter_range, display_range) + } + _ => return None, + }; + + Some(language::CodeLabel { + runs: language.highlight_text(&text.as_str().into(), display_range.clone()), + text: text[display_range].to_string(), + filter_range, + }) + } + + async fn workspace_configuration( + self: Arc<Self>, + _: &dyn Fs, + adapter: &Arc<dyn LspAdapterDelegate>, + toolchains: Arc<dyn LanguageToolchainStore>, + cx: &mut AsyncApp, + ) -> Result<Value> { + let toolchain = toolchains + .active_toolchain( + adapter.worktree_id(), + Arc::from("".as_ref()), + LanguageName::new("Python"), + cx, + ) + .await; + cx.update(move |cx| { + let mut user_settings = + language_server_settings(adapter.as_ref(), &Self::SERVER_NAME, cx) + .and_then(|s| s.settings.clone()) + .unwrap_or_default(); + + // If we have a detected toolchain, configure Pyright to use it + if let Some(toolchain) = toolchain { + if user_settings.is_null() { + user_settings = Value::Object(serde_json::Map::default()); + } + let object = user_settings.as_object_mut().unwrap(); + + let interpreter_path = toolchain.path.to_string(); + + // Detect if this is a virtual environment + if let Some(interpreter_dir) = Path::new(&interpreter_path).parent() { + if let Some(venv_dir) = interpreter_dir.parent() { + // Check if this looks like a virtual environment + if venv_dir.join("pyvenv.cfg").exists() + || venv_dir.join("bin/activate").exists() + || venv_dir.join("Scripts/activate.bat").exists() + { + // Set venvPath and venv at the root level + // This matches the format of a pyrightconfig.json file + if let Some(parent) = venv_dir.parent() { + // Use relative path if the venv is inside the workspace + let venv_path = if parent == adapter.worktree_root_path() { + ".".to_string() + } else { + parent.to_string_lossy().into_owned() + }; + object.insert("venvPath".to_string(), Value::String(venv_path)); + } + + if let Some(venv_name) = venv_dir.file_name() { + object.insert( + "venv".to_owned(), + Value::String(venv_name.to_string_lossy().into_owned()), + ); + } + } + } + } + + // Always set the python interpreter path + // Get or create the python section + let python = object + .entry("python") + .or_insert(Value::Object(serde_json::Map::default())) + .as_object_mut() + .unwrap(); + + // Set both pythonPath and defaultInterpreterPath for compatibility + python.insert( + "pythonPath".to_owned(), + Value::String(interpreter_path.clone()), + ); + python.insert( + "defaultInterpreterPath".to_owned(), + Value::String(interpreter_path), + ); + } + + user_settings + }) + } + + fn manifest_name(&self) -> Option<ManifestName> { + Some(SharedString::new_static("pyproject.toml").into()) + } + + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::WorktreeRoot + } } #[cfg(test)] diff --git a/crates/languages/src/tailwind.rs b/crates/languages/src/tailwind.rs index cb4e939083..a7edbb148c 100644 --- a/crates/languages/src/tailwind.rs +++ b/crates/languages/src/tailwind.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use collections::HashMap; use futures::StreamExt; use gpui::AsyncApp; -use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; @@ -168,20 +168,20 @@ impl LspAdapter for TailwindLspAdapter { })) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { HashMap::from_iter([ - ("Astro".to_string(), "astro".to_string()), - ("HTML".to_string(), "html".to_string()), - ("CSS".to_string(), "css".to_string()), - ("JavaScript".to_string(), "javascript".to_string()), - ("TSX".to_string(), "typescriptreact".to_string()), - ("Svelte".to_string(), "svelte".to_string()), - ("Elixir".to_string(), "phoenix-heex".to_string()), - ("HEEX".to_string(), "phoenix-heex".to_string()), - ("ERB".to_string(), "erb".to_string()), - ("HTML/ERB".to_string(), "erb".to_string()), - ("PHP".to_string(), "php".to_string()), - ("Vue.js".to_string(), "vue".to_string()), + (LanguageName::new("Astro"), "astro".to_string()), + (LanguageName::new("HTML"), "html".to_string()), + (LanguageName::new("CSS"), "css".to_string()), + (LanguageName::new("JavaScript"), "javascript".to_string()), + (LanguageName::new("TSX"), "typescriptreact".to_string()), + (LanguageName::new("Svelte"), "svelte".to_string()), + (LanguageName::new("Elixir"), "phoenix-heex".to_string()), + (LanguageName::new("HEEX"), "phoenix-heex".to_string()), + (LanguageName::new("ERB"), "erb".to_string()), + (LanguageName::new("HTML/ERB"), "erb".to_string()), + (LanguageName::new("PHP"), "php".to_string()), + (LanguageName::new("Vue.js"), "vue".to_string()), ]) } } diff --git a/crates/languages/src/typescript.rs b/crates/languages/src/typescript.rs index fb51544841..9dc3ee303d 100644 --- a/crates/languages/src/typescript.rs +++ b/crates/languages/src/typescript.rs @@ -8,7 +8,8 @@ use futures::future::join_all; use gpui::{App, AppContext, AsyncApp, Task}; use http_client::github::{AssetKind, GitHubLspBinaryVersion, build_asset_url}; use language::{ - ContextLocation, ContextProvider, File, LanguageToolchainStore, LspAdapter, LspAdapterDelegate, + ContextLocation, ContextProvider, File, LanguageName, LanguageToolchainStore, LspAdapter, + LspAdapterDelegate, }; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; @@ -741,11 +742,11 @@ impl LspAdapter for TypeScriptLspAdapter { })) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { HashMap::from_iter([ - ("TypeScript".into(), "typescript".into()), - ("JavaScript".into(), "javascript".into()), - ("TSX".into(), "typescriptreact".into()), + (LanguageName::new("TypeScript"), "typescript".into()), + (LanguageName::new("JavaScript"), "javascript".into()), + (LanguageName::new("TSX"), "typescriptreact".into()), ]) } } diff --git a/crates/languages/src/typescript/runnables.scm b/crates/languages/src/typescript/runnables.scm index 85702cf99d..6bfc536329 100644 --- a/crates/languages/src/typescript/runnables.scm +++ b/crates/languages/src/typescript/runnables.scm @@ -1,4 +1,4 @@ -; Add support for (node:test, bun:test and Jest) runnable +; Add support for (node:test, bun:test, Jest and Deno.test) runnable ; Function expression that has `it`, `test` or `describe` as the function name ( (call_expression @@ -44,3 +44,42 @@ (#set! tag js-test) ) + +; Add support for Deno.test with string names +( + (call_expression + function: (member_expression + object: (identifier) @_namespace + property: (property_identifier) @_method + ) + (#eq? @_namespace "Deno") + (#eq? @_method "test") + arguments: ( + arguments . [ + (string (string_fragment) @run @DENO_TEST_NAME) + (identifier) @run @DENO_TEST_NAME + ] + ) + ) @_js-test + + (#set! tag js-test) +) + +; Add support for Deno.test with named function expressions +( + (call_expression + function: (member_expression + object: (identifier) @_namespace + property: (property_identifier) @_method + ) + (#eq? @_namespace "Deno") + (#eq? @_method "test") + arguments: ( + arguments . (function_expression + name: (identifier) @run @DENO_TEST_NAME + ) + ) + ) @_js-test + + (#set! tag js-test) +) diff --git a/crates/languages/src/vtsls.rs b/crates/languages/src/vtsls.rs index ca07673d5f..33751f733e 100644 --- a/crates/languages/src/vtsls.rs +++ b/crates/languages/src/vtsls.rs @@ -2,7 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use collections::HashMap; use gpui::AsyncApp; -use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; @@ -273,11 +273,11 @@ impl LspAdapter for VtslsLspAdapter { Ok(default_workspace_configuration) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { HashMap::from_iter([ - ("TypeScript".into(), "typescript".into()), - ("JavaScript".into(), "javascript".into()), - ("TSX".into(), "typescriptreact".into()), + (LanguageName::new("TypeScript"), "typescript".into()), + (LanguageName::new("JavaScript"), "javascript".into()), + (LanguageName::new("TSX"), "typescriptreact".into()), ]) } } diff --git a/crates/languages/src/yaml/outline.scm b/crates/languages/src/yaml/outline.scm index 7ab007835f..c5a7f8e5d4 100644 --- a/crates/languages/src/yaml/outline.scm +++ b/crates/languages/src/yaml/outline.scm @@ -1 +1,9 @@ -(block_mapping_pair key: (flow_node (plain_scalar (string_scalar) @name))) @item +(block_mapping_pair + key: + (flow_node + (plain_scalar + (string_scalar) @name)) + value: + (flow_node + (plain_scalar + (string_scalar) @context))?) @item diff --git a/crates/livekit_client/Cargo.toml b/crates/livekit_client/Cargo.toml index c367e03bb7..821fd5d390 100644 --- a/crates/livekit_client/Cargo.toml +++ b/crates/livekit_client/Cargo.toml @@ -40,8 +40,8 @@ util.workspace = true workspace-hack.workspace = true [target.'cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))'.dependencies] -libwebrtc = { rev = "383e5377f8b7de1f8627ee16f0cf11c5293337bd", git = "https://github.com/zed-industries/livekit-rust-sdks" } -livekit = { rev = "383e5377f8b7de1f8627ee16f0cf11c5293337bd", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ +libwebrtc = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks" } +livekit = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ "__rustls-tls" ] } diff --git a/crates/livekit_client/src/livekit_client/playback.rs b/crates/livekit_client/src/livekit_client/playback.rs index c62b8853b4..f14e156125 100644 --- a/crates/livekit_client/src/livekit_client/playback.rs +++ b/crates/livekit_client/src/livekit_client/playback.rs @@ -1,6 +1,7 @@ use anyhow::{Context as _, Result}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait as _}; +use cpal::{Data, FromSample, I24, SampleFormat, SizedSample}; use futures::channel::mpsc::UnboundedSender; use futures::{Stream, StreamExt as _}; use gpui::{ @@ -258,9 +259,15 @@ impl AudioStack { let stream = device .build_input_stream_raw( &config.config(), - cpal::SampleFormat::I16, + config.sample_format(), move |data, _: &_| { - let mut data = data.as_slice::<i16>().unwrap(); + let data = + Self::get_sample_data(config.sample_format(), data).log_err(); + let Some(data) = data else { + return; + }; + let mut data = data.as_slice(); + while data.len() > 0 { let remainder = (buf.capacity() - buf.len()).min(data.len()); buf.extend_from_slice(&data[..remainder]); @@ -313,6 +320,33 @@ impl AudioStack { drop(end_on_drop_tx) } } + + fn get_sample_data(sample_format: SampleFormat, data: &Data) -> Result<Vec<i16>> { + match sample_format { + SampleFormat::I8 => Ok(Self::convert_sample_data::<i8, i16>(data)), + SampleFormat::I16 => Ok(data.as_slice::<i16>().unwrap().to_vec()), + SampleFormat::I24 => Ok(Self::convert_sample_data::<I24, i16>(data)), + SampleFormat::I32 => Ok(Self::convert_sample_data::<i32, i16>(data)), + SampleFormat::I64 => Ok(Self::convert_sample_data::<i64, i16>(data)), + SampleFormat::U8 => Ok(Self::convert_sample_data::<u8, i16>(data)), + SampleFormat::U16 => Ok(Self::convert_sample_data::<u16, i16>(data)), + SampleFormat::U32 => Ok(Self::convert_sample_data::<u32, i16>(data)), + SampleFormat::U64 => Ok(Self::convert_sample_data::<u64, i16>(data)), + SampleFormat::F32 => Ok(Self::convert_sample_data::<f32, i16>(data)), + SampleFormat::F64 => Ok(Self::convert_sample_data::<f64, i16>(data)), + _ => anyhow::bail!("Unsupported sample format"), + } + } + + fn convert_sample_data<TSource: SizedSample, TDest: SizedSample + FromSample<TSource>>( + data: &Data, + ) -> Vec<TDest> { + data.as_slice::<TSource>() + .unwrap() + .iter() + .map(|e| e.to_sample::<TDest>()) + .collect() + } } use super::LocalVideoTrack; diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index 9978d7ebb1..b9701a83d2 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -29,7 +29,7 @@ use std::{ ffi::{OsStr, OsString}, fmt, io::Write, - ops::{Deref, DerefMut}, + ops::DerefMut, path::PathBuf, pin::Pin, sync::{ @@ -100,7 +100,7 @@ pub struct LanguageServer { io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>, output_done_rx: Mutex<Option<barrier::Receiver>>, server: Arc<Mutex<Option<Child>>>, - workspace_folders: Arc<Mutex<BTreeSet<Url>>>, + workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, root_uri: Url, } @@ -307,7 +307,7 @@ impl LanguageServer { binary: LanguageServerBinary, root_path: &Path, code_action_kinds: Option<Vec<CodeActionKind>>, - workspace_folders: Arc<Mutex<BTreeSet<Url>>>, + workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, cx: &mut AsyncApp, ) -> Result<Self> { let working_dir = if root_path.is_dir() { @@ -381,7 +381,7 @@ impl LanguageServer { code_action_kinds: Option<Vec<CodeActionKind>>, binary: LanguageServerBinary, root_uri: Url, - workspace_folders: Arc<Mutex<BTreeSet<Url>>>, + workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, cx: &mut AsyncApp, on_unhandled_notification: F, ) -> Self @@ -421,14 +421,14 @@ impl LanguageServer { .map(|stderr| { let io_handlers = io_handlers.clone(); let stderr_captures = stderr_capture.clone(); - cx.spawn(async move |_| { + cx.background_spawn(async move { Self::handle_stderr(stderr, io_handlers, stderr_captures) .log_err() .await }) }) .unwrap_or_else(|| Task::ready(None)); - let input_task = cx.spawn(async move |_| { + let input_task = cx.background_spawn(async move { let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task); stdout.or(stderr) }); @@ -595,16 +595,26 @@ impl LanguageServer { } pub fn default_initialize_params(&self, pull_diagnostics: bool, cx: &App) -> InitializeParams { - let workspace_folders = self - .workspace_folders - .lock() - .iter() - .cloned() - .map(|uri| WorkspaceFolder { - name: Default::default(), - uri, - }) - .collect::<Vec<_>>(); + let workspace_folders = self.workspace_folders.as_ref().map_or_else( + || { + vec![WorkspaceFolder { + name: Default::default(), + uri: self.root_uri.clone(), + }] + }, + |folders| { + folders + .lock() + .iter() + .cloned() + .map(|uri| WorkspaceFolder { + name: Default::default(), + uri, + }) + .collect() + }, + ); + #[allow(deprecated)] InitializeParams { process_id: None, @@ -836,7 +846,7 @@ impl LanguageServer { configuration: Arc<DidChangeConfigurationParams>, cx: &App, ) -> Task<Result<Arc<Self>>> { - cx.spawn(async move |_| { + cx.background_spawn(async move { let response = self .request::<request::Initialize>(params) .await @@ -1315,7 +1325,10 @@ impl LanguageServer { return; } - let is_new_folder = self.workspace_folders.lock().insert(uri.clone()); + let Some(workspace_folders) = self.workspace_folders.as_ref() else { + return; + }; + let is_new_folder = workspace_folders.lock().insert(uri.clone()); if is_new_folder { let params = DidChangeWorkspaceFoldersParams { event: WorkspaceFoldersChangeEvent { @@ -1345,7 +1358,10 @@ impl LanguageServer { { return; } - let was_removed = self.workspace_folders.lock().remove(&uri); + let Some(workspace_folders) = self.workspace_folders.as_ref() else { + return; + }; + let was_removed = workspace_folders.lock().remove(&uri); if was_removed { let params = DidChangeWorkspaceFoldersParams { event: WorkspaceFoldersChangeEvent { @@ -1360,7 +1376,10 @@ impl LanguageServer { } } pub fn set_workspace_folders(&self, folders: BTreeSet<Url>) { - let mut workspace_folders = self.workspace_folders.lock(); + let Some(workspace_folders) = self.workspace_folders.as_ref() else { + return; + }; + let mut workspace_folders = workspace_folders.lock(); let old_workspace_folders = std::mem::take(&mut *workspace_folders); let added: Vec<_> = folders @@ -1389,8 +1408,11 @@ impl LanguageServer { } } - pub fn workspace_folders(&self) -> impl Deref<Target = BTreeSet<Url>> + '_ { - self.workspace_folders.lock() + pub fn workspace_folders(&self) -> BTreeSet<Url> { + self.workspace_folders.as_ref().map_or_else( + || BTreeSet::from_iter([self.root_uri.clone()]), + |folders| folders.lock().clone(), + ) } pub fn register_buffer( @@ -1535,7 +1557,7 @@ impl FakeLanguageServer { None, binary.clone(), root, - workspace_folders.clone(), + Some(workspace_folders.clone()), cx, |_| {}, ); @@ -1554,7 +1576,7 @@ impl FakeLanguageServer { None, binary, Self::root_path(), - workspace_folders, + Some(workspace_folders), cx, move |msg| { notifications_tx diff --git a/crates/migrator/src/migrations/m_2025_01_29/keymap.rs b/crates/migrator/src/migrations/m_2025_01_29/keymap.rs index c32da88229..646af8f63d 100644 --- a/crates/migrator/src/migrations/m_2025_01_29/keymap.rs +++ b/crates/migrator/src/migrations/m_2025_01_29/keymap.rs @@ -242,22 +242,22 @@ static STRING_REPLACE: LazyLock<HashMap<&str, &str>> = LazyLock::new(|| { "inline_completion::ToggleMenu", "edit_prediction::ToggleMenu", ), - ("editor::NextInlineCompletion", "editor::NextEditPrediction"), + ("editor::NextEditPrediction", "editor::NextEditPrediction"), ( - "editor::PreviousInlineCompletion", + "editor::PreviousEditPrediction", "editor::PreviousEditPrediction", ), ( - "editor::AcceptPartialInlineCompletion", + "editor::AcceptPartialEditPrediction", "editor::AcceptPartialEditPrediction", ), - ("editor::ShowInlineCompletion", "editor::ShowEditPrediction"), + ("editor::ShowEditPrediction", "editor::ShowEditPrediction"), ( - "editor::AcceptInlineCompletion", + "editor::AcceptEditPrediction", "editor::AcceptEditPrediction", ), ( - "editor::ToggleInlineCompletions", + "editor::ToggleEditPredictions", "editor::ToggleEditPrediction", ), ]) diff --git a/crates/multi_buffer/src/anchor.rs b/crates/multi_buffer/src/anchor.rs index 9e28295c56..1305328d38 100644 --- a/crates/multi_buffer/src/anchor.rs +++ b/crates/multi_buffer/src/anchor.rs @@ -167,10 +167,10 @@ impl Anchor { if *self == Anchor::min() || *self == Anchor::max() { true } else if let Some(excerpt) = snapshot.excerpt(self.excerpt_id) { - excerpt.contains(self) - && (self.text_anchor == excerpt.range.context.start - || self.text_anchor == excerpt.range.context.end - || self.text_anchor.is_valid(&excerpt.buffer)) + (self.text_anchor == excerpt.range.context.start + || self.text_anchor == excerpt.range.context.end + || self.text_anchor.is_valid(&excerpt.buffer)) + && excerpt.contains(self) } else { false } diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index f0913e30fb..eb12e6929c 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -43,7 +43,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use sum_tree::{Bias, Cursor, Dimension, SumTree, Summary, TreeMap}; +use sum_tree::{Bias, Cursor, Dimension, Dimensions, SumTree, Summary, TreeMap}; use text::{ BufferId, Edit, LineIndent, TextSummary, locator::Locator, @@ -474,7 +474,7 @@ pub struct MultiBufferRows<'a> { pub struct MultiBufferChunks<'a> { excerpts: Cursor<'a, Excerpt, ExcerptOffset>, - diff_transforms: Cursor<'a, DiffTransform, (usize, ExcerptOffset)>, + diff_transforms: Cursor<'a, DiffTransform, Dimensions<usize, ExcerptOffset>>, diffs: &'a TreeMap<BufferId, BufferDiffSnapshot>, diff_base_chunks: Option<(BufferId, BufferChunks<'a>)>, buffer_chunk: Option<Chunk<'a>>, @@ -2120,10 +2120,10 @@ impl MultiBuffer { let buffers = self.buffers.borrow(); let mut excerpts = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptDimension<Point>)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptDimension<Point>>>(&()); let mut diff_transforms = snapshot .diff_transforms - .cursor::<(ExcerptDimension<Point>, OutputDimension<Point>)>(&()); + .cursor::<Dimensions<ExcerptDimension<Point>, OutputDimension<Point>>>(&()); diff_transforms.next(); let locators = buffers .get(&buffer_id) @@ -2281,7 +2281,7 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let mut edits = Vec::new(); let mut excerpt_ids = ids.iter().copied().peekable(); let mut removed_buffer_ids = Vec::new(); @@ -2492,7 +2492,7 @@ impl MultiBuffer { for locator in &buffer_state.excerpts { let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); cursor.seek_forward(&Some(locator), Bias::Left); if let Some(excerpt) = cursor.item() { if excerpt.locator == *locator { @@ -2845,7 +2845,7 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let mut edits = Vec::<Edit<ExcerptOffset>>::new(); let prefix = cursor.slice(&Some(locator), Bias::Left); @@ -2921,7 +2921,7 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let mut edits = Vec::<Edit<ExcerptOffset>>::new(); for locator in &locators { @@ -3067,7 +3067,7 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); for (locator, buffer, buffer_edited) in excerpts_to_edit { new_excerpts.append(cursor.slice(&Some(locator), Bias::Left), &()); @@ -3135,7 +3135,7 @@ impl MultiBuffer { let mut excerpts = snapshot.excerpts.cursor::<ExcerptOffset>(&()); let mut old_diff_transforms = snapshot .diff_transforms - .cursor::<(ExcerptOffset, usize)>(&()); + .cursor::<Dimensions<ExcerptOffset, usize>>(&()); let mut new_diff_transforms = SumTree::default(); let mut old_expanded_hunks = HashSet::default(); let mut output_edits = Vec::new(); @@ -3260,7 +3260,7 @@ impl MultiBuffer { &self, edit: &Edit<TypedOffset<Excerpt>>, excerpts: &mut Cursor<Excerpt, TypedOffset<Excerpt>>, - old_diff_transforms: &mut Cursor<DiffTransform, (TypedOffset<Excerpt>, usize)>, + old_diff_transforms: &mut Cursor<DiffTransform, Dimensions<TypedOffset<Excerpt>, usize>>, new_diff_transforms: &mut SumTree<DiffTransform>, end_of_current_insert: &mut Option<(TypedOffset<Excerpt>, DiffTransformHunkInfo)>, old_expanded_hunks: &mut HashSet<DiffTransformHunkInfo>, @@ -4713,7 +4713,9 @@ impl MultiBufferSnapshot { O: ToOffset, { let range = range.start.to_offset(self)..range.end.to_offset(self); - let mut cursor = self.diff_transforms.cursor::<(usize, ExcerptOffset)>(&()); + let mut cursor = self + .diff_transforms + .cursor::<Dimensions<usize, ExcerptOffset>>(&()); cursor.seek(&range.start, Bias::Right); let Some(first_transform) = cursor.item() else { @@ -4867,7 +4869,10 @@ impl MultiBufferSnapshot { &self, anchor: &Anchor, excerpt_position: D, - diff_transforms: &mut Cursor<DiffTransform, (ExcerptDimension<D>, OutputDimension<D>)>, + diff_transforms: &mut Cursor< + DiffTransform, + Dimensions<ExcerptDimension<D>, OutputDimension<D>>, + >, ) -> D where D: TextDimension + Ord + Sub<D, Output = D>, @@ -4927,7 +4932,7 @@ impl MultiBufferSnapshot { fn excerpt_offset_for_anchor(&self, anchor: &Anchor) -> ExcerptOffset { let mut cursor = self .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let locator = self.excerpt_locator_for_id(anchor.excerpt_id); cursor.seek(&Some(locator), Bias::Left); @@ -4971,7 +4976,7 @@ impl MultiBufferSnapshot { let mut cursor = self.excerpts.cursor::<ExcerptSummary>(&()); let mut diff_transforms_cursor = self .diff_transforms - .cursor::<(ExcerptDimension<D>, OutputDimension<D>)>(&()); + .cursor::<Dimensions<ExcerptDimension<D>, OutputDimension<D>>>(&()); diff_transforms_cursor.next(); let mut summaries = Vec::new(); @@ -5201,7 +5206,9 @@ impl MultiBufferSnapshot { // Find the given position in the diff transforms. Determine the corresponding // offset in the excerpts, and whether the position is within a deleted hunk. - let mut diff_transforms = self.diff_transforms.cursor::<(usize, ExcerptOffset)>(&()); + let mut diff_transforms = self + .diff_transforms + .cursor::<Dimensions<usize, ExcerptOffset>>(&()); diff_transforms.seek(&offset, Bias::Right); if offset == diff_transforms.start().0 && bias == Bias::Left { @@ -5250,7 +5257,7 @@ impl MultiBufferSnapshot { let mut excerpts = self .excerpts - .cursor::<(ExcerptOffset, Option<ExcerptId>)>(&()); + .cursor::<Dimensions<ExcerptOffset, Option<ExcerptId>>>(&()); excerpts.seek(&excerpt_offset, Bias::Right); if excerpts.item().is_none() && excerpt_offset == excerpts.start().0 && bias == Bias::Left { excerpts.prev(); @@ -5341,7 +5348,7 @@ impl MultiBufferSnapshot { let start_locator = self.excerpt_locator_for_id(id); let mut excerpts = self .excerpts - .cursor::<(Option<&Locator>, ExcerptDimension<usize>)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptDimension<usize>>>(&()); excerpts.seek(&Some(start_locator), Bias::Left); excerpts.prev(); @@ -6242,14 +6249,14 @@ impl MultiBufferSnapshot { pub fn range_for_excerpt(&self, excerpt_id: ExcerptId) -> Option<Range<Point>> { let mut cursor = self .excerpts - .cursor::<(Option<&Locator>, ExcerptDimension<Point>)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptDimension<Point>>>(&()); let locator = self.excerpt_locator_for_id(excerpt_id); if cursor.seek(&Some(locator), Bias::Left) { let start = cursor.start().1.clone(); let end = cursor.end().1; let mut diff_transforms = self .diff_transforms - .cursor::<(ExcerptDimension<Point>, OutputDimension<Point>)>(&()); + .cursor::<Dimensions<ExcerptDimension<Point>, OutputDimension<Point>>>(&()); diff_transforms.seek(&start, Bias::Left); let overshoot = start.0 - diff_transforms.start().0.0; let start = diff_transforms.start().1.0 + overshoot; diff --git a/crates/notifications/src/notification_store.rs b/crates/notifications/src/notification_store.rs index 0329a53cc7..29653748e4 100644 --- a/crates/notifications/src/notification_store.rs +++ b/crates/notifications/src/notification_store.rs @@ -6,7 +6,7 @@ use db::smol::stream::StreamExt; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task}; use rpc::{Notification, TypedEnvelope, proto}; use std::{ops::Range, sync::Arc}; -use sum_tree::{Bias, SumTree}; +use sum_tree::{Bias, Dimensions, SumTree}; use time::OffsetDateTime; use util::ResultExt; @@ -360,7 +360,9 @@ impl NotificationStore { is_new: bool, cx: &mut Context<NotificationStore>, ) { - let mut cursor = self.notifications.cursor::<(NotificationId, Count)>(&()); + let mut cursor = self + .notifications + .cursor::<Dimensions<NotificationId, Count>>(&()); let mut new_notifications = SumTree::default(); let mut old_range = 0..0; diff --git a/crates/onboarding/Cargo.toml b/crates/onboarding/Cargo.toml index 693e39d4ca..436c714cf3 100644 --- a/crates/onboarding/Cargo.toml +++ b/crates/onboarding/Cargo.toml @@ -15,14 +15,33 @@ path = "src/onboarding.rs" default = [] [dependencies] +ai_onboarding.workspace = true anyhow.workspace = true +client.workspace = true command_palette_hooks.workspace = true +component.workspace = true db.workspace = true +documented.workspace = true +editor.workspace = true feature_flags.workspace = true fs.workspace = true +fuzzy.workspace = true gpui.workspace = true +itertools.workspace = true +language.workspace = true +language_model.workspace = true +menu.workspace = true +notifications.workspace = true +picker.workspace = true +project.workspace = true +schemars.workspace = true +serde.workspace = true settings.workspace = true theme.workspace = true ui.workspace = true -workspace.workspace = true +util.workspace = true +vim_mode_setting.workspace = true workspace-hack.workspace = true +workspace.workspace = true +zed_actions.workspace = true +zlog.workspace = true diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs new file mode 100644 index 0000000000..098907870b --- /dev/null +++ b/crates/onboarding/src/ai_setup_page.rs @@ -0,0 +1,420 @@ +use std::sync::Arc; + +use ai_onboarding::{AiUpsellCard, SignInStatus}; +use client::UserStore; +use fs::Fs; +use gpui::{ + Action, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, WeakEntity, + Window, prelude::*, +}; +use itertools; +use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry}; +use project::DisableAiSettings; +use settings::{Settings, update_settings_file}; +use ui::{ + Badge, ButtonLike, Divider, Modal, ModalFooter, ModalHeader, Section, SwitchField, ToggleState, + prelude::*, tooltip_container, +}; +use util::ResultExt; +use workspace::{ModalView, Workspace}; +use zed_actions::agent::OpenSettings; + +const FEATURED_PROVIDERS: [&'static str; 4] = ["anthropic", "google", "openai", "ollama"]; + +fn render_llm_provider_section( + tab_index: &mut isize, + workspace: WeakEntity<Workspace>, + disabled: bool, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + v_flex() + .gap_4() + .child( + v_flex() + .child(Label::new("Or use other LLM providers").size(LabelSize::Large)) + .child( + Label::new("Bring your API keys to use the available providers with Zed's UI for free.") + .color(Color::Muted), + ), + ) + .child(render_llm_provider_card(tab_index, workspace, disabled, window, cx)) +} + +fn render_privacy_card(tab_index: &mut isize, disabled: bool, cx: &mut App) -> impl IntoElement { + let privacy_badge = || { + Badge::new("Privacy") + .icon(IconName::ShieldCheck) + .tooltip(move |_, cx| cx.new(|_| AiPrivacyTooltip::new()).into()) + }; + + v_flex() + .relative() + .pt_2() + .pb_2p5() + .pl_3() + .pr_2() + .border_1() + .border_dashed() + .border_color(cx.theme().colors().border.opacity(0.5)) + .bg(cx.theme().colors().surface_background.opacity(0.3)) + .rounded_lg() + .overflow_hidden() + .map(|this| { + if disabled { + this.child( + h_flex() + .gap_2() + .justify_between() + .child( + h_flex() + .gap_1() + .child(Label::new("AI is disabled across Zed")) + .child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ), + ) + .child(privacy_badge()), + ) + .child( + Label::new("Re-enable it any time in Settings.") + .size(LabelSize::Small) + .color(Color::Muted), + ) + } else { + this.child( + h_flex() + .gap_2() + .justify_between() + .child(Label::new("We don't train models using your data")) + .child( + h_flex().gap_1().child(privacy_badge()).child( + Button::new("learn_more", "Learn More") + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(|_, _, cx| { + cx.open_url("https://zed.dev/docs/ai/privacy-and-security"); + }) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ), + ), + ) + .child( + Label::new( + "Feel confident in the security and privacy of your projects using Zed.", + ) + .size(LabelSize::Small) + .color(Color::Muted), + ) + } + }) +} + +fn render_llm_provider_card( + tab_index: &mut isize, + workspace: WeakEntity<Workspace>, + disabled: bool, + _: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let registry = LanguageModelRegistry::read_global(cx); + + v_flex() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().surface_background.opacity(0.5)) + .rounded_lg() + .overflow_hidden() + .children(itertools::intersperse_with( + FEATURED_PROVIDERS + .into_iter() + .flat_map(|provider_name| { + registry.provider(&LanguageModelProviderId::new(provider_name)) + }) + .enumerate() + .map(|(index, provider)| { + let group_name = SharedString::new(format!("onboarding-hover-group-{}", index)); + let is_authenticated = provider.is_authenticated(cx); + + ButtonLike::new(("onboarding-ai-setup-buttons", index)) + .size(ButtonSize::Large) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .child( + h_flex() + .group(&group_name) + .px_0p5() + .w_full() + .gap_2() + .justify_between() + .child( + h_flex() + .gap_1() + .child( + Icon::new(provider.icon()) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(provider.name().0)), + ) + .child( + h_flex() + .gap_1() + .when(!is_authenticated, |el| { + el.visible_on_hover(group_name.clone()) + .child( + Icon::new(IconName::Settings) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child( + Label::new("Configure") + .color(Color::Muted) + .size(LabelSize::Small), + ) + }) + .when(is_authenticated && !disabled, |el| { + el.child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ) + .child( + Label::new("Configured") + .color(Color::Muted) + .size(LabelSize::Small), + ) + }), + ), + ) + .on_click({ + let workspace = workspace.clone(); + move |_, window, cx| { + workspace + .update(cx, |workspace, cx| { + workspace.toggle_modal(window, cx, |window, cx| { + let modal = AiConfigurationModal::new( + provider.clone(), + window, + cx, + ); + window.focus(&modal.focus_handle(cx)); + modal + }); + }) + .log_err(); + } + }) + .into_any_element() + }), + || Divider::horizontal().into_any_element(), + )) + .child(Divider::horizontal()) + .child( + Button::new("agent_settings", "Add Many Others") + .size(ButtonSize::Large) + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click(|_event, window, cx| { + window.dispatch_action(OpenSettings.boxed_clone(), cx) + }) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) +} + +pub(crate) fn render_ai_setup_page( + workspace: WeakEntity<Workspace>, + user_store: Entity<UserStore>, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let mut tab_index = 0; + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + + v_flex() + .gap_2() + .child( + SwitchField::new( + "enable_ai", + "Enable AI features", + None, + if is_ai_disabled { + ToggleState::Unselected + } else { + ToggleState::Selected + }, + |&toggle_state, _, cx| { + let fs = <dyn Fs>::global(cx); + update_settings_file::<DisableAiSettings>( + fs, + cx, + move |ai_settings: &mut Option<bool>, _| { + *ai_settings = match toggle_state { + ToggleState::Indeterminate => None, + ToggleState::Unselected => Some(true), + ToggleState::Selected => Some(false), + }; + }, + ); + }, + ) + .tab_index({ + tab_index += 1; + tab_index - 1 + }), + ) + .child(render_privacy_card(&mut tab_index, is_ai_disabled, cx)) + .child( + v_flex() + .mt_2() + .gap_6() + .child(AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + user_plan: user_store.read(cx).plan(), + tab_index: Some({ + tab_index += 1; + tab_index - 1 + }), + }) + .child(render_llm_provider_section( + &mut tab_index, + workspace, + is_ai_disabled, + window, + cx, + )) + .when(is_ai_disabled, |this| { + this.child( + div() + .id("backdrop") + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().editor_background) + .opacity(0.8) + .block_mouse_except_scroll(), + ) + }), + ) +} + +struct AiConfigurationModal { + focus_handle: FocusHandle, + selected_provider: Arc<dyn LanguageModelProvider>, + configuration_view: AnyView, +} + +impl AiConfigurationModal { + fn new( + selected_provider: Arc<dyn LanguageModelProvider>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let focus_handle = cx.focus_handle(); + let configuration_view = selected_provider.configuration_view(window, cx); + + Self { + focus_handle, + configuration_view, + selected_provider, + } + } +} + +impl ModalView for AiConfigurationModal {} + +impl EventEmitter<DismissEvent> for AiConfigurationModal {} + +impl Focusable for AiConfigurationModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for AiConfigurationModal { + fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + v_flex() + .w(rems(34.)) + .elevation_3(cx) + .track_focus(&self.focus_handle) + .child( + Modal::new("onboarding-ai-setup-modal", None) + .header( + ModalHeader::new() + .icon( + Icon::new(self.selected_provider.icon()) + .color(Color::Muted) + .size(IconSize::Small), + ) + .headline(self.selected_provider.name().0), + ) + .section(Section::new().child(self.configuration_view.clone())) + .footer( + ModalFooter::new().end_slot( + h_flex() + .gap_1() + .child( + Button::new("onboarding-closing-cancel", "Cancel") + .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + ) + .child(Button::new("save-btn", "Done").on_click(cx.listener( + |_, _, window, cx| { + window.dispatch_action(menu::Confirm.boxed_clone(), cx); + cx.emit(DismissEvent); + }, + ))), + ), + ), + ) + } +} + +pub struct AiPrivacyTooltip {} + +impl AiPrivacyTooltip { + pub fn new() -> Self { + Self {} + } +} + +impl Render for AiPrivacyTooltip { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + const DESCRIPTION: &'static str = "One of Zed's most important principles is transparency. This is why we are and value open-source so much. And it wouldn't be any different with AI."; + + tooltip_container(window, cx, move |this, _, _| { + this.child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::ShieldCheck) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child(Label::new("Privacy Principle")), + ) + .child( + div().max_w_64().child( + Label::new(DESCRIPTION) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + }) + } +} diff --git a/crates/onboarding/src/basics_page.rs b/crates/onboarding/src/basics_page.rs new file mode 100644 index 0000000000..a4e4028051 --- /dev/null +++ b/crates/onboarding/src/basics_page.rs @@ -0,0 +1,361 @@ +use std::sync::Arc; + +use client::TelemetrySettings; +use fs::Fs; +use gpui::{App, IntoElement}; +use settings::{BaseKeymap, Settings, update_settings_file}; +use theme::{ + Appearance, SystemAppearance, ThemeMode, ThemeName, ThemeRegistry, ThemeSelection, + ThemeSettings, +}; +use ui::{ + ParentElement as _, StatefulInteractiveElement, SwitchField, ToggleButtonGroup, + ToggleButtonSimple, ToggleButtonWithIcon, prelude::*, rems_from_px, +}; +use vim_mode_setting::VimModeSetting; + +use crate::theme_preview::{ThemePreviewStyle, ThemePreviewTile}; + +fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { + let theme_selection = ThemeSettings::get_global(cx).theme_selection.clone(); + let system_appearance = theme::SystemAppearance::global(cx); + let theme_selection = theme_selection.unwrap_or_else(|| ThemeSelection::Dynamic { + mode: match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }, + light: ThemeName("One Light".into()), + dark: ThemeName("One Dark".into()), + }); + + let theme_mode = theme_selection + .mode() + .unwrap_or_else(|| match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }); + + return v_flex() + .gap_2() + .child( + h_flex().justify_between().child(Label::new("Theme")).child( + ToggleButtonGroup::single_row( + "theme-selector-onboarding-dark-light", + [ThemeMode::Light, ThemeMode::Dark, ThemeMode::System].map(|mode| { + const MODE_NAMES: [SharedString; 3] = [ + SharedString::new_static("Light"), + SharedString::new_static("Dark"), + SharedString::new_static("System"), + ]; + ToggleButtonSimple::new( + MODE_NAMES[mode as usize].clone(), + move |_, _, cx| { + write_mode_change(mode, cx); + }, + ) + }), + ) + .tab_index(tab_index) + .selected_index(theme_mode as usize) + .style(ui::ToggleButtonGroupStyle::Outlined) + .button_width(rems_from_px(64.)), + ), + ) + .child( + h_flex() + .gap_4() + .justify_between() + .children(render_theme_previews(tab_index, &theme_selection, cx)), + ); + + fn render_theme_previews( + tab_index: &mut isize, + theme_selection: &ThemeSelection, + cx: &mut App, + ) -> [impl IntoElement; 3] { + let system_appearance = SystemAppearance::global(cx); + let theme_registry = ThemeRegistry::global(cx); + + let theme_seed = 0xBEEF as f32; + let theme_mode = theme_selection + .mode() + .unwrap_or_else(|| match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }); + let appearance = match theme_mode { + ThemeMode::Light => Appearance::Light, + ThemeMode::Dark => Appearance::Dark, + ThemeMode::System => *system_appearance, + }; + let current_theme_name = theme_selection.theme(appearance); + + const LIGHT_THEMES: [&'static str; 3] = ["One Light", "Ayu Light", "Gruvbox Light"]; + const DARK_THEMES: [&'static str; 3] = ["One Dark", "Ayu Dark", "Gruvbox Dark"]; + const FAMILY_NAMES: [SharedString; 3] = [ + SharedString::new_static("One"), + SharedString::new_static("Ayu"), + SharedString::new_static("Gruvbox"), + ]; + + let theme_names = match appearance { + Appearance::Light => LIGHT_THEMES, + Appearance::Dark => DARK_THEMES, + }; + + let themes = theme_names.map(|theme| theme_registry.get(theme).unwrap()); + + let theme_previews = [0, 1, 2].map(|index| { + let theme = &themes[index]; + let is_selected = theme.name == current_theme_name; + let name = theme.name.clone(); + let colors = cx.theme().colors(); + + v_flex() + .w_full() + .items_center() + .gap_1() + .child( + h_flex() + .id(name.clone()) + .relative() + .w_full() + .border_2() + .border_color(colors.border_transparent) + .rounded(ThemePreviewTile::ROOT_RADIUS) + .map(|this| { + if is_selected { + this.border_color(colors.border_selected) + } else { + this.opacity(0.8).hover(|s| s.border_color(colors.border)) + } + }) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .focus(|mut style| { + style.border_color = Some(colors.border_focused); + style + }) + .on_click({ + let theme_name = theme.name.clone(); + move |_, _, cx| { + write_theme_change(theme_name.clone(), theme_mode, cx); + } + }) + .map(|this| { + if theme_mode == ThemeMode::System { + let (light, dark) = ( + theme_registry.get(LIGHT_THEMES[index]).unwrap(), + theme_registry.get(DARK_THEMES[index]).unwrap(), + ); + this.child( + ThemePreviewTile::new(light, theme_seed) + .style(ThemePreviewStyle::SideBySide(dark)), + ) + } else { + this.child( + ThemePreviewTile::new(theme.clone(), theme_seed) + .style(ThemePreviewStyle::Bordered), + ) + } + }), + ) + .child( + Label::new(FAMILY_NAMES[index].clone()) + .color(Color::Muted) + .size(LabelSize::Small), + ) + }); + + theme_previews + } + + fn write_mode_change(mode: ThemeMode, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + update_settings_file::<ThemeSettings>(fs, cx, move |settings, _cx| { + settings.set_mode(mode); + }); + } + + fn write_theme_change(theme: impl Into<Arc<str>>, theme_mode: ThemeMode, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + let theme = theme.into(); + update_settings_file::<ThemeSettings>(fs, cx, move |settings, cx| { + if theme_mode == ThemeMode::System { + settings.theme = Some(ThemeSelection::Dynamic { + mode: ThemeMode::System, + light: ThemeName(theme.clone()), + dark: ThemeName(theme.clone()), + }); + } else { + let appearance = *SystemAppearance::global(cx); + settings.set_theme(theme.clone(), appearance); + } + }); + } +} + +fn render_telemetry_section(tab_index: &mut isize, cx: &App) -> impl IntoElement { + let fs = <dyn Fs>::global(cx); + + v_flex() + .gap_4() + .child(Label::new("Telemetry").size(LabelSize::Large)) + .child(SwitchField::new( + "onboarding-telemetry-metrics", + "Help Improve Zed", + Some("Sending anonymous usage data helps us build the right features and create the best experience.".into()), + if TelemetrySettings::get_global(cx).metrics { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + { + let fs = fs.clone(); + move |selection, _, cx| { + let enabled = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { return; }, + }; + + update_settings_file::<TelemetrySettings>( + fs.clone(), + cx, + move |setting, _| setting.metrics = Some(enabled), + ); + }}, + ).tab_index({ + *tab_index += 1; + *tab_index + })) + .child(SwitchField::new( + "onboarding-telemetry-crash-reports", + "Help Fix Zed", + Some("Send crash reports so we can fix critical issues fast.".into()), + if TelemetrySettings::get_global(cx).diagnostics { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + { + let fs = fs.clone(); + move |selection, _, cx| { + let enabled = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { return; }, + }; + + update_settings_file::<TelemetrySettings>( + fs.clone(), + cx, + move |setting, _| setting.diagnostics = Some(enabled), + ); + } + } + ).tab_index({ + *tab_index += 1; + *tab_index + })) +} + +fn render_base_keymap_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { + let base_keymap = match BaseKeymap::get_global(cx) { + BaseKeymap::VSCode => Some(0), + BaseKeymap::JetBrains => Some(1), + BaseKeymap::SublimeText => Some(2), + BaseKeymap::Atom => Some(3), + BaseKeymap::Emacs => Some(4), + BaseKeymap::Cursor => Some(5), + BaseKeymap::TextMate | BaseKeymap::None => None, + }; + + return v_flex().gap_2().child(Label::new("Base Keymap")).child( + ToggleButtonGroup::two_rows( + "base_keymap_selection", + [ + ToggleButtonWithIcon::new("VS Code", IconName::EditorVsCode, |_, _, cx| { + write_keymap_base(BaseKeymap::VSCode, cx); + }), + ToggleButtonWithIcon::new("Jetbrains", IconName::EditorJetBrains, |_, _, cx| { + write_keymap_base(BaseKeymap::JetBrains, cx); + }), + ToggleButtonWithIcon::new("Sublime Text", IconName::EditorSublime, |_, _, cx| { + write_keymap_base(BaseKeymap::SublimeText, cx); + }), + ], + [ + ToggleButtonWithIcon::new("Atom", IconName::EditorAtom, |_, _, cx| { + write_keymap_base(BaseKeymap::Atom, cx); + }), + ToggleButtonWithIcon::new("Emacs", IconName::EditorEmacs, |_, _, cx| { + write_keymap_base(BaseKeymap::Emacs, cx); + }), + ToggleButtonWithIcon::new("Cursor (Beta)", IconName::EditorCursor, |_, _, cx| { + write_keymap_base(BaseKeymap::Cursor, cx); + }), + ], + ) + .when_some(base_keymap, |this, base_keymap| { + this.selected_index(base_keymap) + }) + .tab_index(tab_index) + .button_width(rems_from_px(216.)) + .size(ui::ToggleButtonGroupSize::Medium) + .style(ui::ToggleButtonGroupStyle::Outlined), + ); + + fn write_keymap_base(keymap_base: BaseKeymap, cx: &App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<BaseKeymap>(fs, cx, move |setting, _| { + *setting = Some(keymap_base); + }); + } +} + +fn render_vim_mode_switch(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { + let toggle_state = if VimModeSetting::get_global(cx).0 { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }; + SwitchField::new( + "onboarding-vim-mode", + "Vim Mode", + Some( + "Coming from Neovim? Zed's first-class implementation of Vim Mode has got your back." + .into(), + ), + toggle_state, + { + let fs = <dyn Fs>::global(cx); + move |&selection, _, cx| { + update_settings_file::<VimModeSetting>(fs.clone(), cx, move |setting, _| { + *setting = match selection { + ToggleState::Selected => Some(true), + ToggleState::Unselected => Some(false), + ToggleState::Indeterminate => None, + } + }); + } + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) +} + +pub(crate) fn render_basics_page(cx: &mut App) -> impl IntoElement { + let mut tab_index = 0; + v_flex() + .gap_6() + .child(render_theme_section(&mut tab_index, cx)) + .child(render_base_keymap_section(&mut tab_index, cx)) + .child(render_vim_mode_switch(&mut tab_index, cx)) + .child(render_telemetry_section(&mut tab_index, cx)) +} diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs new file mode 100644 index 0000000000..a8f0265b6b --- /dev/null +++ b/crates/onboarding/src/editing_page.rs @@ -0,0 +1,713 @@ +use std::sync::Arc; + +use editor::{EditorSettings, ShowMinimap}; +use fs::Fs; +use fuzzy::{StringMatch, StringMatchCandidate}; +use gpui::{ + Action, AnyElement, App, Context, FontFeatures, IntoElement, Pixels, SharedString, Task, Window, +}; +use language::language_settings::{AllLanguageSettings, FormatOnSave}; +use picker::{Picker, PickerDelegate}; +use project::project_settings::ProjectSettings; +use settings::{Settings as _, update_settings_file}; +use theme::{FontFamilyCache, FontFamilyName, ThemeSettings}; +use ui::{ + ButtonLike, ListItem, ListItemSpacing, NumericStepper, PopoverMenu, SwitchField, + ToggleButtonGroup, ToggleButtonGroupStyle, ToggleButtonSimple, ToggleState, Tooltip, + prelude::*, +}; + +use crate::{ImportCursorSettings, ImportVsCodeSettings, SettingsImportState}; + +fn read_show_mini_map(cx: &App) -> ShowMinimap { + editor::EditorSettings::get_global(cx).minimap.show +} + +fn write_show_mini_map(show: ShowMinimap, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + // This is used to speed up the UI + // the UI reads the current values to get what toggle state to show on buttons + // there's a slight delay if we just call update_settings_file so we manually set + // the value here then call update_settings file to get around the delay + let mut curr_settings = EditorSettings::get_global(cx).clone(); + curr_settings.minimap.show = show; + EditorSettings::override_global(curr_settings, cx); + + update_settings_file::<EditorSettings>(fs, cx, move |editor_settings, _| { + editor_settings.minimap.get_or_insert_default().show = Some(show); + }); +} + +fn read_inlay_hints(cx: &App) -> bool { + AllLanguageSettings::get_global(cx) + .defaults + .inlay_hints + .enabled +} + +fn write_inlay_hints(enabled: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + let mut curr_settings = AllLanguageSettings::get_global(cx).clone(); + curr_settings.defaults.inlay_hints.enabled = enabled; + AllLanguageSettings::override_global(curr_settings, cx); + + update_settings_file::<AllLanguageSettings>(fs, cx, move |all_language_settings, cx| { + all_language_settings + .defaults + .inlay_hints + .get_or_insert_with(|| { + AllLanguageSettings::get_global(cx) + .clone() + .defaults + .inlay_hints + }) + .enabled = enabled; + }); +} + +fn read_git_blame(cx: &App) -> bool { + ProjectSettings::get_global(cx).git.inline_blame_enabled() +} + +fn set_git_blame(enabled: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + let mut curr_settings = ProjectSettings::get_global(cx).clone(); + curr_settings + .git + .inline_blame + .get_or_insert_default() + .enabled = enabled; + ProjectSettings::override_global(curr_settings, cx); + + update_settings_file::<ProjectSettings>(fs, cx, move |project_settings, _| { + project_settings + .git + .inline_blame + .get_or_insert_default() + .enabled = enabled; + }); +} + +fn write_ui_font_family(font: SharedString, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.ui_font_family = Some(FontFamilyName(font.into())); + }); +} + +fn write_ui_font_size(size: Pixels, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.ui_font_size = Some(size.into()); + }); +} + +fn write_buffer_font_size(size: Pixels, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.buffer_font_size = Some(size.into()); + }); +} + +fn write_buffer_font_family(font_family: SharedString, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.buffer_font_family = Some(FontFamilyName(font_family.into())); + }); +} + +fn read_font_ligatures(cx: &App) -> bool { + ThemeSettings::get_global(cx) + .buffer_font + .features + .is_calt_enabled() + .unwrap_or(true) +} + +fn write_font_ligatures(enabled: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + let bit = if enabled { 1 } else { 0 }; + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + let mut features = theme_settings + .buffer_font_features + .as_mut() + .map(|features| features.tag_value_list().to_vec()) + .unwrap_or_default(); + + if let Some(calt_index) = features.iter().position(|(tag, _)| tag == "calt") { + features[calt_index].1 = bit; + } else { + features.push(("calt".into(), bit)); + } + + theme_settings.buffer_font_features = Some(FontFeatures(Arc::new(features))); + }); +} + +fn read_format_on_save(cx: &App) -> bool { + match AllLanguageSettings::get_global(cx).defaults.format_on_save { + FormatOnSave::On | FormatOnSave::List(_) => true, + FormatOnSave::Off => false, + } +} + +fn write_format_on_save(format_on_save: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<AllLanguageSettings>(fs, cx, move |language_settings, _| { + language_settings.defaults.format_on_save = Some(match format_on_save { + true => FormatOnSave::On, + false => FormatOnSave::Off, + }); + }); +} + +fn render_setting_import_button( + tab_index: isize, + label: SharedString, + icon_name: IconName, + action: &dyn Action, + imported: bool, +) -> impl IntoElement { + let action = action.boxed_clone(); + h_flex().w_full().child( + ButtonLike::new(label.clone()) + .full_width() + .style(ButtonStyle::Outlined) + .size(ButtonSize::Large) + .tab_index(tab_index) + .child( + h_flex() + .w_full() + .justify_between() + .child( + h_flex() + .gap_1p5() + .px_1() + .child( + Icon::new(icon_name) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(label)), + ) + .when(imported, |this| { + this.child( + h_flex() + .gap_1p5() + .child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ) + .child(Label::new("Imported").size(LabelSize::Small)), + ) + }), + ) + .on_click(move |_, window, cx| window.dispatch_action(action.boxed_clone(), cx)), + ) +} + +fn render_import_settings_section(tab_index: &mut isize, cx: &App) -> impl IntoElement { + let import_state = SettingsImportState::global(cx); + let imports: [(SharedString, IconName, &dyn Action, bool); 2] = [ + ( + "VS Code".into(), + IconName::EditorVsCode, + &ImportVsCodeSettings { skip_prompt: false }, + import_state.vscode, + ), + ( + "Cursor".into(), + IconName::EditorCursor, + &ImportCursorSettings { skip_prompt: false }, + import_state.cursor, + ), + ]; + + let [vscode, cursor] = imports.map(|(label, icon_name, action, imported)| { + *tab_index += 1; + render_setting_import_button(*tab_index - 1, label, icon_name, action, imported) + }); + + v_flex() + .gap_4() + .child( + v_flex() + .child(Label::new("Import Settings").size(LabelSize::Large)) + .child( + Label::new("Automatically pull your settings from other editors.") + .color(Color::Muted), + ), + ) + .child(h_flex().w_full().gap_4().child(vscode).child(cursor)) +} + +fn render_font_customization_section( + tab_index: &mut isize, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let theme_settings = ThemeSettings::get_global(cx); + let ui_font_size = theme_settings.ui_font_size(cx); + let ui_font_family = theme_settings.ui_font.family.clone(); + let buffer_font_family = theme_settings.buffer_font.family.clone(); + let buffer_font_size = theme_settings.buffer_font_size(cx); + + let ui_font_picker = + cx.new(|cx| font_picker(ui_font_family.clone(), write_ui_font_family, window, cx)); + + let buffer_font_picker = cx.new(|cx| { + font_picker( + buffer_font_family.clone(), + write_buffer_font_family, + window, + cx, + ) + }); + + let ui_font_handle = ui::PopoverMenuHandle::default(); + let buffer_font_handle = ui::PopoverMenuHandle::default(); + + h_flex() + .w_full() + .gap_4() + .child( + v_flex() + .w_full() + .gap_1() + .child(Label::new("UI Font")) + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + PopoverMenu::new("ui-font-picker") + .menu({ + let ui_font_picker = ui_font_picker.clone(); + move |_window, _cx| Some(ui_font_picker.clone()) + }) + .trigger( + ButtonLike::new("ui-font-family-button") + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .full_width() + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .child( + h_flex() + .w_full() + .justify_between() + .child(Label::new(ui_font_family)) + .child( + Icon::new(IconName::ChevronUpDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + ), + ) + .full_width(true) + .anchor(gpui::Corner::TopLeft) + .offset(gpui::Point { + x: px(0.0), + y: px(4.0), + }) + .with_handle(ui_font_handle), + ) + .child( + NumericStepper::new( + "ui-font-size", + ui_font_size.to_string(), + move |_, _, cx| { + write_ui_font_size(ui_font_size - px(1.), cx); + }, + move |_, _, cx| { + write_ui_font_size(ui_font_size + px(1.), cx); + }, + ) + .style(ui::NumericStepperStyle::Outlined) + .tab_index({ + *tab_index += 2; + *tab_index - 2 + }), + ), + ), + ) + .child( + v_flex() + .w_full() + .gap_1() + .child(Label::new("Editor Font")) + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + PopoverMenu::new("buffer-font-picker") + .menu({ + let buffer_font_picker = buffer_font_picker.clone(); + move |_window, _cx| Some(buffer_font_picker.clone()) + }) + .trigger( + ButtonLike::new("buffer-font-family-button") + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .full_width() + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .child( + h_flex() + .w_full() + .justify_between() + .child(Label::new(buffer_font_family)) + .child( + Icon::new(IconName::ChevronUpDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + ), + ) + .full_width(true) + .anchor(gpui::Corner::TopLeft) + .offset(gpui::Point { + x: px(0.0), + y: px(4.0), + }) + .with_handle(buffer_font_handle), + ) + .child( + NumericStepper::new( + "buffer-font-size", + buffer_font_size.to_string(), + move |_, _, cx| { + write_buffer_font_size(buffer_font_size - px(1.), cx); + }, + move |_, _, cx| { + write_buffer_font_size(buffer_font_size + px(1.), cx); + }, + ) + .style(ui::NumericStepperStyle::Outlined) + .tab_index({ + *tab_index += 2; + *tab_index - 2 + }), + ), + ), + ) +} + +type FontPicker = Picker<FontPickerDelegate>; + +pub struct FontPickerDelegate { + fonts: Vec<SharedString>, + filtered_fonts: Vec<StringMatch>, + selected_index: usize, + current_font: SharedString, + on_font_changed: Arc<dyn Fn(SharedString, &mut App) + 'static>, +} + +impl FontPickerDelegate { + fn new( + current_font: SharedString, + on_font_changed: impl Fn(SharedString, &mut App) + 'static, + cx: &mut Context<FontPicker>, + ) -> Self { + let font_family_cache = FontFamilyCache::global(cx); + + let fonts: Vec<SharedString> = font_family_cache + .list_font_families(cx) + .into_iter() + .collect(); + + let selected_index = fonts + .iter() + .position(|font| *font == current_font) + .unwrap_or(0); + + Self { + fonts: fonts.clone(), + filtered_fonts: fonts + .iter() + .enumerate() + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect(), + selected_index, + current_font, + on_font_changed: Arc::new(on_font_changed), + } + } +} + +impl PickerDelegate for FontPickerDelegate { + type ListItem = AnyElement; + + fn match_count(&self) -> usize { + self.filtered_fonts.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<FontPicker>) { + self.selected_index = ix.min(self.filtered_fonts.len().saturating_sub(1)); + cx.notify(); + } + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> { + "Search fonts…".into() + } + + fn update_matches( + &mut self, + query: String, + _window: &mut Window, + cx: &mut Context<FontPicker>, + ) -> Task<()> { + let fonts = self.fonts.clone(); + let current_font = self.current_font.clone(); + + let matches: Vec<StringMatch> = if query.is_empty() { + fonts + .iter() + .enumerate() + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect() + } else { + let _candidates: Vec<StringMatchCandidate> = fonts + .iter() + .enumerate() + .map(|(id, font)| StringMatchCandidate::new(id, font.as_ref())) + .collect(); + + fonts + .iter() + .enumerate() + .filter(|(_, font)| font.to_lowercase().contains(&query.to_lowercase())) + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect() + }; + + let selected_index = if query.is_empty() { + fonts + .iter() + .position(|font| *font == current_font) + .unwrap_or(0) + } else { + matches + .iter() + .position(|m| fonts[m.candidate_id] == current_font) + .unwrap_or(0) + }; + + self.filtered_fonts = matches; + self.selected_index = selected_index; + cx.notify(); + + Task::ready(()) + } + + fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<FontPicker>) { + if let Some(font_match) = self.filtered_fonts.get(self.selected_index) { + let font = font_match.string.clone(); + (self.on_font_changed)(font.into(), cx); + } + } + + fn dismissed(&mut self, _window: &mut Window, _cx: &mut Context<FontPicker>) {} + + fn render_match( + &self, + ix: usize, + selected: bool, + _window: &mut Window, + _cx: &mut Context<FontPicker>, + ) -> Option<Self::ListItem> { + let font_match = self.filtered_fonts.get(ix)?; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .child(Label::new(font_match.string.clone())) + .into_any_element(), + ) + } +} + +fn font_picker( + current_font: SharedString, + on_font_changed: impl Fn(SharedString, &mut App) + 'static, + window: &mut Window, + cx: &mut Context<FontPicker>, +) -> FontPicker { + let delegate = FontPickerDelegate::new(current_font, on_font_changed, cx); + + Picker::list(delegate, window, cx) + .show_scrollbar(true) + .width(rems_from_px(210.)) + .max_height(Some(rems(20.).into())) +} + +fn render_popular_settings_section( + tab_index: &mut isize, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + const LIGATURE_TOOLTIP: &'static str = "Ligatures are when a font creates a special character out of combining two characters into one. For example, with ligatures turned on, =/= would become ≠."; + + v_flex() + .gap_5() + .child(Label::new("Popular Settings").size(LabelSize::Large).mt_8()) + .child(render_font_customization_section(tab_index, window, cx)) + .child( + SwitchField::new( + "onboarding-font-ligatures", + "Font Ligatures", + Some("Combine text characters into their associated symbols.".into()), + if read_font_ligatures(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_font_ligatures(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .tooltip(Tooltip::text(LIGATURE_TOOLTIP)), + ) + .child( + SwitchField::new( + "onboarding-format-on-save", + "Format on Save", + Some("Format code automatically when saving.".into()), + if read_format_on_save(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_format_on_save(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) + .child( + SwitchField::new( + "onboarding-enable-inlay-hints", + "Inlay Hints", + Some("See parameter names for function and method calls inline.".into()), + if read_inlay_hints(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_inlay_hints(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) + .child( + SwitchField::new( + "onboarding-git-blame-switch", + "Git Blame", + Some("See who committed each line on a given file.".into()), + if read_git_blame(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + set_git_blame(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) + .child( + h_flex() + .items_start() + .justify_between() + .child( + v_flex().child(Label::new("Mini Map")).child( + Label::new("See a high-level overview of your source code.") + .color(Color::Muted), + ), + ) + .child( + ToggleButtonGroup::single_row( + "onboarding-show-mini-map", + [ + ToggleButtonSimple::new("Auto", |_, _, cx| { + write_show_mini_map(ShowMinimap::Auto, cx); + }), + ToggleButtonSimple::new("Always", |_, _, cx| { + write_show_mini_map(ShowMinimap::Always, cx); + }), + ToggleButtonSimple::new("Never", |_, _, cx| { + write_show_mini_map(ShowMinimap::Never, cx); + }), + ], + ) + .selected_index(match read_show_mini_map(cx) { + ShowMinimap::Auto => 0, + ShowMinimap::Always => 1, + ShowMinimap::Never => 2, + }) + .tab_index(tab_index) + .style(ToggleButtonGroupStyle::Outlined) + .button_width(ui::rems_from_px(64.)), + ), + ) +} + +pub(crate) fn render_editing_page(window: &mut Window, cx: &mut App) -> impl IntoElement { + let mut tab_index = 0; + v_flex() + .gap_4() + .child(render_import_settings_section(&mut tab_index, cx)) + .child(render_popular_settings_section(&mut tab_index, window, cx)) +} diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index dfdea1ca5b..c4d2b6847c 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -1,31 +1,61 @@ +use crate::welcome::{ShowWelcome, WelcomePage}; +use client::{Client, UserStore}; use command_palette_hooks::CommandPaletteFilter; use db::kvp::KEY_VALUE_STORE; use feature_flags::{FeatureFlag, FeatureFlagViewExt as _}; use fs::Fs; use gpui::{ - AnyElement, App, AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, - IntoElement, Render, SharedString, Subscription, Task, WeakEntity, Window, actions, + Action, AnyElement, App, AppContext, AsyncWindowContext, Context, Entity, EventEmitter, + FocusHandle, Focusable, Global, IntoElement, KeyContext, Render, SharedString, Subscription, + Task, WeakEntity, Window, actions, }; -use settings::{Settings, SettingsStore, update_settings_file}; +use notifications::status_toast::{StatusToast, ToastIcon}; +use schemars::JsonSchema; +use serde::Deserialize; +use settings::{SettingsStore, VsCodeSettingsSource}; use std::sync::Arc; -use theme::{ThemeMode, ThemeSettings}; use ui::{ - Divider, FluentBuilder, Headline, KeyBinding, ParentElement as _, StatefulInteractiveElement, - ToggleButtonGroup, ToggleButtonSimple, Vector, VectorName, prelude::*, rems_from_px, + Avatar, ButtonLike, FluentBuilder, Headline, KeyBinding, ParentElement as _, + StatefulInteractiveElement, Vector, VectorName, prelude::*, rems_from_px, }; use workspace::{ AppState, Workspace, WorkspaceId, dock::DockPosition, item::{Item, ItemEvent}, - open_new, with_active_or_new_workspace, + notifications::NotifyResultExt as _, + open_new, register_serializable_item, with_active_or_new_workspace, }; +mod ai_setup_page; +mod basics_page; +mod editing_page; +mod theme_preview; +mod welcome; + pub struct OnBoardingFeatureFlag {} impl FeatureFlag for OnBoardingFeatureFlag { const NAME: &'static str = "onboarding"; } +/// Imports settings from Visual Studio Code. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportVsCodeSettings { + #[serde(default)] + pub skip_prompt: bool, +} + +/// Imports settings from Cursor editor. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportCursorSettings { + #[serde(default)] + pub skip_prompt: bool, +} + pub const FIRST_OPEN: &str = "first_open"; actions!( @@ -36,6 +66,20 @@ actions!( ] ); +actions!( + onboarding, + [ + /// Activates the Basics page. + ActivateBasicsPage, + /// Activates the Editing page. + ActivateEditingPage, + /// Activates the AI Setup page. + ActivateAISetupPage, + /// Finish the onboarding process. + Finish, + ] +); + pub fn init(cx: &mut App) { cx.on_action(|_: &OpenOnboarding, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { @@ -50,7 +94,7 @@ pub fn init(cx: &mut App) { if let Some(existing) = existing { workspace.activate_item(&existing, true, true, window, cx); } else { - let settings_page = Onboarding::new(workspace.weak_handle(), cx); + let settings_page = Onboarding::new(workspace, cx); workspace.add_item_to_active_pane( Box::new(settings_page), None, @@ -63,12 +107,86 @@ pub fn init(cx: &mut App) { .detach(); }); }); + + cx.on_action(|_: &ShowWelcome, cx| { + with_active_or_new_workspace(cx, |workspace, window, cx| { + workspace + .with_local_workspace(window, cx, |workspace, window, cx| { + let existing = workspace + .active_pane() + .read(cx) + .items() + .find_map(|item| item.downcast::<WelcomePage>()); + + if let Some(existing) = existing { + workspace.activate_item(&existing, true, true, window, cx); + } else { + let settings_page = WelcomePage::new(window, cx); + workspace.add_item_to_active_pane( + Box::new(settings_page), + None, + true, + window, + cx, + ) + } + }) + .detach(); + }); + }); + + cx.observe_new(|workspace: &mut Workspace, _window, _cx| { + workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { + let fs = <dyn Fs>::global(cx); + let action = *action; + + let workspace = cx.weak_entity(); + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + workspace, + VsCodeSettingsSource::VsCode, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + + workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { + let fs = <dyn Fs>::global(cx); + let action = *action; + + let workspace = cx.weak_entity(); + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + workspace, + VsCodeSettingsSource::Cursor, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + }) + .detach(); + cx.observe_new::<Workspace>(|_, window, cx| { let Some(window) = window else { return; }; - let onboarding_actions = [std::any::TypeId::of::<OpenOnboarding>()]; + let onboarding_actions = [ + std::any::TypeId::of::<OpenOnboarding>(), + std::any::TypeId::of::<ShowWelcome>(), + ]; CommandPaletteFilter::update_global(cx, |filter, _cx| { filter.hide_action_types(&onboarding_actions); @@ -88,6 +206,7 @@ pub fn init(cx: &mut App) { .detach(); }) .detach(); + register_serializable_item::<Onboarding>(cx); } pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyhow::Result<()>> { @@ -98,7 +217,7 @@ pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyh |workspace, window, cx| { { workspace.toggle_dock(DockPosition::Left, window, cx); - let onboarding_page = Onboarding::new(workspace.weak_handle(), cx); + let onboarding_page = Onboarding::new(workspace, cx); workspace.add_item_to_center(Box::new(onboarding_page.clone()), window, cx); window.focus(&onboarding_page.focus_handle(cx)); @@ -112,23 +231,6 @@ pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyh ) } -fn read_theme_selection(cx: &App) -> ThemeMode { - let settings = ThemeSettings::get_global(cx); - settings - .theme_selection - .as_ref() - .and_then(|selection| selection.mode()) - .unwrap_or_default() -} - -fn write_theme_selection(theme_mode: ThemeMode, cx: &App) { - let fs = <dyn Fs>::global(cx); - - update_settings_file::<ThemeSettings>(fs, cx, move |settings, _| { - settings.set_mode(theme_mode); - }); -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SelectedPage { Basics, @@ -140,120 +242,233 @@ struct Onboarding { workspace: WeakEntity<Workspace>, focus_handle: FocusHandle, selected_page: SelectedPage, + user_store: Entity<UserStore>, _settings_subscription: Subscription, } impl Onboarding { - fn new(workspace: WeakEntity<Workspace>, cx: &mut App) -> Entity<Self> { + fn new(workspace: &Workspace, cx: &mut App) -> Entity<Self> { cx.new(|cx| Self { - workspace, + workspace: workspace.weak_handle(), focus_handle: cx.focus_handle(), selected_page: SelectedPage::Basics, + user_store: workspace.user_store().clone(), _settings_subscription: cx.observe_global::<SettingsStore>(move |_, cx| cx.notify()), }) } - fn render_page_nav( + fn set_page(&mut self, page: SelectedPage, cx: &mut Context<Self>) { + self.selected_page = page; + cx.notify(); + cx.emit(ItemEvent::UpdateTab); + } + + fn render_nav_buttons( &mut self, - page: SelectedPage, - _: &mut Window, + window: &mut Window, cx: &mut Context<Self>, - ) -> impl IntoElement { - let text = match page { - SelectedPage::Basics => "Basics", - SelectedPage::Editing => "Editing", - SelectedPage::AiSetup => "AI Setup", - }; - let binding = match page { - SelectedPage::Basics => { - KeyBinding::new(vec![gpui::Keystroke::parse("cmd-1").unwrap()], cx) - } - SelectedPage::Editing => { - KeyBinding::new(vec![gpui::Keystroke::parse("cmd-2").unwrap()], cx) - } - SelectedPage::AiSetup => { - KeyBinding::new(vec![gpui::Keystroke::parse("cmd-3").unwrap()], cx) - } - }; - let selected = self.selected_page == page; - h_flex() - .id(text) - .rounded_sm() - .child(text) - .child(binding) - .h_8() - .gap_2() - .px_2() - .py_0p5() - .w_full() + ) -> [impl IntoElement; 3] { + let pages = [ + SelectedPage::Basics, + SelectedPage::Editing, + SelectedPage::AiSetup, + ]; + + let text = ["Basics", "Editing", "AI Setup"]; + + let actions: [&dyn Action; 3] = [ + &ActivateBasicsPage, + &ActivateEditingPage, + &ActivateAISetupPage, + ]; + + let mut binding = actions.map(|action| { + KeyBinding::for_action_in(action, &self.focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))) + }); + + pages.map(|page| { + let i = page as usize; + let selected = self.selected_page == page; + h_flex() + .id(text[i]) + .relative() + .w_full() + .gap_2() + .px_2() + .py_0p5() + .justify_between() + .rounded_sm() + .when(selected, |this| { + this.child( + div() + .h_4() + .w_px() + .bg(cx.theme().colors().text_accent) + .absolute() + .left_0(), + ) + }) + .hover(|style| style.bg(cx.theme().colors().element_hover)) + .child(Label::new(text[i]).map(|this| { + if selected { + this.color(Color::Default) + } else { + this.color(Color::Muted) + } + })) + .child(binding[i].take().map_or( + gpui::Empty.into_any_element(), + IntoElement::into_any_element, + )) + .on_click(cx.listener(move |this, _, _, cx| { + this.set_page(page, cx); + })) + }) + } + + fn render_nav(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let ai_setup_page = matches!(self.selected_page, SelectedPage::AiSetup); + + v_flex() + .h_full() + .w(rems_from_px(220.)) + .flex_shrink_0() + .gap_4() .justify_between() - .map(|this| { - if selected { - this.bg(Color::Selected.color(cx)) - .border_l_1() - .border_color(Color::Accent.color(cx)) + .child( + v_flex() + .gap_6() + .child( + h_flex() + .px_2() + .gap_4() + .child(Vector::square(VectorName::ZedLogo, rems(2.5))) + .child( + v_flex() + .child( + Headline::new("Welcome to Zed").size(HeadlineSize::Small), + ) + .child( + Label::new("The editor for what's next") + .color(Color::Muted) + .size(LabelSize::Small) + .italic(), + ), + ), + ) + .child( + v_flex() + .gap_4() + .child( + v_flex() + .py_4() + .border_y_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .gap_1() + .children(self.render_nav_buttons(window, cx)), + ) + .map(|this| { + let keybinding = KeyBinding::for_action_in( + &Finish, + &self.focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))); + if ai_setup_page { + this.child( + ButtonLike::new("start_building") + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .child( + h_flex() + .ml_1() + .w_full() + .justify_between() + .child(Label::new("Start Building")) + .child(keybinding.map_or_else( + || { + Icon::new(IconName::Check) + .size(IconSize::Small) + .into_any_element() + }, + IntoElement::into_any_element, + )), + ) + .on_click(|_, window, cx| { + window.dispatch_action(Finish.boxed_clone(), cx); + }), + ) + } else { + this.child( + ButtonLike::new("skip_all") + .size(ButtonSize::Medium) + .child( + h_flex() + .ml_1() + .w_full() + .justify_between() + .child(Label::new("Skip All")) + .child(keybinding.map_or_else( + || gpui::Empty.into_any_element(), + IntoElement::into_any_element, + )), + ) + .on_click(|_, window, cx| { + window.dispatch_action(Finish.boxed_clone(), cx); + }), + ) + } + }), + ), + ) + .child( + if let Some(user) = self.user_store.read(cx).current_user() { + h_flex() + .pl_1p5() + .gap_2() + .child(Avatar::new(user.avatar_uri.clone())) + .child(Label::new(user.github_login.clone())) + .into_any_element() } else { - this.text_color(Color::Muted.color(cx)) - } - }) - .hover(|style| { - if selected { - style.bg(Color::Selected.color(cx).opacity(0.6)) - } else { - style.bg(Color::Selected.color(cx).opacity(0.3)) - } - }) - .on_click(cx.listener(move |this, _, _, cx| { - this.selected_page = page; - cx.notify(); - })) + Button::new("sign_in", "Sign In") + .full_width() + .style(ButtonStyle::Outlined) + .on_click(|_, window, cx| { + let client = Client::global(cx); + window + .spawn(cx, async move |cx| { + client + .sign_in_with_optional_connect(true, &cx) + .await + .notify_async_err(cx); + }) + .detach(); + }) + .into_any_element() + }, + ) } fn render_page(&mut self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement { match self.selected_page { - SelectedPage::Basics => self.render_basics_page(window, cx).into_any_element(), - SelectedPage::Editing => self.render_editing_page(window, cx).into_any_element(), - SelectedPage::AiSetup => self.render_ai_setup_page(window, cx).into_any_element(), + SelectedPage::Basics => crate::basics_page::render_basics_page(cx).into_any_element(), + SelectedPage::Editing => { + crate::editing_page::render_editing_page(window, cx).into_any_element() + } + SelectedPage::AiSetup => crate::ai_setup_page::render_ai_setup_page( + self.workspace.clone(), + self.user_store.clone(), + window, + cx, + ) + .into_any_element(), } } - fn render_basics_page(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - let theme_mode = read_theme_selection(cx); - - v_flex().child( - h_flex().justify_between().child(Label::new("Theme")).child( - ToggleButtonGroup::single_row( - "theme-selector-onboarding", - [ - ToggleButtonSimple::new("Light", |_, _, cx| { - write_theme_selection(ThemeMode::Light, cx) - }), - ToggleButtonSimple::new("Dark", |_, _, cx| { - write_theme_selection(ThemeMode::Dark, cx) - }), - ToggleButtonSimple::new("System", |_, _, cx| { - write_theme_selection(ThemeMode::System, cx) - }), - ], - ) - .selected_index(match theme_mode { - ThemeMode::Light => 0, - ThemeMode::Dark => 1, - ThemeMode::System => 2, - }) - .style(ui::ToggleButtonGroupStyle::Outlined) - .button_width(rems_from_px(64.)), - ), - ) - } - - fn render_editing_page(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - // div().child("editing page") - "Right" - } - - fn render_ai_setup_page(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - div().child("ai setup page") + fn on_finish(_: &Finish, _: &mut Window, cx: &mut App) { + go_to_welcome_page(cx); } } @@ -261,45 +476,54 @@ impl Render for Onboarding { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { h_flex() .image_cache(gpui::retain_all("onboarding-page")) - .key_context("onboarding-page") - .px_24() - .py_12() - .items_start() + .key_context({ + let mut ctx = KeyContext::new_with_defaults(); + ctx.add("Onboarding"); + ctx.add("menu"); + ctx + }) + .track_focus(&self.focus_handle) + .size_full() + .bg(cx.theme().colors().editor_background) + .on_action(Self::on_finish) + .on_action(cx.listener(|this, _: &ActivateBasicsPage, _, cx| { + this.set_page(SelectedPage::Basics, cx); + })) + .on_action(cx.listener(|this, _: &ActivateEditingPage, _, cx| { + this.set_page(SelectedPage::Editing, cx); + })) + .on_action(cx.listener(|this, _: &ActivateAISetupPage, _, cx| { + this.set_page(SelectedPage::AiSetup, cx); + })) + .on_action(cx.listener(|_, _: &menu::SelectNext, window, cx| { + window.focus_next(); + cx.notify(); + })) + .on_action(cx.listener(|_, _: &menu::SelectPrevious, window, cx| { + window.focus_prev(); + cx.notify(); + })) .child( - v_flex() - .w_1_3() - .h_full() + h_flex() + .max_w(rems_from_px(1100.)) + .size_full() + .m_auto() + .py_20() + .px_12() + .items_start() + .gap_12() + .child(self.render_nav(window, cx)) .child( - h_flex() - .pt_0p5() - .child(Vector::square(VectorName::ZedLogo, rems(2.))) - .child( - v_flex() - .left_1() - .items_center() - .child(Headline::new("Welcome to Zed")) - .child( - Label::new("The editor for what's next") - .color(Color::Muted) - .italic(), - ), - ), - ) - .p_1() - .child(Divider::horizontal_dashed()) - .child( - v_flex().gap_1().children([ - self.render_page_nav(SelectedPage::Basics, window, cx) - .into_element(), - self.render_page_nav(SelectedPage::Editing, window, cx) - .into_element(), - self.render_page_nav(SelectedPage::AiSetup, window, cx) - .into_element(), - ]), + v_flex() + .max_w_full() + .min_w_0() + .pl_12() + .border_l_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .size_full() + .child(self.render_page(window, cx)), ), ) - // .child(Divider::vertical_dashed()) - .child(div().w_2_3().h_full().child(self.render_page(window, cx))) } } @@ -332,10 +556,279 @@ impl Item for Onboarding { _: &mut Window, cx: &mut Context<Self>, ) -> Option<Entity<Self>> { - Some(Onboarding::new(self.workspace.clone(), cx)) + self.workspace + .update(cx, |workspace, cx| Onboarding::new(workspace, cx)) + .ok() } fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { f(*event) } } + +fn go_to_welcome_page(cx: &mut App) { + with_active_or_new_workspace(cx, |workspace, window, cx| { + let Some((onboarding_id, onboarding_idx)) = workspace + .active_pane() + .read(cx) + .items() + .enumerate() + .find_map(|(idx, item)| { + let _ = item.downcast::<Onboarding>()?; + Some((item.item_id(), idx)) + }) + else { + return; + }; + + workspace.active_pane().update(cx, |pane, cx| { + // Get the index here to get around the borrow checker + let idx = pane.items().enumerate().find_map(|(idx, item)| { + let _ = item.downcast::<WelcomePage>()?; + Some(idx) + }); + + if let Some(idx) = idx { + pane.activate_item(idx, true, true, window, cx); + } else { + let item = Box::new(WelcomePage::new(window, cx)); + pane.add_item(item, true, true, Some(onboarding_idx), window, cx); + } + + pane.remove_item(onboarding_id, false, false, window, cx); + }); + }); +} + +pub async fn handle_import_vscode_settings( + workspace: WeakEntity<Workspace>, + source: VsCodeSettingsSource, + skip_prompt: bool, + fs: Arc<dyn Fs>, + cx: &mut AsyncWindowContext, +) { + use util::truncate_and_remove_front; + + let vscode_settings = + match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { + Ok(vscode_settings) => vscode_settings, + Err(err) => { + zlog::error!("{err}"); + let _ = cx.prompt( + gpui::PromptLevel::Info, + &format!("Could not find or load a {source} settings file"), + None, + &["Ok"], + ); + return; + } + }; + + if !skip_prompt { + let prompt = cx.prompt( + gpui::PromptLevel::Warning, + &format!( + "Importing {} settings may overwrite your existing settings. \ + Will import settings from {}", + vscode_settings.source, + truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), + ), + None, + &["Ok", "Cancel"], + ); + let result = cx.spawn(async move |_| prompt.await.ok()).await; + if result != Some(0) { + return; + } + }; + + let Ok(result_channel) = cx.update(|_, cx| { + let source = vscode_settings.source; + let path = vscode_settings.path.clone(); + let result_channel = cx + .global::<SettingsStore>() + .import_vscode_settings(fs, vscode_settings); + zlog::info!("Imported {source} settings from {}", path.display()); + result_channel + }) else { + return; + }; + + let result = result_channel.await; + workspace + .update_in(cx, |workspace, _, cx| match result { + Ok(_) => { + let confirmation_toast = StatusToast::new( + format!("Your {} settings were successfully imported.", source), + cx, + |this, _| { + this.icon(ToastIcon::new(IconName::Check).color(Color::Success)) + .dismiss_button(true) + }, + ); + SettingsImportState::update(cx, |state, _| match source { + VsCodeSettingsSource::VsCode => { + state.vscode = true; + } + VsCodeSettingsSource::Cursor => { + state.cursor = true; + } + }); + workspace.toggle_status_toast(confirmation_toast, cx); + } + Err(_) => { + let error_toast = StatusToast::new( + "Failed to import settings. See log for details", + cx, + |this, _| { + this.icon(ToastIcon::new(IconName::X).color(Color::Error)) + .action("Open Log", |window, cx| { + window.dispatch_action(workspace::OpenLog.boxed_clone(), cx) + }) + .dismiss_button(true) + }, + ); + workspace.toggle_status_toast(error_toast, cx); + } + }) + .ok(); +} + +#[derive(Default, Copy, Clone)] +pub struct SettingsImportState { + pub cursor: bool, + pub vscode: bool, +} + +impl Global for SettingsImportState {} + +impl SettingsImportState { + pub fn global(cx: &App) -> Self { + cx.try_global().cloned().unwrap_or_default() + } + pub fn update<R>(cx: &mut App, f: impl FnOnce(&mut Self, &mut App) -> R) -> R { + cx.update_default_global(f) + } +} + +impl workspace::SerializableItem for Onboarding { + fn serialized_item_kind() -> &'static str { + "OnboardingPage" + } + + fn cleanup( + workspace_id: workspace::WorkspaceId, + alive_items: Vec<workspace::ItemId>, + _window: &mut Window, + cx: &mut App, + ) -> gpui::Task<gpui::Result<()>> { + workspace::delete_unloaded_items( + alive_items, + workspace_id, + "onboarding_pages", + &persistence::ONBOARDING_PAGES, + cx, + ) + } + + fn deserialize( + _project: Entity<project::Project>, + workspace: WeakEntity<Workspace>, + workspace_id: workspace::WorkspaceId, + item_id: workspace::ItemId, + window: &mut Window, + cx: &mut App, + ) -> gpui::Task<gpui::Result<Entity<Self>>> { + window.spawn(cx, async move |cx| { + if let Some(page_number) = + persistence::ONBOARDING_PAGES.get_onboarding_page(item_id, workspace_id)? + { + let page = match page_number { + 0 => Some(SelectedPage::Basics), + 1 => Some(SelectedPage::Editing), + 2 => Some(SelectedPage::AiSetup), + _ => None, + }; + workspace.update(cx, |workspace, cx| { + let onboarding_page = Onboarding::new(workspace, cx); + if let Some(page) = page { + zlog::info!("Onboarding page {page:?} loaded"); + onboarding_page.update(cx, |onboarding_page, cx| { + onboarding_page.set_page(page, cx); + }) + } + onboarding_page + }) + } else { + Err(anyhow::anyhow!("No onboarding page to deserialize")) + } + }) + } + + fn serialize( + &mut self, + workspace: &mut Workspace, + item_id: workspace::ItemId, + _closing: bool, + _window: &mut Window, + cx: &mut ui::Context<Self>, + ) -> Option<gpui::Task<gpui::Result<()>>> { + let workspace_id = workspace.database_id()?; + let page_number = self.selected_page as u16; + Some(cx.background_spawn(async move { + persistence::ONBOARDING_PAGES + .save_onboarding_page(item_id, workspace_id, page_number) + .await + })) + } + + fn should_serialize(&self, event: &Self::Event) -> bool { + event == &ItemEvent::UpdateTab + } +} + +mod persistence { + use db::{define_connection, query, sqlez_macros::sql}; + use workspace::WorkspaceDb; + + define_connection! { + pub static ref ONBOARDING_PAGES: OnboardingPagesDb<WorkspaceDb> = + &[ + sql!( + CREATE TABLE onboarding_pages ( + workspace_id INTEGER, + item_id INTEGER UNIQUE, + page_number INTEGER, + + PRIMARY KEY(workspace_id, item_id), + FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) + ON DELETE CASCADE + ) STRICT; + ), + ]; + } + + impl OnboardingPagesDb { + query! { + pub async fn save_onboarding_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId, + page_number: u16 + ) -> Result<()> { + INSERT OR REPLACE INTO onboarding_pages(item_id, workspace_id, page_number) + VALUES (?, ?, ?) + } + } + + query! { + pub fn get_onboarding_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId + ) -> Result<Option<u16>> { + SELECT page_number + FROM onboarding_pages + WHERE item_id = ? AND workspace_id = ? + } + } + } +} diff --git a/crates/onboarding/src/theme_preview.rs b/crates/onboarding/src/theme_preview.rs new file mode 100644 index 0000000000..53631be1c9 --- /dev/null +++ b/crates/onboarding/src/theme_preview.rs @@ -0,0 +1,366 @@ +#![allow(unused, dead_code)] +use gpui::{Hsla, Length}; +use std::sync::Arc; +use theme::{Theme, ThemeColors, ThemeRegistry}; +use ui::{ + IntoElement, RenderOnce, component_prelude::Documented, prelude::*, utils::inner_corner_radius, +}; + +#[derive(Clone, PartialEq)] +pub enum ThemePreviewStyle { + Bordered, + Borderless, + SideBySide(Arc<Theme>), +} + +/// Shows a preview of a theme as an abstract illustration +/// of a thumbnail-sized editor. +#[derive(IntoElement, RegisterComponent, Documented)] +pub struct ThemePreviewTile { + theme: Arc<Theme>, + seed: f32, + style: ThemePreviewStyle, +} + +impl ThemePreviewTile { + pub const SKELETON_HEIGHT_DEFAULT: Pixels = px(2.); + pub const SIDEBAR_SKELETON_ITEM_COUNT: usize = 8; + pub const SIDEBAR_WIDTH_DEFAULT: DefiniteLength = relative(0.25); + pub const ROOT_RADIUS: Pixels = px(8.0); + pub const ROOT_BORDER: Pixels = px(2.0); + pub const ROOT_PADDING: Pixels = px(2.0); + pub const CHILD_BORDER: Pixels = px(1.0); + pub const CHILD_RADIUS: std::cell::LazyCell<Pixels> = std::cell::LazyCell::new(|| { + inner_corner_radius( + Self::ROOT_RADIUS, + Self::ROOT_BORDER, + Self::ROOT_PADDING, + Self::CHILD_BORDER, + ) + }); + + pub fn new(theme: Arc<Theme>, seed: f32) -> Self { + Self { + theme, + seed, + style: ThemePreviewStyle::Bordered, + } + } + + pub fn style(mut self, style: ThemePreviewStyle) -> Self { + self.style = style; + self + } + + pub fn item_skeleton(w: Length, h: Length, bg: Hsla) -> impl IntoElement { + div().w(w).h(h).rounded_full().bg(bg) + } + + pub fn render_sidebar_skeleton_items( + seed: f32, + colors: &ThemeColors, + skeleton_height: impl Into<Length> + Clone, + ) -> [impl IntoElement; Self::SIDEBAR_SKELETON_ITEM_COUNT] { + let skeleton_height = skeleton_height.into(); + std::array::from_fn(|index| { + let width = { + let value = (seed * 1000.0 + index as f32 * 10.0).sin() * 0.5 + 0.5; + 0.5 + value * 0.45 + }; + Self::item_skeleton( + relative(width).into(), + skeleton_height, + colors.text.alpha(0.45), + ) + }) + } + + pub fn render_pseudo_code_skeleton( + seed: f32, + theme: Arc<Theme>, + skeleton_height: impl Into<Length>, + ) -> impl IntoElement { + let colors = theme.colors(); + let syntax = theme.syntax(); + + let keyword_color = syntax.get("keyword").color; + let function_color = syntax.get("function").color; + let string_color = syntax.get("string").color; + let comment_color = syntax.get("comment").color; + let variable_color = syntax.get("variable").color; + let type_color = syntax.get("type").color; + let punctuation_color = syntax.get("punctuation").color; + + let syntax_colors = [ + keyword_color, + function_color, + string_color, + variable_color, + type_color, + punctuation_color, + comment_color, + ]; + + let skeleton_height = skeleton_height.into(); + + let line_width = |line_idx: usize, block_idx: usize| -> f32 { + let val = + (seed * 100.0 + line_idx as f32 * 20.0 + block_idx as f32 * 5.0).sin() * 0.5 + 0.5; + 0.05 + val * 0.2 + }; + + let indentation = |line_idx: usize| -> f32 { + let step = line_idx % 6; + if step < 3 { + step as f32 * 0.1 + } else { + (5 - step) as f32 * 0.1 + } + }; + + let pick_color = |line_idx: usize, block_idx: usize| -> Hsla { + let idx = ((seed * 10.0 + line_idx as f32 * 7.0 + block_idx as f32 * 3.0).sin() * 3.5) + .abs() as usize + % syntax_colors.len(); + syntax_colors[idx].unwrap_or(colors.text) + }; + + let line_count = 13; + + let lines = (0..line_count) + .map(|line_idx| { + let block_count = (((seed * 30.0 + line_idx as f32 * 12.0).sin() * 0.5 + 0.5) * 3.0) + .round() as usize + + 2; + + let indent = indentation(line_idx); + + let blocks = (0..block_count) + .map(|block_idx| { + let width = line_width(line_idx, block_idx); + let color = pick_color(line_idx, block_idx); + Self::item_skeleton(relative(width).into(), skeleton_height, color) + }) + .collect::<Vec<_>>(); + + h_flex().gap(px(2.)).ml(relative(indent)).children(blocks) + }) + .collect::<Vec<_>>(); + + v_flex().size_full().p_1().gap_1p5().children(lines) + } + + pub fn render_sidebar( + seed: f32, + colors: &ThemeColors, + width: impl Into<Length> + Clone, + skeleton_height: impl Into<Length>, + ) -> impl IntoElement { + div() + .h_full() + .w(width) + .border_r(px(1.)) + .border_color(colors.border_transparent) + .bg(colors.panel_background) + .child(v_flex().p_2().size_full().gap_1().children( + Self::render_sidebar_skeleton_items(seed, colors, skeleton_height.into()), + )) + } + + pub fn render_pane( + seed: f32, + theme: Arc<Theme>, + skeleton_height: impl Into<Length>, + ) -> impl IntoElement { + v_flex().h_full().flex_grow().child( + div() + .size_full() + .overflow_hidden() + .bg(theme.colors().editor_background) + .p_2() + .child(Self::render_pseudo_code_skeleton( + seed, + theme, + skeleton_height.into(), + )), + ) + } + + pub fn render_editor( + seed: f32, + theme: Arc<Theme>, + sidebar_width: impl Into<Length> + Clone, + skeleton_height: impl Into<Length> + Clone, + ) -> impl IntoElement { + div() + .size_full() + .flex() + .bg(theme.colors().background.alpha(1.00)) + .child(Self::render_sidebar( + seed, + theme.colors(), + sidebar_width, + skeleton_height.clone(), + )) + .child(Self::render_pane(seed, theme, skeleton_height.clone())) + } + + fn render_borderless(seed: f32, theme: Arc<Theme>) -> impl IntoElement { + return Self::render_editor( + seed, + theme, + Self::SIDEBAR_WIDTH_DEFAULT, + Self::SKELETON_HEIGHT_DEFAULT, + ); + } + + fn render_border(seed: f32, theme: Arc<Theme>) -> impl IntoElement { + div() + .size_full() + .p(Self::ROOT_PADDING) + .rounded(Self::ROOT_RADIUS) + .child( + div() + .size_full() + .rounded(*Self::CHILD_RADIUS) + .border(Self::CHILD_BORDER) + .border_color(theme.colors().border) + .child(Self::render_editor( + seed, + theme.clone(), + Self::SIDEBAR_WIDTH_DEFAULT, + Self::SKELETON_HEIGHT_DEFAULT, + )), + ) + } + + fn render_side_by_side( + seed: f32, + theme: Arc<Theme>, + other_theme: Arc<Theme>, + border_color: Hsla, + ) -> impl IntoElement { + let sidebar_width = relative(0.20); + + return div() + .size_full() + .p(Self::ROOT_PADDING) + .rounded(Self::ROOT_RADIUS) + .child( + h_flex() + .size_full() + .relative() + .rounded(*Self::CHILD_RADIUS) + .border(Self::CHILD_BORDER) + .border_color(border_color) + .overflow_hidden() + .child(div().size_full().child(Self::render_editor( + seed, + theme.clone(), + sidebar_width, + Self::SKELETON_HEIGHT_DEFAULT, + ))) + .child( + div() + .size_full() + .absolute() + .left_1_2() + .bg(other_theme.colors().editor_background) + .child(Self::render_editor( + seed, + other_theme, + sidebar_width, + Self::SKELETON_HEIGHT_DEFAULT, + )), + ), + ) + .into_any_element(); + } +} + +impl RenderOnce for ThemePreviewTile { + fn render(self, _window: &mut ui::Window, _cx: &mut ui::App) -> impl IntoElement { + match self.style { + ThemePreviewStyle::Bordered => { + Self::render_border(self.seed, self.theme).into_any_element() + } + ThemePreviewStyle::Borderless => { + Self::render_borderless(self.seed, self.theme).into_any_element() + } + ThemePreviewStyle::SideBySide(other_theme) => Self::render_side_by_side( + self.seed, + self.theme, + other_theme, + _cx.theme().colors().border, + ) + .into_any_element(), + } + } +} + +impl Component for ThemePreviewTile { + fn description() -> Option<&'static str> { + Some(Self::DOCS) + } + + fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> { + let theme_registry = ThemeRegistry::global(cx); + + let one_dark = theme_registry.get("One Dark"); + let one_light = theme_registry.get("One Light"); + let gruvbox_dark = theme_registry.get("Gruvbox Dark"); + let gruvbox_light = theme_registry.get("Gruvbox Light"); + + let themes_to_preview = vec![ + one_dark.clone().ok(), + one_light.clone().ok(), + gruvbox_dark.clone().ok(), + gruvbox_light.clone().ok(), + ] + .into_iter() + .flatten() + .collect::<Vec<_>>(); + + Some( + v_flex() + .gap_6() + .p_4() + .children({ + if let Some(one_dark) = one_dark.ok() { + vec![example_group(vec![single_example( + "Default", + div() + .w(px(240.)) + .h(px(180.)) + .child(ThemePreviewTile::new(one_dark.clone(), 0.42)) + .into_any_element(), + )])] + } else { + vec![] + } + }) + .child( + example_group(vec![single_example( + "Default Themes", + h_flex() + .gap_4() + .children( + themes_to_preview + .iter() + .enumerate() + .map(|(_, theme)| { + div() + .w(px(200.)) + .h(px(140.)) + .child(ThemePreviewTile::new(theme.clone(), 0.42)) + }) + .collect::<Vec<_>>(), + ) + .into_any_element(), + )]) + .grow(), + ) + .into_any_element(), + ) + } +} diff --git a/crates/onboarding/src/welcome.rs b/crates/onboarding/src/welcome.rs new file mode 100644 index 0000000000..d4d6c3f701 --- /dev/null +++ b/crates/onboarding/src/welcome.rs @@ -0,0 +1,355 @@ +use gpui::{ + Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, + NoAction, ParentElement, Render, Styled, Window, actions, +}; +use menu::{SelectNext, SelectPrevious}; +use ui::{ButtonLike, Divider, DividerColor, KeyBinding, Vector, VectorName, prelude::*}; +use workspace::{ + NewFile, Open, WorkspaceId, + item::{Item, ItemEvent}, + with_active_or_new_workspace, +}; +use zed_actions::{Extensions, OpenSettings, agent, command_palette}; + +use crate::{Onboarding, OpenOnboarding}; + +actions!( + zed, + [ + /// Show the Zed welcome screen + ShowWelcome + ] +); + +const CONTENT: (Section<4>, Section<3>) = ( + Section { + title: "Get Started", + entries: [ + SectionEntry { + icon: IconName::Plus, + title: "New File", + action: &NewFile, + }, + SectionEntry { + icon: IconName::FolderOpen, + title: "Open Project", + action: &Open, + }, + SectionEntry { + icon: IconName::CloudDownload, + title: "Clone a Repo", + // TODO: use proper action + action: &NoAction, + }, + SectionEntry { + icon: IconName::ListCollapse, + title: "Open Command Palette", + action: &command_palette::Toggle, + }, + ], + }, + Section { + title: "Configure", + entries: [ + SectionEntry { + icon: IconName::Settings, + title: "Open Settings", + action: &OpenSettings, + }, + SectionEntry { + icon: IconName::ZedAssistant, + title: "View AI Settings", + action: &agent::OpenSettings, + }, + SectionEntry { + icon: IconName::Blocks, + title: "Explore Extensions", + action: &Extensions { + category_filter: None, + id: None, + }, + }, + ], + }, +); + +struct Section<const COLS: usize> { + title: &'static str, + entries: [SectionEntry; COLS], +} + +impl<const COLS: usize> Section<COLS> { + fn render( + self, + index_offset: usize, + focus: &FocusHandle, + window: &mut Window, + cx: &mut App, + ) -> impl IntoElement { + v_flex() + .min_w_full() + .child( + h_flex() + .px_1() + .mb_2() + .gap_2() + .child( + Label::new(self.title.to_ascii_uppercase()) + .buffer_font(cx) + .color(Color::Muted) + .size(LabelSize::XSmall), + ) + .child(Divider::horizontal().color(DividerColor::BorderVariant)), + ) + .children( + self.entries + .iter() + .enumerate() + .map(|(index, entry)| entry.render(index_offset + index, &focus, window, cx)), + ) + } +} + +struct SectionEntry { + icon: IconName, + title: &'static str, + action: &'static dyn Action, +} + +impl SectionEntry { + fn render( + &self, + button_index: usize, + focus: &FocusHandle, + window: &Window, + cx: &App, + ) -> impl IntoElement { + ButtonLike::new(("onboarding-button-id", button_index)) + .tab_index(button_index as isize) + .full_width() + .size(ButtonSize::Medium) + .child( + h_flex() + .w_full() + .justify_between() + .child( + h_flex() + .gap_2() + .child( + Icon::new(self.icon) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(self.title)), + ) + .children( + KeyBinding::for_action_in(self.action, focus, window, cx) + .map(|s| s.size(rems_from_px(12.))), + ), + ) + .on_click(|_, window, cx| window.dispatch_action(self.action.boxed_clone(), cx)) + } +} + +pub struct WelcomePage { + focus_handle: FocusHandle, +} + +impl WelcomePage { + fn select_next(&mut self, _: &SelectNext, window: &mut Window, cx: &mut Context<Self>) { + window.focus_next(); + cx.notify(); + } + + fn select_previous(&mut self, _: &SelectPrevious, window: &mut Window, cx: &mut Context<Self>) { + window.focus_prev(); + cx.notify(); + } +} + +impl Render for WelcomePage { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let (first_section, second_section) = CONTENT; + let first_section_entries = first_section.entries.len(); + let last_index = first_section_entries + second_section.entries.len(); + + h_flex() + .size_full() + .justify_center() + .overflow_hidden() + .bg(cx.theme().colors().editor_background) + .key_context("Welcome") + .track_focus(&self.focus_handle(cx)) + .on_action(cx.listener(Self::select_previous)) + .on_action(cx.listener(Self::select_next)) + .child( + h_flex() + .px_12() + .py_40() + .size_full() + .relative() + .max_w(px(1100.)) + .child( + div() + .size_full() + .max_w_128() + .mx_auto() + .child( + h_flex() + .w_full() + .justify_center() + .gap_4() + .child(Vector::square(VectorName::ZedLogo, rems(2.))) + .child( + div().child(Headline::new("Welcome to Zed")).child( + Label::new("The editor for what's next") + .size(LabelSize::Small) + .color(Color::Muted) + .italic(), + ), + ), + ) + .child( + v_flex() + .mt_10() + .gap_6() + .child(first_section.render( + Default::default(), + &self.focus_handle, + window, + cx, + )) + .child(second_section.render( + first_section_entries, + &self.focus_handle, + window, + cx, + )) + .child( + h_flex() + .w_full() + .pt_4() + .justify_center() + // We call this a hack + .rounded_b_xs() + .border_t_1() + .border_color(cx.theme().colors().border.opacity(0.6)) + .border_dashed() + .child( + Button::new("welcome-exit", "Return to Setup") + .tab_index(last_index as isize) + .full_width() + .label_size(LabelSize::XSmall) + .on_click(|_, window, cx| { + window.dispatch_action( + OpenOnboarding.boxed_clone(), + cx, + ); + + with_active_or_new_workspace(cx, |workspace, window, cx| { + let Some((welcome_id, welcome_idx)) = workspace + .active_pane() + .read(cx) + .items() + .enumerate() + .find_map(|(idx, item)| { + let _ = item.downcast::<WelcomePage>()?; + Some((item.item_id(), idx)) + }) + else { + return; + }; + + workspace.active_pane().update(cx, |pane, cx| { + // Get the index here to get around the borrow checker + let idx = pane.items().enumerate().find_map( + |(idx, item)| { + let _ = + item.downcast::<Onboarding>()?; + Some(idx) + }, + ); + + if let Some(idx) = idx { + pane.activate_item( + idx, true, true, window, cx, + ); + } else { + let item = + Box::new(Onboarding::new(workspace, cx)); + pane.add_item( + item, + true, + true, + Some(welcome_idx), + window, + cx, + ); + } + + pane.remove_item( + welcome_id, + false, + false, + window, + cx, + ); + }); + }); + }), + ), + ), + ), + ), + ) + } +} + +impl WelcomePage { + pub fn new(window: &mut Window, cx: &mut App) -> Entity<Self> { + cx.new(|cx| { + let focus_handle = cx.focus_handle(); + cx.on_focus(&focus_handle, window, |_, _, cx| cx.notify()) + .detach(); + + WelcomePage { focus_handle } + }) + } +} + +impl EventEmitter<ItemEvent> for WelcomePage {} + +impl Focusable for WelcomePage { + fn focus_handle(&self, _: &App) -> gpui::FocusHandle { + self.focus_handle.clone() + } +} + +impl Item for WelcomePage { + type Event = ItemEvent; + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Welcome".into() + } + + fn telemetry_event_text(&self) -> Option<&'static str> { + Some("New Welcome Page Opened") + } + + fn show_toolbar(&self) -> bool { + false + } + + fn clone_on_split( + &self, + _workspace_id: Option<WorkspaceId>, + _: &mut Window, + _: &mut Context<Self>, + ) -> Option<Entity<Self>> { + None + } + + fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { + f(*event) + } +} diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index 50c6c2dcce..ad96670db9 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -1041,7 +1041,7 @@ impl OutlinePanel { fn open_excerpts( &mut self, - action: &editor::OpenExcerpts, + action: &editor::actions::OpenExcerpts, window: &mut Window, cx: &mut Context<Self>, ) { @@ -1057,7 +1057,7 @@ impl OutlinePanel { fn open_excerpts_split( &mut self, - action: &editor::OpenExcerptsSplit, + action: &editor::actions::OpenExcerptsSplit, window: &mut Window, cx: &mut Context<Self>, ) { @@ -5958,7 +5958,7 @@ mod tests { }); outline_panel.update_in(cx, |outline_panel, window, cx| { - outline_panel.open_excerpts(&editor::OpenExcerpts, window, cx); + outline_panel.open_excerpts(&editor::actions::OpenExcerpts, window, cx); }); cx.executor() .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); diff --git a/crates/paths/src/paths.rs b/crates/paths/src/paths.rs index 2f3b188980..47a0f12c06 100644 --- a/crates/paths/src/paths.rs +++ b/crates/paths/src/paths.rs @@ -35,6 +35,7 @@ pub fn remote_server_dir_relative() -> &'static Path { /// Sets a custom directory for all user data, overriding the default data directory. /// This function must be called before any other path operations that depend on the data directory. +/// The directory's path will be canonicalized to an absolute path by a blocking FS operation. /// The directory will be created if it doesn't exist. /// /// # Arguments @@ -50,13 +51,20 @@ pub fn remote_server_dir_relative() -> &'static Path { /// /// Panics if: /// * Called after the data directory has been initialized (e.g., via `data_dir` or `config_dir`) +/// * The directory's path cannot be canonicalized to an absolute path /// * The directory cannot be created pub fn set_custom_data_dir(dir: &str) -> &'static PathBuf { if CURRENT_DATA_DIR.get().is_some() || CONFIG_DIR.get().is_some() { panic!("set_custom_data_dir called after data_dir or config_dir was initialized"); } CUSTOM_DATA_DIR.get_or_init(|| { - let path = PathBuf::from(dir); + let mut path = PathBuf::from(dir); + if path.is_relative() { + let abs_path = path + .canonicalize() + .expect("failed to canonicalize custom data directory's path to an absolute path"); + path = PathBuf::from(util::paths::SanitizedPath::from(abs_path)) + } std::fs::create_dir_all(&path).expect("failed to create custom data directory"); path }) diff --git a/crates/picker/src/popover_menu.rs b/crates/picker/src/popover_menu.rs index dd1d9c2865..d05308ee71 100644 --- a/crates/picker/src/popover_menu.rs +++ b/crates/picker/src/popover_menu.rs @@ -80,6 +80,7 @@ where { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { let picker = self.picker.clone(); + PopoverMenu::new("popover-menu") .menu(move |_window, _cx| Some(picker.clone())) .trigger_with_tooltip(self.trigger, self.tooltip) diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index ceec0c0a52..c96ab4e8f3 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -13,6 +13,7 @@ use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; use crate::{ + Project, project_settings::{ContextServerSettings, ProjectSettings}, worktree_store::WorktreeStore, }; @@ -144,6 +145,7 @@ pub struct ContextServerStore { context_server_settings: HashMap<Arc<str>, ContextServerSettings>, servers: HashMap<ContextServerId, ContextServerState>, worktree_store: Entity<WorktreeStore>, + project: WeakEntity<Project>, registry: Entity<ContextServerDescriptorRegistry>, update_servers_task: Option<Task<Result<()>>>, context_server_factory: Option<ContextServerFactory>, @@ -161,12 +163,17 @@ pub enum Event { impl EventEmitter<Event> for ContextServerStore {} impl ContextServerStore { - pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self { + pub fn new( + worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, + cx: &mut Context<Self>, + ) -> Self { Self::new_internal( true, None, ContextServerDescriptorRegistry::default_global(cx), worktree_store, + weak_project, cx, ) } @@ -184,9 +191,10 @@ impl ContextServerStore { pub fn test( registry: Entity<ContextServerDescriptorRegistry>, worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, cx: &mut Context<Self>, ) -> Self { - Self::new_internal(false, None, registry, worktree_store, cx) + Self::new_internal(false, None, registry, worktree_store, weak_project, cx) } #[cfg(any(test, feature = "test-support"))] @@ -194,6 +202,7 @@ impl ContextServerStore { context_server_factory: ContextServerFactory, registry: Entity<ContextServerDescriptorRegistry>, worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, cx: &mut Context<Self>, ) -> Self { Self::new_internal( @@ -201,6 +210,7 @@ impl ContextServerStore { Some(context_server_factory), registry, worktree_store, + weak_project, cx, ) } @@ -210,6 +220,7 @@ impl ContextServerStore { context_server_factory: Option<ContextServerFactory>, registry: Entity<ContextServerDescriptorRegistry>, worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, cx: &mut Context<Self>, ) -> Self { let subscriptions = if maintain_server_loop { @@ -235,6 +246,7 @@ impl ContextServerStore { context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx) .clone(), worktree_store, + project: weak_project, registry, needs_server_update: false, servers: HashMap::default(), @@ -360,7 +372,7 @@ impl ContextServerStore { let configuration = state.configuration(); self.stop_server(&state.server().id(), cx)?; - let new_server = self.create_context_server(id.clone(), configuration.clone())?; + let new_server = self.create_context_server(id.clone(), configuration.clone(), cx); self.run_server(new_server, configuration, cx); } Ok(()) @@ -449,14 +461,33 @@ impl ContextServerStore { &self, id: ContextServerId, configuration: Arc<ContextServerConfiguration>, - ) -> Result<Arc<ContextServer>> { + cx: &mut Context<Self>, + ) -> Arc<ContextServer> { + let root_path = self + .project + .read_with(cx, |project, cx| project.active_project_directory(cx)) + .ok() + .flatten() + .or_else(|| { + self.worktree_store.read_with(cx, |store, cx| { + store.visible_worktrees(cx).fold(None, |acc, item| { + if acc.is_none() { + item.read(cx).root_dir() + } else { + acc + } + }) + }) + }); + if let Some(factory) = self.context_server_factory.as_ref() { - Ok(factory(id, configuration)) + factory(id, configuration) } else { - Ok(Arc::new(ContextServer::stdio( + Arc::new(ContextServer::stdio( id, configuration.command().clone(), - ))) + root_path, + )) } } @@ -553,7 +584,7 @@ impl ContextServerStore { let mut servers_to_remove = HashSet::default(); let mut servers_to_stop = HashSet::default(); - this.update(cx, |this, _cx| { + this.update(cx, |this, cx| { for server_id in this.servers.keys() { // All servers that are not in desired_servers should be removed from the store. // This can happen if the user removed a server from the context server settings. @@ -572,14 +603,10 @@ impl ContextServerStore { let existing_config = state.as_ref().map(|state| state.configuration()); if existing_config.as_deref() != Some(&config) || is_stopped { let config = Arc::new(config); - if let Some(server) = this - .create_context_server(id.clone(), config.clone()) - .log_err() - { - servers_to_start.push((server, config)); - if this.servers.contains_key(&id) { - servers_to_stop.insert(id); - } + let server = this.create_context_server(id.clone(), config.clone(), cx); + servers_to_start.push((server, config)); + if this.servers.contains_key(&id) { + servers_to_stop.insert(id); } } } @@ -630,7 +657,12 @@ mod tests { let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); let store = cx.new(|cx| { - ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + ContextServerStore::test( + registry.clone(), + project.read(cx).worktree_store(), + project.downgrade(), + cx, + ) }); let server_1_id = ContextServerId(SERVER_1_ID.into()); @@ -705,7 +737,12 @@ mod tests { let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); let store = cx.new(|cx| { - ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + ContextServerStore::test( + registry.clone(), + project.read(cx).worktree_store(), + project.downgrade(), + cx, + ) }); let server_1_id = ContextServerId(SERVER_1_ID.into()); @@ -758,7 +795,12 @@ mod tests { let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); let store = cx.new(|cx| { - ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + ContextServerStore::test( + registry.clone(), + project.read(cx).worktree_store(), + project.downgrade(), + cx, + ) }); let server_id = ContextServerId(SERVER_1_ID.into()); @@ -842,6 +884,7 @@ mod tests { }), registry.clone(), project.read(cx).worktree_store(), + project.downgrade(), cx, ) }); @@ -1074,6 +1117,7 @@ mod tests { }), registry.clone(), project.read(cx).worktree_store(), + project.downgrade(), cx, ) }); diff --git a/crates/project/src/debugger/test.rs b/crates/project/src/debugger/test.rs index 3b9425e369..53b88323e6 100644 --- a/crates/project/src/debugger/test.rs +++ b/crates/project/src/debugger/test.rs @@ -1,7 +1,7 @@ use std::{path::Path, sync::Arc}; use dap::client::DebugAdapterClient; -use gpui::{App, AppContext, Subscription}; +use gpui::{App, Subscription}; use super::session::{Session, SessionStateEvent}; @@ -19,14 +19,6 @@ pub fn intercept_debug_sessions<T: Fn(&Arc<DebugAdapterClient>) + 'static>( let client = session.adapter_client().unwrap(); register_default_handlers(session, &client, cx); configure(&client); - cx.background_spawn(async move { - client - .fake_event(dap::messages::Events::Initialized( - Some(Default::default()), - )) - .await - }) - .detach(); } }) .detach(); diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index 28dd0e91e3..01fc987816 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -246,6 +246,8 @@ pub struct RepositorySnapshot { pub head_commit: Option<CommitDetails>, pub scan_id: u64, pub merge: MergeDetails, + pub remote_origin_url: Option<String>, + pub remote_upstream_url: Option<String>, } type JobId = u64; @@ -2673,6 +2675,8 @@ impl RepositorySnapshot { head_commit: None, scan_id: 0, merge: Default::default(), + remote_origin_url: None, + remote_upstream_url: None, } } @@ -4025,6 +4029,25 @@ impl Repository { }) } + pub fn default_branch(&mut self) -> oneshot::Receiver<Result<Option<SharedString>>> { + let id = self.id; + self.send_job(None, move |repo, _| async move { + match repo { + RepositoryState::Local { backend, .. } => backend.default_branch().await, + RepositoryState::Remote { project_id, client } => { + let response = client + .request(proto::GetDefaultBranch { + project_id: project_id.0, + repository_id: id.to_proto(), + }) + .await?; + + anyhow::Ok(response.branch.map(SharedString::from)) + } + } + }) + } + pub fn diff(&mut self, diff_type: DiffType, _cx: &App) -> oneshot::Receiver<Result<String>> { let id = self.id; self.send_job(None, move |repo, _cx| async move { @@ -4799,6 +4822,10 @@ async fn compute_snapshot( None => None, }; + // Used by edit prediction data collection + let remote_origin_url = backend.remote_url("origin"); + let remote_upstream_url = backend.remote_url("upstream"); + let snapshot = RepositorySnapshot { id, statuses_by_path, @@ -4807,6 +4834,8 @@ async fn compute_snapshot( branch, head_commit, merge: merge_details, + remote_origin_url, + remote_upstream_url, }; Ok((snapshot, events)) diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index a2f6de44c9..f8e69e2185 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -2154,6 +2154,16 @@ impl LspCommand for GetHover { } } +impl GetCompletions { + pub fn can_resolve_completions(capabilities: &lsp::ServerCapabilities) -> bool { + capabilities + .completion_provider + .as_ref() + .and_then(|options| options.resolve_provider) + .unwrap_or(false) + } +} + #[async_trait(?Send)] impl LspCommand for GetCompletions { type Response = CoreCompletionResponse; @@ -2269,7 +2279,7 @@ impl LspCommand for GetCompletions { // the range based on the syntax tree. None => { if self.position != clipped_position { - log::info!("completion out of expected range"); + log::info!("completion out of expected range "); return false; } @@ -2483,7 +2493,9 @@ pub(crate) fn parse_completion_text_edit( let start = snapshot.clip_point_utf16(range.start, Bias::Left); let end = snapshot.clip_point_utf16(range.end, Bias::Left); if start != range.start.0 || end != range.end.0 { - log::info!("completion out of expected range"); + log::info!( + "completion out of expected range, start: {start:?}, end: {end:?}, range: {range:?}" + ); return None; } snapshot.anchor_before(start)..snapshot.anchor_after(end) @@ -2760,6 +2772,23 @@ impl GetCodeActions { } } +impl OnTypeFormatting { + pub fn supports_on_type_formatting(trigger: &str, capabilities: &ServerCapabilities) -> bool { + let Some(on_type_formatting_options) = &capabilities.document_on_type_formatting_provider + else { + return false; + }; + on_type_formatting_options + .first_trigger_character + .contains(trigger) + || on_type_formatting_options + .more_trigger_character + .iter() + .flatten() + .any(|chars| chars.contains(trigger)) + } +} + #[async_trait(?Send)] impl LspCommand for OnTypeFormatting { type Response = Option<Transaction>; @@ -2771,20 +2800,7 @@ impl LspCommand for OnTypeFormatting { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - let Some(on_type_formatting_options) = &capabilities - .server_capabilities - .document_on_type_formatting_provider - else { - return false; - }; - on_type_formatting_options - .first_trigger_character - .contains(&self.trigger) - || on_type_formatting_options - .more_trigger_character - .iter() - .flatten() - .any(|chars| chars.contains(&self.trigger)) + Self::supports_on_type_formatting(&self.trigger, &capabilities.server_capabilities) } fn to_lsp( @@ -3578,6 +3594,18 @@ impl LspCommand for GetCodeLens { } } +impl LinkedEditingRange { + pub fn check_server_capabilities(capabilities: ServerCapabilities) -> bool { + let Some(linked_editing_options) = capabilities.linked_editing_range_provider else { + return false; + }; + if let LinkedEditingRangeServerCapabilities::Simple(false) = linked_editing_options { + return false; + } + true + } +} + #[async_trait(?Send)] impl LspCommand for LinkedEditingRange { type Response = Vec<Range<Anchor>>; @@ -3589,16 +3617,7 @@ impl LspCommand for LinkedEditingRange { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - let Some(linked_editing_options) = &capabilities - .server_capabilities - .linked_editing_range_provider - else { - return false; - }; - if let LinkedEditingRangeServerCapabilities::Simple(false) = linked_editing_options { - return false; - } - true + Self::check_server_capabilities(capabilities.server_capabilities) } fn to_lsp( @@ -4216,8 +4235,9 @@ impl LspCommand for GetDocumentColor { server_capabilities .server_capabilities .color_provider + .as_ref() .is_some_and(|capability| match capability { - lsp::ColorProviderCapability::Simple(supported) => supported, + lsp::ColorProviderCapability::Simple(supported) => *supported, lsp::ColorProviderCapability::ColorProvider(..) => true, lsp::ColorProviderCapability::Options(..) => true, }) diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 0cd375e0c5..6d448a6fea 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -46,6 +46,7 @@ use language::{ DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, Diff, File as _, Language, LanguageName, LanguageRegistry, LanguageToolchainStore, LocalFile, LspAdapter, LspAdapterDelegate, Patch, PointUtf16, TextBufferSnapshot, ToOffset, ToPointUtf16, Transaction, Unclipped, + WorkspaceFoldersContent, language_settings::{ FormatOnSave, Formatter, LanguageSettings, SelectedFormatter, language_settings, }, @@ -57,12 +58,12 @@ use language::{ range_from_lsp, range_to_lsp, }; use lsp::{ - CodeActionKind, CompletionContext, DiagnosticSeverity, DiagnosticTag, - DidChangeWatchedFilesRegistrationOptions, Edit, FileOperationFilter, FileOperationPatternKind, - FileOperationRegistrationOptions, FileRename, FileSystemWatcher, LanguageServer, - LanguageServerBinary, LanguageServerBinaryOptions, LanguageServerId, LanguageServerName, - LanguageServerSelector, LspRequestFuture, MessageActionItem, MessageType, OneOf, - RenameFilesParams, SymbolKind, TextEdit, WillRenameFiles, WorkDoneProgressCancelParams, + AdapterServerCapabilities, CodeActionKind, CompletionContext, DiagnosticSeverity, + DiagnosticTag, DidChangeWatchedFilesRegistrationOptions, Edit, FileOperationFilter, + FileOperationPatternKind, FileOperationRegistrationOptions, FileRename, FileSystemWatcher, + LanguageServer, LanguageServerBinary, LanguageServerBinaryOptions, LanguageServerId, + LanguageServerName, LanguageServerSelector, LspRequestFuture, MessageActionItem, MessageType, + OneOf, RenameFilesParams, SymbolKind, TextEdit, WillRenameFiles, WorkDoneProgressCancelParams, WorkspaceFolder, notification::DidRenameFiles, }; use node_runtime::read_package_installed_version; @@ -95,6 +96,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; +use sum_tree::Dimensions; use text::{Anchor, BufferId, LineEnding, OffsetRangeExt}; use url::Url; use util::{ @@ -217,6 +219,7 @@ impl LocalLspStore { let binary = self.get_language_server_binary(adapter.clone(), delegate.clone(), true, cx); let pending_workspace_folders: Arc<Mutex<BTreeSet<Url>>> = Default::default(); + let pending_server = cx.spawn({ let adapter = adapter.clone(); let server_name = adapter.name.clone(); @@ -242,14 +245,18 @@ impl LocalLspStore { return Ok(server); } + let code_action_kinds = adapter.code_action_kinds(); lsp::LanguageServer::new( stderr_capture, server_id, server_name, binary, &root_path, - adapter.code_action_kinds(), - pending_workspace_folders, + code_action_kinds, + Some(pending_workspace_folders).filter(|_| { + adapter.adapter.workspace_folders_content() + == WorkspaceFoldersContent::SubprojectRoots + }), cx, ) } @@ -418,7 +425,7 @@ impl LocalLspStore { if settings.as_ref().is_some_and(|b| b.path.is_some()) { let settings = settings.unwrap(); - return cx.spawn(async move |_| { + return cx.background_spawn(async move { let mut env = delegate.shell_env().await; env.extend(settings.env.unwrap_or_default()); @@ -575,8 +582,7 @@ impl LocalLspStore { }; let root = server.workspace_folders(); Ok(Some( - root.iter() - .cloned() + root.into_iter() .map(|uri| WorkspaceFolder { uri, name: Default::default(), @@ -616,7 +622,7 @@ impl LocalLspStore { .on_request::<lsp::request::RegisterCapability, _, _>({ let this = this.clone(); move |params, cx| { - let this = this.clone(); + let lsp_store = this.clone(); let mut cx = cx.clone(); async move { for reg in params.registrations { @@ -624,7 +630,7 @@ impl LocalLspStore { "workspace/didChangeWatchedFiles" => { if let Some(options) = reg.register_options { let options = serde_json::from_value(options)?; - this.update(&mut cx, |this, cx| { + lsp_store.update(&mut cx, |this, cx| { this.as_local_mut()?.on_lsp_did_change_watched_files( server_id, ®.id, options, cx, ); @@ -633,8 +639,9 @@ impl LocalLspStore { } } "textDocument/rangeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -653,14 +660,16 @@ impl LocalLspStore { server.update_capabilities(|capabilities| { capabilities.document_range_formatting_provider = Some(provider); - }) + }); + notify_server_capabilities_updated(&server, cx); } anyhow::Ok(()) })??; } "textDocument/onTypeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -677,15 +686,17 @@ impl LocalLspStore { capabilities .document_on_type_formatting_provider = Some(options); - }) + }); + notify_server_capabilities_updated(&server, cx); } } anyhow::Ok(()) })??; } "textDocument/formatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -704,7 +715,8 @@ impl LocalLspStore { server.update_capabilities(|capabilities| { capabilities.document_formatting_provider = Some(provider); - }) + }); + notify_server_capabilities_updated(&server, cx); } anyhow::Ok(()) })??; @@ -713,8 +725,9 @@ impl LocalLspStore { // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. } "textDocument/rename" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -731,7 +744,8 @@ impl LocalLspStore { server.update_capabilities(|capabilities| { capabilities.rename_provider = Some(options); - }) + }); + notify_server_capabilities_updated(&server, cx); } anyhow::Ok(()) })??; @@ -749,14 +763,15 @@ impl LocalLspStore { .on_request::<lsp::request::UnregisterCapability, _, _>({ let this = this.clone(); move |params, cx| { - let this = this.clone(); + let lsp_store = this.clone(); let mut cx = cx.clone(); async move { for unreg in params.unregisterations.iter() { match unreg.method.as_str() { "workspace/didChangeWatchedFiles" => { - this.update(&mut cx, |this, cx| { - this.as_local_mut()? + lsp_store.update(&mut cx, |lsp_store, cx| { + lsp_store + .as_local_mut()? .on_lsp_unregister_did_change_watched_files( server_id, &unreg.id, cx, ); @@ -767,44 +782,52 @@ impl LocalLspStore { // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. } "textDocument/rename" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.rename_provider = None - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } "textDocument/rangeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.document_range_formatting_provider = None - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } "textDocument/onTypeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.document_on_type_formatting_provider = None; - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } "textDocument/formatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.document_formatting_provider = None; - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } @@ -2420,36 +2443,11 @@ impl LocalLspStore { let server_id = server_node.server_id_or_init( |LaunchDisposition { server_name, - attach, path, settings, }| { - let server_id = match attach { - language::Attach::InstancePerRoot => { - // todo: handle instance per root proper. - if let Some(server_ids) = self - .language_server_ids - .get(&(worktree_id, server_name.clone())) - { - server_ids.iter().cloned().next().unwrap() - } else { - let language_name = language.name(); - let adapter = self.languages - .lsp_adapters(&language_name) - .into_iter() - .find(|adapter| &adapter.name() == server_name) - .expect("To find LSP adapter"); - let server_id = self.start_language_server( - &worktree, - delegate.clone(), - adapter, - settings, - cx, - ); - server_id - } - } - language::Attach::Shared => { + let server_id = + { let uri = Url::from_file_path( worktree.read(cx).abs_path().join(&path.path), ); @@ -2484,20 +2482,8 @@ impl LocalLspStore { } else { unreachable!("Language server ID should be available, as it's registered on demand") } - } + }; - let lsp_store = self.weak.clone(); - let server_name = server_node.name(); - let buffer_abs_path = abs_path.to_string_lossy().to_string(); - cx.defer(move |cx| { - lsp_store.update(cx, |_, cx| cx.emit(LspStoreEvent::LanguageServerUpdate { - language_server_id: server_id, - name: server_name, - message: proto::update_language_server::Variant::RegisteredForBuffer(proto::RegisteredForBuffer { - buffer_abs_path, - }) - })).ok(); - }); server_id }, )?; @@ -2533,11 +2519,13 @@ impl LocalLspStore { snapshot: initial_snapshot.clone(), }; + let mut registered = false; self.buffer_snapshots .entry(buffer_id) .or_default() .entry(server.server_id()) .or_insert_with(|| { + registered = true; server.register_buffer( uri.clone(), adapter.language_id(&language.name()), @@ -2552,15 +2540,18 @@ impl LocalLspStore { .entry(buffer_id) .or_default() .insert(server.server_id()); - cx.emit(LspStoreEvent::LanguageServerUpdate { - language_server_id: server.server_id(), - name: None, - message: proto::update_language_server::Variant::RegisteredForBuffer( - proto::RegisteredForBuffer { - buffer_abs_path: abs_path.to_string_lossy().to_string(), - }, - ), - }); + if registered { + cx.emit(LspStoreEvent::LanguageServerUpdate { + language_server_id: server.server_id(), + name: None, + message: proto::update_language_server::Variant::RegisteredForBuffer( + proto::RegisteredForBuffer { + buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), + }, + ), + }); + } } } @@ -3512,6 +3503,20 @@ impl LocalLspStore { } } +fn notify_server_capabilities_updated(server: &LanguageServer, cx: &mut Context<LspStore>) { + if let Some(capabilities) = serde_json::to_string(&server.capabilities()).ok() { + cx.emit(LspStoreEvent::LanguageServerUpdate { + language_server_id: server.server_id(), + name: Some(server.name()), + message: proto::update_language_server::Variant::MetadataUpdated( + proto::ServerMetadataUpdated { + capabilities: Some(capabilities), + }, + ), + }); + } +} + #[derive(Debug)] pub struct FormattableBuffer { handle: Entity<Buffer>, @@ -3551,7 +3556,9 @@ pub struct LspStore { _maintain_buffer_languages: Task<()>, diagnostic_summaries: HashMap<WorktreeId, HashMap<Arc<Path>, HashMap<LanguageServerId, DiagnosticSummary>>>, - lsp_data: HashMap<BufferId, DocumentColorData>, + pub(super) lsp_server_capabilities: HashMap<LanguageServerId, lsp::ServerCapabilities>, + lsp_document_colors: HashMap<BufferId, DocumentColorData>, + lsp_code_lens: HashMap<BufferId, CodeLensData>, } #[derive(Debug, Default, Clone)] @@ -3561,6 +3568,7 @@ pub struct DocumentColors { } type DocumentColorTask = Shared<Task<std::result::Result<DocumentColors, Arc<anyhow::Error>>>>; +type CodeLensTask = Shared<Task<std::result::Result<Vec<CodeAction>, Arc<anyhow::Error>>>>; #[derive(Debug, Default)] struct DocumentColorData { @@ -3570,8 +3578,15 @@ struct DocumentColorData { colors_update: Option<(Global, DocumentColorTask)>, } +#[derive(Debug, Default)] +struct CodeLensData { + lens_for_version: Global, + lens: HashMap<LanguageServerId, Vec<CodeAction>>, + update: Option<(Global, CodeLensTask)>, +} + #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum ColorFetchStrategy { +pub enum LspFetchStrategy { IgnoreCache, UseCache { known_cache_version: Option<usize> }, } @@ -3613,7 +3628,7 @@ pub enum LspStoreEvent { #[derive(Clone, Debug, Serialize)] pub struct LanguageServerStatus { - pub name: String, + pub name: LanguageServerName, pub pending_work: BTreeMap<String, LanguageServerProgress>, pub has_pending_diagnostic_updates: bool, progress_tokens: HashSet<String>, @@ -3804,7 +3819,9 @@ impl LspStore { language_server_statuses: Default::default(), nonce: StdRng::from_entropy().r#gen(), diagnostic_summaries: HashMap::default(), - lsp_data: HashMap::default(), + lsp_server_capabilities: HashMap::default(), + lsp_document_colors: HashMap::default(), + lsp_code_lens: HashMap::default(), active_entry: None, _maintain_workspace_config, _maintain_buffer_languages: Self::maintain_buffer_languages(languages, cx), @@ -3819,6 +3836,9 @@ impl LspStore { request: R, cx: &mut Context<LspStore>, ) -> Task<anyhow::Result<<R as LspCommand>::Response>> { + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(R::Response::default())); + } let message = request.to_proto(upstream_project_id, buffer.read(cx)); cx.spawn(async move |this, cx| { let response = client.request(message).await?; @@ -3861,7 +3881,9 @@ impl LspStore { language_server_statuses: Default::default(), nonce: StdRng::from_entropy().r#gen(), diagnostic_summaries: HashMap::default(), - lsp_data: HashMap::default(), + lsp_server_capabilities: HashMap::default(), + lsp_document_colors: HashMap::default(), + lsp_code_lens: HashMap::default(), active_entry: None, toolchain_store, _maintain_workspace_config, @@ -4162,7 +4184,8 @@ impl LspStore { *refcount }; if refcount == 0 { - lsp_store.lsp_data.remove(&buffer_id); + lsp_store.lsp_document_colors.remove(&buffer_id); + lsp_store.lsp_code_lens.remove(&buffer_id); let local = lsp_store.as_local_mut().unwrap(); local.registered_buffers.remove(&buffer_id); local.buffers_opened_in_servers.remove(&buffer_id); @@ -4434,20 +4457,73 @@ impl LspStore { } } - pub fn request_lsp<R: LspCommand>( + // TODO: remove MultiLspQuery: instead, the proto handler should pick appropriate server(s) + // Then, use `send_lsp_proto_request` or analogue for most of the LSP proto requests and inline this check inside + fn is_capable_for_proto_request<R>( + &self, + buffer: &Entity<Buffer>, + request: &R, + cx: &Context<Self>, + ) -> bool + where + R: LspCommand, + { + self.check_if_capable_for_proto_request( + buffer, + |capabilities| { + request.check_capabilities(AdapterServerCapabilities { + server_capabilities: capabilities.clone(), + code_action_kinds: None, + }) + }, + cx, + ) + } + + fn check_if_capable_for_proto_request<F>( + &self, + buffer: &Entity<Buffer>, + check: F, + cx: &Context<Self>, + ) -> bool + where + F: Fn(&lsp::ServerCapabilities) -> bool, + { + let Some(language) = buffer.read(cx).language().cloned() else { + return false; + }; + let relevant_language_servers = self + .languages + .lsp_adapters(&language.name()) + .into_iter() + .map(|lsp_adapter| lsp_adapter.name()) + .collect::<HashSet<_>>(); + self.language_server_statuses + .iter() + .filter_map(|(server_id, server_status)| { + relevant_language_servers + .contains(&server_status.name) + .then_some(server_id) + }) + .filter_map(|server_id| self.lsp_server_capabilities.get(&server_id)) + .any(check) + } + + pub fn request_lsp<R>( &mut self, - buffer_handle: Entity<Buffer>, + buffer: Entity<Buffer>, server: LanguageServerToQuery, request: R, cx: &mut Context<Self>, ) -> Task<Result<R::Response>> where + R: LspCommand, <R::LspRequest as lsp::request::Request>::Result: Send, <R::LspRequest as lsp::request::Request>::Params: Send, { if let Some((upstream_client, upstream_project_id)) = self.upstream_client() { return self.send_lsp_proto_request( - buffer_handle, + buffer, upstream_client, upstream_project_id, request, @@ -4455,7 +4531,7 @@ impl LspStore { ); } - let Some(language_server) = buffer_handle.update(cx, |buffer, cx| match server { + let Some(language_server) = buffer.update(cx, |buffer, cx| match server { LanguageServerToQuery::FirstCapable => self.as_local().and_then(|local| { local .language_servers_for_buffer(buffer, cx) @@ -4475,8 +4551,7 @@ impl LspStore { return Task::ready(Ok(Default::default())); }; - let buffer = buffer_handle.read(cx); - let file = File::from_dyn(buffer.file()).and_then(File::as_local); + let file = File::from_dyn(buffer.read(cx).file()).and_then(File::as_local); let Some(file) = file else { return Task::ready(Ok(Default::default())); @@ -4484,7 +4559,7 @@ impl LspStore { let lsp_params = match request.to_lsp_params_or_response( &file.abs_path(cx), - buffer, + buffer.read(cx), &language_server, cx, ) { @@ -4560,7 +4635,7 @@ impl LspStore { .response_from_lsp( response, this.upgrade().context("no app context")?, - buffer_handle, + buffer, language_server.server_id(), cx.clone(), ) @@ -4630,7 +4705,8 @@ impl LspStore { ) }) { let buffer = buffer_handle.read(cx); - if !local.registered_buffers.contains_key(&buffer.remote_id()) { + let buffer_id = buffer.remote_id(); + if !local.registered_buffers.contains_key(&buffer_id) { continue; } if let Some((file, language)) = File::from_dyn(buffer.file()) @@ -4688,35 +4764,11 @@ impl LspStore { let server_id = node.server_id_or_init( |LaunchDisposition { server_name, - attach, + path, settings, - }| match attach { - language::Attach::InstancePerRoot => { - // todo: handle instance per root proper. - if let Some(server_ids) = local - .language_server_ids - .get(&(worktree_id, server_name.clone())) - { - server_ids.iter().cloned().next().unwrap() - } else { - let adapter = local - .languages - .lsp_adapters(&language) - .into_iter() - .find(|adapter| &adapter.name() == server_name) - .expect("To find LSP adapter"); - let server_id = local.start_language_server( - &worktree, - delegate.clone(), - adapter, - settings, - cx, - ); - server_id - } - } - language::Attach::Shared => { + }| + { let uri = Url::from_file_path( worktree.read(cx).abs_path().join(&path.path), ); @@ -4745,7 +4797,6 @@ impl LspStore { } server_id } - }, ); if let Some(language_server_id) = server_id { @@ -4756,6 +4807,7 @@ impl LspStore { proto::update_language_server::Variant::RegisteredForBuffer( proto::RegisteredForBuffer { buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), }, ), }); @@ -4931,19 +4983,24 @@ impl LspStore { pub fn resolve_inlay_hint( &self, - hint: InlayHint, - buffer_handle: Entity<Buffer>, + mut hint: InlayHint, + buffer: Entity<Buffer>, server_id: LanguageServerId, cx: &mut Context<Self>, ) -> Task<anyhow::Result<InlayHint>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + if !self.check_if_capable_for_proto_request(&buffer, InlayHints::can_resolve_inlays, cx) + { + hint.resolve_state = ResolveState::Resolved; + return Task::ready(Ok(hint)); + } let request = proto::ResolveInlayHint { project_id, - buffer_id: buffer_handle.read(cx).remote_id().into(), + buffer_id: buffer.read(cx).remote_id().into(), language_server_id: server_id.0 as u64, hint: Some(InlayHints::project_to_proto_hint(hint.clone())), }; - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let response = upstream_client .request(request) .await @@ -4955,7 +5012,7 @@ impl LspStore { } }) } else { - let Some(lang_server) = buffer_handle.update(cx, |buffer, cx| { + let Some(lang_server) = buffer.update(cx, |buffer, cx| { self.language_server_for_local_buffer(buffer, server_id, cx) .map(|(_, server)| server.clone()) }) else { @@ -4964,7 +5021,7 @@ impl LspStore { if !InlayHints::can_resolve_inlays(&lang_server.capabilities()) { return Task::ready(Ok(hint)); } - let buffer_snapshot = buffer_handle.read(cx).snapshot(); + let buffer_snapshot = buffer.read(cx).snapshot(); cx.spawn(async move |_, cx| { let resolve_task = lang_server.request::<lsp::request::InlayHintResolveRequest>( InlayHints::project_to_lsp_hint(hint, &buffer_snapshot), @@ -4975,7 +5032,7 @@ impl LspStore { .context("inlay hint resolve LSP request")?; let resolved_hint = InlayHints::lsp_to_project_hint( resolved_hint, - &buffer_handle, + &buffer, server_id, ResolveState::Resolved, false, @@ -5086,7 +5143,7 @@ impl LspStore { } } - pub(crate) fn linked_edit( + pub(crate) fn linked_edits( &mut self, buffer: &Entity<Buffer>, position: Anchor, @@ -5101,10 +5158,7 @@ impl LspStore { local .language_servers_for_buffer(buffer, cx) .filter(|(_, server)| { - server - .capabilities() - .linked_editing_range_provider - .is_some() + LinkedEditingRange::check_server_capabilities(server.capabilities()) }) .filter(|(adapter, _)| { scope @@ -5131,7 +5185,7 @@ impl LspStore { }) == Some(true) }) else { - return Task::ready(Ok(vec![])); + return Task::ready(Ok(Vec::new())); }; self.request_lsp( @@ -5150,6 +5204,15 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<Option<Transaction>>> { if let Some((client, project_id)) = self.upstream_client() { + if !self.check_if_capable_for_proto_request( + &buffer, + |capabilities| { + OnTypeFormatting::supports_on_type_formatting(&trigger, capabilities) + }, + cx, + ) { + return Task::ready(Ok(None)); + } let request = proto::OnTypeFormatting { project_id, buffer_id: buffer.read(cx).remote_id().into(), @@ -5157,7 +5220,7 @@ impl LspStore { trigger, version: serialize_version(&buffer.read(cx).version()), }; - cx.spawn(async move |_, _| { + cx.background_spawn(async move { client .request(request) .await? @@ -5261,6 +5324,10 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetDefinitions { position }; + if !self.is_capable_for_proto_request(buffer_handle, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { buffer_id: buffer_handle.read(cx).remote_id().into(), version: serialize_version(&buffer_handle.read(cx).version()), @@ -5269,7 +5336,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetDefinition( - GetDefinitions { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer_handle.read(cx)), )), }); let buffer = buffer_handle.clone(); @@ -5316,7 +5383,7 @@ impl LspStore { GetDefinitions { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(definitions_task .await .into_iter() @@ -5334,6 +5401,10 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetDeclarations { position }; + if !self.is_capable_for_proto_request(buffer_handle, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { buffer_id: buffer_handle.read(cx).remote_id().into(), version: serialize_version(&buffer_handle.read(cx).version()), @@ -5342,7 +5413,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetDeclaration( - GetDeclarations { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer_handle.read(cx)), )), }); let buffer = buffer_handle.clone(); @@ -5389,7 +5460,7 @@ impl LspStore { GetDeclarations { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(declarations_task .await .into_iter() @@ -5402,23 +5473,27 @@ impl LspStore { pub fn type_definitions( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, position: PointUtf16, cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetTypeDefinitions { position }; + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetTypeDefinition( - GetTypeDefinitions { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5457,12 +5532,12 @@ impl LspStore { }) } else { let type_definitions_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetTypeDefinitions { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(type_definitions_task .await .into_iter() @@ -5475,23 +5550,27 @@ impl LspStore { pub fn implementations( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, position: PointUtf16, cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetImplementations { position }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetImplementation( - GetImplementations { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5530,12 +5609,12 @@ impl LspStore { }) } else { let implementations_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetImplementations { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(implementations_task .await .into_iter() @@ -5548,23 +5627,27 @@ impl LspStore { pub fn references( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, position: PointUtf16, cx: &mut Context<Self>, ) -> Task<Result<Vec<Location>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetReferences { position }; + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetReferences( - GetReferences { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5603,12 +5686,12 @@ impl LspStore { }) } else { let references_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetReferences { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(references_task .await .into_iter() @@ -5621,28 +5704,31 @@ impl LspStore { pub fn code_actions( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, range: Range<Anchor>, kinds: Option<Vec<CodeActionKind>>, cx: &mut Context<Self>, ) -> Task<Result<Vec<CodeAction>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetCodeActions { + range: range.clone(), + kinds: kinds.clone(), + }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetCodeActions( - GetCodeActions { - range: range.clone(), - kinds: kinds.clone(), - } - .to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5684,7 +5770,7 @@ impl LspStore { }) } else { let all_actions_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(range.start), GetCodeActions { range: range.clone(), @@ -5692,7 +5778,7 @@ impl LspStore { }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(all_actions_task .await .into_iter() @@ -5702,69 +5788,172 @@ impl LspStore { } } - pub fn code_lens( + pub fn code_lens_actions( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, cx: &mut Context<Self>, - ) -> Task<Result<Vec<CodeAction>>> { + ) -> CodeLensTask { + let version_queried_for = buffer.read(cx).version(); + let buffer_id = buffer.read(cx).remote_id(); + + if let Some(cached_data) = self.lsp_code_lens.get(&buffer_id) { + if !version_queried_for.changed_since(&cached_data.lens_for_version) { + let has_different_servers = self.as_local().is_some_and(|local| { + local + .buffers_opened_in_servers + .get(&buffer_id) + .cloned() + .unwrap_or_default() + != cached_data.lens.keys().copied().collect() + }); + if !has_different_servers { + return Task::ready(Ok(cached_data.lens.values().flatten().cloned().collect())) + .shared(); + } + } + } + + let lsp_data = self.lsp_code_lens.entry(buffer_id).or_default(); + if let Some((updating_for, running_update)) = &lsp_data.update { + if !version_queried_for.changed_since(&updating_for) { + return running_update.clone(); + } + } + let buffer = buffer.clone(); + let query_version_queried_for = version_queried_for.clone(); + let new_task = cx + .spawn(async move |lsp_store, cx| { + cx.background_executor() + .timer(Duration::from_millis(30)) + .await; + let fetched_lens = lsp_store + .update(cx, |lsp_store, cx| lsp_store.fetch_code_lens(&buffer, cx)) + .map_err(Arc::new)? + .await + .context("fetching code lens") + .map_err(Arc::new); + let fetched_lens = match fetched_lens { + Ok(fetched_lens) => fetched_lens, + Err(e) => { + lsp_store + .update(cx, |lsp_store, _| { + lsp_store.lsp_code_lens.entry(buffer_id).or_default().update = None; + }) + .ok(); + return Err(e); + } + }; + + lsp_store + .update(cx, |lsp_store, _| { + let lsp_data = lsp_store.lsp_code_lens.entry(buffer_id).or_default(); + if lsp_data.lens_for_version == query_version_queried_for { + lsp_data.lens.extend(fetched_lens.clone()); + } else if !lsp_data + .lens_for_version + .changed_since(&query_version_queried_for) + { + lsp_data.lens_for_version = query_version_queried_for; + lsp_data.lens = fetched_lens.clone(); + } + lsp_data.update = None; + lsp_data.lens.values().flatten().cloned().collect() + }) + .map_err(Arc::new) + }) + .shared(); + lsp_data.update = Some((version_queried_for, new_task.clone())); + new_task + } + + fn fetch_code_lens( + &mut self, + buffer: &Entity<Buffer>, + cx: &mut Context<Self>, + ) -> Task<Result<HashMap<LanguageServerId, Vec<CodeAction>>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetCodeLens; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(HashMap::default())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetCodeLens( - GetCodeLens.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); - cx.spawn(async move |weak_project, cx| { - let Some(project) = weak_project.upgrade() else { - return Ok(Vec::new()); + let buffer = buffer.clone(); + cx.spawn(async move |weak_lsp_store, cx| { + let Some(lsp_store) = weak_lsp_store.upgrade() else { + return Ok(HashMap::default()); }; let responses = request_task.await?.responses; - let code_lens = join_all( + let code_lens_actions = join_all( responses .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetCodeLensResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } + .filter_map(|lsp_response| { + let response = match lsp_response.response? { + proto::lsp_response::Response::GetCodeLensResponse(response) => { + Some(response) + } + unexpected => { + debug_panic!("Unexpected response: {unexpected:?}"); + None + } + }?; + let server_id = LanguageServerId::from_proto(lsp_response.server_id); + Some((server_id, response)) }) - .map(|code_lens_response| { - GetCodeLens.response_from_proto( - code_lens_response, - project.clone(), - buffer.clone(), - cx.clone(), - ) + .map(|(server_id, code_lens_response)| { + let lsp_store = lsp_store.clone(); + let buffer = buffer.clone(); + let cx = cx.clone(); + async move { + ( + server_id, + GetCodeLens + .response_from_proto( + code_lens_response, + lsp_store, + buffer, + cx, + ) + .await, + ) + } }), ) .await; - Ok(code_lens + let mut has_errors = false; + let code_lens_actions = code_lens_actions .into_iter() - .collect::<Result<Vec<Vec<_>>>>()? - .into_iter() - .flatten() - .collect()) + .filter_map(|(server_id, code_lens)| match code_lens { + Ok(code_lens) => Some((server_id, code_lens)), + Err(e) => { + has_errors = true; + log::error!("{e:#}"); + None + } + }) + .collect::<HashMap<_, _>>(); + anyhow::ensure!( + !has_errors || !code_lens_actions.is_empty(), + "Failed to fetch code lens" + ); + Ok(code_lens_actions) }) } else { - let code_lens_task = - self.request_multiple_lsp_locally(buffer_handle, None::<usize>, GetCodeLens, cx); - cx.spawn(async move |_, _| { - Ok(code_lens_task - .await - .into_iter() - .flat_map(|(_, code_lens)| code_lens) - .collect()) - }) + let code_lens_actions_task = + self.request_multiple_lsp_locally(buffer, None::<usize>, GetCodeLens, cx); + cx.background_spawn( + async move { Ok(code_lens_actions_task.await.into_iter().collect()) }, + ) } } @@ -5779,11 +5968,15 @@ impl LspStore { let language_registry = self.languages.clone(); if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetCompletions { position, context }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let task = self.send_lsp_proto_request( buffer.clone(), upstream_client, project_id, - GetCompletions { position, context }, + request, cx, ); let language = buffer.read(cx).language().cloned(); @@ -5921,11 +6114,17 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<bool>> { let client = self.upstream_client(); - let buffer_id = buffer.read(cx).remote_id(); let buffer_snapshot = buffer.read(cx).snapshot(); - cx.spawn(async move |this, cx| { + if !self.check_if_capable_for_proto_request( + &buffer, + GetCompletions::can_resolve_completions, + cx, + ) { + return Task::ready(Ok(false)); + } + cx.spawn(async move |lsp_store, cx| { let mut did_resolve = false; if let Some((client, project_id)) = client { for completion_index in completion_indices { @@ -5962,7 +6161,7 @@ impl LspStore { completion.source.server_id() }; if let Some(server_id) = server_id { - let server_and_adapter = this + let server_and_adapter = lsp_store .read_with(cx, |lsp_store, _| { let server = lsp_store.language_server_for_id(server_id)?; let adapter = @@ -5977,7 +6176,6 @@ impl LspStore { let resolved = Self::resolve_completion_local( server, - &buffer_snapshot, completions.clone(), completion_index, ) @@ -6010,18 +6208,11 @@ impl LspStore { async fn resolve_completion_local( server: Arc<lsp::LanguageServer>, - snapshot: &BufferSnapshot, completions: Rc<RefCell<Box<[Completion]>>>, completion_index: usize, ) -> Result<()> { let server_id = server.server_id(); - let can_resolve = server - .capabilities() - .completion_provider - .as_ref() - .and_then(|options| options.resolve_provider) - .unwrap_or(false); - if !can_resolve { + if !GetCompletions::can_resolve_completions(&server.capabilities()) { return Ok(()); } @@ -6055,26 +6246,8 @@ impl LspStore { .into_response() .context("resolve completion")?; - if let Some(text_edit) = resolved_completion.text_edit.as_ref() { - // Technically we don't have to parse the whole `text_edit`, since the only - // language server we currently use that does update `text_edit` in `completionItem/resolve` - // is `typescript-language-server` and they only update `text_edit.new_text`. - // But we should not rely on that. - let edit = parse_completion_text_edit(text_edit, snapshot); - - if let Some(mut parsed_edit) = edit { - LineEnding::normalize(&mut parsed_edit.new_text); - - let mut completions = completions.borrow_mut(); - let completion = &mut completions[completion_index]; - - completion.new_text = parsed_edit.new_text; - completion.replace_range = parsed_edit.replace_range; - if let CompletionSource::Lsp { insert_range, .. } = &mut completion.source { - *insert_range = parsed_edit.insert_range; - } - } - } + // We must not use any data such as sortText, filterText, insertText and textEdit to edit `Completion` since they are not suppose change during resolve. + // Refer: https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_completion let mut completions = completions.borrow_mut(); let completion = &mut completions[completion_index]; @@ -6324,12 +6497,10 @@ impl LspStore { }) else { return Task::ready(Ok(None)); }; - let snapshot = buffer_handle.read(&cx).snapshot(); cx.spawn(async move |this, cx| { Self::resolve_completion_local( server.clone(), - &snapshot, completions.clone(), completion_index, ) @@ -6392,16 +6563,24 @@ impl LspStore { pub fn pull_diagnostics( &mut self, - buffer_handle: Entity<Buffer>, + buffer: Entity<Buffer>, cx: &mut Context<Self>, ) -> Task<Result<Vec<LspPullDiagnostics>>> { - let buffer = buffer_handle.read(cx); - let buffer_id = buffer.remote_id(); + let buffer_id = buffer.read(cx).remote_id(); if let Some((client, upstream_project_id)) = self.upstream_client() { + if !self.is_capable_for_proto_request( + &buffer, + &GetDocumentDiagnostics { + previous_result_id: None, + }, + cx, + ) { + return Task::ready(Ok(Vec::new())); + } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer_id.to_proto(), - version: serialize_version(&buffer_handle.read(cx).version()), + version: serialize_version(&buffer.read(cx).version()), project_id: upstream_project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, @@ -6410,7 +6589,7 @@ impl LspStore { proto::GetDocumentDiagnostics { project_id: upstream_project_id, buffer_id: buffer_id.to_proto(), - version: serialize_version(&buffer_handle.read(cx).version()), + version: serialize_version(&buffer.read(cx).version()), }, )), }); @@ -6432,7 +6611,7 @@ impl LspStore { .collect()) }) } else { - let server_ids = buffer_handle.update(cx, |buffer, cx| { + let server_ids = buffer.update(cx, |buffer, cx| { self.language_servers_for_local_buffer(buffer, cx) .map(|(_, server)| server.server_id()) .collect::<Vec<_>>() @@ -6442,7 +6621,7 @@ impl LspStore { .map(|server_id| { let result_id = self.result_id(server_id, buffer_id, cx); self.request_lsp( - buffer_handle.clone(), + buffer.clone(), LanguageServerToQuery::Other(server_id), GetDocumentDiagnostics { previous_result_id: result_id, @@ -6464,34 +6643,36 @@ impl LspStore { pub fn inlay_hints( &mut self, - buffer_handle: Entity<Buffer>, + buffer: Entity<Buffer>, range: Range<Anchor>, cx: &mut Context<Self>, ) -> Task<anyhow::Result<Vec<InlayHint>>> { - let buffer = buffer_handle.read(cx); let range_start = range.start; let range_end = range.end; - let buffer_id = buffer.remote_id().into(); - let lsp_request = InlayHints { range }; + let buffer_id = buffer.read(cx).remote_id().into(); + let request = InlayHints { range }; if let Some((client, project_id)) = self.upstream_client() { - let request = proto::InlayHints { + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } + let proto_request = proto::InlayHints { project_id, buffer_id, start: Some(serialize_anchor(&range_start)), end: Some(serialize_anchor(&range_end)), - version: serialize_version(&buffer_handle.read(cx).version()), + version: serialize_version(&buffer.read(cx).version()), }; cx.spawn(async move |project, cx| { let response = client - .request(request) + .request(proto_request) .await .context("inlay hints proto request")?; LspCommand::response_from_proto( - lsp_request, + request, response, project.upgrade().context("No project")?, - buffer_handle.clone(), + buffer.clone(), cx.clone(), ) .await @@ -6499,13 +6680,13 @@ impl LspStore { }) } else { let lsp_request_task = self.request_lsp( - buffer_handle.clone(), + buffer.clone(), LanguageServerToQuery::FirstCapable, - lsp_request, + request, cx, ); cx.spawn(async move |_, cx| { - buffer_handle + buffer .update(cx, |buffer, _| { buffer.wait_for_edits(vec![range_start.timestamp, range_end.timestamp]) })? @@ -6597,7 +6778,7 @@ impl LspStore { pub fn document_colors( &mut self, - fetch_strategy: ColorFetchStrategy, + fetch_strategy: LspFetchStrategy, buffer: Entity<Buffer>, cx: &mut Context<Self>, ) -> Option<DocumentColorTask> { @@ -6605,11 +6786,11 @@ impl LspStore { let buffer_id = buffer.read(cx).remote_id(); match fetch_strategy { - ColorFetchStrategy::IgnoreCache => {} - ColorFetchStrategy::UseCache { + LspFetchStrategy::IgnoreCache => {} + LspFetchStrategy::UseCache { known_cache_version, } => { - if let Some(cached_data) = self.lsp_data.get(&buffer_id) { + if let Some(cached_data) = self.lsp_document_colors.get(&buffer_id) { if !version_queried_for.changed_since(&cached_data.colors_for_version) { let has_different_servers = self.as_local().is_some_and(|local| { local @@ -6642,7 +6823,7 @@ impl LspStore { } } - let lsp_data = self.lsp_data.entry(buffer_id).or_default(); + let lsp_data = self.lsp_document_colors.entry(buffer_id).or_default(); if let Some((updating_for, running_update)) = &lsp_data.colors_update { if !version_queried_for.changed_since(&updating_for) { return Some(running_update.clone()); @@ -6656,14 +6837,14 @@ impl LspStore { .await; let fetched_colors = lsp_store .update(cx, |lsp_store, cx| { - lsp_store.fetch_document_colors_for_buffer(buffer.clone(), cx) + lsp_store.fetch_document_colors_for_buffer(&buffer, cx) })? .await .context("fetching document colors") .map_err(Arc::new); let fetched_colors = match fetched_colors { Ok(fetched_colors) => { - if fetch_strategy != ColorFetchStrategy::IgnoreCache + if fetch_strategy != LspFetchStrategy::IgnoreCache && Some(true) == buffer .update(cx, |buffer, _| { @@ -6679,7 +6860,7 @@ impl LspStore { lsp_store .update(cx, |lsp_store, _| { lsp_store - .lsp_data + .lsp_document_colors .entry(buffer_id) .or_default() .colors_update = None; @@ -6691,7 +6872,7 @@ impl LspStore { lsp_store .update(cx, |lsp_store, _| { - let lsp_data = lsp_store.lsp_data.entry(buffer_id).or_default(); + let lsp_data = lsp_store.lsp_document_colors.entry(buffer_id).or_default(); if lsp_data.colors_for_version == query_version_queried_for { lsp_data.colors.extend(fetched_colors.clone()); @@ -6725,10 +6906,15 @@ impl LspStore { fn fetch_document_colors_for_buffer( &mut self, - buffer: Entity<Buffer>, + buffer: &Entity<Buffer>, cx: &mut Context<Self>, ) -> Task<anyhow::Result<HashMap<LanguageServerId, HashSet<DocumentColor>>>> { if let Some((client, project_id)) = self.upstream_client() { + let request = GetDocumentColor {}; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(HashMap::default())); + } + let request_task = client.request(proto::MultiLspQuery { project_id, buffer_id: buffer.read(cx).remote_id().to_proto(), @@ -6737,9 +6923,10 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetDocumentColor( - GetDocumentColor {}.to_proto(project_id, buffer.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); + let buffer = buffer.clone(); cx.spawn(async move |project, cx| { let Some(project) = project.upgrade() else { return Ok(HashMap::default()); @@ -6764,7 +6951,7 @@ impl LspStore { } }) .map(|(server_id, color_response)| { - let response = GetDocumentColor {}.response_from_proto( + let response = request.response_from_proto( color_response, project.clone(), buffer.clone(), @@ -6785,8 +6972,8 @@ impl LspStore { }) } else { let document_colors_task = - self.request_multiple_lsp_locally(&buffer, None::<usize>, GetDocumentColor, cx); - cx.spawn(async move |_, _| { + self.request_multiple_lsp_locally(buffer, None::<usize>, GetDocumentColor, cx); + cx.background_spawn(async move { Ok(document_colors_task .await .into_iter() @@ -6811,6 +6998,10 @@ impl LspStore { let position = position.to_point_utf16(buffer.read(cx)); if let Some((client, upstream_project_id)) = self.upstream_client() { + let request = GetSignatureHelp { position }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Vec::new()); + } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer.read(cx).remote_id().into(), version: serialize_version(&buffer.read(cx).version()), @@ -6819,7 +7010,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetSignatureHelp( - GetSignatureHelp { position }.to_proto(upstream_project_id, buffer.read(cx)), + request.to_proto(upstream_project_id, buffer.read(cx)), )), }); let buffer = buffer.clone(); @@ -6865,7 +7056,7 @@ impl LspStore { GetSignatureHelp { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { all_actions_task .await .into_iter() @@ -6882,6 +7073,10 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Vec<Hover>> { if let Some((client, upstream_project_id)) = self.upstream_client() { + let request = GetHover { position }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Vec::new()); + } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer.read(cx).remote_id().into(), version: serialize_version(&buffer.read(cx).version()), @@ -6890,7 +7085,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetHover( - GetHover { position }.to_proto(upstream_project_id, buffer.read(cx)), + request.to_proto(upstream_project_id, buffer.read(cx)), )), }); let buffer = buffer.clone(); @@ -6942,7 +7137,7 @@ impl LspStore { GetHover { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { all_actions_task .await .into_iter() @@ -7210,7 +7405,9 @@ impl LspStore { let build_incremental_change = || { buffer - .edits_since::<(PointUtf16, usize)>(previous_snapshot.snapshot.version()) + .edits_since::<Dimensions<PointUtf16, usize>>( + previous_snapshot.snapshot.version(), + ) .map(|edit| { let edit_start = edit.new.start.0; let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); @@ -7325,21 +7522,23 @@ impl LspStore { } pub(crate) async fn refresh_workspace_configurations( - this: &WeakEntity<Self>, + lsp_store: &WeakEntity<Self>, fs: Arc<dyn Fs>, cx: &mut AsyncApp, ) { maybe!(async move { - let servers = this - .update(cx, |this, cx| { - let Some(local) = this.as_local() else { + let mut refreshed_servers = HashSet::default(); + let servers = lsp_store + .update(cx, |lsp_store, cx| { + let toolchain_store = lsp_store.toolchain_store(cx); + let Some(local) = lsp_store.as_local() else { return Vec::default(); }; local .language_server_ids .iter() .flat_map(|((worktree_id, _), server_ids)| { - let worktree = this + let worktree = lsp_store .worktree_store .read(cx) .worktree_for_id(*worktree_id, cx); @@ -7355,43 +7554,54 @@ impl LspStore { ) }); - server_ids.iter().filter_map(move |server_id| { + let fs = fs.clone(); + let toolchain_store = toolchain_store.clone(); + server_ids.iter().filter_map(|server_id| { + let delegate = delegate.clone()? as Arc<dyn LspAdapterDelegate>; let states = local.language_servers.get(server_id)?; match states { LanguageServerState::Starting { .. } => None, LanguageServerState::Running { adapter, server, .. - } => Some(( - adapter.adapter.clone(), - server.clone(), - delegate.clone()? as Arc<dyn LspAdapterDelegate>, - )), + } => { + let fs = fs.clone(); + let toolchain_store = toolchain_store.clone(); + let adapter = adapter.clone(); + let server = server.clone(); + refreshed_servers.insert(server.name()); + Some(cx.spawn(async move |_, cx| { + let settings = + LocalLspStore::workspace_configuration_for_adapter( + adapter.adapter.clone(), + fs.as_ref(), + &delegate, + toolchain_store, + cx, + ) + .await + .ok()?; + server + .notify::<lsp::notification::DidChangeConfiguration>( + &lsp::DidChangeConfigurationParams { settings }, + ) + .ok()?; + Some(()) + })) + } } - }) + }).collect::<Vec<_>>() }) .collect::<Vec<_>>() }) .ok()?; - let toolchain_store = this.update(cx, |this, cx| this.toolchain_store(cx)).ok()?; - for (adapter, server, delegate) in servers { - let settings = LocalLspStore::workspace_configuration_for_adapter( - adapter, - fs.as_ref(), - &delegate, - toolchain_store.clone(), - cx, - ) - .await - .ok()?; - - server - .notify::<lsp::notification::DidChangeConfiguration>( - &lsp::DidChangeConfigurationParams { settings }, - ) - .ok(); - } + log::info!("Refreshing workspace configurations for servers {refreshed_servers:?}"); + // TODO this asynchronous job runs concurrently with extension (de)registration and may take enough time for a certain extension + // to stop and unregister its language server wrapper. + // This is racy : an extension might have already removed all `local.language_servers` state, but here we `.clone()` and hold onto it anyway. + // This now causes errors in the logs, we should find a way to remove such servers from the processing everywhere. + let _: Vec<Option<()>> = join_all(servers).await; Some(()) }) .await; @@ -7480,16 +7690,20 @@ impl LspStore { self.downstream_client = Some((downstream_client.clone(), project_id)); for (server_id, status) in &self.language_server_statuses { - downstream_client - .send(proto::StartLanguageServer { - project_id, - server: Some(proto::LanguageServer { - id: server_id.0 as u64, - name: status.name.clone(), - worktree_id: None, - }), - }) - .log_err(); + if let Some(server) = self.language_server_for_id(*server_id) { + downstream_client + .send(proto::StartLanguageServer { + project_id, + server: Some(proto::LanguageServer { + id: server_id.to_proto(), + name: status.name.to_string(), + worktree_id: None, + }), + capabilities: serde_json::to_string(&server.capabilities()) + .expect("serializing server LSP capabilities"), + }) + .log_err(); + } } } @@ -7516,7 +7730,7 @@ impl LspStore { ( LanguageServerId(server.id as usize), LanguageServerStatus { - name: server.name, + name: LanguageServerName::from_proto(server.name), pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), @@ -7932,7 +8146,7 @@ impl LspStore { }) .collect::<FuturesUnordered<_>>(); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let mut responses = Vec::with_capacity(response_results.len()); while let Some((server_id, response_result)) = response_results.next().await { if let Some(response) = response_result.log_err() { @@ -8665,18 +8879,29 @@ impl LspStore { } async fn handle_start_language_server( - this: Entity<Self>, + lsp_store: Entity<Self>, envelope: TypedEnvelope<proto::StartLanguageServer>, mut cx: AsyncApp, ) -> Result<()> { let server = envelope.payload.server.context("invalid server")?; - - this.update(&mut cx, |this, cx| { + let server_capabilities = + serde_json::from_str::<lsp::ServerCapabilities>(&envelope.payload.capabilities) + .with_context(|| { + format!( + "incorrect server capabilities {}", + envelope.payload.capabilities + ) + })?; + lsp_store.update(&mut cx, |lsp_store, cx| { let server_id = LanguageServerId(server.id as usize); - this.language_server_statuses.insert( + let server_name = LanguageServerName::from_proto(server.name.clone()); + lsp_store + .lsp_server_capabilities + .insert(server_id, server_capabilities); + lsp_store.language_server_statuses.insert( server_id, LanguageServerStatus { - name: server.name.clone(), + name: server_name.clone(), pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), @@ -8684,7 +8909,7 @@ impl LspStore { ); cx.emit(LspStoreEvent::LanguageServerAdded( server_id, - LanguageServerName(server.name.into()), + server_name, server.worktree_id.map(WorktreeId::from_proto), )); cx.notify(); @@ -8745,7 +8970,8 @@ impl LspStore { } non_lsp @ proto::update_language_server::Variant::StatusUpdate(_) - | non_lsp @ proto::update_language_server::Variant::RegisteredForBuffer(_) => { + | non_lsp @ proto::update_language_server::Variant::RegisteredForBuffer(_) + | non_lsp @ proto::update_language_server::Variant::MetadataUpdated(_) => { cx.emit(LspStoreEvent::LanguageServerUpdate { language_server_id, name: envelope @@ -10192,7 +10418,7 @@ impl LspStore { let name = self .language_server_statuses .remove(&server_id) - .map(|status| LanguageServerName::from(status.name.as_str())) + .map(|status| status.name.clone()) .or_else(|| { if let Some(LanguageServerState::Running { adapter, .. }) = server_state.as_ref() { Some(adapter.name()) @@ -10685,7 +10911,7 @@ impl LspStore { self.language_server_statuses.insert( server_id, LanguageServerStatus { - name: language_server.name().to_string(), + name: language_server.name(), pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), @@ -10699,18 +10925,23 @@ impl LspStore { )); cx.emit(LspStoreEvent::RefreshInlayHints); + let server_capabilities = language_server.capabilities(); if let Some((downstream_client, project_id)) = self.downstream_client.as_ref() { downstream_client .send(proto::StartLanguageServer { project_id: *project_id, server: Some(proto::LanguageServer { - id: server_id.0 as u64, + id: server_id.to_proto(), name: language_server.name().to_string(), worktree_id: Some(key.0.to_proto()), }), + capabilities: serde_json::to_string(&server_capabilities) + .expect("serializing server LSP capabilities"), }) .log_err(); } + self.lsp_server_capabilities + .insert(server_id, server_capabilities); // Tell the language server about every open buffer in the worktree that matches the language. // Also check for buffers in worktrees that reused this server @@ -10758,10 +10989,11 @@ impl LspStore { let local = self.as_local_mut().unwrap(); - if local.registered_buffers.contains_key(&buffer.remote_id()) { + let buffer_id = buffer.remote_id(); + if local.registered_buffers.contains_key(&buffer_id) { let versions = local .buffer_snapshots - .entry(buffer.remote_id()) + .entry(buffer_id) .or_default() .entry(server_id) .and_modify(|_| { @@ -10787,10 +11019,10 @@ impl LspStore { version, initial_snapshot.text(), ); - buffer_paths_registered.push(file.abs_path(cx)); + buffer_paths_registered.push((buffer_id, file.abs_path(cx))); local .buffers_opened_in_servers - .entry(buffer.remote_id()) + .entry(buffer_id) .or_default() .insert(server_id); } @@ -10814,13 +11046,14 @@ impl LspStore { } }); - for abs_path in buffer_paths_registered { + for (buffer_id, abs_path) in buffer_paths_registered { cx.emit(LspStoreEvent::LanguageServerUpdate { language_server_id: server_id, name: Some(adapter.name()), message: proto::update_language_server::Variant::RegisteredForBuffer( proto::RegisteredForBuffer { buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), }, ), }); @@ -11278,9 +11511,13 @@ impl LspStore { } fn cleanup_lsp_data(&mut self, for_server: LanguageServerId) { - for buffer_lsp_data in self.lsp_data.values_mut() { - buffer_lsp_data.colors.remove(&for_server); - buffer_lsp_data.cache_version += 1; + self.lsp_server_capabilities.remove(&for_server); + for buffer_colors in self.lsp_document_colors.values_mut() { + buffer_colors.colors.remove(&for_server); + buffer_colors.cache_version += 1; + } + for buffer_lens in self.lsp_code_lens.values_mut() { + buffer_lens.lens.remove(&for_server); } if let Some(local) = self.as_local_mut() { local.buffer_pull_diagnostics_result_ids.remove(&for_server); diff --git a/crates/project/src/manifest_tree/server_tree.rs b/crates/project/src/manifest_tree/server_tree.rs index 0283f06eec..81cb1c450c 100644 --- a/crates/project/src/manifest_tree/server_tree.rs +++ b/crates/project/src/manifest_tree/server_tree.rs @@ -13,10 +13,10 @@ use std::{ sync::{Arc, Weak}, }; -use collections::{HashMap, IndexMap}; +use collections::IndexMap; use gpui::{App, AppContext as _, Entity, Subscription}; use language::{ - Attach, CachedLspAdapter, LanguageName, LanguageRegistry, ManifestDelegate, + CachedLspAdapter, LanguageName, LanguageRegistry, ManifestDelegate, language_settings::AllLanguageSettings, }; use lsp::LanguageServerName; @@ -38,7 +38,6 @@ pub(crate) struct ServersForWorktree { pub struct LanguageServerTree { manifest_tree: Entity<ManifestTree>, pub(crate) instances: BTreeMap<WorktreeId, ServersForWorktree>, - attach_kind_cache: HashMap<LanguageServerName, Attach>, languages: Arc<LanguageRegistry>, _subscriptions: Subscription, } @@ -53,7 +52,6 @@ pub struct LanguageServerTreeNode(Weak<InnerTreeNode>); #[derive(Debug)] pub(crate) struct LaunchDisposition<'a> { pub(crate) server_name: &'a LanguageServerName, - pub(crate) attach: Attach, pub(crate) path: ProjectPath, pub(crate) settings: Arc<LspSettings>, } @@ -62,7 +60,6 @@ impl<'a> From<&'a InnerTreeNode> for LaunchDisposition<'a> { fn from(value: &'a InnerTreeNode) -> Self { LaunchDisposition { server_name: &value.name, - attach: value.attach, path: value.path.clone(), settings: value.settings.clone(), } @@ -105,7 +102,6 @@ impl From<Weak<InnerTreeNode>> for LanguageServerTreeNode { pub struct InnerTreeNode { id: OnceLock<LanguageServerId>, name: LanguageServerName, - attach: Attach, path: ProjectPath, settings: Arc<LspSettings>, } @@ -113,14 +109,12 @@ pub struct InnerTreeNode { impl InnerTreeNode { fn new( name: LanguageServerName, - attach: Attach, path: ProjectPath, settings: impl Into<Arc<LspSettings>>, ) -> Self { InnerTreeNode { id: Default::default(), name, - attach, path, settings: settings.into(), } @@ -130,8 +124,11 @@ impl InnerTreeNode { /// Determines how the list of adapters to query should be constructed. pub(crate) enum AdapterQuery<'a> { /// Search for roots of all adapters associated with a given language name. + /// Layman: Look for all project roots along the queried path that have any + /// language server associated with this language running. Language(&'a LanguageName), /// Search for roots of adapter with a given name. + /// Layman: Look for all project roots along the queried path that have this server running. Adapter(&'a LanguageServerName), } @@ -147,7 +144,7 @@ impl LanguageServerTree { }), manifest_tree, instances: Default::default(), - attach_kind_cache: Default::default(), + languages, }) } @@ -223,7 +220,6 @@ impl LanguageServerTree { .and_then(|name| roots.get(&name)) .cloned() .unwrap_or_else(|| root_path.clone()); - let attach = adapter.attach_kind(); let inner_node = self .instances @@ -237,7 +233,6 @@ impl LanguageServerTree { ( Arc::new(InnerTreeNode::new( adapter.name(), - attach, root_path.clone(), settings.clone(), )), @@ -379,7 +374,6 @@ pub(crate) struct ServerTreeRebase<'a> { impl<'tree> ServerTreeRebase<'tree> { fn new(new_tree: &'tree mut LanguageServerTree) -> Self { let old_contents = std::mem::take(&mut new_tree.instances); - new_tree.attach_kind_cache.clear(); let all_server_ids = old_contents .values() .flat_map(|nodes| { @@ -446,10 +440,7 @@ impl<'tree> ServerTreeRebase<'tree> { .get(&disposition.path.worktree_id) .and_then(|worktree_nodes| worktree_nodes.roots.get(&disposition.path.path)) .and_then(|roots| roots.get(&disposition.name)) - .filter(|(old_node, _)| { - disposition.attach == old_node.attach - && disposition.settings == old_node.settings - }) + .filter(|(old_node, _)| disposition.settings == old_node.settings) else { return Some(node); }; diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index f9c59d2e95..398e8bde87 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -97,7 +97,7 @@ use rpc::{ }; use search::{SearchInputKind, SearchQuery, SearchResult}; use search_history::SearchHistory; -use settings::{InvalidSettingsError, Settings, SettingsLocation, SettingsStore}; +use settings::{InvalidSettingsError, Settings, SettingsLocation, SettingsSources, SettingsStore}; use smol::channel::Receiver; use snippet::Snippet; use snippet_provider::SnippetProvider; @@ -113,7 +113,7 @@ use std::{ use task_store::TaskStore; use terminals::Terminals; -use text::{Anchor, BufferId, Point}; +use text::{Anchor, BufferId, OffsetRangeExt, Point}; use toolchain_store::EmptyToolchainStore; use util::{ ResultExt as _, @@ -277,6 +277,13 @@ pub enum Event { LanguageServerAdded(LanguageServerId, LanguageServerName, Option<WorktreeId>), LanguageServerRemoved(LanguageServerId), LanguageServerLog(LanguageServerId, LanguageServerLogType, String), + // [`lsp::notification::DidOpenTextDocument`] was sent to this server using the buffer data. + // Zed's buffer-related data is updated accordingly. + LanguageServerBufferRegistered { + server_id: LanguageServerId, + buffer_id: BufferId, + buffer_abs_path: PathBuf, + }, Toast { notification_id: SharedString, message: String, @@ -590,7 +597,7 @@ pub(crate) struct CoreCompletion { } /// A code action provided by a language server. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct CodeAction { /// The id of the language server that produced this code action. pub server_id: LanguageServerId, @@ -604,7 +611,7 @@ pub struct CodeAction { } /// An action sent back by a language server. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum LspAction { /// An action with the full data, may have a command or may not. /// May require resolving. @@ -942,10 +949,38 @@ pub enum PulledDiagnostics { }, } +/// Whether to disable all AI features in Zed. +/// +/// Default: false +#[derive(Copy, Clone, Debug)] +pub struct DisableAiSettings { + pub disable_ai: bool, +} + +impl settings::Settings for DisableAiSettings { + const KEY: Option<&'static str> = Some("disable_ai"); + + type FileContent = Option<bool>; + + fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> { + Ok(Self { + disable_ai: sources + .user + .or(sources.server) + .copied() + .flatten() + .unwrap_or(sources.default.ok_or_else(Self::missing_default)?), + }) + } + + fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} +} + impl Project { pub fn init_settings(cx: &mut App) { WorktreeSettings::register(cx); ProjectSettings::register(cx); + DisableAiSettings::register(cx); } pub fn init(client: &Arc<Client>, cx: &mut App) { @@ -998,8 +1033,9 @@ impl Project { cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); + let weak_self = cx.weak_entity(); let context_server_store = - cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx)); + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx)); let environment = cx.new(|_| ProjectEnvironment::new(env)); let manifest_tree = ManifestTree::new(worktree_store.clone(), cx); @@ -1167,8 +1203,9 @@ impl Project { cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); + let weak_self = cx.weak_entity(); let context_server_store = - cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx)); + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx)); let buffer_store = cx.new(|cx| { BufferStore::remote( @@ -1360,10 +1397,7 @@ impl Project { fs: Arc<dyn Fs>, cx: AsyncApp, ) -> Result<Entity<Self>> { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.connect(true, &cx).await.into_response()?; let subscriptions = [ EntitySubscription::Project(client.subscribe_to_entity::<Self>(remote_id)?), @@ -1428,8 +1462,6 @@ impl Project { let image_store = cx.new(|cx| { ImageStore::remote(worktree_store.clone(), client.clone().into(), remote_id, cx) })?; - let context_server_store = - cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx))?; let environment = cx.new(|_| ProjectEnvironment::new(None))?; @@ -1496,6 +1528,10 @@ impl Project { let snippets = SnippetProvider::new(fs.clone(), BTreeSet::from_iter([]), cx); + let weak_self = cx.weak_entity(); + let context_server_store = + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx)); + let mut worktrees = Vec::new(); for worktree in response.payload.worktrees { let worktree = @@ -2902,8 +2938,8 @@ impl Project { } LspStoreEvent::LanguageServerUpdate { language_server_id, - message, name, + message, } => { if self.is_local() { self.enqueue_buffer_ordered_message( @@ -2915,6 +2951,32 @@ impl Project { ) .ok(); } + + match message { + proto::update_language_server::Variant::MetadataUpdated(update) => { + if let Some(capabilities) = update + .capabilities + .as_ref() + .and_then(|capabilities| serde_json::from_str(capabilities).ok()) + { + self.lsp_store.update(cx, |lsp_store, _| { + lsp_store + .lsp_server_capabilities + .insert(*language_server_id, capabilities); + }); + } + } + proto::update_language_server::Variant::RegisteredForBuffer(update) => { + if let Some(buffer_id) = BufferId::new(update.buffer_id).ok() { + cx.emit(Event::LanguageServerBufferRegistered { + buffer_id, + server_id: *language_server_id, + buffer_abs_path: PathBuf::from(&update.buffer_abs_path), + }); + } + } + _ => (), + } } LspStoreEvent::Notification(message) => cx.emit(Event::Toast { notification_id: "lsp".into(), @@ -3368,7 +3430,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.definitions(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3386,7 +3448,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.declarations(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3404,7 +3466,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.type_definitions(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3422,7 +3484,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.implementations(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3440,27 +3502,13 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.references(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result }) } - fn document_highlights_impl( - &mut self, - buffer: &Entity<Buffer>, - position: PointUtf16, - cx: &mut Context<Self>, - ) -> Task<Result<Vec<DocumentHighlight>>> { - self.request_lsp( - buffer.clone(), - LanguageServerToQuery::FirstCapable, - GetDocumentHighlights { position }, - cx, - ) - } - pub fn document_highlights<T: ToPointUtf16>( &mut self, buffer: &Entity<Buffer>, @@ -3468,7 +3516,12 @@ impl Project { cx: &mut Context<Self>, ) -> Task<Result<Vec<DocumentHighlight>>> { let position = position.to_point_utf16(buffer.read(cx)); - self.document_highlights_impl(buffer, position, cx) + self.request_lsp( + buffer.clone(), + LanguageServerToQuery::FirstCapable, + GetDocumentHighlights { position }, + cx, + ) } pub fn document_symbols( @@ -3569,14 +3622,14 @@ impl Project { .update(cx, |lsp_store, cx| lsp_store.hover(buffer, position, cx)) } - pub fn linked_edit( + pub fn linked_edits( &self, buffer: &Entity<Buffer>, position: Anchor, cx: &mut Context<Self>, ) -> Task<Result<Vec<Range<Anchor>>>> { self.lsp_store.update(cx, |lsp_store, cx| { - lsp_store.linked_edit(buffer, position, cx) + lsp_store.linked_edits(buffer, position, cx) }) } @@ -3607,20 +3660,29 @@ impl Project { }) } - pub fn code_lens<T: Clone + ToOffset>( + pub fn code_lens_actions<T: Clone + ToOffset>( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, range: Range<T>, cx: &mut Context<Self>, ) -> Task<Result<Vec<CodeAction>>> { - let snapshot = buffer_handle.read(cx).snapshot(); - let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end); + let snapshot = buffer.read(cx).snapshot(); + let range = range.clone().to_owned().to_point(&snapshot); + let range_start = snapshot.anchor_before(range.start); + let range_end = if range.start == range.end { + range_start + } else { + snapshot.anchor_after(range.end) + }; + let range = range_start..range_end; let code_lens_actions = self .lsp_store - .update(cx, |lsp_store, cx| lsp_store.code_lens(buffer_handle, cx)); + .update(cx, |lsp_store, cx| lsp_store.code_lens_actions(buffer, cx)); cx.background_spawn(async move { - let mut code_lens_actions = code_lens_actions.await?; + let mut code_lens_actions = code_lens_actions + .await + .map_err(|e| anyhow!("code lens fetch failed: {e:#}"))?; code_lens_actions.retain(|code_lens_action| { range .start @@ -3659,19 +3721,6 @@ impl Project { }) } - fn prepare_rename_impl( - &mut self, - buffer: Entity<Buffer>, - position: PointUtf16, - cx: &mut Context<Self>, - ) -> Task<Result<PrepareRenameResponse>> { - self.request_lsp( - buffer, - LanguageServerToQuery::FirstCapable, - PrepareRename { position }, - cx, - ) - } pub fn prepare_rename<T: ToPointUtf16>( &mut self, buffer: Entity<Buffer>, @@ -3679,7 +3728,12 @@ impl Project { cx: &mut Context<Self>, ) -> Task<Result<PrepareRenameResponse>> { let position = position.to_point_utf16(buffer.read(cx)); - self.prepare_rename_impl(buffer, position, cx) + self.request_lsp( + buffer, + LanguageServerToQuery::FirstCapable, + PrepareRename { position }, + cx, + ) } pub fn perform_rename<T: ToPointUtf16>( @@ -3983,7 +4037,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.request_lsp(buffer_handle, server, request, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index 779cf95add..75ebc8339a 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -1100,7 +1100,7 @@ async fn test_reporting_fs_changes_to_language_servers(cx: &mut gpui::TestAppCon let fake_server = fake_servers.next().await.unwrap(); let (server_id, server_name) = lsp_store.read_with(cx, |lsp_store, _| { let (id, status) = lsp_store.language_server_statuses().next().unwrap(); - (id, LanguageServerName::from(status.name.as_str())) + (id, status.name.clone()) }); // Simulate jumping to a definition in a dependency outside of the worktree. @@ -1698,7 +1698,7 @@ async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppC name: "the-language-server", disk_based_diagnostics_sources: vec!["disk".into()], disk_based_diagnostics_progress_token: Some(progress_token.into()), - ..Default::default() + ..FakeLspAdapter::default() }, ); @@ -1710,6 +1710,7 @@ async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppC }) .await .unwrap(); + let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id()); // Simulate diagnostics starting to update. let fake_server = fake_servers.next().await.unwrap(); fake_server.start_progress(progress_token).await; @@ -1736,6 +1737,14 @@ async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppC ); assert_eq!(events.next().await.unwrap(), Event::RefreshInlayHints); fake_server.start_progress(progress_token).await; + assert_eq!( + events.next().await.unwrap(), + Event::LanguageServerBufferRegistered { + server_id: LanguageServerId(1), + buffer_id, + buffer_abs_path: PathBuf::from(path!("/dir/a.rs")), + } + ); assert_eq!( events.next().await.unwrap(), Event::DiskBasedDiagnosticsStarted { diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index 5330ee506a..353f19adb2 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -79,11 +79,16 @@ message OpenServerSettings { uint64 project_id = 1; } -message GetPanicFiles { +message GetCrashFiles { } -message GetPanicFilesResponse { - repeated string file_contents = 2; +message GetCrashFilesResponse { + repeated CrashReport crashes = 1; +} + +message CrashReport { + optional string panic_contents = 1; + optional bytes minidump_contents = 2; } message Extension { diff --git a/crates/proto/proto/call.proto b/crates/proto/proto/call.proto index 5212f3b43f..b5c882db56 100644 --- a/crates/proto/proto/call.proto +++ b/crates/proto/proto/call.proto @@ -71,6 +71,7 @@ message RejoinedProject { repeated WorktreeMetadata worktrees = 2; repeated Collaborator collaborators = 3; repeated LanguageServer language_servers = 4; + repeated string language_server_capabilities = 5; } message LeaveRoom {} @@ -199,6 +200,7 @@ message JoinProjectResponse { repeated WorktreeMetadata worktrees = 2; repeated Collaborator collaborators = 3; repeated LanguageServer language_servers = 4; + repeated string language_server_capabilities = 8; ChannelRole role = 6; reserved 7; } diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index ea08d36371..c32da9b110 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -422,3 +422,12 @@ message BlameBufferResponse { reserved 1 to 4; } + +message GetDefaultBranch { + uint64 project_id = 1; + uint64 repository_id = 2; +} + +message GetDefaultBranchResponse { + optional string branch = 1; +} diff --git a/crates/proto/proto/lsp.proto b/crates/proto/proto/lsp.proto index e3c2f69c0b..1e693dfdf3 100644 --- a/crates/proto/proto/lsp.proto +++ b/crates/proto/proto/lsp.proto @@ -518,6 +518,7 @@ message LanguageServer { message StartLanguageServer { uint64 project_id = 1; LanguageServer server = 2; + string capabilities = 3; } message UpdateDiagnosticSummary { @@ -545,6 +546,7 @@ message UpdateLanguageServer { LspDiskBasedDiagnosticsUpdated disk_based_diagnostics_updated = 7; StatusUpdate status_update = 9; RegisteredForBuffer registered_for_buffer = 10; + ServerMetadataUpdated metadata_updated = 11; } } @@ -597,6 +599,11 @@ enum ServerBinaryStatus { message RegisteredForBuffer { string buffer_abs_path = 1; + uint64 buffer_id = 2; +} + +message ServerMetadataUpdated { + optional string capabilities = 1; } message LanguageServerLog { diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 29ab2b1e90..9de5c2c0c7 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -294,9 +294,6 @@ message Envelope { GetPathMetadata get_path_metadata = 278; GetPathMetadataResponse get_path_metadata_response = 279; - GetPanicFiles get_panic_files = 280; - GetPanicFilesResponse get_panic_files_response = 281; - CancelLanguageServerWork cancel_language_server_work = 282; LspExtOpenDocs lsp_ext_open_docs = 283; @@ -399,7 +396,13 @@ message Envelope { GetColorPresentationResponse get_color_presentation_response = 356; Stash stash = 357; - StashPop stash_pop = 358; // current max + StashPop stash_pop = 358; + + GetDefaultBranch get_default_branch = 359; + GetDefaultBranchResponse get_default_branch_response = 360; + + GetCrashFiles get_crash_files = 361; + GetCrashFilesResponse get_crash_files_response = 362; // current max } reserved 87 to 88; @@ -420,6 +423,7 @@ message Envelope { reserved 270; reserved 247 to 254; reserved 255 to 256; + reserved 280 to 281; } message Hello { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 9f586a7839..4c447e2eca 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -99,8 +99,8 @@ messages!( (GetHoverResponse, Background), (GetNotifications, Foreground), (GetNotificationsResponse, Foreground), - (GetPanicFiles, Background), - (GetPanicFilesResponse, Background), + (GetCrashFiles, Background), + (GetCrashFilesResponse, Background), (GetPathMetadata, Background), (GetPathMetadataResponse, Background), (GetPermalinkToLine, Foreground), @@ -315,7 +315,9 @@ messages!( (LogToDebugConsole, Background), (GetDocumentDiagnostics, Background), (GetDocumentDiagnosticsResponse, Background), - (PullWorkspaceDiagnostics, Background) + (PullWorkspaceDiagnostics, Background), + (GetDefaultBranch, Background), + (GetDefaultBranchResponse, Background), ); request_messages!( @@ -460,7 +462,7 @@ request_messages!( (ActivateToolchain, Ack), (ActiveToolchain, ActiveToolchainResponse), (GetPathMetadata, GetPathMetadataResponse), - (GetPanicFiles, GetPanicFilesResponse), + (GetCrashFiles, GetCrashFilesResponse), (CancelLanguageServerWork, Ack), (SyncExtensions, SyncExtensionsResponse), (InstallExtension, Ack), @@ -483,7 +485,8 @@ request_messages!( (GetDebugAdapterBinary, DebugAdapterBinary), (RunDebugLocators, DebugRequest), (GetDocumentDiagnostics, GetDocumentDiagnosticsResponse), - (PullWorkspaceDiagnostics, Ack) + (PullWorkspaceDiagnostics, Ack), + (GetDefaultBranch, GetDefaultBranchResponse), ); entity_messages!( @@ -615,7 +618,8 @@ entity_messages!( GetDebugAdapterBinary, LogToDebugConsole, GetDocumentDiagnostics, - PullWorkspaceDiagnostics + PullWorkspaceDiagnostics, + GetDefaultBranch ); entity_messages!( @@ -784,6 +788,25 @@ pub fn split_repository_update( }]) } +impl MultiLspQuery { + pub fn request_str(&self) -> &str { + match self.request { + Some(multi_lsp_query::Request::GetHover(_)) => "GetHover", + Some(multi_lsp_query::Request::GetCodeActions(_)) => "GetCodeActions", + Some(multi_lsp_query::Request::GetSignatureHelp(_)) => "GetSignatureHelp", + Some(multi_lsp_query::Request::GetCodeLens(_)) => "GetCodeLens", + Some(multi_lsp_query::Request::GetDocumentDiagnostics(_)) => "GetDocumentDiagnostics", + Some(multi_lsp_query::Request::GetDocumentColor(_)) => "GetDocumentColor", + Some(multi_lsp_query::Request::GetDefinition(_)) => "GetDefinition", + Some(multi_lsp_query::Request::GetDeclaration(_)) => "GetDeclaration", + Some(multi_lsp_query::Request::GetTypeDefinition(_)) => "GetTypeDefinition", + Some(multi_lsp_query::Request::GetImplementation(_)) => "GetImplementation", + Some(multi_lsp_query::Request::GetReferences(_)) => "GetReferences", + None => "<unknown>", + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index aa5103e62b..655e24860a 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -963,7 +963,7 @@ impl RemoteServerProjects { .child({ let project = project.clone(); // Right-margin to offset it from the Scrollbar - IconButton::new("remove-remote-project", IconName::TrashAlt) + IconButton::new("remove-remote-project", IconName::Trash) .icon_size(IconSize::Small) .shape(IconButtonShape::Square) .size(ButtonSize::Large) diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index e31d3dcfd5..4306251e44 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -1742,7 +1742,7 @@ impl SshRemoteConnection { } }); - cx.spawn(async move |_| { + cx.background_spawn(async move { let result = futures::select! { result = stdin_task.fuse() => { result.context("stdin") diff --git a/crates/remote_server/Cargo.toml b/crates/remote_server/Cargo.toml index 443c47919f..c6a546f345 100644 --- a/crates/remote_server/Cargo.toml +++ b/crates/remote_server/Cargo.toml @@ -67,8 +67,11 @@ watch.workspace = true worktree.workspace = true [target.'cfg(not(windows))'.dependencies] +crashes.workspace = true +crash-handler.workspace = true fork.workspace = true libc.workspace = true +minidumper.workspace = true [dev-dependencies] assistant_tool.workspace = true diff --git a/crates/remote_server/src/main.rs b/crates/remote_server/src/main.rs index 98f635d856..03b0c3eda3 100644 --- a/crates/remote_server/src/main.rs +++ b/crates/remote_server/src/main.rs @@ -12,6 +12,10 @@ struct Cli { /// by having Zed act like netcat communicating over a Unix socket. #[arg(long, hide = true)] askpass: Option<String>, + /// Used for recording minidumps on crashes by having the server run a separate + /// process communicating over a socket. + #[arg(long, hide = true)] + crash_handler: Option<PathBuf>, /// Used for loading the environment from the project. #[arg(long, hide = true)] printenv: bool, @@ -58,6 +62,11 @@ fn main() { return; } + if let Some(socket) = &cli.crash_handler { + crashes::crash_server(socket.as_path()); + return; + } + if cli.printenv { util::shell_env::print_env(); return; diff --git a/crates/remote_server/src/unix.rs b/crates/remote_server/src/unix.rs index 84ce08ff25..9bb5645dc7 100644 --- a/crates/remote_server/src/unix.rs +++ b/crates/remote_server/src/unix.rs @@ -17,6 +17,7 @@ use node_runtime::{NodeBinaryOptions, NodeRuntime}; use paths::logs_dir; use project::project_settings::ProjectSettings; +use proto::CrashReport; use release_channel::{AppVersion, RELEASE_CHANNEL, ReleaseChannel}; use remote::proxy::ProxyLaunchError; use remote::ssh_session::ChannelClient; @@ -33,6 +34,7 @@ use smol::io::AsyncReadExt; use smol::Async; use smol::{net::unix::UnixListener, stream::StreamExt as _}; +use std::collections::HashMap; use std::ffi::OsStr; use std::ops::ControlFlow; use std::str::FromStr; @@ -109,8 +111,9 @@ fn init_logging_server(log_file_path: PathBuf) -> Result<Receiver<Vec<u8>>> { Ok(rx) } -fn init_panic_hook() { - std::panic::set_hook(Box::new(|info| { +fn init_panic_hook(session_id: String) { + std::panic::set_hook(Box::new(move |info| { + crashes::handle_panic(); let payload = info .payload() .downcast_ref::<&str>() @@ -171,9 +174,11 @@ fn init_panic_hook() { architecture: env::consts::ARCH.into(), panicked_on: Utc::now().timestamp_millis(), backtrace, - system_id: None, // Set on SSH client - installation_id: None, // Set on SSH client - session_id: "".to_string(), // Set on SSH client + system_id: None, // Set on SSH client + installation_id: None, // Set on SSH client + + // used on this end to associate panics with minidumps, but will be replaced on the SSH client + session_id: session_id.clone(), }; if let Some(panic_data_json) = serde_json::to_string(&panic_data).log_err() { @@ -194,44 +199,69 @@ fn init_panic_hook() { })); } -fn handle_panic_requests(project: &Entity<HeadlessProject>, client: &Arc<ChannelClient>) { +fn handle_crash_files_requests(project: &Entity<HeadlessProject>, client: &Arc<ChannelClient>) { let client: AnyProtoClient = client.clone().into(); client.add_request_handler( project.downgrade(), - |_, _: TypedEnvelope<proto::GetPanicFiles>, _cx| async move { + |_, _: TypedEnvelope<proto::GetCrashFiles>, _cx| async move { + let mut crashes = Vec::new(); + let mut minidumps_by_session_id = HashMap::new(); let mut children = smol::fs::read_dir(paths::logs_dir()).await?; - let mut panic_files = Vec::new(); while let Some(child) = children.next().await { let child = child?; let child_path = child.path(); - if child_path.extension() != Some(OsStr::new("panic")) { - continue; + let extension = child_path.extension(); + if extension == Some(OsStr::new("panic")) { + let filename = if let Some(filename) = child_path.file_name() { + filename.to_string_lossy() + } else { + continue; + }; + + if !filename.starts_with("zed") { + continue; + } + + let file_contents = smol::fs::read_to_string(&child_path) + .await + .context("error reading panic file")?; + + crashes.push(proto::CrashReport { + panic_contents: Some(file_contents), + minidump_contents: None, + }); + } else if extension == Some(OsStr::new("dmp")) { + let session_id = child_path.file_stem().unwrap().to_string_lossy(); + minidumps_by_session_id + .insert(session_id.to_string(), smol::fs::read(&child_path).await?); } - let filename = if let Some(filename) = child_path.file_name() { - filename.to_string_lossy() - } else { - continue; - }; - - if !filename.starts_with("zed") { - continue; - } - - let file_contents = smol::fs::read_to_string(&child_path) - .await - .context("error reading panic file")?; - - panic_files.push(file_contents); // We've done what we can, delete the file - std::fs::remove_file(child_path) + smol::fs::remove_file(&child_path) + .await .context("error removing panic") .log_err(); } - anyhow::Ok(proto::GetPanicFilesResponse { - file_contents: panic_files, - }) + + for crash in &mut crashes { + let panic: telemetry_events::Panic = + serde_json::from_str(crash.panic_contents.as_ref().unwrap())?; + if let dump @ Some(_) = minidumps_by_session_id.remove(&panic.session_id) { + crash.minidump_contents = dump; + } + } + + crashes.extend( + minidumps_by_session_id + .into_values() + .map(|dmp| CrashReport { + panic_contents: None, + minidump_contents: Some(dmp), + }), + ); + + anyhow::Ok(proto::GetCrashFilesResponse { crashes }) }, ); } @@ -409,7 +439,12 @@ pub fn execute_run( ControlFlow::Continue(_) => {} } - init_panic_hook(); + let app = gpui::Application::headless(); + let id = std::process::id().to_string(); + app.background_executor() + .spawn(crashes::init(id.clone())) + .detach(); + init_panic_hook(id); let log_rx = init_logging_server(log_file)?; log::info!( "starting up. pid_file: {:?}, stdin_socket: {:?}, stdout_socket: {:?}, stderr_socket: {:?}", @@ -425,7 +460,7 @@ pub fn execute_run( let listeners = ServerListeners::new(stdin_socket, stdout_socket, stderr_socket)?; let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new()); - gpui::Application::headless().run(move |cx| { + app.run(move |cx| { settings::init(cx); let app_version = AppVersion::load(env!("ZED_PKG_VERSION")); release_channel::init(app_version, cx); @@ -486,7 +521,7 @@ pub fn execute_run( ) }); - handle_panic_requests(&project, &session); + handle_crash_files_requests(&project, &session); cx.background_spawn(async move { cleanup_old_binaries() }) .detach(); @@ -530,12 +565,15 @@ impl ServerPaths { pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> { init_logging_proxy(); - init_panic_hook(); - - log::info!("starting proxy process. PID: {}", std::process::id()); let server_paths = ServerPaths::new(&identifier)?; + let id = std::process::id().to_string(); + smol::spawn(crashes::init(id.clone())).detach(); + init_panic_hook(id); + + log::info!("starting proxy process. PID: {}", std::process::id()); + let server_pid = check_pid_file(&server_paths.pid_file)?; let server_running = server_pid.is_some(); if is_reconnecting { diff --git a/crates/repl/src/notebook/cell.rs b/crates/repl/src/notebook/cell.rs index 2ed68c17d1..18851417c0 100644 --- a/crates/repl/src/notebook/cell.rs +++ b/crates/repl/src/notebook/cell.rs @@ -38,7 +38,7 @@ pub enum CellControlType { impl CellControlType { fn icon_name(&self) -> IconName { match self { - CellControlType::RunCell => IconName::Play, + CellControlType::RunCell => IconName::PlayOutlined, CellControlType::RerunCell => IconName::ArrowCircle, CellControlType::ClearCell => IconName::ListX, CellControlType::CellOptions => IconName::Ellipsis, diff --git a/crates/repl/src/notebook/notebook_ui.rs b/crates/repl/src/notebook/notebook_ui.rs index d14f458fa9..3e96cc4d11 100644 --- a/crates/repl/src/notebook/notebook_ui.rs +++ b/crates/repl/src/notebook/notebook_ui.rs @@ -343,7 +343,7 @@ impl NotebookEditor { .child( Self::render_notebook_control( "run-all-cells", - IconName::Play, + IconName::PlayOutlined, window, cx, ) diff --git a/crates/reqwest_client/src/reqwest_client.rs b/crates/reqwest_client/src/reqwest_client.rs index daff20ac4a..6461a0ae17 100644 --- a/crates/reqwest_client/src/reqwest_client.rs +++ b/crates/reqwest_client/src/reqwest_client.rs @@ -4,14 +4,13 @@ use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration use anyhow::anyhow; use bytes::{BufMut, Bytes, BytesMut}; -use futures::{AsyncRead, TryStreamExt as _}; +use futures::{AsyncRead, FutureExt as _, TryStreamExt as _}; use http_client::{RedirectPolicy, Url, http}; use regex::Regex; use reqwest::{ header::{HeaderMap, HeaderValue}, redirect, }; -use smol::future::FutureExt; const DEFAULT_CAPACITY: usize = 4096; static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new(); @@ -20,6 +19,7 @@ static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+") pub struct ReqwestClient { client: reqwest::Client, proxy: Option<Url>, + user_agent: Option<HeaderValue>, handle: tokio::runtime::Handle, } @@ -44,9 +44,11 @@ impl ReqwestClient { Ok(client.into()) } - pub fn proxy_and_user_agent(proxy: Option<Url>, agent: &str) -> anyhow::Result<Self> { + pub fn proxy_and_user_agent(proxy: Option<Url>, user_agent: &str) -> anyhow::Result<Self> { + let user_agent = HeaderValue::from_str(user_agent)?; + let mut map = HeaderMap::new(); - map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?); + map.insert(http::header::USER_AGENT, user_agent.clone()); let mut client = Self::builder().default_headers(map); let client_has_proxy; @@ -73,6 +75,7 @@ impl ReqwestClient { .build()?; let mut client: ReqwestClient = client.into(); client.proxy = client_has_proxy.then_some(proxy).flatten(); + client.user_agent = Some(user_agent); Ok(client) } } @@ -96,6 +99,7 @@ impl From<reqwest::Client> for ReqwestClient { client, handle, proxy: None, + user_agent: None, } } } @@ -216,6 +220,10 @@ impl http_client::HttpClient for ReqwestClient { type_name::<Self>() } + fn user_agent(&self) -> Option<&HeaderValue> { + self.user_agent.as_ref() + } + fn send( &self, req: http::Request<http_client::AsyncBody>, @@ -265,6 +273,26 @@ impl http_client::HttpClient for ReqwestClient { } .boxed() } + + fn send_multipart_form<'a>( + &'a self, + url: &str, + form: reqwest::multipart::Form, + ) -> futures::future::BoxFuture<'a, anyhow::Result<http_client::Response<http_client::AsyncBody>>> + { + let response = self.client.post(url).multipart(form).send(); + self.handle + .spawn(async move { + let response = response.await?; + let mut builder = http::response::Builder::new().status(response.status()); + for (k, v) in response.headers() { + builder = builder.header(k, v) + } + Ok(builder.body(response.bytes().await?.into())?) + }) + .map(|e| e?) + .boxed() + } } #[cfg(test)] diff --git a/crates/rope/src/rope.rs b/crates/rope/src/rope.rs index 515cd71331..aa3ed5db57 100644 --- a/crates/rope/src/rope.rs +++ b/crates/rope/src/rope.rs @@ -12,7 +12,7 @@ use std::{ ops::{self, AddAssign, Range}, str, }; -use sum_tree::{Bias, Dimension, SumTree}; +use sum_tree::{Bias, Dimension, Dimensions, SumTree}; pub use chunk::ChunkSlice; pub use offset_utf16::OffsetUtf16; @@ -282,7 +282,7 @@ impl Rope { if offset >= self.summary().len { return self.summary().len_utf16; } - let mut cursor = self.chunks.cursor::<(usize, OffsetUtf16)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<usize, OffsetUtf16>>(&()); cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 @@ -295,7 +295,7 @@ impl Rope { if offset >= self.summary().len_utf16 { return self.summary().len; } - let mut cursor = self.chunks.cursor::<(OffsetUtf16, usize)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<OffsetUtf16, usize>>(&()); cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 @@ -308,7 +308,7 @@ impl Rope { if offset >= self.summary().len { return self.summary().lines; } - let mut cursor = self.chunks.cursor::<(usize, Point)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<usize, Point>>(&()); cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 @@ -321,7 +321,7 @@ impl Rope { if offset >= self.summary().len { return self.summary().lines_utf16(); } - let mut cursor = self.chunks.cursor::<(usize, PointUtf16)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<usize, PointUtf16>>(&()); cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 @@ -334,7 +334,7 @@ impl Rope { if point >= self.summary().lines { return self.summary().lines_utf16(); } - let mut cursor = self.chunks.cursor::<(Point, PointUtf16)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<Point, PointUtf16>>(&()); cursor.seek(&point, Bias::Left); let overshoot = point - cursor.start().0; cursor.start().1 @@ -347,7 +347,7 @@ impl Rope { if point >= self.summary().lines { return self.summary().len; } - let mut cursor = self.chunks.cursor::<(Point, usize)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<Point, usize>>(&()); cursor.seek(&point, Bias::Left); let overshoot = point - cursor.start().0; cursor.start().1 @@ -368,7 +368,7 @@ impl Rope { if point >= self.summary().lines_utf16() { return self.summary().len; } - let mut cursor = self.chunks.cursor::<(PointUtf16, usize)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<PointUtf16, usize>>(&()); cursor.seek(&point, Bias::Left); let overshoot = point - cursor.start().0; cursor.start().1 @@ -381,7 +381,7 @@ impl Rope { if point.0 >= self.summary().lines_utf16() { return self.summary().lines; } - let mut cursor = self.chunks.cursor::<(PointUtf16, Point)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<PointUtf16, Point>>(&()); cursor.seek(&point.0, Bias::Left); let overshoot = Unclipped(point.0 - cursor.start().0); cursor.start().1 @@ -1168,16 +1168,17 @@ pub trait TextDimension: fn add_assign(&mut self, other: &Self); } -impl<D1: TextDimension, D2: TextDimension> TextDimension for (D1, D2) { +impl<D1: TextDimension, D2: TextDimension> TextDimension for Dimensions<D1, D2, ()> { fn from_text_summary(summary: &TextSummary) -> Self { - ( + Dimensions( D1::from_text_summary(summary), D2::from_text_summary(summary), + (), ) } fn from_chunk(chunk: ChunkSlice) -> Self { - (D1::from_chunk(chunk), D2::from_chunk(chunk)) + Dimensions(D1::from_chunk(chunk), D2::from_chunk(chunk), ()) } fn add_assign(&mut self, other: &Self) { diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 80a104641f..c1fd1df5ff 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -422,8 +422,26 @@ impl Peer { receiver_id: ConnectionId, request: T, ) -> impl Future<Output = Result<T::Response>> { + let request_start_time = Instant::now(); + let payload_type = T::NAME; + let elapsed_time = move || request_start_time.elapsed().as_millis(); + tracing::info!(payload_type, "start forwarding request"); self.request_internal(Some(sender_id), receiver_id, request) .map_ok(|envelope| envelope.payload) + .inspect_err(move |_| { + tracing::error!( + waiting_for_host_ms = elapsed_time(), + payload_type, + "error forwarding request" + ) + }) + .inspect_ok(move |_| { + tracing::info!( + waiting_for_host_ms = elapsed_time(), + payload_type, + "finished forwarding request" + ) + }) } fn request_internal<T: RequestMessage>( diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index be6a69c23b..ebec96dd7b 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -319,7 +319,7 @@ impl PickerDelegate for RulePickerDelegate { }) .into_any() } else { - IconButton::new("delete-rule", IconName::TrashAlt) + IconButton::new("delete-rule", IconName::Trash) .icon_color(Color::Muted) .icon_size(IconSize::Small) .shape(IconButtonShape::Square) @@ -1101,7 +1101,7 @@ impl RulesLibrary { inlay_hints_style: editor::make_inlay_hints_style( cx, ), - inline_completion_styles: + edit_prediction_styles: editor::make_suggestion_styles(cx), ..EditorStyle::default() }, @@ -1163,7 +1163,7 @@ impl RulesLibrary { }) .into_any() } else { - IconButton::new("delete-rule", IconName::TrashAlt) + IconButton::new("delete-rule", IconName::Trash) .icon_size(IconSize::Small) .tooltip(move |window, cx| { Tooltip::for_action( diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 3b9700c5f1..15c1099aec 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -355,8 +355,9 @@ impl ProjectSearch { while let Some(new_ranges) = new_ranges.next().await { project_search - .update(cx, |project_search, _| { + .update(cx, |project_search, cx| { project_search.match_ranges.extend(new_ranges); + cx.notify(); }) .ok()?; } diff --git a/crates/settings/src/settings.rs b/crates/settings/src/settings.rs index 4e6bd94d92..afd4ea0890 100644 --- a/crates/settings/src/settings.rs +++ b/crates/settings/src/settings.rs @@ -7,7 +7,7 @@ mod settings_json; mod settings_store; mod vscode_import; -use gpui::App; +use gpui::{App, Global}; use rust_embed::RustEmbed; use std::{borrow::Cow, fmt, str}; use util::asset_str; @@ -27,6 +27,11 @@ pub use settings_store::{ }; pub use vscode_import::{VsCodeSettings, VsCodeSettingsSource}; +#[derive(Clone, Debug, PartialEq)] +pub struct ActiveSettingsProfileName(pub String); + +impl Global for ActiveSettingsProfileName {} + #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)] pub struct WorktreeId(usize); @@ -74,6 +79,7 @@ pub fn init(cx: &mut App) { .unwrap(); cx.set_global(settings); BaseKeymap::register(cx); + SettingsStore::observe_active_settings_profile_name(cx).detach(); } pub fn default_settings() -> Cow<'static, str> { diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 0d23385a68..bc42d2c886 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -2,7 +2,11 @@ use anyhow::{Context as _, Result}; use collections::{BTreeMap, HashMap, btree_map, hash_map}; use ec4rs::{ConfigParser, PropertiesSource, Section}; use fs::Fs; -use futures::{FutureExt, StreamExt, channel::mpsc, future::LocalBoxFuture}; +use futures::{ + FutureExt, StreamExt, + channel::{mpsc, oneshot}, + future::LocalBoxFuture, +}; use gpui::{App, AsyncApp, BorrowAppContext, Global, Task, UpdateGlobal}; use paths::{EDITORCONFIG_NAME, local_settings_file_relative_path, task_file_name}; @@ -26,8 +30,8 @@ use util::{ pub type EditorconfigProperties = ec4rs::Properties; use crate::{ - ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, WorktreeId, - parse_json_with_comments, update_value_in_json_text, + ActiveSettingsProfileName, ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, + WorktreeId, parse_json_with_comments, update_value_in_json_text, }; /// A value that can be defined as a user setting. @@ -122,6 +126,8 @@ pub struct SettingsSources<'a, T> { pub user: Option<&'a T>, /// The user settings for the current release channel. pub release_channel: Option<&'a T>, + /// The settings associated with an enabled settings profile + pub profile: Option<&'a T>, /// The server's settings. pub server: Option<&'a T>, /// The project settings, ordered from least specific to most specific. @@ -141,6 +147,7 @@ impl<'a, T: Serialize> SettingsSources<'a, T> { .chain(self.extensions) .chain(self.user) .chain(self.release_channel) + .chain(self.profile) .chain(self.server) .chain(self.project.iter().copied()) } @@ -282,6 +289,14 @@ impl SettingsStore { } } + pub fn observe_active_settings_profile_name(cx: &mut App) -> gpui::Subscription { + cx.observe_global::<ActiveSettingsProfileName>(|cx| { + Self::update_global(cx, |store, cx| { + store.recompute_values(None, cx).log_err(); + }); + }) + } + pub fn update<C, R>(cx: &mut C, f: impl FnOnce(&mut Self, &mut C) -> R) -> R where C: BorrowAppContext, @@ -321,6 +336,17 @@ impl SettingsStore { .log_err(); } + let mut profile_value = None; + if let Some(active_profile) = cx.try_global::<ActiveSettingsProfileName>() { + if let Some(profiles) = self.raw_user_settings.get("profiles") { + if let Some(profile_settings) = profiles.get(&active_profile.0) { + profile_value = setting_value + .deserialize_setting(profile_settings) + .log_err(); + } + } + } + let server_value = self .raw_server_settings .as_ref() @@ -340,6 +366,7 @@ impl SettingsStore { extensions: extension_value.as_ref(), user: user_value.as_ref(), release_channel: release_channel_value.as_ref(), + profile: profile_value.as_ref(), server: server_value.as_ref(), project: &[], }, @@ -402,6 +429,16 @@ impl SettingsStore { &self.raw_user_settings } + /// Get the configured settings profile names. + pub fn configured_settings_profiles(&self) -> impl Iterator<Item = &str> { + self.raw_user_settings + .get("profiles") + .and_then(|v| v.as_object()) + .into_iter() + .flat_map(|obj| obj.keys()) + .map(|s| s.as_str()) + } + /// Access the raw JSON value of the global settings. pub fn raw_global_settings(&self) -> Option<&Value> { self.raw_global_settings.as_ref() @@ -498,41 +535,64 @@ impl SettingsStore { .ok(); } - pub fn import_vscode_settings(&self, fs: Arc<dyn Fs>, vscode_settings: VsCodeSettings) { + pub fn import_vscode_settings( + &self, + fs: Arc<dyn Fs>, + vscode_settings: VsCodeSettings, + ) -> oneshot::Receiver<Result<()>> { + let (tx, rx) = oneshot::channel::<Result<()>>(); self.setting_file_updates_tx .unbounded_send(Box::new(move |cx: AsyncApp| { async move { - let old_text = Self::load_settings(&fs).await?; - let new_text = cx.read_global(|store: &SettingsStore, _cx| { - store.get_vscode_edits(old_text, &vscode_settings) - })?; - let settings_path = paths::settings_file().as_path(); - if fs.is_file(settings_path).await { - let resolved_path = - fs.canonicalize(settings_path).await.with_context(|| { - format!("Failed to canonicalize settings path {:?}", settings_path) - })?; + let res = async move { + let old_text = Self::load_settings(&fs).await?; + let new_text = cx.read_global(|store: &SettingsStore, _cx| { + store.get_vscode_edits(old_text, &vscode_settings) + })?; + let settings_path = paths::settings_file().as_path(); + if fs.is_file(settings_path).await { + let resolved_path = + fs.canonicalize(settings_path).await.with_context(|| { + format!( + "Failed to canonicalize settings path {:?}", + settings_path + ) + })?; - fs.atomic_write(resolved_path.clone(), new_text) - .await - .with_context(|| { - format!("Failed to write settings to file {:?}", resolved_path) - })?; - } else { - fs.atomic_write(settings_path.to_path_buf(), new_text) - .await - .with_context(|| { - format!("Failed to write settings to file {:?}", settings_path) - })?; + fs.atomic_write(resolved_path.clone(), new_text) + .await + .with_context(|| { + format!("Failed to write settings to file {:?}", resolved_path) + })?; + } else { + fs.atomic_write(settings_path.to_path_buf(), new_text) + .await + .with_context(|| { + format!("Failed to write settings to file {:?}", settings_path) + })?; + } + + anyhow::Ok(()) } + .await; - anyhow::Ok(()) + let new_res = match &res { + Ok(_) => anyhow::Ok(()), + Err(e) => Err(anyhow::anyhow!("Failed to write settings to file {:?}", e)), + }; + + _ = tx.send(new_res); + res } .boxed_local() })) .ok(); - } + rx + } +} + +impl SettingsStore { /// Updates the value of a setting in a JSON file, returning the new text /// for that JSON file. pub fn new_text_for_update<T: Settings>( @@ -1001,18 +1061,18 @@ impl SettingsStore { const ZED_SETTINGS: &str = "ZedSettings"; let zed_settings_ref = add_new_subschema(&mut generator, ZED_SETTINGS, combined_schema); - // add `ZedReleaseStageSettings` which is the same as `ZedSettings` except that unknown - // fields are rejected. - let mut zed_release_stage_settings = zed_settings_ref.clone(); - zed_release_stage_settings.insert("unevaluatedProperties".to_string(), false.into()); - let zed_release_stage_settings_ref = add_new_subschema( + // add `ZedSettingsOverride` which is the same as `ZedSettings` except that unknown + // fields are rejected. This is used for release stage settings and profiles. + let mut zed_settings_override = zed_settings_ref.clone(); + zed_settings_override.insert("unevaluatedProperties".to_string(), false.into()); + let zed_settings_override_ref = add_new_subschema( &mut generator, - "ZedReleaseStageSettings", - zed_release_stage_settings.to_value(), + "ZedSettingsOverride", + zed_settings_override.to_value(), ); // Remove `"additionalProperties": false` added by `DefaultDenyUnknownFields` so that - // unknown fields can be handled by the root schema and `ZedReleaseStageSettings`. + // unknown fields can be handled by the root schema and `ZedSettingsOverride`. let mut definitions = generator.take_definitions(true); definitions .get_mut(ZED_SETTINGS) @@ -1032,15 +1092,20 @@ impl SettingsStore { "$schema": meta_schema, "title": "Zed Settings", "unevaluatedProperties": false, - // ZedSettings + settings overrides for each release stage + // ZedSettings + settings overrides for each release stage / profiles "allOf": [ zed_settings_ref, { "properties": { - "dev": zed_release_stage_settings_ref, - "nightly": zed_release_stage_settings_ref, - "stable": zed_release_stage_settings_ref, - "preview": zed_release_stage_settings_ref, + "dev": zed_settings_override_ref, + "nightly": zed_settings_override_ref, + "stable": zed_settings_override_ref, + "preview": zed_settings_override_ref, + "profiles": { + "type": "object", + "description": "Configures any number of settings profiles.", + "additionalProperties": zed_settings_override_ref + } } } ], @@ -1099,6 +1164,16 @@ impl SettingsStore { } } + let mut profile_settings = None; + if let Some(active_profile) = cx.try_global::<ActiveSettingsProfileName>() { + if let Some(profiles) = self.raw_user_settings.get("profiles") { + if let Some(profile_json) = profiles.get(&active_profile.0) { + profile_settings = + setting_value.deserialize_setting(profile_json).log_err(); + } + } + } + // If the global settings file changed, reload the global value for the field. if changed_local_path.is_none() { if let Some(value) = setting_value @@ -1109,6 +1184,7 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &[], }, @@ -1161,6 +1237,7 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &project_settings_stack.iter().collect::<Vec<_>>(), }, @@ -1286,6 +1363,9 @@ impl<T: Settings> AnySettingValue for SettingValue<T> { release_channel: values .release_channel .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), + profile: values + .profile + .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), server: values .server .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), diff --git a/crates/settings_profile_selector/Cargo.toml b/crates/settings_profile_selector/Cargo.toml new file mode 100644 index 0000000000..189272e54b --- /dev/null +++ b/crates/settings_profile_selector/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "settings_profile_selector" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/settings_profile_selector.rs" +doctest = false + +[dependencies] +fuzzy.workspace = true +gpui.workspace = true +picker.workspace = true +settings.workspace = true +ui.workspace = true +workspace-hack.workspace = true +workspace.workspace = true +zed_actions.workspace = true + +[dev-dependencies] +client = { workspace = true, features = ["test-support"] } +editor = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +menu.workspace = true +project = { workspace = true, features = ["test-support"] } +serde_json.workspace = true +settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_profile_selector/LICENSE-GPL b/crates/settings_profile_selector/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/settings_profile_selector/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/settings_profile_selector/src/settings_profile_selector.rs b/crates/settings_profile_selector/src/settings_profile_selector.rs new file mode 100644 index 0000000000..8a34c12051 --- /dev/null +++ b/crates/settings_profile_selector/src/settings_profile_selector.rs @@ -0,0 +1,581 @@ +use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; +use gpui::{ + App, Context, DismissEvent, Entity, EventEmitter, Focusable, Render, Task, WeakEntity, Window, +}; +use picker::{Picker, PickerDelegate}; +use settings::{ActiveSettingsProfileName, SettingsStore}; +use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*}; +use workspace::{ModalView, Workspace}; + +pub fn init(cx: &mut App) { + cx.on_action(|_: &zed_actions::settings_profile_selector::Toggle, cx| { + workspace::with_active_or_new_workspace(cx, |workspace, window, cx| { + toggle_settings_profile_selector(workspace, window, cx); + }); + }); +} + +fn toggle_settings_profile_selector( + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context<Workspace>, +) { + workspace.toggle_modal(window, cx, |window, cx| { + let delegate = SettingsProfileSelectorDelegate::new(cx.entity().downgrade(), window, cx); + SettingsProfileSelector::new(delegate, window, cx) + }); +} + +pub struct SettingsProfileSelector { + picker: Entity<Picker<SettingsProfileSelectorDelegate>>, +} + +impl ModalView for SettingsProfileSelector {} + +impl EventEmitter<DismissEvent> for SettingsProfileSelector {} + +impl Focusable for SettingsProfileSelector { + fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { + self.picker.focus_handle(cx) + } +} + +impl Render for SettingsProfileSelector { + fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { + v_flex().w(rems(22.)).child(self.picker.clone()) + } +} + +impl SettingsProfileSelector { + pub fn new( + delegate: SettingsProfileSelectorDelegate, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); + Self { picker } + } +} + +pub struct SettingsProfileSelectorDelegate { + matches: Vec<StringMatch>, + profile_names: Vec<Option<String>>, + original_profile_name: Option<String>, + selected_profile_name: Option<String>, + selected_index: usize, + selection_completed: bool, + selector: WeakEntity<SettingsProfileSelector>, +} + +impl SettingsProfileSelectorDelegate { + fn new( + selector: WeakEntity<SettingsProfileSelector>, + _: &mut Window, + cx: &mut Context<SettingsProfileSelector>, + ) -> Self { + let settings_store = cx.global::<SettingsStore>(); + let mut profile_names: Vec<Option<String>> = settings_store + .configured_settings_profiles() + .map(|s| Some(s.to_string())) + .collect(); + profile_names.insert(0, None); + + let matches = profile_names + .iter() + .enumerate() + .map(|(ix, profile_name)| StringMatch { + candidate_id: ix, + score: 0.0, + positions: Default::default(), + string: display_name(profile_name), + }) + .collect(); + + let profile_name = cx + .try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()); + + let mut this = Self { + matches, + profile_names, + original_profile_name: profile_name.clone(), + selected_profile_name: None, + selected_index: 0, + selection_completed: false, + selector, + }; + + if let Some(profile_name) = profile_name { + this.select_if_matching(&profile_name); + } + + this + } + + fn select_if_matching(&mut self, profile_name: &str) { + self.selected_index = self + .matches + .iter() + .position(|mat| mat.string == profile_name) + .unwrap_or(self.selected_index); + } + + fn set_selected_profile( + &self, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) -> Option<String> { + let mat = self.matches.get(self.selected_index)?; + let profile_name = self.profile_names.get(mat.candidate_id)?; + return Self::update_active_profile_name_global(profile_name.clone(), cx); + } + + fn update_active_profile_name_global( + profile_name: Option<String>, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) -> Option<String> { + if let Some(profile_name) = profile_name { + cx.set_global(ActiveSettingsProfileName(profile_name.clone())); + return Some(profile_name.clone()); + } + + if cx.has_global::<ActiveSettingsProfileName>() { + cx.remove_global::<ActiveSettingsProfileName>(); + } + + None + } +} + +impl PickerDelegate for SettingsProfileSelectorDelegate { + type ListItem = ListItem; + + fn placeholder_text(&self, _: &mut Window, _: &mut App) -> std::sync::Arc<str> { + "Select a settings profile...".into() + } + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) { + self.selected_index = ix; + self.selected_profile_name = self.set_selected_profile(cx); + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) -> Task<()> { + let background = cx.background_executor().clone(); + let candidates = self + .profile_names + .iter() + .enumerate() + .map(|(id, profile_name)| StringMatchCandidate::new(id, &display_name(profile_name))) + .collect::<Vec<_>>(); + + cx.spawn_in(window, async move |this, cx| { + let matches = if query.is_empty() { + candidates + .into_iter() + .enumerate() + .map(|(index, candidate)| StringMatch { + candidate_id: index, + string: candidate.string, + positions: Vec::new(), + score: 0.0, + }) + .collect() + } else { + match_strings( + &candidates, + &query, + false, + true, + 100, + &Default::default(), + background, + ) + .await + }; + + this.update_in(cx, |this, _, cx| { + this.delegate.matches = matches; + this.delegate.selected_index = this + .delegate + .selected_index + .min(this.delegate.matches.len().saturating_sub(1)); + this.delegate.selected_profile_name = this.delegate.set_selected_profile(cx); + }) + .ok(); + }) + } + + fn confirm( + &mut self, + _: bool, + _: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) { + self.selection_completed = true; + self.selector + .update(cx, |_, cx| { + cx.emit(DismissEvent); + }) + .ok(); + } + + fn dismissed( + &mut self, + _: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) { + if !self.selection_completed { + SettingsProfileSelectorDelegate::update_active_profile_name_global( + self.original_profile_name.clone(), + cx, + ); + } + self.selector.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _: &mut Window, + _: &mut Context<Picker<Self>>, + ) -> Option<Self::ListItem> { + let mat = &self.matches[ix]; + let profile_name = &self.profile_names[mat.candidate_id]; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .child(HighlightedLabel::new( + display_name(profile_name), + mat.positions.clone(), + )), + ) + } +} + +fn display_name(profile_name: &Option<String>) -> String { + profile_name.clone().unwrap_or("Disabled".into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use client; + use editor; + use gpui::{TestAppContext, UpdateGlobal, VisualTestContext}; + use language; + use menu::{Cancel, Confirm, SelectNext, SelectPrevious}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::Settings; + use theme::{self, ThemeSettings}; + use workspace::{self, AppState}; + use zed_actions::settings_profile_selector; + + async fn init_test( + profiles_json: serde_json::Value, + cx: &mut TestAppContext, + ) -> (Entity<Workspace>, &mut VisualTestContext) { + cx.update(|cx| { + let state = AppState::test(cx); + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + settings::init(cx); + theme::init(theme::LoadThemes::JustBase, cx); + ThemeSettings::register(cx); + client::init_settings(cx); + language::init(cx); + super::init(cx); + editor::init(cx); + workspace::init_settings(cx); + Project::init_settings(cx); + state + }); + + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + let settings_json = json!({ + "buffer_font_size": 10.0, + "profiles": profiles_json, + }); + + store + .set_user_settings(&settings_json.to_string(), cx) + .unwrap(); + }); + }); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, ["/test".as_ref()], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + cx.update(|_, cx| { + assert!(!cx.has_global::<ActiveSettingsProfileName>()); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + (workspace, cx) + } + + #[track_caller] + fn active_settings_profile_picker( + workspace: &Entity<Workspace>, + cx: &mut VisualTestContext, + ) -> Entity<Picker<SettingsProfileSelectorDelegate>> { + workspace.update(cx, |workspace, cx| { + workspace + .active_modal::<SettingsProfileSelector>(cx) + .expect("settings profile selector is not open") + .read(cx) + .picker + .clone() + }) + } + + #[gpui::test] + async fn test_settings_profile_selector_state(cx: &mut TestAppContext) { + let classroom_and_streaming_profile_name = "Classroom / Streaming".to_string(); + let demo_videos_profile_name = "Demo Videos".to_string(); + + let profiles_json = json!({ + classroom_and_streaming_profile_name.clone(): { + "buffer_font_size": 20.0, + }, + demo_videos_profile_name.clone(): { + "buffer_font_size": 15.0 + } + }); + let (workspace, cx) = init_test(profiles_json.clone(), cx).await; + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.matches.len(), 3); + assert_eq!(picker.delegate.matches[0].string, display_name(&None)); + assert_eq!( + picker.delegate.matches[1].string, + classroom_and_streaming_profile_name + ); + assert_eq!(picker.delegate.matches[2].string, demo_videos_profile_name); + assert_eq!(picker.delegate.matches.get(3), None); + + assert_eq!(picker.delegate.selected_index, 0); + assert_eq!(picker.delegate.selected_profile_name, None); + + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(Cancel); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(Cancel); + + cx.update(|_, cx| { + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 0); + assert_eq!(picker.delegate.selected_profile_name, None); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + None + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + } +} diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index 25f033469d..a4c47081c6 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -30,7 +30,6 @@ menu.workspace = true notifications.workspace = true paths.workspace = true project.workspace = true -schemars.workspace = true search.workspace = true serde.workspace = true serde_json.workspace = true @@ -48,3 +47,7 @@ workspace.workspace = true [dev-dependencies] db = {"workspace"= true, "features" = ["test-support"]} +fs = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_ui/src/keybindings.rs b/crates/settings_ui/src/keybindings.rs index a0cbdb9680..70afe1729c 100644 --- a/crates/settings_ui/src/keybindings.rs +++ b/crates/settings_ui/src/keybindings.rs @@ -11,11 +11,10 @@ use editor::{CompletionProvider, Editor, EditorEvent}; use fs::Fs; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ - Action, Animation, AnimationExt, AppContext as _, AsyncApp, Axis, ClickEvent, Context, - DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, Global, IsZero, - KeyContext, Keystroke, Modifiers, ModifiersChangedEvent, MouseButton, Point, ScrollStrategy, - ScrollWheelEvent, Stateful, StyledText, Subscription, Task, TextStyleRefinement, WeakEntity, - actions, anchored, deferred, div, + Action, AppContext as _, AsyncApp, Axis, ClickEvent, Context, DismissEvent, Entity, + EventEmitter, FocusHandle, Focusable, Global, IsZero, KeyContext, Keystroke, MouseButton, + Point, ScrollStrategy, ScrollWheelEvent, Stateful, StyledText, Subscription, Task, + TextStyleRefinement, WeakEntity, actions, anchored, deferred, div, }; use language::{Language, LanguageConfig, ToOffset as _}; use notifications::status_toast::{StatusToast, ToastIcon}; @@ -35,7 +34,10 @@ use workspace::{ use crate::{ keybindings::persistence::KEYBINDING_EDITORS, - ui_components::table::{ColumnWidths, ResizeBehavior, Table, TableInteractionState}, + ui_components::{ + keystroke_input::{ClearKeystrokes, KeystrokeInput, StartRecording, StopRecording}, + table::{ColumnWidths, ResizeBehavior, Table, TableInteractionState}, + }, }; const NO_ACTION_ARGUMENTS_TEXT: SharedString = SharedString::new_static("<no arguments>"); @@ -72,18 +74,6 @@ actions!( ] ); -actions!( - keystroke_input, - [ - /// Starts recording keystrokes - StartRecording, - /// Stops recording keystrokes - StopRecording, - /// Clears the recorded keystrokes - ClearKeystrokes, - ] -); - pub fn init(cx: &mut App) { let keymap_event_channel = KeymapEventChannel::new(); cx.set_global(keymap_event_channel); @@ -393,7 +383,7 @@ impl KeymapEditor { let keystroke_editor = cx.new(|cx| { let mut keystroke_editor = KeystrokeInput::new(None, window, cx); - keystroke_editor.search = true; + keystroke_editor.set_search(true); keystroke_editor }); @@ -566,24 +556,40 @@ impl KeymapEditor { && query.modifiers == keystroke.modifiers }, ) + } else if keystroke_query.len() > keystrokes.len() { + return false; } else { - let key_press_query = - KeyPressIterator::new(keystroke_query.as_slice()); - let mut last_match_idx = 0; + for keystroke_offset in 0..keystrokes.len() { + let mut found_count = 0; + let mut query_cursor = 0; + let mut keystroke_cursor = keystroke_offset; + while query_cursor < keystroke_query.len() + && keystroke_cursor < keystrokes.len() + { + let query = &keystroke_query[query_cursor]; + let keystroke = &keystrokes[keystroke_cursor]; + let matches = + query.modifiers.is_subset_of(&keystroke.modifiers) + && ((query.key.is_empty() + || query.key == keystroke.key) + && query + .key_char + .as_ref() + .map_or(true, |q_kc| { + q_kc == &keystroke.key + })); + if matches { + found_count += 1; + query_cursor += 1; + } + keystroke_cursor += 1; + } - key_press_query.into_iter().all(|key| { - let key_presses = KeyPressIterator::new(keystrokes); - key_presses.into_iter().enumerate().any( - |(index, keystroke)| { - if last_match_idx > index || keystroke != key { - return false; - } - - last_match_idx = index; - true - }, - ) - }) + if found_count == keystroke_query.len() { + return true; + } + } + return false; } }) }); @@ -1232,11 +1238,14 @@ impl KeymapEditor { match self.search_mode { SearchMode::KeyStroke { .. } => { - window.focus(&self.keystroke_editor.read(cx).recording_focus_handle(cx)); + self.keystroke_editor.update(cx, |editor, cx| { + editor.start_recording(&StartRecording, window, cx); + }); } SearchMode::Normal => { self.keystroke_editor.update(cx, |editor, cx| { - editor.clear_keystrokes(&ClearKeystrokes, window, cx) + editor.stop_recording(&StopRecording, window, cx); + editor.clear_keystrokes(&ClearKeystrokes, window, cx); }); window.focus(&self.filter_editor.focus_handle(cx)); } @@ -1671,7 +1680,7 @@ impl Render for KeymapEditor { move |window, cx| this.read(cx).render_no_matches_hint(window, cx) }) .column_widths([ - DefiniteLength::Absolute(AbsoluteLength::Pixels(px(40.))), + DefiniteLength::Absolute(AbsoluteLength::Pixels(px(36.))), DefiniteLength::Fraction(0.25), DefiniteLength::Fraction(0.20), DefiniteLength::Fraction(0.14), @@ -1746,6 +1755,7 @@ impl Render for KeymapEditor { }, ) .into_any_element(); + let keystrokes = binding.ui_key_binding().cloned().map_or( binding .keystroke_text() @@ -1754,6 +1764,7 @@ impl Render for KeymapEditor { .into_any_element(), IntoElement::into_any_element, ); + let action_arguments = match binding.action().arguments.clone() { Some(arguments) => arguments.into_any_element(), @@ -1766,6 +1777,7 @@ impl Render for KeymapEditor { } } }; + let context = binding.context().cloned().map_or( gpui::Empty.into_any_element(), |context| { @@ -1790,11 +1802,13 @@ impl Render for KeymapEditor { .into_any_element() }, ); + let source = binding .keybind_source() .map(|source| source.name()) .unwrap_or_default() .into_any_element(); + Some([ icon.into_any_element(), action, @@ -2955,516 +2969,6 @@ async fn remove_keybinding( Ok(()) } -#[derive(PartialEq, Eq, Debug, Copy, Clone)] -enum CloseKeystrokeResult { - Partial, - Close, - None, -} - -#[derive(PartialEq, Eq, Debug, Clone)] -enum KeyPress<'a> { - Alt, - Control, - Function, - Shift, - Platform, - Key(&'a String), -} - -struct KeystrokeInput { - keystrokes: Vec<Keystroke>, - placeholder_keystrokes: Option<Vec<Keystroke>>, - outer_focus_handle: FocusHandle, - inner_focus_handle: FocusHandle, - intercept_subscription: Option<Subscription>, - _focus_subscriptions: [Subscription; 2], - search: bool, - /// Handles tripe escape to stop recording - close_keystrokes: Option<Vec<Keystroke>>, - close_keystrokes_start: Option<usize>, -} - -impl KeystrokeInput { - const KEYSTROKE_COUNT_MAX: usize = 3; - - fn new( - placeholder_keystrokes: Option<Vec<Keystroke>>, - window: &mut Window, - cx: &mut Context<Self>, - ) -> Self { - let outer_focus_handle = cx.focus_handle(); - let inner_focus_handle = cx.focus_handle(); - let _focus_subscriptions = [ - cx.on_focus_in(&inner_focus_handle, window, Self::on_inner_focus_in), - cx.on_focus_out(&inner_focus_handle, window, Self::on_inner_focus_out), - ]; - Self { - keystrokes: Vec::new(), - placeholder_keystrokes, - inner_focus_handle, - outer_focus_handle, - intercept_subscription: None, - _focus_subscriptions, - search: false, - close_keystrokes: None, - close_keystrokes_start: None, - } - } - - fn set_keystrokes(&mut self, keystrokes: Vec<Keystroke>, cx: &mut Context<Self>) { - self.keystrokes = keystrokes; - self.keystrokes_changed(cx); - } - - fn dummy(modifiers: Modifiers) -> Keystroke { - return Keystroke { - modifiers, - key: "".to_string(), - key_char: None, - }; - } - - fn keystrokes_changed(&self, cx: &mut Context<Self>) { - cx.emit(()); - cx.notify(); - } - - fn key_context() -> KeyContext { - let mut key_context = KeyContext::new_with_defaults(); - key_context.add("KeystrokeInput"); - key_context - } - - fn handle_possible_close_keystroke( - &mut self, - keystroke: &Keystroke, - window: &mut Window, - cx: &mut Context<Self>, - ) -> CloseKeystrokeResult { - let Some(keybind_for_close_action) = window - .highest_precedence_binding_for_action_in_context(&StopRecording, Self::key_context()) - else { - log::trace!("No keybinding to stop recording keystrokes in keystroke input"); - self.close_keystrokes.take(); - self.close_keystrokes_start.take(); - return CloseKeystrokeResult::None; - }; - let action_keystrokes = keybind_for_close_action.keystrokes(); - - if let Some(mut close_keystrokes) = self.close_keystrokes.take() { - let mut index = 0; - - while index < action_keystrokes.len() && index < close_keystrokes.len() { - if !close_keystrokes[index].should_match(&action_keystrokes[index]) { - break; - } - index += 1; - } - if index == close_keystrokes.len() { - if index >= action_keystrokes.len() { - self.close_keystrokes_start.take(); - return CloseKeystrokeResult::None; - } - if keystroke.should_match(&action_keystrokes[index]) { - if action_keystrokes.len() >= 1 && index == action_keystrokes.len() - 1 { - self.stop_recording(&StopRecording, window, cx); - return CloseKeystrokeResult::Close; - } else { - close_keystrokes.push(keystroke.clone()); - self.close_keystrokes = Some(close_keystrokes); - return CloseKeystrokeResult::Partial; - } - } else { - self.close_keystrokes_start.take(); - return CloseKeystrokeResult::None; - } - } - } else if let Some(first_action_keystroke) = action_keystrokes.first() - && keystroke.should_match(first_action_keystroke) - { - self.close_keystrokes = Some(vec![keystroke.clone()]); - return CloseKeystrokeResult::Partial; - } - self.close_keystrokes_start.take(); - return CloseKeystrokeResult::None; - } - - fn on_modifiers_changed( - &mut self, - event: &ModifiersChangedEvent, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - let keystrokes_len = self.keystrokes.len(); - - if let Some(last) = self.keystrokes.last_mut() - && last.key.is_empty() - && keystrokes_len <= Self::KEYSTROKE_COUNT_MAX - { - if self.search { - last.modifiers = last.modifiers.xor(&event.modifiers); - } else if !event.modifiers.modified() { - self.keystrokes.pop(); - } else { - last.modifiers = event.modifiers; - } - - self.keystrokes_changed(cx); - } else if keystrokes_len < Self::KEYSTROKE_COUNT_MAX { - self.keystrokes.push(Self::dummy(event.modifiers)); - self.keystrokes_changed(cx); - } - cx.stop_propagation(); - } - - fn handle_keystroke( - &mut self, - keystroke: &Keystroke, - window: &mut Window, - cx: &mut Context<Self>, - ) { - let close_keystroke_result = self.handle_possible_close_keystroke(keystroke, window, cx); - if close_keystroke_result != CloseKeystrokeResult::Close { - let key_len = self.keystrokes.len(); - if let Some(last) = self.keystrokes.last_mut() - && last.key.is_empty() - && key_len <= Self::KEYSTROKE_COUNT_MAX - { - if self.search { - last.key = keystroke.key.clone(); - if close_keystroke_result == CloseKeystrokeResult::Partial - && self.close_keystrokes_start.is_none() - { - self.close_keystrokes_start = Some(self.keystrokes.len() - 1); - } - self.keystrokes_changed(cx); - cx.stop_propagation(); - return; - } else { - self.keystrokes.pop(); - } - } - if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX { - if close_keystroke_result == CloseKeystrokeResult::Partial - && self.close_keystrokes_start.is_none() - { - self.close_keystrokes_start = Some(self.keystrokes.len()); - } - self.keystrokes.push(keystroke.clone()); - if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX { - self.keystrokes.push(Self::dummy(keystroke.modifiers)); - } - } else if close_keystroke_result != CloseKeystrokeResult::Partial { - self.clear_keystrokes(&ClearKeystrokes, window, cx); - } - } - self.keystrokes_changed(cx); - cx.stop_propagation(); - } - - fn on_inner_focus_in(&mut self, _window: &mut Window, cx: &mut Context<Self>) { - if self.intercept_subscription.is_none() { - let listener = cx.listener(|this, event: &gpui::KeystrokeEvent, window, cx| { - this.handle_keystroke(&event.keystroke, window, cx); - }); - self.intercept_subscription = Some(cx.intercept_keystrokes(listener)) - } - } - - fn on_inner_focus_out( - &mut self, - _event: gpui::FocusOutEvent, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - self.intercept_subscription.take(); - cx.notify(); - } - - fn keystrokes(&self) -> &[Keystroke] { - if let Some(placeholders) = self.placeholder_keystrokes.as_ref() - && self.keystrokes.is_empty() - { - return placeholders; - } - if !self.search - && self - .keystrokes - .last() - .map_or(false, |last| last.key.is_empty()) - { - return &self.keystrokes[..self.keystrokes.len() - 1]; - } - return &self.keystrokes; - } - - fn render_keystrokes(&self, is_recording: bool) -> impl Iterator<Item = Div> { - let keystrokes = if let Some(placeholders) = self.placeholder_keystrokes.as_ref() - && self.keystrokes.is_empty() - { - if is_recording { - &[] - } else { - placeholders.as_slice() - } - } else { - &self.keystrokes - }; - keystrokes.iter().map(move |keystroke| { - h_flex().children(ui::render_keystroke( - keystroke, - Some(Color::Default), - Some(rems(0.875).into()), - ui::PlatformStyle::platform(), - false, - )) - }) - } - - fn recording_focus_handle(&self, _cx: &App) -> FocusHandle { - self.inner_focus_handle.clone() - } - - fn start_recording(&mut self, _: &StartRecording, window: &mut Window, cx: &mut Context<Self>) { - if !self.outer_focus_handle.is_focused(window) { - return; - } - self.clear_keystrokes(&ClearKeystrokes, window, cx); - window.focus(&self.inner_focus_handle); - cx.notify(); - } - - fn stop_recording(&mut self, _: &StopRecording, window: &mut Window, cx: &mut Context<Self>) { - if !self.inner_focus_handle.is_focused(window) { - return; - } - window.focus(&self.outer_focus_handle); - if let Some(close_keystrokes_start) = self.close_keystrokes_start.take() - && close_keystrokes_start < self.keystrokes.len() - { - self.keystrokes.drain(close_keystrokes_start..); - } - self.close_keystrokes.take(); - cx.notify(); - } - - fn clear_keystrokes( - &mut self, - _: &ClearKeystrokes, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - self.keystrokes.clear(); - self.keystrokes_changed(cx); - } -} - -impl EventEmitter<()> for KeystrokeInput {} - -impl Focusable for KeystrokeInput { - fn focus_handle(&self, _cx: &App) -> FocusHandle { - self.outer_focus_handle.clone() - } -} - -impl Render for KeystrokeInput { - fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - let colors = cx.theme().colors(); - let is_focused = self.outer_focus_handle.contains_focused(window, cx); - let is_recording = self.inner_focus_handle.is_focused(window); - - let horizontal_padding = rems_from_px(64.); - - let recording_bg_color = colors - .editor_background - .blend(colors.text_accent.opacity(0.1)); - - let recording_pulse = |color: Color| { - Icon::new(IconName::Circle) - .size(IconSize::Small) - .color(Color::Error) - .with_animation( - "recording-pulse", - Animation::new(std::time::Duration::from_secs(2)) - .repeat() - .with_easing(gpui::pulsating_between(0.4, 0.8)), - { - let color = color.color(cx); - move |this, delta| this.color(Color::Custom(color.opacity(delta))) - }, - ) - }; - - let recording_indicator = h_flex() - .h_4() - .pr_1() - .gap_0p5() - .border_1() - .border_color(colors.border) - .bg(colors - .editor_background - .blend(colors.text_accent.opacity(0.1))) - .rounded_sm() - .child(recording_pulse(Color::Error)) - .child( - Label::new("REC") - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Error), - ); - - let search_indicator = h_flex() - .h_4() - .pr_1() - .gap_0p5() - .border_1() - .border_color(colors.border) - .bg(colors - .editor_background - .blend(colors.text_accent.opacity(0.1))) - .rounded_sm() - .child(recording_pulse(Color::Accent)) - .child( - Label::new("SEARCH") - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Accent), - ); - - let record_icon = if self.search { - IconName::MagnifyingGlass - } else { - IconName::PlayFilled - }; - - h_flex() - .id("keystroke-input") - .track_focus(&self.outer_focus_handle) - .py_2() - .px_3() - .gap_2() - .min_h_10() - .w_full() - .flex_1() - .justify_between() - .rounded_lg() - .overflow_hidden() - .map(|this| { - if is_recording { - this.bg(recording_bg_color) - } else { - this.bg(colors.editor_background) - } - }) - .border_1() - .border_color(colors.border_variant) - .when(is_focused, |parent| { - parent.border_color(colors.border_focused) - }) - .key_context(Self::key_context()) - .on_action(cx.listener(Self::start_recording)) - .on_action(cx.listener(Self::stop_recording)) - .child( - h_flex() - .w(horizontal_padding) - .gap_0p5() - .justify_start() - .flex_none() - .when(is_recording, |this| { - this.map(|this| { - if self.search { - this.child(search_indicator) - } else { - this.child(recording_indicator) - } - }) - }), - ) - .child( - h_flex() - .id("keystroke-input-inner") - .track_focus(&self.inner_focus_handle) - .on_modifiers_changed(cx.listener(Self::on_modifiers_changed)) - .size_full() - .when(!self.search, |this| { - this.focus(|mut style| { - style.border_color = Some(colors.border_focused); - style - }) - }) - .w_full() - .min_w_0() - .justify_center() - .flex_wrap() - .gap(ui::DynamicSpacing::Base04.rems(cx)) - .children(self.render_keystrokes(is_recording)), - ) - .child( - h_flex() - .w(horizontal_padding) - .gap_0p5() - .justify_end() - .flex_none() - .map(|this| { - if is_recording { - this.child( - IconButton::new("stop-record-btn", IconName::StopFilled) - .shape(ui::IconButtonShape::Square) - .map(|this| { - this.tooltip(Tooltip::for_action_title( - if self.search { - "Stop Searching" - } else { - "Stop Recording" - }, - &StopRecording, - )) - }) - .icon_color(Color::Error) - .on_click(cx.listener(|this, _event, window, cx| { - this.stop_recording(&StopRecording, window, cx); - })), - ) - } else { - this.child( - IconButton::new("record-btn", record_icon) - .shape(ui::IconButtonShape::Square) - .map(|this| { - this.tooltip(Tooltip::for_action_title( - if self.search { - "Start Searching" - } else { - "Start Recording" - }, - &StartRecording, - )) - }) - .when(!is_focused, |this| this.icon_color(Color::Muted)) - .on_click(cx.listener(|this, _event, window, cx| { - this.start_recording(&StartRecording, window, cx); - })), - ) - } - }) - .child( - IconButton::new("clear-btn", IconName::Delete) - .shape(ui::IconButtonShape::Square) - .tooltip(Tooltip::for_action_title( - "Clear Keystrokes", - &ClearKeystrokes, - )) - .when(!is_recording || !is_focused, |this| { - this.icon_color(Color::Muted) - }) - .on_click(cx.listener(|this, _event, window, cx| { - this.clear_keystrokes(&ClearKeystrokes, window, cx); - })), - ), - ) - } -} - fn collect_contexts_from_assets() -> Vec<SharedString> { let mut keymap_assets = vec![ util::asset_str::<SettingsAssets>(settings::DEFAULT_KEYMAP_PATH), @@ -3633,72 +3137,3 @@ mod persistence { } } } - -/// Iterator that yields KeyPress values from a slice of Keystrokes -struct KeyPressIterator<'a> { - keystrokes: &'a [Keystroke], - current_keystroke_index: usize, - current_key_press_index: usize, -} - -impl<'a> KeyPressIterator<'a> { - fn new(keystrokes: &'a [Keystroke]) -> Self { - Self { - keystrokes, - current_keystroke_index: 0, - current_key_press_index: 0, - } - } -} - -impl<'a> Iterator for KeyPressIterator<'a> { - type Item = KeyPress<'a>; - - fn next(&mut self) -> Option<Self::Item> { - loop { - let keystroke = self.keystrokes.get(self.current_keystroke_index)?; - - match self.current_key_press_index { - 0 => { - self.current_key_press_index = 1; - if keystroke.modifiers.platform { - return Some(KeyPress::Platform); - } - } - 1 => { - self.current_key_press_index = 2; - if keystroke.modifiers.alt { - return Some(KeyPress::Alt); - } - } - 2 => { - self.current_key_press_index = 3; - if keystroke.modifiers.control { - return Some(KeyPress::Control); - } - } - 3 => { - self.current_key_press_index = 4; - if keystroke.modifiers.shift { - return Some(KeyPress::Shift); - } - } - 4 => { - self.current_key_press_index = 5; - if keystroke.modifiers.function { - return Some(KeyPress::Function); - } - } - _ => { - self.current_keystroke_index += 1; - self.current_key_press_index = 0; - - if keystroke.key.is_empty() { - continue; - } - return Some(KeyPress::Key(&keystroke.key)); - } - } - } - } -} diff --git a/crates/settings_ui/src/settings_ui.rs b/crates/settings_ui/src/settings_ui.rs index 2f0abb4789..3022cc7142 100644 --- a/crates/settings_ui/src/settings_ui.rs +++ b/crates/settings_ui/src/settings_ui.rs @@ -1,20 +1,12 @@ mod appearance_settings_controls; use std::any::TypeId; -use std::sync::Arc; use command_palette_hooks::CommandPaletteFilter; use editor::EditorSettingsControls; use feature_flags::{FeatureFlag, FeatureFlagViewExt}; -use fs::Fs; -use gpui::{ - Action, App, AsyncWindowContext, Entity, EventEmitter, FocusHandle, Focusable, Task, actions, -}; -use schemars::JsonSchema; -use serde::Deserialize; -use settings::{SettingsStore, VsCodeSettingsSource}; +use gpui::{App, Entity, EventEmitter, FocusHandle, Focusable, actions}; use ui::prelude::*; -use util::truncate_and_remove_front; use workspace::item::{Item, ItemEvent}; use workspace::{Workspace, with_active_or_new_workspace}; @@ -29,23 +21,6 @@ impl FeatureFlag for SettingsUiFeatureFlag { const NAME: &'static str = "settings-ui"; } -/// Imports settings from Visual Studio Code. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportVsCodeSettings { - #[serde(default)] - pub skip_prompt: bool, -} - -/// Imports settings from Cursor editor. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportCursorSettings { - #[serde(default)] - pub skip_prompt: bool, -} actions!( zed, [ @@ -72,45 +47,11 @@ pub fn init(cx: &mut App) { }); }); - cx.observe_new(|workspace: &mut Workspace, window, cx| { + cx.observe_new(|_workspace: &mut Workspace, window, cx| { let Some(window) = window else { return; }; - workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { - let fs = <dyn Fs>::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::VsCode, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - - workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { - let fs = <dyn Fs>::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::Cursor, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - let settings_ui_actions = [TypeId::of::<OpenSettingsEditor>()]; CommandPaletteFilter::update_global(cx, |filter, _cx| { @@ -138,57 +79,6 @@ pub fn init(cx: &mut App) { keybindings::init(cx); } -async fn handle_import_vscode_settings( - source: VsCodeSettingsSource, - skip_prompt: bool, - fs: Arc<dyn Fs>, - cx: &mut AsyncWindowContext, -) { - let vscode_settings = - match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { - Ok(vscode_settings) => vscode_settings, - Err(err) => { - log::error!("{err}"); - let _ = cx.prompt( - gpui::PromptLevel::Info, - &format!("Could not find or load a {source} settings file"), - None, - &["Ok"], - ); - return; - } - }; - - let prompt = if skip_prompt { - Task::ready(Some(0)) - } else { - let prompt = cx.prompt( - gpui::PromptLevel::Warning, - &format!( - "Importing {} settings may overwrite your existing settings. \ - Will import settings from {}", - vscode_settings.source, - truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), - ), - None, - &["Ok", "Cancel"], - ); - cx.spawn(async move |_| prompt.await.ok()) - }; - if prompt.await != Some(0) { - return; - } - - cx.update(|_, cx| { - let source = vscode_settings.source; - let path = vscode_settings.path.clone(); - cx.global::<SettingsStore>() - .import_vscode_settings(fs, vscode_settings); - log::info!("Imported {source} settings from {}", path.display()); - }) - .ok(); -} - pub struct SettingsPage { focus_handle: FocusHandle, } diff --git a/crates/settings_ui/src/ui_components/keystroke_input.rs b/crates/settings_ui/src/ui_components/keystroke_input.rs new file mode 100644 index 0000000000..03d27d0ab9 --- /dev/null +++ b/crates/settings_ui/src/ui_components/keystroke_input.rs @@ -0,0 +1,1388 @@ +use gpui::{ + Animation, AnimationExt, Context, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, + Keystroke, Modifiers, ModifiersChangedEvent, Subscription, Task, actions, +}; +use ui::{ + ActiveTheme as _, Color, IconButton, IconButtonShape, IconName, IconSize, Label, LabelSize, + ParentElement as _, Render, Styled as _, Tooltip, Window, prelude::*, +}; + +actions!( + keystroke_input, + [ + /// Starts recording keystrokes + StartRecording, + /// Stops recording keystrokes + StopRecording, + /// Clears the recorded keystrokes + ClearKeystrokes, + ] +); + +const KEY_CONTEXT_VALUE: &'static str = "KeystrokeInput"; + +const CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT: std::time::Duration = + std::time::Duration::from_millis(300); + +enum CloseKeystrokeResult { + Partial, + Close, + None, +} + +impl PartialEq for CloseKeystrokeResult { + fn eq(&self, other: &Self) -> bool { + matches!( + (self, other), + (CloseKeystrokeResult::Partial, CloseKeystrokeResult::Partial) + | (CloseKeystrokeResult::Close, CloseKeystrokeResult::Close) + | (CloseKeystrokeResult::None, CloseKeystrokeResult::None) + ) + } +} + +pub struct KeystrokeInput { + keystrokes: Vec<Keystroke>, + placeholder_keystrokes: Option<Vec<Keystroke>>, + outer_focus_handle: FocusHandle, + inner_focus_handle: FocusHandle, + intercept_subscription: Option<Subscription>, + _focus_subscriptions: [Subscription; 2], + search: bool, + /// The sequence of close keystrokes being typed + close_keystrokes: Option<Vec<Keystroke>>, + close_keystrokes_start: Option<usize>, + previous_modifiers: Modifiers, + /// In order to support inputting keystrokes that end with a prefix of the + /// close keybind keystrokes, we clear the close keystroke capture info + /// on a timeout after a close keystroke is pressed + /// + /// e.g. if close binding is `esc esc esc` and user wants to search for + /// `ctrl-g esc`, after entering the `ctrl-g esc`, hitting `esc` twice would + /// stop recording because of the sequence of three escapes making it + /// impossible to search for anything ending in `esc` + clear_close_keystrokes_timer: Option<Task<()>>, + #[cfg(test)] + recording: bool, +} + +impl KeystrokeInput { + const KEYSTROKE_COUNT_MAX: usize = 3; + + pub fn new( + placeholder_keystrokes: Option<Vec<Keystroke>>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let outer_focus_handle = cx.focus_handle(); + let inner_focus_handle = cx.focus_handle(); + let _focus_subscriptions = [ + cx.on_focus_in(&inner_focus_handle, window, Self::on_inner_focus_in), + cx.on_focus_out(&inner_focus_handle, window, Self::on_inner_focus_out), + ]; + Self { + keystrokes: Vec::new(), + placeholder_keystrokes, + inner_focus_handle, + outer_focus_handle, + intercept_subscription: None, + _focus_subscriptions, + search: false, + close_keystrokes: None, + close_keystrokes_start: None, + previous_modifiers: Modifiers::default(), + clear_close_keystrokes_timer: None, + #[cfg(test)] + recording: false, + } + } + + pub fn set_keystrokes(&mut self, keystrokes: Vec<Keystroke>, cx: &mut Context<Self>) { + self.keystrokes = keystrokes; + self.keystrokes_changed(cx); + } + + pub fn set_search(&mut self, search: bool) { + self.search = search; + } + + pub fn keystrokes(&self) -> &[Keystroke] { + if let Some(placeholders) = self.placeholder_keystrokes.as_ref() + && self.keystrokes.is_empty() + { + return placeholders; + } + if !self.search + && self + .keystrokes + .last() + .map_or(false, |last| last.key.is_empty()) + { + return &self.keystrokes[..self.keystrokes.len() - 1]; + } + return &self.keystrokes; + } + + fn dummy(modifiers: Modifiers) -> Keystroke { + return Keystroke { + modifiers, + key: "".to_string(), + key_char: None, + }; + } + + fn keystrokes_changed(&self, cx: &mut Context<Self>) { + cx.emit(()); + cx.notify(); + } + + fn key_context() -> KeyContext { + let mut key_context = KeyContext::default(); + key_context.add(KEY_CONTEXT_VALUE); + key_context + } + + fn determine_stop_recording_binding(window: &mut Window) -> Option<gpui::KeyBinding> { + if cfg!(test) { + Some(gpui::KeyBinding::new( + "escape escape escape", + StopRecording, + Some(KEY_CONTEXT_VALUE), + )) + } else { + window.highest_precedence_binding_for_action_in_context( + &StopRecording, + Self::key_context(), + ) + } + } + + fn upsert_close_keystrokes_start(&mut self, start: usize, cx: &mut Context<Self>) { + if self.close_keystrokes_start.is_some() { + return; + } + self.close_keystrokes_start = Some(start); + self.update_clear_close_keystrokes_timer(cx); + } + + fn update_clear_close_keystrokes_timer(&mut self, cx: &mut Context<Self>) { + self.clear_close_keystrokes_timer = Some(cx.spawn(async |this, cx| { + cx.background_executor() + .timer(CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT) + .await; + this.update(cx, |this, _cx| { + this.end_close_keystrokes_capture(); + }) + .ok(); + })); + } + + /// Interrupt the capture of close keystrokes, but do not clear the close keystrokes + /// from the input + fn end_close_keystrokes_capture(&mut self) -> Option<usize> { + self.close_keystrokes.take(); + self.clear_close_keystrokes_timer.take(); + return self.close_keystrokes_start.take(); + } + + fn handle_possible_close_keystroke( + &mut self, + keystroke: &Keystroke, + window: &mut Window, + cx: &mut Context<Self>, + ) -> CloseKeystrokeResult { + let Some(keybind_for_close_action) = Self::determine_stop_recording_binding(window) else { + log::trace!("No keybinding to stop recording keystrokes in keystroke input"); + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + }; + let action_keystrokes = keybind_for_close_action.keystrokes(); + + if let Some(mut close_keystrokes) = self.close_keystrokes.take() { + let mut index = 0; + + while index < action_keystrokes.len() && index < close_keystrokes.len() { + if !close_keystrokes[index].should_match(&action_keystrokes[index]) { + break; + } + index += 1; + } + if index == close_keystrokes.len() { + if index >= action_keystrokes.len() { + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + } + if keystroke.should_match(&action_keystrokes[index]) { + close_keystrokes.push(keystroke.clone()); + if close_keystrokes.len() == action_keystrokes.len() { + return CloseKeystrokeResult::Close; + } else { + self.close_keystrokes = Some(close_keystrokes); + self.update_clear_close_keystrokes_timer(cx); + return CloseKeystrokeResult::Partial; + } + } else { + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + } + } + } else if let Some(first_action_keystroke) = action_keystrokes.first() + && keystroke.should_match(first_action_keystroke) + { + self.close_keystrokes = Some(vec![keystroke.clone()]); + return CloseKeystrokeResult::Partial; + } + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + } + + fn on_modifiers_changed( + &mut self, + event: &ModifiersChangedEvent, + window: &mut Window, + cx: &mut Context<Self>, + ) { + cx.stop_propagation(); + let keystrokes_len = self.keystrokes.len(); + + if self.previous_modifiers.modified() + && event.modifiers.is_subset_of(&self.previous_modifiers) + { + self.previous_modifiers &= event.modifiers; + return; + } + self.keystrokes_changed(cx); + + if let Some(last) = self.keystrokes.last_mut() + && last.key.is_empty() + && keystrokes_len <= Self::KEYSTROKE_COUNT_MAX + { + if !self.search && !event.modifiers.modified() { + self.keystrokes.pop(); + return; + } + if self.search { + if self.previous_modifiers.modified() { + last.modifiers |= event.modifiers; + } else { + self.keystrokes.push(Self::dummy(event.modifiers)); + } + self.previous_modifiers |= event.modifiers; + } else { + last.modifiers = event.modifiers; + return; + } + } else if keystrokes_len < Self::KEYSTROKE_COUNT_MAX { + self.keystrokes.push(Self::dummy(event.modifiers)); + if self.search { + self.previous_modifiers |= event.modifiers; + } + } + if keystrokes_len >= Self::KEYSTROKE_COUNT_MAX { + self.clear_keystrokes(&ClearKeystrokes, window, cx); + } + } + + fn handle_keystroke( + &mut self, + keystroke: &Keystroke, + window: &mut Window, + cx: &mut Context<Self>, + ) { + cx.stop_propagation(); + + let close_keystroke_result = self.handle_possible_close_keystroke(keystroke, window, cx); + if close_keystroke_result == CloseKeystrokeResult::Close { + self.stop_recording(&StopRecording, window, cx); + return; + } + + let mut keystroke = keystroke.clone(); + if let Some(last) = self.keystrokes.last() + && last.key.is_empty() + && (!self.search || self.previous_modifiers.modified()) + { + let key = keystroke.key.clone(); + keystroke = last.clone(); + keystroke.key = key; + self.keystrokes.pop(); + } + + if close_keystroke_result == CloseKeystrokeResult::Partial { + self.upsert_close_keystrokes_start(self.keystrokes.len(), cx); + if self.keystrokes.len() >= Self::KEYSTROKE_COUNT_MAX { + return; + } + } + + if self.keystrokes.len() >= Self::KEYSTROKE_COUNT_MAX { + self.clear_keystrokes(&ClearKeystrokes, window, cx); + return; + } + + self.keystrokes.push(keystroke.clone()); + self.keystrokes_changed(cx); + + if self.search { + self.previous_modifiers = keystroke.modifiers; + return; + } + if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX && keystroke.modifiers.modified() { + self.keystrokes.push(Self::dummy(keystroke.modifiers)); + } + } + + fn on_inner_focus_in(&mut self, _window: &mut Window, cx: &mut Context<Self>) { + if self.intercept_subscription.is_none() { + let listener = cx.listener(|this, event: &gpui::KeystrokeEvent, window, cx| { + this.handle_keystroke(&event.keystroke, window, cx); + }); + self.intercept_subscription = Some(cx.intercept_keystrokes(listener)) + } + } + + fn on_inner_focus_out( + &mut self, + _event: gpui::FocusOutEvent, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + self.intercept_subscription.take(); + cx.notify(); + } + + fn render_keystrokes(&self, is_recording: bool) -> impl Iterator<Item = Div> { + let keystrokes = if let Some(placeholders) = self.placeholder_keystrokes.as_ref() + && self.keystrokes.is_empty() + { + if is_recording { + &[] + } else { + placeholders.as_slice() + } + } else { + &self.keystrokes + }; + keystrokes.iter().map(move |keystroke| { + h_flex().children(ui::render_keystroke( + keystroke, + Some(Color::Default), + Some(rems(0.875).into()), + ui::PlatformStyle::platform(), + false, + )) + }) + } + + pub fn start_recording( + &mut self, + _: &StartRecording, + window: &mut Window, + cx: &mut Context<Self>, + ) { + window.focus(&self.inner_focus_handle); + self.clear_keystrokes(&ClearKeystrokes, window, cx); + self.previous_modifiers = window.modifiers(); + #[cfg(test)] + { + self.recording = true; + } + cx.stop_propagation(); + } + + pub fn stop_recording( + &mut self, + _: &StopRecording, + window: &mut Window, + cx: &mut Context<Self>, + ) { + if !self.is_recording(window) { + return; + } + window.focus(&self.outer_focus_handle); + if let Some(close_keystrokes_start) = self.close_keystrokes_start.take() + && close_keystrokes_start < self.keystrokes.len() + { + self.keystrokes.drain(close_keystrokes_start..); + self.keystrokes_changed(cx); + } + self.end_close_keystrokes_capture(); + #[cfg(test)] + { + self.recording = false; + } + cx.notify(); + } + + pub fn clear_keystrokes( + &mut self, + _: &ClearKeystrokes, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + self.keystrokes.clear(); + self.keystrokes_changed(cx); + self.end_close_keystrokes_capture(); + } + + fn is_recording(&self, window: &Window) -> bool { + #[cfg(test)] + { + if true { + // in tests, we just need a simple bool that is toggled on start and stop recording + return self.recording; + } + } + // however, in the real world, checking if the inner focus handle is focused + // is a much more reliable check, as the intercept keystroke handlers are installed + // on focus of the inner focus handle, thereby ensuring our recording state does + // not get de-synced + return self.inner_focus_handle.is_focused(window); + } +} + +impl EventEmitter<()> for KeystrokeInput {} + +impl Focusable for KeystrokeInput { + fn focus_handle(&self, _cx: &gpui::App) -> FocusHandle { + self.outer_focus_handle.clone() + } +} + +impl Render for KeystrokeInput { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let colors = cx.theme().colors(); + let is_focused = self.outer_focus_handle.contains_focused(window, cx); + let is_recording = self.is_recording(window); + + let horizontal_padding = rems_from_px(64.); + + let recording_bg_color = colors + .editor_background + .blend(colors.text_accent.opacity(0.1)); + + let recording_pulse = |color: Color| { + Icon::new(IconName::Circle) + .size(IconSize::Small) + .color(Color::Error) + .with_animation( + "recording-pulse", + Animation::new(std::time::Duration::from_secs(2)) + .repeat() + .with_easing(gpui::pulsating_between(0.4, 0.8)), + { + let color = color.color(cx); + move |this, delta| this.color(Color::Custom(color.opacity(delta))) + }, + ) + }; + + let recording_indicator = h_flex() + .h_4() + .pr_1() + .gap_0p5() + .border_1() + .border_color(colors.border) + .bg(colors + .editor_background + .blend(colors.text_accent.opacity(0.1))) + .rounded_sm() + .child(recording_pulse(Color::Error)) + .child( + Label::new("REC") + .size(LabelSize::XSmall) + .weight(FontWeight::SEMIBOLD) + .color(Color::Error), + ); + + let search_indicator = h_flex() + .h_4() + .pr_1() + .gap_0p5() + .border_1() + .border_color(colors.border) + .bg(colors + .editor_background + .blend(colors.text_accent.opacity(0.1))) + .rounded_sm() + .child(recording_pulse(Color::Accent)) + .child( + Label::new("SEARCH") + .size(LabelSize::XSmall) + .weight(FontWeight::SEMIBOLD) + .color(Color::Accent), + ); + + let record_icon = if self.search { + IconName::MagnifyingGlass + } else { + IconName::PlayFilled + }; + + h_flex() + .id("keystroke-input") + .track_focus(&self.outer_focus_handle) + .py_2() + .px_3() + .gap_2() + .min_h_10() + .w_full() + .flex_1() + .justify_between() + .rounded_lg() + .overflow_hidden() + .map(|this| { + if is_recording { + this.bg(recording_bg_color) + } else { + this.bg(colors.editor_background) + } + }) + .border_1() + .border_color(colors.border_variant) + .when(is_focused, |parent| { + parent.border_color(colors.border_focused) + }) + .key_context(Self::key_context()) + .on_action(cx.listener(Self::start_recording)) + .on_action(cx.listener(Self::clear_keystrokes)) + .child( + h_flex() + .w(horizontal_padding) + .gap_0p5() + .justify_start() + .flex_none() + .when(is_recording, |this| { + this.map(|this| { + if self.search { + this.child(search_indicator) + } else { + this.child(recording_indicator) + } + }) + }), + ) + .child( + h_flex() + .id("keystroke-input-inner") + .track_focus(&self.inner_focus_handle) + .on_modifiers_changed(cx.listener(Self::on_modifiers_changed)) + .size_full() + .when(!self.search, |this| { + this.focus(|mut style| { + style.border_color = Some(colors.border_focused); + style + }) + }) + .w_full() + .min_w_0() + .justify_center() + .flex_wrap() + .gap(ui::DynamicSpacing::Base04.rems(cx)) + .children(self.render_keystrokes(is_recording)), + ) + .child( + h_flex() + .w(horizontal_padding) + .gap_0p5() + .justify_end() + .flex_none() + .map(|this| { + if is_recording { + this.child( + IconButton::new("stop-record-btn", IconName::StopFilled) + .shape(IconButtonShape::Square) + .map(|this| { + this.tooltip(Tooltip::for_action_title( + if self.search { + "Stop Searching" + } else { + "Stop Recording" + }, + &StopRecording, + )) + }) + .icon_color(Color::Error) + .on_click(cx.listener(|this, _event, window, cx| { + this.stop_recording(&StopRecording, window, cx); + })), + ) + } else { + this.child( + IconButton::new("record-btn", record_icon) + .shape(IconButtonShape::Square) + .map(|this| { + this.tooltip(Tooltip::for_action_title( + if self.search { + "Start Searching" + } else { + "Start Recording" + }, + &StartRecording, + )) + }) + .when(!is_focused, |this| this.icon_color(Color::Muted)) + .on_click(cx.listener(|this, _event, window, cx| { + this.start_recording(&StartRecording, window, cx); + })), + ) + } + }) + .child( + IconButton::new("clear-btn", IconName::Delete) + .shape(IconButtonShape::Square) + .tooltip(Tooltip::for_action_title( + "Clear Keystrokes", + &ClearKeystrokes, + )) + .when(!is_recording || !is_focused, |this| { + this.icon_color(Color::Muted) + }) + .on_click(cx.listener(|this, _event, window, cx| { + this.clear_keystrokes(&ClearKeystrokes, window, cx); + })), + ), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use fs::FakeFs; + use gpui::{Entity, TestAppContext, VisualTestContext}; + use itertools::Itertools as _; + use project::Project; + use settings::SettingsStore; + use workspace::Workspace; + + pub struct KeystrokeInputTestHelper { + input: Entity<KeystrokeInput>, + current_modifiers: Modifiers, + cx: VisualTestContext, + } + + impl KeystrokeInputTestHelper { + /// Creates a new test helper with default settings + pub fn new(mut cx: VisualTestContext) -> Self { + let input = cx.new_window_entity(|window, cx| KeystrokeInput::new(None, window, cx)); + + let mut helper = Self { + input, + current_modifiers: Modifiers::default(), + cx, + }; + + helper.start_recording(); + helper + } + + /// Sets search mode on the input + pub fn with_search_mode(&mut self, search: bool) -> &mut Self { + self.input.update(&mut self.cx, |input, _| { + input.set_search(search); + }); + self + } + + /// Sends a keystroke event based on string description + /// Examples: "a", "ctrl-a", "cmd-shift-z", "escape" + #[track_caller] + pub fn send_keystroke(&mut self, keystroke_input: &str) -> &mut Self { + self.expect_is_recording(true); + let keystroke_str = if keystroke_input.ends_with('-') { + format!("{}_", keystroke_input) + } else { + keystroke_input.to_string() + }; + + let mut keystroke = Keystroke::parse(&keystroke_str) + .unwrap_or_else(|_| panic!("Invalid keystroke: {}", keystroke_input)); + + // Remove the dummy key if we added it for modifier-only keystrokes + if keystroke_input.ends_with('-') && keystroke_str.ends_with("_") { + keystroke.key = "".to_string(); + } + + // Combine current modifiers with keystroke modifiers + keystroke.modifiers |= self.current_modifiers; + + self.update_input(|input, window, cx| { + input.handle_keystroke(&keystroke, window, cx); + }); + + // Don't update current_modifiers for keystrokes with actual keys + if keystroke.key.is_empty() { + self.current_modifiers = keystroke.modifiers; + } + self + } + + /// Sends a modifier change event based on string description + /// Examples: "+ctrl", "-ctrl", "+cmd+shift", "-all" + #[track_caller] + pub fn send_modifiers(&mut self, modifiers: &str) -> &mut Self { + self.expect_is_recording(true); + let new_modifiers = if modifiers == "-all" { + Modifiers::default() + } else { + self.parse_modifier_change(modifiers) + }; + + let event = ModifiersChangedEvent { + modifiers: new_modifiers, + capslock: gpui::Capslock::default(), + }; + + self.update_input(|input, window, cx| { + input.on_modifiers_changed(&event, window, cx); + }); + + self.current_modifiers = new_modifiers; + self + } + + /// Sends multiple events in sequence + /// Each event string is either a keystroke or modifier change + #[track_caller] + pub fn send_events(&mut self, events: &[&str]) -> &mut Self { + self.expect_is_recording(true); + for event in events { + if event.starts_with('+') || event.starts_with('-') { + self.send_modifiers(event); + } else { + self.send_keystroke(event); + } + } + self + } + + #[track_caller] + fn expect_keystrokes_equal(actual: &[Keystroke], expected: &[&str]) { + let expected_keystrokes: Result<Vec<Keystroke>, _> = expected + .iter() + .map(|s| { + let keystroke_str = if s.ends_with('-') { + format!("{}_", s) + } else { + s.to_string() + }; + + let mut keystroke = Keystroke::parse(&keystroke_str)?; + + // Remove the dummy key if we added it for modifier-only keystrokes + if s.ends_with('-') && keystroke_str.ends_with("_") { + keystroke.key = "".to_string(); + } + + Ok(keystroke) + }) + .collect(); + + let expected_keystrokes = expected_keystrokes + .unwrap_or_else(|e: anyhow::Error| panic!("Invalid expected keystroke: {}", e)); + + assert_eq!( + actual.len(), + expected_keystrokes.len(), + "Keystroke count mismatch. Expected: {:?}, Actual: {:?}", + expected_keystrokes + .iter() + .map(|k| k.unparse()) + .collect::<Vec<_>>(), + actual.iter().map(|k| k.unparse()).collect::<Vec<_>>() + ); + + for (i, (actual, expected)) in actual.iter().zip(expected_keystrokes.iter()).enumerate() + { + assert_eq!( + actual.unparse(), + expected.unparse(), + "Keystroke {} mismatch. Expected: '{}', Actual: '{}'", + i, + expected.unparse(), + actual.unparse() + ); + } + } + + /// Verifies that the keystrokes match the expected strings + #[track_caller] + pub fn expect_keystrokes(&mut self, expected: &[&str]) -> &mut Self { + let actual = self + .input + .read_with(&mut self.cx, |input, _| input.keystrokes.clone()); + Self::expect_keystrokes_equal(&actual, expected); + self + } + + #[track_caller] + pub fn expect_close_keystrokes(&mut self, expected: &[&str]) -> &mut Self { + let actual = self + .input + .read_with(&mut self.cx, |input, _| input.close_keystrokes.clone()) + .unwrap_or_default(); + Self::expect_keystrokes_equal(&actual, expected); + self + } + + /// Verifies that there are no keystrokes + #[track_caller] + pub fn expect_empty(&mut self) -> &mut Self { + self.expect_keystrokes(&[]) + } + + /// Starts recording keystrokes + #[track_caller] + pub fn start_recording(&mut self) -> &mut Self { + self.expect_is_recording(false); + self.input.update_in(&mut self.cx, |input, window, cx| { + input.start_recording(&StartRecording, window, cx); + }); + self + } + + /// Stops recording keystrokes + pub fn stop_recording(&mut self) -> &mut Self { + self.expect_is_recording(true); + self.input.update_in(&mut self.cx, |input, window, cx| { + input.stop_recording(&StopRecording, window, cx); + }); + self + } + + /// Clears all keystrokes + #[track_caller] + pub fn clear_keystrokes(&mut self) -> &mut Self { + let change_tracker = KeystrokeUpdateTracker::new(self.input.clone(), &mut self.cx); + self.input.update_in(&mut self.cx, |input, window, cx| { + input.clear_keystrokes(&ClearKeystrokes, window, cx); + }); + KeystrokeUpdateTracker::finish(change_tracker, &self.cx); + self.current_modifiers = Default::default(); + self + } + + /// Verifies the recording state + #[track_caller] + pub fn expect_is_recording(&mut self, expected: bool) -> &mut Self { + let actual = self + .input + .update_in(&mut self.cx, |input, window, _| input.is_recording(window)); + assert_eq!( + actual, expected, + "Recording state mismatch. Expected: {}, Actual: {}", + expected, actual + ); + self + } + + pub async fn wait_for_close_keystroke_capture_end(&mut self) -> &mut Self { + let task = self.input.update_in(&mut self.cx, |input, _, _| { + input.clear_close_keystrokes_timer.take() + }); + let task = task.expect("No close keystroke capture end timer task"); + self.cx + .executor() + .advance_clock(CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT); + task.await; + self + } + + /// Parses modifier change strings like "+ctrl", "-shift", "+cmd+alt" + #[track_caller] + fn parse_modifier_change(&self, modifiers_str: &str) -> Modifiers { + let mut modifiers = self.current_modifiers; + + assert!(!modifiers_str.is_empty(), "Empty modifier string"); + + let value; + let split_char; + let remaining; + if let Some(to_add) = modifiers_str.strip_prefix('+') { + value = true; + split_char = '+'; + remaining = to_add; + } else { + let to_remove = modifiers_str + .strip_prefix('-') + .expect("Modifier string must start with '+' or '-'"); + value = false; + split_char = '-'; + remaining = to_remove; + } + + for modifier in remaining.split(split_char) { + match modifier { + "ctrl" | "control" => modifiers.control = value, + "alt" | "option" => modifiers.alt = value, + "shift" => modifiers.shift = value, + "cmd" | "command" | "platform" => modifiers.platform = value, + "fn" | "function" => modifiers.function = value, + _ => panic!("Unknown modifier: {}", modifier), + } + } + + modifiers + } + + #[track_caller] + fn update_input<R>( + &mut self, + cb: impl FnOnce(&mut KeystrokeInput, &mut Window, &mut Context<KeystrokeInput>) -> R, + ) -> R { + let change_tracker = KeystrokeUpdateTracker::new(self.input.clone(), &mut self.cx); + let result = self.input.update_in(&mut self.cx, cb); + KeystrokeUpdateTracker::finish(change_tracker, &self.cx); + return result; + } + } + + struct KeystrokeUpdateTracker { + initial_keystrokes: Vec<Keystroke>, + _subscription: Subscription, + input: Entity<KeystrokeInput>, + received_keystrokes_updated: bool, + } + + impl KeystrokeUpdateTracker { + fn new(input: Entity<KeystrokeInput>, cx: &mut VisualTestContext) -> Entity<Self> { + cx.new(|cx| Self { + initial_keystrokes: input.read_with(cx, |input, _| input.keystrokes.clone()), + _subscription: cx.subscribe(&input, |this: &mut Self, _, _, _| { + this.received_keystrokes_updated = true; + }), + input, + received_keystrokes_updated: false, + }) + } + #[track_caller] + fn finish(this: Entity<Self>, cx: &VisualTestContext) { + let (received_keystrokes_updated, initial_keystrokes_str, updated_keystrokes_str) = + this.read_with(cx, |this, cx| { + let updated_keystrokes = this + .input + .read_with(cx, |input, _| input.keystrokes.clone()); + let initial_keystrokes_str = keystrokes_str(&this.initial_keystrokes); + let updated_keystrokes_str = keystrokes_str(&updated_keystrokes); + ( + this.received_keystrokes_updated, + initial_keystrokes_str, + updated_keystrokes_str, + ) + }); + if received_keystrokes_updated { + assert_ne!( + initial_keystrokes_str, updated_keystrokes_str, + "Received keystrokes_updated event, expected different keystrokes" + ); + } else { + assert_eq!( + initial_keystrokes_str, updated_keystrokes_str, + "Received no keystrokes_updated event, expected same keystrokes" + ); + } + + fn keystrokes_str(ks: &[Keystroke]) -> String { + ks.iter().map(|ks| ks.unparse()).join(" ") + } + } + } + + async fn init_test(cx: &mut TestAppContext) -> KeystrokeInputTestHelper { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + theme::init(theme::LoadThemes::JustBase, cx); + language::init(cx); + project::Project::init_settings(cx); + workspace::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let workspace = + cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = VisualTestContext::from_window(*workspace, cx); + KeystrokeInputTestHelper::new(cx) + } + + #[gpui::test] + async fn test_basic_keystroke_input(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_keystroke("a") + .clear_keystrokes() + .expect_empty(); + } + + #[gpui::test] + async fn test_modifier_handling(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "a", "-ctrl"]) + .expect_keystrokes(&["ctrl-a"]); + } + + #[gpui::test] + async fn test_multiple_modifiers(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_keystroke("cmd-shift-z") + .expect_keystrokes(&["cmd-shift-z", "cmd-shift-"]); + } + + #[gpui::test] + async fn test_search_mode_behavior(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+cmd", "shift-f", "-cmd"]) + // In search mode, when completing a modifier-only keystroke with a key, + // only the original modifiers are preserved, not the keystroke's modifiers + .expect_keystrokes(&["cmd-f"]); + } + + #[gpui::test] + async fn test_keystroke_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_keystroke("a") + .send_keystroke("b") + .send_keystroke("c") + .expect_keystrokes(&["a", "b", "c"]) // At max limit + .send_keystroke("d") + .expect_empty(); // Should clear when exceeding limit + } + + #[gpui::test] + async fn test_modifier_release_all(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+shift", "a", "-all"]) + .expect_keystrokes(&["ctrl-shift-a"]); + } + + #[gpui::test] + async fn test_search_new_modifiers_not_added_until_all_released(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+shift", "a", "-ctrl"]) + .expect_keystrokes(&["ctrl-shift-a"]) + .send_events(&["+ctrl"]) + .expect_keystrokes(&["ctrl-shift-a", "ctrl-shift-"]); + } + + #[gpui::test] + async fn test_previous_modifiers_no_effect_when_not_search(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl+shift", "a", "-all"]) + .expect_keystrokes(&["ctrl-shift-a"]); + } + + #[gpui::test] + async fn test_keystroke_limit_overflow_non_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["a", "b", "c", "d"]) // 4 keystrokes, exceeds limit of 3 + .expect_empty(); // Should clear when exceeding limit + } + + #[gpui::test] + async fn test_complex_modifier_sequences(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "+alt", "a", "-ctrl", "-shift", "-alt"]) + .expect_keystrokes(&["ctrl-shift-alt-a"]); + } + + #[gpui::test] + async fn test_modifier_only_keystrokes_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "-ctrl", "-shift"]) + .expect_keystrokes(&["ctrl-shift-"]); // Modifier-only sequences create modifier-only keystrokes + } + + #[gpui::test] + async fn test_modifier_only_keystrokes_non_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "+shift", "-ctrl", "-shift"]) + .expect_empty(); // Modifier-only sequences get filtered in non-search mode + } + + #[gpui::test] + async fn test_rapid_modifier_changes(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "-ctrl", "+shift", "-shift", "+alt", "a", "-alt"]) + .expect_keystrokes(&["ctrl-", "shift-", "alt-a"]); + } + + #[gpui::test] + async fn test_clear_keystrokes_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "a", "-ctrl", "b"]) + .expect_keystrokes(&["ctrl-a", "b"]) + .clear_keystrokes() + .expect_empty(); + } + + #[gpui::test] + async fn test_non_search_mode_modifier_key_sequence(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "a"]) + .expect_keystrokes(&["ctrl-a", "ctrl-"]) + .send_events(&["-ctrl"]) + .expect_keystrokes(&["ctrl-a"]); // Non-search mode filters trailing empty keystrokes + } + + #[gpui::test] + async fn test_all_modifiers_at_once(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+shift+alt+cmd", "a", "-all"]) + .expect_keystrokes(&["ctrl-shift-alt-cmd-a"]); + } + + #[gpui::test] + async fn test_keystrokes_at_exact_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "b", "c"]) // exactly 3 keystrokes (at limit) + .expect_keystrokes(&["a", "b", "c"]) + .send_events(&["d"]) // should clear when exceeding + .expect_empty(); + } + + #[gpui::test] + async fn test_function_modifier_key(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+fn", "f1", "-fn"]) + .expect_keystrokes(&["fn-f1"]); + } + + #[gpui::test] + async fn test_start_stop_recording(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_events(&["a", "b"]) + .expect_keystrokes(&["a", "b"]) // start_recording clears existing keystrokes + .stop_recording() + .expect_is_recording(false) + .start_recording() + .send_events(&["c"]) + .expect_keystrokes(&["c"]); + } + + #[gpui::test] + async fn test_modifier_sequence_with_interruption(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "a", "-shift", "b", "-ctrl"]) + .expect_keystrokes(&["ctrl-shift-a", "ctrl-b"]); + } + + #[gpui::test] + async fn test_empty_key_sequence_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&[]) // No events at all + .expect_empty(); + } + + #[gpui::test] + async fn test_modifier_sequence_completion_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "-shift", "a", "-ctrl"]) + .expect_keystrokes(&["ctrl-shift-a"]); + } + + #[gpui::test] + async fn test_triple_escape_stops_recording_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "escape", "escape", "escape"]) + .expect_keystrokes(&["a"]) // Triple escape removes final escape, stops recording + .expect_is_recording(false); + } + + #[gpui::test] + async fn test_triple_escape_stops_recording_non_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["a", "escape", "escape", "escape"]) + .expect_keystrokes(&["a"]); // Triple escape stops recording but only removes final escape + } + + #[gpui::test] + async fn test_triple_escape_at_keystroke_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "b", "c", "escape", "escape", "escape"]) // 6 keystrokes total, exceeds limit + .expect_keystrokes(&["a", "b", "c"]); // Triple escape stops recording and removes escapes, leaves original keystrokes + } + + #[gpui::test] + async fn test_interrupted_escape_sequence(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape", "a", "escape"]) // Partial escape sequence interrupted by 'a' + .expect_keystrokes(&["escape", "escape", "a"]); // Escape sequence interrupted by 'a', no close triggered + } + + #[gpui::test] + async fn test_interrupted_escape_sequence_within_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape", "a"]) // Partial escape sequence interrupted by 'a' (3 keystrokes, at limit) + .expect_keystrokes(&["escape", "escape", "a"]); // Should not trigger close, interruption resets escape detection + } + + #[gpui::test] + async fn test_partial_escape_sequence_no_close(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape"]) // Only 2 escapes, not enough to close + .expect_keystrokes(&["escape", "escape"]) + .expect_is_recording(true); // Should remain in keystrokes, no close triggered + } + + #[gpui::test] + async fn test_recording_state_after_triple_escape(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "escape", "escape", "escape"]) + .expect_keystrokes(&["a"]) // Triple escape stops recording, removes final escape + .expect_is_recording(false); + } + + #[gpui::test] + async fn test_triple_escape_mixed_with_other_keystrokes(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "escape", "b", "escape", "escape"]) // Mixed sequence, should not trigger close + .expect_keystrokes(&["a", "escape", "b"]); // No complete triple escape sequence, stays at limit + } + + #[gpui::test] + async fn test_triple_escape_only(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape", "escape"]) // Pure triple escape sequence + .expect_empty(); + } + + #[gpui::test] + async fn test_end_close_keystroke_capture(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_events(&["+ctrl", "g", "-ctrl", "escape"]) + .expect_keystrokes(&["ctrl-g", "escape"]) + .wait_for_close_keystroke_capture_end() + .await + .send_events(&["escape", "escape"]) + .expect_keystrokes(&["ctrl-g", "escape", "escape"]) + .expect_close_keystrokes(&["escape", "escape"]) + .send_keystroke("escape") + .expect_keystrokes(&["ctrl-g", "escape"]); + } + + #[gpui::test] + async fn test_search_previous_modifiers_are_sticky(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+alt", "-ctrl", "j"]) + .expect_keystrokes(&["ctrl-alt-j"]); + } + + #[gpui::test] + async fn test_previous_modifiers_can_be_entered_separately(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "-ctrl"]) + .expect_keystrokes(&["ctrl-"]) + .send_events(&["+alt", "-alt"]) + .expect_keystrokes(&["ctrl-", "alt-"]); + } + + #[gpui::test] + async fn test_previous_modifiers_reset_on_key(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+alt", "-ctrl", "+shift"]) + .expect_keystrokes(&["ctrl-shift-alt-"]) + .send_keystroke("j") + .expect_keystrokes(&["ctrl-shift-alt-j"]) + .send_keystroke("i") + .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i"]) + .send_events(&["-shift-alt", "+cmd"]) + .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i", "cmd-"]); + } + + #[gpui::test] + async fn test_previous_modifiers_reset_on_release_all(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+alt", "-ctrl", "+shift"]) + .expect_keystrokes(&["ctrl-shift-alt-"]) + .send_events(&["-all", "j"]) + .expect_keystrokes(&["ctrl-shift-alt-", "j"]); + } + + #[gpui::test] + async fn test_search_repeat_modifiers(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "-ctrl", "+alt", "-alt", "+shift", "-shift"]) + .expect_keystrokes(&["ctrl-", "alt-", "shift-"]) + .send_events(&["+cmd"]) + .expect_empty(); + } + + #[gpui::test] + async fn test_not_search_repeat_modifiers(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "-ctrl", "+alt", "-alt", "+shift", "-shift"]) + .expect_empty(); + } +} diff --git a/crates/settings_ui/src/ui_components/mod.rs b/crates/settings_ui/src/ui_components/mod.rs index 13971b0a5d..5d6463a61a 100644 --- a/crates/settings_ui/src/ui_components/mod.rs +++ b/crates/settings_ui/src/ui_components/mod.rs @@ -1 +1,2 @@ +pub mod keystroke_input; pub mod table; diff --git a/crates/settings_ui/src/ui_components/table.rs b/crates/settings_ui/src/ui_components/table.rs index 65778c20eb..3c9992bd68 100644 --- a/crates/settings_ui/src/ui_components/table.rs +++ b/crates/settings_ui/src/ui_components/table.rs @@ -17,7 +17,7 @@ use ui::{ StyledTypography, Window, div, example_group_with_title, h_flex, px, single_example, v_flex, }; -const RESIZE_COLUMN_WIDTH: f32 = 5.0; +const RESIZE_COLUMN_WIDTH: f32 = 8.0; #[derive(Debug)] struct DraggedColumn(usize); @@ -214,6 +214,7 @@ impl TableInteractionState { let mut column_ix = 0; let resizable_columns_slice = *resizable_columns; let mut resizable_columns = resizable_columns.into_iter(); + let dividers = intersperse_with(spacers, || { window.with_id(column_ix, |window| { let mut resize_divider = div() @@ -221,9 +222,9 @@ impl TableInteractionState { .id(column_ix) .relative() .top_0() - .w_0p5() + .w_px() .h_full() - .bg(cx.theme().colors().border.opacity(0.5)); + .bg(cx.theme().colors().border.opacity(0.8)); let mut resize_handle = div() .id("column-resize-handle") @@ -237,9 +238,11 @@ impl TableInteractionState { .is_some_and(ResizeBehavior::is_resizable) { let hovered = window.use_state(cx, |_window, _cx| false); + resize_divider = resize_divider.when(*hovered.read(cx), |div| { div.bg(cx.theme().colors().border_focused) }); + resize_handle = resize_handle .on_hover(move |&was_hovered, _, cx| hovered.write(cx, was_hovered)) .cursor_col_resize() @@ -269,12 +272,11 @@ impl TableInteractionState { }) }); - div() + h_flex() .id("resize-handles") - .h_flex() .absolute() - .w_full() .inset_0() + .w_full() .children(dividers) .into_any_element() } @@ -896,7 +898,6 @@ fn base_cell_style(width: Option<Length>) -> Div { .px_1p5() .when_some(width, |this, width| this.w(width)) .when(width.is_none(), |this| this.flex_1()) - .justify_start() .whitespace_nowrap() .text_ellipsis() .overflow_hidden() @@ -941,7 +942,7 @@ pub fn render_row<const COLS: usize>( .map(IntoElement::into_any_element) .into_iter() .zip(column_widths) - .map(|(cell, width)| base_cell_style_text(width, cx).px_1p5().py_1().child(cell)), + .map(|(cell, width)| base_cell_style_text(width, cx).px_1().py_0p5().child(cell)), ); let row = if let Some(map_row) = table_context.map_row { @@ -950,7 +951,7 @@ pub fn render_row<const COLS: usize>( row.into_any_element() }; - div().h_full().w_full().child(row).into_any_element() + div().size_full().child(row).into_any_element() } pub fn render_header<const COLS: usize>( diff --git a/crates/sum_tree/src/sum_tree.rs b/crates/sum_tree/src/sum_tree.rs index 4c5ce39590..3a12e3a681 100644 --- a/crates/sum_tree/src/sum_tree.rs +++ b/crates/sum_tree/src/sum_tree.rs @@ -101,37 +101,32 @@ impl<'a, T: Summary> Dimension<'a, T> for () { fn add_summary(&mut self, _: &'a T, _: &T::Context) {} } -impl<'a, T: Summary, D1: Dimension<'a, T>, D2: Dimension<'a, T>> Dimension<'a, T> for (D1, D2) { +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct Dimensions<D1, D2, D3 = ()>(pub D1, pub D2, pub D3); + +impl<'a, T: Summary, D1: Dimension<'a, T>, D2: Dimension<'a, T>, D3: Dimension<'a, T>> + Dimension<'a, T> for Dimensions<D1, D2, D3> +{ fn zero(cx: &T::Context) -> Self { - (D1::zero(cx), D2::zero(cx)) + Dimensions(D1::zero(cx), D2::zero(cx), D3::zero(cx)) } fn add_summary(&mut self, summary: &'a T, cx: &T::Context) { self.0.add_summary(summary, cx); self.1.add_summary(summary, cx); + self.2.add_summary(summary, cx); } } -impl<'a, S, D1, D2> SeekTarget<'a, S, (D1, D2)> for D1 -where - S: Summary, - D1: SeekTarget<'a, S, D1> + Dimension<'a, S>, - D2: Dimension<'a, S>, -{ - fn cmp(&self, cursor_location: &(D1, D2), cx: &S::Context) -> Ordering { - self.cmp(&cursor_location.0, cx) - } -} - -impl<'a, S, D1, D2, D3> SeekTarget<'a, S, ((D1, D2), D3)> for D1 +impl<'a, S, D1, D2, D3> SeekTarget<'a, S, Dimensions<D1, D2, D3>> for D1 where S: Summary, D1: SeekTarget<'a, S, D1> + Dimension<'a, S>, D2: Dimension<'a, S>, D3: Dimension<'a, S>, { - fn cmp(&self, cursor_location: &((D1, D2), D3), cx: &S::Context) -> Ordering { - self.cmp(&cursor_location.0.0, cx) + fn cmp(&self, cursor_location: &Dimensions<D1, D2, D3>, cx: &S::Context) -> Ordering { + self.cmp(&cursor_location.0, cx) } } diff --git a/crates/supermaven/Cargo.toml b/crates/supermaven/Cargo.toml index d0451f34f2..4fc6a618ff 100644 --- a/crates/supermaven/Cargo.toml +++ b/crates/supermaven/Cargo.toml @@ -16,9 +16,9 @@ doctest = false anyhow.workspace = true client.workspace = true collections.workspace = true +edit_prediction.workspace = true futures.workspace = true gpui.workspace = true -inline_completion.workspace = true language.workspace = true log.workspace = true postage.workspace = true diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index c49272e66e..2660a03e6f 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -1,8 +1,8 @@ use crate::{Supermaven, SupermavenCompletionStateId}; use anyhow::Result; +use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use futures::StreamExt as _; use gpui::{App, Context, Entity, EntityId, Task}; -use inline_completion::{Direction, EditPredictionProvider, InlineCompletion}; use language::{Anchor, Buffer, BufferSnapshot}; use project::Project; use std::{ @@ -44,7 +44,7 @@ fn completion_from_diff( completion_text: &str, position: Anchor, delete_range: Range<Anchor>, -) -> InlineCompletion { +) -> EditPrediction { let buffer_text = snapshot .text_for_range(delete_range.clone()) .collect::<String>(); @@ -91,7 +91,7 @@ fn completion_from_diff( edits.push((edit_range, edit_text)); } - InlineCompletion { + EditPrediction { id: None, edits, edit_preview: None, @@ -182,7 +182,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { buffer: &Entity<Buffer>, cursor_position: Anchor, cx: &mut Context<Self>, - ) -> Option<InlineCompletion> { + ) -> Option<EditPrediction> { let completion_text = self .supermaven .read(cx) diff --git a/crates/tasks_ui/src/modal.rs b/crates/tasks_ui/src/modal.rs index 1510f613e3..c4b0931c35 100644 --- a/crates/tasks_ui/src/modal.rs +++ b/crates/tasks_ui/src/modal.rs @@ -500,7 +500,7 @@ impl PickerDelegate for TasksModalDelegate { .map(|icon| icon.color(Color::Muted).size(IconSize::Small)); let indicator = if matches!(source_kind, TaskSourceKind::Lsp { .. }) { Some(Indicator::icon( - Icon::new(IconName::Bolt).size(IconSize::Small), + Icon::new(IconName::BoltOutlined).size(IconSize::Small), )) } else { None diff --git a/crates/telemetry_events/src/telemetry_events.rs b/crates/telemetry_events/src/telemetry_events.rs index dfe167fcd4..735a1310ae 100644 --- a/crates/telemetry_events/src/telemetry_events.rs +++ b/crates/telemetry_events/src/telemetry_events.rs @@ -94,8 +94,8 @@ impl Display for AssistantPhase { pub enum Event { Flexible(FlexibleEvent), Editor(EditorEvent), - InlineCompletion(InlineCompletionEvent), - InlineCompletionRating(InlineCompletionRatingEvent), + EditPrediction(EditPredictionEvent), + EditPredictionRating(EditPredictionRatingEvent), Call(CallEvent), Assistant(AssistantEventData), Cpu(CpuEvent), @@ -132,7 +132,7 @@ pub struct EditorEvent { } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct InlineCompletionEvent { +pub struct EditPredictionEvent { /// Provider of the completion suggestion (e.g. copilot, supermaven) pub provider: String, pub suggestion_accepted: bool, @@ -140,14 +140,14 @@ pub struct InlineCompletionEvent { } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum InlineCompletionRating { +pub enum EditPredictionRating { Positive, Negative, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct InlineCompletionRatingEvent { - pub rating: InlineCompletionRating, +pub struct EditPredictionRatingEvent { + pub rating: EditPredictionRating, pub input_events: Arc<str>, pub input_excerpt: Arc<str>, pub output_excerpt: Arc<str>, diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index bf65a736e8..2e6be5aaf4 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -1591,7 +1591,7 @@ impl Item for TerminalView { let (icon, icon_color, rerun_button) = match terminal.task() { Some(terminal_task) => match &terminal_task.status { TaskStatus::Running => ( - IconName::Play, + IconName::PlayOutlined, Color::Disabled, TerminalView::rerun_button(&terminal_task), ), diff --git a/crates/text/src/anchor.rs b/crates/text/src/anchor.rs index 5807d3aae0..c4778216e0 100644 --- a/crates/text/src/anchor.rs +++ b/crates/text/src/anchor.rs @@ -3,7 +3,7 @@ use crate::{ locator::Locator, }; use std::{cmp::Ordering, fmt::Debug, ops::Range}; -use sum_tree::Bias; +use sum_tree::{Bias, Dimensions}; /// A timestamped position in a buffer #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash, Default)] @@ -99,8 +99,12 @@ impl Anchor { } else if self.buffer_id != Some(buffer.remote_id) { false } else { - let fragment_id = buffer.fragment_id_for_anchor(self); - let mut fragment_cursor = buffer.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let Some(fragment_id) = buffer.try_fragment_id_for_anchor(self) else { + return false; + }; + let mut fragment_cursor = buffer + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); fragment_cursor.seek(&Some(fragment_id), Bias::Left); fragment_cursor .item() diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index c1da0649da..68c7b2a2cd 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -37,7 +37,7 @@ use std::{ }; pub use subscription::*; pub use sum_tree::Bias; -use sum_tree::{FilterCursor, SumTree, TreeMap, TreeSet}; +use sum_tree::{Dimensions, FilterCursor, SumTree, TreeMap, TreeSet}; use undo_map::UndoMap; #[cfg(any(test, feature = "test-support"))] @@ -1071,7 +1071,9 @@ impl Buffer { let mut insertion_offset = 0; let mut new_ropes = RopeBuilder::new(self.visible_text.cursor(0), self.deleted_text.cursor(0)); - let mut old_fragments = self.fragments.cursor::<(VersionedFullOffset, usize)>(&cx); + let mut old_fragments = self + .fragments + .cursor::<Dimensions<VersionedFullOffset, usize>>(&cx); let mut new_fragments = old_fragments.slice(&VersionedFullOffset::Offset(ranges[0].start), Bias::Left); new_ropes.append(new_fragments.summary().text); @@ -1298,7 +1300,9 @@ impl Buffer { self.snapshot.undo_map.insert(undo); let mut edits = Patch::default(); - let mut old_fragments = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut old_fragments = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); let mut new_fragments = SumTree::new(&None); let mut new_ropes = RopeBuilder::new(self.visible_text.cursor(0), self.deleted_text.cursor(0)); @@ -1561,7 +1565,9 @@ impl Buffer { D: TextDimension, { // get fragment ranges - let mut cursor = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut cursor = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); let offset_ranges = self .fragment_ids_for_edits(edit_ids.into_iter()) .into_iter() @@ -2232,7 +2238,9 @@ impl BufferSnapshot { { let anchors = anchors.into_iter(); let mut insertion_cursor = self.insertions.cursor::<InsertionFragmentKey>(&()); - let mut fragment_cursor = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut fragment_cursor = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); let mut text_cursor = self.visible_text.cursor(0); let mut position = D::zero(&()); @@ -2318,7 +2326,9 @@ impl BufferSnapshot { ); }; - let mut fragment_cursor = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut fragment_cursor = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); fragment_cursor.seek(&Some(&insertion.fragment_id), Bias::Left); let fragment = fragment_cursor.item().unwrap(); let mut fragment_offset = fragment_cursor.start().1; @@ -2330,10 +2340,19 @@ impl BufferSnapshot { } fn fragment_id_for_anchor(&self, anchor: &Anchor) -> &Locator { + self.try_fragment_id_for_anchor(anchor).unwrap_or_else(|| { + panic!( + "invalid anchor {:?}. buffer id: {}, version: {:?}", + anchor, self.remote_id, self.version, + ) + }) + } + + fn try_fragment_id_for_anchor(&self, anchor: &Anchor) -> Option<&Locator> { if *anchor == Anchor::MIN { - Locator::min_ref() + Some(Locator::min_ref()) } else if *anchor == Anchor::MAX { - Locator::max_ref() + Some(Locator::max_ref()) } else { let anchor_key = InsertionFragmentKey { timestamp: anchor.timestamp, @@ -2354,20 +2373,12 @@ impl BufferSnapshot { insertion_cursor.prev(); } - let Some(insertion) = insertion_cursor.item().filter(|insertion| { - if cfg!(debug_assertions) { - insertion.timestamp == anchor.timestamp - } else { - true - } - }) else { - panic!( - "invalid anchor {:?}. buffer id: {}, version: {:?}", - anchor, self.remote_id, self.version - ); - }; - - &insertion.fragment_id + insertion_cursor + .item() + .filter(|insertion| { + !cfg!(debug_assertions) || insertion.timestamp == anchor.timestamp + }) + .map(|insertion| &insertion.fragment_id) } } @@ -2475,7 +2486,7 @@ impl BufferSnapshot { }; let mut cursor = self .fragments - .cursor::<(Option<&Locator>, FragmentTextSummary)>(&None); + .cursor::<Dimensions<Option<&Locator>, FragmentTextSummary>>(&None); let start_fragment_id = self.fragment_id_for_anchor(&range.start); cursor.seek(&Some(start_fragment_id), Bias::Left); diff --git a/crates/theme/src/icon_theme.rs b/crates/theme/src/icon_theme.rs index baa928d722..10fd1e002d 100644 --- a/crates/theme/src/icon_theme.rs +++ b/crates/theme/src/icon_theme.rs @@ -152,6 +152,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ("javascript", &["cjs", "js", "mjs"]), ("json", &["json"]), ("julia", &["jl"]), + ("kdl", &["kdl"]), ("kotlin", &["kt"]), ("lock", &["lock"]), ("log", &["log"]), @@ -315,6 +316,7 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("javascript", "icons/file_icons/javascript.svg"), ("json", "icons/file_icons/code.svg"), ("julia", "icons/file_icons/julia.svg"), + ("kdl", "icons/file_icons/kdl.svg"), ("kotlin", "icons/file_icons/kotlin.svg"), ("lock", "icons/file_icons/lock.svg"), ("log", "icons/file_icons/info.svg"), diff --git a/crates/theme/src/settings.rs b/crates/theme/src/settings.rs index 1c4c90a475..20c837f287 100644 --- a/crates/theme/src/settings.rs +++ b/crates/theme/src/settings.rs @@ -438,7 +438,7 @@ fn default_font_fallbacks() -> Option<FontFallbacks> { impl ThemeSettingsContent { /// Sets the theme for the given appearance to the theme with the specified name. - pub fn set_theme(&mut self, theme_name: String, appearance: Appearance) { + pub fn set_theme(&mut self, theme_name: impl Into<Arc<str>>, appearance: Appearance) { if let Some(selection) = self.theme.as_mut() { let theme_to_update = match selection { ThemeSelection::Static(theme) => theme, @@ -867,6 +867,7 @@ impl settings::Settings for ThemeSettings { .user .into_iter() .chain(sources.release_channel) + .chain(sources.profile) .chain(sources.server) { if let Some(value) = value.ui_density { diff --git a/crates/title_bar/Cargo.toml b/crates/title_bar/Cargo.toml index 8e95c6f79f..cf178e2850 100644 --- a/crates/title_bar/Cargo.toml +++ b/crates/title_bar/Cargo.toml @@ -32,6 +32,7 @@ auto_update.workspace = true call.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true db.workspace = true gpui = { workspace = true, features = ["screen-capture"] } notifications.workspace = true diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 17c4c85b6d..a8b16d881f 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -21,6 +21,7 @@ use crate::application_menu::{ use auto_update::AutoUpdateStatus; use call::ActiveCall; use client::{Client, UserStore, zed_urls}; +use cloud_llm_client::Plan; use gpui::{ Action, AnyElement, App, Context, Corner, Element, Entity, Focusable, InteractiveElement, IntoElement, MouseButton, ParentElement, Render, StatefulInteractiveElement, Styled, @@ -28,7 +29,6 @@ use gpui::{ }; use onboarding_banner::OnboardingBanner; use project::Project; -use rpc::proto; use settings::Settings as _; use settings_ui::keybindings; use std::sync::Arc; @@ -179,24 +179,23 @@ impl Render for TitleBar { children.push(self.banner.clone().into_any_element()) } + let status = self.client.status(); + let status = &*status.borrow(); + let user = self.user_store.read(cx).current_user(); + children.push( h_flex() .gap_1() .pr_1() .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()) .children(self.render_call_controls(window, cx)) - .map(|el| { - let status = self.client.status(); - let status = &*status.borrow(); - if matches!(status, client::Status::Connected { .. }) { - el.child(self.render_user_menu_button(cx)) - } else { - el.children(self.render_connection_status(status, cx)) - .when(TitleBarSettings::get_global(cx).show_sign_in, |el| { - el.child(self.render_sign_in_button(cx)) - }) - .child(self.render_user_menu_button(cx)) - } + .children(self.render_connection_status(status, cx)) + .when( + user.is_none() && TitleBarSettings::get_global(cx).show_sign_in, + |el| el.child(self.render_sign_in_button(cx)), + ) + .when(user.is_some(), |parent| { + parent.child(self.render_user_menu_button(cx)) }) .into_any_element(), ); @@ -618,9 +617,8 @@ impl TitleBar { window .spawn(cx, async move |cx| { client - .authenticate_and_connect(true, &cx) + .sign_in_with_optional_connect(true, &cx) .await - .into_response() .notify_async_err(cx); }) .detach(); @@ -630,8 +628,8 @@ impl TitleBar { pub fn render_user_menu_button(&mut self, cx: &mut Context<Self>) -> impl Element { let user_store = self.user_store.read(cx); if let Some(user) = user_store.current_user() { - let has_subscription_period = self.user_store.read(cx).subscription_period().is_some(); - let plan = self.user_store.read(cx).current_plan().filter(|_| { + let has_subscription_period = user_store.subscription_period().is_some(); + let plan = user_store.plan().filter(|_| { // Since the user might be on the legacy free plan we filter based on whether we have a subscription period. has_subscription_period }); @@ -658,13 +656,9 @@ impl TitleBar { let user_login = user.github_login.clone(); let (plan_name, label_color, bg_color) = match plan { - None | Some(proto::Plan::Free) => { - ("Free", Color::Default, free_chip_bg) - } - Some(proto::Plan::ZedProTrial) => { - ("Pro Trial", Color::Accent, pro_chip_bg) - } - Some(proto::Plan::ZedPro) => ("Pro", Color::Accent, pro_chip_bg), + None | Some(Plan::ZedFree) => ("Free", Color::Default, free_chip_bg), + Some(Plan::ZedProTrial) => ("Pro Trial", Color::Accent, pro_chip_bg), + Some(Plan::ZedPro) => ("Pro", Color::Accent, pro_chip_bg), }; menu.custom_entry( @@ -688,6 +682,10 @@ impl TitleBar { ) .separator() .action("Settings", zed_actions::OpenSettings.boxed_clone()) + .action( + "Settings Profiles", + zed_actions::settings_profile_selector::Toggle.boxed_clone(), + ) .action("Key Bindings", Box::new(keybindings::OpenKeymapEditor)) .action( "Themes…", @@ -732,6 +730,10 @@ impl TitleBar { .menu(|window, cx| { ContextMenu::build(window, cx, |menu, _, _| { menu.action("Settings", zed_actions::OpenSettings.boxed_clone()) + .action( + "Settings Profiles", + zed_actions::settings_profile_selector::Toggle.boxed_clone(), + ) .action("Key Bindings", Box::new(keybindings::OpenKeymapEditor)) .action( "Themes…", diff --git a/crates/ui/src/components.rs b/crates/ui/src/components.rs index 9c2961c55f..486673e733 100644 --- a/crates/ui/src/components.rs +++ b/crates/ui/src/components.rs @@ -1,4 +1,5 @@ mod avatar; +mod badge; mod banner; mod button; mod callout; @@ -41,6 +42,7 @@ mod tooltip; mod stories; pub use avatar::*; +pub use badge::*; pub use banner::*; pub use button::*; pub use callout::*; diff --git a/crates/ui/src/components/badge.rs b/crates/ui/src/components/badge.rs new file mode 100644 index 0000000000..f36e03291c --- /dev/null +++ b/crates/ui/src/components/badge.rs @@ -0,0 +1,94 @@ +use std::rc::Rc; + +use crate::Divider; +use crate::DividerColor; +use crate::Tooltip; +use crate::component_prelude::*; +use crate::prelude::*; +use gpui::AnyView; +use gpui::{AnyElement, IntoElement, SharedString, Window}; + +#[derive(IntoElement, RegisterComponent)] +pub struct Badge { + label: SharedString, + icon: IconName, + tooltip: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView>>, +} + +impl Badge { + pub fn new(label: impl Into<SharedString>) -> Self { + Self { + label: label.into(), + icon: IconName::Check, + tooltip: None, + } + } + + pub fn icon(mut self, icon: IconName) -> Self { + self.icon = icon; + self + } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } +} + +impl RenderOnce for Badge { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let tooltip = self.tooltip; + + h_flex() + .id(self.label.clone()) + .h_full() + .gap_1() + .pl_1() + .pr_2() + .border_1() + .border_color(cx.theme().colors().border.opacity(0.6)) + .bg(cx.theme().colors().element_background) + .rounded_sm() + .overflow_hidden() + .child( + Icon::new(self.icon) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(Divider::vertical().color(DividerColor::Border)) + .child(Label::new(self.label.clone()).size(LabelSize::Small).ml_1()) + .when_some(tooltip, |this, tooltip| { + this.tooltip(move |window, cx| tooltip(window, cx)) + }) + } +} + +impl Component for Badge { + fn scope() -> ComponentScope { + ComponentScope::DataDisplay + } + + fn description() -> Option<&'static str> { + Some( + "A compact, labeled component with optional icon for displaying status, categories, or metadata.", + ) + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { + Some( + v_flex() + .gap_6() + .child(single_example( + "Basic Badge", + Badge::new("Default").into_any_element(), + )) + .child(single_example( + "With Tooltip", + Badge::new("Tooltip") + .tooltip(Tooltip::text("This is a tooltip.")) + .into_any_element(), + )) + .into_any_element(), + ) + } +} diff --git a/crates/ui/src/components/button/button.rs b/crates/ui/src/components/button/button.rs index cae5d0e2ca..19f782fb98 100644 --- a/crates/ui/src/components/button/button.rs +++ b/crates/ui/src/components/button/button.rs @@ -393,6 +393,11 @@ impl ButtonCommon for Button { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.base = self.base.tab_index(tab_index); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.base = self.base.layer(elevation); self diff --git a/crates/ui/src/components/button/button_like.rs b/crates/ui/src/components/button/button_like.rs index 135ecdfe62..15ab00e7e5 100644 --- a/crates/ui/src/components/button/button_like.rs +++ b/crates/ui/src/components/button/button_like.rs @@ -1,7 +1,7 @@ use documented::Documented; use gpui::{ AnyElement, AnyView, ClickEvent, CursorStyle, DefiniteLength, Hsla, MouseButton, - MouseDownEvent, MouseUpEvent, Rems, relative, transparent_black, + MouseDownEvent, MouseUpEvent, Rems, StyleRefinement, relative, transparent_black, }; use smallvec::SmallVec; @@ -37,6 +37,8 @@ pub trait ButtonCommon: Clickable + Disableable { /// exceptions might a scroll bar, or a slider. fn tooltip(self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self; + fn tab_index(self, tab_index: impl Into<isize>) -> Self; + fn layer(self, elevation: ElevationIndex) -> Self; } @@ -358,6 +360,7 @@ impl ButtonStyle { #[derive(Default, PartialEq, Clone, Copy)] pub enum ButtonSize { Large, + Medium, #[default] Default, Compact, @@ -368,6 +371,7 @@ impl ButtonSize { pub fn rems(self) -> Rems { match self { ButtonSize::Large => rems_from_px(32.), + ButtonSize::Medium => rems_from_px(28.), ButtonSize::Default => rems_from_px(22.), ButtonSize::Compact => rems_from_px(18.), ButtonSize::None => rems_from_px(16.), @@ -391,6 +395,7 @@ pub struct ButtonLike { pub(super) width: Option<DefiniteLength>, pub(super) height: Option<DefiniteLength>, pub(super) layer: Option<ElevationIndex>, + tab_index: Option<isize>, size: ButtonSize, rounding: Option<ButtonLikeRounding>, tooltip: Option<Box<dyn Fn(&mut Window, &mut App) -> AnyView>>, @@ -419,6 +424,7 @@ impl ButtonLike { on_click: None, on_right_click: None, layer: None, + tab_index: None, } } @@ -523,6 +529,11 @@ impl ButtonCommon for ButtonLike { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.tab_index = Some(tab_index.into()); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.layer = Some(elevation); self @@ -552,6 +563,7 @@ impl RenderOnce for ButtonLike { self.base .h_flex() .id(self.id.clone()) + .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)) .font_ui(cx) .group("") .flex_none() @@ -573,7 +585,7 @@ impl RenderOnce for ButtonLike { }) .gap(DynamicSpacing::Base04.rems(cx)) .map(|this| match self.size { - ButtonSize::Large => this.px(DynamicSpacing::Base06.rems(cx)), + ButtonSize::Large | ButtonSize::Medium => this.px(DynamicSpacing::Base06.rems(cx)), ButtonSize::Default | ButtonSize::Compact => { this.px(DynamicSpacing::Base04.rems(cx)) } @@ -589,8 +601,12 @@ impl RenderOnce for ButtonLike { } }) .when(!self.disabled, |this| { + let hovered_style = style.hovered(self.layer, cx); + let focus_color = + |refinement: StyleRefinement| refinement.bg(hovered_style.background); this.cursor(self.cursor_style) - .hover(|hover| hover.bg(style.hovered(self.layer, cx).background)) + .hover(focus_color) + .focus(focus_color) .active(|active| active.bg(style.active(cx).background)) }) .when_some( diff --git a/crates/ui/src/components/button/icon_button.rs b/crates/ui/src/components/button/icon_button.rs index e5d13e09cd..8d8718a634 100644 --- a/crates/ui/src/components/button/icon_button.rs +++ b/crates/ui/src/components/button/icon_button.rs @@ -164,6 +164,11 @@ impl ButtonCommon for IconButton { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.base = self.base.tab_index(tab_index); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.base = self.base.layer(elevation); self diff --git a/crates/ui/src/components/button/toggle_button.rs b/crates/ui/src/components/button/toggle_button.rs index c6cf7ac62c..6fbf834667 100644 --- a/crates/ui/src/components/button/toggle_button.rs +++ b/crates/ui/src/components/button/toggle_button.rs @@ -121,6 +121,11 @@ impl ButtonCommon for ToggleButton { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.base = self.base.tab_index(tab_index); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.base = self.base.layer(elevation); self @@ -291,19 +296,25 @@ impl Component for ToggleButton { } } -mod private { - pub trait Sealed {} +pub struct ButtonConfiguration { + label: SharedString, + icon: Option<IconName>, + on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, + selected: bool, } -pub trait ButtonBuilder: 'static + private::Sealed { - fn label(&self) -> impl Into<SharedString>; - fn icon(&self) -> Option<IconName>; - fn on_click(self) -> Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>; +mod private { + pub trait ToggleButtonStyle {} +} + +pub trait ButtonBuilder: 'static + private::ToggleButtonStyle { + fn into_configuration(self) -> ButtonConfiguration; } pub struct ToggleButtonSimple { label: SharedString, on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, + selected: bool, } impl ToggleButtonSimple { @@ -314,23 +325,26 @@ impl ToggleButtonSimple { Self { label: label.into(), on_click: Box::new(on_click), + selected: false, } } + + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self + } } -impl private::Sealed for ToggleButtonSimple {} +impl private::ToggleButtonStyle for ToggleButtonSimple {} impl ButtonBuilder for ToggleButtonSimple { - fn label(&self) -> impl Into<SharedString> { - self.label.clone() - } - - fn icon(&self) -> Option<IconName> { - None - } - - fn on_click(self) -> Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static> { - self.on_click + fn into_configuration(self) -> ButtonConfiguration { + ButtonConfiguration { + label: self.label, + icon: None, + on_click: self.on_click, + selected: self.selected, + } } } @@ -338,6 +352,7 @@ pub struct ToggleButtonWithIcon { label: SharedString, icon: IconName, on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, + selected: bool, } impl ToggleButtonWithIcon { @@ -350,62 +365,25 @@ impl ToggleButtonWithIcon { label: label.into(), icon, on_click: Box::new(on_click), + selected: false, } } + + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self + } } -impl private::Sealed for ToggleButtonWithIcon {} +impl private::ToggleButtonStyle for ToggleButtonWithIcon {} impl ButtonBuilder for ToggleButtonWithIcon { - fn label(&self) -> impl Into<SharedString> { - self.label.clone() - } - - fn icon(&self) -> Option<IconName> { - Some(self.icon) - } - - fn on_click(self) -> Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static> { - self.on_click - } -} - -struct ToggleButtonRow<T: ButtonBuilder> { - items: Vec<T>, - index_offset: usize, - last_item_idx: usize, - is_last_row: bool, -} - -impl<T: ButtonBuilder> ToggleButtonRow<T> { - fn new(items: Vec<T>, index_offset: usize, is_last_row: bool) -> Self { - Self { - index_offset, - last_item_idx: index_offset + items.len() - 1, - is_last_row, - items, - } - } -} - -enum ToggleButtonGroupRows<T: ButtonBuilder> { - Single(Vec<T>), - Multiple(Vec<T>, Vec<T>), -} - -impl<T: ButtonBuilder> ToggleButtonGroupRows<T> { - fn items(self) -> impl IntoIterator<Item = ToggleButtonRow<T>> { - match self { - ToggleButtonGroupRows::Single(items) => { - vec![ToggleButtonRow::new(items, 0, true)] - } - ToggleButtonGroupRows::Multiple(first_row, second_row) => { - let row_len = first_row.len(); - vec![ - ToggleButtonRow::new(first_row, 0, false), - ToggleButtonRow::new(second_row, row_len, true), - ] - } + fn into_configuration(self) -> ButtonConfiguration { + ButtonConfiguration { + label: self.label, + icon: Some(self.icon), + on_click: self.on_click, + selected: self.selected, } } } @@ -417,54 +395,65 @@ pub enum ToggleButtonGroupStyle { Outlined, } +#[derive(Clone, Copy, PartialEq)] +pub enum ToggleButtonGroupSize { + Default, + Medium, +} + #[derive(IntoElement)] -pub struct ToggleButtonGroup<T> +pub struct ToggleButtonGroup<T, const COLS: usize = 3, const ROWS: usize = 1> where T: ButtonBuilder, { - group_name: SharedString, - rows: ToggleButtonGroupRows<T>, + group_name: &'static str, + rows: [[T; COLS]; ROWS], style: ToggleButtonGroupStyle, + size: ToggleButtonGroupSize, button_width: Rems, selected_index: usize, + tab_index: Option<isize>, } -impl<T: ButtonBuilder> ToggleButtonGroup<T> { - pub fn single_row( - group_name: impl Into<SharedString>, - buttons: impl IntoIterator<Item = T>, - ) -> Self { +impl<T: ButtonBuilder, const COLS: usize> ToggleButtonGroup<T, COLS> { + pub fn single_row(group_name: &'static str, buttons: [T; COLS]) -> Self { Self { - group_name: group_name.into(), - rows: ToggleButtonGroupRows::Single(Vec::from_iter(buttons)), + group_name, + rows: [buttons], style: ToggleButtonGroupStyle::Transparent, + size: ToggleButtonGroupSize::Default, button_width: rems_from_px(100.), selected_index: 0, + tab_index: None, } } +} - pub fn multiple_rows<const ROWS: usize>( - group_name: impl Into<SharedString>, - first_row: [T; ROWS], - second_row: [T; ROWS], - ) -> Self { +impl<T: ButtonBuilder, const COLS: usize> ToggleButtonGroup<T, COLS, 2> { + pub fn two_rows(group_name: &'static str, first_row: [T; COLS], second_row: [T; COLS]) -> Self { Self { - group_name: group_name.into(), - rows: ToggleButtonGroupRows::Multiple( - Vec::from_iter(first_row), - Vec::from_iter(second_row), - ), + group_name, + rows: [first_row, second_row], style: ToggleButtonGroupStyle::Transparent, + size: ToggleButtonGroupSize::Default, button_width: rems_from_px(100.), selected_index: 0, + tab_index: None, } } +} +impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> ToggleButtonGroup<T, COLS, ROWS> { pub fn style(mut self, style: ToggleButtonGroupStyle) -> Self { self.style = style; self } + pub fn size(mut self, size: ToggleButtonGroupSize) -> Self { + self.size = size; + self + } + pub fn button_width(mut self, button_width: Rems) -> Self { self.button_width = button_width; self @@ -474,62 +463,79 @@ impl<T: ButtonBuilder> ToggleButtonGroup<T> { self.selected_index = index; self } + + /// Sets the tab index for the toggle button group. + /// The tab index is set to the initial value provided, then the + /// value is incremented by the number of buttons in the group. + pub fn tab_index(mut self, tab_index: &mut isize) -> Self { + self.tab_index = Some(*tab_index); + *tab_index += (COLS * ROWS) as isize; + self + } } -impl<T: ButtonBuilder> RenderOnce for ToggleButtonGroup<T> { +impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> RenderOnce + for ToggleButtonGroup<T, COLS, ROWS> +{ fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - let rows = self.rows.items().into_iter().map(|row| { - ( - row.items - .into_iter() - .enumerate() - .map(move |(index, item)| (index + row.index_offset, row.last_item_idx, item)) - .map(|(index, last_item_idx, item)| { - ( - ButtonLike::new((self.group_name.clone(), index)) - .when(index == self.selected_index, |this| { - this.toggle_state(true) - .selected_style(ButtonStyle::Tinted(TintColor::Accent)) - }) - .rounding(None) - .when(self.style == ToggleButtonGroupStyle::Filled, |button| { - button.style(ButtonStyle::Filled) - }) - .child( - h_flex() - .min_w(self.button_width) - .gap_1p5() - .justify_center() - .when_some(item.icon(), |this, icon| { - this.child(Icon::new(icon).size(IconSize::XSmall).map( - |this| { - if index == self.selected_index { - this.color(Color::Accent) - } else { - this.color(Color::Muted) - } - }, - )) - }) - .child( - Label::new(item.label()) - .when(index == self.selected_index, |this| { - this.color(Color::Accent) - }), - ), - ) - .on_click(item.on_click()), - index == last_item_idx, - ) - }), - row.is_last_row, - ) - }); + let entries = + self.rows.into_iter().enumerate().map(|(row_index, row)| { + row.into_iter().enumerate().map(move |(col_index, button)| { + let ButtonConfiguration { + label, + icon, + on_click, + selected, + } = button.into_configuration(); + let entry_index = row_index * COLS + col_index; + + ButtonLike::new((self.group_name, entry_index)) + .when_some(self.tab_index, |this, tab_index| { + this.tab_index(tab_index + entry_index as isize) + }) + .when(entry_index == self.selected_index || selected, |this| { + this.toggle_state(true) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) + }) + .rounding(None) + .when(self.style == ToggleButtonGroupStyle::Filled, |button| { + button.style(ButtonStyle::Filled) + }) + .when(self.size == ToggleButtonGroupSize::Medium, |button| { + button.size(ButtonSize::Medium) + }) + .child( + h_flex() + .min_w(self.button_width) + .gap_1p5() + .px_3() + .py_1() + .justify_center() + .when_some(icon, |this, icon| { + this.py_2() + .child(Icon::new(icon).size(IconSize::XSmall).map(|this| { + if entry_index == self.selected_index || selected { + this.color(Color::Accent) + } else { + this.color(Color::Muted) + } + })) + }) + .child(Label::new(label).size(LabelSize::Small).when( + entry_index == self.selected_index || selected, + |this| this.color(Color::Accent), + )), + ) + .on_click(on_click) + .into_any_element() + }) + }); + + let border_color = cx.theme().colors().border.opacity(0.6); let is_outlined_or_filled = self.style == ToggleButtonGroupStyle::Outlined || self.style == ToggleButtonGroupStyle::Filled; let is_transparent = self.style == ToggleButtonGroupStyle::Transparent; - let border_color = cx.theme().colors().border.opacity(0.6); v_flex() .rounded_md() @@ -541,13 +547,15 @@ impl<T: ButtonBuilder> RenderOnce for ToggleButtonGroup<T> { this.border_1().border_color(border_color) } }) - .children(rows.map(|(items, last_row)| { + .children(entries.enumerate().map(|(row_index, row)| { + let last_row = row_index == ROWS - 1; h_flex() .when(!is_outlined_or_filled, |this| this.gap_px()) .when(is_outlined_or_filled && !last_row, |this| { this.border_b_1().border_color(border_color) }) - .children(items.map(|(item, last_item)| { + .children(row.enumerate().map(|(item_index, item)| { + let last_item = item_index == COLS - 1; div() .when(is_outlined_or_filled && !last_item, |this| { this.border_r_1().border_color(border_color) @@ -566,7 +574,9 @@ component::__private::inventory::submit! { component::ComponentFn::new(register_toggle_button_group) } -impl<T: ButtonBuilder> Component for ToggleButtonGroup<T> { +impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> Component + for ToggleButtonGroup<T, COLS, ROWS> +{ fn name() -> &'static str { "ToggleButtonGroup" } @@ -628,7 +638,7 @@ impl<T: ButtonBuilder> Component for ToggleButtonGroup<T> { ), single_example( "Multiple Row Group", - ToggleButtonGroup::multiple_rows( + ToggleButtonGroup::two_rows( "multiple_row_test", [ ToggleButtonSimple::new("First", |_, _, _| {}), @@ -647,7 +657,7 @@ impl<T: ButtonBuilder> Component for ToggleButtonGroup<T> { ), single_example( "Multiple Row Group with Icons", - ToggleButtonGroup::multiple_rows( + ToggleButtonGroup::two_rows( "multiple_row_test_icons", [ ToggleButtonWithIcon::new( @@ -736,7 +746,7 @@ impl<T: ButtonBuilder> Component for ToggleButtonGroup<T> { ), single_example( "Multiple Row Group", - ToggleButtonGroup::multiple_rows( + ToggleButtonGroup::two_rows( "multiple_row_test", [ ToggleButtonSimple::new("First", |_, _, _| {}), @@ -756,7 +766,7 @@ impl<T: ButtonBuilder> Component for ToggleButtonGroup<T> { ), single_example( "Multiple Row Group with Icons", - ToggleButtonGroup::multiple_rows( + ToggleButtonGroup::two_rows( "multiple_row_test", [ ToggleButtonWithIcon::new( @@ -846,7 +856,7 @@ impl<T: ButtonBuilder> Component for ToggleButtonGroup<T> { ), single_example( "Multiple Row Group", - ToggleButtonGroup::multiple_rows( + ToggleButtonGroup::two_rows( "multiple_row_test", [ ToggleButtonSimple::new("First", |_, _, _| {}), @@ -866,7 +876,7 @@ impl<T: ButtonBuilder> Component for ToggleButtonGroup<T> { ), single_example( "Multiple Row Group with Icons", - ToggleButtonGroup::multiple_rows( + ToggleButtonGroup::two_rows( "multiple_row_test", [ ToggleButtonWithIcon::new( diff --git a/crates/ui/src/components/dropdown_menu.rs b/crates/ui/src/components/dropdown_menu.rs index 189fac930f..7ad9400f0d 100644 --- a/crates/ui/src/components/dropdown_menu.rs +++ b/crates/ui/src/components/dropdown_menu.rs @@ -8,6 +8,7 @@ use super::PopoverMenuHandle; pub enum DropdownStyle { #[default] Solid, + Outlined, Ghost, } @@ -147,6 +148,23 @@ impl Component for DropdownMenu { ), ], ), + example_group_with_title( + "Styles", + vec![ + single_example( + "Outlined", + DropdownMenu::new("outlined", "Outlined Dropdown", menu.clone()) + .style(DropdownStyle::Outlined) + .into_any_element(), + ), + single_example( + "Ghost", + DropdownMenu::new("ghost", "Ghost Dropdown", menu.clone()) + .style(DropdownStyle::Ghost) + .into_any_element(), + ), + ], + ), example_group_with_title( "States", vec![single_example( @@ -170,10 +188,13 @@ pub struct DropdownTriggerStyle { impl DropdownTriggerStyle { pub fn for_style(style: DropdownStyle, cx: &App) -> Self { let colors = cx.theme().colors(); + let bg = match style { DropdownStyle::Solid => colors.editor_background, + DropdownStyle::Outlined => colors.surface_background, DropdownStyle::Ghost => colors.ghost_element_background, }; + Self { bg } } } @@ -244,29 +265,36 @@ impl RenderOnce for DropdownMenuTrigger { let disabled = self.disabled; let style = DropdownTriggerStyle::for_style(self.style, cx); + let is_outlined = matches!(self.style, DropdownStyle::Outlined); h_flex() .id("dropdown-menu-trigger") - .justify_between() - .rounded_sm() - .bg(style.bg) + .min_w_20() .pl_2() .pr_1p5() .py_0p5() .gap_2() - .min_w_20() - .map(|el| { + .justify_between() + .rounded_sm() + .map(|this| { if self.full_width { - el.w_full() + this.w_full() } else { - el.flex_none().w_auto() + this.flex_none().w_auto() } }) - .map(|el| { + .when(is_outlined, |this| { + this.border_1() + .border_color(cx.theme().colors().border) + .overflow_hidden() + }) + .map(|this| { if disabled { - el.cursor_not_allowed() + this.cursor_not_allowed() + .bg(cx.theme().colors().element_disabled) } else { - el.cursor_pointer() + this.bg(style.bg) + .hover(|s| s.bg(cx.theme().colors().element_hover)) } }) .child(match self.label { diff --git a/crates/ui/src/components/keybinding.rs b/crates/ui/src/components/keybinding.rs index 1d91492f26..5779093ccc 100644 --- a/crates/ui/src/components/keybinding.rs +++ b/crates/ui/src/components/keybinding.rs @@ -44,7 +44,7 @@ impl KeyBinding { pub fn for_action_in( action: &dyn Action, focus: &FocusHandle, - window: &mut Window, + window: &Window, cx: &App, ) -> Option<Self> { let key_binding = window.highest_precedence_binding_for_action_in(action, focus)?; diff --git a/crates/ui/src/components/modal.rs b/crates/ui/src/components/modal.rs index 2145b34ef2..a70f5e1ea5 100644 --- a/crates/ui/src/components/modal.rs +++ b/crates/ui/src/components/modal.rs @@ -1,5 +1,5 @@ use crate::{ - Clickable, Color, DynamicSpacing, Headline, HeadlineSize, IconButton, IconButtonShape, + Clickable, Color, DynamicSpacing, Headline, HeadlineSize, Icon, IconButton, IconButtonShape, IconName, Label, LabelCommon, LabelSize, h_flex, v_flex, }; use gpui::{prelude::FluentBuilder, *}; @@ -92,6 +92,7 @@ impl RenderOnce for Modal { #[derive(IntoElement)] pub struct ModalHeader { + icon: Option<Icon>, headline: Option<SharedString>, description: Option<SharedString>, children: SmallVec<[AnyElement; 2]>, @@ -108,6 +109,7 @@ impl Default for ModalHeader { impl ModalHeader { pub fn new() -> Self { Self { + icon: None, headline: None, description: None, children: SmallVec::new(), @@ -116,6 +118,11 @@ impl ModalHeader { } } + pub fn icon(mut self, icon: Icon) -> Self { + self.icon = Some(icon); + self + } + /// Set the headline of the modal. /// /// This will insert the headline as the first item @@ -179,12 +186,17 @@ impl RenderOnce for ModalHeader { ) }) .child( - v_flex().flex_1().children(children).when_some( - self.description, - |this, description| { + v_flex() + .flex_1() + .child( + h_flex() + .gap_1() + .when_some(self.icon, |this, icon| this.child(icon)) + .children(children), + ) + .when_some(self.description, |this, description| { this.child(Label::new(description).color(Color::Muted).mb_2()) - }, - ), + }), ) .when(self.show_dismiss_button, |this| { this.child( diff --git a/crates/ui/src/components/numeric_stepper.rs b/crates/ui/src/components/numeric_stepper.rs index f9e6e88f01..2ddb86d9a0 100644 --- a/crates/ui/src/components/numeric_stepper.rs +++ b/crates/ui/src/components/numeric_stepper.rs @@ -2,15 +2,24 @@ use gpui::ClickEvent; use crate::{IconButtonShape, prelude::*}; -#[derive(IntoElement)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NumericStepperStyle { + Outlined, + #[default] + Ghost, +} + +#[derive(IntoElement, RegisterComponent)] pub struct NumericStepper { id: ElementId, value: SharedString, + style: NumericStepperStyle, on_decrement: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, on_increment: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, /// Whether to reserve space for the reset button. reserve_space_for_reset: bool, on_reset: Option<Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>>, + tab_index: Option<isize>, } impl NumericStepper { @@ -23,13 +32,20 @@ impl NumericStepper { Self { id: id.into(), value: value.into(), + style: NumericStepperStyle::default(), on_decrement: Box::new(on_decrement), on_increment: Box::new(on_increment), reserve_space_for_reset: false, on_reset: None, + tab_index: None, } } + pub fn style(mut self, style: NumericStepperStyle) -> Self { + self.style = style; + self + } + pub fn reserve_space_for_reset(mut self, reserve_space_for_reset: bool) -> Self { self.reserve_space_for_reset = reserve_space_for_reset; self @@ -42,6 +58,11 @@ impl NumericStepper { self.on_reset = Some(Box::new(on_reset)); self } + + pub fn tab_index(mut self, tab_index: isize) -> Self { + self.tab_index = Some(tab_index); + self + } } impl RenderOnce for NumericStepper { @@ -49,6 +70,9 @@ impl RenderOnce for NumericStepper { let shape = IconButtonShape::Square; let icon_size = IconSize::Small; + let is_outlined = matches!(self.style, NumericStepperStyle::Outlined); + let mut tab_index = self.tab_index; + h_flex() .id(self.id) .gap_1() @@ -58,6 +82,10 @@ impl RenderOnce for NumericStepper { IconButton::new("reset", IconName::RotateCcw) .shape(shape) .icon_size(icon_size) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1) + }) .on_click(on_reset), ) } else if self.reserve_space_for_reset { @@ -74,22 +102,136 @@ impl RenderOnce for NumericStepper { .child( h_flex() .gap_1() - .px_1() - .rounded_xs() - .bg(cx.theme().colors().editor_background) - .child( - IconButton::new("decrement", IconName::Dash) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_decrement), - ) - .child(Label::new(self.value)) - .child( - IconButton::new("increment", IconName::Plus) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_increment), - ), + .rounded_sm() + .map(|this| { + if is_outlined { + this.overflow_hidden() + .bg(cx.theme().colors().surface_background) + .border_1() + .border_color(cx.theme().colors().border_variant) + } else { + this.px_1().bg(cx.theme().colors().editor_background) + } + }) + .map(|decrement| { + if is_outlined { + decrement.child( + h_flex() + .id("decrement_button") + .p_1p5() + .size_full() + .justify_center() + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .border_r_1() + .border_color(cx.theme().colors().border_variant) + .child(Icon::new(IconName::Dash).size(IconSize::Small)) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1).focus(|style| { + style.bg(cx.theme().colors().element_hover) + }) + }) + .on_click(self.on_decrement), + ) + } else { + decrement.child( + IconButton::new("decrement", IconName::Dash) + .shape(shape) + .icon_size(icon_size) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1) + }) + .on_click(self.on_decrement), + ) + } + }) + .child(Label::new(self.value).mx_3()) + .map(|increment| { + if is_outlined { + increment.child( + h_flex() + .id("increment_button") + .p_1p5() + .size_full() + .justify_center() + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .child(Icon::new(IconName::Plus).size(IconSize::Small)) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1).focus(|style| { + style.bg(cx.theme().colors().element_hover) + }) + }) + .on_click(self.on_increment), + ) + } else { + increment.child( + IconButton::new("increment", IconName::Dash) + .shape(shape) + .icon_size(icon_size) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1) + }) + .on_click(self.on_increment), + ) + } + }), ) } } + +impl Component for NumericStepper { + fn scope() -> ComponentScope { + ComponentScope::Input + } + + fn name() -> &'static str { + "Numeric Stepper" + } + + fn sort_name() -> &'static str { + Self::name() + } + + fn description() -> Option<&'static str> { + Some("A button used to increment or decrement a numeric value.") + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { + Some( + v_flex() + .gap_6() + .children(vec![example_group_with_title( + "Styles", + vec![ + single_example( + "Default", + NumericStepper::new( + "numeric-stepper-component-preview", + "10", + move |_, _, _| {}, + move |_, _, _| {}, + ) + .into_any_element(), + ), + single_example( + "Outlined", + NumericStepper::new( + "numeric-stepper-with-border-component-preview", + "10", + move |_, _, _| {}, + move |_, _, _| {}, + ) + .style(NumericStepperStyle::Outlined) + .into_any_element(), + ), + ], + )]) + .into_any_element(), + ) + } +} diff --git a/crates/ui/src/components/popover.rs b/crates/ui/src/components/popover.rs index 24460f6d9c..7143514c52 100644 --- a/crates/ui/src/components/popover.rs +++ b/crates/ui/src/components/popover.rs @@ -50,7 +50,7 @@ impl RenderOnce for Popover { v_flex() .elevation_2(cx) .py(POPOVER_Y_PADDING / 2.) - .children(self.children), + .child(div().children(self.children)), ) .when_some(self.aside, |this, aside| { this.child( diff --git a/crates/ui/src/components/stories/icon_button.rs b/crates/ui/src/components/stories/icon_button.rs index e787e81b55..ad6886252d 100644 --- a/crates/ui/src/components/stories/icon_button.rs +++ b/crates/ui/src/components/stories/icon_button.rs @@ -77,7 +77,7 @@ impl Render for IconButtonStory { let with_tooltip_button = StoryItem::new( "With `tooltip`", - IconButton::new("with_tooltip_button", IconName::MessageBubbles) + IconButton::new("with_tooltip_button", IconName::Chat) .tooltip(Tooltip::text("Open messages")), ) .description("Displays an icon button that has a tooltip when hovered.") diff --git a/crates/ui/src/components/toggle.rs b/crates/ui/src/components/toggle.rs index cf2a56b1c9..53df4767b0 100644 --- a/crates/ui/src/components/toggle.rs +++ b/crates/ui/src/components/toggle.rs @@ -2,10 +2,10 @@ use gpui::{ AnyElement, AnyView, ClickEvent, ElementId, Hsla, IntoElement, Styled, Window, div, hsla, prelude::*, }; -use std::sync::Arc; +use std::{rc::Rc, sync::Arc}; use crate::utils::is_light; -use crate::{Color, Icon, IconName, ToggleState}; +use crate::{Color, Icon, IconName, ToggleState, Tooltip}; use crate::{ElevationIndex, KeyBinding, prelude::*}; // TODO: Checkbox, CheckboxWithLabel, and Switch could all be @@ -424,6 +424,7 @@ pub struct Switch { label: Option<SharedString>, key_binding: Option<KeyBinding>, color: SwitchColor, + tab_index: Option<isize>, } impl Switch { @@ -437,6 +438,7 @@ impl Switch { label: None, key_binding: None, color: SwitchColor::default(), + tab_index: None, } } @@ -472,6 +474,11 @@ impl Switch { self.key_binding = key_binding.into(); self } + + pub fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.tab_index = Some(tab_index.into()); + self + } } impl RenderOnce for Switch { @@ -501,6 +508,20 @@ impl RenderOnce for Switch { .w(DynamicSpacing::Base32.rems(cx)) .h(DynamicSpacing::Base20.rems(cx)) .group(group_id.clone()) + .border_1() + .p(px(1.0)) + .border_color(cx.theme().colors().border_transparent) + .rounded_full() + .id((self.id.clone(), "switch")) + .when_some( + self.tab_index.filter(|_| !self.disabled), + |this, tab_index| { + this.tab_index(tab_index).focus(|mut style| { + style.border_color = Some(cx.theme().colors().border_focused); + style + }) + }, + ) .child( h_flex() .when(is_on, |on| on.justify_end()) @@ -566,32 +587,41 @@ impl RenderOnce for Switch { pub struct SwitchField { id: ElementId, label: SharedString, - description: SharedString, + description: Option<SharedString>, toggle_state: ToggleState, on_click: Arc<dyn Fn(&ToggleState, &mut Window, &mut App) + 'static>, disabled: bool, color: SwitchColor, + tooltip: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView>>, + tab_index: Option<isize>, } impl SwitchField { pub fn new( id: impl Into<ElementId>, label: impl Into<SharedString>, - description: impl Into<SharedString>, + description: Option<SharedString>, toggle_state: impl Into<ToggleState>, on_click: impl Fn(&ToggleState, &mut Window, &mut App) + 'static, ) -> Self { Self { id: id.into(), label: label.into(), - description: description.into(), + description: description, toggle_state: toggle_state.into(), on_click: Arc::new(on_click), disabled: false, color: SwitchColor::Accent, + tooltip: None, + tab_index: None, } } + pub fn description(mut self, description: impl Into<SharedString>) -> Self { + self.description = Some(description.into()); + self + } + pub fn disabled(mut self, disabled: bool) -> Self { self.disabled = disabled; self @@ -603,36 +633,75 @@ impl SwitchField { self.color = color; self } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } + + pub fn tab_index(mut self, tab_index: isize) -> Self { + self.tab_index = Some(tab_index); + self + } } impl RenderOnce for SwitchField { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { + let tooltip = self.tooltip.map(|tooltip_fn| { + h_flex() + .gap_0p5() + .child(Label::new(self.label.clone())) + .child( + IconButton::new("tooltip_button", IconName::Info) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .shape(crate::IconButtonShape::Square) + .tooltip({ + let tooltip = tooltip_fn.clone(); + move |window, cx| tooltip(window, cx) + }), + ) + }); + h_flex() - .id(SharedString::from(format!("{}-container", self.id))) + .id((self.id.clone(), "container")) + .when(!self.disabled, |this| { + this.hover(|this| this.cursor_pointer()) + }) .w_full() .gap_4() .justify_between() .flex_wrap() - .child( - v_flex() + .child(match (&self.description, tooltip) { + (Some(description), Some(tooltip)) => v_flex() .gap_0p5() .max_w_5_6() - .child(Label::new(self.label)) - .child(Label::new(self.description).color(Color::Muted)), - ) + .child(tooltip) + .child(Label::new(description.clone()).color(Color::Muted)) + .into_any_element(), + (Some(description), None) => v_flex() + .gap_0p5() + .max_w_5_6() + .child(Label::new(self.label.clone())) + .child(Label::new(description.clone()).color(Color::Muted)) + .into_any_element(), + (None, Some(tooltip)) => tooltip.into_any_element(), + (None, None) => Label::new(self.label.clone()).into_any_element(), + }) .child( - Switch::new( - SharedString::from(format!("{}-switch", self.id)), - self.toggle_state, - ) - .color(self.color) - .disabled(self.disabled) - .on_click({ - let on_click = self.on_click.clone(); - move |state, window, cx| { - (on_click)(state, window, cx); - } - }), + Switch::new((self.id.clone(), "switch"), self.toggle_state) + .color(self.color) + .disabled(self.disabled) + .when_some( + self.tab_index.filter(|_| !self.disabled), + |this, tab_index| this.tab_index(tab_index), + ) + .on_click({ + let on_click = self.on_click.clone(); + move |state, window, cx| { + (on_click)(state, window, cx); + } + }), ) .when(!self.disabled, |this| { this.on_click({ @@ -668,7 +737,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_unselected", "Enable notifications", - "Receive notifications when new messages arrive.", + Some("Receive notifications when new messages arrive.".into()), ToggleState::Unselected, |_, _, _| {}, ) @@ -679,7 +748,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_selected", "Enable notifications", - "Receive notifications when new messages arrive.", + Some("Receive notifications when new messages arrive.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -695,7 +764,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_default", "Default color", - "This uses the default switch color.", + Some("This uses the default switch color.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -706,7 +775,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_accent", "Accent color", - "This uses the accent color scheme.", + Some("This uses the accent color scheme.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -722,7 +791,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_disabled", "Disabled field", - "This field is disabled and cannot be toggled.", + Some("This field is disabled and cannot be toggled.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -730,6 +799,49 @@ impl Component for SwitchField { .into_any_element(), )], ), + example_group_with_title( + "No Description", + vec![single_example( + "No Description", + SwitchField::new( + "switch_field_disabled", + "Disabled field", + None, + ToggleState::Selected, + |_, _, _| {}, + ) + .into_any_element(), + )], + ), + example_group_with_title( + "With Tooltip", + vec![ + single_example( + "Tooltip with Description", + SwitchField::new( + "switch_field_tooltip_with_desc", + "Nice Feature", + Some("Enable advanced configuration options.".into()), + ToggleState::Unselected, + |_, _, _| {}, + ) + .tooltip(Tooltip::text("This is content for this tooltip!")) + .into_any_element(), + ), + single_example( + "Tooltip without Description", + SwitchField::new( + "switch_field_tooltip_no_desc", + "Nice Feature", + None, + ToggleState::Selected, + |_, _, _| {}, + ) + .tooltip(Tooltip::text("This is content for this tooltip!")) + .into_any_element(), + ), + ], + ), ]) .into_any_element(), ) diff --git a/crates/ui/src/styles/animation.rs b/crates/ui/src/styles/animation.rs index 50c4e0eb0d..0649bee1f8 100644 --- a/crates/ui/src/styles/animation.rs +++ b/crates/ui/src/styles/animation.rs @@ -109,7 +109,7 @@ impl Component for Animation { fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { let container_size = 128.0; let element_size = 32.0; - let left_offset = element_size - container_size / 2.0; + let offset = container_size / 2.0 - element_size / 2.0; Some( v_flex() .gap_6() @@ -129,7 +129,7 @@ impl Component for Animation { .id("animate-in-from-bottom") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::red()) .animate_in(AnimationDirection::FromBottom, false), @@ -148,7 +148,7 @@ impl Component for Animation { .id("animate-in-from-top") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::blue()) .animate_in(AnimationDirection::FromTop, false), @@ -167,7 +167,7 @@ impl Component for Animation { .id("animate-in-from-left") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::green()) .animate_in(AnimationDirection::FromLeft, false), @@ -186,7 +186,7 @@ impl Component for Animation { .id("animate-in-from-right") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::yellow()) .animate_in(AnimationDirection::FromRight, false), @@ -211,7 +211,7 @@ impl Component for Animation { .id("fade-animate-in-from-bottom") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::red()) .animate_in(AnimationDirection::FromBottom, true), @@ -230,7 +230,7 @@ impl Component for Animation { .id("fade-animate-in-from-top") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::blue()) .animate_in(AnimationDirection::FromTop, true), @@ -249,7 +249,7 @@ impl Component for Animation { .id("fade-animate-in-from-left") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::green()) .animate_in(AnimationDirection::FromLeft, true), @@ -268,7 +268,7 @@ impl Component for Animation { .id("fade-animate-in-from-right") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::yellow()) .animate_in(AnimationDirection::FromRight, true), diff --git a/crates/ui_prompt/src/ui_prompt.rs b/crates/ui_prompt/src/ui_prompt.rs index 2b6a030f26..fe6dc5b3f4 100644 --- a/crates/ui_prompt/src/ui_prompt.rs +++ b/crates/ui_prompt/src/ui_prompt.rs @@ -43,7 +43,7 @@ fn zed_prompt_renderer( let renderer = cx.new({ |cx| ZedPromptRenderer { _level: level, - message: message.to_string(), + message: cx.new(|cx| Markdown::new(SharedString::new(message), None, None, cx)), actions: actions.iter().map(|a| a.label().to_string()).collect(), focus: cx.focus_handle(), active_action_id: 0, @@ -58,7 +58,7 @@ fn zed_prompt_renderer( pub struct ZedPromptRenderer { _level: PromptLevel, - message: String, + message: Entity<Markdown>, actions: Vec<String>, focus: FocusHandle, active_action_id: usize, @@ -114,7 +114,7 @@ impl ZedPromptRenderer { impl Render for ZedPromptRenderer { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); - let font_family = settings.ui_font.family.clone(); + let font_size = settings.ui_font_size(cx).into(); let prompt = v_flex() .key_context("Prompt") .cursor_default() @@ -130,24 +130,38 @@ impl Render for ZedPromptRenderer { .overflow_hidden() .p_4() .gap_4() - .font_family(font_family) + .font_family(settings.ui_font.family.clone()) .child( div() .w_full() - .font_weight(FontWeight::BOLD) - .child(self.message.clone()) - .text_color(ui::Color::Default.color(cx)), + .child(MarkdownElement::new(self.message.clone(), { + let mut base_text_style = window.text_style(); + base_text_style.refine(&TextStyleRefinement { + font_family: Some(settings.ui_font.family.clone()), + font_size: Some(font_size), + font_weight: Some(FontWeight::BOLD), + color: Some(ui::Color::Default.color(cx)), + ..Default::default() + }); + MarkdownStyle { + base_text_style, + selection_background_color: cx + .theme() + .colors() + .element_selection_background, + ..Default::default() + } + })), ) .children(self.detail.clone().map(|detail| { div() .w_full() .text_xs() .child(MarkdownElement::new(detail, { - let settings = ThemeSettings::get_global(cx); let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { font_family: Some(settings.ui_font.family.clone()), - font_size: Some(settings.ui_font_size(cx).into()), + font_size: Some(font_size), color: Some(ui::Color::Muted.color(cx)), ..Default::default() }); @@ -176,24 +190,28 @@ impl Render for ZedPromptRenderer { }), )); - div().size_full().occlude().child( - div() - .size_full() - .absolute() - .top_0() - .left_0() - .flex() - .flex_col() - .justify_around() - .child( - div() - .w_full() - .flex() - .flex_row() - .justify_around() - .child(prompt), - ), - ) + div() + .size_full() + .occlude() + .bg(gpui::black().opacity(0.2)) + .child( + div() + .size_full() + .absolute() + .top_0() + .left_0() + .flex() + .flex_col() + .justify_around() + .child( + div() + .w_full() + .flex() + .flex_row() + .justify_around() + .child(prompt), + ), + ) } } diff --git a/crates/util/src/shell_env.rs b/crates/util/src/shell_env.rs index d737999e45..2b1063316f 100644 --- a/crates/util/src/shell_env.rs +++ b/crates/util/src/shell_env.rs @@ -30,6 +30,7 @@ pub fn capture(directory: &std::path::Path) -> Result<collections::HashMap<Strin command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); + let mut command_prefix = String::new(); match shell_name { Some("tcsh" | "csh") => { // For csh/tcsh, login shell requires passing `-` as 0th argument (instead of `-l`) @@ -40,13 +41,20 @@ pub fn capture(directory: &std::path::Path) -> Result<collections::HashMap<Strin command_string.push_str("emit fish_prompt;"); command.arg("-l"); } + Some("nu") => { + // nu needs special handling for -- options. + command_prefix = String::from("^"); + } _ => { command.arg("-l"); } } // cd into the directory, triggering directory specific side-effects (asdf, direnv, etc) command_string.push_str(&format!("cd '{}';", directory.display())); - command_string.push_str(&format!("{} --printenv {}", zed_path, redir)); + command_string.push_str(&format!( + "{}{} --printenv {}", + command_prefix, zed_path, redir + )); command.args(["-i", "-c", &command_string]); super::set_pre_exec_to_start_new_session(&mut command); diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index a50b238cc5..0e487f4410 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -987,7 +987,7 @@ impl Motion { SelectionGoal::None, ), NextWordEnd { ignore_punctuation } => ( - next_word_end(map, point, *ignore_punctuation, times, true), + next_word_end(map, point, *ignore_punctuation, times, true, true), SelectionGoal::None, ), PreviousWordStart { ignore_punctuation } => ( @@ -1723,14 +1723,19 @@ pub(crate) fn next_word_end( ignore_punctuation: bool, times: usize, allow_cross_newline: bool, + always_advance: bool, ) -> DisplayPoint { let classifier = map .buffer_snapshot .char_classifier_at(point.to_point(map)) .ignore_punctuation(ignore_punctuation); for _ in 0..times { - let new_point = next_char(map, point, allow_cross_newline); let mut need_next_char = false; + let new_point = if always_advance { + next_char(map, point, allow_cross_newline) + } else { + point + }; let new_point = movement::find_boundary_exclusive( map, new_point, @@ -3803,7 +3808,7 @@ mod test { cx.update_editor(|editor, _window, cx| { let range = editor.selections.newest_anchor().range(); let inlay_text = " field: int,\n field2: string\n field3: float"; - let inlay = Inlay::inline_completion(1, range.start, inlay_text); + let inlay = Inlay::edit_prediction(1, range.start, inlay_text); editor.splice_inlays(&[], vec![inlay], cx); }); @@ -3835,7 +3840,7 @@ mod test { let end_of_line = snapshot.anchor_after(Point::new(0, snapshot.line_len(MultiBufferRow(0)))); let inlay_text = " hint"; - let inlay = Inlay::inline_completion(1, end_of_line, inlay_text); + let inlay = Inlay::edit_prediction(1, end_of_line, inlay_text); editor.splice_inlays(&[], vec![inlay], cx); }); cx.simulate_keystrokes("$"); diff --git a/crates/vim/src/normal/change.rs b/crates/vim/src/normal/change.rs index 9485f17477..c1bc7a70ae 100644 --- a/crates/vim/src/normal/change.rs +++ b/crates/vim/src/normal/change.rs @@ -51,6 +51,7 @@ impl Vim { ignore_punctuation, &text_layout_details, motion == Motion::NextSubwordStart { ignore_punctuation }, + !matches!(motion, Motion::NextWordStart { .. }), ) } _ => { @@ -89,7 +90,7 @@ impl Vim { if let Some(kind) = motion_kind { vim.copy_selections_content(editor, kind, window, cx); editor.insert("", window, cx); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); } }); }); @@ -122,7 +123,7 @@ impl Vim { if objects_found { vim.copy_selections_content(editor, MotionKind::Exclusive, window, cx); editor.insert("", window, cx); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); } }); }); @@ -148,6 +149,7 @@ fn expand_changed_word_selection( ignore_punctuation: bool, text_layout_details: &TextLayoutDetails, use_subword: bool, + always_advance: bool, ) -> Option<MotionKind> { let is_in_word = || { let classifier = map @@ -173,8 +175,14 @@ fn expand_changed_word_selection( selection.end = motion::next_subword_end(map, selection.end, ignore_punctuation, 1, false); } else { - selection.end = - motion::next_word_end(map, selection.end, ignore_punctuation, 1, false); + selection.end = motion::next_word_end( + map, + selection.end, + ignore_punctuation, + 1, + false, + always_advance, + ); } selection.end = motion::next_char(map, selection.end, false); } @@ -271,6 +279,10 @@ mod test { cx.simulate("c shift-w", "Test teˇst-test test") .await .assert_matches(); + + // on last character of word, `cw` doesn't eat subsequent punctuation + // see https://github.com/zed-industries/zed/issues/35269 + cx.simulate("c w", "tesˇt-test").await.assert_matches(); } #[gpui::test] diff --git a/crates/vim/src/normal/delete.rs b/crates/vim/src/normal/delete.rs index ccbb3dd0fd..2cf40292cf 100644 --- a/crates/vim/src/normal/delete.rs +++ b/crates/vim/src/normal/delete.rs @@ -82,7 +82,7 @@ impl Vim { selection.collapse_to(cursor, selection.goal) }); }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); }); } @@ -169,7 +169,7 @@ impl Vim { selection.collapse_to(cursor, selection.goal) }); }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); }); } diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index c747c30462..72edbe77ed 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -747,7 +747,7 @@ impl Vim { Vim::action( editor, cx, - |vim, action: &editor::AcceptEditPrediction, window, cx| { + |vim, action: &editor::actions::AcceptEditPrediction, window, cx| { vim.update_editor(window, cx, |_, editor, window, cx| { editor.accept_edit_prediction(action, window, cx); }); @@ -1741,11 +1741,11 @@ impl Vim { editor.set_autoindent(vim.should_autoindent()); editor.selections.line_mode = matches!(vim.mode, Mode::VisualLine); - let hide_inline_completions = match vim.mode { + let hide_edit_predictions = match vim.mode { Mode::Insert | Mode::Replace => false, _ => true, }; - editor.set_inline_completions_hidden_for_vim_mode(hide_inline_completions, window, cx); + editor.set_edit_predictions_hidden_for_vim_mode(hide_edit_predictions, window, cx); }); cx.notify() } diff --git a/crates/vim/test_data/test_change_w.json b/crates/vim/test_data/test_change_w.json index 27be543532..149dac8420 100644 --- a/crates/vim/test_data/test_change_w.json +++ b/crates/vim/test_data/test_change_w.json @@ -30,3 +30,7 @@ {"Key":"c"} {"Key":"shift-w"} {"Get":{"state":"Test teˇ test","mode":"Insert"}} +{"Put":{"state":"tesˇt-test"}} +{"Key":"c"} +{"Key":"w"} +{"Get":{"state":"tesˇ-test","mode":"Insert"}} diff --git a/crates/web_search/Cargo.toml b/crates/web_search/Cargo.toml index e5b8ca63b2..4ba46faec4 100644 --- a/crates/web_search/Cargo.toml +++ b/crates/web_search/Cargo.toml @@ -13,8 +13,8 @@ path = "src/web_search.rs" [dependencies] anyhow.workspace = true +cloud_llm_client.workspace = true collections.workspace = true gpui.workspace = true serde.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/web_search/src/web_search.rs b/crates/web_search/src/web_search.rs index a131b0de71..8578cfe4aa 100644 --- a/crates/web_search/src/web_search.rs +++ b/crates/web_search/src/web_search.rs @@ -1,8 +1,9 @@ +use std::sync::Arc; + use anyhow::Result; +use cloud_llm_client::WebSearchResponse; use collections::HashMap; use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task}; -use std::sync::Arc; -use zed_llm_client::WebSearchResponse; pub fn init(cx: &mut App) { let registry = cx.new(|_cx| WebSearchRegistry::default()); diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index 2e052796c4..f7a248d106 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,6 +14,7 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true +cloud_llm_client.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true @@ -22,4 +23,3 @@ serde.workspace = true serde_json.workspace = true web_search.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index adf79b0ff6..52ee0da0d4 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::Client; +use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Subscription, Task}; use http_client::{HttpClient, Method}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use web_search::{WebSearchProvider, WebSearchProviderId}; -use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; pub struct CloudWebSearchProvider { state: Entity<State>, diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml index 769dd8d6aa..acb3fe0f84 100644 --- a/crates/welcome/Cargo.toml +++ b/crates/welcome/Cargo.toml @@ -29,7 +29,6 @@ project.workspace = true serde.workspace = true settings.workspace = true telemetry.workspace = true -theme.workspace = true ui.workspace = true util.workspace = true vim_mode_setting.workspace = true diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs index 49bf2031ab..b0a1c316f4 100644 --- a/crates/welcome/src/welcome.rs +++ b/crates/welcome/src/welcome.rs @@ -1,10 +1,11 @@ -use client::{DisableAiSettings, TelemetrySettings, telemetry::Telemetry}; +use client::{TelemetrySettings, telemetry::Telemetry}; use db::kvp::KEY_VALUE_STORE; use gpui::{ Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, ParentElement, Render, Styled, Subscription, Task, WeakEntity, Window, actions, svg, }; use language::language_settings::{EditPredictionProvider, all_language_settings}; +use project::DisableAiSettings; use settings::{Settings, SettingsStore}; use std::sync::Arc; use ui::{CheckboxWithLabel, ElevationIndex, Tooltip, prelude::*}; @@ -21,7 +22,6 @@ pub use multibuffer_hint::*; mod base_keymap_picker; mod multibuffer_hint; -mod welcome_ui; actions!( welcome, diff --git a/crates/welcome/src/welcome_ui.rs b/crates/welcome/src/welcome_ui.rs deleted file mode 100644 index 622b6f448d..0000000000 --- a/crates/welcome/src/welcome_ui.rs +++ /dev/null @@ -1 +0,0 @@ -mod theme_preview; diff --git a/crates/welcome/src/welcome_ui/theme_preview.rs b/crates/welcome/src/welcome_ui/theme_preview.rs deleted file mode 100644 index b3a80c74c3..0000000000 --- a/crates/welcome/src/welcome_ui/theme_preview.rs +++ /dev/null @@ -1,280 +0,0 @@ -#![allow(unused, dead_code)] -use gpui::{Hsla, Length}; -use std::sync::Arc; -use theme::{Theme, ThemeRegistry}; -use ui::{ - IntoElement, RenderOnce, component_prelude::Documented, prelude::*, utils::inner_corner_radius, -}; - -/// Shows a preview of a theme as an abstract illustration -/// of a thumbnail-sized editor. -#[derive(IntoElement, RegisterComponent, Documented)] -pub struct ThemePreviewTile { - theme: Arc<Theme>, - selected: bool, - seed: f32, -} - -impl ThemePreviewTile { - pub fn new(theme: Arc<Theme>, selected: bool, seed: f32) -> Self { - Self { - theme, - selected, - seed, - } - } - - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self - } -} - -impl RenderOnce for ThemePreviewTile { - fn render(self, _window: &mut ui::Window, _cx: &mut ui::App) -> impl IntoElement { - let color = self.theme.colors(); - - let root_radius = px(8.0); - let root_border = px(2.0); - let root_padding = px(2.0); - let child_border = px(1.0); - let inner_radius = - inner_corner_radius(root_radius, root_border, root_padding, child_border); - - let item_skeleton = |w: Length, h: Pixels, bg: Hsla| div().w(w).h(h).rounded_full().bg(bg); - - let skeleton_height = px(4.); - - let sidebar_seeded_width = |seed: f32, index: usize| { - let value = (seed * 1000.0 + index as f32 * 10.0).sin() * 0.5 + 0.5; - 0.5 + value * 0.45 - }; - - let sidebar_skeleton_items = 8; - - let sidebar_skeleton = (0..sidebar_skeleton_items) - .map(|i| { - let width = sidebar_seeded_width(self.seed, i); - item_skeleton( - relative(width).into(), - skeleton_height, - color.text.alpha(0.45), - ) - }) - .collect::<Vec<_>>(); - - let sidebar = div() - .h_full() - .w(relative(0.25)) - .border_r(px(1.)) - .border_color(color.border_transparent) - .bg(color.panel_background) - .child( - div() - .p_2() - .flex() - .flex_col() - .size_full() - .gap(px(4.)) - .children(sidebar_skeleton), - ); - - let pseudo_code_skeleton = |theme: Arc<Theme>, seed: f32| -> AnyElement { - let colors = theme.colors(); - let syntax = theme.syntax(); - - let keyword_color = syntax.get("keyword").color; - let function_color = syntax.get("function").color; - let string_color = syntax.get("string").color; - let comment_color = syntax.get("comment").color; - let variable_color = syntax.get("variable").color; - let type_color = syntax.get("type").color; - let punctuation_color = syntax.get("punctuation").color; - - let syntax_colors = [ - keyword_color, - function_color, - string_color, - variable_color, - type_color, - punctuation_color, - comment_color, - ]; - - let line_width = |line_idx: usize, block_idx: usize| -> f32 { - let val = (seed * 100.0 + line_idx as f32 * 20.0 + block_idx as f32 * 5.0).sin() - * 0.5 - + 0.5; - 0.05 + val * 0.2 - }; - - let indentation = |line_idx: usize| -> f32 { - let step = line_idx % 6; - if step < 3 { - step as f32 * 0.1 - } else { - (5 - step) as f32 * 0.1 - } - }; - - let pick_color = |line_idx: usize, block_idx: usize| -> Hsla { - let idx = ((seed * 10.0 + line_idx as f32 * 7.0 + block_idx as f32 * 3.0).sin() - * 3.5) - .abs() as usize - % syntax_colors.len(); - syntax_colors[idx].unwrap_or(colors.text) - }; - - let line_count = 13; - - let lines = (0..line_count) - .map(|line_idx| { - let block_count = (((seed * 30.0 + line_idx as f32 * 12.0).sin() * 0.5 + 0.5) - * 3.0) - .round() as usize - + 2; - - let indent = indentation(line_idx); - - let blocks = (0..block_count) - .map(|block_idx| { - let width = line_width(line_idx, block_idx); - let color = pick_color(line_idx, block_idx); - item_skeleton(relative(width).into(), skeleton_height, color) - }) - .collect::<Vec<_>>(); - - h_flex().gap(px(2.)).ml(relative(indent)).children(blocks) - }) - .collect::<Vec<_>>(); - - v_flex() - .size_full() - .p_1() - .gap(px(6.)) - .children(lines) - .into_any_element() - }; - - let pane = div() - .h_full() - .flex_grow() - .flex() - .flex_col() - // .child( - // div() - // .w_full() - // .border_color(color.border) - // .border_b(px(1.)) - // .h(relative(0.1)) - // .bg(color.tab_bar_background), - // ) - .child( - div() - .size_full() - .overflow_hidden() - .bg(color.editor_background) - .p_2() - .child(pseudo_code_skeleton(self.theme.clone(), self.seed)), - ); - - let content = div().size_full().flex().child(sidebar).child(pane); - - div() - .size_full() - .rounded(root_radius) - .p(root_padding) - .border(root_border) - .border_color(color.border_transparent) - .when(self.selected, |this| { - this.border_color(color.border_selected) - }) - .child( - div() - .size_full() - .rounded(inner_radius) - .border(child_border) - .border_color(color.border) - .bg(color.background) - .child(content), - ) - } -} - -impl Component for ThemePreviewTile { - fn description() -> Option<&'static str> { - Some(Self::DOCS) - } - - fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> { - let theme_registry = ThemeRegistry::global(cx); - - let one_dark = theme_registry.get("One Dark"); - let one_light = theme_registry.get("One Light"); - let gruvbox_dark = theme_registry.get("Gruvbox Dark"); - let gruvbox_light = theme_registry.get("Gruvbox Light"); - - let themes_to_preview = vec![ - one_dark.clone().ok(), - one_light.clone().ok(), - gruvbox_dark.clone().ok(), - gruvbox_light.clone().ok(), - ] - .into_iter() - .flatten() - .collect::<Vec<_>>(); - - Some( - v_flex() - .gap_6() - .p_4() - .children({ - if let Some(one_dark) = one_dark.ok() { - vec![example_group(vec![ - single_example( - "Default", - div() - .w(px(240.)) - .h(px(180.)) - .child(ThemePreviewTile::new(one_dark.clone(), false, 0.42)) - .into_any_element(), - ), - single_example( - "Selected", - div() - .w(px(240.)) - .h(px(180.)) - .child(ThemePreviewTile::new(one_dark, true, 0.42)) - .into_any_element(), - ), - ])] - } else { - vec![] - } - }) - .child( - example_group(vec![single_example( - "Default Themes", - h_flex() - .gap_4() - .children( - themes_to_preview - .iter() - .enumerate() - .map(|(i, theme)| { - div().w(px(200.)).h(px(140.)).child(ThemePreviewTile::new( - theme.clone(), - false, - 0.42, - )) - }) - .collect::<Vec<_>>(), - ) - .into_any_element(), - )]) - .grow(), - ) - .into_any_element(), - ) - } -} diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index c7a2562a1b..2062255f4b 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -1664,10 +1664,33 @@ impl Pane { } if should_save { - if !Self::save_item(project.clone(), &pane, &*item_to_close, save_intent, cx) - .await? + match Self::save_item(project.clone(), &pane, &*item_to_close, save_intent, cx) + .await { - break; + Ok(success) => { + if !success { + break; + } + } + Err(err) => { + let answer = pane.update_in(cx, |_, window, cx| { + let detail = Self::file_names_for_prompt( + &mut [&item_to_close].into_iter(), + cx, + ); + window.prompt( + PromptLevel::Warning, + &format!("Unable to save file: {}", &err), + Some(&detail), + &["Close Without Saving", "Cancel"], + cx, + ) + })?; + match answer.await { + Ok(0) => {} + Ok(1..) | Err(_) => break, + } + } } } @@ -2832,7 +2855,7 @@ impl Pane { }) .collect::<Vec<_>>(); let tab_count = tab_items.len(); - if self.pinned_tab_count > tab_count { + if self.is_tab_pinned(tab_count) { log::warn!( "Pinned tab count ({}) exceeds actual tab count ({}). \ This should not happen. If possible, add reproduction steps, \ @@ -3030,7 +3053,7 @@ impl Pane { || cfg!(not(target_os = "macos")) && window.modifiers().control; let from_pane = dragged_tab.pane.clone(); - let from_ix = dragged_tab.ix; + self.workspace .update(cx, |_, cx| { cx.defer_in(window, move |workspace, window, cx| { @@ -3062,9 +3085,13 @@ impl Pane { } to_pane.update(cx, |this, _| { if to_pane == from_pane { - let moved_right = ix > from_ix; - let ix = if moved_right { ix - 1 } else { ix }; - let is_pinned_in_to_pane = this.is_tab_pinned(ix); + let actual_ix = this + .items + .iter() + .position(|item| item.item_id() == item_id) + .unwrap_or(0); + + let is_pinned_in_to_pane = this.is_tab_pinned(actual_ix); if !was_pinned_in_from_pane && is_pinned_in_to_pane { this.pinned_tab_count += 1; @@ -4950,6 +4977,43 @@ mod tests { assert_item_labels(&pane_a, ["B!", "A*!"], cx); } + #[gpui::test] + async fn test_dragging_pinned_tab_onto_unpinned_tab_reduces_unpinned_tab_count( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + + let project = Project::test(fs, None, cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let pane_a = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone()); + + // Add A, B to pane A and pin A + let item_a = add_labeled_item(&pane_a, "A", false, cx); + add_labeled_item(&pane_a, "B", false, cx); + pane_a.update_in(cx, |pane, window, cx| { + let ix = pane.index_for_item_id(item_a.item_id()).unwrap(); + pane.pin_tab_at(ix, window, cx); + }); + assert_item_labels(&pane_a, ["A!", "B*"], cx); + + // Drag pinned A on top of B in the same pane, which changes tab order to B, A + pane_a.update_in(cx, |pane, window, cx| { + let dragged_tab = DraggedTab { + pane: pane_a.clone(), + item: item_a.boxed_clone(), + ix: 0, + detail: 0, + is_active: true, + }; + pane.handle_tab_drop(&dragged_tab, 1, window, cx); + }); + + // Neither are pinned + assert_item_labels(&pane_a, ["B", "A*"], cx); + } + #[gpui::test] async fn test_drag_pinned_tab_beyond_unpinned_tab_in_same_pane_becomes_unpinned( cx: &mut TestAppContext, diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 3f8b098203..6fa5c969e7 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -939,6 +939,26 @@ impl WorkspaceDb { } } + query! { + pub async fn update_ssh_project_paths_query(ssh_project_id: u64, paths: String) -> Result<Option<SerializedSshProject>> { + UPDATE ssh_projects + SET paths = ?2 + WHERE id = ?1 + RETURNING id, host, port, paths, user + } + } + + pub(crate) async fn update_ssh_project_paths( + &self, + ssh_project_id: SshProjectId, + new_paths: Vec<String>, + ) -> Result<SerializedSshProject> { + let paths = serde_json::to_string(&new_paths)?; + self.update_ssh_project_paths_query(ssh_project_id.0, paths) + .await? + .context("failed to update ssh project paths") + } + query! { pub async fn next_id() -> Result<WorkspaceId> { INSERT INTO workspaces DEFAULT VALUES RETURNING workspace_id @@ -2624,4 +2644,56 @@ mod tests { assert_eq!(workspace.center_group, new_workspace.center_group); } + + #[gpui::test] + async fn test_update_ssh_project_paths() { + zlog::init_test(); + + let db = WorkspaceDb::open_test_db("test_update_ssh_project_paths").await; + + let (host, port, initial_paths, user) = ( + "example.com".to_string(), + Some(22_u16), + vec!["/home/user".to_string(), "/etc/nginx".to_string()], + Some("user".to_string()), + ); + + let project = db + .get_or_create_ssh_project(host.clone(), port, initial_paths.clone(), user.clone()) + .await + .unwrap(); + + assert_eq!(project.host, host); + assert_eq!(project.paths, initial_paths); + assert_eq!(project.user, user); + + let new_paths = vec![ + "/home/user".to_string(), + "/etc/nginx".to_string(), + "/var/log".to_string(), + "/opt/app".to_string(), + ]; + + let updated_project = db + .update_ssh_project_paths(project.id, new_paths.clone()) + .await + .unwrap(); + + assert_eq!(updated_project.id, project.id); + assert_eq!(updated_project.paths, new_paths); + + let retrieved_project = db + .get_ssh_project( + host.clone(), + port, + serde_json::to_string(&new_paths).unwrap(), + user.clone(), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(retrieved_project.id, project.id); + assert_eq!(retrieved_project.paths, new_paths); + } } diff --git a/crates/workspace/src/tasks.rs b/crates/workspace/src/tasks.rs index 26edbd8d03..32d066c7eb 100644 --- a/crates/workspace/src/tasks.rs +++ b/crates/workspace/src/tasks.rs @@ -73,7 +73,7 @@ impl Workspace { if let Some(terminal_provider) = self.terminal_provider.as_ref() { let task_status = terminal_provider.spawn(spawn_in_terminal, window, cx); - cx.background_spawn(async move { + let task = cx.background_spawn(async move { match task_status.await { Some(Ok(status)) => { if status.success() { @@ -82,11 +82,11 @@ impl Workspace { log::debug!("Task spawn failed, code: {:?}", status.code()); } } - Some(Err(e)) => log::error!("Task spawn failed: {e}"), + Some(Err(e)) => log::error!("Task spawn failed: {e:#}"), None => log::debug!("Task spawn got cancelled"), } - }) - .detach(); + }); + self.scheduled_tasks.push(task); } } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 0ee8177dd8..63953ff802 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -48,7 +48,10 @@ pub use item::{ ProjectItem, SerializableItem, SerializableItemHandle, WeakItemHandle, }; use itertools::Itertools; -use language::{Buffer, LanguageRegistry, Rope}; +use language::{ + Buffer, LanguageRegistry, Rope, + language_settings::{AllLanguageSettings, all_language_settings}, +}; pub use modal_layer::*; use node_runtime::NodeRuntime; use notifications::{ @@ -74,7 +77,7 @@ use remote::{SshClientDelegate, SshConnectionOptions, ssh_session::ConnectionIde use schemars::JsonSchema; use serde::Deserialize; use session::AppSession; -use settings::Settings; +use settings::{Settings, update_settings_file}; use shared_screen::SharedScreen; use sqlez::{ bindable::{Bind, Column, StaticColumnCount}, @@ -233,6 +236,8 @@ actions!( ToggleBottomDock, /// Toggles centered layout mode. ToggleCenteredLayout, + /// Toggles edit prediction feature globally for all files. + ToggleEditPrediction, /// Toggles the left dock. ToggleLeftDock, /// Toggles the right dock. @@ -1065,7 +1070,6 @@ pub struct Workspace { center: PaneGroup, left_dock: Entity<Dock>, bottom_dock: Entity<Dock>, - bottom_dock_layout: BottomDockLayout, right_dock: Entity<Dock>, panes: Vec<Entity<Pane>>, panes_by_item: HashMap<EntityId, WeakEntity<Pane>>, @@ -1091,7 +1095,8 @@ pub struct Workspace { _subscriptions: Vec<Subscription>, _apply_leader_updates: Task<Result<()>>, _observe_current_user: Task<Result<()>>, - _schedule_serialize: Option<Task<()>>, + _schedule_serialize_workspace: Option<Task<()>>, + _schedule_serialize_ssh_paths: Option<Task<()>>, pane_history_timestamp: Arc<AtomicUsize>, bounds: Bounds<Pixels>, pub centered_layout: bool, @@ -1104,6 +1109,7 @@ pub struct Workspace { serialized_ssh_project: Option<SerializedSshProject>, _items_serializer: Task<Result<()>>, session_id: Option<String>, + scheduled_tasks: Vec<Task<()>>, } impl EventEmitter<Event> for Workspace {} @@ -1149,6 +1155,8 @@ impl Workspace { project::Event::WorktreeRemoved(_) | project::Event::WorktreeAdded(_) => { this.update_window_title(window, cx); + this.update_ssh_paths(cx); + this.serialize_ssh_paths(window, cx); this.serialize_workspace(window, cx); // This event could be triggered by `AddFolderToProject` or `RemoveFromProject`. this.update_history(cx); @@ -1306,7 +1314,6 @@ impl Workspace { ) .detach(); - let bottom_dock_layout = WorkspaceSettings::get_global(cx).bottom_dock_layout; let left_dock = Dock::new(DockPosition::Left, modal_layer.clone(), window, cx); let bottom_dock = Dock::new(DockPosition::Bottom, modal_layer.clone(), window, cx); let right_dock = Dock::new(DockPosition::Right, modal_layer.clone(), window, cx); @@ -1405,7 +1412,6 @@ impl Workspace { suppressed_notifications: HashSet::default(), left_dock, bottom_dock, - bottom_dock_layout, right_dock, project: project.clone(), follower_states: Default::default(), @@ -1418,7 +1424,8 @@ impl Workspace { app_state, _observe_current_user, _apply_leader_updates, - _schedule_serialize: None, + _schedule_serialize_workspace: None, + _schedule_serialize_ssh_paths: None, leader_updates_tx, _subscriptions: subscriptions, pane_history_timestamp, @@ -1435,6 +1442,7 @@ impl Workspace { _items_serializer, session_id: Some(session_id), serialized_ssh_project: None, + scheduled_tasks: Vec::new(), } } @@ -1631,10 +1639,6 @@ impl Workspace { &self.bottom_dock } - pub fn bottom_dock_layout(&self) -> BottomDockLayout { - self.bottom_dock_layout - } - pub fn set_bottom_dock_layout( &mut self, layout: BottomDockLayout, @@ -1646,7 +1650,6 @@ impl Workspace { content.bottom_dock_layout = Some(layout); }); - self.bottom_dock_layout = layout; cx.notify(); self.serialize_workspace(window, cx); } @@ -5079,6 +5082,46 @@ impl Workspace { } } + fn update_ssh_paths(&mut self, cx: &App) { + let project = self.project().read(cx); + if !project.is_local() { + let paths: Vec<String> = project + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).abs_path().to_string_lossy().to_string()) + .collect(); + if let Some(ssh_project) = &mut self.serialized_ssh_project { + ssh_project.paths = paths; + } + } + } + + fn serialize_ssh_paths(&mut self, window: &mut Window, cx: &mut Context<Workspace>) { + if self._schedule_serialize_ssh_paths.is_none() { + self._schedule_serialize_ssh_paths = + Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update_in(cx, |this, window, cx| { + let task = if let Some(ssh_project) = &this.serialized_ssh_project { + let ssh_project_id = ssh_project.id; + let ssh_project_paths = ssh_project.paths.clone(); + window.spawn(cx, async move |_| { + persistence::DB + .update_ssh_project_paths(ssh_project_id, ssh_project_paths) + .await + }) + } else { + Task::ready(Err(anyhow::anyhow!("No SSH project to serialize"))) + }; + task.detach(); + this._schedule_serialize_ssh_paths.take(); + }) + .log_err(); + })); + } + } + fn remove_panes(&mut self, member: Member, window: &mut Window, cx: &mut Context<Workspace>) { match member { Member::Axis(PaneAxis { members, .. }) => { @@ -5122,17 +5165,18 @@ impl Workspace { } fn serialize_workspace(&mut self, window: &mut Window, cx: &mut Context<Self>) { - if self._schedule_serialize.is_none() { - self._schedule_serialize = Some(cx.spawn_in(window, async move |this, cx| { - cx.background_executor() - .timer(Duration::from_millis(100)) - .await; - this.update_in(cx, |this, window, cx| { - this.serialize_workspace_internal(window, cx).detach(); - this._schedule_serialize.take(); - }) - .log_err(); - })); + if self._schedule_serialize_workspace.is_none() { + self._schedule_serialize_workspace = + Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update_in(cx, |this, window, cx| { + this.serialize_workspace_internal(window, cx).detach(); + this._schedule_serialize_workspace.take(); + }) + .log_err(); + })); } } @@ -5507,6 +5551,7 @@ impl Workspace { .on_action(cx.listener(Self::activate_pane_at_index)) .on_action(cx.listener(Self::move_item_to_pane_at_index)) .on_action(cx.listener(Self::move_focused_panel_to_next_position)) + .on_action(cx.listener(Self::toggle_edit_predictions_all_files)) .on_action(cx.listener(|workspace, _: &Unfollow, window, cx| { let pane = workspace.active_pane().clone(); workspace.unfollow_in_pane(&pane, window, cx); @@ -5695,7 +5740,6 @@ impl Workspace { let client = project.read(cx).client(); let user_store = project.read(cx).user_store(); - let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let session = cx.new(|cx| AppSession::new(Session::test(), cx)); window.activate_window(); @@ -5939,6 +5983,19 @@ impl Workspace { } }); } + + fn toggle_edit_predictions_all_files( + &mut self, + _: &ToggleEditPrediction, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + let fs = self.project().read(cx).fs().clone(); + let show_edit_predictions = all_language_settings(None, cx).show_edit_predictions(None, cx); + update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| { + file.defaults.show_edit_predictions = Some(!show_edit_predictions) + }); + } } fn leader_border_for_pane( @@ -6244,6 +6301,7 @@ impl Render for Workspace { .iter() .map(|(_, notification)| notification.entity_id()) .collect::<Vec<_>>(); + let bottom_dock_layout = WorkspaceSettings::get_global(cx).bottom_dock_layout; client_side_decorations( self.actions(div(), window, cx) @@ -6367,7 +6425,7 @@ impl Render for Workspace { )) }) .child({ - match self.bottom_dock_layout { + match bottom_dock_layout { BottomDockLayout::Full => div() .flex() .flex_col() @@ -6899,10 +6957,13 @@ async fn join_channel_internal( match status { Status::Connecting | Status::Authenticating + | Status::Authenticated | Status::Reconnecting | Status::Reauthenticating => continue, Status::Connected { .. } => break 'outer, - Status::SignedOut => return Err(ErrorCode::SignedOut.into()), + Status::SignedOut | Status::AuthenticationError => { + return Err(ErrorCode::SignedOut.into()); + } Status::UpgradeRequired => return Err(ErrorCode::UpgradeRequired.into()), Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { return Err(ErrorCode::Disconnected.into()); diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index e6949f62df..b5a0f71e81 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -62,7 +62,7 @@ use std::{ }, time::{Duration, Instant}, }; -use sum_tree::{Bias, Edit, KeyedItem, SeekTarget, SumTree, Summary, TreeMap, TreeSet}; +use sum_tree::{Bias, Dimensions, Edit, KeyedItem, SeekTarget, SumTree, Summary, TreeMap, TreeSet}; use text::{LineEnding, Rope}; use util::{ ResultExt, @@ -3566,10 +3566,15 @@ impl<'a> sum_tree::Dimension<'a, PathSummary<GitSummary>> for GitSummary { } } -impl<'a> sum_tree::SeekTarget<'a, PathSummary<GitSummary>, (TraversalProgress<'a>, GitSummary)> +impl<'a> + sum_tree::SeekTarget<'a, PathSummary<GitSummary>, Dimensions<TraversalProgress<'a>, GitSummary>> for PathTarget<'_> { - fn cmp(&self, cursor_location: &(TraversalProgress<'a>, GitSummary), _: &()) -> Ordering { + fn cmp( + &self, + cursor_location: &Dimensions<TraversalProgress<'a>, GitSummary>, + _: &(), + ) -> Ordering { self.cmp_path(&cursor_location.0.max_path) } } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index a864ece683..5bd6d981fa 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.198.0" +version = "0.199.0" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team <hi@zed.dev>"] @@ -45,6 +45,7 @@ collections.workspace = true command_palette.workspace = true component.workspace = true copilot.workspace = true +crashes.workspace = true dap_adapters.workspace = true db.workspace = true debug_adapter_extension.workspace = true @@ -76,7 +77,7 @@ gpui_tokio.workspace = true http_client.workspace = true image_viewer.workspace = true indoc.workspace = true -inline_completion_button.workspace = true +edit_prediction_button.workspace = true inspector_ui.workspace = true install_cli.workspace = true jj_ui.workspace = true @@ -106,6 +107,7 @@ outline_panel.workspace = true parking_lot.workspace = true paths.workspace = true picker.workspace = true +settings_profile_selector.workspace = true profiling.workspace = true project.workspace = true project_panel.workspace = true @@ -116,6 +118,7 @@ recent_projects.workspace = true release_channel.workspace = true remote.workspace = true repl.workspace = true +reqwest.workspace = true reqwest_client.workspace = true rope.workspace = true search.workspace = true diff --git a/crates/zed/resources/app-icon-nightly.png b/crates/zed/resources/app-icon-nightly.png index 6c5241f207..776cd06b1b 100644 Binary files a/crates/zed/resources/app-icon-nightly.png and b/crates/zed/resources/app-icon-nightly.png differ diff --git a/crates/zed/resources/app-icon-nightly@2x.png b/crates/zed/resources/app-icon-nightly@2x.png index e31eeb74f2..6d781594ac 100644 Binary files a/crates/zed/resources/app-icon-nightly@2x.png and b/crates/zed/resources/app-icon-nightly@2x.png differ diff --git a/crates/zed/resources/flatpak/manifest-template.json b/crates/zed/resources/flatpak/manifest-template.json index 1560027e9f..0a14a1c2b0 100644 --- a/crates/zed/resources/flatpak/manifest-template.json +++ b/crates/zed/resources/flatpak/manifest-template.json @@ -38,7 +38,7 @@ }, "build-commands": [ "install -Dm644 $ICON_FILE.png /app/share/icons/hicolor/512x512/apps/$APP_ID.png", - "envsubst < zed.desktop.in > zed.desktop && install -Dm644 zed.desktop /app/share/applications/$APP_ID.desktop", + "envsubst < zed.desktop.in > zed.desktop && install -Dm755 zed.desktop /app/share/applications/$APP_ID.desktop", "envsubst < flatpak/zed.metainfo.xml.in > zed.metainfo.xml && install -Dm644 zed.metainfo.xml /app/share/metainfo/$APP_ID.metainfo.xml", "sed -i -e '/@release_info@/{r flatpak/release-info/$CHANNEL' -e 'd}' /app/share/metainfo/$APP_ID.metainfo.xml", "install -Dm755 bin/zed /app/bin/zed", diff --git a/crates/zed/resources/windows/zed.iss b/crates/zed/resources/windows/zed.iss index 9d104d1f15..2e76f35a0b 100644 --- a/crates/zed/resources/windows/zed.iss +++ b/crates/zed/resources/windows/zed.iss @@ -62,6 +62,7 @@ Source: "{#ResourcesDir}\Zed.exe"; DestDir: "{code:GetInstallDir}"; Flags: ignor Source: "{#ResourcesDir}\bin\*"; DestDir: "{code:GetInstallDir}\bin"; Flags: ignoreversion Source: "{#ResourcesDir}\tools\*"; DestDir: "{app}\tools"; Flags: ignoreversion Source: "{#ResourcesDir}\appx\*"; DestDir: "{app}\appx"; BeforeInstall: RemoveAppxPackage; AfterInstall: AddAppxPackage; Flags: ignoreversion; Check: IsWindows11OrLater +Source: "{#ResourcesDir}\amd_ags_x64.dll"; DestDir: "{app}"; Flags: ignoreversion [Icons] Name: "{group}\{#AppName}"; Filename: "{app}\{#AppExeName}.exe"; AppUserModelID: "{#AppUserId}" @@ -1245,16 +1246,6 @@ Root: HKCU; Subkey: "Software\Classes\zed\DefaultIcon"; ValueType: "string"; Val Root: HKCU; Subkey: "Software\Classes\zed\shell\open\command"; ValueType: "string"; ValueData: """{app}\Zed.exe"" ""%1""" [Code] -function InitializeSetup(): Boolean; -begin - Result := True; - - if not WizardSilent() and IsAdmin() then begin - MsgBox('This User Installer is not meant to be run as an Administrator.', mbError, MB_OK); - Result := False; - end; -end; - function WizardNotSilent(): Boolean; begin Result := not WizardSilent(); diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index d0b9c53397..e4a14b5d32 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -42,7 +42,7 @@ use theme::{ ActiveTheme, IconThemeNotFoundError, SystemAppearance, ThemeNotFoundError, ThemeRegistry, ThemeSettings, }; -use util::{ConnectionResult, ResultExt, TryFutureExt, maybe}; +use util::{ResultExt, TryFutureExt, maybe}; use uuid::Uuid; use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::{ @@ -51,9 +51,9 @@ use workspace::{ }; use zed::{ OpenListener, OpenRequest, RawOpenRequest, app_menus, build_window_options, - derive_paths_with_position, handle_cli_connection, handle_keymap_file_changes, - handle_settings_changed, handle_settings_file_changes, initialize_workspace, - inline_completion_registry, open_paths_with_positions, + derive_paths_with_position, edit_prediction_registry, handle_cli_connection, + handle_keymap_file_changes, handle_settings_changed, handle_settings_file_changes, + initialize_workspace, open_paths_with_positions, }; use crate::zed::OpenRequestKind; @@ -172,6 +172,12 @@ pub fn main() { let args = Args::parse(); + // `zed --crash-handler` Makes zed operate in minidump crash handler mode + if let Some(socket) = &args.crash_handler { + crashes::crash_server(socket.as_path()); + return; + } + // `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass if let Some(socket) = &args.askpass { askpass::main(socket); @@ -264,6 +270,9 @@ pub fn main() { let session_id = Uuid::new_v4().to_string(); let session = app.background_executor().block(Session::new()); + app.background_executor() + .spawn(crashes::init(session_id.clone())) + .detach(); reliability::init_panic_hook( app_version, app_commit_sha.clone(), @@ -559,11 +568,7 @@ pub fn main() { web_search::init(cx); web_search_providers::init(app_state.client.clone(), cx); snippet_provider::init(cx); - inline_completion_registry::init( - app_state.client.clone(), - app_state.user_store.clone(), - cx, - ); + edit_prediction_registry::init(app_state.client.clone(), app_state.user_store.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); agent_ui::init( app_state.fs.clone(), @@ -613,6 +618,7 @@ pub fn main() { language_selector::init(cx); toolchain_selector::init(cx); theme_selector::init(cx); + settings_profile_selector::init(cx); language_tools::init(cx); call::init(app_state.client.clone(), app_state.user_store.clone(), cx); notifications::init(app_state.client.clone(), app_state.user_store.clone(), cx); @@ -681,17 +687,9 @@ pub fn main() { cx.spawn({ let client = app_state.client.clone(); - async move |cx| match authenticate(client, &cx).await { - ConnectionResult::Timeout => log::error!("Timeout during initial auth"), - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during initial auth") - } - ConnectionResult::Result(r) => { - r.log_err(); - } - } + async move |cx| authenticate(client, &cx).await }) - .detach(); + .detach_and_log_err(cx); let urls: Vec<_> = args .paths_or_urls @@ -841,15 +839,7 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut let client = app_state.client.clone(); // we continue even if authentication fails as join_channel/ open channel notes will // show a visible error message. - match authenticate(client, &cx).await { - ConnectionResult::Timeout => { - log::error!("Timeout during open request handling") - } - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during open request handling") - } - ConnectionResult::Result(r) => r?, - }; + authenticate(client, &cx).await.log_err(); if let Some(channel_id) = request.join_channel { cx.update(|cx| { @@ -899,18 +889,18 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut } } -async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> ConnectionResult<()> { +async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> Result<()> { if stdout_is_a_pty() { if client::IMPERSONATE_LOGIN.is_some() { - return client.authenticate_and_connect(false, cx).await; + client.sign_in_with_optional_connect(false, cx).await?; } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } - ConnectionResult::Result(Ok(())) + Ok(()) } async fn system_id() -> Result<IdType> { @@ -1144,6 +1134,7 @@ fn init_paths() -> HashMap<io::ErrorKind, Vec<&'static Path>> { paths::config_dir(), paths::extensions_dir(), paths::languages_dir(), + paths::debug_adapters_dir(), paths::database_dir(), paths::logs_dir(), paths::temp_dir(), @@ -1203,6 +1194,11 @@ struct Args { #[arg(long, hide = true)] nc: Option<String>, + /// Used for recording minidumps on crashes by having Zed run a separate + /// process communicating over a socket. + #[arg(long, hide = true)] + crash_handler: Option<PathBuf>, + /// Run zed in the foreground, only used on Windows, to match the behavior on macOS. #[arg(long)] #[cfg(target_os = "windows")] diff --git a/crates/zed/src/reliability.rs b/crates/zed/src/reliability.rs index ccbe57e7b3..ed149a470a 100644 --- a/crates/zed/src/reliability.rs +++ b/crates/zed/src/reliability.rs @@ -2,21 +2,32 @@ use crate::stdout_is_a_pty; use anyhow::{Context as _, Result}; use backtrace::{self, Backtrace}; use chrono::Utc; -use client::{TelemetrySettings, telemetry}; +use client::{ + TelemetrySettings, + telemetry::{self, MINIDUMP_ENDPOINT}, +}; use db::kvp::KEY_VALUE_STORE; +use futures::AsyncReadExt; use gpui::{App, AppContext as _, SemanticVersion}; use http_client::{self, HttpClient, HttpClientWithUrl, HttpRequestExt, Method}; use paths::{crashes_dir, crashes_retired_dir}; use project::Project; use release_channel::{AppCommitSha, RELEASE_CHANNEL, ReleaseChannel}; +use reqwest::multipart::{Form, Part}; use settings::Settings; use smol::stream::StreamExt; use std::{ env, ffi::{OsStr, c_void}, - sync::{Arc, atomic::Ordering}, + fs, + io::Write, + panic, + sync::{ + Arc, + atomic::{AtomicU32, Ordering}, + }, + thread, }; -use std::{io::Write, panic, sync::atomic::AtomicU32, thread}; use telemetry_events::{LocationData, Panic, PanicRequest}; use url::Url; use util::ResultExt; @@ -37,9 +48,10 @@ pub fn init_panic_hook( if prior_panic_count > 0 { // Give the panic-ing thread time to write the panic file loop { - std::thread::yield_now(); + thread::yield_now(); } } + crashes::handle_panic(); let thread = thread::current(); let thread_name = thread.name().unwrap_or("<unnamed>"); @@ -63,7 +75,7 @@ pub fn init_panic_hook( location.column(), match app_commit_sha.as_ref() { Some(commit_sha) => format!( - "https://github.com/zed-industries/zed/blob/{}/src/{}#L{} \ + "https://github.com/zed-industries/zed/blob/{}/{}#L{} \ (may not be uploaded, line may be incorrect if files modified)\n", commit_sha.full(), location.file(), @@ -136,9 +148,8 @@ pub fn init_panic_hook( if let Some(panic_data_json) = serde_json::to_string(&panic_data).log_err() { let timestamp = chrono::Utc::now().format("%Y_%m_%d %H_%M_%S").to_string(); let panic_file_path = paths::logs_dir().join(format!("zed-{timestamp}.panic")); - let panic_file = std::fs::OpenOptions::new() - .append(true) - .create(true) + let panic_file = fs::OpenOptions::new() + .create_new(true) .open(&panic_file_path) .log_err(); if let Some(mut panic_file) = panic_file { @@ -205,27 +216,31 @@ pub fn init( if let Some(ssh_client) = project.ssh_client() { ssh_client.update(cx, |client, cx| { if TelemetrySettings::get_global(cx).diagnostics { - let request = client.proto_client().request(proto::GetPanicFiles {}); + let request = client.proto_client().request(proto::GetCrashFiles {}); cx.background_spawn(async move { - let panic_files = request.await?; - for file in panic_files.file_contents { - let panic: Option<Panic> = serde_json::from_str(&file) - .log_err() - .or_else(|| { - file.lines() - .next() - .and_then(|line| serde_json::from_str(line).ok()) - }) - .unwrap_or_else(|| { - log::error!("failed to deserialize panic file {:?}", file); - None - }); + let crash_files = request.await?; + for crash in crash_files.crashes { + let mut panic: Option<Panic> = crash + .panic_contents + .and_then(|s| serde_json::from_str(&s).log_err()); - if let Some(mut panic) = panic { + if let Some(panic) = panic.as_mut() { panic.session_id = session_id.clone(); panic.system_id = system_id.clone(); panic.installation_id = installation_id.clone(); + } + if let Some(minidump) = crash.minidump_contents { + upload_minidump( + http_client.clone(), + minidump.clone(), + panic.as_ref(), + ) + .await + .log_err(); + } + + if let Some(panic) = panic { upload_panic(&http_client, &panic_report_url, panic, &mut None) .await?; } @@ -510,6 +525,22 @@ async fn upload_previous_panics( }); if let Some(panic) = panic { + let minidump_path = paths::logs_dir() + .join(&panic.session_id) + .with_extension("dmp"); + if minidump_path.exists() { + let minidump = smol::fs::read(&minidump_path) + .await + .context("Failed to read minidump")?; + if upload_minidump(http.clone(), minidump, Some(&panic)) + .await + .log_err() + .is_some() + { + fs::remove_file(minidump_path).ok(); + } + } + if !upload_panic(&http, &panic_report_url, panic, &mut most_recent_panic).await? { continue; } @@ -517,13 +548,75 @@ async fn upload_previous_panics( } // We've done what we can, delete the file - std::fs::remove_file(child_path) + fs::remove_file(child_path) .context("error removing panic") .log_err(); } + + // loop back over the directory again to upload any minidumps that are missing panics + let mut children = smol::fs::read_dir(paths::logs_dir()).await?; + while let Some(child) = children.next().await { + let child = child?; + let child_path = child.path(); + if child_path.extension() != Some(OsStr::new("dmp")) { + continue; + } + if upload_minidump( + http.clone(), + smol::fs::read(&child_path) + .await + .context("Failed to read minidump")?, + None, + ) + .await + .log_err() + .is_some() + { + fs::remove_file(child_path).ok(); + } + } + Ok(most_recent_panic) } +async fn upload_minidump( + http: Arc<HttpClientWithUrl>, + minidump: Vec<u8>, + panic: Option<&Panic>, +) -> Result<()> { + let minidump_endpoint = MINIDUMP_ENDPOINT + .to_owned() + .ok_or_else(|| anyhow::anyhow!("Minidump endpoint not set"))?; + + let mut form = Form::new() + .part( + "upload_file_minidump", + Part::bytes(minidump) + .file_name("minidump.dmp") + .mime_str("application/octet-stream")?, + ) + .text("platform", "rust"); + if let Some(panic) = panic { + form = form.text( + "release", + format!("{}-{}", panic.release_channel, panic.app_version), + ); + // TODO: tack on more fields + } + + let mut response_text = String::new(); + let mut response = http.send_multipart_form(&minidump_endpoint, form).await?; + response + .body_mut() + .read_to_string(&mut response_text) + .await?; + if !response.status().is_success() { + anyhow::bail!("failed to upload minidump: {response_text}"); + } + log::info!("Uploaded minidump. event id: {response_text}"); + Ok(()) +} + async fn upload_panic( http: &Arc<HttpClientWithUrl>, panic_report_url: &Url, diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 0a90f89fa4..ec62ed33fd 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -1,6 +1,6 @@ mod app_menus; pub mod component_preview; -pub mod inline_completion_registry; +pub mod edit_prediction_registry; #[cfg(target_os = "macos")] pub(crate) mod mac_only_instance; mod migrate; @@ -126,17 +126,28 @@ pub fn init(cx: &mut App) { cx.on_action(quit); cx.on_action(|_: &RestoreBanner, cx| title_bar::restore_banner(cx)); - if ReleaseChannel::global(cx) == ReleaseChannel::Dev || cx.has_flag::<PanicFeatureFlag>() { - cx.on_action(|_: &TestPanic, _| panic!("Ran the TestPanic action")); - cx.on_action(|_: &TestCrash, _| { - unsafe extern "C" { - fn puts(s: *const i8); - } - unsafe { - puts(0xabad1d3a as *const i8); - } - }); - } + let flag = cx.wait_for_flag::<PanicFeatureFlag>(); + cx.spawn(async |cx| { + if cx + .update(|cx| ReleaseChannel::global(cx) == ReleaseChannel::Dev) + .unwrap_or_default() + || flag.await + { + cx.update(|cx| { + cx.on_action(|_: &TestPanic, _| panic!("Ran the TestPanic action")); + cx.on_action(|_: &TestCrash, _| { + unsafe extern "C" { + fn puts(s: *const i8); + } + unsafe { + puts(0xabad1d3a as *const i8); + } + }); + }) + .ok(); + }; + }) + .detach(); cx.on_action(|_: &OpenLog, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { open_log_file(workspace, window, cx); @@ -321,18 +332,18 @@ pub fn initialize_workspace( show_software_emulation_warning_if_needed(specs, window, cx); } - let inline_completion_menu_handle = PopoverMenuHandle::default(); + let edit_prediction_menu_handle = PopoverMenuHandle::default(); let edit_prediction_button = cx.new(|cx| { - inline_completion_button::InlineCompletionButton::new( + edit_prediction_button::EditPredictionButton::new( app_state.fs.clone(), app_state.user_store.clone(), - inline_completion_menu_handle.clone(), + edit_prediction_menu_handle.clone(), cx, ) }); workspace.register_action({ - move |_, _: &inline_completion_button::ToggleMenu, window, cx| { - inline_completion_menu_handle.toggle(window, cx); + move |_, _: &edit_prediction_button::ToggleMenu, window, cx| { + edit_prediction_menu_handle.toggle(window, cx); } }); @@ -4343,6 +4354,7 @@ mod tests { "menu", "notebook", "notification_panel", + "onboarding", "outline", "outline_panel", "pane", @@ -4355,6 +4367,7 @@ mod tests { "repl", "rules_library", "search", + "settings_profile_selector", "snippets", "supermaven", "svg", diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index 78532b10b4..15d5659f03 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -24,6 +24,10 @@ pub fn app_menus() -> Vec<Menu> { zed_actions::OpenDefaultKeymap, ), MenuItem::action("Open Project Settings", super::OpenProjectSettings), + MenuItem::action( + "Select Settings Profile...", + zed_actions::settings_profile_selector::Toggle, + ), MenuItem::action( "Select Theme...", zed_actions::theme_selector::Toggle::default(), diff --git a/crates/zed/src/zed/component_preview.rs b/crates/zed/src/zed/component_preview.rs index 670793cff3..480505338b 100644 --- a/crates/zed/src/zed/component_preview.rs +++ b/crates/zed/src/zed/component_preview.rs @@ -105,6 +105,7 @@ enum PreviewPage { struct ComponentPreview { active_page: PreviewPage, active_thread: Option<Entity<ActiveThread>>, + reset_key: usize, component_list: ListState, component_map: HashMap<ComponentId, ComponentMetadata>, components: Vec<ComponentMetadata>, @@ -138,8 +139,7 @@ impl ComponentPreview { let project_clone = project.clone(); cx.spawn_in(window, async move |entity, cx| { - let thread_store_future = - load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx); + let thread_store_future = load_preview_thread_store(project_clone.clone(), cx); let text_thread_store_future = load_preview_text_thread_store(workspace_clone.clone(), project_clone.clone(), cx); @@ -188,6 +188,7 @@ impl ComponentPreview { let mut component_preview = Self { active_page, active_thread: None, + reset_key: 0, component_list, component_map: component_registry.component_map(), components: sorted_components, @@ -265,8 +266,13 @@ impl ComponentPreview { } fn set_active_page(&mut self, page: PreviewPage, cx: &mut Context<Self>) { - self.active_page = page; - cx.emit(ItemEvent::UpdateTab); + if self.active_page == page { + // Force the current preview page to render again + self.reset_key = self.reset_key.wrapping_add(1); + } else { + self.active_page = page; + cx.emit(ItemEvent::UpdateTab); + } cx.notify(); } @@ -690,6 +696,7 @@ impl ComponentPreview { component.clone(), self.workspace.clone(), self.active_thread.clone(), + self.reset_key, )) .into_any_element() } else { @@ -1041,6 +1048,7 @@ pub struct ComponentPreviewPage { component: ComponentMetadata, workspace: WeakEntity<Workspace>, active_thread: Option<Entity<ActiveThread>>, + reset_key: usize, } impl ComponentPreviewPage { @@ -1048,6 +1056,7 @@ impl ComponentPreviewPage { component: ComponentMetadata, workspace: WeakEntity<Workspace>, active_thread: Option<Entity<ActiveThread>>, + reset_key: usize, // languages: Arc<LanguageRegistry> ) -> Self { Self { @@ -1055,6 +1064,7 @@ impl ComponentPreviewPage { component, workspace, active_thread, + reset_key, } } @@ -1155,6 +1165,7 @@ impl ComponentPreviewPage { }; v_flex() + .id(("component-preview", self.reset_key)) .size_full() .flex_1() .px_12() diff --git a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs index 825744572d..de98106fae 100644 --- a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs +++ b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs @@ -12,21 +12,19 @@ use ui::{App, Window}; use workspace::Workspace; pub fn load_preview_thread_store( - workspace: WeakEntity<Workspace>, project: Entity<Project>, cx: &mut AsyncApp, ) -> Task<Result<Entity<ThreadStore>>> { - workspace - .update(cx, |_, cx| { - ThreadStore::load( - project.clone(), - cx.new(|_| ToolWorkingSet::default()), - None, - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - }) - .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) + cx.update(|cx| { + ThreadStore::load( + project.clone(), + cx.new(|_| ToolWorkingSet::default()), + None, + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) } pub fn load_preview_text_thread_store( diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs similarity index 91% rename from crates/zed/src/zed/inline_completion_registry.rs rename to crates/zed/src/zed/edit_prediction_registry.rs index f2e9d21b96..b9f561c0e7 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -11,7 +11,7 @@ use supermaven::{Supermaven, SupermavenCompletionProvider}; use ui::Window; use util::ResultExt; use workspace::Workspace; -use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; +use zeta::{ProviderDataCollection, ZetaEditPredictionProvider}; pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) { let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default(); @@ -90,10 +90,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) { let new_provider = all_language_settings(None, cx).edit_predictions.provider; if new_provider != provider { - let tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); + let tos_accepted = user_store.read(cx).has_accepted_terms_of_service(); telemetry::event!( "Edit Prediction Provider Changed", @@ -174,7 +171,7 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed editor .register_action(cx.listener( |editor, _: &copilot::Suggest, window: &mut Window, cx: &mut Context<Editor>| { - editor.show_inline_completion(&Default::default(), window, cx); + editor.show_edit_prediction(&Default::default(), window, cx); }, )) .detach(); @@ -195,16 +192,6 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed }, )) .detach(); - editor - .register_action(cx.listener( - |editor, - _: &editor::actions::AcceptPartialCopilotSuggestion, - window: &mut Window, - cx: &mut Context<Editor>| { - editor.accept_partial_inline_completion(&Default::default(), window, cx); - }, - )) - .detach(); } fn assign_edit_prediction_provider( @@ -220,7 +207,7 @@ fn assign_edit_prediction_provider( match provider { EditPredictionProvider::None => { - editor.set_edit_prediction_provider::<ZetaInlineCompletionProvider>(None, window, cx); + editor.set_edit_prediction_provider::<ZetaEditPredictionProvider>(None, window, cx); } EditPredictionProvider::Copilot => { if let Some(copilot) = Copilot::global(cx) { @@ -242,7 +229,7 @@ fn assign_edit_prediction_provider( } } EditPredictionProvider::Zed => { - if client.status().borrow().is_connected() { + if user_store.read(cx).current_user().is_some() { let mut worktree = None; if let Some(buffer) = &singleton_buffer { @@ -278,7 +265,7 @@ fn assign_edit_prediction_provider( ProviderDataCollection::new(zeta.clone(), singleton_buffer, cx); let provider = - cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta, data_collection)); + cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, data_collection)); editor.set_edit_prediction_provider(Some(provider), window, cx); } diff --git a/crates/zed/src/zed/quick_action_bar.rs b/crates/zed/src/zed/quick_action_bar.rs index aff124a0bc..e76bef59a3 100644 --- a/crates/zed/src/zed/quick_action_bar.rs +++ b/crates/zed/src/zed/quick_action_bar.rs @@ -2,7 +2,6 @@ mod preview; mod repl_menu; use agent_settings::AgentSettings; -use client::DisableAiSettings; use editor::actions::{ AddSelectionAbove, AddSelectionBelow, CodeActionSource, DuplicateLineDown, GoToDiagnostic, GoToHunk, GoToPreviousDiagnostic, GoToPreviousHunk, MoveLineDown, MoveLineUp, SelectAll, @@ -16,6 +15,7 @@ use gpui::{ FocusHandle, Focusable, InteractiveElement, ParentElement, Render, Styled, Subscription, WeakEntity, Window, anchored, deferred, point, }; +use project::DisableAiSettings; use project::project_settings::DiagnosticSeverity; use search::{BufferSearchBar, buffer_search}; use settings::{Settings, SettingsStore}; @@ -192,7 +192,7 @@ impl Render for QuickActionBar { }; v_flex() .child( - IconButton::new("toggle_code_actions_icon", IconName::Bolt) + IconButton::new("toggle_code_actions_icon", IconName::BoltOutlined) .icon_size(IconSize::Small) .style(ButtonStyle::Subtle) .disabled(!has_available_code_actions) @@ -381,7 +381,7 @@ impl Render for QuickActionBar { } if has_edit_prediction_provider { - let mut inline_completion_entry = ContextMenuEntry::new("Edit Predictions") + let mut edit_prediction_entry = ContextMenuEntry::new("Edit Predictions") .toggleable(IconPosition::Start, edit_predictions_enabled_at_cursor && show_edit_predictions) .disabled(!edit_predictions_enabled_at_cursor) .action( @@ -401,12 +401,12 @@ impl Render for QuickActionBar { } }); if !edit_predictions_enabled_at_cursor { - inline_completion_entry = inline_completion_entry.documentation_aside(DocumentationSide::Left, |_| { + edit_prediction_entry = edit_prediction_entry.documentation_aside(DocumentationSide::Left, |_| { Label::new("You can't toggle edit predictions for this file as it is within the excluded files list.").into_any_element() }); } - menu = menu.item(inline_completion_entry); + menu = menu.item(edit_prediction_entry); } menu = menu.separator(); diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index 4b4bf016c4..64891b6973 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -260,14 +260,25 @@ pub mod icon_theme_selector { } } +pub mod settings_profile_selector { + use gpui::Action; + use schemars::JsonSchema; + use serde::Deserialize; + + #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] + #[action(namespace = settings_profile_selector)] + pub struct Toggle; +} + pub mod agent { use gpui::actions; actions!( agent, [ - /// Opens the agent configuration panel. - OpenConfiguration, + /// Opens the agent settings panel. + #[action(deprecated_aliases = ["agent::OpenConfiguration"])] + OpenSettings, /// Opens the agent onboarding modal. OpenOnboardingModal, /// Resets the agent onboarding state. diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index c2b1de08ae..9f1d02b790 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -21,6 +21,7 @@ ai_onboarding.workspace = true anyhow.workspace = true arrayvec.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true copilot.workspace = true @@ -32,14 +33,13 @@ futures.workspace = true gpui.workspace = true http_client.workspace = true indoc.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true language.workspace = true language_model.workspace = true log.workspace = true menu.workspace = true postage.workspace = true project.workspace = true -proto.workspace = true regex.workspace = true release_channel.workspace = true serde.workspace = true @@ -52,16 +52,17 @@ thiserror.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true +workspace-hack.workspace = true workspace.workspace = true worktree.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true -workspace-hack.workspace = true [dev-dependencies] -collections = { workspace = true, features = ["test-support"] } +call = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } +cloud_api_types.workspace = true +collections = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } @@ -77,5 +78,4 @@ tree-sitter-rust.workspace = true unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } -call = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/zeta/src/completion_diff_element.rs b/crates/zeta/src/completion_diff_element.rs index 3b7355d797..73c3cb20cd 100644 --- a/crates/zeta/src/completion_diff_element.rs +++ b/crates/zeta/src/completion_diff_element.rs @@ -1,6 +1,6 @@ use std::cmp; -use crate::InlineCompletion; +use crate::EditPrediction; use gpui::{ AnyElement, App, BorderStyle, Bounds, Corners, Edges, HighlightStyle, Hsla, StyledText, TextLayout, TextStyle, point, prelude::*, quad, size, @@ -17,7 +17,7 @@ pub struct CompletionDiffElement { } impl CompletionDiffElement { - pub fn new(completion: &InlineCompletion, cx: &App) -> Self { + pub fn new(completion: &EditPrediction, cx: &App) -> Self { let mut diff = completion .snapshot .text_for_range(completion.excerpt_range.clone()) diff --git a/crates/zeta/src/init.rs b/crates/zeta/src/init.rs index 4a65771223..a01e3a89a2 100644 --- a/crates/zeta/src/init.rs +++ b/crates/zeta/src/init.rs @@ -1,10 +1,10 @@ use std::any::{Any, TypeId}; -use client::DisableAiSettings; use command_palette_hooks::CommandPaletteFilter; use feature_flags::{FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; use gpui::actions; use language::language_settings::{AllLanguageSettings, EditPredictionProvider}; +use project::DisableAiSettings; use settings::{Settings, SettingsStore, update_settings_file}; use ui::App; use workspace::Workspace; diff --git a/crates/zeta/src/rate_completion_modal.rs b/crates/zeta/src/rate_completion_modal.rs index 5a873fb8de..ac7fcade91 100644 --- a/crates/zeta/src/rate_completion_modal.rs +++ b/crates/zeta/src/rate_completion_modal.rs @@ -1,4 +1,4 @@ -use crate::{CompletionDiffElement, InlineCompletion, InlineCompletionRating, Zeta}; +use crate::{CompletionDiffElement, EditPrediction, EditPredictionRating, Zeta}; use editor::Editor; use gpui::{App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, actions, prelude::*}; use language::language_settings; @@ -34,7 +34,7 @@ pub struct RateCompletionModal { } struct ActiveCompletion { - completion: InlineCompletion, + completion: EditPrediction, feedback_editor: Entity<Editor>, } @@ -157,7 +157,7 @@ impl RateCompletionModal { if let Some(active) = &self.active_completion { zeta.rate_completion( &active.completion, - InlineCompletionRating::Positive, + EditPredictionRating::Positive, active.feedback_editor.read(cx).text(cx), cx, ); @@ -189,7 +189,7 @@ impl RateCompletionModal { self.zeta.update(cx, |zeta, cx| { zeta.rate_completion( &active.completion, - InlineCompletionRating::Negative, + EditPredictionRating::Negative, active.feedback_editor.read(cx).text(cx), cx, ); @@ -250,7 +250,7 @@ impl RateCompletionModal { pub fn select_completion( &mut self, - completion: Option<InlineCompletion>, + completion: Option<EditPrediction>, focus: bool, window: &mut Window, cx: &mut Context<Self>, diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index d6f033899d..b1bd737dbf 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -8,8 +8,8 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::{Dismissable, KEY_VALUE_STORE}; +use edit_prediction::DataCollectionState; pub use init::*; -use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; pub use license_detection::is_license_eligible_for_data_collection; pub use rate_completion_modal::*; @@ -17,6 +17,10 @@ pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; +use cloud_llm_client::{ + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, + PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME, +}; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; use gpui::{ @@ -30,7 +34,7 @@ use language::{ }; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use postage::watch; -use project::Project; +use project::{Project, ProjectPath}; use release_channel::AppVersion; use settings::WorktreeId; use std::str::FromStr; @@ -46,17 +50,13 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use telemetry_events::InlineCompletionRating; +use telemetry_events::EditPredictionRating; use thiserror::Error; use util::ResultExt; use uuid::Uuid; use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; use worktree::Worktree; -use zed_llm_client::{ - AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, -}; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>"; @@ -81,15 +81,15 @@ actions!( ); #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct InlineCompletionId(Uuid); +pub struct EditPredictionId(Uuid); -impl From<InlineCompletionId> for gpui::ElementId { - fn from(value: InlineCompletionId) -> Self { +impl From<EditPredictionId> for gpui::ElementId { + fn from(value: EditPredictionId) -> Self { gpui::ElementId::Uuid(value.0) } } -impl std::fmt::Display for InlineCompletionId { +impl std::fmt::Display for EditPredictionId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } @@ -121,9 +121,10 @@ impl Dismissable for ZedPredictUpsell { } pub fn should_show_upsell_modal(user_store: &Entity<UserStore>, cx: &App) -> bool { - match user_store.read(cx).current_user_has_accepted_terms() { - Some(true) => !ZedPredictUpsell::dismissed(), - Some(false) | None => true, + if user_store.read(cx).has_accepted_terms_of_service() { + !ZedPredictUpsell::dismissed() + } else { + true } } @@ -133,8 +134,8 @@ struct ZetaGlobal(Entity<Zeta>); impl Global for ZetaGlobal {} #[derive(Clone)] -pub struct InlineCompletion { - id: InlineCompletionId, +pub struct EditPrediction { + id: EditPredictionId, path: Arc<Path>, excerpt_range: Range<usize>, cursor_offset: usize, @@ -145,14 +146,14 @@ pub struct InlineCompletion { input_events: Arc<str>, input_excerpt: Arc<str>, output_excerpt: Arc<str>, - request_sent_at: Instant, + buffer_snapshotted_at: Instant, response_received_at: Instant, } -impl InlineCompletion { +impl EditPrediction { fn latency(&self) -> Duration { self.response_received_at - .duration_since(self.request_sent_at) + .duration_since(self.buffer_snapshotted_at) } fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> { @@ -206,9 +207,9 @@ fn interpolate( if edits.is_empty() { None } else { Some(edits) } } -impl std::fmt::Debug for InlineCompletion { +impl std::fmt::Debug for EditPrediction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InlineCompletion") + f.debug_struct("EditPrediction") .field("id", &self.id) .field("path", &self.path) .field("edits", &self.edits) @@ -221,17 +222,14 @@ pub struct Zeta { client: Arc<Client>, events: VecDeque<Event>, registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>, - shown_completions: VecDeque<InlineCompletion>, - rated_completions: HashSet<InlineCompletionId>, + shown_completions: VecDeque<EditPrediction>, + rated_completions: HashSet<EditPredictionId>, data_collection_choice: Entity<DataCollectionChoice>, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - /// Whether the terms of service have been accepted. - tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. update_required: bool, user_store: Entity<UserStore>, - _user_store_subscription: Subscription, license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>, } @@ -306,22 +304,7 @@ impl Zeta { .detach_and_log_err(cx); }, ), - tos_accepted: user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false), update_required: false, - _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { - match event { - client::user::Event::PrivateUserInfoUpdated => { - this.tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); - } - _ => {} - } - }), license_detection_watchers: HashMap::default(), user_store, } @@ -401,111 +384,64 @@ impl Zeta { can_collect_data: bool, cx: &mut Context<Self>, perform_predict_edits: F, - ) -> Task<Result<Option<InlineCompletion>>> + ) -> Task<Result<Option<EditPrediction>>> where F: FnOnce(PerformPredictEditsParams) -> R + 'static, R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> + Send + 'static, { + let buffer = buffer.clone(); + let buffer_snapshotted_at = Instant::now(); let snapshot = self.report_changes_for_buffer(&buffer, cx); - let diagnostic_groups = snapshot.diagnostic_groups(None); - let cursor_point = cursor.to_point(&snapshot); - let cursor_offset = cursor_point.to_offset(&snapshot); - let events = self.events.clone(); - let path: Arc<Path> = snapshot - .file() - .map(|f| Arc::from(f.full_path(cx).as_path())) - .unwrap_or_else(|| Arc::from(Path::new("untitled"))); - let zeta = cx.entity(); + let events = self.events.clone(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let buffer = buffer.clone(); - - let local_lsp_store = - project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); - let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store { - Some( - diagnostic_groups - .into_iter() - .filter_map(|(language_server_id, diagnostic_group)| { - let language_server = - local_lsp_store.running_language_server_for_id(language_server_id)?; - - Some(( - language_server.name(), - diagnostic_group.resolve::<usize>(&snapshot), - )) - }) - .collect::<Vec<_>>(), - ) + let git_info = if let (true, Some(project), Some(file)) = + (can_collect_data, project, snapshot.file()) + { + git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) } else { None }; + let full_path: Arc<Path> = snapshot + .file() + .map(|f| Arc::from(f.full_path(cx).as_path())) + .unwrap_or_else(|| Arc::from(Path::new("untitled"))); + let full_path_str = full_path.to_string_lossy().to_string(); + let cursor_point = cursor.to_point(&snapshot); + let cursor_offset = cursor_point.to_offset(&snapshot); + let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS); + let gather_task = gather_context( + project, + full_path_str, + &snapshot, + cursor_point, + make_events_prompt, + can_collect_data, + git_info, + cx, + ); + cx.spawn(async move |this, cx| { - let request_sent_at = Instant::now(); - - struct BackgroundValues { - input_events: String, - input_excerpt: String, - speculated_output: String, - editable_range: Range<usize>, - input_outline: String, - } - - let values = cx - .background_spawn({ - let snapshot = snapshot.clone(); - let path = path.clone(); - async move { - let path = path.to_string_lossy(); - let input_excerpt = excerpt_for_cursor_position( - cursor_point, - &path, - &snapshot, - MAX_REWRITE_TOKENS, - MAX_CONTEXT_TOKENS, - ); - let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS); - let input_outline = prompt_for_outline(&snapshot); - - anyhow::Ok(BackgroundValues { - input_events, - input_excerpt: input_excerpt.prompt, - speculated_output: input_excerpt.speculated_output, - editable_range: input_excerpt.editable_range.to_offset(&snapshot), - input_outline, - }) - } - }) - .await?; + let GatherContextOutput { + body, + editable_range, + } = gather_task.await?; log::debug!( "Events:\n{}\nExcerpt:\n{:?}", - values.input_events, - values.input_excerpt + body.input_events, + body.input_excerpt ); - let body = PredictEditsBody { - input_events: values.input_events.clone(), - input_excerpt: values.input_excerpt.clone(), - speculated_output: Some(values.speculated_output), - outline: Some(values.input_outline.clone()), - can_collect_data, - diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| { - diagnostic_groups - .into_iter() - .map(|(name, diagnostic_group)| { - Ok((name.to_string(), serde_json::to_value(diagnostic_group)?)) - }) - .collect::<Result<Vec<_>>>() - .log_err() - }), - }; + let input_outline = body.outline.clone().unwrap_or_default(); + let input_events = body.input_events.clone(); + let input_excerpt = body.input_excerpt.clone(); let response = perform_predict_edits(PerformPredictEditsParams { client, @@ -563,13 +499,13 @@ impl Zeta { response, buffer, &snapshot, - values.editable_range, + editable_range, cursor_offset, - path, - values.input_outline, - values.input_events, - values.input_excerpt, - request_sent_at, + full_path, + input_outline, + input_events, + input_excerpt, + buffer_snapshotted_at, &cx, ) .await @@ -737,7 +673,7 @@ and then another position: language::Anchor, response: PredictEditsResponse, cx: &mut Context<Self>, - ) -> Task<Result<Option<InlineCompletion>>> { + ) -> Task<Result<Option<EditPrediction>>> { use std::future::ready; self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { @@ -752,7 +688,7 @@ and then another position: language::Anchor, can_collect_data: bool, cx: &mut Context<Self>, - ) -> Task<Result<Option<InlineCompletion>>> { + ) -> Task<Result<Option<EditPrediction>>> { let workspace = self .workspace .as_ref() @@ -768,7 +704,7 @@ and then another ) } - fn perform_predict_edits( + pub fn perform_predict_edits( params: PerformPredictEditsParams, ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> { async move { @@ -846,7 +782,7 @@ and then another fn accept_edit_prediction( &mut self, - request_id: InlineCompletionId, + request_id: EditPredictionId, cx: &mut Context<Self>, ) -> Task<Result<()>> { let client = self.client.clone(); @@ -923,9 +859,9 @@ and then another input_outline: String, input_events: String, input_excerpt: String, - request_sent_at: Instant, + buffer_snapshotted_at: Instant, cx: &AsyncApp, - ) -> Task<Result<Option<InlineCompletion>>> { + ) -> Task<Result<Option<EditPrediction>>> { let snapshot = snapshot.clone(); let request_id = prediction_response.request_id; let output_excerpt = prediction_response.output_excerpt; @@ -957,8 +893,8 @@ and then another let edit_preview = edit_preview.await; - Ok(Some(InlineCompletion { - id: InlineCompletionId(request_id), + Ok(Some(EditPrediction { + id: EditPredictionId(request_id), path, excerpt_range: editable_range, cursor_offset, @@ -969,7 +905,7 @@ and then another input_events: input_events.into(), input_excerpt: input_excerpt.into(), output_excerpt, - request_sent_at, + buffer_snapshotted_at, response_received_at: Instant::now(), })) }) @@ -1068,11 +1004,11 @@ and then another .collect() } - pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool { + pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool { self.rated_completions.contains(&completion_id) } - pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) { + pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) { self.shown_completions.push_front(completion.clone()); if self.shown_completions.len() > 50 { let completion = self.shown_completions.pop_back().unwrap(); @@ -1083,8 +1019,8 @@ and then another pub fn rate_completion( &mut self, - completion: &InlineCompletion, - rating: InlineCompletionRating, + completion: &EditPrediction, + rating: EditPredictionRating, feedback: String, cx: &mut Context<Self>, ) { @@ -1102,7 +1038,7 @@ and then another cx.notify(); } - pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> { + pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> { self.shown_completions.iter() } @@ -1153,7 +1089,7 @@ and then another } } -struct PerformPredictEditsParams { +pub struct PerformPredictEditsParams { pub client: Arc<Client>, pub llm_token: LlmApiToken, pub app_version: SemanticVersion, @@ -1228,6 +1164,108 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: .sum() } +fn git_info_for_file( + project: &Entity<Project>, + project_path: &ProjectPath, + cx: &App, +) -> Option<PredictEditsGitInfo> { + let git_store = project.read(cx).git_store().read(cx); + if let Some((repository, _repo_path)) = + git_store.repository_and_path_for_project_path(project_path, cx) + { + let repository = repository.read(cx); + let head_sha = repository + .head_commit + .as_ref() + .map(|head_commit| head_commit.sha.to_string()); + let remote_origin_url = repository.remote_origin_url.clone(); + let remote_upstream_url = repository.remote_upstream_url.clone(); + if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { + return None; + } + Some(PredictEditsGitInfo { + head_sha, + remote_origin_url, + remote_upstream_url, + }) + } else { + None + } +} + +pub struct GatherContextOutput { + pub body: PredictEditsBody, + pub editable_range: Range<usize>, +} + +pub fn gather_context( + project: Option<&Entity<Project>>, + full_path_str: String, + snapshot: &BufferSnapshot, + cursor_point: language::Point, + make_events_prompt: impl FnOnce() -> String + Send + 'static, + can_collect_data: bool, + git_info: Option<PredictEditsGitInfo>, + cx: &App, +) -> Task<Result<GatherContextOutput>> { + let local_lsp_store = + project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); + let diagnostic_groups: Vec<(String, serde_json::Value)> = + if let Some(local_lsp_store) = local_lsp_store { + snapshot + .diagnostic_groups(None) + .into_iter() + .filter_map(|(language_server_id, diagnostic_group)| { + let language_server = + local_lsp_store.running_language_server_for_id(language_server_id)?; + let diagnostic_group = diagnostic_group.resolve::<usize>(&snapshot); + let language_server_name = language_server.name().to_string(); + let serialized = serde_json::to_value(diagnostic_group).unwrap(); + Some((language_server_name, serialized)) + }) + .collect::<Vec<_>>() + } else { + Vec::new() + }; + + cx.background_spawn({ + let snapshot = snapshot.clone(); + async move { + let diagnostic_groups = if diagnostic_groups.is_empty() { + None + } else { + Some(diagnostic_groups) + }; + + let input_excerpt = excerpt_for_cursor_position( + cursor_point, + &full_path_str, + &snapshot, + MAX_REWRITE_TOKENS, + MAX_CONTEXT_TOKENS, + ); + let input_events = make_events_prompt(); + let input_outline = prompt_for_outline(&snapshot); + let editable_range = input_excerpt.editable_range.to_offset(&snapshot); + + let body = PredictEditsBody { + input_events, + input_excerpt: input_excerpt.prompt, + speculated_output: Some(input_excerpt.speculated_output), + outline: Some(input_outline), + can_collect_data, + diagnostic_groups, + git_info, + }; + + Ok(GatherContextOutput { + body, + editable_range, + }) + } + }) +} + fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { let mut input_outline = String::new(); @@ -1278,7 +1316,7 @@ struct RegisteredBuffer { } #[derive(Clone)] -enum Event { +pub enum Event { BufferChange { old_snapshot: BufferSnapshot, new_snapshot: BufferSnapshot, @@ -1325,12 +1363,12 @@ impl Event { } #[derive(Debug, Clone)] -struct CurrentInlineCompletion { +struct CurrentEditPrediction { buffer_id: EntityId, - completion: InlineCompletion, + completion: EditPrediction, } -impl CurrentInlineCompletion { +impl CurrentEditPrediction { fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool { if self.buffer_id != old_completion.buffer_id { return true; @@ -1499,17 +1537,17 @@ async fn llm_token_retry( } } -pub struct ZetaInlineCompletionProvider { +pub struct ZetaEditPredictionProvider { zeta: Entity<Zeta>, pending_completions: ArrayVec<PendingCompletion, 2>, next_pending_completion_id: usize, - current_completion: Option<CurrentInlineCompletion>, + current_completion: Option<CurrentEditPrediction>, /// None if this is entirely disabled for this provider provider_data_collection: ProviderDataCollection, last_request_timestamp: Instant, } -impl ZetaInlineCompletionProvider { +impl ZetaEditPredictionProvider { pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self { @@ -1524,7 +1562,7 @@ impl ZetaInlineCompletionProvider { } } -impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider { +impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { fn name() -> &'static str { "zed-predict" } @@ -1573,7 +1611,12 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } fn needs_terms_acceptance(&self, cx: &App) -> bool { - !self.zeta.read(cx).tos_accepted + !self + .zeta + .read(cx) + .user_store + .read(cx) + .has_accepted_terms_of_service() } fn is_refreshing(&self) -> bool { @@ -1588,7 +1631,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider _debounce: bool, cx: &mut Context<Self>, ) { - if !self.zeta.read(cx).tos_accepted { + if self.needs_terms_acceptance(cx) { return; } @@ -1600,7 +1643,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider .zeta .read(cx) .user_store - .read_with(cx, |user_store, _| { + .read_with(cx, |user_store, _cx| { user_store.account_too_young() || user_store.has_overdue_invoices() }) { @@ -1647,7 +1690,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider Ok(completion_request) => { let completion_request = completion_request.await; completion_request.map(|c| { - c.map(|completion| CurrentInlineCompletion { + c.map(|completion| CurrentEditPrediction { buffer_id: buffer.entity_id(), completion, }) @@ -1720,7 +1763,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider &mut self, _buffer: Entity<Buffer>, _cursor_position: language::Anchor, - _direction: inline_completion::Direction, + _direction: edit_prediction::Direction, _cx: &mut Context<Self>, ) { // Right now we don't support cycling. @@ -1751,8 +1794,8 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut Context<Self>, - ) -> Option<inline_completion::InlineCompletion> { - let CurrentInlineCompletion { + ) -> Option<edit_prediction::EditPrediction> { + let CurrentEditPrediction { buffer_id, completion, .. @@ -1800,7 +1843,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } } - Some(inline_completion::InlineCompletion { + Some(edit_prediction::EditPrediction { id: Some(completion.id.to_string().into()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), edit_preview: Some(completion.edit_preview.clone()), @@ -1817,19 +1860,20 @@ fn tokens_for_bytes(bytes: usize) -> usize { #[cfg(test)] mod tests { + use client::UserStore; use client::test::FakeServer; use clock::FakeSystemClock; + use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; use language::Point; - use rpc::proto; use settings::SettingsStore; use super::*; #[gpui::test] - async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) { + async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| { to_completion_edits( @@ -1844,19 +1888,19 @@ mod tests { .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) .await; - let completion = InlineCompletion { + let completion = EditPrediction { edits, edit_preview, path: Path::new("").into(), snapshot: cx.read(|cx| buffer.read(cx).snapshot()), - id: InlineCompletionId(Uuid::new_v4()), + id: EditPredictionId(Uuid::new_v4()), excerpt_range: 0..0, cursor_offset: 0, input_outline: "".into(), input_events: "".into(), input_excerpt: "".into(), output_excerpt: "".into(), - request_sent_at: Instant::now(), + buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), }; @@ -2010,7 +2054,7 @@ mod tests { } #[gpui::test] - async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) { + async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); @@ -2027,28 +2071,45 @@ mod tests { <|editable_region_end|> ```"}; - let http_client = FakeHttpClient::create(move |_| async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") - .unwrap(), - output_excerpt: completion_response.to_string(), - }) - .unwrap() - .into(), - ) - .unwrap()) + let http_client = FakeHttpClient::create(move |req| async move { + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") + .unwrap(), + output_excerpt: completion_response.to_string(), + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } }); let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -2056,13 +2117,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::<proto::GetUsers>().await.unwrap(); - let token_request = server.receive::<proto::GetLlmToken>().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); buffer.update(cx, |buffer, cx| { buffer.edit(completion.edits.iter().cloned(), None, cx) @@ -2079,20 +2133,36 @@ mod tests { cx: &mut TestAppContext, ) -> Vec<(Range<Point>, String)> { let completion_response = completion_response.to_string(); - let http_client = FakeHttpClient::create(move |_| { + let http_client = FakeHttpClient::create(move |req| { let completion = completion_response.clone(); async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), - output_excerpt: completion, - }) - .unwrap() - .into(), - ) - .unwrap()) + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::new_v4(), + output_excerpt: completion, + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } } }); @@ -2100,9 +2170,10 @@ mod tests { cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); @@ -2111,13 +2182,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::<proto::GetUsers>().await.unwrap(); - let token_request = server.receive::<proto::GetLlmToken>().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); completion .edits diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml new file mode 100644 index 0000000000..e77351c219 --- /dev/null +++ b/crates/zeta_cli/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "zeta_cli" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[[bin]] +name = "zeta" +path = "src/main.rs" + +[dependencies] +anyhow.workspace = true +clap.workspace = true +client.workspace = true +debug_adapter_extension.workspace = true +extension.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +gpui_tokio.workspace = true +language.workspace = true +language_extension.workspace = true +language_model.workspace = true +language_models.workspace = true +languages = { workspace = true, features = ["load-grammars"] } +node_runtime.workspace = true +paths.workspace = true +project.workspace = true +prompt_store.workspace = true +release_channel.workspace = true +reqwest_client.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +shellexpand.workspace = true +terminal_view.workspace = true +util.workspace = true +watch.workspace = true +workspace-hack.workspace = true +zeta.workspace = true +smol.workspace = true diff --git a/crates/zeta_cli/LICENSE-GPL b/crates/zeta_cli/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/zeta_cli/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zeta_cli/build.rs b/crates/zeta_cli/build.rs new file mode 100644 index 0000000000..ccbb54c5b4 --- /dev/null +++ b/crates/zeta_cli/build.rs @@ -0,0 +1,14 @@ +fn main() { + let cargo_toml = + std::fs::read_to_string("../zed/Cargo.toml").expect("Failed to read Cargo.toml"); + let version = cargo_toml + .lines() + .find(|line| line.starts_with("version = ")) + .expect("Version not found in crates/zed/Cargo.toml") + .split('=') + .nth(1) + .expect("Invalid version format") + .trim() + .trim_matches('"'); + println!("cargo:rustc-env=ZED_PKG_VERSION={}", version); +} diff --git a/crates/zeta_cli/src/headless.rs b/crates/zeta_cli/src/headless.rs new file mode 100644 index 0000000000..959bb91a8f --- /dev/null +++ b/crates/zeta_cli/src/headless.rs @@ -0,0 +1,128 @@ +use client::{Client, ProxySettings, UserStore}; +use extension::ExtensionHostProxy; +use fs::RealFs; +use gpui::http_client::read_proxy_from_env; +use gpui::{App, AppContext, Entity}; +use gpui_tokio::Tokio; +use language::LanguageRegistry; +use language_extension::LspAccess; +use node_runtime::{NodeBinaryOptions, NodeRuntime}; +use project::Project; +use project::project_settings::ProjectSettings; +use release_channel::AppVersion; +use reqwest_client::ReqwestClient; +use settings::{Settings, SettingsStore}; +use std::path::PathBuf; +use std::sync::Arc; +use util::ResultExt as _; + +/// Headless subset of `workspace::AppState`. +pub struct ZetaCliAppState { + pub languages: Arc<LanguageRegistry>, + pub client: Arc<Client>, + pub user_store: Entity<UserStore>, + pub fs: Arc<dyn fs::Fs>, + pub node_runtime: NodeRuntime, +} + +// TODO: dedupe with crates/eval/src/eval.rs +pub fn init(cx: &mut App) -> ZetaCliAppState { + let app_version = AppVersion::load(env!("ZED_PKG_VERSION")); + release_channel::init(app_version, cx); + gpui_tokio::init(cx); + + let mut settings_store = SettingsStore::new(cx); + settings_store + .set_default_settings(settings::default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(settings_store); + client::init_settings(cx); + + // Set User-Agent so we can download language servers from GitHub + let user_agent = format!( + "Zed/{} ({}; {})", + app_version, + std::env::consts::OS, + std::env::consts::ARCH + ); + let proxy_str = ProxySettings::get_global(cx).proxy.to_owned(); + let proxy_url = proxy_str + .as_ref() + .and_then(|input| input.parse().ok()) + .or_else(read_proxy_from_env); + let http = { + let _guard = Tokio::handle(cx).enter(); + + ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent) + .expect("could not start HTTP client") + }; + cx.set_http_client(Arc::new(http)); + + Project::init_settings(cx); + + let client = Client::production(cx); + cx.set_http_client(client.http_client()); + + let git_binary_path = None; + let fs = Arc::new(RealFs::new( + git_binary_path, + cx.background_executor().clone(), + )); + + let mut languages = LanguageRegistry::new(cx.background_executor().clone()); + languages.set_language_server_download_dir(paths::languages_dir().clone()); + let languages = Arc::new(languages); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + + extension::init(cx); + + let (mut tx, rx) = watch::channel(None); + cx.observe_global::<SettingsStore>(move |cx| { + let settings = &ProjectSettings::get_global(cx).node; + let options = NodeBinaryOptions { + allow_path_lookup: !settings.ignore_system_version, + allow_binary_download: true, + use_paths: settings.path.as_ref().map(|node_path| { + let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref()); + let npm_path = settings + .npm_path + .as_ref() + .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref())); + ( + node_path.clone(), + npm_path.unwrap_or_else(|| { + let base_path = PathBuf::new(); + node_path.parent().unwrap_or(&base_path).join("npm") + }), + ) + }), + }; + tx.send(Some(options)).log_err(); + }) + .detach(); + let node_runtime = NodeRuntime::new(client.http_client(), None, rx); + + let extension_host_proxy = ExtensionHostProxy::global(cx); + + language::init(cx); + debug_adapter_extension::init(extension_host_proxy.clone(), cx); + language_extension::init( + LspAccess::Noop, + extension_host_proxy.clone(), + languages.clone(), + ); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + languages::init(languages.clone(), node_runtime.clone(), cx); + prompt_store::init(cx); + terminal_view::init(cx); + + ZetaCliAppState { + languages, + client, + user_store, + fs, + node_runtime, + } +} diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs new file mode 100644 index 0000000000..adf7683152 --- /dev/null +++ b/crates/zeta_cli/src/main.rs @@ -0,0 +1,378 @@ +mod headless; + +use anyhow::{Result, anyhow}; +use clap::{Args, Parser, Subcommand}; +use futures::channel::mpsc; +use futures::{FutureExt as _, StreamExt as _}; +use gpui::{AppContext, Application, AsyncApp}; +use gpui::{Entity, Task}; +use language::Bias; +use language::Buffer; +use language::Point; +use language_model::LlmApiToken; +use project::{Project, ProjectPath}; +use release_channel::AppVersion; +use reqwest_client::ReqwestClient; +use std::path::{Path, PathBuf}; +use std::process::exit; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context}; + +use crate::headless::ZetaCliAppState; + +#[derive(Parser, Debug)] +#[command(name = "zeta")] +struct ZetaCliArgs { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +enum Commands { + Context(ContextArgs), + Predict { + #[arg(long)] + predict_edits_body: Option<FileOrStdin>, + #[clap(flatten)] + context_args: Option<ContextArgs>, + }, +} + +#[derive(Debug, Args)] +#[group(requires = "worktree")] +struct ContextArgs { + #[arg(long)] + worktree: PathBuf, + #[arg(long)] + cursor: CursorPosition, + #[arg(long)] + use_language_server: bool, + #[arg(long)] + events: Option<FileOrStdin>, +} + +#[derive(Debug, Clone)] +enum FileOrStdin { + File(PathBuf), + Stdin, +} + +impl FileOrStdin { + async fn read_to_string(&self) -> Result<String, std::io::Error> { + match self { + FileOrStdin::File(path) => smol::fs::read_to_string(path).await, + FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await, + } + } +} + +impl FromStr for FileOrStdin { + type Err = <PathBuf as FromStr>::Err; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s { + "-" => Ok(Self::Stdin), + _ => Ok(Self::File(PathBuf::from_str(s)?)), + } + } +} + +#[derive(Debug, Clone)] +struct CursorPosition { + path: PathBuf, + point: Point, +} + +impl FromStr for CursorPosition { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result<Self> { + let parts: Vec<&str> = s.split(':').collect(); + if parts.len() != 3 { + return Err(anyhow!( + "Invalid cursor format. Expected 'file.rs:line:column', got '{}'", + s + )); + } + + let path = PathBuf::from(parts[0]); + let line: u32 = parts[1] + .parse() + .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?; + let column: u32 = parts[2] + .parse() + .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?; + + // Convert from 1-based to 0-based indexing + let point = Point::new(line.saturating_sub(1), column.saturating_sub(1)); + + Ok(CursorPosition { path, point }) + } +} + +async fn get_context( + args: ContextArgs, + app_state: &Arc<ZetaCliAppState>, + cx: &mut AsyncApp, +) -> Result<GatherContextOutput> { + let ContextArgs { + worktree: worktree_path, + cursor, + use_language_server, + events, + } = args; + + let worktree_path = worktree_path.canonicalize()?; + if cursor.path.is_absolute() { + return Err(anyhow!("Absolute paths are not supported in --cursor")); + } + + let (project, _lsp_open_handle, buffer) = if use_language_server { + let (project, lsp_open_handle, buffer) = + open_buffer_with_language_server(&worktree_path, &cursor.path, &app_state, cx).await?; + (Some(project), Some(lsp_open_handle), buffer) + } else { + let abs_path = worktree_path.join(&cursor.path); + let content = smol::fs::read_to_string(&abs_path).await?; + let buffer = cx.new(|cx| Buffer::local(content, cx))?; + (None, None, buffer) + }; + + let worktree_name = worktree_path + .file_name() + .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?; + let full_path_str = PathBuf::from(worktree_name) + .join(&cursor.path) + .to_string_lossy() + .to_string(); + + let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?; + let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left); + if clipped_cursor != cursor.point { + let max_row = snapshot.max_point().row; + if cursor.point.row < max_row { + return Err(anyhow!( + "Cursor position {:?} is out of bounds (line length is {})", + cursor.point, + snapshot.line_len(cursor.point.row) + )); + } else { + return Err(anyhow!( + "Cursor position {:?} is out of bounds (max row is {})", + cursor.point, + max_row + )); + } + } + + let events = match events { + Some(events) => events.read_to_string().await?, + None => String::new(), + }; + let can_collect_data = false; + let git_info = None; + cx.update(|cx| { + gather_context( + project.as_ref(), + full_path_str, + &snapshot, + clipped_cursor, + move || events, + can_collect_data, + git_info, + cx, + ) + })? + .await +} + +pub async fn open_buffer_with_language_server( + worktree_path: &Path, + path: &Path, + app_state: &Arc<ZetaCliAppState>, + cx: &mut AsyncApp, +) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> { + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(worktree_path, true, cx) + })? + .await?; + + let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { + worktree_id: worktree.id(), + path: path.to_path_buf().into(), + })?; + + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + + let lsp_open_handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + })?; + + let log_prefix = path.to_string_lossy().to_string(); + wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; + + Ok((project, lsp_open_handle, buffer)) +} + +// TODO: Dedupe with similar function in crates/eval/src/instance.rs +pub fn wait_for_lang_server( + project: &Entity<Project>, + buffer: &Entity<Buffer>, + log_prefix: String, + cx: &mut AsyncApp, +) -> Task<Result<()>> { + println!("{}⏵ Waiting for language server", log_prefix); + + let (mut tx, mut rx) = mpsc::channel(1); + + let lsp_store = project + .read_with(cx, |project, _| project.lsp_store()) + .unwrap(); + + let has_lang_server = buffer + .update(cx, |buffer, cx| { + lsp_store.update(cx, |lsp_store, cx| { + lsp_store + .language_servers_for_local_buffer(&buffer, cx) + .next() + .is_some() + }) + }) + .unwrap_or(false); + + if has_lang_server { + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .unwrap() + .detach(); + } + + let subscriptions = [ + cx.subscribe(&lsp_store, { + let log_prefix = log_prefix.clone(); + move |_, event, _| match event { + project::LspStoreEvent::LanguageServerUpdate { + message: + client::proto::update_language_server::Variant::WorkProgress( + client::proto::LspWorkProgress { + message: Some(message), + .. + }, + ), + .. + } => println!("{}⟲ {message}", log_prefix), + _ => {} + } + }), + cx.subscribe(&project, { + let buffer = buffer.clone(); + move |project, event, cx| match event { + project::Event::LanguageServerAdded(_, _, _) => { + let buffer = buffer.clone(); + project + .update(cx, |project, cx| project.save_buffer(buffer, cx)) + .detach(); + } + project::Event::DiskBasedDiagnosticsFinished { .. } => { + tx.try_send(()).ok(); + } + _ => {} + } + }), + ]; + + cx.spawn(async move |cx| { + let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0)); + let result = futures::select! { + _ = rx.next() => { + println!("{}⚑ Language server idle", log_prefix); + anyhow::Ok(()) + }, + _ = timeout.fuse() => { + anyhow::bail!("LSP wait timed out after 5 minutes"); + } + }; + drop(subscriptions); + result + }) +} + +fn main() { + let args = ZetaCliArgs::parse(); + let http_client = Arc::new(ReqwestClient::new()); + let app = Application::headless().with_http_client(http_client); + + app.run(move |cx| { + let app_state = Arc::new(headless::init(cx)); + cx.spawn(async move |cx| { + let result = match args.command { + Commands::Context(context_args) => get_context(context_args, &app_state, cx) + .await + .map(|output| serde_json::to_string_pretty(&output.body).unwrap()), + Commands::Predict { + predict_edits_body, + context_args, + } => { + cx.spawn(async move |cx| { + let app_version = cx.update(|cx| AppVersion::global(cx))?; + app_state.client.sign_in(true, cx).await?; + let llm_token = LlmApiToken::default(); + llm_token.refresh(&app_state.client).await?; + + let predict_edits_body = + if let Some(predict_edits_body) = predict_edits_body { + serde_json::from_str(&predict_edits_body.read_to_string().await?)? + } else if let Some(context_args) = context_args { + get_context(context_args, &app_state, cx).await?.body + } else { + return Err(anyhow!( + "Expected either --predict-edits-body-file \ + or the required args of the `context` command." + )); + }; + + let (response, _usage) = + Zeta::perform_predict_edits(PerformPredictEditsParams { + client: app_state.client.clone(), + llm_token, + app_version, + body: predict_edits_body, + }) + .await?; + + Ok(response.output_excerpt) + }) + .await + } + }; + match result { + Ok(output) => { + println!("{}", output); + let _ = cx.update(|cx| cx.quit()); + } + Err(e) => { + eprintln!("Failed: {:?}", e); + exit(1); + } + } + }) + .detach(); + }); +} diff --git a/crates/zlog/src/sink.rs b/crates/zlog/src/sink.rs index acf0469c77..17aa08026e 100644 --- a/crates/zlog/src/sink.rs +++ b/crates/zlog/src/sink.rs @@ -21,6 +21,8 @@ const ANSI_MAGENTA: &str = "\x1b[35m"; /// Whether stdout output is enabled. static mut ENABLED_SINKS_STDOUT: bool = false; +/// Whether stderr output is enabled. +static mut ENABLED_SINKS_STDERR: bool = false; /// Is Some(file) if file output is enabled. static ENABLED_SINKS_FILE: Mutex<Option<std::fs::File>> = Mutex::new(None); @@ -45,6 +47,12 @@ pub fn init_output_stdout() { } } +pub fn init_output_stderr() { + unsafe { + ENABLED_SINKS_STDERR = true; + } +} + pub fn init_output_file( path: &'static PathBuf, path_rotate: Option<&'static PathBuf>, @@ -115,6 +123,21 @@ pub fn submit(record: Record) { }, record.message ); + } else if unsafe { ENABLED_SINKS_STDERR } { + let mut stdout = std::io::stderr().lock(); + _ = writeln!( + &mut stdout, + "{} {ANSI_BOLD}{}{}{ANSI_RESET} {} {}", + chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%:z"), + LEVEL_ANSI_COLORS[record.level as usize], + LEVEL_OUTPUT_STRINGS[record.level as usize], + SourceFmt { + scope: record.scope, + module_path: record.module_path, + ansi: true, + }, + record.message + ); } let mut file = ENABLED_SINKS_FILE.lock().unwrap_or_else(|handle| { ENABLED_SINKS_FILE.clear_poison(); diff --git a/crates/zlog/src/zlog.rs b/crates/zlog/src/zlog.rs index 570c82314c..5b40278f3f 100644 --- a/crates/zlog/src/zlog.rs +++ b/crates/zlog/src/zlog.rs @@ -5,7 +5,7 @@ mod env_config; pub mod filter; pub mod sink; -pub use sink::{flush, init_output_file, init_output_stdout}; +pub use sink::{flush, init_output_file, init_output_stderr, init_output_stdout}; pub const SCOPE_DEPTH_MAX: usize = 4; diff --git a/docs/README.md b/docs/README.md index 55993c9e36..a225903674 100644 --- a/docs/README.md +++ b/docs/README.md @@ -69,3 +69,64 @@ Templates are just functions that modify the source of the docs pages (usually w - Template Trait: crates/docs_preprocessor/src/templates.rs - Example template: crates/docs_preprocessor/src/templates/keybinding.rs - Client-side plugins: docs/theme/plugins.js + +## Postprocessor + +A postprocessor is implemented as a sub-command of `docs_preprocessor` that wraps the builtin `html` renderer and applies post-processing to the `html` files, to add support for page-specific title and meta description values. + +An example of the syntax can be found in `git.md`, as well as below + +```md +--- +title: Some more detailed title for this page +description: A page-specific description +--- + +# Editor +``` + +The above will be transformed into (with non-relevant tags removed) + +```html +<head> + <title>Editor | Some more detailed title for this page + + + +

Editor

+ +``` + +If no front-matter is provided, or If one or both keys aren't provided, the title and description will be set based on the `default-title` and `default-description` keys in `book.toml` respectively. + +### Implementation details + +Unfortunately, `mdbook` does not support post-processing like it does pre-processing, and only supports defining one description to put in the meta tag per book rather than per file. So in order to apply post-processing (necessary to modify the html head tags) the global book description is set to a marker value `#description#` and the html renderer is replaced with a sub-command of `docs_preprocessor` that wraps the builtin `html` renderer and applies post-processing to the `html` files, replacing the marker value and the `(.*)` with the contents of the front-matter if there is one. + +### Known limitations + +The front-matter parsing is extremely simple, which avoids needing to take on an additional dependency, or implement full yaml parsing. + +- Double quotes and multi-line values are not supported, i.e. Keys and values must be entirely on the same line, with no double quotes around the value. + +The following will not work: + +```md +--- +title: Some + Multi-line + Title +--- +``` + +And neither will: + +```md +--- +title: "Some title" +--- +``` + +- The front-matter must be at the top of the file, with only white-space preceding it + +- The contents of the title and description will not be html-escaped. They should be simple ascii text with no unicode or emoji characters diff --git a/docs/book.toml b/docs/book.toml index 518fbec819..60ddc5ac51 100644 --- a/docs/book.toml +++ b/docs/book.toml @@ -6,13 +6,27 @@ src = "src" title = "Zed" site-url = "/docs/" -[output.html] +[build] +extra-watch-dirs = ["../crates/docs_preprocessor"] + +# zed-html is a "custom" renderer that just wraps the +# builtin mdbook html renderer, and applies post-processing +# as post-processing is not possible with mdbook in the same way +# pre-processing is +# The config is passed directly to the html renderer, so all config +# options that apply to html apply to zed-html +[output.zed-html] +command = "cargo run -p docs_preprocessor -- postprocess" +# Set here instead of above as we only use it replace the `#description#` we set in the template +# when no front-matter is provided value +default-description = "Learn how to use and customize Zed, the fast, collaborative code editor. Official docs on features, configuration, AI tools, and workflows." +default-title = "Zed Code Editor Documentation" no-section-label = true preferred-dark-theme = "dark" additional-css = ["theme/page-toc.css", "theme/plugins.css", "theme/highlight.css"] additional-js = ["theme/page-toc.js", "theme/plugins.js"] -[output.html.print] +[output.zed-html.print] enable = false # Redirects for `/docs` pages. @@ -24,7 +38,7 @@ enable = false # The destination URLs are interpreted relative to `https://zed.dev`. # - Redirects to other docs pages should end in `.html` # - You can link to pages on the Zed site by omitting the `/docs` in front of it. -[output.html.redirect] +[output.zed-html.redirect] # AI "/ai.html" = "/docs/ai/overview.html" "/assistant-panel.html" = "/docs/ai/agent-panel.html" diff --git a/docs/src/ai/agent-settings.md b/docs/src/ai/agent-settings.md index 315ae21929..ff97bcb8ee 100644 --- a/docs/src/ai/agent-settings.md +++ b/docs/src/ai/agent-settings.md @@ -108,7 +108,7 @@ Specify a custom temperature for a provider and/or model: ## Agent Panel Settings {#agent-panel-settings} -Note that some of these settings are also surfaced in the Agent Panel's settings UI, which you can access either via the `agent: open configuration` action or by the dropdown menu on the top-right corner of the panel. +Note that some of these settings are also surfaced in the Agent Panel's settings UI, which you can access either via the `agent: open settings` action or by the dropdown menu on the top-right corner of the panel. ### Default View diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md index cb55c1c94e..04646213e6 100644 --- a/docs/src/ai/llm-providers.md +++ b/docs/src/ai/llm-providers.md @@ -86,7 +86,7 @@ To do this: 1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). 2. Create security credentials for that User, save them and keep them secure. -3. Open the Agent Configuration with (`agent: open configuration`) and go to the Amazon Bedrock section +3. Open the Agent Configuration with (`agent: open settings`) and go to the Amazon Bedrock section 4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields. #### Cross-Region Inference @@ -113,7 +113,7 @@ You can use Anthropic models by choosing them via the model dropdown in the Agen 1. Sign up for Anthropic and [create an API key](https://console.anthropic.com/settings/keys) 2. Make sure that your Anthropic account has credits -3. Open the settings view (`agent: open configuration`) and go to the Anthropic section +3. Open the settings view (`agent: open settings`) and go to the Anthropic section 4. Enter your Anthropic API key Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API. @@ -168,7 +168,7 @@ You can configure a model to use [extended thinking](https://docs.anthropic.com/ > ✅ Supports tool use 1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) -2. Open the settings view (`agent: open configuration`) and go to the DeepSeek section +2. Open the settings view (`agent: open settings`) and go to the DeepSeek section 3. Enter your DeepSeek API key The DeepSeek API key will be saved in your keychain. @@ -213,14 +213,14 @@ You can also modify the `api_url` to use a custom endpoint if needed. You can use GitHub Copilot Chat with the Zed agent by choosing it via the model dropdown in the Agent Panel. -1. Open the settings view (`agent: open configuration`) and go to the GitHub Copilot Chat section +1. Open the settings view (`agent: open settings`) and go to the GitHub Copilot Chat section 2. Click on `Sign in to use GitHub Copilot`, follow the steps shown in the modal. Alternatively, you can provide an OAuth token via the `GH_COPILOT_TOKEN` environment variable. > **Note**: If you don't see specific models in the dropdown, you may need to enable them in your [GitHub Copilot settings](https://github.com/settings/copilot/features). -To use Copilot Enterprise with Zed (for both agent and inline completions), you must configure your enterprise endpoint as described in [Configuring GitHub Copilot Enterprise](./edit-prediction.md#github-copilot-enterprise). +To use Copilot Enterprise with Zed (for both agent and completions), you must configure your enterprise endpoint as described in [Configuring GitHub Copilot Enterprise](./edit-prediction.md#github-copilot-enterprise). ### Google AI {#google-ai} @@ -229,7 +229,7 @@ To use Copilot Enterprise with Zed (for both agent and inline completions), you You can use Gemini models with the Zed agent by choosing it via the model dropdown in the Agent Panel. 1. Go to the Google AI Studio site and [create an API key](https://aistudio.google.com/app/apikey). -2. Open the settings view (`agent: open configuration`) and go to the Google AI section +2. Open the settings view (`agent: open settings`) and go to the Google AI section 3. Enter your Google AI API key and press enter. The Google AI API key will be saved in your keychain. @@ -288,7 +288,7 @@ Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless# > ✅ Supports tool use 1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) -2. Open the configuration view (`agent: open configuration`) and navigate to the Mistral section +2. Open the configuration view (`agent: open settings`) and navigate to the Mistral section 3. Enter your Mistral API key The Mistral API key will be saved in your keychain. @@ -399,7 +399,7 @@ If the model is tagged with `vision` in the Ollama catalog, set this option and 1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) 2. Make sure that your OpenAI account has credits -3. Open the settings view (`agent: open configuration`) and go to the OpenAI section +3. Open the settings view (`agent: open settings`) and go to the OpenAI section 4. Enter your OpenAI API key The OpenAI API key will be saved in your keychain. @@ -441,30 +441,26 @@ Custom models will be listed in the model dropdown in the Agent Panel. ### OpenAI API Compatible {#openai-api-compatible} -Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models. +Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider. +This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models. -To configure a compatible API, you can add a custom API URL for OpenAI either via the UI (currently available only in Preview) or by editing your `settings.json`. +You can add a custom, OpenAI-compatible model via either via the UI or by editing your `settings.json`. -For example, to connect to [Together AI](https://www.together.ai/) via the UI: +To do it via the UI, go to the Agent Panel settings (`agent: open settings`) and look for the "Add Provider" button to the right of the "LLM Providers" section title. +Then, fill up the input fields available in the modal. -1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys). -2. Go to the Agent Panel's settings view, click on the "Add Provider" button, and then on the "OpenAI" menu item -3. Add the requested fields, such as `api_url`, `api_key`, available models, and others - -Alternatively, you can also add it via the `settings.json`: +To do it via your `settings.json`, add the following snippet under `language_models`: ```json { "language_models": { "openai": { - "api_url": "https://api.together.xyz/v1", - "api_key": "YOUR_TOGETHER_AI_API_KEY", + "api_url": "https://api.together.xyz/v1", // Using Together AI as an example "available_models": [ { "name": "mistralai/Mixtral-8x7B-Instruct-v0.1", "display_name": "Together Mixtral 8x7B", - "max_tokens": 32768, - "supports_tools": true + "max_tokens": 32768 } ] } @@ -472,6 +468,9 @@ Alternatively, you can also add it via the `settings.json`: } ``` +Note that LLM API keys aren't stored in your settings file. +So, ensure you have it set in your environment variables (`OPENAI_API_KEY=`) so your settings can pick it up. + ### OpenRouter {#openrouter} > ✅ Supports tool use @@ -480,7 +479,7 @@ OpenRouter provides access to multiple AI models through a single API. It suppor 1. Visit [OpenRouter](https://openrouter.ai) and create an account 2. Generate an API key from your [OpenRouter keys page](https://openrouter.ai/keys) -3. Open the settings view (`agent: open configuration`) and go to the OpenRouter section +3. Open the settings view (`agent: open settings`) and go to the OpenRouter section 4. Enter your OpenRouter API key The OpenRouter API key will be saved in your keychain. @@ -551,7 +550,7 @@ You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. 1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) -2. Open the settings view (`agent: open configuration`) and go to the **xAI** section +2. Open the settings view (`agent: open settings`) and go to the **xAI** section 3. Enter your xAI API key The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined. diff --git a/docs/src/ai/mcp.md b/docs/src/ai/mcp.md index 5aef3d3d72..dfe3e4bdb9 100644 --- a/docs/src/ai/mcp.md +++ b/docs/src/ai/mcp.md @@ -50,7 +50,7 @@ You can connect them by adding their commands directly to your `settings.json`, } ``` -Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open configuration` action). +Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open settings` action). From there, you can add it through the modal that appears when you click the "Add Custom Server" button. ## Using MCP Servers diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index fd1761ebfa..5fd27abad6 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -2588,6 +2588,7 @@ List of `integer` column numbers "font_features": null, "font_size": null, "line_height": "comfortable", + "minimum_contrast": 45, "option_as_meta": false, "button": true, "shell": "system", @@ -2883,6 +2884,30 @@ See Buffer Font Features } ``` +### Terminal: Minimum Contrast + +- Description: Controls the minimum contrast between foreground and background colors in the terminal. Uses the APCA (Accessible Perceptual Contrast Algorithm) for color adjustments. Set this to 0 to disable this feature. +- Setting: `minimum_contrast` +- Default: `45` + +**Options** + +`integer` values from 0 to 106. Common recommended values: + +- `0`: No contrast adjustment +- `45`: Minimum for large fluent text (default) +- `60`: Minimum for other content text +- `75`: Minimum for body text +- `90`: Preferred for body text + +```json +{ + "terminal": { + "minimum_contrast": 45 + } +} +``` + ### Terminal: Option As Meta - Description: Re-interprets the option keys to act like a 'meta' key, like in Emacs. @@ -3390,7 +3415,7 @@ Run the `theme selector: toggle` action in the command palette to see a current ## Agent -Visit [the Configuration page](/ai/configuration.md) under the AI section to learn more about all the agent-related settings. +Visit [the Configuration page](./ai/configuration.md) under the AI section to learn more about all the agent-related settings. ## Outline Panel diff --git a/docs/src/development/debugging-crashes.md b/docs/src/development/debugging-crashes.md index d08ab961cc..ed0a5807a3 100644 --- a/docs/src/development/debugging-crashes.md +++ b/docs/src/development/debugging-crashes.md @@ -6,6 +6,7 @@ When an app crashes, - macOS creates a `.ips` file in `~/Library/Logs/DiagnosticReports`. You can view these using the built in Console app (`cmd-space Console`) under "Crash Reports". - Linux creates a core dump. See the [man pages](https://man7.org/linux/man-pages/man5/core.5.html) for pointers to how your system might be configured to manage core dumps. +- Windows doesn't create crash reports by default, but can be configured to create "minidump" memory dumps upon applications crashing. If you have enabled Zed's telemetry these will be uploaded to us when you restart the app. They end up in a [Slack channel (internal only)](https://zed-industries.slack.com/archives/C04S6T1T7TQ). diff --git a/docs/src/development/linux.md b/docs/src/development/linux.md index 6fff25f6c1..d7b586be34 100644 --- a/docs/src/development/linux.md +++ b/docs/src/development/linux.md @@ -91,7 +91,7 @@ Zed has two main binaries: - You will need to build `crates/cli` and make its binary available in `$PATH` with the name `zed`. - You will need to build `crates/zed` and put it at `$PATH/to/cli/../../libexec/zed-editor`. For example, if you are going to put the cli at `~/.local/bin/zed` put zed at `~/.local/libexec/zed-editor`. As some linux distributions (notably Arch) discourage the use of `libexec`, you can also put this binary at `$PATH/to/cli/../../lib/zed/zed-editor` (e.g. `~/.local/lib/zed/zed-editor`) instead. -- If you are going to provide a `.desktop` file you can find a template in `crates/zed/resources/zed.desktop.in`, and use `envsubst` to populate it with the values required. This file should also be renamed to `$APP_ID.desktop` so that the file [follows the FreeDesktop standards](https://github.com/zed-industries/zed/issues/12707#issuecomment-2168742761). +- If you are going to provide a `.desktop` file you can find a template in `crates/zed/resources/zed.desktop.in`, and use `envsubst` to populate it with the values required. This file should also be renamed to `$APP_ID.desktop` so that the file [follows the FreeDesktop standards](https://github.com/zed-industries/zed/issues/12707#issuecomment-2168742761). You should also make this desktop file executable (`chmod 755`). - You will need to ensure that the necessary libraries are installed. You can get the current list by [inspecting the built binary](https://github.com/zed-industries/zed/blob/935cf542aebf55122ce6ed1c91d0fe8711970c82/script/bundle-linux#L65-L67) on your system. - For an example of a complete build script, see [script/bundle-linux](https://github.com/zed-industries/zed/blob/935cf542aebf55122ce6ed1c91d0fe8711970c82/script/bundle-linux). - You can disable Zed's auto updates and provide instructions for users who try to update Zed manually by building (or running) Zed with the environment variable `ZED_UPDATE_EXPLANATION`. For example: `ZED_UPDATE_EXPLANATION="Please use flatpak to update zed."`. diff --git a/docs/src/extensions/installing-extensions.md b/docs/src/extensions/installing-extensions.md index aed8bef428..801fe5c55c 100644 --- a/docs/src/extensions/installing-extensions.md +++ b/docs/src/extensions/installing-extensions.md @@ -1,6 +1,6 @@ # Installing Extensions -You can search for extensions by launching the Zed Extension Gallery by pressing `cmd-shift-x` (macOS) or `ctrl-shift-x` (Linux), opening the command palette and selecting `zed: extensions` or by selecting "Zed > Extensions" from the menu bar. +You can search for extensions by launching the Zed Extension Gallery by pressing {#kb zed::Extensions} , opening the command palette and selecting {#action zed::Extensions} or by selecting "Zed > Extensions" from the menu bar. Here you can view the extensions that you currently have installed or search and install new ones. diff --git a/docs/src/extensions/languages.md b/docs/src/extensions/languages.md index 44c673e3e1..6756cb8a23 100644 --- a/docs/src/extensions/languages.md +++ b/docs/src/extensions/languages.md @@ -402,11 +402,10 @@ If your language server supports additional languages, you can use `language_ids [language-servers.my-language-server] name = "Whatever LSP" -languages = ["JavaScript", "JSX", "HTML", "CSS"] +languages = ["JavaScript", "HTML", "CSS"] [language-servers.my-language-server.language_ids] "JavaScript" = "javascript" -"JSX" = "javascriptreact" "TSX" = "typescriptreact" "HTML" = "html" "CSS" = "css" diff --git a/docs/src/getting-started.md b/docs/src/getting-started.md index 5940c74b21..22af3b36d7 100644 --- a/docs/src/getting-started.md +++ b/docs/src/getting-started.md @@ -83,6 +83,6 @@ Visit [the AI overview page](./ai/overview.md) to learn how to quickly get start ## Set up your key bindings -To open your custom keymap to add your key bindings, use the {#kb zed::OpenKeymap} keybinding. +To edit your custom keymap and add or remap bindings, you can either use {#kb zed::OpenKeymapEditor} to spawn the Zed Keymap Editor ({#action zed::OpenKeymapEditor}) or you can directly open your Zed Keymap json (`~/.config/zed/keymap.json`) with {#action zed::OpenKeymap}. To access the default key binding set, open the Command Palette with {#kb command_palette::Toggle} and search for "zed: open default keymap". See [Key Bindings](./key-bindings.md) for more info. diff --git a/docs/src/git.md b/docs/src/git.md index 5b5c8a3b15..cccbad9b2e 100644 --- a/docs/src/git.md +++ b/docs/src/git.md @@ -1,3 +1,8 @@ +--- +description: Zed is a text editor that supports lots of Git features +title: Zed Editor Git integration documentation +--- + # Git Zed currently offers a set of fundamental Git features, with support coming in the future for more advanced ones, like conflict resolution tools, line by line staging, and more. diff --git a/docs/src/key-bindings.md b/docs/src/key-bindings.md index 90aa400bb4..feed912787 100644 --- a/docs/src/key-bindings.md +++ b/docs/src/key-bindings.md @@ -18,7 +18,7 @@ You can also enable `vim_mode`, which adds vim bindings too. ## User keymaps -Zed reads your keymap from `~/.config/zed/keymap.json`. You can open the file within Zed with {#kb zed::OpenKeymap}, or via `zed: Open Keymap` in the command palette. +Zed reads your keymap from `~/.config/zed/keymap.json`. You can open the file within Zed with {#action zed::OpenKeymap} from the command palette or to spawn the Zed Keymap Editor ({#action zed::OpenKeymapEditor}) use {#kb zed::OpenKeymapEditor}. The file contains a JSON array of objects with `"bindings"`. If no `"context"` is set the bindings are always active. If it is set the binding is only active when the [context matches](#contexts). @@ -93,7 +93,7 @@ For example: # in an editor, it might look like this: Workspace os=macos keyboard_layout=com.apple.keylayout.QWERTY Pane - Editor mode=full extension=md inline_completion vim_mode=insert + Editor mode=full extension=md vim_mode=insert # in the project panel Workspace os=macos diff --git a/docs/src/languages/deno.md b/docs/src/languages/deno.md index c18b112326..c40b6531e6 100644 --- a/docs/src/languages/deno.md +++ b/docs/src/languages/deno.md @@ -57,6 +57,40 @@ See [Configuring supported languages](../configuring-languages.md) in the Zed do TBD: Deno Typescript REPL instructions [docs/repl#typescript-deno](../repl.md#typescript-deno) --> +## DAP support + +To debug deno programs, add this to `.zed/debug.json` + +```json +[ + { + "adapter": "JavaScript", + "label": "Deno", + "request": "launch", + "type": "pwa-node", + "cwd": "$ZED_WORKTREE_ROOT", + "program": "$ZED_FILE", + "runtimeExecutable": "deno", + "runtimeArgs": ["run", "--allow-all", "--inspect-wait"], + "attachSimplePort": 9229 + } +] +``` + +## Runnable support + +To run deno tasks like tests from the ui, add this to `.zed/tasks.json` + +```json +[ + { + "label": "deno test", + "command": "deno test -A --filter '/^$ZED_CUSTOM_DENO_TEST_NAME$/' $ZED_FILE", + "tags": ["js-test"] + } +] +``` + ## See also: - [TypeScript](./typescript.md) diff --git a/docs/src/telemetry.md b/docs/src/telemetry.md index 7f5994be0c..107aef5a96 100644 --- a/docs/src/telemetry.md +++ b/docs/src/telemetry.md @@ -21,7 +21,7 @@ The telemetry settings can also be configured via the welcome screen, which can Telemetry is sent from the application to our servers. Data is proxied through our servers to enable us to easily switch analytics services. We currently use: -- [Axiom](https://axiom.co): Cloud-monitoring service - stores diagnostic events +- [Sentry](https://sentry.io): Crash-monitoring service - stores diagnostic events - [Snowflake](https://snowflake.com): Data warehouse - stores both diagnostic and metric events - [Hex](https://www.hex.tech): Dashboards and data exploration - accesses data stored in Snowflake - [Amplitude](https://www.amplitude.com): Dashboards and data exploration - accesses data stored in Snowflake @@ -30,9 +30,9 @@ Telemetry is sent from the application to our servers. Data is proxied through o ### Diagnostics -Diagnostic events include debug information (stack traces) from crash reports. Reports are sent on the first application launch after the crash occurred. We've built dashboards that allow us to visualize the frequency and severity of issues experienced by users. Having these reports sent automatically allows us to begin implementing fixes without the user needing to file a report in our issue tracker. The plots in the dashboards also give us an informal measurement of the stability of Zed. +Crash reports consist of a [minidump](https://learn.microsoft.com/en-us/windows/win32/debug/minidump-files) and some extra debug information. Reports are sent on the first application launch after the crash occurred. We've built dashboards that allow us to visualize the frequency and severity of issues experienced by users. Having these reports sent automatically allows us to begin implementing fixes without the user needing to file a report in our issue tracker. The plots in the dashboards also give us an informal measurement of the stability of Zed. -You can see what data is sent when a panic occurs by inspecting the `Panic` struct in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repo. You can find additional information in the [Debugging Crashes](./development/debugging-crashes.md) documentation. +You can see what extra data is sent alongside the minidump in the `Panic` struct in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repo. You can find additional information in the [Debugging Crashes](./development/debugging-crashes.md) documentation. ### Client-Side Usage Data {#client-metrics} diff --git a/docs/theme/index.hbs b/docs/theme/index.hbs index 8ab4f21cf1..4339a02d17 100644 --- a/docs/theme/index.hbs +++ b/docs/theme/index.hbs @@ -15,7 +15,7 @@ {{> head}} - + diff --git a/extensions/emmet/Cargo.toml b/extensions/emmet/Cargo.toml index db8aaaae41..9d72a6c5c4 100644 --- a/extensions/emmet/Cargo.toml +++ b/extensions/emmet/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zed_emmet" -version = "0.0.3" +version = "0.0.4" edition.workspace = true publish.workspace = true license = "Apache-2.0" diff --git a/extensions/ruff/Cargo.toml b/extensions/ruff/Cargo.toml index 830897279a..24616f963b 100644 --- a/extensions/ruff/Cargo.toml +++ b/extensions/ruff/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zed_ruff" -version = "0.1.0" +version = "0.1.1" edition.workspace = true publish.workspace = true license = "Apache-2.0" diff --git a/extensions/ruff/extension.toml b/extensions/ruff/extension.toml index 63929fc191..1f5a7314f4 100644 --- a/extensions/ruff/extension.toml +++ b/extensions/ruff/extension.toml @@ -1,7 +1,7 @@ id = "ruff" name = "Ruff" description = "Support for Ruff, the Python linter and formatter" -version = "0.1.0" +version = "0.1.1" schema_version = 1 authors = [] repository = "https://github.com/zed-industries/zed" diff --git a/nix/build.nix b/nix/build.nix index 873431a427..70b4f76932 100644 --- a/nix/build.nix +++ b/nix/build.nix @@ -298,6 +298,7 @@ craneLib.buildPackage ( export APP_ARGS="%U" mkdir -p "$out/share/applications" ${lib.getExe envsubst} < "crates/zed/resources/zed.desktop.in" > "$out/share/applications/dev.zed.Zed-Nightly.desktop" + chmod +x "$out/share/applications/dev.zed.Zed-Nightly.desktop" ) runHook postInstall diff --git a/script/bundle-freebsd b/script/bundle-freebsd index 7222a06256..87c9459ffb 100755 --- a/script/bundle-freebsd +++ b/script/bundle-freebsd @@ -138,6 +138,7 @@ fi # mkdir -p "${zed_dir}/share/applications" # envsubst <"crates/zed/resources/zed.desktop.in" >"${zed_dir}/share/applications/zed$suffix.desktop" +# chmod +x "${zed_dir}/share/applications/zed$suffix.desktop" # Copy generated licenses so they'll end up in archive too # cp "assets/licenses.md" "${zed_dir}/licenses.md" diff --git a/script/bundle-linux b/script/bundle-linux index c52312015b..ad67b7a0f7 100755 --- a/script/bundle-linux +++ b/script/bundle-linux @@ -83,6 +83,23 @@ if [[ "$remote_server_triple" == "$musl_triple" ]]; then fi cargo build --release --target "${remote_server_triple}" --package remote_server +# Upload debug info to sentry.io +if ! command -v sentry-cli >/dev/null 2>&1; then + echo "sentry-cli not found. skipping sentry upload." + echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" +else + if [[ -n "${SENTRY_AUTH_TOKEN:-}" ]]; then + echo "Uploading zed debug symbols to sentry..." + # note: this uploads the unstripped binary which is needed because it contains + # .eh_frame data for stack unwinindg. see https://github.com/getsentry/symbolic/issues/783 + sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev \ + "${target_dir}/${target_triple}"/release/zed \ + "${target_dir}/${remote_server_triple}"/release/remote_server + else + echo "missing SENTRY_AUTH_TOKEN. skipping sentry upload." + fi +fi + # Strip debug symbols and save them for upload to DigitalOcean objcopy --only-keep-debug "${target_dir}/${target_triple}/release/zed" "${target_dir}/${target_triple}/release/zed.dbg" objcopy --only-keep-debug "${target_dir}/${remote_server_triple}/release/remote_server" "${target_dir}/${remote_server_triple}/release/remote_server.dbg" @@ -162,6 +179,7 @@ fi mkdir -p "${zed_dir}/share/applications" envsubst < "crates/zed/resources/zed.desktop.in" > "${zed_dir}/share/applications/zed$suffix.desktop" +chmod +x "${zed_dir}/share/applications/zed$suffix.desktop" # Copy generated licenses so they'll end up in archive too cp "assets/licenses.md" "${zed_dir}/licenses.md" diff --git a/script/bundle-mac b/script/bundle-mac index 18dfe90815..b2be573235 100755 --- a/script/bundle-mac +++ b/script/bundle-mac @@ -366,3 +366,20 @@ else gzip -f --stdout --best target/x86_64-apple-darwin/release/remote_server > target/zed-remote-server-macos-x86_64.gz gzip -f --stdout --best target/aarch64-apple-darwin/release/remote_server > target/zed-remote-server-macos-aarch64.gz fi + +# Upload debug info to sentry.io +if ! command -v sentry-cli >/dev/null 2>&1; then + echo "sentry-cli not found. skipping sentry upload." + echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" +else + if [[ -n "${SENTRY_AUTH_TOKEN:-}" ]]; then + echo "Uploading zed debug symbols to sentry..." + # note: this uploads the unstripped binary which is needed because it contains + # .eh_frame data for stack unwinindg. see https://github.com/getsentry/symbolic/issues/783 + sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev \ + "target/x86_64-apple-darwin/${target_dir}/" \ + "target/aarch64-apple-darwin/${target_dir}/" + else + echo "missing SENTRY_AUTH_TOKEN. skipping sentry upload." + fi +fi diff --git a/script/bundle-windows.ps1 b/script/bundle-windows.ps1 index 01a1114c26..8ae0212491 100644 --- a/script/bundle-windows.ps1 +++ b/script/bundle-windows.ps1 @@ -26,6 +26,7 @@ if ($Help) { Push-Location -Path crates/zed $channel = Get-Content "RELEASE_CHANNEL" $env:ZED_RELEASE_CHANNEL = $channel +$env:RELEASE_CHANNEL = $channel Pop-Location function CheckEnvironmentVariables { @@ -96,6 +97,21 @@ function ZipZedAndItsFriendsDebug { Compress-Archive -Path $items -DestinationPath ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" -Force } + +function UploadToSentry { + if (-not (Get-Command "sentry-cli" -ErrorAction SilentlyContinue)) { + Write-Output "sentry-cli not found. skipping sentry upload." + Write-Output "install with: 'winget install -e --id=Sentry.sentry-cli'" + return + } + if (-not (Test-Path "env:SENTRY_AUTH_TOKEN")) { + Write-Output "missing SENTRY_AUTH_TOKEN. skipping sentry upload." + return + } + Write-Output "Uploading zed debug symbols to sentry..." + sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev .\target\release\ +} + function MakeAppx { switch ($channel) { "stable" { @@ -120,11 +136,22 @@ function SignZedAndItsFriends { & "$innoDir\sign.ps1" $files } +function DownloadAMDGpuServices { + # If you update the AGS SDK version, please also update the version in `crates/gpui/src/platform/windows/directx_renderer.rs` + $url = "https://codeload.github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/zip/refs/tags/v6.3.0" + $zipPath = ".\AGS_SDK_v6.3.0.zip" + # Download the AGS SDK zip file + Invoke-WebRequest -Uri $url -OutFile $zipPath + # Extract the AGS SDK zip file + Expand-Archive -Path $zipPath -DestinationPath "." -Force +} + function CollectFiles { Move-Item -Path "$innoDir\zed_explorer_command_injector.appx" -Destination "$innoDir\appx\zed_explorer_command_injector.appx" -Force Move-Item -Path "$innoDir\zed_explorer_command_injector.dll" -Destination "$innoDir\appx\zed_explorer_command_injector.dll" -Force Move-Item -Path "$innoDir\cli.exe" -Destination "$innoDir\bin\zed.exe" -Force Move-Item -Path "$innoDir\auto_update_helper.exe" -Destination "$innoDir\tools\auto_update_helper.exe" -Force + Move-Item -Path ".\AGS_SDK-6.3.0\ags_lib\lib\amd_ags_x64.dll" -Destination "$innoDir\amd_ags_x64.dll" -Force } function BuildInstaller { @@ -195,7 +222,6 @@ function BuildInstaller { # Windows runner 2022 default has iscc in PATH, https://github.com/actions/runner-images/blob/main/images/windows/Windows2022-Readme.md # Currently, we are using Windows 2022 runner. # Windows runner 2025 doesn't have iscc in PATH for now, https://github.com/actions/runner-images/issues/11228 - # $innoSetupPath = "iscc.exe" $innoSetupPath = "C:\Program Files (x86)\Inno Setup 6\ISCC.exe" $definitions = @{ @@ -242,6 +268,8 @@ function BuildInstaller { ParseZedWorkspace $innoDir = "$env:ZED_WORKSPACE\inno" +$debugArchive = ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" +$debugStoreKey = "$env:ZED_RELEASE_CHANNEL/zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" CheckEnvironmentVariables PrepareForBundle @@ -250,12 +278,12 @@ BuildZedAndItsFriends MakeAppx SignZedAndItsFriends ZipZedAndItsFriendsDebug +DownloadAMDGpuServices CollectFiles BuildInstaller -$debugArchive = ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" -$debugStoreKey = "$env:ZED_RELEASE_CHANNEL/zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" UploadToBlobStorePublic -BucketName "zed-debug-symbols" -FileToUpload $debugArchive -BlobStoreKey $debugStoreKey +UploadToSentry if ($buildSuccess) { Write-Output "Build successful" diff --git a/script/linux b/script/linux index 98ae026896..029278bea3 100755 --- a/script/linux +++ b/script/linux @@ -143,6 +143,7 @@ if [[ -n $zyp ]]; then gzip jq libvulkan1 + libx11-devel libxcb-devel libxkbcommon-devel libxkbcommon-x11-devel diff --git a/script/zed-local b/script/zed-local index 2568931246..99d9308232 100755 --- a/script/zed-local +++ b/script/zed-local @@ -213,7 +213,7 @@ setTimeout(() => { platform === "win32" ? "http://127.0.0.1:8080/rpc" : "http://localhost:8080/rpc", - ZED_ADMIN_API_TOKEN: "secret", + ZED_ADMIN_API_TOKEN: "internal-api-key-secret", ZED_WINDOW_SIZE: size, ZED_CLIENT_CHECKSUM_SEED: "development-checksum-seed", RUST_LOG: process.env.RUST_LOG || "info", diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index 1026454026..5678e46236 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -82,6 +82,7 @@ lyon = { version = "1", default-features = false, features = ["extra"] } lyon_path = { version = "1" } md-5 = { version = "0.10" } memchr = { version = "2" } +mime_guess = { version = "2" } miniz_oxide = { version = "0.8", features = ["simd"] } nom = { version = "7" } num-bigint = { version = "0.4" } @@ -212,6 +213,7 @@ lyon = { version = "1", default-features = false, features = ["extra"] } lyon_path = { version = "1" } md-5 = { version = "0.10" } memchr = { version = "2" } +mime_guess = { version = "2" } miniz_oxide = { version = "0.8", features = ["simd"] } nom = { version = "7" } num-bigint = { version = "0.4" } @@ -284,14 +286,13 @@ winnow = { version = "0.7", features = ["simd"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -310,18 +311,16 @@ tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } [target.x86_64-apple-darwin.build-dependencies] -clang-sys = { version = "1", default-features = false, features = ["clang_11_0", "runtime"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -344,14 +343,13 @@ tower = { version = "0.5", default-features = false, features = ["timeout", "uti codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -370,18 +368,16 @@ tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } [target.aarch64-apple-darwin.build-dependencies] -clang-sys = { version = "1", default-features = false, features = ["clang_11_0", "runtime"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -420,7 +416,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } @@ -460,7 +457,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } @@ -498,7 +496,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } @@ -538,7 +537,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } @@ -564,7 +564,6 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -naga = { version = "25", features = ["spv-out", "wgsl-in"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } scopeguard = { version = "1" } @@ -578,7 +577,7 @@ windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } -windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } +windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-pc-windows-msvc.build-dependencies] codespan-reporting = { version = "0.12" } @@ -588,7 +587,6 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -naga = { version = "25", features = ["spv-out", "wgsl-in"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } @@ -603,7 +601,7 @@ windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } -windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } +windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-unknown-linux-musl.dependencies] aes = { version = "0.8", default-features = false, features = ["zeroize"] } @@ -625,7 +623,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } @@ -665,7 +664,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } diff --git a/typos.toml b/typos.toml index 7f1c6e04f1..336a829a44 100644 --- a/typos.toml +++ b/typos.toml @@ -71,6 +71,10 @@ extend-ignore-re = [ # Not an actual typo but an intentionally invalid color, in `color_extractor` "#fof", # Stripped version of reserved keyword `type` - "typ" + "typ", + # AMD GPU Services + "ags", + # AMD GPU Services + "AGS" ] check-filename = true