diff --git a/.github/actionlint.yml b/.github/actionlint.yml index d93ec5b15e..6bfbc27705 100644 --- a/.github/actionlint.yml +++ b/.github/actionlint.yml @@ -24,7 +24,6 @@ self-hosted-runner: - buildjet-8vcpu-ubuntu-2204-arm - buildjet-16vcpu-ubuntu-2204-arm - buildjet-32vcpu-ubuntu-2204-arm - - buildjet-64vcpu-ubuntu-2204-arm # Self Hosted Runners - self-mini-macos - self-32vcpu-windows-2022 diff --git a/.github/actions/build_docs/action.yml b/.github/actions/build_docs/action.yml index 9a2d7e1ec7..a7effad247 100644 --- a/.github/actions/build_docs/action.yml +++ b/.github/actions/build_docs/action.yml @@ -19,7 +19,7 @@ runs: shell: bash -euxo pipefail {0} run: ./script/linux - - name: Check for broken links + - name: Check for broken links (in MD) uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 with: args: --no-progress --exclude '^http' './docs/src/**/*' @@ -30,3 +30,9 @@ runs: run: | mkdir -p target/deploy mdbook build ./docs --dest-dir=../target/deploy/docs/ + + - name: Check for broken links (in HTML) + uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 + with: + args: --no-progress --exclude '^http' 'target/deploy/docs/' + fail: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a4da5e99ba..43d305faae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,6 +24,7 @@ env: DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + ZED_MINIDUMP_ENDPOINT: ${{ secrets.ZED_SENTRY_MINIDUMP_ENDPOINT }} jobs: job_spec: @@ -269,6 +270,10 @@ jobs: mkdir -p ./../.cargo cp ./.cargo/ci-config.toml ./../.cargo/config.toml + - name: Check that Cargo.lock is up to date + run: | + cargo update --locked --workspace + - name: cargo clippy run: ./script/clippy @@ -645,7 +650,7 @@ jobs: timeout-minutes: 60 name: Linux arm64 release bundle runs-on: - - buildjet-16vcpu-ubuntu-2204-arm + - buildjet-32vcpu-ubuntu-2204-arm if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') @@ -767,7 +772,8 @@ jobs: timeout-minutes: 120 name: Create a Windows installer runs-on: [self-hosted, Windows, X64] - if: false && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) + if: contains(github.event.pull_request.labels.*.name, 'run-bundling') + # if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) needs: [windows_tests] env: AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }} diff --git a/.github/workflows/nix.yml b/.github/workflows/nix.yml index beacd27774..6c3a97c163 100644 --- a/.github/workflows/nix.yml +++ b/.github/workflows/nix.yml @@ -29,6 +29,7 @@ jobs: runs-on: ${{ matrix.system.runner }} env: ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + ZED_MINIDUMP_ENDPOINT: ${{ secrets.ZED_SENTRY_MINIDUMP_ENDPOINT }} ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }} GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on steps: diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index f799133ea7..c847149984 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -13,6 +13,7 @@ env: CARGO_INCREMENTAL: 0 RUST_BACKTRACE: 1 ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + ZED_MINIDUMP_ENDPOINT: ${{ secrets.ZED_SENTRY_MINIDUMP_ENDPOINT }} DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} @@ -111,6 +112,11 @@ jobs: echo "Publishing version: ${version} on release channel nightly" echo "nightly" > crates/zed/RELEASE_CHANNEL + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Create macOS app bundle run: script/bundle-mac @@ -136,6 +142,11 @@ jobs: - name: Install Linux dependencies run: ./script/linux && ./script/install-mold 2.34.0 + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Limit target directory size run: script/clear-target-dir-if-larger-than 100 @@ -157,7 +168,7 @@ jobs: name: Create a Linux *.tar.gz bundle for ARM if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-16vcpu-ubuntu-2204-arm + - buildjet-32vcpu-ubuntu-2204-arm needs: tests steps: - name: Checkout repo @@ -168,6 +179,11 @@ jobs: - name: Install Linux dependencies run: ./script/linux + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Limit target directory size run: script/clear-target-dir-if-larger-than 100 @@ -262,6 +278,11 @@ jobs: Write-Host "Publishing version: $version on release channel nightly" "nightly" | Set-Content -Path "crates/zed/RELEASE_CHANNEL" + - name: Setup Sentry CLI + uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2 + with: + token: ${{ SECRETS.SENTRY_AUTH_TOKEN }} + - name: Build Zed installer working-directory: ${{ env.ZED_WORKSPACE }} run: script/bundle-windows.ps1 diff --git a/Cargo.lock b/Cargo.lock index 8f791d395a..1eb5669fa2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6,10 +6,9 @@ version = 4 name = "acp_thread" version = "0.1.0" dependencies = [ - "agentic-coding-protocol", + "agent-client-protocol", "anyhow", "assistant_tool", - "async-pipe", "buffer_diff", "editor", "env_logger 0.11.8", @@ -18,8 +17,11 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "language_model", "markdown", + "parking_lot", "project", + "rand 0.8.5", "serde", "serde_json", "settings", @@ -89,6 +91,7 @@ dependencies = [ "assistant_tools", "chrono", "client", + "cloud_llm_client", "collections", "component", "context_server", @@ -112,7 +115,6 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", - "proto", "rand 0.8.5", "ref-cast", "rope", @@ -131,15 +133,68 @@ dependencies = [ "uuid", "workspace", "workspace-hack", - "zed_llm_client", "zstd", ] +[[package]] +name = "agent-client-protocol" +version = "0.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12dbfec3d27680337ed9d3064eecafe97acf0b0f190148bb4e29d96707c9e403" +dependencies = [ + "anyhow", + "futures 0.3.31", + "log", + "parking_lot", + "schemars", + "serde", + "serde_json", +] + +[[package]] +name = "agent2" +version = "0.1.0" +dependencies = [ + "acp_thread", + "agent-client-protocol", + "agent_servers", + "anyhow", + "client", + "clock", + "cloud_llm_client", + "collections", + "ctor", + "env_logger 0.11.8", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "handlebars 4.5.0", + "indoc", + "language_model", + "language_models", + "log", + "project", + "reqwest_client", + "rust-embed", + "schemars", + "serde", + "serde_json", + "settings", + "smol", + "ui", + "util", + "uuid", + "workspace-hack", + "worktree", +] + [[package]] name = "agent_servers" version = "0.1.0" dependencies = [ "acp_thread", + "agent-client-protocol", "agentic-coding-protocol", "anyhow", "collections", @@ -155,6 +210,7 @@ dependencies = [ "nix 0.29.0", "paths", "project", + "rand 0.8.5", "schemars", "serde", "serde_json", @@ -162,6 +218,7 @@ dependencies = [ "smol", "strum 0.27.1", "tempfile", + "thiserror 2.0.12", "ui", "util", "uuid", @@ -175,6 +232,7 @@ name = "agent_settings" version = "0.1.0" dependencies = [ "anyhow", + "cloud_llm_client", "collections", "fs", "gpui", @@ -186,7 +244,6 @@ dependencies = [ "serde_json_lenient", "settings", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -195,9 +252,10 @@ version = "0.1.0" dependencies = [ "acp_thread", "agent", + "agent-client-protocol", + "agent2", "agent_servers", "agent_settings", - "agentic-coding-protocol", "ai_onboarding", "anyhow", "assistant_context", @@ -209,6 +267,7 @@ dependencies = [ "buffer_diff", "chrono", "client", + "cloud_llm_client", "collections", "command_palette_hooks", "component", @@ -280,7 +339,6 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", - "zed_llm_client", ] [[package]] @@ -341,10 +399,10 @@ name = "ai_onboarding" version = "0.1.0" dependencies = [ "client", + "cloud_llm_client", "component", "gpui", "language_model", - "proto", "serde", "smallvec", "telemetry", @@ -673,6 +731,7 @@ dependencies = [ "chrono", "client", "clock", + "cloud_llm_client", "collections", "context_server", "fs", @@ -706,7 +765,6 @@ dependencies = [ "uuid", "workspace", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -814,6 +872,7 @@ dependencies = [ "chrono", "client", "clock", + "cloud_llm_client", "collections", "component", "derive_more 0.99.19", @@ -867,7 +926,6 @@ dependencies = [ "which 6.0.3", "workspace", "workspace-hack", - "zed_llm_client", "zlog", ] @@ -1061,17 +1119,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "async-recursion" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "async-recursion" version = "1.1.1" @@ -1165,7 +1212,7 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_qs 0.10.1", - "smart-default", + "smart-default 0.6.0", "smol_str 0.1.24", "thiserror 1.0.69", "tokio", @@ -2957,11 +3004,12 @@ name = "client" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 0.3.2", "async-tungstenite", "base64 0.22.1", "chrono", "clock", + "cloud_api_client", + "cloud_llm_client", "cocoa 0.26.0", "collections", "credentials_provider", @@ -3004,7 +3052,6 @@ dependencies = [ "windows 0.61.1", "workspace-hack", "worktree", - "zed_llm_client", ] [[package]] @@ -3017,6 +3064,44 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "cloud_api_client" +version = "0.1.0" +dependencies = [ + "anyhow", + "cloud_api_types", + "futures 0.3.31", + "http_client", + "parking_lot", + "serde_json", + "workspace-hack", +] + +[[package]] +name = "cloud_api_types" +version = "0.1.0" +dependencies = [ + "chrono", + "cloud_llm_client", + "pretty_assertions", + "serde", + "serde_json", + "workspace-hack", +] + +[[package]] +name = "cloud_llm_client" +version = "0.1.0" +dependencies = [ + "anyhow", + "pretty_assertions", + "serde", + "serde_json", + "strum 0.27.1", + "uuid", + "workspace-hack", +] + [[package]] name = "clru" version = "0.6.2" @@ -3143,6 +3228,7 @@ dependencies = [ "chrono", "client", "clock", + "cloud_llm_client", "collab_ui", "collections", "command_palette_hooks", @@ -3229,7 +3315,6 @@ dependencies = [ "workspace", "workspace-hack", "worktree", - "zed_llm_client", "zlog", ] @@ -3497,13 +3582,13 @@ dependencies = [ "command_palette_hooks", "ctor", "dirs 4.0.0", + "edit_prediction", "editor", "fs", "futures 0.3.31", "gpui", "http_client", "indoc", - "inline_completion", "itertools 0.14.0", "language", "log", @@ -3517,6 +3602,7 @@ dependencies = [ "serde", "serde_json", "settings", + "sum_tree", "task", "theme", "ui", @@ -3670,17 +3756,6 @@ dependencies = [ "libm", ] -[[package]] -name = "coreaudio-rs" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "321077172d79c662f64f5071a03120748d5bb652f5231570141be24cfcd2bace" -dependencies = [ - "bitflags 1.3.2", - "core-foundation-sys", - "coreaudio-sys", -] - [[package]] name = "coreaudio-rs" version = "0.12.1" @@ -3738,29 +3813,6 @@ dependencies = [ "unicode-segmentation", ] -[[package]] -name = "cpal" -version = "0.15.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "873dab07c8f743075e57f524c583985fbaf745602acbe916a01539364369a779" -dependencies = [ - "alsa", - "core-foundation-sys", - "coreaudio-rs 0.11.3", - "dasp_sample", - "jni", - "js-sys", - "libc", - "mach2", - "ndk 0.8.0", - "ndk-context", - "oboe", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", - "windows 0.54.0", -] - [[package]] name = "cpal" version = "0.16.0" @@ -3774,7 +3826,7 @@ dependencies = [ "js-sys", "libc", "mach2", - "ndk 0.9.0", + "ndk", "ndk-context", "num-derive", "num-traits", @@ -3915,6 +3967,42 @@ dependencies = [ "target-lexicon 0.13.2", ] +[[package]] +name = "crash-context" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031ed29858d90cfdf27fe49fae28028a1f20466db97962fa2f4ea34809aeebf3" +dependencies = [ + "cfg-if", + "libc", + "mach2", +] + +[[package]] +name = "crash-handler" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2066907075af649bcb8bcb1b9b986329b243677e6918b2d920aa64b0aac5ace3" +dependencies = [ + "cfg-if", + "crash-context", + "libc", + "mach2", + "parking_lot", +] + +[[package]] +name = "crashes" +version = "0.1.0" +dependencies = [ + "crash-handler", + "log", + "minidumper", + "paths", + "smol", + "workspace-hack", +] + [[package]] name = "crc" version = "3.2.1" @@ -4244,7 +4332,7 @@ dependencies = [ [[package]] name = "dap-types" version = "0.0.1" -source = "git+https://github.com/zed-industries/dap-types?rev=7f39295b441614ca9dbf44293e53c32f666897f9#7f39295b441614ca9dbf44293e53c32f666897f9" +source = "git+https://github.com/zed-industries/dap-types?rev=1b461b310481d01e02b2603c16d7144b926339f8#1b461b310481d01e02b2603c16d7144b926339f8" dependencies = [ "schemars", "serde", @@ -4276,41 +4364,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "darling" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn 2.0.101", -] - -[[package]] -name = "darling_macro" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" -dependencies = [ - "darling_core", - "quote", - "syn 2.0.101", -] - [[package]] name = "dashmap" version = "5.5.3" @@ -4476,6 +4529,15 @@ dependencies = [ "zlog", ] +[[package]] +name = "debugid" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" +dependencies = [ + "uuid", +] + [[package]] name = "deepseek" version = "0.1.0" @@ -4526,37 +4588,6 @@ dependencies = [ "serde", ] -[[package]] -name = "derive_builder" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 2.0.101", -] - -[[package]] -name = "derive_builder_macro" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" -dependencies = [ - "derive_builder_core", - "syn 2.0.101", -] - [[package]] name = "derive_more" version = "0.99.19" @@ -4778,7 +4809,6 @@ name = "docs_preprocessor" version = "0.1.0" dependencies = [ "anyhow", - "clap", "command_palette", "gpui", "mdbook", @@ -4789,6 +4819,7 @@ dependencies = [ "util", "workspace-hack", "zed", + "zlog", ] [[package]] @@ -4910,6 +4941,49 @@ dependencies = [ "signature 1.6.4", ] +[[package]] +name = "edit_prediction" +version = "0.1.0" +dependencies = [ + "client", + "gpui", + "language", + "project", + "workspace-hack", +] + +[[package]] +name = "edit_prediction_button" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "cloud_llm_client", + "copilot", + "edit_prediction", + "editor", + "feature_flags", + "fs", + "futures 0.3.31", + "gpui", + "indoc", + "language", + "lsp", + "paths", + "project", + "regex", + "serde_json", + "settings", + "supermaven", + "telemetry", + "theme", + "ui", + "workspace", + "workspace-hack", + "zed_actions", + "zeta", +] + [[package]] name = "editor" version = "0.1.0" @@ -4925,6 +4999,7 @@ dependencies = [ "ctor", "dap", "db", + "edit_prediction", "emojis", "file_icons", "fs", @@ -4934,7 +5009,6 @@ dependencies = [ "gpui", "http_client", "indoc", - "inline_completion", "itertools 0.14.0", "language", "languages", @@ -4966,6 +5040,8 @@ dependencies = [ "text", "theme", "time", + "tree-sitter-bash", + "tree-sitter-c", "tree-sitter-html", "tree-sitter-python", "tree-sitter-rust", @@ -5248,6 +5324,7 @@ dependencies = [ "chrono", "clap", "client", + "cloud_llm_client", "collections", "debug_adapter_extension", "dirs 4.0.0", @@ -5287,7 +5364,6 @@ dependencies = [ "uuid", "watch", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -5352,6 +5428,12 @@ dependencies = [ "zune-inflate", ] +[[package]] +name = "extended" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af9673d8203fcb076b19dfd17e38b3d4ae9f44959416ea532ce72415a6020365" + [[package]] name = "extension" version = "0.1.0" @@ -5371,11 +5453,13 @@ dependencies = [ "log", "lsp", "parking_lot", + "pretty_assertions", "semantic_version", "serde", "serde_json", "task", "toml 0.8.20", + "url", "util", "wasm-encoder 0.221.3", "wasmparser 0.221.3", @@ -5926,7 +6010,7 @@ dependencies = [ "ignore", "libc", "log", - "notify", + "notify 8.0.0", "objc", "parking_lot", "paths", @@ -6360,7 +6444,7 @@ dependencies = [ "buffer_diff", "call", "chrono", - "client", + "cloud_llm_client", "collections", "command_palette_hooks", "component", @@ -6371,6 +6455,7 @@ dependencies = [ "fuzzy", "git", "gpui", + "indoc", "itertools 0.14.0", "language", "language_model", @@ -6403,7 +6488,6 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", - "zed_llm_client", "zlog", ] @@ -7236,6 +7320,17 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "goblin" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b363a30c165f666402fe6a3024d3bec7ebc898f96a4a23bd1c99f8dbf3f4f47" +dependencies = [ + "log", + "plain", + "scroll", +] + [[package]] name = "google_ai" version = "0.1.0" @@ -7483,18 +7578,16 @@ dependencies = [ [[package]] name = "handlebars" -version = "6.3.2" +version = "5.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "759e2d5aea3287cb1190c8ec394f42866cb5bf74fcbf213f354e3c856ea26098" +checksum = "d08485b96a0e6393e9e4d1b8d48cf74ad6c063cd905eb33f42c1ce3f0377539b" dependencies = [ - "derive_builder", "log", - "num-order", "pest", "pest_derive", "serde", "serde_json", - "thiserror 2.0.12", + "thiserror 1.0.69", ] [[package]] @@ -7719,12 +7812,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "hound" -version = "3.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" - [[package]] name = "html5ever" version = "0.27.0" @@ -7858,6 +7945,8 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "log", + "parking_lot", + "reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)", "serde", "serde_json", "url", @@ -8165,12 +8254,6 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "1.0.3" @@ -8347,46 +8430,14 @@ dependencies = [ ] [[package]] -name = "inline_completion" -version = "0.1.0" +name = "inotify" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" dependencies = [ - "client", - "gpui", - "language", - "project", - "workspace-hack", -] - -[[package]] -name = "inline_completion_button" -version = "0.1.0" -dependencies = [ - "anyhow", - "client", - "copilot", - "editor", - "feature_flags", - "fs", - "futures 0.3.31", - "gpui", - "indoc", - "inline_completion", - "language", - "lsp", - "paths", - "project", - "regex", - "serde_json", - "settings", - "supermaven", - "telemetry", - "theme", - "ui", - "workspace", - "workspace-hack", - "zed_actions", - "zed_llm_client", - "zeta", + "bitflags 1.3.2", + "inotify-sys", + "libc", ] [[package]] @@ -8542,7 +8593,7 @@ dependencies = [ "fnv", "lazy_static", "libc", - "mio", + "mio 1.0.3", "rand 0.8.5", "serde", "tempfile", @@ -9067,6 +9118,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "client", + "cloud_llm_client", "collections", "futures 0.3.31", "gpui", @@ -9084,7 +9136,6 @@ dependencies = [ "thiserror 2.0.12", "util", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -9100,6 +9151,7 @@ dependencies = [ "bedrock", "chrono", "client", + "cloud_llm_client", "collections", "component", "convert_case 0.8.0", @@ -9123,7 +9175,6 @@ dependencies = [ "open_router", "partial-json-fixer", "project", - "proto", "release_channel", "schemars", "serde", @@ -9141,7 +9192,6 @@ dependencies = [ "vercel", "workspace-hack", "x_ai", - "zed_llm_client", ] [[package]] @@ -9198,11 +9248,13 @@ version = "0.1.0" dependencies = [ "anyhow", "async-compression", + "async-fs", "async-tar", "async-trait", "chrono", "collections", "dap", + "feature_flags", "futures 0.3.31", "gpui", "http_client", @@ -9228,9 +9280,11 @@ dependencies = [ "serde_json", "serde_json_lenient", "settings", + "sha2", "smol", "snippet_provider", "task", + "tempfile", "text", "theme", "toml 0.8.20", @@ -9395,7 +9449,7 @@ dependencies = [ [[package]] name = "libwebrtc" version = "0.3.10" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "cxx", "jni", @@ -9475,7 +9529,7 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "livekit" version = "0.7.8" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "chrono", "futures-util", @@ -9498,7 +9552,7 @@ dependencies = [ [[package]] name = "livekit-api" version = "0.4.2" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "futures-util", "http 0.2.12", @@ -9522,7 +9576,7 @@ dependencies = [ [[package]] name = "livekit-protocol" version = "0.3.9" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "futures-util", "livekit-runtime", @@ -9539,7 +9593,7 @@ dependencies = [ [[package]] name = "livekit-runtime" version = "0.4.0" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "tokio", "tokio-stream", @@ -9571,7 +9625,7 @@ dependencies = [ "core-foundation 0.10.0", "core-video", "coreaudio-rs 0.12.1", - "cpal 0.16.0", + "cpal", "futures 0.3.31", "gpui", "gpui_tokio", @@ -9622,9 +9676,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -9861,7 +9915,7 @@ name = "markdown_preview" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 1.1.1", + "async-recursion", "collections", "editor", "fs", @@ -9981,9 +10035,9 @@ dependencies = [ [[package]] name = "mdbook" -version = "0.4.48" +version = "0.4.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6fbb4ac2d9fd7aa987c3510309ea3c80004a968d063c42f0d34fea070817c1" +checksum = "b45a38e19bd200220ef07c892b0157ad3d2365e5b5a267ca01ad12182491eea5" dependencies = [ "ammonia", "anyhow", @@ -9993,12 +10047,11 @@ dependencies = [ "elasticlunr-rs", "env_logger 0.11.8", "futures-util", - "handlebars 6.3.2", - "hex", + "handlebars 5.1.2", "ignore", "log", "memchr", - "notify", + "notify 6.1.1", "notify-debouncer-mini", "once_cell", "opener", @@ -10007,7 +10060,6 @@ dependencies = [ "regex", "serde", "serde_json", - "sha2", "shlex", "tempfile", "tokio", @@ -10128,6 +10180,63 @@ dependencies = [ "unicase", ] +[[package]] +name = "minidump-common" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c4d14bcca0fd3ed165a03000480aaa364c6860c34e900cb2dafdf3b95340e77" +dependencies = [ + "bitflags 2.9.0", + "debugid", + "num-derive", + "num-traits", + "range-map", + "scroll", + "smart-default 0.7.1", +] + +[[package]] +name = "minidump-writer" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abcd9c8a1e6e1e9d56ce3627851f39a17ea83e17c96bc510f29d7e43d78a7d" +dependencies = [ + "bitflags 2.9.0", + "byteorder", + "cfg-if", + "crash-context", + "goblin", + "libc", + "log", + "mach2", + "memmap2", + "memoffset", + "minidump-common", + "nix 0.28.0", + "procfs-core", + "scroll", + "tempfile", + "thiserror 1.0.69", +] + +[[package]] +name = "minidumper" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4ebc9d1f8847ec1d078f78b35ed598e0ebefa1f242d5f83cd8d7f03960a7d1" +dependencies = [ + "cfg-if", + "crash-context", + "libc", + "log", + "minidump-writer", + "parking_lot", + "polling", + "scroll", + "thiserror 1.0.69", + "uds", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -10150,6 +10259,18 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e53debba6bda7a793e5f99b8dacf19e626084f525f7829104ba9898f367d85ff" +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys 0.48.0", +] + [[package]] name = "mio" version = "1.0.3" @@ -10342,20 +10463,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "ndk" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7" -dependencies = [ - "bitflags 2.9.0", - "jni-sys", - "log", - "ndk-sys 0.5.0+25.2.9519653", - "num_enum", - "thiserror 1.0.69", -] - [[package]] name = "ndk" version = "0.9.0" @@ -10365,7 +10472,7 @@ dependencies = [ "bitflags 2.9.0", "jni-sys", "log", - "ndk-sys 0.6.0+11769913", + "ndk-sys", "num_enum", "thiserror 1.0.69", ] @@ -10376,15 +10483,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" -[[package]] -name = "ndk-sys" -version = "0.5.0+25.2.9519653" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691" -dependencies = [ - "jni-sys", -] - [[package]] name = "ndk-sys" version = "0.6.0+11769913" @@ -10519,6 +10617,25 @@ dependencies = [ "zed_actions", ] +[[package]] +name = "notify" +version = "6.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" +dependencies = [ + "bitflags 2.9.0", + "crossbeam-channel", + "filetime", + "fsevent-sys 4.1.0", + "inotify 0.9.6", + "kqueue", + "libc", + "log", + "mio 0.8.11", + "walkdir", + "windows-sys 0.48.0", +] + [[package]] name = "notify" version = "8.0.0" @@ -10527,11 +10644,11 @@ dependencies = [ "bitflags 2.9.0", "filetime", "fsevent-sys 4.1.0", - "inotify", + "inotify 0.11.0", "kqueue", "libc", "log", - "mio", + "mio 1.0.3", "notify-types", "walkdir", "windows-sys 0.59.0", @@ -10539,14 +10656,13 @@ dependencies = [ [[package]] name = "notify-debouncer-mini" -version = "0.6.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a689eb4262184d9a1727f9087cd03883ea716682ab03ed24efec57d7716dccb8" +checksum = "5d40b221972a1fc5ef4d858a2f671fb34c75983eb385463dff3780eeff6a9d43" dependencies = [ + "crossbeam-channel", "log", - "notify", - "notify-types", - "tempfile", + "notify 6.1.1", ] [[package]] @@ -10686,21 +10802,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-modular" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17bb261bf36fa7d83f4c294f834e91256769097b3cb505d44831e0a179ac647f" - -[[package]] -name = "num-order" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537b596b97c40fcf8056d153049eb22f481c17ebce72a513ec9286e4986d1bb6" -dependencies = [ - "num-modular", -] - [[package]] name = "num-rational" version = "0.4.2" @@ -10954,29 +11055,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "oboe" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8b61bebd49e5d43f5f8cc7ee2891c16e0f41ec7954d36bcb6c14c5e0de867fb" -dependencies = [ - "jni", - "ndk 0.8.0", - "ndk-context", - "num-derive", - "num-traits", - "oboe-sys", -] - -[[package]] -name = "oboe-sys" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8bb09a4a2b1d668170cfe0a7d5bc103f8999fb316c98099b6a9939c9f2e79d" -dependencies = [ - "cc", -] - [[package]] name = "ollama" version = "0.1.0" @@ -10994,17 +11072,36 @@ dependencies = [ name = "onboarding" version = "0.1.0" dependencies = [ + "ai_onboarding", "anyhow", + "client", "command_palette_hooks", + "component", "db", + "documented", + "editor", "feature_flags", "fs", + "fuzzy", "gpui", + "itertools 0.14.0", + "language", + "language_model", + "menu", + "notifications", + "picker", + "project", + "schemars", + "serde", "settings", "theme", "ui", + "util", + "vim_mode_setting", "workspace", "workspace-hack", + "zed_actions", + "zlog", ] [[package]] @@ -11355,9 +11452,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -11365,9 +11462,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", @@ -12131,6 +12228,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "plist" version = "1.7.1" @@ -12391,6 +12494,16 @@ dependencies = [ "yansi", ] +[[package]] +name = "procfs-core" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d3554923a69f4ce04c4a754260c338f505ce22642d3830e049a399fc2059a29" +dependencies = [ + "bitflags 2.9.0", + "hex", +] + [[package]] name = "prodash" version = "29.0.2" @@ -13041,6 +13154,15 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "range-map" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12a5a2d6c7039059af621472a4389be1215a816df61aa4d531cfe85264aee95f" +dependencies = [ + "num-traits", +] + [[package]] name = "rangemap" version = "1.5.1" @@ -13383,6 +13505,8 @@ dependencies = [ "clap", "client", "clock", + "crash-handler", + "crashes", "dap", "dap_adapters", "debug_adapter_extension", @@ -13406,6 +13530,7 @@ dependencies = [ "libc", "log", "lsp", + "minidumper", "node_runtime", "paths", "project", @@ -13594,6 +13719,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "once_cell", "percent-encoding", "pin-project-lite", @@ -13755,12 +13881,15 @@ dependencies = [ [[package]] name = "rodio" -version = "0.20.1" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ceb6607dd738c99bc8cb28eff249b7cd5c8ec88b9db96c0608c1480d140fb1" +checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183" dependencies = [ - "cpal 0.15.3", - "hound", + "cpal", + "dasp_sample", + "num-rational", + "symphonia", + "tracing", ] [[package]] @@ -14319,6 +14448,26 @@ dependencies = [ "once_cell", ] +[[package]] +name = "scroll" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab8598aa408498679922eff7fa985c25d58a90771bd6be794434c5277eab1a6" +dependencies = [ + "scroll_derive", +] + +[[package]] +name = "scroll_derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1783eabc414609e28a5ba76aee5ddd52199f7107a0b24c2e9746a1ecc34a683d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "scrypt" version = "0.11.0" @@ -14765,6 +14914,27 @@ dependencies = [ "zlog", ] +[[package]] +name = "settings_profile_selector" +version = "0.1.0" +dependencies = [ + "client", + "editor", + "fuzzy", + "gpui", + "language", + "menu", + "picker", + "project", + "serde_json", + "settings", + "theme", + "ui", + "workspace", + "workspace-hack", + "zed_actions", +] + [[package]] name = "settings_ui" version = "0.1.0" @@ -14787,7 +14957,6 @@ dependencies = [ "notifications", "paths", "project", - "schemars", "search", "serde", "serde_json", @@ -15044,6 +15213,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "smart-default" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eb01866308440fc64d6c44d9e86c5cc17adfe33c4d6eed55da9145044d0ffc1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "smol" version = "2.0.2" @@ -15626,12 +15806,12 @@ dependencies = [ "anyhow", "client", "collections", + "edit_prediction", "editor", "env_logger 0.11.8", "futures 0.3.31", "gpui", "http_client", - "inline_completion", "language", "log", "postage", @@ -15781,6 +15961,66 @@ dependencies = [ "zeno", ] +[[package]] +name = "symphonia" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9" +dependencies = [ + "lazy_static", + "symphonia-codec-pcm", + "symphonia-core", + "symphonia-format-riff", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-codec-pcm" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f395a67057c2ebc5e84d7bb1be71cce1a7ba99f64e0f0f0e303a03f79116f89b" +dependencies = [ + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-core" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "798306779e3dc7d5231bd5691f5a813496dc79d3f56bf82e25789f2094e022c3" +dependencies = [ + "arrayvec", + "bitflags 1.3.2", + "bytemuck", + "lazy_static", + "log", +] + +[[package]] +name = "symphonia-format-riff" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f7be232f962f937f4b7115cbe62c330929345434c834359425e043bfd15f50" +dependencies = [ + "extended", + "log", + "symphonia-core", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-metadata" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc622b9841a10089c5b18e99eb904f4341615d5aa55bbf4eedde1be721a4023c" +dependencies = [ + "encoding_rs", + "lazy_static", + "log", + "symphonia-core", +] + [[package]] name = "syn" version = "1.0.109" @@ -16164,7 +16404,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_slash_command", - "async-recursion 1.1.1", + "async-recursion", "breadcrumbs", "client", "collections", @@ -16513,6 +16753,7 @@ dependencies = [ "call", "chrono", "client", + "cloud_llm_client", "collections", "db", "gpui", @@ -16548,7 +16789,7 @@ dependencies = [ "backtrace", "bytes 1.10.1", "libc", - "mio", + "mio 1.0.3", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -17266,6 +17507,15 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" +[[package]] +name = "uds" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "885c31f06fce836457fe3ef09a59f83fe8db95d270b11cd78f40a4666c4d1661" +dependencies = [ + "libc", +] + [[package]] name = "uds_windows" version = "1.1.0" @@ -18481,11 +18731,11 @@ name = "web_search" version = "0.1.0" dependencies = [ "anyhow", + "cloud_llm_client", "collections", "gpui", "serde", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -18494,6 +18744,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", + "cloud_llm_client", "futures 0.3.31", "gpui", "http_client", @@ -18502,7 +18753,6 @@ dependencies = [ "serde_json", "web_search", "workspace-hack", - "zed_llm_client", ] [[package]] @@ -18526,7 +18776,7 @@ dependencies = [ [[package]] name = "webrtc-sys" version = "0.3.7" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "cc", "cxx", @@ -18539,7 +18789,7 @@ dependencies = [ [[package]] name = "webrtc-sys-build" version = "0.3.6" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "fs2", "regex", @@ -18574,7 +18824,6 @@ dependencies = [ "serde", "settings", "telemetry", - "theme", "ui", "util", "vim_mode_setting", @@ -19589,7 +19838,7 @@ version = "0.1.0" dependencies = [ "any_vec", "anyhow", - "async-recursion 1.1.1", + "async-recursion", "bincode", "call", "client", @@ -19666,14 +19915,12 @@ dependencies = [ "cc", "chrono", "cipher", - "clang-sys", "clap", "clap_builder", "codespan-reporting 0.12.0", "concurrent-queue", "core-foundation 0.9.4", "core-foundation-sys", - "coreaudio-sys", "cranelift-codegen", "crc32fast", "crossbeam-epoch", @@ -19724,9 +19971,11 @@ dependencies = [ "lyon_path", "md-5", "memchr", + "mime_guess", "miniz_oxide", - "mio", + "mio 1.0.3", "naga", + "nix 0.28.0", "nix 0.29.0", "nom", "num-bigint", @@ -20116,7 +20365,7 @@ dependencies = [ "async-io", "async-lock", "async-process", - "async-recursion 1.1.1", + "async-recursion", "async-task", "async-trait", "blocking", @@ -20169,7 +20418,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.198.0" +version = "0.200.0" dependencies = [ "activity_indicator", "agent", @@ -20198,6 +20447,7 @@ dependencies = [ "command_palette", "component", "copilot", + "crashes", "dap", "dap_adapters", "db", @@ -20205,6 +20455,7 @@ dependencies = [ "debugger_tools", "debugger_ui", "diagnostics", + "edit_prediction_button", "editor", "env_logger 0.11.8", "extension", @@ -20224,7 +20475,6 @@ dependencies = [ "http_client", "image_viewer", "indoc", - "inline_completion_button", "inspector_ui", "install_cli", "itertools 0.14.0", @@ -20265,6 +20515,7 @@ dependencies = [ "release_channel", "remote", "repl", + "reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)", "reqwest_client", "rope", "search", @@ -20272,6 +20523,7 @@ dependencies = [ "serde_json", "session", "settings", + "settings_profile_selector", "settings_ui", "shellexpand 2.1.2", "smol", @@ -20330,7 +20582,7 @@ dependencies = [ [[package]] name = "zed_emmet" -version = "0.0.3" +version = "0.0.4" dependencies = [ "zed_extension_api 0.1.0", ] @@ -20369,19 +20621,6 @@ dependencies = [ "zed_extension_api 0.1.0", ] -[[package]] -name = "zed_llm_client" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc" -dependencies = [ - "anyhow", - "serde", - "serde_json", - "strum 0.27.1", - "uuid", -] - [[package]] name = "zed_proto" version = "0.2.2" @@ -20391,7 +20630,7 @@ dependencies = [ [[package]] name = "zed_ruff" -version = "0.1.0" +version = "0.1.1" dependencies = [ "zed_extension_api 0.1.0", ] @@ -20561,11 +20800,14 @@ dependencies = [ "call", "client", "clock", + "cloud_api_types", + "cloud_llm_client", "collections", "command_palette_hooks", "copilot", "ctor", "db", + "edit_prediction", "editor", "feature_flags", "fs", @@ -20573,14 +20815,12 @@ dependencies = [ "gpui", "http_client", "indoc", - "inline_completion", "language", "language_model", "log", "menu", "postage", "project", - "proto", "regex", "release_channel", "reqwest_client", @@ -20602,10 +20842,45 @@ dependencies = [ "workspace-hack", "worktree", "zed_actions", - "zed_llm_client", "zlog", ] +[[package]] +name = "zeta_cli" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "client", + "debug_adapter_extension", + "extension", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "language", + "language_extension", + "language_model", + "language_models", + "languages", + "node_runtime", + "paths", + "project", + "prompt_store", + "release_channel", + "reqwest_client", + "serde", + "serde_json", + "settings", + "shellexpand 2.1.2", + "smol", + "terminal_view", + "util", + "watch", + "workspace-hack", + "zeta", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index ec793a7429..7b82fd1910 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,14 @@ [workspace] resolver = "2" members = [ - "crates/activity_indicator", "crates/acp_thread", - "crates/agent_ui", + "crates/activity_indicator", "crates/agent", - "crates/agent_settings", - "crates/ai_onboarding", + "crates/agent2", "crates/agent_servers", + "crates/agent_settings", + "crates/agent_ui", + "crates/ai_onboarding", "crates/anthropic", "crates/askpass", "crates/assets", @@ -29,6 +30,9 @@ members = [ "crates/cli", "crates/client", "crates/clock", + "crates/cloud_api_client", + "crates/cloud_api_types", + "crates/cloud_llm_client", "crates/collab", "crates/collab_ui", "crates/collections", @@ -37,6 +41,7 @@ members = [ "crates/component", "crates/context_server", "crates/copilot", + "crates/crashes", "crates/credentials_provider", "crates/dap", "crates/dap_adapters", @@ -48,8 +53,8 @@ members = [ "crates/diagnostics", "crates/docs_preprocessor", "crates/editor", - "crates/explorer_command_injector", "crates/eval", + "crates/explorer_command_injector", "crates/extension", "crates/extension_api", "crates/extension_cli", @@ -70,15 +75,14 @@ members = [ "crates/gpui", "crates/gpui_macros", "crates/gpui_tokio", - "crates/html_to_markdown", "crates/http_client", "crates/http_client_tls", "crates/icons", "crates/image_viewer", "crates/indexed_docs", - "crates/inline_completion", - "crates/inline_completion_button", + "crates/edit_prediction", + "crates/edit_prediction_button", "crates/inspector_ui", "crates/install_cli", "crates/jj", @@ -99,7 +103,6 @@ members = [ "crates/markdown_preview", "crates/media", "crates/menu", - "crates/svg_preview", "crates/migrator", "crates/mistral", "crates/multi_buffer", @@ -140,6 +143,7 @@ members = [ "crates/semantic_version", "crates/session", "crates/settings", + "crates/settings_profile_selector", "crates/settings_ui", "crates/snippet", "crates/snippet_provider", @@ -152,6 +156,7 @@ members = [ "crates/sum_tree", "crates/supermaven", "crates/supermaven_api", + "crates/svg_preview", "crates/tab_switcher", "crates/task", "crates/tasks_ui", @@ -186,6 +191,7 @@ members = [ "crates/zed", "crates/zed_actions", "crates/zeta", + "crates/zeta_cli", "crates/zlog", "crates/zlog_settings", @@ -224,6 +230,7 @@ edition = "2024" acp_thread = { path = "crates/acp_thread" } agent = { path = "crates/agent" } +agent2 = { path = "crates/agent2" } activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } agent_settings = { path = "crates/agent_settings" } @@ -251,6 +258,9 @@ channel = { path = "crates/channel" } cli = { path = "crates/cli" } client = { path = "crates/client" } clock = { path = "crates/clock" } +cloud_api_client = { path = "crates/cloud_api_client" } +cloud_api_types = { path = "crates/cloud_api_types" } +cloud_llm_client = { path = "crates/cloud_llm_client" } collab = { path = "crates/collab" } collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } @@ -259,6 +269,7 @@ command_palette_hooks = { path = "crates/command_palette_hooks" } component = { path = "crates/component" } context_server = { path = "crates/context_server" } copilot = { path = "crates/copilot" } +crashes = { path = "crates/crashes" } credentials_provider = { path = "crates/credentials_provider" } dap = { path = "crates/dap" } dap_adapters = { path = "crates/dap_adapters" } @@ -295,8 +306,8 @@ http_client_tls = { path = "crates/http_client_tls" } icons = { path = "crates/icons" } image_viewer = { path = "crates/image_viewer" } indexed_docs = { path = "crates/indexed_docs" } -inline_completion = { path = "crates/inline_completion" } -inline_completion_button = { path = "crates/inline_completion_button" } +edit_prediction = { path = "crates/edit_prediction" } +edit_prediction_button = { path = "crates/edit_prediction_button" } inspector_ui = { path = "crates/inspector_ui" } install_cli = { path = "crates/install_cli" } jj = { path = "crates/jj" } @@ -337,6 +348,7 @@ picker = { path = "crates/picker" } plugin = { path = "crates/plugin" } plugin_macros = { path = "crates/plugin_macros" } prettier = { path = "crates/prettier" } +settings_profile_selector = { path = "crates/settings_profile_selector" } project = { path = "crates/project" } project_panel = { path = "crates/project_panel" } project_symbols = { path = "crates/project_symbols" } @@ -413,6 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" +agent-client-protocol = "0.0.20" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" @@ -457,9 +470,10 @@ core-foundation = "0.10.0" core-foundation-sys = "0.8.6" core-video = { version = "0.4.3", features = ["metal"] } cpal = "0.16" +crash-handler = "0.6" criterion = { version = "0.5", features = ["html_reports"] } ctor = "0.4.0" -dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "7f39295b441614ca9dbf44293e53c32f666897f9" } +dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "1b461b310481d01e02b2603c16d7144b926339f8" } dashmap = "6.0" derive_more = "0.99.17" dirs = "4.0" @@ -504,6 +518,7 @@ log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" } markup5ever_rcdom = "0.3.0" metal = "0.29" +minidumper = "0.8" moka = { version = "0.12.10", features = ["sync"] } naga = { version = "25.0", features = ["wgsl-in"] } nanoid = "0.4" @@ -543,6 +558,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77 "charset", "http2", "macos-system-configuration", + "multipart", "rustls-tls-native-roots", "socks", "stream", @@ -644,7 +660,6 @@ which = "6.0.0" windows-core = "0.61" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "= 0.8.6" zstd = "0.11" [workspace.dependencies.async-stripe] @@ -671,14 +686,16 @@ features = [ "UI_ViewManagement", "Wdk_System_SystemServices", "Win32_Globalization", - "Win32_Graphics_Direct2D", - "Win32_Graphics_Direct2D_Common", + "Win32_Graphics_Direct3D", + "Win32_Graphics_Direct3D11", + "Win32_Graphics_Direct3D_Fxc", + "Win32_Graphics_DirectComposition", "Win32_Graphics_DirectWrite", "Win32_Graphics_Dwm", + "Win32_Graphics_Dxgi", "Win32_Graphics_Dxgi_Common", "Win32_Graphics_Gdi", "Win32_Graphics_Imaging", - "Win32_Graphics_Imaging_D2D", "Win32_Networking_WinSock", "Win32_Security", "Win32_Security_Credentials", @@ -719,6 +736,11 @@ workspace-hack = { path = "tooling/workspace-hack" } split-debuginfo = "unpacked" codegen-units = 16 +# mirror configuration for crates compiled for the build platform +# (without this cargo will compile ~400 crates twice) +[profile.dev.build-override] +codegen-units = 16 + [profile.dev.package] taffy = { opt-level = 3 } cranelift-codegen = { opt-level = 3 } @@ -741,7 +763,7 @@ feature_flags = { codegen-units = 1 } file_icons = { codegen-units = 1 } fsevent = { codegen-units = 1 } image_viewer = { codegen-units = 1 } -inline_completion_button = { codegen-units = 1 } +edit_prediction_button = { codegen-units = 1 } install_cli = { codegen-units = 1 } journal = { codegen-units = 1 } lmstudio = { codegen-units = 1 } diff --git a/Procfile b/Procfile index 5f1231b90a..b3f13f66a6 100644 --- a/Procfile +++ b/Procfile @@ -1,3 +1,4 @@ collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all +cloud: cd ../cloud; cargo make dev livekit: livekit-server --dev blob_store: ./script/run-local-minio diff --git a/README.md b/README.md index 4c794efc3d..38547c1ca4 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Zed +[![Zed](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/zed-industries/zed/main/assets/badge/v0.json)](https://zed.dev) [![CI](https://github.com/zed-industries/zed/actions/workflows/ci.yml/badge.svg)](https://github.com/zed-industries/zed/actions/workflows/ci.yml) Welcome to Zed, a high-performance, multiplayer code editor from the creators of [Atom](https://github.com/atom/atom) and [Tree-sitter](https://github.com/tree-sitter/tree-sitter). diff --git a/assets/badge/v0.json b/assets/badge/v0.json new file mode 100644 index 0000000000..c7d18bb42b --- /dev/null +++ b/assets/badge/v0.json @@ -0,0 +1,8 @@ +{ + "label": "", + "message": "Zed", + "logoSvg": "", + "logoWidth": 16, + "labelColor": "black", + "color": "white" +} diff --git a/assets/icons/ai_bedrock.svg b/assets/icons/ai_bedrock.svg index 2b672c364e..c9bbcc82e1 100644 --- a/assets/icons/ai_bedrock.svg +++ b/assets/icons/ai_bedrock.svg @@ -1,4 +1,8 @@ - - - + + + + + + + diff --git a/assets/icons/ai_deep_seek.svg b/assets/icons/ai_deep_seek.svg index cf480c834c..c8e5483fb3 100644 --- a/assets/icons/ai_deep_seek.svg +++ b/assets/icons/ai_deep_seek.svg @@ -1 +1,3 @@ -DeepSeek + + + diff --git a/assets/icons/ai_lm_studio.svg b/assets/icons/ai_lm_studio.svg index 0b455f48a7..5cfdeb5578 100644 --- a/assets/icons/ai_lm_studio.svg +++ b/assets/icons/ai_lm_studio.svg @@ -1,33 +1,15 @@ - - - Artboard - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + diff --git a/assets/icons/ai_mistral.svg b/assets/icons/ai_mistral.svg index 23b8f2ef6c..f11c177e2f 100644 --- a/assets/icons/ai_mistral.svg +++ b/assets/icons/ai_mistral.svg @@ -1 +1,8 @@ -Mistral \ No newline at end of file + + + + + + + + diff --git a/assets/icons/ai_ollama.svg b/assets/icons/ai_ollama.svg index d433df3981..36a88c1ad6 100644 --- a/assets/icons/ai_ollama.svg +++ b/assets/icons/ai_ollama.svg @@ -1,14 +1,7 @@ - - - - - - - - - - - - + + + + + diff --git a/assets/icons/ai_open_ai.svg b/assets/icons/ai_open_ai.svg index e659a472d8..e45ac315a0 100644 --- a/assets/icons/ai_open_ai.svg +++ b/assets/icons/ai_open_ai.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/ai_open_router.svg b/assets/icons/ai_open_router.svg index 94f2849146..b6f5164e0b 100644 --- a/assets/icons/ai_open_router.svg +++ b/assets/icons/ai_open_router.svg @@ -1,8 +1,8 @@ - - - - - - - + + + + + + + diff --git a/assets/icons/ai_x_ai.svg b/assets/icons/ai_x_ai.svg index 289525c8ef..d3400fbe9c 100644 --- a/assets/icons/ai_x_ai.svg +++ b/assets/icons/ai_x_ai.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/ai_zed.svg b/assets/icons/ai_zed.svg index 1c6bb8ad63..6d78efacd5 100644 --- a/assets/icons/ai_zed.svg +++ b/assets/icons/ai_zed.svg @@ -1,10 +1,3 @@ - - - - - - - - + diff --git a/assets/icons/audio_off.svg b/assets/icons/audio_off.svg index 93b98471ca..dfb5a1c458 100644 --- a/assets/icons/audio_off.svg +++ b/assets/icons/audio_off.svg @@ -1 +1,7 @@ - + + + + + + + diff --git a/assets/icons/audio_on.svg b/assets/icons/audio_on.svg index 42310ea32c..d1bef0d337 100644 --- a/assets/icons/audio_on.svg +++ b/assets/icons/audio_on.svg @@ -1 +1,5 @@ - + + + + + diff --git a/assets/icons/bolt.svg b/assets/icons/bolt.svg deleted file mode 100644 index 2688ede2a5..0000000000 --- a/assets/icons/bolt.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/bolt_filled.svg b/assets/icons/bolt_filled.svg index 543e72adf8..14d8f53e02 100644 --- a/assets/icons/bolt_filled.svg +++ b/assets/icons/bolt_filled.svg @@ -1,3 +1,3 @@ - - + + diff --git a/assets/icons/bolt_filled_alt.svg b/assets/icons/bolt_filled_alt.svg deleted file mode 100644 index 141e1c5f57..0000000000 --- a/assets/icons/bolt_filled_alt.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/bolt_outlined.svg b/assets/icons/bolt_outlined.svg new file mode 100644 index 0000000000..58fccf7788 --- /dev/null +++ b/assets/icons/bolt_outlined.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/book_plus.svg b/assets/icons/book_plus.svg deleted file mode 100644 index 2868f07cd0..0000000000 --- a/assets/icons/book_plus.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/brain.svg b/assets/icons/brain.svg deleted file mode 100644 index 80c93814f7..0000000000 --- a/assets/icons/brain.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/chat.svg b/assets/icons/chat.svg new file mode 100644 index 0000000000..a0548c3d3e --- /dev/null +++ b/assets/icons/chat.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/at_sign.svg b/assets/icons/cloud_download.svg similarity index 51% rename from assets/icons/at_sign.svg rename to assets/icons/cloud_download.svg index 4cf8cd468f..bc7a8376d1 100644 --- a/assets/icons/at_sign.svg +++ b/assets/icons/cloud_download.svg @@ -1 +1 @@ - + \ No newline at end of file diff --git a/assets/icons/editor_atom.svg b/assets/icons/editor_atom.svg new file mode 100644 index 0000000000..cc5fa83843 --- /dev/null +++ b/assets/icons/editor_atom.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/editor_cursor.svg b/assets/icons/editor_cursor.svg new file mode 100644 index 0000000000..338697be8a --- /dev/null +++ b/assets/icons/editor_cursor.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/assets/icons/editor_emacs.svg b/assets/icons/editor_emacs.svg new file mode 100644 index 0000000000..951d7b2be1 --- /dev/null +++ b/assets/icons/editor_emacs.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/assets/icons/editor_jet_brains.svg b/assets/icons/editor_jet_brains.svg new file mode 100644 index 0000000000..7d9cf0c65c --- /dev/null +++ b/assets/icons/editor_jet_brains.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/editor_sublime.svg b/assets/icons/editor_sublime.svg new file mode 100644 index 0000000000..95a04f6b54 --- /dev/null +++ b/assets/icons/editor_sublime.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/editor_vs_code.svg b/assets/icons/editor_vs_code.svg new file mode 100644 index 0000000000..2a71ad52af --- /dev/null +++ b/assets/icons/editor_vs_code.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/exit.svg b/assets/icons/exit.svg index 2cc6ce120d..1ff9d78824 100644 --- a/assets/icons/exit.svg +++ b/assets/icons/exit.svg @@ -1,8 +1,5 @@ - - + + + + diff --git a/assets/icons/file_icons/kdl.svg b/assets/icons/file_icons/kdl.svg new file mode 100644 index 0000000000..92d9f28428 --- /dev/null +++ b/assets/icons/file_icons/kdl.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/assets/icons/file_icons/surrealql.svg b/assets/icons/file_icons/surrealql.svg new file mode 100644 index 0000000000..076f93e808 --- /dev/null +++ b/assets/icons/file_icons/surrealql.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/file_text.svg b/assets/icons/file_text.svg index 7c602f2ac7..a9b8f971e0 100644 --- a/assets/icons/file_text.svg +++ b/assets/icons/file_text.svg @@ -1 +1,6 @@ - + + + + + + diff --git a/assets/icons/git_onboarding_bg.svg b/assets/icons/git_onboarding_bg.svg deleted file mode 100644 index 18da0230a2..0000000000 --- a/assets/icons/git_onboarding_bg.svg +++ /dev/null @@ -1,40 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/assets/icons/message_bubbles.svg b/assets/icons/message_bubbles.svg deleted file mode 100644 index 03a6c7760c..0000000000 --- a/assets/icons/message_bubbles.svg +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - diff --git a/assets/icons/mic.svg b/assets/icons/mic.svg index 01f4c9bf66..1d9c5bc9ed 100644 --- a/assets/icons/mic.svg +++ b/assets/icons/mic.svg @@ -1,3 +1,5 @@ - - + + + + diff --git a/assets/icons/mic_mute.svg b/assets/icons/mic_mute.svg index fe5f8201cc..8c61ae2f1c 100644 --- a/assets/icons/mic_mute.svg +++ b/assets/icons/mic_mute.svg @@ -1,3 +1,8 @@ - - + + + + + + + diff --git a/assets/icons/microscope.svg b/assets/icons/microscope.svg deleted file mode 100644 index 2b3009a28b..0000000000 --- a/assets/icons/microscope.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/new_from_summary.svg b/assets/icons/new_from_summary.svg deleted file mode 100644 index 3b61ca51a0..0000000000 --- a/assets/icons/new_from_summary.svg +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - - diff --git a/assets/icons/play.svg b/assets/icons/play.svg deleted file mode 100644 index 2481bda7d6..0000000000 --- a/assets/icons/play.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/play_bug.svg b/assets/icons/play_bug.svg deleted file mode 100644 index 7d265dd42a..0000000000 --- a/assets/icons/play_bug.svg +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - diff --git a/assets/icons/play_filled.svg b/assets/icons/play_filled.svg index 387304ef04..c632434305 100644 --- a/assets/icons/play_filled.svg +++ b/assets/icons/play_filled.svg @@ -1,3 +1,3 @@ - - + + diff --git a/assets/icons/play_alt.svg b/assets/icons/play_outlined.svg similarity index 70% rename from assets/icons/play_alt.svg rename to assets/icons/play_outlined.svg index b327ab07b5..7e1cacd5af 100644 --- a/assets/icons/play_alt.svg +++ b/assets/icons/play_outlined.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/reveal.svg b/assets/icons/reveal.svg deleted file mode 100644 index ff5444d8f8..0000000000 --- a/assets/icons/reveal.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/screen.svg b/assets/icons/screen.svg index ad252e64cf..4b686b58f9 100644 --- a/assets/icons/screen.svg +++ b/assets/icons/screen.svg @@ -1,8 +1,5 @@ - - + + + + diff --git a/assets/icons/shield_check.svg b/assets/icons/shield_check.svg new file mode 100644 index 0000000000..6e58c31468 --- /dev/null +++ b/assets/icons/shield_check.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/spinner.svg b/assets/icons/spinner.svg deleted file mode 100644 index 4f4034ae89..0000000000 --- a/assets/icons/spinner.svg +++ /dev/null @@ -1,13 +0,0 @@ - - - - - - - - - - - - - diff --git a/assets/icons/strikethrough.svg b/assets/icons/strikethrough.svg deleted file mode 100644 index d7d0905912..0000000000 --- a/assets/icons/strikethrough.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/assets/icons/new_text_thread.svg b/assets/icons/text_thread.svg similarity index 100% rename from assets/icons/new_text_thread.svg rename to assets/icons/text_thread.svg diff --git a/assets/icons/new_thread.svg b/assets/icons/thread.svg similarity index 100% rename from assets/icons/new_thread.svg rename to assets/icons/thread.svg diff --git a/assets/icons/thread_from_summary.svg b/assets/icons/thread_from_summary.svg new file mode 100644 index 0000000000..7519935aff --- /dev/null +++ b/assets/icons/thread_from_summary.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/assets/icons/trash.svg b/assets/icons/trash.svg index b71035b99c..1322e90f9f 100644 --- a/assets/icons/trash.svg +++ b/assets/icons/trash.svg @@ -1 +1,5 @@ - + + + + + diff --git a/assets/icons/trash_alt.svg b/assets/icons/trash_alt.svg deleted file mode 100644 index 6867b42147..0000000000 --- a/assets/icons/trash_alt.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/zed_predict_bg.svg b/assets/icons/zed_predict_bg.svg deleted file mode 100644 index 1dccbb51af..0000000000 --- a/assets/icons/zed_predict_bg.svg +++ /dev/null @@ -1,19 +0,0 @@ - - - - - - - - - - - - - - - - - - - diff --git a/assets/images/certified_user_stamp.svg b/assets/images/certified_user_stamp.svg new file mode 100644 index 0000000000..7e65c4fc9d --- /dev/null +++ b/assets/images/certified_user_stamp.svg @@ -0,0 +1 @@ + diff --git a/assets/images/pro_trial_stamp.svg b/assets/images/pro_trial_stamp.svg new file mode 100644 index 0000000000..501de88a48 --- /dev/null +++ b/assets/images/pro_trial_stamp.svg @@ -0,0 +1 @@ + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 31adef8cd5..81f5c695a2 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -232,7 +232,7 @@ "ctrl-n": "agent::NewThread", "ctrl-alt-n": "agent::NewTextThread", "ctrl-shift-h": "agent::OpenHistory", - "ctrl-alt-c": "agent::OpenConfiguration", + "ctrl-alt-c": "agent::OpenSettings", "ctrl-alt-p": "agent::OpenRulesLibrary", "ctrl-i": "agent::ToggleProfileSelector", "ctrl-alt-/": "agent::ToggleModelSelector", @@ -495,7 +495,7 @@ "shift-f12": "editor::GoToImplementation", "alt-ctrl-f12": "editor::GoToTypeDefinitionSplit", "alt-shift-f12": "editor::FindAllReferences", - "ctrl-m": "editor::MoveToEnclosingBracket", + "ctrl-m": "editor::MoveToEnclosingBracket", // from jetbrains "ctrl-|": "editor::MoveToEnclosingBracket", "ctrl-{": "editor::Fold", "ctrl-}": "editor::UnfoldLines", @@ -598,6 +598,7 @@ "ctrl-shift-t": "pane::ReopenClosedItem", "ctrl-k ctrl-s": "zed::OpenKeymapEditor", "ctrl-k ctrl-t": "theme_selector::Toggle", + "ctrl-alt-super-p": "settings_profile_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -872,8 +873,6 @@ "tab": "git_panel::FocusEditor", "shift-tab": "git_panel::FocusEditor", "escape": "git_panel::ToggleFocus", - "ctrl-enter": "git::Commit", - "ctrl-shift-enter": "git::Amend", "alt-enter": "menu::SecondaryConfirm", "delete": ["git::RestoreFile", { "skip_prompt": false }], "backspace": ["git::RestoreFile", { "skip_prompt": false }], @@ -910,7 +909,9 @@ "ctrl-g backspace": "git::RestoreTrackedFiles", "ctrl-g shift-backspace": "git::TrashUntrackedFiles", "ctrl-space": "git::StageAll", - "ctrl-shift-space": "git::UnstageAll" + "ctrl-shift-space": "git::UnstageAll", + "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend" } }, { @@ -1167,5 +1168,15 @@ "up": "menu::SelectPrevious", "down": "menu::SelectNext" } + }, + { + "context": "Onboarding", + "use_key_equivalents": true, + "bindings": { + "ctrl-1": "onboarding::ActivateBasicsPage", + "ctrl-2": "onboarding::ActivateEditingPage", + "ctrl-3": "onboarding::ActivateAISetupPage", + "ctrl-escape": "onboarding::Finish" + } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index f942c6f8ae..69958fd1f8 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -272,7 +272,7 @@ "cmd-n": "agent::NewThread", "cmd-alt-n": "agent::NewTextThread", "cmd-shift-h": "agent::OpenHistory", - "cmd-alt-c": "agent::OpenConfiguration", + "cmd-alt-c": "agent::OpenSettings", "cmd-alt-p": "agent::OpenRulesLibrary", "cmd-i": "agent::ToggleProfileSelector", "cmd-alt-/": "agent::ToggleModelSelector", @@ -549,7 +549,7 @@ "alt-cmd-f12": "editor::GoToTypeDefinitionSplit", "alt-shift-f12": "editor::FindAllReferences", "cmd-|": "editor::MoveToEnclosingBracket", - "ctrl-m": "editor::MoveToEnclosingBracket", + "ctrl-m": "editor::MoveToEnclosingBracket", // From Jetbrains "alt-cmd-[": "editor::Fold", "alt-cmd-]": "editor::UnfoldLines", "cmd-k cmd-l": "editor::ToggleFold", @@ -665,6 +665,7 @@ "cmd-shift-t": "pane::ReopenClosedItem", "cmd-k cmd-s": "zed::OpenKeymapEditor", "cmd-k cmd-t": "theme_selector::Toggle", + "ctrl-alt-cmd-p": "settings_profile_selector::Toggle", "cmd-t": "project_symbols::Toggle", "cmd-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -950,8 +951,6 @@ "tab": "git_panel::FocusEditor", "shift-tab": "git_panel::FocusEditor", "escape": "git_panel::ToggleFocus", - "cmd-enter": "git::Commit", - "cmd-shift-enter": "git::Amend", "backspace": ["git::RestoreFile", { "skip_prompt": false }], "delete": ["git::RestoreFile", { "skip_prompt": false }], "cmd-backspace": ["git::RestoreFile", { "skip_prompt": true }], @@ -1001,7 +1000,9 @@ "ctrl-g backspace": "git::RestoreTrackedFiles", "ctrl-g shift-backspace": "git::TrashUntrackedFiles", "cmd-ctrl-y": "git::StageAll", - "cmd-ctrl-shift-y": "git::UnstageAll" + "cmd-ctrl-shift-y": "git::UnstageAll", + "cmd-enter": "git::Commit", + "cmd-shift-enter": "git::Amend" } }, { @@ -1269,5 +1270,15 @@ "up": "menu::SelectPrevious", "down": "menu::SelectNext" } + }, + { + "context": "Onboarding", + "use_key_equivalents": true, + "bindings": { + "cmd-1": "onboarding::ActivateBasicsPage", + "cmd-2": "onboarding::ActivateEditingPage", + "cmd-3": "onboarding::ActivateAISetupPage", + "cmd-escape": "onboarding::Finish" + } } ] diff --git a/assets/keymaps/linux/cursor.json b/assets/keymaps/linux/cursor.json index 347b7885fc..1c381b0cf0 100644 --- a/assets/keymaps/linux/cursor.json +++ b/assets/keymaps/linux/cursor.json @@ -8,7 +8,7 @@ "ctrl-shift-i": "agent::ToggleFocus", "ctrl-l": "agent::ToggleFocus", "ctrl-shift-l": "agent::ToggleFocus", - "ctrl-shift-j": "agent::OpenConfiguration" + "ctrl-shift-j": "agent::OpenSettings" } }, { diff --git a/assets/keymaps/linux/jetbrains.json b/assets/keymaps/linux/jetbrains.json index 629333663d..3df1243fed 100644 --- a/assets/keymaps/linux/jetbrains.json +++ b/assets/keymaps/linux/jetbrains.json @@ -4,6 +4,7 @@ "ctrl-alt-s": "zed::OpenSettings", "ctrl-{": "pane::ActivatePreviousItem", "ctrl-}": "pane::ActivateNextItem", + "shift-escape": null, // Unmap workspace::zoom "ctrl-f2": "debugger::Stop", "f6": "debugger::Pause", "f7": "debugger::StepInto", @@ -44,8 +45,8 @@ "ctrl-alt-right": "pane::GoForward", "alt-f7": "editor::FindAllReferences", "ctrl-alt-f7": "editor::FindAllReferences", - // "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock - // "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleLeftDock + "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock + "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleRightDock "ctrl-shift-b": "editor::GoToTypeDefinition", "ctrl-alt-shift-b": "editor::GoToTypeDefinitionSplit", "f2": "editor::GoToDiagnostic", @@ -94,18 +95,33 @@ "ctrl-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], "alt-shift-f10": "task::Spawn", "ctrl-e": "file_finder::Toggle", - "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor + // "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "ctrl-shift-n": "file_finder::Toggle", "ctrl-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", "ctrl-alt-shift-n": "project_symbols::Toggle", "alt-0": "git_panel::ToggleFocus", - "alt-1": "workspace::ToggleLeftDock", + "alt-1": "project_panel::ToggleFocus", "alt-5": "debug_panel::ToggleFocus", "alt-6": "diagnostics::Deploy", "alt-7": "outline_panel::ToggleFocus" } }, + { + "context": "Pane", // this is to override the default Pane mappings to switch tabs + "bindings": { + "alt-1": "project_panel::ToggleFocus", + "alt-2": null, // Bookmarks (left dock) + "alt-3": null, // Find Panel (bottom dock) + "alt-4": null, // Run Panel (bottom dock) + "alt-5": "debug_panel::ToggleFocus", + "alt-6": "diagnostics::Deploy", + "alt-7": "outline_panel::ToggleFocus", + "alt-8": null, // Services (bottom dock) + "alt-9": null, // Git History (bottom dock) + "alt-0": "git_panel::ToggleFocus" + } + }, { "context": "Workspace || Editor", "bindings": { @@ -150,7 +166,10 @@ { "context": "Diagnostics > Editor", "bindings": { "alt-6": "pane::CloseActiveItem" } }, { "context": "OutlinePanel", "bindings": { "alt-7": "workspace::CloseActiveDock" } }, { - "context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", - "bindings": { "escape": "editor::ToggleFocus" } + "context": "Dock || Workspace || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", + "bindings": { + "escape": "editor::ToggleFocus", + "shift-escape": "workspace::CloseActiveDock" + } } ] diff --git a/assets/keymaps/macos/cursor.json b/assets/keymaps/macos/cursor.json index b1d39bef9e..fdf9c437cf 100644 --- a/assets/keymaps/macos/cursor.json +++ b/assets/keymaps/macos/cursor.json @@ -8,7 +8,7 @@ "cmd-shift-i": "agent::ToggleFocus", "cmd-l": "agent::ToggleFocus", "cmd-shift-l": "agent::ToggleFocus", - "cmd-shift-j": "agent::OpenConfiguration" + "cmd-shift-j": "agent::OpenSettings" } }, { diff --git a/assets/keymaps/macos/jetbrains.json b/assets/keymaps/macos/jetbrains.json index e8b796f534..66962811f4 100644 --- a/assets/keymaps/macos/jetbrains.json +++ b/assets/keymaps/macos/jetbrains.json @@ -4,6 +4,7 @@ "cmd-{": "pane::ActivatePreviousItem", "cmd-}": "pane::ActivateNextItem", "cmd-0": "git_panel::ToggleFocus", // overrides `cmd-0` zoom reset + "shift-escape": null, // Unmap workspace::zoom "ctrl-f2": "debugger::Stop", "f6": "debugger::Pause", "f7": "debugger::StepInto", @@ -96,7 +97,7 @@ "cmd-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], "ctrl-alt-r": "task::Spawn", "cmd-e": "file_finder::Toggle", - "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor + // "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "cmd-shift-o": "file_finder::Toggle", "cmd-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", @@ -108,6 +109,21 @@ "cmd-7": "outline_panel::ToggleFocus" } }, + { + "context": "Pane", // this is to override the default Pane mappings to switch tabs + "bindings": { + "cmd-1": "project_panel::ToggleFocus", + "cmd-2": null, // Bookmarks (left dock) + "cmd-3": null, // Find Panel (bottom dock) + "cmd-4": null, // Run Panel (bottom dock) + "cmd-5": "debug_panel::ToggleFocus", + "cmd-6": "diagnostics::Deploy", + "cmd-7": "outline_panel::ToggleFocus", + "cmd-8": null, // Services (bottom dock) + "cmd-9": null, // Git History (bottom dock) + "cmd-0": "git_panel::ToggleFocus" + } + }, { "context": "Workspace || Editor", "bindings": { @@ -146,11 +162,15 @@ } }, { "context": "GitPanel", "bindings": { "cmd-0": "workspace::CloseActiveDock" } }, + { "context": "ProjectPanel", "bindings": { "cmd-1": "workspace::CloseActiveDock" } }, { "context": "DebugPanel", "bindings": { "cmd-5": "workspace::CloseActiveDock" } }, { "context": "Diagnostics > Editor", "bindings": { "cmd-6": "pane::CloseActiveItem" } }, { "context": "OutlinePanel", "bindings": { "cmd-7": "workspace::CloseActiveDock" } }, { - "context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", - "bindings": { "escape": "editor::ToggleFocus" } + "context": "Dock || Workspace || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", + "bindings": { + "escape": "editor::ToggleFocus", + "shift-escape": "workspace::CloseActiveDock" + } } ] diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index d0cf4621a5..6458ac1510 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -220,6 +220,8 @@ { "context": "vim_mode == normal", "bindings": { + "i": "vim::InsertBefore", + "a": "vim::InsertAfter", "ctrl-[": "editor::Cancel", ":": "command_palette::Toggle", "c": "vim::PushChange", @@ -353,9 +355,7 @@ "shift-d": "vim::DeleteToEndOfLine", "shift-j": "vim::JoinLines", "shift-y": "vim::YankLine", - "i": "vim::InsertBefore", "shift-i": "vim::InsertFirstNonWhitespace", - "a": "vim::InsertAfter", "shift-a": "vim::InsertEndOfLine", "o": "vim::InsertLineBelow", "shift-o": "vim::InsertLineAbove", @@ -377,6 +377,8 @@ { "context": "vim_mode == helix_normal && !menu", "bindings": { + "i": "vim::HelixInsert", + "a": "vim::HelixAppend", "ctrl-[": "editor::Cancel", ";": "vim::HelixCollapseSelection", ":": "command_palette::Toggle", diff --git a/assets/settings/default.json b/assets/settings/default.json index 3a7a48efc2..4734b5d118 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1877,5 +1877,25 @@ "save_breakpoints": true, "dock": "bottom", "button": true - } + }, + // Configures any number of settings profiles that are temporarily applied on + // top of your existing user settings when selected from + // `settings profile selector: toggle`. + // Examples: + // "profiles": { + // "Presenting": { + // "agent_font_size": 20.0, + // "buffer_font_size": 20.0, + // "theme": "One Light", + // "ui_font_size": 20.0 + // }, + // "Python (ty)": { + // "languages": { + // "Python": { + // "language_servers": ["ty"] + // } + // } + // } + // } + "profiles": [] } diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index b44c25ccc9..1831c7e473 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -16,7 +16,7 @@ doctest = false test-support = ["gpui/test-support", "project/test-support"] [dependencies] -agentic-coding-protocol.workspace = true +agent-client-protocol.workspace = true anyhow.workspace = true assistant_tool.workspace = true buffer_diff.workspace = true @@ -25,6 +25,7 @@ futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true +language_model.workspace = true markdown.workspace = true project.workspace = true serde.workspace = true @@ -36,11 +37,12 @@ util.workspace = true workspace-hack.workspace = true [dev-dependencies] -async-pipe.workspace = true env_logger.workspace = true gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true +parking_lot.workspace = true project = { workspace = true, "features" = ["test-support"] } +rand.workspace = true tempfile.workspace = true util.workspace = true settings.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 9af1eeb187..9febd9c46f 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,17 +1,13 @@ mod connection; pub use connection::*; -pub use acp::ToolCallId; -use agentic_coding_protocol::{ - self as acp, AgentRequest, ProtocolVersion, ToolCallConfirmationOutcome, ToolCallLocation, - UserMessageChunk, -}; +use agent_client_protocol as acp; use anyhow::{Context as _, Result}; use assistant_tool::ActionLog; use buffer_diff::BufferDiff; use editor::{Bias, MultiBuffer, PathKey}; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; -use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; +use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use itertools::Itertools; use language::{ Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point, @@ -21,46 +17,38 @@ use markdown::Markdown; use project::{AgentLocation, Project}; use std::collections::HashMap; use std::error::Error; -use std::fmt::{Formatter, Write}; +use std::fmt::Formatter; +use std::process::ExitStatus; +use std::rc::Rc; use std::{ fmt::Display, mem, path::{Path, PathBuf}, sync::Arc, }; -use ui::{App, IconName}; +use ui::App; use util::ResultExt; -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug)] pub struct UserMessage { - pub content: Entity, + pub content: ContentBlock, } impl UserMessage { pub fn from_acp( - message: &acp::SendUserMessageParams, + message: impl IntoIterator, language_registry: Arc, cx: &mut App, ) -> Self { - let mut md_source = String::new(); - - for chunk in &message.chunks { - match chunk { - UserMessageChunk::Text { text } => md_source.push_str(&text), - UserMessageChunk::Path { path } => { - write!(&mut md_source, "{}", MentionPath(&path)).unwrap() - } - } - } - - Self { - content: cx - .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)), + let mut content = ContentBlock::Empty; + for chunk in message { + content.append(chunk, &language_registry, cx) } + Self { content: content } } fn to_markdown(&self, cx: &App) -> String { - format!("## User\n\n{}\n\n", self.content.read(cx).source()) + format!("## User\n\n{}\n\n", self.content.to_markdown(cx)) } } @@ -96,7 +84,7 @@ impl Display for MentionPath<'_> { } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug, PartialEq)] pub struct AssistantMessage { pub chunks: Vec, } @@ -113,42 +101,24 @@ impl AssistantMessage { } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug, PartialEq)] pub enum AssistantMessageChunk { - Text { chunk: Entity }, - Thought { chunk: Entity }, + Message { block: ContentBlock }, + Thought { block: ContentBlock }, } impl AssistantMessageChunk { - pub fn from_acp( - chunk: acp::AssistantMessageChunk, - language_registry: Arc, - cx: &mut App, - ) -> Self { - match chunk { - acp::AssistantMessageChunk::Text { text } => Self::Text { - chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)), - }, - acp::AssistantMessageChunk::Thought { thought } => Self::Thought { - chunk: cx - .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)), - }, - } - } - - pub fn from_str(chunk: &str, language_registry: Arc, cx: &mut App) -> Self { - Self::Text { - chunk: cx.new(|cx| { - Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx) - }), + pub fn from_str(chunk: &str, language_registry: &Arc, cx: &mut App) -> Self { + Self::Message { + block: ContentBlock::new(chunk.into(), language_registry, cx), } } fn to_markdown(&self, cx: &App) -> String { match self { - Self::Text { chunk } => chunk.read(cx).source().to_string(), - Self::Thought { chunk } => { - format!("\n{}\n", chunk.read(cx).source()) + Self::Message { block } => block.to_markdown(cx).to_string(), + Self::Thought { block } => { + format!("\n{}\n", block.to_markdown(cx)) } } } @@ -166,19 +136,15 @@ impl AgentThreadEntry { match self { Self::UserMessage(message) => message.to_markdown(cx), Self::AssistantMessage(message) => message.to_markdown(cx), - Self::ToolCall(too_call) => too_call.to_markdown(cx), + Self::ToolCall(tool_call) => tool_call.to_markdown(cx), } } - pub fn diff(&self) -> Option<&Diff> { - if let AgentThreadEntry::ToolCall(ToolCall { - content: Some(ToolCallContent::Diff { diff }), - .. - }) = self - { - Some(&diff) + pub fn diffs(&self) -> impl Iterator { + if let AgentThreadEntry::ToolCall(call) = self { + itertools::Either::Left(call.diffs()) } else { - None + itertools::Either::Right(std::iter::empty()) } } @@ -195,20 +161,99 @@ impl AgentThreadEntry { pub struct ToolCall { pub id: acp::ToolCallId, pub label: Entity, - pub icon: IconName, - pub content: Option, + pub kind: acp::ToolKind, + pub content: Vec, pub status: ToolCallStatus, pub locations: Vec, + pub raw_input: Option, } impl ToolCall { + fn from_acp( + tool_call: acp::ToolCall, + status: ToolCallStatus, + language_registry: Arc, + cx: &mut App, + ) -> Self { + Self { + id: tool_call.id, + label: cx.new(|cx| { + Markdown::new( + tool_call.title.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + kind: tool_call.kind, + content: tool_call + .content + .into_iter() + .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx)) + .collect(), + locations: tool_call.locations, + status, + raw_input: tool_call.raw_input, + } + } + + fn update( + &mut self, + fields: acp::ToolCallUpdateFields, + language_registry: Arc, + cx: &mut App, + ) { + let acp::ToolCallUpdateFields { + kind, + status, + title, + content, + locations, + raw_input, + } = fields; + + if let Some(kind) = kind { + self.kind = kind; + } + + if let Some(status) = status { + self.status = ToolCallStatus::Allowed { status }; + } + + if let Some(title) = title { + self.label = cx.new(|cx| Markdown::new_text(title.into(), cx)); + } + + if let Some(content) = content { + self.content = content + .into_iter() + .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx)) + .collect(); + } + + if let Some(locations) = locations { + self.locations = locations; + } + + if let Some(raw_input) = raw_input { + self.raw_input = Some(raw_input); + } + } + + pub fn diffs(&self) -> impl Iterator { + self.content.iter().filter_map(|content| match content { + ToolCallContent::ContentBlock { .. } => None, + ToolCallContent::Diff { diff } => Some(diff), + }) + } + fn to_markdown(&self, cx: &App) -> String { let mut markdown = format!( "**Tool Call: {}**\nStatus: {}\n\n", self.label.read(cx).source(), self.status ); - if let Some(content) = &self.content { + for content in &self.content { markdown.push_str(content.to_markdown(cx).as_str()); markdown.push_str("\n\n"); } @@ -219,8 +264,8 @@ impl ToolCall { #[derive(Debug)] pub enum ToolCallStatus { WaitingForConfirmation { - confirmation: ToolCallConfirmation, - respond_tx: oneshot::Sender, + options: Vec, + respond_tx: oneshot::Sender, }, Allowed { status: acp::ToolCallStatus, @@ -237,9 +282,10 @@ impl Display for ToolCallStatus { match self { ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation", ToolCallStatus::Allowed { status } => match status { - acp::ToolCallStatus::Running => "Running", - acp::ToolCallStatus::Finished => "Finished", - acp::ToolCallStatus::Error => "Error", + acp::ToolCallStatus::Pending => "Pending", + acp::ToolCallStatus::InProgress => "In Progress", + acp::ToolCallStatus::Completed => "Completed", + acp::ToolCallStatus::Failed => "Failed", }, ToolCallStatus::Rejected => "Rejected", ToolCallStatus::Canceled => "Canceled", @@ -248,86 +294,92 @@ impl Display for ToolCallStatus { } } -#[derive(Debug)] -pub enum ToolCallConfirmation { - Edit { - description: Option>, - }, - Execute { - command: String, - root_command: String, - description: Option>, - }, - Mcp { - server_name: String, - tool_name: String, - tool_display_name: String, - description: Option>, - }, - Fetch { - urls: Vec, - description: Option>, - }, - Other { - description: Entity, - }, +#[derive(Debug, PartialEq, Clone)] +pub enum ContentBlock { + Empty, + Markdown { markdown: Entity }, } -impl ToolCallConfirmation { - pub fn from_acp( - confirmation: acp::ToolCallConfirmation, +impl ContentBlock { + pub fn new( + block: acp::ContentBlock, + language_registry: &Arc, + cx: &mut App, + ) -> Self { + let mut this = Self::Empty; + this.append(block, language_registry, cx); + this + } + + pub fn new_combined( + blocks: impl IntoIterator, language_registry: Arc, cx: &mut App, ) -> Self { - let to_md = |description: String, cx: &mut App| -> Entity { - cx.new(|cx| { - Markdown::new( - description.into(), - Some(language_registry.clone()), - None, - cx, - ) - }) + let mut this = Self::Empty; + for block in blocks { + this.append(block, &language_registry, cx); + } + this + } + + pub fn append( + &mut self, + block: acp::ContentBlock, + language_registry: &Arc, + cx: &mut App, + ) { + let new_content = match block { + acp::ContentBlock::Text(text_content) => text_content.text.clone(), + acp::ContentBlock::ResourceLink(resource_link) => { + if let Some(path) = resource_link.uri.strip_prefix("file://") { + format!("{}", MentionPath(path.as_ref())) + } else { + resource_link.uri.clone() + } + } + acp::ContentBlock::Image(_) + | acp::ContentBlock::Audio(_) + | acp::ContentBlock::Resource(_) => String::new(), }; - match confirmation { - acp::ToolCallConfirmation::Edit { description } => Self::Edit { - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Execute { - command, - root_command, - description, - } => Self::Execute { - command, - root_command, - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Mcp { - server_name, - tool_name, - tool_display_name, - description, - } => Self::Mcp { - server_name, - tool_name, - tool_display_name, - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch { - urls: urls.iter().map(|url| url.into()).collect(), - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Other { description } => Self::Other { - description: to_md(description, cx), - }, + match self { + ContentBlock::Empty => { + *self = ContentBlock::Markdown { + markdown: cx.new(|cx| { + Markdown::new( + new_content.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + }; + } + ContentBlock::Markdown { markdown } => { + markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx)); + } + } + } + + fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str { + match self { + ContentBlock::Empty => "", + ContentBlock::Markdown { markdown } => markdown.read(cx).source(), + } + } + + pub fn markdown(&self) -> Option<&Entity> { + match self { + ContentBlock::Empty => None, + ContentBlock::Markdown { markdown } => Some(markdown), } } } #[derive(Debug)] pub enum ToolCallContent { - Markdown { markdown: Entity }, + ContentBlock { content: ContentBlock }, Diff { diff: Diff }, } @@ -338,8 +390,8 @@ impl ToolCallContent { cx: &mut App, ) -> Self { match content { - acp::ToolCallContent::Markdown { markdown } => Self::Markdown { - markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)), + acp::ToolCallContent::Content { content } => Self::ContentBlock { + content: ContentBlock::new(content, &language_registry, cx), }, acp::ToolCallContent::Diff { diff } => Self::Diff { diff: Diff::from_acp(diff, language_registry, cx), @@ -347,9 +399,9 @@ impl ToolCallContent { } } - fn to_markdown(&self, cx: &App) -> String { + pub fn to_markdown(&self, cx: &App) -> String { match self { - Self::Markdown { markdown } => markdown.read(cx).source().to_string(), + Self::ContentBlock { content } => content.to_markdown(cx).to_string(), Self::Diff { diff } => diff.to_markdown(cx), } } @@ -359,8 +411,6 @@ impl ToolCallContent { pub struct Diff { pub multibuffer: Entity, pub path: PathBuf, - pub new_buffer: Entity, - pub old_buffer: Entity, _task: Task>, } @@ -381,23 +431,34 @@ impl Diff { let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); - let old_buffer_snapshot = old_buffer.read(cx).snapshot(); let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); - let diff_task = buffer_diff.update(cx, |diff, cx| { - diff.set_base_text( - old_buffer_snapshot, - Some(language_registry.clone()), - new_buffer_snapshot, - cx, - ) - }); let task = cx.spawn({ let multibuffer = multibuffer.clone(); let path = path.clone(); - let new_buffer = new_buffer.clone(); async move |cx| { - diff_task.await?; + let language = language_registry + .language_for_file_path(&path) + .await + .log_err(); + + new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?; + + let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| { + buffer.set_language(language, cx); + buffer.snapshot() + })?; + + buffer_diff + .update(cx, |diff, cx| { + diff.set_base_text( + old_buffer_snapshot, + Some(language_registry), + new_buffer_snapshot, + cx, + ) + })? + .await?; multibuffer .update(cx, |multibuffer, cx| { @@ -416,18 +477,10 @@ impl Diff { editor::DEFAULT_MULTIBUFFER_CONTEXT, cx, ); - multibuffer.add_diff(buffer_diff.clone(), cx); + multibuffer.add_diff(buffer_diff, cx); }) .log_err(); - if let Some(language) = language_registry - .language_for_file_path(&path) - .await - .log_err() - { - new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?; - } - anyhow::Ok(()) } }); @@ -435,8 +488,6 @@ impl Diff { Self { multibuffer, path, - new_buffer, - old_buffer, _task: task, } } @@ -520,13 +571,17 @@ pub struct AcpThread { action_log: Entity, shared_buffers: HashMap, BufferSnapshot>, send_task: Option>, - connection: Arc, - child_status: Option>>, + connection: Rc, + session_id: acp::SessionId, } pub enum AcpThreadEvent { NewEntry, EntryUpdated(usize), + ToolAuthorizationRequired, + Stopped, + Error, + ServerExited(ExitStatus), } impl EventEmitter for AcpThread {} @@ -563,10 +618,10 @@ impl Error for LoadError {} impl AcpThread { pub fn new( - connection: impl AgentConnection + 'static, - title: SharedString, - child_status: Option>>, + title: impl Into, + connection: Rc, project: Entity, + session_id: acp::SessionId, cx: &mut Context, ) -> Self { let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -576,24 +631,11 @@ impl AcpThread { shared_buffers: Default::default(), entries: Default::default(), plan: Default::default(), - title, + title: title.into(), project, send_task: None, - connection: Arc::new(connection), - child_status, - } - } - - /// Send a request to the agent and wait for a response. - pub fn request( - &self, - params: R, - ) -> impl use + Future> { - let params = params.into_any(); - let result = self.connection.request_any(params); - async move { - let result = result.await?; - Ok(R::response_from_any(result)?) + connection, + session_id, } } @@ -613,6 +655,10 @@ impl AcpThread { &self.entries } + pub fn session_id(&self) -> &acp::SessionId { + &self.session_id + } + pub fn status(&self) -> ThreadStatus { if self.send_task.is_some() { if self.waiting_for_tool_confirmation() { @@ -629,15 +675,18 @@ impl AcpThread { for entry in self.entries.iter().rev() { match entry { AgentThreadEntry::UserMessage(_) => return false, - AgentThreadEntry::ToolCall(ToolCall { - status: - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, - .. - }, - content: Some(ToolCallContent::Diff { .. }), - .. - }) => return true, + AgentThreadEntry::ToolCall( + call @ ToolCall { + status: + ToolCallStatus::Allowed { + status: + acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending, + }, + .. + }, + ) if call.diffs().next().is_some() => { + return true; + } AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {} } } @@ -645,49 +694,94 @@ impl AcpThread { false } - pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { - self.entries.push(entry); - cx.emit(AcpThreadEvent::NewEntry); + pub fn used_tools_since_last_user_message(&self) -> bool { + for entry in self.entries.iter().rev() { + match entry { + AgentThreadEntry::UserMessage(..) => return false, + AgentThreadEntry::AssistantMessage(..) => continue, + AgentThreadEntry::ToolCall(..) => return true, + } + } + + false } - pub fn push_assistant_chunk( + pub fn handle_session_update( &mut self, - chunk: acp::AssistantMessageChunk, + update: acp::SessionUpdate, + cx: &mut Context, + ) -> Result<()> { + match update { + acp::SessionUpdate::UserMessageChunk { content } => { + self.push_user_content_block(content, cx); + } + acp::SessionUpdate::AgentMessageChunk { content } => { + self.push_assistant_content_block(content, false, cx); + } + acp::SessionUpdate::AgentThoughtChunk { content } => { + self.push_assistant_content_block(content, true, cx); + } + acp::SessionUpdate::ToolCall(tool_call) => { + self.upsert_tool_call(tool_call, cx); + } + acp::SessionUpdate::ToolCallUpdate(tool_call_update) => { + self.update_tool_call(tool_call_update, cx)?; + } + acp::SessionUpdate::Plan(plan) => { + self.update_plan(plan, cx); + } + } + Ok(()) + } + + pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context) { + let language_registry = self.project.read(cx).languages().clone(); + let entries_len = self.entries.len(); + + if let Some(last_entry) = self.entries.last_mut() + && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry + { + content.append(chunk, &language_registry, cx); + cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + } else { + let content = ContentBlock::new(chunk, &language_registry, cx); + self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx); + } + } + + pub fn push_assistant_content_block( + &mut self, + chunk: acp::ContentBlock, + is_thought: bool, cx: &mut Context, ) { + let language_registry = self.project.read(cx).languages().clone(); let entries_len = self.entries.len(); if let Some(last_entry) = self.entries.last_mut() && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry { cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); - - match (chunks.last_mut(), &chunk) { - ( - Some(AssistantMessageChunk::Text { chunk: old_chunk }), - acp::AssistantMessageChunk::Text { text: new_chunk }, - ) - | ( - Some(AssistantMessageChunk::Thought { chunk: old_chunk }), - acp::AssistantMessageChunk::Thought { thought: new_chunk }, - ) => { - old_chunk.update(cx, |old_chunk, cx| { - old_chunk.append(&new_chunk, cx); - }); + match (chunks.last_mut(), is_thought) { + (Some(AssistantMessageChunk::Message { block }), false) + | (Some(AssistantMessageChunk::Thought { block }), true) => { + block.append(chunk, &language_registry, cx) } _ => { - chunks.push(AssistantMessageChunk::from_acp( - chunk, - self.project.read(cx).languages().clone(), - cx, - )); + let block = ContentBlock::new(chunk, &language_registry, cx); + if is_thought { + chunks.push(AssistantMessageChunk::Thought { block }) + } else { + chunks.push(AssistantMessageChunk::Message { block }) + } } } } else { - let chunk = AssistantMessageChunk::from_acp( - chunk, - self.project.read(cx).languages().clone(), - cx, - ); + let block = ContentBlock::new(chunk, &language_registry, cx); + let chunk = if is_thought { + AssistantMessageChunk::Thought { block } + } else { + AssistantMessageChunk::Message { block } + }; self.push_entry( AgentThreadEntry::AssistantMessage(AssistantMessage { @@ -698,212 +792,79 @@ impl AcpThread { } } - pub fn request_new_tool_call( - &mut self, - tool_call: acp::RequestToolCallConfirmationParams, - cx: &mut Context, - ) -> ToolCallRequest { - let (tx, rx) = oneshot::channel(); - - let status = ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::from_acp( - tool_call.confirmation, - self.project.read(cx).languages().clone(), - cx, - ), - respond_tx: tx, - }; - - let id = self.insert_tool_call(tool_call.tool_call, status, cx); - ToolCallRequest { id, outcome: rx } - } - - pub fn request_tool_call_confirmation( - &mut self, - tool_call_id: ToolCallId, - confirmation: acp::ToolCallConfirmation, - cx: &mut Context, - ) -> Result { - let project = self.project.read(cx).languages().clone(); - let Some((idx, call)) = self.tool_call_mut(tool_call_id) else { - anyhow::bail!("Tool call not found"); - }; - - let (tx, rx) = oneshot::channel(); - - call.status = ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx), - respond_tx: tx, - }; - - cx.emit(AcpThreadEvent::EntryUpdated(idx)); - - Ok(ToolCallRequest { - id: tool_call_id, - outcome: rx, - }) - } - - pub fn push_tool_call( - &mut self, - request: acp::PushToolCallParams, - cx: &mut Context, - ) -> acp::ToolCallId { - let status = ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, - }; - - self.insert_tool_call(request, status, cx) - } - - fn insert_tool_call( - &mut self, - tool_call: acp::PushToolCallParams, - status: ToolCallStatus, - cx: &mut Context, - ) -> acp::ToolCallId { - let language_registry = self.project.read(cx).languages().clone(); - let id = acp::ToolCallId(self.entries.len() as u64); - let call = ToolCall { - id, - label: cx.new(|cx| { - Markdown::new( - tool_call.label.into(), - Some(language_registry.clone()), - None, - cx, - ) - }), - icon: acp_icon_to_ui_icon(tool_call.icon), - content: tool_call - .content - .map(|content| ToolCallContent::from_acp(content, language_registry, cx)), - locations: tool_call.locations, - status, - }; - - let location = call.locations.last().cloned(); - if let Some(location) = location { - self.set_project_location(location, cx) - } - - self.push_entry(AgentThreadEntry::ToolCall(call), cx); - - id - } - - pub fn authorize_tool_call( - &mut self, - id: acp::ToolCallId, - outcome: acp::ToolCallConfirmationOutcome, - cx: &mut Context, - ) { - let Some((ix, call)) = self.tool_call_mut(id) else { - return; - }; - - let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject { - ToolCallStatus::Rejected - } else { - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, - } - }; - - let curr_status = mem::replace(&mut call.status, new_status); - - if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { - respond_tx.send(outcome).log_err(); - } else if cfg!(debug_assertions) { - panic!("tried to authorize an already authorized tool call"); - } - - cx.emit(AcpThreadEvent::EntryUpdated(ix)); + fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { + self.entries.push(entry); + cx.emit(AcpThreadEvent::NewEntry); } pub fn update_tool_call( &mut self, - id: acp::ToolCallId, - new_status: acp::ToolCallStatus, - new_content: Option, + update: acp::ToolCallUpdate, cx: &mut Context, ) -> Result<()> { - let language_registry = self.project.read(cx).languages().clone(); - let (ix, call) = self.tool_call_mut(id).context("Entry not found")?; + let languages = self.project.read(cx).languages().clone(); - if let Some(new_content) = new_content { - call.content = Some(ToolCallContent::from_acp( - new_content, - language_registry, - cx, - )); - } - - match &mut call.status { - ToolCallStatus::Allowed { status } => { - *status = new_status; - } - ToolCallStatus::WaitingForConfirmation { .. } => { - anyhow::bail!("Tool call hasn't been authorized yet") - } - ToolCallStatus::Rejected => { - anyhow::bail!("Tool call was rejected and therefore can't be updated") - } - ToolCallStatus::Canceled => { - call.status = ToolCallStatus::Allowed { status: new_status }; - } - } - - let location = call.locations.last().cloned(); - if let Some(location) = location { - self.set_project_location(location, cx) - } + let (ix, current_call) = self + .tool_call_mut(&update.id) + .context("Tool call not found")?; + current_call.update(update.fields, languages, cx); cx.emit(AcpThreadEvent::EntryUpdated(ix)); + Ok(()) } - fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { - let entry = self.entries.get_mut(id.0 as usize); - debug_assert!( - entry.is_some(), - "We shouldn't give out ids to entries that don't exist" - ); - match entry { - Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)), - _ => { - if cfg!(debug_assertions) { - panic!("entry is not a tool call"); - } - None - } + /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. + pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { + let status = ToolCallStatus::Allowed { + status: tool_call.status, + }; + self.upsert_tool_call_inner(tool_call, status, cx) + } + + pub fn upsert_tool_call_inner( + &mut self, + tool_call: acp::ToolCall, + status: ToolCallStatus, + cx: &mut Context, + ) { + let language_registry = self.project.read(cx).languages().clone(); + let call = ToolCall::from_acp(tool_call, status, language_registry, cx); + + let location = call.locations.last().cloned(); + + if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { + *current_call = call; + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } else { + self.push_entry(AgentThreadEntry::ToolCall(call), cx); + } + + if let Some(location) = location { + self.set_project_location(location, cx) } } - pub fn plan(&self) -> &Plan { - &self.plan + fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { + // The tool call we are looking for is typically the last one, or very close to the end. + // At the moment, it doesn't seem like a hashmap would be a good fit for this use case. + self.entries + .iter_mut() + .enumerate() + .rev() + .find_map(|(index, tool_call)| { + if let AgentThreadEntry::ToolCall(tool_call) = tool_call + && &tool_call.id == id + { + Some((index, tool_call)) + } else { + None + } + }) } - pub fn update_plan(&mut self, request: acp::UpdatePlanParams, cx: &mut Context) { - self.plan = Plan { - entries: request - .entries - .into_iter() - .map(|entry| PlanEntry::from_acp(entry, cx)) - .collect(), - }; - - cx.notify(); - } - - pub fn clear_completed_plan_entries(&mut self, cx: &mut Context) { - self.plan - .entries - .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed)); - cx.notify(); - } - - pub fn set_project_location(&self, location: ToolCallLocation, cx: &mut Context) { + pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context) { self.project.update(cx, |project, cx| { let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else { return; @@ -934,6 +895,57 @@ impl AcpThread { }); } + pub fn request_tool_call_permission( + &mut self, + tool_call: acp::ToolCall, + options: Vec, + cx: &mut Context, + ) -> oneshot::Receiver { + let (tx, rx) = oneshot::channel(); + + let status = ToolCallStatus::WaitingForConfirmation { + options, + respond_tx: tx, + }; + + self.upsert_tool_call_inner(tool_call, status, cx); + cx.emit(AcpThreadEvent::ToolAuthorizationRequired); + rx + } + + pub fn authorize_tool_call( + &mut self, + id: acp::ToolCallId, + option_id: acp::PermissionOptionId, + option_kind: acp::PermissionOptionKind, + cx: &mut Context, + ) { + let Some((ix, call)) = self.tool_call_mut(&id) else { + return; + }; + + let new_status = match option_kind { + acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => { + ToolCallStatus::Rejected + } + acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => { + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress, + } + } + }; + + let curr_status = mem::replace(&mut call.status, new_status); + + if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { + respond_tx.send(option_id).log_err(); + } else if cfg!(debug_assertions) { + panic!("tried to authorize an already authorized tool call"); + } + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } + /// Returns true if the last turn is awaiting tool authorization pub fn waiting_for_tool_confirmation(&self) -> bool { for entry in self.entries.iter().rev() { @@ -953,14 +965,27 @@ impl AcpThread { false } - pub fn initialize(&self) -> impl use<> + Future> { - self.request(acp::InitializeParams { - protocol_version: ProtocolVersion::latest(), - }) + pub fn plan(&self) -> &Plan { + &self.plan } - pub fn authenticate(&self) -> impl use<> + Future> { - self.request(acp::AuthenticateParams) + pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context) { + self.plan = Plan { + entries: request + .entries + .into_iter() + .map(|entry| PlanEntry::from_acp(entry, cx)) + .collect(), + }; + + cx.notify(); + } + + fn clear_completed_plan_entries(&mut self, cx: &mut Context) { + self.plan + .entries + .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed)); + cx.notify(); } #[cfg(any(test, feature = "test-support"))] @@ -968,39 +993,50 @@ impl AcpThread { &mut self, message: &str, cx: &mut Context, - ) -> BoxFuture<'static, Result<(), acp::Error>> { + ) -> BoxFuture<'static, Result<()>> { self.send( - acp::SendUserMessageParams { - chunks: vec![acp::UserMessageChunk::Text { - text: message.to_string(), - }], - }, + vec![acp::ContentBlock::Text(acp::TextContent { + text: message.to_string(), + annotations: None, + })], cx, ) } pub fn send( &mut self, - message: acp::SendUserMessageParams, + message: Vec, cx: &mut Context, - ) -> BoxFuture<'static, Result<(), acp::Error>> { - self.push_entry( - AgentThreadEntry::UserMessage(UserMessage::from_acp( - &message, - self.project.read(cx).languages().clone(), - cx, - )), + ) -> BoxFuture<'static, Result<()>> { + let block = ContentBlock::new_combined( + message.clone(), + self.project.read(cx).languages().clone(), cx, ); + self.push_entry( + AgentThreadEntry::UserMessage(UserMessage { content: block }), + cx, + ); + self.clear_completed_plan_entries(cx); let (tx, rx) = oneshot::channel(); - let cancel = self.cancel(cx); + let cancel_task = self.cancel(cx); self.send_task = Some(cx.spawn(async move |this, cx| { async { - cancel.await.log_err(); + cancel_task.await; - let result = this.update(cx, |this, _| this.request(message))?.await; + let result = this + .update(cx, |this, cx| { + this.connection.prompt( + acp::PromptRequest { + prompt: message, + session_id: this.session_id.clone(), + }, + cx, + ) + })? + .await; tx.send(result).log_err(); this.update(cx, |this, _cx| this.send_task.take())?; anyhow::Ok(()) @@ -1009,57 +1045,53 @@ impl AcpThread { .log_err(); })); - async move { - match rx.await { - Ok(Err(e)) => Err(e)?, - _ => Ok(()), + cx.spawn(async move |this, cx| match rx.await { + Ok(Err(e)) => { + this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error)) + .log_err(); + Err(e)? } - } + _ => { + this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped)) + .log_err(); + Ok(()) + } + }) .boxed() } - pub fn cancel(&mut self, cx: &mut Context) -> Task> { - if self.send_task.take().is_some() { - let request = self.request(acp::CancelSendMessageParams); - cx.spawn(async move |this, cx| { - request.await?; - this.update(cx, |this, _cx| { - for entry in this.entries.iter_mut() { - if let AgentThreadEntry::ToolCall(call) = entry { - let cancel = matches!( - call.status, - ToolCallStatus::WaitingForConfirmation { .. } - | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running - } - ); + pub fn cancel(&mut self, cx: &mut Context) -> Task<()> { + let Some(send_task) = self.send_task.take() else { + return Task::ready(()); + }; - if cancel { - let curr_status = - mem::replace(&mut call.status, ToolCallStatus::Canceled); - - if let ToolCallStatus::WaitingForConfirmation { - respond_tx, .. - } = curr_status - { - respond_tx - .send(acp::ToolCallConfirmationOutcome::Cancel) - .ok(); - } - } + for entry in self.entries.iter_mut() { + if let AgentThreadEntry::ToolCall(call) = entry { + let cancel = matches!( + call.status, + ToolCallStatus::WaitingForConfirmation { .. } + | ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress } - } - })?; - Ok(()) - }) - } else { - Task::ready(Ok(())) + ); + + if cancel { + call.status = ToolCallStatus::Canceled; + } + } } + + self.connection.cancel(&self.session_id, cx); + + // Wait for the send task to complete + cx.foreground_executor().spawn(send_task) } pub fn read_text_file( &self, - request: acp::ReadTextFileParams, + path: PathBuf, + line: Option, + limit: Option, reuse_shared_snapshot: bool, cx: &mut Context, ) -> Task> { @@ -1068,7 +1100,7 @@ impl AcpThread { cx.spawn(async move |this, cx| { let load = project.update(cx, |project, cx| { let path = project - .project_path_for_absolute_path(&request.path, cx) + .project_path_for_absolute_path(&path, cx) .context("invalid path")?; anyhow::Ok(project.open_buffer(path, cx)) }); @@ -1094,7 +1126,7 @@ impl AcpThread { let position = buffer .read(cx) .snapshot() - .anchor_before(Point::new(request.line.unwrap_or_default(), 0)); + .anchor_before(Point::new(line.unwrap_or_default(), 0)); project.set_agent_location( Some(AgentLocation { buffer: buffer.downgrade(), @@ -1110,11 +1142,11 @@ impl AcpThread { this.update(cx, |this, _| { let text = snapshot.text(); this.shared_buffers.insert(buffer.clone(), snapshot); - if request.line.is_none() && request.limit.is_none() { + if line.is_none() && limit.is_none() { return Ok(text); } - let limit = request.limit.unwrap_or(u32::MAX) as usize; - let Some(line) = request.line else { + let limit = limit.unwrap_or(u32::MAX) as usize; + let Some(line) = line else { return Ok(text.lines().take(limit).collect::()); }; @@ -1199,207 +1231,29 @@ impl AcpThread { }) } - pub fn child_status(&mut self) -> Option>> { - self.child_status.take() - } - pub fn to_markdown(&self, cx: &App) -> String { self.entries.iter().map(|e| e.to_markdown(cx)).collect() } -} -#[derive(Clone)] -pub struct AcpClientDelegate { - thread: WeakEntity, - cx: AsyncApp, - // sent_buffer_versions: HashMap, HashMap>, -} - -impl AcpClientDelegate { - pub fn new(thread: WeakEntity, cx: AsyncApp) -> Self { - Self { thread, cx } + pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context) { + cx.emit(AcpThreadEvent::ServerExited(status)); } - - pub async fn clear_completed_plan_entries(&self) -> Result<()> { - let cx = &mut self.cx.clone(); - cx.update(|cx| { - self.thread - .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx)) - })? - .context("Failed to update thread")?; - - Ok(()) - } - - pub async fn request_existing_tool_call_confirmation( - &self, - tool_call_id: ToolCallId, - confirmation: acp::ToolCallConfirmation, - ) -> Result { - let cx = &mut self.cx.clone(); - let ToolCallRequest { outcome, .. } = cx - .update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.request_tool_call_confirmation(tool_call_id, confirmation, cx) - }) - })? - .context("Failed to update thread")??; - - Ok(outcome.await?) - } - - pub async fn read_text_file_reusing_snapshot( - &self, - request: acp::ReadTextFileParams, - ) -> Result { - let content = self - .cx - .update(|cx| { - self.thread - .update(cx, |thread, cx| thread.read_text_file(request, true, cx)) - })? - .context("Failed to update thread")? - .await?; - Ok(acp::ReadTextFileResponse { content }) - } -} - -impl acp::Client for AcpClientDelegate { - async fn stream_assistant_message_chunk( - &self, - params: acp::StreamAssistantMessageChunkParams, - ) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread - .update(cx, |thread, cx| { - thread.push_assistant_chunk(params.chunk, cx) - }) - .ok(); - })?; - - Ok(()) - } - - async fn request_tool_call_confirmation( - &self, - request: acp::RequestToolCallConfirmationParams, - ) -> Result { - let cx = &mut self.cx.clone(); - let ToolCallRequest { id, outcome } = cx - .update(|cx| { - self.thread - .update(cx, |thread, cx| thread.request_new_tool_call(request, cx)) - })? - .context("Failed to update thread")?; - - Ok(acp::RequestToolCallConfirmationResponse { - id, - outcome: outcome.await.map_err(acp::Error::into_internal_error)?, - }) - } - - async fn push_tool_call( - &self, - request: acp::PushToolCallParams, - ) -> Result { - let cx = &mut self.cx.clone(); - let id = cx - .update(|cx| { - self.thread - .update(cx, |thread, cx| thread.push_tool_call(request, cx)) - })? - .context("Failed to update thread")?; - - Ok(acp::PushToolCallResponse { id }) - } - - async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.update_tool_call(request.tool_call_id, request.status, request.content, cx) - }) - })? - .context("Failed to update thread")??; - - Ok(()) - } - - async fn update_plan(&self, request: acp::UpdatePlanParams) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread - .update(cx, |thread, cx| thread.update_plan(request, cx)) - })? - .context("Failed to update thread")?; - - Ok(()) - } - - async fn read_text_file( - &self, - request: acp::ReadTextFileParams, - ) -> Result { - let content = self - .cx - .update(|cx| { - self.thread - .update(cx, |thread, cx| thread.read_text_file(request, false, cx)) - })? - .context("Failed to update thread")? - .await?; - Ok(acp::ReadTextFileResponse { content }) - } - - async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> { - self.cx - .update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.write_text_file(request.path, request.content, cx) - }) - })? - .context("Failed to update thread")? - .await?; - - Ok(()) - } -} - -fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName { - match icon { - acp::Icon::FileSearch => IconName::ToolSearch, - acp::Icon::Folder => IconName::ToolFolder, - acp::Icon::Globe => IconName::ToolWeb, - acp::Icon::Hammer => IconName::ToolHammer, - acp::Icon::LightBulb => IconName::ToolBulb, - acp::Icon::Pencil => IconName::ToolPencil, - acp::Icon::Regex => IconName::ToolRegex, - acp::Icon::Terminal => IconName::ToolTerminal, - } -} - -pub struct ToolCallRequest { - pub id: acp::ToolCallId, - pub outcome: oneshot::Receiver, } #[cfg(test)] mod tests { use super::*; use anyhow::anyhow; - use async_pipe::{PipeReader, PipeWriter}; use futures::{channel::mpsc, future::LocalBoxFuture, select}; - use gpui::{AsyncApp, TestAppContext}; + use gpui::{AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; use project::FakeFs; + use rand::Rng as _; use serde_json::json; use settings::SettingsStore; - use smol::{future::BoxedLocal, stream::StreamExt as _}; + use smol::stream::StreamExt as _; use std::{cell::RefCell, rc::Rc, time::Duration}; + use util::path; fn init_test(cx: &mut TestAppContext) { @@ -1413,40 +1267,137 @@ mod tests { } #[gpui::test] - async fn test_thinking_concatenation(cx: &mut TestAppContext) { + async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; - let (thread, fake_server) = fake_acp_thread(project, cx); - - fake_server.update(cx, |fake_server, _| { - fake_server.on_user_message(move |_, server, mut cx| async move { - server - .update(&mut cx, |server, _| { - server.send_to_zed(acp::StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Thought { - thought: "Thinking ".into(), - }, - }) - })? + let connection = Rc::new(FakeAgentConnection::new()); + let thread = cx + .spawn(async move |mut cx| { + connection + .new_thread(project, Path::new(path!("/test")), &mut cx) .await - .unwrap(); - server - .update(&mut cx, |server, _| { - server.send_to_zed(acp::StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Thought { - thought: "hard!".into(), - }, - }) - })? - .await - .unwrap(); - - Ok(()) }) + .await + .unwrap(); + + // Test creating a new user message + thread.update(cx, |thread, cx| { + thread.push_user_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "Hello, ".to_string(), + }), + cx, + ); }); + thread.update(cx, |thread, cx| { + assert_eq!(thread.entries.len(), 1); + if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.content.to_markdown(cx), "Hello, "); + } else { + panic!("Expected UserMessage"); + } + }); + + // Test appending to existing user message + thread.update(cx, |thread, cx| { + thread.push_user_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "world!".to_string(), + }), + cx, + ); + }); + + thread.update(cx, |thread, cx| { + assert_eq!(thread.entries.len(), 1); + if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!"); + } else { + panic!("Expected UserMessage"); + } + }); + + // Test creating new user message after assistant message + thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "Assistant response".to_string(), + }), + false, + cx, + ); + }); + + thread.update(cx, |thread, cx| { + thread.push_user_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "New user message".to_string(), + }), + cx, + ); + }); + + thread.update(cx, |thread, cx| { + assert_eq!(thread.entries.len(), 3); + if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] { + assert_eq!(user_msg.content.to_markdown(cx), "New user message"); + } else { + panic!("Expected UserMessage at index 2"); + } + }); + } + + #[gpui::test] + async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let connection = Rc::new(FakeAgentConnection::new().on_user_message( + |_, thread, mut cx| { + async move { + thread.update(&mut cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: "Thinking ".into(), + }, + cx, + ) + .unwrap(); + thread + .handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: "hard!".into(), + }, + cx, + ) + .unwrap(); + })?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + }, + )); + + let thread = cx + .spawn(async move |mut cx| { + connection + .new_thread(project, Path::new(path!("/test")), &mut cx) + .await + }) + .await + .unwrap(); + thread .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) .await @@ -1478,7 +1429,40 @@ mod tests { fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"})) .await; let project = Project::test(fs.clone(), [], cx).await; - let (thread, fake_server) = fake_acp_thread(project.clone(), cx); + let (read_file_tx, read_file_rx) = oneshot::channel::<()>(); + let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx))); + let connection = Rc::new(FakeAgentConnection::new().on_user_message( + move |_, thread, mut cx| { + let read_file_tx = read_file_tx.clone(); + async move { + let content = thread + .update(&mut cx, |thread, cx| { + thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx) + }) + .unwrap() + .await + .unwrap(); + assert_eq!(content, "one\ntwo\nthree\n"); + read_file_tx.take().unwrap().send(()).unwrap(); + thread + .update(&mut cx, |thread, cx| { + thread.write_text_file( + path!("/tmp/foo").into(), + "one\ntwo\nthree\nfour\nfive\n".to_string(), + cx, + ) + }) + .unwrap() + .await + .unwrap(); + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + }, + )); + let (worktree, pathbuf) = project .update(cx, |project, cx| { project.find_or_create_worktree(path!("/tmp/foo"), true, cx) @@ -1492,38 +1476,10 @@ mod tests { .await .unwrap(); - let (read_file_tx, read_file_rx) = oneshot::channel::<()>(); - let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx))); - - fake_server.update(cx, |fake_server, _| { - fake_server.on_user_message(move |_, server, mut cx| { - let read_file_tx = read_file_tx.clone(); - async move { - let content = server - .update(&mut cx, |server, _| { - server.send_to_zed(acp::ReadTextFileParams { - path: path!("/tmp/foo").into(), - line: None, - limit: None, - }) - })? - .await - .unwrap(); - assert_eq!(content.content, "one\ntwo\nthree\n"); - read_file_tx.take().unwrap().send(()).unwrap(); - server - .update(&mut cx, |server, _| { - server.send_to_zed(acp::WriteTextFileParams { - path: path!("/tmp/foo").into(), - content: "one\ntwo\nthree\nfour\nfive\n".to_string(), - }) - })? - .await - .unwrap(); - Ok(()) - } - }) - }); + let thread = cx + .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx)) + .await + .unwrap(); let request = thread.update(cx, |thread, cx| { thread.send_raw("Extend the count in /tmp/foo", cx) @@ -1550,36 +1506,46 @@ mod tests { let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; - let (thread, fake_server) = fake_acp_thread(project, cx); + let id = acp::ToolCallId("test".into()); - let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>(); - - let tool_call_id = Rc::new(RefCell::new(None)); - let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx))); - fake_server.update(cx, |fake_server, _| { - let tool_call_id = tool_call_id.clone(); - fake_server.on_user_message(move |_, server, mut cx| { - let end_turn_rx = end_turn_rx.clone(); - let tool_call_id = tool_call_id.clone(); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let id = id.clone(); + move |_, thread, mut cx| { + let id = id.clone(); async move { - let tool_call_result = server - .update(&mut cx, |server, _| { - server.send_to_zed(acp::PushToolCallParams { - label: "Fetch".to_string(), - icon: acp::Icon::Globe, - content: None, - locations: vec![], - }) - })? - .await + thread + .update(&mut cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: id.clone(), + title: "Label".into(), + kind: acp::ToolKind::Fetch, + status: acp::ToolCallStatus::InProgress, + content: vec![], + locations: vec![], + raw_input: None, + }), + cx, + ) + }) + .unwrap() .unwrap(); - *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id); - end_turn_rx.take().unwrap().await.ok(); - - Ok(()) + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) } + .boxed_local() + } + })); + + let thread = cx + .spawn(async move |mut cx| { + connection + .new_thread(project, Path::new(path!("/test")), &mut cx) + .await }) - }); + .await + .unwrap(); let request = thread.update(cx, |thread, cx| { thread.send_raw("Fetch https://example.com", cx) @@ -1592,7 +1558,7 @@ mod tests { thread.entries[1], AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, + status: acp::ToolCallStatus::InProgress, .. }, .. @@ -1600,12 +1566,7 @@ mod tests { )); }); - cx.run_until_parked(); - - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .await - .unwrap(); + thread.update(cx, |thread, cx| thread.cancel(cx)).await; thread.read_with(cx, |thread, _| { assert!(matches!( @@ -1617,18 +1578,21 @@ mod tests { )); }); - fake_server - .update(cx, |fake_server, _| { - fake_server.send_to_zed(acp::UpdateToolCallParams { - tool_call_id: tool_call_id.borrow().unwrap(), - status: acp::ToolCallStatus::Finished, - content: None, - }) + thread + .update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate { + id, + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + ..Default::default() + }, + }), + cx, + ) }) - .await .unwrap(); - drop(end_turn_tx); request.await.unwrap(); thread.read_with(cx, |thread, _| { @@ -1636,7 +1600,7 @@ mod tests { thread.entries[1], AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Finished, + status: acp::ToolCallStatus::Completed, .. }, .. @@ -1645,6 +1609,58 @@ mod tests { }); } + #[gpui::test] + async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + move |_, thread, mut cx| { + async move { + thread + .update(&mut cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("test".into()), + title: "Label".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/test/test.txt".into(), + old_text: None, + new_text: "foo".into(), + }, + }], + locations: vec![], + raw_input: None, + }), + cx, + ) + }) + .unwrap() + .unwrap(); + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + } + })); + + let thread = connection + .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + .await + .unwrap(); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) + .await + .unwrap(); + + assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); + } + async fn run_until_first_tool_call( thread: &Entity, cx: &mut TestAppContext, @@ -1672,140 +1688,114 @@ mod tests { } } - pub fn fake_acp_thread( - project: Entity, - cx: &mut TestAppContext, - ) -> (Entity, Entity) { - let (stdin_tx, stdin_rx) = async_pipe::pipe(); - let (stdout_tx, stdout_rx) = async_pipe::pipe(); - - let thread = cx.new(|cx| { - let foreground_executor = cx.foreground_executor().clone(); - let (connection, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), - stdin_tx, - stdout_rx, - move |fut| { - foreground_executor.spawn(fut).detach(); - }, - ); - - let io_task = cx.background_spawn({ - async move { - io_fut.await.log_err(); - Ok(()) - } - }); - AcpThread::new(connection, "Test".into(), Some(io_task), project, cx) - }); - let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx))); - (thread, agent) - } - - pub struct FakeAcpServer { - connection: acp::ClientConnection, - - _io_task: Task<()>, + #[derive(Clone, Default)] + struct FakeAgentConnection { + auth_methods: Vec, + sessions: Arc>>>, on_user_message: Option< Rc< dyn Fn( - acp::SendUserMessageParams, - Entity, - AsyncApp, - ) -> LocalBoxFuture<'static, Result<(), acp::Error>>, + acp::PromptRequest, + WeakEntity, + AsyncApp, + ) -> LocalBoxFuture<'static, Result> + + 'static, >, >, } - #[derive(Clone)] - struct FakeAgent { - server: Entity, - cx: AsyncApp, - } - - impl acp::Agent for FakeAgent { - async fn initialize( - &self, - params: acp::InitializeParams, - ) -> Result { - Ok(acp::InitializeResponse { - protocol_version: params.protocol_version, - is_authenticated: true, - }) - } - - async fn authenticate(&self) -> Result<(), acp::Error> { - Ok(()) - } - - async fn cancel_send_message(&self) -> Result<(), acp::Error> { - Ok(()) - } - - async fn send_user_message( - &self, - request: acp::SendUserMessageParams, - ) -> Result<(), acp::Error> { - let mut cx = self.cx.clone(); - let handler = self - .server - .update(&mut cx, |server, _| server.on_user_message.clone()) - .ok() - .flatten(); - if let Some(handler) = handler { - handler(request, self.server.clone(), self.cx.clone()).await - } else { - Err(anyhow::anyhow!("No handler for on_user_message").into()) - } - } - } - - impl FakeAcpServer { - fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context) -> Self { - let agent = FakeAgent { - server: cx.entity(), - cx: cx.to_async(), - }; - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp::ClientConnection::connect_to_client( - agent.clone(), - stdout, - stdin, - move |fut| { - foreground_executor.spawn(fut).detach(); - }, - ); - FakeAcpServer { - connection: connection, + impl FakeAgentConnection { + fn new() -> Self { + Self { + auth_methods: Vec::new(), on_user_message: None, - _io_task: cx.background_spawn(async move { - io_fut.await.log_err(); - }), + sessions: Arc::default(), } } - fn on_user_message( - &mut self, - handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity, AsyncApp) -> F - + 'static, - ) where - F: Future> + 'static, - { - self.on_user_message - .replace(Rc::new(move |request, server, cx| { - handler(request, server, cx).boxed_local() - })); + #[expect(unused)] + fn with_auth_methods(mut self, auth_methods: Vec) -> Self { + self.auth_methods = auth_methods; + self } - fn send_to_zed( + fn on_user_message( + mut self, + handler: impl Fn( + acp::PromptRequest, + WeakEntity, + AsyncApp, + ) -> LocalBoxFuture<'static, Result> + + 'static, + ) -> Self { + self.on_user_message.replace(Rc::new(handler)); + self + } + } + + impl AgentConnection for FakeAgentConnection { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + let session_id = acp::SessionId( + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .into(), + ); + let thread = cx + .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) + .unwrap(); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task> { + if self.auth_methods().iter().any(|m| m.id == method) { + Task::ready(Ok(())) + } else { + Task::ready(Err(anyhow!("Invalid Auth Method"))) + } + } + + fn prompt( &self, - message: T, - ) -> BoxedLocal> { - self.connection - .request(message) - .map(|f| f.map_err(|err| anyhow!(err))) - .boxed_local() + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + if let Some(handler) = &self.on_user_message { + let handler = handler.clone(); + let thread = thread.clone(); + cx.spawn(async move |cx| handler(params, thread, cx.clone()).await) + } else { + Task::ready(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + })) + } + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let sessions = self.sessions.lock(); + let thread = sessions.get(&session_id).unwrap().clone(); + + cx.spawn(async move |cx| { + thread + .update(cx, |thread, cx| thread.cancel(cx)) + .unwrap() + .await + }) + .detach(); } } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 7c0ba4f41c..cf06563bee 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,20 +1,93 @@ -use agentic_coding_protocol as acp; +use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; + +use agent_client_protocol::{self as acp}; use anyhow::Result; -use futures::future::{FutureExt as _, LocalBoxFuture}; +use gpui::{AsyncApp, Entity, Task}; +use language_model::LanguageModel; +use project::Project; +use ui::App; + +use crate::AcpThread; + +/// Trait for agents that support listing, selecting, and querying language models. +/// +/// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. +pub trait ModelSelector: 'static { + /// Lists all available language models for this agent. + /// + /// # Parameters + /// - `cx`: The GPUI app context for async operations and global access. + /// + /// # Returns + /// A task resolving to the list of models or an error (e.g., if no models are configured). + fn list_models(&self, cx: &mut AsyncApp) -> Task>>>; + + /// Selects a model for a specific session (thread). + /// + /// This sets the default model for future interactions in the session. + /// If the session doesn't exist or the model is invalid, it returns an error. + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to apply the model to. + /// - `model`: The model to select (should be one from [list_models]). + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to `Ok(())` on success or an error. + fn select_model( + &self, + session_id: acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task>; + + /// Retrieves the currently selected model for a specific session (thread). + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to query. + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to the selected model (always set) or an error (e.g., session not found). + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>>; +} pub trait AgentConnection { - fn request_any( - &self, - params: acp::AnyAgentRequest, - ) -> LocalBoxFuture<'static, Result>; -} + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>>; -impl AgentConnection for acp::AgentConnection { - fn request_any( - &self, - params: acp::AnyAgentRequest, - ) -> LocalBoxFuture<'static, Result> { - let task = self.request_any(params); - async move { Ok(task.await?) }.boxed_local() + fn auth_methods(&self) -> &[acp::AuthMethod]; + + fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; + + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) + -> Task>; + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + + /// Returns this agent as an [Rc] if the model selection capability is supported. + /// + /// If the agent does not support model selection, returns [None]. + /// This allows sharing the selector in UI components. + fn model_selector(&self) -> Option> { + None // Default impl for agents that don't support it + } +} + +#[derive(Debug)] +pub struct AuthRequired; + +impl Error for AuthRequired {} +impl fmt::Display for AuthRequired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthRequired") } } diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 135363ab65..7bc0e82cad 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -25,6 +25,7 @@ assistant_context.workspace = true assistant_tool.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true component.workspace = true context_server.workspace = true @@ -35,9 +36,9 @@ futures.workspace = true git.workspace = true gpui.workspace = true heed.workspace = true +http_client.workspace = true icons.workspace = true indoc.workspace = true -http_client.workspace = true itertools.workspace = true language.workspace = true language_model.workspace = true @@ -46,7 +47,6 @@ paths.workspace = true postage.workspace = true project.workspace = true prompt_store.workspace = true -proto.workspace = true ref-cast.workspace = true rope.workspace = true schemars.workspace = true @@ -63,7 +63,6 @@ time.workspace = true util.workspace = true uuid.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true zstd.workspace = true [dev-dependencies] diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs index a89857e71a..34ea1c8df7 100644 --- a/crates/agent/src/agent_profile.rs +++ b/crates/agent/src/agent_profile.rs @@ -308,7 +308,12 @@ mod tests { unimplemented!() } - fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + fn needs_confirmation( + &self, + _input: &serde_json::Value, + _project: &Entity, + _cx: &App, + ) -> bool { unimplemented!() } diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index ddd13de491..cd366b8308 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -42,8 +42,8 @@ impl ContextKind { ContextKind::Symbol => IconName::Code, ContextKind::Selection => IconName::Context, ContextKind::FetchedUrl => IconName::Globe, - ContextKind::Thread => IconName::MessageBubbles, - ContextKind::TextThread => IconName::MessageBubbles, + ContextKind::Thread => IconName::Thread, + ContextKind::TextThread => IconName::TextThread, ContextKind::Rules => RULES_ICON, ContextKind::Image => IconName::Image, } diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index 4c6d2b2b0b..85e8ac7451 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -47,7 +47,7 @@ impl Tool for ContextServerTool { } } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { true } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 1b3b022ab2..048aa4245d 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -8,11 +8,12 @@ use crate::{ }, tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}, }; -use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; use collections::HashMap; use feature_flags::{self, FeatureFlagAppExt}; use futures::{FutureExt, StreamExt as _, future::Shared}; @@ -36,7 +37,6 @@ use project::{ git_store::{GitStore, GitStoreCheckpoint, RepositoryState}, }; use prompt_store::{ModelContext, PromptBuilder}; -use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -49,7 +49,6 @@ use std::{ use thiserror::Error; use util::{ResultExt as _, post_inc}; use uuid::Uuid; -use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; const MAX_RETRY_ATTEMPTS: u8 = 4; const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); @@ -942,7 +941,7 @@ impl Thread { } pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { - self.tool_use.tool_uses_for_message(id, cx) + self.tool_use.tool_uses_for_message(id, &self.project, cx) } pub fn tool_results_for_message( @@ -1681,7 +1680,7 @@ impl Thread { let completion_mode = request .mode - .unwrap_or(zed_llm_client::CompletionMode::Normal); + .unwrap_or(cloud_llm_client::CompletionMode::Normal); self.last_received_chunk_at = Some(Instant::now()); @@ -2037,6 +2036,12 @@ impl Thread { if let Some(retry_strategy) = Thread::get_retry_strategy(completion_error) { + log::info!( + "Retrying with {:?} for language model completion error {:?}", + retry_strategy, + completion_error + ); + retry_scheduled = thread .handle_retryable_error_with_delay( &completion_error, @@ -2107,12 +2112,10 @@ impl Thread { return; } - let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt"); - let request = self.to_summarize_request( &model.model, CompletionIntent::ThreadSummarization, - added_user_message.into(), + SUMMARIZE_THREAD_PROMPT.into(), cx, ); @@ -2246,15 +2249,14 @@ impl Thread { .. } | AuthenticationError { .. } - | PermissionError { .. } => None, - // These errors might be transient, so retry them - SerializeRequest { .. } - | BuildRequestBody { .. } - | PromptTooLarge { .. } + | PermissionError { .. } + | NoApiKey { .. } | ApiEndpointNotFound { .. } - | NoApiKey { .. } => Some(RetryStrategy::Fixed { + | PromptTooLarge { .. } => None, + // These errors might be transient, so retry them + SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed { delay: BASE_RETRY_DELAY, - max_attempts: 2, + max_attempts: 1, }), // Retry all other 4xx and 5xx errors once. HttpResponseError { status_code, .. } @@ -2552,7 +2554,7 @@ impl Thread { return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx); } - if tool.needs_confirmation(&tool_use.input, cx) + if tool.needs_confirmation(&tool_use.input, &self.project, cx) && !AgentSettings::get_global(cx).always_allow_tool_actions { self.tool_use.confirm_tool_use( @@ -3250,8 +3252,10 @@ impl Thread { } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { - self.project.update(cx, |project, cx| { - project.user_store().update(cx, |user_store, cx| { + self.project + .read(cx) + .user_store() + .update(cx, |user_store, cx| { user_store.update_model_request_usage( ModelRequestUsage(RequestUsage { amount: amount as i32, @@ -3259,8 +3263,7 @@ impl Thread { }), cx, ) - }) - }); + }); } pub fn deny_tool_use( @@ -4042,8 +4045,8 @@ fn main() {{ }); cx.run_until_parked(); - fake_model.stream_last_completion_response("Brief"); - fake_model.stream_last_completion_response(" Introduction"); + fake_model.send_last_completion_stream_text_chunk("Brief"); + fake_model.send_last_completion_stream_text_chunk(" Introduction"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -4136,7 +4139,7 @@ fn main() {{ }); cx.run_until_parked(); - fake_model.stream_last_completion_response("A successful summary"); + fake_model.send_last_completion_stream_text_chunk("A successful summary"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -4769,7 +4772,7 @@ fn main() {{ !pending.is_empty(), "Should have a pending completion after retry" ); - fake_model.stream_completion_response(&pending[0], "Success!"); + fake_model.send_completion_stream_text_chunk(&pending[0], "Success!"); fake_model.end_completion_stream(&pending[0]); cx.run_until_parked(); @@ -4937,7 +4940,7 @@ fn main() {{ // Check for pending completions and complete them if let Some(pending) = inner_fake.pending_completions().first() { - inner_fake.stream_completion_response(pending, "Success!"); + inner_fake.send_completion_stream_text_chunk(pending, "Success!"); inner_fake.end_completion_stream(pending); } cx.run_until_parked(); @@ -5422,7 +5425,7 @@ fn main() {{ fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { cx.run_until_parked(); - fake_model.stream_last_completion_response("Assistant response"); + fake_model.send_last_completion_stream_text_chunk("Assistant response"); fake_model.end_last_completion_stream(); cx.run_until_parked(); } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 0347156cd4..cc7cb50c91 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -41,6 +41,9 @@ use std::{ }; use util::ResultExt as _; +pub static ZED_STATELESS: std::sync::LazyLock = + std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum DataType { #[serde(rename = "json")] @@ -874,7 +877,11 @@ impl ThreadsDatabase { let needs_migration_from_heed = mdb_path.exists(); - let connection = Connection::open_file(&sqlite_path.to_string_lossy()); + let connection = if *ZED_STATELESS { + Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else { + Connection::open_file(&sqlite_path.to_string_lossy()) + }; connection.exec(indoc! {" CREATE TABLE IF NOT EXISTS threads ( diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 74c719b4e6..7392c0878d 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -165,7 +165,12 @@ impl ToolUseState { self.pending_tool_uses_by_id.values().collect() } - pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { + pub fn tool_uses_for_message( + &self, + id: MessageId, + project: &Entity, + cx: &App, + ) -> Vec { let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { return Vec::new(); }; @@ -211,7 +216,10 @@ impl ToolUseState { let (icon, needs_confirmation) = if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) { - (tool.icon(), tool.needs_confirmation(&tool_use.input, cx)) + ( + tool.icon(), + tool.needs_confirmation(&tool_use.input, project, cx), + ) } else { (IconName::Cog, false) }; diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml new file mode 100644 index 0000000000..74aa2993dd --- /dev/null +++ b/crates/agent2/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "agent2" +version = "0.1.0" +edition = "2021" +license = "GPL-3.0-or-later" +publish = false + +[lib] +path = "src/agent2.rs" + +[lints] +workspace = true + +[dependencies] +acp_thread.workspace = true +agent-client-protocol.workspace = true +agent_servers.workspace = true +anyhow.workspace = true +cloud_llm_client.workspace = true +collections.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +handlebars = { workspace = true, features = ["rust-embed"] } +indoc.workspace = true +language_model.workspace = true +language_models.workspace = true +log.workspace = true +project.workspace = true +rust-embed.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +ui.workspace = true +util.workspace = true +uuid.workspace = true +worktree.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] +ctor.workspace = true +client = { workspace = true, "features" = ["test-support"] } +clock = { workspace = true, "features" = ["test-support"] } +env_logger.workspace = true +fs = { workspace = true, "features" = ["test-support"] } +gpui = { workspace = true, "features" = ["test-support"] } +gpui_tokio.workspace = true +language_model = { workspace = true, "features" = ["test-support"] } +project = { workspace = true, "features" = ["test-support"] } +reqwest_client.workspace = true +settings = { workspace = true, "features" = ["test-support"] } +worktree = { workspace = true, "features" = ["test-support"] } diff --git a/crates/inline_completion/LICENSE-GPL b/crates/agent2/LICENSE-GPL similarity index 100% rename from crates/inline_completion/LICENSE-GPL rename to crates/agent2/LICENSE-GPL diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs new file mode 100644 index 0000000000..305a31fc98 --- /dev/null +++ b/crates/agent2/src/agent.rs @@ -0,0 +1,345 @@ +use acp_thread::ModelSelector; +use agent_client_protocol as acp; +use anyhow::{anyhow, Result}; +use futures::StreamExt; +use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity}; +use language_model::{LanguageModel, LanguageModelRegistry}; +use project::Project; +use std::collections::HashMap; +use std::path::Path; +use std::rc::Rc; +use std::sync::Arc; + +use crate::{templates::Templates, AgentResponseEvent, Thread}; + +/// Holds both the internal Thread and the AcpThread for a session +struct Session { + /// The internal thread that processes messages + thread: Entity, + /// The ACP thread that handles protocol communication + acp_thread: WeakEntity, + _subscription: Subscription, +} + +pub struct NativeAgent { + /// Session ID -> Session mapping + sessions: HashMap, + /// Shared templates for all threads + templates: Arc, +} + +impl NativeAgent { + pub fn new(templates: Arc) -> Self { + log::info!("Creating new NativeAgent"); + Self { + sessions: HashMap::new(), + templates, + } + } +} + +/// Wrapper struct that implements the AgentConnection trait +#[derive(Clone)] +pub struct NativeAgentConnection(pub Entity); + +impl ModelSelector for NativeAgentConnection { + fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { + log::debug!("NativeAgentConnection::list_models called"); + cx.spawn(async move |cx| { + cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let models = registry.available_models(cx).collect::>(); + log::info!("Found {} available models", models.len()); + if models.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(models) + } + })? + }) + } + + fn select_model( + &self, + session_id: acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task> { + log::info!( + "Setting model for session {}: {:?}", + session_id, + model.name() + ); + let agent = self.0.clone(); + + cx.spawn(async move |cx| { + agent.update(cx, |agent, cx| { + if let Some(session) = agent.sessions.get(&session_id) { + session.thread.update(cx, |thread, _cx| { + thread.selected_model = model; + }); + Ok(()) + } else { + Err(anyhow!("Session not found")) + } + })? + }) + } + + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + let session_id = session_id.clone(); + cx.spawn(async move |cx| { + let thread = agent + .read_with(cx, |agent, _| { + agent + .sessions + .get(&session_id) + .map(|session| session.thread.clone()) + })? + .ok_or_else(|| anyhow::anyhow!("Session not found"))?; + let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + Ok(selected) + }) + } +} + +impl acp_thread::AgentConnection for NativeAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + log::info!("Creating new thread for project at: {:?}", cwd); + + cx.spawn(async move |cx| { + log::debug!("Starting thread creation in async context"); + // Create Thread + let (session_id, thread) = agent.update( + cx, + |agent, cx: &mut gpui::Context| -> Result<_> { + // Fetch default model from registry settings + let registry = LanguageModelRegistry::read_global(cx); + + // Log available models for debugging + let available_count = registry.available_models(cx).count(); + log::debug!("Total available models: {}", available_count); + + let default_model = registry + .default_model() + .map(|configured| { + log::info!( + "Using configured default model: {:?} from provider: {:?}", + configured.model.name(), + configured.provider.name() + ); + configured.model + }) + .ok_or_else(|| { + log::warn!("No default model configured in settings"); + anyhow!("No default model configured. Please configure a default model in settings.") + })?; + + let thread = cx.new(|_| Thread::new(project.clone(), agent.templates.clone(), default_model)); + + // Generate session ID + let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + log::info!("Created session with ID: {}", session_id); + Ok((session_id, thread)) + }, + )??; + + // Create AcpThread + let acp_thread = cx.update(|cx| { + cx.new(|cx| { + acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx) + }) + })?; + + // Store the session + agent.update(cx, |agent, cx| { + agent.sessions.insert( + session_id, + Session { + thread, + acp_thread: acp_thread.downgrade(), + _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }) + }, + ); + })?; + + Ok(acp_thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] // No auth for in-process + } + + fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } + + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let session_id = params.session_id.clone(); + let agent = self.0.clone(); + log::info!("Received prompt request for session: {}", session_id); + log::debug!("Prompt blocks count: {}", params.prompt.len()); + + cx.spawn(async move |cx| { + // Get session + let (thread, acp_thread) = agent + .update(cx, |agent, _| { + agent + .sessions + .get_mut(&session_id) + .map(|s| (s.thread.clone(), s.acp_thread.clone())) + })? + .ok_or_else(|| { + log::error!("Session not found: {}", session_id); + anyhow::anyhow!("Session not found") + })?; + log::debug!("Found session for: {}", session_id); + + // Convert prompt to message + let message = convert_prompt_to_message(params.prompt); + log::info!("Converted prompt to message: {} chars", message.len()); + log::debug!("Message content: {}", message); + + // Get model using the ModelSelector capability (always available for agent2) + // Get the selected model from the thread directly + let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + + // Send to thread + log::info!("Sending message to thread with model: {:?}", model.name()); + let mut response_stream = + thread.update(cx, |thread, cx| thread.send(model, message, cx))?; + + // Handle response stream and forward to session.acp_thread + while let Some(result) = response_stream.next().await { + match result { + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + + match event { + AgentResponseEvent::Text(text) => { + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + }, + cx, + ) + })??; + } + AgentResponseEvent::Thinking(text) => { + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + }, + cx, + ) + })??; + } + AgentResponseEvent::ToolCall(tool_call) => { + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(tool_call), + cx, + ) + })??; + } + AgentResponseEvent::ToolCallUpdate(tool_call_update) => { + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCallUpdate(tool_call_update), + cx, + ) + })??; + } + AgentResponseEvent::Stop(stop_reason) => { + log::debug!("Assistant message complete: {:?}", stop_reason); + return Ok(acp::PromptResponse { stop_reason }); + } + } + } + Err(e) => { + log::error!("Error in model response stream: {:?}", e); + // TODO: Consider sending an error message to the UI + break; + } + } + } + + log::info!("Response stream completed"); + anyhow::Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + log::info!("Cancelling on session: {}", session_id); + self.0.update(cx, |agent, cx| { + if let Some(agent) = agent.sessions.get(session_id) { + agent.thread.update(cx, |thread, _cx| thread.cancel()); + } + }); + } +} + +/// Convert ACP content blocks to a message string +fn convert_prompt_to_message(blocks: Vec) -> String { + log::debug!("Converting {} content blocks to message", blocks.len()); + let mut message = String::new(); + + for block in blocks { + match block { + acp::ContentBlock::Text(text) => { + log::trace!("Processing text block: {} chars", text.text.len()); + message.push_str(&text.text); + } + acp::ContentBlock::ResourceLink(link) => { + log::trace!("Processing resource link: {}", link.uri); + message.push_str(&format!(" @{} ", link.uri)); + } + acp::ContentBlock::Image(_) => { + log::trace!("Processing image block"); + message.push_str(" [image] "); + } + acp::ContentBlock::Audio(_) => { + log::trace!("Processing audio block"); + message.push_str(" [audio] "); + } + acp::ContentBlock::Resource(resource) => { + log::trace!("Processing resource block: {:?}", resource.resource); + message.push_str(&format!(" [resource: {:?}] ", resource.resource)); + } + } + } + + message +} diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs new file mode 100644 index 0000000000..aa665fe313 --- /dev/null +++ b/crates/agent2/src/agent2.rs @@ -0,0 +1,13 @@ +mod agent; +mod native_agent_server; +mod prompts; +mod templates; +mod thread; +mod tools; + +#[cfg(test)] +mod tests; + +pub use agent::*; +pub use native_agent_server::NativeAgentServer; +pub use thread::*; diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs new file mode 100644 index 0000000000..aafe70a8a2 --- /dev/null +++ b/crates/agent2/src/native_agent_server.rs @@ -0,0 +1,58 @@ +use std::path::Path; +use std::rc::Rc; + +use agent_servers::AgentServer; +use anyhow::Result; +use gpui::{App, AppContext, Entity, Task}; +use project::Project; + +use crate::{templates::Templates, NativeAgent, NativeAgentConnection}; + +#[derive(Clone)] +pub struct NativeAgentServer; + +impl AgentServer for NativeAgentServer { + fn name(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_headline(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_message(&self) -> &'static str { + "How can I help you today?" + } + + fn logo(&self) -> ui::IconName { + // Using the ZedAssistant icon as it's the native built-in agent + ui::IconName::ZedAssistant + } + + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + cx: &mut App, + ) -> Task>> { + log::info!( + "NativeAgentServer::connect called for path: {:?}", + _root_dir + ); + cx.spawn(async move |cx| { + log::debug!("Creating templates for native agent"); + // Create templates (you might want to load these from files or resources) + let templates = Templates::new(); + + // Create the native agent + log::debug!("Creating native agent entity"); + let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?; + + // Create the connection wrapper + let connection = NativeAgentConnection(agent); + log::info!("NativeAgentServer connection established successfully"); + + Ok(Rc::new(connection) as Rc) + }) + } +} diff --git a/crates/agent2/src/prompts.rs b/crates/agent2/src/prompts.rs new file mode 100644 index 0000000000..28507f4968 --- /dev/null +++ b/crates/agent2/src/prompts.rs @@ -0,0 +1,35 @@ +use crate::{ + templates::{BaseTemplate, Template, Templates, WorktreeData}, + thread::Prompt, +}; +use anyhow::Result; +use gpui::{App, Entity}; +use project::Project; + +pub struct BasePrompt { + project: Entity, +} + +impl BasePrompt { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl Prompt for BasePrompt { + fn render(&self, templates: &Templates, cx: &App) -> Result { + BaseTemplate { + os: std::env::consts::OS.to_string(), + shell: util::get_system_shell(), + worktrees: self + .project + .read(cx) + .worktrees(cx) + .map(|worktree| WorktreeData { + root_name: worktree.read(cx).root_name().to_string(), + }) + .collect(), + } + .render(templates) + } +} diff --git a/crates/agent2/src/templates.rs b/crates/agent2/src/templates.rs new file mode 100644 index 0000000000..04569369be --- /dev/null +++ b/crates/agent2/src/templates.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use anyhow::Result; +use handlebars::Handlebars; +use rust_embed::RustEmbed; +use serde::Serialize; + +#[derive(RustEmbed)] +#[folder = "src/templates"] +#[include = "*.hbs"] +struct Assets; + +pub struct Templates(Handlebars<'static>); + +impl Templates { + pub fn new() -> Arc { + let mut handlebars = Handlebars::new(); + handlebars.register_embed_templates::().unwrap(); + Arc::new(Self(handlebars)) + } +} + +pub trait Template: Sized { + const TEMPLATE_NAME: &'static str; + + fn render(&self, templates: &Templates) -> Result + where + Self: Serialize + Sized, + { + Ok(templates.0.render(Self::TEMPLATE_NAME, self)?) + } +} + +#[derive(Serialize)] +pub struct BaseTemplate { + pub os: String, + pub shell: String, + pub worktrees: Vec, +} + +impl Template for BaseTemplate { + const TEMPLATE_NAME: &'static str = "base.hbs"; +} + +#[derive(Serialize)] +pub struct WorktreeData { + pub root_name: String, +} + +#[derive(Serialize)] +pub struct GlobTemplate { + pub project_roots: String, +} + +impl Template for GlobTemplate { + const TEMPLATE_NAME: &'static str = "glob.hbs"; +} diff --git a/crates/agent2/src/templates/base.hbs b/crates/agent2/src/templates/base.hbs new file mode 100644 index 0000000000..7eef231e32 --- /dev/null +++ b/crates/agent2/src/templates/base.hbs @@ -0,0 +1,56 @@ +You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. + +## Communication + +1. Be conversational but professional. +2. Refer to the USER in the second person and yourself in the first person. +3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. +4. NEVER lie or make things up. +5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing. + +## Tool Use + +1. Make sure to adhere to the tools schema. +2. Provide every required argument. +3. DO NOT use tools to access items that are already available in the context section. +4. Use only the tools that are currently available. +5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off. + +## Searching and Reading + +If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions. + +If appropriate, use tool calls to explore the current project, which contains the following root directories: + +{{#each worktrees}} +- `{{root_name}}` +{{/each}} + +- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above. +- When looking for symbols in the project, prefer the `grep` tool. +- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project. +- Bias towards not asking the user for help if you can find the answer yourself. + +## Fixing Diagnostics + +1. Make 1-2 attempts at fixing diagnostics, then defer to the user. +2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem. + +## Debugging + +When debugging, only make code changes if you are certain that you can solve the problem. +Otherwise, follow debugging best practices: +1. Address the root cause instead of the symptoms. +2. Add descriptive logging statements and error messages to track variable and code state. +3. Add test functions and statements to isolate the problem. + +## Calling External APIs + +1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission. +2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data. +3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed) + +## System Information + +Operating System: {{os}} +Default Shell: {{shell}} diff --git a/crates/agent2/src/templates/glob.hbs b/crates/agent2/src/templates/glob.hbs new file mode 100644 index 0000000000..3bf992b093 --- /dev/null +++ b/crates/agent2/src/templates/glob.hbs @@ -0,0 +1,8 @@ +Find paths on disk with glob patterns. + +Assume that all glob patterns are matched in a project directory with the following entries. + +{{project_roots}} + +When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be +sure to anchor them with one of the directories listed above. diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs new file mode 100644 index 0000000000..330d04b60c --- /dev/null +++ b/crates/agent2/src/tests/mod.rs @@ -0,0 +1,534 @@ +use super::*; +use crate::templates::Templates; +use acp_thread::AgentConnection; +use agent_client_protocol as acp; +use client::{Client, UserStore}; +use fs::FakeFs; +use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext}; +use indoc::indoc; +use language_model::{ + fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent, + StopReason, +}; +use project::Project; +use reqwest_client::ReqwestClient; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use smol::stream::StreamExt; +use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; +use util::path; + +mod test_tools; +use test_tools::*; + +#[gpui::test] +#[ignore = "temporarily disabled until it can be run on CI"] +async fn test_echo(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + let events = thread + .update(cx, |thread, cx| { + thread.send(model.clone(), "Testing: Reply with 'Hello'", cx) + }) + .collect() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().content, + vec![MessageContent::Text("Hello".to_string())] + ); + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +#[ignore = "temporarily disabled until it can be run on CI"] +async fn test_thinking(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await; + + let events = thread + .update(cx, |thread, cx| { + thread.send( + model.clone(), + indoc! {" + Testing: + + Generate a thinking step where you just think the word 'Think', + and have your final answer be 'Hello' + "}, + cx, + ) + }) + .collect() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().to_markdown(), + indoc! {" + ## assistant + Think + Hello + "} + ) + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +#[ignore = "temporarily disabled until it can be run on CI"] +async fn test_basic_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send( + model.clone(), + "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", + cx, + ) + }) + .collect() + .await; + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); + + // Test a tool calls that's likely to complete *after* streaming stops. + let events = thread + .update(cx, |thread, cx| { + thread.remove_tool(&AgentTool::name(&EchoTool)); + thread.add_tool(DelayTool); + thread.send( + model.clone(), + "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", + cx, + ) + }) + .collect() + .await; + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); + thread.update(cx, |thread, _cx| { + assert!(thread + .messages() + .last() + .unwrap() + .content + .iter() + .any(|content| { + if let MessageContent::Text(text) = content { + text.contains("Ding") + } else { + false + } + })); + }); +} + +#[gpui::test] +#[ignore = "temporarily disabled until it can be run on CI"] +async fn test_streaming_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(WordListTool); + thread.send(model.clone(), "Test the word_list tool.", cx) + }); + + let mut saw_partial_tool_use = false; + while let Some(event) = events.next().await { + if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { + thread.update(cx, |thread, _cx| { + // Look for a tool use in the thread's last message + let last_content = thread.messages().last().unwrap().content.last().unwrap(); + if let MessageContent::ToolUse(last_tool_use) = last_content { + assert_eq!(last_tool_use.name.as_ref(), "word_list"); + if tool_call.status == acp::ToolCallStatus::Pending { + if !last_tool_use.is_input_complete + && last_tool_use.input.get("g").is_none() + { + saw_partial_tool_use = true; + } + } else { + last_tool_use + .input + .get("a") + .expect("'a' has streamed because input is now complete"); + last_tool_use + .input + .get("g") + .expect("'g' has streamed because input is now complete"); + } + } else { + panic!("last content should be a tool use"); + } + }); + } + } + + assert!( + saw_partial_tool_use, + "should see at least one partially streamed tool use in the history" + ); +} + +#[gpui::test] +#[ignore = "temporarily disabled until it can be run on CI"] +async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + // Test concurrent tool calls with different delay times + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(DelayTool); + thread.send( + model.clone(), + "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", + cx, + ) + }) + .collect() + .await; + + let stop_reasons = stop_events(events); + assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]); + + thread.update(cx, |thread, _cx| { + let last_message = thread.messages().last().unwrap(); + let text = last_message + .content + .iter() + .filter_map(|content| { + if let MessageContent::Text(text) = content { + Some(text.as_str()) + } else { + None + } + }) + .collect::(); + + assert!(text.contains("Ding")); + }); +} + +#[gpui::test] +#[ignore = "temporarily disabled until it can be run on CI"] +async fn test_cancellation(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(InfiniteTool); + thread.add_tool(EchoTool); + thread.send( + model.clone(), + "Call the echo tool and then call the infinite tool, then explain their output", + cx, + ) + }); + + // Wait until both tools are called. + let mut expected_tool_calls = vec!["echo", "infinite"]; + let mut echo_id = None; + let mut echo_completed = false; + while let Some(event) = events.next().await { + match event.unwrap() { + AgentResponseEvent::ToolCall(tool_call) => { + assert_eq!(tool_call.title, expected_tool_calls.remove(0)); + if tool_call.title == "echo" { + echo_id = Some(tool_call.id); + } + } + AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate { + id, + fields: + acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + .. + }, + }) if Some(&id) == echo_id.as_ref() => { + echo_completed = true; + } + _ => {} + } + + if expected_tool_calls.is_empty() && echo_completed { + break; + } + } + + // Cancel the current send and ensure that the event stream is closed, even + // if one of the tools is still running. + thread.update(cx, |thread, _cx| thread.cancel()); + events.collect::>().await; + + // Ensure we can still send a new message after cancellation. + let events = thread + .update(cx, |thread, cx| { + thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx) + }) + .collect::>() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().content, + vec![MessageContent::Text("Hello".to_string())] + ); + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +async fn test_refusal(cx: &mut TestAppContext) { + let fake_model = Arc::new(FakeLanguageModel::default()); + let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await; + + let events = thread.update(cx, |thread, cx| { + thread.send(fake_model.clone(), "Hello", cx) + }); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## user + Hello + "} + ); + }); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## user + Hello + ## assistant + Hey! + "} + ); + }); + + // If the model refuses to continue, the thread should remove all the messages after the last user message. + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal)); + let events = events.collect::>().await; + assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]); + thread.read_with(cx, |thread, _| { + assert_eq!(thread.to_markdown(), ""); + }); +} + +#[gpui::test] +async fn test_agent_connection(cx: &mut TestAppContext) { + cx.update(settings::init); + let templates = Templates::new(); + + // Initialize language model system with test provider + cx.update(|cx| { + gpui_tokio::init(cx); + client::init_settings(cx); + + let http_client = FakeHttpClient::with_404_response(); + let clock = Arc::new(clock::FakeSystemClock::new()); + let client = Client::new(clock, http_client, cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + Project::init_settings(cx); + LanguageModelRegistry::test(cx); + }); + cx.executor().forbid_parking(); + + // Create agent and connection + let agent = cx.new(|_| NativeAgent::new(templates.clone())); + let connection = NativeAgentConnection(agent.clone()); + + // Test model_selector returns Some + let selector_opt = connection.model_selector(); + assert!( + selector_opt.is_some(), + "agent2 should always support ModelSelector" + ); + let selector = selector_opt.unwrap(); + + // Test list_models + let listed_models = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.list_models(&mut async_cx) + }) + .await + .expect("list_models should succeed"); + assert!(!listed_models.is_empty(), "should have at least one model"); + assert_eq!(listed_models[0].id().0, "fake"); + + // Create a project for new_thread + let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); + let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + + // Create a thread using new_thread + let cwd = Path::new("/test"); + let connection_rc = Rc::new(connection.clone()); + let acp_thread = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + connection_rc.new_thread(project, cwd, &mut async_cx) + }) + .await + .expect("new_thread should succeed"); + + // Get the session_id from the AcpThread + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + + // Test selected_model returns the default + let model = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&session_id, &mut async_cx) + }) + .await + .expect("selected_model should succeed"); + let model = model.as_fake(); + assert_eq!(model.id().0, "fake", "should return default model"); + + let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("def"); + cx.run_until_parked(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + abc + + ## Assistant + + def + + "} + ) + }); + + // Test cancel + cx.update(|cx| connection.cancel(&session_id, cx)); + request.await.expect("prompt should fail gracefully"); + + // Ensure that dropping the ACP thread causes the native thread to be + // dropped as well. + cx.update(|_| drop(acp_thread)); + let result = cx + .update(|cx| { + connection.prompt( + acp::PromptRequest { + session_id: session_id.clone(), + prompt: vec!["ghi".into()], + }, + cx, + ) + }) + .await; + assert_eq!( + result.as_ref().unwrap_err().to_string(), + "Session not found", + "unexpected result: {:?}", + result + ); +} + +/// Filters out the stop events for asserting against in tests +fn stop_events( + result_events: Vec>, +) -> Vec { + result_events + .into_iter() + .filter_map(|event| match event.unwrap() { + AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), + _ => None, + }) + .collect() +} + +struct ThreadTest { + model: Arc, + thread: Entity, +} + +enum TestModel { + Sonnet4, + Sonnet4Thinking, + Fake(Arc), +} + +impl TestModel { + fn id(&self) -> LanguageModelId { + match self { + TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()), + TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()), + TestModel::Fake(fake_model) => fake_model.id(), + } + } +} + +async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { + cx.executor().allow_parking(); + cx.update(|cx| { + settings::init(cx); + Project::init_settings(cx); + }); + let templates = Templates::new(); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + + let model = cx + .update(|cx| { + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + if let TestModel::Fake(model) = model { + Task::ready(model as Arc<_>) + } else { + let model_id = model.id(); + let models = LanguageModelRegistry::read_global(cx); + let model = models + .available_models(cx) + .find(|model| model.id() == model_id) + .unwrap(); + + let provider = models.provider(&model.provider_id()).unwrap(); + let authenticated = provider.authenticate(cx); + + cx.spawn(async move |_cx| { + authenticated.await.unwrap(); + model + }) + } + }) + .await; + + let thread = cx.new(|_| Thread::new(project, templates, model.clone())); + + ThreadTest { model, thread } +} + +#[cfg(test)] +#[ctor::ctor] +fn init_logger() { + if std::env::var("RUST_LOG").is_ok() { + env_logger::init(); + } +} diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs new file mode 100644 index 0000000000..1847a14fee --- /dev/null +++ b/crates/agent2/src/tests/test_tools.rs @@ -0,0 +1,106 @@ +use super::*; +use anyhow::Result; +use gpui::{App, SharedString, Task}; +use std::future; + +/// A tool that echoes its input +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct EchoToolInput { + /// The text to echo. + text: String, +} + +pub struct EchoTool; + +impl AgentTool for EchoTool { + type Input = EchoToolInput; + + fn name(&self) -> SharedString { + "echo".into() + } + + fn run(self: Arc, input: Self::Input, _cx: &mut App) -> Task> { + Task::ready(Ok(input.text)) + } +} + +/// A tool that waits for a specified delay +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct DelayToolInput { + /// The delay in milliseconds. + ms: u64, +} + +pub struct DelayTool; + +impl AgentTool for DelayTool { + type Input = DelayToolInput; + + fn name(&self) -> SharedString { + "delay".into() + } + + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> + where + Self: Sized, + { + cx.foreground_executor().spawn(async move { + smol::Timer::after(Duration::from_millis(input.ms)).await; + Ok("Ding".to_string()) + }) + } +} + +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct InfiniteToolInput {} + +pub struct InfiniteTool; + +impl AgentTool for InfiniteTool { + type Input = InfiniteToolInput; + + fn name(&self) -> SharedString { + "infinite".into() + } + + fn run(self: Arc, _input: Self::Input, cx: &mut App) -> Task> { + cx.foreground_executor().spawn(async move { + future::pending::<()>().await; + unreachable!() + }) + } +} + +/// A tool that takes an object with map from letters to random words starting with that letter. +/// All fiealds are required! Pass a word for every letter! +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct WordListInput { + /// Provide a random word that starts with A. + a: Option, + /// Provide a random word that starts with B. + b: Option, + /// Provide a random word that starts with C. + c: Option, + /// Provide a random word that starts with D. + d: Option, + /// Provide a random word that starts with E. + e: Option, + /// Provide a random word that starts with F. + f: Option, + /// Provide a random word that starts with G. + g: Option, +} + +pub struct WordListTool; + +impl AgentTool for WordListTool { + type Input = WordListInput; + + fn name(&self) -> SharedString { + "word_list".into() + } + + fn run(self: Arc, _input: Self::Input, _cx: &mut App) -> Task> { + Task::ready(Ok("ok".to_string())) + } +} diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs new file mode 100644 index 0000000000..af3aa17ea8 --- /dev/null +++ b/crates/agent2/src/thread.rs @@ -0,0 +1,754 @@ +use crate::{prompts::BasePrompt, templates::Templates}; +use agent_client_protocol as acp; +use anyhow::{anyhow, Result}; +use cloud_llm_client::{CompletionIntent, CompletionMode}; +use collections::HashMap; +use futures::{channel::mpsc, stream::FuturesUnordered}; +use gpui::{App, Context, Entity, ImageFormat, SharedString, Task}; +use language_model::{ + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, +}; +use log; +use project::Project; +use schemars::{JsonSchema, Schema}; +use serde::Deserialize; +use smol::stream::StreamExt; +use std::{collections::BTreeMap, fmt::Write, sync::Arc}; +use util::{markdown::MarkdownCodeBlock, ResultExt}; + +#[derive(Debug, Clone)] +pub struct AgentMessage { + pub role: Role, + pub content: Vec, +} + +impl AgentMessage { + pub fn to_markdown(&self) -> String { + let mut markdown = format!("## {}\n", self.role); + + for content in &self.content { + match content { + MessageContent::Text(text) => { + markdown.push_str(text); + markdown.push('\n'); + } + MessageContent::Thinking { text, .. } => { + markdown.push_str(""); + markdown.push_str(text); + markdown.push_str("\n"); + } + MessageContent::RedactedThinking(_) => markdown.push_str("\n"), + MessageContent::Image(_) => { + markdown.push_str("\n"); + } + MessageContent::ToolUse(tool_use) => { + markdown.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + markdown.push_str(&format!( + "{}\n", + MarkdownCodeBlock { + tag: "json", + text: &format!("{:#}", tool_use.input) + } + )); + } + MessageContent::ToolResult(tool_result) => { + markdown.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + markdown.push_str("**ERROR:**\n"); + } + + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + writeln!(markdown, "{text}\n").ok(); + } + LanguageModelToolResultContent::Image(_) => { + writeln!(markdown, "\n").ok(); + } + } + + if let Some(output) = tool_result.output.as_ref() { + writeln!( + markdown, + "**Debug Output**:\n\n```json\n{}\n```\n", + serde_json::to_string_pretty(output).unwrap() + ) + .unwrap(); + } + } + } + } + + markdown + } +} + +#[derive(Debug)] +pub enum AgentResponseEvent { + Text(String), + Thinking(String), + ToolCall(acp::ToolCall), + ToolCallUpdate(acp::ToolCallUpdate), + Stop(acp::StopReason), +} + +pub trait Prompt { + fn render(&self, prompts: &Templates, cx: &App) -> Result; +} + +pub struct Thread { + messages: Vec, + completion_mode: CompletionMode, + /// Holds the task that handles agent interaction until the end of the turn. + /// Survives across multiple requests as the model performs tool calls and + /// we run tools, report their results. + running_turn: Option>, + pending_tool_uses: HashMap, + system_prompts: Vec>, + tools: BTreeMap>, + templates: Arc, + pub selected_model: Arc, + // action_log: Entity, +} + +impl Thread { + pub fn new( + project: Entity, + templates: Arc, + default_model: Arc, + ) -> Self { + Self { + messages: Vec::new(), + completion_mode: CompletionMode::Normal, + system_prompts: vec![Arc::new(BasePrompt::new(project))], + running_turn: None, + pending_tool_uses: HashMap::default(), + tools: BTreeMap::default(), + templates, + selected_model: default_model, + } + } + + pub fn set_mode(&mut self, mode: CompletionMode) { + self.completion_mode = mode; + } + + pub fn messages(&self) -> &[AgentMessage] { + &self.messages + } + + pub fn add_tool(&mut self, tool: impl AgentTool) { + self.tools.insert(tool.name(), tool.erase()); + } + + pub fn remove_tool(&mut self, name: &str) -> bool { + self.tools.remove(name).is_some() + } + + pub fn cancel(&mut self) { + self.running_turn.take(); + + let tool_results = self + .pending_tool_uses + .drain() + .map(|(tool_use_id, tool_use)| { + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id, + tool_name: tool_use.name.clone(), + is_error: true, + content: LanguageModelToolResultContent::Text("Tool canceled by user".into()), + output: None, + }) + }) + .collect::>(); + self.last_user_message().content.extend(tool_results); + } + + /// Sending a message results in the model streaming a response, which could include tool calls. + /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. + /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. + pub fn send( + &mut self, + model: Arc, + content: impl Into, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let content = content.into(); + log::info!("Thread::send called with model: {:?}", model.name()); + log::debug!("Thread::send content: {:?}", content); + + cx.notify(); + let (events_tx, events_rx) = + mpsc::unbounded::>(); + + let user_message_ix = self.messages.len(); + self.messages.push(AgentMessage { + role: Role::User, + content: vec![content], + }); + log::info!("Total messages in thread: {}", self.messages.len()); + self.running_turn = Some(cx.spawn(async move |thread, cx| { + log::info!("Starting agent turn execution"); + let turn_result = async { + // Perform one request, then keep looping if the model makes tool calls. + let mut completion_intent = CompletionIntent::UserPrompt; + 'outer: loop { + log::debug!( + "Building completion request with intent: {:?}", + completion_intent + ); + let request = thread.update(cx, |thread, cx| { + thread.build_completion_request(completion_intent, cx) + })?; + + // println!( + // "request: {}", + // serde_json::to_string_pretty(&request).unwrap() + // ); + + // Stream events, appending to messages and collecting up tool uses. + log::info!("Calling model.stream_completion"); + let mut events = model.stream_completion(request, cx).await?; + log::debug!("Stream completion started successfully"); + let mut tool_uses = FuturesUnordered::new(); + while let Some(event) = events.next().await { + match event { + Ok(LanguageModelCompletionEvent::Stop(reason)) => { + if let Some(reason) = to_acp_stop_reason(reason) { + events_tx + .unbounded_send(Ok(AgentResponseEvent::Stop(reason))) + .ok(); + } + + if reason == StopReason::Refusal { + thread.update(cx, |thread, _cx| { + thread.messages.truncate(user_message_ix); + })?; + break 'outer; + } + } + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + thread + .update(cx, |thread, cx| { + tool_uses.extend(thread.handle_streamed_completion_event( + event, &events_tx, cx, + )); + }) + .ok(); + } + Err(error) => { + log::error!("Error in completion stream: {:?}", error); + events_tx.unbounded_send(Err(error)).ok(); + break; + } + } + } + + // If there are no tool uses, the turn is done. + if tool_uses.is_empty() { + log::info!("No tool uses found, completing turn"); + break; + } + log::info!("Found {} tool uses to execute", tool_uses.len()); + + // As tool results trickle in, insert them in the last user + // message so that they can be sent on the next tick of the + // agentic loop. + while let Some(tool_result) = tool_uses.next().await { + log::info!("Tool finished {:?}", tool_result); + + events_tx + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + to_acp_tool_call_update(&tool_result), + ))) + .ok(); + thread + .update(cx, |thread, _cx| { + thread.pending_tool_uses.remove(&tool_result.tool_use_id); + thread + .last_user_message() + .content + .push(MessageContent::ToolResult(tool_result)); + }) + .ok(); + } + + completion_intent = CompletionIntent::ToolResults; + } + + Ok(()) + } + .await; + + if let Err(error) = turn_result { + log::error!("Turn execution failed: {:?}", error); + events_tx.unbounded_send(Err(error)).ok(); + } else { + log::info!("Turn execution completed successfully"); + } + })); + events_rx + } + + pub fn build_system_message(&self, cx: &App) -> Option { + log::debug!("Building system message"); + let mut system_message = AgentMessage { + role: Role::System, + content: Vec::new(), + }; + + for prompt in &self.system_prompts { + if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() { + system_message + .content + .push(MessageContent::Text(rendered_prompt)); + } + } + + let result = (!system_message.content.is_empty()).then_some(system_message); + log::debug!("System message built: {}", result.is_some()); + result + } + + /// A helper method that's called on every streamed completion event. + /// Returns an optional tool result task, which the main agentic loop in + /// send will send back to the model when it resolves. + fn handle_streamed_completion_event( + &mut self, + event: LanguageModelCompletionEvent, + events_tx: &mpsc::UnboundedSender>, + cx: &mut Context, + ) -> Option> { + log::trace!("Handling streamed completion event: {:?}", event); + use LanguageModelCompletionEvent::*; + + match event { + StartMessage { .. } => { + self.messages.push(AgentMessage { + role: Role::Assistant, + content: Vec::new(), + }); + } + Text(new_text) => self.handle_text_event(new_text, events_tx, cx), + Thinking { text, signature } => { + self.handle_thinking_event(text, signature, events_tx, cx) + } + RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx), + ToolUse(tool_use) => { + return self.handle_tool_use_event(tool_use, events_tx, cx); + } + ToolUseJsonParseError { + id, + tool_name, + raw_input, + json_parse_error, + } => { + return Some(Task::ready(self.handle_tool_use_json_parse_error_event( + id, + tool_name, + raw_input, + json_parse_error, + ))); + } + UsageUpdate(_) | StatusUpdate(_) => {} + Stop(_) => unreachable!(), + } + + None + } + + fn handle_text_event( + &mut self, + new_text: String, + events_tx: &mpsc::UnboundedSender>, + cx: &mut Context, + ) { + events_tx + .unbounded_send(Ok(AgentResponseEvent::Text(new_text.clone()))) + .ok(); + + let last_message = self.last_assistant_message(); + if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { + text.push_str(&new_text); + } else { + last_message.content.push(MessageContent::Text(new_text)); + } + + cx.notify(); + } + + fn handle_thinking_event( + &mut self, + new_text: String, + new_signature: Option, + events_tx: &mpsc::UnboundedSender>, + cx: &mut Context, + ) { + events_tx + .unbounded_send(Ok(AgentResponseEvent::Thinking(new_text.clone()))) + .ok(); + + let last_message = self.last_assistant_message(); + if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut() + { + text.push_str(&new_text); + *signature = new_signature.or(signature.take()); + } else { + last_message.content.push(MessageContent::Thinking { + text: new_text, + signature: new_signature, + }); + } + + cx.notify(); + } + + fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context) { + let last_message = self.last_assistant_message(); + last_message + .content + .push(MessageContent::RedactedThinking(data)); + cx.notify(); + } + + fn handle_tool_use_event( + &mut self, + tool_use: LanguageModelToolUse, + events_tx: &mpsc::UnboundedSender>, + cx: &mut Context, + ) -> Option> { + cx.notify(); + + self.pending_tool_uses + .insert(tool_use.id.clone(), tool_use.clone()); + let last_message = self.last_assistant_message(); + + // Ensure the last message ends in the current tool use + let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { + if let MessageContent::ToolUse(last_tool_use) = content { + if last_tool_use.id == tool_use.id { + *last_tool_use = tool_use.clone(); + false + } else { + true + } + } else { + true + } + }); + if push_new_tool_use { + events_tx + .unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + title: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(tool_use.input.clone()), + }))) + .ok(); + last_message + .content + .push(MessageContent::ToolUse(tool_use.clone())); + } else { + events_tx + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use.id.to_string().into()), + fields: acp::ToolCallUpdateFields { + raw_input: Some(tool_use.input.clone()), + ..Default::default() + }, + }, + ))) + .ok(); + } + + if !tool_use.is_input_complete { + return None; + } + + if let Some(tool) = self.tools.get(tool_use.name.as_ref()) { + events_tx + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use.id.to_string().into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress), + ..Default::default() + }, + }, + ))) + .ok(); + + let pending_tool_result = tool.clone().run(tool_use.input, cx); + + Some(cx.foreground_executor().spawn(async move { + match pending_tool_result.await { + Ok(tool_output) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: false, + content: LanguageModelToolResultContent::Text(Arc::from(tool_output)), + output: None, + }, + Err(error) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), + output: None, + }, + } + })) + } else { + let content = format!("No tool named {} exists", tool_use.name); + Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(content)), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })) + } + } + + fn handle_tool_use_json_parse_error_event( + &mut self, + tool_use_id: LanguageModelToolUseId, + tool_name: Arc, + raw_input: Arc, + json_parse_error: String, + ) -> LanguageModelToolResult { + let tool_output = format!("Error parsing input JSON: {json_parse_error}"); + LanguageModelToolResult { + tool_use_id, + tool_name, + is_error: true, + content: LanguageModelToolResultContent::Text(tool_output.into()), + output: Some(serde_json::Value::String(raw_input.to_string())), + } + } + + /// Guarantees the last message is from the assistant and returns a mutable reference. + fn last_assistant_message(&mut self) -> &mut AgentMessage { + if self + .messages + .last() + .map_or(true, |m| m.role != Role::Assistant) + { + self.messages.push(AgentMessage { + role: Role::Assistant, + content: Vec::new(), + }); + } + self.messages.last_mut().unwrap() + } + + /// Guarantees the last message is from the user and returns a mutable reference. + fn last_user_message(&mut self) -> &mut AgentMessage { + if self.messages.last().map_or(true, |m| m.role != Role::User) { + self.messages.push(AgentMessage { + role: Role::User, + content: Vec::new(), + }); + } + self.messages.last_mut().unwrap() + } + + fn build_completion_request( + &self, + completion_intent: CompletionIntent, + cx: &mut App, + ) -> LanguageModelRequest { + log::debug!("Building completion request"); + log::debug!("Completion intent: {:?}", completion_intent); + log::debug!("Completion mode: {:?}", self.completion_mode); + + let messages = self.build_request_messages(cx); + log::info!("Request will include {} messages", messages.len()); + + let tools: Vec = self + .tools + .values() + .filter_map(|tool| { + let tool_name = tool.name().to_string(); + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name, + description: tool.description(cx).to_string(), + input_schema: tool + .input_schema(LanguageModelToolSchemaFormat::JsonSchema) + .log_err()?, + }) + }) + .collect(); + + log::info!("Request includes {} tools", tools.len()); + + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: Some(completion_intent), + mode: Some(self.completion_mode), + messages, + tools, + tool_choice: None, + stop: Vec::new(), + temperature: None, + thinking_allowed: true, + }; + + log::debug!("Completion request built successfully"); + request + } + + fn build_request_messages(&self, cx: &App) -> Vec { + log::trace!( + "Building request messages from {} thread messages", + self.messages.len() + ); + + let messages = self + .build_system_message(cx) + .iter() + .chain(self.messages.iter()) + .map(|message| { + log::trace!( + " - {} message with {} content items", + match message.role { + Role::System => "System", + Role::User => "User", + Role::Assistant => "Assistant", + }, + message.content.len() + ); + LanguageModelRequestMessage { + role: message.role, + content: message.content.clone(), + cache: false, + } + }) + .collect(); + messages + } + + pub fn to_markdown(&self) -> String { + let mut markdown = String::new(); + for message in &self.messages { + markdown.push_str(&message.to_markdown()); + } + markdown + } +} + +pub trait AgentTool +where + Self: 'static + Sized, +{ + type Input: for<'de> Deserialize<'de> + JsonSchema; + + fn name(&self) -> SharedString; + fn description(&self, _cx: &mut App) -> SharedString { + let schema = schemars::schema_for!(Self::Input); + SharedString::new( + schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap_or_default(), + ) + } + + /// Returns the JSON schema that describes the tool's input. + fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema { + schemars::schema_for!(Self::Input) + } + + /// Runs the tool with the provided input. + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task>; + + fn erase(self) -> Arc { + Arc::new(Erased(Arc::new(self))) + } +} + +pub struct Erased(T); + +pub trait AnyAgentTool { + fn name(&self) -> SharedString; + fn description(&self, cx: &mut App) -> SharedString; + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task>; +} + +impl AnyAgentTool for Erased> +where + T: AgentTool, +{ + fn name(&self) -> SharedString { + self.0.name() + } + + fn description(&self, cx: &mut App) -> SharedString { + self.0.description(cx) + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + Ok(serde_json::to_value(self.0.input_schema(format))?) + } + + fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task> { + let parsed_input: Result = serde_json::from_value(input).map_err(Into::into); + match parsed_input { + Ok(input) => self.0.clone().run(input, cx), + Err(error) => Task::ready(Err(anyhow!(error))), + } + } +} + +fn to_acp_stop_reason(reason: StopReason) -> Option { + match reason { + StopReason::EndTurn => Some(acp::StopReason::EndTurn), + StopReason::MaxTokens => Some(acp::StopReason::MaxTokens), + StopReason::Refusal => Some(acp::StopReason::Refusal), + StopReason::ToolUse => None, + } +} + +fn to_acp_tool_call_update(tool_result: &LanguageModelToolResult) -> acp::ToolCallUpdate { + let status = if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }; + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => text.to_string().into(), + LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => { + acp::ToolCallContent::Content { + content: acp::ContentBlock::Image(acp::ImageContent { + annotations: None, + data: source.to_string(), + mime_type: ImageFormat::Png.mime_type().to_string(), + }), + } + } + }; + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()), + fields: acp::ToolCallUpdateFields { + status: Some(status), + content: Some(vec![content]), + ..Default::default() + }, + } +} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs new file mode 100644 index 0000000000..cf3162abfa --- /dev/null +++ b/crates/agent2/src/tools.rs @@ -0,0 +1 @@ +mod glob; diff --git a/crates/agent2/src/tools/glob.rs b/crates/agent2/src/tools/glob.rs new file mode 100644 index 0000000000..9434311aaf --- /dev/null +++ b/crates/agent2/src/tools/glob.rs @@ -0,0 +1,76 @@ +use anyhow::{anyhow, Result}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::Deserialize; +use std::{path::PathBuf, sync::Arc}; +use util::paths::PathMatcher; +use worktree::Snapshot as WorktreeSnapshot; + +use crate::{ + templates::{GlobTemplate, Template, Templates}, + thread::AgentTool, +}; + +// Description is dynamic, see `fn description` below +#[derive(Deserialize, JsonSchema)] +struct GlobInput { + /// A POSIX glob pattern + glob: SharedString, +} + +struct GlobTool { + project: Entity, + templates: Arc, +} + +impl AgentTool for GlobTool { + type Input = GlobInput; + + fn name(&self) -> SharedString { + "glob".into() + } + + fn description(&self, cx: &mut App) -> SharedString { + let project_roots = self + .project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).root_name().into()) + .collect::>() + .join("\n"); + + GlobTemplate { project_roots } + .render(&self.templates) + .expect("template failed to render") + .into() + } + + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> { + let path_matcher = match PathMatcher::new([&input.glob]) { + Ok(matcher) => matcher, + Err(error) => return Task::ready(Err(anyhow!(error))), + }; + + let snapshots: Vec = self + .project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect(); + + cx.background_spawn(async move { + let paths = snapshots.iter().flat_map(|snapshot| { + let root_name = PathBuf::from(snapshot.root_name()); + snapshot + .entries(false, 0) + .map(move |entry| root_name.join(&entry.path)) + .filter(|path| path_matcher.is_match(&path)) + }); + let output = paths + .map(|path| format!("{}\n", path.display())) + .collect::(); + Ok(output) + }) + } +} diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 4714245b94..81c97c8aa6 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -18,16 +18,19 @@ doctest = false [dependencies] acp_thread.workspace = true +agent-client-protocol.workspace = true agentic-coding-protocol.workspace = true anyhow.workspace = true collections.workspace = true context_server.workspace = true futures.workspace = true gpui.workspace = true +indoc.workspace = true itertools.workspace = true log.workspace = true paths.workspace = true project.workspace = true +rand.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true @@ -35,6 +38,7 @@ settings.workspace = true smol.workspace = true strum.workspace = true tempfile.workspace = true +thiserror.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs new file mode 100644 index 0000000000..00e3e3df50 --- /dev/null +++ b/crates/agent_servers/src/acp.rs @@ -0,0 +1,34 @@ +use std::{path::Path, rc::Rc}; + +use crate::AgentServerCommand; +use acp_thread::AgentConnection; +use anyhow::Result; +use gpui::AsyncApp; +use thiserror::Error; + +mod v0; +mod v1; + +#[derive(Debug, Error)] +#[error("Unsupported version")] +pub struct UnsupportedVersion; + +pub async fn connect( + server_name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, +) -> Result> { + let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await; + + match conn { + Ok(conn) => Ok(Rc::new(conn) as _), + Err(err) if err.is::() => { + // Consider re-using initialize response and subprocess when adding another version here + let conn: Rc = + Rc::new(v0::AcpConnection::stdio(server_name, command, &root_dir, cx).await?); + Ok(conn) + } + Err(err) => Err(err), + } +} diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs new file mode 100644 index 0000000000..c0b64fcc41 --- /dev/null +++ b/crates/agent_servers/src/acp/v0.rs @@ -0,0 +1,508 @@ +// Translates old acp agents into the new schema +use agent_client_protocol as acp; +use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; +use anyhow::{Context as _, Result, anyhow}; +use futures::channel::oneshot; +use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use project::Project; +use std::{cell::RefCell, path::Path, rc::Rc}; +use ui::App; +use util::ResultExt as _; + +use crate::AgentServerCommand; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; + +#[derive(Clone)] +struct OldAcpClientDelegate { + thread: Rc>>, + cx: AsyncApp, + next_tool_call_id: Rc>, + // sent_buffer_versions: HashMap, HashMap>, +} + +impl OldAcpClientDelegate { + fn new(thread: Rc>>, cx: AsyncApp) -> Self { + Self { + thread, + cx, + next_tool_call_id: Rc::new(RefCell::new(0)), + } + } +} + +impl acp_old::Client for OldAcpClientDelegate { + async fn stream_assistant_message_chunk( + &self, + params: acp_old::StreamAssistantMessageChunkParams, + ) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread + .borrow() + .update(cx, |thread, cx| match params.chunk { + acp_old::AssistantMessageChunk::Text { text } => { + thread.push_assistant_content_block(text.into(), false, cx) + } + acp_old::AssistantMessageChunk::Thought { thought } => { + thread.push_assistant_content_block(thought.into(), true, cx) + } + }) + .log_err(); + })?; + + Ok(()) + } + + async fn request_tool_call_confirmation( + &self, + request: acp_old::RequestToolCallConfirmationParams, + ) -> Result { + let cx = &mut self.cx.clone(); + + let old_acp_id = *self.next_tool_call_id.borrow() + 1; + self.next_tool_call_id.replace(old_acp_id); + + let tool_call = into_new_tool_call( + acp::ToolCallId(old_acp_id.to_string().into()), + request.tool_call, + ); + + let mut options = match request.confirmation { + acp_old::ToolCallConfirmation::Edit { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow Edits".to_string(), + )], + acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", root_command), + )], + acp_old::ToolCallConfirmation::Mcp { + server_name, + tool_name, + .. + } => vec![ + ( + acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", server_name), + ), + ( + acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", tool_name), + ), + ], + acp_old::ToolCallConfirmation::Fetch { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow".to_string(), + )], + acp_old::ToolCallConfirmation::Other { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow".to_string(), + )], + }; + + options.extend([ + ( + acp_old::ToolCallConfirmationOutcome::Allow, + acp::PermissionOptionKind::AllowOnce, + "Allow".to_string(), + ), + ( + acp_old::ToolCallConfirmationOutcome::Reject, + acp::PermissionOptionKind::RejectOnce, + "Reject".to_string(), + ), + ]); + + let mut outcomes = Vec::with_capacity(options.len()); + let mut acp_options = Vec::with_capacity(options.len()); + + for (index, (outcome, kind, label)) in options.into_iter().enumerate() { + outcomes.push(outcome); + acp_options.push(acp::PermissionOption { + id: acp::PermissionOptionId(index.to_string().into()), + name: label, + kind, + }) + } + + let response = cx + .update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.request_tool_call_permission(tool_call, acp_options, cx) + }) + })? + .context("Failed to update thread")? + .await; + + let outcome = match response { + Ok(option_id) => outcomes[option_id.0.parse::().unwrap_or(0)], + Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel, + }; + + Ok(acp_old::RequestToolCallConfirmationResponse { + id: acp_old::ToolCallId(old_acp_id), + outcome: outcome, + }) + } + + async fn push_tool_call( + &self, + request: acp_old::PushToolCallParams, + ) -> Result { + let cx = &mut self.cx.clone(); + + let old_acp_id = *self.next_tool_call_id.borrow() + 1; + self.next_tool_call_id.replace(old_acp_id); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.upsert_tool_call( + into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request), + cx, + ) + }) + })? + .context("Failed to update thread")?; + + Ok(acp_old::PushToolCallResponse { + id: acp_old::ToolCallId(old_acp_id), + }) + } + + async fn update_tool_call( + &self, + request: acp_old::UpdateToolCallParams, + ) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.update_tool_call( + acp::ToolCallUpdate { + id: acp::ToolCallId(request.tool_call_id.0.to_string().into()), + fields: acp::ToolCallUpdateFields { + status: Some(into_new_tool_call_status(request.status)), + content: Some( + request + .content + .into_iter() + .map(into_new_tool_call_content) + .collect::>(), + ), + ..Default::default() + }, + }, + cx, + ) + }) + })? + .context("Failed to update thread")??; + + Ok(()) + } + + async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.update_plan( + acp::Plan { + entries: request + .entries + .into_iter() + .map(into_new_plan_entry) + .collect(), + }, + cx, + ) + }) + })? + .context("Failed to update thread")?; + + Ok(()) + } + + async fn read_text_file( + &self, + acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams, + ) -> Result { + let content = self + .cx + .update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.read_text_file(path, line, limit, false, cx) + }) + })? + .context("Failed to update thread")? + .await?; + Ok(acp_old::ReadTextFileResponse { content }) + } + + async fn write_text_file( + &self, + acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams, + ) -> Result<(), acp_old::Error> { + self.cx + .update(|cx| { + self.thread + .borrow() + .update(cx, |thread, cx| thread.write_text_file(path, content, cx)) + })? + .context("Failed to update thread")? + .await?; + + Ok(()) + } +} + +fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall { + acp::ToolCall { + id: id, + title: request.label, + kind: acp_kind_from_old_icon(request.icon), + status: acp::ToolCallStatus::InProgress, + content: request + .content + .into_iter() + .map(into_new_tool_call_content) + .collect(), + locations: request + .locations + .into_iter() + .map(into_new_tool_call_location) + .collect(), + raw_input: None, + } +} + +fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind { + match icon { + acp_old::Icon::FileSearch => acp::ToolKind::Search, + acp_old::Icon::Folder => acp::ToolKind::Search, + acp_old::Icon::Globe => acp::ToolKind::Search, + acp_old::Icon::Hammer => acp::ToolKind::Other, + acp_old::Icon::LightBulb => acp::ToolKind::Think, + acp_old::Icon::Pencil => acp::ToolKind::Edit, + acp_old::Icon::Regex => acp::ToolKind::Search, + acp_old::Icon::Terminal => acp::ToolKind::Execute, + } +} + +fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus { + match status { + acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress, + acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed, + acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed, + } +} + +fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent { + match content { + acp_old::ToolCallContent::Markdown { markdown } => markdown.into(), + acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff { + diff: into_new_diff(diff), + }, + } +} + +fn into_new_diff(diff: acp_old::Diff) -> acp::Diff { + acp::Diff { + path: diff.path, + old_text: diff.old_text, + new_text: diff.new_text, + } +} + +fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation { + acp::ToolCallLocation { + path: location.path, + line: location.line, + } +} + +fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry { + acp::PlanEntry { + content: entry.content, + priority: into_new_plan_priority(entry.priority), + status: into_new_plan_status(entry.status), + } +} + +fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority { + match priority { + acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low, + acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium, + acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High, + } +} + +fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus { + match status { + acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending, + acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress, + acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed, + } +} + +pub struct AcpConnection { + pub name: &'static str, + pub connection: acp_old::AgentConnection, + pub _child_status: Task>, + pub current_thread: Rc>>, +} + +impl AcpConnection { + pub fn stdio( + name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, + ) -> Task> { + let root_dir = root_dir.to_path_buf(); + + cx.spawn(async move |cx| { + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + log::trace!("Spawned (pid: {})", child.id()); + + let foreground_executor = cx.foreground_executor().clone(); + + let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => Err(anyhow!(result)), + }; + drop(io_task); + result + }); + + Ok(Self { + name, + connection, + _child_status: child_status, + current_thread: thread_rc, + }) + }) + } +} + +impl AgentConnection for AcpConnection { + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let task = self.connection.request_any( + acp_old::InitializeParams { + protocol_version: acp_old::ProtocolVersion::latest(), + } + .into_any(), + ); + let current_thread = self.current_thread.clone(); + cx.spawn(async move |cx| { + let result = task.await?; + let result = acp_old::InitializeParams::response_from_any(result)?; + + if !result.is_authenticated { + anyhow::bail!(AuthRequired) + } + + cx.update(|cx| { + let thread = cx.new(|cx| { + let session_id = acp::SessionId("acp-old-no-id".into()); + AcpThread::new(self.name, self.clone(), project, session_id, cx) + }); + current_thread.replace(thread.downgrade()); + thread + }) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let task = self + .connection + .request_any(acp_old::AuthenticateParams.into_any()); + cx.foreground_executor().spawn(async move { + task.await?; + Ok(()) + }) + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let chunks = params + .prompt + .into_iter() + .filter_map(|block| match block { + acp::ContentBlock::Text(text) => { + Some(acp_old::UserMessageChunk::Text { text: text.text }) + } + acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path { + path: link.uri.into(), + }), + _ => None, + }) + .collect(); + + let task = self + .connection + .request_any(acp_old::SendUserMessageParams { chunks }.into_any()); + cx.foreground_executor().spawn(async move { + task.await?; + anyhow::Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) { + let task = self + .connection + .request_any(acp_old::CancelSendMessageParams.into_any()); + cx.foreground_executor() + .spawn(async move { + task.await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx) + } +} diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs new file mode 100644 index 0000000000..178796816a --- /dev/null +++ b/crates/agent_servers/src/acp/v1.rs @@ -0,0 +1,282 @@ +use agent_client_protocol::{self as acp, Agent as _}; +use anyhow::anyhow; +use collections::HashMap; +use futures::channel::oneshot; +use project::Project; +use std::cell::RefCell; +use std::path::Path; +use std::rc::Rc; + +use anyhow::{Context as _, Result}; +use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; + +use crate::{AgentServerCommand, acp::UnsupportedVersion}; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; + +pub struct AcpConnection { + server_name: &'static str, + connection: Rc, + sessions: Rc>>, + auth_methods: Vec, + _io_task: Task>, +} + +pub struct AcpSession { + thread: WeakEntity, +} + +const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1; + +impl AcpConnection { + pub async fn stdio( + server_name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, + ) -> Result { + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter().map(|arg| arg.as_str())) + .envs(command.env.iter().flatten()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdout = child.stdout.take().expect("Failed to take stdout"); + let stdin = child.stdin.take().expect("Failed to take stdin"); + log::trace!("Spawned (pid: {})", child.id()); + + let sessions = Rc::new(RefCell::new(HashMap::default())); + + let client = ClientDelegate { + sessions: sessions.clone(), + cx: cx.clone(), + }; + let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, { + let foreground_executor = cx.foreground_executor().clone(); + move |fut| { + foreground_executor.spawn(fut).detach(); + } + }); + + let io_task = cx.background_spawn(io_task); + + cx.spawn({ + let sessions = sessions.clone(); + async move |cx| { + let status = child.status().await?; + + for session in sessions.borrow().values() { + session + .thread + .update(cx, |thread, cx| thread.emit_server_exited(status, cx)) + .ok(); + } + + anyhow::Ok(()) + } + }) + .detach(); + + let response = connection + .initialize(acp::InitializeRequest { + protocol_version: acp::VERSION, + client_capabilities: acp::ClientCapabilities { + fs: acp::FileSystemCapability { + read_text_file: true, + write_text_file: true, + }, + }, + }) + .await?; + + if response.protocol_version < MINIMUM_SUPPORTED_VERSION { + return Err(UnsupportedVersion.into()); + } + + Ok(Self { + auth_methods: response.auth_methods, + connection: connection.into(), + server_name, + sessions, + _io_task: io_task, + }) + } +} + +impl AgentConnection for AcpConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let conn = self.connection.clone(); + let sessions = self.sessions.clone(); + let cwd = cwd.to_path_buf(); + cx.spawn(async move |cx| { + let response = conn + .new_session(acp::NewSessionRequest { + mcp_servers: vec![], + cwd, + }) + .await + .map_err(|err| { + if err.code == acp::ErrorCode::AUTH_REQUIRED.code { + anyhow!(AuthRequired) + } else { + anyhow!(err) + } + })?; + + let session_id = response.session_id; + + let thread = cx.new(|cx| { + AcpThread::new( + self.server_name, + self.clone(), + project, + session_id.clone(), + cx, + ) + })?; + + let session = AcpSession { + thread: thread.downgrade(), + }; + sessions.borrow_mut().insert(session_id, session); + + Ok(thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor().spawn(async move { + let result = conn + .authenticate(acp::AuthenticateRequest { + method_id: method_id.clone(), + }) + .await?; + + Ok(result) + }) + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor().spawn(async move { + let response = conn.prompt(params).await?; + Ok(response) + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let conn = self.connection.clone(); + let params = acp::CancelNotification { + session_id: session_id.clone(), + }; + cx.foreground_executor() + .spawn(async move { conn.cancel(params).await }) + .detach(); + } +} + +struct ClientDelegate { + sessions: Rc>>, + cx: AsyncApp, +} + +impl acp::Client for ClientDelegate { + async fn request_permission( + &self, + arguments: acp::RequestPermissionRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let rx = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx) + })?; + + let result = rx.await; + + let outcome = match result { + Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, + Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled, + }; + + Ok(acp::RequestPermissionResponse { outcome }) + } + + async fn write_text_file( + &self, + arguments: acp::WriteTextFileRequest, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let task = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.write_text_file(arguments.path, arguments.content, cx) + })?; + + task.await?; + + Ok(()) + } + + async fn read_text_file( + &self, + arguments: acp::ReadTextFileRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let task = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx) + })?; + + let content = task.await?; + + Ok(acp::ReadTextFileResponse { content }) + } + + async fn session_notification( + &self, + notification: acp::SessionNotification, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let sessions = self.sessions.borrow(); + let session = sessions + .get(¬ification.session_id) + .context("Failed to get session")?; + + session.thread.update(cx, |thread, cx| { + thread.handle_session_update(notification.update, cx) + })??; + + Ok(()) + } +} diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 6d9c77f296..b3b8a33170 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,7 +1,7 @@ +mod acp; mod claude; mod gemini; mod settings; -mod stdio_agent_server; #[cfg(test)] mod e2e_tests; @@ -9,9 +9,8 @@ mod e2e_tests; pub use claude::*; pub use gemini::*; pub use settings::*; -pub use stdio_agent_server::*; -use acp_thread::AcpThread; +use acp_thread::AgentConnection; use anyhow::Result; use collections::HashMap; use gpui::{App, AsyncApp, Entity, SharedString, Task}; @@ -20,6 +19,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ path::{Path, PathBuf}, + rc::Rc, sync::Arc, }; use util::ResultExt as _; @@ -33,14 +33,13 @@ pub trait AgentServer: Send { fn name(&self) -> &'static str; fn empty_state_headline(&self) -> &'static str; fn empty_state_message(&self) -> &'static str; - fn supports_always_allow(&self) -> bool; - fn new_thread( + fn connect( &self, root_dir: &Path, project: &Entity, cx: &mut App, - ) -> Task>>; + ) -> Task>>; } impl std::fmt::Debug for AgentServerCommand { @@ -90,6 +89,7 @@ impl AgentServerCommand { pub(crate) async fn resolve( path_bin_name: &'static str, extra_args: &[&'static str], + fallback_path: Option<&Path>, settings: Option, project: &Entity, cx: &mut AsyncApp, @@ -106,13 +106,24 @@ impl AgentServerCommand { env: agent_settings.command.env, }); } else { - find_bin_in_path(path_bin_name, project, cx) - .await - .map(|path| Self { + match find_bin_in_path(path_bin_name, project, cx).await { + Some(path) => Some(Self { path, args: extra_args.iter().map(|arg| arg.to_string()).collect(), env: None, - }) + }), + None => fallback_path.and_then(|path| { + if path.exists() { + Some(Self { + path: path.to_path_buf(), + args: extra_args.iter().map(|arg| arg.to_string()).collect(), + env: None, + }) + } else { + None + } + }), + } } } } diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 835efbd655..59e2e87433 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -1,39 +1,35 @@ mod mcp_server; -mod tools; +pub mod tools; use collections::HashMap; +use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; use std::cell::RefCell; use std::fmt::Display; use std::path::Path; -use std::pin::pin; use std::rc::Rc; use uuid::Uuid; -use agentic_coding_protocol::{ - self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion, - StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams, -}; +use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use futures::channel::oneshot; -use futures::future::LocalBoxFuture; -use futures::{AsyncBufReadExt, AsyncWriteExt, SinkExt}; +use futures::{AsyncBufReadExt, AsyncWriteExt}; use futures::{ AsyncRead, AsyncWrite, FutureExt, StreamExt, channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, io::BufReader, select_biased, }; -use gpui::{App, AppContext, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use serde::{Deserialize, Serialize}; use util::ResultExt; -use crate::claude::mcp_server::ClaudeMcpServer; +use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::tools::ClaudeTool; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection}; +use acp_thread::{AcpThread, AgentConnection}; #[derive(Clone)] pub struct ClaudeCode; @@ -48,36 +44,47 @@ impl AgentServer for ClaudeCode { } fn empty_state_message(&self) -> &'static str { - "" + "How can I help you today?" } fn logo(&self) -> ui::IconName { ui::IconName::AiClaude } - fn supports_always_allow(&self) -> bool { - false - } - - fn new_thread( + fn connect( &self, - root_dir: &Path, - project: &Entity, - cx: &mut App, - ) -> Task>> { - let project = project.clone(); - let root_dir = root_dir.to_path_buf(); - let title = self.name().into(); - cx.spawn(async move |cx| { - let (mut delegate_tx, delegate_rx) = watch::channel(None); - let tool_id_map = Rc::new(RefCell::new(HashMap::default())); + _root_dir: &Path, + _project: &Entity, + _cx: &mut App, + ) -> Task>> { + let connection = ClaudeAgentConnection { + sessions: Default::default(), + }; - let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?; + Task::ready(Ok(Rc::new(connection) as _)) + } +} + +struct ClaudeAgentConnection { + sessions: Rc>>, +} + +impl AgentConnection for ClaudeAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let cwd = cwd.to_owned(); + cx.spawn(async move |cx| { + let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); + let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; let mut mcp_servers = HashMap::default(); mcp_servers.insert( mcp_server::SERVER_NAME.to_string(), - mcp_server.server_config()?, + permission_mcp_server.server_config()?, ); let mcp_config = McpConfig { mcp_servers }; @@ -94,200 +101,196 @@ impl AgentServer for ClaudeCode { settings.get::(None).claude.clone() })?; - let Some(command) = - AgentServerCommand::resolve("claude", &[], settings, &project, cx).await + let Some(command) = AgentServerCommand::resolve( + "claude", + &[], + Some(&util::paths::home_dir().join(".claude/local/claude")), + settings, + &project, + cx, + ) + .await else { anyhow::bail!("Failed to find claude binary"); }; let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); - let (cancel_tx, mut cancel_rx) = mpsc::unbounded::>>(); - let session_id = Uuid::new_v4(); + let session_id = acp::SessionId(Uuid::new_v4().to_string().into()); log::trace!("Starting session with id: {}", session_id); + let mut child = spawn_claude( + &command, + ClaudeSessionMode::Start, + session_id.clone(), + &mcp_config_path, + &cwd, + )?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + let pid = child.id(); + log::trace!("Spawned (pid: {})", pid); + cx.background_spawn(async move { let mut outgoing_rx = Some(outgoing_rx); - let mut mode = ClaudeSessionMode::Start; - loop { - let mut child = - spawn_claude(&command, mode, session_id, &mcp_config_path, &root_dir) - .await?; - mode = ClaudeSessionMode::Resume; + ClaudeAgentSession::handle_io( + outgoing_rx.take().unwrap(), + incoming_message_tx.clone(), + stdin, + stdout, + ) + .await?; - let pid = child.id(); - log::trace!("Spawned (pid: {})", pid); - - let mut io_fut = pin!( - ClaudeAgentConnection::handle_io( - outgoing_rx.take().unwrap(), - incoming_message_tx.clone(), - child.stdin.take().unwrap(), - child.stdout.take().unwrap(), - ) - .fuse() - ); - - select_biased! { - done_tx = cancel_rx.next() => { - if let Some(done_tx) = done_tx { - log::trace!("Interrupted (pid: {})", pid); - let result = send_interrupt(pid as i32); - outgoing_rx.replace(io_fut.await?); - done_tx.send(result).log_err(); - continue; - } - } - result = io_fut => { - result?; - } - } - - log::trace!("Stopped (pid: {})", pid); - break; - } + log::trace!("Stopped (pid: {})", pid); drop(mcp_config_path); anyhow::Ok(()) }) .detach(); - cx.new(|cx| { - let end_turn_tx = Rc::new(RefCell::new(None)); - let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()); - delegate_tx.send(Some(delegate.clone())).log_err(); + let end_turn_tx = Rc::new(RefCell::new(None)); + let handler_task = cx.spawn({ + let end_turn_tx = end_turn_tx.clone(); + let mut thread_rx = thread_rx.clone(); + async move |cx| { + while let Some(message) = incoming_message_rx.next().await { + ClaudeAgentSession::handle_message( + thread_rx.clone(), + message, + end_turn_tx.clone(), + cx, + ) + .await + } - let handler_task = cx.foreground_executor().spawn({ - let end_turn_tx = end_turn_tx.clone(); - let tool_id_map = tool_id_map.clone(); - let delegate = delegate.clone(); - async move { - while let Some(message) = incoming_message_rx.next().await { - ClaudeAgentConnection::handle_message( - delegate.clone(), - message, - end_turn_tx.clone(), - tool_id_map.clone(), - ) - .await + if let Some(status) = child.status().await.log_err() { + if let Some(thread) = thread_rx.recv().await.ok() { + thread + .update(cx, |thread, cx| { + thread.emit_server_exited(status, cx); + }) + .ok(); } } - }); + } + }); - let mut connection = ClaudeAgentConnection { - delegate, - outgoing_tx, - end_turn_tx, - cancel_tx, - session_id, - _handler_task: handler_task, - _mcp_server: None, - }; + let thread = cx.new(|cx| { + AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) + })?; - connection._mcp_server = Some(mcp_server); - acp_thread::AcpThread::new(connection, title, None, project.clone(), cx) - }) + thread_tx.send(thread.downgrade())?; + + let session = ClaudeAgentSession { + outgoing_tx, + end_turn_tx, + _handler_task: handler_task, + _mcp_server: Some(permission_mcp_server), + }; + + self.sessions.borrow_mut().insert(session_id, session); + + Ok(thread) }) } -} -#[cfg(unix)] -fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> { - let pid = nix::unistd::Pid::from_raw(pid); + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } - nix::sys::signal::kill(pid, nix::sys::signal::SIGINT) - .map_err(|e| anyhow!("Failed to interrupt process: {}", e)) -} + fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow!("Authentication not supported"))) + } -#[cfg(windows)] -fn send_interrupt(_pid: i32) -> anyhow::Result<()> { - panic!("Cancel not implemented on Windows") -} - -impl AgentConnection for ClaudeAgentConnection { - /// Send a request to the agent and wait for a response. - fn request_any( + fn prompt( &self, - params: AnyAgentRequest, - ) -> LocalBoxFuture<'static, Result> { - let delegate = self.delegate.clone(); - let end_turn_tx = self.end_turn_tx.clone(); - let outgoing_tx = self.outgoing_tx.clone(); - let mut cancel_tx = self.cancel_tx.clone(); - let session_id = self.session_id; - async move { - match params { - // todo: consider sending an empty request so we get the init response? - AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse( - acp::InitializeResponse { - is_authenticated: true, - protocol_version: ProtocolVersion::latest(), - }, - )), - AnyAgentRequest::AuthenticateParams(_) => { - Err(anyhow!("Authentication not supported")) - } - AnyAgentRequest::SendUserMessageParams(message) => { - delegate.clear_completed_plan_entries().await?; + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(¶ms.session_id) else { + return Task::ready(Err(anyhow!( + "Attempted to send message to nonexistent session {}", + params.session_id + ))); + }; - let (tx, rx) = oneshot::channel(); - end_turn_tx.borrow_mut().replace(tx); - let mut content = String::new(); - for chunk in message.chunks { - match chunk { - agentic_coding_protocol::UserMessageChunk::Text { text } => { - content.push_str(&text) - } - agentic_coding_protocol::UserMessageChunk::Path { path } => { - content.push_str(&format!("@{path:?}")) - } - } - } - outgoing_tx.unbounded_send(SdkMessage::User { - message: Message { - role: Role::User, - content: Content::UntaggedText(content), - id: None, - model: None, - stop_reason: None, - stop_sequence: None, - usage: None, - }, - session_id: Some(session_id), - })?; - rx.await??; - Ok(AnyAgentResult::SendUserMessageResponse( - acp::SendUserMessageResponse, - )) - } - AnyAgentRequest::CancelSendMessageParams(_) => { - let (done_tx, done_rx) = oneshot::channel(); - cancel_tx.send(done_tx).await?; - done_rx.await??; + let (tx, rx) = oneshot::channel(); + session.end_turn_tx.borrow_mut().replace(tx); - Ok(AnyAgentResult::CancelSendMessageResponse( - acp::CancelSendMessageResponse, - )) + let mut content = String::new(); + for chunk in params.prompt { + match chunk { + acp::ContentBlock::Text(text_content) => { + content.push_str(&text_content.text); + } + acp::ContentBlock::ResourceLink(resource_link) => { + content.push_str(&format!("@{}", resource_link.uri)); + } + acp::ContentBlock::Audio(_) + | acp::ContentBlock::Image(_) + | acp::ContentBlock::Resource(_) => { + // TODO } } } - .boxed_local() + + if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User { + message: Message { + role: Role::User, + content: Content::UntaggedText(content), + id: None, + model: None, + stop_reason: None, + stop_sequence: None, + usage: None, + }, + session_id: Some(params.session_id.to_string()), + }) { + return Task::ready(Err(anyhow!(err))); + } + + cx.foreground_executor().spawn(async move { rx.await? }) + } + + fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(&session_id) else { + log::warn!("Attempted to cancel nonexistent session {}", session_id); + return; + }; + + session + .outgoing_tx + .unbounded_send(SdkMessage::new_interrupt_message()) + .log_err(); + + if let Some(end_turn_tx) = session.end_turn_tx.borrow_mut().take() { + end_turn_tx + .send(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Cancelled, + })) + .ok(); + } } } #[derive(Clone, Copy)] enum ClaudeSessionMode { Start, + #[expect(dead_code)] Resume, } -async fn spawn_claude( +fn spawn_claude( command: &AgentServerCommand, mode: ClaudeSessionMode, - session_id: Uuid, + session_id: acp::SessionId, mcp_config_path: &Path, root_dir: &Path, ) -> Result { @@ -305,10 +308,16 @@ async fn spawn_claude( &format!( "mcp__{}__{}", mcp_server::SERVER_NAME, - mcp_server::PERMISSION_TOOL + mcp_server::PermissionTool::NAME, ), "--allowedTools", - "mcp__zed__Read,mcp__zed__Edit", + &format!( + "mcp__{}__{},mcp__{}__{}", + mcp_server::SERVER_NAME, + mcp_server::EditTool::NAME, + mcp_server::SERVER_NAME, + mcp_server::ReadTool::NAME + ), "--disallowedTools", "Read,Edit", ]) @@ -327,105 +336,158 @@ async fn spawn_claude( Ok(child) } -struct ClaudeAgentConnection { - delegate: AcpClientDelegate, - session_id: Uuid, +struct ClaudeAgentSession { outgoing_tx: UnboundedSender, - end_turn_tx: Rc>>>>, - cancel_tx: UnboundedSender>>, - _mcp_server: Option, + end_turn_tx: Rc>>>>, + _mcp_server: Option, _handler_task: Task<()>, } -impl ClaudeAgentConnection { +impl ClaudeAgentSession { async fn handle_message( - delegate: AcpClientDelegate, + mut thread_rx: watch::Receiver>, message: SdkMessage, - end_turn_tx: Rc>>>>, - tool_id_map: Rc>>, + end_turn_tx: Rc>>>>, + cx: &mut AsyncApp, ) { match message { - SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => { + // we should only be sending these out, they don't need to be in the thread + SdkMessage::ControlRequest { .. } => {} + SdkMessage::Assistant { + message, + session_id: _, + } + | SdkMessage::User { + message, + session_id: _, + } => { + let Some(thread) = thread_rx + .recv() + .await + .log_err() + .and_then(|entity| entity.upgrade()) + else { + log::error!("Received an SDK message but thread is gone"); + return; + }; + for chunk in message.content.chunks() { match chunk { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - delegate - .stream_assistant_message_chunk(StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Text { text }, + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block(text.into(), false, cx) + }) + .log_err(); + } + ContentChunk::Thinking { thinking } => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block(thinking.into(), true, cx) + }) + .log_err(); + } + ContentChunk::RedactedThinking => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block( + "[REDACTED]".into(), + true, + cx, + ) }) - .await .log_err(); } ContentChunk::ToolUse { id, name, input } => { let claude_tool = ClaudeTool::infer(&name, input); - if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { - delegate - .update_plan(acp::UpdatePlanParams { - entries: params.todos.into_iter().map(Into::into).collect(), - }) - .await - .log_err(); - } else if let Some(resp) = delegate - .push_tool_call(claude_tool.as_acp()) - .await - .log_err() - { - tool_id_map.borrow_mut().insert(id, resp.id); - } + thread + .update(cx, |thread, cx| { + if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { + thread.update_plan( + acp::Plan { + entries: params + .todos + .into_iter() + .map(Into::into) + .collect(), + }, + cx, + ) + } else { + thread.upsert_tool_call( + claude_tool.as_acp(acp::ToolCallId(id.into())), + cx, + ); + } + }) + .log_err(); } ContentChunk::ToolResult { content, tool_use_id, } => { - let id = tool_id_map.borrow_mut().remove(&tool_use_id); - if let Some(id) = id { - let content = content.to_string(); - delegate - .update_tool_call(UpdateToolCallParams { - tool_call_id: id, - status: acp::ToolCallStatus::Finished, - // Don't unset existing content - content: (!content.is_empty()).then_some( - ToolCallContent::Markdown { - // For now we only include text content - markdown: content, + let content = content.to_string(); + thread + .update(cx, |thread, cx| { + thread.update_tool_call( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use_id.into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + content: (!content.is_empty()) + .then(|| vec![content.into()]), + ..Default::default() }, - ), - }) - .await - .log_err(); - } + }, + cx, + ) + }) + .log_err(); } ContentChunk::Image | ContentChunk::Document - | ContentChunk::Thinking - | ContentChunk::RedactedThinking | ContentChunk::WebSearchToolResult => { - delegate - .stream_assistant_message_chunk(StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Text { - text: format!("Unsupported content: {:?}", chunk), - }, + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block( + format!("Unsupported content: {:?}", chunk).into(), + false, + cx, + ) }) - .await .log_err(); } } } } SdkMessage::Result { - is_error, subtype, .. + is_error, + subtype, + result, + .. } => { if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { - if is_error { - end_turn_tx.send(Err(anyhow!("Error: {subtype}"))).ok(); + if is_error || subtype == ResultErrorType::ErrorDuringExecution { + end_turn_tx + .send(Err(anyhow!( + "Error: {}", + result.unwrap_or_else(|| subtype.to_string()) + ))) + .ok(); } else { - end_turn_tx.send(Ok(())).ok(); + let stop_reason = match subtype { + ResultErrorType::Success => acp::StopReason::EndTurn, + ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, + ResultErrorType::ErrorDuringExecution => unreachable!(), + }; + end_turn_tx + .send(Ok(acp::PromptResponse { stop_reason })) + .ok(); } } } - SdkMessage::System { .. } => {} + SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {} } } @@ -534,11 +596,13 @@ enum ContentChunk { content: Content, tool_use_id: String, }, + Thinking { + thinking: String, + }, + RedactedThinking, // TODO Image, Document, - Thinking, - RedactedThinking, WebSearchToolResult, #[serde(untagged)] UntaggedText(String), @@ -548,12 +612,12 @@ impl Display for ContentChunk { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ContentChunk::Text { text } => write!(f, "{}", text), + ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking), + ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"), ContentChunk::UntaggedText(text) => write!(f, "{}", text), ContentChunk::ToolResult { content, .. } => write!(f, "{}", content), ContentChunk::Image | ContentChunk::Document - | ContentChunk::Thinking - | ContentChunk::RedactedThinking | ContentChunk::ToolUse { .. } | ContentChunk::WebSearchToolResult => { write!(f, "\n{:?}\n", &self) @@ -592,16 +656,14 @@ enum SdkMessage { Assistant { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, - // A user message User { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, - // Emitted as the last message in a conversation Result { subtype: ResultErrorType, @@ -626,9 +688,29 @@ enum SdkMessage { #[serde(rename = "permissionMode")] permission_mode: PermissionMode, }, + /// Messages used to control the conversation, outside of chat messages to the model + ControlRequest { + request_id: String, + request: ControlRequest, + }, + /// Response to a control request + ControlResponse { response: ControlResponse }, } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "subtype", rename_all = "snake_case")] +enum ControlRequest { + /// Cancel the current conversation + Interrupt, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ControlResponse { + request_id: String, + subtype: ResultErrorType, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[serde(rename_all = "snake_case")] enum ResultErrorType { Success, @@ -646,6 +728,24 @@ impl Display for ResultErrorType { } } +impl SdkMessage { + fn new_interrupt_message() -> Self { + use rand::Rng; + // In the Claude Code TS SDK they just generate a random 12 character string, + // `Math.random().toString(36).substring(2, 15)` + let request_id = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(12) + .map(char::from) + .collect(); + + Self::ControlRequest { + request_id, + request: ControlRequest::Interrupt, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] struct McpServer { name: String, @@ -661,27 +761,14 @@ enum PermissionMode { Plan, } -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct McpConfig { - mcp_servers: HashMap, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct McpServerConfig { - command: String, - args: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - env: Option>, -} - #[cfg(test)] pub(crate) mod tests { use super::*; + use crate::e2e_tests; + use gpui::TestAppContext; use serde_json::json; - crate::common_e2e_tests!(ClaudeCode); + crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow"); pub fn local_command() -> AgentServerCommand { AgentServerCommand { @@ -691,6 +778,68 @@ pub(crate) mod tests { } } + #[gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn test_todo_plan(cx: &mut TestAppContext) { + let fs = e2e_tests::init_test(cx).await; + let project = Project::test(fs, [], cx).await; + let thread = + e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.", + cx, + ) + }) + .await + .unwrap(); + + let mut entries_len = 0; + + thread.read_with(cx, |thread, _| { + entries_len = thread.plan().entries.len(); + assert!(thread.plan().entries.len() > 0, "Empty plan"); + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Mark the first entry status as in progress without acting on it.", + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.plan().entries[0].status, + acp::PlanEntryStatus::InProgress + )); + assert_eq!(thread.plan().entries.len(), entries_len); + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Now mark the first entry as completed without acting on it.", + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.plan().entries[0].status, + acp::PlanEntryStatus::Completed + )); + assert_eq!(thread.plan().entries.len(), entries_len); + }); + } + #[test] fn test_deserialize_content_untagged_text() { let json = json!("Hello, world!"); diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index 2405603550..c6f8bb5b69 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -1,78 +1,53 @@ -use std::{cell::RefCell, rc::Rc}; +use std::path::PathBuf; -use acp_thread::AcpClientDelegate; -use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams}; +use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; +use acp_thread::AcpThread; +use agent_client_protocol as acp; use anyhow::{Context, Result}; use collections::HashMap; -use context_server::{ - listener::McpServer, - types::{ - CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse, - ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, - ToolResponseContent, ToolsCapabilities, requests, - }, +use context_server::listener::{McpServerTool, ToolResponse}; +use context_server::types::{ + Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, + ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, }; -use gpui::{App, AsyncApp, Task}; +use gpui::{App, AsyncApp, Task, WeakEntity}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use util::debug_panic; -use crate::claude::{ - McpServerConfig, - tools::{ClaudeTool, EditToolParams, ReadToolParams}, -}; - -pub struct ClaudeMcpServer { - server: McpServer, +pub struct ClaudeZedMcpServer { + server: context_server::listener::McpServer, } pub const SERVER_NAME: &str = "zed"; -pub const READ_TOOL: &str = "Read"; -pub const EDIT_TOOL: &str = "Edit"; -pub const PERMISSION_TOOL: &str = "Confirmation"; -#[derive(Deserialize, JsonSchema, Debug)] -struct PermissionToolParams { - tool_name: String, - input: serde_json::Value, - tool_use_id: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct PermissionToolResponse { - behavior: PermissionToolBehavior, - updated_input: serde_json::Value, -} - -#[derive(Serialize)] -#[serde(rename_all = "snake_case")] -enum PermissionToolBehavior { - Allow, - Deny, -} - -impl ClaudeMcpServer { +impl ClaudeZedMcpServer { pub async fn new( - delegate: watch::Receiver>, - tool_id_map: Rc>>, + thread_rx: watch::Receiver>, cx: &AsyncApp, ) -> Result { - let mut mcp_server = McpServer::new(cx).await?; + let mut mcp_server = context_server::listener::McpServer::new(cx).await?; mcp_server.handle_request::(Self::handle_initialize); - mcp_server.handle_request::(Self::handle_list_tools); - mcp_server.handle_request::(move |request, cx| { - Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx) + + mcp_server.add_tool(PermissionTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(ReadTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(EditTool { + thread_rx: thread_rx.clone(), }); Ok(Self { server: mcp_server }) } pub fn server_config(&self) -> Result { + #[cfg(not(test))] let zed_path = std::env::current_exe() - .context("finding current executable path for use in mcp_server")? - .to_string_lossy() - .to_string(); + .context("finding current executable path for use in mcp_server")?; + + #[cfg(test)] + let zed_path = crate::e2e_tests::get_zed_path(); Ok(McpServerConfig { command: zed_path, @@ -106,191 +81,222 @@ impl ClaudeMcpServer { }) }) } +} - fn handle_list_tools(_: (), cx: &App) -> Task> { - cx.foreground_executor().spawn(async move { - Ok(ListToolsResponse { - tools: vec![ - Tool { - name: PERMISSION_TOOL.into(), - input_schema: schemars::schema_for!(PermissionToolParams).into(), - description: None, - annotations: None, - }, - Tool { - name: READ_TOOL.into(), - input_schema: schemars::schema_for!(ReadToolParams).into(), - description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()), - annotations: Some(ToolAnnotations { - title: Some("Read file".to_string()), - read_only_hint: Some(true), - destructive_hint: Some(false), - open_world_hint: Some(false), - // if time passes the contents might change, but it's not going to do anything different - // true or false seem too strong, let's try a none. - idempotent_hint: None, - }), - }, - Tool { - name: EDIT_TOOL.into(), - input_schema: schemars::schema_for!(EditToolParams).into(), - description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()), - annotations: Some(ToolAnnotations { - title: Some("Edit file".to_string()), - read_only_hint: Some(false), - destructive_hint: Some(false), - open_world_hint: Some(false), - idempotent_hint: Some(false), - }), - }, - ], - next_cursor: None, - meta: None, - }) - }) +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfig { + pub mcp_servers: HashMap, +} + +#[derive(Serialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfig { + pub command: PathBuf, + pub args: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, +} + +// Tools + +#[derive(Clone)] +pub struct PermissionTool { + thread_rx: watch::Receiver>, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct PermissionToolParams { + tool_name: String, + input: serde_json::Value, + tool_use_id: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionToolResponse { + behavior: PermissionToolBehavior, + updated_input: serde_json::Value, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +enum PermissionToolBehavior { + Allow, + Deny, +} + +impl McpServerTool for PermissionTool { + type Input = PermissionToolParams; + type Output = (); + + const NAME: &'static str = "Confirmation"; + + fn description(&self) -> &'static str { + "Request permission for tool calls" } - fn handle_call_tool( - request: CallToolParams, - mut delegate_watch: watch::Receiver>, - tool_id_map: Rc>>, - cx: &App, - ) -> Task> { - cx.spawn(async move |cx| { - let Some(delegate) = delegate_watch.recv().await? else { - debug_panic!("Sent None delegate"); - anyhow::bail!("Server not available"); - }; + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; - if request.name.as_str() == PERMISSION_TOOL { - let input = - serde_json::from_value(request.arguments.context("Arguments required")?)?; + let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone()); + let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into()); + let allow_option_id = acp::PermissionOptionId("allow".into()); + let reject_option_id = acp::PermissionOptionId("reject".into()); - let result = - Self::handle_permissions_tool_call(input, delegate, tool_id_map, cx).await?; - Ok(CallToolResponse { - content: vec![ToolResponseContent::Text { - text: serde_json::to_string(&result)?, - }], - is_error: None, - meta: None, - }) - } else if request.name.as_str() == READ_TOOL { - let input = - serde_json::from_value(request.arguments.context("Arguments required")?)?; - - let content = Self::handle_read_tool_call(input, delegate, cx).await?; - Ok(CallToolResponse { - content, - is_error: None, - meta: None, - }) - } else if request.name.as_str() == EDIT_TOOL { - let input = - serde_json::from_value(request.arguments.context("Arguments required")?)?; - - Self::handle_edit_tool_call(input, delegate, cx).await?; - Ok(CallToolResponse { - content: vec![], - is_error: None, - meta: None, - }) - } else { - anyhow::bail!("Unsupported tool"); - } - }) - } - - fn handle_read_tool_call( - params: ReadToolParams, - delegate: AcpClientDelegate, - cx: &AsyncApp, - ) -> Task>> { - cx.foreground_executor().spawn(async move { - let response = delegate - .read_text_file(ReadTextFileParams { - path: params.abs_path, - line: params.offset, - limit: params.limit, - }) - .await?; - - Ok(vec![ToolResponseContent::Text { - text: response.content, - }]) - }) - } - - fn handle_edit_tool_call( - params: EditToolParams, - delegate: AcpClientDelegate, - cx: &AsyncApp, - ) -> Task> { - cx.foreground_executor().spawn(async move { - let response = delegate - .read_text_file_reusing_snapshot(ReadTextFileParams { - path: params.abs_path.clone(), - line: None, - limit: None, - }) - .await?; - - let new_content = response.content.replace(¶ms.old_text, ¶ms.new_text); - if new_content == response.content { - return Err(anyhow::anyhow!("The old_text was not found in the content")); - } - - delegate - .write_text_file(WriteTextFileParams { - path: params.abs_path, - content: new_content, - }) - .await?; - - Ok(()) - }) - } - - fn handle_permissions_tool_call( - params: PermissionToolParams, - delegate: AcpClientDelegate, - tool_id_map: Rc>>, - cx: &AsyncApp, - ) -> Task> { - cx.foreground_executor().spawn(async move { - let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone()); - - let tool_call_id = match params.tool_use_id { - Some(tool_use_id) => tool_id_map - .borrow() - .get(&tool_use_id) - .cloned() - .context("Tool call ID not found")?, - - None => delegate.push_tool_call(claude_tool.as_acp()).await?.id, - }; - - let outcome = delegate - .request_existing_tool_call_confirmation( - tool_call_id, - claude_tool.confirmation(None), + let chosen_option = thread + .update(cx, |thread, cx| { + thread.request_tool_call_permission( + claude_tool.as_acp(tool_call_id), + vec![ + acp::PermissionOption { + id: allow_option_id.clone(), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: reject_option_id.clone(), + name: "Reject".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + cx, ) - .await?; + })? + .await?; - match outcome { - acp::ToolCallConfirmationOutcome::Allow - | acp::ToolCallConfirmationOutcome::AlwaysAllow - | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer - | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse { - behavior: PermissionToolBehavior::Allow, - updated_input: params.input, - }), - acp::ToolCallConfirmationOutcome::Reject - | acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse { - behavior: PermissionToolBehavior::Deny, - updated_input: params.input, - }), + let response = if chosen_option == allow_option_id { + PermissionToolResponse { + behavior: PermissionToolBehavior::Allow, + updated_input: input.input, } + } else { + debug_assert_eq!(chosen_option, reject_option_id); + PermissionToolResponse { + behavior: PermissionToolBehavior::Deny, + updated_input: input.input, + } + }; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&response)?, + }], + structured_content: (), + }) + } +} + +#[derive(Clone)] +pub struct ReadTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for ReadTool { + type Input = ReadToolParams; + type Output = (); + + const NAME: &'static str = "Read"; + + fn description(&self) -> &'static str { + "Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents." + } + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Read file".to_string()), + read_only_hint: Some(true), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: None, + } + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { text: content }], + structured_content: (), + }) + } +} + +#[derive(Clone)] +pub struct EditTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for EditTool { + type Input = EditToolParams; + type Output = (); + + const NAME: &'static str = "Edit"; + + fn description(&self) -> &'static str { + "Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better." + } + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Edit file".to_string()), + read_only_hint: Some(false), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: Some(false), + } + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path.clone(), None, None, true, cx) + })? + .await?; + + let new_content = content.replace(&input.old_text, &input.new_text); + if new_content == content { + return Err(anyhow::anyhow!("The old_text was not found in the content")); + } + + thread + .update(cx, |thread, cx| { + thread.write_text_file(input.abs_path, new_content, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![], + structured_content: (), }) } } diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index 75a26ee230..85b9a13642 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use agentic_coding_protocol::{self as acp, PushToolCallParams, ToolCallLocation}; +use agent_client_protocol as acp; use itertools::Itertools; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -115,93 +115,68 @@ impl ClaudeTool { Self::Other { name, .. } => name.clone(), } } - - pub fn content(&self) -> Option { + pub fn content(&self) -> Vec { match &self { - Self::Other { input, .. } => Some(acp::ToolCallContent::Markdown { - markdown: format!( + Self::Other { input, .. } => vec![ + format!( "```json\n{}```", serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) - ), - }), - Self::Task(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.prompt.clone(), - }), - Self::NotebookRead(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.notebook_path.display().to_string(), - }), - Self::NotebookEdit(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.new_source.clone(), - }), - Self::Terminal(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: format!( + ) + .into(), + ], + Self::Task(Some(params)) => vec![params.prompt.clone().into()], + Self::NotebookRead(Some(params)) => { + vec![params.notebook_path.display().to_string().into()] + } + Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()], + Self::Terminal(Some(params)) => vec![ + format!( "`{}`\n\n{}", params.command, params.description.as_deref().unwrap_or_default() - ), - }), - Self::ReadFile(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.abs_path.display().to_string(), - }), - Self::Ls(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.path.display().to_string(), - }), - Self::Glob(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.to_string(), - }), - Self::Grep(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: format!("`{params}`"), - }), - Self::WebFetch(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.prompt.clone(), - }), - Self::WebSearch(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.to_string(), - }), - Self::TodoWrite(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params - .todos - .iter() - .map(|todo| { - format!( - "- {} {}: {}", - match todo.status { - TodoStatus::Completed => "✅", - TodoStatus::InProgress => "🚧", - TodoStatus::Pending => "⬜", - }, - todo.priority, - todo.content - ) - }) - .join("\n"), - }), - Self::ExitPlanMode(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.plan.clone(), - }), - Self::Edit(Some(params)) => Some(acp::ToolCallContent::Diff { + ) + .into(), + ], + Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()], + Self::Ls(Some(params)) => vec![params.path.display().to_string().into()], + Self::Glob(Some(params)) => vec![params.to_string().into()], + Self::Grep(Some(params)) => vec![format!("`{params}`").into()], + Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()], + Self::WebSearch(Some(params)) => vec![params.to_string().into()], + Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()], + Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff { diff: acp::Diff { path: params.abs_path.clone(), old_text: Some(params.old_text.clone()), new_text: params.new_text.clone(), }, - }), - Self::Write(Some(params)) => Some(acp::ToolCallContent::Diff { + }], + Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { diff: acp::Diff { path: params.file_path.clone(), old_text: None, new_text: params.content.clone(), }, - }), + }], Self::MultiEdit(Some(params)) => { // todo: show multiple edits in a multibuffer? - params.edits.first().map(|edit| acp::ToolCallContent::Diff { - diff: acp::Diff { - path: params.file_path.clone(), - old_text: Some(edit.old_string.clone()), - new_text: edit.new_string.clone(), - }, - }) + params + .edits + .first() + .map(|edit| { + vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: Some(edit.old_string.clone()), + new_text: edit.new_string.clone(), + }, + }] + }) + .unwrap_or_default() + } + Self::TodoWrite(Some(_)) => { + // These are mapped to plan updates later + vec![] } Self::Task(None) | Self::NotebookRead(None) @@ -217,181 +192,80 @@ impl ClaudeTool { | Self::ExitPlanMode(None) | Self::Edit(None) | Self::Write(None) - | Self::MultiEdit(None) => None, + | Self::MultiEdit(None) => vec![], } } - pub fn icon(&self) -> acp::Icon { + pub fn kind(&self) -> acp::ToolKind { match self { - Self::Task(_) => acp::Icon::Hammer, - Self::NotebookRead(_) => acp::Icon::FileSearch, - Self::NotebookEdit(_) => acp::Icon::Pencil, - Self::Edit(_) => acp::Icon::Pencil, - Self::MultiEdit(_) => acp::Icon::Pencil, - Self::Write(_) => acp::Icon::Pencil, - Self::ReadFile(_) => acp::Icon::FileSearch, - Self::Ls(_) => acp::Icon::Folder, - Self::Glob(_) => acp::Icon::FileSearch, - Self::Grep(_) => acp::Icon::Regex, - Self::Terminal(_) => acp::Icon::Terminal, - Self::WebSearch(_) => acp::Icon::Globe, - Self::WebFetch(_) => acp::Icon::Globe, - Self::TodoWrite(_) => acp::Icon::LightBulb, - Self::ExitPlanMode(_) => acp::Icon::Hammer, - Self::Other { .. } => acp::Icon::Hammer, - } - } - - pub fn confirmation(&self, description: Option) -> acp::ToolCallConfirmation { - match &self { - Self::Edit(_) | Self::Write(_) | Self::NotebookEdit(_) | Self::MultiEdit(_) => { - acp::ToolCallConfirmation::Edit { description } - } - Self::WebFetch(params) => acp::ToolCallConfirmation::Fetch { - urls: params - .as_ref() - .map(|p| vec![p.url.clone()]) - .unwrap_or_default(), - description, - }, - Self::Terminal(Some(BashToolParams { - description, - command, - .. - })) => acp::ToolCallConfirmation::Execute { - command: command.clone(), - root_command: command.clone(), - description: description.clone(), - }, - Self::ExitPlanMode(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {}", params.plan) - } else { - params.plan.clone() - }, - }, - Self::Task(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {}", params.description) - } else { - params.description.clone() - }, - }, - Self::Ls(Some(LsToolParams { path, .. })) - | Self::ReadFile(Some(ReadToolParams { abs_path: path, .. })) => { - let path = path.display(); - acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {path}") - } else { - path.to_string() - }, - } - } - Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { - let path = notebook_path.display(); - acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {path}") - } else { - path.to_string() - }, - } - } - Self::Glob(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {params}") - } else { - params.to_string() - }, - }, - Self::Grep(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {params}") - } else { - params.to_string() - }, - }, - Self::WebSearch(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {params}") - } else { - params.to_string() - }, - }, - Self::TodoWrite(Some(params)) => { - let params = params.todos.iter().map(|todo| &todo.content).join(", "); - acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {params}") - } else { - params - }, - } - } - Self::Terminal(None) - | Self::Task(None) - | Self::NotebookRead(None) - | Self::ExitPlanMode(None) - | Self::Ls(None) - | Self::Glob(None) - | Self::Grep(None) - | Self::ReadFile(None) - | Self::WebSearch(None) - | Self::TodoWrite(None) - | Self::Other { .. } => acp::ToolCallConfirmation::Other { - description: description.unwrap_or("".to_string()), - }, + Self::Task(_) => acp::ToolKind::Think, + Self::NotebookRead(_) => acp::ToolKind::Read, + Self::NotebookEdit(_) => acp::ToolKind::Edit, + Self::Edit(_) => acp::ToolKind::Edit, + Self::MultiEdit(_) => acp::ToolKind::Edit, + Self::Write(_) => acp::ToolKind::Edit, + Self::ReadFile(_) => acp::ToolKind::Read, + Self::Ls(_) => acp::ToolKind::Search, + Self::Glob(_) => acp::ToolKind::Search, + Self::Grep(_) => acp::ToolKind::Search, + Self::Terminal(_) => acp::ToolKind::Execute, + Self::WebSearch(_) => acp::ToolKind::Search, + Self::WebFetch(_) => acp::ToolKind::Fetch, + Self::TodoWrite(_) => acp::ToolKind::Think, + Self::ExitPlanMode(_) => acp::ToolKind::Think, + Self::Other { .. } => acp::ToolKind::Other, } } pub fn locations(&self) -> Vec { match &self { - Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![ToolCallLocation { + Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation { path: abs_path.clone(), line: None, }], Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { - vec![ToolCallLocation { + vec![acp::ToolCallLocation { + path: file_path.clone(), + line: None, + }] + } + Self::Write(Some(WriteToolParams { file_path, .. })) => { + vec![acp::ToolCallLocation { path: file_path.clone(), line: None, }] } - Self::Write(Some(WriteToolParams { file_path, .. })) => vec![ToolCallLocation { - path: file_path.clone(), - line: None, - }], Self::ReadFile(Some(ReadToolParams { abs_path, offset, .. - })) => vec![ToolCallLocation { + })) => vec![acp::ToolCallLocation { path: abs_path.clone(), line: *offset, }], Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { - vec![ToolCallLocation { + vec![acp::ToolCallLocation { path: notebook_path.clone(), line: None, }] } Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { - vec![ToolCallLocation { + vec![acp::ToolCallLocation { path: notebook_path.clone(), line: None, }] } Self::Glob(Some(GlobToolParams { path: Some(path), .. - })) => vec![ToolCallLocation { + })) => vec![acp::ToolCallLocation { path: path.clone(), line: None, }], - Self::Ls(Some(LsToolParams { path, .. })) => vec![ToolCallLocation { + Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation { path: path.clone(), line: None, }], Self::Grep(Some(GrepToolParams { path: Some(path), .. - })) => vec![ToolCallLocation { + })) => vec![acp::ToolCallLocation { path: PathBuf::from(path), line: None, }], @@ -414,12 +288,15 @@ impl ClaudeTool { } } - pub fn as_acp(&self) -> PushToolCallParams { - PushToolCallParams { - label: self.label(), + pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall { + acp::ToolCall { + id, + kind: self.kind(), + status: acp::ToolCallStatus::InProgress, + title: self.label(), content: self.content(), - icon: self.icon(), locations: self.locations(), + raw_input: None, } } } @@ -596,10 +473,11 @@ impl std::fmt::Display for GrepToolParams { } } -#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)] +#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)] #[serde(rename_all = "snake_case")] pub enum TodoPriority { High, + #[default] Medium, Low, } @@ -634,14 +512,13 @@ impl Into for TodoStatus { #[derive(Deserialize, Serialize, JsonSchema, Debug)] pub struct Todo { - /// Unique identifier - pub id: String, /// Task description pub content: String, - /// Priority level of the todo - pub priority: TodoPriority, /// Current status of the todo pub status: TodoStatus, + /// Priority level of the todo + #[serde(default)] + pub priority: TodoPriority, } impl Into for Todo { diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 12f74cb13e..05f874bd30 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -1,15 +1,17 @@ -use std::{path::Path, sync::Arc, time::Duration}; +use std::{ + path::{Path, PathBuf}, + sync::Arc, + time::Duration, +}; use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; -use acp_thread::{ - AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus, -}; -use agentic_coding_protocol as acp; +use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; +use agent_client_protocol as acp; + use futures::{FutureExt, StreamExt, channel::mpsc, select}; use gpui::{Entity, TestAppContext}; use indoc::indoc; use project::{FakeFs, Project}; -use serde_json::json; use settings::{Settings, SettingsStore}; use util::path; @@ -24,7 +26,11 @@ pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppCont .unwrap(); thread.read_with(cx, |thread, _| { - assert_eq!(thread.entries().len(), 2); + assert!( + thread.entries().len() >= 2, + "Expected at least 2 entries. Got: {:?}", + thread.entries() + ); assert!(matches!( thread.entries()[0], AgentThreadEntry::UserMessage(_) @@ -54,19 +60,25 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes thread .update(cx, |thread, cx| { thread.send( - acp::SendUserMessageParams { - chunks: vec![ - acp::UserMessageChunk::Text { - text: "Read the file ".into(), - }, - acp::UserMessageChunk::Path { - path: Path::new("foo.rs").into(), - }, - acp::UserMessageChunk::Text { - text: " and tell me what the content of the println! is".into(), - }, - ], - }, + vec![ + acp::ContentBlock::Text(acp::TextContent { + text: "Read the file ".into(), + annotations: None, + }), + acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: "foo.rs".into(), + name: "foo.rs".into(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }), + acp::ContentBlock::Text(acp::TextContent { + text: " and tell me what the content of the println! is".into(), + annotations: None, + }), + ], cx, ) }) @@ -74,37 +86,44 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes .unwrap(); thread.read_with(cx, |thread, cx| { - assert_eq!(thread.entries().len(), 3); assert!(matches!( thread.entries()[0], AgentThreadEntry::UserMessage(_) )); - assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_))); - let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else { - panic!("Expected AssistantMessage") - }; + let assistant_message = &thread + .entries() + .iter() + .rev() + .find_map(|entry| match entry { + AgentThreadEntry::AssistantMessage(msg) => Some(msg), + _ => None, + }) + .unwrap(); + assert!( assistant_message.to_markdown(cx).contains("Hello, world!"), "unexpected assistant message: {:?}", assistant_message.to_markdown(cx) ); }); + + drop(tempdir); } pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) { - let fs = init_test(cx).await; - fs.insert_tree( - path!("/private/tmp"), - json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), - ) - .await; - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let _fs = init_test(cx).await; + + let tempdir = tempfile::tempdir().unwrap(); + let foo_path = tempdir.path().join("foo"); + std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file"); + + let project = Project::example([tempdir.path()], &mut cx.to_async()).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; thread .update(cx, |thread, cx| { thread.send_raw( - "Read the '/private/tmp/foo' file and tell me what you see.", + &format!("Read {} and tell me what you see.", foo_path.display()), cx, ) }) @@ -127,10 +146,13 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) }) ); }); + + drop(tempdir); } -pub async fn test_tool_call_with_confirmation( +pub async fn test_tool_call_with_permission( server: impl AgentServer + 'static, + allow_option_id: acp::PermissionOptionId, cx: &mut TestAppContext, ) { let fs = init_test(cx).await; @@ -138,7 +160,7 @@ pub async fn test_tool_call_with_confirmation( let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let full_turn = thread.update(cx, |thread, cx| { thread.send_raw( - r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#, + r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, cx, ) }); @@ -158,14 +180,11 @@ pub async fn test_tool_call_with_confirmation( ) .await; - let tool_call_id = thread.read_with(cx, |thread, _cx| { + let tool_call_id = thread.read_with(cx, |thread, cx| { let AgentThreadEntry::ToolCall(ToolCall { id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, + label, + status: ToolCallStatus::WaitingForConfirmation { .. }, .. }) = &thread .entries() @@ -176,13 +195,19 @@ pub async fn test_tool_call_with_confirmation( panic!(); }; - assert!(root_command.contains("touch")); + let label = label.read(cx).source(); + assert!(label.contains("touch"), "Got: {}", label); - *id + id.clone() }); thread.update(cx, |thread, cx| { - thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); + thread.authorize_tool_call( + tool_call_id, + allow_option_id, + acp::PermissionOptionKind::AllowOnce, + cx, + ); assert!(thread.entries().iter().any(|entry| matches!( entry, @@ -197,7 +222,7 @@ pub async fn test_tool_call_with_confirmation( thread.read_with(cx, |thread, cx| { let AgentThreadEntry::ToolCall(ToolCall { - content: Some(ToolCallContent::Markdown { markdown }), + content, status: ToolCallStatus::Allowed { .. }, .. }) = thread @@ -209,13 +234,10 @@ pub async fn test_tool_call_with_confirmation( panic!(); }; - markdown.read_with(cx, |md, _cx| { - assert!( - md.source().contains("Hello"), - r#"Expected '{}' to contain "Hello""#, - md.source() - ); - }); + assert!( + content.iter().any(|c| c.to_markdown(cx).contains("Hello")), + "Expected content to contain 'Hello'" + ); }); } @@ -226,7 +248,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let full_turn = thread.update(cx, |thread, cx| { thread.send_raw( - r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#, + r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, cx, ) }); @@ -246,29 +268,24 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon ) .await; - thread.read_with(cx, |thread, _cx| { + thread.read_with(cx, |thread, cx| { let AgentThreadEntry::ToolCall(ToolCall { id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, + label, + status: ToolCallStatus::WaitingForConfirmation { .. }, .. }) = &thread.entries()[first_tool_call_ix] else { panic!("{:?}", thread.entries()[1]); }; - assert!(root_command.contains("touch")); + let label = label.read(cx).source(); + assert!(label.contains("touch"), "Got: {}", label); - *id + id.clone() }); - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .await - .unwrap(); + let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); full_turn.await.unwrap(); thread.read_with(cx, |thread, _| { let AgentThreadEntry::ToolCall(ToolCall { @@ -294,9 +311,30 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon }); } +pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + let project = Project::test(fs, [], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx)) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(thread.entries().len() >= 2, "Expected at least 2 entries"); + }); + + let weak_thread = thread.downgrade(); + drop(thread); + + cx.executor().run_until_parked(); + assert!(!weak_thread.is_upgradable()); +} + #[macro_export] macro_rules! common_e2e_tests { - ($server:expr) => { + ($server:expr, allow_option_id = $allow_option_id:expr) => { mod common_e2e { use super::*; @@ -320,8 +358,13 @@ macro_rules! common_e2e_tests { #[::gpui::test] #[cfg_attr(not(feature = "e2e"), ignore)] - async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) { - $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await; + async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_tool_call_with_permission( + $server, + ::agent_client_protocol::PermissionOptionId($allow_option_id.into()), + cx, + ) + .await; } #[::gpui::test] @@ -329,6 +372,12 @@ macro_rules! common_e2e_tests { async fn cancel(cx: &mut ::gpui::TestAppContext) { $crate::e2e_tests::test_cancel($server, cx).await; } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn thread_drop(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_thread_drop($server, cx).await; + } } }; } @@ -369,15 +418,16 @@ pub async fn new_test_thread( current_dir: impl AsRef, cx: &mut TestAppContext, ) -> Entity { - let thread = cx - .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx)) + let connection = cx + .update(|cx| server.connect(current_dir.as_ref(), &project, cx)) .await .unwrap(); - thread - .update(cx, |thread, _| thread.initialize()) + let thread = connection + .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async()) .await .unwrap(); + thread } @@ -410,3 +460,24 @@ pub async fn run_until_first_tool_call( } } } + +pub fn get_zed_path() -> PathBuf { + let mut zed_path = std::env::current_exe().unwrap(); + + while zed_path + .file_name() + .map_or(true, |name| name.to_string_lossy() != "debug") + { + if !zed_path.pop() { + panic!("Could not find target directory"); + } + } + + zed_path.push("zed"); + + if !zed_path.exists() { + panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n"); + } + + zed_path +} diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 8ad147cbff..ad883f6da8 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,9 +1,13 @@ -use crate::stdio_agent_server::StdioAgentServer; -use crate::{AgentServerCommand, AgentServerVersion}; -use anyhow::{Context as _, Result}; -use gpui::{AsyncApp, Entity}; +use std::path::Path; +use std::rc::Rc; + +use crate::{AgentServer, AgentServerCommand}; +use acp_thread::{AgentConnection, LoadError}; +use anyhow::Result; +use gpui::{Entity, Task}; use project::Project; use settings::SettingsStore; +use ui::App; use crate::AllAgentServersSettings; @@ -12,7 +16,7 @@ pub struct Gemini; const ACP_ARG: &str = "--experimental-acp"; -impl StdioAgentServer for Gemini { +impl AgentServer for Gemini { fn name(&self) -> &'static str { "Gemini" } @@ -25,79 +29,63 @@ impl StdioAgentServer for Gemini { "Ask questions, edit files, run commands.\nBe specific for the best results." } - fn supports_always_allow(&self) -> bool { - true - } - fn logo(&self) -> ui::IconName { ui::IconName::AiGemini } - async fn command( + fn connect( &self, + root_dir: &Path, project: &Entity, - cx: &mut AsyncApp, - ) -> Result { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).gemini.clone() - })?; + cx: &mut App, + ) -> Task>> { + let project = project.clone(); + let root_dir = root_dir.to_path_buf(); + let server_name = self.name(); + cx.spawn(async move |cx| { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() + })?; - if let Some(command) = - AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await - { - return Ok(command); - }; + let Some(command) = + AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await + else { + anyhow::bail!("Failed to find gemini binary"); + }; - let (fs, node_runtime) = project.update(cx, |project, _| { - (project.fs().clone(), project.node_runtime().cloned()) - })?; - let node_runtime = node_runtime.context("gemini not found on path")?; + let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await; + if result.is_err() { + let version_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--version") + .kill_on_drop(true) + .output(); - let directory = ::paths::agent_servers_dir().join("gemini"); - fs.create_dir(&directory).await?; - node_runtime - .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) - .await?; - let path = directory.join("node_modules/.bin/gemini"); + let help_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--help") + .kill_on_drop(true) + .output(); - Ok(AgentServerCommand { - path, - args: vec![ACP_ARG.into()], - env: None, + let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; + + let current_version = String::from_utf8(version_output?.stdout)?; + let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); + + if !supported { + return Err(LoadError::Unsupported { + error_message: format!( + "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", + current_version + ).into(), + upgrade_message: "Upgrade Gemini to Latest".into(), + upgrade_command: "npm install -g @google/gemini-cli@latest".into(), + }.into()) + } + } + result }) } - - async fn version(&self, command: &AgentServerCommand) -> Result { - let version_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--version") - .kill_on_drop(true) - .output(); - - let help_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--help") - .kill_on_drop(true) - .output(); - - let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; - - let current_version = String::from_utf8(version_output?.stdout)?; - let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); - - if supported { - Ok(AgentServerVersion::Supported) - } else { - Ok(AgentServerVersion::Unsupported { - error_message: format!( - "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", - current_version - ).into(), - upgrade_message: "Upgrade Gemini to Latest".into(), - upgrade_command: "npm install -g @google/gemini-cli@latest".into(), - }) - } - } } #[cfg(test)] @@ -106,7 +94,7 @@ pub(crate) mod tests { use crate::AgentServerCommand; use std::path::Path; - crate::common_e2e_tests!(Gemini); + crate::common_e2e_tests!(Gemini, allow_option_id = "proceed_once"); pub fn local_command() -> AgentServerCommand { let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) @@ -116,7 +104,7 @@ pub(crate) mod tests { AgentServerCommand { path: "node".into(), - args: vec![cli_path, ACP_ARG.into()], + args: vec![cli_path], env: None, } } diff --git a/crates/agent_servers/src/stdio_agent_server.rs b/crates/agent_servers/src/stdio_agent_server.rs deleted file mode 100644 index e60dd39de4..0000000000 --- a/crates/agent_servers/src/stdio_agent_server.rs +++ /dev/null @@ -1,119 +0,0 @@ -use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; -use acp_thread::{AcpClientDelegate, AcpThread, LoadError}; -use agentic_coding_protocol as acp; -use anyhow::{Result, anyhow}; -use gpui::{App, AsyncApp, Entity, Task, prelude::*}; -use project::Project; -use std::path::Path; -use util::ResultExt; - -pub trait StdioAgentServer: Send + Clone { - fn logo(&self) -> ui::IconName; - fn name(&self) -> &'static str; - fn empty_state_headline(&self) -> &'static str; - fn empty_state_message(&self) -> &'static str; - fn supports_always_allow(&self) -> bool; - - fn command( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> impl Future>; - - fn version( - &self, - command: &AgentServerCommand, - ) -> impl Future> + Send; -} - -impl AgentServer for T { - fn name(&self) -> &'static str { - self.name() - } - - fn empty_state_headline(&self) -> &'static str { - self.empty_state_headline() - } - - fn empty_state_message(&self) -> &'static str { - self.empty_state_message() - } - - fn logo(&self) -> ui::IconName { - self.logo() - } - - fn supports_always_allow(&self) -> bool { - self.supports_always_allow() - } - - fn new_thread( - &self, - root_dir: &Path, - project: &Entity, - cx: &mut App, - ) -> Task>> { - let root_dir = root_dir.to_path_buf(); - let project = project.clone(); - let this = self.clone(); - let title = self.name().into(); - - cx.spawn(async move |cx| { - let command = this.command(&project, cx).await?; - - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - cx.new(|cx| { - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(AgentServerVersion::Unsupported { - error_message, - upgrade_message, - upgrade_command, - }) = this.version(&command).await.log_err() - { - Err(anyhow!(LoadError::Unsupported { - error_message, - upgrade_message, - upgrade_command - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } - } - }; - drop(io_task); - result - }); - - AcpThread::new(connection, title, Some(child_status), project.clone(), cx) - }) - }) - } -} diff --git a/crates/agent_settings/Cargo.toml b/crates/agent_settings/Cargo.toml index 3afe5ae547..d34396a5d3 100644 --- a/crates/agent_settings/Cargo.toml +++ b/crates/agent_settings/Cargo.toml @@ -13,6 +13,7 @@ path = "src/agent_settings.rs" [dependencies] anyhow.workspace = true +cloud_llm_client.workspace = true collections.workspace = true gpui.workspace = true language_model.workspace = true @@ -20,7 +21,6 @@ schemars.workspace = true serde.workspace = true settings.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true [dev-dependencies] fs.workspace = true diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 13b966608c..e6a79963d6 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -13,6 +13,9 @@ use std::borrow::Cow; pub use crate::agent_profile::*; +pub const SUMMARIZE_THREAD_PROMPT: &str = + include_str!("../../agent/src/prompts/summarize_thread_prompt.txt"); + pub fn init(cx: &mut App) { AgentSettings::register(cx); } @@ -321,11 +324,11 @@ pub enum CompletionMode { Burn, } -impl From for zed_llm_client::CompletionMode { +impl From for cloud_llm_client::CompletionMode { fn from(value: CompletionMode) -> Self { match value { - CompletionMode::Normal => zed_llm_client::CompletionMode::Normal, - CompletionMode::Burn => zed_llm_client::CompletionMode::Max, + CompletionMode::Normal => cloud_llm_client::CompletionMode::Normal, + CompletionMode::Burn => cloud_llm_client::CompletionMode::Max, } } } diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 7d3b84e42e..c145df0eae 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -17,10 +17,11 @@ test-support = ["gpui/test-support", "language/test-support"] [dependencies] acp_thread.workspace = true +agent-client-protocol.workspace = true agent.workspace = true -agentic-coding-protocol.workspace = true -agent_settings.workspace = true +agent2.workspace = true agent_servers.workspace = true +agent_settings.workspace = true ai_onboarding.workspace = true anyhow.workspace = true assistant_context.workspace = true @@ -31,6 +32,7 @@ audio.workspace = true buffer_diff.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true component.workspace = true @@ -46,9 +48,9 @@ futures.workspace = true fuzzy.workspace = true gpui.workspace = true html_to_markdown.workspace = true -indoc.workspace = true http_client.workspace = true indexed_docs.workspace = true +indoc.workspace = true inventory.workspace = true itertools.workspace = true jsonschema.workspace = true @@ -97,7 +99,6 @@ watch.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true [dev-dependencies] assistant_tools.workspace = true diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 95f4f81205..6475b7eeee 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,13 +1,16 @@ -use acp_thread::Plan; +use acp_thread::{AgentConnection, Plan}; use agent_servers::AgentServer; +use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; +use audio::{Audio, Sound}; use std::cell::RefCell; use std::collections::BTreeMap; use std::path::Path; +use std::process::ExitStatus; use std::rc::Rc; use std::sync::Arc; use std::time::Duration; -use agentic_coding_protocol::{self as acp}; +use agent_client_protocol as acp; use assistant_tool::ActionLog; use buffer_diff::BufferDiff; use collections::{HashMap, HashSet}; @@ -16,13 +19,12 @@ use editor::{ EditorStyle, MinimapVisibility, MultiBuffer, PathKey, }; use file_icons::FileIcons; -use futures::channel::oneshot; use gpui::{ Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, - FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, - Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, - Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*, - pulsating_between, + FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay, + SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, + Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, + linear_gradient, list, percentage, point, prelude::*, pulsating_between, }; use language::language_settings::SoftWrap; use language::{Buffer, Language}; @@ -30,24 +32,28 @@ use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::Project; use settings::Settings as _; -use text::Anchor; +use text::{Anchor, BufferSnapshot}; use theme::ThemeSettings; -use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*}; +use ui::{ + Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*, +}; use util::ResultExt; use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; use ::acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, - LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent, - ToolCallId, ToolCallStatus, + LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, }; use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; use crate::acp::message_history::MessageHistory; use crate::agent_diff::AgentDiff; use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; -use crate::{AgentDiffPane, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll}; +use crate::ui::{AgentNotification, AgentNotificationEvent}; +use crate::{ + AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll, +}; const RESPONSE_PADDING_X: Pixels = px(19.); @@ -58,18 +64,22 @@ pub struct AcpThreadView { thread_state: ThreadState, diff_editors: HashMap>, message_editor: Entity, - message_set_from_history: bool, + message_set_from_history: Option, _message_editor_subscription: Subscription, mention_set: Arc>, + notifications: Vec>, + notification_subscriptions: HashMap, Vec>, last_error: Option>, list_state: ListState, + scrollbar_state: ScrollbarState, auth_task: Option>, - expanded_tool_calls: HashSet, + expanded_tool_calls: HashSet, expanded_thinking_blocks: HashSet<(usize, usize)>, edits_expanded: bool, plan_expanded: bool, editor_expanded: bool, - message_history: Rc>>, + message_history: Rc>>>, + _cancel_task: Option>, } enum ThreadState { @@ -82,14 +92,11 @@ enum ThreadState { }, LoadError(LoadError), Unauthenticated { - thread: Entity, + connection: Rc, + }, + ServerExited { + status: ExitStatus, }, -} - -struct AlwaysAllowOption { - id: &'static str, - label: SharedString, - outcome: acp::ToolCallConfirmationOutcome, } impl AcpThreadView { @@ -97,7 +104,7 @@ impl AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, - message_history: Rc>>, + message_history: Rc>>>, min_lines: usize, max_lines: Option, window: &mut Window, @@ -144,33 +151,32 @@ impl AcpThreadView { editor }); - let message_editor_subscription = cx.subscribe(&message_editor, |this, _, event, _| { - if let editor::EditorEvent::BufferEdited = &event { - if !this.message_set_from_history { - this.message_history.borrow_mut().reset_position(); + let message_editor_subscription = + cx.subscribe(&message_editor, |this, editor, event, cx| { + if let editor::EditorEvent::BufferEdited = &event { + let buffer = editor + .read(cx) + .buffer() + .read(cx) + .as_singleton() + .unwrap() + .read(cx) + .snapshot(); + if let Some(message) = this.message_set_from_history.clone() + && message.version() != buffer.version() + { + this.message_set_from_history = None; + } + + if this.message_set_from_history.is_none() { + this.message_history.borrow_mut().reset_position(); + } } - this.message_set_from_history = false; - } - }); + }); let mention_set = mention_set.clone(); - let list_state = ListState::new( - 0, - gpui::ListAlignment::Bottom, - px(2048.0), - cx.processor({ - move |this: &mut Self, index: usize, window, cx| { - let Some((entry, len)) = this.thread().and_then(|thread| { - let entries = &thread.read(cx).entries(); - Some((entries.get(index)?, entries.len())) - }) else { - return Empty.into_any(); - }; - this.render_entry(index, len, entry, window, cx) - } - }), - ); + let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0)); Self { agent: agent.clone(), @@ -178,11 +184,14 @@ impl AcpThreadView { project: project.clone(), thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, - message_set_from_history: false, + message_set_from_history: None, _message_editor_subscription: message_editor_subscription, mention_set, + notifications: Vec::new(), + notification_subscriptions: HashMap::default(), diff_editors: Default::default(), - list_state: list_state, + list_state: list_state.clone(), + scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), last_error: None, auth_task: None, expanded_tool_calls: HashSet::default(), @@ -191,6 +200,7 @@ impl AcpThreadView { plan_expanded: false, editor_expanded: false, message_history, + _cancel_task: None, } } @@ -208,10 +218,10 @@ impl AcpThreadView { .map(|worktree| worktree.read(cx).abs_path()) .unwrap_or_else(|| paths::home_dir().as_path().into()); - let task = agent.new_thread(&root_dir, &project, cx); + let connect_task = agent.connect(&root_dir, &project, cx); let load_task = cx.spawn_in(window, async move |this, cx| { - let thread = match task.await { - Ok(thread) => thread, + let connection = match connect_task.await { + Ok(connection) => connection, Err(err) => { this.update(cx, |this, cx| { this.handle_load_error(err, cx); @@ -222,48 +232,44 @@ impl AcpThreadView { } }; - let init_response = async { - let resp = thread - .read_with(cx, |thread, _cx| thread.initialize())? - .await?; - anyhow::Ok(resp) - }; + // this.update_in(cx, |_this, _window, cx| { + // let status = connection.exit_status(cx); + // cx.spawn(async move |this, cx| { + // let status = status.await.ok(); + // this.update(cx, |this, cx| { + // this.thread_state = ThreadState::ServerExited { status }; + // cx.notify(); + // }) + // .ok(); + // }) + // .detach(); + // }) + // .ok(); - let result = match init_response.await { + let result = match connection + .clone() + .new_thread(project.clone(), &root_dir, cx) + .await + { Err(e) => { let mut cx = cx.clone(); - if e.downcast_ref::().is_some() { - let child_status = thread - .update(&mut cx, |thread, _| thread.child_status()) - .ok() - .flatten(); - if let Some(child_status) = child_status { - match child_status.await { - Ok(_) => Err(e), - Err(e) => Err(e), - } - } else { - Err(e) - } + if e.is::() { + this.update(&mut cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { connection }; + cx.notify(); + }) + .ok(); + return; } else { Err(e) } } - Ok(response) => { - if !response.is_authenticated { - this.update(cx, |this, _| { - this.thread_state = ThreadState::Unauthenticated { thread }; - }) - .ok(); - return; - }; - Ok(()) - } + Ok(session_id) => Ok(session_id), }; this.update_in(cx, |this, window, cx| { match result { - Ok(()) => { + Ok(thread) => { let thread_subscription = cx.subscribe_in(&thread, window, Self::handle_thread_event); @@ -305,10 +311,11 @@ impl AcpThreadView { pub fn thread(&self) -> Option<&Entity> { match &self.thread_state { - ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { - Some(thread) - } - ThreadState::Loading { .. } | ThreadState::LoadError(..) => None, + ThreadState::Ready { thread, .. } => Some(thread), + ThreadState::Unauthenticated { .. } + | ThreadState::Loading { .. } + | ThreadState::LoadError(..) + | ThreadState::ServerExited { .. } => None, } } @@ -318,6 +325,7 @@ impl AcpThreadView { ThreadState::Loading { .. } => "Loading…".into(), ThreadState::LoadError(_) => "Failed to load".into(), ThreadState::Unauthenticated { .. } => "Not authenticated".into(), + ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(), } } @@ -325,7 +333,7 @@ impl AcpThreadView { self.last_error.take(); if let Some(thread) = self.thread() { - thread.update(cx, |thread, cx| thread.cancel(cx)).detach(); + self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); } } @@ -362,7 +370,7 @@ impl AcpThreadView { self.last_error.take(); let mut ix = 0; - let mut chunks: Vec = Vec::new(); + let mut chunks: Vec = Vec::new(); let project = self.project.clone(); self.message_editor.update(cx, |editor, cx| { let text = editor.text(cx); @@ -374,12 +382,19 @@ impl AcpThreadView { { let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); if crease_range.start > ix { - chunks.push(acp::UserMessageChunk::Text { - text: text[ix..crease_range.start].to_string(), - }); + chunks.push(text[ix..crease_range.start].into()); } if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { - chunks.push(acp::UserMessageChunk::Path { path: abs_path }); + let path_str = abs_path.display().to_string(); + chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: path_str.clone(), + name: path_str, + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + })); } ix = crease_range.end; } @@ -388,9 +403,7 @@ impl AcpThreadView { if ix < text.len() { let last_chunk = text[ix..].trim(); if !last_chunk.is_empty() { - chunks.push(acp::UserMessageChunk::Text { - text: last_chunk.into(), - }); + chunks.push(last_chunk.into()); } } }) @@ -400,9 +413,10 @@ impl AcpThreadView { return; } - let Some(thread) = self.thread() else { return }; - let message = acp::SendUserMessageParams { chunks }; - let task = thread.update(cx, |thread, cx| thread.send(message.clone(), cx)); + let Some(thread) = self.thread() else { + return; + }; + let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); cx.spawn(async move |this, cx| { let result = task.await; @@ -419,12 +433,15 @@ impl AcpThreadView { let mention_set = self.mention_set.clone(); self.set_editor_is_expanded(false, cx); + self.message_editor.update(cx, |editor, cx| { editor.clear(window, cx); editor.remove_creases(mention_set.lock().drain(), cx) }); - self.message_history.borrow_mut().push(message); + self.scroll_to_bottom(cx); + + self.message_history.borrow_mut().push(chunks); } fn previous_history_message( @@ -433,11 +450,21 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { + if self.message_set_from_history.is_none() && !self.message_editor.read(cx).is_empty(cx) { + self.message_editor.update(cx, |editor, cx| { + editor.move_up(&Default::default(), window, cx); + }); + return; + } + self.message_set_from_history = Self::set_draft_message( self.message_editor.clone(), self.mention_set.clone(), self.project.clone(), - self.message_history.borrow_mut().prev(), + self.message_history + .borrow_mut() + .prev() + .map(|blocks| blocks.as_slice()), window, cx, ); @@ -449,14 +476,35 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - self.message_set_from_history = Self::set_draft_message( + if self.message_set_from_history.is_none() { + self.message_editor.update(cx, |editor, cx| { + editor.move_down(&Default::default(), window, cx); + }); + return; + } + + let mut message_history = self.message_history.borrow_mut(); + let next_history = message_history.next(); + + let set_draft_message = Self::set_draft_message( self.message_editor.clone(), self.mention_set.clone(), self.project.clone(), - self.message_history.borrow_mut().next(), + Some( + next_history + .map(|blocks| blocks.as_slice()) + .unwrap_or_else(|| &[]), + ), window, cx, ); + // If we reset the text to an empty string because we ran out of history, + // we don't want to mark it as coming from the history + self.message_set_from_history = if next_history.is_some() { + set_draft_message + } else { + None + }; } fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { @@ -490,31 +538,30 @@ impl AcpThreadView { message_editor: Entity, mention_set: Arc>, project: Entity, - message: Option<&acp::SendUserMessageParams>, + message: Option<&[acp::ContentBlock]>, window: &mut Window, cx: &mut Context, - ) -> bool { + ) -> Option { cx.notify(); - let Some(message) = message else { - return false; - }; + let message = message?; let mut text = String::new(); let mut mentions = Vec::new(); - for chunk in &message.chunks { + for chunk in message { match chunk { - acp::UserMessageChunk::Text { text: chunk } => { - text.push_str(&chunk); + acp::ContentBlock::Text(text_content) => { + text.push_str(&text_content.text); } - acp::UserMessageChunk::Path { path } => { + acp::ContentBlock::ResourceLink(resource_link) => { + let path = Path::new(&resource_link.uri); let start = text.len(); - let content = MentionPath::new(path).to_string(); + let content = MentionPath::new(&path).to_string(); text.push_str(&content); let end = text.len(); if let Some(project_path) = - project.read(cx).project_path_for_absolute_path(path, cx) + project.read(cx).project_path_for_absolute_path(&path, cx) { let filename: SharedString = path .file_name() @@ -525,6 +572,9 @@ impl AcpThreadView { mentions.push((start..end, project_path, filename)); } } + acp::ContentBlock::Image(_) + | acp::ContentBlock::Audio(_) + | acp::ContentBlock::Resource(_) => {} } } @@ -558,7 +608,8 @@ impl AcpThreadView { } } - true + let snapshot = snapshot.as_singleton().unwrap().2.clone(); + Some(snapshot.text) } fn handle_thread_event( @@ -580,6 +631,33 @@ impl AcpThreadView { self.sync_thread_entry_view(index, window, cx); self.list_state.splice(index..index + 1, 1); } + AcpThreadEvent::ToolAuthorizationRequired => { + self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); + } + AcpThreadEvent::Stopped => { + let used_tools = thread.read(cx).used_tools_since_last_user_message(); + self.notify_with_sound( + if used_tools { + "Finished running tools" + } else { + "New message" + }, + IconName::ZedAssistant, + window, + cx, + ); + } + AcpThreadEvent::Error => { + self.notify_with_sound( + "Agent stopped due to an error", + IconName::Warning, + window, + cx, + ); + } + AcpThreadEvent::ServerExited(status) => { + self.thread_state = ThreadState::ServerExited { status: *status }; + } } cx.notify(); } @@ -590,71 +668,84 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let Some(multibuffer) = self.entry_diff_multibuffer(entry_ix, cx) else { + let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else { return; }; - if self.diff_editors.contains_key(&multibuffer.entity_id()) { - return; - } + let multibuffers = multibuffers.collect::>(); - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - None, - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(TextStyleRefinement { - font_size: Some( - TextSize::Small - .rems(cx) - .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) - .into(), - ), - ..Default::default() + for multibuffer in multibuffers { + if self.diff_editors.contains_key(&multibuffer.entity_id()) { + return; + } + + let editor = cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + multibuffer.clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(TextStyleRefinement { + font_size: Some( + TextSize::Small + .rems(cx) + .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) + .into(), + ), + ..Default::default() + }); + editor }); - editor - }); - let entity_id = multibuffer.entity_id(); - cx.observe_release(&multibuffer, move |this, _, _| { - this.diff_editors.remove(&entity_id); - }) - .detach(); + let entity_id = multibuffer.entity_id(); + cx.observe_release(&multibuffer, move |this, _, _| { + this.diff_editors.remove(&entity_id); + }) + .detach(); - self.diff_editors.insert(entity_id, editor); + self.diff_editors.insert(entity_id, editor); + } } - fn entry_diff_multibuffer(&self, entry_ix: usize, cx: &App) -> Option> { + fn entry_diff_multibuffers( + &self, + entry_ix: usize, + cx: &App, + ) -> Option>> { let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - entry.diff().map(|diff| diff.multibuffer.clone()) + Some(entry.diffs().map(|diff| diff.multibuffer.clone())) } - fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { - let Some(thread) = self.thread().cloned() else { + fn authenticate( + &mut self, + method: acp::AuthMethodId, + window: &mut Window, + cx: &mut Context, + ) { + let ThreadState::Unauthenticated { ref connection } = self.thread_state else { return; }; self.last_error.take(); - let authenticate = thread.read(cx).authenticate(); + let authenticate = connection.authenticate(method, cx); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); let agent = self.agent.clone(); @@ -684,15 +775,16 @@ impl AcpThreadView { fn authorize_tool_call( &mut self, - id: ToolCallId, - outcome: acp::ToolCallConfirmationOutcome, + tool_call_id: acp::ToolCallId, + option_id: acp::PermissionOptionId, + option_kind: acp::PermissionOptionKind, cx: &mut Context, ) { let Some(thread) = self.thread() else { return; }; thread.update(cx, |thread, cx| { - thread.authorize_tool_call(id, outcome, cx); + thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx); }); cx.notify(); } @@ -705,7 +797,7 @@ impl AcpThreadView { window: &mut Window, cx: &Context, ) -> AnyElement { - match &entry { + let primary = match &entry { AgentThreadEntry::UserMessage(message) => div() .py_4() .px_2() @@ -719,10 +811,12 @@ impl AcpThreadView { .border_1() .border_color(cx.theme().colors().border) .text_xs() - .child(self.render_markdown( - message.content.clone(), - user_message_markdown_style(window, cx), - )), + .children(message.content.markdown().map(|md| { + self.render_markdown( + md.clone(), + user_message_markdown_style(window, cx), + ) + })), ) .into_any(), AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => { @@ -730,20 +824,28 @@ impl AcpThreadView { let message_body = v_flex() .w_full() .gap_2p5() - .children(chunks.iter().enumerate().map(|(chunk_ix, chunk)| { - match chunk { - AssistantMessageChunk::Text { chunk } => self - .render_markdown(chunk.clone(), style.clone()) - .into_any_element(), - AssistantMessageChunk::Thought { chunk } => self.render_thinking_block( - index, - chunk_ix, - chunk.clone(), - window, - cx, - ), - } - })) + .children(chunks.iter().enumerate().filter_map( + |(chunk_ix, chunk)| match chunk { + AssistantMessageChunk::Message { block } => { + block.markdown().map(|md| { + self.render_markdown(md.clone(), style.clone()) + .into_any_element() + }) + } + AssistantMessageChunk::Thought { block } => { + block.markdown().map(|md| { + self.render_thinking_block( + index, + chunk_ix, + md.clone(), + window, + cx, + ) + .into_any_element() + }) + } + }, + )) .into_any(); v_flex() @@ -760,6 +862,20 @@ impl AcpThreadView { .px_5() .child(self.render_tool_call(index, tool_call, window, cx)) .into_any(), + }; + + let Some(thread) = self.thread() else { + return primary; + }; + let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating); + if index == total_entries - 1 && !is_generating { + v_flex() + .w_full() + .child(primary) + .child(self.render_thread_controls(cx)) + .into_any_element() + } else { + primary } } @@ -869,9 +985,12 @@ impl AcpThreadView { let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix)); let status_icon = match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { .. } => None, ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, + status: acp::ToolCallStatus::Pending, + } + | ToolCallStatus::WaitingForConfirmation { .. } => None, + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress, .. } => Some( Icon::new(IconName::ArrowCircle) @@ -885,13 +1004,13 @@ impl AcpThreadView { .into_any(), ), ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Finished, + status: acp::ToolCallStatus::Completed, .. } => None, ToolCallStatus::Rejected | ToolCallStatus::Canceled | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Error, + status: acp::ToolCallStatus::Failed, .. } => Some( Icon::new(IconName::X) @@ -909,34 +1028,9 @@ impl AcpThreadView { .any(|content| matches!(content, ToolCallContent::Diff { .. })), }; - let is_collapsible = tool_call.content.is_some() && !needs_confirmation; + let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id); - let content = if is_open { - match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { confirmation, .. } => { - Some(self.render_tool_call_confirmation( - tool_call.id, - confirmation, - tool_call.content.as_ref(), - window, - cx, - )) - } - ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { - tool_call.content.as_ref().map(|content| { - div() - .py_1p5() - .child(self.render_tool_call_content(content, window, cx)) - .into_any_element() - }) - } - ToolCallStatus::Rejected => None, - } - } else { - None - }; - v_flex() .when(needs_confirmation, |this| { this.rounded_lg() @@ -976,9 +1070,19 @@ impl AcpThreadView { }) .gap_1p5() .child( - Icon::new(tool_call.icon) - .size(IconSize::Small) - .color(Color::Muted), + Icon::new(match tool_call.kind { + acp::ToolKind::Read => IconName::ToolRead, + acp::ToolKind::Edit => IconName::ToolPencil, + acp::ToolKind::Delete => IconName::ToolDeleteFile, + acp::ToolKind::Move => IconName::ArrowRightLeft, + acp::ToolKind::Search => IconName::ToolSearch, + acp::ToolKind::Execute => IconName::ToolTerminal, + acp::ToolKind::Think => IconName::ToolBulb, + acp::ToolKind::Fetch => IconName::ToolWeb, + acp::ToolKind::Other => IconName::ToolHammer, + }) + .size(IconSize::Small) + .color(Color::Muted), ) .child(if tool_call.locations.len() == 1 { let name = tool_call.locations[0] @@ -1023,16 +1127,16 @@ impl AcpThreadView { .gap_0p5() .when(is_collapsible, |this| { this.child( - Disclosure::new(("expand", tool_call.id.0), is_open) + Disclosure::new(("expand", entry_ix), is_open) .opened_icon(IconName::ChevronUp) .closed_icon(IconName::ChevronDown) .on_click(cx.listener({ - let id = tool_call.id; + let id = tool_call.id.clone(); move |this: &mut Self, _, _, cx: &mut Context| { if is_open { this.expanded_tool_calls.remove(&id); } else { - this.expanded_tool_calls.insert(id); + this.expanded_tool_calls.insert(id.clone()); } cx.notify(); } @@ -1042,12 +1146,12 @@ impl AcpThreadView { .children(status_icon), ) .on_click(cx.listener({ - let id = tool_call.id; + let id = tool_call.id.clone(); move |this: &mut Self, _, _, cx: &mut Context| { if is_open { this.expanded_tool_calls.remove(&id); } else { - this.expanded_tool_calls.insert(id); + this.expanded_tool_calls.insert(id.clone()); } cx.notify(); } @@ -1055,7 +1159,7 @@ impl AcpThreadView { ) .when(is_open, |this| { this.child( - div() + v_flex() .text_xs() .when(is_collapsible, |this| { this.mt_1() @@ -1064,7 +1168,45 @@ impl AcpThreadView { .bg(cx.theme().colors().editor_background) .rounded_lg() }) - .children(content), + .map(|this| { + if is_open { + match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { options, .. } => this + .children(tool_call.content.iter().map(|content| { + div() + .py_1p5() + .child( + self.render_tool_call_content( + content, window, cx, + ), + ) + .into_any_element() + })) + .child(self.render_permission_buttons( + options, + entry_ix, + tool_call.id.clone(), + tool_call.content.is_empty(), + cx, + )), + ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { + this.children(tool_call.content.iter().map(|content| { + div() + .py_1p5() + .child( + self.render_tool_call_content( + content, window, cx, + ), + ) + .into_any_element() + })) + } + ToolCallStatus::Rejected => this, + } + } else { + this + } + }), ) }) } @@ -1076,14 +1218,20 @@ impl AcpThreadView { cx: &Context, ) -> AnyElement { match content { - ToolCallContent::Markdown { markdown } => { - div() - .p_2() - .child(self.render_markdown( - markdown.clone(), - default_markdown_style(false, window, cx), - )) - .into_any_element() + ToolCallContent::ContentBlock { content } => { + if let Some(md) = content.markdown() { + div() + .p_2() + .child( + self.render_markdown( + md.clone(), + default_markdown_style(false, window, cx), + ), + ) + .into_any_element() + } else { + Empty.into_any_element() + } } ToolCallContent::Diff { diff: Diff { multibuffer, .. }, @@ -1092,223 +1240,56 @@ impl AcpThreadView { } } - fn render_tool_call_confirmation( + fn render_permission_buttons( &self, - tool_call_id: ToolCallId, - confirmation: &ToolCallConfirmation, - content: Option<&ToolCallContent>, - window: &Window, - cx: &Context, - ) -> AnyElement { - let confirmation_container = v_flex().mt_1().py_1p5(); - - match confirmation { - ToolCallConfirmation::Edit { description } => confirmation_container - .child( - div() - .px_2() - .children(description.clone().map(|description| { - self.render_markdown( - description, - default_markdown_style(false, window, cx), - ) - })), - ) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[AlwaysAllowOption { - id: "always_allow", - label: "Always Allow Edits".into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, - }], - tool_call_id, - cx, - )) - .into_any(), - ToolCallConfirmation::Execute { - command, - root_command, - description, - } => confirmation_container - .child(v_flex().px_2().pb_1p5().child(command.clone()).children( - description.clone().map(|description| { - self.render_markdown(description, default_markdown_style(false, window, cx)) - .on_url_click({ - let workspace = self.workspace.clone(); - move |text, window, cx| { - Self::open_link(text, &workspace, window, cx); - } - }) - }), - )) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[AlwaysAllowOption { - id: "always_allow", - label: format!("Always Allow {root_command}").into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, - }], - tool_call_id, - cx, - )) - .into_any(), - ToolCallConfirmation::Mcp { - server_name, - tool_name: _, - tool_display_name, - description, - } => confirmation_container - .child( - v_flex() - .px_2() - .pb_1p5() - .child(format!("{server_name} - {tool_display_name}")) - .children(description.clone().map(|description| { - self.render_markdown( - description, - default_markdown_style(false, window, cx), - ) - })), - ) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[ - AlwaysAllowOption { - id: "always_allow_server", - label: format!("Always Allow {server_name}").into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, - }, - AlwaysAllowOption { - id: "always_allow_tool", - label: format!("Always Allow {tool_display_name}").into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowTool, - }, - ], - tool_call_id, - cx, - )) - .into_any(), - ToolCallConfirmation::Fetch { description, urls } => confirmation_container - .child( - v_flex() - .px_2() - .pb_1p5() - .gap_1() - .children(urls.iter().map(|url| { - h_flex().child( - Button::new(url.clone(), url) - .icon(IconName::ArrowUpRight) - .icon_color(Color::Muted) - .icon_size(IconSize::XSmall) - .on_click({ - let url = url.clone(); - move |_, _, cx| cx.open_url(&url) - }), - ) - })) - .children(description.clone().map(|description| { - self.render_markdown( - description, - default_markdown_style(false, window, cx), - ) - })), - ) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[AlwaysAllowOption { - id: "always_allow", - label: "Always Allow".into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, - }], - tool_call_id, - cx, - )) - .into_any(), - ToolCallConfirmation::Other { description } => confirmation_container - .child(v_flex().px_2().pb_1p5().child(self.render_markdown( - description.clone(), - default_markdown_style(false, window, cx), - ))) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[AlwaysAllowOption { - id: "always_allow", - label: "Always Allow".into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, - }], - tool_call_id, - cx, - )) - .into_any(), - } - } - - fn render_confirmation_buttons( - &self, - always_allow_options: &[AlwaysAllowOption], - tool_call_id: ToolCallId, + options: &[acp::PermissionOption], + entry_ix: usize, + tool_call_id: acp::ToolCallId, + empty_content: bool, cx: &Context, ) -> Div { h_flex() - .pt_1p5() + .py_1p5() .px_1p5() .gap_1() .justify_end() - .border_t_1() - .border_color(self.tool_card_border_color(cx)) - .when(self.agent.supports_always_allow(), |this| { - this.children(always_allow_options.into_iter().map(|always_allow_option| { - let outcome = always_allow_option.outcome; - Button::new( - (always_allow_option.id, tool_call_id.0), - always_allow_option.label.clone(), - ) - .icon(IconName::CheckDouble) + .when(!empty_content, |this| { + this.border_t_1() + .border_color(self.tool_card_border_color(cx)) + }) + .children(options.iter().map(|option| { + let option_id = SharedString::from(option.id.0.clone()); + Button::new((option_id, entry_ix), option.name.clone()) + .map(|this| match option.kind { + acp::PermissionOptionKind::AllowOnce => { + this.icon(IconName::Check).icon_color(Color::Success) + } + acp::PermissionOptionKind::AllowAlways => { + this.icon(IconName::CheckDouble).icon_color(Color::Success) + } + acp::PermissionOptionKind::RejectOnce => { + this.icon(IconName::X).icon_color(Color::Error) + } + acp::PermissionOptionKind::RejectAlways => { + this.icon(IconName::X).icon_color(Color::Error) + } + }) .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) - .icon_color(Color::Success) .on_click(cx.listener({ - let id = tool_call_id; + let tool_call_id = tool_call_id.clone(); + let option_id = option.id.clone(); + let option_kind = option.kind; move |this, _, _, cx| { - this.authorize_tool_call(id, outcome, cx); + this.authorize_tool_call( + tool_call_id.clone(), + option_id.clone(), + option_kind, + cx, + ); } })) - })) - }) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ) + })) } fn render_diff_editor(&self, multibuffer: &Entity) -> AnyElement { @@ -1414,7 +1395,29 @@ impl AcpThreadView { .into_any() } - fn render_error_state(&self, e: &LoadError, cx: &Context) -> AnyElement { + fn render_server_exited(&self, status: ExitStatus, _cx: &Context) -> AnyElement { + v_flex() + .items_center() + .justify_center() + .child(self.render_error_agent_logo()) + .child( + v_flex() + .mt_4() + .mb_2() + .gap_0p5() + .text_center() + .items_center() + .child(Headline::new("Server exited unexpectedly").size(HeadlineSize::Medium)) + .child( + Label::new(format!("Exit status: {}", status.code().unwrap_or(-127))) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .into_any_element() + } + + fn render_load_error(&self, e: &LoadError, cx: &Context) -> AnyElement { let mut container = v_flex() .items_center() .justify_center() @@ -2070,15 +2073,15 @@ impl AcpThreadView { .icon_color(Color::Accent) .style(ButtonStyle::Filled) .disabled(self.thread().is_none() || is_editor_empty) - .on_click(cx.listener(|this, _, window, cx| { - this.chat(&Chat, window, cx); - })) .when(!is_editor_empty, |button| { button.tooltip(move |window, cx| Tooltip::for_action("Send", &Chat, window, cx)) }) .when(is_editor_empty, |button| { button.tooltip(Tooltip::text("Type a message to submit")) }) + .on_click(cx.listener(|this, _, window, cx| { + this.chat(&Chat, window, cx); + })) .into_any_element() } else { IconButton::new("stop-generation", IconName::StopFilled) @@ -2245,12 +2248,11 @@ impl AcpThreadView { .languages .language_for_name("Markdown"); - let (thread_summary, markdown) = match &self.thread_state { - ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { - let thread = thread.read(cx); - (thread.title().to_string(), thread.to_markdown(cx)) - } - ThreadState::Loading { .. } | ThreadState::LoadError(..) => return Task::ready(Ok(())), + let (thread_summary, markdown) = if let Some(thread) = self.thread() { + let thread = thread.read(cx); + (thread.title().to_string(), thread.to_markdown(cx)) + } else { + return Task::ready(Ok(())); }; window.spawn(cx, async move |cx| { @@ -2293,17 +2295,165 @@ impl AcpThreadView { self.list_state.scroll_to(ListOffset::default()); cx.notify(); } -} -impl Focusable for AcpThreadView { - fn focus_handle(&self, cx: &App) -> FocusHandle { - self.message_editor.focus_handle(cx) + pub fn scroll_to_bottom(&mut self, cx: &mut Context) { + if let Some(thread) = self.thread() { + let entry_count = thread.read(cx).entries().len(); + self.list_state.reset(entry_count); + cx.notify(); + } } -} -impl Render for AcpThreadView { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let open_as_markdown = IconButton::new("open-as-markdown", IconName::DocumentText) + fn notify_with_sound( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + self.play_notification_sound(window, cx); + self.show_notification(caption, icon, window, cx); + } + + fn play_notification_sound(&self, window: &Window, cx: &mut App) { + let settings = AgentSettings::get_global(cx); + if settings.play_sound_when_agent_done && !window.is_window_active() { + Audio::play_sound(Sound::AgentDone, cx); + } + } + + fn show_notification( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + if window.is_window_active() || !self.notifications.is_empty() { + return; + } + + let title = self.title(cx); + + match AgentSettings::get_global(cx).notify_when_agent_waiting { + NotifyWhenAgentWaiting::PrimaryScreen => { + if let Some(primary) = cx.primary_display() { + self.pop_up(icon, caption.into(), title, window, primary, cx); + } + } + NotifyWhenAgentWaiting::AllScreens => { + let caption = caption.into(); + for screen in cx.displays() { + self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx); + } + } + NotifyWhenAgentWaiting::Never => { + // Don't show anything + } + } + } + + fn pop_up( + &mut self, + icon: IconName, + caption: SharedString, + title: SharedString, + window: &mut Window, + screen: Rc, + cx: &mut Context, + ) { + let options = AgentNotification::window_options(screen, cx); + + let project_name = self.workspace.upgrade().and_then(|workspace| { + workspace + .read(cx) + .project() + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).root_name().to_string()) + }); + + if let Some(screen_window) = cx + .open_window(options, |_, cx| { + cx.new(|_| { + AgentNotification::new(title.clone(), caption.clone(), icon, project_name) + }) + }) + .log_err() + { + if let Some(pop_up) = screen_window.entity(cx).log_err() { + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push(cx.subscribe_in(&pop_up, window, { + |this, _, event, window, cx| match event { + AgentNotificationEvent::Accepted => { + let handle = window.window_handle(); + cx.activate(true); + + let workspace_handle = this.workspace.clone(); + + // If there are multiple Zed windows, activate the correct one. + cx.defer(move |cx| { + handle + .update(cx, |_view, window, _cx| { + window.activate_window(); + + if let Some(workspace) = workspace_handle.upgrade() { + workspace.update(_cx, |workspace, cx| { + workspace.focus_panel::(window, cx); + }); + } + }) + .log_err(); + }); + + this.dismiss_notifications(cx); + } + AgentNotificationEvent::Dismissed => { + this.dismiss_notifications(cx); + } + } + })); + + self.notifications.push(screen_window); + + // If the user manually refocuses the original window, dismiss the popup. + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push({ + let pop_up_weak = pop_up.downgrade(); + + cx.observe_window_activation(window, move |_, window, cx| { + if window.is_window_active() { + if let Some(pop_up) = pop_up_weak.upgrade() { + pop_up.update(cx, |_, cx| { + cx.emit(AgentNotificationEvent::Dismissed); + }); + } + } + }) + }); + } + } + } + + fn dismiss_notifications(&mut self, cx: &mut Context) { + for window in self.notifications.drain(..) { + window + .update(cx, |_, window, _| { + window.remove_window(); + }) + .ok(); + + self.notification_subscriptions.remove(&window); + } + } + + fn render_thread_controls(&self, cx: &Context) -> impl IntoElement { + let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileText) .icon_size(IconSize::XSmall) .icon_color(Color::Ignored) .tooltip(Tooltip::text("Open Thread as Markdown")) @@ -2322,6 +2472,60 @@ impl Render for AcpThreadView { this.scroll_to_top(cx); })); + h_flex() + .mr_1() + .pb_2() + .px(RESPONSE_PADDING_X) + .opacity(0.4) + .hover(|style| style.opacity(1.)) + .flex_wrap() + .justify_end() + .child(open_as_markdown) + .child(scroll_to_top) + } + + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .id("acp-thread-scrollbar") + .occlude() + .on_mouse_move(cx.listener(|_, _, _, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { + cx.stop_propagation(); + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) + } +} + +impl Focusable for AcpThreadView { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.message_editor.focus_handle(cx) + } +} + +impl Render for AcpThreadView { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() .size_full() .key_context("AcpThread") @@ -2329,66 +2533,80 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::previous_history_message)) .on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::open_agent_diff)) + .bg(cx.theme().colors().panel_background) .child(match &self.thread_state { - ThreadState::Unauthenticated { .. } => { - v_flex() - .p_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .child( - h_flex().mt_1p5().justify_center().child( - Button::new("sign-in", format!("Sign in to {}", self.agent.name())) - .on_click(cx.listener(|this, _, window, cx| { - this.authenticate(window, cx) - })), - ), - ) - } + ThreadState::Unauthenticated { connection } => v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_pending_auth_state()) + .child(h_flex().mt_1p5().justify_center().children( + connection.auth_methods().into_iter().map(|method| { + Button::new( + SharedString::from(method.id.0.clone()), + method.name.clone(), + ) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) + }) + }), + )), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2() .flex_1() .items_center() .justify_center() - .child(self.render_error_state(e, cx)), - ThreadState::Ready { thread, .. } => v_flex().flex_1().map(|this| { - if self.list_state.item_count() > 0 { - this.child( - list(self.list_state.clone()) + .child(self.render_load_error(e, cx)), + ThreadState::ServerExited { status } => v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_server_exited(*status, cx)), + ThreadState::Ready { thread, .. } => { + let thread_clone = thread.clone(); + + v_flex().flex_1().map(|this| { + if self.list_state.item_count() > 0 { + this.child( + list( + self.list_state.clone(), + cx.processor(|this, index: usize, window, cx| { + let Some((entry, len)) = this.thread().and_then(|thread| { + let entries = &thread.read(cx).entries(); + Some((entries.get(index)?, entries.len())) + }) else { + return Empty.into_any(); + }; + this.render_entry(index, len, entry, window, cx) + }), + ) .with_sizing_behavior(gpui::ListSizingBehavior::Auto) .flex_grow() .into_any(), - ) - .child( - h_flex() - .group("controls") - .mt_1() - .mr_1() - .py_2() - .px(RESPONSE_PADDING_X) - .opacity(0.4) - .hover(|style| style.opacity(1.)) - .flex_wrap() - .justify_end() - .child(open_as_markdown) - .child(scroll_to_top) - .into_any_element(), - ) - .children(match thread.read(cx).status() { - ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => None, - ThreadStatus::Generating => div() - .px_5() - .py_2() - .child(LoadingLabel::new("").size(LabelSize::Small)) - .into(), - }) - .children(self.render_activity_bar(&thread, window, cx)) - } else { - this.child(self.render_empty_state(cx)) - } - }), + ) + .child(self.render_vertical_scrollbar(cx)) + .children(match thread_clone.read(cx).status() { + ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => { + None + } + ThreadStatus::Generating => div() + .px_5() + .py_2() + .child(LoadingLabel::new("").size(LabelSize::Small)) + .into(), + }) + .children(self.render_activity_bar(&thread_clone, window, cx)) + } else { + this.child(self.render_empty_state(cx)) + } + }) + } }) .when_some(self.last_error.clone(), |el, error| { el.child( @@ -2574,3 +2792,367 @@ fn plan_label_markdown_style( ..default_md_style } } + +#[cfg(test)] +mod tests { + use agent_client_protocol::SessionId; + use editor::EditorSettings; + use fs::FakeFs; + use futures::future::try_join_all; + use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; + use rand::Rng; + use settings::SettingsStore; + + use super::*; + + #[gpui::test] + async fn test_drop(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, _cx) = setup_thread_view(StubAgentServer::default(), cx).await; + let weak_view = thread_view.downgrade(); + drop(thread_view); + assert!(!weak_view.is_upgradable()); + } + + #[gpui::test] + async fn test_notification_for_stop_event(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_notification_for_error(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { + init_test(cx); + + let tool_call_id = acp::ToolCallId("1".into()); + let tool_call = acp::ToolCall { + id: tool_call_id.clone(), + title: "Label".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Pending, + content: vec!["hi".into()], + locations: vec![], + raw_input: None, + }; + let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) + .with_permission_requests(HashMap::from_iter([( + tool_call_id, + vec![acp::PermissionOption { + id: acp::PermissionOptionId("1".into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }], + )])); + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + async fn setup_thread_view( + agent: impl AgentServer + 'static, + cx: &mut TestAppContext, + ) -> (Entity, &mut VisualTestContext) { + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(agent), + workspace.downgrade(), + project, + Rc::new(RefCell::new(MessageHistory::default())), + 1, + None, + window, + cx, + ) + }) + }); + cx.run_until_parked(); + (thread_view, cx) + } + + struct StubAgentServer { + connection: C, + } + + impl StubAgentServer { + fn new(connection: C) -> Self { + Self { connection } + } + } + + impl StubAgentServer { + fn default() -> Self { + Self::new(StubAgentConnection::default()) + } + } + + impl AgentServer for StubAgentServer + where + C: 'static + AgentConnection + Send + Clone, + { + fn logo(&self) -> ui::IconName { + unimplemented!() + } + + fn name(&self) -> &'static str { + unimplemented!() + } + + fn empty_state_headline(&self) -> &'static str { + unimplemented!() + } + + fn empty_state_message(&self) -> &'static str { + unimplemented!() + } + + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + _cx: &mut App, + ) -> Task>> { + Task::ready(Ok(Rc::new(self.connection.clone()))) + } + } + + #[derive(Clone, Default)] + struct StubAgentConnection { + sessions: Arc>>>, + permission_requests: HashMap>, + updates: Vec, + } + + impl StubAgentConnection { + fn new(updates: Vec) -> Self { + Self { + updates, + permission_requests: HashMap::default(), + sessions: Arc::default(), + } + } + + fn with_permission_requests( + mut self, + permission_requests: HashMap>, + ) -> Self { + self.permission_requests = permission_requests; + self + } + } + + impl AgentConnection for StubAgentConnection { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + let session_id = SessionId( + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .into(), + ); + let thread = cx + .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) + .unwrap(); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + let mut tasks = vec![]; + for update in &self.updates { + let thread = thread.clone(); + let update = update.clone(); + let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update + && let Some(options) = self.permission_requests.get(&tool_call.id) + { + Some((tool_call.clone(), options.clone())) + } else { + None + }; + let task = cx.spawn(async move |cx| { + if let Some((tool_call, options)) = permission_request { + let permission = thread.update(cx, |thread, cx| { + thread.request_tool_call_permission( + tool_call.clone(), + options.clone(), + cx, + ) + })?; + permission.await?; + } + thread.update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + })?; + anyhow::Ok(()) + }); + tasks.push(task); + } + cx.spawn(async move |_| { + try_join_all(tasks).await?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + } + + #[derive(Clone)] + struct SaboteurAgentConnection; + + impl AgentConnection for SaboteurAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + Task::ready(Ok(cx + .new(|cx| { + AcpThread::new( + "SaboteurAgentConnection", + self, + project, + SessionId("test".into()), + cx, + ) + }) + .unwrap())) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + Task::ready(Err(anyhow::anyhow!("Error prompting"))) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AgentSettings::register(cx); + workspace::init_settings(cx); + ThemeSettings::register(cx); + release_channel::init(SemanticVersion::default(), cx); + EditorSettings::register(cx); + }); + } +} diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index e27c318221..c4d4e72252 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -14,6 +14,7 @@ use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; use anyhow::Context as _; use assistant_tool::ToolUseStatus; use audio::{Audio, Sound}; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; use editor::scroll::Autoscroll; @@ -52,7 +53,6 @@ use util::ResultExt as _; use util::markdown::MarkdownCodeBlock; use workspace::{CollaboratorId, Workspace}; use zed_actions::assistant::OpenRulesLibrary; -use zed_llm_client::CompletionIntent; const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container"; const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1; @@ -69,8 +69,6 @@ pub struct ActiveThread { messages: Vec, list_state: ListState, scrollbar_state: ScrollbarState, - show_scrollbar: bool, - hide_scrollbar_task: Option>, rendered_messages_by_id: HashMap, rendered_tool_uses: HashMap, editing_message: Option<(MessageId, EditingMessageState)>, @@ -780,13 +778,7 @@ impl ActiveThread { cx.observe_global::(|_, cx| cx.notify()), ]; - let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), { - let this = cx.entity().downgrade(); - move |ix, window: &mut Window, cx: &mut App| { - this.update(cx, |this, cx| this.render_message(ix, window, cx)) - .unwrap() - } - }); + let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.)); let workspace_subscription = if let Some(workspace) = workspace.upgrade() { Some(cx.observe_release(&workspace, |this, _, cx| { @@ -811,9 +803,7 @@ impl ActiveThread { expanded_thinking_segments: HashMap::default(), expanded_code_blocks: HashMap::default(), list_state: list_state.clone(), - scrollbar_state: ScrollbarState::new(list_state), - show_scrollbar: false, - hide_scrollbar_task: None, + scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), editing_message: None, last_error: None, copied_code_block_ids: HashSet::default(), @@ -1846,7 +1836,12 @@ impl ActiveThread { ))) } - fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context) -> AnyElement { + fn render_message( + &mut self, + ix: usize, + window: &mut Window, + cx: &mut Context, + ) -> AnyElement { let message_id = self.messages[ix]; let workspace = self.workspace.clone(); let thread = self.thread.read(cx); @@ -3503,60 +3498,37 @@ impl ActiveThread { } } - fn render_vertical_scrollbar(&self, cx: &mut Context) -> Option> { - if !self.show_scrollbar && !self.scrollbar_state.is_dragging() { - return None; - } - - Some( - div() - .occlude() - .id("active-thread-scrollbar") - .on_mouse_move(cx.listener(|_, _, _, cx| { - cx.notify(); - cx.stop_propagation() - })) - .on_hover(|_, _, cx| { + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .occlude() + .id("active-thread-scrollbar") + .on_mouse_move(cx.listener(|_, _, _, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { cx.stop_propagation(); - }) - .on_any_mouse_down(|_, _, cx| { - cx.stop_propagation(); - }) - .on_mouse_up( - MouseButton::Left, - cx.listener(|_, _, _, cx| { - cx.stop_propagation(); - }), - ) - .on_scroll_wheel(cx.listener(|_, _, _, cx| { - cx.notify(); - })) - .h_full() - .absolute() - .right_1() - .top_1() - .bottom_0() - .w(px(12.)) - .cursor_default() - .children(Scrollbar::vertical(self.scrollbar_state.clone())), - ) - } - - fn hide_scrollbar_later(&mut self, cx: &mut Context) { - const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); - self.hide_scrollbar_task = Some(cx.spawn(async move |thread, cx| { - cx.background_executor() - .timer(SCROLLBAR_SHOW_INTERVAL) - .await; - thread - .update(cx, |thread, cx| { - if !thread.scrollbar_state.is_dragging() { - thread.show_scrollbar = false; - cx.notify(); - } - }) - .log_err(); - })) + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) } pub fn is_codeblock_expanded(&self, message_id: MessageId, ix: usize) -> bool { @@ -3597,26 +3569,8 @@ impl Render for ActiveThread { .size_full() .relative() .bg(cx.theme().colors().panel_background) - .on_mouse_move(cx.listener(|this, _, _, cx| { - this.show_scrollbar = true; - this.hide_scrollbar_later(cx); - cx.notify(); - })) - .on_scroll_wheel(cx.listener(|this, _, _, cx| { - this.show_scrollbar = true; - this.hide_scrollbar_later(cx); - cx.notify(); - })) - .on_mouse_up( - MouseButton::Left, - cx.listener(|this, _, _, cx| { - this.hide_scrollbar_later(cx); - }), - ) - .child(list(self.list_state.clone()).flex_grow()) - .when_some(self.render_vertical_scrollbar(cx), |this, scrollbar| { - this.child(scrollbar) - }) + .child(list(self.list_state.clone(), cx.processor(Self::render_message)).flex_grow()) + .child(self.render_vertical_scrollbar(cx)) } } diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 7a160a5649..02c15b7e41 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -7,6 +7,7 @@ use std::{sync::Arc, time::Duration}; use agent_settings::AgentSettings; use assistant_tool::{ToolSource, ToolWorkingSet}; +use cloud_llm_client::Plan; use collections::HashMap; use context_server::ContextServerId; use extension::ExtensionManifest; @@ -25,7 +26,6 @@ use project::{ context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore}, project_settings::{ContextServerSettings, ProjectSettings}, }; -use proto::Plan; use settings::{Settings, update_settings_file}; use ui::{ Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, @@ -180,11 +180,18 @@ impl AgentConfiguration { let current_plan = if is_zed_provider { self.workspace .upgrade() - .and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan()) + .and_then(|workspace| workspace.read(cx).user_store().read(cx).plan()) } else { None }; + let is_signed_in = self + .workspace + .read_with(cx, |workspace, _| { + workspace.client().status().borrow().is_connected() + }) + .unwrap_or(false); + v_flex() .w_full() .when(is_expanded, |this| this.mb_2()) @@ -233,8 +240,8 @@ impl AgentConfiguration { .size(LabelSize::Large), ) .map(|this| { - if is_zed_provider { - this.gap_2().child( + if is_zed_provider && is_signed_in { + this.child( self.render_zed_plan_info(current_plan, cx), ) } else { @@ -321,46 +328,62 @@ impl AgentConfiguration { .justify_between() .child( v_flex() + .w_full() .gap_0p5() - .child(Headline::new("LLM Providers")) + .child( + h_flex() + .w_full() + .gap_2() + .justify_between() + .child(Headline::new("LLM Providers")) + .child( + PopoverMenu::new("add-provider-popover") + .trigger( + Button::new("add-provider", "Add Provider") + .icon_position(IconPosition::Start) + .icon(IconName::Plus) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .label_size(LabelSize::Small), + ) + .anchor(gpui::Corner::TopRight) + .menu({ + let workspace = self.workspace.clone(); + move |window, cx| { + Some(ContextMenu::build( + window, + cx, + |menu, _window, _cx| { + menu.header("Compatible APIs").entry( + "OpenAI", + None, + { + let workspace = + workspace.clone(); + move |window, cx| { + workspace + .update(cx, |workspace, cx| { + AddLlmProviderModal::toggle( + LlmCompatibleProvider::OpenAi, + workspace, + window, + cx, + ); + }) + .log_err(); + } + }, + ) + }, + )) + } + }), + ), + ) .child( Label::new("Add at least one provider to use AI-powered features.") .color(Color::Muted), ), - ) - .child( - PopoverMenu::new("add-provider-popover") - .trigger( - Button::new("add-provider", "Add Provider") - .icon_position(IconPosition::Start) - .icon(IconName::Plus) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .label_size(LabelSize::Small), - ) - .anchor(gpui::Corner::TopRight) - .menu({ - let workspace = self.workspace.clone(); - move |window, cx| { - Some(ContextMenu::build(window, cx, |menu, _window, _cx| { - menu.header("Compatible APIs").entry("OpenAI", None, { - let workspace = workspace.clone(); - move |window, cx| { - workspace - .update(cx, |workspace, cx| { - AddLlmProviderModal::toggle( - LlmCompatibleProvider::OpenAi, - workspace, - window, - cx, - ); - }) - .log_err(); - } - }) - })) - } - }), ), ) .child( @@ -381,9 +404,11 @@ impl AgentConfiguration { let fs = self.fs.clone(); SwitchField::new( - "single-file-review", - "Enable single-file agent reviews", - "Agent edits are also displayed in single-file editors for review.", + "always-allow-tool-actions-switch", + "Allow running commands without asking for confirmation", + Some( + "The agent can perform potentially destructive actions without asking for your confirmation.".into(), + ), always_allow_tool_actions, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -401,7 +426,7 @@ impl AgentConfiguration { SwitchField::new( "single-file-review", "Enable single-file agent reviews", - "Agent edits are also displayed in single-file editors for review.", + Some("Agent edits are also displayed in single-file editors for review.".into()), single_file_review, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -419,7 +444,9 @@ impl AgentConfiguration { SwitchField::new( "sound-notification", "Play sound when finished generating", - "Hear a notification sound when the agent is done generating changes or needs your input.", + Some( + "Hear a notification sound when the agent is done generating changes or needs your input.".into(), + ), play_sound_when_agent_done, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -437,7 +464,9 @@ impl AgentConfiguration { SwitchField::new( "modifier-send", "Use modifier to submit a message", - "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.", + Some( + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(), + ), use_modifier_to_send, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -479,7 +508,7 @@ impl AgentConfiguration { .blend(cx.theme().colors().text_accent.opacity(0.2)); let (plan_name, label_color, bg_color) = match plan { - Plan::Free => ("Free", Color::Default, free_chip_bg), + Plan::ZedFree => ("Free", Color::Default, free_chip_bg), Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg), Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg), }; @@ -510,7 +539,7 @@ impl AgentConfiguration { v_flex() .gap_0p5() .child(Headline::new("Model Context Protocol (MCP) Servers")) - .child(Label::new("Connect to context servers via the Model Context Protocol either via Zed extensions or directly.").color(Color::Muted)), + .child(Label::new("Connect to context servers through the Model Context Protocol, either using Zed extensions or directly.").color(Color::Muted)), ) .children( context_server_ids.into_iter().map(|context_server_id| { diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs index 94b32d156b..401a633488 100644 --- a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs +++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs @@ -272,42 +272,34 @@ impl AddLlmProviderModal { cx.emit(DismissEvent); } - fn render_section(&self) -> Section { - Section::new() - .child(self.input.provider_name.clone()) - .child(self.input.api_url.clone()) - .child(self.input.api_key.clone()) - } - - fn render_model_section(&self, cx: &mut Context) -> Section { - Section::new().child( - v_flex() - .gap_2() - .child( - h_flex() - .justify_between() - .child(Label::new("Models").size(LabelSize::Small)) - .child( - Button::new("add-model", "Add Model") - .icon(IconName::Plus) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .label_size(LabelSize::Small) - .on_click(cx.listener(|this, _, window, cx| { - this.input.add_model(window, cx); - cx.notify(); - })), - ), - ) - .children( - self.input - .models - .iter() - .enumerate() - .map(|(ix, _)| self.render_model(ix, cx)), - ), - ) + fn render_model_section(&self, cx: &mut Context) -> impl IntoElement { + v_flex() + .mt_1() + .gap_2() + .child( + h_flex() + .justify_between() + .child(Label::new("Models").size(LabelSize::Small)) + .child( + Button::new("add-model", "Add Model") + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + this.input.add_model(window, cx); + cx.notify(); + })), + ), + ) + .children( + self.input + .models + .iter() + .enumerate() + .map(|(ix, _)| self.render_model(ix, cx)), + ) } fn render_model(&self, ix: usize, cx: &mut Context) -> impl IntoElement + use<> { @@ -393,10 +385,14 @@ impl Render for AddLlmProviderModal { .child( v_flex() .id("modal_content") + .size_full() .max_h_128() .overflow_y_scroll() - .gap_2() - .child(self.render_section()) + .px(DynamicSpacing::Base12.rems(cx)) + .gap(DynamicSpacing::Base04.rems(cx)) + .child(self.input.provider_name.clone()) + .child(self.input.api_url.clone()) + .child(self.input.api_key.clone()) .child(self.render_model_section(cx)), ) .footer( diff --git a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs index 45536ff13b..5d44bb2d92 100644 --- a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs @@ -483,7 +483,7 @@ impl ManageProfilesModal { let icon = match mode.profile_id.as_str() { "write" => IconName::Pencil, - "ask" => IconName::MessageBubbles, + "ask" => IconName::Chat, _ => IconName::UserRoundPen, }; diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index e69664ce88..e1ceaf761d 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1506,8 +1506,7 @@ impl AgentDiff { .read(cx) .entries() .last() - .and_then(|entry| entry.diff()) - .is_some() + .map_or(false, |entry| entry.diffs().next().is_some()) { self.update_reviewing_editors(workspace, window, cx); } @@ -1517,12 +1516,15 @@ impl AgentDiff { .read(cx) .entries() .get(*ix) - .and_then(|entry| entry.diff()) - .is_some() + .map_or(false, |entry| entry.diffs().next().is_some()) { self.update_reviewing_editors(workspace, window, cx); } } + AcpThreadEvent::Stopped + | AcpThreadEvent::ToolAuthorizationRequired + | AcpThreadEvent::Error + | AcpThreadEvent::ServerExited(_) => {} } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 6ae2f12b5e..8f5fff5da3 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -43,7 +43,8 @@ use anyhow::{Result, anyhow}; use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; -use client::{DisableAiSettings, UserStore, zed_urls}; +use client::{UserStore, zed_urls}; +use cloud_llm_client::{CompletionIntent, Plan, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; use feature_flags::{self, FeatureFlagAppExt}; use fs::Fs; @@ -57,9 +58,8 @@ use language::LanguageRegistry; use language_model::{ ConfigurationError, ConfiguredModel, LanguageModelProviderTosView, LanguageModelRegistry, }; -use project::{Project, ProjectPath, Worktree}; +use project::{DisableAiSettings, Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; -use proto::Plan; use rules_library::{RulesLibrary, open_rules_library}; use search::{BufferSearchBar, buffer_search}; use settings::{Settings, update_settings_file}; @@ -77,10 +77,9 @@ use workspace::{ }; use zed_actions::{ DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize, - agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding, ToggleModelSelector}, + agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector}, assistant::{OpenRulesLibrary, ToggleFocus}, }; -use zed_llm_client::{CompletionIntent, UsageLimit}; const AGENT_PANEL_KEY: &str = "agent_panel"; @@ -105,7 +104,7 @@ pub fn init(cx: &mut App) { panel.update(cx, |panel, cx| panel.open_history(window, cx)); } }) - .register_action(|workspace, _: &OpenConfiguration, window, cx| { + .register_action(|workspace, _: &OpenSettings, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| panel.open_configuration(window, cx)); @@ -440,7 +439,7 @@ pub struct AgentPanel { local_timezone: UtcOffset, active_view: ActiveView, acp_message_history: - Rc>>, + Rc>>>, previous_view: Option, history_store: Entity, history: Entity, @@ -564,22 +563,8 @@ impl AgentPanel { let inline_assist_context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade()))); - let message_editor = cx.new(|cx| { - MessageEditor::new( - fs.clone(), - workspace.clone(), - user_store.clone(), - message_editor_context_store.clone(), - prompt_store.clone(), - thread_store.downgrade(), - context_store.downgrade(), - thread.clone(), - window, - cx, - ) - }); - let thread_id = thread.read(cx).id().clone(); + let history_store = cx.new(|cx| { HistoryStore::new( thread_store.clone(), @@ -589,6 +574,21 @@ impl AgentPanel { ) }); + let message_editor = cx.new(|cx| { + MessageEditor::new( + fs.clone(), + workspace.clone(), + message_editor_context_store.clone(), + prompt_store.clone(), + thread_store.downgrade(), + context_store.downgrade(), + Some(history_store.downgrade()), + thread.clone(), + window, + cx, + ) + }); + cx.observe(&history_store, |_, _, cx| cx.notify()).detach(); let active_thread = cx.new(|cx| { @@ -846,11 +846,11 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), context_store.clone(), self.prompt_store.clone(), self.thread_store.downgrade(), self.context_store.downgrade(), + Some(self.history_store.downgrade()), thread.clone(), window, cx, @@ -970,13 +970,7 @@ impl AgentPanel { ) }); - this.set_active_view( - ActiveView::ExternalAgentThread { - thread_view: thread_view.clone(), - }, - window, - cx, - ); + this.set_active_view(ActiveView::ExternalAgentThread { thread_view }, window, cx); }) }) .detach_and_log_err(cx); @@ -1119,11 +1113,11 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), context_store, self.prompt_store.clone(), self.thread_store.downgrade(), self.context_store.downgrade(), + Some(self.history_store.downgrade()), thread.clone(), window, cx, @@ -1880,10 +1874,10 @@ impl AgentPanel { }), ); - let zoom_in_label = if self.is_zoomed(window, cx) { - "Zoom Out" + let full_screen_label = if self.is_zoomed(window, cx) { + "Disable Full Screen" } else { - "Zoom In" + "Enable Full Screen" }; let active_thread = match &self.active_view { @@ -1911,27 +1905,6 @@ impl AgentPanel { .when(cx.has_flag::(), |this| { this.header("Zed Agent") }) - .item( - ContextMenuEntry::new("New Thread") - .icon(IconName::NewThread) - .icon_color(Color::Muted) - .action(NewThread::default().boxed_clone()) - .handler(move |window, cx| { - window.dispatch_action( - NewThread::default().boxed_clone(), - cx, - ); - }), - ) - .item( - ContextMenuEntry::new("New Text Thread") - .icon(IconName::NewTextThread) - .icon_color(Color::Muted) - .action(NewTextThread.boxed_clone()) - .handler(move |window, cx| { - window.dispatch_action(NewTextThread.boxed_clone(), cx); - }), - ) .when_some(active_thread, |this, active_thread| { let thread = active_thread.read(cx); @@ -1939,7 +1912,7 @@ impl AgentPanel { let thread_id = thread.id().clone(); this.item( ContextMenuEntry::new("New From Summary") - .icon(IconName::NewFromSummary) + .icon(IconName::ThreadFromSummary) .icon_color(Color::Muted) .handler(move |window, cx| { window.dispatch_action( @@ -1954,6 +1927,27 @@ impl AgentPanel { this } }) + .item( + ContextMenuEntry::new("New Thread") + .icon(IconName::Thread) + .icon_color(Color::Muted) + .action(NewThread::default().boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Text Thread") + .icon(IconName::TextThread) + .icon_color(Color::Muted) + .action(NewTextThread.boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action(NewTextThread.boxed_clone(), cx); + }), + ) .when(cx.has_flag::(), |this| { this.separator() .header("External Agents") @@ -1987,6 +1981,22 @@ impl AgentPanel { ); }), ) + .item( + ContextMenuEntry::new("New Native Agent Thread") + .icon(IconName::ZedAssistant) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::NativeAgent, + ), + } + .boxed_clone(), + cx, + ); + }), + ) }); menu })) @@ -2012,65 +2022,70 @@ impl AgentPanel { ) .anchor(Corner::TopRight) .with_handle(self.agent_panel_menu_handle.clone()) - .menu(move |window, cx| { - Some(ContextMenu::build(window, cx, |mut menu, _window, _| { - if let Some(usage) = usage { + .menu({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Some(ContextMenu::build(window, cx, |mut menu, _window, _| { + menu = menu.context(focus_handle.clone()); + if let Some(usage) = usage { + menu = menu + .header_with_link("Prompt Usage", "Manage", account_url.clone()) + .custom_entry( + move |_window, cx| { + let used_percentage = match usage.limit { + UsageLimit::Limited(limit) => { + Some((usage.amount as f32 / limit as f32) * 100.) + } + UsageLimit::Unlimited => None, + }; + + h_flex() + .flex_1() + .gap_1p5() + .children(used_percentage.map(|percent| { + ProgressBar::new("usage", percent, 100., cx) + })) + .child( + Label::new(match usage.limit { + UsageLimit::Limited(limit) => { + format!("{} / {limit}", usage.amount) + } + UsageLimit::Unlimited => { + format!("{} / ∞", usage.amount) + } + }) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any_element() + }, + move |_, cx| cx.open_url(&zed_urls::account_url(cx)), + ) + .separator() + } + menu = menu - .header_with_link("Prompt Usage", "Manage", account_url.clone()) - .custom_entry( - move |_window, cx| { - let used_percentage = match usage.limit { - UsageLimit::Limited(limit) => { - Some((usage.amount as f32 / limit as f32) * 100.) - } - UsageLimit::Unlimited => None, - }; - - h_flex() - .flex_1() - .gap_1p5() - .children(used_percentage.map(|percent| { - ProgressBar::new("usage", percent, 100., cx) - })) - .child( - Label::new(match usage.limit { - UsageLimit::Limited(limit) => { - format!("{} / {limit}", usage.amount) - } - UsageLimit::Unlimited => { - format!("{} / ∞", usage.amount) - } - }) - .size(LabelSize::Small) - .color(Color::Muted), - ) - .into_any_element() - }, - move |_, cx| cx.open_url(&zed_urls::account_url(cx)), + .header("MCP Servers") + .action( + "View Server Extensions", + Box::new(zed_actions::Extensions { + category_filter: Some( + zed_actions::ExtensionCategoryFilter::ContextServers, + ), + id: None, + }), ) + .action("Add Custom Server…", Box::new(AddContextServer)) + .separator(); + + menu = menu + .action("Rules…", Box::new(OpenRulesLibrary::default())) + .action("Settings", Box::new(OpenSettings)) .separator() - } - - menu = menu - .header("MCP Servers") - .action( - "View Server Extensions", - Box::new(zed_actions::Extensions { - category_filter: Some( - zed_actions::ExtensionCategoryFilter::ContextServers, - ), - id: None, - }), - ) - .action("Add Custom Server…", Box::new(AddContextServer)) - .separator(); - - menu = menu - .action("Rules…", Box::new(OpenRulesLibrary::default())) - .action("Settings", Box::new(OpenConfiguration)) - .action(zoom_in_label, Box::new(ToggleZoom)); - menu - })) + .action(full_screen_label, Box::new(ToggleZoom)); + menu + })) + } }); h_flex() @@ -2271,10 +2286,10 @@ impl AgentPanel { | ActiveView::Configuration => return false, } - let plan = self.user_store.read(cx).current_plan(); + let plan = self.user_store.read(cx).plan(); let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); - matches!(plan, Some(Plan::Free)) && has_previous_trial + matches!(plan, Some(Plan::ZedFree)) && has_previous_trial } fn should_render_onboarding(&self, cx: &mut Context) -> bool { @@ -2283,20 +2298,21 @@ impl AgentPanel { } match &self.active_view { - ActiveView::Thread { thread, .. } => thread - .read(cx) - .thread() - .read(cx) - .configured_model() - .map_or(true, |model| { - model.provider.id() == language_model::ZED_CLOUD_PROVIDER_ID - }), - ActiveView::TextThread { .. } => LanguageModelRegistry::global(cx) - .read(cx) - .default_model() - .map_or(true, |model| { - model.provider.id() == language_model::ZED_CLOUD_PROVIDER_ID - }), + ActiveView::Thread { .. } | ActiveView::TextThread { .. } => { + let history_is_empty = self + .history_store + .update(cx, |store, cx| store.recent_entries(1, cx).is_empty()); + + let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .any(|provider| { + provider.is_authenticated(cx) + && provider.id() != language_model::ZED_CLOUD_PROVIDER_ID + }); + + history_is_empty || !has_configured_non_zed_providers + } ActiveView::ExternalAgentThread { .. } | ActiveView::History | ActiveView::Configuration => false, @@ -2317,9 +2333,8 @@ impl AgentPanel { Some( div() - .size_full() .when(thread_view, |this| { - this.bg(cx.theme().colors().panel_background) + this.size_full().bg(cx.theme().colors().panel_background) }) .when(text_thread_view, |this| { this.bg(cx.theme().colors().editor_background) @@ -2460,14 +2475,14 @@ impl AgentPanel { .icon_color(Color::Muted) .full_width() .key_binding(KeyBinding::for_action_in( - &OpenConfiguration, + &OpenSettings, &focus_handle, window, cx, )) .on_click(|_event, window, cx| { window.dispatch_action( - OpenConfiguration.boxed_clone(), + OpenSettings.boxed_clone(), cx, ) }), @@ -2554,7 +2569,7 @@ impl AgentPanel { NewThreadButton::new( "new-thread-btn", "New Thread", - IconName::NewThread, + IconName::Thread, ) .keybinding(KeyBinding::for_action_in( &NewThread::default(), @@ -2575,7 +2590,7 @@ impl AgentPanel { NewThreadButton::new( "new-text-thread-btn", "New Text Thread", - IconName::NewTextThread, + IconName::TextThread, ) .keybinding(KeyBinding::for_action_in( &NewTextThread, @@ -2644,6 +2659,31 @@ impl AgentPanel { ) }, ), + ) + .child( + NewThreadButton::new( + "new-native-agent-thread-btn", + "New Native Agent Thread", + IconName::ZedAssistant, + ) + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::NativeAgent, + ), + }), + cx, + ) + }, + ), ), ) }), @@ -2672,16 +2712,11 @@ impl AgentPanel { .style(ButtonStyle::Tinted(ui::TintColor::Warning)) .label_size(LabelSize::Small) .key_binding( - KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - ) - .map(|kb| kb.size(rems_from_px(12.))), + KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))), ) .on_click(|_event, window, cx| { - window.dispatch_action(OpenConfiguration.boxed_clone(), cx) + window.dispatch_action(OpenSettings.boxed_clone(), cx) }), ), ConfigurationError::ProviderPendingTermsAcceptance(provider) => { @@ -2875,7 +2910,7 @@ impl AgentPanel { ) -> AnyElement { let error_message = match plan { Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", - Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.", + Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.", }; let icon = Icon::new(IconName::XCircle) @@ -3185,7 +3220,7 @@ impl Render for AgentPanel { .on_action(cx.listener(|this, _: &OpenHistory, window, cx| { this.open_history(window, cx); })) - .on_action(cx.listener(|this, _: &OpenConfiguration, window, cx| { + .on_action(cx.listener(|this, _: &OpenSettings, window, cx| { this.open_configuration(window, cx); })) .on_action(cx.listener(Self::open_active_thread_as_markdown)) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index cac0f1adac..fceb8f4c45 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -31,7 +31,7 @@ use std::sync::Arc; use agent::{Thread, ThreadId}; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection}; use assistant_slash_command::SlashCommandRegistry; -use client::{Client, DisableAiSettings}; +use client::Client; use command_palette_hooks::CommandPaletteFilter; use feature_flags::FeatureFlagAppExt as _; use fs::Fs; @@ -40,6 +40,7 @@ use language::LanguageRegistry; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, }; +use project::DisableAiSettings; use prompt_store::PromptBuilder; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -150,6 +151,7 @@ enum ExternalAgent { #[default] Gemini, ClaudeCode, + NativeAgent, } impl ExternalAgent { @@ -157,6 +159,7 @@ impl ExternalAgent { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), + ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer), } } } @@ -262,6 +265,8 @@ fn update_command_palette_filter(cx: &mut App) { if disable_ai { filter.hide_namespace("agent"); filter.hide_namespace("assistant"); + filter.hide_namespace("copilot"); + filter.hide_namespace("supermaven"); filter.hide_namespace("zed_predict_onboarding"); filter.hide_namespace("edit_prediction"); @@ -282,6 +287,7 @@ fn update_command_palette_filter(cx: &mut App) { } else { filter.show_namespace("agent"); filter.show_namespace("assistant"); + filter.show_namespace("copilot"); filter.show_namespace("zed_predict_onboarding"); filter.show_namespace("edit_prediction"); diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 64498e9281..615142b73d 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -6,6 +6,7 @@ use agent::{ use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; +use cloud_llm_client::CompletionIntent; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; use futures::{ @@ -35,7 +36,6 @@ use std::{ }; use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; -use zed_llm_client::CompletionIntent; pub struct BufferCodegen { alternatives: Vec>, diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 5cc56b014e..32f9a096d9 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -148,7 +148,7 @@ impl ContextPickerMode { Self::File => IconName::File, Self::Symbol => IconName::Code, Self::Fetch => IconName::Globe, - Self::Thread => IconName::MessageBubbles, + Self::Thread => IconName::Thread, Self::Rules => RULES_ICON, } } diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index b377e40b19..5ca0913be7 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -423,7 +423,7 @@ impl ContextPickerCompletionProvider { let icon_for_completion = if recent { IconName::HistoryRerun } else { - IconName::MessageBubbles + IconName::Thread }; let new_text = format!("{} ", MentionLink::for_thread(&thread_entry)); let new_text_len = new_text.len(); @@ -436,7 +436,7 @@ impl ContextPickerCompletionProvider { source: project::CompletionSource::Custom, icon_path: Some(icon_for_completion.path().into()), confirm: Some(confirm_completion_callback( - IconName::MessageBubbles.path().into(), + IconName::Thread.path().into(), thread_entry.title().clone(), excerpt_id, source_range.start, diff --git a/crates/agent_ui/src/context_picker/thread_context_picker.rs b/crates/agent_ui/src/context_picker/thread_context_picker.rs index cb2e97a493..15cc731f8f 100644 --- a/crates/agent_ui/src/context_picker/thread_context_picker.rs +++ b/crates/agent_ui/src/context_picker/thread_context_picker.rs @@ -253,7 +253,7 @@ pub fn render_thread_context_entry( .gap_1p5() .max_w_72() .child( - Icon::new(IconName::MessageBubbles) + Icon::new(IconName::Thread) .size(IconSize::XSmall) .color(Color::Muted), ) diff --git a/crates/agent_ui/src/context_strip.rs b/crates/agent_ui/src/context_strip.rs index 080ffd2ea0..369964f165 100644 --- a/crates/agent_ui/src/context_strip.rs +++ b/crates/agent_ui/src/context_strip.rs @@ -504,7 +504,7 @@ impl Render for ContextStrip { ) .on_click({ Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| { - if event.down.click_count > 1 { + if event.click_count() > 1 { this.open_context(&context, window, cx); } else { this.focused_index = Some(i); diff --git a/crates/agent_ui/src/debug.rs b/crates/agent_ui/src/debug.rs index ff6538dc85..bd34659210 100644 --- a/crates/agent_ui/src/debug.rs +++ b/crates/agent_ui/src/debug.rs @@ -1,10 +1,10 @@ #![allow(unused, dead_code)] use client::{ModelRequestUsage, RequestUsage}; +use cloud_llm_client::{Plan, UsageLimit}; use gpui::Global; use std::ops::{Deref, DerefMut}; use ui::prelude::*; -use zed_llm_client::{Plan, UsageLimit}; /// Debug only: Used for testing various account states /// diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 44ec050ae2..4a4a747899 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -16,7 +16,7 @@ use agent::{ }; use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; -use client::{DisableAiSettings, telemetry::Telemetry}; +use client::telemetry::Telemetry; use collections::{HashMap, HashSet, VecDeque, hash_map}; use editor::SelectionEffects; use editor::{ @@ -39,7 +39,7 @@ use language_model::{ }; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; -use project::{CodeAction, LspAction, Project, ProjectTransaction}; +use project::{CodeAction, DisableAiSettings, LspAction, Project, ProjectTransaction}; use prompt_store::{PromptBuilder, PromptStore}; use settings::{Settings, SettingsStore}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; @@ -48,7 +48,7 @@ use text::{OffsetRangeExt, ToPoint as _}; use ui::prelude::*; use util::{RangeExt, ResultExt, maybe}; use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId}; -use zed_actions::agent::OpenConfiguration; +use zed_actions::agent::OpenSettings; pub fn init( fs: Arc, @@ -162,7 +162,7 @@ impl InlineAssistant { let window = windows[0]; let _ = window.update(cx, |_, window, cx| { editor.update(cx, |editor, cx| { - if editor.has_active_inline_completion() { + if editor.has_active_edit_prediction() { editor.cancel(&Default::default(), window, cx); } }); @@ -231,8 +231,8 @@ impl InlineAssistant { ); if DisableAiSettings::get_global(cx).disable_ai { - // Cancel any active completions - if editor.has_active_inline_completion() { + // Cancel any active edit predictions + if editor.has_active_edit_prediction() { editor.cancel(&Default::default(), window, cx); } } @@ -345,7 +345,7 @@ impl InlineAssistant { if let Some(answer) = answer { if answer == 0 { cx.update(|window, cx| { - window.dispatch_action(Box::new(OpenConfiguration), cx) + window.dispatch_action(Box::new(OpenSettings), cx) }) .ok(); } diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index ade7a5e13d..a5f90edb57 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -541,7 +541,7 @@ impl PromptEditor { match &self.mode { PromptEditorMode::Terminal { .. } => vec![ accept, - IconButton::new("confirm", IconName::Play) + IconButton::new("confirm", IconName::PlayOutlined) .icon_color(Color::Info) .shape(IconButtonShape::Square) .tooltip(|window, cx| { diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 655e87d7cd..7121624c87 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -576,7 +576,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { .icon_position(IconPosition::Start) .on_click(|_, window, cx| { window.dispatch_action( - zed_actions::agent::OpenConfiguration.boxed_clone(), + zed_actions::agent::OpenSettings.boxed_clone(), cx, ); }), diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 62be5629f1..2185885347 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -9,6 +9,7 @@ use crate::ui::{ MaxModeTooltip, preview::{AgentPreview, UsageCallout}, }; +use agent::history_store::HistoryStore; use agent::{ context::{AgentContextKey, ContextLoadResult, load_context}, context_store::ContextStoreEvent, @@ -16,7 +17,7 @@ use agent::{ use agent_settings::{AgentSettings, CompletionMode}; use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; -use client::UserStore; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; use editor::display_map::CreaseId; @@ -29,8 +30,9 @@ use fs::Fs; use futures::future::Shared; use futures::{FutureExt as _, future}; use gpui::{ - Animation, AnimationExt, App, Entity, EventEmitter, Focusable, KeyContext, Subscription, Task, - TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, + Animation, AnimationExt, App, Entity, EventEmitter, Focusable, IntoElement, KeyContext, + Subscription, Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, + pulsating_between, }; use language::{Buffer, Language, Point}; use language_model::{ @@ -40,7 +42,6 @@ use language_model::{ use multi_buffer; use project::Project; use prompt_store::PromptStore; -use proto::Plan; use settings::Settings; use std::time::Duration; use theme::ThemeSettings; @@ -51,7 +52,6 @@ use util::ResultExt as _; use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::Chat; use zed_actions::agent::ToggleModelSelector; -use zed_llm_client::CompletionIntent; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; @@ -77,9 +77,9 @@ pub struct MessageEditor { editor: Entity, workspace: WeakEntity, project: Entity, - user_store: Entity, context_store: Entity, prompt_store: Option>, + history_store: Option>, context_strip: Entity, context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, @@ -156,11 +156,11 @@ impl MessageEditor { pub fn new( fs: Arc, workspace: WeakEntity, - user_store: Entity, context_store: Entity, prompt_store: Option>, thread_store: WeakEntity, text_thread_store: WeakEntity, + history_store: Option>, thread: Entity, window: &mut Window, cx: &mut Context, @@ -227,12 +227,12 @@ impl MessageEditor { Self { editor: editor.clone(), project: thread.read(cx).project().clone(), - user_store, thread, incompatible_tools_state: incompatible_tools.clone(), workspace, context_store, prompt_store, + history_store, context_strip, context_picker_menu_handle, load_context_task: None, @@ -1282,24 +1282,12 @@ impl MessageEditor { return None; } - let user_store = self.user_store.read(cx); - - let ubb_enable = user_store - .usage_based_billing_enabled() - .map_or(false, |enabled| enabled); - - if ubb_enable { + let user_store = self.project.read(cx).user_store().read(cx); + if user_store.is_usage_based_billing_enabled() { return None; } - let plan = user_store - .current_plan() - .map(|plan| match plan { - Plan::Free => zed_llm_client::Plan::ZedFree, - Plan::ZedPro => zed_llm_client::Plan::ZedPro, - Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, - }) - .unwrap_or(zed_llm_client::Plan::ZedFree); + let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree); let usage = user_store.model_request_usage()?; @@ -1661,32 +1649,36 @@ impl Render for MessageEditor { let line_height = TextSize::Small.rems(cx).to_pixels(window.rem_size()) * 1.5; - let in_pro_trial = matches!( - self.user_store.read(cx).current_plan(), - Some(proto::Plan::ZedProTrial) - ); + let has_configured_providers = LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .filter(|provider| { + provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID + }) + .count() + > 0; - let pro_user = matches!( - self.user_store.read(cx).current_plan(), - Some(proto::Plan::ZedPro) - ); + let is_signed_out = self + .workspace + .read_with(cx, |workspace, _| { + workspace.client().status().borrow().is_signed_out() + }) + .unwrap_or(true); - let configured_providers: Vec<(IconName, SharedString)> = - LanguageModelRegistry::read_global(cx) - .providers() - .iter() - .filter(|provider| { - provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID - }) - .map(|provider| (provider.icon(), provider.name().0.clone())) - .collect(); - let has_existing_providers = configured_providers.len() > 0; + let has_history = self + .history_store + .as_ref() + .and_then(|hs| hs.update(cx, |hs, cx| hs.entries(cx).len() > 0).ok()) + .unwrap_or(false) + || self + .thread + .read_with(cx, |thread, _| thread.messages().len() > 0); v_flex() .size_full() .bg(cx.theme().colors().panel_background) .when( - has_existing_providers && !in_pro_trial && !pro_user, + !has_history && is_signed_out && has_configured_providers, |this| this.child(cx.new(ApiKeysWithProviders::new)), ) .when(changed_buffers.len() > 0, |parent| { @@ -1760,7 +1752,6 @@ impl AgentPreview for MessageEditor { ) -> Option { if let Some(workspace) = workspace.upgrade() { let fs = workspace.read(cx).app_state().fs.clone(); - let user_store = workspace.read(cx).app_state().user_store.clone(); let project = workspace.read(cx).project().clone(); let weak_project = project.downgrade(); let context_store = cx.new(|_cx| ContextStore::new(weak_project, None)); @@ -1773,11 +1764,11 @@ impl AgentPreview for MessageEditor { MessageEditor::new( fs, workspace.downgrade(), - user_store, context_store, None, thread_store.downgrade(), text_thread_store.downgrade(), + None, thread, window, cx, diff --git a/crates/agent_ui/src/terminal_inline_assistant.rs b/crates/agent_ui/src/terminal_inline_assistant.rs index 91867957cd..bcbc308c99 100644 --- a/crates/agent_ui/src/terminal_inline_assistant.rs +++ b/crates/agent_ui/src/terminal_inline_assistant.rs @@ -10,6 +10,7 @@ use agent::{ use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; +use cloud_llm_client::CompletionIntent; use collections::{HashMap, VecDeque}; use editor::{MultiBuffer, actions::SelectAll}; use fs::Fs; @@ -27,7 +28,6 @@ use terminal_view::TerminalView; use ui::prelude::*; use util::ResultExt; use workspace::{Toast, Workspace, notifications::NotificationId}; -use zed_llm_client::CompletionIntent; pub fn init( fs: Arc, diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 3df0a48aa4..4836a95c8e 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -12,7 +12,7 @@ use assistant_slash_commands::{ use client::{proto, zed_urls}; use collections::{BTreeSet, HashMap, HashSet, hash_map}; use editor::{ - Anchor, Editor, EditorEvent, MenuInlineCompletionsPolicy, MultiBuffer, MultiBufferSnapshot, + Anchor, Editor, EditorEvent, MenuEditPredictionsPolicy, MultiBuffer, MultiBufferSnapshot, RowExt, ToOffset as _, ToPoint, actions::{MoveToEndOfLine, Newline, ShowCompletions}, display_map::{ @@ -254,7 +254,7 @@ impl TextThreadEditor { editor.set_show_wrap_guides(false, cx); editor.set_show_indent_guides(false, cx); editor.set_completion_provider(Some(Rc::new(completion_provider))); - editor.set_menu_inline_completions_policy(MenuInlineCompletionsPolicy::Never); + editor.set_menu_edit_predictions_policy(MenuEditPredictionsPolicy::Never); editor.set_collaboration_hub(Box::new(project.clone())); let show_edit_predictions = all_language_settings(None, cx) diff --git a/crates/agent_ui/src/thread_history.rs b/crates/agent_ui/src/thread_history.rs index a2ee816f73..b8d1db88d6 100644 --- a/crates/agent_ui/src/thread_history.rs +++ b/crates/agent_ui/src/thread_history.rs @@ -701,7 +701,7 @@ impl RenderOnce for HistoryEntryElement { .on_hover(self.on_hover) .end_slot::(if self.hovered || self.selected { Some( - IconButton::new("delete", IconName::TrashAlt) + IconButton::new("delete", IconName::Trash) .shape(IconButtonShape::Square) .icon_size(IconSize::XSmall) .icon_color(Color::Muted) diff --git a/crates/agent_ui/src/ui.rs b/crates/agent_ui/src/ui.rs index 15f2e28e58..b477a8c385 100644 --- a/crates/agent_ui/src/ui.rs +++ b/crates/agent_ui/src/ui.rs @@ -5,7 +5,6 @@ mod end_trial_upsell; mod new_thread_button; mod onboarding_modal; pub mod preview; -mod upsell; pub use agent_notification::*; pub use burn_mode_tooltip::*; diff --git a/crates/agent_ui/src/ui/end_trial_upsell.rs b/crates/agent_ui/src/ui/end_trial_upsell.rs index 9c2dd98d20..3a8a119800 100644 --- a/crates/agent_ui/src/ui/end_trial_upsell.rs +++ b/crates/agent_ui/src/ui/end_trial_upsell.rs @@ -1,9 +1,9 @@ use std::sync::Arc; -use ai_onboarding::{AgentPanelOnboardingCard, BulletItem}; +use ai_onboarding::{AgentPanelOnboardingCard, PlanDefinitions}; use client::zed_urls; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; -use ui::{Divider, List, prelude::*}; +use ui::{Divider, Tooltip, prelude::*}; #[derive(IntoElement, RegisterComponent)] pub struct EndTrialUpsell { @@ -18,6 +18,8 @@ impl EndTrialUpsell { impl RenderOnce for EndTrialUpsell { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let plan_definitions = PlanDefinitions; + let pro_section = v_flex() .gap_1() .child( @@ -31,16 +33,15 @@ impl RenderOnce for EndTrialUpsell { ) .child(Divider::horizontal()), ) - .child( - List::new() - .child(BulletItem::new("500 prompts per month with Claude models")) - .child(BulletItem::new("Unlimited edit predictions")), - ) + .child(plan_definitions.pro_plan(false)) .child( Button::new("cta-button", "Upgrade to Zed Pro") .full_width() .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))), + .on_click(move |_, _window, cx| { + telemetry::event!("Upgrade To Pro Clicked", state = "end-of-trial"); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), ); let free_section = v_flex() @@ -55,54 +56,58 @@ impl RenderOnce for EndTrialUpsell { .color(Color::Muted) .buffer_font(cx), ) + .child( + Label::new("(Current Plan)") + .size(LabelSize::Small) + .color(Color::Custom(cx.theme().colors().text_muted.opacity(0.6))) + .buffer_font(cx), + ) .child(Divider::horizontal()), ) - .child( - List::new() - .child(BulletItem::new( - "50 prompts per month with the Claude models", - )) - .child(BulletItem::new( - "2000 accepted edit predictions using our open-source Zeta model", - )), - ) - .child( - Button::new("dismiss-button", "Stay on Free") - .full_width() - .style(ButtonStyle::Outlined) - .on_click({ - let callback = self.dismiss_upsell.clone(); - move |_, window, cx| callback(window, cx) - }), - ); + .child(plan_definitions.free_plan()); AgentPanelOnboardingCard::new() - .child(Headline::new("Your Zed Pro trial has expired.")) + .child(Headline::new("Your Zed Pro Trial has expired")) .child( Label::new("You've been automatically reset to the Free plan.") - .size(LabelSize::Small) .color(Color::Muted) - .mb_1(), + .mb_2(), ) .child(pro_section) .child(free_section) + .child( + h_flex().absolute().top_4().right_4().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click({ + let callback = self.dismiss_upsell.clone(); + move |_, window, cx| { + telemetry::event!("Banner Dismissed", source = "AI Onboarding"); + callback(window, cx) + } + }), + ), + ) } } impl Component for EndTrialUpsell { fn scope() -> ComponentScope { - ComponentScope::Agent + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "End of Trial Upsell Banner" } fn sort_name() -> &'static str { - "AgentEndTrialUpsell" + "End of Trial Upsell Banner" } fn preview(_window: &mut Window, _cx: &mut App) -> Option { Some( v_flex() - .p_4() - .gap_4() .child(EndTrialUpsell { dismiss_upsell: Arc::new(|_, _| {}), }) diff --git a/crates/agent_ui/src/ui/preview/usage_callouts.rs b/crates/agent_ui/src/ui/preview/usage_callouts.rs index 45af41395b..64869a6ec7 100644 --- a/crates/agent_ui/src/ui/preview/usage_callouts.rs +++ b/crates/agent_ui/src/ui/preview/usage_callouts.rs @@ -1,8 +1,8 @@ use client::{ModelRequestUsage, RequestUsage, zed_urls}; +use cloud_llm_client::{Plan, UsageLimit}; use component::{empty_example, example_group_with_title, single_example}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; use ui::{Callout, prelude::*}; -use zed_llm_client::{Plan, UsageLimit}; #[derive(IntoElement, RegisterComponent)] pub struct UsageCallout { diff --git a/crates/agent_ui/src/ui/upsell.rs b/crates/agent_ui/src/ui/upsell.rs deleted file mode 100644 index f311aade22..0000000000 --- a/crates/agent_ui/src/ui/upsell.rs +++ /dev/null @@ -1,163 +0,0 @@ -use component::{Component, ComponentScope, single_example}; -use gpui::{ - AnyElement, App, ClickEvent, IntoElement, ParentElement, RenderOnce, SharedString, Styled, - Window, -}; -use theme::ActiveTheme; -use ui::{ - Button, ButtonCommon, ButtonStyle, Checkbox, Clickable, Color, Label, LabelCommon, - RegisterComponent, ToggleState, h_flex, v_flex, -}; - -/// A component that displays an upsell message with a call-to-action button -/// -/// # Example -/// ``` -/// let upsell = Upsell::new( -/// "Upgrade to Zed Pro", -/// "Get access to advanced AI features and more", -/// "Upgrade Now", -/// Box::new(|_, _window, cx| { -/// cx.open_url("https://zed.dev/pricing"); -/// }), -/// Box::new(|_, _window, cx| { -/// // Handle dismiss -/// }), -/// Box::new(|checked, window, cx| { -/// // Handle don't show again -/// }), -/// ); -/// ``` -#[derive(IntoElement, RegisterComponent)] -pub struct Upsell { - title: SharedString, - message: SharedString, - cta_text: SharedString, - on_click: Box, - on_dismiss: Box, - on_dont_show_again: Box, -} - -impl Upsell { - /// Create a new upsell component - pub fn new( - title: impl Into, - message: impl Into, - cta_text: impl Into, - on_click: Box, - on_dismiss: Box, - on_dont_show_again: Box, - ) -> Self { - Self { - title: title.into(), - message: message.into(), - cta_text: cta_text.into(), - on_click, - on_dismiss, - on_dont_show_again, - } - } -} - -impl RenderOnce for Upsell { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - v_flex() - .w_full() - .p_4() - .gap_3() - .bg(cx.theme().colors().surface_background) - .rounded_md() - .border_1() - .border_color(cx.theme().colors().border) - .child( - v_flex() - .gap_1() - .child( - Label::new(self.title) - .size(ui::LabelSize::Large) - .weight(gpui::FontWeight::BOLD), - ) - .child(Label::new(self.message).color(Color::Muted)), - ) - .child( - h_flex() - .w_full() - .justify_between() - .items_center() - .child( - h_flex() - .items_center() - .gap_1() - .child( - Checkbox::new("dont-show-again", ToggleState::Unselected).on_click( - move |_, window, cx| { - (self.on_dont_show_again)(true, window, cx); - }, - ), - ) - .child( - Label::new("Don't show again") - .color(Color::Muted) - .size(ui::LabelSize::Small), - ), - ) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss-button", "No Thanks") - .style(ButtonStyle::Subtle) - .on_click(self.on_dismiss), - ) - .child( - Button::new("cta-button", self.cta_text) - .style(ButtonStyle::Filled) - .on_click(self.on_click), - ), - ), - ) - } -} - -impl Component for Upsell { - fn scope() -> ComponentScope { - ComponentScope::Agent - } - - fn name() -> &'static str { - "Upsell" - } - - fn description() -> Option<&'static str> { - Some("A promotional component that displays a message with a call-to-action.") - } - - fn preview(window: &mut Window, cx: &mut App) -> Option { - let examples = vec![ - single_example( - "Default", - Upsell::new( - "Upgrade to Zed Pro", - "Get unlimited access to AI features and more with Zed Pro. Unlock advanced AI capabilities and other premium features.", - "Upgrade Now", - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - ).render(window, cx).into_any_element(), - ), - single_example( - "Short Message", - Upsell::new( - "Try Zed Pro for free", - "Start your 7-day trial today.", - "Start Trial", - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - ).render(window, cx).into_any_element(), - ), - ]; - - Some(v_flex().gap_4().children(examples).into_any_element()) - } -} diff --git a/crates/ai_onboarding/Cargo.toml b/crates/ai_onboarding/Cargo.toml index 9031e14e29..95a45b1a6f 100644 --- a/crates/ai_onboarding/Cargo.toml +++ b/crates/ai_onboarding/Cargo.toml @@ -16,10 +16,10 @@ default = [] [dependencies] client.workspace = true +cloud_llm_client.workspace = true component.workspace = true gpui.workspace = true language_model.workspace = true -proto.workspace = true serde.workspace = true smallvec.workspace = true telemetry.workspace = true diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index 5f56e4d26e..b55ad4c895 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -1,8 +1,6 @@ use gpui::{Action, IntoElement, ParentElement, RenderOnce, point}; use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; -use ui::{Divider, List, prelude::*}; - -use crate::BulletItem; +use ui::{Divider, List, ListBulletItem, prelude::*}; pub struct ApiKeysWithProviders { configured_providers: Vec<(IconName, SharedString)>, @@ -128,7 +126,7 @@ impl RenderOnce for ApiKeysWithoutProviders { ) .child(Divider::horizontal()), ) - .child(List::new().child(BulletItem::new( + .child(List::new().child(ListBulletItem::new( "Add your own keys to use AI without signing in.", ))) .child( @@ -136,10 +134,7 @@ impl RenderOnce for ApiKeysWithoutProviders { .full_width() .style(ButtonStyle::Outlined) .on_click(move |_, window, cx| { - window.dispatch_action( - zed_actions::agent::OpenConfiguration.boxed_clone(), - cx, - ); + window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx); }), ) } diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index 771482abf3..f1629eeff8 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use client::{Client, UserStore}; +use cloud_llm_client::Plan; use gpui::{Entity, IntoElement, ParentElement}; use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; use ui::prelude::*; @@ -56,10 +57,8 @@ impl AgentPanelOnboarding { impl Render for AgentPanelOnboarding { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let enrolled_in_trial = matches!( - self.user_store.read(cx).current_plan(), - Some(proto::Plan::ZedProTrial) - ); + let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial); + let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro); AgentPanelOnboardingCard::new() .child( @@ -75,7 +74,7 @@ impl Render for AgentPanelOnboarding { }), ) .map(|this| { - if enrolled_in_trial || self.configured_providers.len() >= 1 { + if enrolled_in_trial || is_pro_user || self.configured_providers.len() >= 1 { this } else { this.child(ApiKeysWithoutProviders::new()) diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs index f9a91503ae..b9a1e49a4a 100644 --- a/crates/ai_onboarding/src/ai_onboarding.rs +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -1,49 +1,27 @@ mod agent_api_keys_onboarding; mod agent_panel_onboarding_card; mod agent_panel_onboarding_content; +mod ai_upsell_card; mod edit_prediction_onboarding_content; +mod plan_definitions; mod young_account_banner; pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders}; pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; pub use agent_panel_onboarding_content::AgentPanelOnboarding; +pub use ai_upsell_card::AiUpsellCard; +use cloud_llm_client::Plan; pub use edit_prediction_onboarding_content::EditPredictionOnboarding; +pub use plan_definitions::PlanDefinitions; pub use young_account_banner::YoungAccountBanner; use std::sync::Arc; use client::{Client, UserStore, zed_urls}; -use gpui::{AnyElement, Entity, IntoElement, ParentElement, SharedString}; -use ui::{Divider, List, ListItem, RegisterComponent, TintColor, Tooltip, prelude::*}; - -pub struct BulletItem { - label: SharedString, -} - -impl BulletItem { - pub fn new(label: impl Into) -> Self { - Self { - label: label.into(), - } - } -} - -impl IntoElement for BulletItem { - type Element = AnyElement; - - fn into_element(self) -> Self::Element { - ListItem::new("list-item") - .selectable(false) - .start_slot( - Icon::new(IconName::Dash) - .size(IconSize::XSmall) - .color(Color::Hidden), - ) - .child(div().w_full().child(Label::new(self.label))) - .into_any_element() - } -} +use gpui::{AnyElement, Entity, IntoElement, ParentElement}; +use ui::{Divider, RegisterComponent, TintColor, Tooltip, prelude::*}; +#[derive(PartialEq)] pub enum SignInStatus { SignedIn, SigningIn, @@ -66,7 +44,7 @@ impl From for SignInStatus { pub struct ZedAiOnboarding { pub sign_in_status: SignInStatus, pub has_accepted_terms_of_service: bool, - pub plan: Option, + pub plan: Option, pub account_too_young: bool, pub continue_with_zed_ai: Arc, pub sign_in: Arc, @@ -86,8 +64,8 @@ impl ZedAiOnboarding { Self { sign_in_status: status.into(), - has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false), - plan: store.current_plan(), + has_accepted_terms_of_service: store.has_accepted_terms_of_service(), + plan: store.plan(), account_too_young: store.account_too_young(), continue_with_zed_ai, accept_terms_of_service: Arc::new({ @@ -100,11 +78,9 @@ impl ZedAiOnboarding { sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); - async move |cx| { - client.authenticate_and_connect(true, cx).await; - } + async move |cx| client.sign_in_with_optional_connect(true, cx).await }) - .detach(); + .detach_and_log_err(cx); }), dismiss_onboarding: None, } @@ -118,107 +94,6 @@ impl ZedAiOnboarding { self } - fn free_plan_definition(&self, cx: &mut App) -> impl IntoElement { - v_flex() - .mt_2() - .gap_1() - .child( - h_flex() - .gap_2() - .child( - Label::new("Free") - .size(LabelSize::Small) - .color(Color::Muted) - .buffer_font(cx), - ) - .child( - Label::new("(Current Plan)") - .size(LabelSize::Small) - .color(Color::Custom(cx.theme().colors().text_muted.opacity(0.6))) - .buffer_font(cx), - ) - .child(Divider::horizontal()), - ) - .child( - List::new() - .child(BulletItem::new("50 prompts per month with Claude models")) - .child(BulletItem::new( - "2,000 accepted edit predictions with Zeta, our open-source model", - )), - ) - } - - fn pro_trial_definition(&self) -> impl IntoElement { - List::new() - .child(BulletItem::new("150 prompts with Claude models")) - .child(BulletItem::new( - "Unlimited accepted edit predictions with Zeta, our open-source model", - )) - } - - fn pro_plan_definition(&self, cx: &mut App) -> impl IntoElement { - v_flex().mt_2().gap_1().map(|this| { - if self.account_too_young { - this.child( - h_flex() - .gap_2() - .child( - Label::new("Pro") - .size(LabelSize::Small) - .color(Color::Accent) - .buffer_font(cx), - ) - .child(Divider::horizontal()), - ) - .child( - List::new() - .child(BulletItem::new("500 prompts per month with Claude models")) - .child(BulletItem::new( - "Unlimited accepted edit predictions with Zeta, our open-source model", - )) - .child(BulletItem::new("$20 USD per month")), - ) - .child( - Button::new("pro", "Get Started") - .full_width() - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .on_click(move |_, _window, cx| { - telemetry::event!("Upgrade To Pro Clicked", state = "young-account"); - cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) - }), - ) - } else { - this.child( - h_flex() - .gap_2() - .child( - Label::new("Pro Trial") - .size(LabelSize::Small) - .color(Color::Accent) - .buffer_font(cx), - ) - .child(Divider::horizontal()), - ) - .child( - List::new() - .child(self.pro_trial_definition()) - .child(BulletItem::new( - "Try it out for 14 days for free, no credit card required", - )), - ) - .child( - Button::new("pro", "Start Free Trial") - .full_width() - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .on_click(move |_, _window, cx| { - telemetry::event!("Start Trial Clicked", state = "post-sign-in"); - cx.open_url(&zed_urls::start_trial_url(cx)) - }), - ) - } - }) - } - fn render_accept_terms_of_service(&self) -> AnyElement { v_flex() .gap_1() @@ -257,6 +132,7 @@ impl ZedAiOnboarding { fn render_sign_in_disclaimer(&self, _cx: &mut App) -> AnyElement { let signing_in = matches!(self.sign_in_status, SignInStatus::SigningIn); + let plan_definitions = PlanDefinitions; v_flex() .gap_1() @@ -266,7 +142,7 @@ impl ZedAiOnboarding { .color(Color::Muted) .mb_2(), ) - .child(self.pro_trial_definition()) + .child(plan_definitions.pro_plan(false)) .child( Button::new("sign_in", "Try Zed Pro for Free") .disabled(signing_in) @@ -285,43 +161,132 @@ impl ZedAiOnboarding { fn render_free_plan_state(&self, cx: &mut App) -> AnyElement { let young_account_banner = YoungAccountBanner; + let plan_definitions = PlanDefinitions; - v_flex() - .relative() - .gap_1() - .child(Headline::new("Welcome to Zed AI")) - .map(|this| { - if self.account_too_young { - this.child(young_account_banner) - } else { - this.child(self.free_plan_definition(cx)).when_some( - self.dismiss_onboarding.as_ref(), - |this, dismiss_callback| { - let callback = dismiss_callback.clone(); + if self.account_too_young { + v_flex() + .relative() + .max_w_full() + .gap_1() + .child(Headline::new("Welcome to Zed AI")) + .child(young_account_banner) + .child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_plan(true)) + .child( + Button::new("pro", "Get Started") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Upgrade To Pro Clicked", + state = "young-account" + ); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), + ), + ) + .into_any_element() + } else { + v_flex() + .relative() + .gap_1() + .child(Headline::new("Welcome to Zed AI")) + .child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Free") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child( + Label::new("(Current Plan)") + .size(LabelSize::Small) + .color(Color::Custom( + cx.theme().colors().text_muted.opacity(0.6), + )) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.free_plan()), + ) + .when_some( + self.dismiss_onboarding.as_ref(), + |this, dismiss_callback| { + let callback = dismiss_callback.clone(); - this.child( - h_flex().absolute().top_0().right_0().child( - IconButton::new("dismiss_onboarding", IconName::Close) - .icon_size(IconSize::Small) - .tooltip(Tooltip::text("Dismiss")) - .on_click(move |_, window, cx| { - telemetry::event!( - "Banner Dismissed", - source = "AI Onboarding", - ); - callback(window, cx) - }), - ), - ) - }, - ) - } - }) - .child(self.pro_plan_definition(cx)) - .into_any_element() + this.child( + h_flex().absolute().top_0().right_0().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click(move |_, window, cx| { + telemetry::event!( + "Banner Dismissed", + source = "AI Onboarding", + ); + callback(window, cx) + }), + ), + ) + }, + ) + .child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro Trial") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_trial(true)) + .child( + Button::new("pro", "Start Free Trial") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Start Trial Clicked", + state = "post-sign-in" + ); + cx.open_url(&zed_urls::start_trial_url(cx)) + }), + ), + ) + .into_any_element() + } } fn render_trial_state(&self, _cx: &mut App) -> AnyElement { + let plan_definitions = PlanDefinitions; + v_flex() .relative() .gap_1() @@ -331,13 +296,7 @@ impl ZedAiOnboarding { .color(Color::Muted) .mb_2(), ) - .child( - List::new() - .child(BulletItem::new("150 prompts with Claude models")) - .child(BulletItem::new( - "Unlimited edit predictions with Zeta, our open-source model", - )), - ) + .child(plan_definitions.pro_trial(false)) .when_some( self.dismiss_onboarding.as_ref(), |this, dismiss_callback| { @@ -362,6 +321,8 @@ impl ZedAiOnboarding { } fn render_pro_plan_state(&self, _cx: &mut App) -> AnyElement { + let plan_definitions = PlanDefinitions; + v_flex() .gap_1() .child(Headline::new("Welcome to Zed Pro")) @@ -370,11 +331,7 @@ impl ZedAiOnboarding { .color(Color::Muted) .mb_2(), ) - .child( - List::new() - .child(BulletItem::new("500 prompts with Claude models")) - .child(BulletItem::new("Unlimited edit predictions")), - ) + .child(plan_definitions.pro_plan(false)) .child( Button::new("pro", "Continue with Zed Pro") .full_width() @@ -396,9 +353,9 @@ impl RenderOnce for ZedAiOnboarding { if matches!(self.sign_in_status, SignInStatus::SignedIn) { if self.has_accepted_terms_of_service { match self.plan { - None | Some(proto::Plan::Free) => self.render_free_plan_state(cx), - Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx), - Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx), + None | Some(Plan::ZedFree) => self.render_free_plan_state(cx), + Some(Plan::ZedProTrial) => self.render_trial_state(cx), + Some(Plan::ZedPro) => self.render_pro_plan_state(cx), } } else { self.render_accept_terms_of_service() @@ -411,14 +368,22 @@ impl RenderOnce for ZedAiOnboarding { impl Component for ZedAiOnboarding { fn scope() -> ComponentScope { - ComponentScope::Agent + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "Agent Panel Banners" + } + + fn sort_name() -> &'static str { + "Agent Panel Banners" } fn preview(_window: &mut Window, _cx: &mut App) -> Option { fn onboarding( sign_in_status: SignInStatus, has_accepted_terms_of_service: bool, - plan: Option, + plan: Option, account_too_young: bool, ) -> AnyElement { ZedAiOnboarding { @@ -436,8 +401,9 @@ impl Component for ZedAiOnboarding { Some( v_flex() - .p_4() .gap_4() + .items_center() + .max_w_4_5() .children(vec![ single_example( "Not Signed-in", @@ -448,30 +414,20 @@ impl Component for ZedAiOnboarding { onboarding(SignInStatus::SignedIn, false, None, false), ), single_example( - "Account too young", - onboarding(SignInStatus::SignedIn, false, None, true), + "Young Account", + onboarding(SignInStatus::SignedIn, true, None, true), ), single_example( "Free Plan", - onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false), ), single_example( "Pro Trial", - onboarding( - SignInStatus::SignedIn, - true, - Some(proto::Plan::ZedProTrial), - false, - ), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false), ), single_example( "Pro Plan", - onboarding( - SignInStatus::SignedIn, - true, - Some(proto::Plan::ZedPro), - false, - ), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false), ), ]) .into_any_element(), diff --git a/crates/ai_onboarding/src/ai_upsell_card.rs b/crates/ai_onboarding/src/ai_upsell_card.rs new file mode 100644 index 0000000000..4e4833f770 --- /dev/null +++ b/crates/ai_onboarding/src/ai_upsell_card.rs @@ -0,0 +1,303 @@ +use std::{sync::Arc, time::Duration}; + +use client::{Client, zed_urls}; +use cloud_llm_client::Plan; +use gpui::{ + Animation, AnimationExt, AnyElement, App, IntoElement, RenderOnce, Transformation, Window, + percentage, +}; +use ui::{Divider, Vector, VectorName, prelude::*}; + +use crate::{SignInStatus, plan_definitions::PlanDefinitions}; + +#[derive(IntoElement, RegisterComponent)] +pub struct AiUpsellCard { + pub sign_in_status: SignInStatus, + pub sign_in: Arc, + pub user_plan: Option, + pub tab_index: Option, +} + +impl AiUpsellCard { + pub fn new(client: Arc, user_plan: Option) -> Self { + let status = *client.status().borrow(); + + Self { + user_plan, + sign_in_status: status.into(), + sign_in: Arc::new(move |_window, cx| { + cx.spawn({ + let client = client.clone(); + async move |cx| client.sign_in_with_optional_connect(true, cx).await + }) + .detach_and_log_err(cx); + }), + tab_index: None, + } + } +} + +impl RenderOnce for AiUpsellCard { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let plan_definitions = PlanDefinitions; + + let pro_section = v_flex() + .flex_grow() + .w_full() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_plan(false)); + + let free_section = v_flex() + .flex_grow() + .w_full() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Free") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.free_plan()); + + let grid_bg = h_flex().absolute().inset_0().w_full().h(px(240.)).child( + Vector::new(VectorName::Grid, rems_from_px(500.), rems_from_px(240.)) + .color(Color::Custom(cx.theme().colors().border.opacity(0.05))), + ); + + let gradient_bg = div() + .absolute() + .inset_0() + .size_full() + .bg(gpui::linear_gradient( + 180., + gpui::linear_color_stop( + cx.theme().colors().elevated_surface_background.opacity(0.8), + 0., + ), + gpui::linear_color_stop( + cx.theme().colors().elevated_surface_background.opacity(0.), + 0.8, + ), + )); + + let description = PlanDefinitions::AI_DESCRIPTION; + + let card = v_flex() + .relative() + .flex_grow() + .p_4() + .pt_3() + .border_1() + .border_color(cx.theme().colors().border) + .rounded_lg() + .overflow_hidden() + .child(grid_bg) + .child(gradient_bg); + + let plans_section = h_flex() + .w_full() + .mt_1p5() + .mb_2p5() + .items_start() + .gap_6() + .child(free_section) + .child(pro_section); + + let footer_container = v_flex().items_center().gap_1(); + + let certified_user_stamp = div() + .absolute() + .top_2() + .right_2() + .size(rems_from_px(72.)) + .child( + Vector::new( + VectorName::CertifiedUserStamp, + rems_from_px(72.), + rems_from_px(72.), + ) + .color(Color::Custom(cx.theme().colors().text_accent.alpha(0.3))) + .with_animation( + "loading_stamp", + Animation::new(Duration::from_secs(10)).repeat(), + |this, delta| this.transform(Transformation::rotate(percentage(delta))), + ), + ); + + let pro_trial_stamp = div() + .absolute() + .top_2() + .right_2() + .size(rems_from_px(72.)) + .child( + Vector::new( + VectorName::ProTrialStamp, + rems_from_px(72.), + rems_from_px(72.), + ) + .color(Color::Custom(cx.theme().colors().text.alpha(0.2))), + ); + + match self.sign_in_status { + SignInStatus::SignedIn => match self.user_plan { + None | Some(Plan::ZedFree) => card + .child(Label::new("Try Zed AI").size(LabelSize::Large)) + .child( + div() + .max_w_3_4() + .mb_2() + .child(Label::new(description).color(Color::Muted)), + ) + .child(plans_section) + .child( + footer_container + .child( + Button::new("start_trial", "Start 14-day Free Pro Trial") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .when_some(self.tab_index, |this, tab_index| { + this.tab_index(tab_index) + }) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Start Trial Clicked", + state = "post-sign-in" + ); + cx.open_url(&zed_urls::start_trial_url(cx)) + }), + ) + .child( + Label::new("No credit card required") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ), + Some(Plan::ZedProTrial) => card + .child(pro_trial_stamp) + .child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large)) + .child( + Label::new("Here's what you get for the next 14 days:") + .color(Color::Muted) + .mb_2(), + ) + .child(plan_definitions.pro_trial(false)), + Some(Plan::ZedPro) => card + .child(certified_user_stamp) + .child(Label::new("You're in the Zed Pro plan").size(LabelSize::Large)) + .child( + Label::new("Here's what you get:") + .color(Color::Muted) + .mb_2(), + ) + .child(plan_definitions.pro_plan(false)), + }, + // Signed Out State + _ => card + .child(Label::new("Try Zed AI").size(LabelSize::Large)) + .child( + div() + .max_w_3_4() + .mb_2() + .child(Label::new(description).color(Color::Muted)), + ) + .child(plans_section) + .child( + Button::new("sign_in", "Sign In") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)) + .on_click({ + let callback = self.sign_in.clone(); + move |_, window, cx| { + telemetry::event!("Start Trial Clicked", state = "pre-sign-in"); + callback(window, cx) + } + }), + ), + } + } +} + +impl Component for AiUpsellCard { + fn scope() -> ComponentScope { + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "AI Upsell Card" + } + + fn sort_name() -> &'static str { + "AI Upsell Card" + } + + fn description() -> Option<&'static str> { + Some("A card presenting the Zed AI product during user's first-open onboarding flow.") + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option { + Some( + v_flex() + .gap_4() + .children(vec![example_group(vec![ + single_example( + "Signed Out State", + AiUpsellCard { + sign_in_status: SignInStatus::SignedOut, + sign_in: Arc::new(|_, _| {}), + user_plan: None, + tab_index: Some(0), + } + .into_any_element(), + ), + single_example( + "Free Plan", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + user_plan: Some(Plan::ZedFree), + tab_index: Some(1), + } + .into_any_element(), + ), + single_example( + "Pro Trial", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + user_plan: Some(Plan::ZedProTrial), + tab_index: Some(1), + } + .into_any_element(), + ), + single_example( + "Pro Plan", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + user_plan: Some(Plan::ZedPro), + tab_index: Some(1), + } + .into_any_element(), + ), + ])]) + .into_any_element(), + ) + } +} diff --git a/crates/ai_onboarding/src/plan_definitions.rs b/crates/ai_onboarding/src/plan_definitions.rs new file mode 100644 index 0000000000..8d66f6c356 --- /dev/null +++ b/crates/ai_onboarding/src/plan_definitions.rs @@ -0,0 +1,39 @@ +use gpui::{IntoElement, ParentElement}; +use ui::{List, ListBulletItem, prelude::*}; + +/// Centralized definitions for Zed AI plans +pub struct PlanDefinitions; + +impl PlanDefinitions { + pub const AI_DESCRIPTION: &'static str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI."; + + pub fn free_plan(&self) -> impl IntoElement { + List::new() + .child(ListBulletItem::new("50 prompts with Claude models")) + .child(ListBulletItem::new("2,000 accepted edit predictions")) + } + + pub fn pro_trial(&self, period: bool) -> impl IntoElement { + List::new() + .child(ListBulletItem::new("150 prompts with Claude models")) + .child(ListBulletItem::new( + "Unlimited edit predictions with Zeta, our open-source model", + )) + .when(period, |this| { + this.child(ListBulletItem::new( + "Try it out for 14 days for free, no credit card required", + )) + }) + } + + pub fn pro_plan(&self, price: bool) -> impl IntoElement { + List::new() + .child(ListBulletItem::new("500 prompts with Claude models")) + .child(ListBulletItem::new( + "Unlimited edit predictions with Zeta, our open-source model", + )) + .when(price, |this| { + this.child(ListBulletItem::new("$20 USD per month")) + }) + } +} diff --git a/crates/ai_onboarding/src/young_account_banner.rs b/crates/ai_onboarding/src/young_account_banner.rs index a43625a60e..54f563e4aa 100644 --- a/crates/ai_onboarding/src/young_account_banner.rs +++ b/crates/ai_onboarding/src/young_account_banner.rs @@ -15,6 +15,7 @@ impl RenderOnce for YoungAccountBanner { .child(YOUNG_ACCOUNT_DISCLAIMER); div() + .max_w_full() .my_1() .child(Banner::new().severity(ui::Severity::Warning).child(label)) } diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index c73f606045..3ff1666755 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -36,11 +36,18 @@ pub enum AnthropicModelMode { pub enum Model { #[serde(rename = "claude-opus-4", alias = "claude-opus-4-latest")] ClaudeOpus4, + #[serde(rename = "claude-opus-4-1", alias = "claude-opus-4-1-latest")] + ClaudeOpus4_1, #[serde( rename = "claude-opus-4-thinking", alias = "claude-opus-4-thinking-latest" )] ClaudeOpus4Thinking, + #[serde( + rename = "claude-opus-4-1-thinking", + alias = "claude-opus-4-1-thinking-latest" + )] + ClaudeOpus4_1Thinking, #[default] #[serde(rename = "claude-sonnet-4", alias = "claude-sonnet-4-latest")] ClaudeSonnet4, @@ -91,10 +98,18 @@ impl Model { } pub fn from_id(id: &str) -> Result { + if id.starts_with("claude-opus-4-1-thinking") { + return Ok(Self::ClaudeOpus4_1Thinking); + } + if id.starts_with("claude-opus-4-thinking") { return Ok(Self::ClaudeOpus4Thinking); } + if id.starts_with("claude-opus-4-1") { + return Ok(Self::ClaudeOpus4_1); + } + if id.starts_with("claude-opus-4") { return Ok(Self::ClaudeOpus4); } @@ -141,7 +156,9 @@ impl Model { pub fn id(&self) -> &str { match self { Self::ClaudeOpus4 => "claude-opus-4-latest", + Self::ClaudeOpus4_1 => "claude-opus-4-1-latest", Self::ClaudeOpus4Thinking => "claude-opus-4-thinking-latest", + Self::ClaudeOpus4_1Thinking => "claude-opus-4-1-thinking-latest", Self::ClaudeSonnet4 => "claude-sonnet-4-latest", Self::ClaudeSonnet4Thinking => "claude-sonnet-4-thinking-latest", Self::Claude3_5Sonnet => "claude-3-5-sonnet-latest", @@ -159,6 +176,7 @@ impl Model { pub fn request_id(&self) -> &str { match self { Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking => "claude-opus-4-20250514", + Self::ClaudeOpus4_1 | Self::ClaudeOpus4_1Thinking => "claude-opus-4-1-20250805", Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => "claude-sonnet-4-20250514", Self::Claude3_5Sonnet => "claude-3-5-sonnet-latest", Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => "claude-3-7-sonnet-latest", @@ -173,7 +191,9 @@ impl Model { pub fn display_name(&self) -> &str { match self { Self::ClaudeOpus4 => "Claude Opus 4", + Self::ClaudeOpus4_1 => "Claude Opus 4.1", Self::ClaudeOpus4Thinking => "Claude Opus 4 Thinking", + Self::ClaudeOpus4_1Thinking => "Claude Opus 4.1 Thinking", Self::ClaudeSonnet4 => "Claude Sonnet 4", Self::ClaudeSonnet4Thinking => "Claude Sonnet 4 Thinking", Self::Claude3_7Sonnet => "Claude 3.7 Sonnet", @@ -192,7 +212,9 @@ impl Model { pub fn cache_configuration(&self) -> Option { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -215,7 +237,9 @@ impl Model { pub fn max_token_count(&self) -> u64 { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -232,7 +256,9 @@ impl Model { pub fn max_output_tokens(&self) -> u64 { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -249,7 +275,9 @@ impl Model { pub fn default_temperature(&self) -> f32 { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Sonnet @@ -269,6 +297,7 @@ impl Model { pub fn mode(&self) -> AnthropicModelMode { match self { Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeSonnet4 | Self::Claude3_5Sonnet | Self::Claude3_7Sonnet @@ -277,6 +306,7 @@ impl Model { | Self::Claude3Sonnet | Self::Claude3Haiku => AnthropicModelMode::Default, Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4Thinking | Self::Claude3_7SonnetThinking => AnthropicModelMode::Thinking { budget_tokens: Some(4_096), diff --git a/crates/assistant_context/Cargo.toml b/crates/assistant_context/Cargo.toml index f35dc43340..8f5ff98790 100644 --- a/crates/assistant_context/Cargo.toml +++ b/crates/assistant_context/Cargo.toml @@ -19,6 +19,7 @@ assistant_slash_commands.workspace = true chrono.workspace = true client.workspace = true clock.workspace = true +cloud_llm_client.workspace = true collections.workspace = true context_server.workspace = true fs.workspace = true @@ -48,7 +49,6 @@ util.workspace = true uuid.workspace = true workspace-hack.workspace = true workspace.workspace = true -zed_llm_client.workspace = true [dev-dependencies] indoc.workspace = true diff --git a/crates/assistant_context/src/assistant_context.rs b/crates/assistant_context/src/assistant_context.rs index 136468e084..557f9592e4 100644 --- a/crates/assistant_context/src/assistant_context.rs +++ b/crates/assistant_context/src/assistant_context.rs @@ -2,15 +2,16 @@ mod assistant_context_tests; mod context_store; -use agent_settings::AgentSettings; +use agent_settings::{AgentSettings, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Context as _, Result, bail}; use assistant_slash_command::{ SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection, SlashCommandResult, SlashCommandWorkingSet, }; use assistant_slash_commands::FileCommandMetadata; -use client::{self, Client, proto, telemetry::Telemetry}; +use client::{self, Client, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry}; use clock::ReplicaId; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use collections::{HashMap, HashSet}; use fs::{Fs, RenameOptions}; use futures::{FutureExt, StreamExt, future::Shared}; @@ -46,7 +47,6 @@ use text::{BufferSnapshot, ToPoint}; use ui::IconName; use util::{ResultExt, TryFutureExt, post_inc}; use uuid::Uuid; -use zed_llm_client::CompletionIntent; pub use crate::context_store::*; @@ -2080,7 +2080,18 @@ impl AssistantContext { }); match event { - LanguageModelCompletionEvent::StatusUpdate { .. } => {} + LanguageModelCompletionEvent::StatusUpdate(status_update) => { + match status_update { + CompletionRequestStatus::UsageUpdated { amount, limit } => { + this.update_model_request_usage( + amount as u32, + limit, + cx, + ); + } + _ => {} + } + } LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; @@ -2677,10 +2688,7 @@ impl AssistantContext { let mut request = self.to_completion_request(Some(&model.model), cx); request.messages.push(LanguageModelRequestMessage { role: Role::User, - content: vec![ - "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`" - .into(), - ], + content: vec![SUMMARIZE_THREAD_PROMPT.into()], cache: false, }); @@ -2956,6 +2964,21 @@ impl AssistantContext { summary.text = custom_summary; cx.emit(ContextEvent::SummaryChanged); } + + fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) { + let Some(project) = &self.project else { + return; + }; + project.read(cx).user_store().update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }); + } } #[derive(Debug, Default)] diff --git a/crates/assistant_context/src/assistant_context_tests.rs b/crates/assistant_context/src/assistant_context_tests.rs index f139d525d3..efcad8ed96 100644 --- a/crates/assistant_context/src/assistant_context_tests.rs +++ b/crates/assistant_context/src/assistant_context_tests.rs @@ -1210,8 +1210,8 @@ async fn test_summarization(cx: &mut TestAppContext) { }); cx.run_until_parked(); - fake_model.stream_last_completion_response("Brief"); - fake_model.stream_last_completion_response(" Introduction"); + fake_model.send_last_completion_stream_text_chunk("Brief"); + fake_model.send_last_completion_stream_text_chunk(" Introduction"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -1274,7 +1274,7 @@ async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { }); cx.run_until_parked(); - fake_model.stream_last_completion_response("A successful summary"); + fake_model.send_last_completion_stream_text_chunk("A successful summary"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -1356,7 +1356,7 @@ fn setup_context_editor_with_fake_model( fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { cx.run_until_parked(); - fake_model.stream_last_completion_response("Assistant response"); + fake_model.send_last_completion_stream_text_chunk("Assistant response"); fake_model.end_last_completion_stream(); cx.run_until_parked(); } diff --git a/crates/assistant_context/src/context_store.rs b/crates/assistant_context/src/context_store.rs index 3400913eb8..3090a7b234 100644 --- a/crates/assistant_context/src/context_store.rs +++ b/crates/assistant_context/src/context_store.rs @@ -767,6 +767,11 @@ impl ContextStore { fn reload(&mut self, cx: &mut Context) -> Task> { let fs = self.fs.clone(); cx.spawn(async move |this, cx| { + pub static ZED_STATELESS: LazyLock = + LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); + if *ZED_STATELESS { + return Ok(()); + } fs.create_dir(contexts_dir()).await?; let mut paths = fs.read_dir(contexts_dir()).await?; diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index 672c048872..025aba060d 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -630,6 +630,11 @@ impl ActionLog { false } }); + if tracked_buffer.unreviewed_edits.is_empty() { + if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status { + tracked_buffer.status = TrackedBufferStatus::Modified; + } + } tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx); } } @@ -775,6 +780,9 @@ impl ActionLog { .retain(|_buffer, tracked_buffer| match tracked_buffer.status { TrackedBufferStatus::Deleted => false, _ => { + if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status { + tracked_buffer.status = TrackedBufferStatus::Modified; + } tracked_buffer.unreviewed_edits.clear(); tracked_buffer.diff_base = tracked_buffer.snapshot.as_rope().clone(); tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx); @@ -2075,6 +2083,134 @@ mod tests { assert_eq!(content, "ai content\nuser added this line"); } + #[gpui::test] + async fn test_reject_after_accepting_hunk_on_created_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/new_file", cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path.clone(), cx)) + .await + .unwrap(); + + // AI creates file with initial content + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v1", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + assert_ne!(unreviewed_hunks(&action_log, cx), vec![]); + + // User accepts the single hunk + action_log.update(cx, |log, cx| { + log.keep_edits_in_range(buffer.clone(), Anchor::MIN..Anchor::MAX, cx) + }); + cx.run_until_parked(); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + + // AI modifies the file + cx.update(|cx| { + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v2", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + assert_ne!(unreviewed_hunks(&action_log, cx), vec![]); + + // User rejects the hunk + action_log + .update(cx, |log, cx| { + log.reject_edits_in_ranges(buffer.clone(), vec![Anchor::MIN..Anchor::MAX], cx) + }) + .await + .unwrap(); + cx.run_until_parked(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await,); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "ai content v1" + ); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + + #[gpui::test] + async fn test_reject_edits_on_previously_accepted_created_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/new_file", cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path.clone(), cx)) + .await + .unwrap(); + + // AI creates file with initial content + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v1", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + + // User clicks "Accept All" + action_log.update(cx, |log, cx| log.keep_all_edits(cx)); + cx.run_until_parked(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); // Hunks are cleared + + // AI modifies file again + cx.update(|cx| { + buffer.update(cx, |buffer, cx| buffer.set_text("ai content v2", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + cx.run_until_parked(); + assert_ne!(unreviewed_hunks(&action_log, cx), vec![]); + + // User clicks "Reject All" + action_log + .update(cx, |log, cx| log.reject_all_edits(cx)) + .await; + cx.run_until_parked(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "ai content v1" + ); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + #[gpui::test(iterations = 100)] async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 554b3f3f3c..22cbaac3f8 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -216,7 +216,12 @@ pub trait Tool: 'static + Send + Sync { /// Returns true if the tool needs the users's confirmation /// before having permission to run. - fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool; + fn needs_confirmation( + &self, + input: &serde_json::Value, + project: &Entity, + cx: &App, + ) -> bool; /// Returns true if the tool may perform edits. fn may_perform_edits(&self) -> bool; diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 9a6ec49914..c0a358917b 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -375,7 +375,12 @@ mod tests { false } - fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + fn needs_confirmation( + &self, + _input: &serde_json::Value, + _project: &Entity, + _cx: &App, + ) -> bool { true } diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 146800e094..d4b8fa3afc 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -21,9 +21,11 @@ assistant_tool.workspace = true buffer_diff.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true component.workspace = true derive_more.workspace = true +diffy = "0.4.2" editor.workspace = true feature_flags.workspace = true futures.workspace = true @@ -63,8 +65,6 @@ web_search.workspace = true which.workspace = true workspace-hack.workspace = true workspace.workspace = true -zed_llm_client.workspace = true -diffy = "0.4.2" [dev-dependencies] lsp = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/copy_path_tool.rs b/crates/assistant_tools/src/copy_path_tool.rs index 1922b5677a..e34ae9ff93 100644 --- a/crates/assistant_tools/src/copy_path_tool.rs +++ b/crates/assistant_tools/src/copy_path_tool.rs @@ -44,7 +44,7 @@ impl Tool for CopyPathTool { "copy_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/create_directory_tool.rs b/crates/assistant_tools/src/create_directory_tool.rs index 224e8357e5..11d969d234 100644 --- a/crates/assistant_tools/src/create_directory_tool.rs +++ b/crates/assistant_tools/src/create_directory_tool.rs @@ -37,7 +37,7 @@ impl Tool for CreateDirectoryTool { include_str!("./create_directory_tool/description.md").into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index b13f9863c9..9e69c18b65 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -33,7 +33,7 @@ impl Tool for DeletePathTool { "delete_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index 84595a37b7..12ab97f820 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -46,7 +46,7 @@ impl Tool for DiagnosticsTool { "diagnostics".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index 0184dff36c..715d106a26 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -7,6 +7,7 @@ mod streaming_fuzzy_matcher; use crate::{Template, Templates}; use anyhow::Result; use assistant_tool::ActionLog; +use cloud_llm_client::CompletionIntent; use create_file_parser::{CreateFileParser, CreateFileParserEvent}; pub use edit_parser::EditFormat; use edit_parser::{EditParser, EditParserEvent, EditParserMetrics}; @@ -29,7 +30,6 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task:: use streaming_diff::{CharOperation, StreamingDiff}; use streaming_fuzzy_matcher::StreamingFuzzyMatcher; use util::debug_panic; -use zed_llm_client::CompletionIntent; #[derive(Serialize)] struct CreateFilePromptTemplate { @@ -962,7 +962,7 @@ mod tests { ); cx.run_until_parked(); - model.stream_last_completion_response("a"); + model.send_last_completion_stream_text_chunk("a"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( @@ -974,7 +974,7 @@ mod tests { None ); - model.stream_last_completion_response("bc"); + model.send_last_completion_stream_text_chunk("bc"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -996,7 +996,7 @@ mod tests { }) ); - model.stream_last_completion_response("abX"); + model.send_last_completion_stream_text_chunk("abX"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); assert_eq!( @@ -1011,7 +1011,7 @@ mod tests { }) ); - model.stream_last_completion_response("cY"); + model.send_last_completion_stream_text_chunk("cY"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); assert_eq!( @@ -1026,8 +1026,8 @@ mod tests { }) ); - model.stream_last_completion_response(""); - model.stream_last_completion_response("hall"); + model.send_last_completion_stream_text_chunk(""); + model.send_last_completion_stream_text_chunk("hall"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( @@ -1042,8 +1042,8 @@ mod tests { }) ); - model.stream_last_completion_response("ucinated old"); - model.stream_last_completion_response(""); + model.send_last_completion_stream_text_chunk("ucinated old"); + model.send_last_completion_stream_text_chunk(""); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1061,8 +1061,8 @@ mod tests { }) ); - model.stream_last_completion_response("hallucinated new"); + model.send_last_completion_stream_text_chunk("hallucinated new"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( @@ -1077,7 +1077,7 @@ mod tests { }) ); - model.stream_last_completion_response("\nghi\nj"); + model.send_last_completion_stream_text_chunk("\nghi\nj"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1099,8 +1099,8 @@ mod tests { }) ); - model.stream_last_completion_response("kl"); - model.stream_last_completion_response(""); + model.send_last_completion_stream_text_chunk("kl"); + model.send_last_completion_stream_text_chunk(""); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1122,7 +1122,7 @@ mod tests { }) ); - model.stream_last_completion_response("GHI"); + model.send_last_completion_stream_text_chunk("GHI"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1367,7 +1367,9 @@ mod tests { cx.background_spawn(async move { for chunk in chunks { executor.simulate_random_delay().await; - model.as_fake().stream_last_completion_response(chunk); + model + .as_fake() + .send_last_completion_stream_text_chunk(chunk); } model.as_fake().end_last_completion_stream(); }) diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index eda7eee0e3..9a8e762455 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -1658,23 +1658,24 @@ impl EditAgentTest { } async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> Result { + const MAX_RETRIES: usize = 20; let mut attempt = 0; + loop { attempt += 1; - match request().await { - Ok(result) => return Ok(result), - Err(err) => match err.downcast::() { - Ok(err) => match &err { + let response = request().await; + + if attempt >= MAX_RETRIES { + return response; + } + + let retry_delay = match &response { + Ok(_) => None, + Err(err) => match err.downcast_ref::() { + Some(err) => match &err { LanguageModelCompletionError::RateLimitExceeded { retry_after, .. } | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => { - let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); - // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time. - let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); - eprintln!( - "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" - ); - Timer::after(retry_after + jitter).await; - continue; + Some(retry_after.unwrap_or(Duration::from_secs(5))) } LanguageModelCompletionError::UpstreamProviderError { status, @@ -1687,23 +1688,31 @@ async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE ) || status.as_u16() == 529; - if !should_retry { - return Err(err.into()); + if should_retry { + // Use server-provided retry_after if available, otherwise use default + Some(retry_after.unwrap_or(Duration::from_secs(5))) + } else { + None } - - // Use server-provided retry_after if available, otherwise use default - let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); - let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); - eprintln!( - "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" - ); - Timer::after(retry_after + jitter).await; - continue; } - _ => return Err(err.into()), + LanguageModelCompletionError::ApiReadResponseError { .. } + | LanguageModelCompletionError::ApiInternalServerError { .. } + | LanguageModelCompletionError::HttpSend { .. } => { + // Exponential backoff for transient I/O and internal server errors + Some(Duration::from_secs(2_u64.pow((attempt - 1) as u32).min(30))) + } + _ => None, }, - Err(err) => return Err(err), + _ => None, }, + }; + + if let Some(retry_after) = retry_delay { + let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); + eprintln!("Attempt #{attempt}: Retry after {retry_after:?} + jitter of {jitter:?}"); + Timer::after(retry_after + jitter).await; + } else { + return response; } } } diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 6413677bd9..dce9f49abd 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -25,6 +25,7 @@ use language::{ }; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use markdown::{Markdown, MarkdownElement, MarkdownStyle}; +use paths; use project::{ Project, ProjectPath, lsp_store::{FormatTrigger, LspFormatTarget}, @@ -126,8 +127,47 @@ impl Tool for EditFileTool { "edit_file".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { - false + fn needs_confirmation( + &self, + input: &serde_json::Value, + project: &Entity, + cx: &App, + ) -> bool { + if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { + return false; + } + + let Ok(input) = serde_json::from_value::(input.clone()) else { + // If it's not valid JSON, it's going to error and confirming won't do anything. + return false; + }; + + // If any path component matches the local settings folder, then this could affect + // the editor in ways beyond the project source, so prompt. + let local_settings_folder = paths::local_settings_folder_relative_path(); + let path = Path::new(&input.path); + if path + .components() + .any(|component| component.as_os_str() == local_settings_folder.as_os_str()) + { + return true; + } + + // It's also possible that the global config dir is configured to be inside the project, + // so check for that edge case too. + if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { + if canonical_path.starts_with(paths::config_dir()) { + return true; + } + } + + // Check if path is inside the global config directory + // First check if it's already inside project - if not, try to canonicalize + let project_path = project.read(cx).find_project_path(&input.path, cx); + + // If the path is inside the project, and it's not one of the above edge cases, + // then no confirmation is necessary. Otherwise, confirmation is necessary. + project_path.is_none() } fn may_perform_edits(&self) -> bool { @@ -148,7 +188,25 @@ impl Tool for EditFileTool { fn ui_text(&self, input: &serde_json::Value) -> String { match serde_json::from_value::(input.clone()) { - Ok(input) => input.display_description, + Ok(input) => { + let path = Path::new(&input.path); + let mut description = input.display_description.clone(); + + // Add context about why confirmation may be needed + let local_settings_folder = paths::local_settings_folder_relative_path(); + if path + .components() + .any(|c| c.as_os_str() == local_settings_folder.as_os_str()) + { + description.push_str(" (local settings)"); + } else if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { + if canonical_path.starts_with(paths::config_dir()) { + description.push_str(" (global settings)"); + } + } + + description + } Err(_) => "Editing file".to_string(), } } @@ -1175,19 +1233,20 @@ async fn build_buffer_diff( #[cfg(test)] mod tests { use super::*; + use ::fs::Fs; use client::TelemetrySettings; - use fs::{FakeFs, Fs}; use gpui::{TestAppContext, UpdateGlobal}; use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; + use std::fs; use util::path; #[gpui::test] async fn test_edit_nonexistent_file(cx: &mut TestAppContext) { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -1277,7 +1336,7 @@ mod tests { ) -> anyhow::Result { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree( "/root", json!({ @@ -1384,6 +1443,21 @@ mod tests { cx.set_global(settings_store); language::init(cx); TelemetrySettings::register(cx); + agent_settings::AgentSettings::register(cx); + Project::init_settings(cx); + }); + } + + fn init_test_with_config(cx: &mut TestAppContext, data_dir: &Path) { + cx.update(|cx| { + // Set custom data directory (config will be under data_dir/config) + paths::set_custom_data_dir(data_dir.to_str().unwrap()); + + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + TelemetrySettings::register(cx); + agent_settings::AgentSettings::register(cx); Project::init_settings(cx); }); } @@ -1392,7 +1466,7 @@ mod tests { async fn test_format_on_save(cx: &mut TestAppContext) { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; @@ -1503,7 +1577,7 @@ mod tests { // Stream the unformatted content cx.executor().run_until_parked(); - model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); model.end_last_completion_stream(); edit_task.await @@ -1567,7 +1641,7 @@ mod tests { // Stream the unformatted content cx.executor().run_until_parked(); - model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); model.end_last_completion_stream(); edit_task.await @@ -1591,7 +1665,7 @@ mod tests { async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) { init_test(cx); - let fs = FakeFs::new(cx.executor()); + let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; // Create a simple file with trailing whitespace @@ -1646,7 +1720,9 @@ mod tests { // Stream the content with trailing whitespace cx.executor().run_until_parked(); - model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); model.end_last_completion_stream(); edit_task.await @@ -1703,7 +1779,9 @@ mod tests { // Stream the content with trailing whitespace cx.executor().run_until_parked(); - model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); model.end_last_completion_stream(); edit_task.await @@ -1723,4 +1801,641 @@ mod tests { "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" ); } + + #[gpui::test] + async fn test_needs_confirmation(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({})).await; + + // Test 1: Path with .zed component should require confirmation + let input_with_zed = json!({ + "display_description": "Edit settings", + "path": ".zed/settings.json", + "mode": "edit" + }); + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_with_zed, &project, cx), + "Path with .zed component should require confirmation" + ); + }); + + // Test 2: Absolute path should require confirmation + let input_absolute = json!({ + "display_description": "Edit file", + "path": "/etc/hosts", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_absolute, &project, cx), + "Absolute path should require confirmation" + ); + }); + + // Test 3: Relative path without .zed should not require confirmation + let input_relative = json!({ + "display_description": "Edit file", + "path": "root/src/main.rs", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input_relative, &project, cx), + "Relative path without .zed should not require confirmation" + ); + }); + + // Test 4: Path with .zed in the middle should require confirmation + let input_zed_middle = json!({ + "display_description": "Edit settings", + "path": "root/.zed/tasks.json", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_zed_middle, &project, cx), + "Path with .zed in any component should require confirmation" + ); + }); + + // Test 5: When always_allow_tool_actions is enabled, no confirmation needed + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = true; + agent_settings::AgentSettings::override_global(settings, cx); + + assert!( + !tool.needs_confirmation(&input_with_zed, &project, cx), + "When always_allow_tool_actions is true, no confirmation should be needed" + ); + assert!( + !tool.needs_confirmation(&input_absolute, &project, cx), + "When always_allow_tool_actions is true, no confirmation should be needed for absolute paths" + ); + }); + } + + #[gpui::test] + async fn test_ui_text_shows_correct_context(cx: &mut TestAppContext) { + // Set up a custom config directory for testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + + // Test ui_text shows context for various paths + let test_cases = vec![ + ( + json!({ + "display_description": "Update config", + "path": ".zed/settings.json", + "mode": "edit" + }), + "Update config (local settings)", + ".zed path should show local settings context", + ), + ( + json!({ + "display_description": "Fix bug", + "path": "src/.zed/local.json", + "mode": "edit" + }), + "Fix bug (local settings)", + "Nested .zed path should show local settings context", + ), + ( + json!({ + "display_description": "Update readme", + "path": "README.md", + "mode": "edit" + }), + "Update readme", + "Normal path should not show additional context", + ), + ( + json!({ + "display_description": "Edit config", + "path": "config.zed", + "mode": "edit" + }), + "Edit config", + ".zed as extension should not show context", + ), + ]; + + for (input, expected_text, description) in test_cases { + cx.update(|_cx| { + let ui_text = tool.ui_text(&input); + assert_eq!(ui_text, expected_text, "Failed for case: {}", description); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_outside_project(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + + // Create a project in /project directory + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test file outside project requires confirmation + let input_outside = json!({ + "display_description": "Edit file", + "path": "/outside/file.txt", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_outside, &project, cx), + "File outside project should require confirmation" + ); + }); + + // Test file inside project doesn't require confirmation + let input_inside = json!({ + "display_description": "Edit file", + "path": "project/file.txt", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input_inside, &project, cx), + "File inside project should not require confirmation" + ); + }); + } + + #[gpui::test] + async fn test_needs_confirmation_config_paths(cx: &mut TestAppContext) { + // Set up a custom data directory for testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/home/user/myproject", json!({})).await; + let project = Project::test(fs.clone(), [path!("/home/user/myproject").as_ref()], cx).await; + + // Get the actual local settings folder name + let local_settings_folder = paths::local_settings_folder_relative_path(); + + // Test various config path patterns + let test_cases = vec![ + ( + format!("{}/settings.json", local_settings_folder.display()), + true, + "Top-level local settings file".to_string(), + ), + ( + format!( + "myproject/{}/settings.json", + local_settings_folder.display() + ), + true, + "Local settings in project path".to_string(), + ), + ( + format!("src/{}/config.toml", local_settings_folder.display()), + true, + "Local settings in subdirectory".to_string(), + ), + ( + ".zed.backup/file.txt".to_string(), + true, + ".zed.backup is outside project".to_string(), + ), + ( + "my.zed/file.txt".to_string(), + true, + "my.zed is outside project".to_string(), + ), + ( + "myproject/src/file.zed".to_string(), + false, + ".zed as file extension".to_string(), + ), + ( + "myproject/normal/path/file.rs".to_string(), + false, + "Normal file without config paths".to_string(), + ), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {} - path: {}", + description, + path + ); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_global_config(cx: &mut TestAppContext) { + // Set up a custom data directory for testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + + // Create test files in the global config directory + let global_config_dir = paths::config_dir(); + fs::create_dir_all(&global_config_dir).unwrap(); + let global_settings_path = global_config_dir.join("settings.json"); + fs::write(&global_settings_path, "{}").unwrap(); + + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test global config paths + let test_cases = vec![ + ( + global_settings_path.to_str().unwrap().to_string(), + true, + "Global settings file should require confirmation", + ), + ( + global_config_dir + .join("keymap.json") + .to_str() + .unwrap() + .to_string(), + true, + "Global keymap file should require confirmation", + ), + ( + "project/normal_file.rs".to_string(), + false, + "Normal project file should not require confirmation", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {}", + description + ); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + + // Create multiple worktree directories + fs.insert_tree( + "/workspace/frontend", + json!({ + "src": { + "main.js": "console.log('frontend');" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/backend", + json!({ + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/shared", + json!({ + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + + // Create project with multiple worktrees + let project = Project::test( + fs.clone(), + [ + path!("/workspace/frontend").as_ref(), + path!("/workspace/backend").as_ref(), + path!("/workspace/shared").as_ref(), + ], + cx, + ) + .await; + + // Test files in different worktrees + let test_cases = vec![ + ("frontend/src/main.js", false, "File in first worktree"), + ("backend/src/main.rs", false, "File in second worktree"), + ( + "shared/.zed/settings.json", + true, + ".zed file in third worktree", + ), + ("/etc/hosts", true, "Absolute path outside all worktrees"), + ( + "../outside/file.txt", + true, + "Relative path outside worktrees", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {} - path: {}", + description, + path + ); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + ".zed": { + "settings.json": "{}" + }, + "src": { + ".zed": { + "local.json": "{}" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test edge cases + let test_cases = vec![ + // Empty path - find_project_path returns Some for empty paths + ("", false, "Empty path is treated as project root"), + // Root directory + ("/", true, "Root directory should be outside project"), + // Parent directory references - find_project_path resolves these + ( + "project/../other", + false, + "Path with .. is resolved by find_project_path", + ), + ( + "project/./src/file.rs", + false, + "Path with . should work normally", + ), + // Windows-style paths (if on Windows) + #[cfg(target_os = "windows")] + ("C:\\Windows\\System32\\hosts", true, "Windows system path"), + #[cfg(target_os = "windows")] + ("project\\src\\main.rs", false, "Windows-style project path"), + ]; + + for (path, should_confirm, description) in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert_eq!( + tool.needs_confirmation(&input, &project, cx), + should_confirm, + "Failed for case: {} - path: {}", + description, + path + ); + }); + } + } + + #[gpui::test] + async fn test_ui_text_with_all_path_types(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + + // Test UI text for various scenarios + let test_cases = vec![ + ( + json!({ + "display_description": "Update config", + "path": ".zed/settings.json", + "mode": "edit" + }), + "Update config (local settings)", + ".zed path should show local settings context", + ), + ( + json!({ + "display_description": "Fix bug", + "path": "src/.zed/local.json", + "mode": "edit" + }), + "Fix bug (local settings)", + "Nested .zed path should show local settings context", + ), + ( + json!({ + "display_description": "Update readme", + "path": "README.md", + "mode": "edit" + }), + "Update readme", + "Normal path should not show additional context", + ), + ( + json!({ + "display_description": "Edit config", + "path": "config.zed", + "mode": "edit" + }), + "Edit config", + ".zed as extension should not show context", + ), + ]; + + for (input, expected_text, description) in test_cases { + cx.update(|_cx| { + let ui_text = tool.ui_text(&input); + assert_eq!(ui_text, expected_text, "Failed for case: {}", description); + }); + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) { + init_test(cx); + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "existing.txt": "content", + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Test different EditFileMode values + let modes = vec![ + EditFileMode::Edit, + EditFileMode::Create, + EditFileMode::Overwrite, + ]; + + for mode in modes { + // Test .zed path with different modes + let input_zed = json!({ + "display_description": "Edit settings", + "path": "project/.zed/settings.json", + "mode": mode + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_zed, &project, cx), + ".zed path should require confirmation regardless of mode: {:?}", + mode + ); + }); + + // Test outside path with different modes + let input_outside = json!({ + "display_description": "Edit file", + "path": "/outside/file.txt", + "mode": mode + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input_outside, &project, cx), + "Outside path should require confirmation regardless of mode: {:?}", + mode + ); + }); + + // Test normal path with different modes + let input_normal = json!({ + "display_description": "Edit file", + "path": "project/normal.txt", + "mode": mode + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input_normal, &project, cx), + "Normal path should not require confirmation regardless of mode: {:?}", + mode + ); + }); + } + } + + #[gpui::test] + async fn test_always_allow_tool_actions_bypasses_all_checks(cx: &mut TestAppContext) { + // Set up with custom directories for deterministic testing + let temp_dir = tempfile::tempdir().unwrap(); + init_test_with_config(cx, temp_dir.path()); + + let tool = Arc::new(EditFileTool); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + + // Enable always_allow_tool_actions + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = true; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + // Test that all paths that normally require confirmation are bypassed + let global_settings_path = paths::config_dir().join("settings.json"); + fs::create_dir_all(paths::config_dir()).unwrap(); + fs::write(&global_settings_path, "{}").unwrap(); + + let test_cases = vec![ + ".zed/settings.json", + "project/.zed/config.toml", + global_settings_path.to_str().unwrap(), + "/etc/hosts", + "/absolute/path/file.txt", + "../outside/project.txt", + ]; + + for path in test_cases { + let input = json!({ + "display_description": "Edit file", + "path": path, + "mode": "edit" + }); + cx.update(|cx| { + assert!( + !tool.needs_confirmation(&input, &project, cx), + "Path {} should not require confirmation when always_allow_tool_actions is true", + path + ); + }); + } + + // Disable always_allow_tool_actions and verify confirmation is required again + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = false; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + // Verify .zed path requires confirmation again + let input = json!({ + "display_description": "Edit file", + "path": ".zed/settings.json", + "mode": "edit" + }); + cx.update(|cx| { + assert!( + tool.needs_confirmation(&input, &project, cx), + ".zed path should require confirmation when always_allow_tool_actions is false" + ); + }); + } } diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index 54d49359ba..a31ec39268 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -116,7 +116,7 @@ impl Tool for FetchTool { "fetch".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs index fd0e44e42c..affc019417 100644 --- a/crates/assistant_tools/src/find_path_tool.rs +++ b/crates/assistant_tools/src/find_path_tool.rs @@ -55,7 +55,7 @@ impl Tool for FindPathTool { "find_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index 053273d71b..43c3d1d990 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -57,7 +57,7 @@ impl Tool for GrepTool { "grep".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index 723416e2ce..b1980615d6 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -45,7 +45,7 @@ impl Tool for ListDirectoryTool { "list_directory".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/move_path_tool.rs b/crates/assistant_tools/src/move_path_tool.rs index 27ae10151d..c1cbbf848d 100644 --- a/crates/assistant_tools/src/move_path_tool.rs +++ b/crates/assistant_tools/src/move_path_tool.rs @@ -42,7 +42,7 @@ impl Tool for MovePathTool { "move_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index b6b1cf90a4..b51b91d3d5 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -33,7 +33,7 @@ impl Tool for NowTool { "now".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/open_tool.rs b/crates/assistant_tools/src/open_tool.rs index 97a4769e19..8fddbb0431 100644 --- a/crates/assistant_tools/src/open_tool.rs +++ b/crates/assistant_tools/src/open_tool.rs @@ -23,7 +23,7 @@ impl Tool for OpenTool { "open".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { true } fn may_perform_edits(&self) -> bool { diff --git a/crates/assistant_tools/src/project_notifications_tool.rs b/crates/assistant_tools/src/project_notifications_tool.rs index 7567926dca..03487e5419 100644 --- a/crates/assistant_tools/src/project_notifications_tool.rs +++ b/crates/assistant_tools/src/project_notifications_tool.rs @@ -19,7 +19,7 @@ impl Tool for ProjectNotificationsTool { "project_notifications".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } fn may_perform_edits(&self) -> bool { diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index dc504e2dc4..ee38273cc0 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -54,7 +54,7 @@ impl Tool for ReadFileTool { "read_file".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 03e76f6a5b..58833c5208 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -77,7 +77,7 @@ impl Tool for TerminalTool { Self::NAME.to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { true } diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index 422204f97d..443c2930be 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -24,7 +24,7 @@ impl Tool for ThinkingTool { "thinking".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index 24bc8e9cba..d4a12f22c5 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -6,6 +6,7 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, }; +use cloud_llm_client::{WebSearchResponse, WebSearchResult}; use futures::{Future, FutureExt, TryFutureExt}; use gpui::{ AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, @@ -17,7 +18,6 @@ use serde::{Deserialize, Serialize}; use ui::{IconName, Tooltip, prelude::*}; use web_search::WebSearchRegistry; use workspace::Workspace; -use zed_llm_client::{WebSearchResponse, WebSearchResult}; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct WebSearchToolInput { @@ -32,7 +32,7 @@ impl Tool for WebSearchTool { "web_search".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { false } diff --git a/crates/audio/Cargo.toml b/crates/audio/Cargo.toml index 960aaf8e08..d857a3eb2f 100644 --- a/crates/audio/Cargo.toml +++ b/crates/audio/Cargo.toml @@ -18,6 +18,6 @@ collections.workspace = true derive_more.workspace = true gpui.workspace = true parking_lot.workspace = true -rodio = { version = "0.20.0", default-features = false, features = ["wav"] } +rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] } util.workspace = true workspace-hack.workspace = true diff --git a/crates/audio/src/assets.rs b/crates/audio/src/assets.rs index 02da79dc24..fd5c935d87 100644 --- a/crates/audio/src/assets.rs +++ b/crates/audio/src/assets.rs @@ -3,12 +3,9 @@ use std::{io::Cursor, sync::Arc}; use anyhow::{Context as _, Result}; use collections::HashMap; use gpui::{App, AssetSource, Global}; -use rodio::{ - Decoder, Source, - source::{Buffered, SamplesConverter}, -}; +use rodio::{Decoder, Source, source::Buffered}; -type Sound = Buffered>>, f32>>; +type Sound = Buffered>>>; pub struct SoundRegistry { cache: Arc>>, @@ -48,7 +45,7 @@ impl SoundRegistry { .with_context(|| format!("No asset available for path {path}"))?? .into_owned(); let cursor = Cursor::new(bytes); - let source = Decoder::new(cursor)?.convert_samples::().buffered(); + let source = Decoder::new(cursor)?.buffered(); self.cache.lock().insert(name.to_string(), source.clone()); diff --git a/crates/audio/src/audio.rs b/crates/audio/src/audio.rs index e7b9a59e8f..44baa16aa2 100644 --- a/crates/audio/src/audio.rs +++ b/crates/audio/src/audio.rs @@ -1,7 +1,7 @@ use assets::SoundRegistry; use derive_more::{Deref, DerefMut}; use gpui::{App, AssetSource, BorrowAppContext, Global}; -use rodio::{OutputStream, OutputStreamHandle}; +use rodio::{OutputStream, OutputStreamBuilder}; use util::ResultExt; mod assets; @@ -37,8 +37,7 @@ impl Sound { #[derive(Default)] pub struct Audio { - _output_stream: Option, - output_handle: Option, + output_handle: Option, } #[derive(Deref, DerefMut)] @@ -51,11 +50,9 @@ impl Audio { Self::default() } - fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> { + fn ensure_output_exists(&mut self) -> Option<&OutputStream> { if self.output_handle.is_none() { - let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip(); - self.output_handle = output_handle; - self._output_stream = _output_stream; + self.output_handle = OutputStreamBuilder::open_default_stream().log_err(); } self.output_handle.as_ref() @@ -69,7 +66,7 @@ impl Audio { cx.update_global::(|this, cx| { let output_handle = this.ensure_output_exists()?; let source = SoundRegistry::global(cx).get(sound.file()).log_err()?; - output_handle.play_raw(source).log_err()?; + output_handle.mixer().add(source); Some(()) }); } @@ -80,7 +77,6 @@ impl Audio { } cx.update_global::(|this, _| { - this._output_stream.take(); this.output_handle.take(); }); } diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index b6eeafa2d6..69d2ffb845 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -32,11 +32,18 @@ pub enum Model { ClaudeSonnet4Thinking, #[serde(rename = "claude-opus-4", alias = "claude-opus-4-latest")] ClaudeOpus4, + #[serde(rename = "claude-opus-4-1", alias = "claude-opus-4-1-latest")] + ClaudeOpus4_1, #[serde( rename = "claude-opus-4-thinking", alias = "claude-opus-4-thinking-latest" )] ClaudeOpus4Thinking, + #[serde( + rename = "claude-opus-4-1-thinking", + alias = "claude-opus-4-1-thinking-latest" + )] + ClaudeOpus4_1Thinking, #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")] Claude3_5SonnetV2, #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")] @@ -147,7 +154,9 @@ impl Model { Model::ClaudeSonnet4 => "claude-4-sonnet", Model::ClaudeSonnet4Thinking => "claude-4-sonnet-thinking", Model::ClaudeOpus4 => "claude-4-opus", + Model::ClaudeOpus4_1 => "claude-4-opus-1", Model::ClaudeOpus4Thinking => "claude-4-opus-thinking", + Model::ClaudeOpus4_1Thinking => "claude-4-opus-1-thinking", Model::Claude3_5SonnetV2 => "claude-3-5-sonnet-v2", Model::Claude3_5Sonnet => "claude-3-5-sonnet", Model::Claude3Opus => "claude-3-opus", @@ -208,6 +217,9 @@ impl Model { Model::ClaudeOpus4 | Model::ClaudeOpus4Thinking => { "anthropic.claude-opus-4-20250514-v1:0" } + Model::ClaudeOpus4_1 | Model::ClaudeOpus4_1Thinking => { + "anthropic.claude-opus-4-1-20250805-v1:0" + } Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0", Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0", Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0", @@ -266,7 +278,9 @@ impl Model { Self::ClaudeSonnet4 => "Claude Sonnet 4", Self::ClaudeSonnet4Thinking => "Claude Sonnet 4 Thinking", Self::ClaudeOpus4 => "Claude Opus 4", + Self::ClaudeOpus4_1 => "Claude Opus 4.1", Self::ClaudeOpus4Thinking => "Claude Opus 4 Thinking", + Self::ClaudeOpus4_1Thinking => "Claude Opus 4.1 Thinking", Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2", Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", Self::Claude3Opus => "Claude 3 Opus", @@ -330,8 +344,10 @@ impl Model { | Self::Claude3_7Sonnet | Self::ClaudeSonnet4 | Self::ClaudeOpus4 + | Self::ClaudeOpus4_1 | Self::ClaudeSonnet4Thinking - | Self::ClaudeOpus4Thinking => 200_000, + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1Thinking => 200_000, Self::AmazonNovaPremier => 1_000_000, Self::PalmyraWriterX5 => 1_000_000, Self::PalmyraWriterX4 => 128_000, @@ -348,7 +364,9 @@ impl Model { | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::ClaudeOpus4 - | Model::ClaudeOpus4Thinking => 128_000, + | Model::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Model::ClaudeOpus4_1Thinking => 128_000, Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192, Self::Custom { max_output_tokens, .. @@ -366,6 +384,8 @@ impl Model { | Self::Claude3_7Sonnet | Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => 1.0, Self::Custom { @@ -387,6 +407,8 @@ impl Model { | Self::Claude3_7SonnetThinking | Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::Claude3_5Haiku => true, @@ -420,7 +442,9 @@ impl Model { | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::ClaudeOpus4 - | Self::ClaudeOpus4Thinking => true, + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking => true, // Custom models - check if they have cache configuration Self::Custom { @@ -440,7 +464,9 @@ impl Model { | Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking | Self::ClaudeOpus4 - | Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration { + | Self::ClaudeOpus4Thinking + | Self::ClaudeOpus4_1 + | Self::ClaudeOpus4_1Thinking => Some(BedrockModelCacheConfiguration { max_cache_anchors: 4, min_total_token: 1024, }), @@ -467,9 +493,11 @@ impl Model { Model::ClaudeSonnet4Thinking => BedrockModelMode::Thinking { budget_tokens: Some(4096), }, - Model::ClaudeOpus4Thinking => BedrockModelMode::Thinking { - budget_tokens: Some(4096), - }, + Model::ClaudeOpus4Thinking | Model::ClaudeOpus4_1Thinking => { + BedrockModelMode::Thinking { + budget_tokens: Some(4096), + } + } _ => BedrockModelMode::Default, } } @@ -518,6 +546,8 @@ impl Model { | Model::ClaudeSonnet4Thinking | Model::ClaudeOpus4 | Model::ClaudeOpus4Thinking + | Model::ClaudeOpus4_1 + | Model::ClaudeOpus4_1Thinking | Model::Claude3Haiku | Model::Claude3Opus | Model::Claude3Sonnet diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 866e3ccd90..4ac37ffd14 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -13,7 +13,7 @@ use std::{ ops::{ControlFlow, Range}, sync::Arc, }; -use sum_tree::{Bias, SumTree}; +use sum_tree::{Bias, Dimensions, SumTree}; use time::OffsetDateTime; use util::{ResultExt as _, TryFutureExt, post_inc}; @@ -331,7 +331,9 @@ impl ChannelChat { .update(&mut cx, |chat, cx| { if let Some(first_id) = chat.first_loaded_message_id() { if first_id <= message_id { - let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>(&()); + let mut cursor = chat + .messages + .cursor::>(&()); let message_id = ChannelMessageId::Saved(message_id); cursor.seek(&message_id, Bias::Left); return ControlFlow::Break( @@ -587,7 +589,9 @@ impl ChannelChat { .map(|m| m.nonce) .collect::>(); - let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>(&()); + let mut old_cursor = self + .messages + .cursor::>(&()); let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left); let start_ix = old_cursor.start().1.0; let removed_messages = old_cursor.slice(&last_message.id, Bias::Right); diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index b7ba811421..4ad156b9fb 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -126,7 +126,7 @@ impl ChannelMembership { proto::channel_member::Kind::Member => 0, proto::channel_member::Kind::Invitee => 1, }, - username_order: self.user.github_login.as_str(), + username_order: &self.user.github_login, } } } diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index f8f5de3c39..c92226eeeb 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -259,20 +259,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) { assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx); }); - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![5]); - server.respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 5, - github_login: "nathansobo".into(), - avatar_url: "http://avatar.com/nathansobo".into(), - name: None, - }], - }, - ); - // Join a channel and populate its existing messages. let channel = channel_store.update(cx, |store, cx| { let channel_id = store.ordered_channels().next().unwrap().1.id; @@ -334,7 +320,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "a".into()), + ("user-5".into(), "a".into()), ("maxbrunsfeld".into(), "b".into()) ] ); @@ -437,7 +423,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "y".into()), + ("user-5".into(), "y".into()), ("maxbrunsfeld".into(), "z".into()) ] ); diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index b741f515fd..365625b445 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -17,11 +17,12 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup [dependencies] anyhow.workspace = true -async-recursion = "0.3" async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] } base64.workspace = true chrono = { workspace = true, features = ["serde"] } clock.workspace = true +cloud_api_client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true credentials_provider.workspace = true derive_more.workspace = true @@ -33,8 +34,8 @@ http_client.workspace = true http_client_tls.workspace = true httparse = "1.10" log.workspace = true -paths.workspace = true parking_lot.workspace = true +paths.workspace = true postage.workspace = true rand.workspace = true regex.workspace = true @@ -46,19 +47,18 @@ serde_json.workspace = true settings.workspace = true sha2.workspace = true smol.workspace = true +telemetry.workspace = true telemetry_events.workspace = true text.workspace = true thiserror.workspace = true time.workspace = true tiny_http.workspace = true tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] } +tokio.workspace = true url.workspace = true util.workspace = true -worktree.workspace = true -telemetry.workspace = true -tokio.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true +worktree.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 81bb95b514..b4894cddcf 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -6,22 +6,21 @@ pub mod telemetry; pub mod user; pub mod zed_urls; -use anyhow::{Context as _, Result, anyhow, bail}; -use async_recursion::async_recursion; +use anyhow::{Context as _, Result, anyhow}; use async_tungstenite::tungstenite::{ client::IntoClientRequest, error::Error as WebsocketError, http::{HeaderValue, Request, StatusCode}, }; -use chrono::{DateTime, Utc}; use clock::SystemClock; +use cloud_api_client::CloudApiClient; use credentials_provider::CredentialsProvider; use futures::{ AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, channel::oneshot, future::BoxFuture, }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; -use http_client::{AsyncBody, HttpClient, HttpClientWithUrl}; +use http_client::{HttpClient, HttpClientWithUrl, http}; use parking_lot::RwLock; use postage::watch; use proxy::connect_proxy_stream; @@ -31,7 +30,6 @@ use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use std::pin::Pin; use std::{ any::TypeId, convert::TryFrom, @@ -45,6 +43,7 @@ use std::{ }, time::{Duration, Instant}, }; +use std::{cmp, pin::Pin}; use telemetry::Telemetry; use thiserror::Error; use tokio::net::TcpStream; @@ -78,7 +77,7 @@ pub static ZED_ALWAYS_ACTIVE: LazyLock = LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty())); pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500); -pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10); +pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30); pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20); actions!( @@ -151,7 +150,6 @@ impl Settings for ProxySettings { pub fn init_settings(cx: &mut App) { TelemetrySettings::register(cx); - DisableAiSettings::register(cx); ClientSettings::register(cx); ProxySettings::register(cx); } @@ -162,20 +160,8 @@ pub fn init(client: &Arc, cx: &mut App) { let client = client.clone(); move |_: &SignIn, cx| { if let Some(client) = client.upgrade() { - cx.spawn( - async move |cx| match client.authenticate_and_connect(true, &cx).await { - ConnectionResult::Timeout => { - log::error!("Initial authentication timed out"); - } - ConnectionResult::ConnectionReset => { - log::error!("Initial authentication connection reset"); - } - ConnectionResult::Result(r) => { - r.log_err(); - } - }, - ) - .detach(); + cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await) + .detach_and_log_err(cx); } } }); @@ -213,6 +199,7 @@ pub struct Client { id: AtomicU64, peer: Arc, http: Arc, + cloud_client: Arc, telemetry: Arc, credentials_provider: ClientCredentialsProvider, state: RwLock, @@ -283,6 +270,8 @@ pub enum Status { SignedOut, UpgradeRequired, Authenticating, + Authenticated, + AuthenticationError, Connecting, ConnectionError, Connected { @@ -549,33 +538,6 @@ impl settings::Settings for TelemetrySettings { } } -/// Whether to disable all AI features in Zed. -/// -/// Default: false -#[derive(Copy, Clone, Debug)] -pub struct DisableAiSettings { - pub disable_ai: bool, -} - -impl settings::Settings for DisableAiSettings { - const KEY: Option<&'static str> = Some("disable_ai"); - - type FileContent = Option; - - fn load(sources: SettingsSources, _: &mut App) -> Result { - Ok(Self { - disable_ai: sources - .user - .or(sources.server) - .copied() - .flatten() - .unwrap_or(sources.default.ok_or_else(Self::missing_default)?), - }) - } - - fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} -} - impl Client { pub fn new( clock: Arc, @@ -586,6 +548,7 @@ impl Client { id: AtomicU64::new(0), peer: Peer::new(0), telemetry: Telemetry::new(clock, http.clone(), cx), + cloud_client: Arc::new(CloudApiClient::new(http.clone())), http, credentials_provider: ClientCredentialsProvider::new(cx), state: Default::default(), @@ -618,6 +581,10 @@ impl Client { self.http.clone() } + pub fn cloud_client(&self) -> Arc { + self.cloud_client.clone() + } + pub fn set_id(&self, id: u64) -> &Self { self.id.store(id, Ordering::SeqCst); self @@ -704,7 +671,7 @@ impl Client { let mut delay = INITIAL_RECONNECTION_DELAY; loop { - match client.authenticate_and_connect(true, &cx).await { + match client.connect(true, &cx).await { ConnectionResult::Timeout => { log::error!("client connect attempt timed out") } @@ -720,18 +687,20 @@ impl Client { } } - if matches!(*client.status().borrow(), Status::ConnectionError) { + if matches!( + *client.status().borrow(), + Status::AuthenticationError | Status::ConnectionError + ) { client.set_status( Status::ReconnectionError { next_reconnection: Instant::now() + delay, }, &cx, ); - cx.background_executor().timer(delay).await; - delay = delay - .mul_f32(rng.gen_range(0.5..=2.5)) - .max(INITIAL_RECONNECTION_DELAY) - .min(MAX_RECONNECTION_DELAY); + let jitter = + Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64)); + cx.background_executor().timer(delay + jitter).await; + delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY); } else { break; } @@ -875,17 +844,127 @@ impl Client { .is_some() } - #[async_recursion(?Send)] - pub async fn authenticate_and_connect( + pub async fn sign_in( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> Result { + if self.status().borrow().is_signed_out() { + self.set_status(Status::Authenticating, cx); + } else { + self.set_status(Status::Reauthenticating, cx); + } + + let mut credentials = None; + + let old_credentials = self.state.read().credentials.clone(); + if let Some(old_credentials) = old_credentials { + if self.validate_credentials(&old_credentials, cx).await? { + credentials = Some(old_credentials); + } + } + + if credentials.is_none() && try_provider { + if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await { + if self.validate_credentials(&stored_credentials, cx).await? { + credentials = Some(stored_credentials); + } else { + self.credentials_provider + .delete_credentials(cx) + .await + .log_err(); + } + } + } + + if credentials.is_none() { + let mut status_rx = self.status(); + let _ = status_rx.next().await; + futures::select_biased! { + authenticate = self.authenticate(cx).fuse() => { + match authenticate { + Ok(creds) => { + if IMPERSONATE_LOGIN.is_none() { + self.credentials_provider + .write_credentials(creds.user_id, creds.access_token.clone(), cx) + .await + .log_err(); + } + + credentials = Some(creds); + }, + Err(err) => { + self.set_status(Status::AuthenticationError, cx); + return Err(err); + } + } + } + _ = status_rx.next().fuse() => { + return Err(anyhow!("authentication canceled")); + } + } + } + + let credentials = credentials.unwrap(); + self.set_id(credentials.user_id); + self.cloud_client + .set_credentials(credentials.user_id as u32, credentials.access_token.clone()); + self.state.write().credentials = Some(credentials.clone()); + self.set_status(Status::Authenticated, cx); + + Ok(credentials) + } + + async fn validate_credentials( + self: &Arc, + credentials: &Credentials, + cx: &AsyncApp, + ) -> Result { + match self + .cloud_client + .validate_credentials(credentials.user_id as u32, &credentials.access_token) + .await + { + Ok(valid) => Ok(valid), + Err(err) => { + self.set_status(Status::AuthenticationError, cx); + Err(anyhow!("failed to validate credentials: {}", err)) + } + } + } + + /// Performs a sign-in and also connects to Collab. + /// + /// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls + /// to `sign_in` when we're ready to remove auto-connection to Collab. + pub async fn sign_in_with_optional_connect( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> Result<()> { + let credentials = self.sign_in(try_provider, cx).await?; + + let connect_result = match self.connect_with_credentials(credentials, cx).await { + ConnectionResult::Timeout => Err(anyhow!("connection timed out")), + ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")), + ConnectionResult::Result(result) => result.context("client auth and connect"), + }; + connect_result.log_err(); + + Ok(()) + } + + pub async fn connect( self: &Arc, try_provider: bool, cx: &AsyncApp, ) -> ConnectionResult<()> { let was_disconnected = match *self.status().borrow() { - Status::SignedOut => true, + Status::SignedOut | Status::Authenticated => true, Status::ConnectionError | Status::ConnectionLost | Status::Authenticating { .. } + | Status::AuthenticationError | Status::Reauthenticating { .. } | Status::ReconnectionError { .. } => false, Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { @@ -898,39 +977,10 @@ impl Client { ); } }; - if was_disconnected { - self.set_status(Status::Authenticating, cx); - } else { - self.set_status(Status::Reauthenticating, cx) - } - - let mut read_from_provider = false; - let mut credentials = self.state.read().credentials.clone(); - if credentials.is_none() && try_provider { - credentials = self.credentials_provider.read_credentials(cx).await; - read_from_provider = credentials.is_some(); - } - - if credentials.is_none() { - let mut status_rx = self.status(); - let _ = status_rx.next().await; - futures::select_biased! { - authenticate = self.authenticate(cx).fuse() => { - match authenticate { - Ok(creds) => credentials = Some(creds), - Err(err) => { - self.set_status(Status::ConnectionError, cx); - return ConnectionResult::Result(Err(err)); - } - } - } - _ = status_rx.next().fuse() => { - return ConnectionResult::Result(Err(anyhow!("authentication canceled"))); - } - } - } - let credentials = credentials.unwrap(); - self.set_id(credentials.user_id); + let credentials = match self.sign_in(try_provider, cx).await { + Ok(credentials) => credentials, + Err(err) => return ConnectionResult::Result(Err(err)), + }; if was_disconnected { self.set_status(Status::Connecting, cx); @@ -938,17 +988,20 @@ impl Client { self.set_status(Status::Reconnecting, cx); } + self.connect_with_credentials(credentials, cx).await + } + + async fn connect_with_credentials( + self: &Arc, + credentials: Credentials, + cx: &AsyncApp, + ) -> ConnectionResult<()> { let mut timeout = futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT)); futures::select_biased! { connection = self.establish_connection(&credentials, cx).fuse() => { match connection { Ok(conn) => { - self.state.write().credentials = Some(credentials.clone()); - if !read_from_provider && IMPERSONATE_LOGIN.is_none() { - self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err(); - } - futures::select_biased! { result = self.set_connection(conn, cx).fuse() => { match result.context("client auth and connect") { @@ -966,15 +1019,8 @@ impl Client { } } Err(EstablishConnectionError::Unauthorized) => { - self.state.write().credentials.take(); - if read_from_provider { - self.credentials_provider.delete_credentials(cx).await.log_err(); - self.set_status(Status::SignedOut, cx); - self.authenticate_and_connect(false, cx).await - } else { - self.set_status(Status::ConnectionError, cx); - ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) - } + self.set_status(Status::ConnectionError, cx); + ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) } Err(EstablishConnectionError::UpgradeRequired) => { self.set_status(Status::UpgradeRequired, cx); @@ -1138,7 +1184,7 @@ impl Client { .to_str() .map_err(EstablishConnectionError::other)? .to_string(); - Url::parse(&collab_url).with_context(|| format!("parsing colab rpc url {collab_url}")) + Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}")) } } @@ -1158,6 +1204,7 @@ impl Client { let http = self.http.clone(); let proxy = http.proxy().cloned(); + let user_agent = http.user_agent().cloned(); let credentials = credentials.clone(); let rpc_url = self.rpc_url(http, release_channel); let system_id = self.telemetry.system_id(); @@ -1209,7 +1256,7 @@ impl Client { // We then modify the request to add our desired headers. let request_headers = request.headers_mut(); request_headers.insert( - "Authorization", + http::header::AUTHORIZATION, HeaderValue::from_str(&credentials.authorization_header())?, ); request_headers.insert( @@ -1221,6 +1268,9 @@ impl Client { "x-zed-release-channel", HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?, ); + if let Some(user_agent) = user_agent { + request_headers.insert(http::header::USER_AGENT, user_agent); + } if let Some(system_id) = system_id { request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?); } @@ -1365,96 +1415,31 @@ impl Client { self: &Arc, http: Arc, login: String, - mut api_token: String, + api_token: String, ) -> Result { - #[derive(Deserialize)] - struct AuthenticatedUserResponse { - user: User, + #[derive(Serialize)] + struct ImpersonateUserBody { + github_login: String, } #[derive(Deserialize)] - struct User { - id: u64, + struct ImpersonateUserResponse { + user_id: u64, + access_token: String, } - let github_user = { - #[derive(Deserialize)] - struct GithubUser { - id: i32, - login: String, - created_at: DateTime, - } - - let request = { - let mut request_builder = - Request::get(&format!("https://api.github.com/users/{login}")); - if let Ok(github_token) = std::env::var("GITHUB_TOKEN") { - request_builder = - request_builder.header("Authorization", format!("Bearer {}", github_token)); - } - - request_builder.body(AsyncBody::empty())? - }; - - let mut response = http - .send(request) - .await - .context("error fetching GitHub user")?; - - let mut body = Vec::new(); - response - .body_mut() - .read_to_end(&mut body) - .await - .context("error reading GitHub user")?; - - if !response.status().is_success() { - let text = String::from_utf8_lossy(body.as_slice()); - bail!( - "status error {}, response: {text:?}", - response.status().as_u16() - ); - } - - serde_json::from_slice::(body.as_slice()).map_err(|err| { - log::error!("Error deserializing: {:?}", err); - log::error!( - "GitHub API response text: {:?}", - String::from_utf8_lossy(body.as_slice()) - ); - anyhow!("error deserializing GitHub user") - })? - }; - - let query_params = [ - ("github_login", &github_user.login), - ("github_user_id", &github_user.id.to_string()), - ( - "github_user_created_at", - &github_user.created_at.to_rfc3339(), - ), - ]; - - // Use the collab server's admin API to retrieve the ID - // of the impersonated user. - let mut url = self.rpc_url(http.clone(), None).await?; - url.set_path("/user"); - url.set_query(Some( - &query_params - .iter() - .map(|(key, value)| { - format!( - "{}={}", - key, - url::form_urlencoded::byte_serialize(value.as_bytes()).collect::() - ) - }) - .collect::>() - .join("&"), - )); - let request: http_client::Request = Request::get(url.as_str()) - .header("Authorization", format!("token {api_token}")) - .body("".into())?; + let url = self + .http + .build_zed_cloud_url("/internal/users/impersonate", &[])?; + let request = Request::post(url.as_str()) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {api_token}")) + .body( + serde_json::to_string(&ImpersonateUserBody { + github_login: login, + })? + .into(), + )?; let mut response = http.send(request).await?; let mut body = String::new(); @@ -1465,18 +1450,17 @@ impl Client { response.status().as_u16(), body, ); - let response: AuthenticatedUserResponse = serde_json::from_str(&body)?; + let response: ImpersonateUserResponse = serde_json::from_str(&body)?; - // Use the admin API token to authenticate as the impersonated user. - api_token.insert_str(0, "ADMIN_TOKEN:"); Ok(Credentials { - user_id: response.user.id, - access_token: api_token, + user_id: response.user_id, + access_token: response.access_token, }) } pub async fn sign_out(self: &Arc, cx: &AsyncApp) { self.state.write().credentials = None; + self.cloud_client.clear_credentials(); self.disconnect(cx); if self.has_credentials(cx).await { @@ -1705,7 +1689,7 @@ pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> { #[cfg(test)] mod tests { use super::*; - use crate::test::FakeServer; + use crate::test::{FakeServer, parse_authorization_header}; use clock::FakeSystemClock; use gpui::{AppContext as _, BackgroundExecutor, TestAppContext}; @@ -1756,6 +1740,46 @@ mod tests { assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token } + #[gpui::test(iterations = 10)] + async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) { + init_test(cx); + let http_client = FakeHttpClient::with_200_response(); + let client = + cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx)); + let server = FakeServer::for_client(42, &client, cx).await; + let mut status = client.status(); + assert!(matches!( + status.next().await, + Some(Status::Connected { .. }) + )); + assert_eq!(server.auth_count(), 1); + + // Simulate an auth failure during reconnection. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(503) + .body("".into()) + .unwrap()) + }); + server.disconnect(); + while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} + + // Restore the ability to authenticate. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + }); + cx.executor().advance_clock(Duration::from_secs(10)); + while !matches!(status.next().await, Some(Status::Connected { .. })) {} + assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting + } + #[gpui::test(iterations = 10)] async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) { init_test(cx); @@ -1786,7 +1810,7 @@ mod tests { }); let auth_and_connect = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert!(matches!(status.next().await, Some(Status::Connecting))); @@ -1831,6 +1855,75 @@ mod tests { )); } + #[gpui::test(iterations = 10)] + async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) { + init_test(cx); + let auth_count = Arc::new(Mutex::new(0)); + let http_client = FakeHttpClient::create(|_request| async move { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + }); + let client = + cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx)); + client.override_authenticate({ + let auth_count = auth_count.clone(); + move |cx| { + let auth_count = auth_count.clone(); + cx.background_spawn(async move { + *auth_count.lock() += 1; + Ok(Credentials { + user_id: 1, + access_token: auth_count.lock().to_string(), + }) + }) + } + }); + + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(credentials.access_token, "1"); + + // If credentials are still valid, signing in doesn't trigger authentication. + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(credentials.access_token, "1"); + + // If the server is unavailable, signing in doesn't trigger authentication. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(503) + .body("".into()) + .unwrap()) + }); + client.sign_in(false, &cx.to_async()).await.unwrap_err(); + assert_eq!(*auth_count.lock(), 1); + + // If credentials became invalid, signing in triggers authentication. + http_client + .as_fake() + .replace_handler(|_, request| async move { + let credentials = parse_authorization_header(&request).unwrap(); + if credentials.access_token == "2" { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + } else { + Ok(http_client::Response::builder() + .status(401) + .body("".into()) + .unwrap()) + } + }); + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 2); + assert_eq!(credentials.access_token, "2"); + } + #[gpui::test(iterations = 10)] async fn test_authenticating_more_than_once( cx: &mut TestAppContext, @@ -1863,7 +1956,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - move |cx| async move { client.authenticate_and_connect(false, &cx).await } + move |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 1); @@ -1871,7 +1964,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 2); diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 4983fda5ef..43a1a0b7a4 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -74,6 +74,12 @@ static ZED_CLIENT_CHECKSUM_SEED: LazyLock>> = LazyLock::new(|| { }) }); +pub static MINIDUMP_ENDPOINT: LazyLock> = LazyLock::new(|| { + option_env!("ZED_MINIDUMP_ENDPOINT") + .map(|s| s.to_owned()) + .or_else(|| env::var("ZED_MINIDUMP_ENDPOINT").ok()) +}); + static DOTNET_PROJECT_FILES_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"^(global\.json|Directory\.Build\.props|.*\.(csproj|fsproj|vbproj|sln))$").unwrap() }); @@ -358,13 +364,13 @@ impl Telemetry { worktree_id: WorktreeId, updated_entries_set: &UpdatedEntriesSet, ) { - let Some(project_type_names) = self.detect_project_types(worktree_id, updated_entries_set) + let Some(project_types) = self.detect_project_types(worktree_id, updated_entries_set) else { return; }; - for project_type_name in project_type_names { - telemetry::event!("Project Opened", project_type = project_type_name); + for project_type in project_types { + telemetry::event!("Project Opened", project_type = project_type); } } diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 6ce79fa9c5..439fb100d2 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,8 +1,11 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; use chrono::Duration; +use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; +use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; +use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; use rpc::{ ConnectionId, Peer, Receipt, TypedEnvelope, @@ -39,6 +42,44 @@ impl FakeServer { executor: cx.executor(), }; + client.http_client().as_fake().replace_handler({ + let state = server.state.clone(); + move |old_handler, req| { + let state = state.clone(); + let old_handler = old_handler.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: client_user_id, + access_token: state.lock().access_token.to_string(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + client_user_id as i32, + format!("user-{client_user_id}"), + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => old_handler(req).await, + } + } + } + }); client .override_authenticate({ let state = Arc::downgrade(&server.state); @@ -105,7 +146,7 @@ impl FakeServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -223,3 +264,54 @@ impl Drop for FakeServer { self.disconnect(); } } + +pub fn parse_authorization_header(req: &Request) -> Option { + let mut auth_header = req + .headers() + .get(http::header::AUTHORIZATION)? + .to_str() + .ok()? + .split_whitespace(); + let user_id = auth_header.next()?.parse().ok()?; + let access_token = auth_header.next()?; + Some(Credentials { + user_id, + access_token: access_token.to_string(), + }) +} + +pub fn make_get_authenticated_user_response( + user_id: i32, + github_login: String, +) -> GetAuthenticatedUserResponse { + GetAuthenticatedUserResponse { + user: AuthenticatedUser { + id: user_id, + metrics_id: format!("metrics-id-{user_id}"), + avatar_url: "".to_string(), + github_login, + name: None, + is_staff: false, + accepted_tos_at: None, + }, + feature_flags: vec![], + plan: PlanInfo { + plan: Plan::ZedPro, + subscription_period: None, + usage: CurrentUsage { + model_requests: UsageData { + used: 0, + limit: UsageLimit::Limited(500), + }, + edit_predictions: UsageData { + used: 250, + limit: UsageLimit::Unlimited, + }, + }, + trial_started_at: None, + is_usage_based_billing_enabled: false, + is_account_too_young: false, + has_overdue_invoices: false, + }, + } +} diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index f5213fbcb6..3c125a0882 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,6 +1,11 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; +use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo}; +use cloud_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; use collections::{HashMap, HashSet, hash_map::Entry}; use derive_more::Deref; use feature_flags::FeatureFlagAppExt; @@ -16,11 +21,7 @@ use std::{ sync::{Arc, Weak}, }; use text::ReplicaId; -use util::{TryFutureExt as _, maybe}; -use zed_llm_client::{ - EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, - MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; +use util::{ResultExt, TryFutureExt as _}; pub type UserId = u64; @@ -55,7 +56,7 @@ pub struct ParticipantIndex(pub u32); #[derive(Default, Debug)] pub struct User { pub id: UserId, - pub github_login: String, + pub github_login: SharedString, pub avatar_uri: SharedUri, pub name: Option, } @@ -107,19 +108,14 @@ pub enum ContactRequestStatus { pub struct UserStore { users: HashMap>, - by_github_login: HashMap, + by_github_login: HashMap, participant_indices: HashMap, update_contacts_tx: mpsc::UnboundedSender, - current_plan: Option, - subscription_period: Option<(DateTime, DateTime)>, - trial_started_at: Option>, model_request_usage: Option, edit_prediction_usage: Option, - is_usage_based_billing_enabled: Option, - account_too_young: Option, - has_overdue_invoices: Option, + plan_info: Option, current_user: watch::Receiver>>, - accepted_tos_at: Option>>, + accepted_tos_at: Option>, contacts: Vec>, incoming_contact_requests: Vec>, outgoing_contact_requests: Vec>, @@ -145,6 +141,7 @@ pub enum Event { ShowContacts, ParticipantIndicesChanged, PrivateUserInfoUpdated, + PlanUpdated, } #[derive(Clone, Copy)] @@ -188,14 +185,9 @@ impl UserStore { users: Default::default(), by_github_login: Default::default(), current_user: current_user_rx, - current_plan: None, - subscription_period: None, - trial_started_at: None, + plan_info: None, model_request_usage: None, edit_prediction_usage: None, - is_usage_based_billing_enabled: None, - account_too_young: None, - has_overdue_invoices: None, accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), @@ -225,53 +217,30 @@ impl UserStore { return Ok(()); }; match status { - Status::Connected { .. } => { + Status::Authenticated | Status::Connected { .. } => { if let Some(user_id) = client.user_id() { - let fetch_user = if let Ok(fetch_user) = - this.update(cx, |this, cx| this.get_user(user_id, cx).log_err()) - { - fetch_user - } else { - break; - }; - let fetch_private_user_info = - client.request(proto::GetPrivateUserInfo {}).log_err(); - let (user, info) = - futures::join!(fetch_user, fetch_private_user_info); - + let response = client.cloud_client().get_authenticated_user().await; + let mut current_user = None; cx.update(|cx| { - if let Some(info) = info { - let staff = - info.staff && !*feature_flags::ZED_DISABLE_STAFF; - cx.update_flags(staff, info.flags); - client.telemetry.set_authenticated_user_info( - Some(info.metrics_id.clone()), - staff, - ); - + if let Some(response) = response.log_err() { + let user = Arc::new(User { + id: user_id, + github_login: response.user.github_login.clone().into(), + avatar_uri: response.user.avatar_url.clone().into(), + name: response.user.name.clone(), + }); + current_user = Some(user.clone()); this.update(cx, |this, cx| { - let accepted_tos_at = { - #[cfg(debug_assertions)] - if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() - { - None - } else { - info.accepted_tos_at - } - - #[cfg(not(debug_assertions))] - info.accepted_tos_at - }; - - this.set_current_user_accepted_tos_at(accepted_tos_at); - cx.emit(Event::PrivateUserInfoUpdated); + this.by_github_login + .insert(user.github_login.clone(), user_id); + this.users.insert(user_id, user); + this.update_authenticated_user(response, cx) }) } else { anyhow::Ok(()) } })??; - - current_user_tx.send(user).await.ok(); + current_user_tx.send(current_user).await.ok(); this.update(cx, |_, cx| cx.notify())?; } @@ -352,59 +321,22 @@ impl UserStore { async fn handle_update_plan( this: Entity, - message: TypedEnvelope, + _message: TypedEnvelope, mut cx: AsyncApp, ) -> Result<()> { + let client = this + .read_with(&cx, |this, _| this.client.upgrade())? + .context("client was dropped")?; + + let response = client + .cloud_client() + .get_authenticated_user() + .await + .context("failed to fetch authenticated user")?; + this.update(&mut cx, |this, cx| { - this.current_plan = Some(message.payload.plan()); - this.subscription_period = maybe!({ - let period = message.payload.subscription_period?; - let started_at = DateTime::from_timestamp(period.started_at as i64, 0)?; - let ended_at = DateTime::from_timestamp(period.ended_at as i64, 0)?; - - Some((started_at, ended_at)) - }); - this.trial_started_at = message - .payload - .trial_started_at - .and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0)); - this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled; - this.account_too_young = message.payload.account_too_young; - this.has_overdue_invoices = message.payload.has_overdue_invoices; - - if let Some(usage) = message.payload.usage { - // limits are always present even though they are wrapped in Option - this.model_request_usage = usage - .model_requests_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(ModelRequestUsage); - this.edit_prediction_usage = usage - .edit_predictions_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(EditPredictionUsage); - } - - cx.notify(); - })?; - Ok(()) - } - - pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { - self.model_request_usage = Some(usage); - cx.notify(); - } - - pub fn update_edit_prediction_usage( - &mut self, - usage: EditPredictionUsage, - cx: &mut Context, - ) { - self.edit_prediction_usage = Some(usage); - cx.notify(); + this.update_authenticated_user(response, cx); + }) } fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { @@ -763,57 +695,131 @@ impl UserStore { self.current_user.borrow().clone() } - pub fn current_plan(&self) -> Option { + pub fn plan(&self) -> Option { #[cfg(debug_assertions)] - if let Ok(plan) = std::env::var("ZED_SIMULATE_ZED_PRO_PLAN").as_ref() { + if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() { return match plan.as_str() { - "free" => Some(proto::Plan::Free), - "trial" => Some(proto::Plan::ZedProTrial), - "pro" => Some(proto::Plan::ZedPro), - _ => None, + "free" => Some(cloud_llm_client::Plan::ZedFree), + "trial" => Some(cloud_llm_client::Plan::ZedProTrial), + "pro" => Some(cloud_llm_client::Plan::ZedPro), + _ => { + panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'"); + } }; } - self.current_plan + self.plan_info.as_ref().map(|info| info.plan) } pub fn subscription_period(&self) -> Option<(DateTime, DateTime)> { - self.subscription_period + self.plan_info + .as_ref() + .and_then(|plan| plan.subscription_period) + .map(|subscription_period| { + ( + subscription_period.started_at.0, + subscription_period.ended_at.0, + ) + }) } pub fn trial_started_at(&self) -> Option> { - self.trial_started_at + self.plan_info + .as_ref() + .and_then(|plan| plan.trial_started_at) + .map(|trial_started_at| trial_started_at.0) } - pub fn usage_based_billing_enabled(&self) -> Option { - self.is_usage_based_billing_enabled + /// Returns whether the user's account is too new to use the service. + pub fn account_too_young(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_account_too_young) + .unwrap_or_default() + } + + /// Returns whether the current user has overdue invoices and usage should be blocked. + pub fn has_overdue_invoices(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.has_overdue_invoices) + .unwrap_or_default() + } + + pub fn is_usage_based_billing_enabled(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_usage_based_billing_enabled) + .unwrap_or_default() } pub fn model_request_usage(&self) -> Option { self.model_request_usage } + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + pub fn edit_prediction_usage(&self) -> Option { self.edit_prediction_usage } + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); + } + + fn update_authenticated_user( + &mut self, + response: GetAuthenticatedUserResponse, + cx: &mut Context, + ) { + let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF; + cx.update_flags(staff, response.feature_flags); + if let Some(client) = self.client.upgrade() { + client + .telemetry + .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff); + } + + let accepted_tos_at = { + #[cfg(debug_assertions)] + if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() { + None + } else { + response.user.accepted_tos_at + } + + #[cfg(not(debug_assertions))] + response.user.accepted_tos_at + }; + + self.accepted_tos_at = Some(accepted_tos_at); + self.model_request_usage = Some(ModelRequestUsage(RequestUsage { + limit: response.plan.usage.model_requests.limit, + amount: response.plan.usage.model_requests.used as i32, + })); + self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { + limit: response.plan.usage.edit_predictions.limit, + amount: response.plan.usage.edit_predictions.used as i32, + })); + self.plan_info = Some(response.plan); + cx.emit(Event::PrivateUserInfoUpdated); + } + pub fn watch_current_user(&self) -> watch::Receiver>> { self.current_user.clone() } - /// Returns whether the user's account is too new to use the service. - pub fn account_too_young(&self) -> bool { - self.account_too_young.unwrap_or(false) - } - - /// Returns whether the current user has overdue invoices and usage should be blocked. - pub fn has_overdue_invoices(&self) -> bool { - self.has_overdue_invoices.unwrap_or(false) - } - - pub fn current_user_has_accepted_terms(&self) -> Option { + pub fn has_accepted_terms_of_service(&self) -> bool { self.accepted_tos_at - .map(|accepted_tos_at| accepted_tos_at.is_some()) + .map_or(false, |accepted_tos_at| accepted_tos_at.is_some()) } pub fn accept_terms_of_service(&self, cx: &Context) -> Task> { @@ -825,23 +831,18 @@ impl UserStore { cx.spawn(async move |this, cx| -> anyhow::Result<()> { let client = client.upgrade().context("client not found")?; let response = client - .request(proto::AcceptTermsOfService {}) + .cloud_client() + .accept_terms_of_service() .await .context("error accepting tos")?; this.update(cx, |this, cx| { - this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at)); + this.accepted_tos_at = Some(response.user.accepted_tos_at); cx.emit(Event::PrivateUserInfoUpdated); })?; Ok(()) }) } - fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option) { - self.accepted_tos_at = Some( - accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)), - ); - } - fn load_users( &self, request: impl RequestMessage, @@ -900,7 +901,7 @@ impl UserStore { let mut missing_user_ids = Vec::new(); for id in user_ids { if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) { - ret.insert(id, github_login.into()); + ret.insert(id, github_login); } else { missing_user_ids.push(id) } @@ -921,7 +922,7 @@ impl User { fn new(message: proto::User) -> Arc { Arc::new(User { id: message.id, - github_login: message.github_login, + github_login: message.github_login.into(), avatar_uri: message.avatar_url.into(), name: message.name, }) diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml new file mode 100644 index 0000000000..d56aa94c6e --- /dev/null +++ b/crates/cloud_api_client/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "cloud_api_client" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_api_client.rs" + +[dependencies] +anyhow.workspace = true +cloud_api_types.workspace = true +futures.workspace = true +http_client.workspace = true +parking_lot.workspace = true +serde_json.workspace = true +workspace-hack.workspace = true diff --git a/crates/cloud_api_client/LICENSE-APACHE b/crates/cloud_api_client/LICENSE-APACHE new file mode 120000 index 0000000000..1cd601d0a3 --- /dev/null +++ b/crates/cloud_api_client/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs new file mode 100644 index 0000000000..edac051a0e --- /dev/null +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -0,0 +1,188 @@ +use std::sync::Arc; + +use anyhow::{Context, Result, anyhow}; +pub use cloud_api_types::*; +use futures::AsyncReadExt as _; +use http_client::http::request; +use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode}; +use parking_lot::RwLock; + +struct Credentials { + user_id: u32, + access_token: String, +} + +pub struct CloudApiClient { + credentials: RwLock>, + http_client: Arc, +} + +impl CloudApiClient { + pub fn new(http_client: Arc) -> Self { + Self { + credentials: RwLock::new(None), + http_client, + } + } + + pub fn has_credentials(&self) -> bool { + self.credentials.read().is_some() + } + + pub fn set_credentials(&self, user_id: u32, access_token: String) { + *self.credentials.write() = Some(Credentials { + user_id, + access_token, + }); + } + + pub fn clear_credentials(&self) { + *self.credentials.write() = None; + } + + fn build_request( + &self, + req: request::Builder, + body: impl Into, + ) -> Result> { + let credentials = self.credentials.read(); + let credentials = credentials.as_ref().context("no credentials provided")?; + build_request(req, body, credentials) + } + + pub async fn get_authenticated_user(&self) -> Result { + let request = self.build_request( + Request::builder().method(Method::GET).uri( + self.http_client + .build_zed_cloud_url("/client/users/me", &[])? + .as_ref(), + ), + AsyncBody::default(), + )?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn accept_terms_of_service(&self) -> Result { + let request = self.build_request( + Request::builder().method(Method::POST).uri( + self.http_client + .build_zed_cloud_url("/client/terms_of_service/accept", &[])? + .as_ref(), + ), + AsyncBody::default(), + )?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to accept terms of service.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn create_llm_token( + &self, + system_id: Option, + ) -> Result { + let mut request_builder = Request::builder().method(Method::POST).uri( + self.http_client + .build_zed_cloud_url("/client/llm_tokens", &[])? + .as_ref(), + ); + + if let Some(system_id) = system_id { + request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id); + } + + let request = self.build_request(request_builder, AsyncBody::default())?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to create LLM token.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn validate_credentials(&self, user_id: u32, access_token: &str) -> Result { + let request = build_request( + Request::builder().method(Method::GET).uri( + self.http_client + .build_zed_cloud_url("/client/users/me", &[])? + .as_ref(), + ), + AsyncBody::default(), + &Credentials { + user_id, + access_token: access_token.into(), + }, + )?; + + let mut response = self.http_client.send(request).await?; + + if response.status().is_success() { + Ok(true) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + if response.status() == StatusCode::UNAUTHORIZED { + return Ok(false); + } else { + return Err(anyhow!( + "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", + response.status() + )); + } + } + } +} + +fn build_request( + req: request::Builder, + body: impl Into, + credentials: &Credentials, +) -> Result> { + Ok(req + .header("Content-Type", "application/json") + .header( + "Authorization", + format!("{} {}", credentials.user_id, credentials.access_token), + ) + .body(body.into())?) +} diff --git a/crates/cloud_api_types/Cargo.toml b/crates/cloud_api_types/Cargo.toml new file mode 100644 index 0000000000..868797df3b --- /dev/null +++ b/crates/cloud_api_types/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "cloud_api_types" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_api_types.rs" + +[dependencies] +chrono.workspace = true +cloud_llm_client.workspace = true +serde.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true +serde_json.workspace = true diff --git a/crates/cloud_api_types/LICENSE-APACHE b/crates/cloud_api_types/LICENSE-APACHE new file mode 120000 index 0000000000..1cd601d0a3 --- /dev/null +++ b/crates/cloud_api_types/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_types/src/cloud_api_types.rs b/crates/cloud_api_types/src/cloud_api_types.rs new file mode 100644 index 0000000000..b38b38cde1 --- /dev/null +++ b/crates/cloud_api_types/src/cloud_api_types.rs @@ -0,0 +1,55 @@ +mod timestamp; + +use serde::{Deserialize, Serialize}; + +pub use crate::timestamp::Timestamp; + +pub const ZED_SYSTEM_ID_HEADER_NAME: &str = "x-zed-system-id"; + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct GetAuthenticatedUserResponse { + pub user: AuthenticatedUser, + pub feature_flags: Vec, + pub plan: PlanInfo, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct AuthenticatedUser { + pub id: i32, + pub metrics_id: String, + pub avatar_url: String, + pub github_login: String, + pub name: Option, + pub is_staff: bool, + pub accepted_tos_at: Option, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct PlanInfo { + pub plan: cloud_llm_client::Plan, + pub subscription_period: Option, + pub usage: cloud_llm_client::CurrentUsage, + pub trial_started_at: Option, + pub is_usage_based_billing_enabled: bool, + pub is_account_too_young: bool, + pub has_overdue_invoices: bool, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +pub struct SubscriptionPeriod { + pub started_at: Timestamp, + pub ended_at: Timestamp, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct AcceptTermsOfServiceResponse { + pub user: AuthenticatedUser, +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct LlmToken(pub String); + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct CreateLlmTokenResponse { + pub token: LlmToken, +} diff --git a/crates/cloud_api_types/src/timestamp.rs b/crates/cloud_api_types/src/timestamp.rs new file mode 100644 index 0000000000..1f055d58ef --- /dev/null +++ b/crates/cloud_api_types/src/timestamp.rs @@ -0,0 +1,166 @@ +use chrono::{DateTime, NaiveDateTime, SecondsFormat, Utc}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A timestamp with a serialized representation in RFC 3339 format. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub struct Timestamp(pub DateTime); + +impl Timestamp { + pub fn new(datetime: DateTime) -> Self { + Self(datetime) + } +} + +impl From> for Timestamp { + fn from(value: DateTime) -> Self { + Self(value) + } +} + +impl From for Timestamp { + fn from(value: NaiveDateTime) -> Self { + Self(value.and_utc()) + } +} + +impl Serialize for Timestamp { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let rfc3339_string = self.0.to_rfc3339_opts(SecondsFormat::Millis, true); + serializer.serialize_str(&rfc3339_string) + } +} + +impl<'de> Deserialize<'de> for Timestamp { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = String::deserialize(deserializer)?; + let datetime = DateTime::parse_from_rfc3339(&value) + .map_err(serde::de::Error::custom)? + .to_utc(); + Ok(Self(datetime)) + } +} + +#[cfg(test)] +mod tests { + use chrono::NaiveDate; + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_timestamp_serialization() { + let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + let timestamp = Timestamp::new(datetime); + + let json = serde_json::to_string(×tamp).unwrap(); + assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); + } + + #[test] + fn test_timestamp_deserialization() { + let json = "\"2023-12-25T14:30:45.123Z\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + let expected = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_roundtrip() { + let original = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + let timestamp = Timestamp::new(original); + let json = serde_json::to_string(×tamp).unwrap(); + let deserialized: Timestamp = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.0, original); + } + + #[test] + fn test_timestamp_from_datetime_utc() { + let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + let timestamp = Timestamp::from(datetime); + assert_eq!(timestamp.0, datetime); + } + + #[test] + fn test_timestamp_from_naive_datetime() { + let naive_dt = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_milli_opt(14, 30, 45, 123) + .unwrap(); + + let timestamp = Timestamp::from(naive_dt); + let expected = naive_dt.and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_serialization_with_microseconds() { + // Test that microseconds are truncated to milliseconds + let datetime = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_micro_opt(14, 30, 45, 123456) + .unwrap() + .and_utc(); + + let timestamp = Timestamp::new(datetime); + let json = serde_json::to_string(×tamp).unwrap(); + + // Should be truncated to milliseconds + assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); + } + + #[test] + fn test_timestamp_deserialization_without_milliseconds() { + let json = "\"2023-12-25T14:30:45Z\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + let expected = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_opt(14, 30, 45) + .unwrap() + .and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_deserialization_with_timezone() { + let json = "\"2023-12-25T14:30:45.123+05:30\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + // Should be converted to UTC + let expected = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_milli_opt(9, 0, 45, 123) // 14:30:45 + 5:30 = 20:00:45, but we want UTC so subtract 5:30 + .unwrap() + .and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_deserialization_with_invalid_format() { + let json = "\"invalid-date\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } +} diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml new file mode 100644 index 0000000000..6f090d3c6e --- /dev/null +++ b/crates/cloud_llm_client/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "cloud_llm_client" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_llm_client.rs" + +[dependencies] +anyhow.workspace = true +serde = { workspace = true, features = ["derive", "rc"] } +serde_json.workspace = true +strum = { workspace = true, features = ["derive"] } +uuid = { workspace = true, features = ["serde"] } +workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/cloud_llm_client/LICENSE-APACHE b/crates/cloud_llm_client/LICENSE-APACHE new file mode 120000 index 0000000000..1cd601d0a3 --- /dev/null +++ b/crates/cloud_llm_client/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs new file mode 100644 index 0000000000..e78957ec49 --- /dev/null +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -0,0 +1,386 @@ +use std::str::FromStr; +use std::sync::Arc; + +use anyhow::Context as _; +use serde::{Deserialize, Serialize}; +use strum::{Display, EnumIter, EnumString}; +use uuid::Uuid; + +/// The name of the header used to indicate which version of Zed the client is running. +pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version"; + +/// The name of the header used to indicate when a request failed due to an +/// expired LLM token. +/// +/// The client may use this as a signal to refresh the token. +pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; + +/// The name of the header used to indicate what plan the user is currently on. +pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan"; + +/// The name of the header used to indicate the usage limit for model requests. +pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit"; + +/// The name of the header used to indicate the usage amount for model requests. +pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount"; + +/// The name of the header used to indicate the usage limit for edit predictions. +pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit"; + +/// The name of the header used to indicate the usage amount for edit predictions. +pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount"; + +/// The name of the header used to indicate the resource for which the subscription limit has been reached. +pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource"; + +pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests"; +pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions"; + +/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached. +pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached"; + +/// The name of the header used to indicate the the minimum required Zed version. +/// +/// This can be used to force a Zed upgrade in order to continue communicating +/// with the LLM service. +pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version"; + +/// The name of the header used by the client to indicate to the server that it supports receiving status messages. +pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = + "x-zed-client-supports-status-messages"; + +/// The name of the header used by the server to indicate to the client that it supports sending status messages. +pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = + "x-zed-server-supports-status-messages"; + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum UsageLimit { + Limited(i32), + Unlimited, +} + +impl FromStr for UsageLimit { + type Err = anyhow::Error; + + fn from_str(value: &str) -> Result { + match value { + "unlimited" => Ok(Self::Unlimited), + limit => limit + .parse::() + .map(Self::Limited) + .context("failed to parse limit"), + } + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Plan { + #[default] + #[serde(alias = "Free")] + ZedFree, + #[serde(alias = "ZedPro")] + ZedPro, + #[serde(alias = "ZedProTrial")] + ZedProTrial, +} + +impl Plan { + pub fn as_str(&self) -> &'static str { + match self { + Plan::ZedFree => "zed_free", + Plan::ZedPro => "zed_pro", + Plan::ZedProTrial => "zed_pro_trial", + } + } + + pub fn model_requests_limit(&self) -> UsageLimit { + match self { + Plan::ZedPro => UsageLimit::Limited(500), + Plan::ZedProTrial => UsageLimit::Limited(150), + Plan::ZedFree => UsageLimit::Limited(50), + } + } + + pub fn edit_predictions_limit(&self) -> UsageLimit { + match self { + Plan::ZedPro => UsageLimit::Unlimited, + Plan::ZedProTrial => UsageLimit::Unlimited, + Plan::ZedFree => UsageLimit::Limited(2_000), + } + } +} + +impl FromStr for Plan { + type Err = anyhow::Error; + + fn from_str(value: &str) -> Result { + match value { + "zed_free" => Ok(Plan::ZedFree), + "zed_pro" => Ok(Plan::ZedPro), + "zed_pro_trial" => Ok(Plan::ZedProTrial), + plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")), + } + } +} + +#[derive( + Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum LanguageModelProvider { + Anthropic, + OpenAi, + Google, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsBody { + #[serde(skip_serializing_if = "Option::is_none", default)] + pub outline: Option, + pub input_events: String, + pub input_excerpt: String, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub speculated_output: Option, + /// Whether the user provided consent for sampling this interaction. + #[serde(default, alias = "data_collection_permission")] + pub can_collect_data: bool, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub diagnostic_groups: Option>, + /// Info about the git repository state, only present when can_collect_data is true. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub git_info: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsGitInfo { + /// SHA of git HEAD commit at time of prediction. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub head_sha: Option, + /// URL of the remote called `origin`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub remote_origin_url: Option, + /// URL of the remote called `upstream`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub remote_upstream_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsResponse { + pub request_id: Uuid, + pub output_excerpt: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcceptEditPredictionBody { + pub request_id: Uuid, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionMode { + Normal, + Max, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionIntent { + UserPrompt, + ToolResults, + ThreadSummarization, + ThreadContextSummarization, + CreateFile, + EditFile, + InlineAssist, + TerminalInlineAssist, + GenerateGitCommitMessage, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionBody { + #[serde(skip_serializing_if = "Option::is_none", default)] + pub thread_id: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub prompt_id: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub intent: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub mode: Option, + pub provider: LanguageModelProvider, + pub model: String, + pub provider_request: serde_json::Value, +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionRequestStatus { + Queued { + position: usize, + }, + Started, + Failed { + code: String, + message: String, + request_id: Uuid, + /// Retry duration in seconds. + retry_after: Option, + }, + UsageUpdated { + amount: usize, + limit: UsageLimit, + }, + ToolUseLimitReached, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionEvent { + Status(CompletionRequestStatus), + Event(T), +} + +impl CompletionEvent { + pub fn into_status(self) -> Option { + match self { + Self::Status(status) => Some(status), + Self::Event(_) => None, + } + } + + pub fn into_event(self) -> Option { + match self { + Self::Event(event) => Some(event), + Self::Status(_) => None, + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct WebSearchBody { + pub query: String, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct WebSearchResponse { + pub results: Vec, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct WebSearchResult { + pub title: String, + pub url: String, + pub text: String, +} + +#[derive(Serialize, Deserialize)] +pub struct CountTokensBody { + pub provider: LanguageModelProvider, + pub model: String, + pub provider_request: serde_json::Value, +} + +#[derive(Serialize, Deserialize)] +pub struct CountTokensResponse { + pub tokens: usize, +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelId(pub Arc); + +impl std::fmt::Display for LanguageModelId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LanguageModel { + pub provider: LanguageModelProvider, + pub id: LanguageModelId, + pub display_name: String, + pub max_token_count: usize, + pub max_token_count_in_max_mode: Option, + pub max_output_tokens: usize, + pub supports_tools: bool, + pub supports_images: bool, + pub supports_thinking: bool, + pub supports_max_mode: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListModelsResponse { + pub models: Vec, + pub default_model: LanguageModelId, + pub default_fast_model: LanguageModelId, + pub recommended_models: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GetSubscriptionResponse { + pub plan: Plan, + pub usage: Option, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct CurrentUsage { + pub model_requests: UsageData, + pub edit_predictions: UsageData, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct UsageData { + pub used: u32, + pub limit: UsageLimit, +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; + + #[test] + fn test_plan_deserialize_snake_case() { + let plan = serde_json::from_value::(json!("zed_free")).unwrap(); + assert_eq!(plan, Plan::ZedFree); + + let plan = serde_json::from_value::(json!("zed_pro")).unwrap(); + assert_eq!(plan, Plan::ZedPro); + + let plan = serde_json::from_value::(json!("zed_pro_trial")).unwrap(); + assert_eq!(plan, Plan::ZedProTrial); + } + + #[test] + fn test_plan_deserialize_aliases() { + let plan = serde_json::from_value::(json!("Free")).unwrap(); + assert_eq!(plan, Plan::ZedFree); + + let plan = serde_json::from_value::(json!("ZedPro")).unwrap(); + assert_eq!(plan, Plan::ZedPro); + + let plan = serde_json::from_value::(json!("ZedProTrial")).unwrap(); + assert_eq!(plan, Plan::ZedProTrial); + } + + #[test] + fn test_usage_limit_from_str() { + let limit = UsageLimit::from_str("unlimited").unwrap(); + assert!(matches!(limit, UsageLimit::Unlimited)); + + let limit = UsageLimit::from_str(&0.to_string()).unwrap(); + assert!(matches!(limit, UsageLimit::Limited(0))); + + let limit = UsageLimit::from_str(&50.to_string()).unwrap(); + assert!(matches!(limit, UsageLimit::Limited(50))); + + for value in ["not_a_number", "50xyz"] { + let limit = UsageLimit::from_str(value); + assert!(limit.is_err()); + } + } +} diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index d3b5048283..9af95317e6 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -23,13 +23,14 @@ async-stripe.workspace = true async-trait.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } -aws-sdk-s3 = { version = "1.15.0" } aws-sdk-kinesis = "1.51.0" +aws-sdk-s3 = { version = "1.15.0" } axum = { version = "0.6", features = ["json", "headers", "ws"] } axum-extra = { version = "0.4", features = ["erased-json"] } base64.workspace = true chrono.workspace = true clock.workspace = true +cloud_llm_client.workspace = true collections.workspace = true dashmap.workspace = true derive_more.workspace = true @@ -75,7 +76,6 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re util.workspace = true uuid.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true [dev-dependencies] agent_settings.workspace = true diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index ca840493ad..73d473ab76 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -173,6 +173,7 @@ CREATE TABLE "language_servers" ( "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "name" VARCHAR NOT NULL, + "capabilities" TEXT NOT NULL, PRIMARY KEY (project_id, id) ); diff --git a/crates/collab/migrations/20250804080620_language_server_capabilities.sql b/crates/collab/migrations/20250804080620_language_server_capabilities.sql new file mode 100644 index 0000000000..f74f094ed2 --- /dev/null +++ b/crates/collab/migrations/20250804080620_language_server_capabilities.sql @@ -0,0 +1,5 @@ +ALTER TABLE language_servers + ADD COLUMN capabilities TEXT NOT NULL DEFAULT '{}'; + +ALTER TABLE language_servers + ALTER COLUMN capabilities DROP DEFAULT; diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 3b0f5396a7..6cf3f68f54 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -100,13 +100,11 @@ impl std::fmt::Display for SystemIdHeader { pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() - .route("/user", get(update_or_create_authenticated_user)) .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) .route("/users/:id/update_plan", post(update_plan)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) - .merge(billing::router()) .merge(contributors::router()) .layer( ServiceBuilder::new() @@ -146,48 +144,6 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR Ok::<_, Error>(next.run(req).await) } -#[derive(Debug, Deserialize)] -struct AuthenticatedUserParams { - github_user_id: i32, - github_login: String, - github_email: Option, - github_name: Option, - github_user_created_at: chrono::DateTime, -} - -#[derive(Debug, Serialize)] -struct AuthenticatedUserResponse { - user: User, - metrics_id: String, - feature_flags: Vec, -} - -async fn update_or_create_authenticated_user( - Query(params): Query, - Extension(app): Extension>, -) -> Result> { - let initial_channel_id = app.config.auto_join_channel_id; - - let user = app - .db - .update_or_create_user_by_github_account( - ¶ms.github_login, - params.github_user_id, - params.github_email.as_deref(), - params.github_name.as_deref(), - params.github_user_created_at, - initial_channel_id, - ) - .await?; - let metrics_id = app.db.get_user_metrics_id(user.id).await?; - let feature_flags = app.db.get_user_flags(user.id).await?; - Ok(Json(AuthenticatedUserResponse { - user, - metrics_id, - feature_flags, - })) -} - #[derive(Debug, Deserialize)] struct LookUpUserParams { identifier: String, @@ -354,9 +310,9 @@ async fn refresh_llm_tokens( #[derive(Debug, Serialize, Deserialize)] struct UpdatePlanBody { - pub plan: zed_llm_client::Plan, + pub plan: cloud_llm_client::Plan, pub subscription_period: SubscriptionPeriod, - pub usage: zed_llm_client::CurrentUsage, + pub usage: cloud_llm_client::CurrentUsage, pub trial_started_at: Option>, pub is_usage_based_billing_enabled: bool, pub is_account_too_young: bool, @@ -378,9 +334,9 @@ async fn update_plan( extract::Json(body): extract::Json, ) -> Result> { let plan = match body.plan { - zed_llm_client::Plan::ZedFree => proto::Plan::Free, - zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro, - zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, + cloud_llm_client::Plan::ZedFree => proto::Plan::Free, + cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro, + cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, }; let update_user_plan = proto::UpdateUserPlan { @@ -412,15 +368,15 @@ async fn update_plan( Ok(Json(UpdatePlanResponse {})) } -fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit { +fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit { proto::UsageLimit { variant: Some(match limit { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 9a27e22f87..a0325d14c4 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,544 +1,10 @@ -use anyhow::{Context as _, bail}; -use axum::{Extension, Json, Router, extract, routing::post}; -use chrono::{DateTime, Utc}; -use collections::{HashMap, HashSet}; -use reqwest::StatusCode; -use sea_orm::ActiveValue; -use serde::{Deserialize, Serialize}; -use std::{sync::Arc, time::Duration}; -use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus}; -use util::{ResultExt, maybe}; -use zed_llm_client::LanguageModelProvider; +use std::sync::Arc; +use stripe::SubscriptionStatus; -use crate::db::billing_subscription::{ - StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, -}; -use crate::llm::db::subscription_usage_meter::{self, CompletionMode}; -use crate::rpc::{ResultExt as _, Server}; -use crate::stripe_client::{ - StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription, - StripeSubscriptionId, -}; -use crate::{AppState, Error, Result}; -use crate::{db::UserId, llm::db::LlmDatabase}; -use crate::{ - db::{ - CreateBillingCustomerParams, CreateBillingSubscriptionParams, - CreateProcessedStripeEventParams, UpdateBillingCustomerParams, - UpdateBillingSubscriptionParams, billing_customer, - }, - stripe_billing::StripeBilling, -}; - -pub fn router() -> Router { - Router::new().route( - "/billing/subscriptions/sync", - post(sync_billing_subscription), - ) -} - -#[derive(Debug, Deserialize)] -struct SyncBillingSubscriptionBody { - github_user_id: i32, -} - -#[derive(Debug, Serialize)] -struct SyncBillingSubscriptionResponse { - stripe_customer_id: String, -} - -async fn sync_billing_subscription( - Extension(app): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let Some(stripe_client) = app.stripe_client.clone() else { - log::error!("failed to retrieve Stripe client"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; - - let user = app - .db - .get_user_by_github_user_id(body.github_user_id) - .await? - .context("user not found")?; - - let billing_customer = app - .db - .get_billing_customer_by_user_id(user.id) - .await? - .context("billing customer not found")?; - let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - let subscriptions = stripe_client - .list_subscriptions_for_customer(&stripe_customer_id) - .await?; - - for subscription in subscriptions { - let subscription_id = subscription.id.clone(); - - sync_subscription(&app, &stripe_client, subscription) - .await - .with_context(|| { - format!( - "failed to sync subscription {subscription_id} for user {}", - user.id, - ) - })?; - } - - Ok(Json(SyncBillingSubscriptionResponse { - stripe_customer_id: billing_customer.stripe_customer_id.clone(), - })) -} - -/// The amount of time we wait in between each poll of Stripe events. -/// -/// This value should strike a balance between: -/// 1. Being short enough that we update quickly when something in Stripe changes -/// 2. Being long enough that we don't eat into our rate limits. -/// -/// As a point of reference, the Sequin folks say they have this at **500ms**: -/// -/// > We poll the Stripe /events endpoint every 500ms per account -/// > -/// > — https://blog.sequinstream.com/events-not-webhooks/ -const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5); - -/// The maximum number of events to return per page. -/// -/// We set this to 100 (the max) so we have to make fewer requests to Stripe. -/// -/// > Limit can range between 1 and 100, and the default is 10. -const EVENTS_LIMIT_PER_PAGE: u64 = 100; - -/// The number of pages consisting entirely of already-processed events that we -/// will see before we stop retrieving events. -/// -/// This is used to prevent over-fetching the Stripe events API for events we've -/// already seen and processed. -const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4; - -/// Polls the Stripe events API periodically to reconcile the records in our -/// database with the data in Stripe. -pub fn poll_stripe_events_periodically(app: Arc, rpc_server: Arc) { - let Some(real_stripe_client) = app.real_stripe_client.clone() else { - log::warn!("failed to retrieve Stripe client"); - return; - }; - let Some(stripe_client) = app.stripe_client.clone() else { - log::warn!("failed to retrieve Stripe client"); - return; - }; - - let executor = app.executor.clone(); - executor.spawn_detached({ - let executor = executor.clone(); - async move { - loop { - poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client) - .await - .log_err(); - - executor.sleep(POLL_EVENTS_INTERVAL).await; - } - } - }); -} - -async fn poll_stripe_events( - app: &Arc, - rpc_server: &Arc, - stripe_client: &Arc, - real_stripe_client: &stripe::Client, -) -> anyhow::Result<()> { - fn event_type_to_string(event_type: EventType) -> String { - // Calling `to_string` on `stripe::EventType` members gives us a quoted string, - // so we need to unquote it. - event_type.to_string().trim_matches('"').to_string() - } - - let event_types = [ - EventType::CustomerCreated, - EventType::CustomerUpdated, - EventType::CustomerSubscriptionCreated, - EventType::CustomerSubscriptionUpdated, - EventType::CustomerSubscriptionPaused, - EventType::CustomerSubscriptionResumed, - EventType::CustomerSubscriptionDeleted, - ] - .into_iter() - .map(event_type_to_string) - .collect::>(); - - let mut pages_of_already_processed_events = 0; - let mut unprocessed_events = Vec::new(); - - log::info!( - "Stripe events: starting retrieval for {}", - event_types.join(", ") - ); - let mut params = ListEvents::new(); - params.types = Some(event_types.clone()); - params.limit = Some(EVENTS_LIMIT_PER_PAGE); - - let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms) - .await? - .paginate(params); - - loop { - let processed_event_ids = { - let event_ids = event_pages - .page - .data - .iter() - .map(|event| event.id.as_str()) - .collect::>(); - app.db - .get_processed_stripe_events_by_event_ids(&event_ids) - .await? - .into_iter() - .map(|event| event.stripe_event_id) - .collect::>() - }; - - let mut processed_events_in_page = 0; - let events_in_page = event_pages.page.data.len(); - for event in &event_pages.page.data { - if processed_event_ids.contains(&event.id.to_string()) { - processed_events_in_page += 1; - log::debug!("Stripe events: already processed '{}', skipping", event.id); - } else { - unprocessed_events.push(event.clone()); - } - } - - if processed_events_in_page == events_in_page { - pages_of_already_processed_events += 1; - } - - if event_pages.page.has_more { - if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP - { - log::info!( - "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events" - ); - break; - } else { - log::info!("Stripe events: retrieving next page"); - event_pages = event_pages.next(&real_stripe_client).await?; - } - } else { - break; - } - } - - log::info!("Stripe events: unprocessed {}", unprocessed_events.len()); - - // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred. - unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id))); - - for event in unprocessed_events { - let event_id = event.id.clone(); - let processed_event_params = CreateProcessedStripeEventParams { - stripe_event_id: event.id.to_string(), - stripe_event_type: event_type_to_string(event.type_), - stripe_event_created_timestamp: event.created, - }; - - // If the event has happened too far in the past, we don't want to - // process it and risk overwriting other more-recent updates. - // - // 1 day was chosen arbitrarily. This could be made longer or shorter. - let one_day = Duration::from_secs(24 * 60 * 60); - let a_day_ago = Utc::now() - one_day; - if a_day_ago.timestamp() > event.created { - log::info!( - "Stripe events: event '{}' is more than {one_day:?} old, marking as processed", - event_id - ); - app.db - .create_processed_stripe_event(&processed_event_params) - .await?; - - continue; - } - - let process_result = match event.type_ { - EventType::CustomerCreated | EventType::CustomerUpdated => { - handle_customer_event(app, real_stripe_client, event).await - } - EventType::CustomerSubscriptionCreated - | EventType::CustomerSubscriptionUpdated - | EventType::CustomerSubscriptionPaused - | EventType::CustomerSubscriptionResumed - | EventType::CustomerSubscriptionDeleted => { - handle_customer_subscription_event(app, rpc_server, stripe_client, event).await - } - _ => Ok(()), - }; - - if let Some(()) = process_result - .with_context(|| format!("failed to process event {event_id} successfully")) - .log_err() - { - app.db - .create_processed_stripe_event(&processed_event_params) - .await?; - } - } - - Ok(()) -} - -async fn handle_customer_event( - app: &Arc, - _stripe_client: &stripe::Client, - event: stripe::Event, -) -> anyhow::Result<()> { - let EventObject::Customer(customer) = event.data.object else { - bail!("unexpected event payload for {}", event.id); - }; - - log::info!("handling Stripe {} event: {}", event.type_, event.id); - - let Some(email) = customer.email else { - log::info!("Stripe customer has no email: skipping"); - return Ok(()); - }; - - let Some(user) = app.db.get_user_by_email(&email).await? else { - log::info!("no user found for email: skipping"); - return Ok(()); - }; - - if let Some(existing_customer) = app - .db - .get_billing_customer_by_stripe_customer_id(&customer.id) - .await? - { - app.db - .update_billing_customer( - existing_customer.id, - &UpdateBillingCustomerParams { - // For now we just leave the information as-is, as it is not - // likely to change. - ..Default::default() - }, - ) - .await?; - } else { - app.db - .create_billing_customer(&CreateBillingCustomerParams { - user_id: user.id, - stripe_customer_id: customer.id.to_string(), - }) - .await?; - } - - Ok(()) -} - -async fn sync_subscription( - app: &Arc, - stripe_client: &Arc, - subscription: StripeSubscription, -) -> anyhow::Result { - let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing { - stripe_billing - .determine_subscription_kind(&subscription) - .await - } else { - None - }; - - let billing_customer = - find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer) - .await? - .context("billing customer not found")?; - - if let Some(SubscriptionKind::ZedProTrial) = subscription_kind { - if subscription.status == SubscriptionStatus::Trialing { - let current_period_start = - DateTime::from_timestamp(subscription.current_period_start, 0) - .context("No trial subscription period start")?; - - app.db - .update_billing_customer( - billing_customer.id, - &UpdateBillingCustomerParams { - trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())), - ..Default::default() - }, - ) - .await?; - } - } - - let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled - && subscription - .cancellation_details - .as_ref() - .and_then(|details| details.reason) - .map_or(false, |reason| { - reason == StripeCancellationDetailsReason::PaymentFailed - }); - - if was_canceled_due_to_payment_failure { - app.db - .update_billing_customer( - billing_customer.id, - &UpdateBillingCustomerParams { - has_overdue_invoices: ActiveValue::set(true), - ..Default::default() - }, - ) - .await?; - } - - if let Some(existing_subscription) = app - .db - .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref()) - .await? - { - app.db - .update_billing_subscription( - existing_subscription.id, - &UpdateBillingSubscriptionParams { - billing_customer_id: ActiveValue::set(billing_customer.id), - kind: ActiveValue::set(subscription_kind), - stripe_subscription_id: ActiveValue::set(subscription.id.to_string()), - stripe_subscription_status: ActiveValue::set(subscription.status.into()), - stripe_cancel_at: ActiveValue::set( - subscription - .cancel_at - .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0)) - .map(|time| time.naive_utc()), - ), - stripe_cancellation_reason: ActiveValue::set( - subscription - .cancellation_details - .and_then(|details| details.reason) - .map(|reason| reason.into()), - ), - stripe_current_period_start: ActiveValue::set(Some( - subscription.current_period_start, - )), - stripe_current_period_end: ActiveValue::set(Some( - subscription.current_period_end, - )), - }, - ) - .await?; - } else { - if let Some(existing_subscription) = app - .db - .get_active_billing_subscription(billing_customer.user_id) - .await? - { - if existing_subscription.kind == Some(SubscriptionKind::ZedFree) - && subscription_kind == Some(SubscriptionKind::ZedProTrial) - { - let stripe_subscription_id = StripeSubscriptionId( - existing_subscription.stripe_subscription_id.clone().into(), - ); - - stripe_client - .cancel_subscription(&stripe_subscription_id) - .await?; - } else { - // If the user already has an active billing subscription, ignore the - // event and return an `Ok` to signal that it was processed - // successfully. - // - // There is the possibility that this could cause us to not create a - // subscription in the following scenario: - // - // 1. User has an active subscription A - // 2. User cancels subscription A - // 3. User creates a new subscription B - // 4. We process the new subscription B before the cancellation of subscription A - // 5. User ends up with no subscriptions - // - // In theory this situation shouldn't arise as we try to process the events in the order they occur. - - log::info!( - "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}", - user_id = billing_customer.user_id, - subscription_id = subscription.id - ); - return Ok(billing_customer); - } - } - - app.db - .create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: billing_customer.id, - kind: subscription_kind, - stripe_subscription_id: subscription.id.to_string(), - stripe_subscription_status: subscription.status.into(), - stripe_cancellation_reason: subscription - .cancellation_details - .and_then(|details| details.reason) - .map(|reason| reason.into()), - stripe_current_period_start: Some(subscription.current_period_start), - stripe_current_period_end: Some(subscription.current_period_end), - }) - .await?; - } - - if let Some(stripe_billing) = app.stripe_billing.as_ref() { - if subscription.status == SubscriptionStatus::Canceled - || subscription.status == SubscriptionStatus::Paused - { - let already_has_active_billing_subscription = app - .db - .has_active_billing_subscription(billing_customer.user_id) - .await?; - if !already_has_active_billing_subscription { - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - stripe_billing - .subscribe_to_zed_free(stripe_customer_id) - .await?; - } - } - } - - Ok(billing_customer) -} - -async fn handle_customer_subscription_event( - app: &Arc, - rpc_server: &Arc, - stripe_client: &Arc, - event: stripe::Event, -) -> anyhow::Result<()> { - let EventObject::Subscription(subscription) = event.data.object else { - bail!("unexpected event payload for {}", event.id); - }; - - log::info!("handling Stripe {} event: {}", event.type_, event.id); - - let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?; - - // When the user's subscription changes, push down any changes to their plan. - rpc_server - .update_plan_for_user_legacy(billing_customer.user_id) - .await - .trace_err(); - - // When the user's subscription changes, we want to refresh their LLM tokens - // to either grant/revoke access. - rpc_server - .refresh_llm_tokens_for_user(billing_customer.user_id) - .await; - - Ok(()) -} +use crate::AppState; +use crate::db::billing_subscription::StripeSubscriptionStatus; +use crate::db::{CreateBillingCustomerParams, billing_customer}; +use crate::stripe_client::{StripeClient, StripeCustomerId}; impl From for StripeSubscriptionStatus { fn from(value: SubscriptionStatus) -> Self { @@ -555,16 +21,6 @@ impl From for StripeSubscriptionStatus { } } -impl From for StripeCancellationReason { - fn from(value: CancellationDetailsReason) -> Self { - match value { - CancellationDetailsReason::CancellationRequested => Self::CancellationRequested, - CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed, - CancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} - /// Finds or creates a billing customer using the provided customer. pub async fn find_or_create_billing_customer( app: &Arc, @@ -601,186 +57,3 @@ pub async fn find_or_create_billing_customer( Ok(Some(billing_customer)) } - -const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60); - -pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc) { - let Some(stripe_billing) = app.stripe_billing.clone() else { - log::warn!("failed to retrieve Stripe billing object"); - return; - }; - let Some(llm_db) = app.llm_db.clone() else { - log::warn!("failed to retrieve LLM database"); - return; - }; - - let executor = app.executor.clone(); - executor.spawn_detached({ - let executor = executor.clone(); - async move { - loop { - sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing) - .await - .context("failed to sync LLM request usage to Stripe") - .trace_err(); - executor - .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL) - .await; - } - } - }); -} - -async fn sync_model_request_usage_with_stripe( - app: &Arc, - llm_db: &Arc, - stripe_billing: &Arc, -) -> anyhow::Result<()> { - log::info!("Stripe usage sync: Starting"); - let started_at = Utc::now(); - - let staff_users = app.db.get_staff_users().await?; - let staff_user_ids = staff_users - .iter() - .map(|user| user.id) - .collect::>(); - - let usage_meters = llm_db - .get_current_subscription_usage_meters(Utc::now()) - .await?; - let mut usage_meters_by_user_id = - HashMap::>::default(); - for (usage_meter, usage) in usage_meters { - let meters = usage_meters_by_user_id.entry(usage.user_id).or_default(); - meters.push(usage_meter); - } - - log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions"); - let get_zed_pro_subscriptions_started_at = Utc::now(); - let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?; - log::info!( - "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}", - billing_subscriptions.len(), - Utc::now() - get_zed_pro_subscriptions_started_at - ); - - let claude_sonnet_4 = stripe_billing - .find_price_by_lookup_key("claude-sonnet-4-requests") - .await?; - let claude_sonnet_4_max = stripe_billing - .find_price_by_lookup_key("claude-sonnet-4-requests-max") - .await?; - let claude_opus_4 = stripe_billing - .find_price_by_lookup_key("claude-opus-4-requests") - .await?; - let claude_opus_4_max = stripe_billing - .find_price_by_lookup_key("claude-opus-4-requests-max") - .await?; - let claude_3_5_sonnet = stripe_billing - .find_price_by_lookup_key("claude-3-5-sonnet-requests") - .await?; - let claude_3_7_sonnet = stripe_billing - .find_price_by_lookup_key("claude-3-7-sonnet-requests") - .await?; - let claude_3_7_sonnet_max = stripe_billing - .find_price_by_lookup_key("claude-3-7-sonnet-requests-max") - .await?; - - let model_mode_combinations = [ - ("claude-opus-4", CompletionMode::Max), - ("claude-opus-4", CompletionMode::Normal), - ("claude-sonnet-4", CompletionMode::Max), - ("claude-sonnet-4", CompletionMode::Normal), - ("claude-3-7-sonnet", CompletionMode::Max), - ("claude-3-7-sonnet", CompletionMode::Normal), - ("claude-3-5-sonnet", CompletionMode::Normal), - ]; - - let billing_subscription_count = billing_subscriptions.len(); - - log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions"); - - for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions { - maybe!(async { - if staff_user_ids.contains(&user_id) { - return anyhow::Ok(()); - } - - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - let stripe_subscription_id = - StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into()); - - let usage_meters = usage_meters_by_user_id.get(&user_id); - - for (model, mode) in &model_mode_combinations { - let Ok(model) = - llm_db.model(LanguageModelProvider::Anthropic, model) - else { - log::warn!("Failed to load model for user {user_id}: {model}"); - continue; - }; - - let (price, meter_event_name) = match model.name.as_str() { - "claude-opus-4" => match mode { - CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"), - CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"), - }, - "claude-sonnet-4" => match mode { - CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"), - CompletionMode::Max => { - (&claude_sonnet_4_max, "claude_sonnet_4/requests/max") - } - }, - "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), - "claude-3-7-sonnet" => match mode { - CompletionMode::Normal => { - (&claude_3_7_sonnet, "claude_3_7_sonnet/requests") - } - CompletionMode::Max => { - (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max") - } - }, - model_name => { - bail!("Attempted to sync usage meter for unsupported model: {model_name:?}") - } - }; - - let model_requests = usage_meters - .and_then(|usage_meters| { - usage_meters - .iter() - .find(|meter| meter.model_id == model.id && meter.mode == *mode) - }) - .map(|usage_meter| usage_meter.requests) - .unwrap_or(0); - - if model_requests > 0 { - stripe_billing - .subscribe_to_price(&stripe_subscription_id, price) - .await?; - } - - stripe_billing - .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests) - .await - .with_context(|| { - format!( - "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}", - ) - })?; - } - - Ok(()) - }) - .await - .log_err(); - } - - log::info!( - "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}", - Utc::now() - started_at - ); - - Ok(()) -} diff --git a/crates/collab/src/api/contributors.rs b/crates/collab/src/api/contributors.rs index 9296c1d428..8cfef0ad7e 100644 --- a/crates/collab/src/api/contributors.rs +++ b/crates/collab/src/api/contributors.rs @@ -8,7 +8,6 @@ use axum::{ use chrono::{NaiveDateTime, SecondsFormat}; use serde::{Deserialize, Serialize}; -use crate::api::AuthenticatedUserParams; use crate::db::ContributorSelector; use crate::{AppState, Result}; @@ -104,9 +103,18 @@ impl RenovateBot { } } +#[derive(Debug, Deserialize)] +struct AddContributorBody { + github_user_id: i32, + github_login: String, + github_email: Option, + github_name: Option, + github_user_created_at: chrono::DateTime, +} + async fn add_contributor( Extension(app): Extension>, - extract::Json(params): extract::Json, + extract::Json(params): extract::Json, ) -> Result<()> { let initial_channel_id = app.config.auto_join_channel_id; app.db diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index bc7dd152b0..2f34a843a8 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -580,7 +580,7 @@ fn for_snowflake( }, serde_json::to_value(e).unwrap(), ), - Event::InlineCompletion(e) => ( + Event::EditPrediction(e) => ( format!( "Edit Prediction {}", if e.suggestion_accepted { @@ -591,7 +591,7 @@ fn for_snowflake( ), serde_json::to_value(e).unwrap(), ), - Event::InlineCompletionRating(e) => ( + Event::EditPredictionRating(e) => ( "Edit Prediction Rated".to_string(), serde_json::to_value(e).unwrap(), ), diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 8cd1e3ea83..2c22ca2069 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -529,11 +529,17 @@ pub struct RejoinedProject { pub worktrees: Vec, pub updated_repositories: Vec, pub removed_repositories: Vec, - pub language_servers: Vec, + pub language_servers: Vec, } impl RejoinedProject { pub fn to_proto(&self) -> proto::RejoinedProject { + let (language_servers, language_server_capabilities) = self + .language_servers + .clone() + .into_iter() + .map(|server| (server.server, server.capabilities)) + .unzip(); proto::RejoinedProject { id: self.id.to_proto(), worktrees: self @@ -551,7 +557,8 @@ impl RejoinedProject { .iter() .map(|collaborator| collaborator.to_proto()) .collect(), - language_servers: self.language_servers.clone(), + language_servers, + language_server_capabilities, } } } @@ -598,7 +605,7 @@ pub struct Project { pub collaborators: Vec, pub worktrees: BTreeMap, pub repositories: Vec, - pub language_servers: Vec, + pub language_servers: Vec, } pub struct ProjectCollaborator { @@ -623,6 +630,12 @@ impl ProjectCollaborator { } } +#[derive(Debug, Clone)] +pub struct LanguageServer { + pub server: proto::LanguageServer, + pub capabilities: String, +} + #[derive(Debug)] pub struct LeftProject { pub id: ProjectId, diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 9f82e3dbc4..8361d6b4d0 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -85,19 +85,6 @@ impl Database { .await } - /// Returns the billing subscription with the specified ID. - pub async fn get_billing_subscription_by_id( - &self, - id: BillingSubscriptionId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_subscription::Entity::find_by_id(id) - .one(&*tx) - .await?) - }) - .await - } - /// Returns the billing subscription with the specified Stripe subscription ID. pub async fn get_billing_subscription_by_stripe_subscription_id( &self, @@ -143,119 +130,6 @@ impl Database { .await } - /// Returns all of the billing subscriptions for the user with the specified ID. - /// - /// Note that this returns the subscriptions regardless of their status. - /// If you're wanting to check if a use has an active billing subscription, - /// use `get_active_billing_subscriptions` instead. - pub async fn get_billing_subscriptions( - &self, - user_id: UserId, - ) -> Result> { - self.transaction(|tx| async move { - let subscriptions = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .filter(billing_customer::Column::UserId.eq(user_id)) - .order_by_asc(billing_subscription::Column::Id) - .all(&*tx) - .await?; - - Ok(subscriptions) - }) - .await - } - - pub async fn get_active_billing_subscriptions( - &self, - user_ids: HashSet, - ) -> Result> { - self.transaction(|tx| { - let user_ids = user_ids.clone(); - async move { - let mut rows = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .select_also(billing_customer::Entity) - .filter(billing_customer::Column::UserId.is_in(user_ids)) - .filter( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .filter(billing_subscription::Column::Kind.is_null()) - .order_by_asc(billing_subscription::Column::Id) - .stream(&*tx) - .await?; - - let mut subscriptions = HashMap::default(); - while let Some(row) = rows.next().await { - if let (subscription, Some(customer)) = row? { - subscriptions.insert(customer.user_id, (customer, subscription)); - } - } - Ok(subscriptions) - } - }) - .await - } - - pub async fn get_active_zed_pro_billing_subscriptions( - &self, - ) -> Result> { - self.transaction(|tx| async move { - let mut rows = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .select_also(billing_customer::Entity) - .filter( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro)) - .order_by_asc(billing_subscription::Column::Id) - .stream(&*tx) - .await?; - - let mut subscriptions = HashMap::default(); - while let Some(row) = rows.next().await { - if let (subscription, Some(customer)) = row? { - subscriptions.insert(customer.user_id, (customer, subscription)); - } - } - Ok(subscriptions) - }) - .await - } - - pub async fn get_active_zed_pro_billing_subscriptions_for_users( - &self, - user_ids: HashSet, - ) -> Result> { - self.transaction(|tx| { - let user_ids = user_ids.clone(); - async move { - let mut rows = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .select_also(billing_customer::Entity) - .filter(billing_customer::Column::UserId.is_in(user_ids)) - .filter( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro)) - .order_by_asc(billing_subscription::Column::Id) - .stream(&*tx) - .await?; - - let mut subscriptions = HashMap::default(); - while let Some(row) = rows.next().await { - if let (subscription, Some(customer)) = row? { - subscriptions.insert(customer.user_id, (customer, subscription)); - } - } - Ok(subscriptions) - } - }) - .await - } - /// Returns whether the user has an active billing subscription. pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result { Ok(self.count_active_billing_subscriptions(user_id).await? > 0) diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index a288a4e7eb..2e6b4719d1 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -786,6 +786,32 @@ impl Database { }) .collect()) } + + /// Update language server capabilities for a given id. + pub async fn update_server_capabilities( + &self, + project_id: ProjectId, + server_id: u64, + new_capabilities: String, + ) -> Result<()> { + self.transaction(|tx| { + let new_capabilities = new_capabilities.clone(); + async move { + Ok( + language_server::Entity::update(language_server::ActiveModel { + project_id: ActiveValue::unchanged(project_id), + id: ActiveValue::unchanged(server_id as i64), + capabilities: ActiveValue::set(new_capabilities), + ..Default::default() + }) + .exec(&*tx) + .await?, + ) + } + }) + .await?; + Ok(()) + } } fn operation_to_storage( diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index ba22a7b4e3..31635575a8 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -692,6 +692,7 @@ impl Database { project_id: ActiveValue::set(project_id), id: ActiveValue::set(server.id as i64), name: ActiveValue::set(server.name.clone()), + capabilities: ActiveValue::set(update.capabilities.clone()), }) .on_conflict( OnConflict::columns([ @@ -1054,10 +1055,13 @@ impl Database { repositories, language_servers: language_servers .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id as u64, - name: language_server.name, - worktree_id: None, + .map(|language_server| LanguageServer { + server: proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + worktree_id: None, + }, + capabilities: language_server.capabilities, }) .collect(), }; diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index cb805786dd..c63d7133be 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -804,10 +804,13 @@ impl Database { .all(tx) .await? .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id as u64, - name: language_server.name, - worktree_id: None, + .map(|language_server| LanguageServer { + server: proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + worktree_id: None, + }, + capabilities: language_server.capabilities, }) .collect::>(); diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 43198f9859..522973dbc9 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -95,7 +95,7 @@ pub enum SubscriptionKind { ZedFree, } -impl From for zed_llm_client::Plan { +impl From for cloud_llm_client::Plan { fn from(value: SubscriptionKind) -> Self { match value { SubscriptionKind::ZedPro => Self::ZedPro, diff --git a/crates/collab/src/db/tables/language_server.rs b/crates/collab/src/db/tables/language_server.rs index 9ff8c75fc6..34c7514d91 100644 --- a/crates/collab/src/db/tables/language_server.rs +++ b/crates/collab/src/db/tables/language_server.rs @@ -9,6 +9,7 @@ pub struct Model { #[sea_orm(primary_key)] pub id: i64, pub name: String, + pub capabilities: String, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 9404e2670c..6c2f9dc82a 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -1,4 +1,3 @@ -mod billing_subscription_tests; mod buffer_tests; mod channel_tests; mod contributor_tests; diff --git a/crates/collab/src/db/tests/billing_subscription_tests.rs b/crates/collab/src/db/tests/billing_subscription_tests.rs deleted file mode 100644 index fb5f8552a3..0000000000 --- a/crates/collab/src/db/tests/billing_subscription_tests.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::sync::Arc; - -use crate::db::billing_subscription::StripeSubscriptionStatus; -use crate::db::tests::new_test_user; -use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams}; -use crate::test_both_dbs; - -use super::Database; - -test_both_dbs!( - test_get_active_billing_subscriptions, - test_get_active_billing_subscriptions_postgres, - test_get_active_billing_subscriptions_sqlite -); - -async fn test_get_active_billing_subscriptions(db: &Arc) { - // A user with no subscription has no active billing subscriptions. - { - let user_id = new_test_user(db, "no-subscription-user@example.com").await; - let subscription_count = db - .count_active_billing_subscriptions(user_id) - .await - .unwrap(); - - assert_eq!(subscription_count, 0); - } - - // A user with an active subscription has one active billing subscription. - { - let user_id = new_test_user(db, "active-user@example.com").await; - let customer = db - .create_billing_customer(&CreateBillingCustomerParams { - user_id, - stripe_customer_id: "cus_active_user".into(), - }) - .await - .unwrap(); - assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string()); - - db.create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: customer.id, - kind: None, - stripe_subscription_id: "sub_active_user".into(), - stripe_subscription_status: StripeSubscriptionStatus::Active, - stripe_cancellation_reason: None, - stripe_current_period_start: None, - stripe_current_period_end: None, - }) - .await - .unwrap(); - - let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap(); - assert_eq!(subscriptions.len(), 1); - - let subscription = &subscriptions[0]; - assert_eq!( - subscription.stripe_subscription_id, - "sub_active_user".to_string() - ); - assert_eq!( - subscription.stripe_subscription_status, - StripeSubscriptionStatus::Active - ); - } - - // A user with a past-due subscription has no active billing subscriptions. - { - let user_id = new_test_user(db, "past-due-user@example.com").await; - let customer = db - .create_billing_customer(&CreateBillingCustomerParams { - user_id, - stripe_customer_id: "cus_past_due_user".into(), - }) - .await - .unwrap(); - assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string()); - - db.create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: customer.id, - kind: None, - stripe_subscription_id: "sub_past_due_user".into(), - stripe_subscription_status: StripeSubscriptionStatus::PastDue, - stripe_cancellation_reason: None, - stripe_current_period_start: None, - stripe_current_period_end: None, - }) - .await - .unwrap(); - - let subscription_count = db - .count_active_billing_subscriptions(user_id) - .await - .unwrap(); - assert_eq!(subscription_count, 0); - } -} diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index 6a6efca0de..18ad624dab 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -6,11 +6,11 @@ mod tables; #[cfg(test)] mod tests; +use cloud_llm_client::LanguageModelProvider; use collections::HashMap; pub use ids::*; pub use seed::*; pub use tables::*; -use zed_llm_client::LanguageModelProvider; #[cfg(test)] pub use tests::TestLlmDb; diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index 3565366fdd..0087218b3f 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -1,6 +1,5 @@ use super::*; pub mod providers; -pub mod subscription_usage_meters; pub mod subscription_usages; pub mod usages; diff --git a/crates/collab/src/llm/db/queries/subscription_usage_meters.rs b/crates/collab/src/llm/db/queries/subscription_usage_meters.rs deleted file mode 100644 index c0ce5d679b..0000000000 --- a/crates/collab/src/llm/db/queries/subscription_usage_meters.rs +++ /dev/null @@ -1,72 +0,0 @@ -use crate::db::UserId; -use crate::llm::db::queries::subscription_usages::convert_chrono_to_time; - -use super::*; - -impl LlmDatabase { - /// Returns all current subscription usage meters as of the given timestamp. - pub async fn get_current_subscription_usage_meters( - &self, - now: DateTimeUtc, - ) -> Result> { - let now = convert_chrono_to_time(now)?; - - self.transaction(|tx| async move { - let result = subscription_usage_meter::Entity::find() - .inner_join(subscription_usage::Entity) - .filter( - subscription_usage::Column::PeriodStartAt - .lte(now) - .and(subscription_usage::Column::PeriodEndAt.gte(now)), - ) - .select_also(subscription_usage::Entity) - .all(&*tx) - .await?; - - let result = result - .into_iter() - .filter_map(|(meter, usage)| { - let usage = usage?; - Some((meter, usage)) - }) - .collect(); - - Ok(result) - }) - .await - } - - /// Returns all current subscription usage meters for the given user as of the given timestamp. - pub async fn get_current_subscription_usage_meters_for_user( - &self, - user_id: UserId, - now: DateTimeUtc, - ) -> Result> { - let now = convert_chrono_to_time(now)?; - - self.transaction(|tx| async move { - let result = subscription_usage_meter::Entity::find() - .inner_join(subscription_usage::Entity) - .filter(subscription_usage::Column::UserId.eq(user_id)) - .filter( - subscription_usage::Column::PeriodStartAt - .lte(now) - .and(subscription_usage::Column::PeriodEndAt.gte(now)), - ) - .select_also(subscription_usage::Entity) - .all(&*tx) - .await?; - - let result = result - .into_iter() - .filter_map(|(meter, usage)| { - let usage = usage?; - Some((meter, usage)) - }) - .collect(); - - Ok(result) - }) - .await - } -} diff --git a/crates/collab/src/llm/db/queries/subscription_usages.rs b/crates/collab/src/llm/db/queries/subscription_usages.rs index ee1ebf59b8..8a51979075 100644 --- a/crates/collab/src/llm/db/queries/subscription_usages.rs +++ b/crates/collab/src/llm/db/queries/subscription_usages.rs @@ -1,28 +1,7 @@ -use time::PrimitiveDateTime; - use crate::db::UserId; use super::*; -pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result { - use chrono::{Datelike as _, Timelike as _}; - - let date = time::Date::from_calendar_date( - datetime.year(), - time::Month::try_from(datetime.month() as u8).unwrap(), - datetime.day() as u8, - )?; - - let time = time::Time::from_hms_nano( - datetime.hour() as u8, - datetime.minute() as u8, - datetime.second() as u8, - datetime.nanosecond(), - )?; - - Ok(PrimitiveDateTime::new(date, time)) -} - impl LlmDatabase { pub async fn get_subscription_usage_for_period( &self, diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs index 7d52964b93..f4e1de40ec 100644 --- a/crates/collab/src/llm/db/tests/provider_tests.rs +++ b/crates/collab/src/llm/db/tests/provider_tests.rs @@ -1,5 +1,5 @@ +use cloud_llm_client::LanguageModelProvider; use pretty_assertions::assert_eq; -use zed_llm_client::LanguageModelProvider; use crate::llm::db::LlmDatabase; use crate::test_llm_db; diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index d4566ffcb4..da01c7f3be 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -4,12 +4,12 @@ use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEA use crate::{Config, db::billing_preference}; use anyhow::{Context as _, Result}; use chrono::{NaiveDateTime, Utc}; +use cloud_llm_client::Plan; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use std::time::Duration; use thiserror::Error; use uuid::Uuid; -use zed_llm_client::Plan; #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 6a78049b3f..20641cb232 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -7,8 +7,8 @@ use axum::{ routing::get, }; +use collab::ServiceMode; use collab::api::CloudflareIpCountryHeader; -use collab::api::billing::sync_llm_request_usage_with_stripe_periodically; use collab::llm::db::LlmDatabase; use collab::migrations::run_database_migrations; use collab::user_backfiller::spawn_user_backfiller; @@ -16,7 +16,6 @@ use collab::{ AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, rpc::ResultExt, }; -use collab::{ServiceMode, api::billing::poll_stripe_events_periodically}; use db::Database; use std::{ env::args, @@ -31,7 +30,7 @@ use tower_http::trace::TraceLayer; use tracing_subscriber::{ Layer, filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, }; -use util::{ResultExt as _, maybe}; +use util::ResultExt as _; const VERSION: &str = env!("CARGO_PKG_VERSION"); const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); @@ -120,8 +119,6 @@ async fn main() -> Result<()> { let rpc_server = collab::rpc::Server::new(epoch, state.clone()); rpc_server.start().await?; - poll_stripe_events_periodically(state.clone(), rpc_server.clone()); - app = app .merge(collab::api::routes(rpc_server.clone())) .merge(collab::rpc::routes(rpc_server.clone())); @@ -133,29 +130,6 @@ async fn main() -> Result<()> { fetch_extensions_from_blob_store_periodically(state.clone()); spawn_user_backfiller(state.clone()); - let llm_db = maybe!(async { - let database_url = state - .config - .llm_database_url - .as_ref() - .context("missing LLM_DATABASE_URL")?; - let max_connections = state - .config - .llm_database_max_connections - .context("missing LLM_DATABASE_MAX_CONNECTIONS")?; - - let mut db_options = db::ConnectOptions::new(database_url); - db_options.max_connections(max_connections); - LlmDatabase::new(db_options, state.executor.clone()).await - }) - .await - .trace_err(); - - if let Some(mut llm_db) = llm_db { - llm_db.initialize().await?; - sync_llm_request_usage_with_stripe_periodically(state.clone()); - } - app = app .merge(collab::api::events::router()) .merge(collab::api::extensions::router()) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 0735b08e89..dfa5859f51 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -23,6 +23,7 @@ use anyhow::{Context as _, anyhow, bail}; use async_tungstenite::tungstenite::{ Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame, }; +use axum::headers::UserAgent; use axum::{ Extension, Router, TypedHeader, body::Body, @@ -41,7 +42,7 @@ use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; use reqwest_client::ReqwestClient; -use rpc::proto::split_repository_update; +use rpc::proto::{MultiLspQuery, split_repository_update}; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; use futures::{ @@ -314,7 +315,7 @@ impl Server { .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_find_search_candidates_request) + .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) @@ -339,9 +340,6 @@ impl Server { .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) - .add_request_handler( - forward_read_only_project_request::, - ) .add_request_handler(forward_read_only_project_request::) .add_request_handler( forward_mutating_project_request::, @@ -373,7 +371,7 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) + .add_request_handler(multi_lsp_query) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) @@ -433,6 +431,8 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) + .add_request_handler(forward_mutating_project_request::) + .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_read_only_project_request::) @@ -663,7 +663,6 @@ impl Server { let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0; let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0; let queue_duration_ms = total_duration_ms - processing_duration_ms; - let payload_type = M::NAME; match result { Err(error) => { @@ -672,7 +671,6 @@ impl Server { total_duration_ms, processing_duration_ms, queue_duration_ms, - payload_type, "error handling message" ) } @@ -748,6 +746,8 @@ impl Server { address: String, principal: Principal, zed_version: ZedVersion, + release_channel: Option, + user_agent: Option, geoip_country_code: Option, system_id: Option, send_connection_id: Option>, @@ -760,9 +760,17 @@ impl Server { user_id=field::Empty, login=field::Empty, impersonator=field::Empty, + user_agent=field::Empty, geoip_country_code=field::Empty ); principal.update_span(&span); + if let Some(user_agent) = user_agent { + span.record("user_agent", user_agent); + } + if let Some(release_channel) = release_channel { + span.record("release_channel", release_channel); + } + if let Some(country_code) = geoip_country_code.as_ref() { span.record("geoip_country_code", country_code); } @@ -771,12 +779,11 @@ impl Server { async move { if *teardown.borrow() { tracing::error!("server is tearing down"); - return + return; } - let (connection_id, handle_io, mut incoming_rx) = this - .peer - .add_connection(connection, { + let (connection_id, handle_io, mut incoming_rx) = + this.peer.add_connection(connection, { let executor = executor.clone(); move |duration| executor.sleep(duration) }); @@ -793,10 +800,14 @@ impl Server { } }; - let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new( - supermaven_admin_api_key.to_string(), - http_client.clone(), - ))); + let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map( + |supermaven_admin_api_key| { + Arc::new(SupermavenAdminApi::new( + supermaven_admin_api_key.to_string(), + http_client.clone(), + )) + }, + ); let session = Session { principal: principal.clone(), @@ -811,7 +822,15 @@ impl Server { supermaven_client, }; - if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await { + if let Err(error) = this + .send_initial_client_update( + connection_id, + zed_version, + send_connection_id, + &session, + ) + .await + { tracing::error!(?error, "failed to send initial client update"); return; } @@ -828,14 +847,22 @@ impl Server { // // This arrangement ensures we will attempt to process earlier messages first, but fall // back to processing messages arrived later in the spirit of making progress. + const MAX_CONCURRENT_HANDLERS: usize = 256; let mut foreground_message_handlers = FuturesUnordered::new(); - let concurrent_handlers = Arc::new(Semaphore::new(256)); + let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS)); + let get_concurrent_handlers = { + let concurrent_handlers = concurrent_handlers.clone(); + move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits() + }; loop { let next_message = async { let permit = concurrent_handlers.clone().acquire_owned().await.unwrap(); let message = incoming_rx.next().await; - (permit, message) - }.fuse(); + // Cache the concurrent_handlers here, so that we know what the + // queue looks like as each handler starts + (permit, message, get_concurrent_handlers()) + } + .fuse(); futures::pin_mut!(next_message); futures::select_biased! { _ = teardown.changed().fuse() => return, @@ -847,15 +874,20 @@ impl Server { } _ = foreground_message_handlers.next() => {} next_message = next_message => { - let (permit, message) = next_message; + let (permit, message, concurrent_handlers) = next_message; if let Some(message) = message { let type_name = message.payload_type_name(); // note: we copy all the fields from the parent span so we can query them in the logs. // (https://github.com/tokio-rs/tracing/issues/2670). - let span = tracing::info_span!("receive message", %connection_id, %address, type_name, + let span = tracing::info_span!("receive message", + %connection_id, + %address, + type_name, + concurrent_handlers, user_id=field::Empty, login=field::Empty, impersonator=field::Empty, + multi_lsp_query_request=field::Empty, ); principal.update_span(&span); let span_enter = span.enter(); @@ -885,12 +917,13 @@ impl Server { } drop(foreground_message_handlers); - tracing::info!("signing out"); + let concurrent_handlers = get_concurrent_handlers(); + tracing::info!(concurrent_handlers, "signing out"); if let Err(error) = connection_lost(session, teardown, executor).await { tracing::error!(?error, "error signing out"); } - - }.instrument(span) + } + .instrument(span) } async fn send_initial_client_update( @@ -1152,6 +1185,35 @@ impl Header for AppVersionHeader { } } +#[derive(Debug)] +pub struct ReleaseChannelHeader(String); + +impl Header for ReleaseChannelHeader { + fn name() -> &'static HeaderName { + static ZED_RELEASE_CHANNEL: OnceLock = OnceLock::new(); + ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel")) + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + Ok(Self( + values + .next() + .ok_or_else(axum::headers::Error::invalid)? + .to_str() + .map_err(|_| axum::headers::Error::invalid())? + .to_owned(), + )) + } + + fn encode>(&self, values: &mut E) { + values.extend([self.0.parse().unwrap()]); + } +} + pub fn routes(server: Arc) -> Router<(), Body> { Router::new() .route("/rpc", get(handle_websocket_request)) @@ -1167,9 +1229,11 @@ pub fn routes(server: Arc) -> Router<(), Body> { pub async fn handle_websocket_request( TypedHeader(ProtocolVersion(protocol_version)): TypedHeader, app_version_header: Option>, + release_channel_header: Option>, ConnectInfo(socket_address): ConnectInfo, Extension(server): Extension>, Extension(principal): Extension, + user_agent: Option>, country_code_header: Option>, system_id_header: Option>, ws: WebSocketUpgrade, @@ -1190,6 +1254,8 @@ pub async fn handle_websocket_request( .into_response(); }; + let release_channel = release_channel_header.map(|header| header.0.0); + if !version.can_collaborate() { return ( StatusCode::UPGRADE_REQUIRED, @@ -1225,6 +1291,8 @@ pub async fn handle_websocket_request( socket_address, principal, version, + release_channel, + user_agent.map(|header| header.to_string()), country_code_header.map(|header| header.to_string()), system_id_header.map(|header| header.to_string()), None, @@ -1956,12 +2024,19 @@ async fn join_project( } // First, we send the metadata associated with each worktree. + let (language_servers, language_server_capabilities) = project + .language_servers + .clone() + .into_iter() + .map(|server| (server.server, server.capabilities)) + .unzip(); response.send(proto::JoinProjectResponse { project_id: project.id.0 as u64, worktrees: worktrees.clone(), replica_id: replica_id.0 as u32, collaborators: collaborators.clone(), - language_servers: project.language_servers.clone(), + language_servers, + language_server_capabilities, role: project.role.into(), })?; @@ -2020,8 +2095,8 @@ async fn join_project( session.connection_id, proto::UpdateLanguageServer { project_id: project_id.to_proto(), - server_name: Some(language_server.name.clone()), - language_server_id: language_server.id, + server_name: Some(language_server.server.name.clone()), + language_server_id: language_server.server.id, variant: Some( proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( proto::LspDiskBasedDiagnosticsUpdated {}, @@ -2233,9 +2308,17 @@ async fn update_language_server( session: Session, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = session - .db() - .await + let db = session.db().await; + + if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant + { + if let Some(capabilities) = update.capabilities.clone() { + db.update_server_capabilities(project_id, request.language_server_id, capabilities) + .await?; + } + } + + let project_connection_ids = db .project_connection_ids(project_id, session.connection_id, true) .await?; broadcast( @@ -2274,25 +2357,6 @@ where Ok(()) } -async fn forward_find_search_candidates_request( - request: proto::FindSearchCandidates, - response: Response, - session: Session, -) -> Result<()> { - let project_id = ProjectId::from_proto(request.remote_entity_id()); - let host_connection_id = session - .db() - .await - .host_for_read_only_project_request(project_id, session.connection_id) - .await?; - let payload = session - .peer - .forward_request(session.connection_id, host_connection_id, request) - .await?; - response.send(payload)?; - Ok(()) -} - /// forward a project request to the host. These requests are disallowed /// for guests. async fn forward_mutating_project_request( @@ -2318,6 +2382,16 @@ where Ok(()) } +async fn multi_lsp_query( + request: MultiLspQuery, + response: Response, + session: Session, +) -> Result<()> { + tracing::Span::current().record("multi_lsp_query_request", request.request_str()); + tracing::info!("multi_lsp_query message received"); + forward_mutating_project_request(request, response, session).await +} + /// Notify other participants that a new buffer has been created async fn create_buffer_for_peer( request: proto::CreateBufferForPeer, @@ -2857,12 +2931,12 @@ async fn make_update_user_plan_message( } fn model_requests_limit( - plan: zed_llm_client::Plan, + plan: cloud_llm_client::Plan, feature_flags: &Vec, -) -> zed_llm_client::UsageLimit { +) -> cloud_llm_client::UsageLimit { match plan.model_requests_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == zed_llm_client::Plan::ZedProTrial + cloud_llm_client::UsageLimit::Limited(limit) => { + let limit = if plan == cloud_llm_client::Plan::ZedProTrial && feature_flags .iter() .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) @@ -2872,9 +2946,9 @@ fn model_requests_limit( limit }; - zed_llm_client::UsageLimit::Limited(limit) + cloud_llm_client::UsageLimit::Limited(limit) } - zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited, + cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited, } } @@ -2884,21 +2958,21 @@ fn subscription_usage_to_proto( feature_flags: &Vec, ) -> proto::SubscriptionUsage { let plan = match plan { - proto::Plan::Free => zed_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + proto::Plan::Free => cloud_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, }; proto::SubscriptionUsage { model_requests_usage_amount: usage.model_requests as u32, model_requests_usage_limit: Some(proto::UsageLimit { variant: Some(match model_requests_limit(plan, feature_flags) { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2906,12 +2980,12 @@ fn subscription_usage_to_proto( edit_predictions_usage_amount: usage.edit_predictions as u32, edit_predictions_usage_limit: Some(proto::UsageLimit { variant: Some(match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2924,21 +2998,21 @@ fn make_default_subscription_usage( feature_flags: &Vec, ) -> proto::SubscriptionUsage { let plan = match plan { - proto::Plan::Free => zed_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + proto::Plan::Free => cloud_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, }; proto::SubscriptionUsage { model_requests_usage_amount: 0, model_requests_usage_limit: Some(proto::UsageLimit { variant: Some(match model_requests_limit(plan, feature_flags) { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), @@ -2946,12 +3020,12 @@ fn make_default_subscription_usage( edit_predictions_usage_amount: 0, edit_predictions_usage_limit: Some(proto::UsageLimit { variant: Some(match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { + cloud_llm_client::UsageLimit::Limited(limit) => { proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { limit: limit as u32, }) } - zed_llm_client::UsageLimit::Unlimited => { + cloud_llm_client::UsageLimit::Unlimited => { proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } }), diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 850b716a9f..ef5bef3e7e 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -1,21 +1,15 @@ use std::sync::Arc; use anyhow::anyhow; -use chrono::Utc; use collections::HashMap; use stripe::SubscriptionStatus; use tokio::sync::RwLock; -use uuid::Uuid; use crate::Result; -use crate::db::billing_subscription::SubscriptionKind; use crate::stripe_client::{ - RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams, - StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams, - StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, - StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, - UpdateSubscriptionParams, + RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems, + StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId, + StripeSubscription, }; pub struct StripeBilling { @@ -94,30 +88,6 @@ impl StripeBilling { .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}"))) } - pub async fn determine_subscription_kind( - &self, - subscription: &StripeSubscription, - ) -> Option { - let zed_pro_price_id = self.zed_pro_price_id().await.ok()?; - let zed_free_price_id = self.zed_free_price_id().await.ok()?; - - subscription.items.iter().find_map(|item| { - let price = item.price.as_ref()?; - - if price.id == zed_pro_price_id { - Some(if subscription.status == SubscriptionStatus::Trialing { - SubscriptionKind::ZedProTrial - } else { - SubscriptionKind::ZedPro - }) - } else if price.id == zed_free_price_id { - Some(SubscriptionKind::ZedFree) - } else { - None - } - }) - } - /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does /// not already exist. /// @@ -150,65 +120,6 @@ impl StripeBilling { Ok(customer_id) } - pub async fn subscribe_to_price( - &self, - subscription_id: &StripeSubscriptionId, - price: &StripePrice, - ) -> Result<()> { - let subscription = self.client.get_subscription(subscription_id).await?; - - if subscription_contains_price(&subscription, &price.id) { - return Ok(()); - } - - const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100; - - let price_per_unit = price.unit_amount.unwrap_or_default(); - let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit; - - self.client - .update_subscription( - subscription_id, - UpdateSubscriptionParams { - items: Some(vec![UpdateSubscriptionItems { - price: Some(price.id.clone()), - }]), - trial_settings: Some(StripeSubscriptionTrialSettings { - end_behavior: StripeSubscriptionTrialSettingsEndBehavior { - missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel - }, - }), - }, - ) - .await?; - - Ok(()) - } - - pub async fn bill_model_request_usage( - &self, - customer_id: &StripeCustomerId, - event_name: &str, - requests: i32, - ) -> Result<()> { - let timestamp = Utc::now().timestamp(); - let idempotency_key = Uuid::new_v4(); - - self.client - .create_meter_event(StripeCreateMeterEventParams { - identifier: &format!("model_requests/{}", idempotency_key), - event_name, - payload: StripeCreateMeterEventPayload { - value: requests as u64, - stripe_customer_id: customer_id, - }, - timestamp: Some(timestamp), - }) - .await?; - - Ok(()) - } - pub async fn subscribe_to_zed_free( &self, customer_id: StripeCustomerId, @@ -243,14 +154,3 @@ impl StripeBilling { Ok(subscription) } } - -fn subscription_contains_price( - subscription: &StripeSubscription, - price_id: &StripePriceId, -) -> bool { - subscription.items.iter().any(|item| { - item.price - .as_ref() - .map_or(false, |price| price.id == *price_id) - }) -} diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 19e410de5b..8d5d076780 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -38,12 +38,12 @@ fn room_participants(room: &Entity, cx: &mut TestAppContext) -> RoomPartic let mut remote = room .remote_participants() .values() - .map(|participant| participant.user.github_login.clone()) + .map(|participant| participant.user.github_login.clone().to_string()) .collect::>(); let mut pending = room .pending_participants() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect::>(); remote.sort(); pending.sort(); diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 73ab2b8167..8754b53f6e 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -24,10 +24,7 @@ use language::{ }; use project::{ ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT, - lsp_store::{ - lsp_ext_command::{ExpandedMacro, LspExtExpandMacro}, - rust_analyzer_ext::RUST_ANALYZER_NAME, - }, + lsp_store::lsp_ext_command::{ExpandedMacro, LspExtExpandMacro}, project_settings::{InlineBlameSettings, ProjectSettings}, }; use recent_projects::disconnected_overlay::DisconnectedOverlay; @@ -296,19 +293,28 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + trigger_characters: Some(vec![".".to_string()]), + resolve_provider: Some(true), + ..lsp::CompletionOptions::default() + }), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - completion_provider: Some(lsp::CompletionOptions { - trigger_characters: Some(vec![".".to_string()]), - resolve_provider: Some(true), - ..Default::default() - }), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -566,11 +572,14 @@ async fn test_collaborating_with_code_actions( cx_b.update(editor::init); - // Set up a fake language server. client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a .language_registry() .register_fake_lsp("Rust", FakeLspAdapter::default()); + client_b.language_registry().add(rust_lang()); + client_b + .language_registry() + .register_fake_lsp("Rust", FakeLspAdapter::default()); client_a .fs() @@ -775,19 +784,27 @@ async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut T cx_b.update(editor::init); - // Set up a fake language server. + let capabilities = lsp::ServerCapabilities { + rename_provider: Some(lsp::OneOf::Right(lsp::RenameOptions { + prepare_provider: Some(true), + work_done_progress_options: Default::default(), + })), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - rename_provider: Some(lsp::OneOf::Right(lsp::RenameOptions { - prepare_provider: Some(true), - work_done_progress_options: Default::default(), - })), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -818,6 +835,8 @@ async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut T .downcast::() .unwrap(); let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); // Move cursor to a location that can be renamed. let prepare_rename = editor_b.update_in(cx_b, |editor, window, cx| { @@ -1055,7 +1074,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_a.read_with(cx_a, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); assert_eq!(status.pending_work.len(), 1); assert_eq!( status.pending_work["the-token"].message.as_ref().unwrap(), @@ -1072,7 +1091,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); }); executor.advance_clock(SERVER_PROGRESS_THROTTLE_TIMEOUT); @@ -1089,7 +1108,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_a.read_with(cx_a, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); assert_eq!(status.pending_work.len(), 1); assert_eq!( status.pending_work["the-token"].message.as_ref().unwrap(), @@ -1099,7 +1118,7 @@ async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut Tes project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "the-language-server"); + assert_eq!(status.name.0, "the-language-server"); assert_eq!(status.pending_work.len(), 1); assert_eq!( status.pending_work["the-token"].message.as_ref().unwrap(), @@ -1422,18 +1441,27 @@ async fn test_on_input_format_from_guest_to_host( .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + document_on_type_formatting_provider: Some(lsp::DocumentOnTypeFormattingOptions { + first_trigger_character: ":".to_string(), + more_trigger_character: Some(vec![">".to_string()]), + }), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - document_on_type_formatting_provider: Some(lsp::DocumentOnTypeFormattingOptions { - first_trigger_character: ":".to_string(), - more_trigger_character: Some(vec![">".to_string()]), - }), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -1588,16 +1616,24 @@ async fn test_mutual_editor_inlay_hint_cache_update( }); }); + let capabilities = lsp::ServerCapabilities { + inlay_hint_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - inlay_hint_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -1830,16 +1866,24 @@ async fn test_inlay_hint_refresh_is_forwarded( }); }); + let capabilities = lsp::ServerCapabilities { + inlay_hint_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - inlay_hint_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() }, ); @@ -2004,15 +2048,23 @@ async fn test_lsp_document_color(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo }); }); + let capabilities = lsp::ServerCapabilities { + color_provider: Some(lsp::ColorProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - color_provider: Some(lsp::ColorProviderCapability::Simple(true)), - ..lsp::ServerCapabilities::default() - }, + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, ..FakeLspAdapter::default() }, ); @@ -2063,6 +2115,8 @@ async fn test_lsp_document_color(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo .unwrap(); let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let requests_made = Arc::new(AtomicUsize::new(0)); let closure_requests_made = Arc::clone(&requests_made); @@ -2264,24 +2318,32 @@ async fn test_lsp_pull_diagnostics( cx_a.update(editor::init); cx_b.update(editor::init); + let capabilities = lsp::ServerCapabilities { + diagnostic_provider: Some(lsp::DiagnosticServerCapabilities::Options( + lsp::DiagnosticOptions { + identifier: Some("test-pulls".to_string()), + inter_file_dependencies: true, + workspace_diagnostics: true, + work_done_progress_options: lsp::WorkDoneProgressOptions { + work_done_progress: None, + }, + }, + )), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - diagnostic_provider: Some(lsp::DiagnosticServerCapabilities::Options( - lsp::DiagnosticOptions { - identifier: Some("test-pulls".to_string()), - inter_file_dependencies: true, - workspace_diagnostics: true, - work_done_progress_options: lsp::WorkDoneProgressOptions { - work_done_progress: None, - }, - }, - )), - ..lsp::ServerCapabilities::default() - }, + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, ..FakeLspAdapter::default() }, ); @@ -2334,6 +2396,8 @@ async fn test_lsp_pull_diagnostics( .unwrap(); let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let expected_push_diagnostic_main_message = "pushed main diagnostic"; let expected_push_diagnostic_lib_message = "pushed lib diagnostic"; let expected_pull_diagnostic_main_message = "pulled main diagnostic"; @@ -2689,6 +2753,7 @@ async fn test_lsp_pull_diagnostics( .unwrap() .downcast::() .unwrap(); + cx_b.run_until_parked(); pull_diagnostics_handle.next().await.unwrap(); assert_eq!( @@ -3718,11 +3783,18 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes cx_b.update(editor::init); client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - name: RUST_ANALYZER_NAME, + name: "rust-analyzer", + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: "rust-analyzer", ..FakeLspAdapter::default() }, ); diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index 9795c27574..5a2c40b890 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -842,7 +842,7 @@ async fn test_client_disconnecting_from_room( // Allow user A to reconnect to the server. server.allow_connections(); - executor.advance_clock(RECEIVE_TIMEOUT); + executor.advance_clock(RECONNECT_TIMEOUT); // Call user B again from client A. active_call_a @@ -1286,7 +1286,7 @@ async fn test_calls_on_multiple_connections( client_b1.disconnect(&cx_b1.to_async()); executor.advance_clock(RECEIVE_TIMEOUT); client_b1 - .authenticate_and_connect(false, &cx_b1.to_async()) + .connect(false, &cx_b1.to_async()) .await .into_response() .unwrap(); @@ -1358,7 +1358,7 @@ async fn test_calls_on_multiple_connections( // User A reconnects automatically, then calls user B again. server.allow_connections(); - executor.advance_clock(RECEIVE_TIMEOUT); + executor.advance_clock(RECONNECT_TIMEOUT); active_call_a .update(cx_a, |call, cx| { call.invite(client_b1.user_id().unwrap(), None, cx) @@ -1667,7 +1667,7 @@ async fn test_project_reconnect( // Client A reconnects. Their project is re-shared, and client B re-joins it. server.allow_connections(); client_a - .authenticate_and_connect(false, &cx_a.to_async()) + .connect(false, &cx_a.to_async()) .await .into_response() .unwrap(); @@ -1796,7 +1796,7 @@ async fn test_project_reconnect( // Client B reconnects. They re-join the room and the remaining shared project. server.allow_connections(); client_b - .authenticate_and_connect(false, &cx_b.to_async()) + .connect(false, &cx_b.to_async()) .await .into_response() .unwrap(); @@ -1881,7 +1881,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_a.user_id().unwrap(), - github_login: "user_a".to_string(), + github_login: "user_a".into(), avatar_uri: "avatar_a".into(), name: None, }), @@ -1900,7 +1900,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_b.user_id().unwrap(), - github_login: "user_b".to_string(), + github_login: "user_b".into(), avatar_uri: "avatar_b".into(), name: None, }), @@ -4778,10 +4778,27 @@ async fn test_definition( .await; let active_call_a = cx_a.read(ActiveCall::global); - let mut fake_language_servers = client_a - .language_registry() - .register_fake_lsp("Rust", Default::default()); + let capabilities = lsp::ServerCapabilities { + definition_provider: Some(OneOf::Left(true)), + type_definition_provider: Some(lsp::TypeDefinitionProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); client_a .fs() @@ -4827,13 +4844,19 @@ async fn test_definition( ))) }, ); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let definitions_1 = project_b .update(cx_b, |p, cx| p.definitions(&buffer_b, 23, cx)) .await .unwrap(); cx_b.read(|cx| { - assert_eq!(definitions_1.len(), 1); + assert_eq!( + definitions_1.len(), + 1, + "Unexpected definitions: {definitions_1:?}" + ); assert_eq!(project_b.read(cx).worktrees(cx).count(), 2); let target_buffer = definitions_1[0].target.buffer.read(cx); assert_eq!( @@ -4901,7 +4924,11 @@ async fn test_definition( .await .unwrap(); cx_b.read(|cx| { - assert_eq!(type_definitions.len(), 1); + assert_eq!( + type_definitions.len(), + 1, + "Unexpected type definitions: {type_definitions:?}" + ); let target_buffer = type_definitions[0].target.buffer.read(cx); assert_eq!(target_buffer.text(), "type T2 = usize;"); assert_eq!( @@ -4925,16 +4952,26 @@ async fn test_references( .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + references_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { name: "my-fake-lsp-adapter", - capabilities: lsp::ServerCapabilities { - references_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: "my-fake-lsp-adapter", + capabilities: capabilities, + ..FakeLspAdapter::default() }, ); @@ -4989,6 +5026,8 @@ async fn test_references( } } }); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let references = project_b.update(cx_b, |p, cx| p.references(&buffer_b, 7, cx)); @@ -4996,7 +5035,7 @@ async fn test_references( executor.run_until_parked(); project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "my-fake-lsp-adapter"); + assert_eq!(status.name.0, "my-fake-lsp-adapter"); assert_eq!( status.pending_work.values().next().unwrap().message, Some("Finding references...".into()) @@ -5054,7 +5093,7 @@ async fn test_references( executor.run_until_parked(); project_b.read_with(cx_b, |project, cx| { let status = project.language_server_statuses(cx).next().unwrap().1; - assert_eq!(status.name, "my-fake-lsp-adapter"); + assert_eq!(status.name.0, "my-fake-lsp-adapter"); assert_eq!( status.pending_work.values().next().unwrap().message, Some("Finding references...".into()) @@ -5204,10 +5243,26 @@ async fn test_document_highlights( ) .await; - let mut fake_language_servers = client_a - .language_registry() - .register_fake_lsp("Rust", Default::default()); client_a.language_registry().add(rust_lang()); + let capabilities = lsp::ServerCapabilities { + document_highlight_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); let (project_a, worktree_id) = client_a.build_local_project(path!("/root-1"), cx_a).await; let project_id = active_call_a @@ -5256,6 +5311,8 @@ async fn test_document_highlights( ])) }, ); + cx_a.run_until_parked(); + cx_b.run_until_parked(); let highlights = project_b .update(cx_b, |p, cx| p.document_highlights(&buffer_b, 34, cx)) @@ -5306,30 +5363,49 @@ async fn test_lsp_hover( client_a.language_registry().add(rust_lang()); let language_server_names = ["rust-analyzer", "CrabLang-ls"]; + let capabilities_1 = lsp::ServerCapabilities { + hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; + let capabilities_2 = lsp::ServerCapabilities { + hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }; let mut language_servers = [ client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - name: "rust-analyzer", - capabilities: lsp::ServerCapabilities { - hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), - ..lsp::ServerCapabilities::default() - }, + name: language_server_names[0], + capabilities: capabilities_1.clone(), ..FakeLspAdapter::default() }, ), client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - name: "CrabLang-ls", - capabilities: lsp::ServerCapabilities { - hover_provider: Some(lsp::HoverProviderCapability::Simple(true)), - ..lsp::ServerCapabilities::default() - }, + name: language_server_names[1], + capabilities: capabilities_2.clone(), ..FakeLspAdapter::default() }, ), ]; + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: language_server_names[0], + capabilities: capabilities_1, + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: language_server_names[1], + capabilities: capabilities_2, + ..FakeLspAdapter::default() + }, + ); let (project_a, worktree_id) = client_a.build_local_project(path!("/root-1"), cx_a).await; let project_id = active_call_a @@ -5423,6 +5499,8 @@ async fn test_lsp_hover( unexpected => panic!("Unexpected server name: {unexpected}"), } } + cx_a.run_until_parked(); + cx_b.run_until_parked(); // Request hover information as the guest. let mut hovers = project_b @@ -5605,10 +5683,26 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( .await; let active_call_a = cx_a.read(ActiveCall::global); + let capabilities = lsp::ServerCapabilities { + definition_provider: Some(OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }; client_a.language_registry().add(rust_lang()); - let mut fake_language_servers = client_a - .language_registry() - .register_fake_lsp("Rust", Default::default()); + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); client_a .fs() @@ -5649,6 +5743,8 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( let definitions; let buffer_b2; if rng.r#gen() { + cx_a.run_until_parked(); + cx_b.run_until_parked(); definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx)); (buffer_b2, _) = project_b .update(cx_b, |p, cx| { @@ -5663,11 +5759,17 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( }) .await .unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx)); } let definitions = definitions.await.unwrap(); - assert_eq!(definitions.len(), 1); + assert_eq!( + definitions.len(), + 1, + "Unexpected definitions: {definitions:?}" + ); assert_eq!(definitions[0].target.buffer, buffer_b2); } @@ -5738,7 +5840,7 @@ async fn test_contacts( server.allow_connections(); client_c - .authenticate_and_connect(false, &cx_c.to_async()) + .connect(false, &cx_c.to_async()) .await .into_response() .unwrap(); @@ -6079,7 +6181,7 @@ async fn test_contacts( .iter() .map(|contact| { ( - contact.user.github_login.clone(), + contact.user.github_login.clone().to_string(), if contact.online { "online" } else { "offline" }, if contact.busy { "busy" } else { "free" }, ) @@ -6269,7 +6371,7 @@ async fn test_contact_requests( client.disconnect(&cx.to_async()); client.clear_contacts(cx).await; client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs index 4e64b5526b..9bf906694e 100644 --- a/crates/collab/src/tests/notification_tests.rs +++ b/crates/collab/src/tests/notification_tests.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use gpui::{BackgroundExecutor, TestAppContext}; use notifications::NotificationEvent; use parking_lot::Mutex; +use pretty_assertions::assert_eq; use rpc::{Notification, proto}; use crate::tests::TestServer; @@ -17,6 +18,9 @@ async fn test_notifications( let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; + // Wait for authentication/connection to Collab to be established. + executor.run_until_parked(); + let notification_events_a = Arc::new(Mutex::new(Vec::new())); let notification_events_b = Arc::new(Mutex::new(Vec::new())); client_a.notification_store().update(cx_a, |_, cx| { diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index 5c5bcd5832..bb84bedfcf 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -1,14 +1,9 @@ use std::sync::Arc; -use chrono::{Duration, Utc}; use pretty_assertions::assert_eq; use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{ - FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, - StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, - StripeSubscriptionItemId, UpdateSubscriptionItems, -}; +use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring}; fn make_stripe_billing() -> (StripeBilling, Arc) { let stripe_client = Arc::new(FakeStripeClient::new()); @@ -21,24 +16,6 @@ fn make_stripe_billing() -> (StripeBilling, Arc) { async fn test_initialize() { let (stripe_billing, stripe_client) = make_stripe_billing(); - // Add test meters - let meter1 = StripeMeter { - id: StripeMeterId("meter_1".into()), - event_name: "event_1".to_string(), - }; - let meter2 = StripeMeter { - id: StripeMeterId("meter_2".into()), - event_name: "event_2".to_string(), - }; - stripe_client - .meters - .lock() - .insert(meter1.id.clone(), meter1); - stripe_client - .meters - .lock() - .insert(meter2.id.clone(), meter2); - // Add test prices let price1 = StripePrice { id: StripePriceId("price_1".into()), @@ -144,217 +121,3 @@ async fn test_find_or_create_customer_by_email() { assert_eq!(customer.email.as_deref(), Some(email)); } } - -#[gpui::test] -async fn test_subscribe_to_price() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let price = StripePrice { - id: StripePriceId("price_test".into()), - unit_amount: Some(2000), - lookup_key: Some("test-price".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(price.id.clone(), price.clone()); - - let now = Utc::now(); - let subscription = StripeSubscription { - id: StripeSubscriptionId("sub_test".into()), - customer: StripeCustomerId("cus_test".into()), - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: vec![], - cancel_at: None, - cancellation_details: None, - }; - stripe_client - .subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - stripe_billing - .subscribe_to_price(&subscription.id, &price) - .await - .unwrap(); - - let update_subscription_calls = stripe_client - .update_subscription_calls - .lock() - .iter() - .map(|(id, params)| (id.clone(), params.clone())) - .collect::>(); - assert_eq!(update_subscription_calls.len(), 1); - assert_eq!(update_subscription_calls[0].0, subscription.id); - assert_eq!( - update_subscription_calls[0].1.items, - Some(vec![UpdateSubscriptionItems { - price: Some(price.id.clone()) - }]) - ); - - // Subscribing to a price that is already on the subscription is a no-op. - { - let now = Utc::now(); - let subscription = StripeSubscription { - id: StripeSubscriptionId("sub_test".into()), - customer: StripeCustomerId("cus_test".into()), - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: vec![StripeSubscriptionItem { - id: StripeSubscriptionItemId("si_test".into()), - price: Some(price.clone()), - }], - cancel_at: None, - cancellation_details: None, - }; - stripe_client - .subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - stripe_billing - .subscribe_to_price(&subscription.id, &price) - .await - .unwrap(); - - assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1); - } -} - -#[gpui::test] -async fn test_subscribe_to_zed_free() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let zed_pro_price = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(0), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(zed_pro_price.id.clone(), zed_pro_price.clone()); - let zed_free_price = StripePrice { - id: StripePriceId("price_2".into()), - unit_amount: Some(0), - lookup_key: Some("zed-free".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(zed_free_price.id.clone(), zed_free_price.clone()); - - stripe_billing.initialize().await.unwrap(); - - // Customer is subscribed to Zed Free when not already subscribed to a plan. - { - let customer_id = StripeCustomerId("cus_no_plan".into()); - - let subscription = stripe_billing - .subscribe_to_zed_free(customer_id) - .await - .unwrap(); - - assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price)); - } - - // Customer is not subscribed to Zed Free when they already have an active subscription. - { - let customer_id = StripeCustomerId("cus_active_subscription".into()); - - let now = Utc::now(); - let existing_subscription = StripeSubscription { - id: StripeSubscriptionId("sub_existing_active".into()), - customer: customer_id.clone(), - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: vec![StripeSubscriptionItem { - id: StripeSubscriptionItemId("si_test".into()), - price: Some(zed_pro_price.clone()), - }], - cancel_at: None, - cancellation_details: None, - }; - stripe_client.subscriptions.lock().insert( - existing_subscription.id.clone(), - existing_subscription.clone(), - ); - - let subscription = stripe_billing - .subscribe_to_zed_free(customer_id) - .await - .unwrap(); - - assert_eq!(subscription, existing_subscription); - } - - // Customer is not subscribed to Zed Free when they already have a trial subscription. - { - let customer_id = StripeCustomerId("cus_trial_subscription".into()); - - let now = Utc::now(); - let existing_subscription = StripeSubscription { - id: StripeSubscriptionId("sub_existing_trial".into()), - customer: customer_id.clone(), - status: stripe::SubscriptionStatus::Trialing, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(14)).timestamp(), - items: vec![StripeSubscriptionItem { - id: StripeSubscriptionItemId("si_test".into()), - price: Some(zed_pro_price.clone()), - }], - cancel_at: None, - cancellation_details: None, - }; - stripe_client.subscriptions.lock().insert( - existing_subscription.id.clone(), - existing_subscription.clone(), - ); - - let subscription = stripe_billing - .subscribe_to_zed_free(customer_id) - .await - .unwrap(); - - assert_eq!(subscription, existing_subscription); - } -} - -#[gpui::test] -async fn test_bill_model_request_usage() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let customer_id = StripeCustomerId("cus_test".into()); - - stripe_billing - .bill_model_request_usage(&customer_id, "some_model/requests", 73) - .await - .unwrap(); - - let create_meter_event_calls = stripe_client - .create_meter_event_calls - .lock() - .iter() - .cloned() - .collect::>(); - assert_eq!(create_meter_event_calls.len(), 1); - assert!( - create_meter_event_calls[0] - .identifier - .starts_with("model_requests/") - ); - assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id); - assert_eq!( - create_meter_event_calls[0].event_name.as_ref(), - "some_model/requests" - ); - assert_eq!(create_meter_event_calls[0].value, 73); -} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index ab84e02b19..f5a0e8ea81 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -8,6 +8,7 @@ use crate::{ use anyhow::anyhow; use call::ActiveCall; use channel::{ChannelBuffer, ChannelStore}; +use client::test::{make_get_authenticated_user_response, parse_authorization_header}; use client::{ self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore, proto::PeerId, @@ -20,7 +21,7 @@ use fs::FakeFs; use futures::{StreamExt as _, channel::oneshot}; use git::GitHostingProviderRegistry; use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext}; -use http_client::FakeHttpClient; +use http_client::{FakeHttpClient, Method}; use language::LanguageRegistry; use node_runtime::NodeRuntime; use notifications::NotificationStore; @@ -161,6 +162,8 @@ impl TestServer { } pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { + const ACCESS_TOKEN: &str = "the-token"; + let fs = FakeFs::new(cx.executor()); cx.update(|cx| { @@ -175,7 +178,7 @@ impl TestServer { }); let clock = Arc::new(FakeSystemClock::new()); - let http = FakeHttpClient::with_404_response(); + let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await { user.id @@ -197,6 +200,47 @@ impl TestServer { .expect("creating user failed") .user_id }; + + let http = FakeHttpClient::create({ + let name = name.to_string(); + move |req| { + let name = name.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: user_id.to_proto(), + access_token: ACCESS_TOKEN.into(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + user_id.0, name, + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } + } + } + }); + let client_name = name.to_string(); let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx)); let server = self.server.clone(); @@ -208,11 +252,10 @@ impl TestServer { .unwrap() .set_id(user_id.to_proto()) .override_authenticate(move |cx| { - let access_token = "the-token".to_string(); cx.spawn(async move |_| { Ok(Credentials { user_id: user_id.to_proto(), - access_token, + access_token: ACCESS_TOKEN.into(), }) }) }) @@ -221,7 +264,7 @@ impl TestServer { credentials, &Credentials { user_id: user_id.0 as u64, - access_token: "the-token".into() + access_token: ACCESS_TOKEN.into(), } ); @@ -254,6 +297,8 @@ impl TestServer { client_name, Principal::User(user), ZedVersion(SemanticVersion::new(1, 0, 0)), + Some("test".to_string()), + None, None, None, Some(connection_id_tx), @@ -318,7 +363,7 @@ impl TestServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -691,17 +736,17 @@ impl TestClient { current: store .contacts() .iter() - .map(|contact| contact.user.github_login.clone()) + .map(|contact| contact.user.github_login.clone().to_string()) .collect(), outgoing_requests: store .outgoing_contact_requests() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect(), incoming_requests: store .incoming_contact_requests() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect(), }) } diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 3e2d813f1b..51d9f003f8 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -103,28 +103,16 @@ impl ChatPanel { }); cx.new(|cx| { - let entity = cx.entity().downgrade(); - let message_list = ListState::new( - 0, - gpui::ListAlignment::Bottom, - px(1000.), - move |ix, window, cx| { - if let Some(entity) = entity.upgrade() { - entity.update(cx, |this: &mut Self, cx| { - this.render_message(ix, window, cx).into_any_element() - }) - } else { - div().into_any() - } - }, - ); + let message_list = ListState::new(0, gpui::ListAlignment::Bottom, px(1000.)); - message_list.set_scroll_handler(cx.listener(|this, event: &ListScrollEvent, _, cx| { - if event.visible_range.start < MESSAGE_LOADING_THRESHOLD { - this.load_more_messages(cx); - } - this.is_scrolled_to_bottom = !event.is_scrolled; - })); + message_list.set_scroll_handler(cx.listener( + |this: &mut Self, event: &ListScrollEvent, _, cx| { + if event.visible_range.start < MESSAGE_LOADING_THRESHOLD { + this.load_more_messages(cx); + } + this.is_scrolled_to_bottom = !event.is_scrolled; + }, + )); let local_offset = chrono::Local::now().offset().local_minus_utc(); let mut this = Self { @@ -399,7 +387,7 @@ impl ChatPanel { ix: usize, window: &mut Window, cx: &mut Context, - ) -> impl IntoElement { + ) -> AnyElement { let active_chat = &self.active_chat.as_ref().unwrap().0; let (message, is_continuation_from_previous, is_admin) = active_chat.update(cx, |active_chat, cx| { @@ -582,6 +570,7 @@ impl ChatPanel { self.render_popover_buttons(message_id, can_delete_message, can_edit_message, cx) .mt_neg_2p5(), ) + .into_any_element() } fn has_open_menu(&self, message_id: Option) -> bool { @@ -979,7 +968,13 @@ impl Render for ChatPanel { ) .child(div().flex_grow().px_2().map(|this| { if self.active_chat.is_some() { - this.child(list(self.message_list.clone()).size_full()) + this.child( + list( + self.message_list.clone(), + cx.processor(Self::render_message), + ) + .size_full(), + ) } else { this.child( div() @@ -1162,7 +1157,7 @@ impl Panel for ChatPanel { } fn icon(&self, _window: &Window, cx: &App) -> Option { - self.enabled(cx).then(|| ui::IconName::MessageBubbles) + self.enabled(cx).then(|| ui::IconName::Chat) } fn icon_tooltip(&self, _: &Window, _: &App) -> Option<&'static str> { diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 4d5973481e..bb7c2ba1cd 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -324,20 +324,6 @@ impl CollabPanel { ) .detach(); - let entity = cx.entity().downgrade(); - let list_state = ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - if let Some(entity) = entity.upgrade() { - entity.update(cx, |this, cx| this.render_list_entry(ix, window, cx)) - } else { - div().into_any() - } - }, - ); - let mut this = Self { width: None, focus_handle: cx.focus_handle(), @@ -345,7 +331,7 @@ impl CollabPanel { fs: workspace.app_state().fs.clone(), pending_serialization: Task::ready(None), context_menu: None, - list_state, + list_state: ListState::new(0, gpui::ListAlignment::Top, px(1000.)), channel_name_editor, filter_editor, entries: Vec::default(), @@ -940,7 +926,7 @@ impl CollabPanel { room.read(cx).local_participant().role == proto::ChannelRole::Admin }); - ListItem::new(SharedString::from(user.github_login.clone())) + ListItem::new(user.github_login.clone()) .start_slot(Avatar::new(user.avatar_uri.clone())) .child(Label::new(user.github_login.clone())) .toggle_state(is_selected) @@ -1124,7 +1110,7 @@ impl CollabPanel { .relative() .gap_1() .child(render_tree_branch(false, false, window, cx)) - .child(IconButton::new(0, IconName::MessageBubbles)) + .child(IconButton::new(0, IconName::Chat)) .children(has_messages_notification.then(|| { div() .w_1p5() @@ -2331,7 +2317,7 @@ impl CollabPanel { let client = this.client.clone(); cx.spawn_in(window, async move |_, cx| { client - .authenticate_and_connect(true, &cx) + .connect(true, &cx) .await .into_response() .notify_async_err(cx); @@ -2431,7 +2417,13 @@ impl CollabPanel { }); v_flex() .size_full() - .child(list(self.list_state.clone()).size_full()) + .child( + list( + self.list_state.clone(), + cx.processor(Self::render_list_entry), + ) + .size_full(), + ) .child( v_flex() .child(div().mx_2().border_primary(cx).border_t_1()) @@ -2583,7 +2575,7 @@ impl CollabPanel { ) -> impl IntoElement { let online = contact.online; let busy = contact.busy || calling; - let github_login = SharedString::from(contact.user.github_login.clone()); + let github_login = contact.user.github_login.clone(); let item = ListItem::new(github_login.clone()) .indent_level(1) .indent_step_size(px(20.)) @@ -2605,7 +2597,7 @@ impl CollabPanel { let contact = contact.clone(); move |this, event: &ClickEvent, window, cx| { this.deploy_contact_context_menu( - event.down.position, + event.position(), contact.clone(), window, cx, @@ -2662,7 +2654,7 @@ impl CollabPanel { is_selected: bool, cx: &mut Context, ) -> impl IntoElement { - let github_login = SharedString::from(user.github_login.clone()); + let github_login = user.github_login.clone(); let user_id = user.id; let is_response_pending = self.user_store.read(cx).is_contact_request_pending(user); let color = if is_response_pending { @@ -2923,7 +2915,7 @@ impl CollabPanel { .gap_1() .px_1() .child( - IconButton::new("channel_chat", IconName::MessageBubbles) + IconButton::new("channel_chat", IconName::Chat) .style(ButtonStyle::Filled) .shape(ui::IconButtonShape::Square) .icon_size(IconSize::Small) @@ -2939,7 +2931,7 @@ impl CollabPanel { .visible_on_hover(""), ) .child( - IconButton::new("channel_notes", IconName::File) + IconButton::new("channel_notes", IconName::FileText) .style(ButtonStyle::Filled) .shape(ui::IconButtonShape::Square) .icon_size(IconSize::Small) diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index fba8f66c2d..3a280ff667 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -118,16 +118,7 @@ impl NotificationPanel { }) .detach(); - let entity = cx.entity().downgrade(); - let notification_list = - ListState::new(0, ListAlignment::Top, px(1000.), move |ix, window, cx| { - entity - .upgrade() - .and_then(|entity| { - entity.update(cx, |this, cx| this.render_notification(ix, window, cx)) - }) - .unwrap_or_else(|| div().into_any()) - }); + let notification_list = ListState::new(0, ListAlignment::Top, px(1000.)); notification_list.set_scroll_handler(cx.listener( |this, event: &ListScrollEvent, _, cx| { if event.count.saturating_sub(event.visible_range.end) < LOADING_THRESHOLD { @@ -634,13 +625,13 @@ impl Render for NotificationPanel { .child(Icon::new(IconName::Envelope)), ) .map(|this| { - if self.client.user_id().is_none() { + if !self.client.status().borrow().is_connected() { this.child( v_flex() .gap_2() .p_4() .child( - Button::new("sign_in_prompt_button", "Sign in") + Button::new("connect_prompt_button", "Connect") .icon_color(Color::Muted) .icon(IconName::Github) .icon_position(IconPosition::Start) @@ -652,10 +643,7 @@ impl Render for NotificationPanel { let client = client.clone(); window .spawn(cx, async move |cx| { - match client - .authenticate_and_connect(true, &cx) - .await - { + match client.connect(true, &cx).await { util::ConnectionResult::Timeout => { log::error!("Connection timeout"); } @@ -673,7 +661,7 @@ impl Render for NotificationPanel { ) .child( div().flex().w_full().items_center().child( - Label::new("Sign in to view notifications.") + Label::new("Connect to view notifications.") .color(Color::Muted) .size(LabelSize::Small), ), @@ -690,7 +678,16 @@ impl Render for NotificationPanel { ), ) } else { - this.child(list(self.notification_list.clone()).size_full()) + this.child( + list( + self.notification_list.clone(), + cx.processor(|this, ix, window, cx| { + this.render_notification(ix, window, cx) + .unwrap_or_else(|| div().into_any()) + }), + ) + .size_full(), + ) } }) } diff --git a/crates/command_palette/src/command_palette.rs b/crates/command_palette/src/command_palette.rs index dfaede0dc4..b8800ff912 100644 --- a/crates/command_palette/src/command_palette.rs +++ b/crates/command_palette/src/command_palette.rs @@ -136,7 +136,10 @@ impl Focusable for CommandPalette { impl Render for CommandPalette { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("CommandPalette") + .w(rems(34.)) + .child(self.picker.clone()) } } diff --git a/crates/component/src/component.rs b/crates/component/src/component.rs index 02840cc3cb..0c05ba4a97 100644 --- a/crates/component/src/component.rs +++ b/crates/component/src/component.rs @@ -318,8 +318,10 @@ pub enum ComponentScope { Notification, #[strum(serialize = "Overlays & Layering")] Overlays, + Onboarding, Status, Typography, + Utilities, #[strum(serialize = "Version Control")] VersionControl, } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 6b24d9b136..65283afa87 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result, anyhow}; use collections::HashMap; -use futures::{FutureExt, StreamExt, channel::oneshot, select}; +use futures::{FutureExt, StreamExt, channel::oneshot, future, select}; use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task}; use parking_lot::Mutex; use postage::barrier; @@ -10,15 +10,19 @@ use smol::channel; use std::{ fmt, path::PathBuf, + pin::pin, sync::{ Arc, atomic::{AtomicI32, Ordering::SeqCst}, }, time::{Duration, Instant}, }; -use util::TryFutureExt; +use util::{ResultExt, TryFutureExt}; -use crate::transport::{StdioTransport, Transport}; +use crate::{ + transport::{StdioTransport, Transport}, + types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled}, +}; const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); @@ -32,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603; type ResponseHandler = Box)>; type NotificationHandler = Box; +type RequestHandler = Box; #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[serde(untagged)] @@ -78,6 +83,15 @@ pub struct Request<'a, T> { pub params: T, } +#[derive(Serialize, Deserialize)] +pub struct AnyRequest<'a> { + pub jsonrpc: &'a str, + pub id: RequestId, + pub method: &'a str, + #[serde(skip_serializing_if = "is_null_value")] + pub params: Option<&'a RawValue>, +} + #[derive(Serialize, Deserialize)] struct AnyResponse<'a> { jsonrpc: &'a str, @@ -144,6 +158,7 @@ impl Client { pub fn stdio( server_id: ContextServerId, binary: ModelContextServerBinary, + working_directory: &Option, cx: AsyncApp, ) -> Result { log::info!( @@ -158,7 +173,7 @@ impl Client { .map(|name| name.to_string_lossy().to_string()) .unwrap_or_else(String::new); - let transport = Arc::new(StdioTransport::new(binary, &cx)?); + let transport = Arc::new(StdioTransport::new(binary, working_directory, &cx)?); Self::new(server_id, server_name.into(), transport, cx) } @@ -176,15 +191,23 @@ impl Client { Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); + let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default())); let receive_input_task = cx.spawn({ let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); + let request_handlers = request_handlers.clone(); let transport = transport.clone(); async move |cx| { - Self::handle_input(transport, notification_handlers, response_handlers, cx) - .log_err() - .await + Self::handle_input( + transport, + notification_handlers, + request_handlers, + response_handlers, + cx, + ) + .log_err() + .await } }); let receive_err_task = cx.spawn({ @@ -230,13 +253,24 @@ impl Client { async fn handle_input( transport: Arc, notification_handlers: Arc>>, + request_handlers: Arc>>, response_handlers: Arc>>>, cx: &mut AsyncApp, ) -> anyhow::Result<()> { let mut receiver = transport.receive(); while let Some(message) = receiver.next().await { - if let Ok(response) = serde_json::from_str::(&message) { + log::trace!("recv: {}", &message); + if let Ok(request) = serde_json::from_str::(&message) { + let mut request_handlers = request_handlers.lock(); + if let Some(handler) = request_handlers.get_mut(request.method) { + handler( + request.id, + request.params.unwrap_or(RawValue::NULL), + cx.clone(), + ); + } + } else if let Ok(response) = serde_json::from_str::(&message) { if let Some(handlers) = response_handlers.lock().as_mut() { if let Some(handler) = handlers.remove(&response.id) { handler(Ok(message.to_string())); @@ -247,6 +281,8 @@ impl Client { if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { handler(notification.params.unwrap_or(Value::Null), cx.clone()); } + } else { + log::error!("Unhandled JSON from context_server: {}", message); } } @@ -294,6 +330,17 @@ impl Client { &self, method: &str, params: impl Serialize, + ) -> Result { + self.request_with(method, params, None, Some(REQUEST_TIMEOUT)) + .await + } + + pub async fn request_with( + &self, + method: &str, + params: impl Serialize, + cancel_rx: Option>, + timeout: Option, ) -> Result { let id = self.next_id.fetch_add(1, SeqCst); let request = serde_json::to_string(&Request { @@ -329,7 +376,23 @@ impl Client { handle_response?; send?; - let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse(); + let mut timeout_fut = pin!( + match timeout { + Some(timeout) => future::Either::Left(executor.timer(timeout)), + None => future::Either::Right(future::pending()), + } + .fuse() + ); + let mut cancel_fut = pin!( + match cancel_rx { + Some(rx) => future::Either::Left(async { + rx.await.log_err(); + }), + None => future::Either::Right(future::pending()), + } + .fuse() + ); + select! { response = rx.fuse() => { let elapsed = started.elapsed(); @@ -348,8 +411,18 @@ impl Client { Err(_) => anyhow::bail!("cancelled") } } - _ = timeout => { - log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT); + _ = cancel_fut => { + self.notify( + Cancelled::METHOD, + ClientNotification::Cancelled(CancelledParams { + request_id: RequestId::Int(id), + reason: None + }) + ).log_err(); + anyhow::bail!(RequestCanceled) + } + _ = timeout_fut => { + log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout.unwrap()); anyhow::bail!("Context server request timeout"); } } @@ -368,14 +441,23 @@ impl Client { Ok(()) } - #[allow(unused)] - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { - self.notification_handlers - .lock() - .insert(method, Box::new(f)); + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { + self.notification_handlers.lock().insert(method, f); + } +} + +#[derive(Debug)] +pub struct RequestCanceled; + +impl std::error::Error for RequestCanceled {} + +impl std::fmt::Display for RequestCanceled { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Context server request was canceled") } } diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index f2517feb27..34fa29678d 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -53,7 +53,7 @@ impl std::fmt::Debug for ContextServerCommand { } enum ContextServerTransport { - Stdio(ContextServerCommand), + Stdio(ContextServerCommand, Option), Custom(Arc), } @@ -64,11 +64,18 @@ pub struct ContextServer { } impl ContextServer { - pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self { + pub fn stdio( + id: ContextServerId, + command: ContextServerCommand, + working_directory: Option>, + ) -> Self { Self { id, client: RwLock::new(None), - configuration: ContextServerTransport::Stdio(command), + configuration: ContextServerTransport::Stdio( + command, + working_directory.map(|directory| directory.to_path_buf()), + ), } } @@ -88,15 +95,36 @@ impl ContextServer { self.client.read().clone() } - pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { - let client = match &self.configuration { - ContextServerTransport::Stdio(command) => Client::stdio( + pub async fn start(&self, cx: &AsyncApp) -> Result<()> { + self.initialize(self.new_client(cx)?).await + } + + /// Starts the context server, making sure handlers are registered before initialization happens + pub async fn start_with_handlers( + &self, + notification_handlers: Vec<( + &'static str, + Box, + )>, + cx: &AsyncApp, + ) -> Result<()> { + let client = self.new_client(cx)?; + for (method, handler) in notification_handlers { + client.on_notification(method, handler); + } + self.initialize(client).await + } + + fn new_client(&self, cx: &AsyncApp) -> Result { + Ok(match &self.configuration { + ContextServerTransport::Stdio(command, working_directory) => Client::stdio( client::ContextServerId(self.id.0.clone()), client::ModelContextServerBinary { executable: Path::new(&command.path).to_path_buf(), args: command.args.clone(), env: command.env.clone(), }, + working_directory, cx.clone(), )?, ContextServerTransport::Custom(transport) => Client::new( @@ -105,8 +133,7 @@ impl ContextServer { transport.clone(), cx.clone(), )?, - }; - self.initialize(client).await + }) } async fn initialize(&self, client: Client) -> Result<()> { diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs index 9295ad979c..0e85fb2129 100644 --- a/crates/context_server/src/listener.rs +++ b/crates/context_server/src/listener.rs @@ -9,6 +9,8 @@ use futures::{ }; use gpui::{App, AppContext, AsyncApp, Task}; use net::async_net::{UnixListener, UnixStream}; +use schemars::JsonSchema; +use serde::de::DeserializeOwned; use serde_json::{json, value::RawValue}; use smol::stream::StreamExt; use std::{ @@ -20,16 +22,32 @@ use util::ResultExt; use crate::{ client::{CspResult, RequestId, Response}, - types::Request, + types::{ + CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations, + ToolResponseContent, + requests::{CallTool, ListTools}, + }, }; pub struct McpServer { socket_path: PathBuf, - handlers: Rc>>, + tools: Rc>>, + handlers: Rc>>, _server_task: Task<()>, } -type McpHandler = Box>, &App) -> Task>; +struct RegisteredTool { + tool: Tool, + handler: ToolHandler, +} + +type ToolHandler = Box< + dyn Fn( + Option, + &mut AsyncApp, + ) -> Task>>, +>; +type RequestHandler = Box>, &App) -> Task>; impl McpServer { pub fn new(cx: &AsyncApp) -> Task> { @@ -43,12 +61,14 @@ impl McpServer { cx.spawn(async move |cx| { let (temp_dir, socket_path, listener) = task.await?; + let tools = Rc::new(RefCell::new(HashMap::default())); let handlers = Rc::new(RefCell::new(HashMap::default())); let server_task = cx.spawn({ + let tools = tools.clone(); let handlers = handlers.clone(); async move |cx| { while let Ok((stream, _)) = listener.accept().await { - Self::serve_connection(stream, handlers.clone(), cx); + Self::serve_connection(stream, tools.clone(), handlers.clone(), cx); } drop(temp_dir) } @@ -56,11 +76,60 @@ impl McpServer { Ok(Self { socket_path, _server_task: server_task, - handlers: handlers.clone(), + tools, + handlers: handlers, }) }) } + pub fn add_tool(&mut self, tool: T) { + let mut settings = schemars::generate::SchemaSettings::draft07(); + settings.inline_subschemas = true; + let mut generator = settings.into_generator(); + + let output_schema = generator.root_schema_for::(); + let unit_schema = generator.root_schema_for::(); + + let registered_tool = RegisteredTool { + tool: Tool { + name: T::NAME.into(), + description: Some(tool.description().into()), + input_schema: generator.root_schema_for::().into(), + output_schema: if output_schema == unit_schema { + None + } else { + Some(output_schema.into()) + }, + annotations: Some(tool.annotations()), + }, + handler: Box::new({ + let tool = tool.clone(); + move |input_value, cx| { + let input = match input_value { + Some(input) => serde_json::from_value(input), + None => serde_json::from_value(serde_json::Value::Null), + }; + + let tool = tool.clone(); + match input { + Ok(input) => cx.spawn(async move |cx| { + let output = tool.run(input, cx).await?; + + Ok(ToolResponse { + content: output.content, + structured_content: serde_json::to_value(output.structured_content) + .unwrap_or_default(), + }) + }), + Err(err) => Task::ready(Err(err.into())), + } + } + }), + }; + + self.tools.borrow_mut().insert(T::NAME, registered_tool); + } + pub fn handle_request( &mut self, f: impl Fn(R::Params, &App) -> Task> + 'static, @@ -120,7 +189,8 @@ impl McpServer { fn serve_connection( stream: UnixStream, - handlers: Rc>>, + tools: Rc>>, + handlers: Rc>>, cx: &mut AsyncApp, ) { let (read, write) = smol::io::split(stream); @@ -135,7 +205,13 @@ impl McpServer { let Some(request_id) = request.id.clone() else { continue; }; - if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) { + + if request.method == CallTool::METHOD { + Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx) + .await; + } else if request.method == ListTools::METHOD { + Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx); + } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) { let outgoing_tx = outgoing_tx.clone(); if let Some(task) = cx @@ -149,25 +225,126 @@ impl McpServer { .detach(); } } else { - outgoing_tx - .unbounded_send( - serde_json::to_string(&Response::<()> { - jsonrpc: "2.0", - id: request.id.unwrap(), - value: CspResult::Error(Some(crate::client::Error { - message: format!("unhandled method {}", request.method), - code: -32601, - })), - }) - .unwrap(), - ) - .ok(); + Self::send_err( + request_id, + format!("unhandled method {}", request.method), + &outgoing_tx, + ); } } }) .detach(); } + fn handle_list_tools( + request_id: RequestId, + tools: &Rc>>, + outgoing_tx: &UnboundedSender, + ) { + let response = ListToolsResponse { + tools: tools.borrow().values().map(|t| t.tool.clone()).collect(), + next_cursor: None, + meta: None, + }; + + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Ok(Some(response)), + }) + .unwrap_or_default(), + ) + .ok(); + } + + async fn handle_call_tool( + request_id: RequestId, + params: Option>, + tools: &Rc>>, + outgoing_tx: &UnboundedSender, + cx: &mut AsyncApp, + ) { + let result: Result = match params.as_ref() { + Some(params) => serde_json::from_str(params.get()), + None => serde_json::from_value(serde_json::Value::Null), + }; + + match result { + Ok(params) => { + if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) { + let outgoing_tx = outgoing_tx.clone(); + + let task = (tool.handler)(params.arguments, cx); + cx.spawn(async move |_| { + let response = match task.await { + Ok(result) => CallToolResponse { + content: result.content, + is_error: Some(false), + meta: None, + structured_content: if result.structured_content.is_null() { + None + } else { + Some(result.structured_content) + }, + }, + Err(err) => CallToolResponse { + content: vec![ToolResponseContent::Text { + text: err.to_string(), + }], + is_error: Some(true), + meta: None, + structured_content: None, + }, + }; + + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Ok(Some(response)), + }) + .unwrap_or_default(), + ) + .ok(); + }) + .detach(); + } else { + Self::send_err( + request_id, + format!("Tool not found: {}", params.name), + &outgoing_tx, + ); + } + } + Err(err) => { + Self::send_err(request_id, err.to_string(), &outgoing_tx); + } + } + } + + fn send_err( + request_id: RequestId, + message: impl Into, + outgoing_tx: &UnboundedSender, + ) { + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response::<()> { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Error(Some(crate::client::Error { + message: message.into(), + code: -32601, + })), + }) + .unwrap(), + ) + .ok(); + } + async fn handle_io( mut outgoing_rx: UnboundedReceiver, incoming_tx: UnboundedSender, @@ -216,7 +393,37 @@ impl McpServer { } } -#[derive(Serialize, Deserialize)] +pub trait McpServerTool { + type Input: DeserializeOwned + JsonSchema; + type Output: Serialize + JsonSchema; + + const NAME: &'static str; + + fn description(&self) -> &'static str; + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: None, + read_only_hint: None, + destructive_hint: None, + idempotent_hint: None, + open_world_hint: None, + } + } + + fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> impl Future>>; +} + +pub struct ToolResponse { + pub content: Vec, + pub structured_content: T, +} + +#[derive(Debug, Serialize, Deserialize)] struct RawRequest { #[serde(skip_serializing_if = "Option::is_none")] id: Option, diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index d8bbac60d6..5355f20f62 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -5,7 +5,12 @@ //! read/write messages and the types from types.rs for serialization/deserialization //! of messages. +use std::time::Duration; + use anyhow::Result; +use futures::channel::oneshot; +use gpui::AsyncApp; +use serde_json::Value; use crate::client::Client; use crate::types::{self, Notification, Request}; @@ -95,7 +100,26 @@ impl InitializedContextServerProtocol { self.inner.request(T::METHOD, params).await } + pub async fn request_with( + &self, + params: T::Params, + cancel_rx: Option>, + timeout: Option, + ) -> Result { + self.inner + .request_with(T::METHOD, params, cancel_rx, timeout) + .await + } + pub fn notify(&self, params: T::Params) -> Result<()> { self.inner.notify(T::METHOD, params) } + + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { + self.inner.on_notification(method, f); + } } diff --git a/crates/context_server/src/transport/stdio_transport.rs b/crates/context_server/src/transport/stdio_transport.rs index 56d0240fa5..443b8c16f1 100644 --- a/crates/context_server/src/transport/stdio_transport.rs +++ b/crates/context_server/src/transport/stdio_transport.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use std::pin::Pin; use anyhow::{Context as _, Result}; @@ -22,7 +23,11 @@ pub struct StdioTransport { } impl StdioTransport { - pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result { + pub fn new( + binary: ModelContextServerBinary, + working_directory: &Option, + cx: &AsyncApp, + ) -> Result { let mut command = util::command::new_smol_command(&binary.executable); command .args(&binary.args) @@ -32,6 +37,10 @@ impl StdioTransport { .stderr(std::process::Stdio::piped()) .kill_on_drop(true); + if let Some(working_directory) = working_directory { + command.current_dir(working_directory); + } + let mut server = command.spawn().with_context(|| { format!( "failed to spawn command. (path={:?}, args={:?})", diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 4a6fdcabd3..5fa2420a3d 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -3,6 +3,8 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::Url; +use crate::client::RequestId; + pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const VERSION_2024_11_05: &str = "2024-11-05"; @@ -100,6 +102,7 @@ pub mod notifications { notification!("notifications/initialized", Initialized, ()); notification!("notifications/progress", Progress, ProgressParams); notification!("notifications/message", Message, MessageParams); + notification!("notifications/cancelled", Cancelled, CancelledParams); notification!( "notifications/resources/updated", ResourcesUpdated, @@ -492,18 +495,20 @@ pub struct RootsCapabilities { pub list_changed: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Tool { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, pub input_schema: serde_json::Value, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub output_schema: Option, #[serde(skip_serializing_if = "Option::is_none")] pub annotations: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ToolAnnotations { /// A human-readable title for the tool. @@ -617,11 +622,15 @@ pub enum ClientNotification { Initialized, Progress(ProgressParams), RootsListChanged, - Cancelled { - request_id: String, - #[serde(skip_serializing_if = "Option::is_none")] - reason: Option, - }, + Cancelled(CancelledParams), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CancelledParams { + pub request_id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -673,6 +682,20 @@ pub struct CallToolResponse { pub is_error: Option, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub structured_content: Option, +} + +impl CallToolResponse { + pub fn text_contents(&self) -> String { + let mut text = String::new(); + for chunk in &self.content { + if let ToolResponseContent::Text { text: chunk } = chunk { + text.push_str(&chunk) + }; + } + text + } } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 234875d420..0fc119f311 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -34,7 +34,7 @@ fs.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true language.workspace = true log.workspace = true lsp.workspace = true @@ -46,6 +46,7 @@ project.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +sum_tree.workspace = true task.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index e11242cb15..49ae2b9d9c 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -6,7 +6,6 @@ mod sign_in; use crate::sign_in::initiate_sign_in_within_workspace; use ::fs::Fs; use anyhow::{Context as _, Result, anyhow}; -use client::DisableAiSettings; use collections::{HashMap, HashSet}; use command_palette_hooks::CommandPaletteFilter; use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared}; @@ -24,6 +23,7 @@ use language::{ use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName}; use node_runtime::NodeRuntime; use parking_lot::Mutex; +use project::DisableAiSettings; use request::StatusNotification; use serde_json::json; use settings::Settings; @@ -39,6 +39,7 @@ use std::{ path::{Path, PathBuf}, sync::Arc, }; +use sum_tree::Dimensions; use util::{ResultExt, fs::remove_matching}; use workspace::Workspace; @@ -85,45 +86,13 @@ pub fn init( move |cx| Copilot::start(new_server_id, fs, node_runtime, cx) }); Copilot::set_global(copilot.clone(), cx); - cx.observe(&copilot, |handle, cx| { - let copilot_action_types = [ - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - ]; - let copilot_auth_action_types = [TypeId::of::()]; - let copilot_no_auth_action_types = [TypeId::of::()]; - let status = handle.read(cx).status(); - - let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; - let filter = CommandPaletteFilter::global_mut(cx); - - if is_ai_disabled { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.hide_action_types(&copilot_no_auth_action_types); - } else { - match status { - Status::Disabled => { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.hide_action_types(&copilot_no_auth_action_types); - } - Status::Authorized => { - filter.hide_action_types(&copilot_no_auth_action_types); - filter.show_action_types( - copilot_action_types - .iter() - .chain(&copilot_auth_action_types), - ); - } - _ => { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.show_action_types(copilot_no_auth_action_types.iter()); - } - } + cx.observe(&copilot, |copilot, cx| { + copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); + }) + .detach(); + cx.observe_global::(|cx| { + if let Some(copilot) = Copilot::global(cx) { + copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); } }) .detach(); @@ -271,7 +240,7 @@ impl RegisteredBuffer { let new_snapshot = new_snapshot.clone(); async move { new_snapshot - .edits_since::<(PointUtf16, usize)>(&old_version) + .edits_since::>(&old_version) .map(|edit| { let edit_start = edit.new.start.0; let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); @@ -1131,6 +1100,44 @@ impl Copilot { cx.notify(); } } + + fn update_action_visibilities(&self, cx: &mut App) { + let signed_in_actions = [ + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + let auth_actions = [TypeId::of::()]; + let no_auth_actions = [TypeId::of::()]; + let status = self.status(); + + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let filter = CommandPaletteFilter::global_mut(cx); + + if is_ai_disabled { + filter.hide_action_types(&signed_in_actions); + filter.hide_action_types(&auth_actions); + filter.hide_action_types(&no_auth_actions); + } else { + match status { + Status::Disabled => { + filter.hide_action_types(&signed_in_actions); + filter.hide_action_types(&auth_actions); + filter.hide_action_types(&no_auth_actions); + } + Status::Authorized => { + filter.hide_action_types(&no_auth_actions); + filter.show_action_types(signed_in_actions.iter().chain(&auth_actions)); + } + _ => { + filter.hide_action_types(&signed_in_actions); + filter.hide_action_types(&auth_actions); + filter.show_action_types(no_auth_actions.iter()); + } + } + } + } } fn id_for_language(language: Option<&Arc>) -> String { diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index 8dc04622f9..2a7225c4e3 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -1,7 +1,7 @@ use crate::{Completion, Copilot}; use anyhow::Result; +use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use gpui::{App, Context, Entity, EntityId, Task}; -use inline_completion::{Direction, EditPredictionProvider, InlineCompletion}; use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings}; use project::Project; use settings::Settings; @@ -210,7 +210,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { buffer: &Entity, cursor_position: language::Anchor, cx: &mut Context, - ) -> Option { + ) -> Option { let buffer_id = buffer.entity_id(); let buffer = buffer.read(cx); let completion = self.active_completion()?; @@ -241,7 +241,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { None } else { let position = cursor_position.bias_right(buffer); - Some(InlineCompletion { + Some(EditPrediction { id: None, edits: vec![(position..position, completion_text.into())], edit_preview: None, @@ -343,7 +343,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { assert!(editor.context_menu_visible()); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); // Since we have both, the copilot suggestion is not shown inline assert_eq!(editor.text(cx), "one.\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.\ntwo\nthree\n"); @@ -355,7 +355,7 @@ mod tests { .unwrap() .detach(); assert!(!editor.context_menu_visible()); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.completion_a\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.completion_a\ntwo\nthree\n"); }); @@ -389,7 +389,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); // Since only the copilot is available, it's shown inline assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.\ntwo\nthree\n"); @@ -400,7 +400,7 @@ mod tests { executor.run_until_parked(); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); }); @@ -418,25 +418,25 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); // Canceling should remove the active Copilot suggestion. editor.cancel(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.c\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); // After canceling, tabbing shouldn't insert the previously shown suggestion. editor.tab(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.c \ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c \ntwo\nthree\n"); // When undoing the previously active suggestion is shown again. editor.undo(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n"); }); @@ -444,25 +444,25 @@ mod tests { // If an edit occurs outside of this editor, the suggestion is still correctly interpolated. cx.update_buffer(|buffer, cx| buffer.edit([(5..5, "o")], None, cx)); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.co\ntwo\nthree\n"); // AcceptEditPrediction when there is an active suggestion inserts it. editor.accept_edit_prediction(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.copilot2\ntwo\nthree\n"); // When undoing the previously active suggestion is shown again. editor.undo(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.co\ntwo\nthree\n"); // Hide suggestion. editor.cancel(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.co\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.co\ntwo\nthree\n"); }); @@ -471,7 +471,7 @@ mod tests { // we won't make it visible. cx.update_buffer(|buffer, cx| buffer.edit([(6..6, "p")], None, cx)); cx.update_editor(|editor, _, cx| { - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one.cop\ntwo\nthree\n"); assert_eq!(editor.text(cx), "one.cop\ntwo\nthree\n"); }); @@ -498,19 +498,19 @@ mod tests { }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "fn foo() {\n let x = 4;\n}"); assert_eq!(editor.text(cx), "fn foo() {\n \n}"); // Tabbing inside of leading whitespace inserts indentation without accepting the suggestion. editor.tab(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "fn foo() {\n \n}"); assert_eq!(editor.display_text(cx), "fn foo() {\n let x = 4;\n}"); // Using AcceptEditPrediction again accepts the suggestion. editor.accept_edit_prediction(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "fn foo() {\n let x = 4;\n}"); assert_eq!(editor.display_text(cx), "fn foo() {\n let x = 4;\n}"); }); @@ -575,17 +575,17 @@ mod tests { ); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); // Accepting the first word of the suggestion should only accept the first word and still show the rest. - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.copilot\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); // Accepting next word should accept the non-word and copilot suggestion should be gone - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.copilot1\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.copilot1\ntwo\nthree\n"); }); @@ -617,11 +617,11 @@ mod tests { ); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); // Accepting the first word (non-word) of the suggestion should only accept the first word and still show the rest. - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.123. \ntwo\nthree\n"); assert_eq!( editor.display_text(cx), @@ -629,8 +629,8 @@ mod tests { ); // Accepting next word should accept the next word and copilot suggestion should still exist - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.123. copilot\ntwo\nthree\n"); assert_eq!( editor.display_text(cx), @@ -638,8 +638,8 @@ mod tests { ); // Accepting the whitespace should accept the non-word/whitespaces with newline and copilot suggestion should be gone - editor.accept_partial_inline_completion(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one.123. copilot\n 456\ntwo\nthree\n"); assert_eq!( editor.display_text(cx), @@ -692,29 +692,29 @@ mod tests { }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\ntw\nthree\n"); editor.backspace(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\nt\nthree\n"); editor.backspace(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\n\nthree\n"); // Deleting across the original suggestion range invalidates it. editor.backspace(&Default::default(), window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\nthree\n"); assert_eq!(editor.text(cx), "one\nthree\n"); // Undoing the deletion restores the suggestion. editor.undo(&Default::default(), window, cx); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\n\nthree\n"); }); @@ -775,7 +775,7 @@ mod tests { }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); _ = editor.update(cx, |editor, _, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2 + a\n\n\n\nc = 3\nd = 4\n" @@ -797,7 +797,7 @@ mod tests { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.select_ranges([Point::new(4, 5)..Point::new(4, 5)]) }); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2\n\n\n\nc = 3\nd = 4\n" @@ -806,7 +806,7 @@ mod tests { // Type a character, ensuring we don't even try to interpolate the previous suggestion. editor.handle_input(" ", window, cx); - assert!(!editor.has_active_inline_completion()); + assert!(!editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2\n\n\n\nc = 3\nd = 4 \n" @@ -817,7 +817,7 @@ mod tests { // Ensure the new suggestion is displayed when the debounce timeout expires. executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); _ = editor.update(cx, |editor, _, cx| { - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!( editor.display_text(cx), "\n\na = 1\nb = 2\n\n\n\nc = 3\nd = 4 + c\n" @@ -880,7 +880,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\ntw\nthree\n"); }); @@ -907,7 +907,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(!editor.context_menu_visible()); - assert!(editor.has_active_inline_completion()); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.display_text(cx), "one\ntwo.foo()\nthree\n"); assert_eq!(editor.text(cx), "one\ntwo\nthree\n"); }); @@ -934,7 +934,7 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(editor.context_menu_visible()); - assert!(!editor.has_active_inline_completion(),); + assert!(!editor.has_active_edit_prediction(),); assert_eq!(editor.text(cx), "one\ntwo.\nthree\n"); }); } @@ -1023,7 +1023,7 @@ mod tests { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |selections| { selections.select_ranges([Point::new(0, 0)..Point::new(0, 0)]) }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); @@ -1033,7 +1033,7 @@ mod tests { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.select_ranges([Point::new(5, 0)..Point::new(5, 0)]) }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); diff --git a/crates/crashes/Cargo.toml b/crates/crashes/Cargo.toml new file mode 100644 index 0000000000..641a97765a --- /dev/null +++ b/crates/crashes/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "crashes" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "GPL-3.0-or-later" + +[dependencies] +crash-handler.workspace = true +log.workspace = true +minidumper.workspace = true +paths.workspace = true +smol.workspace = true +workspace-hack.workspace = true + +[lints] +workspace = true + +[lib] +path = "src/crashes.rs" diff --git a/crates/inline_completion_button/LICENSE-GPL b/crates/crashes/LICENSE-GPL similarity index 100% rename from crates/inline_completion_button/LICENSE-GPL rename to crates/crashes/LICENSE-GPL diff --git a/crates/crashes/src/crashes.rs b/crates/crashes/src/crashes.rs new file mode 100644 index 0000000000..cfb4b57d5d --- /dev/null +++ b/crates/crashes/src/crashes.rs @@ -0,0 +1,172 @@ +use crash_handler::CrashHandler; +use log::info; +use minidumper::{Client, LoopAction, MinidumpBinary}; + +use std::{ + env, + fs::File, + io, + path::{Path, PathBuf}, + process::{self, Command}, + sync::{ + OnceLock, + atomic::{AtomicBool, Ordering}, + }, + thread, + time::Duration, +}; + +// set once the crash handler has initialized and the client has connected to it +pub static CRASH_HANDLER: AtomicBool = AtomicBool::new(false); +// set when the first minidump request is made to avoid generating duplicate crash reports +pub static REQUESTED_MINIDUMP: AtomicBool = AtomicBool::new(false); +const CRASH_HANDLER_TIMEOUT: Duration = Duration::from_secs(60); + +pub async fn init(id: String) { + let exe = env::current_exe().expect("unable to find ourselves"); + let zed_pid = process::id(); + // TODO: we should be able to get away with using 1 crash-handler process per machine, + // but for now we append the PID of the current process which makes it unique per remote + // server or interactive zed instance. This solves an issue where occasionally the socket + // used by the crash handler isn't destroyed correctly which causes it to stay on the file + // system and block further attempts to initialize crash handlers with that socket path. + let socket_name = paths::temp_dir().join(format!("zed-crash-handler-{zed_pid}")); + #[allow(unused)] + let server_pid = Command::new(exe) + .arg("--crash-handler") + .arg(&socket_name) + .spawn() + .expect("unable to spawn server process") + .id(); + info!("spawning crash handler process"); + + let mut elapsed = Duration::ZERO; + let retry_frequency = Duration::from_millis(100); + let mut maybe_client = None; + while maybe_client.is_none() { + if let Ok(client) = Client::with_name(socket_name.as_path()) { + maybe_client = Some(client); + info!("connected to crash handler process after {elapsed:?}"); + break; + } + elapsed += retry_frequency; + smol::Timer::after(retry_frequency).await; + } + let client = maybe_client.unwrap(); + client.send_message(1, id).unwrap(); // set session id on the server + + let client = std::sync::Arc::new(client); + let handler = crash_handler::CrashHandler::attach(unsafe { + let client = client.clone(); + crash_handler::make_crash_event(move |crash_context: &crash_handler::CrashContext| { + // only request a minidump once + let res = if REQUESTED_MINIDUMP + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + client.send_message(2, "mistakes were made").unwrap(); + client.ping().unwrap(); + client.request_dump(crash_context).is_ok() + } else { + true + }; + crash_handler::CrashEventResult::Handled(res) + }) + }) + .expect("failed to attach signal handler"); + + #[cfg(target_os = "linux")] + { + handler.set_ptracer(Some(server_pid)); + } + CRASH_HANDLER.store(true, Ordering::Release); + std::mem::forget(handler); + info!("crash handler registered"); + + loop { + client.ping().ok(); + smol::Timer::after(Duration::from_secs(10)).await; + } +} + +pub struct CrashServer { + session_id: OnceLock, +} + +impl minidumper::ServerHandler for CrashServer { + fn create_minidump_file(&self) -> Result<(File, PathBuf), io::Error> { + let err_message = "Need to send a message with the ID upon starting the crash handler"; + let dump_path = paths::logs_dir() + .join(self.session_id.get().expect(err_message)) + .with_extension("dmp"); + let file = File::create(&dump_path)?; + Ok((file, dump_path)) + } + + fn on_minidump_created(&self, result: Result) -> LoopAction { + match result { + Ok(mut md_bin) => { + use io::Write; + let _ = md_bin.file.flush(); + info!("wrote minidump to disk {:?}", md_bin.path); + } + Err(e) => { + info!("failed to write minidump: {:#}", e); + } + } + LoopAction::Exit + } + + fn on_message(&self, kind: u32, buffer: Vec) { + let message = String::from_utf8(buffer).expect("invalid utf-8"); + info!("kind: {kind}, message: {message}",); + if kind == 1 { + self.session_id + .set(message) + .expect("session id already initialized"); + } + } + + fn on_client_disconnected(&self, clients: usize) -> LoopAction { + info!("client disconnected, {clients} remaining"); + if clients == 0 { + LoopAction::Exit + } else { + LoopAction::Continue + } + } +} + +pub fn handle_panic() { + // wait 500ms for the crash handler process to start up + // if it's still not there just write panic info and no minidump + let retry_frequency = Duration::from_millis(100); + for _ in 0..5 { + if CRASH_HANDLER.load(Ordering::Acquire) { + log::error!("triggering a crash to generate a minidump..."); + #[cfg(target_os = "linux")] + CrashHandler.simulate_signal(crash_handler::Signal::Trap as u32); + #[cfg(not(target_os = "linux"))] + CrashHandler.simulate_exception(None); + break; + } + thread::sleep(retry_frequency); + } +} + +pub fn crash_server(socket: &Path) { + let Ok(mut server) = minidumper::Server::with_name(socket) else { + log::info!("Couldn't create socket, there may already be a running crash server"); + return; + }; + let ab = AtomicBool::new(false); + server + .run( + Box::new(CrashServer { + session_id: OnceLock::new(), + }), + &ab, + Some(CRASH_HANDLER_TIMEOUT), + ) + .expect("failed to run server"); +} diff --git a/crates/dap/src/client.rs b/crates/dap/src/client.rs index 86a15b2d8a..7b791450ec 100644 --- a/crates/dap/src/client.rs +++ b/crates/dap/src/client.rs @@ -295,7 +295,7 @@ mod tests { request: dap_types::StartDebuggingRequestArgumentsRequest::Launch, }, }, - Box::new(|_| panic!("Did not expect to hit this code path")), + Box::new(|_| {}), &mut cx.to_async(), ) .await diff --git a/crates/dap/src/transport.rs b/crates/dap/src/transport.rs index 6dadf1cf35..f9fbbfc842 100644 --- a/crates/dap/src/transport.rs +++ b/crates/dap/src/transport.rs @@ -883,6 +883,7 @@ impl FakeTransport { break Err(anyhow!("exit in response to request")); } }; + let success = response.success; let message = serde_json::to_string(&Message::Response(response)).unwrap(); @@ -893,6 +894,25 @@ impl FakeTransport { ) .await .unwrap(); + + if request.command == dap_types::requests::Initialize::COMMAND + && success + { + let message = serde_json::to_string(&Message::Event(Box::new( + dap_types::messages::Events::Initialized(Some( + Default::default(), + )), + ))) + .unwrap(); + writer + .write_all( + TransportDelegate::build_rpc_message(message) + .as_bytes(), + ) + .await + .unwrap(); + } + writer.flush().await.unwrap(); } } diff --git a/crates/dap_adapters/src/python.rs b/crates/dap_adapters/src/python.rs index aa64fea6ed..461ce6fbb3 100644 --- a/crates/dap_adapters/src/python.rs +++ b/crates/dap_adapters/src/python.rs @@ -1,38 +1,36 @@ use crate::*; use anyhow::Context as _; use dap::{DebugRequest, StartDebuggingRequestArguments, adapters::DebugTaskDefinition}; +use fs::RemoveOptions; +use futures::{StreamExt, TryStreamExt}; +use gpui::http_client::AsyncBody; use gpui::{AsyncApp, SharedString}; use json_dotpath::DotPaths; use language::LanguageName; use paths::debug_adapters_dir; use serde_json::Value; +use smol::fs::File; +use smol::io::AsyncReadExt; use smol::lock::OnceCell; +use std::ffi::OsString; use std::net::Ipv4Addr; +use std::str::FromStr; use std::{ collections::HashMap, ffi::OsStr, path::{Path, PathBuf}, }; +use util::{ResultExt, maybe}; #[derive(Default)] pub(crate) struct PythonDebugAdapter { - python_venv_base: OnceCell, String>>, + debugpy_whl_base_path: OnceCell, String>>, } impl PythonDebugAdapter { const ADAPTER_NAME: &'static str = "Debugpy"; const DEBUG_ADAPTER_NAME: DebugAdapterName = DebugAdapterName(SharedString::new_static(Self::ADAPTER_NAME)); - const PYTHON_ADAPTER_IN_VENV: &'static str = if cfg!(target_os = "windows") { - "Scripts/python3" - } else { - "bin/python3" - }; - const ADAPTER_PATH: &'static str = if cfg!(target_os = "windows") { - "debugpy-venv/Scripts/debugpy-adapter" - } else { - "debugpy-venv/bin/debugpy-adapter" - }; const LANGUAGE_NAME: &'static str = "Python"; @@ -41,7 +39,6 @@ impl PythonDebugAdapter { port: u16, user_installed_path: Option<&Path>, user_args: Option>, - installed_in_venv: bool, ) -> Result> { let mut args = if let Some(user_installed_path) = user_installed_path { log::debug!( @@ -49,13 +46,11 @@ impl PythonDebugAdapter { user_installed_path.display() ); vec![user_installed_path.to_string_lossy().to_string()] - } else if installed_in_venv { - log::debug!("Using venv-installed debugpy"); - vec!["-m".to_string(), "debugpy.adapter".to_string()] } else { let adapter_path = paths::debug_adapters_dir().join(Self::DEBUG_ADAPTER_NAME.as_ref()); let path = adapter_path - .join(Self::ADAPTER_PATH) + .join("debugpy") + .join("adapter") .to_string_lossy() .into_owned(); log::debug!("Using pip debugpy adapter from: {path}"); @@ -96,68 +91,145 @@ impl PythonDebugAdapter { }) } - async fn ensure_venv(delegate: &dyn DapDelegate) -> Result> { - let python_path = Self::find_base_python(delegate) + async fn fetch_wheel(delegate: &Arc) -> Result, String> { + let system_python = Self::system_python_name(delegate) .await - .context("Could not find Python installation for DebugPy")?; - let work_dir = debug_adapters_dir().join(Self::ADAPTER_NAME); - let mut path = work_dir.clone(); - path.push("debugpy-venv"); - if !path.exists() { - util::command::new_smol_command(python_path) - .arg("-m") - .arg("venv") - .arg("debugpy-venv") - .current_dir(work_dir) - .spawn()? - .output() - .await?; + .ok_or_else(|| String::from("Could not find a Python installation"))?; + let command: &OsStr = system_python.as_ref(); + let download_dir = debug_adapters_dir().join(Self::ADAPTER_NAME).join("wheels"); + std::fs::create_dir_all(&download_dir).map_err(|e| e.to_string())?; + let installation_succeeded = util::command::new_smol_command(command) + .args([ + "-m", + "pip", + "download", + "debugpy", + "--only-binary=:all:", + "-d", + download_dir.to_string_lossy().as_ref(), + ]) + .output() + .await + .map_err(|e| format!("{e}"))? + .status + .success(); + if !installation_succeeded { + return Err("debugpy installation failed".into()); } - Ok(path.into()) + let wheel_path = std::fs::read_dir(&download_dir) + .map_err(|e| e.to_string())? + .find_map(|entry| { + entry.ok().filter(|e| { + e.file_type().is_ok_and(|typ| typ.is_file()) + && Path::new(&e.file_name()).extension() == Some("whl".as_ref()) + }) + }) + .ok_or_else(|| String::from("Did not find a .whl in {download_dir}"))?; + + util::archive::extract_zip( + &debug_adapters_dir().join(Self::ADAPTER_NAME), + File::open(&wheel_path.path()) + .await + .map_err(|e| e.to_string())?, + ) + .await + .map_err(|e| e.to_string())?; + + Ok(Arc::from(wheel_path.path())) } - // Find "baseline", user python version from which we'll create our own venv. - async fn find_base_python(delegate: &dyn DapDelegate) -> Option { - for path in ["python3", "python"] { - if let Some(path) = delegate.which(path.as_ref()).await { - return Some(path); + async fn maybe_fetch_new_wheel(delegate: &Arc) { + let latest_release = delegate + .http_client() + .get( + "https://pypi.org/pypi/debugpy/json", + AsyncBody::empty(), + false, + ) + .await + .log_err(); + maybe!(async move { + let response = latest_release.filter(|response| response.status().is_success())?; + + let mut output = String::new(); + response + .into_body() + .read_to_string(&mut output) + .await + .ok()?; + let as_json = serde_json::Value::from_str(&output).ok()?; + let latest_version = as_json.get("info").and_then(|info| { + info.get("version") + .and_then(|version| version.as_str()) + .map(ToOwned::to_owned) + })?; + let dist_info_dirname: OsString = format!("debugpy-{latest_version}.dist-info").into(); + let is_up_to_date = delegate + .fs() + .read_dir(&debug_adapters_dir().join(Self::ADAPTER_NAME)) + .await + .ok()? + .into_stream() + .any(async |entry| { + entry.is_ok_and(|e| e.file_name().is_some_and(|name| name == dist_info_dirname)) + }) + .await; + + if !is_up_to_date { + delegate + .fs() + .remove_dir( + &debug_adapters_dir().join(Self::ADAPTER_NAME), + RemoveOptions { + recursive: true, + ignore_if_not_exists: true, + }, + ) + .await + .ok()?; + Self::fetch_wheel(delegate).await.ok()?; } - } - None + Some(()) + }) + .await; } - async fn base_venv(&self, delegate: &dyn DapDelegate) -> Result, String> { - const BINARY_DIR: &str = if cfg!(target_os = "windows") { - "Scripts" - } else { - "bin" - }; - self.python_venv_base - .get_or_init(move || async move { - let venv_base = Self::ensure_venv(delegate) - .await - .map_err(|e| format!("{e}"))?; - let pip_path = venv_base.join(BINARY_DIR).join("pip3"); - let installation_succeeded = util::command::new_smol_command(pip_path.as_path()) - .arg("install") - .arg("debugpy") - .arg("-U") - .output() - .await - .map_err(|e| format!("{e}"))? - .status - .success(); - if !installation_succeeded { - return Err("debugpy installation failed".into()); - } - - Ok(venv_base) + async fn fetch_debugpy_whl( + &self, + delegate: &Arc, + ) -> Result, String> { + self.debugpy_whl_base_path + .get_or_init(|| async move { + Self::maybe_fetch_new_wheel(delegate).await; + Ok(Arc::from( + debug_adapters_dir() + .join(Self::ADAPTER_NAME) + .join("debugpy") + .join("adapter") + .as_ref(), + )) }) .await .clone() } + async fn system_python_name(delegate: &Arc) -> Option { + const BINARY_NAMES: [&str; 3] = ["python3", "python", "py"]; + let mut name = None; + + for cmd in BINARY_NAMES { + name = delegate + .which(OsStr::new(cmd)) + .await + .map(|path| path.to_string_lossy().to_string()); + if name.is_some() { + break; + } + } + name + } + async fn get_installed_binary( &self, delegate: &Arc, @@ -165,27 +237,14 @@ impl PythonDebugAdapter { user_installed_path: Option, user_args: Option>, python_from_toolchain: Option, - installed_in_venv: bool, ) -> Result { - const BINARY_NAMES: [&str; 3] = ["python3", "python", "py"]; let tcp_connection = config.tcp_connection.clone().unwrap_or_default(); let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?; let python_path = if let Some(toolchain) = python_from_toolchain { Some(toolchain) } else { - let mut name = None; - - for cmd in BINARY_NAMES { - name = delegate - .which(OsStr::new(cmd)) - .await - .map(|path| path.to_string_lossy().to_string()); - if name.is_some() { - break; - } - } - name + Self::system_python_name(delegate).await }; let python_command = python_path.context("failed to find binary path for Python")?; @@ -196,7 +255,6 @@ impl PythonDebugAdapter { port, user_installed_path.as_deref(), user_args, - installed_in_venv, ) .await?; @@ -618,63 +676,53 @@ impl DebugAdapter for PythonDebugAdapter { local_path.display() ); return self - .get_installed_binary( - delegate, - &config, - Some(local_path.clone()), - user_args, - None, - false, - ) + .get_installed_binary(delegate, &config, Some(local_path.clone()), user_args, None) .await; } + let base_path = config + .config + .get("cwd") + .and_then(|cwd| { + cwd.as_str() + .map(Path::new)? + .strip_prefix(delegate.worktree_root_path()) + .ok() + }) + .unwrap_or_else(|| "".as_ref()) + .into(); let toolchain = delegate .toolchain_store() .active_toolchain( delegate.worktree_id(), - Arc::from("".as_ref()), + base_path, language::LanguageName::new(Self::LANGUAGE_NAME), cx, ) .await; - if let Some(toolchain) = &toolchain { - if let Some(path) = Path::new(&toolchain.path.to_string()).parent() { - let debugpy_path = path.join("debugpy"); - if delegate.fs().is_file(&debugpy_path).await { - log::debug!( - "Found debugpy in toolchain environment: {}", - debugpy_path.display() - ); - return self - .get_installed_binary( - delegate, - &config, - None, - user_args, - Some(toolchain.path.to_string()), - true, - ) - .await; - } - } - } - let toolchain = self - .base_venv(&**delegate) + let debugpy_path = self + .fetch_debugpy_whl(delegate) .await - .map_err(|e| anyhow::anyhow!(e))? - .join(Self::PYTHON_ADAPTER_IN_VENV); + .map_err(|e| anyhow::anyhow!("{e}"))?; + if let Some(toolchain) = &toolchain { + log::debug!( + "Found debugpy in toolchain environment: {}", + debugpy_path.display() + ); + return self + .get_installed_binary( + delegate, + &config, + None, + user_args, + Some(toolchain.path.to_string()), + ) + .await; + } - self.get_installed_binary( - delegate, - &config, - None, - user_args, - Some(toolchain.to_string_lossy().into_owned()), - false, - ) - .await + self.get_installed_binary(delegate, &config, None, user_args, None) + .await } fn label_for_child_session(&self, args: &StartDebuggingRequestArguments) -> Option { @@ -689,6 +737,8 @@ impl DebugAdapter for PythonDebugAdapter { #[cfg(test)] mod tests { + use util::path; + use super::*; use std::{net::Ipv4Addr, path::PathBuf}; @@ -699,30 +749,24 @@ mod tests { // Case 1: User-defined debugpy path (highest precedence) let user_path = PathBuf::from("/custom/path/to/debugpy/src/debugpy/adapter"); - let user_args = PythonDebugAdapter::generate_debugpy_arguments( - &host, - port, - Some(&user_path), - None, - false, - ) - .await - .unwrap(); - - // Case 2: Venv-installed debugpy (uses -m debugpy.adapter) - let venv_args = - PythonDebugAdapter::generate_debugpy_arguments(&host, port, None, None, true) + let user_args = + PythonDebugAdapter::generate_debugpy_arguments(&host, port, Some(&user_path), None) .await .unwrap(); + // Case 2: Venv-installed debugpy (uses -m debugpy.adapter) + let venv_args = PythonDebugAdapter::generate_debugpy_arguments(&host, port, None, None) + .await + .unwrap(); + assert_eq!(user_args[0], "/custom/path/to/debugpy/src/debugpy/adapter"); assert_eq!(user_args[1], "--host=127.0.0.1"); assert_eq!(user_args[2], "--port=5678"); - assert_eq!(venv_args[0], "-m"); - assert_eq!(venv_args[1], "debugpy.adapter"); - assert_eq!(venv_args[2], "--host=127.0.0.1"); - assert_eq!(venv_args[3], "--port=5678"); + let expected_suffix = path!("debug_adapters/Debugpy/debugpy/adapter"); + assert!(venv_args[0].ends_with(expected_suffix)); + assert_eq!(venv_args[1], "--host=127.0.0.1"); + assert_eq!(venv_args[2], "--port=5678"); // The same cases, with arguments overridden by the user let user_args = PythonDebugAdapter::generate_debugpy_arguments( @@ -730,7 +774,6 @@ mod tests { port, Some(&user_path), Some(vec!["foo".into()]), - false, ) .await .unwrap(); @@ -739,7 +782,6 @@ mod tests { port, None, Some(vec!["foo".into()]), - true, ) .await .unwrap(); @@ -747,9 +789,8 @@ mod tests { assert!(user_args[0].ends_with("src/debugpy/adapter")); assert_eq!(user_args[1], "foo"); - assert_eq!(venv_args[0], "-m"); - assert_eq!(venv_args[1], "debugpy.adapter"); - assert_eq!(venv_args[2], "foo"); + assert!(venv_args[0].ends_with(expected_suffix)); + assert_eq!(venv_args[1], "foo"); // Note: Case 3 (GitHub-downloaded debugpy) is not tested since this requires mocking the Github API. } diff --git a/crates/debugger_ui/src/debugger_ui.rs b/crates/debugger_ui/src/debugger_ui.rs index 9eac59af83..5f5dfd1a1e 100644 --- a/crates/debugger_ui/src/debugger_ui.rs +++ b/crates/debugger_ui/src/debugger_ui.rs @@ -299,59 +299,76 @@ pub fn init(cx: &mut App) { else { return; }; + + let session = active_session + .read(cx) + .running_state + .read(cx) + .session() + .read(cx); + + if session.is_terminated() { + return; + } + let editor = cx.entity().downgrade(); - window.on_action(TypeId::of::(), { - let editor = editor.clone(); - let active_session = active_session.clone(); - move |_, phase, _, cx| { - if phase != DispatchPhase::Bubble { - return; - } - maybe!({ - let (buffer, position, _) = editor - .update(cx, |editor, cx| { - let cursor_point: language::Point = - editor.selections.newest(cx).head(); - editor - .buffer() - .read(cx) - .point_to_buffer_point(cursor_point, cx) - }) - .ok()??; + window.on_action_when( + session.any_stopped_thread(), + TypeId::of::(), + { + let editor = editor.clone(); + let active_session = active_session.clone(); + move |_, phase, _, cx| { + if phase != DispatchPhase::Bubble { + return; + } + maybe!({ + let (buffer, position, _) = editor + .update(cx, |editor, cx| { + let cursor_point: language::Point = + editor.selections.newest(cx).head(); - let path = + editor + .buffer() + .read(cx) + .point_to_buffer_point(cursor_point, cx) + }) + .ok()??; + + let path = debugger::breakpoint_store::BreakpointStore::abs_path_from_buffer( &buffer, cx, )?; - let source_breakpoint = SourceBreakpoint { - row: position.row, - path, - message: None, - condition: None, - hit_condition: None, - state: debugger::breakpoint_store::BreakpointState::Enabled, - }; + let source_breakpoint = SourceBreakpoint { + row: position.row, + path, + message: None, + condition: None, + hit_condition: None, + state: debugger::breakpoint_store::BreakpointState::Enabled, + }; - active_session.update(cx, |session, cx| { - session.running_state().update(cx, |state, cx| { - if let Some(thread_id) = state.selected_thread_id() { - state.session().update(cx, |session, cx| { - session.run_to_position( - source_breakpoint, - thread_id, - cx, - ); - }) - } + active_session.update(cx, |session, cx| { + session.running_state().update(cx, |state, cx| { + if let Some(thread_id) = state.selected_thread_id() { + state.session().update(cx, |session, cx| { + session.run_to_position( + source_breakpoint, + thread_id, + cx, + ); + }) + } + }); }); - }); - Some(()) - }); - } - }); + Some(()) + }); + } + }, + ); window.on_action( TypeId::of::(), diff --git a/crates/debugger_ui/src/new_process_modal.rs b/crates/debugger_ui/src/new_process_modal.rs index 42f77ab056..2695941bc0 100644 --- a/crates/debugger_ui/src/new_process_modal.rs +++ b/crates/debugger_ui/src/new_process_modal.rs @@ -1015,15 +1015,13 @@ impl DebugDelegate { let language_names = languages.language_names(); let language = dap_registry .adapter_language(&scenario.adapter) - .map(|language| TaskSourceKind::Language { - name: language.into(), - }); + .map(|language| TaskSourceKind::Language { name: language.0 }); let language = language.or_else(|| { scenario.label.split_whitespace().find_map(|word| { language_names .iter() - .find(|name| name.eq_ignore_ascii_case(word)) + .find(|name| name.as_ref().eq_ignore_ascii_case(word)) .map(|name| TaskSourceKind::Language { name: name.to_owned().into(), }) diff --git a/crates/debugger_ui/src/session/running/breakpoint_list.rs b/crates/debugger_ui/src/session/running/breakpoint_list.rs index 6ac4b1c878..a6defbbf35 100644 --- a/crates/debugger_ui/src/session/running/breakpoint_list.rs +++ b/crates/debugger_ui/src/session/running/breakpoint_list.rs @@ -29,7 +29,6 @@ use ui::{ Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, Toggleable, Tooltip, Window, div, h_flex, px, v_flex, }; -use util::ResultExt; use workspace::Workspace; use zed_actions::{ToggleEnableBreakpoint, UnsetBreakpoint}; @@ -56,8 +55,6 @@ pub(crate) struct BreakpointList { scrollbar_state: ScrollbarState, breakpoints: Vec, session: Option>, - hide_scrollbar_task: Option>, - show_scrollbar: bool, focus_handle: FocusHandle, scroll_handle: UniformListScrollHandle, selected_ix: Option, @@ -103,8 +100,6 @@ impl BreakpointList { worktree_store, scrollbar_state, breakpoints: Default::default(), - hide_scrollbar_task: None, - show_scrollbar: false, workspace, session, focus_handle, @@ -565,21 +560,6 @@ impl BreakpointList { Ok(()) } - fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context) { - const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); - self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| { - cx.background_executor() - .timer(SCROLLBAR_SHOW_INTERVAL) - .await; - panel - .update(cx, |panel, cx| { - panel.show_scrollbar = false; - cx.notify(); - }) - .log_err(); - })) - } - fn render_list(&mut self, cx: &mut Context) -> impl IntoElement { let selected_ix = self.selected_ix; let focus_handle = self.focus_handle.clone(); @@ -614,43 +594,39 @@ impl BreakpointList { .flex_grow() } - fn render_vertical_scrollbar(&self, cx: &mut Context) -> Option> { - if !(self.show_scrollbar || self.scrollbar_state.is_dragging()) { - return None; - } - Some( - div() - .occlude() - .id("breakpoint-list-vertical-scrollbar") - .on_mouse_move(cx.listener(|_, _, _, cx| { - cx.notify(); - cx.stop_propagation() - })) - .on_hover(|_, _, cx| { + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .occlude() + .id("breakpoint-list-vertical-scrollbar") + .on_mouse_move(cx.listener(|_, _, _, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { cx.stop_propagation(); - }) - .on_any_mouse_down(|_, _, cx| { - cx.stop_propagation(); - }) - .on_mouse_up( - MouseButton::Left, - cx.listener(|_, _, _, cx| { - cx.stop_propagation(); - }), - ) - .on_scroll_wheel(cx.listener(|_, _, _, cx| { - cx.notify(); - })) - .h_full() - .absolute() - .right_1() - .top_1() - .bottom_0() - .w(px(12.)) - .cursor_default() - .children(Scrollbar::vertical(self.scrollbar_state.clone())), - ) + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) } + pub(crate) fn render_control_strip(&self) -> AnyElement { let selection_kind = self.selection_kind(); let focus_handle = self.focus_handle.clone(); @@ -819,15 +795,6 @@ impl Render for BreakpointList { .id("breakpoint-list") .key_context("BreakpointList") .track_focus(&self.focus_handle) - .on_hover(cx.listener(|this, hovered, window, cx| { - if *hovered { - this.show_scrollbar = true; - this.hide_scrollbar_task.take(); - cx.notify(); - } else if !this.focus_handle.contains_focused(window, cx) { - this.hide_scrollbar(window, cx); - } - })) .on_action(cx.listener(Self::select_next)) .on_action(cx.listener(Self::select_previous)) .on_action(cx.listener(Self::select_first)) @@ -844,7 +811,7 @@ impl Render for BreakpointList { v_flex() .size_full() .child(self.render_list(cx)) - .children(self.render_vertical_scrollbar(cx)), + .child(self.render_vertical_scrollbar(cx)), ) .when_some(self.strip_mode, |this, _| { this.child(Divider::horizontal()).child( diff --git a/crates/debugger_ui/src/session/running/loaded_source_list.rs b/crates/debugger_ui/src/session/running/loaded_source_list.rs index dd5487e042..6b376bb892 100644 --- a/crates/debugger_ui/src/session/running/loaded_source_list.rs +++ b/crates/debugger_ui/src/session/running/loaded_source_list.rs @@ -13,22 +13,8 @@ pub(crate) struct LoadedSourceList { impl LoadedSourceList { pub fn new(session: Entity, cx: &mut Context) -> Self { - let weak_entity = cx.weak_entity(); let focus_handle = cx.focus_handle(); - - let list = ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, _window, cx| { - weak_entity - .upgrade() - .map(|loaded_sources| { - loaded_sources.update(cx, |this, cx| this.render_entry(ix, cx)) - }) - .unwrap_or(div().into_any()) - }, - ); + let list = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let _subscription = cx.subscribe(&session, |this, _, event, cx| match event { SessionEvent::Stopped(_) | SessionEvent::LoadedSources => { @@ -98,6 +84,12 @@ impl Render for LoadedSourceList { .track_focus(&self.focus_handle) .size_full() .p_1() - .child(list(self.list.clone()).size_full()) + .child( + list( + self.list.clone(), + cx.processor(|this, ix, _window, cx| this.render_entry(ix, cx)), + ) + .size_full(), + ) } } diff --git a/crates/debugger_ui/src/session/running/memory_view.rs b/crates/debugger_ui/src/session/running/memory_view.rs index 7b62a1d55d..75b8938371 100644 --- a/crates/debugger_ui/src/session/running/memory_view.rs +++ b/crates/debugger_ui/src/session/running/memory_view.rs @@ -23,7 +23,6 @@ use ui::{ ParentElement, Pixels, PopoverMenuHandle, Render, Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, TextSize, Tooltip, Window, div, h_flex, px, v_flex, }; -use util::ResultExt; use workspace::Workspace; use crate::{ToggleDataBreakpoint, session::running::stack_frame_list::StackFrameList}; @@ -34,9 +33,7 @@ pub(crate) struct MemoryView { workspace: WeakEntity, scroll_handle: UniformListScrollHandle, scroll_state: ScrollbarState, - show_scrollbar: bool, stack_frame_list: WeakEntity, - hide_scrollbar_task: Option>, focus_handle: FocusHandle, view_state: ViewState, query_editor: Entity, @@ -150,8 +147,6 @@ impl MemoryView { scroll_state, scroll_handle, stack_frame_list, - show_scrollbar: false, - hide_scrollbar_task: None, focus_handle: cx.focus_handle(), view_state, query_editor, @@ -168,61 +163,42 @@ impl MemoryView { .detach(); this } - fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context) { - const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); - self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| { - cx.background_executor() - .timer(SCROLLBAR_SHOW_INTERVAL) - .await; - panel - .update(cx, |panel, cx| { - panel.show_scrollbar = false; - cx.notify(); - }) - .log_err(); - })) - } - fn render_vertical_scrollbar(&self, cx: &mut Context) -> Option> { - if !(self.show_scrollbar || self.scroll_state.is_dragging()) { - return None; - } - Some( - div() - .occlude() - .id("memory-view-vertical-scrollbar") - .on_drag_move(cx.listener(|this, evt, _, cx| { - let did_handle = this.handle_scroll_drag(evt); - cx.notify(); - if did_handle { - cx.stop_propagation() - } - })) - .on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty)) - .on_hover(|_, _, cx| { + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .occlude() + .id("memory-view-vertical-scrollbar") + .on_drag_move(cx.listener(|this, evt, _, cx| { + let did_handle = this.handle_scroll_drag(evt); + cx.notify(); + if did_handle { + cx.stop_propagation() + } + })) + .on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty)) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { cx.stop_propagation(); - }) - .on_any_mouse_down(|_, _, cx| { - cx.stop_propagation(); - }) - .on_mouse_up( - MouseButton::Left, - cx.listener(|_, _, _, cx| { - cx.stop_propagation(); - }), - ) - .on_scroll_wheel(cx.listener(|_, _, _, cx| { - cx.notify(); - })) - .h_full() - .absolute() - .right_1() - .top_1() - .bottom_0() - .w(px(12.)) - .cursor_default() - .children(Scrollbar::vertical(self.scroll_state.clone())), - ) + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scroll_state.clone()).map(|s| s.auto_hide(cx))) } fn render_memory(&self, cx: &mut Context) -> UniformList { @@ -920,15 +896,6 @@ impl Render for MemoryView { .on_action(cx.listener(Self::page_up)) .size_full() .track_focus(&self.focus_handle) - .on_hover(cx.listener(|this, hovered, window, cx| { - if *hovered { - this.show_scrollbar = true; - this.hide_scrollbar_task.take(); - cx.notify(); - } else if !this.focus_handle.contains_focused(window, cx) { - this.hide_scrollbar(window, cx); - } - })) .child( h_flex() .w_full() @@ -978,7 +945,7 @@ impl Render for MemoryView { ) .with_priority(1) })) - .children(self.render_vertical_scrollbar(cx)), + .child(self.render_vertical_scrollbar(cx)), ) } } diff --git a/crates/debugger_ui/src/session/running/stack_frame_list.rs b/crates/debugger_ui/src/session/running/stack_frame_list.rs index da3674c8e2..2149502f4a 100644 --- a/crates/debugger_ui/src/session/running/stack_frame_list.rs +++ b/crates/debugger_ui/src/session/running/stack_frame_list.rs @@ -70,13 +70,7 @@ impl StackFrameList { _ => {} }); - let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.), { - let this = cx.weak_entity(); - move |ix, _window, cx| { - this.update(cx, |this, cx| this.render_entry(ix, cx)) - .unwrap_or(div().into_any()) - } - }); + let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let scrollbar_state = ScrollbarState::new(list_state.clone()); let mut this = Self { @@ -708,11 +702,14 @@ impl StackFrameList { self.activate_selected_entry(window, cx); } - fn render_list(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - div() - .p_1() - .size_full() - .child(list(self.list_state.clone()).size_full()) + fn render_list(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + div().p_1().size_full().child( + list( + self.list_state.clone(), + cx.processor(|this, ix, _window, cx| this.render_entry(ix, cx)), + ) + .size_full(), + ) } } diff --git a/crates/debugger_ui/src/session/running/variable_list.rs b/crates/debugger_ui/src/session/running/variable_list.rs index 906e482687..efbc72e8cf 100644 --- a/crates/debugger_ui/src/session/running/variable_list.rs +++ b/crates/debugger_ui/src/session/running/variable_list.rs @@ -1107,7 +1107,7 @@ impl VariableList { let variable_value = value.clone(); this.on_click(cx.listener( move |this, click: &ClickEvent, window, cx| { - if click.down.click_count < 2 { + if click.click_count() < 2 { return; } let editor = Self::create_variable_editor( diff --git a/crates/debugger_ui/src/tests/debugger_panel.rs b/crates/debugger_ui/src/tests/debugger_panel.rs index 505df09cfb..6180831ea9 100644 --- a/crates/debugger_ui/src/tests/debugger_panel.rs +++ b/crates/debugger_ui/src/tests/debugger_panel.rs @@ -918,7 +918,7 @@ async fn test_debug_panel_item_thread_status_reset_on_failure( .unwrap(); let client = session.update(cx, |session, _| session.adapter_client().unwrap()); - const THREAD_ID_NUM: u64 = 1; + const THREAD_ID_NUM: i64 = 1; client.on_request::(move |_, _| { Ok(dap::ThreadsResponse { diff --git a/crates/diagnostics/src/diagnostics_tests.rs b/crates/diagnostics/src/diagnostics_tests.rs index 1364aaf853..1bb84488e8 100644 --- a/crates/diagnostics/src/diagnostics_tests.rs +++ b/crates/diagnostics/src/diagnostics_tests.rs @@ -873,7 +873,7 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S editor.splice_inlays( &[], - vec![Inlay::inline_completion( + vec![Inlay::edit_prediction( post_inc(&mut next_inlay_id), snapshot.buffer_snapshot.anchor_before(position), format!("Test inlay {next_inlay_id}"), diff --git a/crates/docs_preprocessor/Cargo.toml b/crates/docs_preprocessor/Cargo.toml index a0df669abe..e46ceb18db 100644 --- a/crates/docs_preprocessor/Cargo.toml +++ b/crates/docs_preprocessor/Cargo.toml @@ -7,17 +7,19 @@ license = "GPL-3.0-or-later" [dependencies] anyhow.workspace = true -clap.workspace = true -mdbook = "0.4.40" +command_palette.workspace = true +gpui.workspace = true +# We are specifically pinning this version of mdbook, as later versions introduce issues with double-nested subdirectories. +# Ask @maxdeviant about this before bumping. +mdbook = "= 0.4.40" +regex.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true -regex.workspace = true util.workspace = true workspace-hack.workspace = true zed.workspace = true -gpui.workspace = true -command_palette.workspace = true +zlog.workspace = true [lints] workspace = true diff --git a/crates/docs_preprocessor/src/main.rs b/crates/docs_preprocessor/src/main.rs index 8eeeb6f0c5..1448f4cb52 100644 --- a/crates/docs_preprocessor/src/main.rs +++ b/crates/docs_preprocessor/src/main.rs @@ -1,14 +1,15 @@ -use anyhow::Result; -use clap::{Arg, ArgMatches, Command}; +use anyhow::{Context, Result}; use mdbook::BookItem; use mdbook::book::{Book, Chapter}; use mdbook::preprocess::CmdPreprocessor; use regex::Regex; use settings::KeymapFile; -use std::collections::HashSet; +use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; use std::io::{self, Read}; use std::process; use std::sync::LazyLock; +use util::paths::PathExt; static KEYMAP_MACOS: LazyLock = LazyLock::new(|| { load_keymap("keymaps/default-macos.json").expect("Failed to load MacOS keymap") @@ -20,60 +21,68 @@ static KEYMAP_LINUX: LazyLock = LazyLock::new(|| { static ALL_ACTIONS: LazyLock> = LazyLock::new(dump_all_gpui_actions); -pub fn make_app() -> Command { - Command::new("zed-docs-preprocessor") - .about("Preprocesses Zed Docs content to provide rich action & keybinding support and more") - .subcommand( - Command::new("supports") - .arg(Arg::new("renderer").required(true)) - .about("Check whether a renderer is supported by this preprocessor"), - ) -} +const FRONT_MATTER_COMMENT: &'static str = ""; fn main() -> Result<()> { - let matches = make_app().get_matches(); + zlog::init(); + zlog::init_output_stderr(); // call a zed:: function so everything in `zed` crate is linked and // all actions in the actual app are registered zed::stdout_is_a_pty(); + let args = std::env::args().skip(1).collect::>(); - if let Some(sub_args) = matches.subcommand_matches("supports") { - handle_supports(sub_args); - } else { - handle_preprocessing()?; + match args.get(0).map(String::as_str) { + Some("supports") => { + let renderer = args.get(1).expect("Required argument"); + let supported = renderer != "not-supported"; + if supported { + process::exit(0); + } else { + process::exit(1); + } + } + Some("postprocess") => handle_postprocessing()?, + _ => handle_preprocessing()?, } Ok(()) } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -enum Error { +enum PreprocessorError { ActionNotFound { action_name: String }, DeprecatedActionUsed { used: String, should_be: String }, + InvalidFrontmatterLine(String), } -impl Error { +impl PreprocessorError { fn new_for_not_found_action(action_name: String) -> Self { for action in &*ALL_ACTIONS { for alias in action.deprecated_aliases { if alias == &action_name { - return Error::DeprecatedActionUsed { + return PreprocessorError::DeprecatedActionUsed { used: action_name.clone(), should_be: action.name.to_string(), }; } } } - Error::ActionNotFound { + PreprocessorError::ActionNotFound { action_name: action_name.to_string(), } } } -impl std::fmt::Display for Error { +impl std::fmt::Display for PreprocessorError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Error::ActionNotFound { action_name } => write!(f, "Action not found: {}", action_name), - Error::DeprecatedActionUsed { used, should_be } => write!( + PreprocessorError::InvalidFrontmatterLine(line) => { + write!(f, "Invalid frontmatter line: {}", line) + } + PreprocessorError::ActionNotFound { action_name } => { + write!(f, "Action not found: {}", action_name) + } + PreprocessorError::DeprecatedActionUsed { used, should_be } => write!( f, "Deprecated action used: {} should be {}", used, should_be @@ -89,8 +98,9 @@ fn handle_preprocessing() -> Result<()> { let (_ctx, mut book) = CmdPreprocessor::parse_input(input.as_bytes())?; - let mut errors = HashSet::::new(); + let mut errors = HashSet::::new(); + handle_frontmatter(&mut book, &mut errors); template_and_validate_keybindings(&mut book, &mut errors); template_and_validate_actions(&mut book, &mut errors); @@ -108,19 +118,41 @@ fn handle_preprocessing() -> Result<()> { Ok(()) } -fn handle_supports(sub_args: &ArgMatches) -> ! { - let renderer = sub_args - .get_one::("renderer") - .expect("Required argument"); - let supported = renderer != "not-supported"; - if supported { - process::exit(0); - } else { - process::exit(1); - } +fn handle_frontmatter(book: &mut Book, errors: &mut HashSet) { + let frontmatter_regex = Regex::new(r"(?s)^\s*---(.*?)---").unwrap(); + for_each_chapter_mut(book, |chapter| { + let new_content = frontmatter_regex.replace(&chapter.content, |caps: ®ex::Captures| { + let frontmatter = caps[1].trim(); + let frontmatter = frontmatter.trim_matches(&[' ', '-', '\n']); + let mut metadata = HashMap::::default(); + for line in frontmatter.lines() { + let Some((name, value)) = line.split_once(':') else { + errors.insert(PreprocessorError::InvalidFrontmatterLine(format!( + "{}: {}", + chapter_breadcrumbs(&chapter), + line + ))); + continue; + }; + let name = name.trim(); + let value = value.trim(); + metadata.insert(name.to_string(), value.to_string()); + } + FRONT_MATTER_COMMENT.replace( + "{}", + &serde_json::to_string(&metadata).expect("Failed to serialize metadata"), + ) + }); + match new_content { + Cow::Owned(content) => { + chapter.content = content; + } + Cow::Borrowed(_) => {} + } + }); } -fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { +fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { let regex = Regex::new(r"\{#kb (.*?)\}").unwrap(); for_each_chapter_mut(book, |chapter| { @@ -128,7 +160,9 @@ fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet) { +fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet) { let regex = Regex::new(r"\{#action (.*?)\}").unwrap(); for_each_chapter_mut(book, |chapter| { @@ -152,7 +186,9 @@ fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet) { .replace_all(&chapter.content, |caps: ®ex::Captures| { let name = caps[1].trim(); let Some(action) = find_action_by_name(name) else { - errors.insert(Error::new_for_not_found_action(name.to_string())); + errors.insert(PreprocessorError::new_for_not_found_action( + name.to_string(), + )); return String::new(); }; format!("{}", &action.human_name) @@ -217,6 +253,13 @@ fn name_for_action(action_as_str: String) -> String { .unwrap_or(action_as_str) } +fn chapter_breadcrumbs(chapter: &Chapter) -> String { + let mut breadcrumbs = Vec::with_capacity(chapter.parent_names.len() + 1); + breadcrumbs.extend(chapter.parent_names.iter().map(String::as_str)); + breadcrumbs.push(chapter.name.as_str()); + format!("[{:?}] {}", chapter.source_path, breadcrumbs.join(" > ")) +} + fn load_keymap(asset_path: &str) -> Result { let content = util::asset_str::(asset_path); KeymapFile::parse(content.as_ref()) @@ -254,3 +297,126 @@ fn dump_all_gpui_actions() -> Vec { return actions; } + +fn handle_postprocessing() -> Result<()> { + let logger = zlog::scoped!("render"); + let mut ctx = mdbook::renderer::RenderContext::from_json(io::stdin())?; + let output = ctx + .config + .get_mut("output") + .expect("has output") + .as_table_mut() + .expect("output is table"); + let zed_html = output.remove("zed-html").expect("zed-html output defined"); + let default_description = zed_html + .get("default-description") + .expect("Default description not found") + .as_str() + .expect("Default description not a string") + .to_string(); + let default_title = zed_html + .get("default-title") + .expect("Default title not found") + .as_str() + .expect("Default title not a string") + .to_string(); + + output.insert("html".to_string(), zed_html); + mdbook::Renderer::render(&mdbook::renderer::HtmlHandlebars::new(), &ctx)?; + let ignore_list = ["toc.html"]; + + let root_dir = ctx.destination.clone(); + let mut files = Vec::with_capacity(128); + let mut queue = Vec::with_capacity(64); + queue.push(root_dir.clone()); + while let Some(dir) = queue.pop() { + for entry in std::fs::read_dir(&dir).context(dir.to_sanitized_string())? { + let Ok(entry) = entry else { + continue; + }; + let file_type = entry.file_type().context("Failed to determine file type")?; + if file_type.is_dir() { + queue.push(entry.path()); + } + if file_type.is_file() + && matches!( + entry.path().extension().and_then(std::ffi::OsStr::to_str), + Some("html") + ) + { + if ignore_list.contains(&&*entry.file_name().to_string_lossy()) { + zlog::info!(logger => "Ignoring {}", entry.path().to_string_lossy()); + } else { + files.push(entry.path()); + } + } + } + } + + zlog::info!(logger => "Processing {} `.html` files", files.len()); + let meta_regex = Regex::new(&FRONT_MATTER_COMMENT.replace("{}", "(.*)")).unwrap(); + for file in files { + let contents = std::fs::read_to_string(&file)?; + let mut meta_description = None; + let mut meta_title = None; + let contents = meta_regex.replace(&contents, |caps: ®ex::Captures| { + let metadata: HashMap = serde_json::from_str(&caps[1]).with_context(|| format!("JSON Metadata: {:?}", &caps[1])).expect("Failed to deserialize metadata"); + for (kind, content) in metadata { + match kind.as_str() { + "description" => { + meta_description = Some(content); + } + "title" => { + meta_title = Some(content); + } + _ => { + zlog::warn!(logger => "Unrecognized frontmatter key: {} in {:?}", kind, pretty_path(&file, &root_dir)); + } + } + } + String::new() + }); + let meta_description = meta_description.as_ref().unwrap_or_else(|| { + zlog::warn!(logger => "No meta description found for {:?}", pretty_path(&file, &root_dir)); + &default_description + }); + let page_title = extract_title_from_page(&contents, pretty_path(&file, &root_dir)); + let meta_title = meta_title.as_ref().unwrap_or_else(|| { + zlog::debug!(logger => "No meta title found for {:?}", pretty_path(&file, &root_dir)); + &default_title + }); + let meta_title = format!("{} | {}", page_title, meta_title); + zlog::trace!(logger => "Updating {:?}", pretty_path(&file, &root_dir)); + let contents = contents.replace("#description#", meta_description); + let contents = TITLE_REGEX + .replace(&contents, |_: ®ex::Captures| { + format!("{}", meta_title) + }) + .to_string(); + // let contents = contents.replace("#title#", &meta_title); + std::fs::write(file, contents)?; + } + return Ok(()); + + fn pretty_path<'a>( + path: &'a std::path::PathBuf, + root: &'a std::path::PathBuf, + ) -> &'a std::path::Path { + &path.strip_prefix(&root).unwrap_or(&path) + } + const TITLE_REGEX: std::cell::LazyCell = + std::cell::LazyCell::new(|| Regex::new(r"\s*(.*?)\s*").unwrap()); + fn extract_title_from_page(contents: &str, pretty_path: &std::path::Path) -> String { + let title_tag_contents = &TITLE_REGEX + .captures(&contents) + .with_context(|| format!("Failed to find title in {:?}", pretty_path)) + .expect("Page has element")[1]; + let title = title_tag_contents + .trim() + .strip_suffix("- Zed") + .unwrap_or(title_tag_contents) + .trim() + .to_string(); + title + } +} diff --git a/crates/inline_completion/Cargo.toml b/crates/edit_prediction/Cargo.toml similarity index 82% rename from crates/inline_completion/Cargo.toml rename to crates/edit_prediction/Cargo.toml index 3a90875def..81c1e5dec2 100644 --- a/crates/inline_completion/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "inline_completion" +name = "edit_prediction" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,7 +9,7 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/inline_completion.rs" +path = "src/edit_prediction.rs" [dependencies] client.workspace = true diff --git a/crates/edit_prediction/LICENSE-GPL b/crates/edit_prediction/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/edit_prediction/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/inline_completion/src/inline_completion.rs b/crates/edit_prediction/src/edit_prediction.rs similarity index 95% rename from crates/inline_completion/src/inline_completion.rs rename to crates/edit_prediction/src/edit_prediction.rs index c8f35bf16a..fd4e9bb21d 100644 --- a/crates/inline_completion/src/inline_completion.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -7,7 +7,7 @@ use project::Project; // TODO: Find a better home for `Direction`. // -// This should live in an ancestor crate of `editor` and `inline_completion`, +// This should live in an ancestor crate of `editor` and `edit_prediction`, // but at time of writing there isn't an obvious spot. #[derive(Copy, Clone, PartialEq, Eq)] pub enum Direction { @@ -16,7 +16,7 @@ pub enum Direction { } #[derive(Clone)] -pub struct InlineCompletion { +pub struct EditPrediction { /// The ID of the completion, if it has one. pub id: Option<SharedString>, pub edits: Vec<(Range<language::Anchor>, String)>, @@ -102,10 +102,10 @@ pub trait EditPredictionProvider: 'static + Sized { buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut Context<Self>, - ) -> Option<InlineCompletion>; + ) -> Option<EditPrediction>; } -pub trait InlineCompletionProviderHandle { +pub trait EditPredictionProviderHandle { fn name(&self) -> &'static str; fn display_name(&self) -> &'static str; fn is_enabled( @@ -143,10 +143,10 @@ pub trait InlineCompletionProviderHandle { buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut App, - ) -> Option<InlineCompletion>; + ) -> Option<EditPrediction>; } -impl<T> InlineCompletionProviderHandle for Entity<T> +impl<T> EditPredictionProviderHandle for Entity<T> where T: EditPredictionProvider, { @@ -233,7 +233,7 @@ where buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut App, - ) -> Option<InlineCompletion> { + ) -> Option<EditPrediction> { self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx)) } } diff --git a/crates/inline_completion_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml similarity index 86% rename from crates/inline_completion_button/Cargo.toml rename to crates/edit_prediction_button/Cargo.toml index c2a619d500..07447280fa 100644 --- a/crates/inline_completion_button/Cargo.toml +++ b/crates/edit_prediction_button/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "inline_completion_button" +name = "edit_prediction_button" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,21 +9,23 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/inline_completion_button.rs" +path = "src/edit_prediction_button.rs" doctest = false [dependencies] anyhow.workspace = true client.workspace = true +cloud_llm_client.workspace = true copilot.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true gpui.workspace = true indoc.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true language.workspace = true paths.workspace = true +project.workspace = true regex.workspace = true settings.workspace = true supermaven.workspace = true @@ -32,7 +34,6 @@ ui.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true zeta.workspace = true [dev-dependencies] diff --git a/crates/edit_prediction_button/LICENSE-GPL b/crates/edit_prediction_button/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/edit_prediction_button/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs similarity index 95% rename from crates/inline_completion_button/src/inline_completion_button.rs rename to crates/edit_prediction_button/src/edit_prediction_button.rs index 2615a8beef..9ab94a4095 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -1,11 +1,8 @@ use anyhow::Result; -use client::{DisableAiSettings, UserStore, zed_urls}; +use client::{UserStore, zed_urls}; +use cloud_llm_client::UsageLimit; use copilot::{Copilot, Status}; -use editor::{ - Editor, SelectionEffects, - actions::{ShowEditPrediction, ToggleEditPrediction}, - scroll::Autoscroll, -}; +use editor::{Editor, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll}; use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag}; use fs::Fs; use gpui::{ @@ -18,6 +15,7 @@ use language::{ EditPredictionsMode, File, Language, language_settings::{self, AllLanguageSettings, EditPredictionProvider, all_language_settings}, }; +use project::DisableAiSettings; use regex::Regex; use settings::{Settings, SettingsStore, update_settings_file}; use std::{ @@ -34,13 +32,12 @@ use workspace::{ notifications::NotificationId, }; use zed_actions::OpenBrowser; -use zed_llm_client::UsageLimit; use zeta::RateCompletions; actions!( edit_prediction, [ - /// Toggles the inline completion menu. + /// Toggles the edit prediction menu. ToggleMenu ] ); @@ -50,14 +47,14 @@ const PRIVACY_DOCS: &str = "https://zed.dev/docs/ai/privacy-and-security"; struct CopilotErrorToast; -pub struct InlineCompletionButton { +pub struct EditPredictionButton { editor_subscription: Option<(Subscription, usize)>, editor_enabled: Option<bool>, editor_show_predictions: bool, editor_focus_handle: Option<FocusHandle>, language: Option<Arc<Language>>, file: Option<Arc<dyn File>>, - edit_prediction_provider: Option<Arc<dyn inline_completion::InlineCompletionProviderHandle>>, + edit_prediction_provider: Option<Arc<dyn edit_prediction::EditPredictionProviderHandle>>, fs: Arc<dyn Fs>, user_store: Entity<UserStore>, popover_menu_handle: PopoverMenuHandle<ContextMenu>, @@ -70,7 +67,7 @@ enum SupermavenButtonStatus { Initializing, } -impl Render for InlineCompletionButton { +impl Render for EditPredictionButton { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { // Return empty div if AI is disabled if DisableAiSettings::get_global(cx).disable_ai { @@ -246,12 +243,15 @@ impl Render for InlineCompletionButton { }; if zeta::should_show_upsell_modal(&self.user_store, cx) { - let tooltip_meta = - match self.user_store.read(cx).current_user_has_accepted_terms() { - Some(true) => "Choose a Plan", - Some(false) => "Accept the Terms of Service", - None => "Sign In", - }; + let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { + if self.user_store.read(cx).has_accepted_terms_of_service() { + "Choose a Plan" + } else { + "Accept the Terms of Service" + } + } else { + "Sign In" + }; return div().child( IconButton::new("zed-predict-pending-button", zeta_icon) @@ -365,7 +365,7 @@ impl Render for InlineCompletionButton { } } -impl InlineCompletionButton { +impl EditPredictionButton { pub fn new( fs: Arc<dyn Fs>, user_store: Entity<UserStore>, @@ -387,9 +387,9 @@ impl InlineCompletionButton { language: None, file: None, edit_prediction_provider: None, + user_store, popover_menu_handle, fs, - user_store, } } @@ -437,9 +437,13 @@ impl InlineCompletionButton { if let Some(editor_focus_handle) = self.editor_focus_handle.clone() { let entry = ContextMenuEntry::new("This Buffer") .toggleable(IconPosition::Start, self.editor_show_predictions) - .action(Box::new(ToggleEditPrediction)) + .action(Box::new(editor::actions::ToggleEditPrediction)) .handler(move |window, cx| { - editor_focus_handle.dispatch_action(&ToggleEditPrediction, window, cx); + editor_focus_handle.dispatch_action( + &editor::actions::ToggleEditPrediction, + window, + cx, + ); }); match language_state.clone() { @@ -466,7 +470,7 @@ impl InlineCompletionButton { IconPosition::Start, None, move |_, cx| { - toggle_show_inline_completions_for_language(language.clone(), fs.clone(), cx) + toggle_show_edit_predictions_for_language(language.clone(), fs.clone(), cx) }, ); } @@ -474,10 +478,13 @@ impl InlineCompletionButton { let settings = AllLanguageSettings::get_global(cx); let globally_enabled = settings.show_edit_predictions(None, cx); - menu = menu.toggleable_entry("All Files", globally_enabled, IconPosition::Start, None, { - let fs = fs.clone(); - move |_, cx| toggle_inline_completions_globally(fs.clone(), cx) - }); + let entry = ContextMenuEntry::new("All Files") + .toggleable(IconPosition::Start, globally_enabled) + .action(workspace::ToggleEditPrediction.boxed_clone()) + .handler(|window, cx| { + window.dispatch_action(workspace::ToggleEditPrediction.boxed_clone(), cx) + }); + menu = menu.item(entry); let provider = settings.edit_predictions.provider; let current_mode = settings.edit_predictions_mode(); @@ -831,7 +838,7 @@ impl InlineCompletionButton { } } -impl StatusItemView for InlineCompletionButton { +impl StatusItemView for EditPredictionButton { fn set_active_pane_item( &mut self, item: Option<&dyn ItemHandle>, @@ -901,7 +908,7 @@ async fn open_disabled_globs_setting_in_editor( let settings = cx.global::<SettingsStore>(); - // Ensure that we always have "inline_completions { "disabled_globs": [] }" + // Ensure that we always have "edit_predictions { "disabled_globs": [] }" let edits = settings.edits_for_update::<AllLanguageSettings>(&text, |file| { file.edit_predictions .get_or_insert_with(Default::default) @@ -939,13 +946,6 @@ async fn open_disabled_globs_setting_in_editor( anyhow::Ok(()) } -fn toggle_inline_completions_globally(fs: Arc<dyn Fs>, cx: &mut App) { - let show_edit_predictions = all_language_settings(None, cx).show_edit_predictions(None, cx); - update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| { - file.defaults.show_edit_predictions = Some(!show_edit_predictions) - }); -} - fn set_completion_provider(fs: Arc<dyn Fs>, cx: &mut App, provider: EditPredictionProvider) { update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| { file.features @@ -954,7 +954,7 @@ fn set_completion_provider(fs: Arc<dyn Fs>, cx: &mut App, provider: EditPredicti }); } -fn toggle_show_inline_completions_for_language( +fn toggle_show_edit_predictions_for_language( language: Arc<Language>, fs: Arc<dyn Fs>, cx: &mut App, diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 4d6939567e..339f98ae8b 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -22,6 +22,7 @@ test-support = [ "theme/test-support", "util/test-support", "workspace/test-support", + "tree-sitter-c", "tree-sitter-rust", "tree-sitter-typescript", "tree-sitter-html", @@ -47,7 +48,7 @@ fs.workspace = true git.workspace = true gpui.workspace = true indoc.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true itertools.workspace = true language.workspace = true linkify.workspace = true @@ -76,6 +77,7 @@ telemetry.workspace = true text.workspace = true time.workspace = true theme.workspace = true +tree-sitter-c = { workspace = true, optional = true } tree-sitter-html = { workspace = true, optional = true } tree-sitter-rust = { workspace = true, optional = true } tree-sitter-typescript = { workspace = true, optional = true } @@ -106,10 +108,12 @@ settings = { workspace = true, features = ["test-support"] } tempfile.workspace = true text = { workspace = true, features = ["test-support"] } theme = { workspace = true, features = ["test-support"] } +tree-sitter-c.workspace = true tree-sitter-html.workspace = true tree-sitter-rust.workspace = true tree-sitter-typescript.workspace = true tree-sitter-yaml.workspace = true +tree-sitter-bash.workspace = true unindent.workspace = true util = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index f80a6afbbb..3a3a57ca64 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -315,9 +315,8 @@ actions!( [ /// Accepts the full edit prediction. AcceptEditPrediction, - /// Accepts a partial Copilot suggestion. - AcceptPartialCopilotSuggestion, /// Accepts a partial edit prediction. + #[action(deprecated_aliases = ["editor::AcceptPartialCopilotSuggestion"])] AcceptPartialEditPrediction, /// Adds a cursor above the current selection. AddSelectionAbove, @@ -365,6 +364,8 @@ actions!( ConvertToLowerCase, /// Toggles the case of selected text. ConvertToOppositeCase, + /// Converts selected text to sentence case. + ConvertToSentenceCase, /// Converts selected text to snake_case. ConvertToSnakeCase, /// Converts selected text to Title Case. diff --git a/crates/editor/src/clangd_ext.rs b/crates/editor/src/clangd_ext.rs index b745bf8c37..3239fdc653 100644 --- a/crates/editor/src/clangd_ext.rs +++ b/crates/editor/src/clangd_ext.rs @@ -29,16 +29,14 @@ pub fn switch_source_header( return; }; - let server_lookup = - find_specific_language_server_in_selection(editor, cx, is_c_language, CLANGD_SERVER_NAME); + let Some((_, _, server_to_query, buffer)) = + find_specific_language_server_in_selection(editor, cx, is_c_language, CLANGD_SERVER_NAME) + else { + return; + }; let project = project.clone(); let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); cx.spawn_in(window, async move |_editor, cx| { - let Some((_, _, server_to_query, buffer)) = - server_lookup.await - else { - return Ok(()); - }; let source_file = buffer.read_with(cx, |buffer, _| { buffer.file().map(|file| file.path()).map(|path| path.to_string_lossy().to_string()).unwrap_or_else(|| "Unknown".to_string()) })?; diff --git a/crates/editor/src/code_completion_tests.rs b/crates/editor/src/code_completion_tests.rs index 4f9822b597..fd8db29584 100644 --- a/crates/editor/src/code_completion_tests.rs +++ b/crates/editor/src/code_completion_tests.rs @@ -94,7 +94,7 @@ async fn test_fuzzy_score(cx: &mut TestAppContext) { filter_and_sort_matches("set_text", &completions, SnippetSortOrder::Top, cx).await; assert_eq!(matches[0].string, "set_text"); assert_eq!(matches[1].string, "set_text_style_refinement"); - assert_eq!(matches[2].string, "set_context_menu_options"); + assert_eq!(matches[2].string, "set_placeholder_text"); } // fuzzy filter text over label, sort_text and sort_kind @@ -216,6 +216,28 @@ async fn test_sort_positions(cx: &mut TestAppContext) { assert_eq!(matches[0].string, "rounded-full"); } +#[gpui::test] +async fn test_fuzzy_over_sort_positions(cx: &mut TestAppContext) { + let completions = vec![ + CompletionBuilder::variable("lsp_document_colors", None, "7fffffff"), // 0.29 fuzzy score + CompletionBuilder::function( + "language_servers_running_disk_based_diagnostics", + None, + "7fffffff", + ), // 0.168 fuzzy score + CompletionBuilder::function("code_lens", None, "7fffffff"), // 3.2 fuzzy score + CompletionBuilder::variable("lsp_code_lens", None, "7fffffff"), // 3.2 fuzzy score + CompletionBuilder::function("fetch_code_lens", None, "7fffffff"), // 3.2 fuzzy score + ]; + + let matches = + filter_and_sort_matches("lens", &completions, SnippetSortOrder::default(), cx).await; + + assert_eq!(matches[0].string, "code_lens"); + assert_eq!(matches[1].string, "lsp_code_lens"); + assert_eq!(matches[2].string, "fetch_code_lens"); +} + async fn test_for_each_prefix<F>( target: &str, completions: &Vec<Completion>, diff --git a/crates/editor/src/code_context_menus.rs b/crates/editor/src/code_context_menus.rs index 52446ceafc..4ae2a14ca7 100644 --- a/crates/editor/src/code_context_menus.rs +++ b/crates/editor/src/code_context_menus.rs @@ -844,7 +844,7 @@ impl CompletionsMenu { .with_sizing_behavior(ListSizingBehavior::Infer) .w(rems(34.)); - Popover::new().child(div().child(list)).into_any_element() + Popover::new().child(list).into_any_element() } fn render_aside( @@ -1057,9 +1057,9 @@ impl CompletionsMenu { enum MatchTier<'a> { WordStartMatch { sort_exact: Reverse<i32>, - sort_positions: Vec<usize>, sort_snippet: Reverse<i32>, sort_score: Reverse<OrderedFloat<f64>>, + sort_positions: Vec<usize>, sort_text: Option<&'a str>, sort_kind: usize, sort_label: &'a str, @@ -1137,9 +1137,9 @@ impl CompletionsMenu { MatchTier::WordStartMatch { sort_exact, - sort_positions, sort_snippet, sort_score, + sort_positions, sort_text, sort_kind, sort_label, diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index 5425d5a8b9..a16e516a70 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -635,7 +635,7 @@ pub(crate) struct Highlights<'a> { } #[derive(Clone, Copy, Debug)] -pub struct InlineCompletionStyles { +pub struct EditPredictionStyles { pub insertion: HighlightStyle, pub whitespace: HighlightStyle, } @@ -643,7 +643,7 @@ pub struct InlineCompletionStyles { #[derive(Default, Debug, Clone, Copy)] pub struct HighlightStyles { pub inlay_hint: Option<HighlightStyle>, - pub inline_completion: Option<InlineCompletionStyles>, + pub edit_prediction: Option<EditPredictionStyles>, } #[derive(Clone)] @@ -958,7 +958,7 @@ impl DisplaySnapshot { language_aware, HighlightStyles { inlay_hint: Some(editor_style.inlay_hints_style), - inline_completion: Some(editor_style.inline_completion_styles), + edit_prediction: Some(editor_style.edit_prediction_styles), }, ) .flat_map(|chunk| { @@ -2036,7 +2036,7 @@ pub mod tests { map.update(cx, |map, cx| { map.splice_inlays( &[], - vec![Inlay::inline_completion( + vec![Inlay::edit_prediction( 0, buffer_snapshot.anchor_after(0), "\n", diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index 85495a2611..e25c02432d 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -22,7 +22,7 @@ use std::{ atomic::{AtomicUsize, Ordering::SeqCst}, }, }; -use sum_tree::{Bias, SumTree, Summary, TreeMap}; +use sum_tree::{Bias, Dimensions, SumTree, Summary, TreeMap}; use text::{BufferId, Edit}; use ui::ElementId; @@ -416,7 +416,7 @@ struct TransformSummary { } pub struct BlockChunks<'a> { - transforms: sum_tree::Cursor<'a, Transform, (BlockRow, WrapRow)>, + transforms: sum_tree::Cursor<'a, Transform, Dimensions<BlockRow, WrapRow>>, input_chunks: wrap_map::WrapChunks<'a>, input_chunk: Chunk<'a>, output_row: u32, @@ -426,7 +426,7 @@ pub struct BlockChunks<'a> { #[derive(Clone)] pub struct BlockRows<'a> { - transforms: sum_tree::Cursor<'a, Transform, (BlockRow, WrapRow)>, + transforms: sum_tree::Cursor<'a, Transform, Dimensions<BlockRow, WrapRow>>, input_rows: wrap_map::WrapRows<'a>, output_row: BlockRow, started: bool, @@ -970,7 +970,7 @@ impl BlockMapReader<'_> { .unwrap_or(self.wrap_snapshot.max_point().row() + 1), ); - let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<WrapRow, BlockRow>>(&()); cursor.seek(&start_wrap_row, Bias::Left); while let Some(transform) = cursor.item() { if cursor.start().0 > end_wrap_row { @@ -1292,7 +1292,7 @@ impl BlockSnapshot { ) -> BlockChunks<'a> { let max_output_row = cmp::min(rows.end, self.transforms.summary().output_rows); - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&BlockRow(rows.start), Bias::Right); let transform_output_start = cursor.start().0.0; let transform_input_start = cursor.start().1.0; @@ -1324,9 +1324,9 @@ impl BlockSnapshot { } pub(super) fn row_infos(&self, start_row: BlockRow) -> BlockRows<'_> { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&start_row, Bias::Right); - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = if cursor .item() .map_or(false, |transform| transform.block.is_none()) @@ -1441,14 +1441,14 @@ impl BlockSnapshot { } pub fn longest_row_in_range(&self, range: Range<BlockRow>) -> BlockRow { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&range.start, Bias::Right); let mut longest_row = range.start; let mut longest_row_chars = 0; if let Some(transform) = cursor.item() { if transform.block.is_none() { - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = range.start.0 - output_start.0; let wrap_start_row = input_start.0 + overshoot; let wrap_end_row = cmp::min( @@ -1474,7 +1474,7 @@ impl BlockSnapshot { if let Some(transform) = cursor.item() { if transform.block.is_none() { - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = range.end.0 - output_start.0; let wrap_start_row = input_start.0; let wrap_end_row = input_start.0 + overshoot; @@ -1492,10 +1492,10 @@ impl BlockSnapshot { } pub(super) fn line_len(&self, row: BlockRow) -> u32 { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&BlockRow(row.0), Bias::Right); if let Some(transform) = cursor.item() { - let (output_start, input_start) = cursor.start(); + let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = row.0 - output_start.0; if transform.block.is_some() { 0 @@ -1510,13 +1510,13 @@ impl BlockSnapshot { } pub(super) fn is_block_line(&self, row: BlockRow) -> bool { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&row, Bias::Right); cursor.item().map_or(false, |t| t.block.is_some()) } pub(super) fn is_folded_buffer_header(&self, row: BlockRow) -> bool { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&row, Bias::Right); let Some(transform) = cursor.item() else { return false; @@ -1528,7 +1528,7 @@ impl BlockSnapshot { let wrap_point = self .wrap_snapshot .make_wrap_point(Point::new(row.0, 0), Bias::Left); - let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<WrapRow, BlockRow>>(&()); cursor.seek(&WrapRow(wrap_point.row()), Bias::Right); cursor.item().map_or(false, |transform| { transform @@ -1539,7 +1539,7 @@ impl BlockSnapshot { } pub fn clip_point(&self, point: BlockPoint, bias: Bias) -> BlockPoint { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&BlockRow(point.row), Bias::Right); let max_input_row = WrapRow(self.transforms.summary().input_rows); @@ -1549,8 +1549,8 @@ impl BlockSnapshot { loop { if let Some(transform) = cursor.item() { - let (output_start_row, input_start_row) = cursor.start(); - let (output_end_row, input_end_row) = cursor.end(); + let Dimensions(output_start_row, input_start_row, _) = cursor.start(); + let Dimensions(output_end_row, input_end_row, _) = cursor.end(); let output_start = Point::new(output_start_row.0, 0); let input_start = Point::new(input_start_row.0, 0); let input_end = Point::new(input_end_row.0, 0); @@ -1599,13 +1599,13 @@ impl BlockSnapshot { } pub fn to_block_point(&self, wrap_point: WrapPoint) -> BlockPoint { - let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<WrapRow, BlockRow>>(&()); cursor.seek(&WrapRow(wrap_point.row()), Bias::Right); if let Some(transform) = cursor.item() { if transform.block.is_some() { BlockPoint::new(cursor.start().1.0, 0) } else { - let (input_start_row, output_start_row) = cursor.start(); + let Dimensions(input_start_row, output_start_row, _) = cursor.start(); let input_start = Point::new(input_start_row.0, 0); let output_start = Point::new(output_start_row.0, 0); let input_overshoot = wrap_point.0 - input_start; @@ -1617,7 +1617,7 @@ impl BlockSnapshot { } pub fn to_wrap_point(&self, block_point: BlockPoint, bias: Bias) -> WrapPoint { - let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(&()); cursor.seek(&BlockRow(block_point.row), Bias::Right); if let Some(transform) = cursor.item() { match transform.block.as_ref() { diff --git a/crates/editor/src/display_map/fold_map.rs b/crates/editor/src/display_map/fold_map.rs index 829d34ff58..c4e53a0f43 100644 --- a/crates/editor/src/display_map/fold_map.rs +++ b/crates/editor/src/display_map/fold_map.rs @@ -17,7 +17,7 @@ use std::{ sync::Arc, usize, }; -use sum_tree::{Bias, Cursor, FilterCursor, SumTree, Summary, TreeMap}; +use sum_tree::{Bias, Cursor, Dimensions, FilterCursor, SumTree, Summary, TreeMap}; use ui::IntoElement as _; use util::post_inc; @@ -98,7 +98,9 @@ impl FoldPoint { } pub fn to_inlay_point(self, snapshot: &FoldSnapshot) -> InlayPoint { - let mut cursor = snapshot.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); + let mut cursor = snapshot + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().0.0; InlayPoint(cursor.start().1.0 + overshoot) @@ -107,7 +109,7 @@ impl FoldPoint { pub fn to_offset(self, snapshot: &FoldSnapshot) -> FoldOffset { let mut cursor = snapshot .transforms - .cursor::<(FoldPoint, TransformSummary)>(&()); + .cursor::<Dimensions<FoldPoint, TransformSummary>>(&()); cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().1.output.lines; let mut offset = cursor.start().1.output.len; @@ -567,8 +569,9 @@ impl FoldMap { let mut old_transforms = self .snapshot .transforms - .cursor::<(InlayOffset, FoldOffset)>(&()); - let mut new_transforms = new_transforms.cursor::<(InlayOffset, FoldOffset)>(&()); + .cursor::<Dimensions<InlayOffset, FoldOffset>>(&()); + let mut new_transforms = + new_transforms.cursor::<Dimensions<InlayOffset, FoldOffset>>(&()); for mut edit in inlay_edits { old_transforms.seek(&edit.old.start, Bias::Left); @@ -651,7 +654,9 @@ impl FoldSnapshot { pub fn text_summary_for_range(&self, range: Range<FoldPoint>) -> TextSummary { let mut summary = TextSummary::default(); - let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); cursor.seek(&range.start, Bias::Right); if let Some(transform) = cursor.item() { let start_in_transform = range.start.0 - cursor.start().0.0; @@ -700,7 +705,9 @@ impl FoldSnapshot { } pub fn to_fold_point(&self, point: InlayPoint, bias: Bias) -> FoldPoint { - let mut cursor = self.transforms.cursor::<(InlayPoint, FoldPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayPoint, FoldPoint>>(&()); cursor.seek(&point, Bias::Right); if cursor.item().map_or(false, |t| t.is_fold()) { if bias == Bias::Left || point == cursor.start().0 { @@ -734,7 +741,9 @@ impl FoldSnapshot { } let fold_point = FoldPoint::new(start_row, 0); - let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); cursor.seek(&fold_point, Bias::Left); let overshoot = fold_point.0 - cursor.start().0.0; @@ -816,7 +825,9 @@ impl FoldSnapshot { language_aware: bool, highlights: Highlights<'a>, ) -> FoldChunks<'a> { - let mut transform_cursor = self.transforms.cursor::<(FoldOffset, InlayOffset)>(&()); + let mut transform_cursor = self + .transforms + .cursor::<Dimensions<FoldOffset, InlayOffset>>(&()); transform_cursor.seek(&range.start, Bias::Right); let inlay_start = { @@ -871,7 +882,9 @@ impl FoldSnapshot { } pub fn clip_point(&self, point: FoldPoint, bias: Bias) -> FoldPoint { - let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<FoldPoint, InlayPoint>>(&()); cursor.seek(&point, Bias::Right); if let Some(transform) = cursor.item() { let transform_start = cursor.start().0.0; @@ -1196,7 +1209,7 @@ impl<'a> sum_tree::Dimension<'a, FoldSummary> for usize { #[derive(Clone)] pub struct FoldRows<'a> { - cursor: Cursor<'a, Transform, (FoldPoint, InlayPoint)>, + cursor: Cursor<'a, Transform, Dimensions<FoldPoint, InlayPoint>>, input_rows: InlayBufferRows<'a>, fold_point: FoldPoint, } @@ -1313,7 +1326,7 @@ impl DerefMut for ChunkRendererContext<'_, '_> { } pub struct FoldChunks<'a> { - transform_cursor: Cursor<'a, Transform, (FoldOffset, InlayOffset)>, + transform_cursor: Cursor<'a, Transform, Dimensions<FoldOffset, InlayOffset>>, inlay_chunks: InlayChunks<'a>, inlay_chunk: Option<(InlayOffset, InlayChunk<'a>)>, inlay_offset: InlayOffset, @@ -1448,7 +1461,7 @@ impl FoldOffset { pub fn to_point(self, snapshot: &FoldSnapshot) -> FoldPoint { let mut cursor = snapshot .transforms - .cursor::<(FoldOffset, TransformSummary)>(&()); + .cursor::<Dimensions<FoldOffset, TransformSummary>>(&()); cursor.seek(&self, Bias::Right); let overshoot = if cursor.item().map_or(true, |t| t.is_fold()) { Point::new(0, (self.0 - cursor.start().0.0) as u32) @@ -1462,7 +1475,9 @@ impl FoldOffset { #[cfg(test)] pub fn to_inlay_offset(self, snapshot: &FoldSnapshot) -> InlayOffset { - let mut cursor = snapshot.transforms.cursor::<(FoldOffset, InlayOffset)>(&()); + let mut cursor = snapshot + .transforms + .cursor::<Dimensions<FoldOffset, InlayOffset>>(&()); cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().0.0; InlayOffset(cursor.start().1.0 + overshoot) diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index a36d18ff6d..fd49c262c6 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -10,7 +10,7 @@ use std::{ ops::{Add, AddAssign, Range, Sub, SubAssign}, sync::Arc, }; -use sum_tree::{Bias, Cursor, SumTree}; +use sum_tree::{Bias, Cursor, Dimensions, SumTree}; use text::{Patch, Rope}; use ui::{ActiveTheme, IntoElement as _, ParentElement as _, Styled as _, div}; @@ -81,9 +81,9 @@ impl Inlay { } } - pub fn inline_completion<T: Into<Rope>>(id: usize, position: Anchor, text: T) -> Self { + pub fn edit_prediction<T: Into<Rope>>(id: usize, position: Anchor, text: T) -> Self { Self { - id: InlayId::InlineCompletion(id), + id: InlayId::EditPrediction(id), position, text: text.into(), color: None, @@ -235,14 +235,14 @@ impl<'a> sum_tree::Dimension<'a, TransformSummary> for Point { #[derive(Clone)] pub struct InlayBufferRows<'a> { - transforms: Cursor<'a, Transform, (InlayPoint, Point)>, + transforms: Cursor<'a, Transform, Dimensions<InlayPoint, Point>>, buffer_rows: MultiBufferRows<'a>, inlay_row: u32, max_buffer_row: MultiBufferRow, } pub struct InlayChunks<'a> { - transforms: Cursor<'a, Transform, (InlayOffset, usize)>, + transforms: Cursor<'a, Transform, Dimensions<InlayOffset, usize>>, buffer_chunks: CustomHighlightsChunks<'a>, buffer_chunk: Option<Chunk<'a>>, inlay_chunks: Option<text::Chunks<'a>>, @@ -340,15 +340,13 @@ impl<'a> Iterator for InlayChunks<'a> { let mut renderer = None; let mut highlight_style = match inlay.id { - InlayId::InlineCompletion(_) => { - self.highlight_styles.inline_completion.map(|s| { - if inlay.text.chars().all(|c| c.is_whitespace()) { - s.whitespace - } else { - s.insertion - } - }) - } + InlayId::EditPrediction(_) => self.highlight_styles.edit_prediction.map(|s| { + if inlay.text.chars().all(|c| c.is_whitespace()) { + s.whitespace + } else { + s.insertion + } + }), InlayId::Hint(_) => self.highlight_styles.inlay_hint, InlayId::DebuggerValue(_) => self.highlight_styles.inlay_hint, InlayId::Color(_) => { @@ -553,7 +551,9 @@ impl InlayMap { } else { let mut inlay_edits = Patch::default(); let mut new_transforms = SumTree::default(); - let mut cursor = snapshot.transforms.cursor::<(usize, InlayOffset)>(&()); + let mut cursor = snapshot + .transforms + .cursor::<Dimensions<usize, InlayOffset>>(&()); let mut buffer_edits_iter = buffer_edits.iter().peekable(); while let Some(buffer_edit) = buffer_edits_iter.next() { new_transforms.append(cursor.slice(&buffer_edit.old.start, Bias::Left), &()); @@ -740,7 +740,7 @@ impl InlayMap { text.clone(), ) } else { - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(next_inlay_id), snapshot.buffer.anchor_at(position, bias), text.clone(), @@ -772,20 +772,20 @@ impl InlaySnapshot { pub fn to_point(&self, offset: InlayOffset) -> InlayPoint { let mut cursor = self .transforms - .cursor::<(InlayOffset, (InlayPoint, usize))>(&()); + .cursor::<Dimensions<InlayOffset, InlayPoint, usize>>(&()); cursor.seek(&offset, Bias::Right); let overshoot = offset.0 - cursor.start().0.0; match cursor.item() { Some(Transform::Isomorphic(_)) => { - let buffer_offset_start = cursor.start().1.1; + let buffer_offset_start = cursor.start().2; let buffer_offset_end = buffer_offset_start + overshoot; let buffer_start = self.buffer.offset_to_point(buffer_offset_start); let buffer_end = self.buffer.offset_to_point(buffer_offset_end); - InlayPoint(cursor.start().1.0.0 + (buffer_end - buffer_start)) + InlayPoint(cursor.start().1.0 + (buffer_end - buffer_start)) } Some(Transform::Inlay(inlay)) => { let overshoot = inlay.text.offset_to_point(overshoot); - InlayPoint(cursor.start().1.0.0 + overshoot) + InlayPoint(cursor.start().1.0 + overshoot) } None => self.max_point(), } @@ -802,26 +802,26 @@ impl InlaySnapshot { pub fn to_offset(&self, point: InlayPoint) -> InlayOffset { let mut cursor = self .transforms - .cursor::<(InlayPoint, (InlayOffset, Point))>(&()); + .cursor::<Dimensions<InlayPoint, InlayOffset, Point>>(&()); cursor.seek(&point, Bias::Right); let overshoot = point.0 - cursor.start().0.0; match cursor.item() { Some(Transform::Isomorphic(_)) => { - let buffer_point_start = cursor.start().1.1; + let buffer_point_start = cursor.start().2; let buffer_point_end = buffer_point_start + overshoot; let buffer_offset_start = self.buffer.point_to_offset(buffer_point_start); let buffer_offset_end = self.buffer.point_to_offset(buffer_point_end); - InlayOffset(cursor.start().1.0.0 + (buffer_offset_end - buffer_offset_start)) + InlayOffset(cursor.start().1.0 + (buffer_offset_end - buffer_offset_start)) } Some(Transform::Inlay(inlay)) => { let overshoot = inlay.text.point_to_offset(overshoot); - InlayOffset(cursor.start().1.0.0 + overshoot) + InlayOffset(cursor.start().1.0 + overshoot) } None => self.len(), } } pub fn to_buffer_point(&self, point: InlayPoint) -> Point { - let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(&()); cursor.seek(&point, Bias::Right); match cursor.item() { Some(Transform::Isomorphic(_)) => { @@ -833,7 +833,9 @@ impl InlaySnapshot { } } pub fn to_buffer_offset(&self, offset: InlayOffset) -> usize { - let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayOffset, usize>>(&()); cursor.seek(&offset, Bias::Right); match cursor.item() { Some(Transform::Isomorphic(_)) => { @@ -846,7 +848,9 @@ impl InlaySnapshot { } pub fn to_inlay_offset(&self, offset: usize) -> InlayOffset { - let mut cursor = self.transforms.cursor::<(usize, InlayOffset)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<usize, InlayOffset>>(&()); cursor.seek(&offset, Bias::Left); loop { match cursor.item() { @@ -879,7 +883,7 @@ impl InlaySnapshot { } } pub fn to_inlay_point(&self, point: Point) -> InlayPoint { - let mut cursor = self.transforms.cursor::<(Point, InlayPoint)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<Point, InlayPoint>>(&()); cursor.seek(&point, Bias::Left); loop { match cursor.item() { @@ -913,7 +917,7 @@ impl InlaySnapshot { } pub fn clip_point(&self, mut point: InlayPoint, mut bias: Bias) -> InlayPoint { - let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(&()); cursor.seek(&point, Bias::Left); loop { match cursor.item() { @@ -1010,7 +1014,9 @@ impl InlaySnapshot { pub fn text_summary_for_range(&self, range: Range<InlayOffset>) -> TextSummary { let mut summary = TextSummary::default(); - let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayOffset, usize>>(&()); cursor.seek(&range.start, Bias::Right); let overshoot = range.start.0 - cursor.start().0.0; @@ -1058,7 +1064,7 @@ impl InlaySnapshot { } pub fn row_infos(&self, row: u32) -> InlayBufferRows<'_> { - let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); + let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(&()); let inlay_point = InlayPoint::new(row, 0); cursor.seek(&inlay_point, Bias::Left); @@ -1100,7 +1106,9 @@ impl InlaySnapshot { language_aware: bool, highlights: Highlights<'a>, ) -> InlayChunks<'a> { - let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<InlayOffset, usize>>(&()); cursor.seek(&range.start, Bias::Right); let buffer_range = self.to_buffer_offset(range.start)..self.to_buffer_offset(range.end); @@ -1389,7 +1397,7 @@ mod tests { buffer.read(cx).snapshot(cx).anchor_before(3), "|123|", ), - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut next_inlay_id), buffer.read(cx).snapshot(cx).anchor_after(3), "|456|", @@ -1609,7 +1617,7 @@ mod tests { buffer.read(cx).snapshot(cx).anchor_before(4), "|456|", ), - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut next_inlay_id), buffer.read(cx).snapshot(cx).anchor_before(7), "\n|567|\n", diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index d55577826e..269f8f0c40 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -9,7 +9,7 @@ use multi_buffer::{MultiBufferSnapshot, RowInfo}; use smol::future::yield_now; use std::sync::LazyLock; use std::{cmp, collections::VecDeque, mem, ops::Range, time::Duration}; -use sum_tree::{Bias, Cursor, SumTree}; +use sum_tree::{Bias, Cursor, Dimensions, SumTree}; use text::Patch; pub use super::tab_map::TextSummary; @@ -55,7 +55,7 @@ pub struct WrapChunks<'a> { input_chunk: Chunk<'a>, output_position: WrapPoint, max_output_row: u32, - transforms: Cursor<'a, Transform, (WrapPoint, TabPoint)>, + transforms: Cursor<'a, Transform, Dimensions<WrapPoint, TabPoint>>, snapshot: &'a WrapSnapshot, } @@ -66,7 +66,7 @@ pub struct WrapRows<'a> { output_row: u32, soft_wrapped: bool, max_output_row: u32, - transforms: Cursor<'a, Transform, (WrapPoint, TabPoint)>, + transforms: Cursor<'a, Transform, Dimensions<WrapPoint, TabPoint>>, } impl WrapRows<'_> { @@ -598,7 +598,9 @@ impl WrapSnapshot { ) -> WrapChunks<'a> { let output_start = WrapPoint::new(rows.start, 0); let output_end = WrapPoint::new(rows.end, 0); - let mut transforms = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut transforms = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); transforms.seek(&output_start, Bias::Right); let mut input_start = TabPoint(transforms.start().1.0); if transforms.item().map_or(false, |t| t.is_isomorphic()) { @@ -626,7 +628,9 @@ impl WrapSnapshot { } pub fn line_len(&self, row: u32) -> u32 { - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Left); if cursor .item() @@ -651,7 +655,9 @@ impl WrapSnapshot { let start = WrapPoint::new(rows.start, 0); let end = WrapPoint::new(rows.end, 0); - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&start, Bias::Right); if let Some(transform) = cursor.item() { let start_in_transform = start.0 - cursor.start().0.0; @@ -721,7 +727,9 @@ impl WrapSnapshot { } pub fn row_infos(&self, start_row: u32) -> WrapRows<'_> { - let mut transforms = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut transforms = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); transforms.seek(&WrapPoint::new(start_row, 0), Bias::Left); let mut input_row = transforms.start().1.row(); if transforms.item().map_or(false, |t| t.is_isomorphic()) { @@ -741,7 +749,9 @@ impl WrapSnapshot { } pub fn to_tab_point(&self, point: WrapPoint) -> TabPoint { - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&point, Bias::Right); let mut tab_point = cursor.start().1.0; if cursor.item().map_or(false, |t| t.is_isomorphic()) { @@ -759,7 +769,9 @@ impl WrapSnapshot { } pub fn tab_point_to_wrap_point(&self, point: TabPoint) -> WrapPoint { - let mut cursor = self.transforms.cursor::<(TabPoint, WrapPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<TabPoint, WrapPoint>>(&()); cursor.seek(&point, Bias::Right); WrapPoint(cursor.start().1.0 + (point.0 - cursor.start().0.0)) } @@ -784,7 +796,9 @@ impl WrapSnapshot { *point.column_mut() = 0; - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&point, Bias::Right); if cursor.item().is_none() { cursor.prev(); @@ -804,7 +818,9 @@ impl WrapSnapshot { pub fn next_row_boundary(&self, mut point: WrapPoint) -> Option<u32> { point.0 += Point::new(1, 0); - let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); + let mut cursor = self + .transforms + .cursor::<Dimensions<WrapPoint, TabPoint>>(&()); cursor.seek(&point, Bias::Right); while let Some(transform) = cursor.item() { if transform.is_isomorphic() && cursor.start().1.column() == 0 { diff --git a/crates/editor/src/inline_completion_tests.rs b/crates/editor/src/edit_prediction_tests.rs similarity index 81% rename from crates/editor/src/inline_completion_tests.rs rename to crates/editor/src/edit_prediction_tests.rs index 5ac34c94f5..527dfb8832 100644 --- a/crates/editor/src/inline_completion_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -1,26 +1,26 @@ +use edit_prediction::EditPredictionProvider; use gpui::{Entity, prelude::*}; use indoc::indoc; -use inline_completion::EditPredictionProvider; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; use project::Project; use std::ops::Range; use text::{Point, ToOffset}; use crate::{ - InlineCompletion, editor_tests::init_test, test::editor_test_context::EditorTestContext, + EditPrediction, editor_tests::init_test, test::editor_test_context::EditorTestContext, }; #[gpui::test] -async fn test_inline_completion_insert(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let absolute_zero_celsius = ˇ;"); propose_edits(&provider, vec![(28..28, "-273.15")], &mut cx); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_edit_completion(&mut cx, |_, edits| { assert_eq!(edits.len(), 1); @@ -33,16 +33,16 @@ async fn test_inline_completion_insert(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_inline_completion_modification(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let pi = ˇ\"foo\";"); propose_edits(&provider, vec![(9..14, "3.14159")], &mut cx); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_edit_completion(&mut cx, |_, edits| { assert_eq!(edits.len(), 1); @@ -55,11 +55,11 @@ async fn test_inline_completion_modification(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_jump_button(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 2+ lines above the proposed edit @@ -77,7 +77,7 @@ async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), Point::new(4, 3)); }); @@ -107,7 +107,7 @@ async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), Point::new(1, 3)); }); @@ -124,11 +124,11 @@ async fn test_inline_completion_jump_button(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext) { +async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 3+ lines above the proposed edit @@ -148,7 +148,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), edit_location); }); @@ -176,7 +176,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext line "}); cx.editor(|editor, _, _| { - assert!(editor.active_inline_completion.is_none()); + assert!(editor.active_edit_prediction.is_none()); }); // Cursor is 3+ lines below the proposed edit @@ -196,7 +196,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext &mut cx, ); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); assert_editor_active_move_completion(&mut cx, |snapshot, move_target| { assert_eq!(move_target.to_point(&snapshot), edit_location); }); @@ -224,7 +224,7 @@ async fn test_inline_completion_invalidation_range(cx: &mut gpui::TestAppContext line ˇ5 "}); cx.editor(|editor, _, _| { - assert!(editor.active_inline_completion.is_none()); + assert!(editor.active_edit_prediction.is_none()); }); } @@ -234,11 +234,11 @@ fn assert_editor_active_edit_completion( ) { cx.editor(|editor, _, cx| { let completion_state = editor - .active_inline_completion + .active_edit_prediction .as_ref() .expect("editor has no active completion"); - if let InlineCompletion::Edit { edits, .. } = &completion_state.completion { + if let EditPrediction::Edit { edits, .. } = &completion_state.completion { assert(editor.buffer().read(cx).snapshot(cx), edits); } else { panic!("expected edit completion"); @@ -252,11 +252,11 @@ fn assert_editor_active_move_completion( ) { cx.editor(|editor, _, cx| { let completion_state = editor - .active_inline_completion + .active_edit_prediction .as_ref() .expect("editor has no active completion"); - if let InlineCompletion::Move { target, .. } = &completion_state.completion { + if let EditPrediction::Move { target, .. } = &completion_state.completion { assert(editor.buffer().read(cx).snapshot(cx), *target); } else { panic!("expected move completion"); @@ -271,7 +271,7 @@ fn accept_completion(cx: &mut EditorTestContext) { } fn propose_edits<T: ToOffset>( - provider: &Entity<FakeInlineCompletionProvider>, + provider: &Entity<FakeEditPredictionProvider>, edits: Vec<(Range<T>, &str)>, cx: &mut EditorTestContext, ) { @@ -283,7 +283,7 @@ fn propose_edits<T: ToOffset>( cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_inline_completion(Some(inline_completion::InlineCompletion { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction { id: None, edits: edits.collect(), edit_preview: None, @@ -293,7 +293,7 @@ fn propose_edits<T: ToOffset>( } fn assign_editor_completion_provider( - provider: Entity<FakeInlineCompletionProvider>, + provider: Entity<FakeEditPredictionProvider>, cx: &mut EditorTestContext, ) { cx.update_editor(|editor, window, cx| { @@ -302,20 +302,17 @@ fn assign_editor_completion_provider( } #[derive(Default, Clone)] -pub struct FakeInlineCompletionProvider { - pub completion: Option<inline_completion::InlineCompletion>, +pub struct FakeEditPredictionProvider { + pub completion: Option<edit_prediction::EditPrediction>, } -impl FakeInlineCompletionProvider { - pub fn set_inline_completion( - &mut self, - completion: Option<inline_completion::InlineCompletion>, - ) { +impl FakeEditPredictionProvider { + pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) { self.completion = completion; } } -impl EditPredictionProvider for FakeInlineCompletionProvider { +impl EditPredictionProvider for FakeEditPredictionProvider { fn name() -> &'static str { "fake-completion-provider" } @@ -355,7 +352,7 @@ impl EditPredictionProvider for FakeInlineCompletionProvider { &mut self, _buffer: gpui::Entity<language::Buffer>, _cursor_position: language::Anchor, - _direction: inline_completion::Direction, + _direction: edit_prediction::Direction, _cx: &mut gpui::Context<Self>, ) { } @@ -369,7 +366,7 @@ impl EditPredictionProvider for FakeInlineCompletionProvider { _buffer: &gpui::Entity<language::Buffer>, _cursor_position: language::Anchor, _cx: &mut gpui::Context<Self>, - ) -> Option<inline_completion::InlineCompletion> { + ) -> Option<edit_prediction::EditPrediction> { self.completion.clone() } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index a695c8fd0c..156fda1b37 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -43,50 +43,65 @@ pub mod tasks; #[cfg(test)] mod code_completion_tests; #[cfg(test)] -mod editor_tests; +mod edit_prediction_tests; #[cfg(test)] -mod inline_completion_tests; +mod editor_tests; mod signature_help; #[cfg(any(test, feature = "test-support"))] pub mod test; pub(crate) use actions::*; -pub use actions::{AcceptEditPrediction, OpenExcerpts, OpenExcerptsSplit}; +pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; +pub use edit_prediction::Direction; +pub use editor_settings::{ + CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, + ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap, ShowScrollbar, +}; +pub use editor_settings_controls::*; +pub use element::{ + CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition, +}; +pub use git::blame::BlameRenderer; +pub use hover_popover::hover_markdown_style; +pub use items::MAX_TAB_TITLE_LEN; +pub use lsp::CompletionContext; +pub use lsp_ext::lsp_tasks; +pub use multi_buffer::{ + Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, PathKey, + RowInfo, ToOffset, ToPoint, +}; +pub use proposed_changes_editor::{ + ProposedChangeLocation, ProposedChangesEditor, ProposedChangesEditorToolbar, +}; +pub use text::Bias; + +use ::git::{ + Restore, + blame::{BlameEntry, ParsedCommitMessage}, +}; use aho_corasick::AhoCorasick; use anyhow::{Context as _, Result, anyhow}; use blink_manager::BlinkManager; use buffer_diff::DiffHunkStatus; use client::{Collaborator, ParticipantIndex}; use clock::{AGENT_REPLICA_ID, ReplicaId}; +use code_context_menus::{ + AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu, + CompletionsMenu, ContextMenuOrigin, +}; use collections::{BTreeMap, HashMap, HashSet, VecDeque}; use convert_case::{Case, Casing}; use dap::TelemetrySpawnLocation; use display_map::*; -pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; -pub use editor_settings::{ - CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, - ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowScrollbar, -}; +use edit_prediction::{EditPredictionProvider, EditPredictionProviderHandle}; use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings}; -pub use editor_settings_controls::*; use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layout_line}; -pub use element::{ - CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition, -}; use futures::{ FutureExt, StreamExt as _, future::{self, Shared, join}, stream::FuturesUnordered, }; use fuzzy::{StringMatch, StringMatchCandidate}; -use lsp_colors::LspColorData; - -use ::git::blame::BlameEntry; -use ::git::{Restore, blame::ParsedCommitMessage}; -use code_context_menus::{ - AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu, - CompletionsMenu, ContextMenuOrigin, -}; use git::blame::{GitBlame, GlobalBlameRenderer}; use gpui::{ Action, Animation, AnimationExt, AnyElement, App, AppContext, AsyncWindowContext, @@ -100,32 +115,42 @@ use gpui::{ }; use highlight_matching_bracket::refresh_matching_bracket_highlights; use hover_links::{HoverLink, HoveredLinkState, InlayHighlight, find_file}; -pub use hover_popover::hover_markdown_style; use hover_popover::{HoverState, hide_hover}; use indent_guides::ActiveIndentGuidesState; use inlay_hint_cache::{InlayHintCache, InlaySplice, InvalidationStrategy}; -pub use inline_completion::Direction; -use inline_completion::{EditPredictionProvider, InlineCompletionProviderHandle}; -pub use items::MAX_TAB_TITLE_LEN; use itertools::Itertools; use language::{ - AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, Capability, CharKind, - CodeLabel, CursorShape, DiagnosticEntry, DiffOptions, EditPredictionsMode, EditPreview, - HighlightedText, IndentKind, IndentSize, Language, OffsetRangeExt, Point, Selection, - SelectionGoal, TextObject, TransactionId, TreeSitterOptions, WordsQuery, + AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, BufferRow, + BufferSnapshot, Capability, CharClassifier, CharKind, CodeLabel, CursorShape, DiagnosticEntry, + DiffOptions, EditPredictionsMode, EditPreview, HighlightedText, IndentKind, IndentSize, + Language, OffsetRangeExt, Point, Runnable, RunnableRange, Selection, SelectionGoal, TextObject, + TransactionId, TreeSitterOptions, WordsQuery, language_settings::{ self, InlayHintSettings, LspInsertMode, RewrapBehavior, WordsCompletionMode, all_language_settings, language_settings, }, - point_from_lsp, text_diff_with_options, + point_from_lsp, point_to_lsp, text_diff_with_options, }; -use language::{BufferRow, CharClassifier, Runnable, RunnableRange, point_to_lsp}; use linked_editing_ranges::refresh_linked_ranges; +use lsp::{ + CodeActionKind, CompletionItemKind, CompletionTriggerKind, InsertTextFormat, InsertTextMode, + LanguageServerId, +}; +use lsp_colors::LspColorData; use markdown::Markdown; use mouse_context_menu::MouseContextMenu; +use movement::TextLayoutDetails; +use multi_buffer::{ + ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow, + MultiOrSingleBufferOffsetRange, ToOffsetUtf16, +}; +use parking_lot::Mutex; use persistence::DB; use project::{ - BreakpointWithPosition, CompletionResponse, ProjectPath, + BreakpointWithPosition, CodeAction, Completion, CompletionIntent, CompletionResponse, + CompletionSource, DisableAiSettings, DocumentHighlight, InlayHint, Location, LocationLink, + PrepareRenameResponse, Project, ProjectItem, ProjectPath, ProjectTransaction, TaskSourceKind, + debugger::breakpoint_store::Breakpoint, debugger::{ breakpoint_store::{ BreakpointEditAction, BreakpointSessionState, BreakpointState, BreakpointStore, @@ -134,44 +159,12 @@ use project::{ session::{Session, SessionEvent}, }, git_store::{GitStoreEvent, RepositoryEvent}, - project_settings::{DiagnosticSeverity, GoToDiagnosticSeverityFilter}, -}; - -pub use git::blame::BlameRenderer; -pub use proposed_changes_editor::{ - ProposedChangeLocation, ProposedChangesEditor, ProposedChangesEditorToolbar, -}; -use std::{cell::OnceCell, iter::Peekable, ops::Not}; -use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables}; - -pub use lsp::CompletionContext; -use lsp::{ - CodeActionKind, CompletionItemKind, CompletionTriggerKind, InsertTextFormat, InsertTextMode, - LanguageServerId, LanguageServerName, -}; - -use language::BufferSnapshot; -pub use lsp_ext::lsp_tasks; -use movement::TextLayoutDetails; -pub use multi_buffer::{ - Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, PathKey, - RowInfo, ToOffset, ToPoint, -}; -use multi_buffer::{ - ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow, - MultiOrSingleBufferOffsetRange, ToOffsetUtf16, -}; -use parking_lot::Mutex; -use project::{ - CodeAction, Completion, CompletionIntent, CompletionSource, DocumentHighlight, InlayHint, - Location, LocationLink, PrepareRenameResponse, Project, ProjectItem, ProjectTransaction, - TaskSourceKind, - debugger::breakpoint_store::Breakpoint, lsp_store::{CompletionDocumentation, FormatTrigger, LspFormatTarget, OpenLspBufferHandle}, + project_settings::{DiagnosticSeverity, GoToDiagnosticSeverityFilter}, project_settings::{GitGutterSetting, ProjectSettings}, }; -use rand::prelude::*; -use rpc::{ErrorExt, proto::*}; +use rand::{seq::SliceRandom, thread_rng}; +use rpc::{ErrorCode, ErrorExt, proto::PeerId}; use scroll::{Autoscroll, OngoingScroll, ScrollAnchor, ScrollManager, ScrollbarAutoHide}; use selections_collection::{ MutableSelectionsCollection, SelectionsCollection, resolve_selections, @@ -180,21 +173,24 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsLocation, SettingsStore, update_settings_file}; use smallvec::{SmallVec, smallvec}; use snippet::Snippet; -use std::sync::Arc; use std::{ any::TypeId, borrow::Cow, + cell::OnceCell, cell::RefCell, cmp::{self, Ordering, Reverse}, + iter::Peekable, mem, num::NonZeroU32, + ops::Not, ops::{ControlFlow, Deref, DerefMut, Range, RangeInclusive}, path::{Path, PathBuf}, rc::Rc, + sync::Arc, time::{Duration, Instant}, }; -pub use sum_tree::Bias; use sum_tree::TreeMap; +use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables}; use text::{BufferId, FromAnchor, OffsetUtf16, Rope}; use theme::{ ActiveTheme, PlayerColor, StatusColors, SyntaxTheme, Theme, ThemeSettings, @@ -213,14 +209,11 @@ use workspace::{ notifications::{DetachAndPromptErr, NotificationId, NotifyTaskExt}, searchable::SearchEvent, }; -use zed_actions; use crate::{ code_context_menus::CompletionsMenuSource, - hover_links::{find_url, find_url_from_range}, -}; -use crate::{ editor_settings::MultiCursorModifier, + hover_links::{find_url, find_url_from_range}, signature_help::{SignatureHelpHiddenBy, SignatureHelpState}, }; @@ -275,7 +268,7 @@ impl InlineValueCache { #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum InlayId { - InlineCompletion(usize), + EditPrediction(usize), DebuggerValue(usize), // LSP Hint(usize), @@ -285,7 +278,7 @@ pub enum InlayId { impl InlayId { fn id(&self) -> usize { match self { - Self::InlineCompletion(id) => *id, + Self::EditPrediction(id) => *id, Self::DebuggerValue(id) => *id, Self::Hint(id) => *id, Self::Color(id) => *id, @@ -554,7 +547,7 @@ pub struct EditorStyle { pub syntax: Arc<SyntaxTheme>, pub status: StatusColors, pub inlay_hints_style: HighlightStyle, - pub inline_completion_styles: InlineCompletionStyles, + pub edit_prediction_styles: EditPredictionStyles, pub unnecessary_code_fade: f32, pub show_underlines: bool, } @@ -573,7 +566,7 @@ impl Default for EditorStyle { // style and retrieve them directly from the theme. status: StatusColors::dark(), inlay_hints_style: HighlightStyle::default(), - inline_completion_styles: InlineCompletionStyles { + edit_prediction_styles: EditPredictionStyles { insertion: HighlightStyle::default(), whitespace: HighlightStyle::default(), }, @@ -595,8 +588,8 @@ pub fn make_inlay_hints_style(cx: &mut App) -> HighlightStyle { } } -pub fn make_suggestion_styles(cx: &mut App) -> InlineCompletionStyles { - InlineCompletionStyles { +pub fn make_suggestion_styles(cx: &mut App) -> EditPredictionStyles { + EditPredictionStyles { insertion: HighlightStyle { color: Some(cx.theme().status().predictive), ..HighlightStyle::default() @@ -616,7 +609,7 @@ pub(crate) enum EditDisplayMode { Inline, } -enum InlineCompletion { +enum EditPrediction { Edit { edits: Vec<(Range<Anchor>, String)>, edit_preview: Option<EditPreview>, @@ -629,9 +622,9 @@ enum InlineCompletion { }, } -struct InlineCompletionState { +struct EditPredictionState { inlay_ids: Vec<InlayId>, - completion: InlineCompletion, + completion: EditPrediction, completion_id: Option<SharedString>, invalidation_range: Range<Anchor>, } @@ -644,7 +637,7 @@ enum EditPredictionSettings { }, } -enum InlineCompletionHighlight {} +enum EditPredictionHighlight {} #[derive(Debug, Clone)] struct InlineDiagnostic { @@ -655,7 +648,7 @@ struct InlineDiagnostic { severity: lsp::DiagnosticSeverity, } -pub enum MenuInlineCompletionsPolicy { +pub enum MenuEditPredictionsPolicy { Never, ByProvider, } @@ -1094,15 +1087,15 @@ pub struct Editor { pending_mouse_down: Option<Rc<RefCell<Option<MouseDownEvent>>>>, gutter_hovered: bool, hovered_link_state: Option<HoveredLinkState>, - edit_prediction_provider: Option<RegisteredInlineCompletionProvider>, + edit_prediction_provider: Option<RegisteredEditPredictionProvider>, code_action_providers: Vec<Rc<dyn CodeActionProvider>>, - active_inline_completion: Option<InlineCompletionState>, + active_edit_prediction: Option<EditPredictionState>, /// Used to prevent flickering as the user types while the menu is open - stale_inline_completion_in_menu: Option<InlineCompletionState>, + stale_edit_prediction_in_menu: Option<EditPredictionState>, edit_prediction_settings: EditPredictionSettings, - inline_completions_hidden_for_vim_mode: bool, - show_inline_completions_override: Option<bool>, - menu_inline_completions_policy: MenuInlineCompletionsPolicy, + edit_predictions_hidden_for_vim_mode: bool, + show_edit_predictions_override: Option<bool>, + menu_edit_predictions_policy: MenuEditPredictionsPolicy, edit_prediction_preview: EditPredictionPreview, edit_prediction_indent_conflict: bool, edit_prediction_requires_modifier_in_indent_conflict: bool, @@ -1305,6 +1298,7 @@ impl Default for SelectionHistoryMode { /// /// Similarly, you might want to disable scrolling if you don't want the viewport to /// move. +#[derive(Clone)] pub struct SelectionEffects { nav_history: Option<bool>, completions: bool, @@ -1516,8 +1510,8 @@ pub struct RenameState { struct InvalidationStack<T>(Vec<T>); -struct RegisteredInlineCompletionProvider { - provider: Arc<dyn InlineCompletionProviderHandle>, +struct RegisteredEditPredictionProvider { + provider: Arc<dyn EditPredictionProviderHandle>, _subscription: Subscription, } @@ -1774,7 +1768,7 @@ impl Editor { ) -> Self { debug_assert!( display_map.is_none() || mode.is_minimap(), - "Providing a display map for a new editor is only intended for the minimap and might have unindended side effects otherwise!" + "Providing a display map for a new editor is only intended for the minimap and might have unintended side effects otherwise!" ); let full_mode = mode.is_full(); @@ -1870,7 +1864,6 @@ impl Editor { editor.tasks_update_task = Some(editor.refresh_runnables(window, cx)); } - editor.update_lsp_data(true, None, window, cx); } project::Event::SnippetEdit(id, snippet_edits) => { if let Some(buffer) = editor.buffer.read(cx).buffer(*id) { @@ -1892,6 +1885,11 @@ impl Editor { } } } + project::Event::LanguageServerBufferRegistered { buffer_id, .. } => { + if editor.buffer().read(cx).buffer(*buffer_id).is_some() { + editor.update_lsp_data(false, Some(*buffer_id), window, cx); + } + } _ => {} }, )); @@ -2102,8 +2100,8 @@ impl Editor { pending_mouse_down: None, hovered_link_state: None, edit_prediction_provider: None, - active_inline_completion: None, - stale_inline_completion_in_menu: None, + active_edit_prediction: None, + stale_edit_prediction_in_menu: None, edit_prediction_preview: EditPredictionPreview::Inactive { released_too_fast: false, }, @@ -2122,9 +2120,9 @@ impl Editor { hovered_cursors: HashMap::default(), next_editor_action_id: EditorActionId::default(), editor_actions: Rc::default(), - inline_completions_hidden_for_vim_mode: false, - show_inline_completions_override: None, - menu_inline_completions_policy: MenuInlineCompletionsPolicy::ByProvider, + edit_predictions_hidden_for_vim_mode: false, + show_edit_predictions_override: None, + menu_edit_predictions_policy: MenuEditPredictionsPolicy::ByProvider, edit_prediction_settings: EditPredictionSettings::Disabled, edit_prediction_indent_conflict: false, edit_prediction_requires_modifier_in_indent_conflict: true, @@ -2356,7 +2354,7 @@ impl Editor { } pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext { - self.key_context_internal(self.has_active_inline_completion(), window, cx) + self.key_context_internal(self.has_active_edit_prediction(), window, cx) } fn key_context_internal( @@ -2723,17 +2721,16 @@ impl Editor { ) where T: EditPredictionProvider, { - self.edit_prediction_provider = - provider.map(|provider| RegisteredInlineCompletionProvider { - _subscription: cx.observe_in(&provider, window, |this, _, window, cx| { - if this.focus_handle.is_focused(window) { - this.update_visible_inline_completion(window, cx); - } - }), - provider: Arc::new(provider), - }); + self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionProvider { + _subscription: cx.observe_in(&provider, window, |this, _, window, cx| { + if this.focus_handle.is_focused(window) { + this.update_visible_edit_prediction(window, cx); + } + }), + provider: Arc::new(provider), + }); self.update_edit_prediction_settings(cx); - self.refresh_inline_completion(false, false, window, cx); + self.refresh_edit_prediction(false, false, window, cx); } pub fn placeholder_text(&self) -> Option<&str> { @@ -2804,24 +2801,24 @@ impl Editor { self.input_enabled = input_enabled; } - pub fn set_inline_completions_hidden_for_vim_mode( + pub fn set_edit_predictions_hidden_for_vim_mode( &mut self, hidden: bool, window: &mut Window, cx: &mut Context<Self>, ) { - if hidden != self.inline_completions_hidden_for_vim_mode { - self.inline_completions_hidden_for_vim_mode = hidden; + if hidden != self.edit_predictions_hidden_for_vim_mode { + self.edit_predictions_hidden_for_vim_mode = hidden; if hidden { - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); } else { - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); } } } - pub fn set_menu_inline_completions_policy(&mut self, value: MenuInlineCompletionsPolicy) { - self.menu_inline_completions_policy = value; + pub fn set_menu_edit_predictions_policy(&mut self, value: MenuEditPredictionsPolicy) { + self.menu_edit_predictions_policy = value; } pub fn set_autoindent(&mut self, autoindent: bool) { @@ -2858,7 +2855,7 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - if self.show_inline_completions_override.is_some() { + if self.show_edit_predictions_override.is_some() { self.set_show_edit_predictions(None, window, cx); } else { let show_edit_predictions = !self.edit_predictions_enabled(); @@ -2872,17 +2869,17 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - self.show_inline_completions_override = show_edit_predictions; + self.show_edit_predictions_override = show_edit_predictions; self.update_edit_prediction_settings(cx); if let Some(false) = show_edit_predictions { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); } else { - self.refresh_inline_completion(false, true, window, cx); + self.refresh_edit_prediction(false, true, window, cx); } } - fn inline_completions_disabled_in_scope( + fn edit_predictions_disabled_in_scope( &self, buffer: &Entity<Buffer>, buffer_position: language::Anchor, @@ -2944,10 +2941,12 @@ impl Editor { } } + let selection_anchors = self.selections.disjoint_anchors(); + if self.focus_handle.is_focused(window) && self.leader_id.is_none() { self.buffer.update(cx, |buffer, cx| { buffer.set_active_selections( - &self.selections.disjoint_anchors(), + &selection_anchors, self.selections.line_mode, self.cursor_shape, cx, @@ -2964,9 +2963,8 @@ impl Editor { self.select_next_state = None; self.select_prev_state = None; self.select_syntax_node_history.try_clear(); - self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), buffer); - self.snippet_stack - .invalidate(&self.selections.disjoint_anchors(), buffer); + self.invalidate_autoclose_regions(&selection_anchors, buffer); + self.snippet_stack.invalidate(&selection_anchors, buffer); self.take_rename(false, window, cx); let newest_selection = self.selections.newest_anchor(); @@ -3048,7 +3046,7 @@ impl Editor { self.refresh_document_highlights(cx); self.refresh_selected_text_highlights(false, window, cx); refresh_matching_bracket_highlights(self, window, cx); - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); self.edit_prediction_requires_modifier_in_indent_conflict = true; linked_editing_ranges::refresh_linked_ranges(self, window, cx); self.inline_blame_popover.take(); @@ -3838,7 +3836,7 @@ impl Editor { return true; } - if is_user_requested && self.discard_inline_completion(true, cx) { + if is_user_requested && self.discard_edit_prediction(true, cx) { return true; } @@ -4047,7 +4045,8 @@ impl Editor { // then don't insert that closing bracket again; just move the selection // past the closing bracket. let should_skip = selection.end == region.range.end.to_point(&snapshot) - && text.as_ref() == region.pair.end.as_str(); + && text.as_ref() == region.pair.end.as_str() + && snapshot.contains_str_at(region.range.end, text.as_ref()); if should_skip { let anchor = snapshot.anchor_after(selection.end); new_selections @@ -4243,7 +4242,7 @@ impl Editor { ); } - let had_active_inline_completion = this.has_active_inline_completion(); + let had_active_edit_prediction = this.has_active_edit_prediction(); this.change_selections( SelectionEffects::scroll(Autoscroll::fit()).completions(false), window, @@ -4268,7 +4267,7 @@ impl Editor { } let trigger_in_words = - this.show_edit_predictions_in_menu() || !had_active_inline_completion; + this.show_edit_predictions_in_menu() || !had_active_edit_prediction; if this.hard_wrap.is_some() { let latest: Range<Point> = this.selections.newest(cx).range(); if latest.is_empty() @@ -4290,7 +4289,7 @@ impl Editor { } this.trigger_completion_on_input(&text, trigger_in_words, window, cx); linked_editing_ranges::refresh_linked_ranges(this, window, cx); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); jsx_tag_auto_close::handle_from(this, initial_buffer_versions, window, cx); }); } @@ -4625,7 +4624,7 @@ impl Editor { .collect(); this.change_selections(Default::default(), window, cx, |s| s.select(new_selections)); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -4973,13 +4972,17 @@ impl Editor { }) } - /// Remove any autoclose regions that no longer contain their selection. + /// Remove any autoclose regions that no longer contain their selection or have invalid anchors in ranges. fn invalidate_autoclose_regions( &mut self, mut selections: &[Selection<Anchor>], buffer: &MultiBufferSnapshot, ) { self.autoclose_regions.retain(|state| { + if !state.range.start.is_valid(buffer) || !state.range.end.is_valid(buffer) { + return false; + } + let mut i = 0; while let Some(selection) = selections.get(i) { if selection.end.cmp(&state.range.start, buffer).is_lt() { @@ -5669,9 +5672,9 @@ impl Editor { crate::hover_popover::hide_hover(editor, cx); if editor.show_edit_predictions_in_menu() { - editor.update_visible_inline_completion(window, cx); + editor.update_visible_edit_prediction(window, cx); } else { - editor.discard_inline_completion(false, cx); + editor.discard_edit_prediction(false, cx); } cx.notify(); @@ -5682,10 +5685,10 @@ impl Editor { if editor.completion_tasks.len() <= 1 { // If there are no more completion tasks and the last menu was empty, we should hide it. let was_hidden = editor.hide_context_menu(window, cx).is_none(); - // If it was already hidden and we don't show inline completions in the menu, we should - // also show the inline-completion when available. + // If it was already hidden and we don't show edit predictions in the menu, + // we should also show the edit prediction when available. if was_hidden && editor.show_edit_predictions_in_menu() { - editor.update_visible_inline_completion(window, cx); + editor.update_visible_edit_prediction(window, cx); } } }) @@ -5779,7 +5782,7 @@ impl Editor { let entries = completions_menu.entries.borrow(); let mat = entries.get(item_ix.unwrap_or(completions_menu.selected_item))?; if self.show_edit_predictions_in_menu() { - self.discard_inline_completion(true, cx); + self.discard_edit_prediction(true, cx); } mat.candidate_id }; @@ -5891,18 +5894,20 @@ impl Editor { text: new_text[common_prefix_len..].into(), }); - self.transact(window, cx, |this, window, cx| { + self.transact(window, cx, |editor, window, cx| { if let Some(mut snippet) = snippet { snippet.text = new_text.to_string(); - this.insert_snippet(&ranges, snippet, window, cx).log_err(); + editor + .insert_snippet(&ranges, snippet, window, cx) + .log_err(); } else { - this.buffer.update(cx, |buffer, cx| { + editor.buffer.update(cx, |multi_buffer, cx| { let auto_indent = match completion.insert_text_mode { Some(InsertTextMode::AS_IS) => None, - _ => this.autoindent_mode.clone(), + _ => editor.autoindent_mode.clone(), }; let edits = ranges.into_iter().map(|range| (range, new_text.as_str())); - buffer.edit(edits, auto_indent, cx); + multi_buffer.edit(edits, auto_indent, cx); }); } for (buffer, edits) in linked_edits { @@ -5921,8 +5926,9 @@ impl Editor { }) } - this.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); + self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), &snapshot); let show_new_completions_on_confirm = completion .confirm @@ -5980,7 +5986,7 @@ impl Editor { let deployed_from = action.deployed_from.clone(); let action = action.clone(); self.completion_tasks.clear(); - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); let multibuffer_point = match &action.deployed_from { Some(CodeActionSource::Indicator(row)) | Some(CodeActionSource::RunMenu(row)) => { @@ -6400,7 +6406,6 @@ impl Editor { IconButton::new("inline_code_actions", ui::IconName::BoltFilled) .icon_size(icon_size) .shape(ui::IconButtonShape::Square) - .style(ButtonStyle::Transparent) .icon_color(ui::Color::Hidden) .toggle_state(is_active) .when(show_tooltip, |this| { @@ -6986,20 +6991,24 @@ impl Editor { } } - pub fn refresh_inline_completion( + pub fn refresh_edit_prediction( &mut self, debounce: bool, user_requested: bool, window: &mut Window, cx: &mut Context<Self>, ) -> Option<()> { + if DisableAiSettings::get_global(cx).disable_ai { + return None; + } + let provider = self.edit_prediction_provider()?; let cursor = self.selections.newest_anchor().head(); let (buffer, cursor_buffer_position) = self.buffer.read(cx).text_anchor_for_position(cursor, cx)?; if !self.edit_predictions_enabled_in_buffer(&buffer, cursor_buffer_position, cx) { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); return None; } @@ -7008,11 +7017,11 @@ impl Editor { || !self.is_focused(window) || buffer.read(cx).is_empty()) { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); return None; } - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); provider.refresh( self.project.clone(), buffer, @@ -7048,8 +7057,9 @@ impl Editor { } pub fn update_edit_prediction_settings(&mut self, cx: &mut Context<Self>) { - if self.edit_prediction_provider.is_none() { + if self.edit_prediction_provider.is_none() || DisableAiSettings::get_global(cx).disable_ai { self.edit_prediction_settings = EditPredictionSettings::Disabled; + self.discard_edit_prediction(false, cx); } else { let selection = self.selections.newest_anchor(); let cursor = selection.head(); @@ -7070,8 +7080,8 @@ impl Editor { cx: &App, ) -> EditPredictionSettings { if !self.mode.is_full() - || !self.show_inline_completions_override.unwrap_or(true) - || self.inline_completions_disabled_in_scope(buffer, buffer_position, cx) + || !self.show_edit_predictions_override.unwrap_or(true) + || self.edit_predictions_disabled_in_scope(buffer, buffer_position, cx) { return EditPredictionSettings::Disabled; } @@ -7085,8 +7095,8 @@ impl Editor { }; let by_provider = matches!( - self.menu_inline_completions_policy, - MenuInlineCompletionsPolicy::ByProvider + self.menu_edit_predictions_policy, + MenuEditPredictionsPolicy::ByProvider ); let show_in_menu = by_provider @@ -7156,7 +7166,7 @@ impl Editor { .unwrap_or(false) } - fn cycle_inline_completion( + fn cycle_edit_prediction( &mut self, direction: Direction, window: &mut Window, @@ -7166,28 +7176,28 @@ impl Editor { let cursor = self.selections.newest_anchor().head(); let (buffer, cursor_buffer_position) = self.buffer.read(cx).text_anchor_for_position(cursor, cx)?; - if self.inline_completions_hidden_for_vim_mode || !self.should_show_edit_predictions() { + if self.edit_predictions_hidden_for_vim_mode || !self.should_show_edit_predictions() { return None; } provider.cycle(buffer, cursor_buffer_position, direction, cx); - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); Some(()) } - pub fn show_inline_completion( + pub fn show_edit_prediction( &mut self, _: &ShowEditPrediction, window: &mut Window, cx: &mut Context<Self>, ) { - if !self.has_active_inline_completion() { - self.refresh_inline_completion(false, true, window, cx); + if !self.has_active_edit_prediction() { + self.refresh_edit_prediction(false, true, window, cx); return; } - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); } pub fn display_cursor_names( @@ -7219,11 +7229,11 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - if self.has_active_inline_completion() { - self.cycle_inline_completion(Direction::Next, window, cx); + if self.has_active_edit_prediction() { + self.cycle_edit_prediction(Direction::Next, window, cx); } else { let is_copilot_disabled = self - .refresh_inline_completion(false, true, window, cx) + .refresh_edit_prediction(false, true, window, cx) .is_none(); if is_copilot_disabled { cx.propagate(); @@ -7237,11 +7247,11 @@ impl Editor { window: &mut Window, cx: &mut Context<Self>, ) { - if self.has_active_inline_completion() { - self.cycle_inline_completion(Direction::Prev, window, cx); + if self.has_active_edit_prediction() { + self.cycle_edit_prediction(Direction::Prev, window, cx); } else { let is_copilot_disabled = self - .refresh_inline_completion(false, true, window, cx) + .refresh_edit_prediction(false, true, window, cx) .is_none(); if is_copilot_disabled { cx.propagate(); @@ -7259,18 +7269,14 @@ impl Editor { self.hide_context_menu(window, cx); } - let Some(active_inline_completion) = self.active_inline_completion.as_ref() else { + let Some(active_edit_prediction) = self.active_edit_prediction.as_ref() else { return; }; - self.report_inline_completion_event( - active_inline_completion.completion_id.clone(), - true, - cx, - ); + self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx); - match &active_inline_completion.completion { - InlineCompletion::Move { target, .. } => { + match &active_edit_prediction.completion { + EditPrediction::Move { target, .. } => { let target = *target; if let Some(position_map) = &self.last_position_map { @@ -7312,7 +7318,7 @@ impl Editor { } } } - InlineCompletion::Edit { edits, .. } => { + EditPrediction::Edit { edits, .. } => { if let Some(provider) = self.edit_prediction_provider() { provider.accept(cx); } @@ -7340,9 +7346,9 @@ impl Editor { } } - self.update_visible_inline_completion(window, cx); - if self.active_inline_completion.is_none() { - self.refresh_inline_completion(true, true, window, cx); + self.update_visible_edit_prediction(window, cx); + if self.active_edit_prediction.is_none() { + self.refresh_edit_prediction(true, true, window, cx); } cx.notify(); @@ -7352,27 +7358,23 @@ impl Editor { self.edit_prediction_requires_modifier_in_indent_conflict = false; } - pub fn accept_partial_inline_completion( + pub fn accept_partial_edit_prediction( &mut self, _: &AcceptPartialEditPrediction, window: &mut Window, cx: &mut Context<Self>, ) { - let Some(active_inline_completion) = self.active_inline_completion.as_ref() else { + let Some(active_edit_prediction) = self.active_edit_prediction.as_ref() else { return; }; if self.selections.count() != 1 { return; } - self.report_inline_completion_event( - active_inline_completion.completion_id.clone(), - true, - cx, - ); + self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx); - match &active_inline_completion.completion { - InlineCompletion::Move { target, .. } => { + match &active_edit_prediction.completion { + EditPrediction::Move { target, .. } => { let target = *target; self.change_selections( SelectionEffects::scroll(Autoscroll::newest()), @@ -7383,7 +7385,7 @@ impl Editor { }, ); } - InlineCompletion::Edit { edits, .. } => { + EditPrediction::Edit { edits, .. } => { // Find an insertion that starts at the cursor position. let snapshot = self.buffer.read(cx).snapshot(cx); let cursor_offset = self.selections.newest::<usize>(cx).head(); @@ -7417,7 +7419,7 @@ impl Editor { self.insert_with_autoindent_mode(&partial_completion, None, window, cx); - self.refresh_inline_completion(true, true, window, cx); + self.refresh_edit_prediction(true, true, window, cx); cx.notify(); } else { self.accept_edit_prediction(&Default::default(), window, cx); @@ -7426,28 +7428,28 @@ impl Editor { } } - fn discard_inline_completion( + fn discard_edit_prediction( &mut self, - should_report_inline_completion_event: bool, + should_report_edit_prediction_event: bool, cx: &mut Context<Self>, ) -> bool { - if should_report_inline_completion_event { + if should_report_edit_prediction_event { let completion_id = self - .active_inline_completion + .active_edit_prediction .as_ref() .and_then(|active_completion| active_completion.completion_id.clone()); - self.report_inline_completion_event(completion_id, false, cx); + self.report_edit_prediction_event(completion_id, false, cx); } if let Some(provider) = self.edit_prediction_provider() { provider.discard(cx); } - self.take_active_inline_completion(cx) + self.take_active_edit_prediction(cx) } - fn report_inline_completion_event(&self, id: Option<SharedString>, accepted: bool, cx: &App) { + fn report_edit_prediction_event(&self, id: Option<SharedString>, accepted: bool, cx: &App) { let Some(provider) = self.edit_prediction_provider() else { return; }; @@ -7478,18 +7480,18 @@ impl Editor { ); } - pub fn has_active_inline_completion(&self) -> bool { - self.active_inline_completion.is_some() + pub fn has_active_edit_prediction(&self) -> bool { + self.active_edit_prediction.is_some() } - fn take_active_inline_completion(&mut self, cx: &mut Context<Self>) -> bool { - let Some(active_inline_completion) = self.active_inline_completion.take() else { + fn take_active_edit_prediction(&mut self, cx: &mut Context<Self>) -> bool { + let Some(active_edit_prediction) = self.active_edit_prediction.take() else { return false; }; - self.splice_inlays(&active_inline_completion.inlay_ids, Default::default(), cx); - self.clear_highlights::<InlineCompletionHighlight>(cx); - self.stale_inline_completion_in_menu = Some(active_inline_completion); + self.splice_inlays(&active_edit_prediction.inlay_ids, Default::default(), cx); + self.clear_highlights::<EditPredictionHighlight>(cx); + self.stale_edit_prediction_in_menu = Some(active_edit_prediction); true } @@ -7634,7 +7636,7 @@ impl Editor { since: Instant::now(), }; - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); cx.notify(); } } else if let EditPredictionPreview::Active { @@ -7657,16 +7659,20 @@ impl Editor { released_too_fast: since.elapsed() < Duration::from_millis(200), }; self.clear_row_highlights::<EditPredictionPreview>(); - self.update_visible_inline_completion(window, cx); + self.update_visible_edit_prediction(window, cx); cx.notify(); } } - fn update_visible_inline_completion( + fn update_visible_edit_prediction( &mut self, _window: &mut Window, cx: &mut Context<Self>, ) -> Option<()> { + if DisableAiSettings::get_global(cx).disable_ai { + return None; + } + let selection = self.selections.newest_anchor(); let cursor = selection.head(); let multibuffer = self.buffer.read(cx).snapshot(cx); @@ -7676,12 +7682,12 @@ impl Editor { let show_in_menu = self.show_edit_predictions_in_menu(); let completions_menu_has_precedence = !show_in_menu && (self.context_menu.borrow().is_some() - || (!self.completion_tasks.is_empty() && !self.has_active_inline_completion())); + || (!self.completion_tasks.is_empty() && !self.has_active_edit_prediction())); if completions_menu_has_precedence || !offset_selection.is_empty() || self - .active_inline_completion + .active_edit_prediction .as_ref() .map_or(false, |completion| { let invalidation_range = completion.invalidation_range.to_offset(&multibuffer); @@ -7689,11 +7695,11 @@ impl Editor { !invalidation_range.contains(&offset_selection.head()) }) { - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); return None; } - self.take_active_inline_completion(cx); + self.take_active_edit_prediction(cx); let Some(provider) = self.edit_prediction_provider() else { self.edit_prediction_settings = EditPredictionSettings::Disabled; return None; @@ -7719,8 +7725,8 @@ impl Editor { } } - let inline_completion = provider.suggest(&buffer, cursor_buffer_position, cx)?; - let edits = inline_completion + let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?; + let edits = edit_prediction .edits .into_iter() .flat_map(|(range, new_text)| { @@ -7755,15 +7761,15 @@ impl Editor { None }; let is_move = - move_invalidation_row_range.is_some() || self.inline_completions_hidden_for_vim_mode; + move_invalidation_row_range.is_some() || self.edit_predictions_hidden_for_vim_mode; let completion = if is_move { invalidation_row_range = move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row); let target = first_edit_start; - InlineCompletion::Move { target, snapshot } + EditPrediction::Move { target, snapshot } } else { let show_completions_in_buffer = !self.edit_prediction_visible_in_cursor_popover(true) - && !self.inline_completions_hidden_for_vim_mode; + && !self.edit_predictions_hidden_for_vim_mode; if show_completions_in_buffer { if edits @@ -7772,7 +7778,7 @@ impl Editor { { let mut inlays = Vec::new(); for (range, new_text) in &edits { - let inlay = Inlay::inline_completion( + let inlay = Inlay::edit_prediction( post_inc(&mut self.next_inlay_id), range.start, new_text.as_str(), @@ -7784,7 +7790,7 @@ impl Editor { self.splice_inlays(&[], inlays, cx); } else { let background_color = cx.theme().status().deleted_background; - self.highlight_text::<InlineCompletionHighlight>( + self.highlight_text::<EditPredictionHighlight>( edits.iter().map(|(range, _)| range.clone()).collect(), HighlightStyle { background_color: Some(background_color), @@ -7807,9 +7813,9 @@ impl Editor { EditDisplayMode::DiffPopover }; - InlineCompletion::Edit { + EditPrediction::Edit { edits, - edit_preview: inline_completion.edit_preview, + edit_preview: edit_prediction.edit_preview, display_mode, snapshot, } @@ -7822,11 +7828,11 @@ impl Editor { multibuffer.line_len(MultiBufferRow(invalidation_row_range.end)), )); - self.stale_inline_completion_in_menu = None; - self.active_inline_completion = Some(InlineCompletionState { + self.stale_edit_prediction_in_menu = None; + self.active_edit_prediction = Some(EditPredictionState { inlay_ids, completion, - completion_id: inline_completion.id, + completion_id: edit_prediction.id, invalidation_range, }); @@ -7835,7 +7841,7 @@ impl Editor { Some(()) } - pub fn edit_prediction_provider(&self) -> Option<Arc<dyn InlineCompletionProviderHandle>> { + pub fn edit_prediction_provider(&self) -> Option<Arc<dyn EditPredictionProviderHandle>> { Some(self.edit_prediction_provider.as_ref()?.provider.clone()) } @@ -8177,7 +8183,7 @@ impl Editor { editor.set_breakpoint_context_menu( row, Some(position), - event.down.position, + event.position(), window, cx, ); @@ -8235,8 +8241,7 @@ impl Editor { return; }; - // Try to find a closest, enclosing node using tree-sitter that has a - // task + // Try to find a closest, enclosing node using tree-sitter that has a task let Some((buffer, buffer_row, tasks)) = self .find_enclosing_node_task(cx) // Or find the task that's closest in row-distance. @@ -8336,26 +8341,33 @@ impl Editor { let color = Color::Muted; let position = breakpoint.as_ref().map(|(anchor, _, _)| *anchor); - IconButton::new(("run_indicator", row.0 as usize), ui::IconName::Play) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(color) - .toggle_state(is_active) - .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| { - let quick_launch = e.down.button == MouseButton::Left; - window.focus(&editor.focus_handle(cx)); - editor.toggle_code_actions( - &ToggleCodeActions { - deployed_from: Some(CodeActionSource::RunMenu(row)), - quick_launch, - }, - window, - cx, - ); - })) - .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| { - editor.set_breakpoint_context_menu(row, position, event.down.position, window, cx); - })) + IconButton::new( + ("run_indicator", row.0 as usize), + ui::IconName::PlayOutlined, + ) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(color) + .toggle_state(is_active) + .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| { + let quick_launch = match e { + ClickEvent::Keyboard(_) => true, + ClickEvent::Mouse(e) => e.down.button == MouseButton::Left, + }; + + window.focus(&editor.focus_handle(cx)); + editor.toggle_code_actions( + &ToggleCodeActions { + deployed_from: Some(CodeActionSource::RunMenu(row)), + quick_launch, + }, + window, + cx, + ); + })) + .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| { + editor.set_breakpoint_context_menu(row, position, event.position(), window, cx); + })) } pub fn context_menu_visible(&self) -> bool { @@ -8402,14 +8414,14 @@ impl Editor { if self.mode().is_minimap() { return None; } - let active_inline_completion = self.active_inline_completion.as_ref()?; + let active_edit_prediction = self.active_edit_prediction.as_ref()?; if self.edit_prediction_visible_in_cursor_popover(true) { return None; } - match &active_inline_completion.completion { - InlineCompletion::Move { target, .. } => { + match &active_edit_prediction.completion { + EditPrediction::Move { target, .. } => { let target_display_point = target.to_display_point(editor_snapshot); if self.edit_prediction_requires_modifier() { @@ -8446,11 +8458,11 @@ impl Editor { ) } } - InlineCompletion::Edit { + EditPrediction::Edit { display_mode: EditDisplayMode::Inline, .. } => None, - InlineCompletion::Edit { + EditPrediction::Edit { display_mode: EditDisplayMode::TabAccept, edits, .. @@ -8471,7 +8483,7 @@ impl Editor { cx, ) } - InlineCompletion::Edit { + EditPrediction::Edit { edits, edit_preview, display_mode: EditDisplayMode::DiffPopover, @@ -8788,7 +8800,7 @@ impl Editor { } let highlighted_edits = - crate::inline_completion_edit_text(&snapshot, edits, edit_preview.as_ref()?, false, cx); + crate::edit_prediction_edit_text(&snapshot, edits, edit_preview.as_ref()?, false, cx); let styled_text = highlighted_edits.to_styled_text(&style.text); let line_count = highlighted_edits.text.lines().count(); @@ -9118,7 +9130,7 @@ impl Editor { .child(Icon::new(IconName::ZedPredict)) } - let completion = match &self.active_inline_completion { + let completion = match &self.active_edit_prediction { Some(prediction) => { if !self.has_visible_completions_menu() { const RADIUS: Pixels = px(6.); @@ -9136,7 +9148,7 @@ impl Editor { .rounded_tl(px(0.)) .overflow_hidden() .child(div().px_1p5().child(match &prediction.completion { - InlineCompletion::Move { target, snapshot } => { + EditPrediction::Move { target, snapshot } => { use text::ToPoint as _; if target.text_anchor.to_point(&snapshot).row > cursor_point.row { @@ -9145,7 +9157,7 @@ impl Editor { Icon::new(IconName::ZedPredictUp) } } - InlineCompletion::Edit { .. } => Icon::new(IconName::ZedPredict), + EditPrediction::Edit { .. } => Icon::new(IconName::ZedPredict), })) .child( h_flex() @@ -9204,7 +9216,7 @@ impl Editor { )? } - None if is_refreshing => match &self.stale_inline_completion_in_menu { + None if is_refreshing => match &self.stale_edit_prediction_in_menu { Some(stale_completion) => self.render_edit_prediction_cursor_popover_preview( stale_completion, cursor_point, @@ -9234,7 +9246,7 @@ impl Editor { completion.into_any_element() }; - let has_completion = self.active_inline_completion.is_some(); + let has_completion = self.active_edit_prediction.is_some(); let is_platform_style_mac = PlatformStyle::platform() == PlatformStyle::Mac; Some( @@ -9293,7 +9305,7 @@ impl Editor { fn render_edit_prediction_cursor_popover_preview( &self, - completion: &InlineCompletionState, + completion: &EditPredictionState, cursor_point: Point, style: &EditorStyle, cx: &mut Context<Editor>, @@ -9321,7 +9333,7 @@ impl Editor { } match &completion.completion { - InlineCompletion::Move { + EditPrediction::Move { target, snapshot, .. } => Some( h_flex() @@ -9338,7 +9350,7 @@ impl Editor { .child(Label::new("Jump to Edit")), ), - InlineCompletion::Edit { + EditPrediction::Edit { edits, edit_preview, snapshot, @@ -9346,7 +9358,7 @@ impl Editor { } => { let first_edit_row = edits.first()?.0.start.text_anchor.to_point(&snapshot).row; - let (highlighted_edits, has_more_lines) = crate::inline_completion_edit_text( + let (highlighted_edits, has_more_lines) = crate::edit_prediction_edit_text( &snapshot, &edits, edit_preview.as_ref()?, @@ -9424,8 +9436,8 @@ impl Editor { cx.notify(); self.completion_tasks.clear(); let context_menu = self.context_menu.borrow_mut().take(); - self.stale_inline_completion_in_menu.take(); - self.update_visible_inline_completion(window, cx); + self.stale_edit_prediction_in_menu.take(); + self.update_visible_edit_prediction(window, cx); if let Some(CodeContextMenu::Completions(_)) = &context_menu { if let Some(completion_provider) = &self.completion_provider { completion_provider.selection_changed(None, window, cx); @@ -9563,27 +9575,46 @@ impl Editor { // Check whether the just-entered snippet ends with an auto-closable bracket. if self.autoclose_regions.is_empty() { let snapshot = self.buffer.read(cx).snapshot(cx); - for selection in &mut self.selections.all::<Point>(cx) { + let mut all_selections = self.selections.all::<Point>(cx); + for selection in &mut all_selections { let selection_head = selection.head(); let Some(scope) = snapshot.language_scope_at(selection_head) else { continue; }; let mut bracket_pair = None; - let next_chars = snapshot.chars_at(selection_head).collect::<String>(); - let prev_chars = snapshot - .reversed_chars_at(selection_head) - .collect::<String>(); - for (pair, enabled) in scope.brackets() { - if enabled - && pair.close - && prev_chars.starts_with(pair.start.as_str()) - && next_chars.starts_with(pair.end.as_str()) - { - bracket_pair = Some(pair.clone()); - break; + let max_lookup_length = scope + .brackets() + .map(|(pair, _)| { + pair.start + .as_str() + .chars() + .count() + .max(pair.end.as_str().chars().count()) + }) + .max(); + if let Some(max_lookup_length) = max_lookup_length { + let next_text = snapshot + .chars_at(selection_head) + .take(max_lookup_length) + .collect::<String>(); + let prev_text = snapshot + .reversed_chars_at(selection_head) + .take(max_lookup_length) + .collect::<String>(); + + for (pair, enabled) in scope.brackets() { + if enabled + && pair.close + && prev_text.starts_with(pair.start.as_str()) + && next_text.starts_with(pair.end.as_str()) + { + bracket_pair = Some(pair.clone()); + break; + } } } + if let Some(pair) = bracket_pair { let snapshot_settings = snapshot.language_settings_at(selection_head, cx); let autoclose_enabled = @@ -9764,7 +9795,7 @@ impl Editor { this.edit(edits, None, cx); }) } - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); linked_editing_ranges::refresh_linked_ranges(this, window, cx); }); } @@ -9783,7 +9814,7 @@ impl Editor { }) }); this.insert("", window, cx); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -9916,7 +9947,7 @@ impl Editor { self.transact(window, cx, |this, window, cx| { this.buffer.update(cx, |b, cx| b.edit(edits, None, cx)); this.change_selections(Default::default(), window, cx, |s| s.select(selections)); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -10878,17 +10909,6 @@ impl Editor { }); } - pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context<Self>) { - self.manipulate_text(window, cx, |text| { - let has_upper_case_characters = text.chars().any(|c| c.is_uppercase()); - if has_upper_case_characters { - text.to_lowercase() - } else { - text.to_uppercase() - } - }) - } - fn manipulate_immutable_lines<Fn>( &mut self, window: &mut Window, @@ -11144,6 +11164,26 @@ impl Editor { }) } + pub fn convert_to_sentence_case( + &mut self, + _: &ConvertToSentenceCase, + window: &mut Window, + cx: &mut Context<Self>, + ) { + self.manipulate_text(window, cx, |text| text.to_case(Case::Sentence)) + } + + pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context<Self>) { + self.manipulate_text(window, cx, |text| { + let has_upper_case_characters = text.chars().any(|c| c.is_uppercase()); + if has_upper_case_characters { + text.to_lowercase() + } else { + text.to_uppercase() + } + }) + } + pub fn convert_to_rot13( &mut self, _: &ConvertToRot13, @@ -12236,7 +12276,7 @@ impl Editor { } self.request_autoscroll(Autoscroll::fit(), cx); self.unmark_text(window, cx); - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); cx.emit(EditorEvent::Edited { transaction_id }); cx.emit(EditorEvent::TransactionUndone { transaction_id }); } @@ -12266,7 +12306,7 @@ impl Editor { } self.request_autoscroll(Autoscroll::fit(), cx); self.unmark_text(window, cx); - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); cx.emit(EditorEvent::Edited { transaction_id }); } } @@ -15253,7 +15293,7 @@ impl Editor { ]) }); self.activate_diagnostics(buffer_id, next_diagnostic, window, cx); - self.refresh_inline_completion(false, true, window, cx); + self.refresh_edit_prediction(false, true, window, cx); } pub fn go_to_next_hunk(&mut self, _: &GoToHunk, window: &mut Window, cx: &mut Context<Self>) { @@ -15814,7 +15854,7 @@ impl Editor { let language_server_name = project .language_server_statuses(cx) .find(|(id, _)| server_id == *id) - .map(|(_, status)| LanguageServerName::from(status.name.as_str())); + .map(|(_, status)| status.name.clone()); language_server_name.map(|language_server_name| { project.open_local_buffer_via_lsp( lsp_location.uri.clone(), @@ -16217,7 +16257,7 @@ impl Editor { font_weight: Some(FontWeight::BOLD), ..make_inlay_hints_style(cx.app) }, - inline_completion_styles: make_suggestion_styles( + edit_prediction_styles: make_suggestion_styles( cx.app, ), ..EditorStyle::default() @@ -16968,7 +17008,7 @@ impl Editor { now: Instant, window: &mut Window, cx: &mut Context<Self>, - ) { + ) -> Option<TransactionId> { self.end_selection(window, cx); if let Some(tx_id) = self .buffer @@ -16978,7 +17018,10 @@ impl Editor { .insert_transaction(tx_id, self.selections.disjoint_anchors()); cx.emit(EditorEvent::TransactionBegun { transaction_id: tx_id, - }) + }); + Some(tx_id) + } else { + None } } @@ -17006,6 +17049,17 @@ impl Editor { } } + pub fn modify_transaction_selection_history( + &mut self, + transaction_id: TransactionId, + modify: impl FnOnce(&mut (Arc<[Selection<Anchor>]>, Option<Arc<[Selection<Anchor>]>>)), + ) -> bool { + self.selection_history + .transaction_mut(transaction_id) + .map(modify) + .is_some() + } + pub fn set_mark(&mut self, _: &actions::SetMark, window: &mut Window, cx: &mut Context<Self>) { if self.selection_mark_mode { self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { @@ -18977,7 +19031,7 @@ impl Editor { (selection.range(), uuid.to_string()) }); this.edit(edits, cx); - this.refresh_inline_completion(true, false, window, cx); + this.refresh_edit_prediction(true, false, window, cx); }); } @@ -19830,8 +19884,8 @@ impl Editor { self.refresh_selected_text_highlights(true, window, cx); self.refresh_single_line_folds(window, cx); refresh_matching_bracket_highlights(self, window, cx); - if self.has_active_inline_completion() { - self.update_visible_inline_completion(window, cx); + if self.has_active_edit_prediction() { + self.update_visible_edit_prediction(window, cx); } if let Some(project) = self.project.as_ref() { if let Some(edited_buffer) = edited_buffer { @@ -20033,7 +20087,7 @@ impl Editor { } self.tasks_update_task = Some(self.refresh_runnables(window, cx)); self.update_edit_prediction_settings(cx); - self.refresh_inline_completion(true, false, window, cx); + self.refresh_edit_prediction(true, false, window, cx); self.refresh_inline_values(cx); self.refresh_inlay_hints( InlayHintRefreshReason::SettingsChange(inlay_hint_settings( @@ -20665,7 +20719,7 @@ impl Editor { { self.hide_context_menu(window, cx); } - self.discard_inline_completion(false, cx); + self.discard_edit_prediction(false, cx); cx.emit(EditorEvent::Blurred); cx.notify(); } @@ -21078,13 +21132,6 @@ fn process_completion_for_edit( .is_le(), "replace_range should start before or at cursor position" ); - debug_assert!( - insert_range - .end - .cmp(&cursor_position, &buffer_snapshot) - .is_le(), - "insert_range should end before or at cursor position" - ); let should_replace = match intent { CompletionIntent::CompleteWithInsert => false, @@ -21812,11 +21859,11 @@ impl CodeActionProvider for Entity<Project> { cx: &mut App, ) -> Task<Result<Vec<CodeAction>>> { self.update(cx, |project, cx| { - let code_lens = project.code_lens(buffer, range.clone(), cx); + let code_lens_actions = project.code_lens_actions(buffer, range.clone(), cx); let code_actions = project.code_actions(buffer, range, None, cx); cx.background_spawn(async move { - let (code_lens, code_actions) = join(code_lens, code_actions).await; - Ok(code_lens + let (code_lens_actions, code_actions) = join(code_lens_actions, code_actions).await; + Ok(code_lens_actions .context("code lens fetch")? .into_iter() .chain(code_actions.context("code action fetch")?) @@ -22145,7 +22192,6 @@ impl SemanticsProvider for Entity<Project> { } fn supports_inlay_hints(&self, buffer: &Entity<Buffer>, cx: &mut App) -> bool { - // TODO: make this work for remote projects self.update(cx, |project, cx| { if project .active_debug_session(cx) @@ -22734,7 +22780,7 @@ impl Render for Editor { syntax: cx.theme().syntax().clone(), status: cx.theme().status().clone(), inlay_hints_style: make_inlay_hints_style(cx), - inline_completion_styles: make_suggestion_styles(cx), + edit_prediction_styles: make_suggestion_styles(cx), unnecessary_code_fade: ThemeSettings::get_global(cx).unnecessary_code_fade, show_underlines: self.diagnostics_enabled(), }, @@ -23129,7 +23175,7 @@ impl InvalidationRegion for SnippetState { } } -fn inline_completion_edit_text( +fn edit_prediction_edit_text( current_snapshot: &BufferSnapshot, edits: &[(Range<Anchor>, String)], edit_preview: &EditPreview, diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 0d69c067ee..1cb3565733 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -2,7 +2,7 @@ use super::*; use crate::{ JoinLines, code_context_menus::CodeContextMenu, - inline_completion_tests::FakeInlineCompletionProvider, + edit_prediction_tests::FakeEditPredictionProvider, linked_editing_ranges::LinkedEditingRanges, scroll::scroll_amount::ScrollAmount, test::{ @@ -4724,6 +4724,23 @@ async fn test_toggle_case(cx: &mut TestAppContext) { "}); } +#[gpui::test] +async fn test_convert_to_sentence_case(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + + cx.set_state(indoc! {" + «implement-windows-supportˇ» + "}); + cx.update_editor(|e, window, cx| { + e.convert_to_sentence_case(&ConvertToSentenceCase, window, cx) + }); + cx.assert_editor_state(indoc! {" + «Implement windows supportˇ» + "}); +} + #[gpui::test] async fn test_manipulate_text(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -7234,12 +7251,12 @@ async fn test_undo_format_scrolls_to_last_edit_pos(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_undo_inline_completion_scrolls_to_edit_pos(cx: &mut TestAppContext) { +async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeInlineCompletionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionProvider::default()); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(provider.clone()), window, cx); }); @@ -7262,7 +7279,7 @@ async fn test_undo_inline_completion_scrolls_to_edit_pos(cx: &mut TestAppContext cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_inline_completion(Some(inline_completion::InlineCompletion { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction { id: None, edits: vec![(edit_position..edit_position, "X".into())], edit_preview: None, @@ -7270,7 +7287,7 @@ async fn test_undo_inline_completion_scrolls_to_edit_pos(cx: &mut TestAppContext }) }); - cx.update_editor(|editor, window, cx| editor.update_visible_inline_completion(window, cx)); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); cx.update_editor(|editor, window, cx| { editor.accept_edit_prediction(&crate::AcceptEditPrediction, window, cx) }); @@ -8595,6 +8612,7 @@ async fn test_autoclose_with_embedded_language(cx: &mut TestAppContext) { cx.language_registry().add(html_language.clone()); cx.language_registry().add(javascript_language.clone()); + cx.executor().run_until_parked(); cx.update_buffer(|buffer, cx| { buffer.set_language(Some(html_language), cx); @@ -10055,8 +10073,14 @@ async fn test_autosave_with_dirty_buffers(cx: &mut TestAppContext) { ); } -#[gpui::test] -async fn test_range_format_during_save(cx: &mut TestAppContext) { +async fn setup_range_format_test( + cx: &mut TestAppContext, +) -> ( + Entity<Project>, + Entity<Editor>, + &mut gpui::VisualTestContext, + lsp::FakeLanguageServer, +) { init_test(cx, |_| {}); let fs = FakeFs::new(cx.executor()); @@ -10071,9 +10095,9 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { FakeLspAdapter { capabilities: lsp::ServerCapabilities { document_range_formatting_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() + ..lsp::ServerCapabilities::default() }, - ..Default::default() + ..FakeLspAdapter::default() }, ); @@ -10088,14 +10112,22 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { let (editor, cx) = cx.add_window_view(|window, cx| { build_editor_with_project(project.clone(), buffer, window, cx) }); + + cx.executor().start_waiting(); + let fake_server = fake_servers.next().await.unwrap(); + + (project, editor, cx, fake_server) +} + +#[gpui::test] +async fn test_range_format_on_save_success(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; + editor.update_in(cx, |editor, window, cx| { editor.set_text("one\ntwo\nthree\n", window, cx) }); assert!(cx.read(|cx| editor.is_dirty(cx))); - cx.executor().start_waiting(); - let fake_server = fake_servers.next().await.unwrap(); - let save = editor .update_in(cx, |editor, window, cx| { editor.save( @@ -10130,13 +10162,18 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { "one, two\nthree\n" ); assert!(!cx.read(|cx| editor.is_dirty(cx))); +} + +#[gpui::test] +async fn test_range_format_on_save_timeout(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; editor.update_in(cx, |editor, window, cx| { editor.set_text("one\ntwo\nthree\n", window, cx) }); assert!(cx.read(|cx| editor.is_dirty(cx))); - // Ensure we can still save even if formatting hangs. + // Test that save still works when formatting hangs fake_server.set_request_handler::<lsp::request::RangeFormatting, _, _>( move |params, _| async move { assert_eq!( @@ -10168,8 +10205,13 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { "one\ntwo\nthree\n" ); assert!(!cx.read(|cx| editor.is_dirty(cx))); +} - // For non-dirty buffer, no formatting request should be sent +#[gpui::test] +async fn test_range_format_not_called_for_clean_buffer(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; + + // Buffer starts clean, no formatting should be requested let save = editor .update_in(cx, |editor, window, cx| { editor.save( @@ -10190,6 +10232,12 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { .next(); cx.executor().start_waiting(); save.await; + cx.run_until_parked(); +} + +#[gpui::test] +async fn test_range_format_respects_language_tab_size_override(cx: &mut TestAppContext) { + let (project, editor, cx, fake_server) = setup_range_format_test(cx).await; // Set Rust language override and assert overridden tabsize is sent to language server update_test_language_settings(cx, |settings| { @@ -10203,7 +10251,7 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) { }); editor.update_in(cx, |editor, window, cx| { - editor.set_text("somehting_new\n", window, cx) + editor.set_text("something_new\n", window, cx) }); assert!(cx.read(|cx| editor.is_dirty(cx))); let save = editor @@ -13353,6 +13401,178 @@ async fn test_as_is_completions(cx: &mut TestAppContext) { cx.assert_editor_state("fn a() {}\n unsafeˇ"); } +#[gpui::test] +async fn test_panic_during_c_completions(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + let language = + Arc::try_unwrap(languages::language("c", tree_sitter_c::LANGUAGE.into())).unwrap(); + let mut cx = EditorLspTestContext::new( + language, + lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + ..lsp::CompletionOptions::default() + }), + ..lsp::ServerCapabilities::default() + }, + cx, + ) + .await; + + cx.set_state( + "#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +ˇ", + ); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("#", window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("i", window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("n", window, cx); + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + "#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#inˇ", + ); + + cx.lsp + .set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move { + Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList { + is_incomplete: false, + item_defaults: None, + items: vec![lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::SNIPPET), + label_details: Some(lsp::CompletionItemLabelDetails { + detail: Some("header".to_string()), + description: None, + }), + label: " include".to_string(), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + range: lsp::Range { + start: lsp::Position { + line: 8, + character: 1, + }, + end: lsp::Position { + line: 8, + character: 1, + }, + }, + new_text: "include \"$0\"".to_string(), + })), + sort_text: Some("40b67681include".to_string()), + insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), + filter_text: Some("include".to_string()), + insert_text: Some("include \"$0\"".to_string()), + ..lsp::CompletionItem::default() + }], + }))) + }); + cx.update_editor(|editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.confirm_completion(&ConfirmCompletion::default(), window, cx) + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + "#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include \"ˇ\"", + ); + + cx.lsp + .set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move { + Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList { + is_incomplete: true, + item_defaults: None, + items: vec![lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::FILE), + label: "AGL/".to_string(), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + range: lsp::Range { + start: lsp::Position { + line: 8, + character: 10, + }, + end: lsp::Position { + line: 8, + character: 11, + }, + }, + new_text: "AGL/".to_string(), + })), + sort_text: Some("40b67681AGL/".to_string()), + insert_text_format: Some(lsp::InsertTextFormat::PLAIN_TEXT), + filter_text: Some("AGL/".to_string()), + insert_text: Some("AGL/".to_string()), + ..lsp::CompletionItem::default() + }], + }))) + }); + cx.update_editor(|editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.confirm_completion(&ConfirmCompletion::default(), window, cx) + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + r##"#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include "AGL/ˇ"##, + ); + + cx.update_editor(|editor, window, cx| { + editor.handle_input("\"", window, cx); + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + r##"#ifndef BAR_H +#define BAR_H + +#include <stdbool.h> + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include "AGL/"ˇ"##, + ); +} + #[gpui::test] async fn test_no_duplicated_completion_requests(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -16864,7 +17084,7 @@ async fn test_multibuffer_reverts(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_mutlibuffer_in_navigation_history(cx: &mut TestAppContext) { +async fn test_multibuffer_in_navigation_history(cx: &mut TestAppContext) { init_test(cx, |_| {}); let cols = 4; @@ -20332,7 +20552,7 @@ async fn test_multi_buffer_navigation_with_folded_buffers(cx: &mut TestAppContex } #[gpui::test] -async fn test_inline_completion_text(cx: &mut TestAppContext) { +async fn test_edit_prediction_text(cx: &mut TestAppContext) { init_test(cx, |_| {}); // Simple insertion @@ -20431,7 +20651,7 @@ async fn test_inline_completion_text(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_inline_completion_text_with_deletions(cx: &mut TestAppContext) { +async fn test_edit_prediction_text_with_deletions(cx: &mut TestAppContext) { init_test(cx, |_| {}); // Deletion @@ -20521,7 +20741,7 @@ async fn assert_highlighted_edits( .await; cx.update(|_window, cx| { - let highlighted_edits = inline_completion_edit_text( + let highlighted_edits = edit_prediction_edit_text( &snapshot.as_singleton().unwrap().2, &edits, &edit_preview, @@ -21293,16 +21513,32 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex }, ); - let (buffer, _handle) = project - .update(cx, |p, cx| { - p.open_local_buffer_with_lsp(path!("/dir/a.ts"), cx) + let editor = workspace + .update(cx, |workspace, window, cx| { + workspace.open_abs_path( + PathBuf::from(path!("/dir/a.ts")), + OpenOptions::default(), + window, + cx, + ) }) + .unwrap() .await + .unwrap() + .downcast::<Editor>() .unwrap(); cx.executor().run_until_parked(); let fake_server = fake_language_servers.next().await.unwrap(); + let buffer = editor.update(cx, |editor, cx| { + editor + .buffer() + .read(cx) + .as_singleton() + .expect("have opened a single file by path") + }); + let buffer_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); let anchor = buffer_snapshot.anchor_at(0, text::Bias::Left); drop(buffer_snapshot); @@ -21360,7 +21596,7 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex assert_eq!( actions.len(), 1, - "Should have only one valid action for the 0..0 range" + "Should have only one valid action for the 0..0 range, got: {actions:#?}" ); let action = actions[0].clone(); let apply = project.update(cx, |project, cx| { @@ -21406,7 +21642,7 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex .into_iter() .collect(), ), - ..Default::default() + ..lsp::WorkspaceEdit::default() }, }, ) @@ -21429,6 +21665,38 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex buffer.undo(cx); assert_eq!(buffer.text(), "a"); }); + + let actions_after_edits = cx + .update_window(*workspace, |_, window, cx| { + project.code_actions(&buffer, anchor..anchor, window, cx) + }) + .unwrap() + .await + .unwrap(); + assert_eq!( + actions, actions_after_edits, + "For the same selection, same code lens actions should be returned" + ); + + let _responses = + fake_server.set_request_handler::<lsp::request::CodeLensRequest, _, _>(|_, _| async move { + panic!("No more code lens requests are expected"); + }); + editor.update_in(cx, |editor, window, cx| { + editor.select_all(&SelectAll, window, cx); + }); + cx.executor().run_until_parked(); + let new_actions = cx + .update_window(*workspace, |_, window, cx| { + project.code_actions(&buffer, anchor..anchor, window, cx) + }) + .unwrap() + .await + .unwrap(); + assert_eq!( + actions, new_actions, + "Code lens are queried for the same range and should get the same set back, but without additional LSP queries now" + ); } #[gpui::test] @@ -22568,6 +22836,435 @@ async fn test_indent_on_newline_for_python(cx: &mut TestAppContext) { "}); } +#[gpui::test] +async fn test_tab_in_leading_whitespace_auto_indents_for_bash(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + + // test cursor move to start of each line on tab + // for `if`, `elif`, `else`, `while`, `for`, `case` and `function` + cx.set_state(indoc! {" + function main() { + ˇ for item in $items; do + ˇ while [ -n \"$item\" ]; do + ˇ if [ \"$value\" -gt 10 ]; then + ˇ continue + ˇ elif [ \"$value\" -lt 0 ]; then + ˇ break + ˇ else + ˇ echo \"$item\" + ˇ fi + ˇ done + ˇ done + ˇ} + "}); + cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx)); + cx.assert_editor_state(indoc! {" + function main() { + ˇfor item in $items; do + ˇwhile [ -n \"$item\" ]; do + ˇif [ \"$value\" -gt 10 ]; then + ˇcontinue + ˇelif [ \"$value\" -lt 0 ]; then + ˇbreak + ˇelse + ˇecho \"$item\" + ˇfi + ˇdone + ˇdone + ˇ} + "}); + // test relative indent is preserved when tab + cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx)); + cx.assert_editor_state(indoc! {" + function main() { + ˇfor item in $items; do + ˇwhile [ -n \"$item\" ]; do + ˇif [ \"$value\" -gt 10 ]; then + ˇcontinue + ˇelif [ \"$value\" -lt 0 ]; then + ˇbreak + ˇelse + ˇecho \"$item\" + ˇfi + ˇdone + ˇdone + ˇ} + "}); + + // test cursor move to start of each line on tab + // for `case` statement with patterns + cx.set_state(indoc! {" + function handle() { + ˇ case \"$1\" in + ˇ start) + ˇ echo \"a\" + ˇ ;; + ˇ stop) + ˇ echo \"b\" + ˇ ;; + ˇ *) + ˇ echo \"c\" + ˇ ;; + ˇ esac + ˇ} + "}); + cx.update_editor(|e, window, cx| e.tab(&Tab, window, cx)); + cx.assert_editor_state(indoc! {" + function handle() { + ˇcase \"$1\" in + ˇstart) + ˇecho \"a\" + ˇ;; + ˇstop) + ˇecho \"b\" + ˇ;; + ˇ*) + ˇecho \"c\" + ˇ;; + ˇesac + ˇ} + "}); +} + +#[gpui::test] +async fn test_indent_after_input_for_bash(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + + // test indents on comment insert + cx.set_state(indoc! {" + function main() { + ˇ for item in $items; do + ˇ while [ -n \"$item\" ]; do + ˇ if [ \"$value\" -gt 10 ]; then + ˇ continue + ˇ elif [ \"$value\" -lt 0 ]; then + ˇ break + ˇ else + ˇ echo \"$item\" + ˇ fi + ˇ done + ˇ done + ˇ} + "}); + cx.update_editor(|e, window, cx| e.handle_input("#", window, cx)); + cx.assert_editor_state(indoc! {" + function main() { + #ˇ for item in $items; do + #ˇ while [ -n \"$item\" ]; do + #ˇ if [ \"$value\" -gt 10 ]; then + #ˇ continue + #ˇ elif [ \"$value\" -lt 0 ]; then + #ˇ break + #ˇ else + #ˇ echo \"$item\" + #ˇ fi + #ˇ done + #ˇ done + #ˇ} + "}); +} + +#[gpui::test] +async fn test_outdent_after_input_for_bash(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + + // test `else` auto outdents when typed inside `if` block + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("else", window, cx); + }); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + elseˇ + "}); + + // test `elif` auto outdents when typed inside `if` block + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("elif", window, cx); + }); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + elifˇ + "}); + + // test `fi` auto outdents when typed inside `else` block + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + else + echo \"bar baz\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("fi", window, cx); + }); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"foo bar\" + else + echo \"bar baz\" + fiˇ + "}); + + // test `done` auto outdents when typed inside `while` block + cx.set_state(indoc! {" + while read line; do + echo \"$line\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("done", window, cx); + }); + cx.assert_editor_state(indoc! {" + while read line; do + echo \"$line\" + doneˇ + "}); + + // test `done` auto outdents when typed inside `for` block + cx.set_state(indoc! {" + for file in *.txt; do + cat \"$file\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("done", window, cx); + }); + cx.assert_editor_state(indoc! {" + for file in *.txt; do + cat \"$file\" + doneˇ + "}); + + // test `esac` auto outdents when typed inside `case` block + cx.set_state(indoc! {" + case \"$1\" in + start) + echo \"foo bar\" + ;; + stop) + echo \"bar baz\" + ;; + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("esac", window, cx); + }); + cx.assert_editor_state(indoc! {" + case \"$1\" in + start) + echo \"foo bar\" + ;; + stop) + echo \"bar baz\" + ;; + esacˇ + "}); + + // test `*)` auto outdents when typed inside `case` block + cx.set_state(indoc! {" + case \"$1\" in + start) + echo \"foo bar\" + ;; + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("*)", window, cx); + }); + cx.assert_editor_state(indoc! {" + case \"$1\" in + start) + echo \"foo bar\" + ;; + *)ˇ + "}); + + // test `fi` outdents to correct level with nested if blocks + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"outer if\" + if [ \"$2\" = \"debug\" ]; then + echo \"inner if\" + ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.handle_input("fi", window, cx); + }); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + echo \"outer if\" + if [ \"$2\" = \"debug\" ]; then + echo \"inner if\" + fiˇ + "}); +} + +#[gpui::test] +async fn test_indent_on_newline_for_bash(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + update_test_language_settings(cx, |settings| { + settings.defaults.extend_comment_on_newline = Some(false); + }); + let mut cx = EditorTestContext::new(cx).await; + let language = languages::language("bash", tree_sitter_bash::LANGUAGE.into()); + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + + // test correct indent after newline on comment + cx.set_state(indoc! {" + # COMMENT:ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.assert_editor_state(indoc! {" + # COMMENT: + ˇ + "}); + + // test correct indent after newline after `then` + cx.set_state(indoc! {" + + if [ \"$1\" = \"test\" ]; thenˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + + if [ \"$1\" = \"test\" ]; then + ˇ + "}); + + // test correct indent after newline after `else` + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + elseˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + else + ˇ + "}); + + // test correct indent after newline after `elif` + cx.set_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + elifˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + if [ \"$1\" = \"test\" ]; then + elif + ˇ + "}); + + // test correct indent after newline after `do` + cx.set_state(indoc! {" + for file in *.txt; doˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + for file in *.txt; do + ˇ + "}); + + // test correct indent after newline after case pattern + cx.set_state(indoc! {" + case \"$1\" in + start)ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + case \"$1\" in + start) + ˇ + "}); + + // test correct indent after newline after case pattern + cx.set_state(indoc! {" + case \"$1\" in + start) + ;; + *)ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + case \"$1\" in + start) + ;; + *) + ˇ + "}); + + // test correct indent after newline after function opening brace + cx.set_state(indoc! {" + function test() {ˇ} + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + function test() { + ˇ + } + "}); + + // test no extra indent after semicolon on same line + cx.set_state(indoc! {" + echo \"test\";ˇ + "}); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.run_until_parked(); + cx.assert_editor_state(indoc! {" + echo \"test\"; + ˇ + "}); +} + fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> { let point = DisplayPoint::new(DisplayRow(row as u32), column as u32); point..point diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index d2ee9d6b0a..e1647215bc 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -3,11 +3,11 @@ use crate::{ CodeActionSource, ColumnarMode, ConflictsOurs, ConflictsOursMarker, ConflictsOuter, ConflictsTheirs, ConflictsTheirsMarker, ContextMenuPlacement, CursorShape, CustomBlockId, DisplayDiffHunk, DisplayPoint, DisplayRow, DocumentHighlightRead, DocumentHighlightWrite, - EditDisplayMode, Editor, EditorMode, EditorSettings, EditorSnapshot, EditorStyle, - FILE_HEADER_HEIGHT, FocusedBlock, GutterDimensions, HalfPageDown, HalfPageUp, HandleInput, - HoveredCursor, InlayHintRefreshReason, InlineCompletion, JumpData, LineDown, LineHighlight, - LineUp, MAX_LINE_LEN, MINIMAP_FONT_SIZE, MULTI_BUFFER_EXCERPT_HEADER_HEIGHT, OpenExcerpts, - PageDown, PageUp, PhantomBreakpointIndicator, Point, RowExt, RowRangeExt, SelectPhase, + EditDisplayMode, EditPrediction, Editor, EditorMode, EditorSettings, EditorSnapshot, + EditorStyle, FILE_HEADER_HEIGHT, FocusedBlock, GutterDimensions, HalfPageDown, HalfPageUp, + HandleInput, HoveredCursor, InlayHintRefreshReason, JumpData, LineDown, LineHighlight, LineUp, + MAX_LINE_LEN, MINIMAP_FONT_SIZE, MULTI_BUFFER_EXCERPT_HEADER_HEIGHT, OpenExcerpts, PageDown, + PageUp, PhantomBreakpointIndicator, Point, RowExt, RowRangeExt, SelectPhase, SelectedTextHighlight, Selection, SelectionDragState, SoftWrap, StickyHeaderExcerpt, ToPoint, ToggleFold, ToggleFoldAll, code_context_menus::{CodeActionsMenu, MENU_ASIDE_MAX_WIDTH, MENU_ASIDE_MIN_WIDTH, MENU_GAP}, @@ -43,11 +43,11 @@ use gpui::{ Bounds, ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges, Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero, Keystroke, Length, - ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, - ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, - Size, StatefulInteractiveElement, Style, Styled, TextRun, TextStyleRefinement, WeakEntity, - Window, anchored, deferred, div, fill, linear_color_stop, linear_gradient, outline, point, px, - quad, relative, size, solid_background, transparent_black, + ModifiersChangedEvent, MouseButton, MouseClickEvent, MouseDownEvent, MouseMoveEvent, + MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, + ShapedLine, SharedString, Size, StatefulInteractiveElement, Style, Styled, TextRun, + TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill, linear_color_stop, + linear_gradient, outline, point, px, quad, relative, size, solid_background, transparent_black, }; use itertools::Itertools; use language::language_settings::{ @@ -230,7 +230,6 @@ impl EditorElement { register_action(editor, window, Editor::sort_lines_case_insensitive); register_action(editor, window, Editor::reverse_lines); register_action(editor, window, Editor::shuffle_lines); - register_action(editor, window, Editor::toggle_case); register_action(editor, window, Editor::convert_indentation_to_spaces); register_action(editor, window, Editor::convert_indentation_to_tabs); register_action(editor, window, Editor::convert_to_upper_case); @@ -241,6 +240,8 @@ impl EditorElement { register_action(editor, window, Editor::convert_to_upper_camel_case); register_action(editor, window, Editor::convert_to_lower_camel_case); register_action(editor, window, Editor::convert_to_opposite_case); + register_action(editor, window, Editor::convert_to_sentence_case); + register_action(editor, window, Editor::toggle_case); register_action(editor, window, Editor::convert_to_rot13); register_action(editor, window, Editor::convert_to_rot47); register_action(editor, window, Editor::delete_to_previous_word_start); @@ -553,7 +554,7 @@ impl EditorElement { register_action(editor, window, Editor::signature_help_next); register_action(editor, window, Editor::next_edit_prediction); register_action(editor, window, Editor::previous_edit_prediction); - register_action(editor, window, Editor::show_inline_completion); + register_action(editor, window, Editor::show_edit_prediction); register_action(editor, window, Editor::context_menu_first); register_action(editor, window, Editor::context_menu_prev); register_action(editor, window, Editor::context_menu_next); @@ -561,7 +562,7 @@ impl EditorElement { register_action(editor, window, Editor::display_cursor_names); register_action(editor, window, Editor::unique_lines_case_insensitive); register_action(editor, window, Editor::unique_lines_case_sensitive); - register_action(editor, window, Editor::accept_partial_inline_completion); + register_action(editor, window, Editor::accept_partial_edit_prediction); register_action(editor, window, Editor::accept_edit_prediction); register_action(editor, window, Editor::restore_file); register_action(editor, window, Editor::git_restore); @@ -948,8 +949,12 @@ impl EditorElement { let hovered_link_modifier = Editor::multi_cursor_modifier(false, &event.modifiers(), cx); - if !pending_nonempty_selections && hovered_link_modifier && text_hitbox.is_hovered(window) { - let point = position_map.point_for_position(event.up.position); + if let Some(mouse_position) = event.mouse_position() + && !pending_nonempty_selections + && hovered_link_modifier + && text_hitbox.is_hovered(window) + { + let point = position_map.point_for_position(mouse_position); editor.handle_click_hovered_link(point, event.modifiers(), window, cx); editor.selection_drag_state = SelectionDragState::None; @@ -2092,7 +2097,7 @@ impl EditorElement { row_block_types: &HashMap<DisplayRow, bool>, content_origin: gpui::Point<Pixels>, scroll_pixel_position: gpui::Point<Pixels>, - inline_completion_popover_origin: Option<gpui::Point<Pixels>>, + edit_prediction_popover_origin: Option<gpui::Point<Pixels>>, start_row: DisplayRow, end_row: DisplayRow, line_height: Pixels, @@ -2209,12 +2214,13 @@ impl EditorElement { cmp::max(padded_line, min_start) }; - let behind_inline_completion_popover = inline_completion_popover_origin - .as_ref() - .map_or(false, |inline_completion_popover_origin| { - (pos_y..pos_y + line_height).contains(&inline_completion_popover_origin.y) - }); - let opacity = if behind_inline_completion_popover { + let behind_edit_prediction_popover = edit_prediction_popover_origin.as_ref().map_or( + false, + |edit_prediction_popover_origin| { + (pos_y..pos_y + line_height).contains(&edit_prediction_popover_origin.y) + }, + ); + let opacity = if behind_edit_prediction_popover { 0.5 } else { 1.0 @@ -2426,9 +2432,9 @@ impl EditorElement { let mut padding = INLINE_BLAME_PADDING_EM_WIDTHS; - if let Some(inline_completion) = editor.active_inline_completion.as_ref() { - match &inline_completion.completion { - InlineCompletion::Edit { + if let Some(edit_prediction) = editor.active_edit_prediction.as_ref() { + match &edit_prediction.completion { + EditPrediction::Edit { display_mode: EditDisplayMode::TabAccept, .. } => padding += INLINE_ACCEPT_SUGGESTION_EM_WIDTHS, @@ -3733,7 +3739,7 @@ impl EditorElement { move |editor, e: &ClickEvent, window, cx| { editor.open_excerpts_common( Some(jump_data.clone()), - e.down.modifiers.secondary(), + e.modifiers().secondary(), window, cx, ); @@ -4085,8 +4091,7 @@ impl EditorElement { { let editor = self.editor.read(cx); - if editor - .edit_prediction_visible_in_cursor_popover(editor.has_active_inline_completion()) + if editor.edit_prediction_visible_in_cursor_popover(editor.has_active_edit_prediction()) { height_above_menu += editor.edit_prediction_cursor_popover_height() + POPOVER_Y_PADDING; @@ -6675,14 +6680,14 @@ impl EditorElement { } } - fn paint_inline_completion_popover( + fn paint_edit_prediction_popover( &mut self, layout: &mut EditorLayout, window: &mut Window, cx: &mut App, ) { - if let Some(inline_completion_popover) = layout.inline_completion_popover.as_mut() { - inline_completion_popover.paint(window, cx); + if let Some(edit_prediction_popover) = layout.edit_prediction_popover.as_mut() { + edit_prediction_popover.paint(window, cx); } } @@ -6881,10 +6886,10 @@ impl EditorElement { // Fire click handlers during the bubble phase. DispatchPhase::Bubble => editor.update(cx, |editor, cx| { if let Some(mouse_down) = captured_mouse_down.take() { - let event = ClickEvent { + let event = ClickEvent::Mouse(MouseClickEvent { down: mouse_down, up: event.clone(), - }; + }); Self::click(editor, &event, &position_map, window, cx); } }), @@ -7943,17 +7948,11 @@ impl Element for EditorElement { right: right_margin, }; - // Offset the content_bounds from the text_bounds by the gutter margin (which - // is roughly half a character wide) to make hit testing work more like how we want. - let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO); - - let editor_content_width = editor_width - content_offset.x; - snapshot = self.editor.update(cx, |editor, cx| { editor.last_bounds = Some(bounds); editor.gutter_dimensions = gutter_dimensions; editor.set_visible_line_count(bounds.size.height / line_height, window, cx); - editor.set_visible_column_count(editor_content_width / em_advance); + editor.set_visible_column_count(editor_width / em_advance); if matches!( editor.mode, @@ -7965,10 +7964,10 @@ impl Element for EditorElement { let wrap_width = match editor.soft_wrap_mode(cx) { SoftWrap::GitDiff => None, SoftWrap::None => Some(wrap_width_for(MAX_LINE_LEN as u32 / 2)), - SoftWrap::EditorWidth => Some(editor_content_width), + SoftWrap::EditorWidth => Some(editor_width), SoftWrap::Column(column) => Some(wrap_width_for(column)), SoftWrap::Bounded(column) => { - Some(editor_content_width.min(wrap_width_for(column))) + Some(editor_width.min(wrap_width_for(column))) } }; @@ -7993,13 +7992,12 @@ impl Element for EditorElement { HitboxBehavior::Normal, ); + // Offset the content_bounds from the text_bounds by the gutter margin (which + // is roughly half a character wide) to make hit testing work more like how we want. + let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO); let content_origin = text_hitbox.origin + content_offset; - let editor_text_bounds = - Bounds::from_corners(content_origin, bounds.bottom_right()); - - let height_in_lines = editor_text_bounds.size.height / line_height; - + let height_in_lines = bounds.size.height / line_height; let max_row = snapshot.max_point().row().as_f32(); // The max scroll position for the top of the window @@ -8383,7 +8381,6 @@ impl Element for EditorElement { glyph_grid_cell, size(longest_line_width, max_row.as_f32() * line_height), longest_line_blame_width, - editor_width, EditorSettings::get_global(cx), ); @@ -8455,7 +8452,7 @@ impl Element for EditorElement { MultiBufferRow(end_anchor.to_point(&snapshot.buffer_snapshot).row); let scroll_max = point( - ((scroll_width - editor_content_width) / em_advance).max(0.0), + ((scroll_width - editor_width) / em_advance).max(0.0), max_scroll_top, ); @@ -8467,7 +8464,7 @@ impl Element for EditorElement { if needs_horizontal_autoscroll.0 && let Some(new_scroll_position) = editor.autoscroll_horizontally( start_row, - editor_content_width, + editor_width, scroll_width, em_advance, &line_layouts, @@ -8508,7 +8505,7 @@ impl Element for EditorElement { ) }); - let (inline_completion_popover, inline_completion_popover_origin) = self + let (edit_prediction_popover, edit_prediction_popover_origin) = self .editor .update(cx, |editor, cx| { editor.render_edit_prediction_popover( @@ -8537,7 +8534,7 @@ impl Element for EditorElement { &row_block_types, content_origin, scroll_pixel_position, - inline_completion_popover_origin, + edit_prediction_popover_origin, start_row, end_row, line_height, @@ -8926,7 +8923,7 @@ impl Element for EditorElement { cursors, visible_cursors, selections, - inline_completion_popover, + edit_prediction_popover, diff_hunk_controls, mouse_context_menu, test_indicators, @@ -9008,7 +9005,7 @@ impl Element for EditorElement { self.paint_minimap(layout, window, cx); self.paint_scrollbars(layout, window, cx); - self.paint_inline_completion_popover(layout, window, cx); + self.paint_edit_prediction_popover(layout, window, cx); self.paint_mouse_context_menu(layout, window, cx); }); }) @@ -9048,7 +9045,6 @@ impl ScrollbarLayoutInformation { glyph_grid_cell: Size<Pixels>, document_size: Size<Pixels>, longest_line_blame_width: Pixels, - editor_width: Pixels, settings: &EditorSettings, ) -> Self { let vertical_overscroll = match settings.scroll_beyond_last_line { @@ -9059,19 +9055,11 @@ impl ScrollbarLayoutInformation { } }; - let right_margin = if document_size.width + longest_line_blame_width >= editor_width { - glyph_grid_cell.width - } else { - px(0.0) - }; - - let overscroll = size(right_margin + longest_line_blame_width, vertical_overscroll); - - let scroll_range = document_size + overscroll; + let overscroll = size(longest_line_blame_width, vertical_overscroll); ScrollbarLayoutInformation { editor_bounds, - scroll_range, + scroll_range: document_size + overscroll, glyph_grid_cell, } } @@ -9118,7 +9106,7 @@ pub struct EditorLayout { expand_toggles: Vec<Option<(AnyElement, gpui::Point<Pixels>)>>, diff_hunk_controls: Vec<AnyElement>, crease_trailers: Vec<Option<CreaseTrailerLayout>>, - inline_completion_popover: Option<AnyElement>, + edit_prediction_popover: Option<AnyElement>, mouse_context_menu: Option<AnyElement>, tab_invisible: ShapedLine, space_invisible: ShapedLine, @@ -9176,7 +9164,7 @@ struct EditorScrollbars { impl EditorScrollbars { pub fn from_scrollbar_axes( - settings_visibility: ScrollbarAxes, + show_scrollbar: ScrollbarAxes, layout_information: &ScrollbarLayoutInformation, content_offset: gpui::Point<Pixels>, scroll_position: gpui::Point<f32>, @@ -9214,22 +9202,13 @@ impl EditorScrollbars { }; let mut create_scrollbar_layout = |axis| { - settings_visibility - .along(axis) + let viewport_size = viewport_size.along(axis); + let scroll_range = scroll_range.along(axis); + + // We always want a vertical scrollbar track for scrollbar diagnostic visibility. + (show_scrollbar.along(axis) + && (axis == ScrollbarAxis::Vertical || scroll_range > viewport_size)) .then(|| { - ( - viewport_size.along(axis) - content_offset.along(axis), - scroll_range.along(axis), - ) - }) - .filter(|(viewport_size, scroll_range)| { - // The scrollbar should only be rendered if the content does - // not entirely fit into the editor - // However, this only applies to the horizontal scrollbar, as information about the - // vertical scrollbar layout is always needed for scrollbar diagnostics. - axis != ScrollbarAxis::Horizontal || viewport_size < scroll_range - }) - .map(|(viewport_size, scroll_range)| { ScrollbarLayout::new( window.insert_hitbox(scrollbar_bounds_for(axis), HitboxBehavior::Normal), viewport_size, diff --git a/crates/editor/src/linked_editing_ranges.rs b/crates/editor/src/linked_editing_ranges.rs index 7c2672fc0d..a185de33ca 100644 --- a/crates/editor/src/linked_editing_ranges.rs +++ b/crates/editor/src/linked_editing_ranges.rs @@ -95,7 +95,7 @@ pub(super) fn refresh_linked_ranges( let snapshot = buffer.read(cx).snapshot(); let buffer_id = buffer.read(cx).remote_id(); - let linked_edits_task = project.linked_edit(buffer, *start, cx); + let linked_edits_task = project.linked_edits(buffer, *start, cx); let highlights = move || async move { let edits = linked_edits_task.await.log_err()?; // Find the range containing our current selection. diff --git a/crates/editor/src/lsp_colors.rs b/crates/editor/src/lsp_colors.rs index ce07dd43fe..08cf9078f2 100644 --- a/crates/editor/src/lsp_colors.rs +++ b/crates/editor/src/lsp_colors.rs @@ -6,7 +6,7 @@ use gpui::{Hsla, Rgba}; use itertools::Itertools; use language::point_from_lsp; use multi_buffer::Anchor; -use project::{DocumentColor, lsp_store::ColorFetchStrategy}; +use project::{DocumentColor, lsp_store::LspFetchStrategy}; use settings::Settings as _; use text::{Bias, BufferId, OffsetRangeExt as _}; use ui::{App, Context, Window}; @@ -180,9 +180,9 @@ impl Editor { .filter_map(|buffer| { let buffer_id = buffer.read(cx).remote_id(); let fetch_strategy = if ignore_cache { - ColorFetchStrategy::IgnoreCache + LspFetchStrategy::IgnoreCache } else { - ColorFetchStrategy::UseCache { + LspFetchStrategy::UseCache { known_cache_version: self.colors.as_ref().and_then(|colors| { Some(colors.buffer_colors.get(&buffer_id)?.cache_version_used) }), diff --git a/crates/editor/src/lsp_ext.rs b/crates/editor/src/lsp_ext.rs index 8d078f304c..6161afbbc0 100644 --- a/crates/editor/src/lsp_ext.rs +++ b/crates/editor/src/lsp_ext.rs @@ -3,9 +3,8 @@ use std::time::Duration; use crate::Editor; use collections::HashMap; -use futures::stream::FuturesUnordered; use gpui::AsyncApp; -use gpui::{App, AppContext as _, Entity, Task}; +use gpui::{App, Entity, Task}; use itertools::Itertools; use language::Buffer; use language::Language; @@ -18,7 +17,6 @@ use project::Project; use project::TaskSourceKind; use project::lsp_store::lsp_ext_command::GetLspRunnables; use smol::future::FutureExt as _; -use smol::stream::StreamExt; use task::ResolvedTask; use task::TaskContext; use text::BufferId; @@ -29,52 +27,32 @@ pub(crate) fn find_specific_language_server_in_selection<F>( editor: &Editor, cx: &mut App, filter_language: F, - language_server_name: &str, -) -> Task<Option<(Anchor, Arc<Language>, LanguageServerId, Entity<Buffer>)>> + language_server_name: LanguageServerName, +) -> Option<(Anchor, Arc<Language>, LanguageServerId, Entity<Buffer>)> where F: Fn(&Language) -> bool, { - let Some(project) = &editor.project else { - return Task::ready(None); - }; - - let applicable_buffers = editor + let project = editor.project.clone()?; + editor .selections .disjoint_anchors() .iter() .filter_map(|selection| Some((selection.head(), selection.head().buffer_id?))) .unique_by(|(_, buffer_id)| *buffer_id) - .filter_map(|(trigger_anchor, buffer_id)| { + .find_map(|(trigger_anchor, buffer_id)| { let buffer = editor.buffer().read(cx).buffer(buffer_id)?; let language = buffer.read(cx).language_at(trigger_anchor.text_anchor)?; if filter_language(&language) { - Some((trigger_anchor, buffer, language)) + let server_id = buffer.update(cx, |buffer, cx| { + project + .read(cx) + .language_server_id_for_name(buffer, &language_server_name, cx) + })?; + Some((trigger_anchor, language, server_id, buffer)) } else { None } }) - .collect::<Vec<_>>(); - - let applicable_buffer_tasks = applicable_buffers - .into_iter() - .map(|(trigger_anchor, buffer, language)| { - let task = buffer.update(cx, |buffer, cx| { - project.update(cx, |project, cx| { - project.language_server_id_for_name(buffer, language_server_name, cx) - }) - }); - (trigger_anchor, buffer, language, task) - }) - .collect::<Vec<_>>(); - cx.background_spawn(async move { - for (trigger_anchor, buffer, language, task) in applicable_buffer_tasks { - if let Some(server_id) = task.await { - return Some((trigger_anchor, language, server_id, buffer)); - } - } - - None - }) } async fn lsp_task_context( @@ -116,9 +94,9 @@ pub fn lsp_tasks( for_position: Option<text::Anchor>, cx: &mut App, ) -> Task<Vec<(TaskSourceKind, Vec<(Option<LocationLink>, ResolvedTask)>)>> { - let mut lsp_task_sources = task_sources + let lsp_task_sources = task_sources .iter() - .map(|(name, buffer_ids)| { + .filter_map(|(name, buffer_ids)| { let buffers = buffer_ids .iter() .filter(|&&buffer_id| match for_position { @@ -127,61 +105,63 @@ pub fn lsp_tasks( }) .filter_map(|&buffer_id| project.read(cx).buffer_for_id(buffer_id, cx)) .collect::<Vec<_>>(); - language_server_for_buffers(project.clone(), name.clone(), buffers, cx) + + let server_id = buffers.iter().find_map(|buffer| { + project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), name, cx) + }) + }); + server_id.zip(Some(buffers)) }) - .collect::<FuturesUnordered<_>>(); + .collect::<Vec<_>>(); cx.spawn(async move |cx| { cx.spawn(async move |cx| { let mut lsp_tasks = HashMap::default(); - while let Some(server_to_query) = lsp_task_sources.next().await { - if let Some((server_id, buffers)) = server_to_query { - let mut new_lsp_tasks = Vec::new(); - for buffer in buffers { - let source_kind = match buffer.update(cx, |buffer, _| { - buffer.language().map(|language| language.name()) - }) { - Ok(Some(language_name)) => TaskSourceKind::Lsp { - server: server_id, - language_name: SharedString::from(language_name), - }, - Ok(None) => continue, - Err(_) => return Vec::new(), - }; - let id_base = source_kind.to_id_base(); - let lsp_buffer_context = lsp_task_context(&project, &buffer, cx) - .await - .unwrap_or_default(); + for (server_id, buffers) in lsp_task_sources { + let mut new_lsp_tasks = Vec::new(); + for buffer in buffers { + let source_kind = match buffer.update(cx, |buffer, _| { + buffer.language().map(|language| language.name()) + }) { + Ok(Some(language_name)) => TaskSourceKind::Lsp { + server: server_id, + language_name: SharedString::from(language_name), + }, + Ok(None) => continue, + Err(_) => return Vec::new(), + }; + let id_base = source_kind.to_id_base(); + let lsp_buffer_context = lsp_task_context(&project, &buffer, cx) + .await + .unwrap_or_default(); - if let Ok(runnables_task) = project.update(cx, |project, cx| { - let buffer_id = buffer.read(cx).remote_id(); - project.request_lsp( - buffer, - LanguageServerToQuery::Other(server_id), - GetLspRunnables { - buffer_id, - position: for_position, + if let Ok(runnables_task) = project.update(cx, |project, cx| { + let buffer_id = buffer.read(cx).remote_id(); + project.request_lsp( + buffer, + LanguageServerToQuery::Other(server_id), + GetLspRunnables { + buffer_id, + position: for_position, + }, + cx, + ) + }) { + if let Some(new_runnables) = runnables_task.await.log_err() { + new_lsp_tasks.extend(new_runnables.runnables.into_iter().filter_map( + |(location, runnable)| { + let resolved_task = + runnable.resolve_task(&id_base, &lsp_buffer_context)?; + Some((location, resolved_task)) }, - cx, - ) - }) { - if let Some(new_runnables) = runnables_task.await.log_err() { - new_lsp_tasks.extend( - new_runnables.runnables.into_iter().filter_map( - |(location, runnable)| { - let resolved_task = runnable - .resolve_task(&id_base, &lsp_buffer_context)?; - Some((location, resolved_task)) - }, - ), - ); - } + )); } - lsp_tasks - .entry(source_kind) - .or_insert_with(Vec::new) - .append(&mut new_lsp_tasks); } + lsp_tasks + .entry(source_kind) + .or_insert_with(Vec::new) + .append(&mut new_lsp_tasks); } } lsp_tasks.into_iter().collect() @@ -198,27 +178,3 @@ pub fn lsp_tasks( .await }) } - -fn language_server_for_buffers( - project: Entity<Project>, - name: LanguageServerName, - candidates: Vec<Entity<Buffer>>, - cx: &mut App, -) -> Task<Option<(LanguageServerId, Vec<Entity<Buffer>>)>> { - cx.spawn(async move |cx| { - for buffer in &candidates { - let server_id = buffer - .update(cx, |buffer, cx| { - project.update(cx, |project, cx| { - project.language_server_id_for_name(buffer, &name.0, cx) - }) - }) - .ok()? - .await; - if let Some(server_id) = server_id { - return Some((server_id, candidates)); - } - } - None - }) -} diff --git a/crates/editor/src/mouse_context_menu.rs b/crates/editor/src/mouse_context_menu.rs index cbb6791a2f..9d5145dec1 100644 --- a/crates/editor/src/mouse_context_menu.rs +++ b/crates/editor/src/mouse_context_menu.rs @@ -1,8 +1,8 @@ use crate::{ Copy, CopyAndTrim, CopyPermalinkToLine, Cut, DisplayPoint, DisplaySnapshot, Editor, EvaluateSelectedText, FindAllReferences, GoToDeclaration, GoToDefinition, GoToImplementation, - GoToTypeDefinition, Paste, Rename, RevealInFileManager, SelectMode, SelectionEffects, - SelectionExt, ToDisplayPoint, ToggleCodeActions, + GoToTypeDefinition, Paste, Rename, RevealInFileManager, RunToCursor, SelectMode, + SelectionEffects, SelectionExt, ToDisplayPoint, ToggleCodeActions, actions::{Format, FormatSelections}, selections_collection::SelectionsCollection, }; @@ -200,15 +200,21 @@ pub fn deploy_context_menu( }); let evaluate_selection = window.is_action_available(&EvaluateSelectedText, cx); + let run_to_cursor = window.is_action_available(&RunToCursor, cx); ui::ContextMenu::build(window, cx, |menu, _window, _cx| { let builder = menu .on_blur_subscription(Subscription::new(|| {})) - .when(evaluate_selection && has_selections, |builder| { - builder - .action("Evaluate Selection", Box::new(EvaluateSelectedText)) - .separator() + .when(run_to_cursor, |builder| { + builder.action("Run to Cursor", Box::new(RunToCursor)) }) + .when(evaluate_selection && has_selections, |builder| { + builder.action("Evaluate Selection", Box::new(EvaluateSelectedText)) + }) + .when( + run_to_cursor || (evaluate_selection && has_selections), + |builder| builder.separator(), + ) .action("Go to Definition", Box::new(GoToDefinition)) .action("Go to Declaration", Box::new(GoToDeclaration)) .action("Go to Type Definition", Box::new(GoToTypeDefinition)) diff --git a/crates/editor/src/movement.rs b/crates/editor/src/movement.rs index b9b7cb2e58..a8850984a1 100644 --- a/crates/editor/src/movement.rs +++ b/crates/editor/src/movement.rs @@ -907,12 +907,12 @@ mod tests { let inlays = (0..buffer_snapshot.len()) .flat_map(|offset| { [ - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut id), buffer_snapshot.anchor_at(offset, Bias::Left), "test", ), - Inlay::inline_completion( + Inlay::edit_prediction( post_inc(&mut id), buffer_snapshot.anchor_at(offset, Bias::Right), "test", diff --git a/crates/editor/src/rust_analyzer_ext.rs b/crates/editor/src/rust_analyzer_ext.rs index da0f11036f..2b8150de67 100644 --- a/crates/editor/src/rust_analyzer_ext.rs +++ b/crates/editor/src/rust_analyzer_ext.rs @@ -57,21 +57,21 @@ pub fn go_to_parent_module( return; }; - let server_lookup = find_specific_language_server_in_selection( - editor, - cx, - is_rust_language, - RUST_ANALYZER_NAME, - ); + let Some((trigger_anchor, _, server_to_query, buffer)) = + find_specific_language_server_in_selection( + editor, + cx, + is_rust_language, + RUST_ANALYZER_NAME, + ) + else { + return; + }; let project = project.clone(); let lsp_store = project.read(cx).lsp_store(); let upstream_client = lsp_store.read(cx).upstream_client(); cx.spawn_in(window, async move |editor, cx| { - let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else { - return anyhow::Ok(()); - }; - let location_links = if let Some((client, project_id)) = upstream_client { let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id())?; @@ -121,7 +121,7 @@ pub fn go_to_parent_module( ) })? .await?; - Ok(()) + anyhow::Ok(()) }) .detach_and_log_err(cx); } @@ -139,21 +139,19 @@ pub fn expand_macro_recursively( return; }; - let server_lookup = find_specific_language_server_in_selection( - editor, - cx, - is_rust_language, - RUST_ANALYZER_NAME, - ); - + let Some((trigger_anchor, rust_language, server_to_query, buffer)) = + find_specific_language_server_in_selection( + editor, + cx, + is_rust_language, + RUST_ANALYZER_NAME, + ) + else { + return; + }; let project = project.clone(); let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); cx.spawn_in(window, async move |_editor, cx| { - let Some((trigger_anchor, rust_language, server_to_query, buffer)) = server_lookup.await - else { - return Ok(()); - }; - let macro_expansion = if let Some((client, project_id)) = upstream_client { let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id())?; let request = proto::LspExtExpandMacro { @@ -231,20 +229,20 @@ pub fn open_docs(editor: &mut Editor, _: &OpenDocs, window: &mut Window, cx: &mu return; }; - let server_lookup = find_specific_language_server_in_selection( - editor, - cx, - is_rust_language, - RUST_ANALYZER_NAME, - ); + let Some((trigger_anchor, _, server_to_query, buffer)) = + find_specific_language_server_in_selection( + editor, + cx, + is_rust_language, + RUST_ANALYZER_NAME, + ) + else { + return; + }; let project = project.clone(); let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); cx.spawn_in(window, async move |_editor, cx| { - let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else { - return Ok(()); - }; - let docs_urls = if let Some((client, project_id)) = upstream_client { let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id())?; let request = proto::LspExtOpenDocs { diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index d5db7f71a4..a0214c76a1 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -19,8 +19,8 @@ path = "src/explorer.rs" [dependencies] agent.workspace = true -agent_ui.workspace = true agent_settings.workspace = true +agent_ui.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true @@ -29,6 +29,7 @@ buffer_diff.workspace = true chrono.workspace = true clap.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true debug_adapter_extension.workspace = true dirs.workspace = true @@ -68,4 +69,3 @@ util.workspace = true uuid.workspace = true watch.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index a02b4a7f0b..d638ac171f 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -18,7 +18,7 @@ use collections::{HashMap, HashSet}; use extension::ExtensionHostProxy; use futures::future; use gpui::http_client::read_proxy_from_env; -use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal}; +use gpui::{App, AppContext, Application, AsyncApp, Entity, UpdateGlobal}; use gpui_tokio::Tokio; use language::LanguageRegistry; use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel}; @@ -337,7 +337,8 @@ pub struct AgentAppState { } pub fn init(cx: &mut App) -> Arc<AgentAppState> { - release_channel::init(SemanticVersion::default(), cx); + let app_version = AppVersion::global(cx); + release_channel::init(app_version, cx); gpui_tokio::init(cx); let mut settings_store = SettingsStore::new(cx); @@ -350,7 +351,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> { // Set User-Agent so we can download language servers from GitHub let user_agent = format!( "Zed/{} ({}; {})", - AppVersion::global(cx), + app_version, std::env::consts::OS, std::env::consts::ARCH ); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 7ce3b1fdf1..23c8814916 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -15,11 +15,11 @@ use agent_settings::AgentProfileId; use anyhow::{Result, anyhow}; use async_trait::async_trait; use buffer_diff::DiffHunkStatus; +use cloud_llm_client::CompletionIntent; use collections::HashMap; use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased}; use gpui::{App, AppContext, AsyncApp, Entity}; use language_model::{LanguageModel, Role, StopReason}; -use zed_llm_client::CompletionIntent; pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); diff --git a/crates/extension/Cargo.toml b/crates/extension/Cargo.toml index 4fc7da2dca..42189f20b3 100644 --- a/crates/extension/Cargo.toml +++ b/crates/extension/Cargo.toml @@ -32,7 +32,11 @@ serde.workspace = true serde_json.workspace = true task.workspace = true toml.workspace = true +url.workspace = true util.workspace = true wasm-encoder.workspace = true wasmparser.workspace = true workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/extension/src/capabilities.rs b/crates/extension/src/capabilities.rs new file mode 100644 index 0000000000..b8afc4ec06 --- /dev/null +++ b/crates/extension/src/capabilities.rs @@ -0,0 +1,20 @@ +mod download_file_capability; +mod npm_install_package_capability; +mod process_exec_capability; + +pub use download_file_capability::*; +pub use npm_install_package_capability::*; +pub use process_exec_capability::*; + +use serde::{Deserialize, Serialize}; + +/// A capability for an extension. +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum ExtensionCapability { + #[serde(rename = "process:exec")] + ProcessExec(ProcessExecCapability), + DownloadFile(DownloadFileCapability), + #[serde(rename = "npm:install")] + NpmInstallPackage(NpmInstallPackageCapability), +} diff --git a/crates/extension/src/capabilities/download_file_capability.rs b/crates/extension/src/capabilities/download_file_capability.rs new file mode 100644 index 0000000000..a76755b593 --- /dev/null +++ b/crates/extension/src/capabilities/download_file_capability.rs @@ -0,0 +1,121 @@ +use serde::{Deserialize, Serialize}; +use url::Url; + +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct DownloadFileCapability { + pub host: String, + pub path: Vec<String>, +} + +impl DownloadFileCapability { + /// Returns whether the capability allows downloading a file from the given URL. + pub fn allows(&self, url: &Url) -> bool { + let Some(desired_host) = url.host_str() else { + return false; + }; + + let Some(desired_path) = url.path_segments() else { + return false; + }; + let desired_path = desired_path.collect::<Vec<_>>(); + + if self.host != desired_host && self.host != "*" { + return false; + } + + for (ix, path_segment) in self.path.iter().enumerate() { + if path_segment == "**" { + return true; + } + + if ix >= desired_path.len() { + return false; + } + + if path_segment != "*" && path_segment != desired_path[ix] { + return false; + } + } + + if self.path.len() < desired_path.len() { + return false; + } + + true + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_allows() { + let capability = DownloadFileCapability { + host: "*".to_string(), + path: vec!["**".to_string()], + }; + assert_eq!( + capability.allows(&"https://example.com/some/path".parse().unwrap()), + true + ); + + let capability = DownloadFileCapability { + host: "github.com".to_string(), + path: vec!["**".to_string()], + }; + assert_eq!( + capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()), + true + ); + assert_eq!( + capability.allows( + &"https://fake-github.com/some-owner/some-repo" + .parse() + .unwrap() + ), + false + ); + + let capability = DownloadFileCapability { + host: "github.com".to_string(), + path: vec!["specific-owner".to_string(), "*".to_string()], + }; + assert_eq!( + capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()), + false + ); + assert_eq!( + capability.allows( + &"https://github.com/specific-owner/some-repo" + .parse() + .unwrap() + ), + true + ); + + let capability = DownloadFileCapability { + host: "github.com".to_string(), + path: vec!["specific-owner".to_string(), "*".to_string()], + }; + assert_eq!( + capability.allows( + &"https://github.com/some-owner/some-repo/extra" + .parse() + .unwrap() + ), + false + ); + assert_eq!( + capability.allows( + &"https://github.com/specific-owner/some-repo/extra" + .parse() + .unwrap() + ), + false + ); + } +} diff --git a/crates/extension/src/capabilities/npm_install_package_capability.rs b/crates/extension/src/capabilities/npm_install_package_capability.rs new file mode 100644 index 0000000000..287645fc75 --- /dev/null +++ b/crates/extension/src/capabilities/npm_install_package_capability.rs @@ -0,0 +1,39 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct NpmInstallPackageCapability { + pub package: String, +} + +impl NpmInstallPackageCapability { + /// Returns whether the capability allows installing the given NPM package. + pub fn allows(&self, package: &str) -> bool { + self.package == "*" || self.package == package + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_allows() { + let capability = NpmInstallPackageCapability { + package: "*".to_string(), + }; + assert_eq!(capability.allows("package"), true); + + let capability = NpmInstallPackageCapability { + package: "react".to_string(), + }; + assert_eq!(capability.allows("react"), true); + + let capability = NpmInstallPackageCapability { + package: "react".to_string(), + }; + assert_eq!(capability.allows("malicious-package"), false); + } +} diff --git a/crates/extension/src/capabilities/process_exec_capability.rs b/crates/extension/src/capabilities/process_exec_capability.rs new file mode 100644 index 0000000000..053a7b212b --- /dev/null +++ b/crates/extension/src/capabilities/process_exec_capability.rs @@ -0,0 +1,116 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct ProcessExecCapability { + /// The command to execute. + pub command: String, + /// The arguments to pass to the command. Use `*` for a single wildcard argument. + /// If the last element is `**`, then any trailing arguments are allowed. + pub args: Vec<String>, +} + +impl ProcessExecCapability { + /// Returns whether the capability allows the given command and arguments. + pub fn allows( + &self, + desired_command: &str, + desired_args: &[impl AsRef<str> + std::fmt::Debug], + ) -> bool { + if self.command != desired_command && self.command != "*" { + return false; + } + + for (ix, arg) in self.args.iter().enumerate() { + if arg == "**" { + return true; + } + + if ix >= desired_args.len() { + return false; + } + + if arg != "*" && arg != desired_args[ix].as_ref() { + return false; + } + } + + if self.args.len() < desired_args.len() { + return false; + } + + true + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_allows_with_exact_match() { + let capability = ProcessExecCapability { + command: "ls".to_string(), + args: vec!["-la".to_string()], + }; + + assert_eq!(capability.allows("ls", &["-la"]), true); + assert_eq!(capability.allows("ls", &["-l"]), false); + assert_eq!(capability.allows("pwd", &[] as &[&str]), false); + } + + #[test] + fn test_allows_with_wildcard_arg() { + let capability = ProcessExecCapability { + command: "git".to_string(), + args: vec!["*".to_string()], + }; + + assert_eq!(capability.allows("git", &["status"]), true); + assert_eq!(capability.allows("git", &["commit"]), true); + // Too many args. + assert_eq!(capability.allows("git", &["status", "-s"]), false); + // Wrong command. + assert_eq!(capability.allows("npm", &["install"]), false); + } + + #[test] + fn test_allows_with_double_wildcard() { + let capability = ProcessExecCapability { + command: "cargo".to_string(), + args: vec!["test".to_string(), "**".to_string()], + }; + + assert_eq!(capability.allows("cargo", &["test"]), true); + assert_eq!(capability.allows("cargo", &["test", "--all"]), true); + assert_eq!( + capability.allows("cargo", &["test", "--all", "--no-fail-fast"]), + true + ); + // Wrong first arg. + assert_eq!(capability.allows("cargo", &["build"]), false); + } + + #[test] + fn test_allows_with_mixed_wildcards() { + let capability = ProcessExecCapability { + command: "docker".to_string(), + args: vec!["run".to_string(), "*".to_string(), "**".to_string()], + }; + + assert_eq!(capability.allows("docker", &["run", "nginx"]), true); + assert_eq!(capability.allows("docker", &["run"]), false); + assert_eq!( + capability.allows("docker", &["run", "ubuntu", "bash"]), + true + ); + assert_eq!( + capability.allows("docker", &["run", "alpine", "sh", "-c", "echo hello"]), + true + ); + // Wrong first arg. + assert_eq!(capability.allows("docker", &["ps"]), false); + } +} diff --git a/crates/extension/src/extension.rs b/crates/extension/src/extension.rs index 8b150e19b9..35f7f41938 100644 --- a/crates/extension/src/extension.rs +++ b/crates/extension/src/extension.rs @@ -1,3 +1,4 @@ +mod capabilities; pub mod extension_builder; mod extension_events; mod extension_host_proxy; @@ -16,6 +17,7 @@ use language::LanguageName; use semantic_version::SemanticVersion; use task::{SpawnInTerminal, ZedDebugConfig}; +pub use crate::capabilities::*; pub use crate::extension_events::*; pub use crate::extension_host_proxy::*; pub use crate::extension_manifest::*; diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index 0a14923c0c..5852b3e3fc 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -12,6 +12,8 @@ use std::{ sync::Arc, }; +use crate::ExtensionCapability; + /// This is the old version of the extension manifest, from when it was `extension.json`. #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct OldExtensionManifest { @@ -100,24 +102,8 @@ impl ExtensionManifest { desired_args: &[impl AsRef<str> + std::fmt::Debug], ) -> Result<()> { let is_allowed = self.capabilities.iter().any(|capability| match capability { - ExtensionCapability::ProcessExec { command, args } if command == desired_command => { - for (ix, arg) in args.iter().enumerate() { - if arg == "**" { - return true; - } - - if ix >= desired_args.len() { - return false; - } - - if arg != "*" && arg != desired_args[ix].as_ref() { - return false; - } - } - if args.len() < desired_args.len() { - return false; - } - true + ExtensionCapability::ProcessExec(capability) => { + capability.allows(desired_command, desired_args) } _ => false, }); @@ -148,20 +134,6 @@ pub fn build_debug_adapter_schema_path( }) } -/// A capability for an extension. -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -#[serde(tag = "kind")] -pub enum ExtensionCapability { - #[serde(rename = "process:exec")] - ProcessExec { - /// The command to execute. - command: String, - /// The arguments to pass to the command. Use `*` for a single wildcard argument. - /// If the last element is `**`, then any trailing arguments are allowed. - args: Vec<String>, - }, -} - #[derive(Clone, Default, PartialEq, Eq, Debug, Deserialize, Serialize)] pub struct LibManifestEntry { pub kind: Option<ExtensionLibraryKind>, @@ -191,7 +163,7 @@ pub struct LanguageServerManifestEntry { #[serde(default)] languages: Vec<LanguageName>, #[serde(default)] - pub language_ids: HashMap<String, String>, + pub language_ids: HashMap<LanguageName, String>, #[serde(default)] pub code_action_kinds: Option<Vec<lsp::CodeActionKind>>, } @@ -309,6 +281,10 @@ fn manifest_from_old_manifest( #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; + + use crate::ProcessExecCapability; + use super::*; fn extension_manifest() -> ExtensionManifest { @@ -360,12 +336,12 @@ mod tests { } #[test] - fn test_allow_exact_match() { + fn test_allow_exec_exact_match() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { command: "ls".to_string(), args: vec!["-la".to_string()], - }], + })], ..extension_manifest() }; @@ -375,12 +351,12 @@ mod tests { } #[test] - fn test_allow_wildcard_arg() { + fn test_allow_exec_wildcard_arg() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { command: "git".to_string(), args: vec!["*".to_string()], - }], + })], ..extension_manifest() }; @@ -391,12 +367,12 @@ mod tests { } #[test] - fn test_allow_double_wildcard() { + fn test_allow_exec_double_wildcard() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { command: "cargo".to_string(), args: vec!["test".to_string(), "**".to_string()], - }], + })], ..extension_manifest() }; @@ -411,12 +387,12 @@ mod tests { } #[test] - fn test_allow_mixed_wildcards() { + fn test_allow_exec_mixed_wildcards() { let manifest = ExtensionManifest { - capabilities: vec![ExtensionCapability::ProcessExec { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { command: "docker".to_string(), args: vec!["run".to_string(), "*".to_string(), "**".to_string()], - }], + })], ..extension_manifest() }; diff --git a/crates/extension_host/benches/extension_compilation_benchmark.rs b/crates/extension_host/benches/extension_compilation_benchmark.rs index 9d867af041..a4fa9bfeff 100644 --- a/crates/extension_host/benches/extension_compilation_benchmark.rs +++ b/crates/extension_host/benches/extension_compilation_benchmark.rs @@ -134,10 +134,12 @@ fn manifest() -> ExtensionManifest { slash_commands: BTreeMap::default(), indexed_docs_providers: BTreeMap::default(), snippets: None, - capabilities: vec![ExtensionCapability::ProcessExec { - command: "echo".into(), - args: vec!["hello!".into()], - }], + capabilities: vec![ExtensionCapability::ProcessExec( + extension::ProcessExecCapability { + command: "echo".into(), + args: vec!["hello!".into()], + }, + )], debug_adapters: Default::default(), debug_locators: Default::default(), } diff --git a/crates/extension_host/src/capability_granter.rs b/crates/extension_host/src/capability_granter.rs new file mode 100644 index 0000000000..c77e5ecba1 --- /dev/null +++ b/crates/extension_host/src/capability_granter.rs @@ -0,0 +1,153 @@ +use std::sync::Arc; + +use anyhow::{Result, bail}; +use extension::{ExtensionCapability, ExtensionManifest}; +use url::Url; + +pub struct CapabilityGranter { + granted_capabilities: Vec<ExtensionCapability>, + manifest: Arc<ExtensionManifest>, +} + +impl CapabilityGranter { + pub fn new( + granted_capabilities: Vec<ExtensionCapability>, + manifest: Arc<ExtensionManifest>, + ) -> Self { + Self { + granted_capabilities, + manifest, + } + } + + pub fn grant_exec( + &self, + desired_command: &str, + desired_args: &[impl AsRef<str> + std::fmt::Debug], + ) -> Result<()> { + self.manifest.allow_exec(desired_command, desired_args)?; + + let is_allowed = self + .granted_capabilities + .iter() + .any(|capability| match capability { + ExtensionCapability::ProcessExec(capability) => { + capability.allows(desired_command, desired_args) + } + _ => false, + }); + + if !is_allowed { + bail!( + "capability for process:exec {desired_command} {desired_args:?} is not granted by the extension host", + ); + } + + Ok(()) + } + + pub fn grant_download_file(&self, desired_url: &Url) -> Result<()> { + let is_allowed = self + .granted_capabilities + .iter() + .any(|capability| match capability { + ExtensionCapability::DownloadFile(capability) => capability.allows(desired_url), + _ => false, + }); + + if !is_allowed { + bail!( + "capability for download_file {desired_url} is not granted by the extension host", + ); + } + + Ok(()) + } + + pub fn grant_npm_install_package(&self, package_name: &str) -> Result<()> { + let is_allowed = self + .granted_capabilities + .iter() + .any(|capability| match capability { + ExtensionCapability::NpmInstallPackage(capability) => { + capability.allows(package_name) + } + _ => false, + }); + + if !is_allowed { + bail!("capability for npm:install {package_name} is not granted by the extension host",); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use extension::{ProcessExecCapability, SchemaVersion}; + + use super::*; + + fn extension_manifest() -> ExtensionManifest { + ExtensionManifest { + id: "test".into(), + name: "Test".to_string(), + version: "1.0.0".into(), + schema_version: SchemaVersion::ZERO, + description: None, + repository: None, + authors: vec![], + lib: Default::default(), + themes: vec![], + icon_themes: vec![], + languages: vec![], + grammars: BTreeMap::default(), + language_servers: BTreeMap::default(), + context_servers: BTreeMap::default(), + slash_commands: BTreeMap::default(), + indexed_docs_providers: BTreeMap::default(), + snippets: None, + capabilities: vec![], + debug_adapters: Default::default(), + debug_locators: Default::default(), + } + } + + #[test] + fn test_grant_exec() { + let manifest = Arc::new(ExtensionManifest { + capabilities: vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + command: "ls".to_string(), + args: vec!["-la".to_string()], + })], + ..extension_manifest() + }); + + // It returns an error when the extension host has no granted capabilities. + let granter = CapabilityGranter::new(Vec::new(), manifest.clone()); + assert!(granter.grant_exec("ls", &["-la"]).is_err()); + + // It succeeds when the extension host has the exact capability. + let granter = CapabilityGranter::new( + vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + command: "ls".to_string(), + args: vec!["-la".to_string()], + })], + manifest.clone(), + ); + assert!(granter.grant_exec("ls", &["-la"]).is_ok()); + + // It succeeds when the extension host has a wildcard capability. + let granter = CapabilityGranter::new( + vec![ExtensionCapability::ProcessExec(ProcessExecCapability { + command: "*".to_string(), + args: vec!["**".to_string()], + })], + manifest.clone(), + ); + assert!(granter.grant_exec("ls", &["-la"]).is_ok()); + } +} diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index fd64d3fa59..dc38c244f1 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -1,3 +1,4 @@ +mod capability_granter; pub mod extension_settings; pub mod headless_host; pub mod wasm_host; diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index 891ab91852..c31774c20d 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -10,7 +10,7 @@ use fs::{FakeFs, Fs, RealFs}; use futures::{AsyncReadExt, StreamExt, io::BufReader}; use gpui::{AppContext as _, SemanticVersion, TestAppContext}; use http_client::{FakeHttpClient, Response}; -use language::{BinaryStatus, LanguageMatcher, LanguageRegistry}; +use language::{BinaryStatus, LanguageMatcher, LanguageName, LanguageRegistry}; use language_extension::LspAccess; use lsp::LanguageServerName; use node_runtime::NodeRuntime; @@ -306,7 +306,11 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!( language_registry.language_names(), - ["ERB", "Plain Text", "Ruby"] + [ + LanguageName::new("ERB"), + LanguageName::new("Plain Text"), + LanguageName::new("Ruby"), + ] ); assert_eq!( theme_registry.list_names(), @@ -458,7 +462,11 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!( language_registry.language_names(), - ["ERB", "Plain Text", "Ruby"] + [ + LanguageName::new("ERB"), + LanguageName::new("Plain Text"), + LanguageName::new("Ruby"), + ] ); assert_eq!( language_registry.grammar_names(), @@ -513,7 +521,10 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!(actual_language.hidden, expected_language.hidden); } - assert_eq!(language_registry.language_names(), ["Plain Text"]); + assert_eq!( + language_registry.language_names(), + [LanguageName::new("Plain Text")] + ); assert_eq!(language_registry.grammar_names(), []); }); } diff --git a/crates/extension_host/src/wasm_host.rs b/crates/extension_host/src/wasm_host.rs index 3e0f06fa38..d990b670f4 100644 --- a/crates/extension_host/src/wasm_host.rs +++ b/crates/extension_host/src/wasm_host.rs @@ -1,13 +1,15 @@ pub mod wit; use crate::ExtensionManifest; +use crate::capability_granter::CapabilityGranter; use anyhow::{Context as _, Result, anyhow, bail}; use async_trait::async_trait; use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest}; use extension::{ CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary, - DebugTaskDefinition, ExtensionHostProxy, KeyValueStoreDelegate, ProjectDelegate, SlashCommand, - SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate, + DebugTaskDefinition, DownloadFileCapability, ExtensionCapability, ExtensionHostProxy, + KeyValueStoreDelegate, NpmInstallPackageCapability, ProcessExecCapability, ProjectDelegate, + SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate, }; use fs::{Fs, normalize_path}; use futures::future::LocalBoxFuture; @@ -50,6 +52,8 @@ pub struct WasmHost { pub(crate) proxy: Arc<ExtensionHostProxy>, fs: Arc<dyn Fs>, pub work_dir: PathBuf, + /// The capabilities granted to extensions running on the host. + pub(crate) granted_capabilities: Vec<ExtensionCapability>, _main_thread_message_task: Task<()>, main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>, } @@ -102,7 +106,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_initialization_options( @@ -127,7 +131,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_workspace_configuration( @@ -150,7 +154,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_additional_initialization_options( @@ -175,7 +179,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn language_server_additional_workspace_configuration( @@ -200,7 +204,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn labels_for_completions( @@ -226,7 +230,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn labels_for_symbols( @@ -252,7 +256,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn complete_slash_command_argument( @@ -271,7 +275,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn run_slash_command( @@ -297,7 +301,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn context_server_command( @@ -316,7 +320,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn context_server_configuration( @@ -343,7 +347,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> { @@ -358,7 +362,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn index_docs( @@ -384,7 +388,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn get_dap_binary( @@ -406,7 +410,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn dap_request_kind( &self, @@ -423,7 +427,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn dap_config_to_scenario(&self, config: ZedDebugConfig) -> Result<DebugScenario> { @@ -437,7 +441,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn dap_locator_create_scenario( @@ -461,7 +465,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } async fn run_dap_locator( &self, @@ -477,7 +481,7 @@ impl extension::Extension for WasmExtension { } .boxed() }) - .await + .await? } } @@ -486,6 +490,7 @@ pub struct WasmState { pub table: ResourceTable, ctx: wasi::WasiCtx, pub host: Arc<WasmHost>, + pub(crate) capability_granter: CapabilityGranter, } type MainThreadCall = Box<dyn Send + for<'a> FnOnce(&'a mut AsyncApp) -> LocalBoxFuture<'a, ()>>; @@ -571,6 +576,19 @@ impl WasmHost { node_runtime, proxy, release_channel: ReleaseChannel::global(cx), + granted_capabilities: vec![ + ExtensionCapability::ProcessExec(ProcessExecCapability { + command: "*".to_string(), + args: vec!["**".to_string()], + }), + ExtensionCapability::DownloadFile(DownloadFileCapability { + host: "*".to_string(), + path: vec!["**".to_string()], + }), + ExtensionCapability::NpmInstallPackage(NpmInstallPackageCapability { + package: "*".to_string(), + }), + ], _main_thread_message_task: task, main_thread_message_tx: tx, }) @@ -597,6 +615,10 @@ impl WasmHost { manifest: manifest.clone(), table: ResourceTable::new(), host: this.clone(), + capability_granter: CapabilityGranter::new( + this.granted_capabilities.clone(), + manifest.clone(), + ), }, ); // Store will yield after 1 tick, and get a new deadline of 1 tick after each yield. @@ -739,7 +761,7 @@ impl WasmExtension { .with_context(|| format!("failed to load wasm extension {}", manifest.id)) } - pub async fn call<T, Fn>(&self, f: Fn) -> T + pub async fn call<T, Fn>(&self, f: Fn) -> Result<T> where T: 'static + Send, Fn: 'static @@ -755,8 +777,19 @@ impl WasmExtension { } .boxed() })) - .expect("wasm extension channel should not be closed yet"); - return_rx.await.expect("wasm extension channel") + .map_err(|_| { + anyhow!( + "wasm extension channel should not be closed yet, extension {} (id {})", + self.manifest.name, + self.manifest.id, + ) + })?; + return_rx.await.with_context(|| { + format!( + "wasm extension channel, extension {} (id {})", + self.manifest.name, self.manifest.id, + ) + }) } } @@ -777,8 +810,19 @@ impl WasmState { } .boxed_local() })) - .expect("main thread message channel should not be closed yet"); - async move { return_rx.await.expect("main thread message channel") } + .unwrap_or_else(|_| { + panic!( + "main thread message channel should not be closed yet, extension {} (id {})", + self.manifest.name, self.manifest.id, + ) + }); + let name = self.manifest.name.clone(); + let id = self.manifest.id.clone(); + async move { + return_rx.await.unwrap_or_else(|_| { + panic!("main thread message channel, extension {name} (id {id})") + }) + } } fn work_dir(&self) -> PathBuf { diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs index d25328ee7f..767b9033ad 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs @@ -30,6 +30,7 @@ use std::{ sync::{Arc, OnceLock}, }; use task::{SpawnInTerminal, ZedDebugConfig}; +use url::Url; use util::{archive::extract_zip, fs::make_file_executable, maybe}; use wasmtime::component::{Linker, Resource}; @@ -744,6 +745,9 @@ impl nodejs::Host for WasmState { package_name: String, version: String, ) -> wasmtime::Result<Result<(), String>> { + self.capability_granter + .grant_npm_install_package(&package_name)?; + self.host .node_runtime .npm_install_packages(&self.work_dir(), &[(&package_name, &version)]) @@ -847,7 +851,8 @@ impl process::Host for WasmState { command: process::Command, ) -> wasmtime::Result<Result<process::Output, String>> { maybe!(async { - self.manifest.allow_exec(&command.command, &command.args)?; + self.capability_granter + .grant_exec(&command.command, &command.args)?; let output = util::command::new_smol_command(command.command.as_str()) .args(&command.args) @@ -1010,6 +1015,9 @@ impl ExtensionImports for WasmState { file_type: DownloadedFileType, ) -> wasmtime::Result<Result<(), String>> { maybe!(async { + let parsed_url = Url::parse(&url)?; + self.capability_granter.grant_download_file(&parsed_url)?; + let path = PathBuf::from(path); let extension_work_dir = self.host.work_dir.join(self.manifest.id.as_ref()); diff --git a/crates/file_finder/src/file_finder.rs b/crates/file_finder/src/file_finder.rs index a4d61dd56f..e5ac70bb58 100644 --- a/crates/file_finder/src/file_finder.rs +++ b/crates/file_finder/src/file_finder.rs @@ -1404,14 +1404,21 @@ impl PickerDelegate for FileFinderDelegate { } else { let path_position = PathWithPosition::parse_str(&raw_query); + #[cfg(windows)] + let raw_query = raw_query.trim().to_owned().replace("/", "\\"); + #[cfg(not(windows))] + let raw_query = raw_query.trim().to_owned(); + + let file_query_end = if path_position.path.to_str().unwrap_or(&raw_query) == raw_query { + None + } else { + // Safe to unwrap as we won't get here when the unwrap in if fails + Some(path_position.path.to_str().unwrap().len()) + }; + let query = FileSearchQuery { - raw_query: raw_query.trim().to_owned(), - file_query_end: if path_position.path.to_str().unwrap_or(raw_query) == raw_query { - None - } else { - // Safe to unwrap as we won't get here when the unwrap in if fails - Some(path_position.path.to_str().unwrap().len()) - }, + raw_query, + file_query_end, path_position, }; diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 8a4f7c03bb..04ba656232 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -10,7 +10,7 @@ use git::{ }, status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus}, }; -use gpui::{AsyncApp, BackgroundExecutor}; +use gpui::{AsyncApp, BackgroundExecutor, SharedString}; use ignore::gitignore::GitignoreBuilder; use rope::Rope; use smol::future::FutureExt as _; @@ -398,6 +398,18 @@ impl GitRepository for FakeGitRepository { }) } + fn stash_paths( + &self, + _paths: Vec<RepoPath>, + _env: Arc<HashMap<String, String>>, + ) -> BoxFuture<Result<()>> { + unimplemented!() + } + + fn stash_pop(&self, _env: Arc<HashMap<String, String>>) -> BoxFuture<Result<()>> { + unimplemented!() + } + fn commit( &self, _message: gpui::SharedString, @@ -479,4 +491,8 @@ impl GitRepository for FakeGitRepository { ) -> BoxFuture<'_, Result<String>> { unimplemented!() } + + fn default_branch(&self) -> BoxFuture<'_, Result<Option<SharedString>>> { + unimplemented!() + } } diff --git a/crates/fuzzy/src/matcher.rs b/crates/fuzzy/src/matcher.rs index aff6390534..e649d47dd6 100644 --- a/crates/fuzzy/src/matcher.rs +++ b/crates/fuzzy/src/matcher.rs @@ -208,8 +208,15 @@ impl<'a> Matcher<'a> { return 1.0; } - let path_len = prefix.len() + path.len(); + let limit = self.last_positions[query_idx]; + let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1); + let safe_limit = limit.min(max_valid_index); + if path_idx > safe_limit { + return 0.0; + } + + let path_len = prefix.len() + path.len(); if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] { return memoized; } @@ -218,16 +225,13 @@ impl<'a> Matcher<'a> { let mut best_position = 0; let query_char = self.lowercase_query[query_idx]; - let limit = self.last_positions[query_idx]; - - let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1); - let safe_limit = limit.min(max_valid_index); let mut last_slash = 0; + for j in path_idx..=safe_limit { let extra_lowercase_chars_count = extra_lowercase_chars .iter() - .take_while(|(i, _)| i < &&j) + .take_while(|&(&i, _)| i < j) .map(|(_, increment)| increment) .sum::<usize>(); let j_regular = j - extra_lowercase_chars_count; @@ -236,10 +240,9 @@ impl<'a> Matcher<'a> { lowercase_prefix[j] } else { let path_index = j - prefix.len(); - if path_index < path_lowercased.len() { - path_lowercased[path_index] - } else { - continue; + match path_lowercased.get(path_index) { + Some(&char) => char, + None => continue, } }; let is_path_sep = path_char == MAIN_SEPARATOR; @@ -255,18 +258,16 @@ impl<'a> Matcher<'a> { #[cfg(target_os = "windows")] let need_to_score = query_char == path_char || (is_path_sep && query_char == '_'); if need_to_score { - let curr = if j_regular < prefix.len() { - prefix[j_regular] - } else { - path[j_regular - prefix.len()] + let curr = match prefix.get(j_regular) { + Some(&curr) => curr, + None => path[j_regular - prefix.len()], }; let mut char_score = 1.0; if j > path_idx { - let last = if j_regular - 1 < prefix.len() { - prefix[j_regular - 1] - } else { - path[j_regular - 1 - prefix.len()] + let last = match prefix.get(j_regular - 1) { + Some(&last) => last, + None => path[j_regular - 1 - prefix.len()], }; if last == MAIN_SEPARATOR { diff --git a/crates/git/src/git.rs b/crates/git/src/git.rs index 3714086dd0..553361e673 100644 --- a/crates/git/src/git.rs +++ b/crates/git/src/git.rs @@ -55,6 +55,10 @@ actions!( StageAll, /// Unstages all changes in the repository. UnstageAll, + /// Stashes all changes in the repository, including untracked files. + StashAll, + /// Pops the most recent stash. + StashPop, /// Restores all tracked files to their last committed state. RestoreTrackedFiles, /// Moves all untracked files to trash. diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index 9cc3442392..b536bed710 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -395,6 +395,14 @@ pub trait GitRepository: Send + Sync { env: Arc<HashMap<String, String>>, ) -> BoxFuture<'_, Result<()>>; + fn stash_paths( + &self, + paths: Vec<RepoPath>, + env: Arc<HashMap<String, String>>, + ) -> BoxFuture<Result<()>>; + + fn stash_pop(&self, env: Arc<HashMap<String, String>>) -> BoxFuture<Result<()>>; + fn push( &self, branch_name: String, @@ -455,6 +463,8 @@ pub trait GitRepository: Send + Sync { base_checkpoint: GitRepositoryCheckpoint, target_checkpoint: GitRepositoryCheckpoint, ) -> BoxFuture<'_, Result<String>>; + + fn default_branch(&self) -> BoxFuture<'_, Result<Option<SharedString>>>; } pub enum DiffType { @@ -1189,6 +1199,55 @@ impl GitRepository for RealGitRepository { .boxed() } + fn stash_paths( + &self, + paths: Vec<RepoPath>, + env: Arc<HashMap<String, String>>, + ) -> BoxFuture<Result<()>> { + let working_directory = self.working_directory(); + self.executor + .spawn(async move { + let mut cmd = new_smol_command("git"); + cmd.current_dir(&working_directory?) + .envs(env.iter()) + .args(["stash", "push", "--quiet"]) + .arg("--include-untracked"); + + cmd.args(paths.iter().map(|p| p.as_ref())); + + let output = cmd.output().await?; + + anyhow::ensure!( + output.status.success(), + "Failed to stash:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + Ok(()) + }) + .boxed() + } + + fn stash_pop(&self, env: Arc<HashMap<String, String>>) -> BoxFuture<Result<()>> { + let working_directory = self.working_directory(); + self.executor + .spawn(async move { + let mut cmd = new_smol_command("git"); + cmd.current_dir(&working_directory?) + .envs(env.iter()) + .args(["stash", "pop"]); + + let output = cmd.output().await?; + + anyhow::ensure!( + output.status.success(), + "Failed to stash pop:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + Ok(()) + }) + .boxed() + } + fn commit( &self, message: SharedString, @@ -1550,6 +1609,37 @@ impl GitRepository for RealGitRepository { }) .boxed() } + + fn default_branch(&self) -> BoxFuture<'_, Result<Option<SharedString>>> { + let working_directory = self.working_directory(); + let git_binary_path = self.git_binary_path.clone(); + + let executor = self.executor.clone(); + self.executor + .spawn(async move { + let working_directory = working_directory?; + let git = GitBinary::new(git_binary_path, working_directory, executor); + + if let Ok(output) = git + .run(&["symbolic-ref", "refs/remotes/upstream/HEAD"]) + .await + { + let output = output + .strip_prefix("refs/remotes/upstream/") + .map(|s| SharedString::from(s.to_owned())); + return Ok(output); + } + + let output = git + .run(&["symbolic-ref", "refs/remotes/origin/HEAD"]) + .await?; + + Ok(output + .strip_prefix("refs/remotes/origin/") + .map(|s| SharedString::from(s.to_owned()))) + }) + .boxed() + } } fn git_status_args(path_prefixes: &[RepoPath]) -> Vec<OsString> { diff --git a/crates/git_hosting_providers/src/providers/github.rs b/crates/git_hosting_providers/src/providers/github.rs index 649b2f30ae..30f8d058a7 100644 --- a/crates/git_hosting_providers/src/providers/github.rs +++ b/crates/git_hosting_providers/src/providers/github.rs @@ -159,7 +159,11 @@ impl GitHostingProvider for Github { } let mut path_segments = url.path_segments()?; - let owner = path_segments.next()?; + let mut owner = path_segments.next()?; + if owner.is_empty() { + owner = path_segments.next()?; + } + let repo = path_segments.next()?.trim_end_matches(".git"); Some(ParsedGitRemote { @@ -244,6 +248,22 @@ mod tests { use super::*; + #[test] + fn test_remote_url_with_root_slash() { + let remote_url = "git@github.com:/zed-industries/zed"; + let parsed_remote = Github::public_instance() + .parse_remote_url(remote_url) + .unwrap(); + + assert_eq!( + parsed_remote, + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed".into(), + } + ); + } + #[test] fn test_invalid_self_hosted_remote_url() { let remote_url = "git@github.com:zed-industries/zed.git"; diff --git a/crates/git_ui/Cargo.toml b/crates/git_ui/Cargo.toml index 2fb80b7e73..35f7a60354 100644 --- a/crates/git_ui/Cargo.toml +++ b/crates/git_ui/Cargo.toml @@ -23,7 +23,7 @@ askpass.workspace = true buffer_diff.workspace = true call.workspace = true chrono.workspace = true -client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true component.workspace = true @@ -62,7 +62,6 @@ watch.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true [target.'cfg(windows)'.dependencies] windows.workspace = true @@ -71,6 +70,7 @@ windows.workspace = true ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } +indoc.workspace = true pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 9eac3ce5af..b74fa649b0 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -13,7 +13,7 @@ use project::git_store::Repository; use std::sync::Arc; use time::OffsetDateTime; use time_format::format_local_timestamp; -use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*}; +use ui::{HighlightedLabel, ListItem, ListItemSpacing, Tooltip, prelude::*}; use util::ResultExt; use workspace::notifications::DetachAndPromptErr; use workspace::{ModalView, Workspace}; @@ -90,11 +90,21 @@ impl BranchList { let all_branches_request = repository .clone() .map(|repository| repository.update(cx, |repository, _| repository.branches())); + let default_branch_request = repository + .clone() + .map(|repository| repository.update(cx, |repository, _| repository.default_branch())); cx.spawn_in(window, async move |this, cx| { let mut all_branches = all_branches_request .context("No active repository")? .await??; + let default_branch = default_branch_request + .context("No active repository")? + .await + .map(Result::ok) + .ok() + .flatten() + .flatten(); let all_branches = cx .background_spawn(async move { @@ -124,6 +134,7 @@ impl BranchList { this.update_in(cx, |this, window, cx| { this.picker.update(cx, |picker, cx| { + picker.delegate.default_branch = default_branch; picker.delegate.all_branches = Some(all_branches); picker.refresh(window, cx); }) @@ -169,6 +180,7 @@ impl Focusable for BranchList { impl Render for BranchList { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { v_flex() + .key_context("GitBranchSelector") .w(self.width) .on_modifiers_changed(cx.listener(Self::handle_modifiers_changed)) .child(self.picker.clone()) @@ -192,6 +204,7 @@ struct BranchEntry { pub struct BranchListDelegate { matches: Vec<BranchEntry>, all_branches: Option<Vec<Branch>>, + default_branch: Option<SharedString>, repo: Option<Entity<Repository>>, style: BranchListStyle, selected_index: usize, @@ -206,6 +219,7 @@ impl BranchListDelegate { repo, style, all_branches: None, + default_branch: None, selected_index: 0, last_query: Default::default(), modifiers: Default::default(), @@ -214,6 +228,7 @@ impl BranchListDelegate { fn create_branch( &self, + from_branch: Option<SharedString>, new_branch_name: SharedString, window: &mut Window, cx: &mut Context<Picker<Self>>, @@ -223,6 +238,11 @@ impl BranchListDelegate { }; let new_branch_name = new_branch_name.to_string().replace(' ', "-"); cx.spawn(async move |_, cx| { + if let Some(based_branch) = from_branch { + repo.update(cx, |repo, _| repo.change_branch(based_branch.to_string()))? + .await??; + } + repo.update(cx, |repo, _| { repo.create_branch(new_branch_name.to_string()) })? @@ -353,12 +373,22 @@ impl PickerDelegate for BranchListDelegate { }) } - fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) { + fn confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) { let Some(entry) = self.matches.get(self.selected_index()) else { return; }; if entry.is_new { - self.create_branch(entry.branch.name().to_owned().into(), window, cx); + let from_branch = if secondary { + self.default_branch.clone() + } else { + None + }; + self.create_branch( + from_branch, + entry.branch.name().to_owned().into(), + window, + cx, + ); return; } @@ -439,6 +469,28 @@ impl PickerDelegate for BranchListDelegate { }) .unwrap_or_else(|| (None, None)); + let icon = if let Some(default_branch) = self.default_branch.clone() + && entry.is_new + { + Some( + IconButton::new("branch-from-default", IconName::GitBranchSmall) + .on_click(cx.listener(move |this, _, window, cx| { + this.delegate.set_selected_index(ix, window, cx); + this.delegate.confirm(true, window, cx); + })) + .tooltip(move |window, cx| { + Tooltip::for_action( + format!("Create branch based off default: {default_branch}"), + &menu::SecondaryConfirm, + window, + cx, + ) + }), + ) + } else { + None + }; + let branch_name = if entry.is_new { h_flex() .gap_1() @@ -504,7 +556,8 @@ impl PickerDelegate for BranchListDelegate { .color(Color::Muted) })) }), - ), + ) + .end_slot::<IconButton>(icon), ) } diff --git a/crates/git_ui/src/commit_modal.rs b/crates/git_ui/src/commit_modal.rs index b99f628806..5dfa800ae5 100644 --- a/crates/git_ui/src/commit_modal.rs +++ b/crates/git_ui/src/commit_modal.rs @@ -1,9 +1,9 @@ use crate::branch_picker::{self, BranchList}; use crate::git_panel::{GitPanel, commit_message_editor}; -use client::DisableAiSettings; use git::repository::CommitOptions; use git::{Amend, Commit, GenerateCommitMessage, Signoff}; use panel::{panel_button, panel_editor_style}; +use project::DisableAiSettings; use settings::Settings; use ui::{ ContextMenu, KeybindingHint, PopoverMenu, PopoverMenuHandle, SplitButton, Tooltip, prelude::*, @@ -295,11 +295,13 @@ impl CommitModal { IconPosition::Start, Some(Box::new(Amend)), { - let git_panel = git_panel_entity.clone(); - move |window, cx| { - git_panel.update(cx, |git_panel, cx| { - git_panel.toggle_amend_pending(&Amend, window, cx); - }) + let git_panel = git_panel_entity.downgrade(); + move |_, cx| { + git_panel + .update(cx, |git_panel, cx| { + git_panel.toggle_amend_pending(cx); + }) + .ok(); } }, ) diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 061833a6c7..44222b8299 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -12,7 +12,6 @@ use crate::{ use agent_settings::AgentSettings; use anyhow::Context as _; use askpass::AskPassDelegate; -use client::DisableAiSettings; use db::kvp::KEY_VALUE_STORE; use editor::{ Editor, EditorElement, EditorMode, EditorSettings, MultiBuffer, ShowScrollbar, @@ -27,7 +26,10 @@ use git::repository::{ }; use git::status::StageStatus; use git::{Amend, Signoff, ToggleStaged, repository::RepoPath, status::FileStatus}; -use git::{ExpandCommitEditor, RestoreTrackedFiles, StageAll, TrashUntrackedFiles, UnstageAll}; +use git::{ + ExpandCommitEditor, RestoreTrackedFiles, StageAll, StashAll, StashPop, TrashUntrackedFiles, + UnstageAll, +}; use gpui::{ Action, Animation, AnimationExt as _, AsyncApp, AsyncWindowContext, Axis, ClickEvent, Corner, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, KeyContext, @@ -48,10 +50,9 @@ use panel::{ PanelHeader, panel_button, panel_editor_container, panel_editor_style, panel_filled_button, panel_icon_button, }; -use project::git_store::{RepositoryEvent, RepositoryId}; use project::{ - Fs, Project, ProjectPath, - git_store::{GitStoreEvent, Repository}, + DisableAiSettings, Fs, Project, ProjectPath, + git_store::{GitStoreEvent, Repository, RepositoryEvent, RepositoryId}, }; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -68,12 +69,12 @@ use ui::{ use util::{ResultExt, TryFutureExt, maybe}; use workspace::SERIALIZATION_THROTTLE_TIME; +use cloud_llm_client::CompletionIntent; use workspace::{ Workspace, dock::{DockPosition, Panel, PanelEvent}, notifications::{DetachAndPromptErr, ErrorMessagePrompt, NotificationId}, }; -use zed_llm_client::CompletionIntent; actions!( git_panel, @@ -140,6 +141,13 @@ fn git_panel_context_menu( UnstageAll.boxed_clone(), ) .separator() + .action_disabled_when( + !(state.has_new_changes || state.has_tracked_changes), + "Stash All", + StashAll.boxed_clone(), + ) + .action("Stash Pop", StashPop.boxed_clone()) + .separator() .action("Open Diff", project_diff::Diff.boxed_clone()) .separator() .action_disabled_when( @@ -380,6 +388,9 @@ pub(crate) fn commit_message_editor( window: &mut Window, cx: &mut Context<Editor>, ) -> Editor { + project.update(cx, |this, cx| { + this.mark_buffer_as_non_searchable(commit_message_buffer.read(cx).remote_id(), cx); + }); let buffer = cx.new(|cx| MultiBuffer::singleton(commit_message_buffer, cx)); let max_lines = if in_panel { MAX_PANEL_EDITOR_LINES } else { 18 }; let mut commit_editor = Editor::new( @@ -1412,6 +1423,52 @@ impl GitPanel { self.tracked_staged_count + self.new_staged_count + self.conflicted_staged_count } + pub fn stash_pop(&mut self, _: &StashPop, _window: &mut Window, cx: &mut Context<Self>) { + let Some(active_repository) = self.active_repository.clone() else { + return; + }; + + cx.spawn({ + async move |this, cx| { + let stash_task = active_repository + .update(cx, |repo, cx| repo.stash_pop(cx))? + .await; + this.update(cx, |this, cx| { + stash_task + .map_err(|e| { + this.show_error_toast("stash pop", e, cx); + }) + .ok(); + cx.notify(); + }) + } + }) + .detach(); + } + + pub fn stash_all(&mut self, _: &StashAll, _window: &mut Window, cx: &mut Context<Self>) { + let Some(active_repository) = self.active_repository.clone() else { + return; + }; + + cx.spawn({ + async move |this, cx| { + let stash_task = active_repository + .update(cx, |repo, cx| repo.stash_all(cx))? + .await; + this.update(cx, |this, cx| { + stash_task + .map_err(|e| { + this.show_error_toast("stash", e, cx); + }) + .ok(); + cx.notify(); + }) + } + }) + .detach(); + } + pub fn commit_message_buffer(&self, cx: &App) -> Entity<Buffer> { self.commit_editor .read(cx) @@ -2357,7 +2414,7 @@ impl GitPanel { .committer_name .clone() .or_else(|| participant.user.name.clone()) - .unwrap_or_else(|| participant.user.github_login.clone()); + .unwrap_or_else(|| participant.user.github_login.clone().to_string()); new_co_authors.push((name.clone(), email.clone())) } } @@ -2377,7 +2434,7 @@ impl GitPanel { .name .clone() .or_else(|| user.name.clone()) - .unwrap_or_else(|| user.github_login.clone()); + .unwrap_or_else(|| user.github_login.clone().to_string()); Some((name, email)) } @@ -2842,7 +2899,9 @@ impl GitPanel { let status_toast = StatusToast::new(message, cx, move |this, _cx| { use remote_output::SuccessStyle::*; match style { - Toast { .. } => this, + Toast { .. } => { + this.icon(ToastIcon::new(IconName::GitBranchSmall).color(Color::Muted)) + } ToastWithLog { output } => this .icon(ToastIcon::new(IconName::GitBranchSmall).color(Color::Muted)) .action("View Log", move |window, cx| { @@ -2855,9 +2914,9 @@ impl GitPanel { }) .ok(); }), - PushPrLink { link } => this + PushPrLink { text, link } => this .icon(ToastIcon::new(IconName::GitBranchSmall).color(Color::Muted)) - .action("Open Pull Request", move |_, cx| cx.open_url(&link)), + .action(text, move |_, cx| cx.open_url(&link)), } }); workspace.toggle_status_toast(status_toast, cx) @@ -3054,6 +3113,7 @@ impl GitPanel { ), ) .menu({ + let git_panel = cx.entity(); let has_previous_commit = self.head_commit(cx).is_some(); let amend = self.amend_pending(); let signoff = self.signoff_enabled; @@ -3070,7 +3130,16 @@ impl GitPanel { amend, IconPosition::Start, Some(Box::new(Amend)), - move |window, cx| window.dispatch_action(Box::new(Amend), cx), + { + let git_panel = git_panel.downgrade(); + move |_, cx| { + git_panel + .update(cx, |git_panel, cx| { + git_panel.toggle_amend_pending(cx); + }) + .ok(); + } + }, ) }) .toggleable_entry( @@ -3441,9 +3510,11 @@ impl GitPanel { .truncate(), ), ) - .child(panel_button("Cancel").size(ButtonSize::Default).on_click( - cx.listener(|this, _, window, cx| this.toggle_amend_pending(&Amend, window, cx)), - )) + .child( + panel_button("Cancel") + .size(ButtonSize::Default) + .on_click(cx.listener(|this, _, _, cx| this.set_amend_pending(false, cx))), + ) } fn render_previous_commit(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> { @@ -4204,17 +4275,8 @@ impl GitPanel { pub fn set_amend_pending(&mut self, value: bool, cx: &mut Context<Self>) { self.amend_pending = value; - cx.notify(); - } - - pub fn toggle_amend_pending( - &mut self, - _: &Amend, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - self.set_amend_pending(!self.amend_pending, cx); self.serialize(cx); + cx.notify(); } pub fn signoff_enabled(&self) -> bool { @@ -4308,6 +4370,13 @@ impl GitPanel { anchor: path, }); } + + pub(crate) fn toggle_amend_pending(&mut self, cx: &mut Context<Self>) { + self.set_amend_pending(!self.amend_pending, cx); + if self.amend_pending { + self.load_last_commit_message_if_empty(cx); + } + } } fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn LanguageModel>> { @@ -4352,7 +4421,6 @@ impl Render for GitPanel { .on_action(cx.listener(Self::stage_range)) .on_action(cx.listener(GitPanel::commit)) .on_action(cx.listener(GitPanel::amend)) - .on_action(cx.listener(GitPanel::toggle_amend_pending)) .on_action(cx.listener(GitPanel::toggle_signoff_enabled)) .on_action(cx.listener(Self::stage_all)) .on_action(cx.listener(Self::unstage_all)) @@ -4362,6 +4430,8 @@ impl Render for GitPanel { .on_action(cx.listener(Self::revert_selected)) .on_action(cx.listener(Self::clean_all)) .on_action(cx.listener(Self::generate_commit_message_action)) + .on_action(cx.listener(Self::stash_all)) + .on_action(cx.listener(Self::stash_pop)) }) .on_action(cx.listener(Self::select_first)) .on_action(cx.listener(Self::select_next)) @@ -5045,7 +5115,6 @@ mod tests { language::init(cx); editor::init(cx); Project::init_settings(cx); - client::DisableAiSettings::register(cx); crate::init(cx); }); } diff --git a/crates/git_ui/src/git_ui.rs b/crates/git_ui/src/git_ui.rs index 2d7fba13c5..0163175eda 100644 --- a/crates/git_ui/src/git_ui.rs +++ b/crates/git_ui/src/git_ui.rs @@ -114,6 +114,22 @@ pub fn init(cx: &mut App) { }); }); } + workspace.register_action(|workspace, action: &git::StashAll, window, cx| { + let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { + return; + }; + panel.update(cx, |panel, cx| { + panel.stash_all(action, window, cx); + }); + }); + workspace.register_action(|workspace, action: &git::StashPop, window, cx| { + let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { + return; + }; + panel.update(cx, |panel, cx| { + panel.stash_pop(action, window, cx); + }); + }); workspace.register_action(|workspace, action: &git::StageAll, window, cx| { let Some(panel) = workspace.panel::<git_panel::GitPanel>(cx) else { return; diff --git a/crates/git_ui/src/remote_output.rs b/crates/git_ui/src/remote_output.rs index 03fbf4f917..8437bf0d0d 100644 --- a/crates/git_ui/src/remote_output.rs +++ b/crates/git_ui/src/remote_output.rs @@ -24,7 +24,7 @@ impl RemoteAction { pub enum SuccessStyle { Toast, ToastWithLog { output: RemoteCommandOutput }, - PushPrLink { link: String }, + PushPrLink { text: String, link: String }, } pub struct SuccessMessage { @@ -37,7 +37,7 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ RemoteAction::Fetch(remote) => { if output.stderr.is_empty() { SuccessMessage { - message: "Already up to date".into(), + message: "Fetch: Already up to date".into(), style: SuccessStyle::Toast, } } else { @@ -68,10 +68,9 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ Ok(files_changed) }; - - if output.stderr.starts_with("Everything up to date") { + if output.stdout.ends_with("Already up to date.\n") { SuccessMessage { - message: output.stderr.trim().to_owned(), + message: "Pull: Already up to date".into(), style: SuccessStyle::Toast, } } else if output.stdout.starts_with("Updating") { @@ -119,48 +118,42 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ } } RemoteAction::Push(branch_name, remote_ref) => { - if output.stderr.contains("* [new branch]") { - let pr_hints = [ - // GitHub - "Create a pull request", - // Bitbucket - "Create pull request", - // GitLab - "create a merge request", - ]; - let style = if pr_hints - .iter() - .any(|indicator| output.stderr.contains(indicator)) - { - let finder = LinkFinder::new(); - let first_link = finder - .links(&output.stderr) - .filter(|link| *link.kind() == LinkKind::Url) - .map(|link| link.start()..link.end()) - .next(); - if let Some(link) = first_link { - let link = output.stderr[link].to_string(); - SuccessStyle::PushPrLink { link } - } else { - SuccessStyle::ToastWithLog { output } - } - } else { - SuccessStyle::ToastWithLog { output } - }; - SuccessMessage { - message: format!("Published {} to {}", branch_name, remote_ref.name), - style, - } - } else if output.stderr.starts_with("Everything up to date") { - SuccessMessage { - message: output.stderr.trim().to_owned(), - style: SuccessStyle::Toast, - } + let message = if output.stderr.ends_with("Everything up-to-date\n") { + "Push: Everything is up-to-date".to_string() } else { - SuccessMessage { - message: format!("Pushed {} to {}", branch_name, remote_ref.name), - style: SuccessStyle::ToastWithLog { output }, - } + format!("Pushed {} to {}", branch_name, remote_ref.name) + }; + + let style = if output.stderr.ends_with("Everything up-to-date\n") { + Some(SuccessStyle::Toast) + } else if output.stderr.contains("\nremote: ") { + let pr_hints = [ + ("Create a pull request", "Create Pull Request"), // GitHub + ("Create pull request", "Create Pull Request"), // Bitbucket + ("create a merge request", "Create Merge Request"), // GitLab + ("View merge request", "View Merge Request"), // GitLab + ]; + pr_hints + .iter() + .find(|(indicator, _)| output.stderr.contains(indicator)) + .and_then(|(_, mapped)| { + let finder = LinkFinder::new(); + finder + .links(&output.stderr) + .filter(|link| *link.kind() == LinkKind::Url) + .map(|link| link.start()..link.end()) + .next() + .map(|link| SuccessStyle::PushPrLink { + text: mapped.to_string(), + link: output.stderr[link].to_string(), + }) + }) + } else { + None + }; + SuccessMessage { + message, + style: style.unwrap_or(SuccessStyle::ToastWithLog { output }), } } } @@ -169,6 +162,7 @@ pub fn format_output(action: &RemoteAction, output: RemoteCommandOutput) -> Succ #[cfg(test)] mod tests { use super::*; + use indoc::indoc; #[test] fn test_push_new_branch_pull_request() { @@ -181,8 +175,7 @@ mod tests { let output = RemoteCommandOutput { stdout: String::new(), - stderr: String::from( - " + stderr: indoc! { " Total 0 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0) remote: remote: Create a pull request for 'test' on GitHub by visiting: @@ -190,13 +183,14 @@ mod tests { remote: To example.com:test/test.git * [new branch] test -> test - ", - ), + "} + .to_string(), }; let msg = format_output(&action, output); - if let SuccessStyle::PushPrLink { link } = &msg.style { + if let SuccessStyle::PushPrLink { text: hint, link } = &msg.style { + assert_eq!(hint, "Create Pull Request"); assert_eq!(link, "https://example.com/test/test/pull/new/test"); } else { panic!("Expected PushPrLink variant"); @@ -214,7 +208,7 @@ mod tests { let output = RemoteCommandOutput { stdout: String::new(), - stderr: String::from(" + stderr: indoc! {" Total 0 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0) remote: remote: To create a merge request for test, visit: @@ -222,12 +216,14 @@ mod tests { remote: To example.com:test/test.git * [new branch] test -> test - "), - }; + "} + .to_string() + }; let msg = format_output(&action, output); - if let SuccessStyle::PushPrLink { link } = &msg.style { + if let SuccessStyle::PushPrLink { text, link } = &msg.style { + assert_eq!(text, "Create Merge Request"); assert_eq!( link, "https://example.com/test/test/-/merge_requests/new?merge_request%5Bsource_branch%5D=test" @@ -237,6 +233,39 @@ mod tests { } } + #[test] + fn test_push_branch_existing_merge_request() { + let action = RemoteAction::Push( + SharedString::new("test_branch"), + Remote { + name: SharedString::new("test_remote"), + }, + ); + + let output = RemoteCommandOutput { + stdout: String::new(), + stderr: indoc! {" + Total 0 (delta 0), reused 0 (delta 0), pack-reused 0 (from 0) + remote: + remote: View merge request for test: + remote: https://example.com/test/test/-/merge_requests/99999 + remote: + To example.com:test/test.git + + 80bd3c83be...e03d499d2e test -> test + "} + .to_string(), + }; + + let msg = format_output(&action, output); + + if let SuccessStyle::PushPrLink { text, link } = &msg.style { + assert_eq!(text, "View Merge Request"); + assert_eq!(link, "https://example.com/test/test/-/merge_requests/99999"); + } else { + panic!("Expected PushPrLink variant"); + } + } + #[test] fn test_push_new_branch_no_link() { let action = RemoteAction::Push( @@ -248,12 +277,12 @@ mod tests { let output = RemoteCommandOutput { stdout: String::new(), - stderr: String::from( - " + stderr: indoc! { " To http://example.com/test/test.git * [new branch] test -> test ", - ), + } + .to_string(), }; let msg = format_output(&action, output); @@ -261,10 +290,7 @@ mod tests { if let SuccessStyle::ToastWithLog { output } = &msg.style { assert_eq!( output.stderr, - " - To http://example.com/test/test.git - * [new branch] test -> test - " + "To http://example.com/test/test.git\n * [new branch] test -> test\n" ); } else { panic!("Expected ToastWithLog variant"); diff --git a/crates/git_ui/src/repository_selector.rs b/crates/git_ui/src/repository_selector.rs index b5865e9a85..db080ab0b4 100644 --- a/crates/git_ui/src/repository_selector.rs +++ b/crates/git_ui/src/repository_selector.rs @@ -109,7 +109,10 @@ impl Focusable for RepositorySelector { impl Render for RepositorySelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement { - div().w(self.width).child(self.picker.clone()) + div() + .key_context("GitRepositorySelector") + .w(self.width) + .child(self.picker.clone()) } } diff --git a/crates/git_ui/src/text_diff_view.rs b/crates/git_ui/src/text_diff_view.rs index e7386cf7bd..005c1e18b4 100644 --- a/crates/git_ui/src/text_diff_view.rs +++ b/crates/git_ui/src/text_diff_view.rs @@ -12,6 +12,7 @@ use language::{self, Buffer, Point}; use project::Project; use std::{ any::{Any, TypeId}, + cmp, ops::Range, pin::pin, sync::Arc, @@ -45,38 +46,60 @@ impl TextDiffView { ) -> Option<Task<Result<Entity<Self>>>> { let source_editor = diff_data.editor.clone(); - let source_editor_buffer_and_range = source_editor.update(cx, |editor, cx| { + let selection_data = source_editor.update(cx, |editor, cx| { let multibuffer = editor.buffer().read(cx); let source_buffer = multibuffer.as_singleton()?.clone(); let selections = editor.selections.all::<Point>(cx); let buffer_snapshot = source_buffer.read(cx); let first_selection = selections.first()?; - let selection_range = if first_selection.is_empty() { - Point::new(0, 0)..buffer_snapshot.max_point() - } else { - first_selection.start..first_selection.end - }; + let max_point = buffer_snapshot.max_point(); - Some((source_buffer, selection_range)) + if first_selection.is_empty() { + let full_range = Point::new(0, 0)..max_point; + return Some((source_buffer, full_range)); + } + + let start = first_selection.start; + let end = first_selection.end; + let expanded_start = Point::new(start.row, 0); + + let expanded_end = if end.column > 0 { + let next_row = end.row + 1; + cmp::min(max_point, Point::new(next_row, 0)) + } else { + end + }; + Some((source_buffer, expanded_start..expanded_end)) }); - let Some((source_buffer, selected_range)) = source_editor_buffer_and_range else { + let Some((source_buffer, expanded_selection_range)) = selection_data else { log::warn!("There should always be at least one selection in Zed. This is a bug."); return None; }; - let clipboard_text = diff_data.clipboard_text.clone(); - - let workspace = workspace.weak_handle(); - - let diff_buffer = cx.new(|cx| { - let source_buffer_snapshot = source_buffer.read(cx).snapshot(); - let diff = BufferDiff::new(&source_buffer_snapshot.text, cx); - diff + source_editor.update(cx, |source_editor, cx| { + source_editor.change_selections(Default::default(), window, cx, |s| { + s.select_ranges(vec![ + expanded_selection_range.start..expanded_selection_range.end, + ]); + }) }); - let clipboard_buffer = - build_clipboard_buffer(clipboard_text, &source_buffer, selected_range.clone(), cx); + let source_buffer_snapshot = source_buffer.read(cx).snapshot(); + let mut clipboard_text = diff_data.clipboard_text.clone(); + + if !clipboard_text.ends_with("\n") { + clipboard_text.push_str("\n"); + } + + let workspace = workspace.weak_handle(); + let diff_buffer = cx.new(|cx| BufferDiff::new(&source_buffer_snapshot.text, cx)); + let clipboard_buffer = build_clipboard_buffer( + clipboard_text, + &source_buffer, + expanded_selection_range.clone(), + cx, + ); let task = window.spawn(cx, async move |cx| { let project = workspace.update(cx, |workspace, _| workspace.project().clone())?; @@ -89,7 +112,7 @@ impl TextDiffView { clipboard_buffer, source_editor, source_buffer, - selected_range, + expanded_selection_range, diff_buffer, project, window, @@ -208,9 +231,9 @@ impl TextDiffView { } fn build_clipboard_buffer( - clipboard_text: String, + text: String, source_buffer: &Entity<Buffer>, - selected_range: Range<Point>, + replacement_range: Range<Point>, cx: &mut App, ) -> Entity<Buffer> { let source_buffer_snapshot = source_buffer.read(cx).snapshot(); @@ -219,9 +242,9 @@ fn build_clipboard_buffer( let language = source_buffer.read(cx).language().cloned(); buffer.set_language(language, cx); - let range_start = source_buffer_snapshot.point_to_offset(selected_range.start); - let range_end = source_buffer_snapshot.point_to_offset(selected_range.end); - buffer.edit([(range_start..range_end, clipboard_text)], None, cx); + let range_start = source_buffer_snapshot.point_to_offset(replacement_range.start); + let range_end = source_buffer_snapshot.point_to_offset(replacement_range.end); + buffer.edit([(range_start..range_end, text)], None, cx); buffer }) @@ -293,7 +316,7 @@ impl Item for TextDiffView { } fn telemetry_event_text(&self) -> Option<&'static str> { - Some("Diff View Opened") + Some("Selection Diff View Opened") } fn deactivated(&mut self, window: &mut Window, cx: &mut Context<Self>) { @@ -395,21 +418,13 @@ pub fn selection_location_text(editor: &Editor, cx: &App) -> Option<String> { let buffer_snapshot = buffer.snapshot(cx); let first_selection = editor.selections.disjoint.first()?; - let (start_row, start_column, end_row, end_column) = - if first_selection.start == first_selection.end { - let max_point = buffer_snapshot.max_point(); - (0, 0, max_point.row, max_point.column) - } else { - let selection_start = first_selection.start.to_point(&buffer_snapshot); - let selection_end = first_selection.end.to_point(&buffer_snapshot); + let selection_start = first_selection.start.to_point(&buffer_snapshot); + let selection_end = first_selection.end.to_point(&buffer_snapshot); - ( - selection_start.row, - selection_start.column, - selection_end.row, - selection_end.column, - ) - }; + let start_row = selection_start.row; + let start_column = selection_start.column; + let end_row = selection_end.row; + let end_column = selection_end.column; let range_text = if start_row == end_row { format!("L{}:{}-{}", start_row + 1, start_column + 1, end_column + 1) @@ -435,14 +450,13 @@ impl Render for TextDiffView { #[cfg(test)] mod tests { use super::*; - - use editor::{actions, test::editor_test_context::assert_state_with_diff}; + use editor::test::editor_test_context::assert_state_with_diff; use gpui::{TestAppContext, VisualContext}; use project::{FakeFs, Project}; use serde_json::json; use settings::{Settings, SettingsStore}; use unindent::unindent; - use util::path; + use util::{path, test::marked_text_ranges}; fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { @@ -457,52 +471,236 @@ mod tests { } #[gpui::test] - async fn test_diffing_clipboard_against_specific_selection(cx: &mut TestAppContext) { - base_test(true, cx).await; + async fn test_diffing_clipboard_against_empty_selection_uses_full_buffer_selection( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "def process_incoming_inventory(items, warehouse_id):\n pass\n", + "def process_outgoing_inventory(items, warehouse_id):\n passˇ\n", + &unindent( + " + - def process_incoming_inventory(items, warehouse_id): + + ˇdef process_outgoing_inventory(items, warehouse_id): + pass + ", + ), + "Clipboard ↔ text.txt @ L1:1-L3:1", + &format!("Clipboard ↔ {} @ L1:1-L3:1", path!("test/text.txt")), + cx, + ) + .await; } #[gpui::test] - async fn test_diffing_clipboard_against_empty_selection_uses_full_buffer( + async fn test_diffing_clipboard_against_multiline_selection_expands_to_full_lines( cx: &mut TestAppContext, ) { - base_test(false, cx).await; + base_test( + path!("/test"), + path!("/test/text.txt"), + "def process_incoming_inventory(items, warehouse_id):\n pass\n", + "«def process_outgoing_inventory(items, warehouse_id):\n passˇ»\n", + &unindent( + " + - def process_incoming_inventory(items, warehouse_id): + + ˇdef process_outgoing_inventory(items, warehouse_id): + pass + ", + ), + "Clipboard ↔ text.txt @ L1:1-L3:1", + &format!("Clipboard ↔ {} @ L1:1-L3:1", path!("test/text.txt")), + cx, + ) + .await; } - async fn base_test(select_all_text: bool, cx: &mut TestAppContext) { + #[gpui::test] + async fn test_diffing_clipboard_against_single_line_selection(cx: &mut TestAppContext) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "a", + "«bbˇ»", + &unindent( + " + - a + + ˇbb", + ), + "Clipboard ↔ text.txt @ L1:1-3", + &format!("Clipboard ↔ {} @ L1:1-3", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_with_leading_whitespace_against_line(cx: &mut TestAppContext) { + base_test( + path!("/test"), + path!("/test/text.txt"), + " a", + "«bbˇ»", + &unindent( + " + - a + + ˇbb", + ), + "Clipboard ↔ text.txt @ L1:1-3", + &format!("Clipboard ↔ {} @ L1:1-3", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_against_line_with_leading_whitespace(cx: &mut TestAppContext) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "a", + " «bbˇ»", + &unindent( + " + - a + + ˇ bb", + ), + "Clipboard ↔ text.txt @ L1:1-7", + &format!("Clipboard ↔ {} @ L1:1-7", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_against_line_with_leading_whitespace_included_in_selection( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "a", + "« bbˇ»", + &unindent( + " + - a + + ˇ bb", + ), + "Clipboard ↔ text.txt @ L1:1-7", + &format!("Clipboard ↔ {} @ L1:1-7", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_with_leading_whitespace_against_line_with_leading_whitespace( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + " a", + " «bbˇ»", + &unindent( + " + - a + + ˇ bb", + ), + "Clipboard ↔ text.txt @ L1:1-7", + &format!("Clipboard ↔ {} @ L1:1-7", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_with_leading_whitespace_against_line_with_leading_whitespace_included_in_selection( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + " a", + "« bbˇ»", + &unindent( + " + - a + + ˇ bb", + ), + "Clipboard ↔ text.txt @ L1:1-7", + &format!("Clipboard ↔ {} @ L1:1-7", path!("test/text.txt")), + cx, + ) + .await; + } + + #[gpui::test] + async fn test_diffing_clipboard_against_partial_selection_expands_to_include_trailing_characters( + cx: &mut TestAppContext, + ) { + base_test( + path!("/test"), + path!("/test/text.txt"), + "a", + "«bˇ»b", + &unindent( + " + - a + + ˇbb", + ), + "Clipboard ↔ text.txt @ L1:1-3", + &format!("Clipboard ↔ {} @ L1:1-3", path!("test/text.txt")), + cx, + ) + .await; + } + + async fn base_test( + project_root: &str, + file_path: &str, + clipboard_text: &str, + editor_text: &str, + expected_diff: &str, + expected_tab_title: &str, + expected_tab_tooltip: &str, + cx: &mut TestAppContext, + ) { init_test(cx); + let file_name = std::path::Path::new(file_path) + .file_name() + .unwrap() + .to_str() + .unwrap(); + let fs = FakeFs::new(cx.executor()); fs.insert_tree( - path!("/test"), + project_root, json!({ - "a": { - "b": { - "text.txt": "new line 1\nline 2\nnew line 3\nline 4" - } - } + file_name: editor_text }), ) .await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project = Project::test(fs, [project_root.as_ref()], cx).await; let (workspace, mut cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/test/a/b/text.txt"), cx) - }) + .update(cx, |project, cx| project.open_local_buffer(file_path, cx)) .await .unwrap(); let editor = cx.new_window_entity(|window, cx| { let mut editor = Editor::for_buffer(buffer, None, window, cx); - editor.set_text("new line 1\nline 2\nnew line 3\nline 4\n", window, cx); - - if select_all_text { - editor.select_all(&actions::SelectAll, window, cx); - } + let (unmarked_text, selection_ranges) = marked_text_ranges(editor_text, false); + editor.set_text(unmarked_text, window, cx); + editor.change_selections(Default::default(), window, cx, |s| { + s.select_ranges(selection_ranges) + }); editor }); @@ -511,7 +709,7 @@ mod tests { .update_in(cx, |workspace, window, cx| { TextDiffView::open( &DiffClipboardWithSelectionData { - clipboard_text: "old line 1\nline 2\nold line 3\nline 4\n".to_string(), + clipboard_text: clipboard_text.to_string(), editor, }, workspace, @@ -528,26 +726,14 @@ mod tests { assert_state_with_diff( &diff_view.read_with(cx, |diff_view, _| diff_view.diff_editor.clone()), &mut cx, - &unindent( - " - - old line 1 - + ˇnew line 1 - line 2 - - old line 3 - + new line 3 - line 4 - ", - ), + expected_diff, ); diff_view.read_with(cx, |diff_view, cx| { - assert_eq!( - diff_view.tab_content_text(0, cx), - "Clipboard ↔ text.txt @ L1:1-L5:1" - ); + assert_eq!(diff_view.tab_content_text(0, cx), expected_tab_title); assert_eq!( diff_view.tab_tooltip_text(cx).unwrap(), - format!("Clipboard ↔ {}", path!("test/a/b/text.txt @ L1:1-L5:1")) + expected_tab_tooltip ); }); } diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 29e81269e3..2bf49fa7d8 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -216,10 +216,6 @@ xim = { git = "https://github.com/XDeme1/xim-rs", rev = "d50d461764c2213655cd9cf x11-clipboard = { version = "0.9.3", optional = true } [target.'cfg(target_os = "windows")'.dependencies] -blade-util.workspace = true -bytemuck = "1" -blade-graphics.workspace = true -blade-macros.workspace = true flume = "0.11" rand.workspace = true windows.workspace = true @@ -240,7 +236,6 @@ util = { workspace = true, features = ["test-support"] } [target.'cfg(target_os = "windows")'.build-dependencies] embed-resource = "3.0" -naga.workspace = true [target.'cfg(target_os = "macos")'.build-dependencies] bindgen = "0.71" @@ -287,6 +282,10 @@ path = "examples/shadow.rs" name = "svg" path = "examples/svg/svg.rs" +[[example]] +name = "tab_stop" +path = "examples/tab_stop.rs" + [[example]] name = "text" path = "examples/text.rs" diff --git a/crates/gpui/build.rs b/crates/gpui/build.rs index aed4397440..93a1c15c41 100644 --- a/crates/gpui/build.rs +++ b/crates/gpui/build.rs @@ -9,7 +9,10 @@ fn main() { let target = env::var("CARGO_CFG_TARGET_OS"); println!("cargo::rustc-check-cfg=cfg(gles)"); - #[cfg(any(not(target_os = "macos"), feature = "macos-blade"))] + #[cfg(any( + not(any(target_os = "macos", target_os = "windows")), + all(target_os = "macos", feature = "macos-blade") + ))] check_wgsl_shaders(); match target.as_deref() { @@ -17,21 +20,18 @@ fn main() { #[cfg(target_os = "macos")] macos::build(); } - #[cfg(all(target_os = "windows", feature = "windows-manifest"))] Ok("windows") => { - let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); - let rc_file = std::path::Path::new("resources/windows/gpui.rc"); - println!("cargo:rerun-if-changed={}", manifest.display()); - println!("cargo:rerun-if-changed={}", rc_file.display()); - embed_resource::compile(rc_file, embed_resource::NONE) - .manifest_required() - .unwrap(); + #[cfg(target_os = "windows")] + windows::build(); } _ => (), }; } -#[allow(dead_code)] +#[cfg(any( + not(any(target_os = "macos", target_os = "windows")), + all(target_os = "macos", feature = "macos-blade") +))] fn check_wgsl_shaders() { use std::path::PathBuf; use std::process; @@ -128,6 +128,7 @@ mod macos { "AtlasTile".into(), "PathRasterizationInputIndex".into(), "PathVertex_ScaledPixels".into(), + "PathRasterizationVertex".into(), "ShadowInputIndex".into(), "Shadow".into(), "QuadInputIndex".into(), @@ -242,3 +243,215 @@ mod macos { } } } + +#[cfg(target_os = "windows")] +mod windows { + use std::{ + fs, + io::Write, + path::{Path, PathBuf}, + process::{self, Command}, + }; + + pub(super) fn build() { + // Compile HLSL shaders + #[cfg(not(debug_assertions))] + compile_shaders(); + + // Embed the Windows manifest and resource file + #[cfg(feature = "windows-manifest")] + embed_resource(); + } + + #[cfg(feature = "windows-manifest")] + fn embed_resource() { + let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); + let rc_file = std::path::Path::new("resources/windows/gpui.rc"); + println!("cargo:rerun-if-changed={}", manifest.display()); + println!("cargo:rerun-if-changed={}", rc_file.display()); + embed_resource::compile(rc_file, embed_resource::NONE) + .manifest_required() + .unwrap(); + } + + /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. + fn compile_shaders() { + let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("src/platform/windows/shaders.hlsl"); + let out_dir = std::env::var("OUT_DIR").unwrap(); + + println!("cargo:rerun-if-changed={}", shader_path.display()); + + // Check if fxc.exe is available + let fxc_path = find_fxc_compiler(); + + // Define all modules + let modules = [ + "quad", + "shadow", + "path_rasterization", + "path_sprite", + "underline", + "monochrome_sprite", + "polychrome_sprite", + ]; + + let rust_binding_path = format!("{}/shaders_bytes.rs", out_dir); + if Path::new(&rust_binding_path).exists() { + fs::remove_file(&rust_binding_path) + .expect("Failed to remove existing Rust binding file"); + } + for module in modules { + compile_shader_for_module( + module, + &out_dir, + &fxc_path, + shader_path.to_str().unwrap(), + &rust_binding_path, + ); + } + + { + let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("src/platform/windows/color_text_raster.hlsl"); + compile_shader_for_module( + "emoji_rasterization", + &out_dir, + &fxc_path, + shader_path.to_str().unwrap(), + &rust_binding_path, + ); + } + } + + /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. + fn find_fxc_compiler() -> String { + // Check environment variable + if let Ok(path) = std::env::var("GPUI_FXC_PATH") { + if Path::new(&path).exists() { + return path; + } + } + + // Try to find in PATH + // NOTE: This has to be `where.exe` on Windows, not `where`, it must be ended with `.exe` + if let Ok(output) = std::process::Command::new("where.exe") + .arg("fxc.exe") + .output() + { + if output.status.success() { + let path = String::from_utf8_lossy(&output.stdout); + return path.trim().to_string(); + } + } + + // Check the default path + if Path::new(r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe") + .exists() + { + return r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe" + .to_string(); + } + + panic!("Failed to find fxc.exe"); + } + + fn compile_shader_for_module( + module: &str, + out_dir: &str, + fxc_path: &str, + shader_path: &str, + rust_binding_path: &str, + ) { + // Compile vertex shader + let output_file = format!("{}/{}_vs.h", out_dir, module); + let const_name = format!("{}_VERTEX_BYTES", module.to_uppercase()); + compile_shader_impl( + fxc_path, + &format!("{module}_vertex"), + &output_file, + &const_name, + shader_path, + "vs_4_1", + ); + generate_rust_binding(&const_name, &output_file, &rust_binding_path); + + // Compile fragment shader + let output_file = format!("{}/{}_ps.h", out_dir, module); + let const_name = format!("{}_FRAGMENT_BYTES", module.to_uppercase()); + compile_shader_impl( + fxc_path, + &format!("{module}_fragment"), + &output_file, + &const_name, + shader_path, + "ps_4_1", + ); + generate_rust_binding(&const_name, &output_file, &rust_binding_path); + } + + fn compile_shader_impl( + fxc_path: &str, + entry_point: &str, + output_path: &str, + var_name: &str, + shader_path: &str, + target: &str, + ) { + let output = Command::new(fxc_path) + .args([ + "/T", + target, + "/E", + entry_point, + "/Fh", + output_path, + "/Vn", + var_name, + "/O3", + shader_path, + ]) + .output(); + + match output { + Ok(result) => { + if result.status.success() { + return; + } + eprintln!( + "Shader compilation failed for {}:\n{}", + entry_point, + String::from_utf8_lossy(&result.stderr) + ); + process::exit(1); + } + Err(e) => { + eprintln!("Failed to run fxc for {}: {}", entry_point, e); + process::exit(1); + } + } + } + + fn generate_rust_binding(const_name: &str, head_file: &str, output_path: &str) { + let header_content = fs::read_to_string(head_file).expect("Failed to read header file"); + let const_definition = { + let global_var_start = header_content.find("const BYTE").unwrap(); + let global_var = &header_content[global_var_start..]; + let equal = global_var.find('=').unwrap(); + global_var[equal + 1..].trim() + }; + let rust_binding = format!( + "const {}: &[u8] = &{}\n", + const_name, + const_definition.replace('{', "[").replace('}', "]") + ); + let mut options = fs::OpenOptions::new() + .create(true) + .append(true) + .open(output_path) + .expect("Failed to open Rust binding file"); + options + .write_all(rust_binding.as_bytes()) + .expect("Failed to write Rust binding file"); + } +} diff --git a/crates/gpui/examples/painting.rs b/crates/gpui/examples/painting.rs index ff4b64cbda..668aed2377 100644 --- a/crates/gpui/examples/painting.rs +++ b/crates/gpui/examples/painting.rs @@ -1,11 +1,12 @@ use gpui::{ Application, Background, Bounds, ColorSpace, Context, MouseDownEvent, Path, PathBuilder, PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowOptions, canvas, - div, linear_color_stop, linear_gradient, point, prelude::*, px, rgb, size, + div, linear_color_stop, linear_gradient, point, prelude::*, px, quad, rgb, size, }; struct PaintingViewer { default_lines: Vec<(Path<Pixels>, Background)>, + background_quads: Vec<(Bounds<Pixels>, Background)>, lines: Vec<Vec<Point<Pixels>>>, start: Point<Pixels>, dashed: bool, @@ -16,12 +17,148 @@ impl PaintingViewer { fn new(_window: &mut Window, _cx: &mut Context<Self>) -> Self { let mut lines = vec![]; + // Black squares beneath transparent paths. + let background_quads = vec![ + ( + Bounds { + origin: point(px(70.), px(70.)), + size: size(px(40.), px(40.)), + }, + gpui::black().into(), + ), + ( + Bounds { + origin: point(px(170.), px(70.)), + size: size(px(40.), px(40.)), + }, + gpui::black().into(), + ), + ( + Bounds { + origin: point(px(270.), px(70.)), + size: size(px(40.), px(40.)), + }, + gpui::black().into(), + ), + ( + Bounds { + origin: point(px(370.), px(70.)), + size: size(px(40.), px(40.)), + }, + gpui::black().into(), + ), + ( + Bounds { + origin: point(px(450.), px(50.)), + size: size(px(80.), px(80.)), + }, + gpui::black().into(), + ), + ]; + + // 50% opaque red path that extends across black quad. + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(50.), px(50.))); + builder.line_to(point(px(130.), px(50.))); + builder.line_to(point(px(130.), px(130.))); + builder.line_to(point(px(50.), px(130.))); + builder.close(); + let path = builder.build().unwrap(); + let mut red = rgb(0xFF0000); + red.a = 0.5; + lines.push((path, red.into())); + + // 50% opaque blue path that extends across black quad. + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(150.), px(50.))); + builder.line_to(point(px(230.), px(50.))); + builder.line_to(point(px(230.), px(130.))); + builder.line_to(point(px(150.), px(130.))); + builder.close(); + let path = builder.build().unwrap(); + let mut blue = rgb(0x0000FF); + blue.a = 0.5; + lines.push((path, blue.into())); + + // 50% opaque green path that extends across black quad. + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(250.), px(50.))); + builder.line_to(point(px(330.), px(50.))); + builder.line_to(point(px(330.), px(130.))); + builder.line_to(point(px(250.), px(130.))); + builder.close(); + let path = builder.build().unwrap(); + let mut green = rgb(0x00FF00); + green.a = 0.5; + lines.push((path, green.into())); + + // 50% opaque black path that extends across black quad. + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(350.), px(50.))); + builder.line_to(point(px(430.), px(50.))); + builder.line_to(point(px(430.), px(130.))); + builder.line_to(point(px(350.), px(130.))); + builder.close(); + let path = builder.build().unwrap(); + let mut black = rgb(0x000000); + black.a = 0.5; + lines.push((path, black.into())); + + // Two 50% opaque red circles overlapping - center should be darker red + let mut builder = PathBuilder::fill(); + let center = point(px(530.), px(85.)); + let radius = px(30.); + builder.move_to(point(center.x + radius, center.y)); + builder.arc_to( + point(radius, radius), + px(0.), + false, + false, + point(center.x - radius, center.y), + ); + builder.arc_to( + point(radius, radius), + px(0.), + false, + false, + point(center.x + radius, center.y), + ); + builder.close(); + let path = builder.build().unwrap(); + let mut red1 = rgb(0xFF0000); + red1.a = 0.5; + lines.push((path, red1.into())); + + let mut builder = PathBuilder::fill(); + let center = point(px(570.), px(85.)); + let radius = px(30.); + builder.move_to(point(center.x + radius, center.y)); + builder.arc_to( + point(radius, radius), + px(0.), + false, + false, + point(center.x - radius, center.y), + ); + builder.arc_to( + point(radius, radius), + px(0.), + false, + false, + point(center.x + radius, center.y), + ); + builder.close(); + let path = builder.build().unwrap(); + let mut red2 = rgb(0xFF0000); + red2.a = 0.5; + lines.push((path, red2.into())); + // draw a Rust logo let mut builder = lyon::path::Path::svg_builder(); lyon::extra::rust_logo::build_logo_path(&mut builder); // move down the Path let mut builder: PathBuilder = builder.into(); - builder.translate(point(px(10.), px(100.))); + builder.translate(point(px(10.), px(200.))); builder.scale(0.9); let path = builder.build().unwrap(); lines.push((path, gpui::black().into())); @@ -30,10 +167,10 @@ impl PaintingViewer { let mut builder = PathBuilder::fill(); builder.add_polygon( &[ - point(px(150.), px(200.)), - point(px(200.), px(125.)), - point(px(200.), px(175.)), - point(px(250.), px(100.)), + point(px(150.), px(300.)), + point(px(200.), px(225.)), + point(px(200.), px(275.)), + point(px(250.), px(200.)), ], false, ); @@ -42,17 +179,17 @@ impl PaintingViewer { // draw a ⭐ let mut builder = PathBuilder::fill(); - builder.move_to(point(px(350.), px(100.))); - builder.line_to(point(px(370.), px(160.))); - builder.line_to(point(px(430.), px(160.))); - builder.line_to(point(px(380.), px(200.))); - builder.line_to(point(px(400.), px(260.))); - builder.line_to(point(px(350.), px(220.))); - builder.line_to(point(px(300.), px(260.))); - builder.line_to(point(px(320.), px(200.))); - builder.line_to(point(px(270.), px(160.))); - builder.line_to(point(px(330.), px(160.))); - builder.line_to(point(px(350.), px(100.))); + builder.move_to(point(px(350.), px(200.))); + builder.line_to(point(px(370.), px(260.))); + builder.line_to(point(px(430.), px(260.))); + builder.line_to(point(px(380.), px(300.))); + builder.line_to(point(px(400.), px(360.))); + builder.line_to(point(px(350.), px(320.))); + builder.line_to(point(px(300.), px(360.))); + builder.line_to(point(px(320.), px(300.))); + builder.line_to(point(px(270.), px(260.))); + builder.line_to(point(px(330.), px(260.))); + builder.line_to(point(px(350.), px(200.))); let path = builder.build().unwrap(); lines.push(( path, @@ -66,7 +203,7 @@ impl PaintingViewer { // draw linear gradient let square_bounds = Bounds { - origin: point(px(450.), px(100.)), + origin: point(px(450.), px(200.)), size: size(px(200.), px(80.)), }; let height = square_bounds.size.height; @@ -96,31 +233,31 @@ impl PaintingViewer { // draw a pie chart let center = point(px(96.), px(96.)); - let pie_center = point(px(775.), px(155.)); + let pie_center = point(px(775.), px(255.)); let segments = [ ( - point(px(871.), px(155.)), - point(px(747.), px(63.)), + point(px(871.), px(255.)), + point(px(747.), px(163.)), rgb(0x1374e9), ), ( - point(px(747.), px(63.)), - point(px(679.), px(163.)), + point(px(747.), px(163.)), + point(px(679.), px(263.)), rgb(0xe13527), ), ( - point(px(679.), px(163.)), - point(px(754.), px(249.)), + point(px(679.), px(263.)), + point(px(754.), px(349.)), rgb(0x0751ce), ), ( - point(px(754.), px(249.)), - point(px(854.), px(210.)), + point(px(754.), px(349.)), + point(px(854.), px(310.)), rgb(0x209742), ), ( - point(px(854.), px(210.)), - point(px(871.), px(155.)), + point(px(854.), px(310.)), + point(px(871.), px(255.)), rgb(0xfbc10a), ), ]; @@ -140,11 +277,11 @@ impl PaintingViewer { .with_line_width(1.) .with_line_join(lyon::path::LineJoin::Bevel); let mut builder = PathBuilder::stroke(px(1.)).with_style(PathStyle::Stroke(options)); - builder.move_to(point(px(40.), px(320.))); + builder.move_to(point(px(40.), px(420.))); for i in 1..50 { builder.line_to(point( px(40.0 + i as f32 * 10.0), - px(320.0 + (i as f32 * 10.0).sin() * 40.0), + px(420.0 + (i as f32 * 10.0).sin() * 40.0), )); } let path = builder.build().unwrap(); @@ -152,6 +289,7 @@ impl PaintingViewer { Self { default_lines: lines.clone(), + background_quads, lines: vec![], start: point(px(0.), px(0.)), dashed: false, @@ -185,6 +323,7 @@ fn button( impl Render for PaintingViewer { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let default_lines = self.default_lines.clone(); + let background_quads = self.background_quads.clone(); let lines = self.lines.clone(); let dashed = self.dashed; @@ -221,6 +360,19 @@ impl Render for PaintingViewer { canvas( move |_, _, _| {}, move |_, _, window, _| { + // First draw background quads + for (bounds, color) in background_quads.iter() { + window.paint_quad(quad( + *bounds, + px(0.), + *color, + px(0.), + gpui::transparent_black(), + Default::default(), + )); + } + + // Then draw the default paths on top for (path, color) in default_lines { window.paint_path(path, color); } @@ -303,6 +455,10 @@ fn main() { |window, cx| cx.new(|cx| PaintingViewer::new(window, cx)), ) .unwrap(); + cx.on_window_closed(|cx| { + cx.quit(); + }) + .detach(); cx.activate(true); }); } diff --git a/crates/gpui/examples/paths_bench.rs b/crates/gpui/examples/paths_bench.rs new file mode 100644 index 0000000000..a801889ae8 --- /dev/null +++ b/crates/gpui/examples/paths_bench.rs @@ -0,0 +1,92 @@ +use gpui::{ + Application, Background, Bounds, ColorSpace, Context, Path, PathBuilder, Pixels, Render, + TitlebarOptions, Window, WindowBounds, WindowOptions, canvas, div, linear_color_stop, + linear_gradient, point, prelude::*, px, rgb, size, +}; + +const DEFAULT_WINDOW_WIDTH: Pixels = px(1024.0); +const DEFAULT_WINDOW_HEIGHT: Pixels = px(768.0); + +struct PaintingViewer { + default_lines: Vec<(Path<Pixels>, Background)>, + _painting: bool, +} + +impl PaintingViewer { + fn new(_window: &mut Window, _cx: &mut Context<Self>) -> Self { + let mut lines = vec![]; + + // draw a lightening bolt ⚡ + for _ in 0..2000 { + // draw a ⭐ + let mut builder = PathBuilder::fill(); + builder.move_to(point(px(350.), px(100.))); + builder.line_to(point(px(370.), px(160.))); + builder.line_to(point(px(430.), px(160.))); + builder.line_to(point(px(380.), px(200.))); + builder.line_to(point(px(400.), px(260.))); + builder.line_to(point(px(350.), px(220.))); + builder.line_to(point(px(300.), px(260.))); + builder.line_to(point(px(320.), px(200.))); + builder.line_to(point(px(270.), px(160.))); + builder.line_to(point(px(330.), px(160.))); + builder.line_to(point(px(350.), px(100.))); + let path = builder.build().unwrap(); + lines.push(( + path, + linear_gradient( + 180., + linear_color_stop(rgb(0xFACC15), 0.7), + linear_color_stop(rgb(0xD56D0C), 1.), + ) + .color_space(ColorSpace::Oklab), + )); + } + + Self { + default_lines: lines, + _painting: false, + } + } +} + +impl Render for PaintingViewer { + fn render(&mut self, window: &mut Window, _: &mut Context<Self>) -> impl IntoElement { + window.request_animation_frame(); + let lines = self.default_lines.clone(); + div().size_full().child( + canvas( + move |_, _, _| {}, + move |_, _, window, _| { + for (path, color) in lines { + window.paint_path(path, color); + } + }, + ) + .size_full(), + ) + } +} + +fn main() { + Application::new().run(|cx| { + cx.open_window( + WindowOptions { + titlebar: Some(TitlebarOptions { + title: Some("Vulkan".into()), + ..Default::default() + }), + focus: true, + window_bounds: Some(WindowBounds::Windowed(Bounds::centered( + None, + size(DEFAULT_WINDOW_WIDTH, DEFAULT_WINDOW_HEIGHT), + cx, + ))), + ..Default::default() + }, + |window, cx| cx.new(|cx| PaintingViewer::new(window, cx)), + ) + .unwrap(); + cx.activate(true); + }); +} diff --git a/crates/gpui/examples/tab_stop.rs b/crates/gpui/examples/tab_stop.rs index 9c58b52a5e..8dbcbeccb7 100644 --- a/crates/gpui/examples/tab_stop.rs +++ b/crates/gpui/examples/tab_stop.rs @@ -6,6 +6,7 @@ use gpui::{ actions!(example, [Tab, TabPrev]); struct Example { + focus_handle: FocusHandle, items: Vec<FocusHandle>, message: SharedString, } @@ -20,8 +21,11 @@ impl Example { cx.focus_handle().tab_index(2).tab_stop(true), ]; - window.focus(items.first().unwrap()); + let focus_handle = cx.focus_handle(); + window.focus(&focus_handle); + Self { + focus_handle, items, message: SharedString::from("Press `Tab`, `Shift-Tab` to switch focus."), } @@ -40,6 +44,10 @@ impl Example { impl Render for Example { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + fn tab_stop_style<T: Styled>(this: T) -> T { + this.border_3().border_color(gpui::blue()) + } + fn button(id: impl Into<ElementId>) -> Stateful<Div> { div() .id(id) @@ -52,12 +60,13 @@ impl Render for Example { .border_color(gpui::black()) .bg(gpui::black()) .text_color(gpui::white()) - .focus(|this| this.border_color(gpui::blue())) + .focus(tab_stop_style) .shadow_sm() } div() .id("app") + .track_focus(&self.focus_handle) .on_action(cx.listener(Self::on_tab)) .on_action(cx.listener(Self::on_tab_prev)) .size_full() @@ -86,7 +95,7 @@ impl Render for Example { .border_color(gpui::black()) .when( item_handle.tab_stop && item_handle.is_focused(window), - |this| this.border_color(gpui::blue()), + tab_stop_style, ) .map(|this| match item_handle.tab_stop { true => this @@ -102,8 +111,24 @@ impl Render for Example { .flex_row() .gap_3() .items_center() - .child(button("el1").tab_index(4).child("Button 1")) - .child(button("el2").tab_index(5).child("Button 2")), + .child( + button("el1") + .tab_index(4) + .child("Button 1") + .on_click(cx.listener(|this, _, _, cx| { + this.message = "You have clicked Button 1.".into(); + cx.notify(); + })), + ) + .child( + button("el2") + .tab_index(5) + .child("Button 2") + .on_click(cx.listener(|this, _, _, cx| { + this.message = "You have clicked Button 2.".into(); + cx.notify(); + })), + ), ) } } diff --git a/crates/gpui/examples/text.rs b/crates/gpui/examples/text.rs index 19214aebde..1166bb2795 100644 --- a/crates/gpui/examples/text.rs +++ b/crates/gpui/examples/text.rs @@ -198,7 +198,7 @@ impl RenderOnce for CharacterGrid { "χ", "ψ", "∂", "а", "в", "Ж", "ж", "З", "з", "К", "к", "л", "м", "Н", "н", "Р", "р", "У", "у", "ф", "ч", "ь", "ы", "Э", "э", "Я", "я", "ij", "öẋ", ".,", "⣝⣑", "~", "*", "_", "^", "`", "'", "(", "{", "«", "#", "&", "@", "$", "¢", "%", "|", "?", "¶", "µ", - "❮", "<=", "!=", "==", "--", "++", "=>", "->", + "❮", "<=", "!=", "==", "--", "++", "=>", "->", "🏀", "🎊", "😍", "❤️", "👍", "👎", ]; let columns = 11; diff --git a/crates/gpui/examples/window_shadow.rs b/crates/gpui/examples/window_shadow.rs index 06dde91133..469017da79 100644 --- a/crates/gpui/examples/window_shadow.rs +++ b/crates/gpui/examples/window_shadow.rs @@ -165,8 +165,8 @@ impl Render for WindowShadow { }, ) .on_click(|e, window, _| { - if e.down.button == MouseButton::Right { - window.show_window_menu(e.up.position); + if e.is_right_click() { + window.show_window_menu(e.position()); } }) .text_color(black()) diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index 759d33563e..ded7bae316 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -2023,6 +2023,10 @@ impl HttpClient for NullHttpClient { .boxed() } + fn user_agent(&self) -> Option<&http_client::http::HeaderValue> { + None + } + fn proxy(&self) -> Option<&Url> { None } diff --git a/crates/gpui/src/color.rs b/crates/gpui/src/color.rs index a16c8f46be..639c84c101 100644 --- a/crates/gpui/src/color.rs +++ b/crates/gpui/src/color.rs @@ -35,6 +35,7 @@ pub(crate) fn swap_rgba_pa_to_bgra(color: &mut [u8]) { /// An RGBA color #[derive(PartialEq, Clone, Copy, Default)] +#[repr(C)] pub struct Rgba { /// The red component of the color, in the range 0.0 to 1.0 pub r: f32, diff --git a/crates/gpui/src/elements/div.rs b/crates/gpui/src/elements/div.rs index 4655c92409..09afbff929 100644 --- a/crates/gpui/src/elements/div.rs +++ b/crates/gpui/src/elements/div.rs @@ -19,10 +19,10 @@ use crate::{ Action, AnyDrag, AnyElement, AnyTooltip, AnyView, App, Bounds, ClickEvent, DispatchPhase, Element, ElementId, Entity, FocusHandle, Global, GlobalElementId, Hitbox, HitboxBehavior, HitboxId, InspectorElementId, IntoElement, IsZero, KeyContext, KeyDownEvent, KeyUpEvent, - LayoutId, ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, - Overflow, ParentElement, Pixels, Point, Render, ScrollWheelEvent, SharedString, Size, Style, - StyleRefinement, Styled, Task, TooltipId, Visibility, Window, WindowControlArea, point, px, - size, + KeyboardButton, KeyboardClickEvent, LayoutId, ModifiersChangedEvent, MouseButton, + MouseClickEvent, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Overflow, ParentElement, Pixels, + Point, Render, ScrollWheelEvent, SharedString, Size, Style, StyleRefinement, Styled, Task, + TooltipId, Visibility, Window, WindowControlArea, point, px, size, }; use collections::HashMap; use refineable::Refineable; @@ -484,10 +484,9 @@ impl Interactivity { where Self: Sized, { - self.click_listeners - .push(Box::new(move |event, window, cx| { - listener(event, window, cx) - })); + self.click_listeners.push(Rc::new(move |event, window, cx| { + listener(event, window, cx) + })); } /// On drag initiation, this callback will be used to create a new view to render the dragged value for a @@ -1156,7 +1155,7 @@ pub(crate) type MouseMoveListener = pub(crate) type ScrollWheelListener = Box<dyn Fn(&ScrollWheelEvent, DispatchPhase, &Hitbox, &mut Window, &mut App) + 'static>; -pub(crate) type ClickListener = Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>; +pub(crate) type ClickListener = Rc<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>; pub(crate) type DragListener = Box<dyn Fn(&dyn Any, Point<Pixels>, &mut Window, &mut App) -> AnyView + 'static>; @@ -1334,7 +1333,6 @@ impl Element for Div { } else if let Some(scroll_handle) = self.interactivity.tracked_scroll_handle.as_ref() { let mut state = scroll_handle.0.borrow_mut(); state.child_bounds = Vec::with_capacity(request_layout.child_layout_ids.len()); - state.bounds = bounds; for child_layout_id in &request_layout.child_layout_ids { let child_bounds = window.layout_bounds(*child_layout_id); child_min = child_min.min(&child_bounds.origin); @@ -1706,6 +1704,7 @@ impl Interactivity { if let Some(mut scroll_handle_state) = tracked_scroll_handle { scroll_handle_state.max_offset = scroll_max; + scroll_handle_state.bounds = bounds; } *scroll_offset @@ -1950,6 +1949,12 @@ impl Interactivity { window: &mut Window, cx: &mut App, ) { + let is_focused = self + .tracked_focus_handle + .as_ref() + .map(|handle| handle.is_focused(window)) + .unwrap_or(false); + // If this element can be focused, register a mouse down listener // that will automatically transfer focus when hitting the element. // This behavior can be suppressed by using `cx.prevent_default()`. @@ -2113,6 +2118,39 @@ impl Interactivity { } }); + if is_focused { + // Press enter, space to trigger click, when the element is focused. + window.on_key_event({ + let click_listeners = click_listeners.clone(); + let hitbox = hitbox.clone(); + move |event: &KeyUpEvent, phase, window, cx| { + if phase.bubble() && !window.default_prevented() { + let stroke = &event.keystroke; + let keyboard_button = if stroke.key.eq("enter") { + Some(KeyboardButton::Enter) + } else if stroke.key.eq("space") { + Some(KeyboardButton::Space) + } else { + None + }; + + if let Some(button) = keyboard_button + && !stroke.modifiers.modified() + { + let click_event = ClickEvent::Keyboard(KeyboardClickEvent { + button, + bounds: hitbox.bounds, + }); + + for listener in &click_listeners { + listener(&click_event, window, cx); + } + } + } + } + }); + } + window.on_mouse_event({ let mut captured_mouse_down = None; let hitbox = hitbox.clone(); @@ -2138,10 +2176,10 @@ impl Interactivity { // Fire click handlers during the bubble phase. DispatchPhase::Bubble => { if let Some(mouse_down) = captured_mouse_down.take() { - let mouse_click = ClickEvent { + let mouse_click = ClickEvent::Mouse(MouseClickEvent { down: mouse_down, up: event.clone(), - }; + }); for listener in &click_listeners { listener(&mouse_click, window, cx); } @@ -3007,11 +3045,6 @@ impl ScrollHandle { self.0.borrow().bounds } - /// Set the bounds into which this child is painted - pub(super) fn set_bounds(&self, bounds: Bounds<Pixels>) { - self.0.borrow_mut().bounds = bounds; - } - /// Get the bounds for a specific child. pub fn bounds_for_item(&self, ix: usize) -> Option<Bounds<Pixels>> { self.0.borrow().child_bounds.get(ix).cloned() diff --git a/crates/gpui/src/elements/list.rs b/crates/gpui/src/elements/list.rs index 328a6a4cc1..39f38bdc69 100644 --- a/crates/gpui/src/elements/list.rs +++ b/crates/gpui/src/elements/list.rs @@ -16,12 +16,18 @@ use crate::{ use collections::VecDeque; use refineable::Refineable as _; use std::{cell::RefCell, ops::Range, rc::Rc}; -use sum_tree::{Bias, SumTree}; +use sum_tree::{Bias, Dimensions, SumTree}; + +type RenderItemFn = dyn FnMut(usize, &mut Window, &mut App) -> AnyElement + 'static; /// Construct a new list element -pub fn list(state: ListState) -> List { +pub fn list( + state: ListState, + render_item: impl FnMut(usize, &mut Window, &mut App) -> AnyElement + 'static, +) -> List { List { state, + render_item: Box::new(render_item), style: StyleRefinement::default(), sizing_behavior: ListSizingBehavior::default(), } @@ -30,6 +36,7 @@ pub fn list(state: ListState) -> List { /// A list element pub struct List { state: ListState, + render_item: Box<RenderItemFn>, style: StyleRefinement, sizing_behavior: ListSizingBehavior, } @@ -55,7 +62,6 @@ impl std::fmt::Debug for ListState { struct StateInner { last_layout_bounds: Option<Bounds<Pixels>>, last_padding: Option<Edges<Pixels>>, - render_item: Box<dyn FnMut(usize, &mut Window, &mut App) -> AnyElement>, items: SumTree<ListItem>, logical_scroll_top: Option<ListOffset>, alignment: ListAlignment, @@ -186,19 +192,10 @@ impl ListState { /// above and below the visible area. Elements within this area will /// be measured even though they are not visible. This can help ensure /// that the list doesn't flicker or pop in when scrolling. - pub fn new<R>( - item_count: usize, - alignment: ListAlignment, - overdraw: Pixels, - render_item: R, - ) -> Self - where - R: 'static + FnMut(usize, &mut Window, &mut App) -> AnyElement, - { + pub fn new(item_count: usize, alignment: ListAlignment, overdraw: Pixels) -> Self { let this = Self(Rc::new(RefCell::new(StateInner { last_layout_bounds: None, last_padding: None, - render_item: Box::new(render_item), items: SumTree::default(), logical_scroll_top: None, alignment, @@ -371,14 +368,14 @@ impl ListState { return None; } - let mut cursor = state.items.cursor::<(Count, Height)>(&()); + let mut cursor = state.items.cursor::<Dimensions<Count, Height>>(&()); cursor.seek(&Count(scroll_top.item_ix), Bias::Right); let scroll_top = cursor.start().1.0 + scroll_top.offset_in_item; cursor.seek_forward(&Count(ix), Bias::Right); if let Some(&ListItem::Measured { size, .. }) = cursor.item() { - let &(Count(count), Height(top)) = cursor.start(); + let &Dimensions(Count(count), Height(top), _) = cursor.start(); if count == ix { let top = bounds.top() + top - scroll_top; return Some(Bounds::from_corners( @@ -532,6 +529,7 @@ impl StateInner { available_width: Option<Pixels>, available_height: Pixels, padding: &Edges<Pixels>, + render_item: &mut RenderItemFn, window: &mut Window, cx: &mut App, ) -> LayoutItemsResponse { @@ -566,7 +564,7 @@ impl StateInner { // If we're within the visible area or the height wasn't cached, render and measure the item's element if visible_height < available_height || size.is_none() { let item_index = scroll_top.item_ix + ix; - let mut element = (self.render_item)(item_index, window, cx); + let mut element = render_item(item_index, window, cx); let element_size = element.layout_as_root(available_item_space, window, cx); size = Some(element_size); if visible_height < available_height { @@ -601,7 +599,7 @@ impl StateInner { cursor.prev(); if let Some(item) = cursor.item() { let item_index = cursor.start().0; - let mut element = (self.render_item)(item_index, window, cx); + let mut element = render_item(item_index, window, cx); let element_size = element.layout_as_root(available_item_space, window, cx); let focus_handle = item.focus_handle(); rendered_height += element_size.height; @@ -650,7 +648,7 @@ impl StateInner { let size = if let ListItem::Measured { size, .. } = item { *size } else { - let mut element = (self.render_item)(cursor.start().0, window, cx); + let mut element = render_item(cursor.start().0, window, cx); element.layout_as_root(available_item_space, window, cx) }; @@ -683,7 +681,7 @@ impl StateInner { while let Some(item) = cursor.item() { if item.contains_focused(window, cx) { let item_index = cursor.start().0; - let mut element = (self.render_item)(cursor.start().0, window, cx); + let mut element = render_item(cursor.start().0, window, cx); let size = element.layout_as_root(available_item_space, window, cx); item_layouts.push_back(ItemLayout { index: item_index, @@ -708,6 +706,7 @@ impl StateInner { bounds: Bounds<Pixels>, padding: Edges<Pixels>, autoscroll: bool, + render_item: &mut RenderItemFn, window: &mut Window, cx: &mut App, ) -> Result<LayoutItemsResponse, ListOffset> { @@ -716,6 +715,7 @@ impl StateInner { Some(bounds.size.width), bounds.size.height, &padding, + render_item, window, cx, ); @@ -753,8 +753,7 @@ impl StateInner { let Some(item) = cursor.item() else { break }; let size = item.size().unwrap_or_else(|| { - let mut item = - (self.render_item)(cursor.start().0, window, cx); + let mut item = render_item(cursor.start().0, window, cx); let item_available_size = size( bounds.size.width.into(), AvailableSpace::MinContent, @@ -876,8 +875,14 @@ impl Element for List { window.rem_size(), ); - let layout_response = - state.layout_items(None, available_height, &padding, window, cx); + let layout_response = state.layout_items( + None, + available_height, + &padding, + &mut self.render_item, + window, + cx, + ); let max_element_width = layout_response.max_item_width; let summary = state.items.summary(); @@ -951,15 +956,16 @@ impl Element for List { let padding = style .padding .to_pixels(bounds.size.into(), window.rem_size()); - let layout = match state.prepaint_items(bounds, padding, true, window, cx) { - Ok(layout) => layout, - Err(autoscroll_request) => { - state.logical_scroll_top = Some(autoscroll_request); - state - .prepaint_items(bounds, padding, false, window, cx) - .unwrap() - } - }; + let layout = + match state.prepaint_items(bounds, padding, true, &mut self.render_item, window, cx) { + Ok(layout) => layout, + Err(autoscroll_request) => { + state.logical_scroll_top = Some(autoscroll_request); + state + .prepaint_items(bounds, padding, false, &mut self.render_item, window, cx) + .unwrap() + } + }; state.last_layout_bounds = Some(bounds); state.last_padding = Some(padding); @@ -1108,9 +1114,7 @@ mod test { let cx = cx.add_empty_window(); - let state = ListState::new(5, crate::ListAlignment::Top, px(10.), |_, _, _| { - div().h(px(10.)).w_full().into_any() - }); + let state = ListState::new(5, crate::ListAlignment::Top, px(10.)); // Ensure that the list is scrolled to the top state.scroll_to(gpui::ListOffset { @@ -1121,7 +1125,11 @@ mod test { struct TestView(ListState); impl Render for TestView { fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - list(self.0.clone()).w_full().h_full() + list(self.0.clone(), |_, _, _| { + div().h(px(10.)).w_full().into_any() + }) + .w_full() + .h_full() } } @@ -1154,14 +1162,16 @@ mod test { let cx = cx.add_empty_window(); - let state = ListState::new(5, crate::ListAlignment::Top, px(10.), |_, _, _| { - div().h(px(20.)).w_full().into_any() - }); + let state = ListState::new(5, crate::ListAlignment::Top, px(10.)); struct TestView(ListState); impl Render for TestView { fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - list(self.0.clone()).w_full().h_full() + list(self.0.clone(), |_, _, _| { + div().h(px(20.)).w_full().into_any() + }) + .w_full() + .h_full() } } diff --git a/crates/gpui/src/elements/uniform_list.rs b/crates/gpui/src/elements/uniform_list.rs index 52e2015c20..cdf90d4eb8 100644 --- a/crates/gpui/src/elements/uniform_list.rs +++ b/crates/gpui/src/elements/uniform_list.rs @@ -88,15 +88,24 @@ pub enum ScrollStrategy { /// May not be possible if there's not enough list items above the item scrolled to: /// in this case, the element will be placed at the closest possible position. Center, - /// Scrolls the element to be at the given item index from the top of the viewport. - ToPosition(usize), +} + +#[derive(Clone, Copy, Debug)] +#[allow(missing_docs)] +pub struct DeferredScrollToItem { + /// The item index to scroll to + pub item_index: usize, + /// The scroll strategy to use + pub strategy: ScrollStrategy, + /// The offset in number of items + pub offset: usize, } #[derive(Clone, Debug, Default)] #[allow(missing_docs)] pub struct UniformListScrollState { pub base_handle: ScrollHandle, - pub deferred_scroll_to_item: Option<(usize, ScrollStrategy)>, + pub deferred_scroll_to_item: Option<DeferredScrollToItem>, /// Size of the item, captured during last layout. pub last_item_size: Option<ItemSize>, /// Whether the list was vertically flipped during last layout. @@ -126,7 +135,24 @@ impl UniformListScrollHandle { /// Scroll the list to the given item index. pub fn scroll_to_item(&self, ix: usize, strategy: ScrollStrategy) { - self.0.borrow_mut().deferred_scroll_to_item = Some((ix, strategy)); + self.0.borrow_mut().deferred_scroll_to_item = Some(DeferredScrollToItem { + item_index: ix, + strategy, + offset: 0, + }); + } + + /// Scroll the list to the given item index with an offset. + /// + /// For ScrollStrategy::Top, the item will be placed at the offset position from the top. + /// + /// For ScrollStrategy::Center, the item will be centered between offset and the last visible item. + pub fn scroll_to_item_with_offset(&self, ix: usize, strategy: ScrollStrategy, offset: usize) { + self.0.borrow_mut().deferred_scroll_to_item = Some(DeferredScrollToItem { + item_index: ix, + strategy, + offset, + }); } /// Check if the list is flipped vertically. @@ -139,7 +165,8 @@ impl UniformListScrollHandle { pub fn logical_scroll_top_index(&self) -> usize { let this = self.0.borrow(); this.deferred_scroll_to_item - .map(|(ix, _)| ix) + .as_ref() + .map(|deferred| deferred.item_index) .unwrap_or_else(|| this.base_handle.logical_scroll_top().0) } @@ -295,9 +322,8 @@ impl Element for UniformList { bounds.bottom_right() - point(border.right + padding.right, border.bottom), ); - let y_flipped = if let Some(scroll_handle) = self.scroll_handle.as_mut() { - let mut scroll_state = scroll_handle.0.borrow_mut(); - scroll_state.base_handle.set_bounds(bounds); + let y_flipped = if let Some(scroll_handle) = &self.scroll_handle { + let scroll_state = scroll_handle.0.borrow(); scroll_state.y_flipped } else { false @@ -321,7 +347,8 @@ impl Element for UniformList { scroll_offset.x = Pixels::ZERO; } - if let Some((mut ix, scroll_strategy)) = shared_scroll_to_item { + if let Some(deferred_scroll) = shared_scroll_to_item { + let mut ix = deferred_scroll.item_index; if y_flipped { ix = self.item_count.saturating_sub(ix + 1); } @@ -330,23 +357,28 @@ impl Element for UniformList { let item_top = item_height * ix + padding.top; let item_bottom = item_top + item_height; let scroll_top = -updated_scroll_offset.y; + let offset_pixels = item_height * deferred_scroll.offset; let mut scrolled_to_top = false; - if item_top < scroll_top + padding.top { + + if item_top < scroll_top + padding.top + offset_pixels { scrolled_to_top = true; - updated_scroll_offset.y = -(item_top) + padding.top; + updated_scroll_offset.y = -(item_top) + padding.top + offset_pixels; } else if item_bottom > scroll_top + list_height - padding.bottom { scrolled_to_top = true; updated_scroll_offset.y = -(item_bottom - list_height) - padding.bottom; } - match scroll_strategy { + match deferred_scroll.strategy { ScrollStrategy::Top => {} ScrollStrategy::Center => { if scrolled_to_top { let item_center = item_top + item_height / 2.0; - let target_scroll_top = item_center - list_height / 2.0; - if item_top < scroll_top + let viewport_height = list_height - offset_pixels; + let viewport_center = offset_pixels + viewport_height / 2.0; + let target_scroll_top = item_center - viewport_center; + + if item_top < scroll_top + offset_pixels || item_bottom > scroll_top + list_height { updated_scroll_offset.y = -target_scroll_top @@ -356,15 +388,6 @@ impl Element for UniformList { } } } - ScrollStrategy::ToPosition(sticky_index) => { - let target_y_in_viewport = item_height * sticky_index; - let target_scroll_top = item_top - target_y_in_viewport; - let max_scroll_top = - (content_height - list_height).max(Pixels::ZERO); - let new_scroll_top = - target_scroll_top.clamp(Pixels::ZERO, max_scroll_top); - updated_scroll_offset.y = -new_scroll_top; - } } scroll_offset = *updated_scroll_offset } diff --git a/crates/gpui/src/geometry.rs b/crates/gpui/src/geometry.rs index 74be6344f9..3d2d9cd9db 100644 --- a/crates/gpui/src/geometry.rs +++ b/crates/gpui/src/geometry.rs @@ -3522,7 +3522,7 @@ impl Serialize for Length { /// # Returns /// /// A `DefiniteLength` representing the relative length as a fraction of the parent's size. -pub fn relative(fraction: f32) -> DefiniteLength { +pub const fn relative(fraction: f32) -> DefiniteLength { DefiniteLength::Fraction(fraction) } diff --git a/crates/gpui/src/interactive.rs b/crates/gpui/src/interactive.rs index edd807da11..46af946e69 100644 --- a/crates/gpui/src/interactive.rs +++ b/crates/gpui/src/interactive.rs @@ -1,6 +1,6 @@ use crate::{ - Capslock, Context, Empty, IntoElement, Keystroke, Modifiers, Pixels, Point, Render, Window, - point, seal::Sealed, + Bounds, Capslock, Context, Empty, IntoElement, Keystroke, Modifiers, Pixels, Point, Render, + Window, point, seal::Sealed, }; use smallvec::SmallVec; use std::{any::Any, fmt::Debug, ops::Deref, path::PathBuf}; @@ -141,7 +141,7 @@ impl MouseEvent for MouseUpEvent {} /// A click event, generated when a mouse button is pressed and released. #[derive(Clone, Debug, Default)] -pub struct ClickEvent { +pub struct MouseClickEvent { /// The mouse event when the button was pressed. pub down: MouseDownEvent, @@ -149,18 +149,119 @@ pub struct ClickEvent { pub up: MouseUpEvent, } +/// A click event that was generated by a keyboard button being pressed and released. +#[derive(Clone, Debug)] +pub struct KeyboardClickEvent { + /// The keyboard button that was pressed to trigger the click. + pub button: KeyboardButton, + + /// The bounds of the element that was clicked. + pub bounds: Bounds<Pixels>, +} + +/// A click event, generated when a mouse button or keyboard button is pressed and released. +#[derive(Clone, Debug)] +pub enum ClickEvent { + /// A click event trigger by a mouse button being pressed and released. + Mouse(MouseClickEvent), + /// A click event trigger by a keyboard button being pressed and released. + Keyboard(KeyboardClickEvent), +} + impl ClickEvent { - /// Returns the modifiers that were held down during both the - /// mouse down and mouse up events + /// Returns the modifiers that were held during the click event + /// + /// `Keyboard`: The keyboard click events never have modifiers. + /// `Mouse`: Modifiers that were held during the mouse key up event. pub fn modifiers(&self) -> Modifiers { - Modifiers { - control: self.up.modifiers.control && self.down.modifiers.control, - alt: self.up.modifiers.alt && self.down.modifiers.alt, - shift: self.up.modifiers.shift && self.down.modifiers.shift, - platform: self.up.modifiers.platform && self.down.modifiers.platform, - function: self.up.modifiers.function && self.down.modifiers.function, + match self { + // Click events are only generated from keyboard events _without any modifiers_, so we know the modifiers are always Default + ClickEvent::Keyboard(_) => Modifiers::default(), + // Click events on the web only reflect the modifiers for the keyup event, + // tested via observing the behavior of the `ClickEvent.shiftKey` field in Chrome 138 + // under various combinations of modifiers and keyUp / keyDown events. + ClickEvent::Mouse(event) => event.up.modifiers, } } + + /// Returns the position of the click event + /// + /// `Keyboard`: The bottom left corner of the clicked hitbox + /// `Mouse`: The position of the mouse when the button was released. + pub fn position(&self) -> Point<Pixels> { + match self { + ClickEvent::Keyboard(event) => event.bounds.bottom_left(), + ClickEvent::Mouse(event) => event.up.position, + } + } + + /// Returns the mouse position of the click event + /// + /// `Keyboard`: None + /// `Mouse`: The position of the mouse when the button was released. + pub fn mouse_position(&self) -> Option<Point<Pixels>> { + match self { + ClickEvent::Keyboard(_) => None, + ClickEvent::Mouse(event) => Some(event.up.position), + } + } + + /// Returns if this was a right click + /// + /// `Keyboard`: false + /// `Mouse`: Whether the right button was pressed and released + pub fn is_right_click(&self) -> bool { + match self { + ClickEvent::Keyboard(_) => false, + ClickEvent::Mouse(event) => { + event.down.button == MouseButton::Right && event.up.button == MouseButton::Right + } + } + } + + /// Returns whether the click was a standard click + /// + /// `Keyboard`: Always true + /// `Mouse`: Left button pressed and released + pub fn standard_click(&self) -> bool { + match self { + ClickEvent::Keyboard(_) => true, + ClickEvent::Mouse(event) => { + event.down.button == MouseButton::Left && event.up.button == MouseButton::Left + } + } + } + + /// Returns whether the click focused the element + /// + /// `Keyboard`: false, keyboard clicks only work if an element is already focused + /// `Mouse`: Whether this was the first focusing click + pub fn first_focus(&self) -> bool { + match self { + ClickEvent::Keyboard(_) => false, + ClickEvent::Mouse(event) => event.down.first_mouse, + } + } + + /// Returns the click count of the click event + /// + /// `Keyboard`: Always 1 + /// `Mouse`: Count of clicks from MouseUpEvent + pub fn click_count(&self) -> usize { + match self { + ClickEvent::Keyboard(_) => 1, + ClickEvent::Mouse(event) => event.up.click_count, + } + } +} + +/// An enum representing the keyboard button that was pressed for a click event. +#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)] +pub enum KeyboardButton { + /// Enter key was clicked + Enter, + /// Space key was clicked + Space, } /// An enum representing the mouse button that was pressed. diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index 6f227f1d07..b495d70dfd 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -13,8 +13,7 @@ mod mac; any(target_os = "linux", target_os = "freebsd"), any(feature = "x11", feature = "wayland") ), - target_os = "windows", - feature = "macos-blade" + all(target_os = "macos", feature = "macos-blade") ))] mod blade; @@ -448,6 +447,8 @@ impl Tiling { #[derive(Debug, Copy, Clone, Eq, PartialEq, Default)] pub(crate) struct RequestFrameOptions { pub(crate) require_presentation: bool, + /// Force refresh of all rendering states when true + pub(crate) force_render: bool, } pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle { @@ -809,7 +810,6 @@ pub(crate) struct AtlasTextureId { pub(crate) enum AtlasTextureKind { Monochrome = 0, Polychrome = 1, - Path = 2, } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] diff --git a/crates/gpui/src/platform/blade/blade_atlas.rs b/crates/gpui/src/platform/blade/blade_atlas.rs index 78ba52056a..74500ebf83 100644 --- a/crates/gpui/src/platform/blade/blade_atlas.rs +++ b/crates/gpui/src/platform/blade/blade_atlas.rs @@ -10,8 +10,6 @@ use etagere::BucketedAtlasAllocator; use parking_lot::Mutex; use std::{borrow::Cow, ops, sync::Arc}; -pub(crate) const PATH_TEXTURE_FORMAT: gpu::TextureFormat = gpu::TextureFormat::R16Float; - pub(crate) struct BladeAtlas(Mutex<BladeAtlasState>); struct PendingUpload { @@ -27,7 +25,6 @@ struct BladeAtlasState { tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, initializations: Vec<AtlasTextureId>, uploads: Vec<PendingUpload>, - path_sample_count: u32, } #[cfg(gles)] @@ -41,13 +38,11 @@ impl BladeAtlasState { } pub struct BladeTextureInfo { - pub size: gpu::Extent, pub raw_view: gpu::TextureView, - pub msaa_view: Option<gpu::TextureView>, } impl BladeAtlas { - pub(crate) fn new(gpu: &Arc<gpu::Context>, path_sample_count: u32) -> Self { + pub(crate) fn new(gpu: &Arc<gpu::Context>) -> Self { BladeAtlas(Mutex::new(BladeAtlasState { gpu: Arc::clone(gpu), upload_belt: BufferBelt::new(BufferBeltDescriptor { @@ -59,7 +54,6 @@ impl BladeAtlas { tiles_by_key: Default::default(), initializations: Vec::new(), uploads: Vec::new(), - path_sample_count, })) } @@ -67,27 +61,6 @@ impl BladeAtlas { self.0.lock().destroy(); } - pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) { - let mut lock = self.0.lock(); - let textures = &mut lock.storage[texture_kind]; - for texture in textures.iter_mut() { - texture.clear(); - } - } - - /// Allocate a rectangle and make it available for rendering immediately (without waiting for `before_frame`) - pub fn allocate_for_rendering( - &self, - size: Size<DevicePixels>, - texture_kind: AtlasTextureKind, - gpu_encoder: &mut gpu::CommandEncoder, - ) -> AtlasTile { - let mut lock = self.0.lock(); - let tile = lock.allocate(size, texture_kind); - lock.flush_initializations(gpu_encoder); - tile - } - pub fn before_frame(&self, gpu_encoder: &mut gpu::CommandEncoder) { let mut lock = self.0.lock(); lock.flush(gpu_encoder); @@ -101,15 +74,8 @@ impl BladeAtlas { pub fn get_texture_info(&self, id: AtlasTextureId) -> BladeTextureInfo { let lock = self.0.lock(); let texture = &lock.storage[id]; - let size = texture.allocator.size(); BladeTextureInfo { - size: gpu::Extent { - width: size.width as u32, - height: size.height as u32, - depth: 1, - }, raw_view: texture.raw_view, - msaa_view: texture.msaa_view, } } } @@ -200,48 +166,8 @@ impl BladeAtlasState { format = gpu::TextureFormat::Bgra8UnormSrgb; usage = gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE; } - AtlasTextureKind::Path => { - format = PATH_TEXTURE_FORMAT; - usage = gpu::TextureUsage::COPY - | gpu::TextureUsage::RESOURCE - | gpu::TextureUsage::TARGET; - } } - // We currently only enable MSAA for path textures. - let (msaa, msaa_view) = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path { - let msaa = self.gpu.create_texture(gpu::TextureDesc { - name: "msaa path texture", - format, - size: gpu::Extent { - width: size.width.into(), - height: size.height.into(), - depth: 1, - }, - array_layer_count: 1, - mip_level_count: 1, - sample_count: self.path_sample_count, - dimension: gpu::TextureDimension::D2, - usage: gpu::TextureUsage::TARGET, - external: None, - }); - - ( - Some(msaa), - Some(self.gpu.create_texture_view( - msaa, - gpu::TextureViewDesc { - name: "msaa texture view", - format, - dimension: gpu::ViewDimension::D2, - subresources: &Default::default(), - }, - )), - ) - } else { - (None, None) - }; - let raw = self.gpu.create_texture(gpu::TextureDesc { name: "atlas", format, @@ -279,8 +205,6 @@ impl BladeAtlasState { format, raw, raw_view, - msaa, - msaa_view, live_atlas_keys: 0, }; @@ -340,7 +264,6 @@ impl BladeAtlasState { struct BladeAtlasStorage { monochrome_textures: AtlasTextureList<BladeAtlasTexture>, polychrome_textures: AtlasTextureList<BladeAtlasTexture>, - path_textures: AtlasTextureList<BladeAtlasTexture>, } impl ops::Index<AtlasTextureKind> for BladeAtlasStorage { @@ -349,7 +272,6 @@ impl ops::Index<AtlasTextureKind> for BladeAtlasStorage { match kind { crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, - crate::AtlasTextureKind::Path => &self.path_textures, } } } @@ -359,7 +281,6 @@ impl ops::IndexMut<AtlasTextureKind> for BladeAtlasStorage { match kind { crate::AtlasTextureKind::Monochrome => &mut self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &mut self.polychrome_textures, - crate::AtlasTextureKind::Path => &mut self.path_textures, } } } @@ -370,7 +291,6 @@ impl ops::Index<AtlasTextureId> for BladeAtlasStorage { let textures = match id.kind { crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, - crate::AtlasTextureKind::Path => &self.path_textures, }; textures[id.index as usize].as_ref().unwrap() } @@ -384,9 +304,6 @@ impl BladeAtlasStorage { for mut texture in self.polychrome_textures.drain().flatten() { texture.destroy(gpu); } - for mut texture in self.path_textures.drain().flatten() { - texture.destroy(gpu); - } } } @@ -395,17 +312,11 @@ struct BladeAtlasTexture { allocator: BucketedAtlasAllocator, raw: gpu::Texture, raw_view: gpu::TextureView, - msaa: Option<gpu::Texture>, - msaa_view: Option<gpu::TextureView>, format: gpu::TextureFormat, live_atlas_keys: u32, } impl BladeAtlasTexture { - fn clear(&mut self) { - self.allocator.clear(); - } - fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { let allocation = self.allocator.allocate(size.into())?; let tile = AtlasTile { @@ -424,12 +335,6 @@ impl BladeAtlasTexture { fn destroy(&mut self, gpu: &gpu::Context) { gpu.destroy_texture(self.raw); gpu.destroy_texture_view(self.raw_view); - if let Some(msaa) = self.msaa { - gpu.destroy_texture(msaa); - } - if let Some(msaa_view) = self.msaa_view { - gpu.destroy_texture_view(msaa_view); - } } fn bytes_per_pixel(&self) -> u8 { diff --git a/crates/gpui/src/platform/blade/blade_renderer.rs b/crates/gpui/src/platform/blade/blade_renderer.rs index cac47434ae..46d3c16c72 100644 --- a/crates/gpui/src/platform/blade/blade_renderer.rs +++ b/crates/gpui/src/platform/blade/blade_renderer.rs @@ -1,24 +1,19 @@ // Doing `if let` gives you nice scoping with passes/encoders #![allow(irrefutable_let_patterns)] -use super::{BladeAtlas, BladeContext, PATH_TEXTURE_FORMAT}; +use super::{BladeAtlas, BladeContext}; use crate::{ - AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels, GpuSpecs, - MonochromeSprite, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch, Quad, - ScaledPixels, Scene, Shadow, Size, Underline, + Background, Bounds, DevicePixels, GpuSpecs, MonochromeSprite, Path, Point, PolychromeSprite, + PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Underline, }; use blade_graphics as gpu; use blade_util::{BufferBelt, BufferBeltDescriptor}; use bytemuck::{Pod, Zeroable}; -use collections::HashMap; #[cfg(target_os = "macos")] use media::core_video::CVMetalTextureCache; -use std::{mem, sync::Arc}; +use std::sync::Arc; const MAX_FRAME_TIME_MS: u32 = 10000; -// Use 4x MSAA, all devices support it. -// https://developer.apple.com/documentation/metal/mtldevice/1433355-supportstexturesamplecount -const DEFAULT_PATH_SAMPLE_COUNT: u32 = 4; #[repr(C)] #[derive(Clone, Copy, Pod, Zeroable)] @@ -114,8 +109,15 @@ struct ShaderSurfacesData { #[repr(C)] struct PathSprite { bounds: Bounds<ScaledPixels>, +} + +#[derive(Clone, Debug)] +#[repr(C)] +struct PathRasterizationVertex { + xy_position: Point<ScaledPixels>, + st_position: Point<f32>, color: Background, - tile: AtlasTile, + bounds: Bounds<ScaledPixels>, } struct BladePipelines { @@ -144,10 +146,7 @@ impl BladePipelines { shader.check_struct_size::<SurfaceParams>(); shader.check_struct_size::<Quad>(); shader.check_struct_size::<Shadow>(); - assert_eq!( - mem::size_of::<PathVertex<ScaledPixels>>(), - shader.get_struct_size("PathVertex") as usize, - ); + shader.check_struct_size::<PathRasterizationVertex>(); shader.check_struct_size::<PathSprite>(); shader.check_struct_size::<Underline>(); shader.check_struct_size::<MonochromeSprite>(); @@ -205,9 +204,16 @@ impl BladePipelines { }, depth_stencil: None, fragment: Some(shader.at("fs_path_rasterization")), + // The original implementation was using ADDITIVE blende mode, + // I don't know why + // color_targets: &[gpu::ColorTargetState { + // format: PATH_TEXTURE_FORMAT, + // blend: Some(gpu::BlendState::ADDITIVE), + // write_mask: gpu::ColorWrites::default(), + // }], color_targets: &[gpu::ColorTargetState { - format: PATH_TEXTURE_FORMAT, - blend: Some(gpu::BlendState::ADDITIVE), + format: surface_info.format, + blend: Some(gpu::BlendState::PREMULTIPLIED_ALPHA_BLENDING), write_mask: gpu::ColorWrites::default(), }], multisample_state: gpu::MultisampleState { @@ -226,7 +232,14 @@ impl BladePipelines { }, depth_stencil: None, fragment: Some(shader.at("fs_path")), - color_targets, + color_targets: &[gpu::ColorTargetState { + format: surface_info.format, + blend: Some(gpu::BlendState { + color: gpu::BlendComponent::OVER, + alpha: gpu::BlendComponent::ADDITIVE, + }), + write_mask: gpu::ColorWrites::default(), + }], multisample_state: gpu::MultisampleState::default(), }), underlines: gpu.create_render_pipeline(gpu::RenderPipelineDesc { @@ -317,12 +330,15 @@ pub struct BladeRenderer { last_sync_point: Option<gpu::SyncPoint>, pipelines: BladePipelines, instance_belt: BufferBelt, - path_tiles: HashMap<PathId, AtlasTile>, atlas: Arc<BladeAtlas>, atlas_sampler: gpu::Sampler, #[cfg(target_os = "macos")] core_video_texture_cache: CVMetalTextureCache, path_sample_count: u32, + path_intermediate_texture: gpu::Texture, + path_intermediate_texture_view: gpu::TextureView, + path_intermediate_msaa_texture: Option<gpu::Texture>, + path_intermediate_msaa_texture_view: Option<gpu::TextureView>, } impl BladeRenderer { @@ -352,21 +368,43 @@ impl BladeRenderer { let path_sample_count = std::env::var("ZED_PATH_SAMPLE_COUNT") .ok() .and_then(|v| v.parse().ok()) - .unwrap_or(DEFAULT_PATH_SAMPLE_COUNT); + .or_else(|| { + [4, 2, 1] + .into_iter() + .find(|count| context.gpu.supports_texture_sample_count(*count)) + }) + .unwrap_or(1); let pipelines = BladePipelines::new(&context.gpu, surface.info(), path_sample_count); let instance_belt = BufferBelt::new(BufferBeltDescriptor { memory: gpu::Memory::Shared, min_chunk_size: 0x1000, alignment: 0x40, // Vulkan `minStorageBufferOffsetAlignment` on Intel Xe }); - let atlas = Arc::new(BladeAtlas::new(&context.gpu, path_sample_count)); + let atlas = Arc::new(BladeAtlas::new(&context.gpu)); let atlas_sampler = context.gpu.create_sampler(gpu::SamplerDesc { - name: "atlas", + name: "path rasterization sampler", mag_filter: gpu::FilterMode::Linear, min_filter: gpu::FilterMode::Linear, ..Default::default() }); + let (path_intermediate_texture, path_intermediate_texture_view) = + create_path_intermediate_texture( + &context.gpu, + surface.info().format, + config.size.width, + config.size.height, + ); + let (path_intermediate_msaa_texture, path_intermediate_msaa_texture_view) = + create_msaa_texture_if_needed( + &context.gpu, + surface.info().format, + config.size.width, + config.size.height, + path_sample_count, + ) + .unzip(); + #[cfg(target_os = "macos")] let core_video_texture_cache = unsafe { CVMetalTextureCache::new( @@ -383,12 +421,15 @@ impl BladeRenderer { last_sync_point: None, pipelines, instance_belt, - path_tiles: HashMap::default(), atlas, atlas_sampler, #[cfg(target_os = "macos")] core_video_texture_cache, path_sample_count, + path_intermediate_texture, + path_intermediate_texture_view, + path_intermediate_msaa_texture, + path_intermediate_msaa_texture_view, }) } @@ -441,6 +482,35 @@ impl BladeRenderer { self.surface_config.size = gpu_size; self.gpu .reconfigure_surface(&mut self.surface, self.surface_config); + self.gpu.destroy_texture(self.path_intermediate_texture); + self.gpu + .destroy_texture_view(self.path_intermediate_texture_view); + if let Some(msaa_texture) = self.path_intermediate_msaa_texture { + self.gpu.destroy_texture(msaa_texture); + } + if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { + self.gpu.destroy_texture_view(msaa_view); + } + let (path_intermediate_texture, path_intermediate_texture_view) = + create_path_intermediate_texture( + &self.gpu, + self.surface.info().format, + gpu_size.width, + gpu_size.height, + ); + self.path_intermediate_texture = path_intermediate_texture; + self.path_intermediate_texture_view = path_intermediate_texture_view; + let (path_intermediate_msaa_texture, path_intermediate_msaa_texture_view) = + create_msaa_texture_if_needed( + &self.gpu, + self.surface.info().format, + gpu_size.width, + gpu_size.height, + self.path_sample_count, + ) + .unzip(); + self.path_intermediate_msaa_texture = path_intermediate_msaa_texture; + self.path_intermediate_msaa_texture_view = path_intermediate_msaa_texture_view; } } @@ -491,76 +561,63 @@ impl BladeRenderer { } #[profiling::function] - fn rasterize_paths(&mut self, paths: &[Path<ScaledPixels>]) { - self.path_tiles.clear(); - let mut vertices_by_texture_id = HashMap::default(); - - for path in paths { - let clipped_bounds = path - .bounds - .intersect(&path.content_mask.bounds) - .map_origin(|origin| origin.floor()) - .map_size(|size| size.ceil()); - let tile = self.atlas.allocate_for_rendering( - clipped_bounds.size.map(Into::into), - AtlasTextureKind::Path, - &mut self.command_encoder, - ); - vertices_by_texture_id - .entry(tile.texture_id) - .or_insert(Vec::new()) - .extend(path.vertices.iter().map(|vertex| PathVertex { - xy_position: vertex.xy_position - clipped_bounds.origin - + tile.bounds.origin.map(Into::into), - st_position: vertex.st_position, - content_mask: ContentMask { - bounds: tile.bounds.map(Into::into), - }, - })); - self.path_tiles.insert(path.id, tile); + fn draw_paths_to_intermediate( + &mut self, + paths: &[Path<ScaledPixels>], + width: f32, + height: f32, + ) { + self.command_encoder + .init_texture(self.path_intermediate_texture); + if let Some(msaa_texture) = self.path_intermediate_msaa_texture { + self.command_encoder.init_texture(msaa_texture); } - for (texture_id, vertices) in vertices_by_texture_id { - let tex_info = self.atlas.get_texture_info(texture_id); + let target = if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { + gpu::RenderTarget { + view: msaa_view, + init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), + finish_op: gpu::FinishOp::ResolveTo(self.path_intermediate_texture_view), + } + } else { + gpu::RenderTarget { + view: self.path_intermediate_texture_view, + init_op: gpu::InitOp::Clear(gpu::TextureColor::TransparentBlack), + finish_op: gpu::FinishOp::Store, + } + }; + if let mut pass = self.command_encoder.render( + "rasterize paths", + gpu::RenderTargetSet { + colors: &[target], + depth_stencil: None, + }, + ) { let globals = GlobalParams { - viewport_size: [tex_info.size.width as f32, tex_info.size.height as f32], + viewport_size: [width, height], premultiplied_alpha: 0, pad: 0, }; + let mut encoder = pass.with(&self.pipelines.path_rasterization); - let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) }; - let frame_view = tex_info.raw_view; - let color_target = if let Some(msaa_view) = tex_info.msaa_view { - gpu::RenderTarget { - view: msaa_view, - init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack), - finish_op: gpu::FinishOp::ResolveTo(frame_view), - } - } else { - gpu::RenderTarget { - view: frame_view, - init_op: gpu::InitOp::Clear(gpu::TextureColor::OpaqueBlack), - finish_op: gpu::FinishOp::Store, - } - }; - - if let mut pass = self.command_encoder.render( - "paths", - gpu::RenderTargetSet { - colors: &[color_target], - depth_stencil: None, - }, - ) { - let mut encoder = pass.with(&self.pipelines.path_rasterization); - encoder.bind( - 0, - &ShaderPathRasterizationData { - globals, - b_path_vertices: vertex_buf, - }, - ); - encoder.draw(0, vertices.len() as u32, 0, 1); + let mut vertices = Vec::new(); + for path in paths { + vertices.extend(path.vertices.iter().map(|v| PathRasterizationVertex { + xy_position: v.xy_position, + st_position: v.st_position, + color: path.color, + bounds: path.clipped_bounds(), + })); } + let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) }; + encoder.bind( + 0, + &ShaderPathRasterizationData { + globals, + b_path_vertices: vertex_buf, + }, + ); + encoder.draw(0, vertices.len() as u32, 0, 1); } } @@ -572,12 +629,20 @@ impl BladeRenderer { self.gpu.destroy_command_encoder(&mut self.command_encoder); self.pipelines.destroy(&self.gpu); self.gpu.destroy_surface(&mut self.surface); + self.gpu.destroy_texture(self.path_intermediate_texture); + self.gpu + .destroy_texture_view(self.path_intermediate_texture_view); + if let Some(msaa_texture) = self.path_intermediate_msaa_texture { + self.gpu.destroy_texture(msaa_texture); + } + if let Some(msaa_view) = self.path_intermediate_msaa_texture_view { + self.gpu.destroy_texture_view(msaa_view); + } } pub fn draw(&mut self, scene: &Scene) { self.command_encoder.start(); self.atlas.before_frame(&mut self.command_encoder); - self.rasterize_paths(scene.paths()); let frame = { profiling::scope!("acquire frame"); @@ -597,7 +662,7 @@ impl BladeRenderer { pad: 0, }; - if let mut pass = self.command_encoder.render( + let mut pass = self.command_encoder.render( "main", gpu::RenderTargetSet { colors: &[gpu::RenderTarget { @@ -607,209 +672,235 @@ impl BladeRenderer { }], depth_stencil: None, }, - ) { - profiling::scope!("render pass"); - for batch in scene.batches() { - match batch { - PrimitiveBatch::Quads(quads) => { - let instance_buf = - unsafe { self.instance_belt.alloc_typed(quads, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.quads); - encoder.bind( - 0, - &ShaderQuadsData { - globals, - b_quads: instance_buf, - }, - ); - encoder.draw(0, 4, 0, quads.len() as u32); - } - PrimitiveBatch::Shadows(shadows) => { - let instance_buf = - unsafe { self.instance_belt.alloc_typed(shadows, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.shadows); - encoder.bind( - 0, - &ShaderShadowsData { - globals, - b_shadows: instance_buf, - }, - ); - encoder.draw(0, 4, 0, shadows.len() as u32); - } - PrimitiveBatch::Paths(paths) => { - let mut encoder = pass.with(&self.pipelines.paths); - // todo(linux): group by texture ID - for path in paths { - let tile = &self.path_tiles[&path.id]; - let tex_info = self.atlas.get_texture_info(tile.texture_id); - let origin = path.bounds.intersect(&path.content_mask.bounds).origin; - let sprites = [PathSprite { - bounds: Bounds { - origin: origin.map(|p| p.floor()), - size: tile.bounds.size.map(Into::into), - }, - color: path.color, - tile: (*tile).clone(), - }]; + ); - let instance_buf = - unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) }; - encoder.bind( - 0, - &ShaderPathsData { - globals, - t_sprite: tex_info.raw_view, - s_sprite: self.atlas_sampler, - b_path_sprites: instance_buf, - }, - ); - encoder.draw(0, 4, 0, sprites.len() as u32); + profiling::scope!("render pass"); + for batch in scene.batches() { + match batch { + PrimitiveBatch::Quads(quads) => { + let instance_buf = unsafe { self.instance_belt.alloc_typed(quads, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.quads); + encoder.bind( + 0, + &ShaderQuadsData { + globals, + b_quads: instance_buf, + }, + ); + encoder.draw(0, 4, 0, quads.len() as u32); + } + PrimitiveBatch::Shadows(shadows) => { + let instance_buf = + unsafe { self.instance_belt.alloc_typed(shadows, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.shadows); + encoder.bind( + 0, + &ShaderShadowsData { + globals, + b_shadows: instance_buf, + }, + ); + encoder.draw(0, 4, 0, shadows.len() as u32); + } + PrimitiveBatch::Paths(paths) => { + let Some(first_path) = paths.first() else { + continue; + }; + drop(pass); + self.draw_paths_to_intermediate( + paths, + self.surface_config.size.width as f32, + self.surface_config.size.height as f32, + ); + pass = self.command_encoder.render( + "main", + gpu::RenderTargetSet { + colors: &[gpu::RenderTarget { + view: frame.texture_view(), + init_op: gpu::InitOp::Load, + finish_op: gpu::FinishOp::Store, + }], + depth_stencil: None, + }, + ); + let mut encoder = pass.with(&self.pipelines.paths); + // When copying paths from the intermediate texture to the drawable, + // each pixel must only be copied once, in case of transparent paths. + // + // If all paths have the same draw order, then their bounds are all + // disjoint, so we can copy each path's bounds individually. If this + // batch combines different draw orders, we perform a single copy + // for a minimal spanning rect. + let sprites = if paths.last().unwrap().order == first_path.order { + paths + .iter() + .map(|path| PathSprite { + bounds: path.clipped_bounds(), + }) + .collect() + } else { + let mut bounds = first_path.clipped_bounds(); + for path in paths.iter().skip(1) { + bounds = bounds.union(&path.clipped_bounds()); } - } - PrimitiveBatch::Underlines(underlines) => { - let instance_buf = - unsafe { self.instance_belt.alloc_typed(underlines, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.underlines); - encoder.bind( - 0, - &ShaderUnderlinesData { - globals, - b_underlines: instance_buf, - }, - ); - encoder.draw(0, 4, 0, underlines.len() as u32); - } - PrimitiveBatch::MonochromeSprites { - texture_id, - sprites, - } => { - let tex_info = self.atlas.get_texture_info(texture_id); - let instance_buf = - unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.mono_sprites); - encoder.bind( - 0, - &ShaderMonoSpritesData { - globals, - t_sprite: tex_info.raw_view, - s_sprite: self.atlas_sampler, - b_mono_sprites: instance_buf, - }, - ); - encoder.draw(0, 4, 0, sprites.len() as u32); - } - PrimitiveBatch::PolychromeSprites { - texture_id, - sprites, - } => { - let tex_info = self.atlas.get_texture_info(texture_id); - let instance_buf = - unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; - let mut encoder = pass.with(&self.pipelines.poly_sprites); - encoder.bind( - 0, - &ShaderPolySpritesData { - globals, - t_sprite: tex_info.raw_view, - s_sprite: self.atlas_sampler, - b_poly_sprites: instance_buf, - }, - ); - encoder.draw(0, 4, 0, sprites.len() as u32); - } - PrimitiveBatch::Surfaces(surfaces) => { - let mut _encoder = pass.with(&self.pipelines.surfaces); + vec![PathSprite { bounds }] + }; + let instance_buf = + unsafe { self.instance_belt.alloc_typed(&sprites, &self.gpu) }; + encoder.bind( + 0, + &ShaderPathsData { + globals, + t_sprite: self.path_intermediate_texture_view, + s_sprite: self.atlas_sampler, + b_path_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); + } + PrimitiveBatch::Underlines(underlines) => { + let instance_buf = + unsafe { self.instance_belt.alloc_typed(underlines, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.underlines); + encoder.bind( + 0, + &ShaderUnderlinesData { + globals, + b_underlines: instance_buf, + }, + ); + encoder.draw(0, 4, 0, underlines.len() as u32); + } + PrimitiveBatch::MonochromeSprites { + texture_id, + sprites, + } => { + let tex_info = self.atlas.get_texture_info(texture_id); + let instance_buf = + unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.mono_sprites); + encoder.bind( + 0, + &ShaderMonoSpritesData { + globals, + t_sprite: tex_info.raw_view, + s_sprite: self.atlas_sampler, + b_mono_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); + } + PrimitiveBatch::PolychromeSprites { + texture_id, + sprites, + } => { + let tex_info = self.atlas.get_texture_info(texture_id); + let instance_buf = + unsafe { self.instance_belt.alloc_typed(sprites, &self.gpu) }; + let mut encoder = pass.with(&self.pipelines.poly_sprites); + encoder.bind( + 0, + &ShaderPolySpritesData { + globals, + t_sprite: tex_info.raw_view, + s_sprite: self.atlas_sampler, + b_poly_sprites: instance_buf, + }, + ); + encoder.draw(0, 4, 0, sprites.len() as u32); + } + PrimitiveBatch::Surfaces(surfaces) => { + let mut _encoder = pass.with(&self.pipelines.surfaces); - for surface in surfaces { - #[cfg(not(target_os = "macos"))] - { - let _ = surface; - continue; - }; + for surface in surfaces { + #[cfg(not(target_os = "macos"))] + { + let _ = surface; + continue; + }; - #[cfg(target_os = "macos")] - { - let (t_y, t_cb_cr) = unsafe { - use core_foundation::base::TCFType as _; - use std::ptr; + #[cfg(target_os = "macos")] + { + let (t_y, t_cb_cr) = unsafe { + use core_foundation::base::TCFType as _; + use std::ptr; - assert_eq!( + assert_eq!( surface.image_buffer.get_pixel_format(), core_video::pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange ); - let y_texture = self - .core_video_texture_cache - .create_texture_from_image( - surface.image_buffer.as_concrete_TypeRef(), - ptr::null(), - metal::MTLPixelFormat::R8Unorm, - surface.image_buffer.get_width_of_plane(0), - surface.image_buffer.get_height_of_plane(0), - 0, - ) - .unwrap(); - let cb_cr_texture = self - .core_video_texture_cache - .create_texture_from_image( - surface.image_buffer.as_concrete_TypeRef(), - ptr::null(), - metal::MTLPixelFormat::RG8Unorm, - surface.image_buffer.get_width_of_plane(1), - surface.image_buffer.get_height_of_plane(1), - 1, - ) - .unwrap(); - ( - gpu::TextureView::from_metal_texture( - &objc2::rc::Retained::retain( - foreign_types::ForeignTypeRef::as_ptr( - y_texture.as_texture_ref(), - ) - as *mut objc2::runtime::ProtocolObject< - dyn objc2_metal::MTLTexture, - >, - ) - .unwrap(), - gpu::TexelAspects::COLOR, - ), - gpu::TextureView::from_metal_texture( - &objc2::rc::Retained::retain( - foreign_types::ForeignTypeRef::as_ptr( - cb_cr_texture.as_texture_ref(), - ) - as *mut objc2::runtime::ProtocolObject< - dyn objc2_metal::MTLTexture, - >, - ) - .unwrap(), - gpu::TexelAspects::COLOR, - ), + let y_texture = self + .core_video_texture_cache + .create_texture_from_image( + surface.image_buffer.as_concrete_TypeRef(), + ptr::null(), + metal::MTLPixelFormat::R8Unorm, + surface.image_buffer.get_width_of_plane(0), + surface.image_buffer.get_height_of_plane(0), + 0, ) - }; + .unwrap(); + let cb_cr_texture = self + .core_video_texture_cache + .create_texture_from_image( + surface.image_buffer.as_concrete_TypeRef(), + ptr::null(), + metal::MTLPixelFormat::RG8Unorm, + surface.image_buffer.get_width_of_plane(1), + surface.image_buffer.get_height_of_plane(1), + 1, + ) + .unwrap(); + ( + gpu::TextureView::from_metal_texture( + &objc2::rc::Retained::retain( + foreign_types::ForeignTypeRef::as_ptr( + y_texture.as_texture_ref(), + ) + as *mut objc2::runtime::ProtocolObject< + dyn objc2_metal::MTLTexture, + >, + ) + .unwrap(), + gpu::TexelAspects::COLOR, + ), + gpu::TextureView::from_metal_texture( + &objc2::rc::Retained::retain( + foreign_types::ForeignTypeRef::as_ptr( + cb_cr_texture.as_texture_ref(), + ) + as *mut objc2::runtime::ProtocolObject< + dyn objc2_metal::MTLTexture, + >, + ) + .unwrap(), + gpu::TexelAspects::COLOR, + ), + ) + }; - _encoder.bind( - 0, - &ShaderSurfacesData { - globals, - surface_locals: SurfaceParams { - bounds: surface.bounds.into(), - content_mask: surface.content_mask.bounds.into(), - }, - t_y, - t_cb_cr, - s_surface: self.atlas_sampler, + _encoder.bind( + 0, + &ShaderSurfacesData { + globals, + surface_locals: SurfaceParams { + bounds: surface.bounds.into(), + content_mask: surface.content_mask.bounds.into(), }, - ); + t_y, + t_cb_cr, + s_surface: self.atlas_sampler, + }, + ); - _encoder.draw(0, 4, 0, 1); - } + _encoder.draw(0, 4, 0, 1); } } } } } + drop(pass); self.command_encoder.present(frame); let sync_point = self.gpu.submit(&mut self.command_encoder); @@ -817,9 +908,79 @@ impl BladeRenderer { profiling::scope!("finish"); self.instance_belt.flush(&sync_point); self.atlas.after_frame(&sync_point); - self.atlas.clear_textures(AtlasTextureKind::Path); self.wait_for_gpu(); self.last_sync_point = Some(sync_point); } } + +fn create_path_intermediate_texture( + gpu: &gpu::Context, + format: gpu::TextureFormat, + width: u32, + height: u32, +) -> (gpu::Texture, gpu::TextureView) { + let texture = gpu.create_texture(gpu::TextureDesc { + name: "path intermediate", + format, + size: gpu::Extent { + width, + height, + depth: 1, + }, + array_layer_count: 1, + mip_level_count: 1, + sample_count: 1, + dimension: gpu::TextureDimension::D2, + usage: gpu::TextureUsage::COPY | gpu::TextureUsage::RESOURCE | gpu::TextureUsage::TARGET, + external: None, + }); + let texture_view = gpu.create_texture_view( + texture, + gpu::TextureViewDesc { + name: "path intermediate view", + format, + dimension: gpu::ViewDimension::D2, + subresources: &Default::default(), + }, + ); + (texture, texture_view) +} + +fn create_msaa_texture_if_needed( + gpu: &gpu::Context, + format: gpu::TextureFormat, + width: u32, + height: u32, + sample_count: u32, +) -> Option<(gpu::Texture, gpu::TextureView)> { + if sample_count <= 1 { + return None; + } + let texture_msaa = gpu.create_texture(gpu::TextureDesc { + name: "path intermediate msaa", + format, + size: gpu::Extent { + width, + height, + depth: 1, + }, + array_layer_count: 1, + mip_level_count: 1, + sample_count, + dimension: gpu::TextureDimension::D2, + usage: gpu::TextureUsage::TARGET, + external: None, + }); + let texture_view_msaa = gpu.create_texture_view( + texture_msaa, + gpu::TextureViewDesc { + name: "path intermediate msaa view", + format, + dimension: gpu::ViewDimension::D2, + subresources: &Default::default(), + }, + ); + + Some((texture_msaa, texture_view_msaa)) +} diff --git a/crates/gpui/src/platform/blade/shaders.wgsl b/crates/gpui/src/platform/blade/shaders.wgsl index 0b34a0eea3..b1ffb1812e 100644 --- a/crates/gpui/src/platform/blade/shaders.wgsl +++ b/crates/gpui/src/platform/blade/shaders.wgsl @@ -924,16 +924,19 @@ fn fs_shadow(input: ShadowVarying) -> @location(0) vec4<f32> { // --- path rasterization --- // -struct PathVertex { +struct PathRasterizationVertex { xy_position: vec2<f32>, st_position: vec2<f32>, - content_mask: Bounds, + color: Background, + bounds: Bounds, } -var<storage, read> b_path_vertices: array<PathVertex>; + +var<storage, read> b_path_vertices: array<PathRasterizationVertex>; struct PathRasterizationVarying { @builtin(position) position: vec4<f32>, @location(0) st_position: vec2<f32>, + @location(1) vertex_id: u32, //TODO: use `clip_distance` once Naga supports it @location(3) clip_distances: vec4<f32>, } @@ -945,40 +948,54 @@ fn vs_path_rasterization(@builtin(vertex_index) vertex_id: u32) -> PathRasteriza var out = PathRasterizationVarying(); out.position = to_device_position_impl(v.xy_position); out.st_position = v.st_position; - out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.content_mask); + out.vertex_id = vertex_id; + out.clip_distances = distance_from_clip_rect_impl(v.xy_position, v.bounds); return out; } @fragment -fn fs_path_rasterization(input: PathRasterizationVarying) -> @location(0) f32 { +fn fs_path_rasterization(input: PathRasterizationVarying) -> @location(0) vec4<f32> { let dx = dpdx(input.st_position); let dy = dpdy(input.st_position); if (any(input.clip_distances < vec4<f32>(0.0))) { - return 0.0; + return vec4<f32>(0.0); } - let gradient = 2.0 * input.st_position.xx * vec2<f32>(dx.x, dy.x) - vec2<f32>(dx.y, dy.y); - let f = input.st_position.x * input.st_position.x - input.st_position.y; - let distance = f / length(gradient); - return saturate(0.5 - distance); + let v = b_path_vertices[input.vertex_id]; + let background = v.color; + let bounds = v.bounds; + + var alpha: f32; + if (length(vec2<f32>(dx.x, dy.x)) < 0.001) { + // If the gradient is too small, return a solid color. + alpha = 1.0; + } else { + let gradient = 2.0 * input.st_position.xx * vec2<f32>(dx.x, dy.x) - vec2<f32>(dx.y, dy.y); + let f = input.st_position.x * input.st_position.x - input.st_position.y; + let distance = f / length(gradient); + alpha = saturate(0.5 - distance); + } + let gradient_color = prepare_gradient_color( + background.tag, + background.color_space, + background.solid, + background.colors, + ); + let color = gradient_color(background, input.position.xy, bounds, + gradient_color.solid, gradient_color.color0, gradient_color.color1); + return vec4<f32>(color.rgb * color.a * alpha, color.a * alpha); } // --- paths --- // struct PathSprite { bounds: Bounds, - color: Background, - tile: AtlasTile, } var<storage, read> b_path_sprites: array<PathSprite>; struct PathVarying { @builtin(position) position: vec4<f32>, - @location(0) tile_position: vec2<f32>, - @location(1) @interpolate(flat) instance_id: u32, - @location(2) @interpolate(flat) color_solid: vec4<f32>, - @location(3) @interpolate(flat) color0: vec4<f32>, - @location(4) @interpolate(flat) color1: vec4<f32>, + @location(0) texture_coords: vec2<f32>, } @vertex @@ -986,33 +1003,22 @@ fn vs_path(@builtin(vertex_index) vertex_id: u32, @builtin(instance_index) insta let unit_vertex = vec2<f32>(f32(vertex_id & 1u), 0.5 * f32(vertex_id & 2u)); let sprite = b_path_sprites[instance_id]; // Don't apply content mask because it was already accounted for when rasterizing the path. + let device_position = to_device_position(unit_vertex, sprite.bounds); + // For screen-space intermediate texture, convert screen position to texture coordinates + let screen_position = sprite.bounds.origin + unit_vertex * sprite.bounds.size; + let texture_coords = screen_position / globals.viewport_size; var out = PathVarying(); - out.position = to_device_position(unit_vertex, sprite.bounds); - out.tile_position = to_tile_position(unit_vertex, sprite.tile); - out.instance_id = instance_id; + out.position = device_position; + out.texture_coords = texture_coords; - let gradient = prepare_gradient_color( - sprite.color.tag, - sprite.color.color_space, - sprite.color.solid, - sprite.color.colors - ); - out.color_solid = gradient.solid; - out.color0 = gradient.color0; - out.color1 = gradient.color1; return out; } @fragment fn fs_path(input: PathVarying) -> @location(0) vec4<f32> { - let sample = textureSample(t_sprite, s_sprite, input.tile_position).r; - let mask = 1.0 - abs(1.0 - sample % 2.0); - let sprite = b_path_sprites[input.instance_id]; - let background = sprite.color; - let color = gradient_color(background, input.position.xy, sprite.bounds, - input.color_solid, input.color0, input.color1); - return blend_color(color, mask); + let sample = textureSample(t_sprite, s_sprite, input.texture_coords); + return sample; } // --- underlines --- // diff --git a/crates/gpui/src/platform/keystroke.rs b/crates/gpui/src/platform/keystroke.rs index 8b6e72d150..24601eefd6 100644 --- a/crates/gpui/src/platform/keystroke.rs +++ b/crates/gpui/src/platform/keystroke.rs @@ -417,17 +417,6 @@ impl Modifiers { self.control || self.alt || self.shift || self.platform || self.function } - /// Returns the XOR of two modifier sets - pub fn xor(&self, other: &Modifiers) -> Modifiers { - Modifiers { - control: self.control ^ other.control, - alt: self.alt ^ other.alt, - shift: self.shift ^ other.shift, - platform: self.platform ^ other.platform, - function: self.function ^ other.function, - } - } - /// Whether the semantically 'secondary' modifier key is pressed. /// /// On macOS, this is the command key. @@ -545,11 +534,62 @@ impl Modifiers { /// Checks if this [`Modifiers`] is a subset of another [`Modifiers`]. pub fn is_subset_of(&self, other: &Modifiers) -> bool { - (other.control || !self.control) - && (other.alt || !self.alt) - && (other.shift || !self.shift) - && (other.platform || !self.platform) - && (other.function || !self.function) + (*other & *self) == *self + } +} + +impl std::ops::BitOr for Modifiers { + type Output = Self; + + fn bitor(mut self, other: Self) -> Self::Output { + self |= other; + self + } +} + +impl std::ops::BitOrAssign for Modifiers { + fn bitor_assign(&mut self, other: Self) { + self.control |= other.control; + self.alt |= other.alt; + self.shift |= other.shift; + self.platform |= other.platform; + self.function |= other.function; + } +} + +impl std::ops::BitXor for Modifiers { + type Output = Self; + fn bitxor(mut self, rhs: Self) -> Self::Output { + self ^= rhs; + self + } +} + +impl std::ops::BitXorAssign for Modifiers { + fn bitxor_assign(&mut self, other: Self) { + self.control ^= other.control; + self.alt ^= other.alt; + self.shift ^= other.shift; + self.platform ^= other.platform; + self.function ^= other.function; + } +} + +impl std::ops::BitAnd for Modifiers { + type Output = Self; + fn bitand(mut self, rhs: Self) -> Self::Output { + self &= rhs; + self + } +} + +impl std::ops::BitAndAssign for Modifiers { + fn bitand_assign(&mut self, other: Self) { + self.control &= other.control; + self.alt &= other.alt; + self.shift &= other.shift; + self.platform &= other.platform; + self.function &= other.function; } } diff --git a/crates/gpui/src/platform/linux/platform.rs b/crates/gpui/src/platform/linux/platform.rs index d65118e994..fe6a36baa8 100644 --- a/crates/gpui/src/platform/linux/platform.rs +++ b/crates/gpui/src/platform/linux/platform.rs @@ -845,9 +845,15 @@ impl crate::Keystroke { { if key.is_ascii_graphic() { key_utf8.to_lowercase() - // map ctrl-a to a - } else if key_utf32 <= 0x1f { - ((key_utf32 as u8 + 0x60) as char).to_string() + // map ctrl-a to `a` + // ctrl-0..9 may emit control codes like ctrl-[, but + // we don't want to map them to `[` + } else if key_utf32 <= 0x1f + && !name.chars().next().is_some_and(|c| c.is_ascii_digit()) + { + ((key_utf32 as u8 + 0x40) as char) + .to_ascii_lowercase() + .to_string() } else { name } diff --git a/crates/gpui/src/platform/linux/wayland/window.rs b/crates/gpui/src/platform/linux/wayland/window.rs index 255ae9c372..2b2207e22c 100644 --- a/crates/gpui/src/platform/linux/wayland/window.rs +++ b/crates/gpui/src/platform/linux/wayland/window.rs @@ -111,7 +111,7 @@ pub struct WaylandWindowState { resize_throttle: bool, in_progress_window_controls: Option<WindowControls>, window_controls: WindowControls, - inset: Option<Pixels>, + client_inset: Option<Pixels>, } #[derive(Clone)] @@ -186,7 +186,7 @@ impl WaylandWindowState { hovered: false, in_progress_window_controls: None, window_controls: WindowControls::default(), - inset: None, + client_inset: None, }) } @@ -211,6 +211,13 @@ impl WaylandWindowState { self.display = current_output; scale } + + pub fn inset(&self) -> Pixels { + match self.decorations { + WindowDecorations::Server => px(0.0), + WindowDecorations::Client => self.client_inset.unwrap_or(px(0.0)), + } + } } pub(crate) struct WaylandWindow(pub WaylandWindowStatePtr); @@ -380,7 +387,7 @@ impl WaylandWindowStatePtr { configure.size = if got_unmaximized { Some(state.window_bounds.size) } else { - compute_outer_size(state.inset, configure.size, state.tiling) + compute_outer_size(state.inset(), configure.size, state.tiling) }; if let Some(size) = configure.size { state.window_bounds = Bounds { @@ -400,7 +407,7 @@ impl WaylandWindowStatePtr { let window_geometry = inset_by_tiling( state.bounds.map_origin(|_| px(0.0)), - state.inset.unwrap_or(px(0.0)), + state.inset(), state.tiling, ) .map(|v| v.0 as i32) @@ -818,7 +825,7 @@ impl PlatformWindow for WaylandWindow { } else if state.maximized { WindowBounds::Maximized(state.window_bounds) } else { - let inset = state.inset.unwrap_or(px(0.)); + let inset = state.inset(); drop(state); WindowBounds::Windowed(self.bounds().inset(inset)) } @@ -1073,8 +1080,8 @@ impl PlatformWindow for WaylandWindow { fn set_client_inset(&self, inset: Pixels) { let mut state = self.borrow_mut(); - if Some(inset) != state.inset { - state.inset = Some(inset); + if Some(inset) != state.client_inset { + state.client_inset = Some(inset); update_window(state); } } @@ -1094,9 +1101,7 @@ fn update_window(mut state: RefMut<WaylandWindowState>) { state.renderer.update_transparency(!opaque); let mut opaque_area = state.window_bounds.map(|v| v.0 as i32); - if let Some(inset) = state.inset { - opaque_area.inset(inset.0 as i32); - } + opaque_area.inset(state.inset().0 as i32); let region = state .globals @@ -1169,12 +1174,10 @@ impl ResizeEdge { /// updating to account for the client decorations. But that's not the area we want to render /// to, due to our intrusize CSD. So, here we calculate the 'actual' size, by adding back in the insets fn compute_outer_size( - inset: Option<Pixels>, + inset: Pixels, new_size: Option<Size<Pixels>>, tiling: Tiling, ) -> Option<Size<Pixels>> { - let Some(inset) = inset else { return new_size }; - new_size.map(|mut new_size| { if !tiling.top { new_size.height += inset; diff --git a/crates/gpui/src/platform/linux/x11/client.rs b/crates/gpui/src/platform/linux/x11/client.rs index d1cb7d00cc..573e4addf7 100644 --- a/crates/gpui/src/platform/linux/x11/client.rs +++ b/crates/gpui/src/platform/linux/x11/client.rs @@ -1004,12 +1004,13 @@ impl X11Client { let mut keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code); let keysym = state.xkb.key_get_one_sym(code); - // should be called after key_get_one_sym - state.xkb.update_key(code, xkbc::KeyDirection::Down); - if keysym.is_modifier_key() { return Some(()); } + + // should be called after key_get_one_sym + state.xkb.update_key(code, xkbc::KeyDirection::Down); + if let Some(mut compose_state) = state.compose_state.take() { compose_state.feed(keysym); match compose_state.status() { @@ -1067,12 +1068,13 @@ impl X11Client { let keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code); let keysym = state.xkb.key_get_one_sym(code); - // should be called after key_get_one_sym - state.xkb.update_key(code, xkbc::KeyDirection::Up); - if keysym.is_modifier_key() { return Some(()); } + + // should be called after key_get_one_sym + state.xkb.update_key(code, xkbc::KeyDirection::Up); + keystroke }; drop(state); @@ -1793,6 +1795,7 @@ impl X11ClientState { drop(state); window.refresh(RequestFrameOptions { require_presentation: expose_event_received, + force_render: false, }); } xcb_connection diff --git a/crates/gpui/src/platform/mac/metal_atlas.rs b/crates/gpui/src/platform/mac/metal_atlas.rs index 366f2dcc3c..5d2d8e63e0 100644 --- a/crates/gpui/src/platform/mac/metal_atlas.rs +++ b/crates/gpui/src/platform/mac/metal_atlas.rs @@ -13,53 +13,25 @@ use std::borrow::Cow; pub(crate) struct MetalAtlas(Mutex<MetalAtlasState>); impl MetalAtlas { - pub(crate) fn new(device: Device, path_sample_count: u32) -> Self { + pub(crate) fn new(device: Device) -> Self { MetalAtlas(Mutex::new(MetalAtlasState { device: AssertSend(device), monochrome_textures: Default::default(), polychrome_textures: Default::default(), - path_textures: Default::default(), tiles_by_key: Default::default(), - path_sample_count, })) } pub(crate) fn metal_texture(&self, id: AtlasTextureId) -> metal::Texture { self.0.lock().texture(id).metal_texture.clone() } - - pub(crate) fn msaa_texture(&self, id: AtlasTextureId) -> Option<metal::Texture> { - self.0.lock().texture(id).msaa_texture.clone() - } - - pub(crate) fn allocate( - &self, - size: Size<DevicePixels>, - texture_kind: AtlasTextureKind, - ) -> Option<AtlasTile> { - self.0.lock().allocate(size, texture_kind) - } - - pub(crate) fn clear_textures(&self, texture_kind: AtlasTextureKind) { - let mut lock = self.0.lock(); - let textures = match texture_kind { - AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, - AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, - AtlasTextureKind::Path => &mut lock.path_textures, - }; - for texture in textures.iter_mut() { - texture.clear(); - } - } } struct MetalAtlasState { device: AssertSend<Device>, monochrome_textures: AtlasTextureList<MetalAtlasTexture>, polychrome_textures: AtlasTextureList<MetalAtlasTexture>, - path_textures: AtlasTextureList<MetalAtlasTexture>, tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, - path_sample_count: u32, } impl PlatformAtlas for MetalAtlas { @@ -94,7 +66,6 @@ impl PlatformAtlas for MetalAtlas { let textures = match id.kind { AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, - AtlasTextureKind::Path => &mut lock.polychrome_textures, }; let Some(texture_slot) = textures @@ -128,7 +99,6 @@ impl MetalAtlasState { let textures = match texture_kind { AtlasTextureKind::Monochrome => &mut self.monochrome_textures, AtlasTextureKind::Polychrome => &mut self.polychrome_textures, - AtlasTextureKind::Path => &mut self.path_textures, }; if let Some(tile) = textures @@ -173,31 +143,14 @@ impl MetalAtlasState { pixel_format = metal::MTLPixelFormat::BGRA8Unorm; usage = metal::MTLTextureUsage::ShaderRead; } - AtlasTextureKind::Path => { - pixel_format = metal::MTLPixelFormat::R16Float; - usage = metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead; - } } texture_descriptor.set_pixel_format(pixel_format); texture_descriptor.set_usage(usage); let metal_texture = self.device.new_texture(&texture_descriptor); - // We currently only enable MSAA for path textures. - let msaa_texture = if self.path_sample_count > 1 && kind == AtlasTextureKind::Path { - let mut descriptor = texture_descriptor.clone(); - descriptor.set_texture_type(metal::MTLTextureType::D2Multisample); - descriptor.set_storage_mode(metal::MTLStorageMode::Private); - descriptor.set_sample_count(self.path_sample_count as _); - let msaa_texture = self.device.new_texture(&descriptor); - Some(msaa_texture) - } else { - None - }; - let texture_list = match kind { AtlasTextureKind::Monochrome => &mut self.monochrome_textures, AtlasTextureKind::Polychrome => &mut self.polychrome_textures, - AtlasTextureKind::Path => &mut self.path_textures, }; let index = texture_list.free_list.pop(); @@ -209,7 +162,6 @@ impl MetalAtlasState { }, allocator: etagere::BucketedAtlasAllocator::new(size.into()), metal_texture: AssertSend(metal_texture), - msaa_texture: AssertSend(msaa_texture), live_atlas_keys: 0, }; @@ -226,7 +178,6 @@ impl MetalAtlasState { let textures = match id.kind { crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, - crate::AtlasTextureKind::Path => &self.path_textures, }; textures[id.index as usize].as_ref().unwrap() } @@ -236,15 +187,10 @@ struct MetalAtlasTexture { id: AtlasTextureId, allocator: BucketedAtlasAllocator, metal_texture: AssertSend<metal::Texture>, - msaa_texture: AssertSend<Option<metal::Texture>>, live_atlas_keys: u32, } impl MetalAtlasTexture { - fn clear(&mut self) { - self.allocator.clear(); - } - fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { let allocation = self.allocator.allocate(size.into())?; let tile = AtlasTile { diff --git a/crates/gpui/src/platform/mac/metal_renderer.rs b/crates/gpui/src/platform/mac/metal_renderer.rs index 3cdc2dd2cf..629654014d 100644 --- a/crates/gpui/src/platform/mac/metal_renderer.rs +++ b/crates/gpui/src/platform/mac/metal_renderer.rs @@ -1,27 +1,30 @@ use super::metal_atlas::MetalAtlas; use crate::{ - AtlasTextureId, AtlasTextureKind, AtlasTile, Background, Bounds, ContentMask, DevicePixels, - MonochromeSprite, PaintSurface, Path, PathId, PathVertex, PolychromeSprite, PrimitiveBatch, - Quad, ScaledPixels, Scene, Shadow, Size, Surface, Underline, point, size, + AtlasTextureId, Background, Bounds, ContentMask, DevicePixels, MonochromeSprite, PaintSurface, + Path, Point, PolychromeSprite, PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, + Surface, Underline, point, size, }; -use anyhow::{Context as _, Result}; +use anyhow::Result; use block::ConcreteBlock; use cocoa::{ base::{NO, YES}, foundation::{NSSize, NSUInteger}, quartzcore::AutoresizingMask, }; -use collections::HashMap; + use core_foundation::base::TCFType; use core_video::{ metal_texture::CVMetalTextureGetTexture, metal_texture_cache::CVMetalTextureCache, pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange, }; use foreign_types::{ForeignType, ForeignTypeRef}; -use metal::{CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange}; +use metal::{ + CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange, + RenderPassColorAttachmentDescriptorRef, +}; use objc::{self, msg_send, sel, sel_impl}; use parking_lot::Mutex; -use smallvec::SmallVec; + use std::{cell::Cell, ffi::c_void, mem, ptr, sync::Arc}; // Exported to metal @@ -111,6 +114,17 @@ pub(crate) struct MetalRenderer { instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>, sprite_atlas: Arc<MetalAtlas>, core_video_texture_cache: core_video::metal_texture_cache::CVMetalTextureCache, + path_intermediate_texture: Option<metal::Texture>, + path_intermediate_msaa_texture: Option<metal::Texture>, + path_sample_count: u32, +} + +#[repr(C)] +pub struct PathRasterizationVertex { + pub xy_position: Point<ScaledPixels>, + pub st_position: Point<f32>, + pub color: Background, + pub bounds: Bounds<ScaledPixels>, } impl MetalRenderer { @@ -175,10 +189,10 @@ impl MetalRenderer { "paths_rasterization", "path_rasterization_vertex", "path_rasterization_fragment", - MTLPixelFormat::R16Float, + MTLPixelFormat::BGRA8Unorm, PATH_SAMPLE_COUNT, ); - let path_sprites_pipeline_state = build_pipeline_state( + let path_sprites_pipeline_state = build_path_sprite_pipeline_state( &device, &library, "path_sprites", @@ -236,7 +250,7 @@ impl MetalRenderer { ); let command_queue = device.new_command_queue(); - let sprite_atlas = Arc::new(MetalAtlas::new(device.clone(), PATH_SAMPLE_COUNT)); + let sprite_atlas = Arc::new(MetalAtlas::new(device.clone())); let core_video_texture_cache = CVMetalTextureCache::new(None, device.clone(), None).unwrap(); @@ -257,6 +271,9 @@ impl MetalRenderer { instance_buffer_pool, sprite_atlas, core_video_texture_cache, + path_intermediate_texture: None, + path_intermediate_msaa_texture: None, + path_sample_count: PATH_SAMPLE_COUNT, } } @@ -289,6 +306,31 @@ impl MetalRenderer { setDrawableSize: size ]; } + let device_pixels_size = Size { + width: DevicePixels(size.width as i32), + height: DevicePixels(size.height as i32), + }; + self.update_path_intermediate_textures(device_pixels_size); + } + + fn update_path_intermediate_textures(&mut self, size: Size<DevicePixels>) { + let texture_descriptor = metal::TextureDescriptor::new(); + texture_descriptor.set_width(size.width.0 as u64); + texture_descriptor.set_height(size.height.0 as u64); + texture_descriptor.set_pixel_format(metal::MTLPixelFormat::BGRA8Unorm); + texture_descriptor + .set_usage(metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead); + self.path_intermediate_texture = Some(self.device.new_texture(&texture_descriptor)); + + if self.path_sample_count > 1 { + let mut msaa_descriptor = texture_descriptor.clone(); + msaa_descriptor.set_texture_type(metal::MTLTextureType::D2Multisample); + msaa_descriptor.set_storage_mode(metal::MTLStorageMode::Private); + msaa_descriptor.set_sample_count(self.path_sample_count as _); + self.path_intermediate_msaa_texture = Some(self.device.new_texture(&msaa_descriptor)); + } else { + self.path_intermediate_msaa_texture = None; + } } pub fn update_transparency(&self, _transparent: bool) { @@ -374,38 +416,18 @@ impl MetalRenderer { ) -> Result<metal::CommandBuffer> { let command_queue = self.command_queue.clone(); let command_buffer = command_queue.new_command_buffer(); + let alpha = if self.layer.is_opaque() { 1. } else { 0. }; let mut instance_offset = 0; - let path_tiles = self - .rasterize_paths( - scene.paths(), - instance_buffer, - &mut instance_offset, - command_buffer, - ) - .with_context(|| format!("rasterizing {} paths", scene.paths().len()))?; - - let render_pass_descriptor = metal::RenderPassDescriptor::new(); - let color_attachment = render_pass_descriptor - .color_attachments() - .object_at(0) - .unwrap(); - - color_attachment.set_texture(Some(drawable.texture())); - color_attachment.set_load_action(metal::MTLLoadAction::Clear); - color_attachment.set_store_action(metal::MTLStoreAction::Store); - let alpha = if self.layer.is_opaque() { 1. } else { 0. }; - color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha)); - let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); - - command_encoder.set_viewport(metal::MTLViewport { - originX: 0.0, - originY: 0.0, - width: i32::from(viewport_size.width) as f64, - height: i32::from(viewport_size.height) as f64, - znear: 0.0, - zfar: 1.0, - }); + let mut command_encoder = new_command_encoder( + command_buffer, + drawable, + viewport_size, + |color_attachment| { + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., alpha)); + }, + ); for batch in scene.batches() { let ok = match batch { @@ -414,29 +436,53 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), PrimitiveBatch::Quads(quads) => self.draw_quads( quads, instance_buffer, &mut instance_offset, viewport_size, - command_encoder, - ), - PrimitiveBatch::Paths(paths) => self.draw_paths( - paths, - &path_tiles, - instance_buffer, - &mut instance_offset, - viewport_size, - command_encoder, + &command_encoder, ), + PrimitiveBatch::Paths(paths) => { + command_encoder.end_encoding(); + + let did_draw = self.draw_paths_to_intermediate( + paths, + instance_buffer, + &mut instance_offset, + viewport_size, + command_buffer, + ); + + command_encoder = new_command_encoder( + command_buffer, + drawable, + viewport_size, + |color_attachment| { + color_attachment.set_load_action(metal::MTLLoadAction::Load); + }, + ); + + if did_draw { + self.draw_paths_from_intermediate( + paths, + instance_buffer, + &mut instance_offset, + viewport_size, + &command_encoder, + ) + } else { + false + } + } PrimitiveBatch::Underlines(underlines) => self.draw_underlines( underlines, instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), PrimitiveBatch::MonochromeSprites { texture_id, @@ -447,7 +493,7 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), PrimitiveBatch::PolychromeSprites { texture_id, @@ -458,17 +504,16 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces( surfaces, instance_buffer, &mut instance_offset, viewport_size, - command_encoder, + &command_encoder, ), }; - if !ok { command_encoder.end_encoding(); anyhow::bail!( @@ -493,104 +538,90 @@ impl MetalRenderer { Ok(command_buffer.to_owned()) } - fn rasterize_paths( + fn draw_paths_to_intermediate( &self, paths: &[Path<ScaledPixels>], instance_buffer: &mut InstanceBuffer, instance_offset: &mut usize, + viewport_size: Size<DevicePixels>, command_buffer: &metal::CommandBufferRef, - ) -> Option<HashMap<PathId, AtlasTile>> { - self.sprite_atlas.clear_textures(AtlasTextureKind::Path); + ) -> bool { + if paths.is_empty() { + return true; + } + let Some(intermediate_texture) = &self.path_intermediate_texture else { + return false; + }; - let mut tiles = HashMap::default(); - let mut vertices_by_texture_id = HashMap::default(); + let render_pass_descriptor = metal::RenderPassDescriptor::new(); + let color_attachment = render_pass_descriptor + .color_attachments() + .object_at(0) + .unwrap(); + color_attachment.set_load_action(metal::MTLLoadAction::Clear); + color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., 0.)); + + if let Some(msaa_texture) = &self.path_intermediate_msaa_texture { + color_attachment.set_texture(Some(msaa_texture)); + color_attachment.set_resolve_texture(Some(intermediate_texture)); + color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve); + } else { + color_attachment.set_texture(Some(intermediate_texture)); + color_attachment.set_store_action(metal::MTLStoreAction::Store); + } + + let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); + command_encoder.set_render_pipeline_state(&self.paths_rasterization_pipeline_state); + + align_offset(instance_offset); + let mut vertices = Vec::new(); for path in paths { - let clipped_bounds = path.bounds.intersect(&path.content_mask.bounds); - - let tile = self - .sprite_atlas - .allocate(clipped_bounds.size.map(Into::into), AtlasTextureKind::Path)?; - vertices_by_texture_id - .entry(tile.texture_id) - .or_insert(Vec::new()) - .extend(path.vertices.iter().map(|vertex| PathVertex { - xy_position: vertex.xy_position - clipped_bounds.origin - + tile.bounds.origin.map(Into::into), - st_position: vertex.st_position, - content_mask: ContentMask { - bounds: tile.bounds.map(Into::into), - }, - })); - tiles.insert(path.id, tile); + vertices.extend(path.vertices.iter().map(|v| PathRasterizationVertex { + xy_position: v.xy_position, + st_position: v.st_position, + color: path.color, + bounds: path.bounds.intersect(&path.content_mask.bounds), + })); } - - for (texture_id, vertices) in vertices_by_texture_id { - align_offset(instance_offset); - let vertices_bytes_len = mem::size_of_val(vertices.as_slice()); - let next_offset = *instance_offset + vertices_bytes_len; - if next_offset > instance_buffer.size { - return None; - } - - let render_pass_descriptor = metal::RenderPassDescriptor::new(); - let color_attachment = render_pass_descriptor - .color_attachments() - .object_at(0) - .unwrap(); - - let texture = self.sprite_atlas.metal_texture(texture_id); - let msaa_texture = self.sprite_atlas.msaa_texture(texture_id); - - if let Some(msaa_texture) = msaa_texture { - color_attachment.set_texture(Some(&msaa_texture)); - color_attachment.set_resolve_texture(Some(&texture)); - color_attachment.set_load_action(metal::MTLLoadAction::Clear); - color_attachment.set_store_action(metal::MTLStoreAction::MultisampleResolve); - } else { - color_attachment.set_texture(Some(&texture)); - color_attachment.set_load_action(metal::MTLLoadAction::Clear); - color_attachment.set_store_action(metal::MTLStoreAction::Store); - } - color_attachment.set_clear_color(metal::MTLClearColor::new(0., 0., 0., 1.)); - - let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); - command_encoder.set_render_pipeline_state(&self.paths_rasterization_pipeline_state); - command_encoder.set_vertex_buffer( - PathRasterizationInputIndex::Vertices as u64, - Some(&instance_buffer.metal_buffer), - *instance_offset as u64, - ); - let texture_size = Size { - width: DevicePixels::from(texture.width()), - height: DevicePixels::from(texture.height()), - }; - command_encoder.set_vertex_bytes( - PathRasterizationInputIndex::AtlasTextureSize as u64, - mem::size_of_val(&texture_size) as u64, - &texture_size as *const Size<DevicePixels> as *const _, - ); - - let buffer_contents = unsafe { - (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) - }; - unsafe { - ptr::copy_nonoverlapping( - vertices.as_ptr() as *const u8, - buffer_contents, - vertices_bytes_len, - ); - } - - command_encoder.draw_primitives( - metal::MTLPrimitiveType::Triangle, - 0, - vertices.len() as u64, - ); + let vertices_bytes_len = mem::size_of_val(vertices.as_slice()); + let next_offset = *instance_offset + vertices_bytes_len; + if next_offset > instance_buffer.size { command_encoder.end_encoding(); - *instance_offset = next_offset; + return false; } + command_encoder.set_vertex_buffer( + PathRasterizationInputIndex::Vertices as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + command_encoder.set_vertex_bytes( + PathRasterizationInputIndex::ViewportSize as u64, + mem::size_of_val(&viewport_size) as u64, + &viewport_size as *const Size<DevicePixels> as *const _, + ); + command_encoder.set_fragment_buffer( + PathRasterizationInputIndex::Vertices as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + let buffer_contents = + unsafe { (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) }; + unsafe { + ptr::copy_nonoverlapping( + vertices.as_ptr() as *const u8, + buffer_contents, + vertices_bytes_len, + ); + } + command_encoder.draw_primitives( + metal::MTLPrimitiveType::Triangle, + 0, + vertices.len() as u64, + ); + *instance_offset = next_offset; - Some(tiles) + command_encoder.end_encoding(); + true } fn draw_shadows( @@ -715,18 +746,21 @@ impl MetalRenderer { true } - fn draw_paths( + fn draw_paths_from_intermediate( &self, paths: &[Path<ScaledPixels>], - tiles_by_path_id: &HashMap<PathId, AtlasTile>, instance_buffer: &mut InstanceBuffer, instance_offset: &mut usize, viewport_size: Size<DevicePixels>, command_encoder: &metal::RenderCommandEncoderRef, ) -> bool { - if paths.is_empty() { + let Some(ref first_path) = paths.first() else { return true; - } + }; + + let Some(ref intermediate_texture) = self.path_intermediate_texture else { + return false; + }; command_encoder.set_render_pipeline_state(&self.path_sprites_pipeline_state); command_encoder.set_vertex_buffer( @@ -740,88 +774,65 @@ impl MetalRenderer { &viewport_size as *const Size<DevicePixels> as *const _, ); - let mut prev_texture_id = None; - let mut sprites = SmallVec::<[_; 1]>::new(); - let mut paths_and_tiles = paths - .iter() - .map(|path| (path, tiles_by_path_id.get(&path.id).unwrap())) - .peekable(); + command_encoder.set_fragment_texture( + SpriteInputIndex::AtlasTexture as u64, + Some(intermediate_texture), + ); - loop { - if let Some((path, tile)) = paths_and_tiles.peek() { - if prev_texture_id.map_or(true, |texture_id| texture_id == tile.texture_id) { - prev_texture_id = Some(tile.texture_id); - let origin = path.bounds.intersect(&path.content_mask.bounds).origin; - sprites.push(PathSprite { - bounds: Bounds { - origin: origin.map(|p| p.floor()), - size: tile.bounds.size.map(Into::into), - }, - color: path.color, - tile: (*tile).clone(), - }); - paths_and_tiles.next(); - continue; - } - } - - if sprites.is_empty() { - break; - } else { - align_offset(instance_offset); - let texture_id = prev_texture_id.take().unwrap(); - let texture: metal::Texture = self.sprite_atlas.metal_texture(texture_id); - let texture_size = size( - DevicePixels(texture.width() as i32), - DevicePixels(texture.height() as i32), - ); - - command_encoder.set_vertex_buffer( - SpriteInputIndex::Sprites as u64, - Some(&instance_buffer.metal_buffer), - *instance_offset as u64, - ); - command_encoder.set_vertex_bytes( - SpriteInputIndex::AtlasTextureSize as u64, - mem::size_of_val(&texture_size) as u64, - &texture_size as *const Size<DevicePixels> as *const _, - ); - command_encoder.set_fragment_buffer( - SpriteInputIndex::Sprites as u64, - Some(&instance_buffer.metal_buffer), - *instance_offset as u64, - ); - command_encoder - .set_fragment_texture(SpriteInputIndex::AtlasTexture as u64, Some(&texture)); - - let sprite_bytes_len = mem::size_of_val(sprites.as_slice()); - let next_offset = *instance_offset + sprite_bytes_len; - if next_offset > instance_buffer.size { - return false; - } - - let buffer_contents = unsafe { - (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) - }; - - unsafe { - ptr::copy_nonoverlapping( - sprites.as_ptr() as *const u8, - buffer_contents, - sprite_bytes_len, - ); - } - - command_encoder.draw_primitives_instanced( - metal::MTLPrimitiveType::Triangle, - 0, - 6, - sprites.len() as u64, - ); - *instance_offset = next_offset; - sprites.clear(); + // When copying paths from the intermediate texture to the drawable, + // each pixel must only be copied once, in case of transparent paths. + // + // If all paths have the same draw order, then their bounds are all + // disjoint, so we can copy each path's bounds individually. If this + // batch combines different draw orders, we perform a single copy + // for a minimal spanning rect. + let sprites; + if paths.last().unwrap().order == first_path.order { + sprites = paths + .iter() + .map(|path| PathSprite { + bounds: path.clipped_bounds(), + }) + .collect(); + } else { + let mut bounds = first_path.clipped_bounds(); + for path in paths.iter().skip(1) { + bounds = bounds.union(&path.clipped_bounds()); } + sprites = vec![PathSprite { bounds }]; } + + align_offset(instance_offset); + let sprite_bytes_len = mem::size_of_val(sprites.as_slice()); + let next_offset = *instance_offset + sprite_bytes_len; + if next_offset > instance_buffer.size { + return false; + } + + command_encoder.set_vertex_buffer( + SpriteInputIndex::Sprites as u64, + Some(&instance_buffer.metal_buffer), + *instance_offset as u64, + ); + + let buffer_contents = + unsafe { (instance_buffer.metal_buffer.contents() as *mut u8).add(*instance_offset) }; + unsafe { + ptr::copy_nonoverlapping( + sprites.as_ptr() as *const u8, + buffer_contents, + sprite_bytes_len, + ); + } + + command_encoder.draw_primitives_instanced( + metal::MTLPrimitiveType::Triangle, + 0, + 6, + sprites.len() as u64, + ); + *instance_offset = next_offset; + true } @@ -1136,6 +1147,33 @@ impl MetalRenderer { } } +fn new_command_encoder<'a>( + command_buffer: &'a metal::CommandBufferRef, + drawable: &'a metal::MetalDrawableRef, + viewport_size: Size<DevicePixels>, + configure_color_attachment: impl Fn(&RenderPassColorAttachmentDescriptorRef), +) -> &'a metal::RenderCommandEncoderRef { + let render_pass_descriptor = metal::RenderPassDescriptor::new(); + let color_attachment = render_pass_descriptor + .color_attachments() + .object_at(0) + .unwrap(); + color_attachment.set_texture(Some(drawable.texture())); + color_attachment.set_store_action(metal::MTLStoreAction::Store); + configure_color_attachment(color_attachment); + + let command_encoder = command_buffer.new_render_command_encoder(render_pass_descriptor); + command_encoder.set_viewport(metal::MTLViewport { + originX: 0.0, + originY: 0.0, + width: i32::from(viewport_size.width) as f64, + height: i32::from(viewport_size.height) as f64, + znear: 0.0, + zfar: 1.0, + }); + command_encoder +} + fn build_pipeline_state( device: &metal::DeviceRef, library: &metal::LibraryRef, @@ -1170,6 +1208,40 @@ fn build_pipeline_state( .expect("could not create render pipeline state") } +fn build_path_sprite_pipeline_state( + device: &metal::DeviceRef, + library: &metal::LibraryRef, + label: &str, + vertex_fn_name: &str, + fragment_fn_name: &str, + pixel_format: metal::MTLPixelFormat, +) -> metal::RenderPipelineState { + let vertex_fn = library + .get_function(vertex_fn_name, None) + .expect("error locating vertex function"); + let fragment_fn = library + .get_function(fragment_fn_name, None) + .expect("error locating fragment function"); + + let descriptor = metal::RenderPipelineDescriptor::new(); + descriptor.set_label(label); + descriptor.set_vertex_function(Some(vertex_fn.as_ref())); + descriptor.set_fragment_function(Some(fragment_fn.as_ref())); + let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); + color_attachment.set_pixel_format(pixel_format); + color_attachment.set_blending_enabled(true); + color_attachment.set_rgb_blend_operation(metal::MTLBlendOperation::Add); + color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add); + color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); + color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::One); + + device + .new_render_pipeline_state(&descriptor) + .expect("could not create render pipeline state") +} + fn build_path_rasterization_pipeline_state( device: &metal::DeviceRef, library: &metal::LibraryRef, @@ -1192,7 +1264,7 @@ fn build_path_rasterization_pipeline_state( descriptor.set_fragment_function(Some(fragment_fn.as_ref())); if path_sample_count > 1 { descriptor.set_raster_sample_count(path_sample_count as _); - descriptor.set_alpha_to_coverage_enabled(true); + descriptor.set_alpha_to_coverage_enabled(false); } let color_attachment = descriptor.color_attachments().object_at(0).unwrap(); color_attachment.set_pixel_format(pixel_format); @@ -1201,8 +1273,8 @@ fn build_path_rasterization_pipeline_state( color_attachment.set_alpha_blend_operation(metal::MTLBlendOperation::Add); color_attachment.set_source_rgb_blend_factor(metal::MTLBlendFactor::One); color_attachment.set_source_alpha_blend_factor(metal::MTLBlendFactor::One); - color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::One); - color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::One); + color_attachment.set_destination_rgb_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); + color_attachment.set_destination_alpha_blend_factor(metal::MTLBlendFactor::OneMinusSourceAlpha); device .new_render_pipeline_state(&descriptor) @@ -1257,15 +1329,13 @@ enum SurfaceInputIndex { #[repr(C)] enum PathRasterizationInputIndex { Vertices = 0, - AtlasTextureSize = 1, + ViewportSize = 1, } #[derive(Clone, Debug, Eq, PartialEq)] #[repr(C)] pub struct PathSprite { pub bounds: Bounds<ScaledPixels>, - pub color: Background, - pub tile: AtlasTile, } #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/crates/gpui/src/platform/mac/shaders.metal b/crates/gpui/src/platform/mac/shaders.metal index 64ebb1e22b..f9d5bdbf4c 100644 --- a/crates/gpui/src/platform/mac/shaders.metal +++ b/crates/gpui/src/platform/mac/shaders.metal @@ -701,107 +701,117 @@ fragment float4 polychrome_sprite_fragment( struct PathRasterizationVertexOutput { float4 position [[position]]; float2 st_position; + uint vertex_id [[flat]]; float clip_rect_distance [[clip_distance]][4]; }; struct PathRasterizationFragmentInput { float4 position [[position]]; float2 st_position; + uint vertex_id [[flat]]; }; vertex PathRasterizationVertexOutput path_rasterization_vertex( - uint vertex_id [[vertex_id]], - constant PathVertex_ScaledPixels *vertices - [[buffer(PathRasterizationInputIndex_Vertices)]], - constant Size_DevicePixels *atlas_size - [[buffer(PathRasterizationInputIndex_AtlasTextureSize)]]) { - PathVertex_ScaledPixels v = vertices[vertex_id]; + uint vertex_id [[vertex_id]], + constant PathRasterizationVertex *vertices [[buffer(PathRasterizationInputIndex_Vertices)]], + constant Size_DevicePixels *atlas_size [[buffer(PathRasterizationInputIndex_ViewportSize)]] +) { + PathRasterizationVertex v = vertices[vertex_id]; float2 vertex_position = float2(v.xy_position.x, v.xy_position.y); - float2 viewport_size = float2(atlas_size->width, atlas_size->height); + float4 position = float4( + vertex_position * float2(2. / atlas_size->width, -2. / atlas_size->height) + float2(-1., 1.), + 0., + 1. + ); return PathRasterizationVertexOutput{ - float4(vertex_position / viewport_size * float2(2., -2.) + - float2(-1., 1.), - 0., 1.), + position, float2(v.st_position.x, v.st_position.y), - {v.xy_position.x - v.content_mask.bounds.origin.x, - v.content_mask.bounds.origin.x + v.content_mask.bounds.size.width - - v.xy_position.x, - v.xy_position.y - v.content_mask.bounds.origin.y, - v.content_mask.bounds.origin.y + v.content_mask.bounds.size.height - - v.xy_position.y}}; + vertex_id, + { + v.xy_position.x - v.bounds.origin.x, + v.bounds.origin.x + v.bounds.size.width - v.xy_position.x, + v.xy_position.y - v.bounds.origin.y, + v.bounds.origin.y + v.bounds.size.height - v.xy_position.y + } + }; } -fragment float4 path_rasterization_fragment(PathRasterizationFragmentInput input - [[stage_in]]) { +fragment float4 path_rasterization_fragment( + PathRasterizationFragmentInput input [[stage_in]], + constant PathRasterizationVertex *vertices [[buffer(PathRasterizationInputIndex_Vertices)]] +) { float2 dx = dfdx(input.st_position); float2 dy = dfdy(input.st_position); - float2 gradient = float2((2. * input.st_position.x) * dx.x - dx.y, - (2. * input.st_position.x) * dy.x - dy.y); - float f = (input.st_position.x * input.st_position.x) - input.st_position.y; - float distance = f / length(gradient); - float alpha = saturate(0.5 - distance); - return float4(alpha, 0., 0., 1.); + + PathRasterizationVertex v = vertices[input.vertex_id]; + Background background = v.color; + Bounds_ScaledPixels path_bounds = v.bounds; + float alpha; + if (length(float2(dx.x, dy.x)) < 0.001) { + alpha = 1.0; + } else { + float2 gradient = float2( + (2. * input.st_position.x) * dx.x - dx.y, + (2. * input.st_position.x) * dy.x - dy.y + ); + float f = (input.st_position.x * input.st_position.x) - input.st_position.y; + float distance = f / length(gradient); + alpha = saturate(0.5 - distance); + } + + GradientColor gradient_color = prepare_fill_color( + background.tag, + background.color_space, + background.solid, + background.colors[0].color, + background.colors[1].color + ); + + float4 color = fill_color( + background, + input.position.xy, + path_bounds, + gradient_color.solid, + gradient_color.color0, + gradient_color.color1 + ); + return float4(color.rgb * color.a * alpha, alpha * color.a); } struct PathSpriteVertexOutput { float4 position [[position]]; - float2 tile_position; - uint sprite_id [[flat]]; - float4 solid_color [[flat]]; - float4 color0 [[flat]]; - float4 color1 [[flat]]; + float2 texture_coords; }; vertex PathSpriteVertexOutput path_sprite_vertex( - uint unit_vertex_id [[vertex_id]], uint sprite_id [[instance_id]], - constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]], - constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], - constant Size_DevicePixels *viewport_size - [[buffer(SpriteInputIndex_ViewportSize)]], - constant Size_DevicePixels *atlas_size - [[buffer(SpriteInputIndex_AtlasTextureSize)]]) { - + uint unit_vertex_id [[vertex_id]], + uint sprite_id [[instance_id]], + constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]], + constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], + constant Size_DevicePixels *viewport_size [[buffer(SpriteInputIndex_ViewportSize)]] +) { float2 unit_vertex = unit_vertices[unit_vertex_id]; PathSprite sprite = sprites[sprite_id]; // Don't apply content mask because it was already accounted for when // rasterizing the path. float4 device_position = to_device_position(unit_vertex, sprite.bounds, viewport_size); - float2 tile_position = to_tile_position(unit_vertex, sprite.tile, atlas_size); - GradientColor gradient = prepare_fill_color( - sprite.color.tag, - sprite.color.color_space, - sprite.color.solid, - sprite.color.colors[0].color, - sprite.color.colors[1].color - ); + float2 screen_position = float2(sprite.bounds.origin.x, sprite.bounds.origin.y) + unit_vertex * float2(sprite.bounds.size.width, sprite.bounds.size.height); + float2 texture_coords = screen_position / float2(viewport_size->width, viewport_size->height); return PathSpriteVertexOutput{ device_position, - tile_position, - sprite_id, - gradient.solid, - gradient.color0, - gradient.color1 + texture_coords }; } fragment float4 path_sprite_fragment( - PathSpriteVertexOutput input [[stage_in]], - constant PathSprite *sprites [[buffer(SpriteInputIndex_Sprites)]], - texture2d<float> atlas_texture [[texture(SpriteInputIndex_AtlasTexture)]]) { - constexpr sampler atlas_texture_sampler(mag_filter::linear, - min_filter::linear); - float4 sample = - atlas_texture.sample(atlas_texture_sampler, input.tile_position); - float mask = 1. - abs(1. - fmod(sample.r, 2.)); - PathSprite sprite = sprites[input.sprite_id]; - Background background = sprite.color; - float4 color = fill_color(background, input.position.xy, sprite.bounds, - input.solid_color, input.color0, input.color1); - color.a *= mask; - return color; + PathSpriteVertexOutput input [[stage_in]], + texture2d<float> intermediate_texture [[texture(SpriteInputIndex_AtlasTexture)]] +) { + constexpr sampler intermediate_texture_sampler(mag_filter::linear, min_filter::linear); + return intermediate_texture.sample(intermediate_texture_sampler, input.texture_coords); } struct SurfaceVertexOutput { diff --git a/crates/gpui/src/platform/test/window.rs b/crates/gpui/src/platform/test/window.rs index 1b88415d3b..e15bd7aeec 100644 --- a/crates/gpui/src/platform/test/window.rs +++ b/crates/gpui/src/platform/test/window.rs @@ -341,7 +341,7 @@ impl PlatformAtlas for TestAtlas { crate::AtlasTile { texture_id: AtlasTextureId { index: texture_id, - kind: crate::AtlasTextureKind::Path, + kind: crate::AtlasTextureKind::Monochrome, }, tile_id: TileId(tile_id), padding: 0, diff --git a/crates/gpui/src/platform/windows.rs b/crates/gpui/src/platform/windows.rs index 4bdf42080d..5268d3ccba 100644 --- a/crates/gpui/src/platform/windows.rs +++ b/crates/gpui/src/platform/windows.rs @@ -1,6 +1,8 @@ mod clipboard; mod destination_list; mod direct_write; +mod directx_atlas; +mod directx_renderer; mod dispatcher; mod display; mod events; @@ -14,6 +16,8 @@ mod wrapper; pub(crate) use clipboard::*; pub(crate) use destination_list::*; pub(crate) use direct_write::*; +pub(crate) use directx_atlas::*; +pub(crate) use directx_renderer::*; pub(crate) use dispatcher::*; pub(crate) use display::*; pub(crate) use events::*; diff --git a/crates/gpui/src/platform/windows/color_text_raster.hlsl b/crates/gpui/src/platform/windows/color_text_raster.hlsl new file mode 100644 index 0000000000..ccc5fa26f0 --- /dev/null +++ b/crates/gpui/src/platform/windows/color_text_raster.hlsl @@ -0,0 +1,39 @@ +struct RasterVertexOutput { + float4 position : SV_Position; + float2 texcoord : TEXCOORD0; +}; + +RasterVertexOutput emoji_rasterization_vertex(uint vertexID : SV_VERTEXID) +{ + RasterVertexOutput output; + output.texcoord = float2((vertexID << 1) & 2, vertexID & 2); + output.position = float4(output.texcoord * 2.0f - 1.0f, 0.0f, 1.0f); + output.position.y = -output.position.y; + + return output; +} + +struct PixelInput { + float4 position: SV_Position; + float2 texcoord : TEXCOORD0; +}; + +struct Bounds { + int2 origin; + int2 size; +}; + +Texture2D<float4> t_layer : register(t0); +SamplerState s_layer : register(s0); + +cbuffer GlyphLayerTextureParams : register(b0) { + Bounds bounds; + float4 run_color; +}; + +float4 emoji_rasterization_fragment(PixelInput input): SV_Target { + float3 sampled = t_layer.Sample(s_layer, input.texcoord.xy).rgb; + float alpha = (sampled.r + sampled.g + sampled.b) / 3; + + return float4(run_color.rgb, alpha); +} diff --git a/crates/gpui/src/platform/windows/direct_write.rs b/crates/gpui/src/platform/windows/direct_write.rs index ada306c15c..587cb7b4a6 100644 --- a/crates/gpui/src/platform/windows/direct_write.rs +++ b/crates/gpui/src/platform/windows/direct_write.rs @@ -10,10 +10,11 @@ use windows::{ Foundation::*, Globalization::GetUserDefaultLocaleName, Graphics::{ - Direct2D::{Common::*, *}, + Direct3D::D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + Direct3D11::*, DirectWrite::*, Dxgi::Common::*, - Gdi::LOGFONTW, + Gdi::{IsRectEmpty, LOGFONTW}, Imaging::*, }, System::SystemServices::LOCALE_NAME_MAX_LENGTH, @@ -40,16 +41,21 @@ struct DirectWriteComponent { locale: String, factory: IDWriteFactory5, bitmap_factory: AgileReference<IWICImagingFactory>, - d2d1_factory: ID2D1Factory, in_memory_loader: IDWriteInMemoryFontFileLoader, builder: IDWriteFontSetBuilder1, text_renderer: Arc<TextRendererWrapper>, - render_context: GlyphRenderContext, + + render_params: IDWriteRenderingParams3, + gpu_state: GPUState, } -struct GlyphRenderContext { - params: IDWriteRenderingParams3, - dc_target: ID2D1DeviceContext4, +struct GPUState { + device: ID3D11Device, + device_context: ID3D11DeviceContext, + sampler: [Option<ID3D11SamplerState>; 1], + blend_state: ID3D11BlendState, + vertex_shader: ID3D11VertexShader, + pixel_shader: ID3D11PixelShader, } struct DirectWriteState { @@ -70,12 +76,11 @@ struct FontIdentifier { } impl DirectWriteComponent { - pub fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> { + pub fn new(bitmap_factory: &IWICImagingFactory, gpu_context: &DirectXDevices) -> Result<Self> { + // todo: ideally this would not be a large unsafe block but smaller isolated ones for easier auditing unsafe { let factory: IDWriteFactory5 = DWriteCreateFactory(DWRITE_FACTORY_TYPE_SHARED)?; let bitmap_factory = AgileReference::new(bitmap_factory)?; - let d2d1_factory: ID2D1Factory = - D2D1CreateFactory(D2D1_FACTORY_TYPE_MULTI_THREADED, None)?; // The `IDWriteInMemoryFontFileLoader` here is supported starting from // Windows 10 Creators Update, which consequently requires the entire // `DirectWriteTextSystem` to run on `win10 1703`+. @@ -86,60 +91,132 @@ impl DirectWriteComponent { GetUserDefaultLocaleName(&mut locale_vec); let locale = String::from_utf16_lossy(&locale_vec); let text_renderer = Arc::new(TextRendererWrapper::new(&locale)); - let render_context = GlyphRenderContext::new(&factory, &d2d1_factory)?; + + let render_params = { + let default_params: IDWriteRenderingParams3 = + factory.CreateRenderingParams()?.cast()?; + let gamma = default_params.GetGamma(); + let enhanced_contrast = default_params.GetEnhancedContrast(); + let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); + let cleartype_level = default_params.GetClearTypeLevel(); + let grid_fit_mode = default_params.GetGridFitMode(); + + factory.CreateCustomRenderingParams( + gamma, + enhanced_contrast, + gray_contrast, + cleartype_level, + DWRITE_PIXEL_GEOMETRY_RGB, + DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + grid_fit_mode, + )? + }; + + let gpu_state = GPUState::new(gpu_context)?; Ok(DirectWriteComponent { locale, factory, bitmap_factory, - d2d1_factory, in_memory_loader, builder, text_renderer, - render_context, + render_params, + gpu_state, }) } } } -impl GlyphRenderContext { - pub fn new(factory: &IDWriteFactory5, d2d1_factory: &ID2D1Factory) -> Result<Self> { - unsafe { - let default_params: IDWriteRenderingParams3 = - factory.CreateRenderingParams()?.cast()?; - let gamma = default_params.GetGamma(); - let enhanced_contrast = default_params.GetEnhancedContrast(); - let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); - let cleartype_level = default_params.GetClearTypeLevel(); - let grid_fit_mode = default_params.GetGridFitMode(); +impl GPUState { + fn new(gpu_context: &DirectXDevices) -> Result<Self> { + let device = gpu_context.device.clone(); + let device_context = gpu_context.device_context.clone(); - let params = factory.CreateCustomRenderingParams( - gamma, - enhanced_contrast, - gray_contrast, - cleartype_level, - DWRITE_PIXEL_GEOMETRY_RGB, - DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, - grid_fit_mode, - )?; - let dc_target = { - let target = d2d1_factory.CreateDCRenderTarget(&get_render_target_property( - DXGI_FORMAT_B8G8R8A8_UNORM, - D2D1_ALPHA_MODE_PREMULTIPLIED, - ))?; - let target = target.cast::<ID2D1DeviceContext4>()?; - target.SetTextRenderingParams(¶ms); - target + let blend_state = { + let mut blend_state = None; + let desc = D3D11_BLEND_DESC { + AlphaToCoverageEnable: false.into(), + IndependentBlendEnable: false.into(), + RenderTarget: [ + D3D11_RENDER_TARGET_BLEND_DESC { + BlendEnable: true.into(), + SrcBlend: D3D11_BLEND_SRC_ALPHA, + DestBlend: D3D11_BLEND_INV_SRC_ALPHA, + BlendOp: D3D11_BLEND_OP_ADD, + SrcBlendAlpha: D3D11_BLEND_SRC_ALPHA, + DestBlendAlpha: D3D11_BLEND_INV_SRC_ALPHA, + BlendOpAlpha: D3D11_BLEND_OP_ADD, + RenderTargetWriteMask: D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8, + }, + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + ], }; + unsafe { device.CreateBlendState(&desc, Some(&mut blend_state)) }?; + blend_state.unwrap() + }; - Ok(Self { params, dc_target }) - } + let sampler = { + let mut sampler = None; + let desc = D3D11_SAMPLER_DESC { + Filter: D3D11_FILTER_MIN_MAG_MIP_POINT, + AddressU: D3D11_TEXTURE_ADDRESS_BORDER, + AddressV: D3D11_TEXTURE_ADDRESS_BORDER, + AddressW: D3D11_TEXTURE_ADDRESS_BORDER, + MipLODBias: 0.0, + MaxAnisotropy: 1, + ComparisonFunc: D3D11_COMPARISON_ALWAYS, + BorderColor: [0.0, 0.0, 0.0, 0.0], + MinLOD: 0.0, + MaxLOD: 0.0, + }; + unsafe { device.CreateSamplerState(&desc, Some(&mut sampler)) }?; + [sampler] + }; + + let vertex_shader = { + let source = shader_resources::RawShaderBytes::new( + shader_resources::ShaderModule::EmojiRasterization, + shader_resources::ShaderTarget::Vertex, + )?; + let mut shader = None; + unsafe { device.CreateVertexShader(source.as_bytes(), None, Some(&mut shader)) }?; + shader.unwrap() + }; + + let pixel_shader = { + let source = shader_resources::RawShaderBytes::new( + shader_resources::ShaderModule::EmojiRasterization, + shader_resources::ShaderTarget::Fragment, + )?; + let mut shader = None; + unsafe { device.CreatePixelShader(source.as_bytes(), None, Some(&mut shader)) }?; + shader.unwrap() + }; + + Ok(Self { + device, + device_context, + sampler, + blend_state, + vertex_shader, + pixel_shader, + }) } } impl DirectWriteTextSystem { - pub(crate) fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> { - let components = DirectWriteComponent::new(bitmap_factory)?; + pub(crate) fn new( + gpu_context: &DirectXDevices, + bitmap_factory: &IWICImagingFactory, + ) -> Result<Self> { + let components = DirectWriteComponent::new(bitmap_factory, gpu_context)?; let system_font_collection = unsafe { let mut result = std::mem::zeroed(); components @@ -648,15 +725,13 @@ impl DirectWriteState { } } - fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> { - let render_target = &self.components.render_context.dc_target; - unsafe { - render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); - render_target.SetDpi(96.0 * params.scale_factor, 96.0 * params.scale_factor); - } + fn create_glyph_run_analysis( + &self, + params: &RenderGlyphParams, + ) -> Result<IDWriteGlyphRunAnalysis> { let font = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; - let advance = [0.0f32]; + let advance = [0.0]; let offset = [DWRITE_GLYPH_OFFSET::default()]; let glyph_run = DWRITE_GLYPH_RUN { fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, @@ -668,44 +743,87 @@ impl DirectWriteState { isSideways: BOOL(0), bidiLevel: 0, }; - let bounds = unsafe { - render_target.GetGlyphRunWorldBounds( - Vector2 { X: 0.0, Y: 0.0 }, - &glyph_run, - DWRITE_MEASURING_MODE_NATURAL, - )? + let transform = DWRITE_MATRIX { + m11: params.scale_factor, + m12: 0.0, + m21: 0.0, + m22: params.scale_factor, + dx: 0.0, + dy: 0.0, }; - // todo(windows) - // This is a walkaround, deleted when figured out. - let y_offset; - let extra_height; - if params.is_emoji { - y_offset = 0; - extra_height = 0; - } else { - // make some room for scaler. - y_offset = -1; - extra_height = 2; + let subpixel_shift = params + .subpixel_variant + .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + let baseline_origin_x = subpixel_shift.x / params.scale_factor; + let baseline_origin_y = subpixel_shift.y / params.scale_factor; + + let mut rendering_mode = DWRITE_RENDERING_MODE1::default(); + let mut grid_fit_mode = DWRITE_GRID_FIT_MODE::default(); + unsafe { + font.font_face.GetRecommendedRenderingMode( + params.font_size.0, + // The dpi here seems that it has the same effect with `Some(&transform)` + 1.0, + 1.0, + Some(&transform), + false, + DWRITE_OUTLINE_THRESHOLD_ANTIALIASED, + DWRITE_MEASURING_MODE_NATURAL, + &self.components.render_params, + &mut rendering_mode, + &mut grid_fit_mode, + )?; } - if bounds.right < bounds.left { + let glyph_analysis = unsafe { + self.components.factory.CreateGlyphRunAnalysis( + &glyph_run, + Some(&transform), + rendering_mode, + DWRITE_MEASURING_MODE_NATURAL, + grid_fit_mode, + // We're using cleartype not grayscale for monochrome is because it provides better quality + DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + baseline_origin_x, + baseline_origin_y, + ) + }?; + Ok(glyph_analysis) + } + + fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> { + let glyph_analysis = self.create_glyph_run_analysis(params)?; + + let bounds = unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1)? }; + // Some glyphs cannot be drawn with ClearType, such as bitmap fonts. In that case + // GetAlphaTextureBounds() supposedly returns an empty RECT, but I haven't tested that yet. + if !unsafe { IsRectEmpty(&bounds) }.as_bool() { Ok(Bounds { - origin: point(0.into(), 0.into()), - size: size(0.into(), 0.into()), + origin: point(bounds.left.into(), bounds.top.into()), + size: size( + (bounds.right - bounds.left).into(), + (bounds.bottom - bounds.top).into(), + ), }) } else { - Ok(Bounds { - origin: point( - ((bounds.left * params.scale_factor).ceil() as i32).into(), - ((bounds.top * params.scale_factor).ceil() as i32 + y_offset).into(), - ), - size: size( - (((bounds.right - bounds.left) * params.scale_factor).ceil() as i32).into(), - (((bounds.bottom - bounds.top) * params.scale_factor).ceil() as i32 - + extra_height) - .into(), - ), - }) + // If it's empty, retry with grayscale AA. + let bounds = + unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_ALIASED_1x1)? }; + + if bounds.right < bounds.left { + Ok(Bounds { + origin: point(0.into(), 0.into()), + size: size(0.into(), 0.into()), + }) + } else { + Ok(Bounds { + origin: point(bounds.left.into(), bounds.top.into()), + size: size( + (bounds.right - bounds.left).into(), + (bounds.bottom - bounds.top).into(), + ), + }) + } } } @@ -731,7 +849,95 @@ impl DirectWriteState { anyhow::bail!("glyph bounds are empty"); } - let font_info = &self.fonts[params.font_id.0]; + let bitmap_data = if params.is_emoji { + if let Ok(color) = self.rasterize_color(¶ms, glyph_bounds) { + color + } else { + let monochrome = self.rasterize_monochrome(params, glyph_bounds)?; + monochrome + .into_iter() + .flat_map(|pixel| [0, 0, 0, pixel]) + .collect::<Vec<_>>() + } + } else { + self.rasterize_monochrome(params, glyph_bounds)? + }; + + Ok((glyph_bounds.size, bitmap_data)) + } + + fn rasterize_monochrome( + &self, + params: &RenderGlyphParams, + glyph_bounds: Bounds<DevicePixels>, + ) -> Result<Vec<u8>> { + let mut bitmap_data = + vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize * 3]; + + let glyph_analysis = self.create_glyph_run_analysis(params)?; + unsafe { + glyph_analysis.CreateAlphaTexture( + // We're using cleartype not grayscale for monochrome is because it provides better quality + DWRITE_TEXTURE_CLEARTYPE_3x1, + &RECT { + left: glyph_bounds.origin.x.0, + top: glyph_bounds.origin.y.0, + right: glyph_bounds.size.width.0 + glyph_bounds.origin.x.0, + bottom: glyph_bounds.size.height.0 + glyph_bounds.origin.y.0, + }, + &mut bitmap_data, + )?; + } + + let bitmap_factory = self.components.bitmap_factory.resolve()?; + let bitmap = unsafe { + bitmap_factory.CreateBitmapFromMemory( + glyph_bounds.size.width.0 as u32, + glyph_bounds.size.height.0 as u32, + &GUID_WICPixelFormat24bppRGB, + glyph_bounds.size.width.0 as u32 * 3, + &bitmap_data, + ) + }?; + + let grayscale_bitmap = + unsafe { WICConvertBitmapSource(&GUID_WICPixelFormat8bppGray, &bitmap) }?; + + let mut bitmap_data = + vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize]; + unsafe { + grayscale_bitmap.CopyPixels( + std::ptr::null() as _, + glyph_bounds.size.width.0 as u32, + &mut bitmap_data, + ) + }?; + + Ok(bitmap_data) + } + + fn rasterize_color( + &self, + params: &RenderGlyphParams, + glyph_bounds: Bounds<DevicePixels>, + ) -> Result<Vec<u8>> { + let bitmap_size = glyph_bounds.size; + let subpixel_shift = params + .subpixel_variant + .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + let baseline_origin_x = subpixel_shift.x / params.scale_factor; + let baseline_origin_y = subpixel_shift.y / params.scale_factor; + + let transform = DWRITE_MATRIX { + m11: params.scale_factor, + m12: 0.0, + m21: 0.0, + m22: params.scale_factor, + dx: 0.0, + dy: 0.0, + }; + + let font = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; let advance = [glyph_bounds.size.width.0 as f32]; let offset = [DWRITE_GLYPH_OFFSET { @@ -739,7 +945,7 @@ impl DirectWriteState { ascenderOffset: glyph_bounds.origin.y.0 as f32 / params.scale_factor, }]; let glyph_run = DWRITE_GLYPH_RUN { - fontFace: unsafe { std::mem::transmute_copy(&font_info.font_face) }, + fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, fontEmSize: params.font_size.0, glyphCount: 1, glyphIndices: glyph_id.as_ptr(), @@ -749,160 +955,254 @@ impl DirectWriteState { bidiLevel: 0, }; - // Add an extra pixel when the subpixel variant isn't zero to make room for anti-aliasing. - let mut bitmap_size = glyph_bounds.size; - if params.subpixel_variant.x > 0 { - bitmap_size.width += DevicePixels(1); - } - if params.subpixel_variant.y > 0 { - bitmap_size.height += DevicePixels(1); - } - let bitmap_size = bitmap_size; + // todo: support formats other than COLR + let color_enumerator = unsafe { + self.components.factory.TranslateColorGlyphRun( + Vector2::new(baseline_origin_x, baseline_origin_y), + &glyph_run, + None, + DWRITE_GLYPH_IMAGE_FORMATS_COLR, + DWRITE_MEASURING_MODE_NATURAL, + Some(&transform), + 0, + ) + }?; - let total_bytes; - let bitmap_format; - let render_target_property; - let bitmap_width; - let bitmap_height; - let bitmap_stride; - let bitmap_dpi; - if params.is_emoji { - total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize * 4; - bitmap_format = &GUID_WICPixelFormat32bppPBGRA; - render_target_property = get_render_target_property( - DXGI_FORMAT_B8G8R8A8_UNORM, - D2D1_ALPHA_MODE_PREMULTIPLIED, - ); - bitmap_width = bitmap_size.width.0 as u32; - bitmap_height = bitmap_size.height.0 as u32; - bitmap_stride = bitmap_size.width.0 as u32 * 4; - bitmap_dpi = 96.0; - } else { - total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize; - bitmap_format = &GUID_WICPixelFormat8bppAlpha; - render_target_property = - get_render_target_property(DXGI_FORMAT_A8_UNORM, D2D1_ALPHA_MODE_STRAIGHT); - bitmap_width = bitmap_size.width.0 as u32 * 2; - bitmap_height = bitmap_size.height.0 as u32 * 2; - bitmap_stride = bitmap_size.width.0 as u32; - bitmap_dpi = 192.0; + let mut glyph_layers = Vec::new(); + loop { + let color_run = unsafe { color_enumerator.GetCurrentRun() }?; + let color_run = unsafe { &*color_run }; + let image_format = color_run.glyphImageFormat & !DWRITE_GLYPH_IMAGE_FORMATS_TRUETYPE; + if image_format == DWRITE_GLYPH_IMAGE_FORMATS_COLR { + let color_analysis = unsafe { + self.components.factory.CreateGlyphRunAnalysis( + &color_run.Base.glyphRun as *const _, + Some(&transform), + DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + DWRITE_MEASURING_MODE_NATURAL, + DWRITE_GRID_FIT_MODE_DEFAULT, + DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + baseline_origin_x, + baseline_origin_y, + ) + }?; + + let color_bounds = + unsafe { color_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1) }?; + + let color_size = size( + color_bounds.right - color_bounds.left, + color_bounds.bottom - color_bounds.top, + ); + if color_size.width > 0 && color_size.height > 0 { + let mut alpha_data = + vec![0u8; (color_size.width * color_size.height * 3) as usize]; + unsafe { + color_analysis.CreateAlphaTexture( + DWRITE_TEXTURE_CLEARTYPE_3x1, + &color_bounds, + &mut alpha_data, + ) + }?; + + let run_color = { + let run_color = color_run.Base.runColor; + Rgba { + r: run_color.r, + g: run_color.g, + b: run_color.b, + a: run_color.a, + } + }; + let bounds = bounds(point(color_bounds.left, color_bounds.top), color_size); + let alpha_data = alpha_data + .chunks_exact(3) + .flat_map(|chunk| [chunk[0], chunk[1], chunk[2], 255]) + .collect::<Vec<_>>(); + glyph_layers.push(GlyphLayerTexture::new( + &self.components.gpu_state, + run_color, + bounds, + &alpha_data, + )?); + } + } + + let has_next = unsafe { color_enumerator.MoveNext() } + .map(|e| e.as_bool()) + .unwrap_or(false); + if !has_next { + break; + } } - let bitmap_factory = self.components.bitmap_factory.resolve()?; - unsafe { - let bitmap = bitmap_factory.CreateBitmap( - bitmap_width, - bitmap_height, - bitmap_format, - WICBitmapCacheOnLoad, - )?; - let render_target = self - .components - .d2d1_factory - .CreateWicBitmapRenderTarget(&bitmap, &render_target_property)?; - let brush = render_target.CreateSolidColorBrush(&BRUSH_COLOR, None)?; - let subpixel_shift = params - .subpixel_variant - .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); - let baseline_origin = Vector2 { - X: subpixel_shift.x / params.scale_factor, - Y: subpixel_shift.y / params.scale_factor, + let gpu_state = &self.components.gpu_state; + let params_buffer = { + let desc = D3D11_BUFFER_DESC { + ByteWidth: std::mem::size_of::<GlyphLayerTextureParams>() as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + StructureByteStride: 0, }; - // This `cast()` action here should never fail since we are running on Win10+, and - // ID2D1DeviceContext4 requires Win8+ - let render_target = render_target.cast::<ID2D1DeviceContext4>().unwrap(); - render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); - render_target.SetDpi( - bitmap_dpi * params.scale_factor, - bitmap_dpi * params.scale_factor, - ); - render_target.SetTextRenderingParams(&self.components.render_context.params); - render_target.BeginDraw(); + let mut buffer = None; + unsafe { + gpu_state + .device + .CreateBuffer(&desc, None, Some(&mut buffer)) + }?; + [buffer] + }; - if params.is_emoji { - // WARN: only DWRITE_GLYPH_IMAGE_FORMATS_COLR has been tested - let enumerator = self.components.factory.TranslateColorGlyphRun( - baseline_origin, - &glyph_run as _, - None, - DWRITE_GLYPH_IMAGE_FORMATS_COLR - | DWRITE_GLYPH_IMAGE_FORMATS_SVG - | DWRITE_GLYPH_IMAGE_FORMATS_PNG - | DWRITE_GLYPH_IMAGE_FORMATS_JPEG - | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8, - DWRITE_MEASURING_MODE_NATURAL, - None, + let render_target_texture = { + let mut texture = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: bitmap_size.width.0 as u32, + Height: bitmap_size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture)) + }?; + texture.unwrap() + }; + + let render_target_view = { + let desc = D3D11_RENDER_TARGET_VIEW_DESC { + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + ViewDimension: D3D11_RTV_DIMENSION_TEXTURE2D, + Anonymous: D3D11_RENDER_TARGET_VIEW_DESC_0 { + Texture2D: D3D11_TEX2D_RTV { MipSlice: 0 }, + }, + }; + let mut rtv = None; + unsafe { + gpu_state.device.CreateRenderTargetView( + &render_target_texture, + Some(&desc), + Some(&mut rtv), + ) + }?; + [rtv] + }; + + let staging_texture = { + let mut texture = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: bitmap_size.width.0 as u32, + Height: bitmap_size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_STAGING, + BindFlags: 0, + CPUAccessFlags: D3D11_CPU_ACCESS_READ.0 as u32, + MiscFlags: 0, + }; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture)) + }?; + texture.unwrap() + }; + + let device_context = &gpu_state.device_context; + unsafe { device_context.IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP) }; + unsafe { device_context.VSSetShader(&gpu_state.vertex_shader, None) }; + unsafe { device_context.PSSetShader(&gpu_state.pixel_shader, None) }; + unsafe { device_context.VSSetConstantBuffers(0, Some(¶ms_buffer)) }; + unsafe { device_context.PSSetConstantBuffers(0, Some(¶ms_buffer)) }; + unsafe { device_context.OMSetRenderTargets(Some(&render_target_view), None) }; + unsafe { device_context.PSSetSamplers(0, Some(&gpu_state.sampler)) }; + unsafe { device_context.OMSetBlendState(&gpu_state.blend_state, None, 0xffffffff) }; + + for layer in glyph_layers { + let params = GlyphLayerTextureParams { + run_color: layer.run_color, + bounds: layer.bounds, + }; + unsafe { + let mut dest = std::mem::zeroed(); + gpu_state.device_context.Map( + params_buffer[0].as_ref().unwrap(), 0, + D3D11_MAP_WRITE_DISCARD, + 0, + Some(&mut dest), )?; - while enumerator.MoveNext().is_ok() { - let Ok(color_glyph) = enumerator.GetCurrentRun() else { - break; - }; - let color_glyph = &*color_glyph; - let brush_color = translate_color(&color_glyph.Base.runColor); - brush.SetColor(&brush_color); - match color_glyph.glyphImageFormat { - DWRITE_GLYPH_IMAGE_FORMATS_PNG - | DWRITE_GLYPH_IMAGE_FORMATS_JPEG - | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8 => render_target - .DrawColorBitmapGlyphRun( - color_glyph.glyphImageFormat, - baseline_origin, - &color_glyph.Base.glyphRun, - color_glyph.measuringMode, - D2D1_COLOR_BITMAP_GLYPH_SNAP_OPTION_DEFAULT, - ), - DWRITE_GLYPH_IMAGE_FORMATS_SVG => render_target.DrawSvgGlyphRun( - baseline_origin, - &color_glyph.Base.glyphRun, - &brush, - None, - color_glyph.Base.paletteIndex as u32, - color_glyph.measuringMode, - ), - _ => render_target.DrawGlyphRun( - baseline_origin, - &color_glyph.Base.glyphRun, - Some(color_glyph.Base.glyphRunDescription as *const _), - &brush, - color_glyph.measuringMode, - ), - } - } - } else { - render_target.DrawGlyphRun( - baseline_origin, - &glyph_run, - None, - &brush, - DWRITE_MEASURING_MODE_NATURAL, - ); - } - render_target.EndDraw(None, None)?; + std::ptr::copy_nonoverlapping(¶ms as *const _, dest.pData as *mut _, 1); + gpu_state + .device_context + .Unmap(params_buffer[0].as_ref().unwrap(), 0); + }; - let mut raw_data = vec![0u8; total_bytes]; - if params.is_emoji { - bitmap.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; - // Convert from BGRA with premultiplied alpha to BGRA with straight alpha. - for pixel in raw_data.chunks_exact_mut(4) { - let a = pixel[3] as f32 / 255.; - pixel[0] = (pixel[0] as f32 / a) as u8; - pixel[1] = (pixel[1] as f32 / a) as u8; - pixel[2] = (pixel[2] as f32 / a) as u8; - } - } else { - let scaler = bitmap_factory.CreateBitmapScaler()?; - scaler.Initialize( - &bitmap, - bitmap_size.width.0 as u32, - bitmap_size.height.0 as u32, - WICBitmapInterpolationModeHighQualityCubic, - )?; - scaler.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; - } - Ok((bitmap_size, raw_data)) + let texture = [Some(layer.texture_view)]; + unsafe { device_context.PSSetShaderResources(0, Some(&texture)) }; + + let viewport = [D3D11_VIEWPORT { + TopLeftX: layer.bounds.origin.x as f32, + TopLeftY: layer.bounds.origin.y as f32, + Width: layer.bounds.size.width as f32, + Height: layer.bounds.size.height as f32, + MinDepth: 0.0, + MaxDepth: 1.0, + }]; + unsafe { device_context.RSSetViewports(Some(&viewport)) }; + + unsafe { device_context.Draw(4, 0) }; } + + unsafe { device_context.CopyResource(&staging_texture, &render_target_texture) }; + + let mapped_data = { + let mut mapped_data = D3D11_MAPPED_SUBRESOURCE::default(); + unsafe { + device_context.Map( + &staging_texture, + 0, + D3D11_MAP_READ, + 0, + Some(&mut mapped_data), + ) + }?; + mapped_data + }; + let mut rasterized = + vec![0u8; (bitmap_size.width.0 as u32 * bitmap_size.height.0 as u32 * 4) as usize]; + + for y in 0..bitmap_size.height.0 as usize { + let width = bitmap_size.width.0 as usize; + unsafe { + std::ptr::copy_nonoverlapping::<u8>( + (mapped_data.pData as *const u8).byte_add(mapped_data.RowPitch as usize * y), + rasterized + .as_mut_ptr() + .byte_add(width * y * std::mem::size_of::<u32>()), + width * std::mem::size_of::<u32>(), + ) + }; + } + + Ok(rasterized) } fn get_typographic_bounds(&self, font_id: FontId, glyph_id: GlyphId) -> Result<Bounds<f32>> { @@ -976,6 +1276,84 @@ impl Drop for DirectWriteState { } } +struct GlyphLayerTexture { + run_color: Rgba, + bounds: Bounds<i32>, + texture_view: ID3D11ShaderResourceView, + // holding on to the texture to not RAII drop it + _texture: ID3D11Texture2D, +} + +impl GlyphLayerTexture { + pub fn new( + gpu_state: &GPUState, + run_color: Rgba, + bounds: Bounds<i32>, + alpha_data: &[u8], + ) -> Result<Self> { + let texture_size = bounds.size; + + let desc = D3D11_TEXTURE2D_DESC { + Width: texture_size.width as u32, + Height: texture_size.height as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_R8G8B8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + }; + + let texture = { + let mut texture: Option<ID3D11Texture2D> = None; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture))? + }; + texture.unwrap() + }; + let texture_view = { + let mut view: Option<ID3D11ShaderResourceView> = None; + unsafe { + gpu_state + .device + .CreateShaderResourceView(&texture, None, Some(&mut view))? + }; + view.unwrap() + }; + + unsafe { + gpu_state.device_context.UpdateSubresource( + &texture, + 0, + None, + alpha_data.as_ptr() as _, + (texture_size.width * 4) as u32, + 0, + ) + }; + + Ok(GlyphLayerTexture { + run_color, + bounds, + texture_view, + _texture: texture, + }) + } +} + +#[repr(C)] +struct GlyphLayerTextureParams { + bounds: Bounds<i32>, + run_color: Rgba, +} + struct TextRendererWrapper(pub IDWriteTextRenderer); impl TextRendererWrapper { @@ -1470,16 +1848,6 @@ fn get_name(string: IDWriteLocalizedStrings, locale: &str) -> Result<String> { Ok(String::from_utf16_lossy(&name_vec[..name_length])) } -#[inline] -fn translate_color(color: &DWRITE_COLOR_F) -> D2D1_COLOR_F { - D2D1_COLOR_F { - r: color.r, - g: color.g, - b: color.b, - a: color.a, - } -} - fn get_system_ui_font_name() -> SharedString { unsafe { let mut info: LOGFONTW = std::mem::zeroed(); @@ -1504,24 +1872,6 @@ fn get_system_ui_font_name() -> SharedString { } } -#[inline] -fn get_render_target_property( - pixel_format: DXGI_FORMAT, - alpha_mode: D2D1_ALPHA_MODE, -) -> D2D1_RENDER_TARGET_PROPERTIES { - D2D1_RENDER_TARGET_PROPERTIES { - r#type: D2D1_RENDER_TARGET_TYPE_DEFAULT, - pixelFormat: D2D1_PIXEL_FORMAT { - format: pixel_format, - alphaMode: alpha_mode, - }, - dpiX: 96.0, - dpiY: 96.0, - usage: D2D1_RENDER_TARGET_USAGE_NONE, - minLevel: D2D1_FEATURE_LEVEL_DEFAULT, - } -} - // One would think that with newer DirectWrite method: IDWriteFontFace4::GetGlyphImageFormats // but that doesn't seem to work for some glyphs, say ❤ fn is_color_glyph( @@ -1561,12 +1911,6 @@ fn is_color_glyph( } const DEFAULT_LOCALE_NAME: PCWSTR = windows::core::w!("en-US"); -const BRUSH_COLOR: D2D1_COLOR_F = D2D1_COLOR_F { - r: 1.0, - g: 1.0, - b: 1.0, - a: 1.0, -}; #[cfg(test)] mod tests { diff --git a/crates/gpui/src/platform/windows/directx_atlas.rs b/crates/gpui/src/platform/windows/directx_atlas.rs new file mode 100644 index 0000000000..6bced4c11d --- /dev/null +++ b/crates/gpui/src/platform/windows/directx_atlas.rs @@ -0,0 +1,309 @@ +use collections::FxHashMap; +use etagere::BucketedAtlasAllocator; +use parking_lot::Mutex; +use windows::Win32::Graphics::{ + Direct3D11::{ + D3D11_BIND_SHADER_RESOURCE, D3D11_BOX, D3D11_CPU_ACCESS_WRITE, D3D11_TEXTURE2D_DESC, + D3D11_USAGE_DEFAULT, ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView, + ID3D11Texture2D, + }, + Dxgi::Common::*, +}; + +use crate::{ + AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTile, Bounds, DevicePixels, PlatformAtlas, + Point, Size, platform::AtlasTextureList, +}; + +pub(crate) struct DirectXAtlas(Mutex<DirectXAtlasState>); + +struct DirectXAtlasState { + device: ID3D11Device, + device_context: ID3D11DeviceContext, + monochrome_textures: AtlasTextureList<DirectXAtlasTexture>, + polychrome_textures: AtlasTextureList<DirectXAtlasTexture>, + tiles_by_key: FxHashMap<AtlasKey, AtlasTile>, +} + +struct DirectXAtlasTexture { + id: AtlasTextureId, + bytes_per_pixel: u32, + allocator: BucketedAtlasAllocator, + texture: ID3D11Texture2D, + view: [Option<ID3D11ShaderResourceView>; 1], + live_atlas_keys: u32, +} + +impl DirectXAtlas { + pub(crate) fn new(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Self { + DirectXAtlas(Mutex::new(DirectXAtlasState { + device: device.clone(), + device_context: device_context.clone(), + monochrome_textures: Default::default(), + polychrome_textures: Default::default(), + tiles_by_key: Default::default(), + })) + } + + pub(crate) fn get_texture_view( + &self, + id: AtlasTextureId, + ) -> [Option<ID3D11ShaderResourceView>; 1] { + let lock = self.0.lock(); + let tex = lock.texture(id); + tex.view.clone() + } + + pub(crate) fn handle_device_lost( + &self, + device: &ID3D11Device, + device_context: &ID3D11DeviceContext, + ) { + let mut lock = self.0.lock(); + lock.device = device.clone(); + lock.device_context = device_context.clone(); + lock.monochrome_textures = AtlasTextureList::default(); + lock.polychrome_textures = AtlasTextureList::default(); + lock.tiles_by_key.clear(); + } +} + +impl PlatformAtlas for DirectXAtlas { + fn get_or_insert_with<'a>( + &self, + key: &AtlasKey, + build: &mut dyn FnMut() -> anyhow::Result< + Option<(Size<DevicePixels>, std::borrow::Cow<'a, [u8]>)>, + >, + ) -> anyhow::Result<Option<AtlasTile>> { + let mut lock = self.0.lock(); + if let Some(tile) = lock.tiles_by_key.get(key) { + Ok(Some(tile.clone())) + } else { + let Some((size, bytes)) = build()? else { + return Ok(None); + }; + let tile = lock + .allocate(size, key.texture_kind()) + .ok_or_else(|| anyhow::anyhow!("failed to allocate"))?; + let texture = lock.texture(tile.texture_id); + texture.upload(&lock.device_context, tile.bounds, &bytes); + lock.tiles_by_key.insert(key.clone(), tile.clone()); + Ok(Some(tile)) + } + } + + fn remove(&self, key: &AtlasKey) { + let mut lock = self.0.lock(); + + let Some(id) = lock.tiles_by_key.remove(key).map(|tile| tile.texture_id) else { + return; + }; + + let textures = match id.kind { + AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, + AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, + }; + + let Some(texture_slot) = textures.textures.get_mut(id.index as usize) else { + return; + }; + + if let Some(mut texture) = texture_slot.take() { + texture.decrement_ref_count(); + if texture.is_unreferenced() { + textures.free_list.push(texture.id.index as usize); + lock.tiles_by_key.remove(key); + } else { + *texture_slot = Some(texture); + } + } + } +} + +impl DirectXAtlasState { + fn allocate( + &mut self, + size: Size<DevicePixels>, + texture_kind: AtlasTextureKind, + ) -> Option<AtlasTile> { + { + let textures = match texture_kind { + AtlasTextureKind::Monochrome => &mut self.monochrome_textures, + AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + }; + + if let Some(tile) = textures + .iter_mut() + .rev() + .find_map(|texture| texture.allocate(size)) + { + return Some(tile); + } + } + + let texture = self.push_texture(size, texture_kind)?; + texture.allocate(size) + } + + fn push_texture( + &mut self, + min_size: Size<DevicePixels>, + kind: AtlasTextureKind, + ) -> Option<&mut DirectXAtlasTexture> { + const DEFAULT_ATLAS_SIZE: Size<DevicePixels> = Size { + width: DevicePixels(1024), + height: DevicePixels(1024), + }; + // Max texture size for DirectX. See: + // https://learn.microsoft.com/en-us/windows/win32/direct3d11/overviews-direct3d-11-resources-limits + const MAX_ATLAS_SIZE: Size<DevicePixels> = Size { + width: DevicePixels(16384), + height: DevicePixels(16384), + }; + let size = min_size.min(&MAX_ATLAS_SIZE).max(&DEFAULT_ATLAS_SIZE); + let pixel_format; + let bind_flag; + let bytes_per_pixel; + match kind { + AtlasTextureKind::Monochrome => { + pixel_format = DXGI_FORMAT_R8_UNORM; + bind_flag = D3D11_BIND_SHADER_RESOURCE; + bytes_per_pixel = 1; + } + AtlasTextureKind::Polychrome => { + pixel_format = DXGI_FORMAT_B8G8R8A8_UNORM; + bind_flag = D3D11_BIND_SHADER_RESOURCE; + bytes_per_pixel = 4; + } + } + let texture_desc = D3D11_TEXTURE2D_DESC { + Width: size.width.0 as u32, + Height: size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: pixel_format, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: bind_flag.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + }; + let mut texture: Option<ID3D11Texture2D> = None; + unsafe { + // This only returns None if the device is lost, which we will recreate later. + // So it's ok to return None here. + self.device + .CreateTexture2D(&texture_desc, None, Some(&mut texture)) + .ok()?; + } + let texture = texture.unwrap(); + + let texture_list = match kind { + AtlasTextureKind::Monochrome => &mut self.monochrome_textures, + AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + }; + let index = texture_list.free_list.pop(); + let view = unsafe { + let mut view = None; + self.device + .CreateShaderResourceView(&texture, None, Some(&mut view)) + .ok()?; + [view] + }; + let atlas_texture = DirectXAtlasTexture { + id: AtlasTextureId { + index: index.unwrap_or(texture_list.textures.len()) as u32, + kind, + }, + bytes_per_pixel, + allocator: etagere::BucketedAtlasAllocator::new(size.into()), + texture, + view, + live_atlas_keys: 0, + }; + if let Some(ix) = index { + texture_list.textures[ix] = Some(atlas_texture); + texture_list.textures.get_mut(ix).unwrap().as_mut() + } else { + texture_list.textures.push(Some(atlas_texture)); + texture_list.textures.last_mut().unwrap().as_mut() + } + } + + fn texture(&self, id: AtlasTextureId) -> &DirectXAtlasTexture { + let textures = match id.kind { + crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, + crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + }; + textures[id.index as usize].as_ref().unwrap() + } +} + +impl DirectXAtlasTexture { + fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> { + let allocation = self.allocator.allocate(size.into())?; + let tile = AtlasTile { + texture_id: self.id, + tile_id: allocation.id.into(), + bounds: Bounds { + origin: allocation.rectangle.min.into(), + size, + }, + padding: 0, + }; + self.live_atlas_keys += 1; + Some(tile) + } + + fn upload( + &self, + device_context: &ID3D11DeviceContext, + bounds: Bounds<DevicePixels>, + bytes: &[u8], + ) { + unsafe { + device_context.UpdateSubresource( + &self.texture, + 0, + Some(&D3D11_BOX { + left: bounds.left().0 as u32, + top: bounds.top().0 as u32, + front: 0, + right: bounds.right().0 as u32, + bottom: bounds.bottom().0 as u32, + back: 1, + }), + bytes.as_ptr() as _, + bounds.size.width.to_bytes(self.bytes_per_pixel as u8), + 0, + ); + } + } + + fn decrement_ref_count(&mut self) { + self.live_atlas_keys -= 1; + } + + fn is_unreferenced(&mut self) -> bool { + self.live_atlas_keys == 0 + } +} + +impl From<Size<DevicePixels>> for etagere::Size { + fn from(size: Size<DevicePixels>) -> Self { + etagere::Size::new(size.width.into(), size.height.into()) + } +} + +impl From<etagere::Point> for Point<DevicePixels> { + fn from(value: etagere::Point) -> Self { + Point { + x: DevicePixels::from(value.x), + y: DevicePixels::from(value.y), + } + } +} diff --git a/crates/gpui/src/platform/windows/directx_renderer.rs b/crates/gpui/src/platform/windows/directx_renderer.rs new file mode 100644 index 0000000000..ac285b79ac --- /dev/null +++ b/crates/gpui/src/platform/windows/directx_renderer.rs @@ -0,0 +1,1807 @@ +use std::{mem::ManuallyDrop, sync::Arc}; + +use ::util::ResultExt; +use anyhow::{Context, Result}; +use windows::{ + Win32::{ + Foundation::{HMODULE, HWND}, + Graphics::{ + Direct3D::*, + Direct3D11::*, + DirectComposition::*, + Dxgi::{Common::*, *}, + }, + }, + core::Interface, +}; + +use crate::{ + platform::windows::directx_renderer::shader_resources::{ + RawShaderBytes, ShaderModule, ShaderTarget, + }, + *, +}; + +pub(crate) const DISABLE_DIRECT_COMPOSITION: &str = "GPUI_DISABLE_DIRECT_COMPOSITION"; +const RENDER_TARGET_FORMAT: DXGI_FORMAT = DXGI_FORMAT_B8G8R8A8_UNORM; +// This configuration is used for MSAA rendering on paths only, and it's guaranteed to be supported by DirectX 11. +const PATH_MULTISAMPLE_COUNT: u32 = 4; + +pub(crate) struct DirectXRenderer { + hwnd: HWND, + atlas: Arc<DirectXAtlas>, + devices: ManuallyDrop<DirectXDevices>, + resources: ManuallyDrop<DirectXResources>, + globals: DirectXGlobalElements, + pipelines: DirectXRenderPipelines, + direct_composition: Option<DirectComposition>, +} + +/// Direct3D objects +#[derive(Clone)] +pub(crate) struct DirectXDevices { + adapter: IDXGIAdapter1, + dxgi_factory: IDXGIFactory6, + pub(crate) device: ID3D11Device, + pub(crate) device_context: ID3D11DeviceContext, + dxgi_device: Option<IDXGIDevice>, +} + +struct DirectXResources { + // Direct3D rendering objects + swap_chain: IDXGISwapChain1, + render_target: ManuallyDrop<ID3D11Texture2D>, + render_target_view: [Option<ID3D11RenderTargetView>; 1], + + // Path intermediate textures (with MSAA) + path_intermediate_texture: ID3D11Texture2D, + path_intermediate_srv: [Option<ID3D11ShaderResourceView>; 1], + path_intermediate_msaa_texture: ID3D11Texture2D, + path_intermediate_msaa_view: [Option<ID3D11RenderTargetView>; 1], + + // Cached window size and viewport + width: u32, + height: u32, + viewport: [D3D11_VIEWPORT; 1], +} + +struct DirectXRenderPipelines { + shadow_pipeline: PipelineState<Shadow>, + quad_pipeline: PipelineState<Quad>, + path_rasterization_pipeline: PipelineState<PathRasterizationSprite>, + path_sprite_pipeline: PipelineState<PathSprite>, + underline_pipeline: PipelineState<Underline>, + mono_sprites: PipelineState<MonochromeSprite>, + poly_sprites: PipelineState<PolychromeSprite>, +} + +struct DirectXGlobalElements { + global_params_buffer: [Option<ID3D11Buffer>; 1], + sampler: [Option<ID3D11SamplerState>; 1], +} + +struct DirectComposition { + comp_device: IDCompositionDevice, + comp_target: IDCompositionTarget, + comp_visual: IDCompositionVisual, +} + +impl DirectXDevices { + pub(crate) fn new(disable_direct_composition: bool) -> Result<ManuallyDrop<Self>> { + let debug_layer_available = check_debug_layer_available(); + let dxgi_factory = + get_dxgi_factory(debug_layer_available).context("Creating DXGI factory")?; + let adapter = + get_adapter(&dxgi_factory, debug_layer_available).context("Getting DXGI adapter")?; + let (device, device_context) = { + let mut device: Option<ID3D11Device> = None; + let mut context: Option<ID3D11DeviceContext> = None; + let mut feature_level = D3D_FEATURE_LEVEL::default(); + get_device( + &adapter, + Some(&mut device), + Some(&mut context), + Some(&mut feature_level), + debug_layer_available, + ) + .context("Creating Direct3D device")?; + match feature_level { + D3D_FEATURE_LEVEL_11_1 => { + log::info!("Created device with Direct3D 11.1 feature level.") + } + D3D_FEATURE_LEVEL_11_0 => { + log::info!("Created device with Direct3D 11.0 feature level.") + } + D3D_FEATURE_LEVEL_10_1 => { + log::info!("Created device with Direct3D 10.1 feature level.") + } + _ => unreachable!(), + } + (device.unwrap(), context.unwrap()) + }; + let dxgi_device = if disable_direct_composition { + None + } else { + Some(device.cast().context("Creating DXGI device")?) + }; + + Ok(ManuallyDrop::new(Self { + adapter, + dxgi_factory, + dxgi_device, + device, + device_context, + })) + } +} + +impl DirectXRenderer { + pub(crate) fn new(hwnd: HWND, disable_direct_composition: bool) -> Result<Self> { + if disable_direct_composition { + log::info!("Direct Composition is disabled."); + } + + let devices = + DirectXDevices::new(disable_direct_composition).context("Creating DirectX devices")?; + let atlas = Arc::new(DirectXAtlas::new(&devices.device, &devices.device_context)); + + let resources = DirectXResources::new(&devices, 1, 1, hwnd, disable_direct_composition) + .context("Creating DirectX resources")?; + let globals = DirectXGlobalElements::new(&devices.device) + .context("Creating DirectX global elements")?; + let pipelines = DirectXRenderPipelines::new(&devices.device) + .context("Creating DirectX render pipelines")?; + + let direct_composition = if disable_direct_composition { + None + } else { + let composition = DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), hwnd) + .context("Creating DirectComposition")?; + composition + .set_swap_chain(&resources.swap_chain) + .context("Setting swap chain for DirectComposition")?; + Some(composition) + }; + + Ok(DirectXRenderer { + hwnd, + atlas, + devices, + resources, + globals, + pipelines, + direct_composition, + }) + } + + pub(crate) fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> { + self.atlas.clone() + } + + fn pre_draw(&self) -> Result<()> { + update_buffer( + &self.devices.device_context, + self.globals.global_params_buffer[0].as_ref().unwrap(), + &[GlobalParams { + viewport_size: [ + self.resources.viewport[0].Width, + self.resources.viewport[0].Height, + ], + _pad: 0, + }], + )?; + unsafe { + self.devices.device_context.ClearRenderTargetView( + self.resources.render_target_view[0].as_ref().unwrap(), + &[0.0; 4], + ); + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + self.devices + .device_context + .RSSetViewports(Some(&self.resources.viewport)); + } + Ok(()) + } + + fn present(&mut self) -> Result<()> { + unsafe { + let result = self.resources.swap_chain.Present(1, DXGI_PRESENT(0)); + // Presenting the swap chain can fail if the DirectX device was removed or reset. + if result == DXGI_ERROR_DEVICE_REMOVED || result == DXGI_ERROR_DEVICE_RESET { + let reason = self.devices.device.GetDeviceRemovedReason(); + log::error!( + "DirectX device removed or reset when drawing. Reason: {:?}", + reason + ); + self.handle_device_lost()?; + } else { + result.ok()?; + } + } + Ok(()) + } + + fn handle_device_lost(&mut self) -> Result<()> { + // Here we wait a bit to ensure the the system has time to recover from the device lost state. + // If we don't wait, the final drawing result will be blank. + std::thread::sleep(std::time::Duration::from_millis(300)); + let disable_direct_composition = self.direct_composition.is_none(); + + unsafe { + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device) + .context("Failed to report live objects after device lost") + .log_err(); + + ManuallyDrop::drop(&mut self.resources); + self.devices.device_context.OMSetRenderTargets(None, None); + self.devices.device_context.ClearState(); + self.devices.device_context.Flush(); + + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device) + .context("Failed to report live objects after device lost") + .log_err(); + + drop(self.direct_composition.take()); + ManuallyDrop::drop(&mut self.devices); + } + + let devices = DirectXDevices::new(disable_direct_composition) + .context("Recreating DirectX devices")?; + let resources = DirectXResources::new( + &devices, + self.resources.width, + self.resources.height, + self.hwnd, + disable_direct_composition, + )?; + let globals = DirectXGlobalElements::new(&devices.device)?; + let pipelines = DirectXRenderPipelines::new(&devices.device)?; + + let direct_composition = if disable_direct_composition { + None + } else { + let composition = + DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), self.hwnd)?; + composition.set_swap_chain(&resources.swap_chain)?; + Some(composition) + }; + + self.atlas + .handle_device_lost(&devices.device, &devices.device_context); + self.devices = devices; + self.resources = resources; + self.globals = globals; + self.pipelines = pipelines; + self.direct_composition = direct_composition; + + unsafe { + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + Ok(()) + } + + pub(crate) fn draw(&mut self, scene: &Scene) -> Result<()> { + self.pre_draw()?; + for batch in scene.batches() { + match batch { + PrimitiveBatch::Shadows(shadows) => self.draw_shadows(shadows), + PrimitiveBatch::Quads(quads) => self.draw_quads(quads), + PrimitiveBatch::Paths(paths) => { + self.draw_paths_to_intermediate(paths)?; + self.draw_paths_from_intermediate(paths) + } + PrimitiveBatch::Underlines(underlines) => self.draw_underlines(underlines), + PrimitiveBatch::MonochromeSprites { + texture_id, + sprites, + } => self.draw_monochrome_sprites(texture_id, sprites), + PrimitiveBatch::PolychromeSprites { + texture_id, + sprites, + } => self.draw_polychrome_sprites(texture_id, sprites), + PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces(surfaces), + }.context(format!("scene too large: {} paths, {} shadows, {} quads, {} underlines, {} mono, {} poly, {} surfaces", + scene.paths.len(), + scene.shadows.len(), + scene.quads.len(), + scene.underlines.len(), + scene.monochrome_sprites.len(), + scene.polychrome_sprites.len(), + scene.surfaces.len(),))?; + } + self.present() + } + + pub(crate) fn resize(&mut self, new_size: Size<DevicePixels>) -> Result<()> { + let width = new_size.width.0.max(1) as u32; + let height = new_size.height.0.max(1) as u32; + if self.resources.width == width && self.resources.height == height { + return Ok(()); + } + unsafe { + // Clear the render target before resizing + self.devices.device_context.OMSetRenderTargets(None, None); + ManuallyDrop::drop(&mut self.resources.render_target); + drop(self.resources.render_target_view[0].take().unwrap()); + + let result = self.resources.swap_chain.ResizeBuffers( + BUFFER_COUNT as u32, + width, + height, + RENDER_TARGET_FORMAT, + DXGI_SWAP_CHAIN_FLAG(0), + ); + // Resizing the swap chain requires a call to the underlying DXGI adapter, which can return the device removed error. + // The app might have moved to a monitor that's attached to a different graphics device. + // When a graphics device is removed or reset, the desktop resolution often changes, resulting in a window size change. + match result { + Ok(_) => {} + Err(e) => { + if e.code() == DXGI_ERROR_DEVICE_REMOVED || e.code() == DXGI_ERROR_DEVICE_RESET + { + let reason = self.devices.device.GetDeviceRemovedReason(); + log::error!( + "DirectX device removed or reset when resizing. Reason: {:?}", + reason + ); + self.resources.width = width; + self.resources.height = height; + self.handle_device_lost()?; + return Ok(()); + } else { + log::error!("Failed to resize swap chain: {:?}", e); + return Err(e.into()); + } + } + } + + self.resources + .recreate_resources(&self.devices, width, height)?; + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + Ok(()) + } + + fn draw_shadows(&mut self, shadows: &[Shadow]) -> Result<()> { + if shadows.is_empty() { + return Ok(()); + } + self.pipelines.shadow_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + shadows, + )?; + self.pipelines.shadow_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + shadows.len() as u32, + ) + } + + fn draw_quads(&mut self, quads: &[Quad]) -> Result<()> { + if quads.is_empty() { + return Ok(()); + } + self.pipelines.quad_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + quads, + )?; + self.pipelines.quad_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + quads.len() as u32, + ) + } + + fn draw_paths_to_intermediate(&mut self, paths: &[Path<ScaledPixels>]) -> Result<()> { + if paths.is_empty() { + return Ok(()); + } + + // Clear intermediate MSAA texture + unsafe { + self.devices.device_context.ClearRenderTargetView( + self.resources.path_intermediate_msaa_view[0] + .as_ref() + .unwrap(), + &[0.0; 4], + ); + // Set intermediate MSAA texture as render target + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.path_intermediate_msaa_view), None); + } + + // Collect all vertices and sprites for a single draw call + let mut vertices = Vec::new(); + + for path in paths { + vertices.extend(path.vertices.iter().map(|v| PathRasterizationSprite { + xy_position: v.xy_position, + st_position: v.st_position, + color: path.color, + bounds: path.clipped_bounds(), + })); + } + + self.pipelines.path_rasterization_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + &vertices, + )?; + self.pipelines.path_rasterization_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST, + vertices.len() as u32, + 1, + )?; + + // Resolve MSAA to non-MSAA intermediate texture + unsafe { + self.devices.device_context.ResolveSubresource( + &self.resources.path_intermediate_texture, + 0, + &self.resources.path_intermediate_msaa_texture, + 0, + RENDER_TARGET_FORMAT, + ); + // Restore main render target + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + + Ok(()) + } + + fn draw_paths_from_intermediate(&mut self, paths: &[Path<ScaledPixels>]) -> Result<()> { + let Some(first_path) = paths.first() else { + return Ok(()); + }; + + // When copying paths from the intermediate texture to the drawable, + // each pixel must only be copied once, in case of transparent paths. + // + // If all paths have the same draw order, then their bounds are all + // disjoint, so we can copy each path's bounds individually. If this + // batch combines different draw orders, we perform a single copy + // for a minimal spanning rect. + let sprites = if paths.last().unwrap().order == first_path.order { + paths + .iter() + .map(|path| PathSprite { + bounds: path.clipped_bounds(), + }) + .collect::<Vec<_>>() + } else { + let mut bounds = first_path.clipped_bounds(); + for path in paths.iter().skip(1) { + bounds = bounds.union(&path.clipped_bounds()); + } + vec![PathSprite { bounds }] + }; + + self.pipelines.path_sprite_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + &sprites, + )?; + + // Draw the sprites with the path texture + self.pipelines.path_sprite_pipeline.draw_with_texture( + &self.devices.device_context, + &self.resources.path_intermediate_srv, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_underlines(&mut self, underlines: &[Underline]) -> Result<()> { + if underlines.is_empty() { + return Ok(()); + } + self.pipelines.underline_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + underlines, + )?; + self.pipelines.underline_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + underlines.len() as u32, + ) + } + + fn draw_monochrome_sprites( + &mut self, + texture_id: AtlasTextureId, + sprites: &[MonochromeSprite], + ) -> Result<()> { + if sprites.is_empty() { + return Ok(()); + } + self.pipelines.mono_sprites.update_buffer( + &self.devices.device, + &self.devices.device_context, + sprites, + )?; + let texture_view = self.atlas.get_texture_view(texture_id); + self.pipelines.mono_sprites.draw_with_texture( + &self.devices.device_context, + &texture_view, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_polychrome_sprites( + &mut self, + texture_id: AtlasTextureId, + sprites: &[PolychromeSprite], + ) -> Result<()> { + if sprites.is_empty() { + return Ok(()); + } + self.pipelines.poly_sprites.update_buffer( + &self.devices.device, + &self.devices.device_context, + sprites, + )?; + let texture_view = self.atlas.get_texture_view(texture_id); + self.pipelines.poly_sprites.draw_with_texture( + &self.devices.device_context, + &texture_view, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_surfaces(&mut self, surfaces: &[PaintSurface]) -> Result<()> { + if surfaces.is_empty() { + return Ok(()); + } + Ok(()) + } + + pub(crate) fn gpu_specs(&self) -> Result<GpuSpecs> { + let desc = unsafe { self.devices.adapter.GetDesc1() }?; + let is_software_emulated = (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE.0 as u32) != 0; + let device_name = String::from_utf16_lossy(&desc.Description) + .trim_matches(char::from(0)) + .to_string(); + let driver_name = match desc.VendorId { + 0x10DE => "NVIDIA Corporation".to_string(), + 0x1002 => "AMD Corporation".to_string(), + 0x8086 => "Intel Corporation".to_string(), + id => format!("Unknown Vendor (ID: {:#X})", id), + }; + let driver_version = match desc.VendorId { + 0x10DE => nvidia::get_driver_version(), + 0x1002 => amd::get_driver_version(), + // For Intel and other vendors, we use the DXGI API to get the driver version. + _ => dxgi::get_driver_version(&self.devices.adapter), + } + .context("Failed to get gpu driver info") + .log_err() + .unwrap_or("Unknown Driver".to_string()); + Ok(GpuSpecs { + is_software_emulated, + device_name, + driver_name, + driver_info: driver_version, + }) + } +} + +impl DirectXResources { + pub fn new( + devices: &DirectXDevices, + width: u32, + height: u32, + hwnd: HWND, + disable_direct_composition: bool, + ) -> Result<ManuallyDrop<Self>> { + let swap_chain = if disable_direct_composition { + create_swap_chain(&devices.dxgi_factory, &devices.device, hwnd, width, height)? + } else { + create_swap_chain_for_composition( + &devices.dxgi_factory, + &devices.device, + width, + height, + )? + }; + + let ( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + ) = create_resources(devices, &swap_chain, width, height)?; + set_rasterizer_state(&devices.device, &devices.device_context)?; + + Ok(ManuallyDrop::new(Self { + swap_chain, + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + path_intermediate_srv, + viewport, + width, + height, + })) + } + + #[inline] + fn recreate_resources( + &mut self, + devices: &DirectXDevices, + width: u32, + height: u32, + ) -> Result<()> { + let ( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + ) = create_resources(devices, &self.swap_chain, width, height)?; + self.render_target = render_target; + self.render_target_view = render_target_view; + self.path_intermediate_texture = path_intermediate_texture; + self.path_intermediate_msaa_texture = path_intermediate_msaa_texture; + self.path_intermediate_msaa_view = path_intermediate_msaa_view; + self.path_intermediate_srv = path_intermediate_srv; + self.viewport = viewport; + self.width = width; + self.height = height; + Ok(()) + } +} + +impl DirectXRenderPipelines { + pub fn new(device: &ID3D11Device) -> Result<Self> { + let shadow_pipeline = PipelineState::new( + device, + "shadow_pipeline", + ShaderModule::Shadow, + 4, + create_blend_state(device)?, + )?; + let quad_pipeline = PipelineState::new( + device, + "quad_pipeline", + ShaderModule::Quad, + 64, + create_blend_state(device)?, + )?; + let path_rasterization_pipeline = PipelineState::new( + device, + "path_rasterization_pipeline", + ShaderModule::PathRasterization, + 32, + create_blend_state_for_path_rasterization(device)?, + )?; + let path_sprite_pipeline = PipelineState::new( + device, + "path_sprite_pipeline", + ShaderModule::PathSprite, + 4, + create_blend_state_for_path_sprite(device)?, + )?; + let underline_pipeline = PipelineState::new( + device, + "underline_pipeline", + ShaderModule::Underline, + 4, + create_blend_state(device)?, + )?; + let mono_sprites = PipelineState::new( + device, + "monochrome_sprite_pipeline", + ShaderModule::MonochromeSprite, + 512, + create_blend_state(device)?, + )?; + let poly_sprites = PipelineState::new( + device, + "polychrome_sprite_pipeline", + ShaderModule::PolychromeSprite, + 16, + create_blend_state(device)?, + )?; + + Ok(Self { + shadow_pipeline, + quad_pipeline, + path_rasterization_pipeline, + path_sprite_pipeline, + underline_pipeline, + mono_sprites, + poly_sprites, + }) + } +} + +impl DirectComposition { + pub fn new(dxgi_device: &IDXGIDevice, hwnd: HWND) -> Result<Self> { + let comp_device = get_comp_device(&dxgi_device)?; + let comp_target = unsafe { comp_device.CreateTargetForHwnd(hwnd, true) }?; + let comp_visual = unsafe { comp_device.CreateVisual() }?; + + Ok(Self { + comp_device, + comp_target, + comp_visual, + }) + } + + pub fn set_swap_chain(&self, swap_chain: &IDXGISwapChain1) -> Result<()> { + unsafe { + self.comp_visual.SetContent(swap_chain)?; + self.comp_target.SetRoot(&self.comp_visual)?; + self.comp_device.Commit()?; + } + Ok(()) + } +} + +impl DirectXGlobalElements { + pub fn new(device: &ID3D11Device) -> Result<Self> { + let global_params_buffer = unsafe { + let desc = D3D11_BUFFER_DESC { + ByteWidth: std::mem::size_of::<GlobalParams>() as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + ..Default::default() + }; + let mut buffer = None; + device.CreateBuffer(&desc, None, Some(&mut buffer))?; + [buffer] + }; + + let sampler = unsafe { + let desc = D3D11_SAMPLER_DESC { + Filter: D3D11_FILTER_MIN_MAG_MIP_LINEAR, + AddressU: D3D11_TEXTURE_ADDRESS_WRAP, + AddressV: D3D11_TEXTURE_ADDRESS_WRAP, + AddressW: D3D11_TEXTURE_ADDRESS_WRAP, + MipLODBias: 0.0, + MaxAnisotropy: 1, + ComparisonFunc: D3D11_COMPARISON_ALWAYS, + BorderColor: [0.0; 4], + MinLOD: 0.0, + MaxLOD: D3D11_FLOAT32_MAX, + }; + let mut output = None; + device.CreateSamplerState(&desc, Some(&mut output))?; + [output] + }; + + Ok(Self { + global_params_buffer, + sampler, + }) + } +} + +#[derive(Debug, Default)] +#[repr(C)] +struct GlobalParams { + viewport_size: [f32; 2], + _pad: u64, +} + +struct PipelineState<T> { + label: &'static str, + vertex: ID3D11VertexShader, + fragment: ID3D11PixelShader, + buffer: ID3D11Buffer, + buffer_size: usize, + view: [Option<ID3D11ShaderResourceView>; 1], + blend_state: ID3D11BlendState, + _marker: std::marker::PhantomData<T>, +} + +impl<T> PipelineState<T> { + fn new( + device: &ID3D11Device, + label: &'static str, + shader_module: ShaderModule, + buffer_size: usize, + blend_state: ID3D11BlendState, + ) -> Result<Self> { + let vertex = { + let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Vertex)?; + create_vertex_shader(device, raw_shader.as_bytes())? + }; + let fragment = { + let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Fragment)?; + create_fragment_shader(device, raw_shader.as_bytes())? + }; + let buffer = create_buffer(device, std::mem::size_of::<T>(), buffer_size)?; + let view = create_buffer_view(device, &buffer)?; + + Ok(PipelineState { + label, + vertex, + fragment, + buffer, + buffer_size, + view, + blend_state, + _marker: std::marker::PhantomData, + }) + } + + fn update_buffer( + &mut self, + device: &ID3D11Device, + device_context: &ID3D11DeviceContext, + data: &[T], + ) -> Result<()> { + if self.buffer_size < data.len() { + let new_buffer_size = data.len().next_power_of_two(); + log::info!( + "Updating {} buffer size from {} to {}", + self.label, + self.buffer_size, + new_buffer_size + ); + let buffer = create_buffer(device, std::mem::size_of::<T>(), new_buffer_size)?; + let view = create_buffer_view(device, &buffer)?; + self.buffer = buffer; + self.view = view; + self.buffer_size = new_buffer_size; + } + update_buffer(device_context, &self.buffer, data) + } + + fn draw( + &self, + device_context: &ID3D11DeviceContext, + viewport: &[D3D11_VIEWPORT], + global_params: &[Option<ID3D11Buffer>], + topology: D3D_PRIMITIVE_TOPOLOGY, + vertex_count: u32, + instance_count: u32, + ) -> Result<()> { + set_pipeline_state( + device_context, + &self.view, + topology, + viewport, + &self.vertex, + &self.fragment, + global_params, + &self.blend_state, + ); + unsafe { + device_context.DrawInstanced(vertex_count, instance_count, 0, 0); + } + Ok(()) + } + + fn draw_with_texture( + &self, + device_context: &ID3D11DeviceContext, + texture: &[Option<ID3D11ShaderResourceView>], + viewport: &[D3D11_VIEWPORT], + global_params: &[Option<ID3D11Buffer>], + sampler: &[Option<ID3D11SamplerState>], + instance_count: u32, + ) -> Result<()> { + set_pipeline_state( + device_context, + &self.view, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + viewport, + &self.vertex, + &self.fragment, + global_params, + &self.blend_state, + ); + unsafe { + device_context.PSSetSamplers(0, Some(sampler)); + device_context.VSSetShaderResources(0, Some(texture)); + device_context.PSSetShaderResources(0, Some(texture)); + + device_context.DrawInstanced(4, instance_count, 0, 0); + } + Ok(()) + } +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct PathRasterizationSprite { + xy_position: Point<ScaledPixels>, + st_position: Point<f32>, + color: Background, + bounds: Bounds<ScaledPixels>, +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct PathSprite { + bounds: Bounds<ScaledPixels>, +} + +impl Drop for DirectXRenderer { + fn drop(&mut self) { + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device).ok(); + unsafe { + ManuallyDrop::drop(&mut self.devices); + ManuallyDrop::drop(&mut self.resources); + } + } +} + +impl Drop for DirectXResources { + fn drop(&mut self) { + unsafe { + ManuallyDrop::drop(&mut self.render_target); + } + } +} + +#[inline] +fn check_debug_layer_available() -> bool { + #[cfg(debug_assertions)] + { + unsafe { DXGIGetDebugInterface1::<IDXGIInfoQueue>(0) } + .log_err() + .is_some() + } + #[cfg(not(debug_assertions))] + { + false + } +} + +#[inline] +fn get_dxgi_factory(debug_layer_available: bool) -> Result<IDXGIFactory6> { + let factory_flag = if debug_layer_available { + DXGI_CREATE_FACTORY_DEBUG + } else { + #[cfg(debug_assertions)] + log::warn!( + "Failed to get DXGI debug interface. DirectX debugging features will be disabled." + ); + DXGI_CREATE_FACTORY_FLAGS::default() + }; + unsafe { Ok(CreateDXGIFactory2(factory_flag)?) } +} + +fn get_adapter(dxgi_factory: &IDXGIFactory6, debug_layer_available: bool) -> Result<IDXGIAdapter1> { + for adapter_index in 0.. { + let adapter: IDXGIAdapter1 = unsafe { + dxgi_factory + .EnumAdapterByGpuPreference(adapter_index, DXGI_GPU_PREFERENCE_MINIMUM_POWER) + }?; + if let Ok(desc) = unsafe { adapter.GetDesc1() } { + let gpu_name = String::from_utf16_lossy(&desc.Description) + .trim_matches(char::from(0)) + .to_string(); + log::info!("Using GPU: {}", gpu_name); + } + // Check to see whether the adapter supports Direct3D 11, but don't + // create the actual device yet. + if get_device(&adapter, None, None, None, debug_layer_available) + .log_err() + .is_some() + { + return Ok(adapter); + } + } + + unreachable!() +} + +fn get_device( + adapter: &IDXGIAdapter1, + device: Option<*mut Option<ID3D11Device>>, + context: Option<*mut Option<ID3D11DeviceContext>>, + feature_level: Option<*mut D3D_FEATURE_LEVEL>, + debug_layer_available: bool, +) -> Result<()> { + let device_flags = if debug_layer_available { + D3D11_CREATE_DEVICE_BGRA_SUPPORT | D3D11_CREATE_DEVICE_DEBUG + } else { + D3D11_CREATE_DEVICE_BGRA_SUPPORT + }; + unsafe { + D3D11CreateDevice( + adapter, + D3D_DRIVER_TYPE_UNKNOWN, + HMODULE::default(), + device_flags, + // 4x MSAA is required for Direct3D Feature Level 10.1 or better + Some(&[ + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_10_1, + ]), + D3D11_SDK_VERSION, + device, + feature_level, + context, + )?; + } + Ok(()) +} + +#[inline] +fn get_comp_device(dxgi_device: &IDXGIDevice) -> Result<IDCompositionDevice> { + Ok(unsafe { DCompositionCreateDevice(dxgi_device)? }) +} + +fn create_swap_chain_for_composition( + dxgi_factory: &IDXGIFactory6, + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<IDXGISwapChain1> { + let desc = DXGI_SWAP_CHAIN_DESC1 { + Width: width, + Height: height, + Format: RENDER_TARGET_FORMAT, + Stereo: false.into(), + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, + BufferCount: BUFFER_COUNT as u32, + // Composition SwapChains only support the DXGI_SCALING_STRETCH Scaling. + Scaling: DXGI_SCALING_STRETCH, + SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, + AlphaMode: DXGI_ALPHA_MODE_PREMULTIPLIED, + Flags: 0, + }; + Ok(unsafe { dxgi_factory.CreateSwapChainForComposition(device, &desc, None)? }) +} + +fn create_swap_chain( + dxgi_factory: &IDXGIFactory6, + device: &ID3D11Device, + hwnd: HWND, + width: u32, + height: u32, +) -> Result<IDXGISwapChain1> { + use windows::Win32::Graphics::Dxgi::DXGI_MWA_NO_ALT_ENTER; + + let desc = DXGI_SWAP_CHAIN_DESC1 { + Width: width, + Height: height, + Format: RENDER_TARGET_FORMAT, + Stereo: false.into(), + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, + BufferCount: BUFFER_COUNT as u32, + Scaling: DXGI_SCALING_NONE, + SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, + AlphaMode: DXGI_ALPHA_MODE_IGNORE, + Flags: 0, + }; + let swap_chain = + unsafe { dxgi_factory.CreateSwapChainForHwnd(device, hwnd, &desc, None, None) }?; + unsafe { dxgi_factory.MakeWindowAssociation(hwnd, DXGI_MWA_NO_ALT_ENTER) }?; + Ok(swap_chain) +} + +#[inline] +fn create_resources( + devices: &DirectXDevices, + swap_chain: &IDXGISwapChain1, + width: u32, + height: u32, +) -> Result<( + ManuallyDrop<ID3D11Texture2D>, + [Option<ID3D11RenderTargetView>; 1], + ID3D11Texture2D, + [Option<ID3D11ShaderResourceView>; 1], + ID3D11Texture2D, + [Option<ID3D11RenderTargetView>; 1], + [D3D11_VIEWPORT; 1], +)> { + let (render_target, render_target_view) = + create_render_target_and_its_view(&swap_chain, &devices.device)?; + let (path_intermediate_texture, path_intermediate_srv) = + create_path_intermediate_texture(&devices.device, width, height)?; + let (path_intermediate_msaa_texture, path_intermediate_msaa_view) = + create_path_intermediate_msaa_texture_and_view(&devices.device, width, height)?; + let viewport = set_viewport(&devices.device_context, width as f32, height as f32); + Ok(( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + )) +} + +#[inline] +fn create_render_target_and_its_view( + swap_chain: &IDXGISwapChain1, + device: &ID3D11Device, +) -> Result<( + ManuallyDrop<ID3D11Texture2D>, + [Option<ID3D11RenderTargetView>; 1], +)> { + let render_target: ID3D11Texture2D = unsafe { swap_chain.GetBuffer(0) }?; + let mut render_target_view = None; + unsafe { device.CreateRenderTargetView(&render_target, None, Some(&mut render_target_view))? }; + Ok(( + ManuallyDrop::new(render_target), + [Some(render_target_view.unwrap())], + )) +} + +#[inline] +fn create_path_intermediate_texture( + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<(ID3D11Texture2D, [Option<ID3D11ShaderResourceView>; 1])> { + let texture = unsafe { + let mut output = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: width, + Height: height, + MipLevels: 1, + ArraySize: 1, + Format: RENDER_TARGET_FORMAT, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: (D3D11_BIND_RENDER_TARGET.0 | D3D11_BIND_SHADER_RESOURCE.0) as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + device.CreateTexture2D(&desc, None, Some(&mut output))?; + output.unwrap() + }; + + let mut shader_resource_view = None; + unsafe { device.CreateShaderResourceView(&texture, None, Some(&mut shader_resource_view))? }; + + Ok((texture, [Some(shader_resource_view.unwrap())])) +} + +#[inline] +fn create_path_intermediate_msaa_texture_and_view( + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<(ID3D11Texture2D, [Option<ID3D11RenderTargetView>; 1])> { + let msaa_texture = unsafe { + let mut output = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: width, + Height: height, + MipLevels: 1, + ArraySize: 1, + Format: RENDER_TARGET_FORMAT, + SampleDesc: DXGI_SAMPLE_DESC { + Count: PATH_MULTISAMPLE_COUNT, + Quality: D3D11_STANDARD_MULTISAMPLE_PATTERN.0 as u32, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + device.CreateTexture2D(&desc, None, Some(&mut output))?; + output.unwrap() + }; + let mut msaa_view = None; + unsafe { device.CreateRenderTargetView(&msaa_texture, None, Some(&mut msaa_view))? }; + Ok((msaa_texture, [Some(msaa_view.unwrap())])) +} + +#[inline] +fn set_viewport( + device_context: &ID3D11DeviceContext, + width: f32, + height: f32, +) -> [D3D11_VIEWPORT; 1] { + let viewport = [D3D11_VIEWPORT { + TopLeftX: 0.0, + TopLeftY: 0.0, + Width: width, + Height: height, + MinDepth: 0.0, + MaxDepth: 1.0, + }]; + unsafe { device_context.RSSetViewports(Some(&viewport)) }; + viewport +} + +#[inline] +fn set_rasterizer_state(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Result<()> { + let desc = D3D11_RASTERIZER_DESC { + FillMode: D3D11_FILL_SOLID, + CullMode: D3D11_CULL_NONE, + FrontCounterClockwise: false.into(), + DepthBias: 0, + DepthBiasClamp: 0.0, + SlopeScaledDepthBias: 0.0, + DepthClipEnable: true.into(), + ScissorEnable: false.into(), + MultisampleEnable: true.into(), + AntialiasedLineEnable: false.into(), + }; + let rasterizer_state = unsafe { + let mut state = None; + device.CreateRasterizerState(&desc, Some(&mut state))?; + state.unwrap() + }; + unsafe { device_context.RSSetState(&rasterizer_state) }; + Ok(()) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_blend_desc +#[inline] +fn create_blend_state(device: &ID3D11Device) -> Result<ID3D11BlendState> { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_SRC_ALPHA; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_blend_state_for_path_rasterization(device: &ID3D11Device) -> Result<ID3D11BlendState> { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_blend_state_for_path_sprite(device: &ID3D11Device) -> Result<ID3D11BlendState> { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_vertex_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11VertexShader> { + unsafe { + let mut shader = None; + device.CreateVertexShader(bytes, None, Some(&mut shader))?; + Ok(shader.unwrap()) + } +} + +#[inline] +fn create_fragment_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11PixelShader> { + unsafe { + let mut shader = None; + device.CreatePixelShader(bytes, None, Some(&mut shader))?; + Ok(shader.unwrap()) + } +} + +#[inline] +fn create_buffer( + device: &ID3D11Device, + element_size: usize, + buffer_size: usize, +) -> Result<ID3D11Buffer> { + let desc = D3D11_BUFFER_DESC { + ByteWidth: (element_size * buffer_size) as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: D3D11_RESOURCE_MISC_BUFFER_STRUCTURED.0 as u32, + StructureByteStride: element_size as u32, + }; + let mut buffer = None; + unsafe { device.CreateBuffer(&desc, None, Some(&mut buffer)) }?; + Ok(buffer.unwrap()) +} + +#[inline] +fn create_buffer_view( + device: &ID3D11Device, + buffer: &ID3D11Buffer, +) -> Result<[Option<ID3D11ShaderResourceView>; 1]> { + let mut view = None; + unsafe { device.CreateShaderResourceView(buffer, None, Some(&mut view)) }?; + Ok([view]) +} + +#[inline] +fn update_buffer<T>( + device_context: &ID3D11DeviceContext, + buffer: &ID3D11Buffer, + data: &[T], +) -> Result<()> { + unsafe { + let mut dest = std::mem::zeroed(); + device_context.Map(buffer, 0, D3D11_MAP_WRITE_DISCARD, 0, Some(&mut dest))?; + std::ptr::copy_nonoverlapping(data.as_ptr(), dest.pData as _, data.len()); + device_context.Unmap(buffer, 0); + } + Ok(()) +} + +#[inline] +fn set_pipeline_state( + device_context: &ID3D11DeviceContext, + buffer_view: &[Option<ID3D11ShaderResourceView>], + topology: D3D_PRIMITIVE_TOPOLOGY, + viewport: &[D3D11_VIEWPORT], + vertex_shader: &ID3D11VertexShader, + fragment_shader: &ID3D11PixelShader, + global_params: &[Option<ID3D11Buffer>], + blend_state: &ID3D11BlendState, +) { + unsafe { + device_context.VSSetShaderResources(1, Some(buffer_view)); + device_context.PSSetShaderResources(1, Some(buffer_view)); + device_context.IASetPrimitiveTopology(topology); + device_context.RSSetViewports(Some(viewport)); + device_context.VSSetShader(vertex_shader, None); + device_context.PSSetShader(fragment_shader, None); + device_context.VSSetConstantBuffers(0, Some(global_params)); + device_context.PSSetConstantBuffers(0, Some(global_params)); + device_context.OMSetBlendState(blend_state, None, 0xFFFFFFFF); + } +} + +#[cfg(debug_assertions)] +fn report_live_objects(device: &ID3D11Device) -> Result<()> { + let debug_device: ID3D11Debug = device.cast()?; + unsafe { + debug_device.ReportLiveDeviceObjects(D3D11_RLDO_DETAIL)?; + } + Ok(()) +} + +const BUFFER_COUNT: usize = 3; + +pub(crate) mod shader_resources { + use anyhow::Result; + + #[cfg(debug_assertions)] + use windows::{ + Win32::Graphics::Direct3D::{ + Fxc::{D3DCOMPILE_DEBUG, D3DCOMPILE_SKIP_OPTIMIZATION, D3DCompileFromFile}, + ID3DBlob, + }, + core::{HSTRING, PCSTR}, + }; + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(crate) enum ShaderModule { + Quad, + Shadow, + Underline, + PathRasterization, + PathSprite, + MonochromeSprite, + PolychromeSprite, + EmojiRasterization, + } + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(crate) enum ShaderTarget { + Vertex, + Fragment, + } + + pub(crate) struct RawShaderBytes<'t> { + inner: &'t [u8], + + #[cfg(debug_assertions)] + _blob: ID3DBlob, + } + + impl<'t> RawShaderBytes<'t> { + pub(crate) fn new(module: ShaderModule, target: ShaderTarget) -> Result<Self> { + #[cfg(not(debug_assertions))] + { + Ok(Self::from_bytes(module, target)) + } + #[cfg(debug_assertions)] + { + let blob = build_shader_blob(module, target)?; + let inner = unsafe { + std::slice::from_raw_parts( + blob.GetBufferPointer() as *const u8, + blob.GetBufferSize(), + ) + }; + Ok(Self { inner, _blob: blob }) + } + } + + pub(crate) fn as_bytes(&'t self) -> &'t [u8] { + self.inner + } + + #[cfg(not(debug_assertions))] + fn from_bytes(module: ShaderModule, target: ShaderTarget) -> Self { + let bytes = match module { + ShaderModule::Quad => match target { + ShaderTarget::Vertex => QUAD_VERTEX_BYTES, + ShaderTarget::Fragment => QUAD_FRAGMENT_BYTES, + }, + ShaderModule::Shadow => match target { + ShaderTarget::Vertex => SHADOW_VERTEX_BYTES, + ShaderTarget::Fragment => SHADOW_FRAGMENT_BYTES, + }, + ShaderModule::Underline => match target { + ShaderTarget::Vertex => UNDERLINE_VERTEX_BYTES, + ShaderTarget::Fragment => UNDERLINE_FRAGMENT_BYTES, + }, + ShaderModule::PathRasterization => match target { + ShaderTarget::Vertex => PATH_RASTERIZATION_VERTEX_BYTES, + ShaderTarget::Fragment => PATH_RASTERIZATION_FRAGMENT_BYTES, + }, + ShaderModule::PathSprite => match target { + ShaderTarget::Vertex => PATH_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => PATH_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::MonochromeSprite => match target { + ShaderTarget::Vertex => MONOCHROME_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => MONOCHROME_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::PolychromeSprite => match target { + ShaderTarget::Vertex => POLYCHROME_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => POLYCHROME_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::EmojiRasterization => match target { + ShaderTarget::Vertex => EMOJI_RASTERIZATION_VERTEX_BYTES, + ShaderTarget::Fragment => EMOJI_RASTERIZATION_FRAGMENT_BYTES, + }, + }; + Self { inner: bytes } + } + } + + #[cfg(debug_assertions)] + pub(super) fn build_shader_blob(entry: ShaderModule, target: ShaderTarget) -> Result<ID3DBlob> { + unsafe { + let shader_name = if matches!(entry, ShaderModule::EmojiRasterization) { + "color_text_raster.hlsl" + } else { + "shaders.hlsl" + }; + + let entry = format!( + "{}_{}\0", + entry.as_str(), + match target { + ShaderTarget::Vertex => "vertex", + ShaderTarget::Fragment => "fragment", + } + ); + let target = match target { + ShaderTarget::Vertex => "vs_4_1\0", + ShaderTarget::Fragment => "ps_4_1\0", + }; + + let mut compile_blob = None; + let mut error_blob = None; + let shader_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(&format!("src/platform/windows/{}", shader_name)) + .canonicalize()?; + + let entry_point = PCSTR::from_raw(entry.as_ptr()); + let target_cstr = PCSTR::from_raw(target.as_ptr()); + + let ret = D3DCompileFromFile( + &HSTRING::from(shader_path.to_str().unwrap()), + None, + None, + entry_point, + target_cstr, + D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, + 0, + &mut compile_blob, + Some(&mut error_blob), + ); + if ret.is_err() { + let Some(error_blob) = error_blob else { + return Err(anyhow::anyhow!("{ret:?}")); + }; + + let error_string = + std::ffi::CStr::from_ptr(error_blob.GetBufferPointer() as *const i8) + .to_string_lossy(); + log::error!("Shader compile error: {}", error_string); + return Err(anyhow::anyhow!("Compile error: {}", error_string)); + } + Ok(compile_blob.unwrap()) + } + } + + #[cfg(not(debug_assertions))] + include!(concat!(env!("OUT_DIR"), "/shaders_bytes.rs")); + + #[cfg(debug_assertions)] + impl ShaderModule { + pub fn as_str(&self) -> &str { + match self { + ShaderModule::Quad => "quad", + ShaderModule::Shadow => "shadow", + ShaderModule::Underline => "underline", + ShaderModule::PathRasterization => "path_rasterization", + ShaderModule::PathSprite => "path_sprite", + ShaderModule::MonochromeSprite => "monochrome_sprite", + ShaderModule::PolychromeSprite => "polychrome_sprite", + ShaderModule::EmojiRasterization => "emoji_rasterization", + } + } + } +} + +mod nvidia { + use std::{ + ffi::CStr, + os::raw::{c_char, c_int, c_uint}, + }; + + use anyhow::{Context, Result}; + use windows::{ + Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, + core::s, + }; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180 + const NVAPI_SHORT_STRING_MAX: usize = 64; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L235 + #[allow(non_camel_case_types)] + type NvAPI_ShortString = [c_char; NVAPI_SHORT_STRING_MAX]; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L447 + #[allow(non_camel_case_types)] + type NvAPI_SYS_GetDriverAndBranchVersion_t = unsafe extern "C" fn( + driver_version: *mut c_uint, + build_branch_string: *mut NvAPI_ShortString, + ) -> c_int; + + pub(super) fn get_driver_version() -> Result<String> { + unsafe { + // Try to load the NVIDIA driver DLL + #[cfg(target_pointer_width = "64")] + let nvidia_dll = LoadLibraryA(s!("nvapi64.dll")).context("Can't load nvapi64.dll")?; + #[cfg(target_pointer_width = "32")] + let nvidia_dll = LoadLibraryA(s!("nvapi.dll")).context("Can't load nvapi.dll")?; + + let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface")) + .ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?; + let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr); + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_interface.h#L41 + let nvapi_get_driver_version_ptr = nvapi_query(0x2926aaad); + if nvapi_get_driver_version_ptr.is_null() { + anyhow::bail!("Failed to get NVIDIA driver version function pointer"); + } + let nvapi_get_driver_version: NvAPI_SYS_GetDriverAndBranchVersion_t = + std::mem::transmute(nvapi_get_driver_version_ptr); + + let mut driver_version: c_uint = 0; + let mut build_branch_string: NvAPI_ShortString = [0; NVAPI_SHORT_STRING_MAX]; + let result = nvapi_get_driver_version( + &mut driver_version as *mut c_uint, + &mut build_branch_string as *mut NvAPI_ShortString, + ); + + if result != 0 { + anyhow::bail!( + "Failed to get NVIDIA driver version, error code: {}", + result + ); + } + let major = driver_version / 100; + let minor = driver_version % 100; + let branch_string = CStr::from_ptr(build_branch_string.as_ptr()); + Ok(format!( + "{}.{} {}", + major, + minor, + branch_string.to_string_lossy() + )) + } + } +} + +mod amd { + use std::os::raw::{c_char, c_int, c_void}; + + use anyhow::{Context, Result}; + use windows::{ + Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, + core::s, + }; + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145 + const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12); + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L204 + // This is an opaque type, using struct to represent it properly for FFI + #[repr(C)] + struct AGSContext { + _private: [u8; 0], + } + + #[repr(C)] + pub struct AGSGPUInfo { + pub driver_version: *const c_char, + pub radeon_software_version: *const c_char, + pub num_devices: c_int, + pub devices: *mut c_void, + } + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L429 + #[allow(non_camel_case_types)] + type agsInitialize_t = unsafe extern "C" fn( + version: c_int, + config: *const c_void, + context: *mut *mut AGSContext, + gpu_info: *mut AGSGPUInfo, + ) -> c_int; + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L436 + #[allow(non_camel_case_types)] + type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int; + + pub(super) fn get_driver_version() -> Result<String> { + unsafe { + #[cfg(target_pointer_width = "64")] + let amd_dll = + LoadLibraryA(s!("amd_ags_x64.dll")).context("Failed to load AMD AGS library")?; + #[cfg(target_pointer_width = "32")] + let amd_dll = + LoadLibraryA(s!("amd_ags_x86.dll")).context("Failed to load AMD AGS library")?; + + let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize")) + .ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?; + let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize")) + .ok_or_else(|| anyhow::anyhow!("Failed to get agsDeInitialize address"))?; + + let ags_initialize: agsInitialize_t = std::mem::transmute(ags_initialize_addr); + let ags_deinitialize: agsDeInitialize_t = std::mem::transmute(ags_deinitialize_addr); + + let mut context: *mut AGSContext = std::ptr::null_mut(); + let mut gpu_info: AGSGPUInfo = AGSGPUInfo { + driver_version: std::ptr::null(), + radeon_software_version: std::ptr::null(), + num_devices: 0, + devices: std::ptr::null_mut(), + }; + + let result = ags_initialize( + AGS_CURRENT_VERSION, + std::ptr::null(), + &mut context, + &mut gpu_info, + ); + if result != 0 { + anyhow::bail!("Failed to initialize AMD AGS, error code: {}", result); + } + + // Vulkan acctually returns this as the driver version + let software_version = if !gpu_info.radeon_software_version.is_null() { + std::ffi::CStr::from_ptr(gpu_info.radeon_software_version) + .to_string_lossy() + .into_owned() + } else { + "Unknown Radeon Software Version".to_string() + }; + + let driver_version = if !gpu_info.driver_version.is_null() { + std::ffi::CStr::from_ptr(gpu_info.driver_version) + .to_string_lossy() + .into_owned() + } else { + "Unknown Radeon Driver Version".to_string() + }; + + ags_deinitialize(context); + Ok(format!("{} ({})", software_version, driver_version)) + } + } +} + +mod dxgi { + use windows::{ + Win32::Graphics::Dxgi::{IDXGIAdapter1, IDXGIDevice}, + core::Interface, + }; + + pub(super) fn get_driver_version(adapter: &IDXGIAdapter1) -> anyhow::Result<String> { + let number = unsafe { adapter.CheckInterfaceSupport(&IDXGIDevice::IID as _) }?; + Ok(format!( + "{}.{}.{}.{}", + number >> 48, + (number >> 32) & 0xFFFF, + (number >> 16) & 0xFFFF, + number & 0xFFFF + )) + } +} diff --git a/crates/gpui/src/platform/windows/events.rs b/crates/gpui/src/platform/windows/events.rs index 839fd10375..4ab257d27a 100644 --- a/crates/gpui/src/platform/windows/events.rs +++ b/crates/gpui/src/platform/windows/events.rs @@ -23,1027 +23,894 @@ pub(crate) const WM_GPUI_CURSOR_STYLE_CHANGED: u32 = WM_USER + 1; pub(crate) const WM_GPUI_CLOSE_ONE_WINDOW: u32 = WM_USER + 2; pub(crate) const WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD: u32 = WM_USER + 3; pub(crate) const WM_GPUI_DOCK_MENU_ACTION: u32 = WM_USER + 4; +pub(crate) const WM_GPUI_FORCE_UPDATE_WINDOW: u32 = WM_USER + 5; const SIZE_MOVE_LOOP_TIMER_ID: usize = 1; const AUTO_HIDE_TASKBAR_THICKNESS_PX: i32 = 1; -pub(crate) fn handle_msg( - handle: HWND, - msg: u32, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> LRESULT { - let handled = match msg { - WM_ACTIVATE => handle_activate_msg(wparam, state_ptr), - WM_CREATE => handle_create_msg(handle, state_ptr), - WM_MOVE => handle_move_msg(handle, lparam, state_ptr), - WM_SIZE => handle_size_msg(wparam, lparam, state_ptr), - WM_GETMINMAXINFO => handle_get_min_max_info_msg(lparam, state_ptr), - WM_ENTERSIZEMOVE | WM_ENTERMENULOOP => handle_size_move_loop(handle), - WM_EXITSIZEMOVE | WM_EXITMENULOOP => handle_size_move_loop_exit(handle), - WM_TIMER => handle_timer_msg(handle, wparam, state_ptr), - WM_NCCALCSIZE => handle_calc_client_size(handle, wparam, lparam, state_ptr), - WM_DPICHANGED => handle_dpi_changed_msg(handle, wparam, lparam, state_ptr), - WM_DISPLAYCHANGE => handle_display_change_msg(handle, state_ptr), - WM_NCHITTEST => handle_hit_test_msg(handle, msg, wparam, lparam, state_ptr), - WM_PAINT => handle_paint_msg(handle, state_ptr), - WM_CLOSE => handle_close_msg(handle, state_ptr), - WM_DESTROY => handle_destroy_msg(handle, state_ptr), - WM_MOUSEMOVE => handle_mouse_move_msg(handle, lparam, wparam, state_ptr), - WM_MOUSELEAVE | WM_NCMOUSELEAVE => handle_mouse_leave_msg(state_ptr), - WM_NCMOUSEMOVE => handle_nc_mouse_move_msg(handle, lparam, state_ptr), - WM_NCLBUTTONDOWN => { - handle_nc_mouse_down_msg(handle, MouseButton::Left, wparam, lparam, state_ptr) +impl WindowsWindowInner { + pub(crate) fn handle_msg( + self: &Rc<Self>, + handle: HWND, + msg: u32, + wparam: WPARAM, + lparam: LPARAM, + ) -> LRESULT { + let handled = match msg { + WM_ACTIVATE => self.handle_activate_msg(wparam), + WM_CREATE => self.handle_create_msg(handle), + WM_DEVICECHANGE => self.handle_device_change_msg(handle, wparam), + WM_MOVE => self.handle_move_msg(handle, lparam), + WM_SIZE => self.handle_size_msg(wparam, lparam), + WM_GETMINMAXINFO => self.handle_get_min_max_info_msg(lparam), + WM_ENTERSIZEMOVE | WM_ENTERMENULOOP => self.handle_size_move_loop(handle), + WM_EXITSIZEMOVE | WM_EXITMENULOOP => self.handle_size_move_loop_exit(handle), + WM_TIMER => self.handle_timer_msg(handle, wparam), + WM_NCCALCSIZE => self.handle_calc_client_size(handle, wparam, lparam), + WM_DPICHANGED => self.handle_dpi_changed_msg(handle, wparam, lparam), + WM_DISPLAYCHANGE => self.handle_display_change_msg(handle), + WM_NCHITTEST => self.handle_hit_test_msg(handle, msg, wparam, lparam), + WM_PAINT => self.handle_paint_msg(handle), + WM_CLOSE => self.handle_close_msg(), + WM_DESTROY => self.handle_destroy_msg(handle), + WM_MOUSEMOVE => self.handle_mouse_move_msg(handle, lparam, wparam), + WM_MOUSELEAVE | WM_NCMOUSELEAVE => self.handle_mouse_leave_msg(), + WM_NCMOUSEMOVE => self.handle_nc_mouse_move_msg(handle, lparam), + WM_NCLBUTTONDOWN => { + self.handle_nc_mouse_down_msg(handle, MouseButton::Left, wparam, lparam) + } + WM_NCRBUTTONDOWN => { + self.handle_nc_mouse_down_msg(handle, MouseButton::Right, wparam, lparam) + } + WM_NCMBUTTONDOWN => { + self.handle_nc_mouse_down_msg(handle, MouseButton::Middle, wparam, lparam) + } + WM_NCLBUTTONUP => { + self.handle_nc_mouse_up_msg(handle, MouseButton::Left, wparam, lparam) + } + WM_NCRBUTTONUP => { + self.handle_nc_mouse_up_msg(handle, MouseButton::Right, wparam, lparam) + } + WM_NCMBUTTONUP => { + self.handle_nc_mouse_up_msg(handle, MouseButton::Middle, wparam, lparam) + } + WM_LBUTTONDOWN => self.handle_mouse_down_msg(handle, MouseButton::Left, lparam), + WM_RBUTTONDOWN => self.handle_mouse_down_msg(handle, MouseButton::Right, lparam), + WM_MBUTTONDOWN => self.handle_mouse_down_msg(handle, MouseButton::Middle, lparam), + WM_XBUTTONDOWN => { + self.handle_xbutton_msg(handle, wparam, lparam, Self::handle_mouse_down_msg) + } + WM_LBUTTONUP => self.handle_mouse_up_msg(handle, MouseButton::Left, lparam), + WM_RBUTTONUP => self.handle_mouse_up_msg(handle, MouseButton::Right, lparam), + WM_MBUTTONUP => self.handle_mouse_up_msg(handle, MouseButton::Middle, lparam), + WM_XBUTTONUP => { + self.handle_xbutton_msg(handle, wparam, lparam, Self::handle_mouse_up_msg) + } + WM_MOUSEWHEEL => self.handle_mouse_wheel_msg(handle, wparam, lparam), + WM_MOUSEHWHEEL => self.handle_mouse_horizontal_wheel_msg(handle, wparam, lparam), + WM_SYSKEYDOWN => self.handle_syskeydown_msg(handle, wparam, lparam), + WM_SYSKEYUP => self.handle_syskeyup_msg(handle, wparam, lparam), + WM_SYSCOMMAND => self.handle_system_command(wparam), + WM_KEYDOWN => self.handle_keydown_msg(handle, wparam, lparam), + WM_KEYUP => self.handle_keyup_msg(handle, wparam, lparam), + WM_CHAR => self.handle_char_msg(wparam), + WM_DEADCHAR => self.handle_dead_char_msg(wparam), + WM_IME_STARTCOMPOSITION => self.handle_ime_position(handle), + WM_IME_COMPOSITION => self.handle_ime_composition(handle, lparam), + WM_SETCURSOR => self.handle_set_cursor(handle, lparam), + WM_SETTINGCHANGE => self.handle_system_settings_changed(handle, wparam, lparam), + WM_INPUTLANGCHANGE => self.handle_input_language_changed(lparam), + WM_GPUI_CURSOR_STYLE_CHANGED => self.handle_cursor_changed(lparam), + WM_GPUI_FORCE_UPDATE_WINDOW => self.draw_window(handle, true), + _ => None, + }; + if let Some(n) = handled { + LRESULT(n) + } else { + unsafe { DefWindowProcW(handle, msg, wparam, lparam) } } - WM_NCRBUTTONDOWN => { - handle_nc_mouse_down_msg(handle, MouseButton::Right, wparam, lparam, state_ptr) - } - WM_NCMBUTTONDOWN => { - handle_nc_mouse_down_msg(handle, MouseButton::Middle, wparam, lparam, state_ptr) - } - WM_NCLBUTTONUP => { - handle_nc_mouse_up_msg(handle, MouseButton::Left, wparam, lparam, state_ptr) - } - WM_NCRBUTTONUP => { - handle_nc_mouse_up_msg(handle, MouseButton::Right, wparam, lparam, state_ptr) - } - WM_NCMBUTTONUP => { - handle_nc_mouse_up_msg(handle, MouseButton::Middle, wparam, lparam, state_ptr) - } - WM_LBUTTONDOWN => handle_mouse_down_msg(handle, MouseButton::Left, lparam, state_ptr), - WM_RBUTTONDOWN => handle_mouse_down_msg(handle, MouseButton::Right, lparam, state_ptr), - WM_MBUTTONDOWN => handle_mouse_down_msg(handle, MouseButton::Middle, lparam, state_ptr), - WM_XBUTTONDOWN => { - handle_xbutton_msg(handle, wparam, lparam, handle_mouse_down_msg, state_ptr) - } - WM_LBUTTONUP => handle_mouse_up_msg(handle, MouseButton::Left, lparam, state_ptr), - WM_RBUTTONUP => handle_mouse_up_msg(handle, MouseButton::Right, lparam, state_ptr), - WM_MBUTTONUP => handle_mouse_up_msg(handle, MouseButton::Middle, lparam, state_ptr), - WM_XBUTTONUP => handle_xbutton_msg(handle, wparam, lparam, handle_mouse_up_msg, state_ptr), - WM_MOUSEWHEEL => handle_mouse_wheel_msg(handle, wparam, lparam, state_ptr), - WM_MOUSEHWHEEL => handle_mouse_horizontal_wheel_msg(handle, wparam, lparam, state_ptr), - WM_SYSKEYDOWN => handle_syskeydown_msg(handle, wparam, lparam, state_ptr), - WM_SYSKEYUP => handle_syskeyup_msg(handle, wparam, lparam, state_ptr), - WM_SYSCOMMAND => handle_system_command(wparam, state_ptr), - WM_KEYDOWN => handle_keydown_msg(handle, wparam, lparam, state_ptr), - WM_KEYUP => handle_keyup_msg(handle, wparam, lparam, state_ptr), - WM_CHAR => handle_char_msg(wparam, state_ptr), - WM_DEADCHAR => handle_dead_char_msg(wparam, state_ptr), - WM_IME_STARTCOMPOSITION => handle_ime_position(handle, state_ptr), - WM_IME_COMPOSITION => handle_ime_composition(handle, lparam, state_ptr), - WM_SETCURSOR => handle_set_cursor(handle, lparam, state_ptr), - WM_SETTINGCHANGE => handle_system_settings_changed(handle, wparam, lparam, state_ptr), - WM_INPUTLANGCHANGE => handle_input_language_changed(lparam, state_ptr), - WM_GPUI_CURSOR_STYLE_CHANGED => handle_cursor_changed(lparam, state_ptr), - _ => None, - }; - if let Some(n) = handled { - LRESULT(n) - } else { - unsafe { DefWindowProcW(handle, msg, wparam, lparam) } - } -} - -fn handle_move_msg( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let origin = logical_point( - lparam.signed_loword() as f32, - lparam.signed_hiword() as f32, - lock.scale_factor, - ); - lock.origin = origin; - let size = lock.logical_size; - let center_x = origin.x.0 + size.width.0 / 2.; - let center_y = origin.y.0 + size.height.0 / 2.; - let monitor_bounds = lock.display.bounds(); - if center_x < monitor_bounds.left().0 - || center_x > monitor_bounds.right().0 - || center_y < monitor_bounds.top().0 - || center_y > monitor_bounds.bottom().0 - { - // center of the window may have moved to another monitor - let monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; - // minimize the window can trigger this event too, in this case, - // monitor is invalid, we do nothing. - if !monitor.is_invalid() && lock.display.handle != monitor { - // we will get the same monitor if we only have one - lock.display = WindowsDisplay::new_with_handle(monitor); - } - } - if let Some(mut callback) = lock.callbacks.moved.take() { - drop(lock); - callback(); - state_ptr.state.borrow_mut().callbacks.moved = Some(callback); - } - Some(0) -} - -fn handle_get_min_max_info_msg( - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let lock = state_ptr.state.borrow(); - let min_size = lock.min_size?; - let scale_factor = lock.scale_factor; - let boarder_offset = lock.border_offset; - drop(lock); - unsafe { - let minmax_info = &mut *(lparam.0 as *mut MINMAXINFO); - minmax_info.ptMinTrackSize.x = - min_size.width.scale(scale_factor).0 as i32 + boarder_offset.width_offset; - minmax_info.ptMinTrackSize.y = - min_size.height.scale(scale_factor).0 as i32 + boarder_offset.height_offset; - } - Some(0) -} - -fn handle_size_msg( - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - - // Don't resize the renderer when the window is minimized, but record that it was minimized so - // that on restore the swap chain can be recreated via `update_drawable_size_even_if_unchanged`. - if wparam.0 == SIZE_MINIMIZED as usize { - lock.restore_from_minimized = lock.callbacks.request_frame.take(); - return Some(0); } - let width = lparam.loword().max(1) as i32; - let height = lparam.hiword().max(1) as i32; - let new_size = size(DevicePixels(width), DevicePixels(height)); - let scale_factor = lock.scale_factor; - if lock.restore_from_minimized.is_some() { - lock.renderer - .update_drawable_size_even_if_unchanged(new_size); - lock.callbacks.request_frame = lock.restore_from_minimized.take(); - } else { - lock.renderer.update_drawable_size(new_size); - } - let new_size = new_size.to_pixels(scale_factor); - lock.logical_size = new_size; - if let Some(mut callback) = lock.callbacks.resize.take() { - drop(lock); - callback(new_size, scale_factor); - state_ptr.state.borrow_mut().callbacks.resize = Some(callback); - } - Some(0) -} - -fn handle_size_move_loop(handle: HWND) -> Option<isize> { - unsafe { - let ret = SetTimer( - Some(handle), - SIZE_MOVE_LOOP_TIMER_ID, - USER_TIMER_MINIMUM, - None, + fn handle_move_msg(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let origin = logical_point( + lparam.signed_loword() as f32, + lparam.signed_hiword() as f32, + lock.scale_factor, ); - if ret == 0 { - log::error!( - "unable to create timer: {}", - std::io::Error::last_os_error() - ); - } - } - None -} - -fn handle_size_move_loop_exit(handle: HWND) -> Option<isize> { - unsafe { - KillTimer(Some(handle), SIZE_MOVE_LOOP_TIMER_ID).log_err(); - } - None -} - -fn handle_timer_msg( - handle: HWND, - wparam: WPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID { - for runnable in state_ptr.main_receiver.drain() { - runnable.run(); - } - handle_paint_msg(handle, state_ptr) - } else { - None - } -} - -fn handle_paint_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut request_frame) = lock.callbacks.request_frame.take() { - drop(lock); - request_frame(Default::default()); - state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame); - } - unsafe { ValidateRect(Some(handle), None).ok().log_err() }; - Some(0) -} - -fn handle_close_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let output = if let Some(mut callback) = lock.callbacks.should_close.take() { - drop(lock); - let should_close = callback(); - state_ptr.state.borrow_mut().callbacks.should_close = Some(callback); - if should_close { None } else { Some(0) } - } else { - None - }; - - // Workaround as window close animation is not played with `WS_EX_LAYERED` enabled. - if output.is_none() { - unsafe { - let current_style = get_window_long(handle, GWL_EXSTYLE); - set_window_long( - handle, - GWL_EXSTYLE, - current_style & !WS_EX_LAYERED.0 as isize, - ); - } - } - - output -} - -fn handle_destroy_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let callback = { - let mut lock = state_ptr.state.borrow_mut(); - lock.callbacks.close.take() - }; - if let Some(callback) = callback { - callback(); - } - unsafe { - PostThreadMessageW( - state_ptr.main_thread_id_win32, - WM_GPUI_CLOSE_ONE_WINDOW, - WPARAM(state_ptr.validation_number), - LPARAM(handle.0 as isize), - ) - .log_err(); - } - Some(0) -} - -fn handle_mouse_move_msg( - handle: HWND, - lparam: LPARAM, - wparam: WPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - start_tracking_mouse(handle, &state_ptr, TME_LEAVE); - - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let scale_factor = lock.scale_factor; - drop(lock); - - let pressed_button = match MODIFIERKEYS_FLAGS(wparam.loword() as u32) { - flags if flags.contains(MK_LBUTTON) => Some(MouseButton::Left), - flags if flags.contains(MK_RBUTTON) => Some(MouseButton::Right), - flags if flags.contains(MK_MBUTTON) => Some(MouseButton::Middle), - flags if flags.contains(MK_XBUTTON1) => { - Some(MouseButton::Navigate(NavigationDirection::Back)) - } - flags if flags.contains(MK_XBUTTON2) => { - Some(MouseButton::Navigate(NavigationDirection::Forward)) - } - _ => None, - }; - let x = lparam.signed_loword() as f32; - let y = lparam.signed_hiword() as f32; - let input = PlatformInput::MouseMove(MouseMoveEvent { - position: logical_point(x, y, scale_factor), - pressed_button, - modifiers: current_modifiers(), - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_mouse_leave_msg(state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - lock.hovered = false; - if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { - drop(lock); - callback(false); - state_ptr.state.borrow_mut().callbacks.hovered_status_change = Some(callback); - } - - Some(0) -} - -fn handle_syskeydown_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyDown(KeyDownEvent { - keystroke, - is_held: lparam.0 & (0x1 << 30) > 0, - }) - })?; - let mut func = lock.callbacks.input.take()?; - drop(lock); - - let handled = !func(input).propagate; - - let mut lock = state_ptr.state.borrow_mut(); - lock.callbacks.input = Some(func); - - if handled { - lock.system_key_handled = true; - Some(0) - } else { - // we need to call `DefWindowProcW`, or we will lose the system-wide `Alt+F4`, `Alt+{other keys}` - // shortcuts. - None - } -} - -fn handle_syskeyup_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyUp(KeyUpEvent { keystroke }) - })?; - let mut func = lock.callbacks.input.take()?; - drop(lock); - func(input); - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - // Always return 0 to indicate that the message was handled, so we could properly handle `ModifiersChanged` event. - Some(0) -} - -// It's a known bug that you can't trigger `ctrl-shift-0`. See: -// https://superuser.com/questions/1455762/ctrl-shift-number-key-combination-has-stopped-working-for-a-few-numbers -fn handle_keydown_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyDown(KeyDownEvent { - keystroke, - is_held: lparam.0 & (0x1 << 30) > 0, - }) - }) else { - return Some(1); - }; - drop(lock); - - let is_composing = with_input_handler(&state_ptr, |input_handler| { - input_handler.marked_text_range() - }) - .flatten() - .is_some(); - if is_composing { - translate_message(handle, wparam, lparam); - return Some(0); - } - - let Some(mut func) = state_ptr.state.borrow_mut().callbacks.input.take() else { - return Some(1); - }; - - let handled = !func(input).propagate; - - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { - Some(0) - } else { - translate_message(handle, wparam, lparam); - Some(1) - } -} - -fn handle_keyup_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { - PlatformInput::KeyUp(KeyUpEvent { keystroke }) - }) else { - return Some(1); - }; - - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - drop(lock); - - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_char_msg(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let input = parse_char_message(wparam, &state_ptr)?; - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_text_in_range(None, &input); - }); - - Some(0) -} - -fn handle_dead_char_msg(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let ch = char::from_u32(wparam.0 as u32)?.to_string(); - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_and_mark_text_in_range(None, &ch, None); - }); - None -} - -fn handle_mouse_down_msg( - handle: HWND, - button: MouseButton, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - unsafe { SetCapture(handle) }; - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let x = lparam.signed_loword(); - let y = lparam.signed_hiword(); - let physical_point = point(DevicePixels(x as i32), DevicePixels(y as i32)); - let click_count = lock.click_state.update(button, physical_point); - let scale_factor = lock.scale_factor; - drop(lock); - - let input = PlatformInput::MouseDown(MouseDownEvent { - button, - position: logical_point(x as f32, y as f32, scale_factor), - modifiers: current_modifiers(), - click_count, - first_mouse: false, - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_mouse_up_msg( - _handle: HWND, - button: MouseButton, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - unsafe { ReleaseCapture().log_err() }; - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let x = lparam.signed_loword() as f32; - let y = lparam.signed_hiword() as f32; - let click_count = lock.click_state.current_count; - let scale_factor = lock.scale_factor; - drop(lock); - - let input = PlatformInput::MouseUp(MouseUpEvent { - button, - position: logical_point(x, y, scale_factor), - modifiers: current_modifiers(), - click_count, - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_xbutton_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - handler: impl Fn(HWND, MouseButton, LPARAM, Rc<WindowsWindowStatePtr>) -> Option<isize>, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let nav_dir = match wparam.hiword() { - XBUTTON1 => NavigationDirection::Back, - XBUTTON2 => NavigationDirection::Forward, - _ => return Some(1), - }; - handler(handle, MouseButton::Navigate(nav_dir), lparam, state_ptr) -} - -fn handle_mouse_wheel_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let modifiers = current_modifiers(); - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let scale_factor = lock.scale_factor; - let wheel_scroll_amount = match modifiers.shift { - true => lock.system_settings.mouse_wheel_settings.wheel_scroll_chars, - false => lock.system_settings.mouse_wheel_settings.wheel_scroll_lines, - }; - drop(lock); - - let wheel_distance = - (wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_amount as f32; - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let input = PlatformInput::ScrollWheel(ScrollWheelEvent { - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - delta: ScrollDelta::Lines(match modifiers.shift { - true => Point { - x: wheel_distance, - y: 0.0, - }, - false => Point { - y: wheel_distance, - x: 0.0, - }, - }), - modifiers, - touch_phase: TouchPhase::Moved, - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn handle_mouse_horizontal_wheel_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - let Some(mut func) = lock.callbacks.input.take() else { - return Some(1); - }; - let scale_factor = lock.scale_factor; - let wheel_scroll_chars = lock.system_settings.mouse_wheel_settings.wheel_scroll_chars; - drop(lock); - - let wheel_distance = - (-wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_chars as f32; - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let event = PlatformInput::ScrollWheel(ScrollWheelEvent { - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - delta: ScrollDelta::Lines(Point { - x: wheel_distance, - y: 0.0, - }), - modifiers: current_modifiers(), - touch_phase: TouchPhase::Moved, - }); - let handled = !func(event).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { Some(1) } -} - -fn retrieve_caret_position(state_ptr: &Rc<WindowsWindowStatePtr>) -> Option<POINT> { - with_input_handler_and_scale_factor(state_ptr, |input_handler, scale_factor| { - let caret_range = input_handler.selected_text_range(false)?; - let caret_position = input_handler.bounds_for_range(caret_range.range)?; - Some(POINT { - // logical to physical - x: (caret_position.origin.x.0 * scale_factor) as i32, - y: (caret_position.origin.y.0 * scale_factor) as i32 - + ((caret_position.size.height.0 * scale_factor) as i32 / 2), - }) - }) -} - -fn handle_ime_position(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - unsafe { - let ctx = ImmGetContext(handle); - - let Some(caret_position) = retrieve_caret_position(&state_ptr) else { - return Some(0); - }; + lock.origin = origin; + let size = lock.logical_size; + let center_x = origin.x.0 + size.width.0 / 2.; + let center_y = origin.y.0 + size.height.0 / 2.; + let monitor_bounds = lock.display.bounds(); + if center_x < monitor_bounds.left().0 + || center_x > monitor_bounds.right().0 + || center_y < monitor_bounds.top().0 + || center_y > monitor_bounds.bottom().0 { - let config = COMPOSITIONFORM { - dwStyle: CFS_POINT, - ptCurrentPos: caret_position, - ..Default::default() - }; - ImmSetCompositionWindow(ctx, &config as _).ok().log_err(); - } - { - let config = CANDIDATEFORM { - dwStyle: CFS_CANDIDATEPOS, - ptCurrentPos: caret_position, - ..Default::default() - }; - ImmSetCandidateWindow(ctx, &config as _).ok().log_err(); - } - ImmReleaseContext(handle, ctx).ok().log_err(); - Some(0) - } -} - -fn handle_ime_composition( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let ctx = unsafe { ImmGetContext(handle) }; - let result = handle_ime_composition_inner(ctx, lparam, state_ptr); - unsafe { ImmReleaseContext(handle, ctx).ok().log_err() }; - result -} - -fn handle_ime_composition_inner( - ctx: HIMC, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let lparam = lparam.0 as u32; - if lparam == 0 { - // Japanese IME may send this message with lparam = 0, which indicates that - // there is no composition string. - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_text_in_range(None, ""); - })?; - Some(0) - } else { - if lparam & GCS_COMPSTR.0 > 0 { - let comp_string = parse_ime_composition_string(ctx, GCS_COMPSTR)?; - let caret_pos = (!comp_string.is_empty() && lparam & GCS_CURSORPOS.0 > 0).then(|| { - let pos = retrieve_composition_cursor_position(ctx); - pos..pos - }); - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_and_mark_text_in_range(None, &comp_string, caret_pos); - })?; - } - if lparam & GCS_RESULTSTR.0 > 0 { - let comp_result = parse_ime_composition_string(ctx, GCS_RESULTSTR)?; - with_input_handler(&state_ptr, |input_handler| { - input_handler.replace_text_in_range(None, &comp_result); - })?; - return Some(0); - } - - // currently, we don't care other stuff - None - } -} - -/// SEE: https://learn.microsoft.com/en-us/windows/win32/winmsg/wm-nccalcsize -fn handle_calc_client_size( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if !state_ptr.hide_title_bar || state_ptr.state.borrow().is_fullscreen() || wparam.0 == 0 { - return None; - } - - let is_maximized = state_ptr.state.borrow().is_maximized(); - let insets = get_client_area_insets(handle, is_maximized, state_ptr.windows_version); - // wparam is TRUE so lparam points to an NCCALCSIZE_PARAMS structure - let mut params = lparam.0 as *mut NCCALCSIZE_PARAMS; - let mut requested_client_rect = unsafe { &mut ((*params).rgrc) }; - - requested_client_rect[0].left += insets.left; - requested_client_rect[0].top += insets.top; - requested_client_rect[0].right -= insets.right; - requested_client_rect[0].bottom -= insets.bottom; - - // Fix auto hide taskbar not showing. This solution is based on the approach - // used by Chrome. However, it may result in one row of pixels being obscured - // in our client area. But as Chrome says, "there seems to be no better solution." - if is_maximized { - if let Some(ref taskbar_position) = state_ptr - .state - .borrow() - .system_settings - .auto_hide_taskbar_position - { - // Fot the auto-hide taskbar, adjust in by 1 pixel on taskbar edge, - // so the window isn't treated as a "fullscreen app", which would cause - // the taskbar to disappear. - match taskbar_position { - AutoHideTaskbarPosition::Left => { - requested_client_rect[0].left += AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Top => { - requested_client_rect[0].top += AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Right => { - requested_client_rect[0].right -= AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Bottom => { - requested_client_rect[0].bottom -= AUTO_HIDE_TASKBAR_THICKNESS_PX - } + // center of the window may have moved to another monitor + let monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; + // minimize the window can trigger this event too, in this case, + // monitor is invalid, we do nothing. + if !monitor.is_invalid() && lock.display.handle != monitor { + // we will get the same monitor if we only have one + lock.display = WindowsDisplay::new_with_handle(monitor); } } - } - - Some(0) -} - -fn handle_activate_msg(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let activated = wparam.loword() > 0; - let this = state_ptr.clone(); - state_ptr - .executor - .spawn(async move { - let mut lock = this.state.borrow_mut(); - if let Some(mut func) = lock.callbacks.active_status_change.take() { - drop(lock); - func(activated); - this.state.borrow_mut().callbacks.active_status_change = Some(func); - } - }) - .detach(); - - None -} - -fn handle_create_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - if state_ptr.hide_title_bar { - notify_frame_changed(handle); - Some(0) - } else { - None - } -} - -fn handle_dpi_changed_msg( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let new_dpi = wparam.loword() as f32; - let mut lock = state_ptr.state.borrow_mut(); - lock.scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32; - lock.border_offset.update(handle).log_err(); - drop(lock); - - let rect = unsafe { &*(lparam.0 as *const RECT) }; - let width = rect.right - rect.left; - let height = rect.bottom - rect.top; - // this will emit `WM_SIZE` and `WM_MOVE` right here - // even before this function returns - // the new size is handled in `WM_SIZE` - unsafe { - SetWindowPos( - handle, - None, - rect.left, - rect.top, - width, - height, - SWP_NOZORDER | SWP_NOACTIVATE, - ) - .context("unable to set window position after dpi has changed") - .log_err(); - } - - Some(0) -} - -/// The following conditions will trigger this event: -/// 1. The monitor on which the window is located goes offline or changes resolution. -/// 2. Another monitor goes offline, is plugged in, or changes resolution. -/// -/// In either case, the window will only receive information from the monitor on which -/// it is located. -/// -/// For example, in the case of condition 2, where the monitor on which the window is -/// located has actually changed nothing, it will still receive this event. -fn handle_display_change_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - // NOTE: - // Even the `lParam` holds the resolution of the screen, we just ignore it. - // Because WM_DPICHANGED, WM_MOVE, WM_SIZE will come first, window reposition and resize - // are handled there. - // So we only care about if monitor is disconnected. - let previous_monitor = state_ptr.state.borrow().display; - if WindowsDisplay::is_connected(previous_monitor.handle) { - // we are fine, other display changed - return None; - } - // display disconnected - // in this case, the OS will move our window to another monitor, and minimize it. - // we deminimize the window and query the monitor after moving - unsafe { - let _ = ShowWindow(handle, SW_SHOWNORMAL); - }; - let new_monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; - // all monitors disconnected - if new_monitor.is_invalid() { - log::error!("No monitor detected!"); - return None; - } - let new_display = WindowsDisplay::new_with_handle(new_monitor); - state_ptr.state.borrow_mut().display = new_display; - Some(0) -} - -fn handle_hit_test_msg( - handle: HWND, - msg: u32, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if !state_ptr.is_movable || state_ptr.state.borrow().is_fullscreen() { - return None; - } - - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut callback) = lock.callbacks.hit_test_window_control.take() { - drop(lock); - let area = callback(); - state_ptr - .state - .borrow_mut() - .callbacks - .hit_test_window_control = Some(callback); - if let Some(area) = area { - return match area { - WindowControlArea::Drag => Some(HTCAPTION as _), - WindowControlArea::Close => Some(HTCLOSE as _), - WindowControlArea::Max => Some(HTMAXBUTTON as _), - WindowControlArea::Min => Some(HTMINBUTTON as _), - }; + if let Some(mut callback) = lock.callbacks.moved.take() { + drop(lock); + callback(); + self.state.borrow_mut().callbacks.moved = Some(callback); } - } else { - drop(lock); + Some(0) } - if !state_ptr.hide_title_bar { - // If the OS draws the title bar, we don't need to handle hit test messages. - return None; - } - - // default handler for resize areas - let hit = unsafe { DefWindowProcW(handle, msg, wparam, lparam) }; - if matches!( - hit.0 as u32, - HTNOWHERE - | HTRIGHT - | HTLEFT - | HTTOPLEFT - | HTTOP - | HTTOPRIGHT - | HTBOTTOMRIGHT - | HTBOTTOM - | HTBOTTOMLEFT - ) { - return Some(hit.0); - } - - if state_ptr.state.borrow().is_fullscreen() { - return Some(HTCLIENT as _); - } - - let dpi = unsafe { GetDpiForWindow(handle) }; - let frame_y = unsafe { GetSystemMetricsForDpi(SM_CYFRAME, dpi) }; - - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - if !state_ptr.state.borrow().is_maximized() && cursor_point.y >= 0 && cursor_point.y <= frame_y - { - return Some(HTTOP as _); - } - - Some(HTCLIENT as _) -} - -fn handle_nc_mouse_move_msg( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - start_tracking_mouse(handle, &state_ptr, TME_LEAVE | TME_NONCLIENT); - - let mut lock = state_ptr.state.borrow_mut(); - let mut func = lock.callbacks.input.take()?; - let scale_factor = lock.scale_factor; - drop(lock); - - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), - }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let input = PlatformInput::MouseMove(MouseMoveEvent { - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), - pressed_button: None, - modifiers: current_modifiers(), - }); - let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); - - if handled { Some(0) } else { None } -} - -fn handle_nc_mouse_down_msg( - handle: HWND, - button: MouseButton, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut func) = lock.callbacks.input.take() { + fn handle_get_min_max_info_msg(&self, lparam: LPARAM) -> Option<isize> { + let lock = self.state.borrow(); + let min_size = lock.min_size?; let scale_factor = lock.scale_factor; - let mut cursor_point = POINT { - x: lparam.signed_loword().into(), - y: lparam.signed_hiword().into(), + let boarder_offset = lock.border_offset; + drop(lock); + unsafe { + let minmax_info = &mut *(lparam.0 as *mut MINMAXINFO); + minmax_info.ptMinTrackSize.x = + min_size.width.scale(scale_factor).0 as i32 + boarder_offset.width_offset; + minmax_info.ptMinTrackSize.y = + min_size.height.scale(scale_factor).0 as i32 + boarder_offset.height_offset; + } + Some(0) + } + + fn handle_size_msg(&self, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + + // Don't resize the renderer when the window is minimized, but record that it was minimized so + // that on restore the swap chain can be recreated via `update_drawable_size_even_if_unchanged`. + if wparam.0 == SIZE_MINIMIZED as usize { + lock.restore_from_minimized = lock.callbacks.request_frame.take(); + return Some(0); + } + + let width = lparam.loword().max(1) as i32; + let height = lparam.hiword().max(1) as i32; + let new_size = size(DevicePixels(width), DevicePixels(height)); + + let scale_factor = lock.scale_factor; + let mut should_resize_renderer = false; + if lock.restore_from_minimized.is_some() { + lock.callbacks.request_frame = lock.restore_from_minimized.take(); + } else { + should_resize_renderer = true; + } + drop(lock); + + self.handle_size_change(new_size, scale_factor, should_resize_renderer); + Some(0) + } + + fn handle_size_change( + &self, + device_size: Size<DevicePixels>, + scale_factor: f32, + should_resize_renderer: bool, + ) { + let new_logical_size = device_size.to_pixels(scale_factor); + let mut lock = self.state.borrow_mut(); + lock.logical_size = new_logical_size; + if should_resize_renderer { + lock.renderer.resize(device_size).log_err(); + } + if let Some(mut callback) = lock.callbacks.resize.take() { + drop(lock); + callback(new_logical_size, scale_factor); + self.state.borrow_mut().callbacks.resize = Some(callback); + } + } + + fn handle_size_move_loop(&self, handle: HWND) -> Option<isize> { + unsafe { + let ret = SetTimer( + Some(handle), + SIZE_MOVE_LOOP_TIMER_ID, + USER_TIMER_MINIMUM, + None, + ); + if ret == 0 { + log::error!( + "unable to create timer: {}", + std::io::Error::last_os_error() + ); + } + } + None + } + + fn handle_size_move_loop_exit(&self, handle: HWND) -> Option<isize> { + unsafe { + KillTimer(Some(handle), SIZE_MOVE_LOOP_TIMER_ID).log_err(); + } + None + } + + fn handle_timer_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> { + if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID { + for runnable in self.main_receiver.drain() { + runnable.run(); + } + self.handle_paint_msg(handle) + } else { + None + } + } + + fn handle_paint_msg(&self, handle: HWND) -> Option<isize> { + self.draw_window(handle, false) + } + + fn handle_close_msg(&self) -> Option<isize> { + let mut callback = self.state.borrow_mut().callbacks.should_close.take()?; + let should_close = callback(); + self.state.borrow_mut().callbacks.should_close = Some(callback); + if should_close { None } else { Some(0) } + } + + fn handle_destroy_msg(&self, handle: HWND) -> Option<isize> { + let callback = { + let mut lock = self.state.borrow_mut(); + lock.callbacks.close.take() }; - unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let physical_point = point(DevicePixels(cursor_point.x), DevicePixels(cursor_point.y)); + if let Some(callback) = callback { + callback(); + } + unsafe { + PostThreadMessageW( + self.main_thread_id_win32, + WM_GPUI_CLOSE_ONE_WINDOW, + WPARAM(self.validation_number), + LPARAM(handle.0 as isize), + ) + .log_err(); + } + Some(0) + } + + fn handle_mouse_move_msg(&self, handle: HWND, lparam: LPARAM, wparam: WPARAM) -> Option<isize> { + self.start_tracking_mouse(handle, TME_LEAVE); + + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let scale_factor = lock.scale_factor; + drop(lock); + + let pressed_button = match MODIFIERKEYS_FLAGS(wparam.loword() as u32) { + flags if flags.contains(MK_LBUTTON) => Some(MouseButton::Left), + flags if flags.contains(MK_RBUTTON) => Some(MouseButton::Right), + flags if flags.contains(MK_MBUTTON) => Some(MouseButton::Middle), + flags if flags.contains(MK_XBUTTON1) => { + Some(MouseButton::Navigate(NavigationDirection::Back)) + } + flags if flags.contains(MK_XBUTTON2) => { + Some(MouseButton::Navigate(NavigationDirection::Forward)) + } + _ => None, + }; + let x = lparam.signed_loword() as f32; + let y = lparam.signed_hiword() as f32; + let input = PlatformInput::MouseMove(MouseMoveEvent { + position: logical_point(x, y, scale_factor), + pressed_button, + modifiers: current_modifiers(), + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn handle_mouse_leave_msg(&self) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + lock.hovered = false; + if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { + drop(lock); + callback(false); + self.state.borrow_mut().callbacks.hovered_status_change = Some(callback); + } + + Some(0) + } + + fn handle_syskeydown_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyDown(KeyDownEvent { + keystroke, + is_held: lparam.0 & (0x1 << 30) > 0, + }) + })?; + let mut func = lock.callbacks.input.take()?; + drop(lock); + + let handled = !func(input).propagate; + + let mut lock = self.state.borrow_mut(); + lock.callbacks.input = Some(func); + + if handled { + lock.system_key_handled = true; + Some(0) + } else { + // we need to call `DefWindowProcW`, or we will lose the system-wide `Alt+F4`, `Alt+{other keys}` + // shortcuts. + None + } + } + + fn handle_syskeyup_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let input = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyUp(KeyUpEvent { keystroke }) + })?; + let mut func = lock.callbacks.input.take()?; + drop(lock); + func(input); + self.state.borrow_mut().callbacks.input = Some(func); + + // Always return 0 to indicate that the message was handled, so we could properly handle `ModifiersChanged` event. + Some(0) + } + + // It's a known bug that you can't trigger `ctrl-shift-0`. See: + // https://superuser.com/questions/1455762/ctrl-shift-number-key-combination-has-stopped-working-for-a-few-numbers + fn handle_keydown_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyDown(KeyDownEvent { + keystroke, + is_held: lparam.0 & (0x1 << 30) > 0, + }) + }) else { + return Some(1); + }; + drop(lock); + + let is_composing = self + .with_input_handler(|input_handler| input_handler.marked_text_range()) + .flatten() + .is_some(); + if is_composing { + translate_message(handle, wparam, lparam); + return Some(0); + } + + let Some(mut func) = self.state.borrow_mut().callbacks.input.take() else { + return Some(1); + }; + + let handled = !func(input).propagate; + + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { + Some(0) + } else { + translate_message(handle, wparam, lparam); + Some(1) + } + } + + fn handle_keyup_msg(&self, handle: HWND, wparam: WPARAM, lparam: LPARAM) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let Some(input) = handle_key_event(handle, wparam, lparam, &mut lock, |keystroke| { + PlatformInput::KeyUp(KeyUpEvent { keystroke }) + }) else { + return Some(1); + }; + + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + drop(lock); + + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn handle_char_msg(&self, wparam: WPARAM) -> Option<isize> { + let input = self.parse_char_message(wparam)?; + self.with_input_handler(|input_handler| { + input_handler.replace_text_in_range(None, &input); + }); + + Some(0) + } + + fn handle_dead_char_msg(&self, wparam: WPARAM) -> Option<isize> { + let ch = char::from_u32(wparam.0 as u32)?.to_string(); + self.with_input_handler(|input_handler| { + input_handler.replace_and_mark_text_in_range(None, &ch, None); + }); + None + } + + fn handle_mouse_down_msg( + &self, + handle: HWND, + button: MouseButton, + lparam: LPARAM, + ) -> Option<isize> { + unsafe { SetCapture(handle) }; + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let x = lparam.signed_loword(); + let y = lparam.signed_hiword(); + let physical_point = point(DevicePixels(x as i32), DevicePixels(y as i32)); let click_count = lock.click_state.update(button, physical_point); + let scale_factor = lock.scale_factor; drop(lock); let input = PlatformInput::MouseDown(MouseDownEvent { button, - position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + position: logical_point(x as f32, y as f32, scale_factor), modifiers: current_modifiers(), click_count, first_mouse: false, }); - let result = func(input.clone()); - let handled = !result.propagate || result.default_prevented; - state_ptr.state.borrow_mut().callbacks.input = Some(func); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); - if handled { - return Some(0); - } - } else { - drop(lock); - }; + if handled { Some(0) } else { Some(1) } + } - // Since these are handled in handle_nc_mouse_up_msg we must prevent the default window proc - if button == MouseButton::Left { - match wparam.0 as u32 { - HTMINBUTTON => state_ptr.state.borrow_mut().nc_button_pressed = Some(HTMINBUTTON), - HTMAXBUTTON => state_ptr.state.borrow_mut().nc_button_pressed = Some(HTMAXBUTTON), - HTCLOSE => state_ptr.state.borrow_mut().nc_button_pressed = Some(HTCLOSE), - _ => return None, + fn handle_mouse_up_msg( + &self, + _handle: HWND, + button: MouseButton, + lparam: LPARAM, + ) -> Option<isize> { + unsafe { ReleaseCapture().log_err() }; + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); }; + let x = lparam.signed_loword() as f32; + let y = lparam.signed_hiword() as f32; + let click_count = lock.click_state.current_count; + let scale_factor = lock.scale_factor; + drop(lock); + + let input = PlatformInput::MouseUp(MouseUpEvent { + button, + position: logical_point(x, y, scale_factor), + modifiers: current_modifiers(), + click_count, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn handle_xbutton_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + handler: impl Fn(&Self, HWND, MouseButton, LPARAM) -> Option<isize>, + ) -> Option<isize> { + let nav_dir = match wparam.hiword() { + XBUTTON1 => NavigationDirection::Back, + XBUTTON2 => NavigationDirection::Forward, + _ => return Some(1), + }; + handler(self, handle, MouseButton::Navigate(nav_dir), lparam) + } + + fn handle_mouse_wheel_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let modifiers = current_modifiers(); + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let scale_factor = lock.scale_factor; + let wheel_scroll_amount = match modifiers.shift { + true => lock.system_settings.mouse_wheel_settings.wheel_scroll_chars, + false => lock.system_settings.mouse_wheel_settings.wheel_scroll_lines, + }; + drop(lock); + + let wheel_distance = + (wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_amount as f32; + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let input = PlatformInput::ScrollWheel(ScrollWheelEvent { + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + delta: ScrollDelta::Lines(match modifiers.shift { + true => Point { + x: wheel_distance, + y: 0.0, + }, + false => Point { + y: wheel_distance, + x: 0.0, + }, + }), + modifiers, + touch_phase: TouchPhase::Moved, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn handle_mouse_horizontal_wheel_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + let Some(mut func) = lock.callbacks.input.take() else { + return Some(1); + }; + let scale_factor = lock.scale_factor; + let wheel_scroll_chars = lock.system_settings.mouse_wheel_settings.wheel_scroll_chars; + drop(lock); + + let wheel_distance = + (-wparam.signed_hiword() as f32 / WHEEL_DELTA as f32) * wheel_scroll_chars as f32; + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let event = PlatformInput::ScrollWheel(ScrollWheelEvent { + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + delta: ScrollDelta::Lines(Point { + x: wheel_distance, + y: 0.0, + }), + modifiers: current_modifiers(), + touch_phase: TouchPhase::Moved, + }); + let handled = !func(event).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { Some(0) } else { Some(1) } + } + + fn retrieve_caret_position(&self) -> Option<POINT> { + self.with_input_handler_and_scale_factor(|input_handler, scale_factor| { + let caret_range = input_handler.selected_text_range(false)?; + let caret_position = input_handler.bounds_for_range(caret_range.range)?; + Some(POINT { + // logical to physical + x: (caret_position.origin.x.0 * scale_factor) as i32, + y: (caret_position.origin.y.0 * scale_factor) as i32 + + ((caret_position.size.height.0 * scale_factor) as i32 / 2), + }) + }) + } + + fn handle_ime_position(&self, handle: HWND) -> Option<isize> { + unsafe { + let ctx = ImmGetContext(handle); + + let Some(caret_position) = self.retrieve_caret_position() else { + return Some(0); + }; + { + let config = COMPOSITIONFORM { + dwStyle: CFS_POINT, + ptCurrentPos: caret_position, + ..Default::default() + }; + ImmSetCompositionWindow(ctx, &config as _).ok().log_err(); + } + { + let config = CANDIDATEFORM { + dwStyle: CFS_CANDIDATEPOS, + ptCurrentPos: caret_position, + ..Default::default() + }; + ImmSetCandidateWindow(ctx, &config as _).ok().log_err(); + } + ImmReleaseContext(handle, ctx).ok().log_err(); + Some(0) + } + } + + fn handle_ime_composition(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + let ctx = unsafe { ImmGetContext(handle) }; + let result = self.handle_ime_composition_inner(ctx, lparam); + unsafe { ImmReleaseContext(handle, ctx).ok().log_err() }; + result + } + + fn handle_ime_composition_inner(&self, ctx: HIMC, lparam: LPARAM) -> Option<isize> { + let lparam = lparam.0 as u32; + if lparam == 0 { + // Japanese IME may send this message with lparam = 0, which indicates that + // there is no composition string. + self.with_input_handler(|input_handler| { + input_handler.replace_text_in_range(None, ""); + })?; + Some(0) + } else { + if lparam & GCS_COMPSTR.0 > 0 { + let comp_string = parse_ime_composition_string(ctx, GCS_COMPSTR)?; + let caret_pos = + (!comp_string.is_empty() && lparam & GCS_CURSORPOS.0 > 0).then(|| { + let pos = retrieve_composition_cursor_position(ctx); + pos..pos + }); + self.with_input_handler(|input_handler| { + input_handler.replace_and_mark_text_in_range(None, &comp_string, caret_pos); + })?; + } + if lparam & GCS_RESULTSTR.0 > 0 { + let comp_result = parse_ime_composition_string(ctx, GCS_RESULTSTR)?; + self.with_input_handler(|input_handler| { + input_handler.replace_text_in_range(None, &comp_result); + })?; + return Some(0); + } + + // currently, we don't care other stuff + None + } + } + + /// SEE: https://learn.microsoft.com/en-us/windows/win32/winmsg/wm-nccalcsize + fn handle_calc_client_size( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + if !self.hide_title_bar || self.state.borrow().is_fullscreen() || wparam.0 == 0 { + return None; + } + + let is_maximized = self.state.borrow().is_maximized(); + let insets = get_client_area_insets(handle, is_maximized, self.windows_version); + // wparam is TRUE so lparam points to an NCCALCSIZE_PARAMS structure + let mut params = lparam.0 as *mut NCCALCSIZE_PARAMS; + let mut requested_client_rect = unsafe { &mut ((*params).rgrc) }; + + requested_client_rect[0].left += insets.left; + requested_client_rect[0].top += insets.top; + requested_client_rect[0].right -= insets.right; + requested_client_rect[0].bottom -= insets.bottom; + + // Fix auto hide taskbar not showing. This solution is based on the approach + // used by Chrome. However, it may result in one row of pixels being obscured + // in our client area. But as Chrome says, "there seems to be no better solution." + if is_maximized { + if let Some(ref taskbar_position) = self + .state + .borrow() + .system_settings + .auto_hide_taskbar_position + { + // Fot the auto-hide taskbar, adjust in by 1 pixel on taskbar edge, + // so the window isn't treated as a "fullscreen app", which would cause + // the taskbar to disappear. + match taskbar_position { + AutoHideTaskbarPosition::Left => { + requested_client_rect[0].left += AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Top => { + requested_client_rect[0].top += AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Right => { + requested_client_rect[0].right -= AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Bottom => { + requested_client_rect[0].bottom -= AUTO_HIDE_TASKBAR_THICKNESS_PX + } + } + } + } + Some(0) - } else { + } + + fn handle_activate_msg(self: &Rc<Self>, wparam: WPARAM) -> Option<isize> { + let activated = wparam.loword() > 0; + let this = self.clone(); + self.executor + .spawn(async move { + let mut lock = this.state.borrow_mut(); + if let Some(mut func) = lock.callbacks.active_status_change.take() { + drop(lock); + func(activated); + this.state.borrow_mut().callbacks.active_status_change = Some(func); + } + }) + .detach(); + None } -} -fn handle_nc_mouse_up_msg( - handle: HWND, - button: MouseButton, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut func) = lock.callbacks.input.take() { + fn handle_create_msg(&self, handle: HWND) -> Option<isize> { + if self.hide_title_bar { + notify_frame_changed(handle); + Some(0) + } else { + None + } + } + + fn handle_dpi_changed_msg( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let new_dpi = wparam.loword() as f32; + let mut lock = self.state.borrow_mut(); + let is_maximized = lock.is_maximized(); + let new_scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32; + lock.scale_factor = new_scale_factor; + lock.border_offset.update(handle).log_err(); + drop(lock); + + let rect = unsafe { &*(lparam.0 as *const RECT) }; + let width = rect.right - rect.left; + let height = rect.bottom - rect.top; + // this will emit `WM_SIZE` and `WM_MOVE` right here + // even before this function returns + // the new size is handled in `WM_SIZE` + unsafe { + SetWindowPos( + handle, + None, + rect.left, + rect.top, + width, + height, + SWP_NOZORDER | SWP_NOACTIVATE, + ) + .context("unable to set window position after dpi has changed") + .log_err(); + } + + // When maximized, SetWindowPos doesn't send WM_SIZE, so we need to manually + // update the size and call the resize callback + if is_maximized { + let device_size = size(DevicePixels(width), DevicePixels(height)); + self.handle_size_change(device_size, new_scale_factor, true); + } + + Some(0) + } + + /// The following conditions will trigger this event: + /// 1. The monitor on which the window is located goes offline or changes resolution. + /// 2. Another monitor goes offline, is plugged in, or changes resolution. + /// + /// In either case, the window will only receive information from the monitor on which + /// it is located. + /// + /// For example, in the case of condition 2, where the monitor on which the window is + /// located has actually changed nothing, it will still receive this event. + fn handle_display_change_msg(&self, handle: HWND) -> Option<isize> { + // NOTE: + // Even the `lParam` holds the resolution of the screen, we just ignore it. + // Because WM_DPICHANGED, WM_MOVE, WM_SIZE will come first, window reposition and resize + // are handled there. + // So we only care about if monitor is disconnected. + let previous_monitor = self.state.borrow().display; + if WindowsDisplay::is_connected(previous_monitor.handle) { + // we are fine, other display changed + return None; + } + // display disconnected + // in this case, the OS will move our window to another monitor, and minimize it. + // we deminimize the window and query the monitor after moving + unsafe { + let _ = ShowWindow(handle, SW_SHOWNORMAL); + }; + let new_monitor = unsafe { MonitorFromWindow(handle, MONITOR_DEFAULTTONULL) }; + // all monitors disconnected + if new_monitor.is_invalid() { + log::error!("No monitor detected!"); + return None; + } + let new_display = WindowsDisplay::new_with_handle(new_monitor); + self.state.borrow_mut().display = new_display; + Some(0) + } + + fn handle_hit_test_msg( + &self, + handle: HWND, + msg: u32, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + if !self.is_movable || self.state.borrow().is_fullscreen() { + return None; + } + + let mut lock = self.state.borrow_mut(); + if let Some(mut callback) = lock.callbacks.hit_test_window_control.take() { + drop(lock); + let area = callback(); + self.state.borrow_mut().callbacks.hit_test_window_control = Some(callback); + if let Some(area) = area { + return match area { + WindowControlArea::Drag => Some(HTCAPTION as _), + WindowControlArea::Close => Some(HTCLOSE as _), + WindowControlArea::Max => Some(HTMAXBUTTON as _), + WindowControlArea::Min => Some(HTMINBUTTON as _), + }; + } + } else { + drop(lock); + } + + if !self.hide_title_bar { + // If the OS draws the title bar, we don't need to handle hit test messages. + return None; + } + + // default handler for resize areas + let hit = unsafe { DefWindowProcW(handle, msg, wparam, lparam) }; + if matches!( + hit.0 as u32, + HTNOWHERE + | HTRIGHT + | HTLEFT + | HTTOPLEFT + | HTTOP + | HTTOPRIGHT + | HTBOTTOMRIGHT + | HTBOTTOM + | HTBOTTOMLEFT + ) { + return Some(hit.0); + } + + if self.state.borrow().is_fullscreen() { + return Some(HTCLIENT as _); + } + + let dpi = unsafe { GetDpiForWindow(handle) }; + let frame_y = unsafe { GetSystemMetricsForDpi(SM_CYFRAME, dpi) }; + + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + if !self.state.borrow().is_maximized() && cursor_point.y >= 0 && cursor_point.y <= frame_y { + return Some(HTTOP as _); + } + + Some(HTCLIENT as _) + } + + fn handle_nc_mouse_move_msg(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + self.start_tracking_mouse(handle, TME_LEAVE | TME_NONCLIENT); + + let mut lock = self.state.borrow_mut(); + let mut func = lock.callbacks.input.take()?; let scale_factor = lock.scale_factor; drop(lock); @@ -1052,206 +919,355 @@ fn handle_nc_mouse_up_msg( y: lparam.signed_hiword().into(), }; unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; - let input = PlatformInput::MouseUp(MouseUpEvent { - button, + let input = PlatformInput::MouseMove(MouseMoveEvent { position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + pressed_button: None, modifiers: current_modifiers(), - click_count: 1, }); let handled = !func(input).propagate; - state_ptr.state.borrow_mut().callbacks.input = Some(func); + self.state.borrow_mut().callbacks.input = Some(func); - if handled { - return Some(0); - } - } else { - drop(lock); + if handled { Some(0) } else { None } } - let last_pressed = state_ptr.state.borrow_mut().nc_button_pressed.take(); - if button == MouseButton::Left - && let Some(last_pressed) = last_pressed - { - let handled = match (wparam.0 as u32, last_pressed) { - (HTMINBUTTON, HTMINBUTTON) => { - unsafe { ShowWindowAsync(handle, SW_MINIMIZE).ok().log_err() }; - true + fn handle_nc_mouse_down_msg( + &self, + handle: HWND, + button: MouseButton, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + if let Some(mut func) = lock.callbacks.input.take() { + let scale_factor = lock.scale_factor; + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let physical_point = point(DevicePixels(cursor_point.x), DevicePixels(cursor_point.y)); + let click_count = lock.click_state.update(button, physical_point); + drop(lock); + + let input = PlatformInput::MouseDown(MouseDownEvent { + button, + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + modifiers: current_modifiers(), + click_count, + first_mouse: false, + }); + let result = func(input.clone()); + let handled = !result.propagate || result.default_prevented; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { + return Some(0); } - (HTMAXBUTTON, HTMAXBUTTON) => { - if state_ptr.state.borrow().is_maximized() { - unsafe { ShowWindowAsync(handle, SW_NORMAL).ok().log_err() }; - } else { - unsafe { ShowWindowAsync(handle, SW_MAXIMIZE).ok().log_err() }; - } - true - } - (HTCLOSE, HTCLOSE) => { - unsafe { - PostMessageW(Some(handle), WM_CLOSE, WPARAM::default(), LPARAM::default()) - .log_err() - }; - true - } - _ => false, + } else { + drop(lock); }; - if handled { - return Some(0); - } - } - None -} - -fn handle_cursor_changed(lparam: LPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - let mut state = state_ptr.state.borrow_mut(); - let had_cursor = state.current_cursor.is_some(); - - state.current_cursor = if lparam.0 == 0 { - None - } else { - Some(HCURSOR(lparam.0 as _)) - }; - - if had_cursor != state.current_cursor.is_some() { - unsafe { SetCursor(state.current_cursor) }; - } - - Some(0) -} - -fn handle_set_cursor( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if unsafe { !IsWindowEnabled(handle).as_bool() } - || matches!( - lparam.loword() as u32, - HTLEFT - | HTRIGHT - | HTTOP - | HTTOPLEFT - | HTTOPRIGHT - | HTBOTTOM - | HTBOTTOMLEFT - | HTBOTTOMRIGHT - ) - { - return None; - } - unsafe { - SetCursor(state_ptr.state.borrow().current_cursor); - }; - Some(1) -} - -fn handle_system_settings_changed( - handle: HWND, - wparam: WPARAM, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - if wparam.0 != 0 { - let mut lock = state_ptr.state.borrow_mut(); - let display = lock.display; - lock.system_settings.update(display, wparam.0); - lock.click_state.system_update(wparam.0); - lock.border_offset.update(handle).log_err(); - } else { - handle_system_theme_changed(handle, lparam, state_ptr)?; - }; - // Force to trigger WM_NCCALCSIZE event to ensure that we handle auto hide - // taskbar correctly. - notify_frame_changed(handle); - - Some(0) -} - -fn handle_system_command(wparam: WPARAM, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> { - if wparam.0 == SC_KEYMENU as usize { - let mut lock = state_ptr.state.borrow_mut(); - if lock.system_key_handled { - lock.system_key_handled = false; - return Some(0); - } - } - None -} - -fn handle_system_theme_changed( - handle: HWND, - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - // lParam is a pointer to a string that indicates the area containing the system parameter - // that was changed. - let parameter = PCWSTR::from_raw(lparam.0 as _); - if unsafe { !parameter.is_null() && !parameter.is_empty() } { - if let Some(parameter_string) = unsafe { parameter.to_string() }.log_err() { - log::info!("System settings changed: {}", parameter_string); - match parameter_string.as_str() { - "ImmersiveColorSet" => { - let new_appearance = system_appearance() - .context("unable to get system appearance when handling ImmersiveColorSet") - .log_err()?; - let mut lock = state_ptr.state.borrow_mut(); - if new_appearance != lock.appearance { - lock.appearance = new_appearance; - let mut callback = lock.callbacks.appearance_changed.take()?; - drop(lock); - callback(); - state_ptr.state.borrow_mut().callbacks.appearance_changed = Some(callback); - configure_dwm_dark_mode(handle, new_appearance); - } - } - _ => {} - } - } - } - Some(0) -} - -fn handle_input_language_changed( - lparam: LPARAM, - state_ptr: Rc<WindowsWindowStatePtr>, -) -> Option<isize> { - let thread = state_ptr.main_thread_id_win32; - let validation = state_ptr.validation_number; - unsafe { - PostThreadMessageW(thread, WM_INPUTLANGCHANGE, WPARAM(validation), lparam).log_err(); - } - Some(0) -} - -#[inline] -fn parse_char_message(wparam: WPARAM, state_ptr: &Rc<WindowsWindowStatePtr>) -> Option<String> { - let code_point = wparam.loword(); - let mut lock = state_ptr.state.borrow_mut(); - // https://www.unicode.org/versions/Unicode16.0.0/core-spec/chapter-3/#G2630 - match code_point { - 0xD800..=0xDBFF => { - // High surrogate, wait for low surrogate - lock.pending_surrogate = Some(code_point); + // Since these are handled in handle_nc_mouse_up_msg we must prevent the default window proc + if button == MouseButton::Left { + match wparam.0 as u32 { + HTMINBUTTON => self.state.borrow_mut().nc_button_pressed = Some(HTMINBUTTON), + HTMAXBUTTON => self.state.borrow_mut().nc_button_pressed = Some(HTMAXBUTTON), + HTCLOSE => self.state.borrow_mut().nc_button_pressed = Some(HTCLOSE), + _ => return None, + }; + Some(0) + } else { None } - 0xDC00..=0xDFFF => { - if let Some(high_surrogate) = lock.pending_surrogate.take() { - // Low surrogate, combine with pending high surrogate - String::from_utf16(&[high_surrogate, code_point]).ok() - } else { - // Invalid low surrogate without a preceding high surrogate - log::warn!( - "Received low surrogate without a preceding high surrogate: {code_point:x}" - ); - None + } + + fn handle_nc_mouse_up_msg( + &self, + handle: HWND, + button: MouseButton, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + let mut lock = self.state.borrow_mut(); + if let Some(mut func) = lock.callbacks.input.take() { + let scale_factor = lock.scale_factor; + drop(lock); + + let mut cursor_point = POINT { + x: lparam.signed_loword().into(), + y: lparam.signed_hiword().into(), + }; + unsafe { ScreenToClient(handle, &mut cursor_point).ok().log_err() }; + let input = PlatformInput::MouseUp(MouseUpEvent { + button, + position: logical_point(cursor_point.x as f32, cursor_point.y as f32, scale_factor), + modifiers: current_modifiers(), + click_count: 1, + }); + let handled = !func(input).propagate; + self.state.borrow_mut().callbacks.input = Some(func); + + if handled { + return Some(0); + } + } else { + drop(lock); + } + + let last_pressed = self.state.borrow_mut().nc_button_pressed.take(); + if button == MouseButton::Left + && let Some(last_pressed) = last_pressed + { + let handled = match (wparam.0 as u32, last_pressed) { + (HTMINBUTTON, HTMINBUTTON) => { + unsafe { ShowWindowAsync(handle, SW_MINIMIZE).ok().log_err() }; + true + } + (HTMAXBUTTON, HTMAXBUTTON) => { + if self.state.borrow().is_maximized() { + unsafe { ShowWindowAsync(handle, SW_NORMAL).ok().log_err() }; + } else { + unsafe { ShowWindowAsync(handle, SW_MAXIMIZE).ok().log_err() }; + } + true + } + (HTCLOSE, HTCLOSE) => { + unsafe { + PostMessageW(Some(handle), WM_CLOSE, WPARAM::default(), LPARAM::default()) + .log_err() + }; + true + } + _ => false, + }; + if handled { + return Some(0); } } - _ => { - lock.pending_surrogate = None; - char::from_u32(code_point as u32) - .filter(|c| !c.is_control()) - .map(|c| c.to_string()) + + None + } + + fn handle_cursor_changed(&self, lparam: LPARAM) -> Option<isize> { + let mut state = self.state.borrow_mut(); + let had_cursor = state.current_cursor.is_some(); + + state.current_cursor = if lparam.0 == 0 { + None + } else { + Some(HCURSOR(lparam.0 as _)) + }; + + if had_cursor != state.current_cursor.is_some() { + unsafe { SetCursor(state.current_cursor) }; } + + Some(0) + } + + fn handle_set_cursor(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + if unsafe { !IsWindowEnabled(handle).as_bool() } + || matches!( + lparam.loword() as u32, + HTLEFT + | HTRIGHT + | HTTOP + | HTTOPLEFT + | HTTOPRIGHT + | HTBOTTOM + | HTBOTTOMLEFT + | HTBOTTOMRIGHT + ) + { + return None; + } + unsafe { + SetCursor(self.state.borrow().current_cursor); + }; + Some(1) + } + + fn handle_system_settings_changed( + &self, + handle: HWND, + wparam: WPARAM, + lparam: LPARAM, + ) -> Option<isize> { + if wparam.0 != 0 { + let mut lock = self.state.borrow_mut(); + let display = lock.display; + lock.system_settings.update(display, wparam.0); + lock.click_state.system_update(wparam.0); + lock.border_offset.update(handle).log_err(); + } else { + self.handle_system_theme_changed(handle, lparam)?; + }; + // Force to trigger WM_NCCALCSIZE event to ensure that we handle auto hide + // taskbar correctly. + notify_frame_changed(handle); + + Some(0) + } + + fn handle_system_command(&self, wparam: WPARAM) -> Option<isize> { + if wparam.0 == SC_KEYMENU as usize { + let mut lock = self.state.borrow_mut(); + if lock.system_key_handled { + lock.system_key_handled = false; + return Some(0); + } + } + None + } + + fn handle_system_theme_changed(&self, handle: HWND, lparam: LPARAM) -> Option<isize> { + // lParam is a pointer to a string that indicates the area containing the system parameter + // that was changed. + let parameter = PCWSTR::from_raw(lparam.0 as _); + if unsafe { !parameter.is_null() && !parameter.is_empty() } { + if let Some(parameter_string) = unsafe { parameter.to_string() }.log_err() { + log::info!("System settings changed: {}", parameter_string); + match parameter_string.as_str() { + "ImmersiveColorSet" => { + let new_appearance = system_appearance() + .context( + "unable to get system appearance when handling ImmersiveColorSet", + ) + .log_err()?; + let mut lock = self.state.borrow_mut(); + if new_appearance != lock.appearance { + lock.appearance = new_appearance; + let mut callback = lock.callbacks.appearance_changed.take()?; + drop(lock); + callback(); + self.state.borrow_mut().callbacks.appearance_changed = Some(callback); + configure_dwm_dark_mode(handle, new_appearance); + } + } + _ => {} + } + } + } + Some(0) + } + + fn handle_input_language_changed(&self, lparam: LPARAM) -> Option<isize> { + let thread = self.main_thread_id_win32; + let validation = self.validation_number; + unsafe { + PostThreadMessageW(thread, WM_INPUTLANGCHANGE, WPARAM(validation), lparam).log_err(); + } + Some(0) + } + + fn handle_device_change_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> { + if wparam.0 == DBT_DEVNODES_CHANGED as usize { + // The reason for sending this message is to actually trigger a redraw of the window. + unsafe { + PostMessageW( + Some(handle), + WM_GPUI_FORCE_UPDATE_WINDOW, + WPARAM(0), + LPARAM(0), + ) + .log_err(); + } + // If the GPU device is lost, this redraw will take care of recreating the device context. + // The WM_GPUI_FORCE_UPDATE_WINDOW message will take care of redrawing the window, after + // the device context has been recreated. + self.draw_window(handle, true) + } else { + // Other device change messages are not handled. + None + } + } + + #[inline] + fn draw_window(&self, handle: HWND, force_render: bool) -> Option<isize> { + let mut request_frame = self.state.borrow_mut().callbacks.request_frame.take()?; + request_frame(RequestFrameOptions { + require_presentation: false, + force_render, + }); + self.state.borrow_mut().callbacks.request_frame = Some(request_frame); + unsafe { ValidateRect(Some(handle), None).ok().log_err() }; + Some(0) + } + + #[inline] + fn parse_char_message(&self, wparam: WPARAM) -> Option<String> { + let code_point = wparam.loword(); + let mut lock = self.state.borrow_mut(); + // https://www.unicode.org/versions/Unicode16.0.0/core-spec/chapter-3/#G2630 + match code_point { + 0xD800..=0xDBFF => { + // High surrogate, wait for low surrogate + lock.pending_surrogate = Some(code_point); + None + } + 0xDC00..=0xDFFF => { + if let Some(high_surrogate) = lock.pending_surrogate.take() { + // Low surrogate, combine with pending high surrogate + String::from_utf16(&[high_surrogate, code_point]).ok() + } else { + // Invalid low surrogate without a preceding high surrogate + log::warn!( + "Received low surrogate without a preceding high surrogate: {code_point:x}" + ); + None + } + } + _ => { + lock.pending_surrogate = None; + char::from_u32(code_point as u32) + .filter(|c| !c.is_control()) + .map(|c| c.to_string()) + } + } + } + + fn start_tracking_mouse(&self, handle: HWND, flags: TRACKMOUSEEVENT_FLAGS) { + let mut lock = self.state.borrow_mut(); + if !lock.hovered { + lock.hovered = true; + unsafe { + TrackMouseEvent(&mut TRACKMOUSEEVENT { + cbSize: std::mem::size_of::<TRACKMOUSEEVENT>() as u32, + dwFlags: flags, + hwndTrack: handle, + dwHoverTime: HOVER_DEFAULT, + }) + .log_err() + }; + if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { + drop(lock); + callback(true); + self.state.borrow_mut().callbacks.hovered_status_change = Some(callback); + } + } + } + + fn with_input_handler<F, R>(&self, f: F) -> Option<R> + where + F: FnOnce(&mut PlatformInputHandler) -> R, + { + let mut input_handler = self.state.borrow_mut().input_handler.take()?; + let result = f(&mut input_handler); + self.state.borrow_mut().input_handler = Some(input_handler); + Some(result) + } + + fn with_input_handler_and_scale_factor<F, R>(&self, f: F) -> Option<R> + where + F: FnOnce(&mut PlatformInputHandler, f32) -> Option<R>, + { + let mut lock = self.state.borrow_mut(); + let mut input_handler = lock.input_handler.take()?; + let scale_factor = lock.scale_factor; + drop(lock); + let result = f(&mut input_handler, scale_factor); + self.state.borrow_mut().input_handler = Some(input_handler); + result } } @@ -1521,54 +1537,3 @@ fn notify_frame_changed(handle: HWND) { .log_err(); } } - -fn start_tracking_mouse( - handle: HWND, - state_ptr: &Rc<WindowsWindowStatePtr>, - flags: TRACKMOUSEEVENT_FLAGS, -) { - let mut lock = state_ptr.state.borrow_mut(); - if !lock.hovered { - lock.hovered = true; - unsafe { - TrackMouseEvent(&mut TRACKMOUSEEVENT { - cbSize: std::mem::size_of::<TRACKMOUSEEVENT>() as u32, - dwFlags: flags, - hwndTrack: handle, - dwHoverTime: HOVER_DEFAULT, - }) - .log_err() - }; - if let Some(mut callback) = lock.callbacks.hovered_status_change.take() { - drop(lock); - callback(true); - state_ptr.state.borrow_mut().callbacks.hovered_status_change = Some(callback); - } - } -} - -fn with_input_handler<F, R>(state_ptr: &Rc<WindowsWindowStatePtr>, f: F) -> Option<R> -where - F: FnOnce(&mut PlatformInputHandler) -> R, -{ - let mut input_handler = state_ptr.state.borrow_mut().input_handler.take()?; - let result = f(&mut input_handler); - state_ptr.state.borrow_mut().input_handler = Some(input_handler); - Some(result) -} - -fn with_input_handler_and_scale_factor<F, R>( - state_ptr: &Rc<WindowsWindowStatePtr>, - f: F, -) -> Option<R> -where - F: FnOnce(&mut PlatformInputHandler, f32) -> Option<R>, -{ - let mut lock = state_ptr.state.borrow_mut(); - let mut input_handler = lock.input_handler.take()?; - let scale_factor = lock.scale_factor; - drop(lock); - let result = f(&mut input_handler, scale_factor); - state_ptr.state.borrow_mut().input_handler = Some(input_handler); - result -} diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index 401ecdeffe..01b043a755 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -28,13 +28,12 @@ use windows::{ core::*, }; -use crate::{platform::blade::BladeContext, *}; +use crate::*; pub(crate) struct WindowsPlatform { state: RefCell<WindowsPlatformState>, raw_window_handles: RwLock<SmallVec<[HWND; 4]>>, // The below members will never change throughout the entire lifecycle of the app. - gpu_context: BladeContext, icon: HICON, main_receiver: flume::Receiver<Runnable>, background_executor: BackgroundExecutor, @@ -45,6 +44,7 @@ pub(crate) struct WindowsPlatform { drop_target_helper: IDropTargetHelper, validation_number: usize, main_thread_id_win32: u32, + disable_direct_composition: bool, } pub(crate) struct WindowsPlatformState { @@ -94,14 +94,18 @@ impl WindowsPlatform { main_thread_id_win32, validation_number, )); + let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION) + .is_ok_and(|value| value == "true" || value == "1"); let background_executor = BackgroundExecutor::new(dispatcher.clone()); let foreground_executor = ForegroundExecutor::new(dispatcher); + let directx_devices = DirectXDevices::new(disable_direct_composition) + .context("Unable to init directx devices.")?; let bitmap_factory = ManuallyDrop::new(unsafe { CoCreateInstance(&CLSID_WICImagingFactory, None, CLSCTX_INPROC_SERVER) .context("Error creating bitmap factory.")? }); let text_system = Arc::new( - DirectWriteTextSystem::new(&bitmap_factory) + DirectWriteTextSystem::new(&directx_devices, &bitmap_factory) .context("Error creating DirectWriteTextSystem")?, ); let drop_target_helper: IDropTargetHelper = unsafe { @@ -111,18 +115,17 @@ impl WindowsPlatform { let icon = load_icon().unwrap_or_default(); let state = RefCell::new(WindowsPlatformState::new()); let raw_window_handles = RwLock::new(SmallVec::new()); - let gpu_context = BladeContext::new().context("Unable to init GPU context")?; let windows_version = WindowsVersion::new().context("Error retrieve windows version")?; Ok(Self { state, raw_window_handles, - gpu_context, icon, main_receiver, background_executor, foreground_executor, text_system, + disable_direct_composition, windows_version, bitmap_factory, drop_target_helper, @@ -141,12 +144,12 @@ impl WindowsPlatform { } } - pub fn try_get_windows_inner_from_hwnd(&self, hwnd: HWND) -> Option<Rc<WindowsWindowStatePtr>> { + pub fn window_from_hwnd(&self, hwnd: HWND) -> Option<Rc<WindowsWindowInner>> { self.raw_window_handles .read() .iter() .find(|entry| *entry == &hwnd) - .and_then(|hwnd| try_get_window_inner(*hwnd)) + .and_then(|hwnd| window_from_hwnd(*hwnd)) } #[inline] @@ -187,6 +190,7 @@ impl WindowsPlatform { validation_number: self.validation_number, main_receiver: self.main_receiver.clone(), main_thread_id_win32: self.main_thread_id_win32, + disable_direct_composition: self.disable_direct_composition, } } @@ -343,27 +347,11 @@ impl Platform for WindowsPlatform { fn run(&self, on_finish_launching: Box<dyn 'static + FnOnce()>) { on_finish_launching(); - let vsync_event = unsafe { Owned::new(CreateEventW(None, false, false, None).unwrap()) }; - begin_vsync(*vsync_event); - 'a: loop { - let wait_result = unsafe { - MsgWaitForMultipleObjects(Some(&[*vsync_event]), false, INFINITE, QS_ALLINPUT) - }; - - match wait_result { - // compositor clock ticked so we should draw a frame - WAIT_EVENT(0) => self.redraw_all(), - // Windows thread messages are posted - WAIT_EVENT(1) => { - if self.handle_events() { - break 'a; - } - } - _ => { - log::error!("Something went wrong while waiting {:?}", wait_result); - break; - } + loop { + if self.handle_events() { + break; } + self.redraw_all(); } if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit { @@ -446,7 +434,7 @@ impl Platform for WindowsPlatform { fn active_window(&self) -> Option<AnyWindowHandle> { let active_window_hwnd = unsafe { GetActiveWindow() }; - self.try_get_windows_inner_from_hwnd(active_window_hwnd) + self.window_from_hwnd(active_window_hwnd) .map(|inner| inner.handle) } @@ -455,12 +443,7 @@ impl Platform for WindowsPlatform { handle: AnyWindowHandle, options: WindowParams, ) -> Result<Box<dyn PlatformWindow>> { - let window = WindowsWindow::new( - handle, - options, - self.generate_creation_info(), - &self.gpu_context, - )?; + let window = WindowsWindow::new(handle, options, self.generate_creation_info())?; let handle = window.get_raw_handle(); self.raw_window_handles.write().push(handle); @@ -739,6 +722,7 @@ pub(crate) struct WindowCreationInfo { pub(crate) validation_number: usize, pub(crate) main_receiver: flume::Receiver<Runnable>, pub(crate) main_thread_id_win32: u32, + pub(crate) disable_direct_composition: bool, } fn open_target(target: &str) { @@ -846,16 +830,6 @@ fn file_save_dialog(directory: PathBuf, window: Option<HWND>) -> Result<Option<P Ok(Some(PathBuf::from(file_path_string))) } -fn begin_vsync(vsync_event: HANDLE) { - let event: SafeHandle = vsync_event.into(); - std::thread::spawn(move || unsafe { - loop { - windows::Win32::Graphics::Dwm::DwmFlush().log_err(); - SetEvent(*event).log_err(); - } - }); -} - fn load_icon() -> Result<HICON> { let module = unsafe { GetModuleHandleW(None).context("unable to get module handle")? }; let handle = unsafe { diff --git a/crates/gpui/src/platform/windows/shaders.hlsl b/crates/gpui/src/platform/windows/shaders.hlsl new file mode 100644 index 0000000000..25830e4b6c --- /dev/null +++ b/crates/gpui/src/platform/windows/shaders.hlsl @@ -0,0 +1,1159 @@ +cbuffer GlobalParams: register(b0) { + float2 global_viewport_size; + uint2 _pad; +}; + +Texture2D<float4> t_sprite: register(t0); +SamplerState s_sprite: register(s0); + +struct Bounds { + float2 origin; + float2 size; +}; + +struct Corners { + float top_left; + float top_right; + float bottom_right; + float bottom_left; +}; + +struct Edges { + float top; + float right; + float bottom; + float left; +}; + +struct Hsla { + float h; + float s; + float l; + float a; +}; + +struct LinearColorStop { + Hsla color; + float percentage; +}; + +struct Background { + // 0u is Solid + // 1u is LinearGradient + // 2u is PatternSlash + uint tag; + // 0u is sRGB linear color + // 1u is Oklab color + uint color_space; + Hsla solid; + float gradient_angle_or_pattern_height; + LinearColorStop colors[2]; + uint pad; +}; + +struct GradientColor { + float4 solid; + float4 color0; + float4 color1; +}; + +struct AtlasTextureId { + uint index; + uint kind; +}; + +struct AtlasBounds { + int2 origin; + int2 size; +}; + +struct AtlasTile { + AtlasTextureId texture_id; + uint tile_id; + uint padding; + AtlasBounds bounds; +}; + +struct TransformationMatrix { + float2x2 rotation_scale; + float2 translation; +}; + +static const float M_PI_F = 3.141592653f; +static const float3 GRAYSCALE_FACTORS = float3(0.2126f, 0.7152f, 0.0722f); + +float4 to_device_position_impl(float2 position) { + float2 device_position = position / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); + return float4(device_position, 0., 1.); +} + +float4 to_device_position(float2 unit_vertex, Bounds bounds) { + float2 position = unit_vertex * bounds.size + bounds.origin; + return to_device_position_impl(position); +} + +float4 distance_from_clip_rect_impl(float2 position, Bounds clip_bounds) { + float2 tl = position - clip_bounds.origin; + float2 br = clip_bounds.origin + clip_bounds.size - position; + return float4(tl.x, br.x, tl.y, br.y); +} + +float4 distance_from_clip_rect(float2 unit_vertex, Bounds bounds, Bounds clip_bounds) { + float2 position = unit_vertex * bounds.size + bounds.origin; + return distance_from_clip_rect_impl(position, clip_bounds); +} + +// Convert linear RGB to sRGB +float3 linear_to_srgb(float3 color) { + return pow(color, float3(2.2, 2.2, 2.2)); +} + +// Convert sRGB to linear RGB +float3 srgb_to_linear(float3 color) { + return pow(color, float3(1.0 / 2.2, 1.0 / 2.2, 1.0 / 2.2)); +} + +/// Hsla to linear RGBA conversion. +float4 hsla_to_rgba(Hsla hsla) { + float h = hsla.h * 6.0; // Now, it's an angle but scaled in [0, 6) range + float s = hsla.s; + float l = hsla.l; + float a = hsla.a; + + float c = (1.0 - abs(2.0 * l - 1.0)) * s; + float x = c * (1.0 - abs(fmod(h, 2.0) - 1.0)); + float m = l - c / 2.0; + + float r = 0.0; + float g = 0.0; + float b = 0.0; + + if (h >= 0.0 && h < 1.0) { + r = c; + g = x; + b = 0.0; + } else if (h >= 1.0 && h < 2.0) { + r = x; + g = c; + b = 0.0; + } else if (h >= 2.0 && h < 3.0) { + r = 0.0; + g = c; + b = x; + } else if (h >= 3.0 && h < 4.0) { + r = 0.0; + g = x; + b = c; + } else if (h >= 4.0 && h < 5.0) { + r = x; + g = 0.0; + b = c; + } else { + r = c; + g = 0.0; + b = x; + } + + float4 rgba; + rgba.x = (r + m); + rgba.y = (g + m); + rgba.z = (b + m); + rgba.w = a; + return rgba; +} + +// Converts a sRGB color to the Oklab color space. +// Reference: https://bottosson.github.io/posts/oklab/#converting-from-linear-srgb-to-oklab +float4 srgb_to_oklab(float4 color) { + // Convert non-linear sRGB to linear sRGB + color = float4(srgb_to_linear(color.rgb), color.a); + + float l = 0.4122214708 * color.r + 0.5363325363 * color.g + 0.0514459929 * color.b; + float m = 0.2119034982 * color.r + 0.6806995451 * color.g + 0.1073969566 * color.b; + float s = 0.0883024619 * color.r + 0.2817188376 * color.g + 0.6299787005 * color.b; + + float l_ = pow(l, 1.0/3.0); + float m_ = pow(m, 1.0/3.0); + float s_ = pow(s, 1.0/3.0); + + return float4( + 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_, + 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_, + 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_, + color.a + ); +} + +// Converts an Oklab color to the sRGB color space. +float4 oklab_to_srgb(float4 color) { + float l_ = color.r + 0.3963377774 * color.g + 0.2158037573 * color.b; + float m_ = color.r - 0.1055613458 * color.g - 0.0638541728 * color.b; + float s_ = color.r - 0.0894841775 * color.g - 1.2914855480 * color.b; + + float l = l_ * l_ * l_; + float m = m_ * m_ * m_; + float s = s_ * s_ * s_; + + float3 linear_rgb = float3( + 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s, + -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s, + -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s + ); + + // Convert linear sRGB to non-linear sRGB + return float4(linear_to_srgb(linear_rgb), color.a); +} + +// This approximates the error function, needed for the gaussian integral +float2 erf(float2 x) { + float2 s = sign(x); + float2 a = abs(x); + x = 1. + (0.278393 + (0.230389 + 0.078108 * (a * a)) * a) * a; + x *= x; + return s - s / (x * x); +} + +float blur_along_x(float x, float y, float sigma, float corner, float2 half_size) { + float delta = min(half_size.y - corner - abs(y), 0.); + float curved = half_size.x - corner + sqrt(max(0., corner * corner - delta * delta)); + float2 integral = 0.5 + 0.5 * erf((x + float2(-curved, curved)) * (sqrt(0.5) / sigma)); + return integral.y - integral.x; +} + +// A standard gaussian function, used for weighting samples +float gaussian(float x, float sigma) { + return exp(-(x * x) / (2. * sigma * sigma)) / (sqrt(2. * M_PI_F) * sigma); +} + +float4 over(float4 below, float4 above) { + float4 result; + float alpha = above.a + below.a * (1.0 - above.a); + result.rgb = (above.rgb * above.a + below.rgb * below.a * (1.0 - above.a)) / alpha; + result.a = alpha; + return result; +} + +float2 to_tile_position(float2 unit_vertex, AtlasTile tile) { + float2 atlas_size; + t_sprite.GetDimensions(atlas_size.x, atlas_size.y); + return (float2(tile.bounds.origin) + unit_vertex * float2(tile.bounds.size)) / atlas_size; +} + +// Selects corner radius based on quadrant. +float pick_corner_radius(float2 center_to_point, Corners corner_radii) { + if (center_to_point.x < 0.) { + if (center_to_point.y < 0.) { + return corner_radii.top_left; + } else { + return corner_radii.bottom_left; + } + } else { + if (center_to_point.y < 0.) { + return corner_radii.top_right; + } else { + return corner_radii.bottom_right; + } + } +} + +float4 to_device_position_transformed(float2 unit_vertex, Bounds bounds, + TransformationMatrix transformation) { + float2 position = unit_vertex * bounds.size + bounds.origin; + float2 transformed = mul(position, transformation.rotation_scale) + transformation.translation; + float2 device_position = transformed / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); + return float4(device_position, 0.0, 1.0); +} + +// Implementation of quad signed distance field +float quad_sdf_impl(float2 corner_center_to_point, float corner_radius) { + if (corner_radius == 0.0) { + // Fast path for unrounded corners + return max(corner_center_to_point.x, corner_center_to_point.y); + } else { + // Signed distance of the point from a quad that is inset by corner_radius + // It is negative inside this quad, and positive outside + float signed_distance_to_inset_quad = + // 0 inside the inset quad, and positive outside + length(max(float2(0.0, 0.0), corner_center_to_point)) + + // 0 outside the inset quad, and negative inside + min(0.0, max(corner_center_to_point.x, corner_center_to_point.y)); + + return signed_distance_to_inset_quad - corner_radius; + } +} + +float quad_sdf(float2 pt, Bounds bounds, Corners corner_radii) { + float2 half_size = bounds.size / 2.; + float2 center = bounds.origin + half_size; + float2 center_to_point = pt - center; + float corner_radius = pick_corner_radius(center_to_point, corner_radii); + float2 corner_to_point = abs(center_to_point) - half_size; + float2 corner_center_to_point = corner_to_point + corner_radius; + return quad_sdf_impl(corner_center_to_point, corner_radius); +} + +GradientColor prepare_gradient_color(uint tag, uint color_space, Hsla solid, LinearColorStop colors[2]) { + GradientColor output; + if (tag == 0 || tag == 2) { + output.solid = hsla_to_rgba(solid); + } else if (tag == 1) { + output.color0 = hsla_to_rgba(colors[0].color); + output.color1 = hsla_to_rgba(colors[1].color); + + // Prepare color space in vertex for avoid conversion + // in fragment shader for performance reasons + if (color_space == 1) { + // Oklab + output.color0 = srgb_to_oklab(output.color0); + output.color1 = srgb_to_oklab(output.color1); + } + } + + return output; +} + +float2x2 rotate2d(float angle) { + float s = sin(angle); + float c = cos(angle); + return float2x2(c, -s, s, c); +} + +float4 gradient_color(Background background, + float2 position, + Bounds bounds, + float4 solid_color, float4 color0, float4 color1) { + float4 color; + + switch (background.tag) { + case 0: + color = solid_color; + break; + case 1: { + // -90 degrees to match the CSS gradient angle. + float gradient_angle = background.gradient_angle_or_pattern_height; + float radians = (fmod(gradient_angle, 360.0) - 90.0) * (M_PI_F / 180.0); + float2 direction = float2(cos(radians), sin(radians)); + + // Expand the short side to be the same as the long side + if (bounds.size.x > bounds.size.y) { + direction.y *= bounds.size.y / bounds.size.x; + } else { + direction.x *= bounds.size.x / bounds.size.y; + } + + // Get the t value for the linear gradient with the color stop percentages. + float2 half_size = bounds.size * 0.5; + float2 center = bounds.origin + half_size; + float2 center_to_point = position - center; + float t = dot(center_to_point, direction) / length(direction); + // Check the direct to determine the use x or y + if (abs(direction.x) > abs(direction.y)) { + t = (t + half_size.x) / bounds.size.x; + } else { + t = (t + half_size.y) / bounds.size.y; + } + + // Adjust t based on the stop percentages + t = (t - background.colors[0].percentage) + / (background.colors[1].percentage + - background.colors[0].percentage); + t = clamp(t, 0.0, 1.0); + + switch (background.color_space) { + case 0: + color = lerp(color0, color1, t); + break; + case 1: { + float4 oklab_color = lerp(color0, color1, t); + color = oklab_to_srgb(oklab_color); + break; + } + } + break; + } + case 2: { + float gradient_angle_or_pattern_height = background.gradient_angle_or_pattern_height; + float pattern_width = (gradient_angle_or_pattern_height / 65535.0f) / 255.0f; + float pattern_interval = fmod(gradient_angle_or_pattern_height, 65535.0f) / 255.0f; + float pattern_height = pattern_width + pattern_interval; + float stripe_angle = M_PI_F / 4.0; + float pattern_period = pattern_height * sin(stripe_angle); + float2x2 rotation = rotate2d(stripe_angle); + float2 relative_position = position - bounds.origin; + float2 rotated_point = mul(rotation, relative_position); + float pattern = fmod(rotated_point.x, pattern_period); + float distance = min(pattern, pattern_period - pattern) - pattern_period * (pattern_width / pattern_height) / 2.0f; + color = solid_color; + color.a *= saturate(0.5 - distance); + break; + } + } + + return color; +} + +// Returns the dash velocity of a corner given the dash velocity of the two +// sides, by returning the slower velocity (larger dashes). +// +// Since 0 is used for dash velocity when the border width is 0 (instead of +// +inf), this returns the other dash velocity in that case. +// +// An alternative to this might be to appropriately interpolate the dash +// velocity around the corner, but that seems overcomplicated. +float corner_dash_velocity(float dv1, float dv2) { + if (dv1 == 0.0) { + return dv2; + } else if (dv2 == 0.0) { + return dv1; + } else { + return min(dv1, dv2); + } +} + +// Returns alpha used to render antialiased dashes. +// `t` is within the dash when `fmod(t, period) < length`. +float dash_alpha( + float t, float period, float length, float dash_velocity, + float antialias_threshold +) { + float half_period = period / 2.0; + float half_length = length / 2.0; + // Value in [-half_period, half_period] + // The dash is in [-half_length, half_length] + float centered = fmod(t + half_period - half_length, period) - half_period; + // Signed distance for the dash, negative values are inside the dash + float signed_distance = abs(centered) - half_length; + // Antialiased alpha based on the signed distance + return saturate(antialias_threshold - signed_distance / dash_velocity); +} + +// This approximates distance to the nearest point to a quarter ellipse in a way +// that is sufficient for anti-aliasing when the ellipse is not very eccentric. +// The components of `point` are expected to be positive. +// +// Negative on the outside and positive on the inside. +float quarter_ellipse_sdf(float2 pt, float2 radii) { + // Scale the space to treat the ellipse like a unit circle + float2 circle_vec = pt / radii; + float unit_circle_sdf = length(circle_vec) - 1.0; + // Approximate up-scaling of the length by using the average of the radii. + // + // TODO: A better solution would be to use the gradient of the implicit + // function for an ellipse to approximate a scaling factor. + return unit_circle_sdf * (radii.x + radii.y) * -0.5; +} + +/* +** +** Quads +** +*/ + +struct Quad { + uint order; + uint border_style; + Bounds bounds; + Bounds content_mask; + Background background; + Hsla border_color; + Corners corner_radii; + Edges border_widths; +}; + +struct QuadVertexOutput { + nointerpolation uint quad_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 border_color: COLOR0; + nointerpolation float4 background_solid: COLOR1; + nointerpolation float4 background_color0: COLOR2; + nointerpolation float4 background_color1: COLOR3; + float4 clip_distance: SV_ClipDistance; +}; + +struct QuadFragmentInput { + nointerpolation uint quad_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 border_color: COLOR0; + nointerpolation float4 background_solid: COLOR1; + nointerpolation float4 background_color0: COLOR2; + nointerpolation float4 background_color1: COLOR3; +}; + +StructuredBuffer<Quad> quads: register(t1); + +QuadVertexOutput quad_vertex(uint vertex_id: SV_VertexID, uint quad_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Quad quad = quads[quad_id]; + float4 device_position = to_device_position(unit_vertex, quad.bounds); + + GradientColor gradient = prepare_gradient_color( + quad.background.tag, + quad.background.color_space, + quad.background.solid, + quad.background.colors + ); + float4 clip_distance = distance_from_clip_rect(unit_vertex, quad.bounds, quad.content_mask); + float4 border_color = hsla_to_rgba(quad.border_color); + + QuadVertexOutput output; + output.position = device_position; + output.border_color = border_color; + output.quad_id = quad_id; + output.background_solid = gradient.solid; + output.background_color0 = gradient.color0; + output.background_color1 = gradient.color1; + output.clip_distance = clip_distance; + return output; +} + +float4 quad_fragment(QuadFragmentInput input): SV_Target { + Quad quad = quads[input.quad_id]; + float4 background_color = gradient_color(quad.background, input.position.xy, quad.bounds, + input.background_solid, input.background_color0, input.background_color1); + + bool unrounded = quad.corner_radii.top_left == 0.0 && + quad.corner_radii.top_right == 0.0 && + quad.corner_radii.bottom_left == 0.0 && + quad.corner_radii.bottom_right == 0.0; + + // Fast path when the quad is not rounded and doesn't have any border + if (quad.border_widths.top == 0.0 && + quad.border_widths.left == 0.0 && + quad.border_widths.right == 0.0 && + quad.border_widths.bottom == 0.0 && + unrounded) { + return background_color; + } + + float2 size = quad.bounds.size; + float2 half_size = size / 2.; + float2 the_point = input.position.xy - quad.bounds.origin; + float2 center_to_point = the_point - half_size; + + // Signed distance field threshold for inclusion of pixels. 0.5 is the + // minimum distance between the center of the pixel and the edge. + const float antialias_threshold = 0.5; + + // Radius of the nearest corner + float corner_radius = pick_corner_radius(center_to_point, quad.corner_radii); + + float2 border = float2( + center_to_point.x < 0.0 ? quad.border_widths.left : quad.border_widths.right, + center_to_point.y < 0.0 ? quad.border_widths.top : quad.border_widths.bottom + ); + + // 0-width borders are reduced so that `inner_sdf >= antialias_threshold`. + // The purpose of this is to not draw antialiasing pixels in this case. + float2 reduced_border = float2( + border.x == 0.0 ? -antialias_threshold : border.x, + border.y == 0.0 ? -antialias_threshold : border.y + ); + + // Vector from the corner of the quad bounds to the point, after mirroring + // the point into the bottom right quadrant. Both components are <= 0. + float2 corner_to_point = abs(center_to_point) - half_size; + + // Vector from the point to the center of the rounded corner's circle, also + // mirrored into bottom right quadrant. + float2 corner_center_to_point = corner_to_point + corner_radius; + + // Whether the nearest point on the border is rounded + bool is_near_rounded_corner = + corner_center_to_point.x >= 0.0 && + corner_center_to_point.y >= 0.0; + + // Vector from straight border inner corner to point. + // + // 0-width borders are turned into width -1 so that inner_sdf is > 1.0 near + // the border. Without this, antialiasing pixels would be drawn. + float2 straight_border_inner_corner_to_point = corner_to_point + reduced_border; + + // Whether the point is beyond the inner edge of the straight border + bool is_beyond_inner_straight_border = + straight_border_inner_corner_to_point.x > 0.0 || + straight_border_inner_corner_to_point.y > 0.0; + + // Whether the point is far enough inside the quad, such that the pixels are + // not affected by the straight border. + bool is_within_inner_straight_border = + straight_border_inner_corner_to_point.x < -antialias_threshold && + straight_border_inner_corner_to_point.y < -antialias_threshold; + + // Fast path for points that must be part of the background + if (is_within_inner_straight_border && !is_near_rounded_corner) { + return background_color; + } + + // Signed distance of the point to the outside edge of the quad's border + float outer_sdf = quad_sdf_impl(corner_center_to_point, corner_radius); + + // Approximate signed distance of the point to the inside edge of the quad's + // border. It is negative outside this edge (within the border), and + // positive inside. + // + // This is not always an accurate signed distance: + // * The rounded portions with varying border width use an approximation of + // nearest-point-on-ellipse. + // * When it is quickly known to be outside the edge, -1.0 is used. + float inner_sdf = 0.0; + if (corner_center_to_point.x <= 0.0 || corner_center_to_point.y <= 0.0) { + // Fast paths for straight borders + inner_sdf = -max(straight_border_inner_corner_to_point.x, + straight_border_inner_corner_to_point.y); + } else if (is_beyond_inner_straight_border) { + // Fast path for points that must be outside the inner edge + inner_sdf = -1.0; + } else if (reduced_border.x == reduced_border.y) { + // Fast path for circular inner edge. + inner_sdf = -(outer_sdf + reduced_border.x); + } else { + float2 ellipse_radii = max(float2(0.0, 0.0), float2(corner_radius, corner_radius) - reduced_border); + inner_sdf = quarter_ellipse_sdf(corner_center_to_point, ellipse_radii); + } + + // Negative when inside the border + float border_sdf = max(inner_sdf, outer_sdf); + + float4 color = background_color; + if (border_sdf < antialias_threshold) { + float4 border_color = input.border_color; + // Dashed border logic when border_style == 1 + if (quad.border_style == 1) { + // Position along the perimeter in "dash space", where each dash + // period has length 1 + float t = 0.0; + + // Total number of dash periods, so that the dash spacing can be + // adjusted to evenly divide it + float max_t = 0.0; + + // Border width is proportional to dash size. This is the behavior + // used by browsers, but also avoids dashes from different segments + // overlapping when dash size is smaller than the border width. + // + // Dash pattern: (2 * border width) dash, (1 * border width) gap + const float dash_length_per_width = 2.0; + const float dash_gap_per_width = 1.0; + const float dash_period_per_width = dash_length_per_width + dash_gap_per_width; + + // Since the dash size is determined by border width, the density of + // dashes varies. Multiplying a pixel distance by this returns a + // position in dash space - it has units (dash period / pixels). So + // a dash velocity of (1 / 10) is 1 dash every 10 pixels. + float dash_velocity = 0.0; + + // Dividing this by the border width gives the dash velocity + const float dv_numerator = 1.0 / dash_period_per_width; + + if (unrounded) { + // When corners aren't rounded, the dashes are separately laid + // out on each straight line, rather than around the whole + // perimeter. This way each line starts and ends with a dash. + bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; + float border_width = is_horizontal ? border.x : border.y; + dash_velocity = dv_numerator / border_width; + t = is_horizontal ? the_point.x : the_point.y; + t *= dash_velocity; + max_t = is_horizontal ? size.x : size.y; + max_t *= dash_velocity; + } else { + // When corners are rounded, the dashes are laid out clockwise + // around the whole perimeter. + + float r_tr = quad.corner_radii.top_right; + float r_br = quad.corner_radii.bottom_right; + float r_bl = quad.corner_radii.bottom_left; + float r_tl = quad.corner_radii.top_left; + + float w_t = quad.border_widths.top; + float w_r = quad.border_widths.right; + float w_b = quad.border_widths.bottom; + float w_l = quad.border_widths.left; + + // Straight side dash velocities + float dv_t = w_t <= 0.0 ? 0.0 : dv_numerator / w_t; + float dv_r = w_r <= 0.0 ? 0.0 : dv_numerator / w_r; + float dv_b = w_b <= 0.0 ? 0.0 : dv_numerator / w_b; + float dv_l = w_l <= 0.0 ? 0.0 : dv_numerator / w_l; + + // Straight side lengths in dash space + float s_t = (size.x - r_tl - r_tr) * dv_t; + float s_r = (size.y - r_tr - r_br) * dv_r; + float s_b = (size.x - r_br - r_bl) * dv_b; + float s_l = (size.y - r_bl - r_tl) * dv_l; + + float corner_dash_velocity_tr = corner_dash_velocity(dv_t, dv_r); + float corner_dash_velocity_br = corner_dash_velocity(dv_b, dv_r); + float corner_dash_velocity_bl = corner_dash_velocity(dv_b, dv_l); + float corner_dash_velocity_tl = corner_dash_velocity(dv_t, dv_l); + + // Corner lengths in dash space + float c_tr = r_tr * (M_PI_F / 2.0) * corner_dash_velocity_tr; + float c_br = r_br * (M_PI_F / 2.0) * corner_dash_velocity_br; + float c_bl = r_bl * (M_PI_F / 2.0) * corner_dash_velocity_bl; + float c_tl = r_tl * (M_PI_F / 2.0) * corner_dash_velocity_tl; + + // Cumulative dash space upto each segment + float upto_tr = s_t; + float upto_r = upto_tr + c_tr; + float upto_br = upto_r + s_r; + float upto_b = upto_br + c_br; + float upto_bl = upto_b + s_b; + float upto_l = upto_bl + c_bl; + float upto_tl = upto_l + s_l; + max_t = upto_tl + c_tl; + + if (is_near_rounded_corner) { + float radians = atan2(corner_center_to_point.y, corner_center_to_point.x); + float corner_t = radians * corner_radius; + + if (center_to_point.x >= 0.0) { + if (center_to_point.y < 0.0) { + dash_velocity = corner_dash_velocity_tr; + // Subtracted because radians is pi/2 to 0 when + // going clockwise around the top right corner, + // since the y axis has been flipped + t = upto_r - corner_t * dash_velocity; + } else { + dash_velocity = corner_dash_velocity_br; + // Added because radians is 0 to pi/2 when going + // clockwise around the bottom-right corner + t = upto_br + corner_t * dash_velocity; + } + } else { + if (center_to_point.y >= 0.0) { + dash_velocity = corner_dash_velocity_bl; + // Subtracted because radians is pi/1 to 0 when + // going clockwise around the bottom-left corner, + // since the x axis has been flipped + t = upto_l - corner_t * dash_velocity; + } else { + dash_velocity = corner_dash_velocity_tl; + // Added because radians is 0 to pi/2 when going + // clockwise around the top-left corner, since both + // axis were flipped + t = upto_tl + corner_t * dash_velocity; + } + } + } else { + // Straight borders + bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; + if (is_horizontal) { + if (center_to_point.y < 0.0) { + dash_velocity = dv_t; + t = (the_point.x - r_tl) * dash_velocity; + } else { + dash_velocity = dv_b; + t = upto_bl - (the_point.x - r_bl) * dash_velocity; + } + } else { + if (center_to_point.x < 0.0) { + dash_velocity = dv_l; + t = upto_tl - (the_point.y - r_tl) * dash_velocity; + } else { + dash_velocity = dv_r; + t = upto_r + (the_point.y - r_tr) * dash_velocity; + } + } + } + } + float dash_length = dash_length_per_width / dash_period_per_width; + float desired_dash_gap = dash_gap_per_width / dash_period_per_width; + + // Straight borders should start and end with a dash, so max_t is + // reduced to cause this. + max_t -= unrounded ? dash_length : 0.0; + if (max_t >= 1.0) { + // Adjust dash gap to evenly divide max_t + float dash_count = floor(max_t); + float dash_period = max_t / dash_count; + border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); + } else if (unrounded) { + // When there isn't enough space for the full gap between the + // two start / end dashes of a straight border, reduce gap to + // make them fit. + float dash_gap = max_t - dash_length; + if (dash_gap > 0.0) { + float dash_period = dash_length + dash_gap; + border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); + } + } + } + + // Blend the border on top of the background and then linearly interpolate + // between the two as we slide inside the background. + float4 blended_border = over(background_color, border_color); + color = lerp(background_color, blended_border, + saturate(antialias_threshold - inner_sdf)); + } + + return color * float4(1.0, 1.0, 1.0, saturate(antialias_threshold - outer_sdf)); +} + +/* +** +** Shadows +** +*/ + +struct Shadow { + uint order; + float blur_radius; + Bounds bounds; + Corners corner_radii; + Bounds content_mask; + Hsla color; +}; + +struct ShadowVertexOutput { + nointerpolation uint shadow_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct ShadowFragmentInput { + nointerpolation uint shadow_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; +}; + +StructuredBuffer<Shadow> shadows: register(t1); + +ShadowVertexOutput shadow_vertex(uint vertex_id: SV_VertexID, uint shadow_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Shadow shadow = shadows[shadow_id]; + + float margin = 3.0 * shadow.blur_radius; + Bounds bounds = shadow.bounds; + bounds.origin -= margin; + bounds.size += 2.0 * margin; + + float4 device_position = to_device_position(unit_vertex, bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, bounds, shadow.content_mask); + float4 color = hsla_to_rgba(shadow.color); + + ShadowVertexOutput output; + output.position = device_position; + output.color = color; + output.shadow_id = shadow_id; + output.clip_distance = clip_distance; + + return output; +} + +float4 shadow_fragment(ShadowFragmentInput input): SV_TARGET { + Shadow shadow = shadows[input.shadow_id]; + + float2 half_size = shadow.bounds.size / 2.; + float2 center = shadow.bounds.origin + half_size; + float2 point0 = input.position.xy - center; + float corner_radius = pick_corner_radius(point0, shadow.corner_radii); + + // The signal is only non-zero in a limited range, so don't waste samples + float low = point0.y - half_size.y; + float high = point0.y + half_size.y; + float start = clamp(-3. * shadow.blur_radius, low, high); + float end = clamp(3. * shadow.blur_radius, low, high); + + // Accumulate samples (we can get away with surprisingly few samples) + float step = (end - start) / 4.; + float y = start + step * 0.5; + float alpha = 0.; + for (int i = 0; i < 4; i++) { + alpha += blur_along_x(point0.x, point0.y - y, shadow.blur_radius, + corner_radius, half_size) * + gaussian(y, shadow.blur_radius) * step; + y += step; + } + + return input.color * float4(1., 1., 1., alpha); +} + +/* +** +** Path Rasterization +** +*/ + +struct PathRasterizationSprite { + float2 xy_position; + float2 st_position; + Background color; + Bounds bounds; +}; + +StructuredBuffer<PathRasterizationSprite> path_rasterization_sprites: register(t1); + +struct PathVertexOutput { + float4 position: SV_Position; + float2 st_position: TEXCOORD0; + nointerpolation uint vertex_id: TEXCOORD1; + float4 clip_distance: SV_ClipDistance; +}; + +struct PathFragmentInput { + float4 position: SV_Position; + float2 st_position: TEXCOORD0; + nointerpolation uint vertex_id: TEXCOORD1; +}; + +PathVertexOutput path_rasterization_vertex(uint vertex_id: SV_VertexID) { + PathRasterizationSprite sprite = path_rasterization_sprites[vertex_id]; + + PathVertexOutput output; + output.position = to_device_position_impl(sprite.xy_position); + output.st_position = sprite.st_position; + output.vertex_id = vertex_id; + output.clip_distance = distance_from_clip_rect_impl(sprite.xy_position, sprite.bounds); + + return output; +} + +float4 path_rasterization_fragment(PathFragmentInput input): SV_Target { + float2 dx = ddx(input.st_position); + float2 dy = ddy(input.st_position); + PathRasterizationSprite sprite = path_rasterization_sprites[input.vertex_id]; + + Background background = sprite.color; + Bounds bounds = sprite.bounds; + + float alpha; + if (length(float2(dx.x, dy.x))) { + alpha = 1.0; + } else { + float2 gradient = 2.0 * input.st_position.xx * float2(dx.x, dy.x) - float2(dx.y, dy.y); + float f = input.st_position.x * input.st_position.x - input.st_position.y; + float distance = f / length(gradient); + alpha = saturate(0.5 - distance); + } + + GradientColor gradient = prepare_gradient_color( + background.tag, background.color_space, background.solid, background.colors); + + float4 color = gradient_color(background, input.position.xy, bounds, + gradient.solid, gradient.color0, gradient.color1); + return float4(color.rgb * color.a * alpha, alpha * color.a); +} + +/* +** +** Path Sprites +** +*/ + +struct PathSprite { + Bounds bounds; +}; + +struct PathSpriteVertexOutput { + float4 position: SV_Position; + float2 texture_coords: TEXCOORD0; +}; + +StructuredBuffer<PathSprite> path_sprites: register(t1); + +PathSpriteVertexOutput path_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + PathSprite sprite = path_sprites[sprite_id]; + + // Don't apply content mask because it was already accounted for when rasterizing the path + float4 device_position = to_device_position(unit_vertex, sprite.bounds); + + float2 screen_position = sprite.bounds.origin + unit_vertex * sprite.bounds.size; + float2 texture_coords = screen_position / global_viewport_size; + + PathSpriteVertexOutput output; + output.position = device_position; + output.texture_coords = texture_coords; + return output; +} + +float4 path_sprite_fragment(PathSpriteVertexOutput input): SV_Target { + return t_sprite.Sample(s_sprite, input.texture_coords); +} + +/* +** +** Underlines +** +*/ + +struct Underline { + uint order; + uint pad; + Bounds bounds; + Bounds content_mask; + Hsla color; + float thickness; + uint wavy; +}; + +struct UnderlineVertexOutput { + nointerpolation uint underline_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct UnderlineFragmentInput { + nointerpolation uint underline_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; +}; + +StructuredBuffer<Underline> underlines: register(t1); + +UnderlineVertexOutput underline_vertex(uint vertex_id: SV_VertexID, uint underline_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Underline underline = underlines[underline_id]; + float4 device_position = to_device_position(unit_vertex, underline.bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, underline.bounds, + underline.content_mask); + float4 color = hsla_to_rgba(underline.color); + + UnderlineVertexOutput output; + output.position = device_position; + output.color = color; + output.underline_id = underline_id; + output.clip_distance = clip_distance; + return output; +} + +float4 underline_fragment(UnderlineFragmentInput input): SV_Target { + Underline underline = underlines[input.underline_id]; + if (underline.wavy) { + float half_thickness = underline.thickness * 0.5; + float2 origin = underline.bounds.origin; + float2 st = ((input.position.xy - origin) / underline.bounds.size.y) - float2(0., 0.5); + float frequency = (M_PI_F * (3. * underline.thickness)) / 8.; + float amplitude = 1. / (2. * underline.thickness); + float sine = sin(st.x * frequency) * amplitude; + float dSine = cos(st.x * frequency) * amplitude * frequency; + float distance = (st.y - sine) / sqrt(1. + dSine * dSine); + float distance_in_pixels = distance * underline.bounds.size.y; + float distance_from_top_border = distance_in_pixels - half_thickness; + float distance_from_bottom_border = distance_in_pixels + half_thickness; + float alpha = saturate( + 0.5 - max(-distance_from_bottom_border, distance_from_top_border)); + return input.color * float4(1., 1., 1., alpha); + } else { + return input.color; + } +} + +/* +** +** Monochrome sprites +** +*/ + +struct MonochromeSprite { + uint order; + uint pad; + Bounds bounds; + Bounds content_mask; + Hsla color; + AtlasTile tile; + TransformationMatrix transformation; +}; + +struct MonochromeSpriteVertexOutput { + float4 position: SV_Position; + float2 tile_position: POSITION; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct MonochromeSpriteFragmentInput { + float4 position: SV_Position; + float2 tile_position: POSITION; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +StructuredBuffer<MonochromeSprite> mono_sprites: register(t1); + +MonochromeSpriteVertexOutput monochrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + MonochromeSprite sprite = mono_sprites[sprite_id]; + float4 device_position = + to_device_position_transformed(unit_vertex, sprite.bounds, sprite.transformation); + float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, sprite.content_mask); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile); + float4 color = hsla_to_rgba(sprite.color); + + MonochromeSpriteVertexOutput output; + output.position = device_position; + output.tile_position = tile_position; + output.color = color; + output.clip_distance = clip_distance; + return output; +} + +float4 monochrome_sprite_fragment(MonochromeSpriteFragmentInput input): SV_Target { + float sample = t_sprite.Sample(s_sprite, input.tile_position).r; + return float4(input.color.rgb, input.color.a * sample); +} + +/* +** +** Polychrome sprites +** +*/ + +struct PolychromeSprite { + uint order; + uint pad; + uint grayscale; + float opacity; + Bounds bounds; + Bounds content_mask; + Corners corner_radii; + AtlasTile tile; +}; + +struct PolychromeSpriteVertexOutput { + nointerpolation uint sprite_id: TEXCOORD0; + float4 position: SV_Position; + float2 tile_position: POSITION; + float4 clip_distance: SV_ClipDistance; +}; + +struct PolychromeSpriteFragmentInput { + nointerpolation uint sprite_id: TEXCOORD0; + float4 position: SV_Position; + float2 tile_position: POSITION; +}; + +StructuredBuffer<PolychromeSprite> poly_sprites: register(t1); + +PolychromeSpriteVertexOutput polychrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + PolychromeSprite sprite = poly_sprites[sprite_id]; + float4 device_position = to_device_position(unit_vertex, sprite.bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, + sprite.content_mask); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile); + + PolychromeSpriteVertexOutput output; + output.position = device_position; + output.tile_position = tile_position; + output.sprite_id = sprite_id; + output.clip_distance = clip_distance; + return output; +} + +float4 polychrome_sprite_fragment(PolychromeSpriteFragmentInput input): SV_Target { + PolychromeSprite sprite = poly_sprites[input.sprite_id]; + float4 sample = t_sprite.Sample(s_sprite, input.tile_position); + float distance = quad_sdf(input.position.xy, sprite.bounds, sprite.corner_radii); + + float4 color = sample; + if ((sprite.grayscale & 0xFFu) != 0u) { + float3 grayscale = dot(color.rgb, GRAYSCALE_FACTORS); + color = float4(grayscale, sample.a); + } + color.a *= sprite.opacity * saturate(0.5 - distance); + return color; +} diff --git a/crates/gpui/src/platform/windows/window.rs b/crates/gpui/src/platform/windows/window.rs index 5703a82815..32a6da2391 100644 --- a/crates/gpui/src/platform/windows/window.rs +++ b/crates/gpui/src/platform/windows/window.rs @@ -26,10 +26,9 @@ use windows::{ core::*, }; -use crate::platform::blade::{BladeContext, BladeRenderer}; use crate::*; -pub(crate) struct WindowsWindow(pub Rc<WindowsWindowStatePtr>); +pub(crate) struct WindowsWindow(pub Rc<WindowsWindowInner>); pub struct WindowsWindowState { pub origin: Point<Pixels>, @@ -49,7 +48,7 @@ pub struct WindowsWindowState { pub system_key_handled: bool, pub hovered: bool, - pub renderer: BladeRenderer, + pub renderer: DirectXRenderer, pub click_state: ClickState, pub system_settings: WindowsSystemSettings, @@ -62,9 +61,9 @@ pub struct WindowsWindowState { hwnd: HWND, } -pub(crate) struct WindowsWindowStatePtr { +pub(crate) struct WindowsWindowInner { hwnd: HWND, - this: Weak<Self>, + pub(super) this: Weak<Self>, drop_target_helper: IDropTargetHelper, pub(crate) state: RefCell<WindowsWindowState>, pub(crate) handle: AnyWindowHandle, @@ -80,21 +79,23 @@ pub(crate) struct WindowsWindowStatePtr { impl WindowsWindowState { fn new( hwnd: HWND, - transparent: bool, - cs: &CREATESTRUCTW, + window_params: &CREATESTRUCTW, current_cursor: Option<HCURSOR>, display: WindowsDisplay, - gpu_context: &BladeContext, min_size: Option<Size<Pixels>>, appearance: WindowAppearance, + disable_direct_composition: bool, ) -> Result<Self> { let scale_factor = { let monitor_dpi = unsafe { GetDpiForWindow(hwnd) } as f32; monitor_dpi / USER_DEFAULT_SCREEN_DPI as f32 }; - let origin = logical_point(cs.x as f32, cs.y as f32, scale_factor); + let origin = logical_point(window_params.x as f32, window_params.y as f32, scale_factor); let logical_size = { - let physical_size = size(DevicePixels(cs.cx), DevicePixels(cs.cy)); + let physical_size = size( + DevicePixels(window_params.cx), + DevicePixels(window_params.cy), + ); physical_size.to_pixels(scale_factor) }; let fullscreen_restore_bounds = Bounds { @@ -103,7 +104,8 @@ impl WindowsWindowState { }; let border_offset = WindowBorderOffset::default(); let restore_from_minimized = None; - let renderer = windows_renderer::init(gpu_context, hwnd, transparent)?; + let renderer = DirectXRenderer::new(hwnd, disable_direct_composition) + .context("Creating DirectX renderer")?; let callbacks = Callbacks::default(); let input_handler = None; let pending_surrogate = None; @@ -202,17 +204,16 @@ impl WindowsWindowState { } } -impl WindowsWindowStatePtr { +impl WindowsWindowInner { fn new(context: &WindowCreateContext, hwnd: HWND, cs: &CREATESTRUCTW) -> Result<Rc<Self>> { let state = RefCell::new(WindowsWindowState::new( hwnd, - context.transparent, cs, context.current_cursor, context.display, - context.gpu_context, context.min_size, context.appearance, + context.disable_direct_composition, )?); Ok(Rc::new_cyclic(|this| Self { @@ -232,13 +233,13 @@ impl WindowsWindowStatePtr { } fn toggle_fullscreen(&self) { - let Some(state_ptr) = self.this.upgrade() else { + let Some(this) = self.this.upgrade() else { log::error!("Unable to toggle fullscreen: window has been dropped"); return; }; self.executor .spawn(async move { - let mut lock = state_ptr.state.borrow_mut(); + let mut lock = this.state.borrow_mut(); let StyleAndBounds { style, x, @@ -250,10 +251,9 @@ impl WindowsWindowStatePtr { } else { let (window_bounds, _) = lock.calculate_window_bounds(); lock.fullscreen_restore_bounds = window_bounds; - let style = - WINDOW_STYLE(unsafe { get_window_long(state_ptr.hwnd, GWL_STYLE) } as _); + let style = WINDOW_STYLE(unsafe { get_window_long(this.hwnd, GWL_STYLE) } as _); let mut rc = RECT::default(); - unsafe { GetWindowRect(state_ptr.hwnd, &mut rc) }.log_err(); + unsafe { GetWindowRect(this.hwnd, &mut rc) }.log_err(); let _ = lock.fullscreen.insert(StyleAndBounds { style, x: rc.left, @@ -277,10 +277,10 @@ impl WindowsWindowStatePtr { } }; drop(lock); - unsafe { set_window_long(state_ptr.hwnd, GWL_STYLE, style.0 as isize) }; + unsafe { set_window_long(this.hwnd, GWL_STYLE, style.0 as isize) }; unsafe { SetWindowPos( - state_ptr.hwnd, + this.hwnd, None, x, y, @@ -329,12 +329,11 @@ pub(crate) struct Callbacks { pub(crate) appearance_changed: Option<Box<dyn FnMut()>>, } -struct WindowCreateContext<'a> { - inner: Option<Result<Rc<WindowsWindowStatePtr>>>, +struct WindowCreateContext { + inner: Option<Result<Rc<WindowsWindowInner>>>, handle: AnyWindowHandle, hide_title_bar: bool, display: WindowsDisplay, - transparent: bool, is_movable: bool, min_size: Option<Size<Pixels>>, executor: ForegroundExecutor, @@ -343,9 +342,9 @@ struct WindowCreateContext<'a> { drop_target_helper: IDropTargetHelper, validation_number: usize, main_receiver: flume::Receiver<Runnable>, - gpu_context: &'a BladeContext, main_thread_id_win32: u32, appearance: WindowAppearance, + disable_direct_composition: bool, } impl WindowsWindow { @@ -353,7 +352,6 @@ impl WindowsWindow { handle: AnyWindowHandle, params: WindowParams, creation_info: WindowCreationInfo, - gpu_context: &BladeContext, ) -> Result<Self> { let WindowCreationInfo { icon, @@ -364,14 +362,15 @@ impl WindowsWindow { validation_number, main_receiver, main_thread_id_win32, + disable_direct_composition, } = creation_info; - let classname = register_wnd_class(icon); + register_window_class(icon); let hide_title_bar = params .titlebar .as_ref() .map(|titlebar| titlebar.appears_transparent) .unwrap_or(true); - let windowname = HSTRING::from( + let window_name = HSTRING::from( params .titlebar .as_ref() @@ -379,14 +378,18 @@ impl WindowsWindow { .map(|title| title.as_ref()) .unwrap_or(""), ); - let (dwexstyle, mut dwstyle) = if params.kind == WindowKind::PopUp { - (WS_EX_TOOLWINDOW | WS_EX_LAYERED, WINDOW_STYLE(0x0)) + + let (mut dwexstyle, dwstyle) = if params.kind == WindowKind::PopUp { + (WS_EX_TOOLWINDOW, WINDOW_STYLE(0x0)) } else { ( - WS_EX_APPWINDOW | WS_EX_LAYERED, + WS_EX_APPWINDOW, WS_THICKFRAME | WS_SYSMENU | WS_MAXIMIZEBOX | WS_MINIMIZEBOX, ) }; + if !disable_direct_composition { + dwexstyle |= WS_EX_NOREDIRECTIONBITMAP; + } let hinstance = get_module_handle(); let display = if let Some(display_id) = params.display_id { @@ -401,7 +404,6 @@ impl WindowsWindow { handle, hide_title_bar, display, - transparent: true, is_movable: params.is_movable, min_size: params.window_min_size, executor, @@ -410,16 +412,15 @@ impl WindowsWindow { drop_target_helper, validation_number, main_receiver, - gpu_context, main_thread_id_win32, appearance, + disable_direct_composition, }; - let lpparam = Some(&context as *const _ as *const _); let creation_result = unsafe { CreateWindowExW( dwexstyle, - classname, - &windowname, + WINDOW_CLASS_NAME, + &window_name, dwstyle, CW_USEDEFAULT, CW_USEDEFAULT, @@ -428,41 +429,35 @@ impl WindowsWindow { None, None, Some(hinstance.into()), - lpparam, + Some(&context as *const _ as *const _), ) }; - // We should call `?` on state_ptr first, then call `?` on hwnd. - // Or, we will lose the error info reported by `WindowsWindowState::new` - let state_ptr = context.inner.take().unwrap()?; + + // Failure to create a `WindowsWindowState` can cause window creation to fail, + // so check the inner result first. + let this = context.inner.take().unwrap()?; let hwnd = creation_result?; - register_drag_drop(state_ptr.clone())?; + + register_drag_drop(&this)?; configure_dwm_dark_mode(hwnd, appearance); - state_ptr.state.borrow_mut().border_offset.update(hwnd)?; + this.state.borrow_mut().border_offset.update(hwnd)?; let placement = retrieve_window_placement( hwnd, display, params.bounds, - state_ptr.state.borrow().scale_factor, - state_ptr.state.borrow().border_offset, + this.state.borrow().scale_factor, + this.state.borrow().border_offset, )?; if params.show { unsafe { SetWindowPlacement(hwnd, &placement)? }; } else { - state_ptr.state.borrow_mut().initial_placement = Some(WindowOpenStatus { + this.state.borrow_mut().initial_placement = Some(WindowOpenStatus { placement, state: WindowOpenState::Windowed, }); } - // The render pipeline will perform compositing on the GPU when the - // swapchain is configured correctly (see downstream of - // update_transparency). - // The following configuration is a one-time setup to ensure that the - // window is going to be composited with per-pixel alpha, but the render - // pipeline is responsible for effectively calling UpdateLayeredWindow - // at the appropriate time. - unsafe { SetLayeredWindowAttributes(hwnd, COLORREF(0), 255, LWA_ALPHA)? }; - Ok(Self(state_ptr)) + Ok(Self(this)) } } @@ -485,7 +480,6 @@ impl rwh::HasDisplayHandle for WindowsWindow { impl Drop for WindowsWindow { fn drop(&mut self) { - self.0.state.borrow_mut().renderer.destroy(); // clone this `Rc` to prevent early release of the pointer let this = self.0.clone(); self.0 @@ -683,6 +677,36 @@ impl PlatformWindow for WindowsWindow { this.set_window_placement().log_err(); unsafe { SetActiveWindow(hwnd).log_err() }; unsafe { SetFocus(Some(hwnd)).log_err() }; + + // premium ragebait by windows, this is needed because the window + // must have received an input event to be able to set itself to foreground + // so let's just simulate user input as that seems to be the most reliable way + // some more info: https://gist.github.com/Aetopia/1581b40f00cc0cadc93a0e8ccb65dc8c + // bonus: this bug also doesn't manifest if you have vs attached to the process + let inputs = [ + INPUT { + r#type: INPUT_KEYBOARD, + Anonymous: INPUT_0 { + ki: KEYBDINPUT { + wVk: VK_MENU, + dwFlags: KEYBD_EVENT_FLAGS(0), + ..Default::default() + }, + }, + }, + INPUT { + r#type: INPUT_KEYBOARD, + Anonymous: INPUT_0 { + ki: KEYBDINPUT { + wVk: VK_MENU, + dwFlags: KEYEVENTF_KEYUP, + ..Default::default() + }, + }, + }, + ]; + unsafe { SendInput(&inputs, std::mem::size_of::<INPUT>() as i32) }; + // todo(windows) // crate `windows 0.56` reports true as Err unsafe { SetForegroundWindow(hwnd).as_bool() }; @@ -705,24 +729,21 @@ impl PlatformWindow for WindowsWindow { } fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) { - let mut window_state = self.0.state.borrow_mut(); - window_state - .renderer - .update_transparency(background_appearance != WindowBackgroundAppearance::Opaque); + let hwnd = self.0.hwnd; match background_appearance { WindowBackgroundAppearance::Opaque => { // ACCENT_DISABLED - set_window_composition_attribute(window_state.hwnd, None, 0); + set_window_composition_attribute(hwnd, None, 0); } WindowBackgroundAppearance::Transparent => { // Use ACCENT_ENABLE_TRANSPARENTGRADIENT for transparent background - set_window_composition_attribute(window_state.hwnd, None, 2); + set_window_composition_attribute(hwnd, None, 2); } WindowBackgroundAppearance::Blurred => { // Enable acrylic blur // ACCENT_ENABLE_ACRYLICBLURBEHIND - set_window_composition_attribute(window_state.hwnd, Some((0, 0, 0, 0)), 4); + set_window_composition_attribute(hwnd, Some((0, 0, 0, 0)), 4); } } } @@ -794,11 +815,11 @@ impl PlatformWindow for WindowsWindow { } fn draw(&self, scene: &Scene) { - self.0.state.borrow_mut().renderer.draw(scene) + self.0.state.borrow_mut().renderer.draw(scene).log_err(); } fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> { - self.0.state.borrow().renderer.sprite_atlas().clone() + self.0.state.borrow().renderer.sprite_atlas() } fn get_raw_handle(&self) -> HWND { @@ -806,16 +827,16 @@ impl PlatformWindow for WindowsWindow { } fn gpu_specs(&self) -> Option<GpuSpecs> { - Some(self.0.state.borrow().renderer.gpu_specs()) + self.0.state.borrow().renderer.gpu_specs().log_err() } fn update_ime_position(&self, _bounds: Bounds<ScaledPixels>) { - // todo(windows) + // There is no such thing on Windows. } } #[implement(IDropTarget)] -struct WindowsDragDropHandler(pub Rc<WindowsWindowStatePtr>); +struct WindowsDragDropHandler(pub Rc<WindowsWindowInner>); impl WindowsDragDropHandler { fn handle_drag_drop(&self, input: PlatformInput) { @@ -1096,15 +1117,15 @@ enum WindowOpenState { Windowed, } -fn register_wnd_class(icon_handle: HICON) -> PCWSTR { - const CLASS_NAME: PCWSTR = w!("Zed::Window"); +const WINDOW_CLASS_NAME: PCWSTR = w!("Zed::Window"); +fn register_window_class(icon_handle: HICON) { static ONCE: Once = Once::new(); ONCE.call_once(|| { let wc = WNDCLASSW { - lpfnWndProc: Some(wnd_proc), + lpfnWndProc: Some(window_procedure), hIcon: icon_handle, - lpszClassName: PCWSTR(CLASS_NAME.as_ptr()), + lpszClassName: PCWSTR(WINDOW_CLASS_NAME.as_ptr()), style: CS_HREDRAW | CS_VREDRAW, hInstance: get_module_handle().into(), hbrBackground: unsafe { CreateSolidBrush(COLORREF(0x00000000)) }, @@ -1112,54 +1133,58 @@ fn register_wnd_class(icon_handle: HICON) -> PCWSTR { }; unsafe { RegisterClassW(&wc) }; }); - - CLASS_NAME } -unsafe extern "system" fn wnd_proc( +unsafe extern "system" fn window_procedure( hwnd: HWND, msg: u32, wparam: WPARAM, lparam: LPARAM, ) -> LRESULT { if msg == WM_NCCREATE { - let cs = lparam.0 as *const CREATESTRUCTW; - let cs = unsafe { &*cs }; - let ctx = cs.lpCreateParams as *mut WindowCreateContext; - let ctx = unsafe { &mut *ctx }; - let creation_result = WindowsWindowStatePtr::new(ctx, hwnd, cs); - if creation_result.is_err() { - ctx.inner = Some(creation_result); - return LRESULT(0); - } - let weak = Box::new(Rc::downgrade(creation_result.as_ref().unwrap())); - unsafe { set_window_long(hwnd, GWLP_USERDATA, Box::into_raw(weak) as isize) }; - ctx.inner = Some(creation_result); - return unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) }; + let window_params = lparam.0 as *const CREATESTRUCTW; + let window_params = unsafe { &*window_params }; + let window_creation_context = window_params.lpCreateParams as *mut WindowCreateContext; + let window_creation_context = unsafe { &mut *window_creation_context }; + return match WindowsWindowInner::new(window_creation_context, hwnd, window_params) { + Ok(window_state) => { + let weak = Box::new(Rc::downgrade(&window_state)); + unsafe { set_window_long(hwnd, GWLP_USERDATA, Box::into_raw(weak) as isize) }; + window_creation_context.inner = Some(Ok(window_state)); + unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) } + } + Err(error) => { + window_creation_context.inner = Some(Err(error)); + LRESULT(0) + } + }; } - let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowStatePtr>; + + let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowInner>; if ptr.is_null() { return unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) }; } let inner = unsafe { &*ptr }; - let r = if let Some(state) = inner.upgrade() { - handle_msg(hwnd, msg, wparam, lparam, state) + let result = if let Some(inner) = inner.upgrade() { + inner.handle_msg(hwnd, msg, wparam, lparam) } else { unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) } }; + if msg == WM_NCDESTROY { unsafe { set_window_long(hwnd, GWLP_USERDATA, 0) }; unsafe { drop(Box::from_raw(ptr)) }; } - r + + result } -pub(crate) fn try_get_window_inner(hwnd: HWND) -> Option<Rc<WindowsWindowStatePtr>> { +pub(crate) fn window_from_hwnd(hwnd: HWND) -> Option<Rc<WindowsWindowInner>> { if hwnd.is_invalid() { return None; } - let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowStatePtr>; + let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak<WindowsWindowInner>; if !ptr.is_null() { let inner = unsafe { &*ptr }; inner.upgrade() @@ -1182,9 +1207,9 @@ fn get_module_handle() -> HMODULE { } } -fn register_drag_drop(state_ptr: Rc<WindowsWindowStatePtr>) -> Result<()> { - let window_handle = state_ptr.hwnd; - let handler = WindowsDragDropHandler(state_ptr); +fn register_drag_drop(window: &Rc<WindowsWindowInner>) -> Result<()> { + let window_handle = window.hwnd; + let handler = WindowsDragDropHandler(window.clone()); // The lifetime of `IDropTarget` is handled by Windows, it won't release until // we call `RevokeDragDrop`. // So, it's safe to drop it here. @@ -1306,52 +1331,6 @@ fn set_window_composition_attribute(hwnd: HWND, color: Option<Color>, state: u32 } } -mod windows_renderer { - use crate::platform::blade::{BladeContext, BladeRenderer, BladeSurfaceConfig}; - use raw_window_handle as rwh; - use std::num::NonZeroIsize; - use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::GWLP_HINSTANCE}; - - use crate::{get_window_long, show_error}; - - pub(super) fn init( - context: &BladeContext, - hwnd: HWND, - transparent: bool, - ) -> anyhow::Result<BladeRenderer> { - let raw = RawWindow { hwnd }; - let config = BladeSurfaceConfig { - size: Default::default(), - transparent, - }; - BladeRenderer::new(context, &raw, config) - .inspect_err(|err| show_error("Failed to initialize BladeRenderer", err.to_string())) - } - - struct RawWindow { - hwnd: HWND, - } - - impl rwh::HasWindowHandle for RawWindow { - fn window_handle(&self) -> Result<rwh::WindowHandle<'_>, rwh::HandleError> { - Ok(unsafe { - let hwnd = NonZeroIsize::new_unchecked(self.hwnd.0 as isize); - let mut handle = rwh::Win32WindowHandle::new(hwnd); - let hinstance = get_window_long(self.hwnd, GWLP_HINSTANCE); - handle.hinstance = NonZeroIsize::new(hinstance); - rwh::WindowHandle::borrow_raw(handle.into()) - }) - } - } - - impl rwh::HasDisplayHandle for RawWindow { - fn display_handle(&self) -> Result<rwh::DisplayHandle<'_>, rwh::HandleError> { - let handle = rwh::WindowsDisplayHandle::new(); - Ok(unsafe { rwh::DisplayHandle::borrow_raw(handle.into()) }) - } - } -} - #[cfg(test)] mod tests { use super::ClickState; diff --git a/crates/gpui/src/scene.rs b/crates/gpui/src/scene.rs index 4eaef64afa..c527dfe750 100644 --- a/crates/gpui/src/scene.rs +++ b/crates/gpui/src/scene.rs @@ -8,7 +8,12 @@ use crate::{ AtlasTextureId, AtlasTile, Background, Bounds, ContentMask, Corners, Edges, Hsla, Pixels, Point, Radians, ScaledPixels, Size, bounds_tree::BoundsTree, point, }; -use std::{fmt::Debug, iter::Peekable, ops::Range, slice}; +use std::{ + fmt::Debug, + iter::Peekable, + ops::{Add, Range, Sub}, + slice, +}; #[allow(non_camel_case_types, unused)] pub(crate) type PathVertex_ScaledPixels = PathVertex<ScaledPixels>; @@ -43,17 +48,6 @@ impl Scene { self.surfaces.clear(); } - #[cfg_attr( - all( - any(target_os = "linux", target_os = "freebsd"), - not(any(feature = "x11", feature = "wayland")) - ), - allow(dead_code) - )] - pub fn paths(&self) -> &[Path<ScaledPixels>] { - &self.paths - } - pub fn len(&self) -> usize { self.paint_operations.len() } @@ -681,7 +675,7 @@ pub(crate) struct PathId(pub(crate) usize); #[derive(Clone, Debug)] pub struct Path<P: Clone + Debug + Default + PartialEq> { pub(crate) id: PathId, - order: DrawOrder, + pub(crate) order: DrawOrder, pub(crate) bounds: Bounds<P>, pub(crate) content_mask: ContentMask<P>, pub(crate) vertices: Vec<PathVertex<P>>, @@ -804,6 +798,16 @@ impl Path<Pixels> { } } +impl<T> Path<T> +where + T: Clone + Debug + Default + PartialEq + PartialOrd + Add<T, Output = T> + Sub<Output = T>, +{ + #[allow(unused)] + pub(crate) fn clipped_bounds(&self) -> Bounds<T> { + self.bounds.intersect(&self.content_mask.bounds) + } +} + impl From<Path<ScaledPixels>> for Primitive { fn from(path: Path<ScaledPixels>) -> Self { Primitive::Path(path) diff --git a/crates/gpui/src/tab_stop.rs b/crates/gpui/src/tab_stop.rs index 2ec3f560e8..7dde42efed 100644 --- a/crates/gpui/src/tab_stop.rs +++ b/crates/gpui/src/tab_stop.rs @@ -5,7 +5,7 @@ use crate::{FocusHandle, FocusId}; /// Used to manage the `Tab` event to switch between focus handles. #[derive(Default)] pub(crate) struct TabHandles { - handles: Vec<FocusHandle>, + pub(crate) handles: Vec<FocusHandle>, } impl TabHandles { @@ -32,20 +32,18 @@ impl TabHandles { self.handles.clear(); } - fn current_index(&self, focused_id: Option<&FocusId>) -> usize { - self.handles - .iter() - .position(|h| Some(&h.id) == focused_id) - .unwrap_or_default() + fn current_index(&self, focused_id: Option<&FocusId>) -> Option<usize> { + self.handles.iter().position(|h| Some(&h.id) == focused_id) } pub(crate) fn next(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { - let ix = self.current_index(focused_id); - - let mut next_ix = ix + 1; - if next_ix + 1 > self.handles.len() { - next_ix = 0; - } + let next_ix = self + .current_index(focused_id) + .and_then(|ix| { + let next_ix = ix + 1; + (next_ix < self.handles.len()).then_some(next_ix) + }) + .unwrap_or_default(); if let Some(next_handle) = self.handles.get(next_ix) { Some(next_handle.clone()) @@ -55,7 +53,7 @@ impl TabHandles { } pub(crate) fn prev(&self, focused_id: Option<&FocusId>) -> Option<FocusHandle> { - let ix = self.current_index(focused_id); + let ix = self.current_index(focused_id).unwrap_or_default(); let prev_ix; if ix == 0 { prev_ix = self.handles.len().saturating_sub(1); @@ -108,8 +106,14 @@ mod tests { ] ); - // next - assert_eq!(tab.next(None), Some(tab.handles[1].clone())); + // Select first tab index if no handle is currently focused. + assert_eq!(tab.next(None), Some(tab.handles[0].clone())); + // Select last tab index if no handle is currently focused. + assert_eq!( + tab.prev(None), + Some(tab.handles[tab.handles.len() - 1].clone()) + ); + assert_eq!( tab.next(Some(&tab.handles[0].id)), Some(tab.handles[1].clone()) diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index 963d2bb45c..40d3845ff9 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -79,11 +79,13 @@ pub enum DispatchPhase { impl DispatchPhase { /// Returns true if this represents the "bubble" phase. + #[inline] pub fn bubble(self) -> bool { self == DispatchPhase::Bubble } /// Returns true if this represents the "capture" phase. + #[inline] pub fn capture(self) -> bool { self == DispatchPhase::Capture } @@ -702,6 +704,7 @@ pub(crate) struct PaintIndex { input_handlers_index: usize, cursor_styles_index: usize, accessed_element_states_index: usize, + tab_handle_index: usize, line_layout_index: LineLayoutIndex, } @@ -1019,7 +1022,7 @@ impl Window { || (active.get() && last_input_timestamp.get().elapsed() < Duration::from_secs(1)); - if invalidator.is_dirty() { + if invalidator.is_dirty() || request_frame_options.force_render { measure("frame duration", || { handle .update(&mut cx, |_, window, cx| { @@ -2208,6 +2211,7 @@ impl Window { input_handlers_index: self.next_frame.input_handlers.len(), cursor_styles_index: self.next_frame.cursor_styles.len(), accessed_element_states_index: self.next_frame.accessed_element_states.len(), + tab_handle_index: self.next_frame.tab_handles.handles.len(), line_layout_index: self.text_system.layout_index(), } } @@ -2237,6 +2241,12 @@ impl Window { .iter() .map(|(id, type_id)| (GlobalElementId(id.0.clone()), *type_id)), ); + self.next_frame.tab_handles.handles.extend( + self.rendered_frame.tab_handles.handles + [range.start.tab_handle_index..range.end.tab_handle_index] + .iter() + .cloned(), + ); self.text_system .reuse_layouts(range.start.line_layout_index..range.end.line_layout_index); @@ -4238,6 +4248,25 @@ impl Window { .on_action(action_type, Rc::new(listener)); } + /// Register an action listener on the window for the next frame if the condition is true. + /// The type of action is determined by the first parameter of the given listener. + /// When the next frame is rendered the listener will be cleared. + /// + /// This is a fairly low-level method, so prefer using action handlers on elements unless you have + /// a specific need to register a global listener. + pub fn on_action_when( + &mut self, + condition: bool, + action_type: TypeId, + listener: impl Fn(&dyn Any, DispatchPhase, &mut Window, &mut App) + 'static, + ) { + if condition { + self.next_frame + .dispatch_tree + .on_action(action_type, Rc::new(listener)); + } + } + /// Read information about the GPU backing this window. /// Currently returns None on Mac and Windows. pub fn gpu_specs(&self) -> Option<GpuSpecs> { @@ -4691,6 +4720,8 @@ pub enum ElementId { Path(Arc<std::path::Path>), /// A code location. CodeLocation(core::panic::Location<'static>), + /// A labeled child of an element. + NamedChild(Box<ElementId>, SharedString), } impl ElementId { @@ -4711,6 +4742,7 @@ impl Display for ElementId { ElementId::Uuid(uuid) => write!(f, "{}", uuid)?, ElementId::Path(path) => write!(f, "{}", path.display())?, ElementId::CodeLocation(location) => write!(f, "{}", location)?, + ElementId::NamedChild(id, name) => write!(f, "{}-{}", id, name)?, } Ok(()) @@ -4801,6 +4833,12 @@ impl From<(&'static str, u32)> for ElementId { } } +impl<T: Into<SharedString>> From<(ElementId, T)> for ElementId { + fn from((id, name): (ElementId, T)) -> Self { + ElementId::NamedChild(Box::new(id), name.into()) + } +} + /// A rectangle to be rendered in the window at the given position and size. /// Passed as an argument [`Window::paint_quad`]. #[derive(Clone)] diff --git a/crates/http_client/Cargo.toml b/crates/http_client/Cargo.toml index 2045708ff2..f63bff295e 100644 --- a/crates/http_client/Cargo.toml +++ b/crates/http_client/Cargo.toml @@ -23,6 +23,8 @@ futures.workspace = true http.workspace = true http-body.workspace = true log.workspace = true +parking_lot.workspace = true +reqwest.workspace = true serde.workspace = true serde_json.workspace = true url.workspace = true diff --git a/crates/http_client/src/async_body.rs b/crates/http_client/src/async_body.rs index 88972d279c..473849f3cd 100644 --- a/crates/http_client/src/async_body.rs +++ b/crates/http_client/src/async_body.rs @@ -88,6 +88,17 @@ impl From<&'static str> for AsyncBody { } } +impl TryFrom<reqwest::Body> for AsyncBody { + type Error = anyhow::Error; + + fn try_from(value: reqwest::Body) -> Result<Self, Self::Error> { + value + .as_bytes() + .ok_or_else(|| anyhow::anyhow!("Underlying data is a stream")) + .map(|bytes| Self::from_bytes(Bytes::copy_from_slice(bytes))) + } +} + impl<T: Into<Self>> From<Option<T>> for AsyncBody { fn from(body: Option<T>) -> Self { match body { diff --git a/crates/http_client/src/github.rs b/crates/http_client/src/github.rs index a038915e2f..a19c13b0ff 100644 --- a/crates/http_client/src/github.rs +++ b/crates/http_client/src/github.rs @@ -8,6 +8,7 @@ use url::Url; pub struct GitHubLspBinaryVersion { pub name: String, pub url: String, + pub digest: Option<String>, } #[derive(Deserialize, Debug)] @@ -24,6 +25,7 @@ pub struct GithubRelease { pub struct GithubReleaseAsset { pub name: String, pub browser_download_url: String, + pub digest: Option<String>, } pub async fn latest_github_release( diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index eebab86e21..a7f75b0962 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -4,16 +4,18 @@ pub mod github; pub use anyhow::{Result, anyhow}; pub use async_body::{AsyncBody, Inner}; use derive_more::Deref; +use http::HeaderValue; pub use http::{self, Method, Request, Response, StatusCode, Uri}; -use futures::future::BoxFuture; +use futures::{ + FutureExt as _, + future::{self, BoxFuture}, +}; use http::request::Builder; +use parking_lot::Mutex; #[cfg(feature = "test-support")] use std::fmt; -use std::{ - any::type_name, - sync::{Arc, Mutex}, -}; +use std::{any::type_name, sync::Arc}; pub use url::Url; #[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] @@ -39,6 +41,8 @@ impl HttpRequestExt for http::request::Builder { pub trait HttpClient: 'static + Send + Sync { fn type_name(&self) -> &'static str; + fn user_agent(&self) -> Option<&HeaderValue>; + fn send( &self, req: http::Request<AsyncBody>, @@ -83,6 +87,19 @@ pub trait HttpClient: 'static + Send + Sync { } fn proxy(&self) -> Option<&Url>; + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::<Self>()) + } + + fn send_multipart_form<'a>( + &'a self, + _url: &str, + _request: reqwest::multipart::Form, + ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> { + future::ready(Err(anyhow!("not implemented"))).boxed() + } } /// An [`HttpClient`] that may have a proxy. @@ -118,21 +135,8 @@ impl HttpClient for HttpClientWithProxy { self.client.send(req) } - fn proxy(&self) -> Option<&Url> { - self.proxy.as_ref() - } - - fn type_name(&self) -> &'static str { - self.client.type_name() - } -} - -impl HttpClient for Arc<HttpClientWithProxy> { - fn send( - &self, - req: Request<AsyncBody>, - ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - self.client.send(req) + fn user_agent(&self) -> Option<&HeaderValue> { + self.client.user_agent() } fn proxy(&self) -> Option<&Url> { @@ -142,6 +146,19 @@ impl HttpClient for Arc<HttpClientWithProxy> { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } + + fn send_multipart_form<'a>( + &'a self, + url: &str, + form: reqwest::multipart::Form, + ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> { + self.client.send_multipart_form(url, form) + } } /// An [`HttpClient`] that has a base URL. @@ -188,20 +205,13 @@ impl HttpClientWithUrl { /// Returns the base URL. pub fn base_url(&self) -> String { - self.base_url - .lock() - .map_or_else(|_| Default::default(), |url| url.clone()) + self.base_url.lock().clone() } /// Sets the base URL. pub fn set_base_url(&self, base_url: impl Into<String>) { let base_url = base_url.into(); - self.base_url - .lock() - .map(|mut url| { - *url = base_url; - }) - .ok(); + *self.base_url.lock() = base_url; } /// Builds a URL using the given path. @@ -225,6 +235,22 @@ impl HttpClientWithUrl { )?) } + /// Builds a Zed Cloud URL using the given path. + pub fn build_zed_cloud_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> { + let base_url = self.base_url(); + let base_api_url = match base_url.as_ref() { + "https://zed.dev" => "https://cloud.zed.dev", + "https://staging.zed.dev" => "https://cloud.zed.dev", + "http://localhost:3000" => "http://localhost:8787", + other => other, + }; + + Ok(Url::parse_with_params( + &format!("{}{}", base_api_url, path), + query, + )?) + } + /// Builds a Zed LLM URL using the given path. pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> { let base_url = self.base_url(); @@ -242,23 +268,6 @@ impl HttpClientWithUrl { } } -impl HttpClient for Arc<HttpClientWithUrl> { - fn send( - &self, - req: Request<AsyncBody>, - ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - self.client.send(req) - } - - fn proxy(&self) -> Option<&Url> { - self.client.proxy.as_ref() - } - - fn type_name(&self) -> &'static str { - self.client.type_name() - } -} - impl HttpClient for HttpClientWithUrl { fn send( &self, @@ -267,6 +276,10 @@ impl HttpClient for HttpClientWithUrl { self.client.send(req) } + fn user_agent(&self) -> Option<&HeaderValue> { + self.client.user_agent() + } + fn proxy(&self) -> Option<&Url> { self.client.proxy.as_ref() } @@ -274,6 +287,19 @@ impl HttpClient for HttpClientWithUrl { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } + + fn send_multipart_form<'a>( + &'a self, + url: &str, + request: reqwest::multipart::Form, + ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> { + self.client.send_multipart_form(url, request) + } } pub fn read_proxy_from_env() -> Option<Url> { @@ -314,6 +340,10 @@ impl HttpClient for BlockedHttpClient { }) } + fn user_agent(&self) -> Option<&HeaderValue> { + None + } + fn proxy(&self) -> Option<&Url> { None } @@ -321,10 +351,15 @@ impl HttpClient for BlockedHttpClient { fn type_name(&self) -> &'static str { type_name::<Self>() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::<Self>()) + } } #[cfg(feature = "test-support")] -type FakeHttpHandler = Box< +type FakeHttpHandler = Arc< dyn Fn(Request<AsyncBody>) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> + Send + Sync @@ -333,7 +368,8 @@ type FakeHttpHandler = Box< #[cfg(feature = "test-support")] pub struct FakeHttpClient { - handler: FakeHttpHandler, + handler: Mutex<Option<FakeHttpHandler>>, + user_agent: HeaderValue, } #[cfg(feature = "test-support")] @@ -347,7 +383,8 @@ impl FakeHttpClient { base_url: Mutex::new("http://test.example".into()), client: HttpClientWithProxy { client: Arc::new(Self { - handler: Box::new(move |req| Box::pin(handler(req))), + handler: Mutex::new(Some(Arc::new(move |req| Box::pin(handler(req))))), + user_agent: HeaderValue::from_static(type_name::<Self>()), }), proxy: None, }, @@ -371,6 +408,18 @@ impl FakeHttpClient { .unwrap()) }) } + + pub fn replace_handler<Fut, F>(&self, new_handler: F) + where + Fut: futures::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send + 'static, + F: Fn(FakeHttpHandler, Request<AsyncBody>) -> Fut + Send + Sync + 'static, + { + let mut handler = self.handler.lock(); + let old_handler = handler.take().unwrap(); + *handler = Some(Arc::new(move |req| { + Box::pin(new_handler(old_handler.clone(), req)) + })); + } } #[cfg(feature = "test-support")] @@ -386,10 +435,14 @@ impl HttpClient for FakeHttpClient { &self, req: Request<AsyncBody>, ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { - let future = (self.handler)(req); + let future = (self.handler.lock().as_ref().unwrap())(req); future } + fn user_agent(&self) -> Option<&HeaderValue> { + Some(&self.user_agent) + } + fn proxy(&self) -> Option<&Url> { None } @@ -397,4 +450,8 @@ impl HttpClient for FakeHttpClient { fn type_name(&self) -> &'static str { type_name::<Self>() } + + fn as_fake(&self) -> &FakeHttpClient { + self + } } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index e7066ae151..a94d89bdc8 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -38,7 +38,6 @@ pub enum IconName { ArrowUpFromLine, ArrowUpRight, ArrowUpRightAlt, - AtSign, AudioOff, AudioOn, Backspace, @@ -48,15 +47,13 @@ pub enum IconName { BellRing, Binary, Blocks, - Bolt, + BoltOutlined, BoltFilled, - BoltFilledAlt, Book, BookCopy, - BookPlus, - Brain, BugOff, CaseSensitive, + Chat, Check, CheckDouble, ChevronDown, @@ -71,6 +68,7 @@ pub enum IconName { CircleHelp, Close, Cloud, + CloudDownload, Code, Cog, Command, @@ -106,6 +104,12 @@ pub enum IconName { Disconnected, DocumentText, Download, + EditorAtom, + EditorCursor, + EditorEmacs, + EditorJetBrains, + EditorSublime, + EditorVsCode, Ellipsis, EllipsisVertical, Envelope, @@ -177,14 +181,9 @@ pub enum IconName { Maximize, Menu, MenuAlt, - MessageBubbles, Mic, MicMute, - Microscope, Minimize, - NewFromSummary, - NewTextThread, - NewThread, Option, PageDown, PageUp, @@ -195,9 +194,7 @@ pub enum IconName { PersonCircle, PhoneIncoming, Pin, - Play, - PlayAlt, - PlayBug, + PlayOutlined, PlayFilled, Plus, PocketKnife, @@ -214,7 +211,6 @@ pub enum IconName { ReplyArrowRight, Rerun, Return, - Reveal, RotateCcw, RotateCw, Route, @@ -228,6 +224,7 @@ pub enum IconName { Server, Settings, SettingsAlt, + ShieldCheck, Shift, Slash, SlashSquare, @@ -238,7 +235,6 @@ pub enum IconName { Sparkle, SparkleAlt, SparkleFilled, - Spinner, Split, SplitAlt, SquareDot, @@ -248,7 +244,6 @@ pub enum IconName { StarFilled, Stop, StopFilled, - Strikethrough, Supermaven, SupermavenDisabled, SupermavenError, @@ -258,6 +253,9 @@ pub enum IconName { Terminal, TerminalAlt, TextSnippet, + TextThread, + Thread, + ThreadFromSummary, ThumbsDown, ThumbsUp, TodoComplete, @@ -277,7 +275,6 @@ pub enum IconName { ToolTerminal, ToolWeb, Trash, - TrashAlt, Triangle, TriangleRight, Undo, diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 1df33286ee..b9933dfcec 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -161,12 +161,11 @@ pub struct CachedLspAdapter { pub name: LanguageServerName, pub disk_based_diagnostic_sources: Vec<String>, pub disk_based_diagnostics_progress_token: Option<String>, - language_ids: HashMap<String, String>, + language_ids: HashMap<LanguageName, String>, pub adapter: Arc<dyn LspAdapter>, pub reinstall_attempt_count: AtomicU64, cached_binary: futures::lock::Mutex<Option<LanguageServerBinary>>, manifest_name: OnceLock<Option<ManifestName>>, - attach_kind: OnceLock<Attach>, } impl Debug for CachedLspAdapter { @@ -202,7 +201,6 @@ impl CachedLspAdapter { adapter, cached_binary: Default::default(), reinstall_attempt_count: AtomicU64::new(0), - attach_kind: Default::default(), manifest_name: Default::default(), }) } @@ -279,38 +277,25 @@ impl CachedLspAdapter { pub fn language_id(&self, language_name: &LanguageName) -> String { self.language_ids - .get(language_name.as_ref()) + .get(language_name) .cloned() .unwrap_or_else(|| language_name.lsp_id()) } + pub fn manifest_name(&self) -> Option<ManifestName> { self.manifest_name .get_or_init(|| self.adapter.manifest_name()) .clone() } - pub fn attach_kind(&self) -> Attach { - *self.attach_kind.get_or_init(|| self.adapter.attach_kind()) - } } +/// Determines what gets sent out as a workspace folders content #[derive(Clone, Copy, Debug, PartialEq)] -pub enum Attach { - /// Create a single language server instance per subproject root. - InstancePerRoot, - /// Use one shared language server instance for all subprojects within a project. - Shared, -} - -impl Attach { - pub fn root_path( - &self, - root_subproject_path: (WorktreeId, Arc<Path>), - ) -> (WorktreeId, Arc<Path>) { - match self { - Attach::InstancePerRoot => root_subproject_path, - Attach::Shared => (root_subproject_path.0, Arc::from(Path::new(""))), - } - } +pub enum WorkspaceFoldersContent { + /// Send out a single entry with the root of the workspace. + WorktreeRoot, + /// Send out a list of subproject roots. + SubprojectRoots, } /// [`LspAdapterDelegate`] allows [`LspAdapter]` implementations to interface with the application @@ -589,8 +574,8 @@ pub trait LspAdapter: 'static + Send + Sync { None } - fn language_ids(&self) -> HashMap<String, String> { - Default::default() + fn language_ids(&self) -> HashMap<LanguageName, String> { + HashMap::default() } /// Support custom initialize params. @@ -602,8 +587,11 @@ pub trait LspAdapter: 'static + Send + Sync { Ok(original) } - fn attach_kind(&self) -> Attach { - Attach::Shared + /// Determines whether a language server supports workspace folders. + /// + /// And does not trip over itself in the process. + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::SubprojectRoots } fn manifest_name(&self) -> Option<ManifestName> { @@ -2365,9 +2353,9 @@ mod tests { assert_eq!( languages.language_names(), &[ - "JSON".to_string(), - "Plain Text".to_string(), - "Rust".to_string(), + LanguageName::new("JSON"), + LanguageName::new("Plain Text"), + LanguageName::new("Rust"), ] ); @@ -2378,9 +2366,9 @@ mod tests { assert_eq!( languages.language_names(), &[ - "JSON".to_string(), - "Plain Text".to_string(), - "Rust".to_string(), + LanguageName::new("JSON"), + LanguageName::new("Plain Text"), + LanguageName::new("Rust"), ] ); @@ -2391,9 +2379,9 @@ mod tests { assert_eq!( languages.language_names(), &[ - "JSON".to_string(), - "Plain Text".to_string(), - "Rust".to_string(), + LanguageName::new("JSON"), + LanguageName::new("Plain Text"), + LanguageName::new("Rust"), ] ); diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index ab3c0f9b37..ea988e8098 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -411,30 +411,6 @@ impl LanguageRegistry { cached } - pub fn get_or_register_lsp_adapter( - &self, - language_name: LanguageName, - server_name: LanguageServerName, - build_adapter: impl FnOnce() -> Arc<dyn LspAdapter> + 'static, - ) -> Arc<CachedLspAdapter> { - let registered = self - .state - .write() - .lsp_adapters - .entry(language_name.clone()) - .or_default() - .iter() - .find(|cached_adapter| cached_adapter.name == server_name) - .cloned(); - - if let Some(found) = registered { - found - } else { - let adapter = build_adapter(); - self.register_lsp_adapter(language_name, adapter) - } - } - /// Register a fake language server and adapter /// The returned channel receives a new instance of the language server every time it is started #[cfg(any(feature = "test-support", test))] @@ -571,15 +547,15 @@ impl LanguageRegistry { self.state.read().language_settings.clone() } - pub fn language_names(&self) -> Vec<String> { + pub fn language_names(&self) -> Vec<LanguageName> { let state = self.state.read(); let mut result = state .available_languages .iter() - .filter_map(|l| l.loaded.not().then_some(l.name.to_string())) - .chain(state.languages.iter().map(|l| l.config.name.to_string())) + .filter_map(|l| l.loaded.not().then_some(l.name.clone())) + .chain(state.languages.iter().map(|l| l.config.name.clone())) .collect::<Vec<_>>(); - result.sort_unstable_by_key(|language_name| language_name.to_lowercase()); + result.sort_unstable_by_key(|language_name| language_name.as_ref().to_lowercase()); result } diff --git a/crates/language/src/syntax_map.rs b/crates/language/src/syntax_map.rs index f441114a90..c56ffed066 100644 --- a/crates/language/src/syntax_map.rs +++ b/crates/language/src/syntax_map.rs @@ -17,7 +17,7 @@ use std::{ sync::Arc, }; use streaming_iterator::StreamingIterator; -use sum_tree::{Bias, SeekTarget, SumTree}; +use sum_tree::{Bias, Dimensions, SeekTarget, SumTree}; use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, Rope, ToOffset, ToPoint}; use tree_sitter::{Node, Query, QueryCapture, QueryCaptures, QueryCursor, QueryMatches, Tree}; @@ -285,7 +285,7 @@ impl SyntaxSnapshot { pub fn interpolate(&mut self, text: &BufferSnapshot) { let edits = text - .anchored_edits_since::<(usize, Point)>(&self.interpolated_version) + .anchored_edits_since::<Dimensions<usize, Point>>(&self.interpolated_version) .collect::<Vec<_>>(); self.interpolated_version = text.version().clone(); @@ -333,7 +333,8 @@ impl SyntaxSnapshot { }; let Some(layer) = cursor.item() else { break }; - let (start_byte, start_point) = layer.range.start.summary::<(usize, Point)>(text); + let Dimensions(start_byte, start_point, _) = + layer.range.start.summary::<Dimensions<usize, Point>>(text); // Ignore edits that end before the start of this layer, and don't consider them // for any subsequent layers at this same depth. @@ -562,8 +563,8 @@ impl SyntaxSnapshot { } let Some(step) = step else { break }; - let (step_start_byte, step_start_point) = - step.range.start.summary::<(usize, Point)>(text); + let Dimensions(step_start_byte, step_start_point, _) = + step.range.start.summary::<Dimensions<usize, Point>>(text); let step_end_byte = step.range.end.to_offset(text); let mut old_layer = cursor.item(); diff --git a/crates/language_extension/src/extension_lsp_adapter.rs b/crates/language_extension/src/extension_lsp_adapter.rs index 58fbe6cda2..98b6fd4b5a 100644 --- a/crates/language_extension/src/extension_lsp_adapter.rs +++ b/crates/language_extension/src/extension_lsp_adapter.rs @@ -242,7 +242,7 @@ impl LspAdapter for ExtensionLspAdapter { ])) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { // TODO: The language IDs can be provided via the language server options // in `extension.toml now but we're leaving these existing usages in place temporarily // to avoid any compatibility issues between Zed and the extension versions. @@ -250,7 +250,7 @@ impl LspAdapter for ExtensionLspAdapter { // We can remove once the following extension versions no longer see any use: // - php@0.0.1 if self.extension.manifest().id.as_ref() == "php" { - return HashMap::from_iter([("PHP".into(), "php".into())]); + return HashMap::from_iter([(LanguageName::new("PHP"), "php".into())]); } self.extension diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index b718c530f5..841be60b0e 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,6 +20,7 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true base64.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true @@ -37,7 +38,6 @@ telemetry_events.workspace = true thiserror.workspace = true util.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index d54db7554a..a9c7d5c034 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -92,7 +92,12 @@ pub struct ToolUseRequest { pub struct FakeLanguageModel { provider_id: LanguageModelProviderId, provider_name: LanguageModelProviderName, - current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>, + current_completion_txs: Mutex< + Vec<( + LanguageModelRequest, + mpsc::UnboundedSender<LanguageModelCompletionEvent>, + )>, + >, } impl Default for FakeLanguageModel { @@ -118,10 +123,21 @@ impl FakeLanguageModel { self.current_completion_txs.lock().len() } - pub fn stream_completion_response( + pub fn send_completion_stream_text_chunk( &self, request: &LanguageModelRequest, chunk: impl Into<String>, + ) { + self.send_completion_stream_event( + request, + LanguageModelCompletionEvent::Text(chunk.into()), + ); + } + + pub fn send_completion_stream_event( + &self, + request: &LanguageModelRequest, + event: impl Into<LanguageModelCompletionEvent>, ) { let current_completion_txs = self.current_completion_txs.lock(); let tx = current_completion_txs @@ -129,7 +145,7 @@ impl FakeLanguageModel { .find(|(req, _)| req == request) .map(|(_, tx)| tx) .unwrap(); - tx.unbounded_send(chunk.into()).unwrap(); + tx.unbounded_send(event.into()).unwrap(); } pub fn end_completion_stream(&self, request: &LanguageModelRequest) { @@ -138,8 +154,15 @@ impl FakeLanguageModel { .retain(|(req, _)| req != request); } - pub fn stream_last_completion_response(&self, chunk: impl Into<String>) { - self.stream_completion_response(self.pending_completions().last().unwrap(), chunk); + pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into<String>) { + self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk); + } + + pub fn send_last_completion_stream_event( + &self, + event: impl Into<LanguageModelCompletionEvent>, + ) { + self.send_completion_stream_event(self.pending_completions().last().unwrap(), event); } pub fn end_last_completion_stream(&self) { @@ -201,12 +224,7 @@ impl LanguageModel for FakeLanguageModel { > { let (tx, rx) = mpsc::unbounded(); self.current_completion_txs.lock().push((request, tx)); - async move { - Ok(rx - .map(|text| Ok(LanguageModelCompletionEvent::Text(text))) - .boxed()) - } - .boxed() + async move { Ok(rx.map(Ok).boxed()) }.boxed() } fn as_fake(&self) -> &Self { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 54640419b6..1637d2de8a 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -11,6 +11,7 @@ pub mod fake_provider; use anthropic::{AnthropicError, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::Client; +use cloud_llm_client::{CompletionMode, CompletionRequestStatus}; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; @@ -26,7 +27,6 @@ use std::time::Duration; use std::{fmt, io}; use thiserror::Error; use util::serde::is_default; -use zed_llm_client::{CompletionMode, CompletionRequestStatus}; pub use crate::model::*; pub use crate::rate_limiter::*; diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 72b7132c60..8ae5893410 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,10 +3,11 @@ use std::sync::Arc; use anyhow::Result; use client::Client; +use cloud_llm_client::Plan; use gpui::{ App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, }; -use proto::{Plan, TypedEnvelope}; +use proto::TypedEnvelope; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -30,7 +31,7 @@ pub struct ModelRequestLimitReachedError { impl fmt::Display for ModelRequestLimitReachedError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let message = match self.plan { - Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.", + Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.", Plan::ZedPro => { "Model request limit reached. Upgrade to usage-based billing for more requests." } @@ -64,9 +65,14 @@ impl LlmApiToken { mut lock: RwLockWriteGuard<'_, Option<String>>, client: &Arc<Client>, ) -> Result<String> { - let response = client.request(proto::GetLlmToken {}).await?; - *lock = Some(response.token.clone()); - Ok(response.token.clone()) + let system_id = client + .telemetry() + .system_id() + .map(|system_id| system_id.to_string()); + + let response = client.cloud_client().create_llm_token(system_id).await?; + *lock = Some(response.token.0.clone()); + Ok(response.token.0.clone()) } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 6f3d420ad5..dc485e9937 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,10 +1,9 @@ use std::io::{Cursor, Write}; use std::sync::Arc; -use crate::role::Role; -use crate::{LanguageModelToolUse, LanguageModelToolUseId}; use anyhow::Result; use base64::write::EncoderWriter; +use cloud_llm_client::{CompletionIntent, CompletionMode}; use gpui::{ App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task, point, px, size, @@ -12,7 +11,9 @@ use gpui::{ use image::codecs::png::PngEncoder; use serde::{Deserialize, Serialize}; use util::ResultExt; -use zed_llm_client::{CompletionIntent, CompletionMode}; + +use crate::role::Role; +use crate::{LanguageModelToolUse, LanguageModelToolUseId}; #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct LanguageModelImage { diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 574579aaa7..b5bfb870f6 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -16,18 +16,17 @@ ai_onboarding.workspace = true anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true aws-config = { workspace = true, features = ["behavior-version-latest"] } -aws-credential-types = { workspace = true, features = [ - "hardcoded-credentials", -] } +aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] } aws_http_client.workspace = true bedrock.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true component.workspace = true -credentials_provider.workspace = true convert_case.workspace = true copilot.workspace = true +credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } editor.workspace = true futures.workspace = true @@ -35,6 +34,7 @@ google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true gpui_tokio.workspace = true http_client.workspace = true +language.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true @@ -43,10 +43,7 @@ mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } -vercel = { workspace = true, features = ["schemars"] } -x_ai = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true -proto.workspace = true release_channel.workspace = true schemars.workspace = true serde.workspace = true @@ -61,9 +58,9 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true ui_input.workspace = true util.workspace = true +vercel = { workspace = true, features = ["schemars"] } workspace-hack.workspace = true -zed_llm_client.workspace = true -language.workspace = true +x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] editor = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 959cbccf39..ef21e85f71 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1012,7 +1012,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with Anthropic, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:")) .child( List::new() .child( diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index a86b3e78f5..6df96c5c56 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -1251,7 +1251,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(ConfigurationView::save_credentials)) - .child(Label::new("To use Zed's assistant with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials.")) + .child(Label::new("To use Zed's agent with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials.")) .child(Label::new("But, to access models on AWS, you need to:").mt_1()) .child( List::new() diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index fac8810714..40dd120761 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -3,6 +3,13 @@ use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use client::{Client, ModelRequestUsage, UserStore, zed_urls}; +use cloud_llm_client::{ + CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, + CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, + EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, + SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, + TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, +}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; @@ -20,7 +27,6 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, }; -use proto::Plan; use release_channel::AppVersion; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -33,13 +39,6 @@ use std::time::Duration; use thiserror::Error; use ui::{TintColor, prelude::*}; use util::{ResultExt as _, maybe}; -use zed_llm_client::{ - CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, - CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, - ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, - SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, - TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, -}; use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; use crate::provider::google::{GoogleEventMapper, into_google}; @@ -120,10 +119,10 @@ pub struct State { user_store: Entity<UserStore>, status: client::Status, accept_terms_of_service_task: Option<Task<Result<()>>>, - models: Vec<Arc<zed_llm_client::LanguageModel>>, - default_model: Option<Arc<zed_llm_client::LanguageModel>>, - default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>, - recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>, + models: Vec<Arc<cloud_llm_client::LanguageModel>>, + default_model: Option<Arc<cloud_llm_client::LanguageModel>>, + default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>, + recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>, _fetch_models_task: Task<()>, _settings_subscription: Subscription, _llm_token_subscription: Subscription, @@ -137,11 +136,11 @@ impl State { cx: &mut Context<Self>, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - + let mut current_user = user_store.read(cx).watch_current_user(); Self { client: client.clone(), llm_api_token: LlmApiToken::default(), - user_store, + user_store: user_store.clone(), status, accept_terms_of_service_task: None, models: Vec::new(), @@ -153,21 +152,14 @@ impl State { let (client, llm_api_token) = this .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; - loop { - let status = this.read_with(cx, |this, _cx| this.status)?; - if matches!(status, client::Status::Connected { .. }) { - break; - } - - cx.background_executor() - .timer(Duration::from_millis(100)) - .await; + while current_user.borrow().is_none() { + current_user.next().await; } - let response = Self::fetch_models(client, llm_api_token).await?; - this.update(cx, |this, cx| { - this.update_models(response, cx); - }) + let response = + Self::fetch_models(client.clone(), llm_api_token.clone()).await?; + this.update(cx, |this, cx| this.update_models(response, cx))?; + anyhow::Ok(()) }) .await .context("failed to fetch Zed models") @@ -194,26 +186,20 @@ impl State { } } - fn is_signed_out(&self) -> bool { - self.status.is_signed_out() + fn is_signed_out(&self, cx: &App) -> bool { + self.user_store.read(cx).current_user().is_none() } fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { let client = self.client.clone(); cx.spawn(async move |state, cx| { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.sign_in_with_optional_connect(true, &cx).await?; state.update(cx, |_, cx| cx.notify()) }) } fn has_accepted_terms_of_service(&self, cx: &App) -> bool { - self.user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false) + self.user_store.read(cx).has_accepted_terms_of_service() } fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) { @@ -238,8 +224,8 @@ impl State { // Right now we represent thinking variants of models as separate models on the client, // so we need to insert variants for any model that supports thinking. if model.supports_thinking { - models.push(Arc::new(zed_llm_client::LanguageModel { - id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), + models.push(Arc::new(cloud_llm_client::LanguageModel { + id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), display_name: format!("{} Thinking", model.display_name), ..model })); @@ -328,7 +314,7 @@ impl CloudLanguageModelProvider { fn create_language_model( &self, - model: Arc<zed_llm_client::LanguageModel>, + model: Arc<cloud_llm_client::LanguageModel>, llm_api_token: LlmApiToken, ) -> Arc<dyn LanguageModel> { Arc::new(CloudLanguageModel { @@ -398,7 +384,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn is_authenticated(&self, cx: &App) -> bool { let state = self.state.read(cx); - !state.is_signed_out() && state.has_accepted_terms_of_service(cx) + !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx) } fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> { @@ -518,7 +504,7 @@ fn render_accept_terms( pub struct CloudLanguageModel { id: LanguageModelId, - model: Arc<zed_llm_client::LanguageModel>, + model: Arc<cloud_llm_client::LanguageModel>, llm_api_token: LlmApiToken, client: Arc<Client>, request_limiter: RateLimiter, @@ -611,13 +597,8 @@ impl CloudLanguageModel { .headers() .get(CURRENT_PLAN_HEADER_NAME) .and_then(|plan| plan.to_str().ok()) - .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok()) + .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok()) { - let plan = match plan { - zed_llm_client::Plan::ZedFree => Plan::Free, - zed_llm_client::Plan::ZedPro => Plan::ZedPro, - zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial, - }; return Err(anyhow!(ModelRequestLimitReachedError { plan })); } } @@ -729,7 +710,7 @@ impl LanguageModel for CloudLanguageModel { } fn upstream_provider_id(&self) -> LanguageModelProviderId { - use zed_llm_client::LanguageModelProvider::*; + use cloud_llm_client::LanguageModelProvider::*; match self.model.provider { Anthropic => language_model::ANTHROPIC_PROVIDER_ID, OpenAi => language_model::OPEN_AI_PROVIDER_ID, @@ -738,7 +719,7 @@ impl LanguageModel for CloudLanguageModel { } fn upstream_provider_name(&self) -> LanguageModelProviderName { - use zed_llm_client::LanguageModelProvider::*; + use cloud_llm_client::LanguageModelProvider::*; match self.model.provider { Anthropic => language_model::ANTHROPIC_PROVIDER_NAME, OpenAi => language_model::OPEN_AI_PROVIDER_NAME, @@ -772,11 +753,11 @@ impl LanguageModel for CloudLanguageModel { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { match self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic - | zed_llm_client::LanguageModelProvider::OpenAi => { + cloud_llm_client::LanguageModelProvider::Anthropic + | cloud_llm_client::LanguageModelProvider::OpenAi => { LanguageModelToolSchemaFormat::JsonSchema } - zed_llm_client::LanguageModelProvider::Google => { + cloud_llm_client::LanguageModelProvider::Google => { LanguageModelToolSchemaFormat::JsonSchemaSubset } } @@ -795,15 +776,15 @@ impl LanguageModel for CloudLanguageModel { fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> { match &self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic => { + cloud_llm_client::LanguageModelProvider::Anthropic => { Some(LanguageModelCacheConfiguration { min_total_token: 2_048, should_speculate: true, max_cache_anchors: 4, }) } - zed_llm_client::LanguageModelProvider::OpenAi - | zed_llm_client::LanguageModelProvider::Google => None, + cloud_llm_client::LanguageModelProvider::OpenAi + | cloud_llm_client::LanguageModelProvider::Google => None, } } @@ -813,15 +794,17 @@ impl LanguageModel for CloudLanguageModel { cx: &App, ) -> BoxFuture<'static, Result<u64>> { match self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx), - zed_llm_client::LanguageModelProvider::OpenAi => { + cloud_llm_client::LanguageModelProvider::Anthropic => { + count_anthropic_tokens(request, cx) + } + cloud_llm_client::LanguageModelProvider::OpenAi => { let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model, Err(err) => return async move { Err(anyhow!(err)) }.boxed(), }; count_open_ai_tokens(request, model, cx) } - zed_llm_client::LanguageModelProvider::Google => { + cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let model_id = self.model.id.to_string(); @@ -832,7 +815,7 @@ impl LanguageModel for CloudLanguageModel { let token = llm_api_token.acquire(&client).await?; let request_body = CountTokensBody { - provider: zed_llm_client::LanguageModelProvider::Google, + provider: cloud_llm_client::LanguageModelProvider::Google, model: model_id, provider_request: serde_json::to_value(&google_ai::CountTokensRequest { generate_content_request, @@ -893,7 +876,7 @@ impl LanguageModel for CloudLanguageModel { let app_version = cx.update(|cx| AppVersion::global(cx)).ok(); let thinking_allowed = request.thinking_allowed; match self.model.provider { - zed_llm_client::LanguageModelProvider::Anthropic => { + cloud_llm_client::LanguageModelProvider::Anthropic => { let request = into_anthropic( request, self.model.id.to_string(), @@ -924,7 +907,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: zed_llm_client::LanguageModelProvider::Anthropic, + provider: cloud_llm_client::LanguageModelProvider::Anthropic, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -948,7 +931,7 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } - zed_llm_client::LanguageModelProvider::OpenAi => { + cloud_llm_client::LanguageModelProvider::OpenAi => { let client = self.client.clone(); let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model, @@ -976,7 +959,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: zed_llm_client::LanguageModelProvider::OpenAi, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -996,7 +979,7 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } - zed_llm_client::LanguageModelProvider::Google => { + cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let request = into_google(request, self.model.id.to_string(), GoogleModelMode::Default); @@ -1016,7 +999,7 @@ impl LanguageModel for CloudLanguageModel { prompt_id, intent, mode, - provider: zed_llm_client::LanguageModelProvider::Google, + provider: cloud_llm_client::LanguageModelProvider::Google, model: request.model.model_id.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -1040,15 +1023,8 @@ impl LanguageModel for CloudLanguageModel { } } -#[derive(Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CloudCompletionEvent<T> { - Status(CompletionRequestStatus), - Event(T), -} - fn map_cloud_completion_events<T, F>( - stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>, + stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>, mut map_callback: F, ) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> where @@ -1063,10 +1039,10 @@ where Err(error) => { vec![Err(LanguageModelCompletionError::from(error))] } - Ok(CloudCompletionEvent::Status(event)) => { + Ok(CompletionEvent::Status(event)) => { vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] } - Ok(CloudCompletionEvent::Event(event)) => map_callback(event), + Ok(CompletionEvent::Event(event)) => map_callback(event), }) }) .boxed() @@ -1074,9 +1050,9 @@ where fn usage_updated_event<T>( usage: Option<ModelRequestUsage>, -) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { +) -> impl Stream<Item = Result<CompletionEvent<T>>> { futures::stream::iter(usage.map(|usage| { - Ok(CloudCompletionEvent::Status( + Ok(CompletionEvent::Status( CompletionRequestStatus::UsageUpdated { amount: usage.amount as usize, limit: usage.limit, @@ -1087,9 +1063,9 @@ fn usage_updated_event<T>( fn tool_use_limit_reached_event<T>( tool_use_limit_reached: bool, -) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { +) -> impl Stream<Item = Result<CompletionEvent<T>>> { futures::stream::iter(tool_use_limit_reached.then(|| { - Ok(CloudCompletionEvent::Status( + Ok(CompletionEvent::Status( CompletionRequestStatus::ToolUseLimitReached, )) })) @@ -1098,7 +1074,7 @@ fn tool_use_limit_reached_event<T>( fn response_lines<T: DeserializeOwned>( response: Response<AsyncBody>, includes_status_messages: bool, -) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { +) -> impl Stream<Item = Result<CompletionEvent<T>>> { futures::stream::try_unfold( (String::new(), BufReader::new(response.into_body())), move |(mut line, mut body)| async move { @@ -1106,9 +1082,9 @@ fn response_lines<T: DeserializeOwned>( Ok(0) => Ok(None), Ok(_) => { let event = if includes_status_messages { - serde_json::from_str::<CloudCompletionEvent<T>>(&line)? + serde_json::from_str::<CompletionEvent<T>>(&line)? } else { - CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?) + CompletionEvent::Event(serde_json::from_str::<T>(&line)?) }; line.clear(); @@ -1123,7 +1099,7 @@ fn response_lines<T: DeserializeOwned>( #[derive(IntoElement, RegisterComponent)] struct ZedAiConfiguration { is_connected: bool, - plan: Option<proto::Plan>, + plan: Option<Plan>, subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>, eligible_for_trial: bool, has_accepted_terms_of_service: bool, @@ -1137,15 +1113,15 @@ impl RenderOnce for ZedAiConfiguration { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { let young_account_banner = YoungAccountBanner; - let is_pro = self.plan == Some(proto::Plan::ZedPro); + let is_pro = self.plan == Some(Plan::ZedPro); let subscription_text = match (self.plan, self.subscription_period) { - (Some(proto::Plan::ZedPro), Some(_)) => { + (Some(Plan::ZedPro), Some(_)) => { "You have access to Zed's hosted models through your Pro subscription." } - (Some(proto::Plan::ZedProTrial), Some(_)) => { + (Some(Plan::ZedProTrial), Some(_)) => { "You have access to Zed's hosted models through your Pro trial." } - (Some(proto::Plan::Free), Some(_)) => { + (Some(Plan::ZedFree), Some(_)) => { "You have basic access to Zed's hosted models through the Free plan." } _ => { @@ -1159,19 +1135,20 @@ impl RenderOnce for ZedAiConfiguration { let manage_subscription_buttons = if is_pro { Button::new("manage_settings", "Manage Subscription") + .full_width() .style(ButtonStyle::Tinted(TintColor::Accent)) .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))) .into_any_element() } else if self.plan.is_none() || self.eligible_for_trial { Button::new("start_trial", "Start 14-day Free Pro Trial") - .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) .full_width() + .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx))) .into_any_element() } else { Button::new("upgrade", "Upgrade to Pro") - .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) .full_width() + .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))) .into_any_element() }; @@ -1269,8 +1246,8 @@ impl Render for ConfigurationView { let user_store = state.user_store.read(cx); ZedAiConfiguration { - is_connected: !state.is_signed_out(), - plan: user_store.current_plan(), + is_connected: !state.is_signed_out(cx), + plan: user_store.plan(), subscription_period: user_store.subscription_period(), eligible_for_trial: user_store.trial_started_at().is_none(), has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), @@ -1283,14 +1260,22 @@ impl Render for ConfigurationView { } impl Component for ZedAiConfiguration { + fn name() -> &'static str { + "AI Configuration Content" + } + + fn sort_name() -> &'static str { + "AI Configuration Content" + } + fn scope() -> ComponentScope { - ComponentScope::Agent + ComponentScope::Onboarding } fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { fn configuration( is_connected: bool, - plan: Option<proto::Plan>, + plan: Option<Plan>, eligible_for_trial: bool, account_too_young: bool, has_accepted_terms_of_service: bool, @@ -1334,15 +1319,15 @@ impl Component for ZedAiConfiguration { ), single_example( "Free Plan", - configuration(true, Some(proto::Plan::Free), true, false, true), + configuration(true, Some(Plan::ZedFree), true, false, true), ), single_example( "Zed Pro Trial Plan", - configuration(true, Some(proto::Plan::ZedProTrial), true, false, true), + configuration(true, Some(Plan::ZedProTrial), true, false, true), ), single_example( "Zed Pro Plan", - configuration(true, Some(proto::Plan::ZedPro), true, false, true), + configuration(true, Some(Plan::ZedPro), true, false, true), ), ]) .into_any_element(), diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index d9a84f1eb7..73f73a9a31 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -3,6 +3,7 @@ use std::str::FromStr as _; use std::sync::Arc; use anyhow::{Result, anyhow}; +use cloud_llm_client::CompletionIntent; use collections::HashMap; use copilot::copilot_chat::{ ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl, @@ -30,7 +31,6 @@ use settings::SettingsStore; use std::time::Duration; use ui::prelude::*; use util::debug_panic; -use zed_llm_client::CompletionIntent; use super::anthropic::count_anthropic_tokens; use super::google::count_google_tokens; @@ -706,7 +706,8 @@ impl Render for ConfigurationView { .child(svg().size_8().path(IconName::CopilotError.path())) } _ => { - const LABEL: &str = "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription."; + const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription."; + v_flex().gap_2().child(Label::new(LABEL)).child( Button::new("sign_in", "Sign in to use GitHub Copilot") .icon_color(Color::Muted) diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index bd8a09970a..b287e8181a 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -880,7 +880,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with Google AI, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with Google AI, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 01600f3646..9792b4f27b 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -744,7 +744,7 @@ impl Render for ConfigurationView { Button::new("retry_lmstudio_models", "Connect") .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) - .icon(IconName::Play) + .icon(IconName::PlayOutlined) .on_click(cx.listener(move |this, _, _window, cx| { this.retry_connection(cx) })), diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index fb385308fa..02e53cb99a 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -807,7 +807,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with Mistral, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index dc81e8be18..c845c97b09 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -192,12 +192,16 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { IconName::AiOllama } - fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { - self.provided_models(cx).into_iter().next() + fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> { + // We shouldn't try to select default model, because it might lead to a load call for an unloaded model. + // In a constrained environment where user might not have enough resources it'll be a bad UX to select something + // to load by default. + None } - fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { - self.default_model(cx) + fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> { + // See explanation for default_model. + None } fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { @@ -627,7 +631,7 @@ impl Render for ConfigurationView { } }) .child( - Button::new("view-models", "All Models") + Button::new("view-models", "View All Models") .style(ButtonStyle::Subtle) .icon(IconName::ArrowUpRight) .icon_size(IconSize::XSmall) @@ -654,7 +658,7 @@ impl Render for ConfigurationView { Button::new("retry_ollama_models", "Connect") .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) - .icon(IconName::Play) + .icon(IconName::PlayOutlined) .on_click(cx.listener(move |this, _, _, cx| { this.retry_connection(cx) })), diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 6c4d4c9b3e..5185e979b7 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -780,7 +780,7 @@ impl Render for ConfigurationView { let api_key_section = if self.should_render_editor(cx) { v_flex() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( @@ -868,7 +868,7 @@ impl Render for ConfigurationView { .icon_size(IconSize::XSmall) .icon_color(Color::Muted) .on_click(move |_, _window, cx| { - cx.open_url("https://zed.dev/docs/ai/configuration#openai-api-compatible") + cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible") }), ); diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 64add5483d..38bd7cee06 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -466,7 +466,7 @@ impl Render for ConfigurationView { let api_key_section = if self.should_render_editor(cx) { v_flex() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with an OpenAI compatible provider, you need to add an API key.")) + .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key.")) .child( div() .pt(DynamicSpacing::Base04.rems(cx)) diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 5a6acc4329..3a492086f1 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -855,7 +855,7 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's assistant with OpenRouter, you need to add an API key. Follow these steps:")) + .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( diff --git a/crates/language_selector/src/language_selector.rs b/crates/language_selector/src/language_selector.rs index 4c03430553..f6e2d75015 100644 --- a/crates/language_selector/src/language_selector.rs +++ b/crates/language_selector/src/language_selector.rs @@ -86,7 +86,10 @@ impl LanguageSelector { impl Render for LanguageSelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("LanguageSelector") + .w(rems(34.)) + .child(self.picker.clone()) } } @@ -121,13 +124,13 @@ impl LanguageSelectorDelegate { .into_iter() .filter_map(|name| { language_registry - .available_language_for_name(&name)? + .available_language_for_name(name.as_ref())? .hidden() .not() .then_some(name) }) .enumerate() - .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, &name)) + .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, name.as_ref())) .collect::<Vec<_>>(); Self { diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index d1a90d7dbb..606f3a3f0e 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -253,8 +253,8 @@ impl LogStore { let copilot_subscription = Copilot::global(cx).map(|copilot| { let copilot = &copilot; - cx.subscribe(copilot, |this, copilot, inline_completion_event, cx| { - if let copilot::Event::CopilotLanguageServerStarted = inline_completion_event { + cx.subscribe(copilot, |this, copilot, edit_prediction_event, cx| { + if let copilot::Event::CopilotLanguageServerStarted = edit_prediction_event { if let Some(server) = copilot.read(cx).language_server() { let server_id = server.server_id(); let weak_this = cx.weak_entity(); @@ -867,7 +867,7 @@ impl LspLogView { BINARY = server.binary(), WORKSPACE_FOLDERS = server .workspace_folders() - .iter() + .into_iter() .filter_map(|path| path .to_file_path() .ok() diff --git a/crates/language_tools/src/lsp_tool.rs b/crates/language_tools/src/lsp_tool.rs index 9e95ed4673..50547253a9 100644 --- a/crates/language_tools/src/lsp_tool.rs +++ b/crates/language_tools/src/lsp_tool.rs @@ -1015,7 +1015,7 @@ impl Render for LspTool { .anchor(Corner::BottomLeft) .with_handle(self.popover_menu_handle.clone()) .trigger_with_tooltip( - IconButton::new("zed-lsp-tool-button", IconName::BoltFilledAlt) + IconButton::new("zed-lsp-tool-button", IconName::BoltOutlined) .when_some(indicator, IconButton::indicator) .icon_size(IconSize::Small) .indicator_border_color(Some(cx.theme().colors().status_bar_background)), diff --git a/crates/languages/Cargo.toml b/crates/languages/Cargo.toml index 2e8f007cff..8e25818070 100644 --- a/crates/languages/Cargo.toml +++ b/crates/languages/Cargo.toml @@ -36,11 +36,13 @@ load-grammars = [ [dependencies] anyhow.workspace = true async-compression.workspace = true +async-fs.workspace = true async-tar.workspace = true async-trait.workspace = true chrono.workspace = true collections.workspace = true dap.workspace = true +feature_flags.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true @@ -61,6 +63,7 @@ regex.workspace = true rope.workspace = true rust-embed.workspace = true schemars.workspace = true +sha2.workspace = true serde.workspace = true serde_json.workspace = true serde_json_lenient.workspace = true @@ -68,6 +71,7 @@ settings.workspace = true smol.workspace = true snippet_provider.workspace = true task.workspace = true +tempfile.workspace = true toml.workspace = true tree-sitter = { workspace = true, optional = true } tree-sitter-bash = { workspace = true, optional = true } diff --git a/crates/languages/src/bash/config.toml b/crates/languages/src/bash/config.toml index db9a2749e7..8ff4802aee 100644 --- a/crates/languages/src/bash/config.toml +++ b/crates/languages/src/bash/config.toml @@ -18,17 +18,20 @@ brackets = [ { start = "in", end = "esac", close = false, newline = true, not_in = ["comment", "string"] }, ] -### WARN: the following is not working when you insert an `elif` just before an else -### example: (^ is cursor after hitting enter) -### ``` -### if true; then -### foo -### elif -### ^ -### else -### bar -### fi -### ``` -increase_indent_pattern = "(^|\\s+|;)(do|then|in|else|elif)\\b.*$" -decrease_indent_pattern = "(^|\\s+|;)(fi|done|esac|else|elif)\\b.*$" -# make sure to test each line mode & block mode +auto_indent_using_last_non_empty_line = false +increase_indent_pattern = "^\\s*(\\b(else|elif)\\b|([^#]+\\b(do|then|in)\\b)|([\\w\\*]+\\)))\\s*$" +decrease_indent_patterns = [ + { pattern = "^\\s*elif\\b.*", valid_after = ["if", "elif"] }, + { pattern = "^\\s*else\\b.*", valid_after = ["if", "elif", "for", "while"] }, + { pattern = "^\\s*fi\\b.*", valid_after = ["if", "elif", "else"] }, + { pattern = "^\\s*done\\b.*", valid_after = ["for", "while"] }, + { pattern = "^\\s*esac\\b.*", valid_after = ["case"] }, + { pattern = "^\\s*[\\w\\*]+\\)\\s*$", valid_after = ["case_item"] }, +] + +# We can't use decrease_indent_patterns simply for elif, because +# there is bug in tree sitter which throws ERROR on if match. +# +# This is workaround. That means, elif will outdents with despite +# of wrong context. Like using elif after else. +decrease_indent_pattern = "(^|\\s+|;)(elif)\\b.*$" diff --git a/crates/languages/src/bash/indents.scm b/crates/languages/src/bash/indents.scm index acdcddabfe..468fc595e5 100644 --- a/crates/languages/src/bash/indents.scm +++ b/crates/languages/src/bash/indents.scm @@ -1,12 +1,12 @@ -(function_definition - "function"? - body: ( - _ - "{" @start - "}" @end - )) @indent +(_ "[" "]" @end) @indent +(_ "{" "}" @end) @indent +(_ "(" ")" @end) @indent -(array - "(" @start - ")" @end - ) @indent +(function_definition) @start.function +(if_statement) @start.if +(elif_clause) @start.elif +(else_clause) @start.else +(for_statement) @start.for +(while_statement) @start.while +(case_statement) @start.case +(case_item) @start.case_item diff --git a/crates/languages/src/c.rs b/crates/languages/src/c.rs index c06c35ee69..a55d8ff998 100644 --- a/crates/languages/src/c.rs +++ b/crates/languages/src/c.rs @@ -2,14 +2,16 @@ use anyhow::{Context as _, Result, bail}; use async_trait::async_trait; use futures::StreamExt; use gpui::{App, AsyncApp}; -use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; +use http_client::github::{AssetKind, GitHubLspBinaryVersion, latest_github_release}; pub use language::*; use lsp::{InitializeParams, LanguageServerBinary, LanguageServerName}; use project::lsp_store::clangd_ext; use serde_json::json; use smol::fs; use std::{any::Any, env::consts, path::PathBuf, sync::Arc}; -use util::{ResultExt, archive::extract_zip, fs::remove_matching, maybe, merge_json_value_into}; +use util::{ResultExt, fs::remove_matching, maybe, merge_json_value_into}; + +use crate::github_download::{GithubBinaryMetadata, download_server_binary}; pub struct CLspAdapter; @@ -58,6 +60,7 @@ impl super::LspAdapter for CLspAdapter { let version = GitHubLspBinaryVersion { name: release.tag_name, url: asset.browser_download_url.clone(), + digest: asset.digest.clone(), }; Ok(Box::new(version) as Box<_>) } @@ -68,32 +71,67 @@ impl super::LspAdapter for CLspAdapter { container_dir: PathBuf, delegate: &dyn LspAdapterDelegate, ) -> Result<LanguageServerBinary> { - let version = version.downcast::<GitHubLspBinaryVersion>().unwrap(); - let version_dir = container_dir.join(format!("clangd_{}", version.name)); + let GitHubLspBinaryVersion { name, url, digest } = + &*version.downcast::<GitHubLspBinaryVersion>().unwrap(); + let version_dir = container_dir.join(format!("clangd_{name}")); let binary_path = version_dir.join("bin/clangd"); - if fs::metadata(&binary_path).await.is_err() { - let mut response = delegate - .http_client() - .get(&version.url, Default::default(), true) - .await - .context("error downloading release")?; - anyhow::ensure!( - response.status().is_success(), - "download failed with status {}", - response.status().to_string() - ); - extract_zip(&container_dir, response.body_mut()) - .await - .with_context(|| format!("unzipping clangd archive to {container_dir:?}"))?; - remove_matching(&container_dir, |entry| entry != version_dir).await; - } - - Ok(LanguageServerBinary { - path: binary_path, + let binary = LanguageServerBinary { + path: binary_path.clone(), env: None, - arguments: Vec::new(), - }) + arguments: Default::default(), + }; + + let metadata_path = version_dir.join("metadata"); + let metadata = GithubBinaryMetadata::read_from_file(&metadata_path) + .await + .ok(); + if let Some(metadata) = metadata { + let validity_check = async || { + delegate + .try_exec(LanguageServerBinary { + path: binary_path.clone(), + arguments: vec!["--version".into()], + env: None, + }) + .await + .inspect_err(|err| { + log::warn!("Unable to run {binary_path:?} asset, redownloading: {err}",) + }) + }; + if let (Some(actual_digest), Some(expected_digest)) = (&metadata.digest, digest) { + if actual_digest == expected_digest { + if validity_check().await.is_ok() { + return Ok(binary); + } + } else { + log::info!( + "SHA-256 mismatch for {binary_path:?} asset, downloading new asset. Expected: {expected_digest}, Got: {actual_digest}" + ); + } + } else if validity_check().await.is_ok() { + return Ok(binary); + } + } + download_server_binary( + delegate, + url, + digest.as_deref(), + &container_dir, + AssetKind::Zip, + ) + .await?; + remove_matching(&container_dir, |entry| entry != version_dir).await; + GithubBinaryMetadata::write_to_file( + &GithubBinaryMetadata { + metadata_version: 1, + digest: digest.clone(), + }, + &metadata_path, + ) + .await?; + + Ok(binary) } async fn cached_server_binary( diff --git a/crates/languages/src/css.rs b/crates/languages/src/css.rs index f2a94809a0..7725e079be 100644 --- a/crates/languages/src/css.rs +++ b/crates/languages/src/css.rs @@ -5,7 +5,7 @@ use gpui::AsyncApp; use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; -use project::Fs; +use project::{Fs, lsp_store::language_server_settings}; use serde_json::json; use smol::fs; use std::{ @@ -14,7 +14,7 @@ use std::{ path::{Path, PathBuf}, sync::Arc, }; -use util::{ResultExt, maybe}; +use util::{ResultExt, maybe, merge_json_value_into}; const SERVER_PATH: &str = "node_modules/vscode-langservers-extracted/bin/vscode-css-language-server"; @@ -134,6 +134,37 @@ impl LspAdapter for CssLspAdapter { "provideFormatter": true }))) } + + async fn workspace_configuration( + self: Arc<Self>, + _: &dyn Fs, + delegate: &Arc<dyn LspAdapterDelegate>, + _: Arc<dyn LanguageToolchainStore>, + cx: &mut AsyncApp, + ) -> Result<serde_json::Value> { + let mut default_config = json!({ + "css": { + "lint": {} + }, + "less": { + "lint": {} + }, + "scss": { + "lint": {} + } + }); + + let project_options = cx.update(|cx| { + language_server_settings(delegate.as_ref(), &self.name(), cx) + .and_then(|s| s.settings.clone()) + })?; + + if let Some(override_options) = project_options { + merge_json_value_into(override_options, &mut default_config); + } + + Ok(default_config) + } } async fn get_cached_server_binary( diff --git a/crates/languages/src/github_download.rs b/crates/languages/src/github_download.rs new file mode 100644 index 0000000000..a3cd0a964b --- /dev/null +++ b/crates/languages/src/github_download.rs @@ -0,0 +1,190 @@ +use std::{path::Path, pin::Pin, task::Poll}; + +use anyhow::{Context, Result}; +use async_compression::futures::bufread::GzipDecoder; +use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader}; +use http_client::github::AssetKind; +use language::LspAdapterDelegate; +use sha2::{Digest, Sha256}; + +#[derive(serde::Deserialize, serde::Serialize, Debug)] +pub(crate) struct GithubBinaryMetadata { + pub(crate) metadata_version: u64, + pub(crate) digest: Option<String>, +} + +impl GithubBinaryMetadata { + pub(crate) async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> { + let metadata_content = async_fs::read_to_string(metadata_path) + .await + .with_context(|| format!("reading metadata file at {metadata_path:?}"))?; + let metadata: GithubBinaryMetadata = serde_json::from_str(&metadata_content) + .with_context(|| format!("parsing metadata file at {metadata_path:?}"))?; + Ok(metadata) + } + + pub(crate) async fn write_to_file(&self, metadata_path: &Path) -> Result<()> { + let metadata_content = serde_json::to_string(self) + .with_context(|| format!("serializing metadata for {metadata_path:?}"))?; + async_fs::write(metadata_path, metadata_content.as_bytes()) + .await + .with_context(|| format!("writing metadata file at {metadata_path:?}"))?; + Ok(()) + } +} + +pub(crate) async fn download_server_binary( + delegate: &dyn LspAdapterDelegate, + url: &str, + digest: Option<&str>, + destination_path: &Path, + asset_kind: AssetKind, +) -> Result<(), anyhow::Error> { + log::info!("downloading github artifact from {url}"); + let mut response = delegate + .http_client() + .get(url, Default::default(), true) + .await + .with_context(|| format!("downloading release from {url}"))?; + let body = response.body_mut(); + match digest { + Some(expected_sha_256) => { + let temp_asset_file = tempfile::NamedTempFile::new() + .with_context(|| format!("creating a temporary file for {url}"))?; + let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts(); + let mut writer = HashingWriter { + writer: async_fs::File::from(temp_asset_file), + hasher: Sha256::new(), + }; + futures::io::copy(&mut BufReader::new(body), &mut writer) + .await + .with_context(|| { + format!("saving archive contents into the temporary file for {url}",) + })?; + let asset_sha_256 = format!("{:x}", writer.hasher.finalize()); + anyhow::ensure!( + asset_sha_256 == expected_sha_256, + "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}", + ); + writer + .writer + .seek(std::io::SeekFrom::Start(0)) + .await + .with_context(|| format!("seeking temporary file {destination_path:?}",))?; + stream_file_archive(&mut writer.writer, url, destination_path, asset_kind) + .await + .with_context(|| { + format!("extracting downloaded asset for {url} into {destination_path:?}",) + })?; + } + None => stream_response_archive(body, url, destination_path, asset_kind) + .await + .with_context(|| { + format!("extracting response for asset {url} into {destination_path:?}",) + })?, + } + Ok(()) +} + +async fn stream_response_archive( + response: impl AsyncRead + Unpin, + url: &str, + destination_path: &Path, + asset_kind: AssetKind, +) -> Result<()> { + match asset_kind { + AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?, + AssetKind::Gz => extract_gz(destination_path, url, response).await?, + AssetKind::Zip => { + util::archive::extract_zip(&destination_path, response).await?; + } + }; + Ok(()) +} + +async fn stream_file_archive( + file_archive: impl AsyncRead + AsyncSeek + Unpin, + url: &str, + destination_path: &Path, + asset_kind: AssetKind, +) -> Result<()> { + match asset_kind { + AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?, + AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?, + #[cfg(not(windows))] + AssetKind::Zip => { + util::archive::extract_seekable_zip(&destination_path, file_archive).await?; + } + #[cfg(windows)] + AssetKind::Zip => { + util::archive::extract_zip(&destination_path, file_archive).await?; + } + }; + Ok(()) +} + +async fn extract_tar_gz( + destination_path: &Path, + url: &str, + from: impl AsyncRead + Unpin, +) -> Result<(), anyhow::Error> { + let decompressed_bytes = GzipDecoder::new(BufReader::new(from)); + let archive = async_tar::Archive::new(decompressed_bytes); + archive + .unpack(&destination_path) + .await + .with_context(|| format!("extracting {url} to {destination_path:?}"))?; + Ok(()) +} + +async fn extract_gz( + destination_path: &Path, + url: &str, + from: impl AsyncRead + Unpin, +) -> Result<(), anyhow::Error> { + let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from)); + let mut file = smol::fs::File::create(&destination_path) + .await + .with_context(|| { + format!("creating a file {destination_path:?} for a download from {url}") + })?; + futures::io::copy(&mut decompressed_bytes, &mut file) + .await + .with_context(|| format!("extracting {url} to {destination_path:?}"))?; + Ok(()) +} + +struct HashingWriter<W: AsyncWrite + Unpin> { + writer: W, + hasher: Sha256, +} + +impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll<std::result::Result<usize, std::io::Error>> { + match Pin::new(&mut self.writer).poll_write(cx, buf) { + Poll::Ready(Ok(n)) => { + self.hasher.update(&buf[..n]); + Poll::Ready(Ok(n)) + } + other => other, + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.writer).poll_flush(cx) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), std::io::Error>> { + Pin::new(&mut self.writer).poll_close(cx) + } +} diff --git a/crates/languages/src/go/runnables.scm b/crates/languages/src/go/runnables.scm index 49e112b860..6418cd04d8 100644 --- a/crates/languages/src/go/runnables.scm +++ b/crates/languages/src/go/runnables.scm @@ -69,7 +69,7 @@ ( ( (function_declaration name: (_) @run @_name - (#match? @_name "^Benchmark.+")) + (#match? @_name "^Benchmark.*")) ) @_ (#set! tag go-benchmark) ) diff --git a/crates/languages/src/json.rs b/crates/languages/src/json.rs index 15818730b8..ca82bb2431 100644 --- a/crates/languages/src/json.rs +++ b/crates/languages/src/json.rs @@ -8,8 +8,8 @@ use futures::StreamExt; use gpui::{App, AsyncApp, Task}; use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; use language::{ - ContextProvider, LanguageRegistry, LanguageToolchainStore, LocalFile as _, LspAdapter, - LspAdapterDelegate, + ContextProvider, LanguageName, LanguageRegistry, LanguageToolchainStore, LocalFile as _, + LspAdapter, LspAdapterDelegate, }; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; @@ -269,7 +269,15 @@ impl JsonLspAdapter { .await; let config = cx.update(|cx| { - Self::get_workspace_config(self.languages.language_names().clone(), adapter_schemas, cx) + Self::get_workspace_config( + self.languages + .language_names() + .into_iter() + .map(|name| name.to_string()) + .collect(), + adapter_schemas, + cx, + ) })?; writer.replace(config.clone()); return Ok(config); @@ -408,10 +416,10 @@ impl LspAdapter for JsonLspAdapter { Ok(config) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { [ - ("JSON".into(), "json".into()), - ("JSONC".into(), "jsonc".into()), + (LanguageName::new("JSON"), "json".into()), + (LanguageName::new("JSONC"), "jsonc".into()), ] .into_iter() .collect() @@ -509,6 +517,7 @@ impl LspAdapter for NodeVersionAdapter { Ok(Box::new(GitHubLspBinaryVersion { name: release.tag_name, url: asset.browser_download_url.clone(), + digest: asset.digest.clone(), })) } diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index a224111002..195ba79e1d 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -1,4 +1,5 @@ use anyhow::Context as _; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use gpui::{App, UpdateGlobal}; use node_runtime::NodeRuntime; use python::PyprojectTomlManifestProvider; @@ -11,11 +12,12 @@ use util::{ResultExt, asset_str}; pub use language::*; -use crate::json::JsonTaskProvider; +use crate::{json::JsonTaskProvider, python::BasedPyrightLspAdapter}; mod bash; mod c; mod css; +mod github_download; mod go; mod json; mod package_json; @@ -52,6 +54,12 @@ pub static LANGUAGE_GIT_COMMIT: std::sync::LazyLock<Arc<Language>> = )) }); +struct BasedPyrightFeatureFlag; + +impl FeatureFlag for BasedPyrightFeatureFlag { + const NAME: &'static str = "basedpyright"; +} + pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { #[cfg(feature = "load-grammars")] languages.register_native_grammars([ @@ -88,6 +96,7 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { let py_lsp_adapter = Arc::new(python::PyLspAdapter::new()); let python_context_provider = Arc::new(python::PythonContextProvider); let python_lsp_adapter = Arc::new(python::PythonLspAdapter::new(node.clone())); + let basedpyright_lsp_adapter = Arc::new(BasedPyrightLspAdapter::new()); let python_toolchain_provider = Arc::new(python::PythonToolchainProvider::default()); let rust_context_provider = Arc::new(rust::RustContextProvider); let rust_lsp_adapter = Arc::new(rust::RustLspAdapter); @@ -228,6 +237,20 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) { ); } + let mut basedpyright_lsp_adapter = Some(basedpyright_lsp_adapter); + cx.observe_flag::<BasedPyrightFeatureFlag, _>({ + let languages = languages.clone(); + move |enabled, _| { + if enabled { + if let Some(adapter) = basedpyright_lsp_adapter.take() { + languages + .register_available_lsp_adapter(adapter.name(), move || adapter.clone()); + } + } + } + }) + .detach(); + // Register globally available language servers. // // This will allow users to add support for a built-in language server (e.g., Tailwind) diff --git a/crates/languages/src/python.rs b/crates/languages/src/python.rs index dc6996d399..0524c02fd5 100644 --- a/crates/languages/src/python.rs +++ b/crates/languages/src/python.rs @@ -4,13 +4,13 @@ use async_trait::async_trait; use collections::HashMap; use gpui::{App, Task}; use gpui::{AsyncApp, SharedString}; -use language::Toolchain; use language::ToolchainList; use language::ToolchainLister; use language::language_settings::language_settings; use language::{ContextLocation, LanguageToolchainStore}; use language::{ContextProvider, LspAdapter, LspAdapterDelegate}; use language::{LanguageName, ManifestName, ManifestProvider, ManifestQuery}; +use language::{Toolchain, WorkspaceFoldersContent}; use lsp::LanguageServerBinary; use lsp::LanguageServerName; use node_runtime::NodeRuntime; @@ -400,6 +400,9 @@ impl LspAdapter for PythonLspAdapter { fn manifest_name(&self) -> Option<ManifestName> { Some(SharedString::new_static("pyproject.toml").into()) } + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::WorktreeRoot + } } async fn get_cached_server_binary( @@ -1282,6 +1285,350 @@ impl LspAdapter for PyLspAdapter { fn manifest_name(&self) -> Option<ManifestName> { Some(SharedString::new_static("pyproject.toml").into()) } + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::WorktreeRoot + } +} + +pub(crate) struct BasedPyrightLspAdapter { + python_venv_base: OnceCell<Result<Arc<Path>, String>>, +} + +impl BasedPyrightLspAdapter { + const SERVER_NAME: LanguageServerName = LanguageServerName::new_static("basedpyright"); + const BINARY_NAME: &'static str = "basedpyright-langserver"; + + pub(crate) fn new() -> Self { + Self { + python_venv_base: OnceCell::new(), + } + } + + async fn ensure_venv(delegate: &dyn LspAdapterDelegate) -> Result<Arc<Path>> { + let python_path = Self::find_base_python(delegate) + .await + .context("Could not find Python installation for basedpyright")?; + let work_dir = delegate + .language_server_download_dir(&Self::SERVER_NAME) + .await + .context("Could not get working directory for basedpyright")?; + let mut path = PathBuf::from(work_dir.as_ref()); + path.push("basedpyright-venv"); + if !path.exists() { + util::command::new_smol_command(python_path) + .arg("-m") + .arg("venv") + .arg("basedpyright-venv") + .current_dir(work_dir) + .spawn()? + .output() + .await?; + } + + Ok(path.into()) + } + + // Find "baseline", user python version from which we'll create our own venv. + async fn find_base_python(delegate: &dyn LspAdapterDelegate) -> Option<PathBuf> { + for path in ["python3", "python"] { + if let Some(path) = delegate.which(path.as_ref()).await { + return Some(path); + } + } + None + } + + async fn base_venv(&self, delegate: &dyn LspAdapterDelegate) -> Result<Arc<Path>, String> { + self.python_venv_base + .get_or_init(move || async move { + Self::ensure_venv(delegate) + .await + .map_err(|e| format!("{e}")) + }) + .await + .clone() + } +} + +#[async_trait(?Send)] +impl LspAdapter for BasedPyrightLspAdapter { + fn name(&self) -> LanguageServerName { + Self::SERVER_NAME.clone() + } + + async fn initialization_options( + self: Arc<Self>, + _: &dyn Fs, + _: &Arc<dyn LspAdapterDelegate>, + ) -> Result<Option<Value>> { + // Provide minimal initialization options + // Virtual environment configuration will be handled through workspace configuration + Ok(Some(json!({ + "python": { + "analysis": { + "autoSearchPaths": true, + "useLibraryCodeForTypes": true, + "autoImportCompletions": true + } + } + }))) + } + + async fn check_if_user_installed( + &self, + delegate: &dyn LspAdapterDelegate, + toolchains: Arc<dyn LanguageToolchainStore>, + cx: &AsyncApp, + ) -> Option<LanguageServerBinary> { + if let Some(bin) = delegate.which(Self::BINARY_NAME.as_ref()).await { + let env = delegate.shell_env().await; + Some(LanguageServerBinary { + path: bin, + env: Some(env), + arguments: vec!["--stdio".into()], + }) + } else { + let venv = toolchains + .active_toolchain( + delegate.worktree_id(), + Arc::from("".as_ref()), + LanguageName::new("Python"), + &mut cx.clone(), + ) + .await?; + let path = Path::new(venv.path.as_ref()) + .parent()? + .join(Self::BINARY_NAME); + path.exists().then(|| LanguageServerBinary { + path, + arguments: vec!["--stdio".into()], + env: None, + }) + } + } + + async fn fetch_latest_server_version( + &self, + _: &dyn LspAdapterDelegate, + ) -> Result<Box<dyn 'static + Any + Send>> { + Ok(Box::new(()) as Box<_>) + } + + async fn fetch_server_binary( + &self, + _latest_version: Box<dyn 'static + Send + Any>, + _container_dir: PathBuf, + delegate: &dyn LspAdapterDelegate, + ) -> Result<LanguageServerBinary> { + let venv = self.base_venv(delegate).await.map_err(|e| anyhow!(e))?; + let pip_path = venv.join(BINARY_DIR).join("pip3"); + ensure!( + util::command::new_smol_command(pip_path.as_path()) + .arg("install") + .arg("basedpyright") + .arg("-U") + .output() + .await? + .status + .success(), + "basedpyright installation failed" + ); + let pylsp = venv.join(BINARY_DIR).join(Self::BINARY_NAME); + Ok(LanguageServerBinary { + path: pylsp, + env: None, + arguments: vec!["--stdio".into()], + }) + } + + async fn cached_server_binary( + &self, + _container_dir: PathBuf, + delegate: &dyn LspAdapterDelegate, + ) -> Option<LanguageServerBinary> { + let venv = self.base_venv(delegate).await.ok()?; + let pylsp = venv.join(BINARY_DIR).join(Self::BINARY_NAME); + Some(LanguageServerBinary { + path: pylsp, + env: None, + arguments: vec!["--stdio".into()], + }) + } + + async fn process_completions(&self, items: &mut [lsp::CompletionItem]) { + // Pyright assigns each completion item a `sortText` of the form `XX.YYYY.name`. + // Where `XX` is the sorting category, `YYYY` is based on most recent usage, + // and `name` is the symbol name itself. + // + // Because the symbol name is included, there generally are not ties when + // sorting by the `sortText`, so the symbol's fuzzy match score is not taken + // into account. Here, we remove the symbol name from the sortText in order + // to allow our own fuzzy score to be used to break ties. + // + // see https://github.com/microsoft/pyright/blob/95ef4e103b9b2f129c9320427e51b73ea7cf78bd/packages/pyright-internal/src/languageService/completionProvider.ts#LL2873 + for item in items { + let Some(sort_text) = &mut item.sort_text else { + continue; + }; + let mut parts = sort_text.split('.'); + let Some(first) = parts.next() else { continue }; + let Some(second) = parts.next() else { continue }; + let Some(_) = parts.next() else { continue }; + sort_text.replace_range(first.len() + second.len() + 1.., ""); + } + } + + async fn label_for_completion( + &self, + item: &lsp::CompletionItem, + language: &Arc<language::Language>, + ) -> Option<language::CodeLabel> { + let label = &item.label; + let grammar = language.grammar()?; + let highlight_id = match item.kind? { + lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method")?, + lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function")?, + lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type")?, + lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant")?, + _ => return None, + }; + let filter_range = item + .filter_text + .as_deref() + .and_then(|filter| label.find(filter).map(|ix| ix..ix + filter.len())) + .unwrap_or(0..label.len()); + Some(language::CodeLabel { + text: label.clone(), + runs: vec![(0..label.len(), highlight_id)], + filter_range, + }) + } + + async fn label_for_symbol( + &self, + name: &str, + kind: lsp::SymbolKind, + language: &Arc<language::Language>, + ) -> Option<language::CodeLabel> { + let (text, filter_range, display_range) = match kind { + lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => { + let text = format!("def {}():\n", name); + let filter_range = 4..4 + name.len(); + let display_range = 0..filter_range.end; + (text, filter_range, display_range) + } + lsp::SymbolKind::CLASS => { + let text = format!("class {}:", name); + let filter_range = 6..6 + name.len(); + let display_range = 0..filter_range.end; + (text, filter_range, display_range) + } + lsp::SymbolKind::CONSTANT => { + let text = format!("{} = 0", name); + let filter_range = 0..name.len(); + let display_range = 0..filter_range.end; + (text, filter_range, display_range) + } + _ => return None, + }; + + Some(language::CodeLabel { + runs: language.highlight_text(&text.as_str().into(), display_range.clone()), + text: text[display_range].to_string(), + filter_range, + }) + } + + async fn workspace_configuration( + self: Arc<Self>, + _: &dyn Fs, + adapter: &Arc<dyn LspAdapterDelegate>, + toolchains: Arc<dyn LanguageToolchainStore>, + cx: &mut AsyncApp, + ) -> Result<Value> { + let toolchain = toolchains + .active_toolchain( + adapter.worktree_id(), + Arc::from("".as_ref()), + LanguageName::new("Python"), + cx, + ) + .await; + cx.update(move |cx| { + let mut user_settings = + language_server_settings(adapter.as_ref(), &Self::SERVER_NAME, cx) + .and_then(|s| s.settings.clone()) + .unwrap_or_default(); + + // If we have a detected toolchain, configure Pyright to use it + if let Some(toolchain) = toolchain { + if user_settings.is_null() { + user_settings = Value::Object(serde_json::Map::default()); + } + let object = user_settings.as_object_mut().unwrap(); + + let interpreter_path = toolchain.path.to_string(); + + // Detect if this is a virtual environment + if let Some(interpreter_dir) = Path::new(&interpreter_path).parent() { + if let Some(venv_dir) = interpreter_dir.parent() { + // Check if this looks like a virtual environment + if venv_dir.join("pyvenv.cfg").exists() + || venv_dir.join("bin/activate").exists() + || venv_dir.join("Scripts/activate.bat").exists() + { + // Set venvPath and venv at the root level + // This matches the format of a pyrightconfig.json file + if let Some(parent) = venv_dir.parent() { + // Use relative path if the venv is inside the workspace + let venv_path = if parent == adapter.worktree_root_path() { + ".".to_string() + } else { + parent.to_string_lossy().into_owned() + }; + object.insert("venvPath".to_string(), Value::String(venv_path)); + } + + if let Some(venv_name) = venv_dir.file_name() { + object.insert( + "venv".to_owned(), + Value::String(venv_name.to_string_lossy().into_owned()), + ); + } + } + } + } + + // Always set the python interpreter path + // Get or create the python section + let python = object + .entry("python") + .or_insert(Value::Object(serde_json::Map::default())) + .as_object_mut() + .unwrap(); + + // Set both pythonPath and defaultInterpreterPath for compatibility + python.insert( + "pythonPath".to_owned(), + Value::String(interpreter_path.clone()), + ); + python.insert( + "defaultInterpreterPath".to_owned(), + Value::String(interpreter_path), + ); + } + + user_settings + }) + } + + fn manifest_name(&self) -> Option<ManifestName> { + Some(SharedString::new_static("pyproject.toml").into()) + } + + fn workspace_folders_content(&self) -> WorkspaceFoldersContent { + WorkspaceFoldersContent::WorktreeRoot + } } #[cfg(test)] diff --git a/crates/languages/src/rust.rs b/crates/languages/src/rust.rs index 3f83c9c000..6545bf64a2 100644 --- a/crates/languages/src/rust.rs +++ b/crates/languages/src/rust.rs @@ -1,8 +1,7 @@ use anyhow::{Context as _, Result}; -use async_compression::futures::bufread::GzipDecoder; use async_trait::async_trait; use collections::HashMap; -use futures::{StreamExt, io::BufReader}; +use futures::StreamExt; use gpui::{App, AppContext, AsyncApp, SharedString, Task}; use http_client::github::AssetKind; use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; @@ -23,14 +22,11 @@ use std::{ sync::{Arc, LazyLock}, }; use task::{TaskTemplate, TaskTemplates, TaskVariables, VariableName}; -use util::archive::extract_zip; +use util::fs::make_file_executable; use util::merge_json_value_into; -use util::{ - ResultExt, - fs::{make_file_executable, remove_matching}, - maybe, -}; +use util::{ResultExt, maybe}; +use crate::github_download::{GithubBinaryMetadata, download_server_binary}; use crate::language_settings::language_settings; pub struct RustLspAdapter; @@ -163,7 +159,6 @@ impl LspAdapter for RustLspAdapter { ) .await?; let asset_name = Self::build_asset_name(); - let asset = release .assets .iter() @@ -172,6 +167,7 @@ impl LspAdapter for RustLspAdapter { Ok(Box::new(GitHubLspBinaryVersion { name: release.tag_name, url: asset.browser_download_url.clone(), + digest: asset.digest.clone(), })) } @@ -181,58 +177,76 @@ impl LspAdapter for RustLspAdapter { container_dir: PathBuf, delegate: &dyn LspAdapterDelegate, ) -> Result<LanguageServerBinary> { - let version = version.downcast::<GitHubLspBinaryVersion>().unwrap(); - let destination_path = container_dir.join(format!("rust-analyzer-{}", version.name)); + let GitHubLspBinaryVersion { name, url, digest } = + &*version.downcast::<GitHubLspBinaryVersion>().unwrap(); + let expected_digest = digest + .as_ref() + .and_then(|digest| digest.strip_prefix("sha256:")); + let destination_path = container_dir.join(format!("rust-analyzer-{name}")); let server_path = match Self::GITHUB_ASSET_KIND { AssetKind::TarGz | AssetKind::Gz => destination_path.clone(), // Tar and gzip extract in place. AssetKind::Zip => destination_path.clone().join("rust-analyzer.exe"), // zip contains a .exe }; - if fs::metadata(&server_path).await.is_err() { - remove_matching(&container_dir, |entry| entry != destination_path).await; + let binary = LanguageServerBinary { + path: server_path.clone(), + env: None, + arguments: Default::default(), + }; - let mut response = delegate - .http_client() - .get(&version.url, Default::default(), true) - .await - .with_context(|| format!("downloading release from {}", version.url))?; - match Self::GITHUB_ASSET_KIND { - AssetKind::TarGz => { - let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); - let archive = async_tar::Archive::new(decompressed_bytes); - archive.unpack(&destination_path).await.with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Gz => { - let mut decompressed_bytes = - GzipDecoder::new(BufReader::new(response.body_mut())); - let mut file = - fs::File::create(&destination_path).await.with_context(|| { - format!( - "creating a file {:?} for a download from {}", - destination_path, version.url, - ) - })?; - futures::io::copy(&mut decompressed_bytes, &mut file) - .await - .with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Zip => { - extract_zip(&destination_path, response.body_mut()) - .await - .with_context(|| { - format!("unzipping {} to {:?}", version.url, destination_path) - })?; - } + let metadata_path = destination_path.with_extension("metadata"); + let metadata = GithubBinaryMetadata::read_from_file(&metadata_path) + .await + .ok(); + if let Some(metadata) = metadata { + let validity_check = async || { + delegate + .try_exec(LanguageServerBinary { + path: server_path.clone(), + arguments: vec!["--version".into()], + env: None, + }) + .await + .inspect_err(|err| { + log::warn!("Unable to run {server_path:?} asset, redownloading: {err}",) + }) }; - - // todo("windows") - make_file_executable(&server_path).await?; + if let (Some(actual_digest), Some(expected_digest)) = + (&metadata.digest, expected_digest) + { + if actual_digest == expected_digest { + if validity_check().await.is_ok() { + return Ok(binary); + } + } else { + log::info!( + "SHA-256 mismatch for {destination_path:?} asset, downloading new asset. Expected: {expected_digest}, Got: {actual_digest}" + ); + } + } else if validity_check().await.is_ok() { + return Ok(binary); + } } + _ = fs::remove_dir_all(&destination_path).await; + download_server_binary( + delegate, + url, + expected_digest, + &destination_path, + Self::GITHUB_ASSET_KIND, + ) + .await?; + make_file_executable(&server_path).await?; + GithubBinaryMetadata::write_to_file( + &GithubBinaryMetadata { + metadata_version: 1, + digest: expected_digest.map(ToString::to_string), + }, + &metadata_path, + ) + .await?; + Ok(LanguageServerBinary { path: server_path, env: None, @@ -1025,7 +1039,11 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option<LanguageServ let mut last = None; let mut entries = fs::read_dir(&container_dir).await?; while let Some(entry) = entries.next().await { - last = Some(entry?.path()); + let path = entry?.path(); + if path.extension().is_some_and(|ext| ext == "metadata") { + continue; + } + last = Some(path); } anyhow::Ok(LanguageServerBinary { diff --git a/crates/languages/src/tailwind.rs b/crates/languages/src/tailwind.rs index cb4e939083..a7edbb148c 100644 --- a/crates/languages/src/tailwind.rs +++ b/crates/languages/src/tailwind.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use collections::HashMap; use futures::StreamExt; use gpui::AsyncApp; -use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; @@ -168,20 +168,20 @@ impl LspAdapter for TailwindLspAdapter { })) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { HashMap::from_iter([ - ("Astro".to_string(), "astro".to_string()), - ("HTML".to_string(), "html".to_string()), - ("CSS".to_string(), "css".to_string()), - ("JavaScript".to_string(), "javascript".to_string()), - ("TSX".to_string(), "typescriptreact".to_string()), - ("Svelte".to_string(), "svelte".to_string()), - ("Elixir".to_string(), "phoenix-heex".to_string()), - ("HEEX".to_string(), "phoenix-heex".to_string()), - ("ERB".to_string(), "erb".to_string()), - ("HTML/ERB".to_string(), "erb".to_string()), - ("PHP".to_string(), "php".to_string()), - ("Vue.js".to_string(), "vue".to_string()), + (LanguageName::new("Astro"), "astro".to_string()), + (LanguageName::new("HTML"), "html".to_string()), + (LanguageName::new("CSS"), "css".to_string()), + (LanguageName::new("JavaScript"), "javascript".to_string()), + (LanguageName::new("TSX"), "typescriptreact".to_string()), + (LanguageName::new("Svelte"), "svelte".to_string()), + (LanguageName::new("Elixir"), "phoenix-heex".to_string()), + (LanguageName::new("HEEX"), "phoenix-heex".to_string()), + (LanguageName::new("ERB"), "erb".to_string()), + (LanguageName::new("HTML/ERB"), "erb".to_string()), + (LanguageName::new("PHP"), "php".to_string()), + (LanguageName::new("Vue.js"), "vue".to_string()), ]) } } diff --git a/crates/languages/src/typescript.rs b/crates/languages/src/typescript.rs index 34b9c3224e..f976b62614 100644 --- a/crates/languages/src/typescript.rs +++ b/crates/languages/src/typescript.rs @@ -1,6 +1,4 @@ use anyhow::{Context as _, Result}; -use async_compression::futures::bufread::GzipDecoder; -use async_tar::Archive; use async_trait::async_trait; use chrono::{DateTime, Local}; use collections::HashMap; @@ -8,13 +6,14 @@ use futures::future::join_all; use gpui::{App, AppContext, AsyncApp, Task}; use http_client::github::{AssetKind, GitHubLspBinaryVersion, build_asset_url}; use language::{ - ContextLocation, ContextProvider, File, LanguageToolchainStore, LspAdapter, LspAdapterDelegate, + ContextLocation, ContextProvider, File, LanguageName, LanguageToolchainStore, LspAdapter, + LspAdapterDelegate, }; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; use serde_json::{Value, json}; -use smol::{fs, io::BufReader, lock::RwLock, stream::StreamExt}; +use smol::{fs, lock::RwLock, stream::StreamExt}; use std::{ any::Any, borrow::Cow, @@ -23,11 +22,10 @@ use std::{ sync::Arc, }; use task::{TaskTemplate, TaskTemplates, VariableName}; -use util::archive::extract_zip; use util::merge_json_value_into; use util::{ResultExt, fs::remove_matching, maybe}; -use crate::{PackageJson, PackageJsonData}; +use crate::{PackageJson, PackageJsonData, github_download::download_server_binary}; #[derive(Debug)] pub(crate) struct TypeScriptContextProvider { @@ -512,7 +510,7 @@ fn eslint_server_binary_arguments(server_path: &Path) -> Vec<OsString> { fn replace_test_name_parameters(test_name: &str) -> String { let pattern = regex::Regex::new(r"(%|\$)[0-9a-zA-Z]+").unwrap(); - pattern.replace_all(test_name, "(.+?)").to_string() + regex::escape(&pattern.replace_all(test_name, "(.+?)")) } pub struct TypeScriptLspAdapter { @@ -741,11 +739,11 @@ impl LspAdapter for TypeScriptLspAdapter { })) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { HashMap::from_iter([ - ("TypeScript".into(), "typescript".into()), - ("JavaScript".into(), "javascript".into()), - ("TSX".into(), "typescriptreact".into()), + (LanguageName::new("TypeScript"), "typescript".into()), + (LanguageName::new("JavaScript"), "javascript".into()), + (LanguageName::new("TSX"), "typescriptreact".into()), ]) } } @@ -896,6 +894,7 @@ impl LspAdapter for EsLintLspAdapter { Ok(Box::new(GitHubLspBinaryVersion { name: Self::CURRENT_VERSION.into(), + digest: None, url, })) } @@ -913,43 +912,14 @@ impl LspAdapter for EsLintLspAdapter { if fs::metadata(&server_path).await.is_err() { remove_matching(&container_dir, |entry| entry != destination_path).await; - let mut response = delegate - .http_client() - .get(&version.url, Default::default(), true) - .await - .context("downloading release")?; - match Self::GITHUB_ASSET_KIND { - AssetKind::TarGz => { - let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); - let archive = Archive::new(decompressed_bytes); - archive.unpack(&destination_path).await.with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Gz => { - let mut decompressed_bytes = - GzipDecoder::new(BufReader::new(response.body_mut())); - let mut file = - fs::File::create(&destination_path).await.with_context(|| { - format!( - "creating a file {:?} for a download from {}", - destination_path, version.url, - ) - })?; - futures::io::copy(&mut decompressed_bytes, &mut file) - .await - .with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Zip => { - extract_zip(&destination_path, response.body_mut()) - .await - .with_context(|| { - format!("unzipping {} to {:?}", version.url, destination_path) - })?; - } - } + download_server_binary( + delegate, + &version.url, + None, + &destination_path, + Self::GITHUB_ASSET_KIND, + ) + .await?; let mut dir = fs::read_dir(&destination_path).await?; let first = dir.next().await.context("missing first file")??; diff --git a/crates/languages/src/typescript/runnables.scm b/crates/languages/src/typescript/runnables.scm index 85702cf99d..6bfc536329 100644 --- a/crates/languages/src/typescript/runnables.scm +++ b/crates/languages/src/typescript/runnables.scm @@ -1,4 +1,4 @@ -; Add support for (node:test, bun:test and Jest) runnable +; Add support for (node:test, bun:test, Jest and Deno.test) runnable ; Function expression that has `it`, `test` or `describe` as the function name ( (call_expression @@ -44,3 +44,42 @@ (#set! tag js-test) ) + +; Add support for Deno.test with string names +( + (call_expression + function: (member_expression + object: (identifier) @_namespace + property: (property_identifier) @_method + ) + (#eq? @_namespace "Deno") + (#eq? @_method "test") + arguments: ( + arguments . [ + (string (string_fragment) @run @DENO_TEST_NAME) + (identifier) @run @DENO_TEST_NAME + ] + ) + ) @_js-test + + (#set! tag js-test) +) + +; Add support for Deno.test with named function expressions +( + (call_expression + function: (member_expression + object: (identifier) @_namespace + property: (property_identifier) @_method + ) + (#eq? @_namespace "Deno") + (#eq? @_method "test") + arguments: ( + arguments . (function_expression + name: (identifier) @run @DENO_TEST_NAME + ) + ) + ) @_js-test + + (#set! tag js-test) +) diff --git a/crates/languages/src/vtsls.rs b/crates/languages/src/vtsls.rs index ca07673d5f..33751f733e 100644 --- a/crates/languages/src/vtsls.rs +++ b/crates/languages/src/vtsls.rs @@ -2,7 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use collections::HashMap; use gpui::AsyncApp; -use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; @@ -273,11 +273,11 @@ impl LspAdapter for VtslsLspAdapter { Ok(default_workspace_configuration) } - fn language_ids(&self) -> HashMap<String, String> { + fn language_ids(&self) -> HashMap<LanguageName, String> { HashMap::from_iter([ - ("TypeScript".into(), "typescript".into()), - ("JavaScript".into(), "javascript".into()), - ("TSX".into(), "typescriptreact".into()), + (LanguageName::new("TypeScript"), "typescript".into()), + (LanguageName::new("JavaScript"), "javascript".into()), + (LanguageName::new("TSX"), "typescriptreact".into()), ]) } } diff --git a/crates/languages/src/yaml/config.toml b/crates/languages/src/yaml/config.toml index 4dfb890c54..e54bceda1a 100644 --- a/crates/languages/src/yaml/config.toml +++ b/crates/languages/src/yaml/config.toml @@ -1,6 +1,6 @@ name = "YAML" grammar = "yaml" -path_suffixes = ["yml", "yaml"] +path_suffixes = ["yml", "yaml", "pixi.lock"] line_comments = ["# "] autoclose_before = ",]}" brackets = [ diff --git a/crates/languages/src/yaml/outline.scm b/crates/languages/src/yaml/outline.scm index 7ab007835f..c5a7f8e5d4 100644 --- a/crates/languages/src/yaml/outline.scm +++ b/crates/languages/src/yaml/outline.scm @@ -1 +1,9 @@ -(block_mapping_pair key: (flow_node (plain_scalar (string_scalar) @name))) @item +(block_mapping_pair + key: + (flow_node + (plain_scalar + (string_scalar) @name)) + value: + (flow_node + (plain_scalar + (string_scalar) @context))?) @item diff --git a/crates/livekit_client/Cargo.toml b/crates/livekit_client/Cargo.toml index a0c11d46e6..821fd5d390 100644 --- a/crates/livekit_client/Cargo.toml +++ b/crates/livekit_client/Cargo.toml @@ -40,8 +40,8 @@ util.workspace = true workspace-hack.workspace = true [target.'cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))'.dependencies] -libwebrtc = { rev = "d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4", git = "https://github.com/zed-industries/livekit-rust-sdks" } -livekit = { rev = "d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ +libwebrtc = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks" } +livekit = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ "__rustls-tls" ] } diff --git a/crates/livekit_client/src/livekit_client/playback.rs b/crates/livekit_client/src/livekit_client/playback.rs index c62b8853b4..f14e156125 100644 --- a/crates/livekit_client/src/livekit_client/playback.rs +++ b/crates/livekit_client/src/livekit_client/playback.rs @@ -1,6 +1,7 @@ use anyhow::{Context as _, Result}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait as _}; +use cpal::{Data, FromSample, I24, SampleFormat, SizedSample}; use futures::channel::mpsc::UnboundedSender; use futures::{Stream, StreamExt as _}; use gpui::{ @@ -258,9 +259,15 @@ impl AudioStack { let stream = device .build_input_stream_raw( &config.config(), - cpal::SampleFormat::I16, + config.sample_format(), move |data, _: &_| { - let mut data = data.as_slice::<i16>().unwrap(); + let data = + Self::get_sample_data(config.sample_format(), data).log_err(); + let Some(data) = data else { + return; + }; + let mut data = data.as_slice(); + while data.len() > 0 { let remainder = (buf.capacity() - buf.len()).min(data.len()); buf.extend_from_slice(&data[..remainder]); @@ -313,6 +320,33 @@ impl AudioStack { drop(end_on_drop_tx) } } + + fn get_sample_data(sample_format: SampleFormat, data: &Data) -> Result<Vec<i16>> { + match sample_format { + SampleFormat::I8 => Ok(Self::convert_sample_data::<i8, i16>(data)), + SampleFormat::I16 => Ok(data.as_slice::<i16>().unwrap().to_vec()), + SampleFormat::I24 => Ok(Self::convert_sample_data::<I24, i16>(data)), + SampleFormat::I32 => Ok(Self::convert_sample_data::<i32, i16>(data)), + SampleFormat::I64 => Ok(Self::convert_sample_data::<i64, i16>(data)), + SampleFormat::U8 => Ok(Self::convert_sample_data::<u8, i16>(data)), + SampleFormat::U16 => Ok(Self::convert_sample_data::<u16, i16>(data)), + SampleFormat::U32 => Ok(Self::convert_sample_data::<u32, i16>(data)), + SampleFormat::U64 => Ok(Self::convert_sample_data::<u64, i16>(data)), + SampleFormat::F32 => Ok(Self::convert_sample_data::<f32, i16>(data)), + SampleFormat::F64 => Ok(Self::convert_sample_data::<f64, i16>(data)), + _ => anyhow::bail!("Unsupported sample format"), + } + } + + fn convert_sample_data<TSource: SizedSample, TDest: SizedSample + FromSample<TSource>>( + data: &Data, + ) -> Vec<TDest> { + data.as_slice::<TSource>() + .unwrap() + .iter() + .map(|e| e.to_sample::<TDest>()) + .collect() + } } use super::LocalVideoTrack; diff --git a/crates/livekit_client/src/mock_client/participant.rs b/crates/livekit_client/src/mock_client/participant.rs index 991d10bd50..033808cbb5 100644 --- a/crates/livekit_client/src/mock_client/participant.rs +++ b/crates/livekit_client/src/mock_client/participant.rs @@ -5,7 +5,9 @@ use crate::{ }; use anyhow::Result; use collections::HashMap; -use gpui::{AsyncApp, ScreenCaptureSource, ScreenCaptureStream, TestScreenCaptureStream}; +use gpui::{ + AsyncApp, DevicePixels, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata, size, +}; #[derive(Clone, Debug)] pub struct LocalParticipant { @@ -119,3 +121,16 @@ impl RemoteParticipant { self.identity.clone() } } + +struct TestScreenCaptureStream; + +impl ScreenCaptureStream for TestScreenCaptureStream { + fn metadata(&self) -> Result<SourceMetadata> { + Ok(SourceMetadata { + id: 0, + is_main: None, + label: None, + resolution: size(DevicePixels(1), DevicePixels(1)), + }) + } +} diff --git a/crates/lmstudio/src/lmstudio.rs b/crates/lmstudio/src/lmstudio.rs index a5477994ff..43c78115cd 100644 --- a/crates/lmstudio/src/lmstudio.rs +++ b/crates/lmstudio/src/lmstudio.rs @@ -1,4 +1,4 @@ -use anyhow::{Context as _, Result}; +use anyhow::{Context as _, Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http}; use serde::{Deserialize, Serialize}; @@ -275,11 +275,16 @@ impl Capabilities { } } +#[derive(Serialize, Deserialize, Debug)] +pub struct LmStudioError { + pub message: String, +} + #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum ResponseStreamResult { Ok(ResponseStreamEvent), - Err { error: String }, + Err { error: LmStudioError }, } #[derive(Serialize, Deserialize, Debug)] @@ -392,7 +397,6 @@ pub async fn stream_chat_completion( let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); - Ok(reader .lines() .filter_map(|line| async move { @@ -402,18 +406,16 @@ pub async fn stream_chat_completion( if line == "[DONE]" { None } else { - let result = serde_json::from_str(&line) - .context("Unable to parse chat completions response"); - if let Err(ref e) = result { - eprintln!("Error parsing line: {e}\nLine content: '{line}'"); + match serde_json::from_str(line) { + Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)), + Ok(ResponseStreamResult::Err { error, .. }) => { + Some(Err(anyhow!(error.message))) + } + Err(error) => Some(Err(anyhow!(error))), } - Some(result) } } - Err(e) => { - eprintln!("Error reading line: {e}"); - Some(Err(e.into())) - } + Err(error) => Some(Err(anyhow!(error))), } }) .boxed()) diff --git a/crates/lsp/src/input_handler.rs b/crates/lsp/src/input_handler.rs index db3f1190fc..001ebf1fc9 100644 --- a/crates/lsp/src/input_handler.rs +++ b/crates/lsp/src/input_handler.rs @@ -13,14 +13,15 @@ use parking_lot::Mutex; use smol::io::BufReader; use crate::{ - AnyNotification, AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, RequestId, ResponseHandler, + AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, NotificationOrRequest, RequestId, + ResponseHandler, }; const HEADER_DELIMITER: &[u8; 4] = b"\r\n\r\n"; /// Handler for stdout of language server. pub struct LspStdoutHandler { pub(super) loop_handle: Task<Result<()>>, - pub(super) notifications_channel: UnboundedReceiver<AnyNotification>, + pub(super) incoming_messages: UnboundedReceiver<NotificationOrRequest>, } async fn read_headers<Stdout>(reader: &mut BufReader<Stdout>, buffer: &mut Vec<u8>) -> Result<()> @@ -54,13 +55,13 @@ impl LspStdoutHandler { let loop_handle = cx.spawn(Self::handler(stdout, tx, response_handlers, io_handlers)); Self { loop_handle, - notifications_channel, + incoming_messages: notifications_channel, } } async fn handler<Input>( stdout: Input, - notifications_sender: UnboundedSender<AnyNotification>, + notifications_sender: UnboundedSender<NotificationOrRequest>, response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>, io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>, ) -> anyhow::Result<()> @@ -96,7 +97,7 @@ impl LspStdoutHandler { } } - if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) { + if let Ok(msg) = serde_json::from_slice::<NotificationOrRequest>(&buffer) { notifications_sender.unbounded_send(msg)?; } else if let Ok(AnyResponse { id, error, result, .. diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index a820aaf748..a92787cd3e 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -29,7 +29,7 @@ use std::{ ffi::{OsStr, OsString}, fmt, io::Write, - ops::{Deref, DerefMut}, + ops::DerefMut, path::PathBuf, pin::Pin, sync::{ @@ -100,7 +100,7 @@ pub struct LanguageServer { io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>, output_done_rx: Mutex<Option<barrier::Receiver>>, server: Arc<Mutex<Option<Child>>>, - workspace_folders: Arc<Mutex<BTreeSet<Url>>>, + workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, root_uri: Url, } @@ -242,7 +242,7 @@ struct Notification<'a, T> { /// Language server RPC notification message before it is deserialized into a concrete type. #[derive(Debug, Clone, Deserialize)] -struct AnyNotification { +struct NotificationOrRequest { #[serde(default)] id: Option<RequestId>, method: String, @@ -252,7 +252,10 @@ struct AnyNotification { #[derive(Debug, Serialize, Deserialize)] struct Error { + code: i64, message: String, + #[serde(default)] + data: Option<serde_json::Value>, } pub trait LspRequestFuture<O>: Future<Output = ConnectionResult<O>> { @@ -307,7 +310,7 @@ impl LanguageServer { binary: LanguageServerBinary, root_path: &Path, code_action_kinds: Option<Vec<CodeActionKind>>, - workspace_folders: Arc<Mutex<BTreeSet<Url>>>, + workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, cx: &mut AsyncApp, ) -> Result<Self> { let working_dir = if root_path.is_dir() { @@ -364,6 +367,7 @@ impl LanguageServer { notification.method, serde_json::to_string_pretty(¬ification.params).unwrap(), ); + false }, ); @@ -381,7 +385,7 @@ impl LanguageServer { code_action_kinds: Option<Vec<CodeActionKind>>, binary: LanguageServerBinary, root_uri: Url, - workspace_folders: Arc<Mutex<BTreeSet<Url>>>, + workspace_folders: Option<Arc<Mutex<BTreeSet<Url>>>>, cx: &mut AsyncApp, on_unhandled_notification: F, ) -> Self @@ -389,7 +393,7 @@ impl LanguageServer { Stdin: AsyncWrite + Unpin + Send + 'static, Stdout: AsyncRead + Unpin + Send + 'static, Stderr: AsyncRead + Unpin + Send + 'static, - F: FnMut(AnyNotification) + 'static + Send + Sync + Clone, + F: Fn(&NotificationOrRequest) -> bool + 'static + Send + Sync + Clone, { let (outbound_tx, outbound_rx) = channel::unbounded::<String>(); let (output_done_tx, output_done_rx) = barrier::channel(); @@ -400,14 +404,34 @@ impl LanguageServer { let io_handlers = Arc::new(Mutex::new(HashMap::default())); let stdout_input_task = cx.spawn({ - let on_unhandled_notification = on_unhandled_notification.clone(); + let unhandled_notification_wrapper = { + let response_channel = outbound_tx.clone(); + async move |msg: NotificationOrRequest| { + let did_handle = on_unhandled_notification(&msg); + if !did_handle && let Some(message_id) = msg.id { + let response = AnyResponse { + jsonrpc: JSON_RPC_VERSION, + id: message_id, + error: Some(Error { + code: -32601, + message: format!("Unrecognized method `{}`", msg.method), + data: None, + }), + result: None, + }; + if let Ok(response) = serde_json::to_string(&response) { + response_channel.send(response).await.ok(); + } + } + } + }; let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); let io_handlers = io_handlers.clone(); async move |cx| { - Self::handle_input( + Self::handle_incoming_messages( stdout, - on_unhandled_notification, + unhandled_notification_wrapper, notification_handlers, response_handlers, io_handlers, @@ -421,19 +445,19 @@ impl LanguageServer { .map(|stderr| { let io_handlers = io_handlers.clone(); let stderr_captures = stderr_capture.clone(); - cx.spawn(async move |_| { + cx.background_spawn(async move { Self::handle_stderr(stderr, io_handlers, stderr_captures) .log_err() .await }) }) .unwrap_or_else(|| Task::ready(None)); - let input_task = cx.spawn(async move |_| { + let input_task = cx.background_spawn(async move { let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task); stdout.or(stderr) }); let output_task = cx.background_spawn({ - Self::handle_output( + Self::handle_outgoing_messages( stdin, outbound_rx, output_done_tx, @@ -479,9 +503,9 @@ impl LanguageServer { self.code_action_kinds.clone() } - async fn handle_input<Stdout, F>( + async fn handle_incoming_messages<Stdout>( stdout: Stdout, - mut on_unhandled_notification: F, + on_unhandled_notification: impl AsyncFn(NotificationOrRequest) + 'static + Send, notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>, io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>, @@ -489,7 +513,6 @@ impl LanguageServer { ) -> anyhow::Result<()> where Stdout: AsyncRead + Unpin + Send + 'static, - F: FnMut(AnyNotification) + 'static + Send, { use smol::stream::StreamExt; let stdout = BufReader::new(stdout); @@ -506,15 +529,19 @@ impl LanguageServer { cx.background_executor().clone(), ); - while let Some(msg) = input_handler.notifications_channel.next().await { - { + while let Some(msg) = input_handler.incoming_messages.next().await { + let unhandled_message = { let mut notification_handlers = notification_handlers.lock(); if let Some(handler) = notification_handlers.get_mut(msg.method.as_str()) { handler(msg.id, msg.params.unwrap_or(Value::Null), cx); + None } else { - drop(notification_handlers); - on_unhandled_notification(msg); + Some(msg) } + }; + + if let Some(msg) = unhandled_message { + on_unhandled_notification(msg).await; } // Don't starve the main thread when receiving lots of notifications at once. @@ -558,7 +585,7 @@ impl LanguageServer { } } - async fn handle_output<Stdin>( + async fn handle_outgoing_messages<Stdin>( stdin: Stdin, outbound_rx: channel::Receiver<String>, output_done_tx: barrier::Sender, @@ -595,16 +622,26 @@ impl LanguageServer { } pub fn default_initialize_params(&self, pull_diagnostics: bool, cx: &App) -> InitializeParams { - let workspace_folders = self - .workspace_folders - .lock() - .iter() - .cloned() - .map(|uri| WorkspaceFolder { - name: Default::default(), - uri, - }) - .collect::<Vec<_>>(); + let workspace_folders = self.workspace_folders.as_ref().map_or_else( + || { + vec![WorkspaceFolder { + name: Default::default(), + uri: self.root_uri.clone(), + }] + }, + |folders| { + folders + .lock() + .iter() + .cloned() + .map(|uri| WorkspaceFolder { + name: Default::default(), + uri, + }) + .collect() + }, + ); + #[allow(deprecated)] InitializeParams { process_id: None, @@ -710,6 +747,10 @@ impl LanguageServer { InsertTextMode::ADJUST_INDENTATION, ], }), + documentation_format: Some(vec![ + MarkupKind::Markdown, + MarkupKind::PlainText, + ]), ..Default::default() }), insert_text_mode: Some(InsertTextMode::ADJUST_INDENTATION), @@ -836,7 +877,7 @@ impl LanguageServer { configuration: Arc<DidChangeConfigurationParams>, cx: &App, ) -> Task<Result<Arc<Self>>> { - cx.spawn(async move |_| { + cx.background_spawn(async move { let response = self .request::<request::Initialize>(params) .await @@ -877,39 +918,41 @@ impl LanguageServer { let server = self.server.clone(); let name = self.name.clone(); + let server_id = self.server_id; let mut timer = self.executor.timer(SERVER_SHUTDOWN_TIMEOUT).fuse(); - Some( - async move { - log::debug!("language server shutdown started"); + Some(async move { + log::debug!("language server shutdown started"); - select! { - request_result = shutdown_request.fuse() => { - match request_result { - ConnectionResult::Timeout => { - log::warn!("timeout waiting for language server {name} to shutdown"); - }, - ConnectionResult::ConnectionReset => {}, - ConnectionResult::Result(r) => r?, - } + select! { + request_result = shutdown_request.fuse() => { + match request_result { + ConnectionResult::Timeout => { + log::warn!("timeout waiting for language server {name} (id {server_id}) to shutdown"); + }, + ConnectionResult::ConnectionReset => { + log::warn!("language server {name} (id {server_id}) closed the shutdown request connection"); + }, + ConnectionResult::Result(Err(e)) => { + log::error!("Shutdown request failure, server {name} (id {server_id}): {e:#}"); + }, + ConnectionResult::Result(Ok(())) => {} } - - _ = timer => { - log::info!("timeout waiting for language server {name} to shutdown"); - }, } - response_handlers.lock().take(); - Self::notify_internal::<notification::Exit>(&outbound_tx, &()).ok(); - outbound_tx.close(); - output_done.recv().await; - server.lock().take().map(|mut child| child.kill()); - log::debug!("language server shutdown finished"); - - drop(tasks); - anyhow::Ok(()) + _ = timer => { + log::info!("timeout waiting for language server {name} (id {server_id}) to shutdown"); + }, } - .log_err(), - ) + + response_handlers.lock().take(); + Self::notify_internal::<notification::Exit>(&outbound_tx, &()).ok(); + outbound_tx.close(); + output_done.recv().await; + server.lock().take().map(|mut child| child.kill()); + drop(tasks); + log::debug!("language server shutdown finished"); + Some(()) + }) } else { None } @@ -1024,7 +1067,9 @@ impl LanguageServer { jsonrpc: JSON_RPC_VERSION, id, value: LspResult::Error(Some(Error { + code: lsp_types::error_codes::REQUEST_FAILED, message: error.to_string(), + data: None, })), }, }; @@ -1045,7 +1090,9 @@ impl LanguageServer { id, result: None, error: Some(Error { + code: -32700, // Parse error message: error.to_string(), + data: None, }), }; if let Some(response) = serde_json::to_string(&response).log_err() { @@ -1313,7 +1360,10 @@ impl LanguageServer { return; } - let is_new_folder = self.workspace_folders.lock().insert(uri.clone()); + let Some(workspace_folders) = self.workspace_folders.as_ref() else { + return; + }; + let is_new_folder = workspace_folders.lock().insert(uri.clone()); if is_new_folder { let params = DidChangeWorkspaceFoldersParams { event: WorkspaceFoldersChangeEvent { @@ -1343,7 +1393,10 @@ impl LanguageServer { { return; } - let was_removed = self.workspace_folders.lock().remove(&uri); + let Some(workspace_folders) = self.workspace_folders.as_ref() else { + return; + }; + let was_removed = workspace_folders.lock().remove(&uri); if was_removed { let params = DidChangeWorkspaceFoldersParams { event: WorkspaceFoldersChangeEvent { @@ -1358,7 +1411,10 @@ impl LanguageServer { } } pub fn set_workspace_folders(&self, folders: BTreeSet<Url>) { - let mut workspace_folders = self.workspace_folders.lock(); + let Some(workspace_folders) = self.workspace_folders.as_ref() else { + return; + }; + let mut workspace_folders = workspace_folders.lock(); let old_workspace_folders = std::mem::take(&mut *workspace_folders); let added: Vec<_> = folders @@ -1387,8 +1443,11 @@ impl LanguageServer { } } - pub fn workspace_folders(&self) -> impl Deref<Target = BTreeSet<Url>> + '_ { - self.workspace_folders.lock() + pub fn workspace_folders(&self) -> BTreeSet<Url> { + self.workspace_folders.as_ref().map_or_else( + || BTreeSet::from_iter([self.root_uri.clone()]), + |folders| folders.lock().clone(), + ) } pub fn register_buffer( @@ -1533,9 +1592,9 @@ impl FakeLanguageServer { None, binary.clone(), root, - workspace_folders.clone(), + Some(workspace_folders.clone()), cx, - |_| {}, + |_| false, ); server.process_name = process_name; let fake = FakeLanguageServer { @@ -1552,15 +1611,16 @@ impl FakeLanguageServer { None, binary, Self::root_path(), - workspace_folders, + Some(workspace_folders), cx, move |msg| { notifications_tx .try_send(( msg.method.to_string(), - msg.params.unwrap_or(Value::Null).to_string(), + msg.params.as_ref().unwrap_or(&Value::Null).to_string(), )) .ok(); + true }, ); server.process_name = name.as_str().into(); @@ -1838,7 +1898,7 @@ mod tests { #[gpui::test] fn test_deserialize_string_digit_id() { let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#; - let notification = serde_json::from_str::<AnyNotification>(json) + let notification = serde_json::from_str::<NotificationOrRequest>(json) .expect("message with string id should be parsed"); let expected_id = RequestId::Str("2".to_string()); assert_eq!(notification.id, Some(expected_id)); @@ -1847,7 +1907,7 @@ mod tests { #[gpui::test] fn test_deserialize_string_id() { let json = r#"{"jsonrpc":"2.0","id":"anythingAtAll","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#; - let notification = serde_json::from_str::<AnyNotification>(json) + let notification = serde_json::from_str::<NotificationOrRequest>(json) .expect("message with string id should be parsed"); let expected_id = RequestId::Str("anythingAtAll".to_string()); assert_eq!(notification.id, Some(expected_id)); @@ -1856,7 +1916,7 @@ mod tests { #[gpui::test] fn test_deserialize_int_id() { let json = r#"{"jsonrpc":"2.0","id":2,"method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#; - let notification = serde_json::from_str::<AnyNotification>(json) + let notification = serde_json::from_str::<NotificationOrRequest>(json) .expect("message with string id should be parsed"); let expected_id = RequestId::Int(2); assert_eq!(notification.id, Some(expected_id)); diff --git a/crates/markdown_preview/src/markdown_preview_view.rs b/crates/markdown_preview/src/markdown_preview_view.rs index 03cfd7ee82..a0c8819991 100644 --- a/crates/markdown_preview/src/markdown_preview_view.rs +++ b/crates/markdown_preview/src/markdown_preview_view.rs @@ -18,6 +18,7 @@ use workspace::item::{Item, ItemHandle}; use workspace::{Pane, Workspace}; use crate::markdown_elements::ParsedMarkdownElement; +use crate::markdown_renderer::CheckboxClickedEvent; use crate::{ MovePageDown, MovePageUp, OpenFollowingPreview, OpenPreview, OpenPreviewToTheSide, markdown_elements::ParsedMarkdown, @@ -203,114 +204,7 @@ impl MarkdownPreviewView { cx: &mut Context<Workspace>, ) -> Entity<Self> { cx.new(|cx| { - let view = cx.entity().downgrade(); - - let list_state = ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - if let Some(view) = view.upgrade() { - view.update(cx, |this: &mut Self, cx| { - let Some(contents) = &this.contents else { - return div().into_any(); - }; - - let mut render_cx = - RenderContext::new(Some(this.workspace.clone()), window, cx) - .with_checkbox_clicked_callback({ - let view = view.clone(); - move |checked, source_range, window, cx| { - view.update(cx, |view, cx| { - if let Some(editor) = view - .active_editor - .as_ref() - .map(|s| s.editor.clone()) - { - editor.update(cx, |editor, cx| { - let task_marker = - if checked { "[x]" } else { "[ ]" }; - - editor.edit( - vec![(source_range, task_marker)], - cx, - ); - }); - view.parse_markdown_from_active_editor( - false, window, cx, - ); - cx.notify(); - } - }) - } - }); - - let block = contents.children.get(ix).unwrap(); - let rendered_block = render_markdown_block(block, &mut render_cx); - - let should_apply_padding = Self::should_apply_padding_between( - block, - contents.children.get(ix + 1), - ); - - div() - .id(ix) - .when(should_apply_padding, |this| { - this.pb(render_cx.scaled_rems(0.75)) - }) - .group("markdown-block") - .on_click(cx.listener( - move |this, event: &ClickEvent, window, cx| { - if event.down.click_count == 2 { - if let Some(source_range) = this - .contents - .as_ref() - .and_then(|c| c.children.get(ix)) - .and_then(|block| block.source_range()) - { - this.move_cursor_to_block( - window, - cx, - source_range.start..source_range.start, - ); - } - } - }, - )) - .map(move |container| { - let indicator = div() - .h_full() - .w(px(4.0)) - .when(ix == this.selected_block, |this| { - this.bg(cx.theme().colors().border) - }) - .group_hover("markdown-block", |s| { - if ix == this.selected_block { - s - } else { - s.bg(cx.theme().colors().border_variant) - } - }) - .rounded_xs(); - - container.child( - div() - .relative() - .child( - div() - .pl(render_cx.scaled_rems(1.0)) - .child(rendered_block), - ) - .child(indicator.absolute().left_0().top_0()), - ) - }) - .into_any() - }) - } else { - div().into_any() - } - }, - ); + let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let mut this = Self { selected_block: 0, @@ -607,10 +501,107 @@ impl Render for MarkdownPreviewView { .p_4() .text_size(buffer_size) .line_height(relative(buffer_line_height.value())) - .child( - div() - .flex_grow() - .map(|this| this.child(list(self.list_state.clone()).size_full())), - ) + .child(div().flex_grow().map(|this| { + this.child( + list( + self.list_state.clone(), + cx.processor(|this, ix, window, cx| { + let Some(contents) = &this.contents else { + return div().into_any(); + }; + + let mut render_cx = + RenderContext::new(Some(this.workspace.clone()), window, cx) + .with_checkbox_clicked_callback(cx.listener( + move |this, e: &CheckboxClickedEvent, window, cx| { + if let Some(editor) = this + .active_editor + .as_ref() + .map(|s| s.editor.clone()) + { + editor.update(cx, |editor, cx| { + let task_marker = + if e.checked() { "[x]" } else { "[ ]" }; + + editor.edit( + vec![(e.source_range(), task_marker)], + cx, + ); + }); + this.parse_markdown_from_active_editor( + false, window, cx, + ); + cx.notify(); + } + }, + )); + + let block = contents.children.get(ix).unwrap(); + let rendered_block = render_markdown_block(block, &mut render_cx); + + let should_apply_padding = Self::should_apply_padding_between( + block, + contents.children.get(ix + 1), + ); + + div() + .id(ix) + .when(should_apply_padding, |this| { + this.pb(render_cx.scaled_rems(0.75)) + }) + .group("markdown-block") + .on_click(cx.listener( + move |this, event: &ClickEvent, window, cx| { + if event.click_count() == 2 { + if let Some(source_range) = this + .contents + .as_ref() + .and_then(|c| c.children.get(ix)) + .and_then(|block: &ParsedMarkdownElement| { + block.source_range() + }) + { + this.move_cursor_to_block( + window, + cx, + source_range.start..source_range.start, + ); + } + } + }, + )) + .map(move |container| { + let indicator = div() + .h_full() + .w(px(4.0)) + .when(ix == this.selected_block, |this| { + this.bg(cx.theme().colors().border) + }) + .group_hover("markdown-block", |s| { + if ix == this.selected_block { + s + } else { + s.bg(cx.theme().colors().border_variant) + } + }) + .rounded_xs(); + + container.child( + div() + .relative() + .child( + div() + .pl(render_cx.scaled_rems(1.0)) + .child(rendered_block), + ) + .child(indicator.absolute().left_0().top_0()), + ) + }) + .into_any() + }), + ) + .size_full(), + ) + })) } } diff --git a/crates/markdown_preview/src/markdown_renderer.rs b/crates/markdown_preview/src/markdown_renderer.rs index 80bed8a6e8..37d2ca2110 100644 --- a/crates/markdown_preview/src/markdown_renderer.rs +++ b/crates/markdown_preview/src/markdown_renderer.rs @@ -26,7 +26,22 @@ use ui::{ }; use workspace::{OpenOptions, OpenVisible, Workspace}; -type CheckboxClickedCallback = Arc<Box<dyn Fn(bool, Range<usize>, &mut Window, &mut App)>>; +pub struct CheckboxClickedEvent { + pub checked: bool, + pub source_range: Range<usize>, +} + +impl CheckboxClickedEvent { + pub fn source_range(&self) -> Range<usize> { + self.source_range.clone() + } + + pub fn checked(&self) -> bool { + self.checked + } +} + +type CheckboxClickedCallback = Arc<Box<dyn Fn(&CheckboxClickedEvent, &mut Window, &mut App)>>; #[derive(Clone)] pub struct RenderContext { @@ -80,7 +95,7 @@ impl RenderContext { pub fn with_checkbox_clicked_callback( mut self, - callback: impl Fn(bool, Range<usize>, &mut Window, &mut App) + 'static, + callback: impl Fn(&CheckboxClickedEvent, &mut Window, &mut App) + 'static, ) -> Self { self.checkbox_clicked_callback = Some(Arc::new(Box::new(callback))); self @@ -229,7 +244,14 @@ fn render_markdown_list_item( }; if window.modifiers().secondary() { - callback(checked, range.clone(), window, cx); + callback( + &CheckboxClickedEvent { + checked, + source_range: range.clone(), + }, + window, + cx, + ); } } }) diff --git a/crates/migrator/src/migrations/m_2025_01_29/keymap.rs b/crates/migrator/src/migrations/m_2025_01_29/keymap.rs index c32da88229..646af8f63d 100644 --- a/crates/migrator/src/migrations/m_2025_01_29/keymap.rs +++ b/crates/migrator/src/migrations/m_2025_01_29/keymap.rs @@ -242,22 +242,22 @@ static STRING_REPLACE: LazyLock<HashMap<&str, &str>> = LazyLock::new(|| { "inline_completion::ToggleMenu", "edit_prediction::ToggleMenu", ), - ("editor::NextInlineCompletion", "editor::NextEditPrediction"), + ("editor::NextEditPrediction", "editor::NextEditPrediction"), ( - "editor::PreviousInlineCompletion", + "editor::PreviousEditPrediction", "editor::PreviousEditPrediction", ), ( - "editor::AcceptPartialInlineCompletion", + "editor::AcceptPartialEditPrediction", "editor::AcceptPartialEditPrediction", ), - ("editor::ShowInlineCompletion", "editor::ShowEditPrediction"), + ("editor::ShowEditPrediction", "editor::ShowEditPrediction"), ( - "editor::AcceptInlineCompletion", + "editor::AcceptEditPrediction", "editor::AcceptEditPrediction", ), ( - "editor::ToggleInlineCompletions", + "editor::ToggleEditPredictions", "editor::ToggleEditPrediction", ), ]) diff --git a/crates/multi_buffer/src/anchor.rs b/crates/multi_buffer/src/anchor.rs index 9e28295c56..1305328d38 100644 --- a/crates/multi_buffer/src/anchor.rs +++ b/crates/multi_buffer/src/anchor.rs @@ -167,10 +167,10 @@ impl Anchor { if *self == Anchor::min() || *self == Anchor::max() { true } else if let Some(excerpt) = snapshot.excerpt(self.excerpt_id) { - excerpt.contains(self) - && (self.text_anchor == excerpt.range.context.start - || self.text_anchor == excerpt.range.context.end - || self.text_anchor.is_valid(&excerpt.buffer)) + (self.text_anchor == excerpt.range.context.start + || self.text_anchor == excerpt.range.context.end + || self.text_anchor.is_valid(&excerpt.buffer)) + && excerpt.contains(self) } else { false } diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index f0913e30fb..eb12e6929c 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -43,7 +43,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use sum_tree::{Bias, Cursor, Dimension, SumTree, Summary, TreeMap}; +use sum_tree::{Bias, Cursor, Dimension, Dimensions, SumTree, Summary, TreeMap}; use text::{ BufferId, Edit, LineIndent, TextSummary, locator::Locator, @@ -474,7 +474,7 @@ pub struct MultiBufferRows<'a> { pub struct MultiBufferChunks<'a> { excerpts: Cursor<'a, Excerpt, ExcerptOffset>, - diff_transforms: Cursor<'a, DiffTransform, (usize, ExcerptOffset)>, + diff_transforms: Cursor<'a, DiffTransform, Dimensions<usize, ExcerptOffset>>, diffs: &'a TreeMap<BufferId, BufferDiffSnapshot>, diff_base_chunks: Option<(BufferId, BufferChunks<'a>)>, buffer_chunk: Option<Chunk<'a>>, @@ -2120,10 +2120,10 @@ impl MultiBuffer { let buffers = self.buffers.borrow(); let mut excerpts = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptDimension<Point>)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptDimension<Point>>>(&()); let mut diff_transforms = snapshot .diff_transforms - .cursor::<(ExcerptDimension<Point>, OutputDimension<Point>)>(&()); + .cursor::<Dimensions<ExcerptDimension<Point>, OutputDimension<Point>>>(&()); diff_transforms.next(); let locators = buffers .get(&buffer_id) @@ -2281,7 +2281,7 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let mut edits = Vec::new(); let mut excerpt_ids = ids.iter().copied().peekable(); let mut removed_buffer_ids = Vec::new(); @@ -2492,7 +2492,7 @@ impl MultiBuffer { for locator in &buffer_state.excerpts { let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); cursor.seek_forward(&Some(locator), Bias::Left); if let Some(excerpt) = cursor.item() { if excerpt.locator == *locator { @@ -2845,7 +2845,7 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let mut edits = Vec::<Edit<ExcerptOffset>>::new(); let prefix = cursor.slice(&Some(locator), Bias::Left); @@ -2921,7 +2921,7 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let mut edits = Vec::<Edit<ExcerptOffset>>::new(); for locator in &locators { @@ -3067,7 +3067,7 @@ impl MultiBuffer { let mut new_excerpts = SumTree::default(); let mut cursor = snapshot .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); for (locator, buffer, buffer_edited) in excerpts_to_edit { new_excerpts.append(cursor.slice(&Some(locator), Bias::Left), &()); @@ -3135,7 +3135,7 @@ impl MultiBuffer { let mut excerpts = snapshot.excerpts.cursor::<ExcerptOffset>(&()); let mut old_diff_transforms = snapshot .diff_transforms - .cursor::<(ExcerptOffset, usize)>(&()); + .cursor::<Dimensions<ExcerptOffset, usize>>(&()); let mut new_diff_transforms = SumTree::default(); let mut old_expanded_hunks = HashSet::default(); let mut output_edits = Vec::new(); @@ -3260,7 +3260,7 @@ impl MultiBuffer { &self, edit: &Edit<TypedOffset<Excerpt>>, excerpts: &mut Cursor<Excerpt, TypedOffset<Excerpt>>, - old_diff_transforms: &mut Cursor<DiffTransform, (TypedOffset<Excerpt>, usize)>, + old_diff_transforms: &mut Cursor<DiffTransform, Dimensions<TypedOffset<Excerpt>, usize>>, new_diff_transforms: &mut SumTree<DiffTransform>, end_of_current_insert: &mut Option<(TypedOffset<Excerpt>, DiffTransformHunkInfo)>, old_expanded_hunks: &mut HashSet<DiffTransformHunkInfo>, @@ -4713,7 +4713,9 @@ impl MultiBufferSnapshot { O: ToOffset, { let range = range.start.to_offset(self)..range.end.to_offset(self); - let mut cursor = self.diff_transforms.cursor::<(usize, ExcerptOffset)>(&()); + let mut cursor = self + .diff_transforms + .cursor::<Dimensions<usize, ExcerptOffset>>(&()); cursor.seek(&range.start, Bias::Right); let Some(first_transform) = cursor.item() else { @@ -4867,7 +4869,10 @@ impl MultiBufferSnapshot { &self, anchor: &Anchor, excerpt_position: D, - diff_transforms: &mut Cursor<DiffTransform, (ExcerptDimension<D>, OutputDimension<D>)>, + diff_transforms: &mut Cursor< + DiffTransform, + Dimensions<ExcerptDimension<D>, OutputDimension<D>>, + >, ) -> D where D: TextDimension + Ord + Sub<D, Output = D>, @@ -4927,7 +4932,7 @@ impl MultiBufferSnapshot { fn excerpt_offset_for_anchor(&self, anchor: &Anchor) -> ExcerptOffset { let mut cursor = self .excerpts - .cursor::<(Option<&Locator>, ExcerptOffset)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptOffset>>(&()); let locator = self.excerpt_locator_for_id(anchor.excerpt_id); cursor.seek(&Some(locator), Bias::Left); @@ -4971,7 +4976,7 @@ impl MultiBufferSnapshot { let mut cursor = self.excerpts.cursor::<ExcerptSummary>(&()); let mut diff_transforms_cursor = self .diff_transforms - .cursor::<(ExcerptDimension<D>, OutputDimension<D>)>(&()); + .cursor::<Dimensions<ExcerptDimension<D>, OutputDimension<D>>>(&()); diff_transforms_cursor.next(); let mut summaries = Vec::new(); @@ -5201,7 +5206,9 @@ impl MultiBufferSnapshot { // Find the given position in the diff transforms. Determine the corresponding // offset in the excerpts, and whether the position is within a deleted hunk. - let mut diff_transforms = self.diff_transforms.cursor::<(usize, ExcerptOffset)>(&()); + let mut diff_transforms = self + .diff_transforms + .cursor::<Dimensions<usize, ExcerptOffset>>(&()); diff_transforms.seek(&offset, Bias::Right); if offset == diff_transforms.start().0 && bias == Bias::Left { @@ -5250,7 +5257,7 @@ impl MultiBufferSnapshot { let mut excerpts = self .excerpts - .cursor::<(ExcerptOffset, Option<ExcerptId>)>(&()); + .cursor::<Dimensions<ExcerptOffset, Option<ExcerptId>>>(&()); excerpts.seek(&excerpt_offset, Bias::Right); if excerpts.item().is_none() && excerpt_offset == excerpts.start().0 && bias == Bias::Left { excerpts.prev(); @@ -5341,7 +5348,7 @@ impl MultiBufferSnapshot { let start_locator = self.excerpt_locator_for_id(id); let mut excerpts = self .excerpts - .cursor::<(Option<&Locator>, ExcerptDimension<usize>)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptDimension<usize>>>(&()); excerpts.seek(&Some(start_locator), Bias::Left); excerpts.prev(); @@ -6242,14 +6249,14 @@ impl MultiBufferSnapshot { pub fn range_for_excerpt(&self, excerpt_id: ExcerptId) -> Option<Range<Point>> { let mut cursor = self .excerpts - .cursor::<(Option<&Locator>, ExcerptDimension<Point>)>(&()); + .cursor::<Dimensions<Option<&Locator>, ExcerptDimension<Point>>>(&()); let locator = self.excerpt_locator_for_id(excerpt_id); if cursor.seek(&Some(locator), Bias::Left) { let start = cursor.start().1.clone(); let end = cursor.end().1; let mut diff_transforms = self .diff_transforms - .cursor::<(ExcerptDimension<Point>, OutputDimension<Point>)>(&()); + .cursor::<Dimensions<ExcerptDimension<Point>, OutputDimension<Point>>>(&()); diff_transforms.seek(&start, Bias::Left); let overshoot = start.0 - diff_transforms.start().0.0; let start = diff_transforms.start().1.0 + overshoot; diff --git a/crates/notifications/src/notification_store.rs b/crates/notifications/src/notification_store.rs index 0329a53cc7..29653748e4 100644 --- a/crates/notifications/src/notification_store.rs +++ b/crates/notifications/src/notification_store.rs @@ -6,7 +6,7 @@ use db::smol::stream::StreamExt; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task}; use rpc::{Notification, TypedEnvelope, proto}; use std::{ops::Range, sync::Arc}; -use sum_tree::{Bias, SumTree}; +use sum_tree::{Bias, Dimensions, SumTree}; use time::OffsetDateTime; use util::ResultExt; @@ -360,7 +360,9 @@ impl NotificationStore { is_new: bool, cx: &mut Context<NotificationStore>, ) { - let mut cursor = self.notifications.cursor::<(NotificationId, Count)>(&()); + let mut cursor = self + .notifications + .cursor::<Dimensions<NotificationId, Count>>(&()); let mut new_notifications = SumTree::default(); let mut old_range = 0..0; diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 62c32b4161..64cd1cc0cb 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -58,7 +58,7 @@ fn get_max_tokens(name: &str) -> u64 { "magistral" => 40000, "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r" | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder" - | "devstral" => 128000, + | "devstral" | "gpt-oss" => 128000, _ => DEFAULT_TOKENS, } .clamp(1, MAXIMUM_TOKENS) diff --git a/crates/onboarding/Cargo.toml b/crates/onboarding/Cargo.toml index 693e39d4ca..436c714cf3 100644 --- a/crates/onboarding/Cargo.toml +++ b/crates/onboarding/Cargo.toml @@ -15,14 +15,33 @@ path = "src/onboarding.rs" default = [] [dependencies] +ai_onboarding.workspace = true anyhow.workspace = true +client.workspace = true command_palette_hooks.workspace = true +component.workspace = true db.workspace = true +documented.workspace = true +editor.workspace = true feature_flags.workspace = true fs.workspace = true +fuzzy.workspace = true gpui.workspace = true +itertools.workspace = true +language.workspace = true +language_model.workspace = true +menu.workspace = true +notifications.workspace = true +picker.workspace = true +project.workspace = true +schemars.workspace = true +serde.workspace = true settings.workspace = true theme.workspace = true ui.workspace = true -workspace.workspace = true +util.workspace = true +vim_mode_setting.workspace = true workspace-hack.workspace = true +workspace.workspace = true +zed_actions.workspace = true +zlog.workspace = true diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs new file mode 100644 index 0000000000..098907870b --- /dev/null +++ b/crates/onboarding/src/ai_setup_page.rs @@ -0,0 +1,420 @@ +use std::sync::Arc; + +use ai_onboarding::{AiUpsellCard, SignInStatus}; +use client::UserStore; +use fs::Fs; +use gpui::{ + Action, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, WeakEntity, + Window, prelude::*, +}; +use itertools; +use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry}; +use project::DisableAiSettings; +use settings::{Settings, update_settings_file}; +use ui::{ + Badge, ButtonLike, Divider, Modal, ModalFooter, ModalHeader, Section, SwitchField, ToggleState, + prelude::*, tooltip_container, +}; +use util::ResultExt; +use workspace::{ModalView, Workspace}; +use zed_actions::agent::OpenSettings; + +const FEATURED_PROVIDERS: [&'static str; 4] = ["anthropic", "google", "openai", "ollama"]; + +fn render_llm_provider_section( + tab_index: &mut isize, + workspace: WeakEntity<Workspace>, + disabled: bool, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + v_flex() + .gap_4() + .child( + v_flex() + .child(Label::new("Or use other LLM providers").size(LabelSize::Large)) + .child( + Label::new("Bring your API keys to use the available providers with Zed's UI for free.") + .color(Color::Muted), + ), + ) + .child(render_llm_provider_card(tab_index, workspace, disabled, window, cx)) +} + +fn render_privacy_card(tab_index: &mut isize, disabled: bool, cx: &mut App) -> impl IntoElement { + let privacy_badge = || { + Badge::new("Privacy") + .icon(IconName::ShieldCheck) + .tooltip(move |_, cx| cx.new(|_| AiPrivacyTooltip::new()).into()) + }; + + v_flex() + .relative() + .pt_2() + .pb_2p5() + .pl_3() + .pr_2() + .border_1() + .border_dashed() + .border_color(cx.theme().colors().border.opacity(0.5)) + .bg(cx.theme().colors().surface_background.opacity(0.3)) + .rounded_lg() + .overflow_hidden() + .map(|this| { + if disabled { + this.child( + h_flex() + .gap_2() + .justify_between() + .child( + h_flex() + .gap_1() + .child(Label::new("AI is disabled across Zed")) + .child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ), + ) + .child(privacy_badge()), + ) + .child( + Label::new("Re-enable it any time in Settings.") + .size(LabelSize::Small) + .color(Color::Muted), + ) + } else { + this.child( + h_flex() + .gap_2() + .justify_between() + .child(Label::new("We don't train models using your data")) + .child( + h_flex().gap_1().child(privacy_badge()).child( + Button::new("learn_more", "Learn More") + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(|_, _, cx| { + cx.open_url("https://zed.dev/docs/ai/privacy-and-security"); + }) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ), + ), + ) + .child( + Label::new( + "Feel confident in the security and privacy of your projects using Zed.", + ) + .size(LabelSize::Small) + .color(Color::Muted), + ) + } + }) +} + +fn render_llm_provider_card( + tab_index: &mut isize, + workspace: WeakEntity<Workspace>, + disabled: bool, + _: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let registry = LanguageModelRegistry::read_global(cx); + + v_flex() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().surface_background.opacity(0.5)) + .rounded_lg() + .overflow_hidden() + .children(itertools::intersperse_with( + FEATURED_PROVIDERS + .into_iter() + .flat_map(|provider_name| { + registry.provider(&LanguageModelProviderId::new(provider_name)) + }) + .enumerate() + .map(|(index, provider)| { + let group_name = SharedString::new(format!("onboarding-hover-group-{}", index)); + let is_authenticated = provider.is_authenticated(cx); + + ButtonLike::new(("onboarding-ai-setup-buttons", index)) + .size(ButtonSize::Large) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .child( + h_flex() + .group(&group_name) + .px_0p5() + .w_full() + .gap_2() + .justify_between() + .child( + h_flex() + .gap_1() + .child( + Icon::new(provider.icon()) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(provider.name().0)), + ) + .child( + h_flex() + .gap_1() + .when(!is_authenticated, |el| { + el.visible_on_hover(group_name.clone()) + .child( + Icon::new(IconName::Settings) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child( + Label::new("Configure") + .color(Color::Muted) + .size(LabelSize::Small), + ) + }) + .when(is_authenticated && !disabled, |el| { + el.child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ) + .child( + Label::new("Configured") + .color(Color::Muted) + .size(LabelSize::Small), + ) + }), + ), + ) + .on_click({ + let workspace = workspace.clone(); + move |_, window, cx| { + workspace + .update(cx, |workspace, cx| { + workspace.toggle_modal(window, cx, |window, cx| { + let modal = AiConfigurationModal::new( + provider.clone(), + window, + cx, + ); + window.focus(&modal.focus_handle(cx)); + modal + }); + }) + .log_err(); + } + }) + .into_any_element() + }), + || Divider::horizontal().into_any_element(), + )) + .child(Divider::horizontal()) + .child( + Button::new("agent_settings", "Add Many Others") + .size(ButtonSize::Large) + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click(|_event, window, cx| { + window.dispatch_action(OpenSettings.boxed_clone(), cx) + }) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) +} + +pub(crate) fn render_ai_setup_page( + workspace: WeakEntity<Workspace>, + user_store: Entity<UserStore>, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let mut tab_index = 0; + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + + v_flex() + .gap_2() + .child( + SwitchField::new( + "enable_ai", + "Enable AI features", + None, + if is_ai_disabled { + ToggleState::Unselected + } else { + ToggleState::Selected + }, + |&toggle_state, _, cx| { + let fs = <dyn Fs>::global(cx); + update_settings_file::<DisableAiSettings>( + fs, + cx, + move |ai_settings: &mut Option<bool>, _| { + *ai_settings = match toggle_state { + ToggleState::Indeterminate => None, + ToggleState::Unselected => Some(true), + ToggleState::Selected => Some(false), + }; + }, + ); + }, + ) + .tab_index({ + tab_index += 1; + tab_index - 1 + }), + ) + .child(render_privacy_card(&mut tab_index, is_ai_disabled, cx)) + .child( + v_flex() + .mt_2() + .gap_6() + .child(AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + user_plan: user_store.read(cx).plan(), + tab_index: Some({ + tab_index += 1; + tab_index - 1 + }), + }) + .child(render_llm_provider_section( + &mut tab_index, + workspace, + is_ai_disabled, + window, + cx, + )) + .when(is_ai_disabled, |this| { + this.child( + div() + .id("backdrop") + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().editor_background) + .opacity(0.8) + .block_mouse_except_scroll(), + ) + }), + ) +} + +struct AiConfigurationModal { + focus_handle: FocusHandle, + selected_provider: Arc<dyn LanguageModelProvider>, + configuration_view: AnyView, +} + +impl AiConfigurationModal { + fn new( + selected_provider: Arc<dyn LanguageModelProvider>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let focus_handle = cx.focus_handle(); + let configuration_view = selected_provider.configuration_view(window, cx); + + Self { + focus_handle, + configuration_view, + selected_provider, + } + } +} + +impl ModalView for AiConfigurationModal {} + +impl EventEmitter<DismissEvent> for AiConfigurationModal {} + +impl Focusable for AiConfigurationModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for AiConfigurationModal { + fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + v_flex() + .w(rems(34.)) + .elevation_3(cx) + .track_focus(&self.focus_handle) + .child( + Modal::new("onboarding-ai-setup-modal", None) + .header( + ModalHeader::new() + .icon( + Icon::new(self.selected_provider.icon()) + .color(Color::Muted) + .size(IconSize::Small), + ) + .headline(self.selected_provider.name().0), + ) + .section(Section::new().child(self.configuration_view.clone())) + .footer( + ModalFooter::new().end_slot( + h_flex() + .gap_1() + .child( + Button::new("onboarding-closing-cancel", "Cancel") + .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + ) + .child(Button::new("save-btn", "Done").on_click(cx.listener( + |_, _, window, cx| { + window.dispatch_action(menu::Confirm.boxed_clone(), cx); + cx.emit(DismissEvent); + }, + ))), + ), + ), + ) + } +} + +pub struct AiPrivacyTooltip {} + +impl AiPrivacyTooltip { + pub fn new() -> Self { + Self {} + } +} + +impl Render for AiPrivacyTooltip { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + const DESCRIPTION: &'static str = "One of Zed's most important principles is transparency. This is why we are and value open-source so much. And it wouldn't be any different with AI."; + + tooltip_container(window, cx, move |this, _, _| { + this.child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::ShieldCheck) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child(Label::new("Privacy Principle")), + ) + .child( + div().max_w_64().child( + Label::new(DESCRIPTION) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + }) + } +} diff --git a/crates/onboarding/src/basics_page.rs b/crates/onboarding/src/basics_page.rs new file mode 100644 index 0000000000..a4e4028051 --- /dev/null +++ b/crates/onboarding/src/basics_page.rs @@ -0,0 +1,361 @@ +use std::sync::Arc; + +use client::TelemetrySettings; +use fs::Fs; +use gpui::{App, IntoElement}; +use settings::{BaseKeymap, Settings, update_settings_file}; +use theme::{ + Appearance, SystemAppearance, ThemeMode, ThemeName, ThemeRegistry, ThemeSelection, + ThemeSettings, +}; +use ui::{ + ParentElement as _, StatefulInteractiveElement, SwitchField, ToggleButtonGroup, + ToggleButtonSimple, ToggleButtonWithIcon, prelude::*, rems_from_px, +}; +use vim_mode_setting::VimModeSetting; + +use crate::theme_preview::{ThemePreviewStyle, ThemePreviewTile}; + +fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { + let theme_selection = ThemeSettings::get_global(cx).theme_selection.clone(); + let system_appearance = theme::SystemAppearance::global(cx); + let theme_selection = theme_selection.unwrap_or_else(|| ThemeSelection::Dynamic { + mode: match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }, + light: ThemeName("One Light".into()), + dark: ThemeName("One Dark".into()), + }); + + let theme_mode = theme_selection + .mode() + .unwrap_or_else(|| match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }); + + return v_flex() + .gap_2() + .child( + h_flex().justify_between().child(Label::new("Theme")).child( + ToggleButtonGroup::single_row( + "theme-selector-onboarding-dark-light", + [ThemeMode::Light, ThemeMode::Dark, ThemeMode::System].map(|mode| { + const MODE_NAMES: [SharedString; 3] = [ + SharedString::new_static("Light"), + SharedString::new_static("Dark"), + SharedString::new_static("System"), + ]; + ToggleButtonSimple::new( + MODE_NAMES[mode as usize].clone(), + move |_, _, cx| { + write_mode_change(mode, cx); + }, + ) + }), + ) + .tab_index(tab_index) + .selected_index(theme_mode as usize) + .style(ui::ToggleButtonGroupStyle::Outlined) + .button_width(rems_from_px(64.)), + ), + ) + .child( + h_flex() + .gap_4() + .justify_between() + .children(render_theme_previews(tab_index, &theme_selection, cx)), + ); + + fn render_theme_previews( + tab_index: &mut isize, + theme_selection: &ThemeSelection, + cx: &mut App, + ) -> [impl IntoElement; 3] { + let system_appearance = SystemAppearance::global(cx); + let theme_registry = ThemeRegistry::global(cx); + + let theme_seed = 0xBEEF as f32; + let theme_mode = theme_selection + .mode() + .unwrap_or_else(|| match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }); + let appearance = match theme_mode { + ThemeMode::Light => Appearance::Light, + ThemeMode::Dark => Appearance::Dark, + ThemeMode::System => *system_appearance, + }; + let current_theme_name = theme_selection.theme(appearance); + + const LIGHT_THEMES: [&'static str; 3] = ["One Light", "Ayu Light", "Gruvbox Light"]; + const DARK_THEMES: [&'static str; 3] = ["One Dark", "Ayu Dark", "Gruvbox Dark"]; + const FAMILY_NAMES: [SharedString; 3] = [ + SharedString::new_static("One"), + SharedString::new_static("Ayu"), + SharedString::new_static("Gruvbox"), + ]; + + let theme_names = match appearance { + Appearance::Light => LIGHT_THEMES, + Appearance::Dark => DARK_THEMES, + }; + + let themes = theme_names.map(|theme| theme_registry.get(theme).unwrap()); + + let theme_previews = [0, 1, 2].map(|index| { + let theme = &themes[index]; + let is_selected = theme.name == current_theme_name; + let name = theme.name.clone(); + let colors = cx.theme().colors(); + + v_flex() + .w_full() + .items_center() + .gap_1() + .child( + h_flex() + .id(name.clone()) + .relative() + .w_full() + .border_2() + .border_color(colors.border_transparent) + .rounded(ThemePreviewTile::ROOT_RADIUS) + .map(|this| { + if is_selected { + this.border_color(colors.border_selected) + } else { + this.opacity(0.8).hover(|s| s.border_color(colors.border)) + } + }) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .focus(|mut style| { + style.border_color = Some(colors.border_focused); + style + }) + .on_click({ + let theme_name = theme.name.clone(); + move |_, _, cx| { + write_theme_change(theme_name.clone(), theme_mode, cx); + } + }) + .map(|this| { + if theme_mode == ThemeMode::System { + let (light, dark) = ( + theme_registry.get(LIGHT_THEMES[index]).unwrap(), + theme_registry.get(DARK_THEMES[index]).unwrap(), + ); + this.child( + ThemePreviewTile::new(light, theme_seed) + .style(ThemePreviewStyle::SideBySide(dark)), + ) + } else { + this.child( + ThemePreviewTile::new(theme.clone(), theme_seed) + .style(ThemePreviewStyle::Bordered), + ) + } + }), + ) + .child( + Label::new(FAMILY_NAMES[index].clone()) + .color(Color::Muted) + .size(LabelSize::Small), + ) + }); + + theme_previews + } + + fn write_mode_change(mode: ThemeMode, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + update_settings_file::<ThemeSettings>(fs, cx, move |settings, _cx| { + settings.set_mode(mode); + }); + } + + fn write_theme_change(theme: impl Into<Arc<str>>, theme_mode: ThemeMode, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + let theme = theme.into(); + update_settings_file::<ThemeSettings>(fs, cx, move |settings, cx| { + if theme_mode == ThemeMode::System { + settings.theme = Some(ThemeSelection::Dynamic { + mode: ThemeMode::System, + light: ThemeName(theme.clone()), + dark: ThemeName(theme.clone()), + }); + } else { + let appearance = *SystemAppearance::global(cx); + settings.set_theme(theme.clone(), appearance); + } + }); + } +} + +fn render_telemetry_section(tab_index: &mut isize, cx: &App) -> impl IntoElement { + let fs = <dyn Fs>::global(cx); + + v_flex() + .gap_4() + .child(Label::new("Telemetry").size(LabelSize::Large)) + .child(SwitchField::new( + "onboarding-telemetry-metrics", + "Help Improve Zed", + Some("Sending anonymous usage data helps us build the right features and create the best experience.".into()), + if TelemetrySettings::get_global(cx).metrics { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + { + let fs = fs.clone(); + move |selection, _, cx| { + let enabled = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { return; }, + }; + + update_settings_file::<TelemetrySettings>( + fs.clone(), + cx, + move |setting, _| setting.metrics = Some(enabled), + ); + }}, + ).tab_index({ + *tab_index += 1; + *tab_index + })) + .child(SwitchField::new( + "onboarding-telemetry-crash-reports", + "Help Fix Zed", + Some("Send crash reports so we can fix critical issues fast.".into()), + if TelemetrySettings::get_global(cx).diagnostics { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + { + let fs = fs.clone(); + move |selection, _, cx| { + let enabled = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { return; }, + }; + + update_settings_file::<TelemetrySettings>( + fs.clone(), + cx, + move |setting, _| setting.diagnostics = Some(enabled), + ); + } + } + ).tab_index({ + *tab_index += 1; + *tab_index + })) +} + +fn render_base_keymap_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { + let base_keymap = match BaseKeymap::get_global(cx) { + BaseKeymap::VSCode => Some(0), + BaseKeymap::JetBrains => Some(1), + BaseKeymap::SublimeText => Some(2), + BaseKeymap::Atom => Some(3), + BaseKeymap::Emacs => Some(4), + BaseKeymap::Cursor => Some(5), + BaseKeymap::TextMate | BaseKeymap::None => None, + }; + + return v_flex().gap_2().child(Label::new("Base Keymap")).child( + ToggleButtonGroup::two_rows( + "base_keymap_selection", + [ + ToggleButtonWithIcon::new("VS Code", IconName::EditorVsCode, |_, _, cx| { + write_keymap_base(BaseKeymap::VSCode, cx); + }), + ToggleButtonWithIcon::new("Jetbrains", IconName::EditorJetBrains, |_, _, cx| { + write_keymap_base(BaseKeymap::JetBrains, cx); + }), + ToggleButtonWithIcon::new("Sublime Text", IconName::EditorSublime, |_, _, cx| { + write_keymap_base(BaseKeymap::SublimeText, cx); + }), + ], + [ + ToggleButtonWithIcon::new("Atom", IconName::EditorAtom, |_, _, cx| { + write_keymap_base(BaseKeymap::Atom, cx); + }), + ToggleButtonWithIcon::new("Emacs", IconName::EditorEmacs, |_, _, cx| { + write_keymap_base(BaseKeymap::Emacs, cx); + }), + ToggleButtonWithIcon::new("Cursor (Beta)", IconName::EditorCursor, |_, _, cx| { + write_keymap_base(BaseKeymap::Cursor, cx); + }), + ], + ) + .when_some(base_keymap, |this, base_keymap| { + this.selected_index(base_keymap) + }) + .tab_index(tab_index) + .button_width(rems_from_px(216.)) + .size(ui::ToggleButtonGroupSize::Medium) + .style(ui::ToggleButtonGroupStyle::Outlined), + ); + + fn write_keymap_base(keymap_base: BaseKeymap, cx: &App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<BaseKeymap>(fs, cx, move |setting, _| { + *setting = Some(keymap_base); + }); + } +} + +fn render_vim_mode_switch(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { + let toggle_state = if VimModeSetting::get_global(cx).0 { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }; + SwitchField::new( + "onboarding-vim-mode", + "Vim Mode", + Some( + "Coming from Neovim? Zed's first-class implementation of Vim Mode has got your back." + .into(), + ), + toggle_state, + { + let fs = <dyn Fs>::global(cx); + move |&selection, _, cx| { + update_settings_file::<VimModeSetting>(fs.clone(), cx, move |setting, _| { + *setting = match selection { + ToggleState::Selected => Some(true), + ToggleState::Unselected => Some(false), + ToggleState::Indeterminate => None, + } + }); + } + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) +} + +pub(crate) fn render_basics_page(cx: &mut App) -> impl IntoElement { + let mut tab_index = 0; + v_flex() + .gap_6() + .child(render_theme_section(&mut tab_index, cx)) + .child(render_base_keymap_section(&mut tab_index, cx)) + .child(render_vim_mode_switch(&mut tab_index, cx)) + .child(render_telemetry_section(&mut tab_index, cx)) +} diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs new file mode 100644 index 0000000000..a8f0265b6b --- /dev/null +++ b/crates/onboarding/src/editing_page.rs @@ -0,0 +1,713 @@ +use std::sync::Arc; + +use editor::{EditorSettings, ShowMinimap}; +use fs::Fs; +use fuzzy::{StringMatch, StringMatchCandidate}; +use gpui::{ + Action, AnyElement, App, Context, FontFeatures, IntoElement, Pixels, SharedString, Task, Window, +}; +use language::language_settings::{AllLanguageSettings, FormatOnSave}; +use picker::{Picker, PickerDelegate}; +use project::project_settings::ProjectSettings; +use settings::{Settings as _, update_settings_file}; +use theme::{FontFamilyCache, FontFamilyName, ThemeSettings}; +use ui::{ + ButtonLike, ListItem, ListItemSpacing, NumericStepper, PopoverMenu, SwitchField, + ToggleButtonGroup, ToggleButtonGroupStyle, ToggleButtonSimple, ToggleState, Tooltip, + prelude::*, +}; + +use crate::{ImportCursorSettings, ImportVsCodeSettings, SettingsImportState}; + +fn read_show_mini_map(cx: &App) -> ShowMinimap { + editor::EditorSettings::get_global(cx).minimap.show +} + +fn write_show_mini_map(show: ShowMinimap, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + // This is used to speed up the UI + // the UI reads the current values to get what toggle state to show on buttons + // there's a slight delay if we just call update_settings_file so we manually set + // the value here then call update_settings file to get around the delay + let mut curr_settings = EditorSettings::get_global(cx).clone(); + curr_settings.minimap.show = show; + EditorSettings::override_global(curr_settings, cx); + + update_settings_file::<EditorSettings>(fs, cx, move |editor_settings, _| { + editor_settings.minimap.get_or_insert_default().show = Some(show); + }); +} + +fn read_inlay_hints(cx: &App) -> bool { + AllLanguageSettings::get_global(cx) + .defaults + .inlay_hints + .enabled +} + +fn write_inlay_hints(enabled: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + let mut curr_settings = AllLanguageSettings::get_global(cx).clone(); + curr_settings.defaults.inlay_hints.enabled = enabled; + AllLanguageSettings::override_global(curr_settings, cx); + + update_settings_file::<AllLanguageSettings>(fs, cx, move |all_language_settings, cx| { + all_language_settings + .defaults + .inlay_hints + .get_or_insert_with(|| { + AllLanguageSettings::get_global(cx) + .clone() + .defaults + .inlay_hints + }) + .enabled = enabled; + }); +} + +fn read_git_blame(cx: &App) -> bool { + ProjectSettings::get_global(cx).git.inline_blame_enabled() +} + +fn set_git_blame(enabled: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + let mut curr_settings = ProjectSettings::get_global(cx).clone(); + curr_settings + .git + .inline_blame + .get_or_insert_default() + .enabled = enabled; + ProjectSettings::override_global(curr_settings, cx); + + update_settings_file::<ProjectSettings>(fs, cx, move |project_settings, _| { + project_settings + .git + .inline_blame + .get_or_insert_default() + .enabled = enabled; + }); +} + +fn write_ui_font_family(font: SharedString, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.ui_font_family = Some(FontFamilyName(font.into())); + }); +} + +fn write_ui_font_size(size: Pixels, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.ui_font_size = Some(size.into()); + }); +} + +fn write_buffer_font_size(size: Pixels, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.buffer_font_size = Some(size.into()); + }); +} + +fn write_buffer_font_family(font_family: SharedString, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + theme_settings.buffer_font_family = Some(FontFamilyName(font_family.into())); + }); +} + +fn read_font_ligatures(cx: &App) -> bool { + ThemeSettings::get_global(cx) + .buffer_font + .features + .is_calt_enabled() + .unwrap_or(true) +} + +fn write_font_ligatures(enabled: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + let bit = if enabled { 1 } else { 0 }; + + update_settings_file::<ThemeSettings>(fs, cx, move |theme_settings, _| { + let mut features = theme_settings + .buffer_font_features + .as_mut() + .map(|features| features.tag_value_list().to_vec()) + .unwrap_or_default(); + + if let Some(calt_index) = features.iter().position(|(tag, _)| tag == "calt") { + features[calt_index].1 = bit; + } else { + features.push(("calt".into(), bit)); + } + + theme_settings.buffer_font_features = Some(FontFeatures(Arc::new(features))); + }); +} + +fn read_format_on_save(cx: &App) -> bool { + match AllLanguageSettings::get_global(cx).defaults.format_on_save { + FormatOnSave::On | FormatOnSave::List(_) => true, + FormatOnSave::Off => false, + } +} + +fn write_format_on_save(format_on_save: bool, cx: &mut App) { + let fs = <dyn Fs>::global(cx); + + update_settings_file::<AllLanguageSettings>(fs, cx, move |language_settings, _| { + language_settings.defaults.format_on_save = Some(match format_on_save { + true => FormatOnSave::On, + false => FormatOnSave::Off, + }); + }); +} + +fn render_setting_import_button( + tab_index: isize, + label: SharedString, + icon_name: IconName, + action: &dyn Action, + imported: bool, +) -> impl IntoElement { + let action = action.boxed_clone(); + h_flex().w_full().child( + ButtonLike::new(label.clone()) + .full_width() + .style(ButtonStyle::Outlined) + .size(ButtonSize::Large) + .tab_index(tab_index) + .child( + h_flex() + .w_full() + .justify_between() + .child( + h_flex() + .gap_1p5() + .px_1() + .child( + Icon::new(icon_name) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(label)), + ) + .when(imported, |this| { + this.child( + h_flex() + .gap_1p5() + .child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ) + .child(Label::new("Imported").size(LabelSize::Small)), + ) + }), + ) + .on_click(move |_, window, cx| window.dispatch_action(action.boxed_clone(), cx)), + ) +} + +fn render_import_settings_section(tab_index: &mut isize, cx: &App) -> impl IntoElement { + let import_state = SettingsImportState::global(cx); + let imports: [(SharedString, IconName, &dyn Action, bool); 2] = [ + ( + "VS Code".into(), + IconName::EditorVsCode, + &ImportVsCodeSettings { skip_prompt: false }, + import_state.vscode, + ), + ( + "Cursor".into(), + IconName::EditorCursor, + &ImportCursorSettings { skip_prompt: false }, + import_state.cursor, + ), + ]; + + let [vscode, cursor] = imports.map(|(label, icon_name, action, imported)| { + *tab_index += 1; + render_setting_import_button(*tab_index - 1, label, icon_name, action, imported) + }); + + v_flex() + .gap_4() + .child( + v_flex() + .child(Label::new("Import Settings").size(LabelSize::Large)) + .child( + Label::new("Automatically pull your settings from other editors.") + .color(Color::Muted), + ), + ) + .child(h_flex().w_full().gap_4().child(vscode).child(cursor)) +} + +fn render_font_customization_section( + tab_index: &mut isize, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let theme_settings = ThemeSettings::get_global(cx); + let ui_font_size = theme_settings.ui_font_size(cx); + let ui_font_family = theme_settings.ui_font.family.clone(); + let buffer_font_family = theme_settings.buffer_font.family.clone(); + let buffer_font_size = theme_settings.buffer_font_size(cx); + + let ui_font_picker = + cx.new(|cx| font_picker(ui_font_family.clone(), write_ui_font_family, window, cx)); + + let buffer_font_picker = cx.new(|cx| { + font_picker( + buffer_font_family.clone(), + write_buffer_font_family, + window, + cx, + ) + }); + + let ui_font_handle = ui::PopoverMenuHandle::default(); + let buffer_font_handle = ui::PopoverMenuHandle::default(); + + h_flex() + .w_full() + .gap_4() + .child( + v_flex() + .w_full() + .gap_1() + .child(Label::new("UI Font")) + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + PopoverMenu::new("ui-font-picker") + .menu({ + let ui_font_picker = ui_font_picker.clone(); + move |_window, _cx| Some(ui_font_picker.clone()) + }) + .trigger( + ButtonLike::new("ui-font-family-button") + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .full_width() + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .child( + h_flex() + .w_full() + .justify_between() + .child(Label::new(ui_font_family)) + .child( + Icon::new(IconName::ChevronUpDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + ), + ) + .full_width(true) + .anchor(gpui::Corner::TopLeft) + .offset(gpui::Point { + x: px(0.0), + y: px(4.0), + }) + .with_handle(ui_font_handle), + ) + .child( + NumericStepper::new( + "ui-font-size", + ui_font_size.to_string(), + move |_, _, cx| { + write_ui_font_size(ui_font_size - px(1.), cx); + }, + move |_, _, cx| { + write_ui_font_size(ui_font_size + px(1.), cx); + }, + ) + .style(ui::NumericStepperStyle::Outlined) + .tab_index({ + *tab_index += 2; + *tab_index - 2 + }), + ), + ), + ) + .child( + v_flex() + .w_full() + .gap_1() + .child(Label::new("Editor Font")) + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + PopoverMenu::new("buffer-font-picker") + .menu({ + let buffer_font_picker = buffer_font_picker.clone(); + move |_window, _cx| Some(buffer_font_picker.clone()) + }) + .trigger( + ButtonLike::new("buffer-font-family-button") + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .full_width() + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .child( + h_flex() + .w_full() + .justify_between() + .child(Label::new(buffer_font_family)) + .child( + Icon::new(IconName::ChevronUpDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + ), + ) + .full_width(true) + .anchor(gpui::Corner::TopLeft) + .offset(gpui::Point { + x: px(0.0), + y: px(4.0), + }) + .with_handle(buffer_font_handle), + ) + .child( + NumericStepper::new( + "buffer-font-size", + buffer_font_size.to_string(), + move |_, _, cx| { + write_buffer_font_size(buffer_font_size - px(1.), cx); + }, + move |_, _, cx| { + write_buffer_font_size(buffer_font_size + px(1.), cx); + }, + ) + .style(ui::NumericStepperStyle::Outlined) + .tab_index({ + *tab_index += 2; + *tab_index - 2 + }), + ), + ), + ) +} + +type FontPicker = Picker<FontPickerDelegate>; + +pub struct FontPickerDelegate { + fonts: Vec<SharedString>, + filtered_fonts: Vec<StringMatch>, + selected_index: usize, + current_font: SharedString, + on_font_changed: Arc<dyn Fn(SharedString, &mut App) + 'static>, +} + +impl FontPickerDelegate { + fn new( + current_font: SharedString, + on_font_changed: impl Fn(SharedString, &mut App) + 'static, + cx: &mut Context<FontPicker>, + ) -> Self { + let font_family_cache = FontFamilyCache::global(cx); + + let fonts: Vec<SharedString> = font_family_cache + .list_font_families(cx) + .into_iter() + .collect(); + + let selected_index = fonts + .iter() + .position(|font| *font == current_font) + .unwrap_or(0); + + Self { + fonts: fonts.clone(), + filtered_fonts: fonts + .iter() + .enumerate() + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect(), + selected_index, + current_font, + on_font_changed: Arc::new(on_font_changed), + } + } +} + +impl PickerDelegate for FontPickerDelegate { + type ListItem = AnyElement; + + fn match_count(&self) -> usize { + self.filtered_fonts.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<FontPicker>) { + self.selected_index = ix.min(self.filtered_fonts.len().saturating_sub(1)); + cx.notify(); + } + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> { + "Search fonts…".into() + } + + fn update_matches( + &mut self, + query: String, + _window: &mut Window, + cx: &mut Context<FontPicker>, + ) -> Task<()> { + let fonts = self.fonts.clone(); + let current_font = self.current_font.clone(); + + let matches: Vec<StringMatch> = if query.is_empty() { + fonts + .iter() + .enumerate() + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect() + } else { + let _candidates: Vec<StringMatchCandidate> = fonts + .iter() + .enumerate() + .map(|(id, font)| StringMatchCandidate::new(id, font.as_ref())) + .collect(); + + fonts + .iter() + .enumerate() + .filter(|(_, font)| font.to_lowercase().contains(&query.to_lowercase())) + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect() + }; + + let selected_index = if query.is_empty() { + fonts + .iter() + .position(|font| *font == current_font) + .unwrap_or(0) + } else { + matches + .iter() + .position(|m| fonts[m.candidate_id] == current_font) + .unwrap_or(0) + }; + + self.filtered_fonts = matches; + self.selected_index = selected_index; + cx.notify(); + + Task::ready(()) + } + + fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<FontPicker>) { + if let Some(font_match) = self.filtered_fonts.get(self.selected_index) { + let font = font_match.string.clone(); + (self.on_font_changed)(font.into(), cx); + } + } + + fn dismissed(&mut self, _window: &mut Window, _cx: &mut Context<FontPicker>) {} + + fn render_match( + &self, + ix: usize, + selected: bool, + _window: &mut Window, + _cx: &mut Context<FontPicker>, + ) -> Option<Self::ListItem> { + let font_match = self.filtered_fonts.get(ix)?; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .child(Label::new(font_match.string.clone())) + .into_any_element(), + ) + } +} + +fn font_picker( + current_font: SharedString, + on_font_changed: impl Fn(SharedString, &mut App) + 'static, + window: &mut Window, + cx: &mut Context<FontPicker>, +) -> FontPicker { + let delegate = FontPickerDelegate::new(current_font, on_font_changed, cx); + + Picker::list(delegate, window, cx) + .show_scrollbar(true) + .width(rems_from_px(210.)) + .max_height(Some(rems(20.).into())) +} + +fn render_popular_settings_section( + tab_index: &mut isize, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + const LIGATURE_TOOLTIP: &'static str = "Ligatures are when a font creates a special character out of combining two characters into one. For example, with ligatures turned on, =/= would become ≠."; + + v_flex() + .gap_5() + .child(Label::new("Popular Settings").size(LabelSize::Large).mt_8()) + .child(render_font_customization_section(tab_index, window, cx)) + .child( + SwitchField::new( + "onboarding-font-ligatures", + "Font Ligatures", + Some("Combine text characters into their associated symbols.".into()), + if read_font_ligatures(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_font_ligatures(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }) + .tooltip(Tooltip::text(LIGATURE_TOOLTIP)), + ) + .child( + SwitchField::new( + "onboarding-format-on-save", + "Format on Save", + Some("Format code automatically when saving.".into()), + if read_format_on_save(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_format_on_save(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) + .child( + SwitchField::new( + "onboarding-enable-inlay-hints", + "Inlay Hints", + Some("See parameter names for function and method calls inline.".into()), + if read_inlay_hints(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_inlay_hints(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) + .child( + SwitchField::new( + "onboarding-git-blame-switch", + "Git Blame", + Some("See who committed each line on a given file.".into()), + if read_git_blame(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + set_git_blame(toggle_state == &ToggleState::Selected, cx); + }, + ) + .tab_index({ + *tab_index += 1; + *tab_index - 1 + }), + ) + .child( + h_flex() + .items_start() + .justify_between() + .child( + v_flex().child(Label::new("Mini Map")).child( + Label::new("See a high-level overview of your source code.") + .color(Color::Muted), + ), + ) + .child( + ToggleButtonGroup::single_row( + "onboarding-show-mini-map", + [ + ToggleButtonSimple::new("Auto", |_, _, cx| { + write_show_mini_map(ShowMinimap::Auto, cx); + }), + ToggleButtonSimple::new("Always", |_, _, cx| { + write_show_mini_map(ShowMinimap::Always, cx); + }), + ToggleButtonSimple::new("Never", |_, _, cx| { + write_show_mini_map(ShowMinimap::Never, cx); + }), + ], + ) + .selected_index(match read_show_mini_map(cx) { + ShowMinimap::Auto => 0, + ShowMinimap::Always => 1, + ShowMinimap::Never => 2, + }) + .tab_index(tab_index) + .style(ToggleButtonGroupStyle::Outlined) + .button_width(ui::rems_from_px(64.)), + ), + ) +} + +pub(crate) fn render_editing_page(window: &mut Window, cx: &mut App) -> impl IntoElement { + let mut tab_index = 0; + v_flex() + .gap_4() + .child(render_import_settings_section(&mut tab_index, cx)) + .child(render_popular_settings_section(&mut tab_index, window, cx)) +} diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index 1ce236f941..c4d2b6847c 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -1,33 +1,61 @@ +use crate::welcome::{ShowWelcome, WelcomePage}; +use client::{Client, UserStore}; use command_palette_hooks::CommandPaletteFilter; use db::kvp::KEY_VALUE_STORE; use feature_flags::{FeatureFlag, FeatureFlagViewExt as _}; use fs::Fs; use gpui::{ - AnyElement, App, AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, - IntoElement, Render, SharedString, Subscription, Task, WeakEntity, Window, actions, + Action, AnyElement, App, AppContext, AsyncWindowContext, Context, Entity, EventEmitter, + FocusHandle, Focusable, Global, IntoElement, KeyContext, Render, SharedString, Subscription, + Task, WeakEntity, Window, actions, }; -use settings::{Settings, SettingsStore, update_settings_file}; +use notifications::status_toast::{StatusToast, ToastIcon}; +use schemars::JsonSchema; +use serde::Deserialize; +use settings::{SettingsStore, VsCodeSettingsSource}; use std::sync::Arc; -use theme::{ThemeMode, ThemeSettings}; use ui::{ - ButtonCommon as _, ButtonSize, ButtonStyle, Clickable as _, Color, Divider, FluentBuilder, - Headline, InteractiveElement, KeyBinding, Label, LabelCommon, ParentElement as _, - StatefulInteractiveElement, Styled, ToggleButton, Toggleable as _, Vector, VectorName, div, - h_flex, rems, v_container, v_flex, + Avatar, ButtonLike, FluentBuilder, Headline, KeyBinding, ParentElement as _, + StatefulInteractiveElement, Vector, VectorName, prelude::*, rems_from_px, }; use workspace::{ AppState, Workspace, WorkspaceId, dock::DockPosition, item::{Item, ItemEvent}, - open_new, with_active_or_new_workspace, + notifications::NotifyResultExt as _, + open_new, register_serializable_item, with_active_or_new_workspace, }; +mod ai_setup_page; +mod basics_page; +mod editing_page; +mod theme_preview; +mod welcome; + pub struct OnBoardingFeatureFlag {} impl FeatureFlag for OnBoardingFeatureFlag { const NAME: &'static str = "onboarding"; } +/// Imports settings from Visual Studio Code. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportVsCodeSettings { + #[serde(default)] + pub skip_prompt: bool, +} + +/// Imports settings from Cursor editor. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportCursorSettings { + #[serde(default)] + pub skip_prompt: bool, +} + pub const FIRST_OPEN: &str = "first_open"; actions!( @@ -38,6 +66,20 @@ actions!( ] ); +actions!( + onboarding, + [ + /// Activates the Basics page. + ActivateBasicsPage, + /// Activates the Editing page. + ActivateEditingPage, + /// Activates the AI Setup page. + ActivateAISetupPage, + /// Finish the onboarding process. + Finish, + ] +); + pub fn init(cx: &mut App) { cx.on_action(|_: &OpenOnboarding, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { @@ -52,7 +94,7 @@ pub fn init(cx: &mut App) { if let Some(existing) = existing { workspace.activate_item(&existing, true, true, window, cx); } else { - let settings_page = Onboarding::new(workspace.weak_handle(), cx); + let settings_page = Onboarding::new(workspace, cx); workspace.add_item_to_active_pane( Box::new(settings_page), None, @@ -65,12 +107,86 @@ pub fn init(cx: &mut App) { .detach(); }); }); + + cx.on_action(|_: &ShowWelcome, cx| { + with_active_or_new_workspace(cx, |workspace, window, cx| { + workspace + .with_local_workspace(window, cx, |workspace, window, cx| { + let existing = workspace + .active_pane() + .read(cx) + .items() + .find_map(|item| item.downcast::<WelcomePage>()); + + if let Some(existing) = existing { + workspace.activate_item(&existing, true, true, window, cx); + } else { + let settings_page = WelcomePage::new(window, cx); + workspace.add_item_to_active_pane( + Box::new(settings_page), + None, + true, + window, + cx, + ) + } + }) + .detach(); + }); + }); + + cx.observe_new(|workspace: &mut Workspace, _window, _cx| { + workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { + let fs = <dyn Fs>::global(cx); + let action = *action; + + let workspace = cx.weak_entity(); + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + workspace, + VsCodeSettingsSource::VsCode, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + + workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { + let fs = <dyn Fs>::global(cx); + let action = *action; + + let workspace = cx.weak_entity(); + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + workspace, + VsCodeSettingsSource::Cursor, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + }) + .detach(); + cx.observe_new::<Workspace>(|_, window, cx| { let Some(window) = window else { return; }; - let onboarding_actions = [std::any::TypeId::of::<OpenOnboarding>()]; + let onboarding_actions = [ + std::any::TypeId::of::<OpenOnboarding>(), + std::any::TypeId::of::<ShowWelcome>(), + ]; CommandPaletteFilter::update_global(cx, |filter, _cx| { filter.hide_action_types(&onboarding_actions); @@ -90,6 +206,7 @@ pub fn init(cx: &mut App) { .detach(); }) .detach(); + register_serializable_item::<Onboarding>(cx); } pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyhow::Result<()>> { @@ -100,7 +217,7 @@ pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyh |workspace, window, cx| { { workspace.toggle_dock(DockPosition::Left, window, cx); - let onboarding_page = Onboarding::new(workspace.weak_handle(), cx); + let onboarding_page = Onboarding::new(workspace, cx); workspace.add_item_to_center(Box::new(onboarding_page.clone()), window, cx); window.focus(&onboarding_page.focus_handle(cx)); @@ -114,23 +231,6 @@ pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyh ) } -fn read_theme_selection(cx: &App) -> ThemeMode { - let settings = ThemeSettings::get_global(cx); - settings - .theme_selection - .as_ref() - .and_then(|selection| selection.mode()) - .unwrap_or_default() -} - -fn write_theme_selection(theme_mode: ThemeMode, cx: &App) { - let fs = <dyn Fs>::global(cx); - - update_settings_file::<ThemeSettings>(fs, cx, move |settings, _| { - settings.set_mode(theme_mode); - }); -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SelectedPage { Basics, @@ -142,129 +242,233 @@ struct Onboarding { workspace: WeakEntity<Workspace>, focus_handle: FocusHandle, selected_page: SelectedPage, + user_store: Entity<UserStore>, _settings_subscription: Subscription, } impl Onboarding { - fn new(workspace: WeakEntity<Workspace>, cx: &mut App) -> Entity<Self> { + fn new(workspace: &Workspace, cx: &mut App) -> Entity<Self> { cx.new(|cx| Self { - workspace, + workspace: workspace.weak_handle(), focus_handle: cx.focus_handle(), selected_page: SelectedPage::Basics, + user_store: workspace.user_store().clone(), _settings_subscription: cx.observe_global::<SettingsStore>(move |_, cx| cx.notify()), }) } - fn render_page_nav( + fn set_page(&mut self, page: SelectedPage, cx: &mut Context<Self>) { + self.selected_page = page; + cx.notify(); + cx.emit(ItemEvent::UpdateTab); + } + + fn render_nav_buttons( &mut self, - page: SelectedPage, - _: &mut Window, + window: &mut Window, cx: &mut Context<Self>, - ) -> impl IntoElement { - let text = match page { - SelectedPage::Basics => "Basics", - SelectedPage::Editing => "Editing", - SelectedPage::AiSetup => "AI Setup", - }; - let binding = match page { - SelectedPage::Basics => { - KeyBinding::new(vec![gpui::Keystroke::parse("cmd-1").unwrap()], cx) - } - SelectedPage::Editing => { - KeyBinding::new(vec![gpui::Keystroke::parse("cmd-2").unwrap()], cx) - } - SelectedPage::AiSetup => { - KeyBinding::new(vec![gpui::Keystroke::parse("cmd-3").unwrap()], cx) - } - }; - let selected = self.selected_page == page; - h_flex() - .id(text) - .rounded_sm() - .child(text) - .child(binding) - .h_8() - .gap_2() - .px_2() - .py_0p5() - .w_full() + ) -> [impl IntoElement; 3] { + let pages = [ + SelectedPage::Basics, + SelectedPage::Editing, + SelectedPage::AiSetup, + ]; + + let text = ["Basics", "Editing", "AI Setup"]; + + let actions: [&dyn Action; 3] = [ + &ActivateBasicsPage, + &ActivateEditingPage, + &ActivateAISetupPage, + ]; + + let mut binding = actions.map(|action| { + KeyBinding::for_action_in(action, &self.focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))) + }); + + pages.map(|page| { + let i = page as usize; + let selected = self.selected_page == page; + h_flex() + .id(text[i]) + .relative() + .w_full() + .gap_2() + .px_2() + .py_0p5() + .justify_between() + .rounded_sm() + .when(selected, |this| { + this.child( + div() + .h_4() + .w_px() + .bg(cx.theme().colors().text_accent) + .absolute() + .left_0(), + ) + }) + .hover(|style| style.bg(cx.theme().colors().element_hover)) + .child(Label::new(text[i]).map(|this| { + if selected { + this.color(Color::Default) + } else { + this.color(Color::Muted) + } + })) + .child(binding[i].take().map_or( + gpui::Empty.into_any_element(), + IntoElement::into_any_element, + )) + .on_click(cx.listener(move |this, _, _, cx| { + this.set_page(page, cx); + })) + }) + } + + fn render_nav(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let ai_setup_page = matches!(self.selected_page, SelectedPage::AiSetup); + + v_flex() + .h_full() + .w(rems_from_px(220.)) + .flex_shrink_0() + .gap_4() .justify_between() - .map(|this| { - if selected { - this.bg(Color::Selected.color(cx)) - .border_l_1() - .border_color(Color::Accent.color(cx)) + .child( + v_flex() + .gap_6() + .child( + h_flex() + .px_2() + .gap_4() + .child(Vector::square(VectorName::ZedLogo, rems(2.5))) + .child( + v_flex() + .child( + Headline::new("Welcome to Zed").size(HeadlineSize::Small), + ) + .child( + Label::new("The editor for what's next") + .color(Color::Muted) + .size(LabelSize::Small) + .italic(), + ), + ), + ) + .child( + v_flex() + .gap_4() + .child( + v_flex() + .py_4() + .border_y_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .gap_1() + .children(self.render_nav_buttons(window, cx)), + ) + .map(|this| { + let keybinding = KeyBinding::for_action_in( + &Finish, + &self.focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))); + if ai_setup_page { + this.child( + ButtonLike::new("start_building") + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .child( + h_flex() + .ml_1() + .w_full() + .justify_between() + .child(Label::new("Start Building")) + .child(keybinding.map_or_else( + || { + Icon::new(IconName::Check) + .size(IconSize::Small) + .into_any_element() + }, + IntoElement::into_any_element, + )), + ) + .on_click(|_, window, cx| { + window.dispatch_action(Finish.boxed_clone(), cx); + }), + ) + } else { + this.child( + ButtonLike::new("skip_all") + .size(ButtonSize::Medium) + .child( + h_flex() + .ml_1() + .w_full() + .justify_between() + .child(Label::new("Skip All")) + .child(keybinding.map_or_else( + || gpui::Empty.into_any_element(), + IntoElement::into_any_element, + )), + ) + .on_click(|_, window, cx| { + window.dispatch_action(Finish.boxed_clone(), cx); + }), + ) + } + }), + ), + ) + .child( + if let Some(user) = self.user_store.read(cx).current_user() { + h_flex() + .pl_1p5() + .gap_2() + .child(Avatar::new(user.avatar_uri.clone())) + .child(Label::new(user.github_login.clone())) + .into_any_element() } else { - this.text_color(Color::Muted.color(cx)) - } - }) - .hover(|style| { - if selected { - style.bg(Color::Selected.color(cx).opacity(0.6)) - } else { - style.bg(Color::Selected.color(cx).opacity(0.3)) - } - }) - .on_click(cx.listener(move |this, _, _, cx| { - this.selected_page = page; - cx.notify(); - })) + Button::new("sign_in", "Sign In") + .full_width() + .style(ButtonStyle::Outlined) + .on_click(|_, window, cx| { + let client = Client::global(cx); + window + .spawn(cx, async move |cx| { + client + .sign_in_with_optional_connect(true, &cx) + .await + .notify_async_err(cx); + }) + .detach(); + }) + .into_any_element() + }, + ) } fn render_page(&mut self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement { match self.selected_page { - SelectedPage::Basics => self.render_basics_page(window, cx).into_any_element(), - SelectedPage::Editing => self.render_editing_page(window, cx).into_any_element(), - SelectedPage::AiSetup => self.render_ai_setup_page(window, cx).into_any_element(), + SelectedPage::Basics => crate::basics_page::render_basics_page(cx).into_any_element(), + SelectedPage::Editing => { + crate::editing_page::render_editing_page(window, cx).into_any_element() + } + SelectedPage::AiSetup => crate::ai_setup_page::render_ai_setup_page( + self.workspace.clone(), + self.user_store.clone(), + window, + cx, + ) + .into_any_element(), } } - fn render_basics_page(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - let theme_mode = read_theme_selection(cx); - - v_container().child( - h_flex() - .items_center() - .justify_between() - .child(Label::new("Theme")) - .child( - h_flex() - .rounded_md() - .child( - ToggleButton::new("light", "Light") - .style(ButtonStyle::Filled) - .size(ButtonSize::Large) - .toggle_state(theme_mode == ThemeMode::Light) - .on_click(|_, _, cx| write_theme_selection(ThemeMode::Light, cx)) - .first(), - ) - .child( - ToggleButton::new("dark", "Dark") - .style(ButtonStyle::Filled) - .size(ButtonSize::Large) - .toggle_state(theme_mode == ThemeMode::Dark) - .on_click(|_, _, cx| write_theme_selection(ThemeMode::Dark, cx)) - .last(), - ) - .child( - ToggleButton::new("system", "System") - .style(ButtonStyle::Filled) - .size(ButtonSize::Large) - .toggle_state(theme_mode == ThemeMode::System) - .on_click(|_, _, cx| write_theme_selection(ThemeMode::System, cx)) - .middle(), - ), - ), - ) - } - - fn render_editing_page(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - // div().child("editing page") - "Right" - } - - fn render_ai_setup_page(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { - div().child("ai setup page") + fn on_finish(_: &Finish, _: &mut Window, cx: &mut App) { + go_to_welcome_page(cx); } } @@ -272,45 +476,54 @@ impl Render for Onboarding { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { h_flex() .image_cache(gpui::retain_all("onboarding-page")) - .key_context("onboarding-page") - .px_24() - .py_12() - .items_start() + .key_context({ + let mut ctx = KeyContext::new_with_defaults(); + ctx.add("Onboarding"); + ctx.add("menu"); + ctx + }) + .track_focus(&self.focus_handle) + .size_full() + .bg(cx.theme().colors().editor_background) + .on_action(Self::on_finish) + .on_action(cx.listener(|this, _: &ActivateBasicsPage, _, cx| { + this.set_page(SelectedPage::Basics, cx); + })) + .on_action(cx.listener(|this, _: &ActivateEditingPage, _, cx| { + this.set_page(SelectedPage::Editing, cx); + })) + .on_action(cx.listener(|this, _: &ActivateAISetupPage, _, cx| { + this.set_page(SelectedPage::AiSetup, cx); + })) + .on_action(cx.listener(|_, _: &menu::SelectNext, window, cx| { + window.focus_next(); + cx.notify(); + })) + .on_action(cx.listener(|_, _: &menu::SelectPrevious, window, cx| { + window.focus_prev(); + cx.notify(); + })) .child( - v_flex() - .w_1_3() - .h_full() + h_flex() + .max_w(rems_from_px(1100.)) + .size_full() + .m_auto() + .py_20() + .px_12() + .items_start() + .gap_12() + .child(self.render_nav(window, cx)) .child( - h_flex() - .pt_0p5() - .child(Vector::square(VectorName::ZedLogo, rems(2.))) - .child( - v_flex() - .left_1() - .items_center() - .child(Headline::new("Welcome to Zed")) - .child( - Label::new("The editor for what's next") - .color(Color::Muted) - .italic(), - ), - ), - ) - .p_1() - .child(Divider::horizontal_dashed()) - .child( - v_flex().gap_1().children([ - self.render_page_nav(SelectedPage::Basics, window, cx) - .into_element(), - self.render_page_nav(SelectedPage::Editing, window, cx) - .into_element(), - self.render_page_nav(SelectedPage::AiSetup, window, cx) - .into_element(), - ]), + v_flex() + .max_w_full() + .min_w_0() + .pl_12() + .border_l_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .size_full() + .child(self.render_page(window, cx)), ), ) - // .child(Divider::vertical_dashed()) - .child(div().w_2_3().h_full().child(self.render_page(window, cx))) } } @@ -343,10 +556,279 @@ impl Item for Onboarding { _: &mut Window, cx: &mut Context<Self>, ) -> Option<Entity<Self>> { - Some(Onboarding::new(self.workspace.clone(), cx)) + self.workspace + .update(cx, |workspace, cx| Onboarding::new(workspace, cx)) + .ok() } fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { f(*event) } } + +fn go_to_welcome_page(cx: &mut App) { + with_active_or_new_workspace(cx, |workspace, window, cx| { + let Some((onboarding_id, onboarding_idx)) = workspace + .active_pane() + .read(cx) + .items() + .enumerate() + .find_map(|(idx, item)| { + let _ = item.downcast::<Onboarding>()?; + Some((item.item_id(), idx)) + }) + else { + return; + }; + + workspace.active_pane().update(cx, |pane, cx| { + // Get the index here to get around the borrow checker + let idx = pane.items().enumerate().find_map(|(idx, item)| { + let _ = item.downcast::<WelcomePage>()?; + Some(idx) + }); + + if let Some(idx) = idx { + pane.activate_item(idx, true, true, window, cx); + } else { + let item = Box::new(WelcomePage::new(window, cx)); + pane.add_item(item, true, true, Some(onboarding_idx), window, cx); + } + + pane.remove_item(onboarding_id, false, false, window, cx); + }); + }); +} + +pub async fn handle_import_vscode_settings( + workspace: WeakEntity<Workspace>, + source: VsCodeSettingsSource, + skip_prompt: bool, + fs: Arc<dyn Fs>, + cx: &mut AsyncWindowContext, +) { + use util::truncate_and_remove_front; + + let vscode_settings = + match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { + Ok(vscode_settings) => vscode_settings, + Err(err) => { + zlog::error!("{err}"); + let _ = cx.prompt( + gpui::PromptLevel::Info, + &format!("Could not find or load a {source} settings file"), + None, + &["Ok"], + ); + return; + } + }; + + if !skip_prompt { + let prompt = cx.prompt( + gpui::PromptLevel::Warning, + &format!( + "Importing {} settings may overwrite your existing settings. \ + Will import settings from {}", + vscode_settings.source, + truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), + ), + None, + &["Ok", "Cancel"], + ); + let result = cx.spawn(async move |_| prompt.await.ok()).await; + if result != Some(0) { + return; + } + }; + + let Ok(result_channel) = cx.update(|_, cx| { + let source = vscode_settings.source; + let path = vscode_settings.path.clone(); + let result_channel = cx + .global::<SettingsStore>() + .import_vscode_settings(fs, vscode_settings); + zlog::info!("Imported {source} settings from {}", path.display()); + result_channel + }) else { + return; + }; + + let result = result_channel.await; + workspace + .update_in(cx, |workspace, _, cx| match result { + Ok(_) => { + let confirmation_toast = StatusToast::new( + format!("Your {} settings were successfully imported.", source), + cx, + |this, _| { + this.icon(ToastIcon::new(IconName::Check).color(Color::Success)) + .dismiss_button(true) + }, + ); + SettingsImportState::update(cx, |state, _| match source { + VsCodeSettingsSource::VsCode => { + state.vscode = true; + } + VsCodeSettingsSource::Cursor => { + state.cursor = true; + } + }); + workspace.toggle_status_toast(confirmation_toast, cx); + } + Err(_) => { + let error_toast = StatusToast::new( + "Failed to import settings. See log for details", + cx, + |this, _| { + this.icon(ToastIcon::new(IconName::X).color(Color::Error)) + .action("Open Log", |window, cx| { + window.dispatch_action(workspace::OpenLog.boxed_clone(), cx) + }) + .dismiss_button(true) + }, + ); + workspace.toggle_status_toast(error_toast, cx); + } + }) + .ok(); +} + +#[derive(Default, Copy, Clone)] +pub struct SettingsImportState { + pub cursor: bool, + pub vscode: bool, +} + +impl Global for SettingsImportState {} + +impl SettingsImportState { + pub fn global(cx: &App) -> Self { + cx.try_global().cloned().unwrap_or_default() + } + pub fn update<R>(cx: &mut App, f: impl FnOnce(&mut Self, &mut App) -> R) -> R { + cx.update_default_global(f) + } +} + +impl workspace::SerializableItem for Onboarding { + fn serialized_item_kind() -> &'static str { + "OnboardingPage" + } + + fn cleanup( + workspace_id: workspace::WorkspaceId, + alive_items: Vec<workspace::ItemId>, + _window: &mut Window, + cx: &mut App, + ) -> gpui::Task<gpui::Result<()>> { + workspace::delete_unloaded_items( + alive_items, + workspace_id, + "onboarding_pages", + &persistence::ONBOARDING_PAGES, + cx, + ) + } + + fn deserialize( + _project: Entity<project::Project>, + workspace: WeakEntity<Workspace>, + workspace_id: workspace::WorkspaceId, + item_id: workspace::ItemId, + window: &mut Window, + cx: &mut App, + ) -> gpui::Task<gpui::Result<Entity<Self>>> { + window.spawn(cx, async move |cx| { + if let Some(page_number) = + persistence::ONBOARDING_PAGES.get_onboarding_page(item_id, workspace_id)? + { + let page = match page_number { + 0 => Some(SelectedPage::Basics), + 1 => Some(SelectedPage::Editing), + 2 => Some(SelectedPage::AiSetup), + _ => None, + }; + workspace.update(cx, |workspace, cx| { + let onboarding_page = Onboarding::new(workspace, cx); + if let Some(page) = page { + zlog::info!("Onboarding page {page:?} loaded"); + onboarding_page.update(cx, |onboarding_page, cx| { + onboarding_page.set_page(page, cx); + }) + } + onboarding_page + }) + } else { + Err(anyhow::anyhow!("No onboarding page to deserialize")) + } + }) + } + + fn serialize( + &mut self, + workspace: &mut Workspace, + item_id: workspace::ItemId, + _closing: bool, + _window: &mut Window, + cx: &mut ui::Context<Self>, + ) -> Option<gpui::Task<gpui::Result<()>>> { + let workspace_id = workspace.database_id()?; + let page_number = self.selected_page as u16; + Some(cx.background_spawn(async move { + persistence::ONBOARDING_PAGES + .save_onboarding_page(item_id, workspace_id, page_number) + .await + })) + } + + fn should_serialize(&self, event: &Self::Event) -> bool { + event == &ItemEvent::UpdateTab + } +} + +mod persistence { + use db::{define_connection, query, sqlez_macros::sql}; + use workspace::WorkspaceDb; + + define_connection! { + pub static ref ONBOARDING_PAGES: OnboardingPagesDb<WorkspaceDb> = + &[ + sql!( + CREATE TABLE onboarding_pages ( + workspace_id INTEGER, + item_id INTEGER UNIQUE, + page_number INTEGER, + + PRIMARY KEY(workspace_id, item_id), + FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) + ON DELETE CASCADE + ) STRICT; + ), + ]; + } + + impl OnboardingPagesDb { + query! { + pub async fn save_onboarding_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId, + page_number: u16 + ) -> Result<()> { + INSERT OR REPLACE INTO onboarding_pages(item_id, workspace_id, page_number) + VALUES (?, ?, ?) + } + } + + query! { + pub fn get_onboarding_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId + ) -> Result<Option<u16>> { + SELECT page_number + FROM onboarding_pages + WHERE item_id = ? AND workspace_id = ? + } + } + } +} diff --git a/crates/onboarding/src/theme_preview.rs b/crates/onboarding/src/theme_preview.rs new file mode 100644 index 0000000000..81eb14ec4b --- /dev/null +++ b/crates/onboarding/src/theme_preview.rs @@ -0,0 +1,378 @@ +#![allow(unused, dead_code)] +use gpui::{Hsla, Length}; +use std::sync::Arc; +use theme::{Theme, ThemeColors, ThemeRegistry}; +use ui::{ + IntoElement, RenderOnce, component_prelude::Documented, prelude::*, utils::inner_corner_radius, +}; + +#[derive(Clone, PartialEq)] +pub enum ThemePreviewStyle { + Bordered, + Borderless, + SideBySide(Arc<Theme>), +} + +/// Shows a preview of a theme as an abstract illustration +/// of a thumbnail-sized editor. +#[derive(IntoElement, RegisterComponent, Documented)] +pub struct ThemePreviewTile { + theme: Arc<Theme>, + seed: f32, + style: ThemePreviewStyle, +} + +impl ThemePreviewTile { + pub const SKELETON_HEIGHT_DEFAULT: Pixels = px(2.); + pub const SIDEBAR_SKELETON_ITEM_COUNT: usize = 8; + pub const SIDEBAR_WIDTH_DEFAULT: DefiniteLength = relative(0.25); + pub const ROOT_RADIUS: Pixels = px(8.0); + pub const ROOT_BORDER: Pixels = px(2.0); + pub const ROOT_PADDING: Pixels = px(2.0); + pub const CHILD_BORDER: Pixels = px(1.0); + pub const CHILD_RADIUS: std::cell::LazyCell<Pixels> = std::cell::LazyCell::new(|| { + inner_corner_radius( + Self::ROOT_RADIUS, + Self::ROOT_BORDER, + Self::ROOT_PADDING, + Self::CHILD_BORDER, + ) + }); + + pub fn new(theme: Arc<Theme>, seed: f32) -> Self { + Self { + theme, + seed, + style: ThemePreviewStyle::Bordered, + } + } + + pub fn style(mut self, style: ThemePreviewStyle) -> Self { + self.style = style; + self + } + + pub fn item_skeleton(w: Length, h: Length, bg: Hsla) -> impl IntoElement { + div().w(w).h(h).rounded_full().bg(bg) + } + + pub fn render_sidebar_skeleton_items( + seed: f32, + colors: &ThemeColors, + skeleton_height: impl Into<Length> + Clone, + ) -> [impl IntoElement; Self::SIDEBAR_SKELETON_ITEM_COUNT] { + let skeleton_height = skeleton_height.into(); + std::array::from_fn(|index| { + let width = { + let value = (seed * 1000.0 + index as f32 * 10.0).sin() * 0.5 + 0.5; + 0.5 + value * 0.45 + }; + Self::item_skeleton( + relative(width).into(), + skeleton_height, + colors.text.alpha(0.45), + ) + }) + } + + pub fn render_pseudo_code_skeleton( + seed: f32, + theme: Arc<Theme>, + skeleton_height: impl Into<Length>, + ) -> impl IntoElement { + let colors = theme.colors(); + let syntax = theme.syntax(); + + let keyword_color = syntax.get("keyword").color; + let function_color = syntax.get("function").color; + let string_color = syntax.get("string").color; + let comment_color = syntax.get("comment").color; + let variable_color = syntax.get("variable").color; + let type_color = syntax.get("type").color; + let punctuation_color = syntax.get("punctuation").color; + + let syntax_colors = [ + keyword_color, + function_color, + string_color, + variable_color, + type_color, + punctuation_color, + comment_color, + ]; + + let skeleton_height = skeleton_height.into(); + + let line_width = |line_idx: usize, block_idx: usize| -> f32 { + let val = + (seed * 100.0 + line_idx as f32 * 20.0 + block_idx as f32 * 5.0).sin() * 0.5 + 0.5; + 0.05 + val * 0.2 + }; + + let indentation = |line_idx: usize| -> f32 { + let step = line_idx % 6; + if step < 3 { + step as f32 * 0.1 + } else { + (5 - step) as f32 * 0.1 + } + }; + + let pick_color = |line_idx: usize, block_idx: usize| -> Hsla { + let idx = ((seed * 10.0 + line_idx as f32 * 7.0 + block_idx as f32 * 3.0).sin() * 3.5) + .abs() as usize + % syntax_colors.len(); + syntax_colors[idx].unwrap_or(colors.text) + }; + + let line_count = 13; + + let lines = (0..line_count) + .map(|line_idx| { + let block_count = (((seed * 30.0 + line_idx as f32 * 12.0).sin() * 0.5 + 0.5) * 3.0) + .round() as usize + + 2; + + let indent = indentation(line_idx); + + let blocks = (0..block_count) + .map(|block_idx| { + let width = line_width(line_idx, block_idx); + let color = pick_color(line_idx, block_idx); + Self::item_skeleton(relative(width).into(), skeleton_height, color) + }) + .collect::<Vec<_>>(); + + h_flex().gap(px(2.)).ml(relative(indent)).children(blocks) + }) + .collect::<Vec<_>>(); + + v_flex().size_full().p_1().gap_1p5().children(lines) + } + + pub fn render_sidebar( + seed: f32, + colors: &ThemeColors, + width: impl Into<Length> + Clone, + skeleton_height: impl Into<Length>, + ) -> impl IntoElement { + div() + .h_full() + .w(width) + .border_r(px(1.)) + .border_color(colors.border_transparent) + .bg(colors.panel_background) + .child(v_flex().p_2().size_full().gap_1().children( + Self::render_sidebar_skeleton_items(seed, colors, skeleton_height.into()), + )) + } + + pub fn render_pane( + seed: f32, + theme: Arc<Theme>, + skeleton_height: impl Into<Length>, + ) -> impl IntoElement { + v_flex().h_full().flex_grow().child( + div() + .size_full() + .overflow_hidden() + .bg(theme.colors().editor_background) + .p_2() + .child(Self::render_pseudo_code_skeleton( + seed, + theme, + skeleton_height.into(), + )), + ) + } + + pub fn render_editor( + seed: f32, + theme: Arc<Theme>, + sidebar_width: impl Into<Length> + Clone, + skeleton_height: impl Into<Length> + Clone, + ) -> impl IntoElement { + div() + .size_full() + .flex() + .bg(theme.colors().background.alpha(1.00)) + .child(Self::render_sidebar( + seed, + theme.colors(), + sidebar_width, + skeleton_height.clone(), + )) + .child(Self::render_pane(seed, theme, skeleton_height.clone())) + } + + fn render_borderless(seed: f32, theme: Arc<Theme>) -> impl IntoElement { + return Self::render_editor( + seed, + theme, + Self::SIDEBAR_WIDTH_DEFAULT, + Self::SKELETON_HEIGHT_DEFAULT, + ); + } + + fn render_border(seed: f32, theme: Arc<Theme>) -> impl IntoElement { + div() + .size_full() + .p(Self::ROOT_PADDING) + .rounded(Self::ROOT_RADIUS) + .child( + div() + .size_full() + .rounded(*Self::CHILD_RADIUS) + .border(Self::CHILD_BORDER) + .border_color(theme.colors().border) + .child(Self::render_editor( + seed, + theme.clone(), + Self::SIDEBAR_WIDTH_DEFAULT, + Self::SKELETON_HEIGHT_DEFAULT, + )), + ) + } + + fn render_side_by_side( + seed: f32, + theme: Arc<Theme>, + other_theme: Arc<Theme>, + border_color: Hsla, + ) -> impl IntoElement { + let sidebar_width = relative(0.20); + + return div() + .size_full() + .p(Self::ROOT_PADDING) + .rounded(Self::ROOT_RADIUS) + .child( + h_flex() + .size_full() + .relative() + .rounded(*Self::CHILD_RADIUS) + .border(Self::CHILD_BORDER) + .border_color(border_color) + .overflow_hidden() + .child(div().size_full().child(Self::render_editor( + seed, + theme.clone(), + sidebar_width, + Self::SKELETON_HEIGHT_DEFAULT, + ))) + .child( + div() + .size_full() + .absolute() + .left_1_2() + .bg(other_theme.colors().editor_background) + .child(Self::render_editor( + seed, + other_theme, + sidebar_width, + Self::SKELETON_HEIGHT_DEFAULT, + )), + ), + ) + .into_any_element(); + } +} + +impl RenderOnce for ThemePreviewTile { + fn render(self, _window: &mut ui::Window, _cx: &mut ui::App) -> impl IntoElement { + match self.style { + ThemePreviewStyle::Bordered => { + Self::render_border(self.seed, self.theme).into_any_element() + } + ThemePreviewStyle::Borderless => { + Self::render_borderless(self.seed, self.theme).into_any_element() + } + ThemePreviewStyle::SideBySide(other_theme) => Self::render_side_by_side( + self.seed, + self.theme, + other_theme, + _cx.theme().colors().border, + ) + .into_any_element(), + } + } +} + +impl Component for ThemePreviewTile { + fn scope() -> ComponentScope { + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "Theme Preview Tile" + } + + fn sort_name() -> &'static str { + "Theme Preview Tile" + } + + fn description() -> Option<&'static str> { + Some(Self::DOCS) + } + + fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> { + let theme_registry = ThemeRegistry::global(cx); + + let one_dark = theme_registry.get("One Dark"); + let one_light = theme_registry.get("One Light"); + let gruvbox_dark = theme_registry.get("Gruvbox Dark"); + let gruvbox_light = theme_registry.get("Gruvbox Light"); + + let themes_to_preview = vec![ + one_dark.clone().ok(), + one_light.clone().ok(), + gruvbox_dark.clone().ok(), + gruvbox_light.clone().ok(), + ] + .into_iter() + .flatten() + .collect::<Vec<_>>(); + + Some( + v_flex() + .gap_6() + .p_4() + .children({ + if let Some(one_dark) = one_dark.ok() { + vec![example_group(vec![single_example( + "Default", + div() + .w(px(240.)) + .h(px(180.)) + .child(ThemePreviewTile::new(one_dark.clone(), 0.42)) + .into_any_element(), + )])] + } else { + vec![] + } + }) + .child( + example_group(vec![single_example( + "Default Themes", + h_flex() + .gap_4() + .children( + themes_to_preview + .iter() + .enumerate() + .map(|(_, theme)| { + div() + .w(px(200.)) + .h(px(140.)) + .child(ThemePreviewTile::new(theme.clone(), 0.42)) + }) + .collect::<Vec<_>>(), + ) + .into_any_element(), + )]) + .grow(), + ) + .into_any_element(), + ) + } +} diff --git a/crates/onboarding/src/welcome.rs b/crates/onboarding/src/welcome.rs new file mode 100644 index 0000000000..d4d6c3f701 --- /dev/null +++ b/crates/onboarding/src/welcome.rs @@ -0,0 +1,355 @@ +use gpui::{ + Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, + NoAction, ParentElement, Render, Styled, Window, actions, +}; +use menu::{SelectNext, SelectPrevious}; +use ui::{ButtonLike, Divider, DividerColor, KeyBinding, Vector, VectorName, prelude::*}; +use workspace::{ + NewFile, Open, WorkspaceId, + item::{Item, ItemEvent}, + with_active_or_new_workspace, +}; +use zed_actions::{Extensions, OpenSettings, agent, command_palette}; + +use crate::{Onboarding, OpenOnboarding}; + +actions!( + zed, + [ + /// Show the Zed welcome screen + ShowWelcome + ] +); + +const CONTENT: (Section<4>, Section<3>) = ( + Section { + title: "Get Started", + entries: [ + SectionEntry { + icon: IconName::Plus, + title: "New File", + action: &NewFile, + }, + SectionEntry { + icon: IconName::FolderOpen, + title: "Open Project", + action: &Open, + }, + SectionEntry { + icon: IconName::CloudDownload, + title: "Clone a Repo", + // TODO: use proper action + action: &NoAction, + }, + SectionEntry { + icon: IconName::ListCollapse, + title: "Open Command Palette", + action: &command_palette::Toggle, + }, + ], + }, + Section { + title: "Configure", + entries: [ + SectionEntry { + icon: IconName::Settings, + title: "Open Settings", + action: &OpenSettings, + }, + SectionEntry { + icon: IconName::ZedAssistant, + title: "View AI Settings", + action: &agent::OpenSettings, + }, + SectionEntry { + icon: IconName::Blocks, + title: "Explore Extensions", + action: &Extensions { + category_filter: None, + id: None, + }, + }, + ], + }, +); + +struct Section<const COLS: usize> { + title: &'static str, + entries: [SectionEntry; COLS], +} + +impl<const COLS: usize> Section<COLS> { + fn render( + self, + index_offset: usize, + focus: &FocusHandle, + window: &mut Window, + cx: &mut App, + ) -> impl IntoElement { + v_flex() + .min_w_full() + .child( + h_flex() + .px_1() + .mb_2() + .gap_2() + .child( + Label::new(self.title.to_ascii_uppercase()) + .buffer_font(cx) + .color(Color::Muted) + .size(LabelSize::XSmall), + ) + .child(Divider::horizontal().color(DividerColor::BorderVariant)), + ) + .children( + self.entries + .iter() + .enumerate() + .map(|(index, entry)| entry.render(index_offset + index, &focus, window, cx)), + ) + } +} + +struct SectionEntry { + icon: IconName, + title: &'static str, + action: &'static dyn Action, +} + +impl SectionEntry { + fn render( + &self, + button_index: usize, + focus: &FocusHandle, + window: &Window, + cx: &App, + ) -> impl IntoElement { + ButtonLike::new(("onboarding-button-id", button_index)) + .tab_index(button_index as isize) + .full_width() + .size(ButtonSize::Medium) + .child( + h_flex() + .w_full() + .justify_between() + .child( + h_flex() + .gap_2() + .child( + Icon::new(self.icon) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(self.title)), + ) + .children( + KeyBinding::for_action_in(self.action, focus, window, cx) + .map(|s| s.size(rems_from_px(12.))), + ), + ) + .on_click(|_, window, cx| window.dispatch_action(self.action.boxed_clone(), cx)) + } +} + +pub struct WelcomePage { + focus_handle: FocusHandle, +} + +impl WelcomePage { + fn select_next(&mut self, _: &SelectNext, window: &mut Window, cx: &mut Context<Self>) { + window.focus_next(); + cx.notify(); + } + + fn select_previous(&mut self, _: &SelectPrevious, window: &mut Window, cx: &mut Context<Self>) { + window.focus_prev(); + cx.notify(); + } +} + +impl Render for WelcomePage { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let (first_section, second_section) = CONTENT; + let first_section_entries = first_section.entries.len(); + let last_index = first_section_entries + second_section.entries.len(); + + h_flex() + .size_full() + .justify_center() + .overflow_hidden() + .bg(cx.theme().colors().editor_background) + .key_context("Welcome") + .track_focus(&self.focus_handle(cx)) + .on_action(cx.listener(Self::select_previous)) + .on_action(cx.listener(Self::select_next)) + .child( + h_flex() + .px_12() + .py_40() + .size_full() + .relative() + .max_w(px(1100.)) + .child( + div() + .size_full() + .max_w_128() + .mx_auto() + .child( + h_flex() + .w_full() + .justify_center() + .gap_4() + .child(Vector::square(VectorName::ZedLogo, rems(2.))) + .child( + div().child(Headline::new("Welcome to Zed")).child( + Label::new("The editor for what's next") + .size(LabelSize::Small) + .color(Color::Muted) + .italic(), + ), + ), + ) + .child( + v_flex() + .mt_10() + .gap_6() + .child(first_section.render( + Default::default(), + &self.focus_handle, + window, + cx, + )) + .child(second_section.render( + first_section_entries, + &self.focus_handle, + window, + cx, + )) + .child( + h_flex() + .w_full() + .pt_4() + .justify_center() + // We call this a hack + .rounded_b_xs() + .border_t_1() + .border_color(cx.theme().colors().border.opacity(0.6)) + .border_dashed() + .child( + Button::new("welcome-exit", "Return to Setup") + .tab_index(last_index as isize) + .full_width() + .label_size(LabelSize::XSmall) + .on_click(|_, window, cx| { + window.dispatch_action( + OpenOnboarding.boxed_clone(), + cx, + ); + + with_active_or_new_workspace(cx, |workspace, window, cx| { + let Some((welcome_id, welcome_idx)) = workspace + .active_pane() + .read(cx) + .items() + .enumerate() + .find_map(|(idx, item)| { + let _ = item.downcast::<WelcomePage>()?; + Some((item.item_id(), idx)) + }) + else { + return; + }; + + workspace.active_pane().update(cx, |pane, cx| { + // Get the index here to get around the borrow checker + let idx = pane.items().enumerate().find_map( + |(idx, item)| { + let _ = + item.downcast::<Onboarding>()?; + Some(idx) + }, + ); + + if let Some(idx) = idx { + pane.activate_item( + idx, true, true, window, cx, + ); + } else { + let item = + Box::new(Onboarding::new(workspace, cx)); + pane.add_item( + item, + true, + true, + Some(welcome_idx), + window, + cx, + ); + } + + pane.remove_item( + welcome_id, + false, + false, + window, + cx, + ); + }); + }); + }), + ), + ), + ), + ), + ) + } +} + +impl WelcomePage { + pub fn new(window: &mut Window, cx: &mut App) -> Entity<Self> { + cx.new(|cx| { + let focus_handle = cx.focus_handle(); + cx.on_focus(&focus_handle, window, |_, _, cx| cx.notify()) + .detach(); + + WelcomePage { focus_handle } + }) + } +} + +impl EventEmitter<ItemEvent> for WelcomePage {} + +impl Focusable for WelcomePage { + fn focus_handle(&self, _: &App) -> gpui::FocusHandle { + self.focus_handle.clone() + } +} + +impl Item for WelcomePage { + type Event = ItemEvent; + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Welcome".into() + } + + fn telemetry_event_text(&self) -> Option<&'static str> { + Some("New Welcome Page Opened") + } + + fn show_toolbar(&self) -> bool { + false + } + + fn clone_on_split( + &self, + _workspace_id: Option<WorkspaceId>, + _: &mut Window, + _: &mut Context<Self>, + ) -> Option<Entity<Self>> { + None + } + + fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { + f(*event) + } +} diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index 50c6c2dcce..1cda3897ec 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -1041,7 +1041,7 @@ impl OutlinePanel { fn open_excerpts( &mut self, - action: &editor::OpenExcerpts, + action: &editor::actions::OpenExcerpts, window: &mut Window, cx: &mut Context<Self>, ) { @@ -1057,7 +1057,7 @@ impl OutlinePanel { fn open_excerpts_split( &mut self, - action: &editor::OpenExcerptsSplit, + action: &editor::actions::OpenExcerptsSplit, window: &mut Window, cx: &mut Context<Self>, ) { @@ -2570,11 +2570,11 @@ impl OutlinePanel { .on_click({ let clicked_entry = rendered_entry.clone(); cx.listener(move |outline_panel, event: &gpui::ClickEvent, window, cx| { - if event.down.button == MouseButton::Right || event.down.first_mouse { + if event.is_right_click() || event.first_focus() { return; } - let change_focus = event.down.click_count > 1; + let change_focus = event.click_count() > 1; outline_panel.toggle_expanded(&clicked_entry, window, cx); outline_panel.scroll_editor_to_entry( @@ -5958,7 +5958,7 @@ mod tests { }); outline_panel.update_in(cx, |outline_panel, window, cx| { - outline_panel.open_excerpts(&editor::OpenExcerpts, window, cx); + outline_panel.open_excerpts(&editor::actions::OpenExcerpts, window, cx); }); cx.executor() .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); diff --git a/crates/paths/src/paths.rs b/crates/paths/src/paths.rs index 2f3b188980..47a0f12c06 100644 --- a/crates/paths/src/paths.rs +++ b/crates/paths/src/paths.rs @@ -35,6 +35,7 @@ pub fn remote_server_dir_relative() -> &'static Path { /// Sets a custom directory for all user data, overriding the default data directory. /// This function must be called before any other path operations that depend on the data directory. +/// The directory's path will be canonicalized to an absolute path by a blocking FS operation. /// The directory will be created if it doesn't exist. /// /// # Arguments @@ -50,13 +51,20 @@ pub fn remote_server_dir_relative() -> &'static Path { /// /// Panics if: /// * Called after the data directory has been initialized (e.g., via `data_dir` or `config_dir`) +/// * The directory's path cannot be canonicalized to an absolute path /// * The directory cannot be created pub fn set_custom_data_dir(dir: &str) -> &'static PathBuf { if CURRENT_DATA_DIR.get().is_some() || CONFIG_DIR.get().is_some() { panic!("set_custom_data_dir called after data_dir or config_dir was initialized"); } CUSTOM_DATA_DIR.get_or_init(|| { - let path = PathBuf::from(dir); + let mut path = PathBuf::from(dir); + if path.is_relative() { + let abs_path = path + .canonicalize() + .expect("failed to canonicalize custom data directory's path to an absolute path"); + path = PathBuf::from(util::paths::SanitizedPath::from(abs_path)) + } std::fs::create_dir_all(&path).expect("failed to create custom data directory"); path }) diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs index 692bdd5bd7..34af5fed02 100644 --- a/crates/picker/src/picker.rs +++ b/crates/picker/src/picker.rs @@ -292,7 +292,7 @@ impl<D: PickerDelegate> Picker<D> { window: &mut Window, cx: &mut Context<Self>, ) -> Self { - let element_container = Self::create_element_container(container, cx); + let element_container = Self::create_element_container(container); let scrollbar_state = match &element_container { ElementContainer::UniformList(scroll_handle) => { ScrollbarState::new(scroll_handle.clone()) @@ -323,31 +323,13 @@ impl<D: PickerDelegate> Picker<D> { this } - fn create_element_container( - container: ContainerKind, - cx: &mut Context<Self>, - ) -> ElementContainer { + fn create_element_container(container: ContainerKind) -> ElementContainer { match container { ContainerKind::UniformList => { ElementContainer::UniformList(UniformListScrollHandle::new()) } ContainerKind::List => { - let entity = cx.entity().downgrade(); - ElementContainer::List(ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - entity - .upgrade() - .map(|entity| { - entity.update(cx, |this, cx| { - this.render_element(window, cx, ix).into_any_element() - }) - }) - .unwrap_or_else(|| div().into_any_element()) - }, - )) + ElementContainer::List(ListState::new(0, gpui::ListAlignment::Top, px(1000.))) } } } @@ -786,11 +768,16 @@ impl<D: PickerDelegate> Picker<D> { .py_1() .track_scroll(scroll_handle.clone()) .into_any_element(), - ElementContainer::List(state) => list(state.clone()) - .with_sizing_behavior(sizing_behavior) - .flex_grow() - .py_2() - .into_any_element(), + ElementContainer::List(state) => list( + state.clone(), + cx.processor(|this, ix, window, cx| { + this.render_element(window, cx, ix).into_any_element() + }), + ) + .with_sizing_behavior(sizing_behavior) + .flex_grow() + .py_2() + .into_any_element(), } } diff --git a/crates/picker/src/popover_menu.rs b/crates/picker/src/popover_menu.rs index dd1d9c2865..d05308ee71 100644 --- a/crates/picker/src/popover_menu.rs +++ b/crates/picker/src/popover_menu.rs @@ -80,6 +80,7 @@ where { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { let picker = self.picker.clone(); + PopoverMenu::new("popover-menu") .menu(move |_window, _cx| Some(picker.clone())) .trigger_with_tooltip(self.trigger, self.tooltip) diff --git a/crates/prettier/src/prettier_server.js b/crates/prettier/src/prettier_server.js index 6799b4aceb..b3d8a660a4 100644 --- a/crates/prettier/src/prettier_server.js +++ b/crates/prettier/src/prettier_server.js @@ -152,6 +152,10 @@ async function handleMessage(message, prettier) { throw new Error(`Message method is undefined: ${JSON.stringify(message)}`); } else if (method == "initialized") { return; + } else if (method === "shutdown") { + sendResponse({ result: {} }); + } else if (method == "exit") { + process.exit(0); } if (id === undefined) { diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index ceec0c0a52..c96ab4e8f3 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -13,6 +13,7 @@ use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; use crate::{ + Project, project_settings::{ContextServerSettings, ProjectSettings}, worktree_store::WorktreeStore, }; @@ -144,6 +145,7 @@ pub struct ContextServerStore { context_server_settings: HashMap<Arc<str>, ContextServerSettings>, servers: HashMap<ContextServerId, ContextServerState>, worktree_store: Entity<WorktreeStore>, + project: WeakEntity<Project>, registry: Entity<ContextServerDescriptorRegistry>, update_servers_task: Option<Task<Result<()>>>, context_server_factory: Option<ContextServerFactory>, @@ -161,12 +163,17 @@ pub enum Event { impl EventEmitter<Event> for ContextServerStore {} impl ContextServerStore { - pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self { + pub fn new( + worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, + cx: &mut Context<Self>, + ) -> Self { Self::new_internal( true, None, ContextServerDescriptorRegistry::default_global(cx), worktree_store, + weak_project, cx, ) } @@ -184,9 +191,10 @@ impl ContextServerStore { pub fn test( registry: Entity<ContextServerDescriptorRegistry>, worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, cx: &mut Context<Self>, ) -> Self { - Self::new_internal(false, None, registry, worktree_store, cx) + Self::new_internal(false, None, registry, worktree_store, weak_project, cx) } #[cfg(any(test, feature = "test-support"))] @@ -194,6 +202,7 @@ impl ContextServerStore { context_server_factory: ContextServerFactory, registry: Entity<ContextServerDescriptorRegistry>, worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, cx: &mut Context<Self>, ) -> Self { Self::new_internal( @@ -201,6 +210,7 @@ impl ContextServerStore { Some(context_server_factory), registry, worktree_store, + weak_project, cx, ) } @@ -210,6 +220,7 @@ impl ContextServerStore { context_server_factory: Option<ContextServerFactory>, registry: Entity<ContextServerDescriptorRegistry>, worktree_store: Entity<WorktreeStore>, + weak_project: WeakEntity<Project>, cx: &mut Context<Self>, ) -> Self { let subscriptions = if maintain_server_loop { @@ -235,6 +246,7 @@ impl ContextServerStore { context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx) .clone(), worktree_store, + project: weak_project, registry, needs_server_update: false, servers: HashMap::default(), @@ -360,7 +372,7 @@ impl ContextServerStore { let configuration = state.configuration(); self.stop_server(&state.server().id(), cx)?; - let new_server = self.create_context_server(id.clone(), configuration.clone())?; + let new_server = self.create_context_server(id.clone(), configuration.clone(), cx); self.run_server(new_server, configuration, cx); } Ok(()) @@ -449,14 +461,33 @@ impl ContextServerStore { &self, id: ContextServerId, configuration: Arc<ContextServerConfiguration>, - ) -> Result<Arc<ContextServer>> { + cx: &mut Context<Self>, + ) -> Arc<ContextServer> { + let root_path = self + .project + .read_with(cx, |project, cx| project.active_project_directory(cx)) + .ok() + .flatten() + .or_else(|| { + self.worktree_store.read_with(cx, |store, cx| { + store.visible_worktrees(cx).fold(None, |acc, item| { + if acc.is_none() { + item.read(cx).root_dir() + } else { + acc + } + }) + }) + }); + if let Some(factory) = self.context_server_factory.as_ref() { - Ok(factory(id, configuration)) + factory(id, configuration) } else { - Ok(Arc::new(ContextServer::stdio( + Arc::new(ContextServer::stdio( id, configuration.command().clone(), - ))) + root_path, + )) } } @@ -553,7 +584,7 @@ impl ContextServerStore { let mut servers_to_remove = HashSet::default(); let mut servers_to_stop = HashSet::default(); - this.update(cx, |this, _cx| { + this.update(cx, |this, cx| { for server_id in this.servers.keys() { // All servers that are not in desired_servers should be removed from the store. // This can happen if the user removed a server from the context server settings. @@ -572,14 +603,10 @@ impl ContextServerStore { let existing_config = state.as_ref().map(|state| state.configuration()); if existing_config.as_deref() != Some(&config) || is_stopped { let config = Arc::new(config); - if let Some(server) = this - .create_context_server(id.clone(), config.clone()) - .log_err() - { - servers_to_start.push((server, config)); - if this.servers.contains_key(&id) { - servers_to_stop.insert(id); - } + let server = this.create_context_server(id.clone(), config.clone(), cx); + servers_to_start.push((server, config)); + if this.servers.contains_key(&id) { + servers_to_stop.insert(id); } } } @@ -630,7 +657,12 @@ mod tests { let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); let store = cx.new(|cx| { - ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + ContextServerStore::test( + registry.clone(), + project.read(cx).worktree_store(), + project.downgrade(), + cx, + ) }); let server_1_id = ContextServerId(SERVER_1_ID.into()); @@ -705,7 +737,12 @@ mod tests { let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); let store = cx.new(|cx| { - ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + ContextServerStore::test( + registry.clone(), + project.read(cx).worktree_store(), + project.downgrade(), + cx, + ) }); let server_1_id = ContextServerId(SERVER_1_ID.into()); @@ -758,7 +795,12 @@ mod tests { let registry = cx.new(|_| ContextServerDescriptorRegistry::new()); let store = cx.new(|cx| { - ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) + ContextServerStore::test( + registry.clone(), + project.read(cx).worktree_store(), + project.downgrade(), + cx, + ) }); let server_id = ContextServerId(SERVER_1_ID.into()); @@ -842,6 +884,7 @@ mod tests { }), registry.clone(), project.read(cx).worktree_store(), + project.downgrade(), cx, ) }); @@ -1074,6 +1117,7 @@ mod tests { }), registry.clone(), project.read(cx).worktree_store(), + project.downgrade(), cx, ) }); diff --git a/crates/project/src/debugger/dap_command.rs b/crates/project/src/debugger/dap_command.rs index 1cb611680c..3be3192369 100644 --- a/crates/project/src/debugger/dap_command.rs +++ b/crates/project/src/debugger/dap_command.rs @@ -107,7 +107,7 @@ impl<T: DapCommand> DapCommand for Arc<T> { #[derive(Debug, Hash, PartialEq, Eq)] pub struct StepCommand { - pub thread_id: u64, + pub thread_id: i64, pub granularity: Option<SteppingGranularity>, pub single_thread: Option<bool>, } @@ -483,7 +483,7 @@ impl DapCommand for ContinueCommand { #[derive(Debug, Hash, PartialEq, Eq)] pub(crate) struct PauseCommand { - pub thread_id: u64, + pub thread_id: i64, } impl LocalDapCommand for PauseCommand { @@ -612,7 +612,7 @@ impl DapCommand for DisconnectCommand { #[derive(Debug, Hash, PartialEq, Eq)] pub(crate) struct TerminateThreadsCommand { - pub thread_ids: Option<Vec<u64>>, + pub thread_ids: Option<Vec<i64>>, } impl LocalDapCommand for TerminateThreadsCommand { @@ -1182,7 +1182,7 @@ impl DapCommand for LoadedSourcesCommand { #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub(crate) struct StackTraceCommand { - pub thread_id: u64, + pub thread_id: i64, pub start_frame: Option<u64>, pub levels: Option<u64>, } diff --git a/crates/project/src/debugger/dap_store.rs b/crates/project/src/debugger/dap_store.rs index d494088b13..6f834b5dc0 100644 --- a/crates/project/src/debugger/dap_store.rs +++ b/crates/project/src/debugger/dap_store.rs @@ -920,12 +920,22 @@ impl dap::adapters::DapDelegate for DapAdapterDelegate { self.console.unbounded_send(msg).ok(); } + #[cfg(not(target_os = "windows"))] async fn which(&self, command: &OsStr) -> Option<PathBuf> { let worktree_abs_path = self.worktree.abs_path(); let shell_path = self.shell_env().await.get("PATH").cloned(); which::which_in(command, shell_path.as_ref(), worktree_abs_path).ok() } + #[cfg(target_os = "windows")] + async fn which(&self, command: &OsStr) -> Option<PathBuf> { + // On Windows, `PATH` is handled differently from Unix. Windows generally expects users to modify the `PATH` themselves, + // and every program loads it directly from the system at startup. + // There's also no concept of a default shell on Windows, and you can't really retrieve one, so trying to get shell environment variables + // from a specific directory doesn’t make sense on Windows. + which::which(command).ok() + } + async fn shell_env(&self) -> HashMap<String, String> { let task = self.load_shell_env_task.clone(); task.await.unwrap_or_default() diff --git a/crates/project/src/debugger/locators/cargo.rs b/crates/project/src/debugger/locators/cargo.rs index 7d70371380..fa265dae58 100644 --- a/crates/project/src/debugger/locators/cargo.rs +++ b/crates/project/src/debugger/locators/cargo.rs @@ -128,7 +128,7 @@ impl DapLocator for CargoLocator { .chain(Some("--message-format=json".to_owned())) .collect(), ); - let mut child = Command::new(program) + let mut child = util::command::new_smol_command(program) .args(args) .envs(build_config.env.iter().map(|(k, v)| (k.clone(), v.clone()))) .current_dir(cwd) diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index 1e296ac2ac..f60a7becf7 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -61,15 +61,10 @@ use worktree::Worktree; #[derive(Debug, Copy, Clone, Hash, PartialEq, PartialOrd, Ord, Eq)] #[repr(transparent)] -pub struct ThreadId(pub u64); +pub struct ThreadId(pub i64); -impl ThreadId { - pub const MIN: ThreadId = ThreadId(u64::MIN); - pub const MAX: ThreadId = ThreadId(u64::MAX); -} - -impl From<u64> for ThreadId { - fn from(id: u64) -> Self { +impl From<i64> for ThreadId { + fn from(id: i64) -> Self { Self(id) } } diff --git a/crates/project/src/debugger/test.rs b/crates/project/src/debugger/test.rs index 3b9425e369..53b88323e6 100644 --- a/crates/project/src/debugger/test.rs +++ b/crates/project/src/debugger/test.rs @@ -1,7 +1,7 @@ use std::{path::Path, sync::Arc}; use dap::client::DebugAdapterClient; -use gpui::{App, AppContext, Subscription}; +use gpui::{App, Subscription}; use super::session::{Session, SessionStateEvent}; @@ -19,14 +19,6 @@ pub fn intercept_debug_sessions<T: Fn(&Arc<DebugAdapterClient>) + 'static>( let client = session.adapter_client().unwrap(); register_default_handlers(session, &client, cx); configure(&client); - cx.background_spawn(async move { - client - .fake_event(dap::messages::Events::Initialized( - Some(Default::default()), - )) - .await - }) - .detach(); } }) .detach(); diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index eb16446daf..01fc987816 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -14,9 +14,10 @@ use collections::HashMap; pub use conflict_set::{ConflictRegion, ConflictSet, ConflictSetSnapshot, ConflictSetUpdate}; use fs::Fs; use futures::{ - FutureExt, StreamExt as _, + FutureExt, StreamExt, channel::{mpsc, oneshot}, future::{self, Shared}, + stream::FuturesOrdered, }; use git::{ BuildPermalinkParams, GitHostingProviderRegistry, WORK_DIRECTORY_REPO_PATH, @@ -63,8 +64,8 @@ use sum_tree::{Edit, SumTree, TreeSet}; use text::{Bias, BufferId}; use util::{ResultExt, debug_panic, post_inc}; use worktree::{ - File, PathKey, PathProgress, PathSummary, PathTarget, UpdatedGitRepositoriesSet, - UpdatedGitRepository, Worktree, + File, PathChange, PathKey, PathProgress, PathSummary, PathTarget, ProjectEntryId, + UpdatedGitRepositoriesSet, UpdatedGitRepository, Worktree, }; pub struct GitStore { @@ -245,6 +246,8 @@ pub struct RepositorySnapshot { pub head_commit: Option<CommitDetails>, pub scan_id: u64, pub merge: MergeDetails, + pub remote_origin_url: Option<String>, + pub remote_upstream_url: Option<String>, } type JobId = u64; @@ -419,6 +422,8 @@ impl GitStore { client.add_entity_request_handler(Self::handle_fetch); client.add_entity_request_handler(Self::handle_stage); client.add_entity_request_handler(Self::handle_unstage); + client.add_entity_request_handler(Self::handle_stash); + client.add_entity_request_handler(Self::handle_stash_pop); client.add_entity_request_handler(Self::handle_commit); client.add_entity_request_handler(Self::handle_reset); client.add_entity_request_handler(Self::handle_show); @@ -1083,27 +1088,26 @@ impl GitStore { match event { WorktreeStoreEvent::WorktreeUpdatedEntries(worktree_id, updated_entries) => { - let mut paths_by_git_repo = HashMap::<_, Vec<_>>::default(); - for (relative_path, _, _) in updated_entries.iter() { - let Some((repo, repo_path)) = self.repository_and_path_for_project_path( - &(*worktree_id, relative_path.clone()).into(), - cx, - ) else { - continue; - }; - paths_by_git_repo.entry(repo).or_default().push(repo_path) - } - - for (repo, paths) in paths_by_git_repo { - repo.update(cx, |repo, cx| { - repo.paths_changed( - paths, - downstream - .as_ref() - .map(|downstream| downstream.updates_tx.clone()), - cx, - ); - }); + if let Some(worktree) = self + .worktree_store + .read(cx) + .worktree_for_id(*worktree_id, cx) + { + let paths_by_git_repo = + self.process_updated_entries(&worktree, updated_entries, cx); + let downstream = downstream + .as_ref() + .map(|downstream| downstream.updates_tx.clone()); + cx.spawn(async move |_, cx| { + let paths_by_git_repo = paths_by_git_repo.await; + for (repo, paths) in paths_by_git_repo { + repo.update(cx, |repo, cx| { + repo.paths_changed(paths, downstream.clone(), cx); + }) + .ok(); + } + }) + .detach(); } } WorktreeStoreEvent::WorktreeUpdatedGitRepositories(worktree_id, changed_repos) => { @@ -1696,6 +1700,48 @@ impl GitStore { Ok(proto::Ack {}) } + async fn handle_stash( + this: Entity<Self>, + envelope: TypedEnvelope<proto::Stash>, + mut cx: AsyncApp, + ) -> Result<proto::Ack> { + let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); + let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; + + let entries = envelope + .payload + .paths + .into_iter() + .map(PathBuf::from) + .map(RepoPath::new) + .collect(); + + repository_handle + .update(&mut cx, |repository_handle, cx| { + repository_handle.stash_entries(entries, cx) + })? + .await?; + + Ok(proto::Ack {}) + } + + async fn handle_stash_pop( + this: Entity<Self>, + envelope: TypedEnvelope<proto::StashPop>, + mut cx: AsyncApp, + ) -> Result<proto::Ack> { + let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); + let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; + + repository_handle + .update(&mut cx, |repository_handle, cx| { + repository_handle.stash_pop(cx) + })? + .await?; + + Ok(proto::Ack {}) + } + async fn handle_set_index_text( this: Entity<Self>, envelope: TypedEnvelope<proto::SetIndexText>, @@ -2191,6 +2237,80 @@ impl GitStore { .map(|(id, repo)| (*id, repo.read(cx).snapshot.clone())) .collect() } + + fn process_updated_entries( + &self, + worktree: &Entity<Worktree>, + updated_entries: &[(Arc<Path>, ProjectEntryId, PathChange)], + cx: &mut App, + ) -> Task<HashMap<Entity<Repository>, Vec<RepoPath>>> { + let mut repo_paths = self + .repositories + .values() + .map(|repo| (repo.read(cx).work_directory_abs_path.clone(), repo.clone())) + .collect::<Vec<_>>(); + let mut entries: Vec<_> = updated_entries + .iter() + .map(|(path, _, _)| path.clone()) + .collect(); + entries.sort(); + let worktree = worktree.read(cx); + + let entries = entries + .into_iter() + .filter_map(|path| worktree.absolutize(&path).ok()) + .collect::<Arc<[_]>>(); + + let executor = cx.background_executor().clone(); + cx.background_executor().spawn(async move { + repo_paths.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); + let mut paths_by_git_repo = HashMap::<_, Vec<_>>::default(); + let mut tasks = FuturesOrdered::new(); + for (repo_path, repo) in repo_paths.into_iter().rev() { + let entries = entries.clone(); + let task = executor.spawn(async move { + // Find all repository paths that belong to this repo + let mut ix = entries.partition_point(|path| path < &*repo_path); + if ix == entries.len() { + return None; + }; + + let mut paths = vec![]; + // All paths prefixed by a given repo will constitute a continuous range. + while let Some(path) = entries.get(ix) + && let Some(repo_path) = + RepositorySnapshot::abs_path_to_repo_path_inner(&repo_path, &path) + { + paths.push((repo_path, ix)); + ix += 1; + } + Some((repo, paths)) + }); + tasks.push_back(task); + } + + // Now, let's filter out the "duplicate" entries that were processed by multiple distinct repos. + let mut path_was_used = vec![false; entries.len()]; + let tasks = tasks.collect::<Vec<_>>().await; + // Process tasks from the back: iterating backwards allows us to see more-specific paths first. + // We always want to assign a path to it's innermost repository. + for t in tasks { + let Some((repo, paths)) = t else { + continue; + }; + let entry = paths_by_git_repo.entry(repo).or_default(); + for (repo_path, ix) in paths { + if path_was_used[ix] { + continue; + } + path_was_used[ix] = true; + entry.push(repo_path); + } + } + + paths_by_git_repo + }) + } } impl BufferGitState { @@ -2555,6 +2675,8 @@ impl RepositorySnapshot { head_commit: None, scan_id: 0, merge: Default::default(), + remote_origin_url: None, + remote_upstream_url: None, } } @@ -2660,8 +2782,16 @@ impl RepositorySnapshot { } pub fn abs_path_to_repo_path(&self, abs_path: &Path) -> Option<RepoPath> { + Self::abs_path_to_repo_path_inner(&self.work_directory_abs_path, abs_path) + } + + #[inline] + fn abs_path_to_repo_path_inner( + work_directory_abs_path: &Path, + abs_path: &Path, + ) -> Option<RepoPath> { abs_path - .strip_prefix(&self.work_directory_abs_path) + .strip_prefix(&work_directory_abs_path) .map(RepoPath::from) .ok() } @@ -3458,6 +3588,82 @@ impl Repository { self.unstage_entries(to_unstage, cx) } + pub fn stash_all(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> { + let to_stash = self + .cached_status() + .map(|entry| entry.repo_path.clone()) + .collect(); + + self.stash_entries(to_stash, cx) + } + + pub fn stash_entries( + &mut self, + entries: Vec<RepoPath>, + cx: &mut Context<Self>, + ) -> Task<anyhow::Result<()>> { + let id = self.id; + + cx.spawn(async move |this, cx| { + this.update(cx, |this, _| { + this.send_job(None, move |git_repo, _cx| async move { + match git_repo { + RepositoryState::Local { + backend, + environment, + .. + } => backend.stash_paths(entries, environment).await, + RepositoryState::Remote { project_id, client } => { + client + .request(proto::Stash { + project_id: project_id.0, + repository_id: id.to_proto(), + paths: entries + .into_iter() + .map(|repo_path| repo_path.as_ref().to_proto()) + .collect(), + }) + .await + .context("sending stash request")?; + Ok(()) + } + } + }) + })? + .await??; + Ok(()) + }) + } + + pub fn stash_pop(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> { + let id = self.id; + cx.spawn(async move |this, cx| { + this.update(cx, |this, _| { + this.send_job(None, move |git_repo, _cx| async move { + match git_repo { + RepositoryState::Local { + backend, + environment, + .. + } => backend.stash_pop(environment).await, + RepositoryState::Remote { project_id, client } => { + client + .request(proto::StashPop { + project_id: project_id.0, + repository_id: id.to_proto(), + }) + .await + .context("sending stash pop request")?; + Ok(()) + } + } + }) + })? + .await??; + Ok(()) + }) + } + pub fn commit( &mut self, message: SharedString, @@ -3823,6 +4029,25 @@ impl Repository { }) } + pub fn default_branch(&mut self) -> oneshot::Receiver<Result<Option<SharedString>>> { + let id = self.id; + self.send_job(None, move |repo, _| async move { + match repo { + RepositoryState::Local { backend, .. } => backend.default_branch().await, + RepositoryState::Remote { project_id, client } => { + let response = client + .request(proto::GetDefaultBranch { + project_id: project_id.0, + repository_id: id.to_proto(), + }) + .await?; + + anyhow::Ok(response.branch.map(SharedString::from)) + } + } + }) + } + pub fn diff(&mut self, diff_type: DiffType, _cx: &App) -> oneshot::Receiver<Result<String>> { let id = self.id; self.send_job(None, move |repo, _cx| async move { @@ -4597,6 +4822,10 @@ async fn compute_snapshot( None => None, }; + // Used by edit prediction data collection + let remote_origin_url = backend.remote_url("origin"); + let remote_upstream_url = backend.remote_url("upstream"); + let snapshot = RepositorySnapshot { id, statuses_by_path, @@ -4605,6 +4834,8 @@ async fn compute_snapshot( branch, head_commit, merge: merge_details, + remote_origin_url, + remote_upstream_url, }; Ok((snapshot, events)) diff --git a/crates/project/src/git_store/git_traversal.rs b/crates/project/src/git_store/git_traversal.rs index cd173d5714..777042cb02 100644 --- a/crates/project/src/git_store/git_traversal.rs +++ b/crates/project/src/git_store/git_traversal.rs @@ -1,6 +1,6 @@ use collections::HashMap; -use git::status::GitSummary; -use std::{ops::Deref, path::Path}; +use git::{repository::RepoPath, status::GitSummary}; +use std::{collections::BTreeMap, ops::Deref, path::Path}; use sum_tree::Cursor; use text::Bias; use worktree::{Entry, PathProgress, PathTarget, Traversal}; @@ -11,7 +11,7 @@ use super::{RepositoryId, RepositorySnapshot, StatusEntry}; pub struct GitTraversal<'a> { traversal: Traversal<'a>, current_entry_summary: Option<GitSummary>, - repo_snapshots: &'a HashMap<RepositoryId, RepositorySnapshot>, + repo_root_to_snapshot: BTreeMap<&'a Path, &'a RepositorySnapshot>, repo_location: Option<(RepositoryId, Cursor<'a, StatusEntry, PathProgress<'a>>)>, } @@ -20,16 +20,46 @@ impl<'a> GitTraversal<'a> { repo_snapshots: &'a HashMap<RepositoryId, RepositorySnapshot>, traversal: Traversal<'a>, ) -> GitTraversal<'a> { + let repo_root_to_snapshot = repo_snapshots + .values() + .map(|snapshot| (&*snapshot.work_directory_abs_path, snapshot)) + .collect(); let mut this = GitTraversal { traversal, - repo_snapshots, current_entry_summary: None, repo_location: None, + repo_root_to_snapshot, }; this.synchronize_statuses(true); this } + fn repo_root_for_path(&self, path: &Path) -> Option<(&'a RepositorySnapshot, RepoPath)> { + // We might need to perform a range search multiple times, as there may be a nested repository inbetween + // the target and our path. E.g: + // /our_root_repo/ + // .git/ + // other_repo/ + // .git/ + // our_query.txt + let mut query = path.ancestors(); + while let Some(query) = query.next() { + let (_, snapshot) = self + .repo_root_to_snapshot + .range(Path::new("")..=query) + .last()?; + + let stripped = snapshot + .abs_path_to_repo_path(path) + .map(|repo_path| (*snapshot, repo_path)); + if stripped.is_some() { + return stripped; + } + } + + None + } + fn synchronize_statuses(&mut self, reset: bool) { self.current_entry_summary = None; @@ -42,15 +72,7 @@ impl<'a> GitTraversal<'a> { return; }; - let Some((repo, repo_path)) = self - .repo_snapshots - .values() - .filter_map(|repo_snapshot| { - let repo_path = repo_snapshot.abs_path_to_repo_path(&abs_path)?; - Some((repo_snapshot, repo_path)) - }) - .max_by_key(|(repo, _)| repo.work_directory_abs_path.clone()) - else { + let Some((repo, repo_path)) = self.repo_root_for_path(&abs_path) else { self.repo_location = None; return; }; diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index a2f6de44c9..c458b6b300 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -2154,6 +2154,16 @@ impl LspCommand for GetHover { } } +impl GetCompletions { + pub fn can_resolve_completions(capabilities: &lsp::ServerCapabilities) -> bool { + capabilities + .completion_provider + .as_ref() + .and_then(|options| options.resolve_provider) + .unwrap_or(false) + } +} + #[async_trait(?Send)] impl LspCommand for GetCompletions { type Response = CoreCompletionResponse; @@ -2269,7 +2279,7 @@ impl LspCommand for GetCompletions { // the range based on the syntax tree. None => { if self.position != clipped_position { - log::info!("completion out of expected range"); + log::info!("completion out of expected range "); return false; } @@ -2483,7 +2493,9 @@ pub(crate) fn parse_completion_text_edit( let start = snapshot.clip_point_utf16(range.start, Bias::Left); let end = snapshot.clip_point_utf16(range.end, Bias::Left); if start != range.start.0 || end != range.end.0 { - log::info!("completion out of expected range"); + log::info!( + "completion out of expected range, start: {start:?}, end: {end:?}, range: {range:?}" + ); return None; } snapshot.anchor_before(start)..snapshot.anchor_after(end) @@ -2760,6 +2772,23 @@ impl GetCodeActions { } } +impl OnTypeFormatting { + pub fn supports_on_type_formatting(trigger: &str, capabilities: &ServerCapabilities) -> bool { + let Some(on_type_formatting_options) = &capabilities.document_on_type_formatting_provider + else { + return false; + }; + on_type_formatting_options + .first_trigger_character + .contains(trigger) + || on_type_formatting_options + .more_trigger_character + .iter() + .flatten() + .any(|chars| chars.contains(trigger)) + } +} + #[async_trait(?Send)] impl LspCommand for OnTypeFormatting { type Response = Option<Transaction>; @@ -2771,20 +2800,7 @@ impl LspCommand for OnTypeFormatting { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - let Some(on_type_formatting_options) = &capabilities - .server_capabilities - .document_on_type_formatting_provider - else { - return false; - }; - on_type_formatting_options - .first_trigger_character - .contains(&self.trigger) - || on_type_formatting_options - .more_trigger_character - .iter() - .flatten() - .any(|chars| chars.contains(&self.trigger)) + Self::supports_on_type_formatting(&self.trigger, &capabilities.server_capabilities) } fn to_lsp( @@ -3268,6 +3284,16 @@ impl InlayHints { }) .unwrap_or(false) } + + pub fn check_capabilities(capabilities: &ServerCapabilities) -> bool { + capabilities + .inlay_hint_provider + .as_ref() + .is_some_and(|inlay_hint_provider| match inlay_hint_provider { + lsp::OneOf::Left(enabled) => *enabled, + lsp::OneOf::Right(_) => true, + }) + } } #[async_trait(?Send)] @@ -3281,17 +3307,7 @@ impl LspCommand for InlayHints { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - let Some(inlay_hint_provider) = &capabilities.server_capabilities.inlay_hint_provider - else { - return false; - }; - match inlay_hint_provider { - lsp::OneOf::Left(enabled) => *enabled, - lsp::OneOf::Right(inlay_hint_capabilities) => match inlay_hint_capabilities { - lsp::InlayHintServerCapabilities::Options(_) => true, - lsp::InlayHintServerCapabilities::RegistrationOptions(_) => false, - }, - } + Self::check_capabilities(&capabilities.server_capabilities) } fn to_lsp( @@ -3578,6 +3594,18 @@ impl LspCommand for GetCodeLens { } } +impl LinkedEditingRange { + pub fn check_server_capabilities(capabilities: ServerCapabilities) -> bool { + let Some(linked_editing_options) = capabilities.linked_editing_range_provider else { + return false; + }; + if let LinkedEditingRangeServerCapabilities::Simple(false) = linked_editing_options { + return false; + } + true + } +} + #[async_trait(?Send)] impl LspCommand for LinkedEditingRange { type Response = Vec<Range<Anchor>>; @@ -3589,16 +3617,7 @@ impl LspCommand for LinkedEditingRange { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - let Some(linked_editing_options) = &capabilities - .server_capabilities - .linked_editing_range_provider - else { - return false; - }; - if let LinkedEditingRangeServerCapabilities::Simple(false) = linked_editing_options { - return false; - } - true + Self::check_server_capabilities(capabilities.server_capabilities) } fn to_lsp( @@ -4216,8 +4235,9 @@ impl LspCommand for GetDocumentColor { server_capabilities .server_capabilities .color_provider + .as_ref() .is_some_and(|capability| match capability { - lsp::ColorProviderCapability::Simple(supported) => supported, + lsp::ColorProviderCapability::Simple(supported) => *supported, lsp::ColorProviderCapability::ColorProvider(..) => true, lsp::ColorProviderCapability::Options(..) => true, }) diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 0cd375e0c5..4489f9f043 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -46,6 +46,7 @@ use language::{ DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, Diff, File as _, Language, LanguageName, LanguageRegistry, LanguageToolchainStore, LocalFile, LspAdapter, LspAdapterDelegate, Patch, PointUtf16, TextBufferSnapshot, ToOffset, ToPointUtf16, Transaction, Unclipped, + WorkspaceFoldersContent, language_settings::{ FormatOnSave, Formatter, LanguageSettings, SelectedFormatter, language_settings, }, @@ -57,12 +58,12 @@ use language::{ range_from_lsp, range_to_lsp, }; use lsp::{ - CodeActionKind, CompletionContext, DiagnosticSeverity, DiagnosticTag, - DidChangeWatchedFilesRegistrationOptions, Edit, FileOperationFilter, FileOperationPatternKind, - FileOperationRegistrationOptions, FileRename, FileSystemWatcher, LanguageServer, - LanguageServerBinary, LanguageServerBinaryOptions, LanguageServerId, LanguageServerName, - LanguageServerSelector, LspRequestFuture, MessageActionItem, MessageType, OneOf, - RenameFilesParams, SymbolKind, TextEdit, WillRenameFiles, WorkDoneProgressCancelParams, + AdapterServerCapabilities, CodeActionKind, CompletionContext, DiagnosticSeverity, + DiagnosticTag, DidChangeWatchedFilesRegistrationOptions, Edit, FileOperationFilter, + FileOperationPatternKind, FileOperationRegistrationOptions, FileRename, FileSystemWatcher, + LanguageServer, LanguageServerBinary, LanguageServerBinaryOptions, LanguageServerId, + LanguageServerName, LanguageServerSelector, LspRequestFuture, MessageActionItem, MessageType, + OneOf, RenameFilesParams, SymbolKind, TextEdit, WillRenameFiles, WorkDoneProgressCancelParams, WorkspaceFolder, notification::DidRenameFiles, }; use node_runtime::read_package_installed_version; @@ -95,6 +96,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; +use sum_tree::Dimensions; use text::{Anchor, BufferId, LineEnding, OffsetRangeExt}; use url::Url; use util::{ @@ -217,6 +219,7 @@ impl LocalLspStore { let binary = self.get_language_server_binary(adapter.clone(), delegate.clone(), true, cx); let pending_workspace_folders: Arc<Mutex<BTreeSet<Url>>> = Default::default(); + let pending_server = cx.spawn({ let adapter = adapter.clone(); let server_name = adapter.name.clone(); @@ -242,14 +245,18 @@ impl LocalLspStore { return Ok(server); } + let code_action_kinds = adapter.code_action_kinds(); lsp::LanguageServer::new( stderr_capture, server_id, server_name, binary, &root_path, - adapter.code_action_kinds(), - pending_workspace_folders, + code_action_kinds, + Some(pending_workspace_folders).filter(|_| { + adapter.adapter.workspace_folders_content() + == WorkspaceFoldersContent::SubprojectRoots + }), cx, ) } @@ -418,7 +425,7 @@ impl LocalLspStore { if settings.as_ref().is_some_and(|b| b.path.is_some()) { let settings = settings.unwrap(); - return cx.spawn(async move |_| { + return cx.background_spawn(async move { let mut env = delegate.shell_env().await; env.extend(settings.env.unwrap_or_default()); @@ -575,8 +582,7 @@ impl LocalLspStore { }; let root = server.workspace_folders(); Ok(Some( - root.iter() - .cloned() + root.into_iter() .map(|uri| WorkspaceFolder { uri, name: Default::default(), @@ -616,7 +622,7 @@ impl LocalLspStore { .on_request::<lsp::request::RegisterCapability, _, _>({ let this = this.clone(); move |params, cx| { - let this = this.clone(); + let lsp_store = this.clone(); let mut cx = cx.clone(); async move { for reg in params.registrations { @@ -624,7 +630,7 @@ impl LocalLspStore { "workspace/didChangeWatchedFiles" => { if let Some(options) = reg.register_options { let options = serde_json::from_value(options)?; - this.update(&mut cx, |this, cx| { + lsp_store.update(&mut cx, |this, cx| { this.as_local_mut()?.on_lsp_did_change_watched_files( server_id, ®.id, options, cx, ); @@ -633,8 +639,9 @@ impl LocalLspStore { } } "textDocument/rangeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -653,14 +660,16 @@ impl LocalLspStore { server.update_capabilities(|capabilities| { capabilities.document_range_formatting_provider = Some(provider); - }) + }); + notify_server_capabilities_updated(&server, cx); } anyhow::Ok(()) })??; } "textDocument/onTypeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -677,15 +686,17 @@ impl LocalLspStore { capabilities .document_on_type_formatting_provider = Some(options); - }) + }); + notify_server_capabilities_updated(&server, cx); } } anyhow::Ok(()) })??; } "textDocument/formatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -704,7 +715,8 @@ impl LocalLspStore { server.update_capabilities(|capabilities| { capabilities.document_formatting_provider = Some(provider); - }) + }); + notify_server_capabilities_updated(&server, cx); } anyhow::Ok(()) })??; @@ -713,8 +725,9 @@ impl LocalLspStore { // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. } "textDocument/rename" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { let options = reg .register_options @@ -731,7 +744,8 @@ impl LocalLspStore { server.update_capabilities(|capabilities| { capabilities.rename_provider = Some(options); - }) + }); + notify_server_capabilities_updated(&server, cx); } anyhow::Ok(()) })??; @@ -749,14 +763,15 @@ impl LocalLspStore { .on_request::<lsp::request::UnregisterCapability, _, _>({ let this = this.clone(); move |params, cx| { - let this = this.clone(); + let lsp_store = this.clone(); let mut cx = cx.clone(); async move { for unreg in params.unregisterations.iter() { match unreg.method.as_str() { "workspace/didChangeWatchedFiles" => { - this.update(&mut cx, |this, cx| { - this.as_local_mut()? + lsp_store.update(&mut cx, |lsp_store, cx| { + lsp_store + .as_local_mut()? .on_lsp_unregister_did_change_watched_files( server_id, &unreg.id, cx, ); @@ -767,44 +782,52 @@ impl LocalLspStore { // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. } "textDocument/rename" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.rename_provider = None - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } "textDocument/rangeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.document_range_formatting_provider = None - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } "textDocument/onTypeFormatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.document_on_type_formatting_provider = None; - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } "textDocument/formatting" => { - this.read_with(&mut cx, |this, _| { - if let Some(server) = this.language_server_for_id(server_id) + lsp_store.update(&mut cx, |lsp_store, cx| { + if let Some(server) = + lsp_store.language_server_for_id(server_id) { server.update_capabilities(|capabilities| { capabilities.document_formatting_provider = None; - }) + }); + notify_server_capabilities_updated(&server, cx); } })?; } @@ -2420,36 +2443,11 @@ impl LocalLspStore { let server_id = server_node.server_id_or_init( |LaunchDisposition { server_name, - attach, path, settings, }| { - let server_id = match attach { - language::Attach::InstancePerRoot => { - // todo: handle instance per root proper. - if let Some(server_ids) = self - .language_server_ids - .get(&(worktree_id, server_name.clone())) - { - server_ids.iter().cloned().next().unwrap() - } else { - let language_name = language.name(); - let adapter = self.languages - .lsp_adapters(&language_name) - .into_iter() - .find(|adapter| &adapter.name() == server_name) - .expect("To find LSP adapter"); - let server_id = self.start_language_server( - &worktree, - delegate.clone(), - adapter, - settings, - cx, - ); - server_id - } - } - language::Attach::Shared => { + let server_id = + { let uri = Url::from_file_path( worktree.read(cx).abs_path().join(&path.path), ); @@ -2484,20 +2482,8 @@ impl LocalLspStore { } else { unreachable!("Language server ID should be available, as it's registered on demand") } - } + }; - let lsp_store = self.weak.clone(); - let server_name = server_node.name(); - let buffer_abs_path = abs_path.to_string_lossy().to_string(); - cx.defer(move |cx| { - lsp_store.update(cx, |_, cx| cx.emit(LspStoreEvent::LanguageServerUpdate { - language_server_id: server_id, - name: server_name, - message: proto::update_language_server::Variant::RegisteredForBuffer(proto::RegisteredForBuffer { - buffer_abs_path, - }) - })).ok(); - }); server_id }, )?; @@ -2533,11 +2519,13 @@ impl LocalLspStore { snapshot: initial_snapshot.clone(), }; + let mut registered = false; self.buffer_snapshots .entry(buffer_id) .or_default() .entry(server.server_id()) .or_insert_with(|| { + registered = true; server.register_buffer( uri.clone(), adapter.language_id(&language.name()), @@ -2552,15 +2540,18 @@ impl LocalLspStore { .entry(buffer_id) .or_default() .insert(server.server_id()); - cx.emit(LspStoreEvent::LanguageServerUpdate { - language_server_id: server.server_id(), - name: None, - message: proto::update_language_server::Variant::RegisteredForBuffer( - proto::RegisteredForBuffer { - buffer_abs_path: abs_path.to_string_lossy().to_string(), - }, - ), - }); + if registered { + cx.emit(LspStoreEvent::LanguageServerUpdate { + language_server_id: server.server_id(), + name: None, + message: proto::update_language_server::Variant::RegisteredForBuffer( + proto::RegisteredForBuffer { + buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), + }, + ), + }); + } } } @@ -3512,6 +3503,20 @@ impl LocalLspStore { } } +fn notify_server_capabilities_updated(server: &LanguageServer, cx: &mut Context<LspStore>) { + if let Some(capabilities) = serde_json::to_string(&server.capabilities()).ok() { + cx.emit(LspStoreEvent::LanguageServerUpdate { + language_server_id: server.server_id(), + name: Some(server.name()), + message: proto::update_language_server::Variant::MetadataUpdated( + proto::ServerMetadataUpdated { + capabilities: Some(capabilities), + }, + ), + }); + } +} + #[derive(Debug)] pub struct FormattableBuffer { handle: Entity<Buffer>, @@ -3551,7 +3556,9 @@ pub struct LspStore { _maintain_buffer_languages: Task<()>, diagnostic_summaries: HashMap<WorktreeId, HashMap<Arc<Path>, HashMap<LanguageServerId, DiagnosticSummary>>>, - lsp_data: HashMap<BufferId, DocumentColorData>, + pub(super) lsp_server_capabilities: HashMap<LanguageServerId, lsp::ServerCapabilities>, + lsp_document_colors: HashMap<BufferId, DocumentColorData>, + lsp_code_lens: HashMap<BufferId, CodeLensData>, } #[derive(Debug, Default, Clone)] @@ -3561,6 +3568,7 @@ pub struct DocumentColors { } type DocumentColorTask = Shared<Task<std::result::Result<DocumentColors, Arc<anyhow::Error>>>>; +type CodeLensTask = Shared<Task<std::result::Result<Vec<CodeAction>, Arc<anyhow::Error>>>>; #[derive(Debug, Default)] struct DocumentColorData { @@ -3570,8 +3578,15 @@ struct DocumentColorData { colors_update: Option<(Global, DocumentColorTask)>, } +#[derive(Debug, Default)] +struct CodeLensData { + lens_for_version: Global, + lens: HashMap<LanguageServerId, Vec<CodeAction>>, + update: Option<(Global, CodeLensTask)>, +} + #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum ColorFetchStrategy { +pub enum LspFetchStrategy { IgnoreCache, UseCache { known_cache_version: Option<usize> }, } @@ -3613,7 +3628,7 @@ pub enum LspStoreEvent { #[derive(Clone, Debug, Serialize)] pub struct LanguageServerStatus { - pub name: String, + pub name: LanguageServerName, pub pending_work: BTreeMap<String, LanguageServerProgress>, pub has_pending_diagnostic_updates: bool, progress_tokens: HashSet<String>, @@ -3656,7 +3671,6 @@ impl LspStore { client.add_entity_request_handler(Self::handle_apply_additional_edits_for_completion); client.add_entity_request_handler(Self::handle_register_buffer_with_language_servers); client.add_entity_request_handler(Self::handle_rename_project_entry); - client.add_entity_request_handler(Self::handle_language_server_id_for_name); client.add_entity_request_handler(Self::handle_pull_workspace_diagnostics); client.add_entity_request_handler(Self::handle_lsp_command::<GetCodeActions>); client.add_entity_request_handler(Self::handle_lsp_command::<GetCompletions>); @@ -3804,7 +3818,9 @@ impl LspStore { language_server_statuses: Default::default(), nonce: StdRng::from_entropy().r#gen(), diagnostic_summaries: HashMap::default(), - lsp_data: HashMap::default(), + lsp_server_capabilities: HashMap::default(), + lsp_document_colors: HashMap::default(), + lsp_code_lens: HashMap::default(), active_entry: None, _maintain_workspace_config, _maintain_buffer_languages: Self::maintain_buffer_languages(languages, cx), @@ -3819,6 +3835,9 @@ impl LspStore { request: R, cx: &mut Context<LspStore>, ) -> Task<anyhow::Result<<R as LspCommand>::Response>> { + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(R::Response::default())); + } let message = request.to_proto(upstream_project_id, buffer.read(cx)); cx.spawn(async move |this, cx| { let response = client.request(message).await?; @@ -3861,7 +3880,9 @@ impl LspStore { language_server_statuses: Default::default(), nonce: StdRng::from_entropy().r#gen(), diagnostic_summaries: HashMap::default(), - lsp_data: HashMap::default(), + lsp_server_capabilities: HashMap::default(), + lsp_document_colors: HashMap::default(), + lsp_code_lens: HashMap::default(), active_entry: None, toolchain_store, _maintain_workspace_config, @@ -4162,7 +4183,8 @@ impl LspStore { *refcount }; if refcount == 0 { - lsp_store.lsp_data.remove(&buffer_id); + lsp_store.lsp_document_colors.remove(&buffer_id); + lsp_store.lsp_code_lens.remove(&buffer_id); let local = lsp_store.as_local_mut().unwrap(); local.registered_buffers.remove(&buffer_id); local.buffers_opened_in_servers.remove(&buffer_id); @@ -4434,20 +4456,73 @@ impl LspStore { } } - pub fn request_lsp<R: LspCommand>( + // TODO: remove MultiLspQuery: instead, the proto handler should pick appropriate server(s) + // Then, use `send_lsp_proto_request` or analogue for most of the LSP proto requests and inline this check inside + fn is_capable_for_proto_request<R>( + &self, + buffer: &Entity<Buffer>, + request: &R, + cx: &Context<Self>, + ) -> bool + where + R: LspCommand, + { + self.check_if_capable_for_proto_request( + buffer, + |capabilities| { + request.check_capabilities(AdapterServerCapabilities { + server_capabilities: capabilities.clone(), + code_action_kinds: None, + }) + }, + cx, + ) + } + + fn check_if_capable_for_proto_request<F>( + &self, + buffer: &Entity<Buffer>, + check: F, + cx: &Context<Self>, + ) -> bool + where + F: Fn(&lsp::ServerCapabilities) -> bool, + { + let Some(language) = buffer.read(cx).language().cloned() else { + return false; + }; + let relevant_language_servers = self + .languages + .lsp_adapters(&language.name()) + .into_iter() + .map(|lsp_adapter| lsp_adapter.name()) + .collect::<HashSet<_>>(); + self.language_server_statuses + .iter() + .filter_map(|(server_id, server_status)| { + relevant_language_servers + .contains(&server_status.name) + .then_some(server_id) + }) + .filter_map(|server_id| self.lsp_server_capabilities.get(&server_id)) + .any(check) + } + + pub fn request_lsp<R>( &mut self, - buffer_handle: Entity<Buffer>, + buffer: Entity<Buffer>, server: LanguageServerToQuery, request: R, cx: &mut Context<Self>, ) -> Task<Result<R::Response>> where + R: LspCommand, <R::LspRequest as lsp::request::Request>::Result: Send, <R::LspRequest as lsp::request::Request>::Params: Send, { if let Some((upstream_client, upstream_project_id)) = self.upstream_client() { return self.send_lsp_proto_request( - buffer_handle, + buffer, upstream_client, upstream_project_id, request, @@ -4455,7 +4530,7 @@ impl LspStore { ); } - let Some(language_server) = buffer_handle.update(cx, |buffer, cx| match server { + let Some(language_server) = buffer.update(cx, |buffer, cx| match server { LanguageServerToQuery::FirstCapable => self.as_local().and_then(|local| { local .language_servers_for_buffer(buffer, cx) @@ -4475,8 +4550,7 @@ impl LspStore { return Task::ready(Ok(Default::default())); }; - let buffer = buffer_handle.read(cx); - let file = File::from_dyn(buffer.file()).and_then(File::as_local); + let file = File::from_dyn(buffer.read(cx).file()).and_then(File::as_local); let Some(file) = file else { return Task::ready(Ok(Default::default())); @@ -4484,7 +4558,7 @@ impl LspStore { let lsp_params = match request.to_lsp_params_or_response( &file.abs_path(cx), - buffer, + buffer.read(cx), &language_server, cx, ) { @@ -4560,7 +4634,7 @@ impl LspStore { .response_from_lsp( response, this.upgrade().context("no app context")?, - buffer_handle, + buffer, language_server.server_id(), cx.clone(), ) @@ -4630,7 +4704,8 @@ impl LspStore { ) }) { let buffer = buffer_handle.read(cx); - if !local.registered_buffers.contains_key(&buffer.remote_id()) { + let buffer_id = buffer.remote_id(); + if !local.registered_buffers.contains_key(&buffer_id) { continue; } if let Some((file, language)) = File::from_dyn(buffer.file()) @@ -4688,35 +4763,11 @@ impl LspStore { let server_id = node.server_id_or_init( |LaunchDisposition { server_name, - attach, + path, settings, - }| match attach { - language::Attach::InstancePerRoot => { - // todo: handle instance per root proper. - if let Some(server_ids) = local - .language_server_ids - .get(&(worktree_id, server_name.clone())) - { - server_ids.iter().cloned().next().unwrap() - } else { - let adapter = local - .languages - .lsp_adapters(&language) - .into_iter() - .find(|adapter| &adapter.name() == server_name) - .expect("To find LSP adapter"); - let server_id = local.start_language_server( - &worktree, - delegate.clone(), - adapter, - settings, - cx, - ); - server_id - } - } - language::Attach::Shared => { + }| + { let uri = Url::from_file_path( worktree.read(cx).abs_path().join(&path.path), ); @@ -4745,7 +4796,6 @@ impl LspStore { } server_id } - }, ); if let Some(language_server_id) = server_id { @@ -4756,6 +4806,7 @@ impl LspStore { proto::update_language_server::Variant::RegisteredForBuffer( proto::RegisteredForBuffer { buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), }, ), }); @@ -4931,19 +4982,24 @@ impl LspStore { pub fn resolve_inlay_hint( &self, - hint: InlayHint, - buffer_handle: Entity<Buffer>, + mut hint: InlayHint, + buffer: Entity<Buffer>, server_id: LanguageServerId, cx: &mut Context<Self>, ) -> Task<anyhow::Result<InlayHint>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + if !self.check_if_capable_for_proto_request(&buffer, InlayHints::can_resolve_inlays, cx) + { + hint.resolve_state = ResolveState::Resolved; + return Task::ready(Ok(hint)); + } let request = proto::ResolveInlayHint { project_id, - buffer_id: buffer_handle.read(cx).remote_id().into(), + buffer_id: buffer.read(cx).remote_id().into(), language_server_id: server_id.0 as u64, hint: Some(InlayHints::project_to_proto_hint(hint.clone())), }; - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let response = upstream_client .request(request) .await @@ -4955,7 +5011,7 @@ impl LspStore { } }) } else { - let Some(lang_server) = buffer_handle.update(cx, |buffer, cx| { + let Some(lang_server) = buffer.update(cx, |buffer, cx| { self.language_server_for_local_buffer(buffer, server_id, cx) .map(|(_, server)| server.clone()) }) else { @@ -4964,7 +5020,7 @@ impl LspStore { if !InlayHints::can_resolve_inlays(&lang_server.capabilities()) { return Task::ready(Ok(hint)); } - let buffer_snapshot = buffer_handle.read(cx).snapshot(); + let buffer_snapshot = buffer.read(cx).snapshot(); cx.spawn(async move |_, cx| { let resolve_task = lang_server.request::<lsp::request::InlayHintResolveRequest>( InlayHints::project_to_lsp_hint(hint, &buffer_snapshot), @@ -4975,7 +5031,7 @@ impl LspStore { .context("inlay hint resolve LSP request")?; let resolved_hint = InlayHints::lsp_to_project_hint( resolved_hint, - &buffer_handle, + &buffer, server_id, ResolveState::Resolved, false, @@ -5086,7 +5142,7 @@ impl LspStore { } } - pub(crate) fn linked_edit( + pub(crate) fn linked_edits( &mut self, buffer: &Entity<Buffer>, position: Anchor, @@ -5101,10 +5157,7 @@ impl LspStore { local .language_servers_for_buffer(buffer, cx) .filter(|(_, server)| { - server - .capabilities() - .linked_editing_range_provider - .is_some() + LinkedEditingRange::check_server_capabilities(server.capabilities()) }) .filter(|(adapter, _)| { scope @@ -5131,7 +5184,7 @@ impl LspStore { }) == Some(true) }) else { - return Task::ready(Ok(vec![])); + return Task::ready(Ok(Vec::new())); }; self.request_lsp( @@ -5150,6 +5203,15 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<Option<Transaction>>> { if let Some((client, project_id)) = self.upstream_client() { + if !self.check_if_capable_for_proto_request( + &buffer, + |capabilities| { + OnTypeFormatting::supports_on_type_formatting(&trigger, capabilities) + }, + cx, + ) { + return Task::ready(Ok(None)); + } let request = proto::OnTypeFormatting { project_id, buffer_id: buffer.read(cx).remote_id().into(), @@ -5157,7 +5219,7 @@ impl LspStore { trigger, version: serialize_version(&buffer.read(cx).version()), }; - cx.spawn(async move |_, _| { + cx.background_spawn(async move { client .request(request) .await? @@ -5261,6 +5323,10 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetDefinitions { position }; + if !self.is_capable_for_proto_request(buffer_handle, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { buffer_id: buffer_handle.read(cx).remote_id().into(), version: serialize_version(&buffer_handle.read(cx).version()), @@ -5269,7 +5335,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetDefinition( - GetDefinitions { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer_handle.read(cx)), )), }); let buffer = buffer_handle.clone(); @@ -5316,7 +5382,7 @@ impl LspStore { GetDefinitions { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(definitions_task .await .into_iter() @@ -5334,6 +5400,10 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetDeclarations { position }; + if !self.is_capable_for_proto_request(buffer_handle, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { buffer_id: buffer_handle.read(cx).remote_id().into(), version: serialize_version(&buffer_handle.read(cx).version()), @@ -5342,7 +5412,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetDeclaration( - GetDeclarations { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer_handle.read(cx)), )), }); let buffer = buffer_handle.clone(); @@ -5389,7 +5459,7 @@ impl LspStore { GetDeclarations { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(declarations_task .await .into_iter() @@ -5402,23 +5472,27 @@ impl LspStore { pub fn type_definitions( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, position: PointUtf16, cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetTypeDefinitions { position }; + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetTypeDefinition( - GetTypeDefinitions { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5457,12 +5531,12 @@ impl LspStore { }) } else { let type_definitions_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetTypeDefinitions { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(type_definitions_task .await .into_iter() @@ -5475,23 +5549,27 @@ impl LspStore { pub fn implementations( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, position: PointUtf16, cx: &mut Context<Self>, ) -> Task<Result<Vec<LocationLink>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetImplementations { position }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetImplementation( - GetImplementations { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5530,12 +5608,12 @@ impl LspStore { }) } else { let implementations_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetImplementations { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(implementations_task .await .into_iter() @@ -5548,23 +5626,27 @@ impl LspStore { pub fn references( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, position: PointUtf16, cx: &mut Context<Self>, ) -> Task<Result<Vec<Location>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetReferences { position }; + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetReferences( - GetReferences { position }.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5603,12 +5685,12 @@ impl LspStore { }) } else { let references_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetReferences { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(references_task .await .into_iter() @@ -5621,28 +5703,31 @@ impl LspStore { pub fn code_actions( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, range: Range<Anchor>, kinds: Option<Vec<CodeActionKind>>, cx: &mut Context<Self>, ) -> Task<Result<Vec<CodeAction>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetCodeActions { + range: range.clone(), + kinds: kinds.clone(), + }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetCodeActions( - GetCodeActions { - range: range.clone(), - kinds: kinds.clone(), - } - .to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { return Ok(Vec::new()); @@ -5684,7 +5769,7 @@ impl LspStore { }) } else { let all_actions_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(range.start), GetCodeActions { range: range.clone(), @@ -5692,7 +5777,7 @@ impl LspStore { }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(all_actions_task .await .into_iter() @@ -5702,69 +5787,172 @@ impl LspStore { } } - pub fn code_lens( + pub fn code_lens_actions( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, cx: &mut Context<Self>, - ) -> Task<Result<Vec<CodeAction>>> { + ) -> CodeLensTask { + let version_queried_for = buffer.read(cx).version(); + let buffer_id = buffer.read(cx).remote_id(); + + if let Some(cached_data) = self.lsp_code_lens.get(&buffer_id) { + if !version_queried_for.changed_since(&cached_data.lens_for_version) { + let has_different_servers = self.as_local().is_some_and(|local| { + local + .buffers_opened_in_servers + .get(&buffer_id) + .cloned() + .unwrap_or_default() + != cached_data.lens.keys().copied().collect() + }); + if !has_different_servers { + return Task::ready(Ok(cached_data.lens.values().flatten().cloned().collect())) + .shared(); + } + } + } + + let lsp_data = self.lsp_code_lens.entry(buffer_id).or_default(); + if let Some((updating_for, running_update)) = &lsp_data.update { + if !version_queried_for.changed_since(&updating_for) { + return running_update.clone(); + } + } + let buffer = buffer.clone(); + let query_version_queried_for = version_queried_for.clone(); + let new_task = cx + .spawn(async move |lsp_store, cx| { + cx.background_executor() + .timer(Duration::from_millis(30)) + .await; + let fetched_lens = lsp_store + .update(cx, |lsp_store, cx| lsp_store.fetch_code_lens(&buffer, cx)) + .map_err(Arc::new)? + .await + .context("fetching code lens") + .map_err(Arc::new); + let fetched_lens = match fetched_lens { + Ok(fetched_lens) => fetched_lens, + Err(e) => { + lsp_store + .update(cx, |lsp_store, _| { + lsp_store.lsp_code_lens.entry(buffer_id).or_default().update = None; + }) + .ok(); + return Err(e); + } + }; + + lsp_store + .update(cx, |lsp_store, _| { + let lsp_data = lsp_store.lsp_code_lens.entry(buffer_id).or_default(); + if lsp_data.lens_for_version == query_version_queried_for { + lsp_data.lens.extend(fetched_lens.clone()); + } else if !lsp_data + .lens_for_version + .changed_since(&query_version_queried_for) + { + lsp_data.lens_for_version = query_version_queried_for; + lsp_data.lens = fetched_lens.clone(); + } + lsp_data.update = None; + lsp_data.lens.values().flatten().cloned().collect() + }) + .map_err(Arc::new) + }) + .shared(); + lsp_data.update = Some((version_queried_for, new_task.clone())); + new_task + } + + fn fetch_code_lens( + &mut self, + buffer: &Entity<Buffer>, + cx: &mut Context<Self>, + ) -> Task<Result<HashMap<LanguageServerId, Vec<CodeAction>>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetCodeLens; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(HashMap::default())); + } let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + buffer_id: buffer.read(cx).remote_id().into(), + version: serialize_version(&buffer.read(cx).version()), project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetCodeLens( - GetCodeLens.to_proto(project_id, buffer_handle.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); - let buffer = buffer_handle.clone(); - cx.spawn(async move |weak_project, cx| { - let Some(project) = weak_project.upgrade() else { - return Ok(Vec::new()); + let buffer = buffer.clone(); + cx.spawn(async move |weak_lsp_store, cx| { + let Some(lsp_store) = weak_lsp_store.upgrade() else { + return Ok(HashMap::default()); }; let responses = request_task.await?.responses; - let code_lens = join_all( + let code_lens_actions = join_all( responses .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetCodeLensResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } + .filter_map(|lsp_response| { + let response = match lsp_response.response? { + proto::lsp_response::Response::GetCodeLensResponse(response) => { + Some(response) + } + unexpected => { + debug_panic!("Unexpected response: {unexpected:?}"); + None + } + }?; + let server_id = LanguageServerId::from_proto(lsp_response.server_id); + Some((server_id, response)) }) - .map(|code_lens_response| { - GetCodeLens.response_from_proto( - code_lens_response, - project.clone(), - buffer.clone(), - cx.clone(), - ) + .map(|(server_id, code_lens_response)| { + let lsp_store = lsp_store.clone(); + let buffer = buffer.clone(); + let cx = cx.clone(); + async move { + ( + server_id, + GetCodeLens + .response_from_proto( + code_lens_response, + lsp_store, + buffer, + cx, + ) + .await, + ) + } }), ) .await; - Ok(code_lens + let mut has_errors = false; + let code_lens_actions = code_lens_actions .into_iter() - .collect::<Result<Vec<Vec<_>>>>()? - .into_iter() - .flatten() - .collect()) + .filter_map(|(server_id, code_lens)| match code_lens { + Ok(code_lens) => Some((server_id, code_lens)), + Err(e) => { + has_errors = true; + log::error!("{e:#}"); + None + } + }) + .collect::<HashMap<_, _>>(); + anyhow::ensure!( + !has_errors || !code_lens_actions.is_empty(), + "Failed to fetch code lens" + ); + Ok(code_lens_actions) }) } else { - let code_lens_task = - self.request_multiple_lsp_locally(buffer_handle, None::<usize>, GetCodeLens, cx); - cx.spawn(async move |_, _| { - Ok(code_lens_task - .await - .into_iter() - .flat_map(|(_, code_lens)| code_lens) - .collect()) - }) + let code_lens_actions_task = + self.request_multiple_lsp_locally(buffer, None::<usize>, GetCodeLens, cx); + cx.background_spawn( + async move { Ok(code_lens_actions_task.await.into_iter().collect()) }, + ) } } @@ -5779,11 +5967,15 @@ impl LspStore { let language_registry = self.languages.clone(); if let Some((upstream_client, project_id)) = self.upstream_client() { + let request = GetCompletions { position, context }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } let task = self.send_lsp_proto_request( buffer.clone(), upstream_client, project_id, - GetCompletions { position, context }, + request, cx, ); let language = buffer.read(cx).language().cloned(); @@ -5921,11 +6113,17 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Result<bool>> { let client = self.upstream_client(); - let buffer_id = buffer.read(cx).remote_id(); let buffer_snapshot = buffer.read(cx).snapshot(); - cx.spawn(async move |this, cx| { + if !self.check_if_capable_for_proto_request( + &buffer, + GetCompletions::can_resolve_completions, + cx, + ) { + return Task::ready(Ok(false)); + } + cx.spawn(async move |lsp_store, cx| { let mut did_resolve = false; if let Some((client, project_id)) = client { for completion_index in completion_indices { @@ -5962,7 +6160,7 @@ impl LspStore { completion.source.server_id() }; if let Some(server_id) = server_id { - let server_and_adapter = this + let server_and_adapter = lsp_store .read_with(cx, |lsp_store, _| { let server = lsp_store.language_server_for_id(server_id)?; let adapter = @@ -5977,7 +6175,6 @@ impl LspStore { let resolved = Self::resolve_completion_local( server, - &buffer_snapshot, completions.clone(), completion_index, ) @@ -6010,18 +6207,11 @@ impl LspStore { async fn resolve_completion_local( server: Arc<lsp::LanguageServer>, - snapshot: &BufferSnapshot, completions: Rc<RefCell<Box<[Completion]>>>, completion_index: usize, ) -> Result<()> { let server_id = server.server_id(); - let can_resolve = server - .capabilities() - .completion_provider - .as_ref() - .and_then(|options| options.resolve_provider) - .unwrap_or(false); - if !can_resolve { + if !GetCompletions::can_resolve_completions(&server.capabilities()) { return Ok(()); } @@ -6055,26 +6245,8 @@ impl LspStore { .into_response() .context("resolve completion")?; - if let Some(text_edit) = resolved_completion.text_edit.as_ref() { - // Technically we don't have to parse the whole `text_edit`, since the only - // language server we currently use that does update `text_edit` in `completionItem/resolve` - // is `typescript-language-server` and they only update `text_edit.new_text`. - // But we should not rely on that. - let edit = parse_completion_text_edit(text_edit, snapshot); - - if let Some(mut parsed_edit) = edit { - LineEnding::normalize(&mut parsed_edit.new_text); - - let mut completions = completions.borrow_mut(); - let completion = &mut completions[completion_index]; - - completion.new_text = parsed_edit.new_text; - completion.replace_range = parsed_edit.replace_range; - if let CompletionSource::Lsp { insert_range, .. } = &mut completion.source { - *insert_range = parsed_edit.insert_range; - } - } - } + // We must not use any data such as sortText, filterText, insertText and textEdit to edit `Completion` since they are not suppose change during resolve. + // Refer: https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_completion let mut completions = completions.borrow_mut(); let completion = &mut completions[completion_index]; @@ -6324,12 +6496,10 @@ impl LspStore { }) else { return Task::ready(Ok(None)); }; - let snapshot = buffer_handle.read(&cx).snapshot(); cx.spawn(async move |this, cx| { Self::resolve_completion_local( server.clone(), - &snapshot, completions.clone(), completion_index, ) @@ -6392,16 +6562,24 @@ impl LspStore { pub fn pull_diagnostics( &mut self, - buffer_handle: Entity<Buffer>, + buffer: Entity<Buffer>, cx: &mut Context<Self>, ) -> Task<Result<Vec<LspPullDiagnostics>>> { - let buffer = buffer_handle.read(cx); - let buffer_id = buffer.remote_id(); + let buffer_id = buffer.read(cx).remote_id(); if let Some((client, upstream_project_id)) = self.upstream_client() { + if !self.is_capable_for_proto_request( + &buffer, + &GetDocumentDiagnostics { + previous_result_id: None, + }, + cx, + ) { + return Task::ready(Ok(Vec::new())); + } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer_id.to_proto(), - version: serialize_version(&buffer_handle.read(cx).version()), + version: serialize_version(&buffer.read(cx).version()), project_id: upstream_project_id, strategy: Some(proto::multi_lsp_query::Strategy::All( proto::AllLanguageServers {}, @@ -6410,7 +6588,7 @@ impl LspStore { proto::GetDocumentDiagnostics { project_id: upstream_project_id, buffer_id: buffer_id.to_proto(), - version: serialize_version(&buffer_handle.read(cx).version()), + version: serialize_version(&buffer.read(cx).version()), }, )), }); @@ -6432,7 +6610,7 @@ impl LspStore { .collect()) }) } else { - let server_ids = buffer_handle.update(cx, |buffer, cx| { + let server_ids = buffer.update(cx, |buffer, cx| { self.language_servers_for_local_buffer(buffer, cx) .map(|(_, server)| server.server_id()) .collect::<Vec<_>>() @@ -6442,7 +6620,7 @@ impl LspStore { .map(|server_id| { let result_id = self.result_id(server_id, buffer_id, cx); self.request_lsp( - buffer_handle.clone(), + buffer.clone(), LanguageServerToQuery::Other(server_id), GetDocumentDiagnostics { previous_result_id: result_id, @@ -6464,34 +6642,36 @@ impl LspStore { pub fn inlay_hints( &mut self, - buffer_handle: Entity<Buffer>, + buffer: Entity<Buffer>, range: Range<Anchor>, cx: &mut Context<Self>, ) -> Task<anyhow::Result<Vec<InlayHint>>> { - let buffer = buffer_handle.read(cx); let range_start = range.start; let range_end = range.end; - let buffer_id = buffer.remote_id().into(); - let lsp_request = InlayHints { range }; + let buffer_id = buffer.read(cx).remote_id().into(); + let request = InlayHints { range }; if let Some((client, project_id)) = self.upstream_client() { - let request = proto::InlayHints { + if !self.is_capable_for_proto_request(&buffer, &request, cx) { + return Task::ready(Ok(Vec::new())); + } + let proto_request = proto::InlayHints { project_id, buffer_id, start: Some(serialize_anchor(&range_start)), end: Some(serialize_anchor(&range_end)), - version: serialize_version(&buffer_handle.read(cx).version()), + version: serialize_version(&buffer.read(cx).version()), }; cx.spawn(async move |project, cx| { let response = client - .request(request) + .request(proto_request) .await .context("inlay hints proto request")?; LspCommand::response_from_proto( - lsp_request, + request, response, project.upgrade().context("No project")?, - buffer_handle.clone(), + buffer.clone(), cx.clone(), ) .await @@ -6499,13 +6679,13 @@ impl LspStore { }) } else { let lsp_request_task = self.request_lsp( - buffer_handle.clone(), + buffer.clone(), LanguageServerToQuery::FirstCapable, - lsp_request, + request, cx, ); cx.spawn(async move |_, cx| { - buffer_handle + buffer .update(cx, |buffer, _| { buffer.wait_for_edits(vec![range_start.timestamp, range_end.timestamp]) })? @@ -6597,7 +6777,7 @@ impl LspStore { pub fn document_colors( &mut self, - fetch_strategy: ColorFetchStrategy, + fetch_strategy: LspFetchStrategy, buffer: Entity<Buffer>, cx: &mut Context<Self>, ) -> Option<DocumentColorTask> { @@ -6605,11 +6785,11 @@ impl LspStore { let buffer_id = buffer.read(cx).remote_id(); match fetch_strategy { - ColorFetchStrategy::IgnoreCache => {} - ColorFetchStrategy::UseCache { + LspFetchStrategy::IgnoreCache => {} + LspFetchStrategy::UseCache { known_cache_version, } => { - if let Some(cached_data) = self.lsp_data.get(&buffer_id) { + if let Some(cached_data) = self.lsp_document_colors.get(&buffer_id) { if !version_queried_for.changed_since(&cached_data.colors_for_version) { let has_different_servers = self.as_local().is_some_and(|local| { local @@ -6642,7 +6822,7 @@ impl LspStore { } } - let lsp_data = self.lsp_data.entry(buffer_id).or_default(); + let lsp_data = self.lsp_document_colors.entry(buffer_id).or_default(); if let Some((updating_for, running_update)) = &lsp_data.colors_update { if !version_queried_for.changed_since(&updating_for) { return Some(running_update.clone()); @@ -6656,14 +6836,14 @@ impl LspStore { .await; let fetched_colors = lsp_store .update(cx, |lsp_store, cx| { - lsp_store.fetch_document_colors_for_buffer(buffer.clone(), cx) + lsp_store.fetch_document_colors_for_buffer(&buffer, cx) })? .await .context("fetching document colors") .map_err(Arc::new); let fetched_colors = match fetched_colors { Ok(fetched_colors) => { - if fetch_strategy != ColorFetchStrategy::IgnoreCache + if fetch_strategy != LspFetchStrategy::IgnoreCache && Some(true) == buffer .update(cx, |buffer, _| { @@ -6679,7 +6859,7 @@ impl LspStore { lsp_store .update(cx, |lsp_store, _| { lsp_store - .lsp_data + .lsp_document_colors .entry(buffer_id) .or_default() .colors_update = None; @@ -6691,7 +6871,7 @@ impl LspStore { lsp_store .update(cx, |lsp_store, _| { - let lsp_data = lsp_store.lsp_data.entry(buffer_id).or_default(); + let lsp_data = lsp_store.lsp_document_colors.entry(buffer_id).or_default(); if lsp_data.colors_for_version == query_version_queried_for { lsp_data.colors.extend(fetched_colors.clone()); @@ -6725,10 +6905,15 @@ impl LspStore { fn fetch_document_colors_for_buffer( &mut self, - buffer: Entity<Buffer>, + buffer: &Entity<Buffer>, cx: &mut Context<Self>, ) -> Task<anyhow::Result<HashMap<LanguageServerId, HashSet<DocumentColor>>>> { if let Some((client, project_id)) = self.upstream_client() { + let request = GetDocumentColor {}; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(HashMap::default())); + } + let request_task = client.request(proto::MultiLspQuery { project_id, buffer_id: buffer.read(cx).remote_id().to_proto(), @@ -6737,9 +6922,10 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetDocumentColor( - GetDocumentColor {}.to_proto(project_id, buffer.read(cx)), + request.to_proto(project_id, buffer.read(cx)), )), }); + let buffer = buffer.clone(); cx.spawn(async move |project, cx| { let Some(project) = project.upgrade() else { return Ok(HashMap::default()); @@ -6764,7 +6950,7 @@ impl LspStore { } }) .map(|(server_id, color_response)| { - let response = GetDocumentColor {}.response_from_proto( + let response = request.response_from_proto( color_response, project.clone(), buffer.clone(), @@ -6785,8 +6971,8 @@ impl LspStore { }) } else { let document_colors_task = - self.request_multiple_lsp_locally(&buffer, None::<usize>, GetDocumentColor, cx); - cx.spawn(async move |_, _| { + self.request_multiple_lsp_locally(buffer, None::<usize>, GetDocumentColor, cx); + cx.background_spawn(async move { Ok(document_colors_task .await .into_iter() @@ -6811,6 +6997,10 @@ impl LspStore { let position = position.to_point_utf16(buffer.read(cx)); if let Some((client, upstream_project_id)) = self.upstream_client() { + let request = GetSignatureHelp { position }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Vec::new()); + } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer.read(cx).remote_id().into(), version: serialize_version(&buffer.read(cx).version()), @@ -6819,7 +7009,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetSignatureHelp( - GetSignatureHelp { position }.to_proto(upstream_project_id, buffer.read(cx)), + request.to_proto(upstream_project_id, buffer.read(cx)), )), }); let buffer = buffer.clone(); @@ -6865,7 +7055,7 @@ impl LspStore { GetSignatureHelp { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { all_actions_task .await .into_iter() @@ -6882,6 +7072,10 @@ impl LspStore { cx: &mut Context<Self>, ) -> Task<Vec<Hover>> { if let Some((client, upstream_project_id)) = self.upstream_client() { + let request = GetHover { position }; + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Vec::new()); + } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer.read(cx).remote_id().into(), version: serialize_version(&buffer.read(cx).version()), @@ -6890,7 +7084,7 @@ impl LspStore { proto::AllLanguageServers {}, )), request: Some(proto::multi_lsp_query::Request::GetHover( - GetHover { position }.to_proto(upstream_project_id, buffer.read(cx)), + request.to_proto(upstream_project_id, buffer.read(cx)), )), }); let buffer = buffer.clone(); @@ -6942,7 +7136,7 @@ impl LspStore { GetHover { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { all_actions_task .await .into_iter() @@ -7210,7 +7404,9 @@ impl LspStore { let build_incremental_change = || { buffer - .edits_since::<(PointUtf16, usize)>(previous_snapshot.snapshot.version()) + .edits_since::<Dimensions<PointUtf16, usize>>( + previous_snapshot.snapshot.version(), + ) .map(|edit| { let edit_start = edit.new.start.0; let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); @@ -7325,21 +7521,23 @@ impl LspStore { } pub(crate) async fn refresh_workspace_configurations( - this: &WeakEntity<Self>, + lsp_store: &WeakEntity<Self>, fs: Arc<dyn Fs>, cx: &mut AsyncApp, ) { maybe!(async move { - let servers = this - .update(cx, |this, cx| { - let Some(local) = this.as_local() else { + let mut refreshed_servers = HashSet::default(); + let servers = lsp_store + .update(cx, |lsp_store, cx| { + let toolchain_store = lsp_store.toolchain_store(cx); + let Some(local) = lsp_store.as_local() else { return Vec::default(); }; local .language_server_ids .iter() .flat_map(|((worktree_id, _), server_ids)| { - let worktree = this + let worktree = lsp_store .worktree_store .read(cx) .worktree_for_id(*worktree_id, cx); @@ -7355,43 +7553,54 @@ impl LspStore { ) }); - server_ids.iter().filter_map(move |server_id| { + let fs = fs.clone(); + let toolchain_store = toolchain_store.clone(); + server_ids.iter().filter_map(|server_id| { + let delegate = delegate.clone()? as Arc<dyn LspAdapterDelegate>; let states = local.language_servers.get(server_id)?; match states { LanguageServerState::Starting { .. } => None, LanguageServerState::Running { adapter, server, .. - } => Some(( - adapter.adapter.clone(), - server.clone(), - delegate.clone()? as Arc<dyn LspAdapterDelegate>, - )), + } => { + let fs = fs.clone(); + let toolchain_store = toolchain_store.clone(); + let adapter = adapter.clone(); + let server = server.clone(); + refreshed_servers.insert(server.name()); + Some(cx.spawn(async move |_, cx| { + let settings = + LocalLspStore::workspace_configuration_for_adapter( + adapter.adapter.clone(), + fs.as_ref(), + &delegate, + toolchain_store, + cx, + ) + .await + .ok()?; + server + .notify::<lsp::notification::DidChangeConfiguration>( + &lsp::DidChangeConfigurationParams { settings }, + ) + .ok()?; + Some(()) + })) + } } - }) + }).collect::<Vec<_>>() }) .collect::<Vec<_>>() }) .ok()?; - let toolchain_store = this.update(cx, |this, cx| this.toolchain_store(cx)).ok()?; - for (adapter, server, delegate) in servers { - let settings = LocalLspStore::workspace_configuration_for_adapter( - adapter, - fs.as_ref(), - &delegate, - toolchain_store.clone(), - cx, - ) - .await - .ok()?; - - server - .notify::<lsp::notification::DidChangeConfiguration>( - &lsp::DidChangeConfigurationParams { settings }, - ) - .ok(); - } + log::info!("Refreshing workspace configurations for servers {refreshed_servers:?}"); + // TODO this asynchronous job runs concurrently with extension (de)registration and may take enough time for a certain extension + // to stop and unregister its language server wrapper. + // This is racy : an extension might have already removed all `local.language_servers` state, but here we `.clone()` and hold onto it anyway. + // This now causes errors in the logs, we should find a way to remove such servers from the processing everywhere. + let _: Vec<Option<()>> = join_all(servers).await; Some(()) }) .await; @@ -7480,16 +7689,20 @@ impl LspStore { self.downstream_client = Some((downstream_client.clone(), project_id)); for (server_id, status) in &self.language_server_statuses { - downstream_client - .send(proto::StartLanguageServer { - project_id, - server: Some(proto::LanguageServer { - id: server_id.0 as u64, - name: status.name.clone(), - worktree_id: None, - }), - }) - .log_err(); + if let Some(server) = self.language_server_for_id(*server_id) { + downstream_client + .send(proto::StartLanguageServer { + project_id, + server: Some(proto::LanguageServer { + id: server_id.to_proto(), + name: status.name.to_string(), + worktree_id: None, + }), + capabilities: serde_json::to_string(&server.capabilities()) + .expect("serializing server LSP capabilities"), + }) + .log_err(); + } } } @@ -7516,7 +7729,7 @@ impl LspStore { ( LanguageServerId(server.id as usize), LanguageServerStatus { - name: server.name, + name: LanguageServerName::from_proto(server.name), pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), @@ -7932,7 +8145,7 @@ impl LspStore { }) .collect::<FuturesUnordered<_>>(); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let mut responses = Vec::with_capacity(response_results.len()); while let Some((server_id, response_result)) = response_results.next().await { if let Some(response) = response_result.log_err() { @@ -8531,34 +8744,6 @@ impl LspStore { Ok(proto::Ack {}) } - async fn handle_language_server_id_for_name( - lsp_store: Entity<Self>, - envelope: TypedEnvelope<proto::LanguageServerIdForName>, - mut cx: AsyncApp, - ) -> Result<proto::LanguageServerIdForNameResponse> { - let name = &envelope.payload.name; - let buffer_id = BufferId::new(envelope.payload.buffer_id)?; - lsp_store - .update(&mut cx, |lsp_store, cx| { - let buffer = lsp_store.buffer_store.read(cx).get_existing(buffer_id)?; - let server_id = buffer.update(cx, |buffer, cx| { - lsp_store - .language_servers_for_local_buffer(buffer, cx) - .find_map(|(adapter, server)| { - if adapter.name.0.as_ref() == name { - Some(server.server_id()) - } else { - None - } - }) - }); - Ok(server_id) - })? - .map(|server_id| proto::LanguageServerIdForNameResponse { - server_id: server_id.map(|id| id.to_proto()), - }) - } - async fn handle_rename_project_entry( this: Entity<Self>, envelope: TypedEnvelope<proto::RenameProjectEntry>, @@ -8665,18 +8850,29 @@ impl LspStore { } async fn handle_start_language_server( - this: Entity<Self>, + lsp_store: Entity<Self>, envelope: TypedEnvelope<proto::StartLanguageServer>, mut cx: AsyncApp, ) -> Result<()> { let server = envelope.payload.server.context("invalid server")?; - - this.update(&mut cx, |this, cx| { + let server_capabilities = + serde_json::from_str::<lsp::ServerCapabilities>(&envelope.payload.capabilities) + .with_context(|| { + format!( + "incorrect server capabilities {}", + envelope.payload.capabilities + ) + })?; + lsp_store.update(&mut cx, |lsp_store, cx| { let server_id = LanguageServerId(server.id as usize); - this.language_server_statuses.insert( + let server_name = LanguageServerName::from_proto(server.name.clone()); + lsp_store + .lsp_server_capabilities + .insert(server_id, server_capabilities); + lsp_store.language_server_statuses.insert( server_id, LanguageServerStatus { - name: server.name.clone(), + name: server_name.clone(), pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), @@ -8684,7 +8880,7 @@ impl LspStore { ); cx.emit(LspStoreEvent::LanguageServerAdded( server_id, - LanguageServerName(server.name.into()), + server_name, server.worktree_id.map(WorktreeId::from_proto), )); cx.notify(); @@ -8745,7 +8941,8 @@ impl LspStore { } non_lsp @ proto::update_language_server::Variant::StatusUpdate(_) - | non_lsp @ proto::update_language_server::Variant::RegisteredForBuffer(_) => { + | non_lsp @ proto::update_language_server::Variant::RegisteredForBuffer(_) + | non_lsp @ proto::update_language_server::Variant::MetadataUpdated(_) => { cx.emit(LspStoreEvent::LanguageServerUpdate { language_server_id, name: envelope @@ -10192,7 +10389,7 @@ impl LspStore { let name = self .language_server_statuses .remove(&server_id) - .map(|status| LanguageServerName::from(status.name.as_str())) + .map(|status| status.name.clone()) .or_else(|| { if let Some(LanguageServerState::Running { adapter, .. }) = server_state.as_ref() { Some(adapter.name()) @@ -10685,7 +10882,7 @@ impl LspStore { self.language_server_statuses.insert( server_id, LanguageServerStatus { - name: language_server.name().to_string(), + name: language_server.name(), pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), @@ -10699,18 +10896,23 @@ impl LspStore { )); cx.emit(LspStoreEvent::RefreshInlayHints); + let server_capabilities = language_server.capabilities(); if let Some((downstream_client, project_id)) = self.downstream_client.as_ref() { downstream_client .send(proto::StartLanguageServer { project_id: *project_id, server: Some(proto::LanguageServer { - id: server_id.0 as u64, + id: server_id.to_proto(), name: language_server.name().to_string(), worktree_id: Some(key.0.to_proto()), }), + capabilities: serde_json::to_string(&server_capabilities) + .expect("serializing server LSP capabilities"), }) .log_err(); } + self.lsp_server_capabilities + .insert(server_id, server_capabilities); // Tell the language server about every open buffer in the worktree that matches the language. // Also check for buffers in worktrees that reused this server @@ -10758,10 +10960,11 @@ impl LspStore { let local = self.as_local_mut().unwrap(); - if local.registered_buffers.contains_key(&buffer.remote_id()) { + let buffer_id = buffer.remote_id(); + if local.registered_buffers.contains_key(&buffer_id) { let versions = local .buffer_snapshots - .entry(buffer.remote_id()) + .entry(buffer_id) .or_default() .entry(server_id) .and_modify(|_| { @@ -10787,10 +10990,10 @@ impl LspStore { version, initial_snapshot.text(), ); - buffer_paths_registered.push(file.abs_path(cx)); + buffer_paths_registered.push((buffer_id, file.abs_path(cx))); local .buffers_opened_in_servers - .entry(buffer.remote_id()) + .entry(buffer_id) .or_default() .insert(server_id); } @@ -10814,13 +11017,14 @@ impl LspStore { } }); - for abs_path in buffer_paths_registered { + for (buffer_id, abs_path) in buffer_paths_registered { cx.emit(LspStoreEvent::LanguageServerUpdate { language_server_id: server_id, name: Some(adapter.name()), message: proto::update_language_server::Variant::RegisteredForBuffer( proto::RegisteredForBuffer { buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), }, ), }); @@ -11278,9 +11482,13 @@ impl LspStore { } fn cleanup_lsp_data(&mut self, for_server: LanguageServerId) { - for buffer_lsp_data in self.lsp_data.values_mut() { - buffer_lsp_data.colors.remove(&for_server); - buffer_lsp_data.cache_version += 1; + self.lsp_server_capabilities.remove(&for_server); + for buffer_colors in self.lsp_document_colors.values_mut() { + buffer_colors.colors.remove(&for_server); + buffer_colors.cache_version += 1; + } + for buffer_lens in self.lsp_code_lens.values_mut() { + buffer_lens.lens.remove(&for_server); } if let Some(local) = self.as_local_mut() { local.buffer_pull_diagnostics_result_ids.remove(&for_server); diff --git a/crates/project/src/lsp_store/clangd_ext.rs b/crates/project/src/lsp_store/clangd_ext.rs index 6a09bb99b4..cf3507352e 100644 --- a/crates/project/src/lsp_store/clangd_ext.rs +++ b/crates/project/src/lsp_store/clangd_ext.rs @@ -3,12 +3,12 @@ use std::sync::Arc; use ::serde::{Deserialize, Serialize}; use gpui::WeakEntity; use language::{CachedLspAdapter, Diagnostic, DiagnosticSourceKind}; -use lsp::LanguageServer; +use lsp::{LanguageServer, LanguageServerName}; use util::ResultExt as _; use crate::LspStore; -pub const CLANGD_SERVER_NAME: &str = "clangd"; +pub const CLANGD_SERVER_NAME: LanguageServerName = LanguageServerName::new_static("clangd"); const INACTIVE_REGION_MESSAGE: &str = "inactive region"; const INACTIVE_DIAGNOSTIC_SEVERITY: lsp::DiagnosticSeverity = lsp::DiagnosticSeverity::INFORMATION; @@ -34,7 +34,7 @@ pub fn is_inactive_region(diag: &Diagnostic) -> bool { && diag .source .as_ref() - .is_some_and(|v| v == CLANGD_SERVER_NAME) + .is_some_and(|v| v == &CLANGD_SERVER_NAME.0) } pub fn is_lsp_inactive_region(diag: &lsp::Diagnostic) -> bool { @@ -43,7 +43,7 @@ pub fn is_lsp_inactive_region(diag: &lsp::Diagnostic) -> bool { && diag .source .as_ref() - .is_some_and(|v| v == CLANGD_SERVER_NAME) + .is_some_and(|v| v == &CLANGD_SERVER_NAME.0) } pub fn register_notifications( @@ -51,7 +51,7 @@ pub fn register_notifications( language_server: &LanguageServer, adapter: Arc<CachedLspAdapter>, ) { - if language_server.name().0 != CLANGD_SERVER_NAME { + if language_server.name() != CLANGD_SERVER_NAME { return; } let server_id = language_server.server_id(); diff --git a/crates/project/src/lsp_store/rust_analyzer_ext.rs b/crates/project/src/lsp_store/rust_analyzer_ext.rs index d78715d385..6c425717a8 100644 --- a/crates/project/src/lsp_store/rust_analyzer_ext.rs +++ b/crates/project/src/lsp_store/rust_analyzer_ext.rs @@ -2,12 +2,12 @@ use ::serde::{Deserialize, Serialize}; use anyhow::Context as _; use gpui::{App, Entity, Task, WeakEntity}; use language::ServerHealth; -use lsp::LanguageServer; +use lsp::{LanguageServer, LanguageServerName}; use rpc::proto; use crate::{LspStore, LspStoreEvent, Project, ProjectPath, lsp_store}; -pub const RUST_ANALYZER_NAME: &str = "rust-analyzer"; +pub const RUST_ANALYZER_NAME: LanguageServerName = LanguageServerName::new_static("rust-analyzer"); pub const CARGO_DIAGNOSTICS_SOURCE_NAME: &str = "rustc"; /// Experimental: Informs the end user about the state of the server @@ -97,13 +97,9 @@ pub fn cancel_flycheck( cx.spawn(async move |cx| { let buffer = buffer.await?; - let Some(rust_analyzer_server) = project - .update(cx, |project, cx| { - buffer.update(cx, |buffer, cx| { - project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx) - }) - })? - .await + let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + })? else { return Ok(()); }; @@ -148,13 +144,9 @@ pub fn run_flycheck( cx.spawn(async move |cx| { let buffer = buffer.await?; - let Some(rust_analyzer_server) = project - .update(cx, |project, cx| { - buffer.update(cx, |buffer, cx| { - project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx) - }) - })? - .await + let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + })? else { return Ok(()); }; @@ -204,13 +196,9 @@ pub fn clear_flycheck( cx.spawn(async move |cx| { let buffer = buffer.await?; - let Some(rust_analyzer_server) = project - .update(cx, |project, cx| { - buffer.update(cx, |buffer, cx| { - project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx) - }) - })? - .await + let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + })? else { return Ok(()); }; diff --git a/crates/project/src/manifest_tree/server_tree.rs b/crates/project/src/manifest_tree/server_tree.rs index 0283f06eec..81cb1c450c 100644 --- a/crates/project/src/manifest_tree/server_tree.rs +++ b/crates/project/src/manifest_tree/server_tree.rs @@ -13,10 +13,10 @@ use std::{ sync::{Arc, Weak}, }; -use collections::{HashMap, IndexMap}; +use collections::IndexMap; use gpui::{App, AppContext as _, Entity, Subscription}; use language::{ - Attach, CachedLspAdapter, LanguageName, LanguageRegistry, ManifestDelegate, + CachedLspAdapter, LanguageName, LanguageRegistry, ManifestDelegate, language_settings::AllLanguageSettings, }; use lsp::LanguageServerName; @@ -38,7 +38,6 @@ pub(crate) struct ServersForWorktree { pub struct LanguageServerTree { manifest_tree: Entity<ManifestTree>, pub(crate) instances: BTreeMap<WorktreeId, ServersForWorktree>, - attach_kind_cache: HashMap<LanguageServerName, Attach>, languages: Arc<LanguageRegistry>, _subscriptions: Subscription, } @@ -53,7 +52,6 @@ pub struct LanguageServerTreeNode(Weak<InnerTreeNode>); #[derive(Debug)] pub(crate) struct LaunchDisposition<'a> { pub(crate) server_name: &'a LanguageServerName, - pub(crate) attach: Attach, pub(crate) path: ProjectPath, pub(crate) settings: Arc<LspSettings>, } @@ -62,7 +60,6 @@ impl<'a> From<&'a InnerTreeNode> for LaunchDisposition<'a> { fn from(value: &'a InnerTreeNode) -> Self { LaunchDisposition { server_name: &value.name, - attach: value.attach, path: value.path.clone(), settings: value.settings.clone(), } @@ -105,7 +102,6 @@ impl From<Weak<InnerTreeNode>> for LanguageServerTreeNode { pub struct InnerTreeNode { id: OnceLock<LanguageServerId>, name: LanguageServerName, - attach: Attach, path: ProjectPath, settings: Arc<LspSettings>, } @@ -113,14 +109,12 @@ pub struct InnerTreeNode { impl InnerTreeNode { fn new( name: LanguageServerName, - attach: Attach, path: ProjectPath, settings: impl Into<Arc<LspSettings>>, ) -> Self { InnerTreeNode { id: Default::default(), name, - attach, path, settings: settings.into(), } @@ -130,8 +124,11 @@ impl InnerTreeNode { /// Determines how the list of adapters to query should be constructed. pub(crate) enum AdapterQuery<'a> { /// Search for roots of all adapters associated with a given language name. + /// Layman: Look for all project roots along the queried path that have any + /// language server associated with this language running. Language(&'a LanguageName), /// Search for roots of adapter with a given name. + /// Layman: Look for all project roots along the queried path that have this server running. Adapter(&'a LanguageServerName), } @@ -147,7 +144,7 @@ impl LanguageServerTree { }), manifest_tree, instances: Default::default(), - attach_kind_cache: Default::default(), + languages, }) } @@ -223,7 +220,6 @@ impl LanguageServerTree { .and_then(|name| roots.get(&name)) .cloned() .unwrap_or_else(|| root_path.clone()); - let attach = adapter.attach_kind(); let inner_node = self .instances @@ -237,7 +233,6 @@ impl LanguageServerTree { ( Arc::new(InnerTreeNode::new( adapter.name(), - attach, root_path.clone(), settings.clone(), )), @@ -379,7 +374,6 @@ pub(crate) struct ServerTreeRebase<'a> { impl<'tree> ServerTreeRebase<'tree> { fn new(new_tree: &'tree mut LanguageServerTree) -> Self { let old_contents = std::mem::take(&mut new_tree.instances); - new_tree.attach_kind_cache.clear(); let all_server_ids = old_contents .values() .flat_map(|nodes| { @@ -446,10 +440,7 @@ impl<'tree> ServerTreeRebase<'tree> { .get(&disposition.path.worktree_id) .and_then(|worktree_nodes| worktree_nodes.roots.get(&disposition.path.path)) .and_then(|roots| roots.get(&disposition.name)) - .filter(|(old_node, _)| { - disposition.attach == old_node.attach - && disposition.settings == old_node.settings - }) + .filter(|(old_node, _)| disposition.settings == old_node.settings) else { return Some(node); }; diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index f9c59d2e95..cca026ec87 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -97,7 +97,7 @@ use rpc::{ }; use search::{SearchInputKind, SearchQuery, SearchResult}; use search_history::SearchHistory; -use settings::{InvalidSettingsError, Settings, SettingsLocation, SettingsStore}; +use settings::{InvalidSettingsError, Settings, SettingsLocation, SettingsSources, SettingsStore}; use smol::channel::Receiver; use snippet::Snippet; use snippet_provider::SnippetProvider; @@ -113,7 +113,7 @@ use std::{ use task_store::TaskStore; use terminals::Terminals; -use text::{Anchor, BufferId, Point}; +use text::{Anchor, BufferId, OffsetRangeExt, Point}; use toolchain_store::EmptyToolchainStore; use util::{ ResultExt as _, @@ -277,6 +277,13 @@ pub enum Event { LanguageServerAdded(LanguageServerId, LanguageServerName, Option<WorktreeId>), LanguageServerRemoved(LanguageServerId), LanguageServerLog(LanguageServerId, LanguageServerLogType, String), + // [`lsp::notification::DidOpenTextDocument`] was sent to this server using the buffer data. + // Zed's buffer-related data is updated accordingly. + LanguageServerBufferRegistered { + server_id: LanguageServerId, + buffer_id: BufferId, + buffer_abs_path: PathBuf, + }, Toast { notification_id: SharedString, message: String, @@ -590,7 +597,7 @@ pub(crate) struct CoreCompletion { } /// A code action provided by a language server. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct CodeAction { /// The id of the language server that produced this code action. pub server_id: LanguageServerId, @@ -604,7 +611,7 @@ pub struct CodeAction { } /// An action sent back by a language server. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum LspAction { /// An action with the full data, may have a command or may not. /// May require resolving. @@ -942,10 +949,38 @@ pub enum PulledDiagnostics { }, } +/// Whether to disable all AI features in Zed. +/// +/// Default: false +#[derive(Copy, Clone, Debug)] +pub struct DisableAiSettings { + pub disable_ai: bool, +} + +impl settings::Settings for DisableAiSettings { + const KEY: Option<&'static str> = Some("disable_ai"); + + type FileContent = Option<bool>; + + fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> { + Ok(Self { + disable_ai: sources + .user + .or(sources.server) + .copied() + .flatten() + .unwrap_or(sources.default.ok_or_else(Self::missing_default)?), + }) + } + + fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} +} + impl Project { pub fn init_settings(cx: &mut App) { WorktreeSettings::register(cx); ProjectSettings::register(cx); + DisableAiSettings::register(cx); } pub fn init(client: &Arc<Client>, cx: &mut App) { @@ -998,8 +1033,9 @@ impl Project { cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); + let weak_self = cx.weak_entity(); let context_server_store = - cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx)); + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx)); let environment = cx.new(|_| ProjectEnvironment::new(env)); let manifest_tree = ManifestTree::new(worktree_store.clone(), cx); @@ -1167,8 +1203,9 @@ impl Project { cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); + let weak_self = cx.weak_entity(); let context_server_store = - cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx)); + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx)); let buffer_store = cx.new(|cx| { BufferStore::remote( @@ -1360,10 +1397,7 @@ impl Project { fs: Arc<dyn Fs>, cx: AsyncApp, ) -> Result<Entity<Self>> { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.connect(true, &cx).await.into_response()?; let subscriptions = [ EntitySubscription::Project(client.subscribe_to_entity::<Self>(remote_id)?), @@ -1428,8 +1462,6 @@ impl Project { let image_store = cx.new(|cx| { ImageStore::remote(worktree_store.clone(), client.clone().into(), remote_id, cx) })?; - let context_server_store = - cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx))?; let environment = cx.new(|_| ProjectEnvironment::new(None))?; @@ -1496,6 +1528,10 @@ impl Project { let snippets = SnippetProvider::new(fs.clone(), BTreeSet::from_iter([]), cx); + let weak_self = cx.weak_entity(); + let context_server_store = + cx.new(|cx| ContextServerStore::new(worktree_store.clone(), weak_self, cx)); + let mut worktrees = Vec::new(); for worktree in response.payload.worktrees { let worktree = @@ -2902,8 +2938,8 @@ impl Project { } LspStoreEvent::LanguageServerUpdate { language_server_id, - message, name, + message, } => { if self.is_local() { self.enqueue_buffer_ordered_message( @@ -2915,6 +2951,32 @@ impl Project { ) .ok(); } + + match message { + proto::update_language_server::Variant::MetadataUpdated(update) => { + if let Some(capabilities) = update + .capabilities + .as_ref() + .and_then(|capabilities| serde_json::from_str(capabilities).ok()) + { + self.lsp_store.update(cx, |lsp_store, _| { + lsp_store + .lsp_server_capabilities + .insert(*language_server_id, capabilities); + }); + } + } + proto::update_language_server::Variant::RegisteredForBuffer(update) => { + if let Some(buffer_id) = BufferId::new(update.buffer_id).ok() { + cx.emit(Event::LanguageServerBufferRegistered { + buffer_id, + server_id: *language_server_id, + buffer_abs_path: PathBuf::from(&update.buffer_abs_path), + }); + } + } + _ => (), + } } LspStoreEvent::Notification(message) => cx.emit(Event::Toast { notification_id: "lsp".into(), @@ -3368,7 +3430,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.definitions(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3386,7 +3448,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.declarations(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3404,7 +3466,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.type_definitions(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3422,7 +3484,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.implementations(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3440,27 +3502,13 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.references(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result }) } - fn document_highlights_impl( - &mut self, - buffer: &Entity<Buffer>, - position: PointUtf16, - cx: &mut Context<Self>, - ) -> Task<Result<Vec<DocumentHighlight>>> { - self.request_lsp( - buffer.clone(), - LanguageServerToQuery::FirstCapable, - GetDocumentHighlights { position }, - cx, - ) - } - pub fn document_highlights<T: ToPointUtf16>( &mut self, buffer: &Entity<Buffer>, @@ -3468,7 +3516,12 @@ impl Project { cx: &mut Context<Self>, ) -> Task<Result<Vec<DocumentHighlight>>> { let position = position.to_point_utf16(buffer.read(cx)); - self.document_highlights_impl(buffer, position, cx) + self.request_lsp( + buffer.clone(), + LanguageServerToQuery::FirstCapable, + GetDocumentHighlights { position }, + cx, + ) } pub fn document_symbols( @@ -3569,14 +3622,14 @@ impl Project { .update(cx, |lsp_store, cx| lsp_store.hover(buffer, position, cx)) } - pub fn linked_edit( + pub fn linked_edits( &self, buffer: &Entity<Buffer>, position: Anchor, cx: &mut Context<Self>, ) -> Task<Result<Vec<Range<Anchor>>>> { self.lsp_store.update(cx, |lsp_store, cx| { - lsp_store.linked_edit(buffer, position, cx) + lsp_store.linked_edits(buffer, position, cx) }) } @@ -3607,20 +3660,29 @@ impl Project { }) } - pub fn code_lens<T: Clone + ToOffset>( + pub fn code_lens_actions<T: Clone + ToOffset>( &mut self, - buffer_handle: &Entity<Buffer>, + buffer: &Entity<Buffer>, range: Range<T>, cx: &mut Context<Self>, ) -> Task<Result<Vec<CodeAction>>> { - let snapshot = buffer_handle.read(cx).snapshot(); - let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end); + let snapshot = buffer.read(cx).snapshot(); + let range = range.clone().to_owned().to_point(&snapshot); + let range_start = snapshot.anchor_before(range.start); + let range_end = if range.start == range.end { + range_start + } else { + snapshot.anchor_after(range.end) + }; + let range = range_start..range_end; let code_lens_actions = self .lsp_store - .update(cx, |lsp_store, cx| lsp_store.code_lens(buffer_handle, cx)); + .update(cx, |lsp_store, cx| lsp_store.code_lens_actions(buffer, cx)); cx.background_spawn(async move { - let mut code_lens_actions = code_lens_actions.await?; + let mut code_lens_actions = code_lens_actions + .await + .map_err(|e| anyhow!("code lens fetch failed: {e:#}"))?; code_lens_actions.retain(|code_lens_action| { range .start @@ -3659,19 +3721,6 @@ impl Project { }) } - fn prepare_rename_impl( - &mut self, - buffer: Entity<Buffer>, - position: PointUtf16, - cx: &mut Context<Self>, - ) -> Task<Result<PrepareRenameResponse>> { - self.request_lsp( - buffer, - LanguageServerToQuery::FirstCapable, - PrepareRename { position }, - cx, - ) - } pub fn prepare_rename<T: ToPointUtf16>( &mut self, buffer: Entity<Buffer>, @@ -3679,7 +3728,12 @@ impl Project { cx: &mut Context<Self>, ) -> Task<Result<PrepareRenameResponse>> { let position = position.to_point_utf16(buffer.read(cx)); - self.prepare_rename_impl(buffer, position, cx) + self.request_lsp( + buffer, + LanguageServerToQuery::FirstCapable, + PrepareRename { position }, + cx, + ) } pub fn perform_rename<T: ToPointUtf16>( @@ -3983,7 +4037,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.request_lsp(buffer_handle, server, request, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -4948,63 +5002,53 @@ impl Project { } pub fn any_language_server_supports_inlay_hints(&self, buffer: &Buffer, cx: &mut App) -> bool { - self.lsp_store.update(cx, |this, cx| { - this.language_servers_for_local_buffer(buffer, cx) - .any( - |(_, server)| match server.capabilities().inlay_hint_provider { - Some(lsp::OneOf::Left(enabled)) => enabled, - Some(lsp::OneOf::Right(_)) => true, - None => false, - }, - ) + let Some(language) = buffer.language().cloned() else { + return false; + }; + self.lsp_store.update(cx, |lsp_store, _| { + let relevant_language_servers = lsp_store + .languages + .lsp_adapters(&language.name()) + .into_iter() + .map(|lsp_adapter| lsp_adapter.name()) + .collect::<HashSet<_>>(); + lsp_store + .language_server_statuses() + .filter_map(|(server_id, server_status)| { + relevant_language_servers + .contains(&server_status.name) + .then_some(server_id) + }) + .filter_map(|server_id| lsp_store.lsp_server_capabilities.get(&server_id)) + .any(InlayHints::check_capabilities) }) } pub fn language_server_id_for_name( &self, buffer: &Buffer, - name: &str, - cx: &mut App, - ) -> Task<Option<LanguageServerId>> { - if self.is_local() { - Task::ready(self.lsp_store.update(cx, |lsp_store, cx| { - lsp_store - .language_servers_for_local_buffer(buffer, cx) - .find_map(|(adapter, server)| { - if adapter.name.0 == name { - Some(server.server_id()) - } else { - None - } - }) - })) - } else if let Some(project_id) = self.remote_id() { - let request = self.client.request(proto::LanguageServerIdForName { - project_id, - buffer_id: buffer.remote_id().to_proto(), - name: name.to_string(), - }); - cx.background_spawn(async move { - let response = request.await.log_err()?; - response.server_id.map(LanguageServerId::from_proto) - }) - } else if let Some(ssh_client) = self.ssh_client.as_ref() { - let request = - ssh_client - .read(cx) - .proto_client() - .request(proto::LanguageServerIdForName { - project_id: SSH_PROJECT_ID, - buffer_id: buffer.remote_id().to_proto(), - name: name.to_string(), - }); - cx.background_spawn(async move { - let response = request.await.log_err()?; - response.server_id.map(LanguageServerId::from_proto) - }) - } else { - Task::ready(None) + name: &LanguageServerName, + cx: &App, + ) -> Option<LanguageServerId> { + let language = buffer.language()?; + let relevant_language_servers = self + .languages + .lsp_adapters(&language.name()) + .into_iter() + .map(|lsp_adapter| lsp_adapter.name()) + .collect::<HashSet<_>>(); + if !relevant_language_servers.contains(name) { + return None; } + self.language_server_statuses(cx) + .filter(|(_, server_status)| relevant_language_servers.contains(&server_status.name)) + .find_map(|(server_id, server_status)| { + if &server_status.name == name { + Some(server_id) + } else { + None + } + }) } pub fn has_language_servers_for(&self, buffer: &Buffer, cx: &mut App) -> bool { diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index 779cf95add..75ebc8339a 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -1100,7 +1100,7 @@ async fn test_reporting_fs_changes_to_language_servers(cx: &mut gpui::TestAppCon let fake_server = fake_servers.next().await.unwrap(); let (server_id, server_name) = lsp_store.read_with(cx, |lsp_store, _| { let (id, status) = lsp_store.language_server_statuses().next().unwrap(); - (id, LanguageServerName::from(status.name.as_str())) + (id, status.name.clone()) }); // Simulate jumping to a definition in a dependency outside of the worktree. @@ -1698,7 +1698,7 @@ async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppC name: "the-language-server", disk_based_diagnostics_sources: vec!["disk".into()], disk_based_diagnostics_progress_token: Some(progress_token.into()), - ..Default::default() + ..FakeLspAdapter::default() }, ); @@ -1710,6 +1710,7 @@ async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppC }) .await .unwrap(); + let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id()); // Simulate diagnostics starting to update. let fake_server = fake_servers.next().await.unwrap(); fake_server.start_progress(progress_token).await; @@ -1736,6 +1737,14 @@ async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppC ); assert_eq!(events.next().await.unwrap(), Event::RefreshInlayHints); fake_server.start_progress(progress_token).await; + assert_eq!( + events.next().await.unwrap(), + Event::LanguageServerBufferRegistered { + server_id: LanguageServerId(1), + buffer_id, + buffer_abs_path: PathBuf::from(path!("/dir/a.rs")), + } + ); assert_eq!( events.next().await.unwrap(), Event::DiskBasedDiagnosticsStarted { diff --git a/crates/project/src/terminals.rs b/crates/project/src/terminals.rs index 8cfbdff311..973d4e8811 100644 --- a/crates/project/src/terminals.rs +++ b/crates/project/src/terminals.rs @@ -213,17 +213,24 @@ impl Project { cx: &mut Context<Self>, ) -> Result<Entity<Terminal>> { let this = &mut *self; + let ssh_details = this.ssh_details(cx); let path: Option<Arc<Path>> = match &kind { TerminalKind::Shell(path) => path.as_ref().map(|path| Arc::from(path.as_ref())), TerminalKind::Task(spawn_task) => { if let Some(cwd) = &spawn_task.cwd { - Some(Arc::from(cwd.as_ref())) + if ssh_details.is_some() { + Some(Arc::from(cwd.as_ref())) + } else { + let cwd = cwd.to_string_lossy(); + let tilde_substituted = shellexpand::tilde(&cwd); + Some(Arc::from(Path::new(tilde_substituted.as_ref()))) + } } else { this.active_project_directory(cx) } } }; - let ssh_details = this.ssh_details(cx); + let is_ssh_terminal = ssh_details.is_some(); let mut settings_location = None; diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index b0073f294f..048b9e73d0 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -114,6 +114,7 @@ pub struct ProjectPanel { mouse_down: bool, hover_expand_task: Option<Task<()>>, previous_drag_position: Option<Point<Pixels>>, + sticky_items_count: usize, } struct DragTargetEntry { @@ -572,6 +573,9 @@ impl ProjectPanel { if project_panel_settings.hide_root != new_settings.hide_root { this.update_visible_entries(None, cx); } + if project_panel_settings.sticky_scroll && !new_settings.sticky_scroll { + this.sticky_items_count = 0; + } project_panel_settings = new_settings; this.update_diagnostics(cx); cx.notify(); @@ -615,6 +619,7 @@ impl ProjectPanel { mouse_down: false, hover_expand_task: None, previous_drag_position: None, + sticky_items_count: 0, }; this.update_visible_entries(None, cx); @@ -2267,8 +2272,11 @@ impl ProjectPanel { fn autoscroll(&mut self, cx: &mut Context<Self>) { if let Some((_, _, index)) = self.selection.and_then(|s| self.index_for_selection(s)) { - self.scroll_handle - .scroll_to_item(index, ScrollStrategy::Center); + self.scroll_handle.scroll_to_item_with_offset( + index, + ScrollStrategy::Center, + self.sticky_items_count, + ); cx.notify(); } } @@ -2723,26 +2731,7 @@ impl ProjectPanel { } fn index_for_selection(&self, selection: SelectedEntry) -> Option<(usize, usize, usize)> { - let mut entry_index = 0; - let mut visible_entries_index = 0; - for (worktree_index, (worktree_id, worktree_entries, _)) in - self.visible_entries.iter().enumerate() - { - if *worktree_id == selection.worktree_id { - for entry in worktree_entries { - if entry.id == selection.entry_id { - return Some((worktree_index, entry_index, visible_entries_index)); - } else { - visible_entries_index += 1; - entry_index += 1; - } - } - break; - } else { - visible_entries_index += worktree_entries.len(); - } - } - None + self.index_for_entry(selection.entry_id, selection.worktree_id) } fn disjoint_entries(&self, cx: &App) -> BTreeSet<SelectedEntry> { @@ -3353,12 +3342,12 @@ impl ProjectPanel { entry_id: ProjectEntryId, worktree_id: WorktreeId, ) -> Option<(usize, usize, usize)> { - let mut worktree_ix = 0; let mut total_ix = 0; - for (current_worktree_id, visible_worktree_entries, _) in &self.visible_entries { + for (worktree_ix, (current_worktree_id, visible_worktree_entries, _)) in + self.visible_entries.iter().enumerate() + { if worktree_id != *current_worktree_id { total_ix += visible_worktree_entries.len(); - worktree_ix += 1; continue; } @@ -4168,13 +4157,12 @@ impl ProjectPanel { ) .on_click( cx.listener(move |this, event: &gpui::ClickEvent, window, cx| { - if event.down.button == MouseButton::Right - || event.down.first_mouse + if event.is_right_click() || event.first_focus() || show_editor { return; } - if event.down.button == MouseButton::Left { + if event.standard_click() { this.mouse_down = false; } cx.stop_propagation(); @@ -4214,7 +4202,7 @@ impl ProjectPanel { this.marked_entries.insert(clicked_entry); } } else if event.modifiers().secondary() { - if event.down.click_count > 1 { + if event.click_count() > 1 { this.split_entry(entry_id, cx); } else { this.selection = Some(selection); @@ -4226,10 +4214,7 @@ impl ProjectPanel { this.marked_entries.clear(); if is_sticky { if let Some((_, _, index)) = this.index_for_entry(entry_id, worktree_id) { - let strategy = sticky_index - .map(ScrollStrategy::ToPosition) - .unwrap_or(ScrollStrategy::Top); - this.scroll_handle.scroll_to_item(index, strategy); + this.scroll_handle.scroll_to_item_with_offset(index, ScrollStrategy::Top, sticky_index.unwrap_or(0)); cx.notify(); // move down by 1px so that clicked item // don't count as sticky anymore @@ -4251,7 +4236,7 @@ impl ProjectPanel { } } else { let preview_tabs_enabled = PreviewTabsSettings::get_global(cx).enabled; - let click_count = event.up.click_count; + let click_count = event.click_count(); let focus_opened_item = !preview_tabs_enabled || click_count > 1; let allow_preview = preview_tabs_enabled && click_count == 1; this.open_entry(entry_id, focus_opened_item, allow_preview, cx); @@ -5152,7 +5137,10 @@ impl Render for ProjectPanel { this.hide_scrollbar(window, cx); } })) - .on_click(cx.listener(|this, _event, _, cx| { + .on_click(cx.listener(|this, event, _, cx| { + if matches!(event, gpui::ClickEvent::Keyboard(_)) { + return; + } cx.stop_propagation(); this.selection = None; this.marked_entries.clear(); @@ -5193,7 +5181,7 @@ impl Render for ProjectPanel { .on_action(cx.listener(Self::paste)) .on_action(cx.listener(Self::duplicate)) .on_click(cx.listener(|this, event: &gpui::ClickEvent, window, cx| { - if event.up.click_count > 1 { + if event.click_count() > 1 { if let Some(entry_id) = this.last_worktree_root_id { let project = this.project.read(cx); @@ -5366,7 +5354,10 @@ impl Render for ProjectPanel { items }, |this, marker_entry, window, cx| { - this.render_sticky_entries(marker_entry, window, cx) + let sticky_entries = + this.render_sticky_entries(marker_entry, window, cx); + this.sticky_items_count = sticky_entries.len(); + sticky_entries }, ); list.with_decoration(if show_indent_guides { diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index 5330ee506a..353f19adb2 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -79,11 +79,16 @@ message OpenServerSettings { uint64 project_id = 1; } -message GetPanicFiles { +message GetCrashFiles { } -message GetPanicFilesResponse { - repeated string file_contents = 2; +message GetCrashFilesResponse { + repeated CrashReport crashes = 1; +} + +message CrashReport { + optional string panic_contents = 1; + optional bytes minidump_contents = 2; } message Extension { diff --git a/crates/proto/proto/call.proto b/crates/proto/proto/call.proto index 5212f3b43f..b5c882db56 100644 --- a/crates/proto/proto/call.proto +++ b/crates/proto/proto/call.proto @@ -71,6 +71,7 @@ message RejoinedProject { repeated WorktreeMetadata worktrees = 2; repeated Collaborator collaborators = 3; repeated LanguageServer language_servers = 4; + repeated string language_server_capabilities = 5; } message LeaveRoom {} @@ -199,6 +200,7 @@ message JoinProjectResponse { repeated WorktreeMetadata worktrees = 2; repeated Collaborator collaborators = 3; repeated LanguageServer language_servers = 4; + repeated string language_server_capabilities = 8; ChannelRole role = 6; reserved 7; } diff --git a/crates/proto/proto/debugger.proto b/crates/proto/proto/debugger.proto index 09abd4bf1c..c6f9c9f134 100644 --- a/crates/proto/proto/debugger.proto +++ b/crates/proto/proto/debugger.proto @@ -188,7 +188,7 @@ message DapSetVariableValueResponse { message DapPauseRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; } message DapDisconnectRequest { @@ -202,7 +202,7 @@ message DapDisconnectRequest { message DapTerminateThreadsRequest { uint64 project_id = 1; uint64 client_id = 2; - repeated uint64 thread_ids = 3; + repeated int64 thread_ids = 3; } message DapThreadsRequest { @@ -246,7 +246,7 @@ message IgnoreBreakpointState { message DapNextRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional bool single_thread = 4; optional SteppingGranularity granularity = 5; } @@ -254,7 +254,7 @@ message DapNextRequest { message DapStepInRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional uint64 target_id = 4; optional bool single_thread = 5; optional SteppingGranularity granularity = 6; @@ -263,7 +263,7 @@ message DapStepInRequest { message DapStepOutRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional bool single_thread = 4; optional SteppingGranularity granularity = 5; } @@ -271,7 +271,7 @@ message DapStepOutRequest { message DapStepBackRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional bool single_thread = 4; optional SteppingGranularity granularity = 5; } @@ -279,7 +279,7 @@ message DapStepBackRequest { message DapContinueRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional bool single_thread = 4; } @@ -311,7 +311,7 @@ message DapLoadedSourcesResponse { message DapStackTraceRequest { uint64 project_id = 1; uint64 client_id = 2; - uint64 thread_id = 3; + int64 thread_id = 3; optional uint64 start_frame = 4; optional uint64 stack_trace_levels = 5; } @@ -358,7 +358,7 @@ message DapVariable { } message DapThread { - uint64 id = 1; + int64 id = 1; string name = 2; } diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index 1d544b15ff..c32da9b110 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -286,6 +286,17 @@ message Unstage { repeated string paths = 4; } +message Stash { + uint64 project_id = 1; + uint64 repository_id = 2; + repeated string paths = 3; +} + +message StashPop { + uint64 project_id = 1; + uint64 repository_id = 2; +} + message Commit { uint64 project_id = 1; reserved 2; @@ -411,3 +422,12 @@ message BlameBufferResponse { reserved 1 to 4; } + +message GetDefaultBranch { + uint64 project_id = 1; + uint64 repository_id = 2; +} + +message GetDefaultBranchResponse { + optional string branch = 1; +} diff --git a/crates/proto/proto/lsp.proto b/crates/proto/proto/lsp.proto index e3c2f69c0b..164a7d3a50 100644 --- a/crates/proto/proto/lsp.proto +++ b/crates/proto/proto/lsp.proto @@ -518,6 +518,7 @@ message LanguageServer { message StartLanguageServer { uint64 project_id = 1; LanguageServer server = 2; + string capabilities = 3; } message UpdateDiagnosticSummary { @@ -545,6 +546,7 @@ message UpdateLanguageServer { LspDiskBasedDiagnosticsUpdated disk_based_diagnostics_updated = 7; StatusUpdate status_update = 9; RegisteredForBuffer registered_for_buffer = 10; + ServerMetadataUpdated metadata_updated = 11; } } @@ -597,6 +599,11 @@ enum ServerBinaryStatus { message RegisteredForBuffer { string buffer_abs_path = 1; + uint64 buffer_id = 2; +} + +message ServerMetadataUpdated { + optional string capabilities = 1; } message LanguageServerLog { @@ -811,16 +818,6 @@ message LspResponse { uint64 server_id = 7; } -message LanguageServerIdForName { - uint64 project_id = 1; - uint64 buffer_id = 2; - string name = 3; -} - -message LanguageServerIdForNameResponse { - optional uint64 server_id = 1; -} - message LspExtRunnables { uint64 project_id = 1; uint64 buffer_id = 2; diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 31f929ec90..bb97bd500a 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -294,9 +294,6 @@ message Envelope { GetPathMetadata get_path_metadata = 278; GetPathMetadataResponse get_path_metadata_response = 279; - GetPanicFiles get_panic_files = 280; - GetPanicFilesResponse get_panic_files_response = 281; - CancelLanguageServerWork cancel_language_server_work = 282; LspExtOpenDocs lsp_ext_open_docs = 283; @@ -365,9 +362,6 @@ message Envelope { GetDocumentSymbols get_document_symbols = 330; GetDocumentSymbolsResponse get_document_symbols_response = 331; - LanguageServerIdForName language_server_id_for_name = 332; - LanguageServerIdForNameResponse language_server_id_for_name_response = 333; - LoadCommitDiff load_commit_diff = 334; LoadCommitDiffResponse load_commit_diff_response = 335; @@ -396,8 +390,16 @@ message Envelope { GetDocumentColor get_document_color = 353; GetDocumentColorResponse get_document_color_response = 354; GetColorPresentation get_color_presentation = 355; - GetColorPresentationResponse get_color_presentation_response = 356; // current max + GetColorPresentationResponse get_color_presentation_response = 356; + Stash stash = 357; + StashPop stash_pop = 358; + + GetDefaultBranch get_default_branch = 359; + GetDefaultBranchResponse get_default_branch_response = 360; + + GetCrashFiles get_crash_files = 361; + GetCrashFilesResponse get_crash_files_response = 362; // current max } reserved 87 to 88; @@ -418,6 +420,8 @@ message Envelope { reserved 270; reserved 247 to 254; reserved 255 to 256; + reserved 280 to 281; + reserved 332 to 333; } message Hello { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 918ac9e935..9edb041b4b 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -99,8 +99,8 @@ messages!( (GetHoverResponse, Background), (GetNotifications, Foreground), (GetNotificationsResponse, Foreground), - (GetPanicFiles, Background), - (GetPanicFilesResponse, Background), + (GetCrashFiles, Background), + (GetCrashFilesResponse, Background), (GetPathMetadata, Background), (GetPathMetadataResponse, Background), (GetPermalinkToLine, Foreground), @@ -121,8 +121,6 @@ messages!( (GetImplementationResponse, Background), (GetLlmToken, Background), (GetLlmTokenResponse, Background), - (LanguageServerIdForName, Background), - (LanguageServerIdForNameResponse, Background), (OpenUnstagedDiff, Foreground), (OpenUnstagedDiffResponse, Foreground), (OpenUncommittedDiff, Foreground), @@ -261,6 +259,8 @@ messages!( (Unfollow, Foreground), (UnshareProject, Foreground), (Unstage, Background), + (Stash, Background), + (StashPop, Background), (UpdateBuffer, Foreground), (UpdateBufferFile, Foreground), (UpdateChannelBuffer, Foreground), @@ -313,7 +313,9 @@ messages!( (LogToDebugConsole, Background), (GetDocumentDiagnostics, Background), (GetDocumentDiagnosticsResponse, Background), - (PullWorkspaceDiagnostics, Background) + (PullWorkspaceDiagnostics, Background), + (GetDefaultBranch, Background), + (GetDefaultBranchResponse, Background), ); request_messages!( @@ -419,13 +421,14 @@ request_messages!( (TaskContextForLocation, TaskContext), (Test, Test), (Unstage, Ack), + (Stash, Ack), + (StashPop, Ack), (UpdateBuffer, Ack), (UpdateParticipantLocation, Ack), (UpdateProject, Ack), (UpdateWorktree, Ack), (UpdateRepository, Ack), (RemoveRepository, Ack), - (LanguageServerIdForName, LanguageServerIdForNameResponse), (LspExtExpandMacro, LspExtExpandMacroResponse), (LspExtOpenDocs, LspExtOpenDocsResponse), (LspExtRunnables, LspExtRunnablesResponse), @@ -456,7 +459,7 @@ request_messages!( (ActivateToolchain, Ack), (ActiveToolchain, ActiveToolchainResponse), (GetPathMetadata, GetPathMetadataResponse), - (GetPanicFiles, GetPanicFilesResponse), + (GetCrashFiles, GetCrashFilesResponse), (CancelLanguageServerWork, Ack), (SyncExtensions, SyncExtensionsResponse), (InstallExtension, Ack), @@ -479,7 +482,8 @@ request_messages!( (GetDebugAdapterBinary, DebugAdapterBinary), (RunDebugLocators, DebugRequest), (GetDocumentDiagnostics, GetDocumentDiagnosticsResponse), - (PullWorkspaceDiagnostics, Ack) + (PullWorkspaceDiagnostics, Ack), + (GetDefaultBranch, GetDefaultBranchResponse), ); entity_messages!( @@ -549,6 +553,8 @@ entity_messages!( TaskContextForLocation, UnshareProject, Unstage, + Stash, + StashPop, UpdateBuffer, UpdateBufferFile, UpdateDiagnosticSummary, @@ -579,7 +585,6 @@ entity_messages!( OpenServerSettings, GetPermalinkToLine, LanguageServerPromptRequest, - LanguageServerIdForName, GitGetBranches, UpdateGitBranch, ListToolchains, @@ -609,7 +614,8 @@ entity_messages!( GetDebugAdapterBinary, LogToDebugConsole, GetDocumentDiagnostics, - PullWorkspaceDiagnostics + PullWorkspaceDiagnostics, + GetDefaultBranch ); entity_messages!( @@ -778,6 +784,25 @@ pub fn split_repository_update( }]) } +impl MultiLspQuery { + pub fn request_str(&self) -> &str { + match self.request { + Some(multi_lsp_query::Request::GetHover(_)) => "GetHover", + Some(multi_lsp_query::Request::GetCodeActions(_)) => "GetCodeActions", + Some(multi_lsp_query::Request::GetSignatureHelp(_)) => "GetSignatureHelp", + Some(multi_lsp_query::Request::GetCodeLens(_)) => "GetCodeLens", + Some(multi_lsp_query::Request::GetDocumentDiagnostics(_)) => "GetDocumentDiagnostics", + Some(multi_lsp_query::Request::GetDocumentColor(_)) => "GetDocumentColor", + Some(multi_lsp_query::Request::GetDefinition(_)) => "GetDefinition", + Some(multi_lsp_query::Request::GetDeclaration(_)) => "GetDeclaration", + Some(multi_lsp_query::Request::GetTypeDefinition(_)) => "GetTypeDefinition", + Some(multi_lsp_query::Request::GetImplementation(_)) => "GetImplementation", + Some(multi_lsp_query::Request::GetReferences(_)) => "GetReferences", + None => "<unknown>", + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index 5dbde6496d..2093e96cae 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -141,6 +141,7 @@ impl Focusable for RecentProjects { impl Render for RecentProjects { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { v_flex() + .key_context("RecentProjects") .w(rems(self.rem_width)) .child(self.picker.clone()) .on_mouse_down_out(cx.listener(|this, _, window, cx| { diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index aa5103e62b..354434a7fc 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -953,7 +953,7 @@ impl RemoteServerProjects { ) .child(Label::new(project.paths.join(", "))) .on_click(cx.listener(move |this, e: &ClickEvent, window, cx| { - let secondary_confirm = e.down.modifiers.platform; + let secondary_confirm = e.modifiers().platform; callback(this, secondary_confirm, window, cx) })) .when(is_from_zed, |server_list_item| { @@ -963,7 +963,7 @@ impl RemoteServerProjects { .child({ let project = project.clone(); // Right-margin to offset it from the Scrollbar - IconButton::new("remove-remote-project", IconName::TrashAlt) + IconButton::new("remove-remote-project", IconName::Trash) .icon_size(IconSize::Small) .shape(IconButtonShape::Square) .size(ButtonSize::Large) diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index e31d3dcfd5..4306251e44 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -1742,7 +1742,7 @@ impl SshRemoteConnection { } }); - cx.spawn(async move |_| { + cx.background_spawn(async move { let result = futures::select! { result = stdin_task.fuse() => { result.context("stdin") diff --git a/crates/remote_server/Cargo.toml b/crates/remote_server/Cargo.toml index 443c47919f..c6a546f345 100644 --- a/crates/remote_server/Cargo.toml +++ b/crates/remote_server/Cargo.toml @@ -67,8 +67,11 @@ watch.workspace = true worktree.workspace = true [target.'cfg(not(windows))'.dependencies] +crashes.workspace = true +crash-handler.workspace = true fork.workspace = true libc.workspace = true +minidumper.workspace = true [dev-dependencies] assistant_tool.workspace = true diff --git a/crates/remote_server/src/main.rs b/crates/remote_server/src/main.rs index 98f635d856..03b0c3eda3 100644 --- a/crates/remote_server/src/main.rs +++ b/crates/remote_server/src/main.rs @@ -12,6 +12,10 @@ struct Cli { /// by having Zed act like netcat communicating over a Unix socket. #[arg(long, hide = true)] askpass: Option<String>, + /// Used for recording minidumps on crashes by having the server run a separate + /// process communicating over a socket. + #[arg(long, hide = true)] + crash_handler: Option<PathBuf>, /// Used for loading the environment from the project. #[arg(long, hide = true)] printenv: bool, @@ -58,6 +62,11 @@ fn main() { return; } + if let Some(socket) = &cli.crash_handler { + crashes::crash_server(socket.as_path()); + return; + } + if cli.printenv { util::shell_env::print_env(); return; diff --git a/crates/remote_server/src/unix.rs b/crates/remote_server/src/unix.rs index 84ce08ff25..9bb5645dc7 100644 --- a/crates/remote_server/src/unix.rs +++ b/crates/remote_server/src/unix.rs @@ -17,6 +17,7 @@ use node_runtime::{NodeBinaryOptions, NodeRuntime}; use paths::logs_dir; use project::project_settings::ProjectSettings; +use proto::CrashReport; use release_channel::{AppVersion, RELEASE_CHANNEL, ReleaseChannel}; use remote::proxy::ProxyLaunchError; use remote::ssh_session::ChannelClient; @@ -33,6 +34,7 @@ use smol::io::AsyncReadExt; use smol::Async; use smol::{net::unix::UnixListener, stream::StreamExt as _}; +use std::collections::HashMap; use std::ffi::OsStr; use std::ops::ControlFlow; use std::str::FromStr; @@ -109,8 +111,9 @@ fn init_logging_server(log_file_path: PathBuf) -> Result<Receiver<Vec<u8>>> { Ok(rx) } -fn init_panic_hook() { - std::panic::set_hook(Box::new(|info| { +fn init_panic_hook(session_id: String) { + std::panic::set_hook(Box::new(move |info| { + crashes::handle_panic(); let payload = info .payload() .downcast_ref::<&str>() @@ -171,9 +174,11 @@ fn init_panic_hook() { architecture: env::consts::ARCH.into(), panicked_on: Utc::now().timestamp_millis(), backtrace, - system_id: None, // Set on SSH client - installation_id: None, // Set on SSH client - session_id: "".to_string(), // Set on SSH client + system_id: None, // Set on SSH client + installation_id: None, // Set on SSH client + + // used on this end to associate panics with minidumps, but will be replaced on the SSH client + session_id: session_id.clone(), }; if let Some(panic_data_json) = serde_json::to_string(&panic_data).log_err() { @@ -194,44 +199,69 @@ fn init_panic_hook() { })); } -fn handle_panic_requests(project: &Entity<HeadlessProject>, client: &Arc<ChannelClient>) { +fn handle_crash_files_requests(project: &Entity<HeadlessProject>, client: &Arc<ChannelClient>) { let client: AnyProtoClient = client.clone().into(); client.add_request_handler( project.downgrade(), - |_, _: TypedEnvelope<proto::GetPanicFiles>, _cx| async move { + |_, _: TypedEnvelope<proto::GetCrashFiles>, _cx| async move { + let mut crashes = Vec::new(); + let mut minidumps_by_session_id = HashMap::new(); let mut children = smol::fs::read_dir(paths::logs_dir()).await?; - let mut panic_files = Vec::new(); while let Some(child) = children.next().await { let child = child?; let child_path = child.path(); - if child_path.extension() != Some(OsStr::new("panic")) { - continue; + let extension = child_path.extension(); + if extension == Some(OsStr::new("panic")) { + let filename = if let Some(filename) = child_path.file_name() { + filename.to_string_lossy() + } else { + continue; + }; + + if !filename.starts_with("zed") { + continue; + } + + let file_contents = smol::fs::read_to_string(&child_path) + .await + .context("error reading panic file")?; + + crashes.push(proto::CrashReport { + panic_contents: Some(file_contents), + minidump_contents: None, + }); + } else if extension == Some(OsStr::new("dmp")) { + let session_id = child_path.file_stem().unwrap().to_string_lossy(); + minidumps_by_session_id + .insert(session_id.to_string(), smol::fs::read(&child_path).await?); } - let filename = if let Some(filename) = child_path.file_name() { - filename.to_string_lossy() - } else { - continue; - }; - - if !filename.starts_with("zed") { - continue; - } - - let file_contents = smol::fs::read_to_string(&child_path) - .await - .context("error reading panic file")?; - - panic_files.push(file_contents); // We've done what we can, delete the file - std::fs::remove_file(child_path) + smol::fs::remove_file(&child_path) + .await .context("error removing panic") .log_err(); } - anyhow::Ok(proto::GetPanicFilesResponse { - file_contents: panic_files, - }) + + for crash in &mut crashes { + let panic: telemetry_events::Panic = + serde_json::from_str(crash.panic_contents.as_ref().unwrap())?; + if let dump @ Some(_) = minidumps_by_session_id.remove(&panic.session_id) { + crash.minidump_contents = dump; + } + } + + crashes.extend( + minidumps_by_session_id + .into_values() + .map(|dmp| CrashReport { + panic_contents: None, + minidump_contents: Some(dmp), + }), + ); + + anyhow::Ok(proto::GetCrashFilesResponse { crashes }) }, ); } @@ -409,7 +439,12 @@ pub fn execute_run( ControlFlow::Continue(_) => {} } - init_panic_hook(); + let app = gpui::Application::headless(); + let id = std::process::id().to_string(); + app.background_executor() + .spawn(crashes::init(id.clone())) + .detach(); + init_panic_hook(id); let log_rx = init_logging_server(log_file)?; log::info!( "starting up. pid_file: {:?}, stdin_socket: {:?}, stdout_socket: {:?}, stderr_socket: {:?}", @@ -425,7 +460,7 @@ pub fn execute_run( let listeners = ServerListeners::new(stdin_socket, stdout_socket, stderr_socket)?; let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new()); - gpui::Application::headless().run(move |cx| { + app.run(move |cx| { settings::init(cx); let app_version = AppVersion::load(env!("ZED_PKG_VERSION")); release_channel::init(app_version, cx); @@ -486,7 +521,7 @@ pub fn execute_run( ) }); - handle_panic_requests(&project, &session); + handle_crash_files_requests(&project, &session); cx.background_spawn(async move { cleanup_old_binaries() }) .detach(); @@ -530,12 +565,15 @@ impl ServerPaths { pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> { init_logging_proxy(); - init_panic_hook(); - - log::info!("starting proxy process. PID: {}", std::process::id()); let server_paths = ServerPaths::new(&identifier)?; + let id = std::process::id().to_string(); + smol::spawn(crashes::init(id.clone())).detach(); + init_panic_hook(id); + + log::info!("starting proxy process. PID: {}", std::process::id()); + let server_pid = check_pid_file(&server_paths.pid_file)?; let server_running = server_pid.is_some(); if is_reconnecting { diff --git a/crates/repl/src/notebook/cell.rs b/crates/repl/src/notebook/cell.rs index 2ed68c17d1..18851417c0 100644 --- a/crates/repl/src/notebook/cell.rs +++ b/crates/repl/src/notebook/cell.rs @@ -38,7 +38,7 @@ pub enum CellControlType { impl CellControlType { fn icon_name(&self) -> IconName { match self { - CellControlType::RunCell => IconName::Play, + CellControlType::RunCell => IconName::PlayOutlined, CellControlType::RerunCell => IconName::ArrowCircle, CellControlType::ClearCell => IconName::ListX, CellControlType::CellOptions => IconName::Ellipsis, diff --git a/crates/repl/src/notebook/notebook_ui.rs b/crates/repl/src/notebook/notebook_ui.rs index d14f458fa9..2efa51e0cc 100644 --- a/crates/repl/src/notebook/notebook_ui.rs +++ b/crates/repl/src/notebook/notebook_ui.rs @@ -126,29 +126,7 @@ impl NotebookEditor { let cell_count = cell_order.len(); let this = cx.entity(); - let cell_list = ListState::new( - cell_count, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - notebook_handle - .upgrade() - .and_then(|notebook_handle| { - notebook_handle.update(cx, |notebook, cx| { - notebook - .cell_order - .get(ix) - .and_then(|cell_id| notebook.cell_map.get(cell_id)) - .map(|cell| { - notebook - .render_cell(ix, cell, window, cx) - .into_any_element() - }) - }) - }) - .unwrap_or_else(|| div().into_any()) - }, - ); + let cell_list = ListState::new(cell_count, gpui::ListAlignment::Top, px(1000.)); Self { project, @@ -343,7 +321,7 @@ impl NotebookEditor { .child( Self::render_notebook_control( "run-all-cells", - IconName::Play, + IconName::PlayOutlined, window, cx, ) @@ -544,7 +522,19 @@ impl Render for NotebookEditor { .flex_1() .size_full() .overflow_y_scroll() - .child(list(self.cell_list.clone()).size_full()), + .child(list( + self.cell_list.clone(), + cx.processor(|this, ix, window, cx| { + this.cell_order + .get(ix) + .and_then(|cell_id| this.cell_map.get(cell_id)) + .map(|cell| { + this.render_cell(ix, cell, window, cx).into_any_element() + }) + .unwrap_or_else(|| div().into_any()) + }), + )) + .size_full(), ) .child(self.render_notebook_controls(window, cx)) } diff --git a/crates/reqwest_client/src/reqwest_client.rs b/crates/reqwest_client/src/reqwest_client.rs index daff20ac4a..6461a0ae17 100644 --- a/crates/reqwest_client/src/reqwest_client.rs +++ b/crates/reqwest_client/src/reqwest_client.rs @@ -4,14 +4,13 @@ use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration use anyhow::anyhow; use bytes::{BufMut, Bytes, BytesMut}; -use futures::{AsyncRead, TryStreamExt as _}; +use futures::{AsyncRead, FutureExt as _, TryStreamExt as _}; use http_client::{RedirectPolicy, Url, http}; use regex::Regex; use reqwest::{ header::{HeaderMap, HeaderValue}, redirect, }; -use smol::future::FutureExt; const DEFAULT_CAPACITY: usize = 4096; static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new(); @@ -20,6 +19,7 @@ static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+") pub struct ReqwestClient { client: reqwest::Client, proxy: Option<Url>, + user_agent: Option<HeaderValue>, handle: tokio::runtime::Handle, } @@ -44,9 +44,11 @@ impl ReqwestClient { Ok(client.into()) } - pub fn proxy_and_user_agent(proxy: Option<Url>, agent: &str) -> anyhow::Result<Self> { + pub fn proxy_and_user_agent(proxy: Option<Url>, user_agent: &str) -> anyhow::Result<Self> { + let user_agent = HeaderValue::from_str(user_agent)?; + let mut map = HeaderMap::new(); - map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?); + map.insert(http::header::USER_AGENT, user_agent.clone()); let mut client = Self::builder().default_headers(map); let client_has_proxy; @@ -73,6 +75,7 @@ impl ReqwestClient { .build()?; let mut client: ReqwestClient = client.into(); client.proxy = client_has_proxy.then_some(proxy).flatten(); + client.user_agent = Some(user_agent); Ok(client) } } @@ -96,6 +99,7 @@ impl From<reqwest::Client> for ReqwestClient { client, handle, proxy: None, + user_agent: None, } } } @@ -216,6 +220,10 @@ impl http_client::HttpClient for ReqwestClient { type_name::<Self>() } + fn user_agent(&self) -> Option<&HeaderValue> { + self.user_agent.as_ref() + } + fn send( &self, req: http::Request<http_client::AsyncBody>, @@ -265,6 +273,26 @@ impl http_client::HttpClient for ReqwestClient { } .boxed() } + + fn send_multipart_form<'a>( + &'a self, + url: &str, + form: reqwest::multipart::Form, + ) -> futures::future::BoxFuture<'a, anyhow::Result<http_client::Response<http_client::AsyncBody>>> + { + let response = self.client.post(url).multipart(form).send(); + self.handle + .spawn(async move { + let response = response.await?; + let mut builder = http::response::Builder::new().status(response.status()); + for (k, v) in response.headers() { + builder = builder.header(k, v) + } + Ok(builder.body(response.bytes().await?.into())?) + }) + .map(|e| e?) + .boxed() + } } #[cfg(test)] diff --git a/crates/rope/src/rope.rs b/crates/rope/src/rope.rs index 515cd71331..aa3ed5db57 100644 --- a/crates/rope/src/rope.rs +++ b/crates/rope/src/rope.rs @@ -12,7 +12,7 @@ use std::{ ops::{self, AddAssign, Range}, str, }; -use sum_tree::{Bias, Dimension, SumTree}; +use sum_tree::{Bias, Dimension, Dimensions, SumTree}; pub use chunk::ChunkSlice; pub use offset_utf16::OffsetUtf16; @@ -282,7 +282,7 @@ impl Rope { if offset >= self.summary().len { return self.summary().len_utf16; } - let mut cursor = self.chunks.cursor::<(usize, OffsetUtf16)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<usize, OffsetUtf16>>(&()); cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 @@ -295,7 +295,7 @@ impl Rope { if offset >= self.summary().len_utf16 { return self.summary().len; } - let mut cursor = self.chunks.cursor::<(OffsetUtf16, usize)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<OffsetUtf16, usize>>(&()); cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 @@ -308,7 +308,7 @@ impl Rope { if offset >= self.summary().len { return self.summary().lines; } - let mut cursor = self.chunks.cursor::<(usize, Point)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<usize, Point>>(&()); cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 @@ -321,7 +321,7 @@ impl Rope { if offset >= self.summary().len { return self.summary().lines_utf16(); } - let mut cursor = self.chunks.cursor::<(usize, PointUtf16)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<usize, PointUtf16>>(&()); cursor.seek(&offset, Bias::Left); let overshoot = offset - cursor.start().0; cursor.start().1 @@ -334,7 +334,7 @@ impl Rope { if point >= self.summary().lines { return self.summary().lines_utf16(); } - let mut cursor = self.chunks.cursor::<(Point, PointUtf16)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<Point, PointUtf16>>(&()); cursor.seek(&point, Bias::Left); let overshoot = point - cursor.start().0; cursor.start().1 @@ -347,7 +347,7 @@ impl Rope { if point >= self.summary().lines { return self.summary().len; } - let mut cursor = self.chunks.cursor::<(Point, usize)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<Point, usize>>(&()); cursor.seek(&point, Bias::Left); let overshoot = point - cursor.start().0; cursor.start().1 @@ -368,7 +368,7 @@ impl Rope { if point >= self.summary().lines_utf16() { return self.summary().len; } - let mut cursor = self.chunks.cursor::<(PointUtf16, usize)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<PointUtf16, usize>>(&()); cursor.seek(&point, Bias::Left); let overshoot = point - cursor.start().0; cursor.start().1 @@ -381,7 +381,7 @@ impl Rope { if point.0 >= self.summary().lines_utf16() { return self.summary().lines; } - let mut cursor = self.chunks.cursor::<(PointUtf16, Point)>(&()); + let mut cursor = self.chunks.cursor::<Dimensions<PointUtf16, Point>>(&()); cursor.seek(&point.0, Bias::Left); let overshoot = Unclipped(point.0 - cursor.start().0); cursor.start().1 @@ -1168,16 +1168,17 @@ pub trait TextDimension: fn add_assign(&mut self, other: &Self); } -impl<D1: TextDimension, D2: TextDimension> TextDimension for (D1, D2) { +impl<D1: TextDimension, D2: TextDimension> TextDimension for Dimensions<D1, D2, ()> { fn from_text_summary(summary: &TextSummary) -> Self { - ( + Dimensions( D1::from_text_summary(summary), D2::from_text_summary(summary), + (), ) } fn from_chunk(chunk: ChunkSlice) -> Self { - (D1::from_chunk(chunk), D2::from_chunk(chunk)) + Dimensions(D1::from_chunk(chunk), D2::from_chunk(chunk), ()) } fn add_assign(&mut self, other: &Self) { diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 80a104641f..8ddebfb269 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -422,8 +422,23 @@ impl Peer { receiver_id: ConnectionId, request: T, ) -> impl Future<Output = Result<T::Response>> { + let request_start_time = Instant::now(); + let elapsed_time = move || request_start_time.elapsed().as_millis(); + tracing::info!("start forwarding request"); self.request_internal(Some(sender_id), receiver_id, request) .map_ok(|envelope| envelope.payload) + .inspect_err(move |_| { + tracing::error!( + waiting_for_host_ms = elapsed_time(), + "error forwarding request" + ) + }) + .inspect_ok(move |_| { + tracing::info!( + waiting_for_host_ms = elapsed_time(), + "finished forwarding request" + ) + }) } fn request_internal<T: RequestMessage>( diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index be6a69c23b..ebec96dd7b 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -319,7 +319,7 @@ impl PickerDelegate for RulePickerDelegate { }) .into_any() } else { - IconButton::new("delete-rule", IconName::TrashAlt) + IconButton::new("delete-rule", IconName::Trash) .icon_color(Color::Muted) .icon_size(IconSize::Small) .shape(IconButtonShape::Square) @@ -1101,7 +1101,7 @@ impl RulesLibrary { inlay_hints_style: editor::make_inlay_hints_style( cx, ), - inline_completion_styles: + edit_prediction_styles: editor::make_suggestion_styles(cx), ..EditorStyle::default() }, @@ -1163,7 +1163,7 @@ impl RulesLibrary { }) .into_any() } else { - IconButton::new("delete-rule", IconName::TrashAlt) + IconButton::new("delete-rule", IconName::Trash) .icon_size(IconSize::Small) .tooltip(move |window, cx| { Tooltip::for_action( diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index c2590ec9b0..5d77a95027 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -228,16 +228,17 @@ impl Render for BufferSearchBar { if in_replace { key_context.add("in_replace"); } - let editor_border = if self.query_error.is_some() { + let query_border = if self.query_error.is_some() { Color::Error.color(cx) } else { cx.theme().colors().border }; + let replacement_border = cx.theme().colors().border; let container_width = window.viewport_size().width; let input_width = SearchInputWidth::calc_width(container_width); - let input_base_styles = || { + let input_base_styles = |border_color| { h_flex() .min_w_32() .w(input_width) @@ -246,7 +247,7 @@ impl Render for BufferSearchBar { .pr_1() .py_1() .border_1() - .border_color(editor_border) + .border_color(border_color) .rounded_lg() }; @@ -256,7 +257,7 @@ impl Render for BufferSearchBar { el.child(Label::new("Find in results").color(Color::Hint)) }) .child( - input_base_styles() + input_base_styles(query_border) .id("editor-scroll") .track_scroll(&self.editor_scroll_handle) .child(self.render_text_input(&self.query_editor, color_override, cx)) @@ -430,11 +431,13 @@ impl Render for BufferSearchBar { let replace_line = should_show_replace_input.then(|| { h_flex() .gap_2() - .child(input_base_styles().child(self.render_text_input( - &self.replacement_editor, - None, - cx, - ))) + .child( + input_base_styles(replacement_border).child(self.render_text_input( + &self.replacement_editor, + None, + cx, + )), + ) .child( h_flex() .min_w_64() @@ -700,7 +703,11 @@ impl BufferSearchBar { window: &mut Window, cx: &mut Context<Self>, ) -> Self { - let query_editor = cx.new(|cx| Editor::single_line(window, cx)); + let query_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_use_autoclose(false); + editor + }); cx.subscribe_in(&query_editor, window, Self::on_query_editor_event) .detach(); let replacement_editor = cx.new(|cx| Editor::single_line(window, cx)); @@ -771,6 +778,7 @@ impl BufferSearchBar { pub fn dismiss(&mut self, _: &Dismiss, window: &mut Window, cx: &mut Context<Self>) { self.dismissed = true; + self.query_error = None; for searchable_item in self.searchable_items_with_matches.keys() { if let Some(searchable_item) = WeakSearchableItemHandle::upgrade(searchable_item.as_ref(), cx) diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 57ca5e56b9..15c1099aec 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -195,6 +195,7 @@ pub struct ProjectSearch { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum InputPanel { Query, + Replacement, Exclude, Include, } @@ -354,8 +355,9 @@ impl ProjectSearch { while let Some(new_ranges) = new_ranges.next().await { project_search - .update(cx, |project_search, _| { + .update(cx, |project_search, cx| { project_search.match_ranges.extend(new_ranges); + cx.notify(); }) .ok()?; } @@ -1962,7 +1964,7 @@ impl Render for ProjectSearchBar { MultipleInputs, } - let input_base_styles = |base_style: BaseStyle| { + let input_base_styles = |base_style: BaseStyle, panel: InputPanel| { h_flex() .min_w_32() .map(|div| match base_style { @@ -1974,11 +1976,11 @@ impl Render for ProjectSearchBar { .pr_1() .py_1() .border_1() - .border_color(search.border_color_for(InputPanel::Query, cx)) + .border_color(search.border_color_for(panel, cx)) .rounded_lg() }; - let query_column = input_base_styles(BaseStyle::SingleInput) + let query_column = input_base_styles(BaseStyle::SingleInput, InputPanel::Query) .on_action(cx.listener(|this, action, window, cx| this.confirm(action, window, cx))) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) @@ -2167,7 +2169,7 @@ impl Render for ProjectSearchBar { .child(h_flex().min_w_64().child(mode_column).child(matches_column)); let replace_line = search.replace_enabled.then(|| { - let replace_column = input_base_styles(BaseStyle::SingleInput) + let replace_column = input_base_styles(BaseStyle::SingleInput, InputPanel::Replacement) .child(self.render_text_input(&search.replacement_editor, cx)); let focus_handle = search.replacement_editor.read(cx).focus_handle(cx); @@ -2241,7 +2243,7 @@ impl Render for ProjectSearchBar { .gap_2() .w(input_width) .child( - input_base_styles(BaseStyle::MultipleInputs) + input_base_styles(BaseStyle::MultipleInputs, InputPanel::Include) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) })) @@ -2251,7 +2253,7 @@ impl Render for ProjectSearchBar { .child(self.render_text_input(&search.included_files_editor, cx)), ) .child( - input_base_styles(BaseStyle::MultipleInputs) + input_base_styles(BaseStyle::MultipleInputs, InputPanel::Exclude) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) })) diff --git a/crates/semantic_index/src/project_index_debug_view.rs b/crates/semantic_index/src/project_index_debug_view.rs index 1b0d87fca0..8d6a49c45c 100644 --- a/crates/semantic_index/src/project_index_debug_view.rs +++ b/crates/semantic_index/src/project_index_debug_view.rs @@ -115,21 +115,9 @@ impl ProjectIndexDebugView { .collect::<Vec<_>>(); this.update(cx, |this, cx| { - let view = cx.entity().downgrade(); this.selected_path = Some(PathState { path: file_path, - list_state: ListState::new( - chunks.len(), - gpui::ListAlignment::Top, - px(100.), - move |ix, _, cx| { - if let Some(view) = view.upgrade() { - view.update(cx, |view, cx| view.render_chunk(ix, cx)) - } else { - div().into_any() - } - }, - ), + list_state: ListState::new(chunks.len(), gpui::ListAlignment::Top, px(100.)), chunks, }); cx.notify(); @@ -219,7 +207,13 @@ impl Render for ProjectIndexDebugView { cx.notify(); })), ) - .child(list(selected_path.list_state.clone()).size_full()) + .child( + list( + selected_path.list_state.clone(), + cx.processor(|this, ix, _, cx| this.render_chunk(ix, cx)), + ) + .size_full(), + ) .size_full() .into_any_element() } else { diff --git a/crates/settings/src/settings.rs b/crates/settings/src/settings.rs index 4e6bd94d92..afd4ea0890 100644 --- a/crates/settings/src/settings.rs +++ b/crates/settings/src/settings.rs @@ -7,7 +7,7 @@ mod settings_json; mod settings_store; mod vscode_import; -use gpui::App; +use gpui::{App, Global}; use rust_embed::RustEmbed; use std::{borrow::Cow, fmt, str}; use util::asset_str; @@ -27,6 +27,11 @@ pub use settings_store::{ }; pub use vscode_import::{VsCodeSettings, VsCodeSettingsSource}; +#[derive(Clone, Debug, PartialEq)] +pub struct ActiveSettingsProfileName(pub String); + +impl Global for ActiveSettingsProfileName {} + #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)] pub struct WorktreeId(usize); @@ -74,6 +79,7 @@ pub fn init(cx: &mut App) { .unwrap(); cx.set_global(settings); BaseKeymap::register(cx); + SettingsStore::observe_active_settings_profile_name(cx).detach(); } pub fn default_settings() -> Cow<'static, str> { diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 0d23385a68..bc42d2c886 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -2,7 +2,11 @@ use anyhow::{Context as _, Result}; use collections::{BTreeMap, HashMap, btree_map, hash_map}; use ec4rs::{ConfigParser, PropertiesSource, Section}; use fs::Fs; -use futures::{FutureExt, StreamExt, channel::mpsc, future::LocalBoxFuture}; +use futures::{ + FutureExt, StreamExt, + channel::{mpsc, oneshot}, + future::LocalBoxFuture, +}; use gpui::{App, AsyncApp, BorrowAppContext, Global, Task, UpdateGlobal}; use paths::{EDITORCONFIG_NAME, local_settings_file_relative_path, task_file_name}; @@ -26,8 +30,8 @@ use util::{ pub type EditorconfigProperties = ec4rs::Properties; use crate::{ - ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, WorktreeId, - parse_json_with_comments, update_value_in_json_text, + ActiveSettingsProfileName, ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, + WorktreeId, parse_json_with_comments, update_value_in_json_text, }; /// A value that can be defined as a user setting. @@ -122,6 +126,8 @@ pub struct SettingsSources<'a, T> { pub user: Option<&'a T>, /// The user settings for the current release channel. pub release_channel: Option<&'a T>, + /// The settings associated with an enabled settings profile + pub profile: Option<&'a T>, /// The server's settings. pub server: Option<&'a T>, /// The project settings, ordered from least specific to most specific. @@ -141,6 +147,7 @@ impl<'a, T: Serialize> SettingsSources<'a, T> { .chain(self.extensions) .chain(self.user) .chain(self.release_channel) + .chain(self.profile) .chain(self.server) .chain(self.project.iter().copied()) } @@ -282,6 +289,14 @@ impl SettingsStore { } } + pub fn observe_active_settings_profile_name(cx: &mut App) -> gpui::Subscription { + cx.observe_global::<ActiveSettingsProfileName>(|cx| { + Self::update_global(cx, |store, cx| { + store.recompute_values(None, cx).log_err(); + }); + }) + } + pub fn update<C, R>(cx: &mut C, f: impl FnOnce(&mut Self, &mut C) -> R) -> R where C: BorrowAppContext, @@ -321,6 +336,17 @@ impl SettingsStore { .log_err(); } + let mut profile_value = None; + if let Some(active_profile) = cx.try_global::<ActiveSettingsProfileName>() { + if let Some(profiles) = self.raw_user_settings.get("profiles") { + if let Some(profile_settings) = profiles.get(&active_profile.0) { + profile_value = setting_value + .deserialize_setting(profile_settings) + .log_err(); + } + } + } + let server_value = self .raw_server_settings .as_ref() @@ -340,6 +366,7 @@ impl SettingsStore { extensions: extension_value.as_ref(), user: user_value.as_ref(), release_channel: release_channel_value.as_ref(), + profile: profile_value.as_ref(), server: server_value.as_ref(), project: &[], }, @@ -402,6 +429,16 @@ impl SettingsStore { &self.raw_user_settings } + /// Get the configured settings profile names. + pub fn configured_settings_profiles(&self) -> impl Iterator<Item = &str> { + self.raw_user_settings + .get("profiles") + .and_then(|v| v.as_object()) + .into_iter() + .flat_map(|obj| obj.keys()) + .map(|s| s.as_str()) + } + /// Access the raw JSON value of the global settings. pub fn raw_global_settings(&self) -> Option<&Value> { self.raw_global_settings.as_ref() @@ -498,41 +535,64 @@ impl SettingsStore { .ok(); } - pub fn import_vscode_settings(&self, fs: Arc<dyn Fs>, vscode_settings: VsCodeSettings) { + pub fn import_vscode_settings( + &self, + fs: Arc<dyn Fs>, + vscode_settings: VsCodeSettings, + ) -> oneshot::Receiver<Result<()>> { + let (tx, rx) = oneshot::channel::<Result<()>>(); self.setting_file_updates_tx .unbounded_send(Box::new(move |cx: AsyncApp| { async move { - let old_text = Self::load_settings(&fs).await?; - let new_text = cx.read_global(|store: &SettingsStore, _cx| { - store.get_vscode_edits(old_text, &vscode_settings) - })?; - let settings_path = paths::settings_file().as_path(); - if fs.is_file(settings_path).await { - let resolved_path = - fs.canonicalize(settings_path).await.with_context(|| { - format!("Failed to canonicalize settings path {:?}", settings_path) - })?; + let res = async move { + let old_text = Self::load_settings(&fs).await?; + let new_text = cx.read_global(|store: &SettingsStore, _cx| { + store.get_vscode_edits(old_text, &vscode_settings) + })?; + let settings_path = paths::settings_file().as_path(); + if fs.is_file(settings_path).await { + let resolved_path = + fs.canonicalize(settings_path).await.with_context(|| { + format!( + "Failed to canonicalize settings path {:?}", + settings_path + ) + })?; - fs.atomic_write(resolved_path.clone(), new_text) - .await - .with_context(|| { - format!("Failed to write settings to file {:?}", resolved_path) - })?; - } else { - fs.atomic_write(settings_path.to_path_buf(), new_text) - .await - .with_context(|| { - format!("Failed to write settings to file {:?}", settings_path) - })?; + fs.atomic_write(resolved_path.clone(), new_text) + .await + .with_context(|| { + format!("Failed to write settings to file {:?}", resolved_path) + })?; + } else { + fs.atomic_write(settings_path.to_path_buf(), new_text) + .await + .with_context(|| { + format!("Failed to write settings to file {:?}", settings_path) + })?; + } + + anyhow::Ok(()) } + .await; - anyhow::Ok(()) + let new_res = match &res { + Ok(_) => anyhow::Ok(()), + Err(e) => Err(anyhow::anyhow!("Failed to write settings to file {:?}", e)), + }; + + _ = tx.send(new_res); + res } .boxed_local() })) .ok(); - } + rx + } +} + +impl SettingsStore { /// Updates the value of a setting in a JSON file, returning the new text /// for that JSON file. pub fn new_text_for_update<T: Settings>( @@ -1001,18 +1061,18 @@ impl SettingsStore { const ZED_SETTINGS: &str = "ZedSettings"; let zed_settings_ref = add_new_subschema(&mut generator, ZED_SETTINGS, combined_schema); - // add `ZedReleaseStageSettings` which is the same as `ZedSettings` except that unknown - // fields are rejected. - let mut zed_release_stage_settings = zed_settings_ref.clone(); - zed_release_stage_settings.insert("unevaluatedProperties".to_string(), false.into()); - let zed_release_stage_settings_ref = add_new_subschema( + // add `ZedSettingsOverride` which is the same as `ZedSettings` except that unknown + // fields are rejected. This is used for release stage settings and profiles. + let mut zed_settings_override = zed_settings_ref.clone(); + zed_settings_override.insert("unevaluatedProperties".to_string(), false.into()); + let zed_settings_override_ref = add_new_subschema( &mut generator, - "ZedReleaseStageSettings", - zed_release_stage_settings.to_value(), + "ZedSettingsOverride", + zed_settings_override.to_value(), ); // Remove `"additionalProperties": false` added by `DefaultDenyUnknownFields` so that - // unknown fields can be handled by the root schema and `ZedReleaseStageSettings`. + // unknown fields can be handled by the root schema and `ZedSettingsOverride`. let mut definitions = generator.take_definitions(true); definitions .get_mut(ZED_SETTINGS) @@ -1032,15 +1092,20 @@ impl SettingsStore { "$schema": meta_schema, "title": "Zed Settings", "unevaluatedProperties": false, - // ZedSettings + settings overrides for each release stage + // ZedSettings + settings overrides for each release stage / profiles "allOf": [ zed_settings_ref, { "properties": { - "dev": zed_release_stage_settings_ref, - "nightly": zed_release_stage_settings_ref, - "stable": zed_release_stage_settings_ref, - "preview": zed_release_stage_settings_ref, + "dev": zed_settings_override_ref, + "nightly": zed_settings_override_ref, + "stable": zed_settings_override_ref, + "preview": zed_settings_override_ref, + "profiles": { + "type": "object", + "description": "Configures any number of settings profiles.", + "additionalProperties": zed_settings_override_ref + } } } ], @@ -1099,6 +1164,16 @@ impl SettingsStore { } } + let mut profile_settings = None; + if let Some(active_profile) = cx.try_global::<ActiveSettingsProfileName>() { + if let Some(profiles) = self.raw_user_settings.get("profiles") { + if let Some(profile_json) = profiles.get(&active_profile.0) { + profile_settings = + setting_value.deserialize_setting(profile_json).log_err(); + } + } + } + // If the global settings file changed, reload the global value for the field. if changed_local_path.is_none() { if let Some(value) = setting_value @@ -1109,6 +1184,7 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &[], }, @@ -1161,6 +1237,7 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &project_settings_stack.iter().collect::<Vec<_>>(), }, @@ -1286,6 +1363,9 @@ impl<T: Settings> AnySettingValue for SettingValue<T> { release_channel: values .release_channel .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), + profile: values + .profile + .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), server: values .server .map(|value| value.0.downcast_ref::<T::FileContent>().unwrap()), diff --git a/crates/settings_profile_selector/Cargo.toml b/crates/settings_profile_selector/Cargo.toml new file mode 100644 index 0000000000..189272e54b --- /dev/null +++ b/crates/settings_profile_selector/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "settings_profile_selector" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/settings_profile_selector.rs" +doctest = false + +[dependencies] +fuzzy.workspace = true +gpui.workspace = true +picker.workspace = true +settings.workspace = true +ui.workspace = true +workspace-hack.workspace = true +workspace.workspace = true +zed_actions.workspace = true + +[dev-dependencies] +client = { workspace = true, features = ["test-support"] } +editor = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +menu.workspace = true +project = { workspace = true, features = ["test-support"] } +serde_json.workspace = true +settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_profile_selector/LICENSE-GPL b/crates/settings_profile_selector/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/settings_profile_selector/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/settings_profile_selector/src/settings_profile_selector.rs b/crates/settings_profile_selector/src/settings_profile_selector.rs new file mode 100644 index 0000000000..8a34c12051 --- /dev/null +++ b/crates/settings_profile_selector/src/settings_profile_selector.rs @@ -0,0 +1,581 @@ +use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; +use gpui::{ + App, Context, DismissEvent, Entity, EventEmitter, Focusable, Render, Task, WeakEntity, Window, +}; +use picker::{Picker, PickerDelegate}; +use settings::{ActiveSettingsProfileName, SettingsStore}; +use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*}; +use workspace::{ModalView, Workspace}; + +pub fn init(cx: &mut App) { + cx.on_action(|_: &zed_actions::settings_profile_selector::Toggle, cx| { + workspace::with_active_or_new_workspace(cx, |workspace, window, cx| { + toggle_settings_profile_selector(workspace, window, cx); + }); + }); +} + +fn toggle_settings_profile_selector( + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context<Workspace>, +) { + workspace.toggle_modal(window, cx, |window, cx| { + let delegate = SettingsProfileSelectorDelegate::new(cx.entity().downgrade(), window, cx); + SettingsProfileSelector::new(delegate, window, cx) + }); +} + +pub struct SettingsProfileSelector { + picker: Entity<Picker<SettingsProfileSelectorDelegate>>, +} + +impl ModalView for SettingsProfileSelector {} + +impl EventEmitter<DismissEvent> for SettingsProfileSelector {} + +impl Focusable for SettingsProfileSelector { + fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { + self.picker.focus_handle(cx) + } +} + +impl Render for SettingsProfileSelector { + fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement { + v_flex().w(rems(22.)).child(self.picker.clone()) + } +} + +impl SettingsProfileSelector { + pub fn new( + delegate: SettingsProfileSelectorDelegate, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); + Self { picker } + } +} + +pub struct SettingsProfileSelectorDelegate { + matches: Vec<StringMatch>, + profile_names: Vec<Option<String>>, + original_profile_name: Option<String>, + selected_profile_name: Option<String>, + selected_index: usize, + selection_completed: bool, + selector: WeakEntity<SettingsProfileSelector>, +} + +impl SettingsProfileSelectorDelegate { + fn new( + selector: WeakEntity<SettingsProfileSelector>, + _: &mut Window, + cx: &mut Context<SettingsProfileSelector>, + ) -> Self { + let settings_store = cx.global::<SettingsStore>(); + let mut profile_names: Vec<Option<String>> = settings_store + .configured_settings_profiles() + .map(|s| Some(s.to_string())) + .collect(); + profile_names.insert(0, None); + + let matches = profile_names + .iter() + .enumerate() + .map(|(ix, profile_name)| StringMatch { + candidate_id: ix, + score: 0.0, + positions: Default::default(), + string: display_name(profile_name), + }) + .collect(); + + let profile_name = cx + .try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()); + + let mut this = Self { + matches, + profile_names, + original_profile_name: profile_name.clone(), + selected_profile_name: None, + selected_index: 0, + selection_completed: false, + selector, + }; + + if let Some(profile_name) = profile_name { + this.select_if_matching(&profile_name); + } + + this + } + + fn select_if_matching(&mut self, profile_name: &str) { + self.selected_index = self + .matches + .iter() + .position(|mat| mat.string == profile_name) + .unwrap_or(self.selected_index); + } + + fn set_selected_profile( + &self, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) -> Option<String> { + let mat = self.matches.get(self.selected_index)?; + let profile_name = self.profile_names.get(mat.candidate_id)?; + return Self::update_active_profile_name_global(profile_name.clone(), cx); + } + + fn update_active_profile_name_global( + profile_name: Option<String>, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) -> Option<String> { + if let Some(profile_name) = profile_name { + cx.set_global(ActiveSettingsProfileName(profile_name.clone())); + return Some(profile_name.clone()); + } + + if cx.has_global::<ActiveSettingsProfileName>() { + cx.remove_global::<ActiveSettingsProfileName>(); + } + + None + } +} + +impl PickerDelegate for SettingsProfileSelectorDelegate { + type ListItem = ListItem; + + fn placeholder_text(&self, _: &mut Window, _: &mut App) -> std::sync::Arc<str> { + "Select a settings profile...".into() + } + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) { + self.selected_index = ix; + self.selected_profile_name = self.set_selected_profile(cx); + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) -> Task<()> { + let background = cx.background_executor().clone(); + let candidates = self + .profile_names + .iter() + .enumerate() + .map(|(id, profile_name)| StringMatchCandidate::new(id, &display_name(profile_name))) + .collect::<Vec<_>>(); + + cx.spawn_in(window, async move |this, cx| { + let matches = if query.is_empty() { + candidates + .into_iter() + .enumerate() + .map(|(index, candidate)| StringMatch { + candidate_id: index, + string: candidate.string, + positions: Vec::new(), + score: 0.0, + }) + .collect() + } else { + match_strings( + &candidates, + &query, + false, + true, + 100, + &Default::default(), + background, + ) + .await + }; + + this.update_in(cx, |this, _, cx| { + this.delegate.matches = matches; + this.delegate.selected_index = this + .delegate + .selected_index + .min(this.delegate.matches.len().saturating_sub(1)); + this.delegate.selected_profile_name = this.delegate.set_selected_profile(cx); + }) + .ok(); + }) + } + + fn confirm( + &mut self, + _: bool, + _: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) { + self.selection_completed = true; + self.selector + .update(cx, |_, cx| { + cx.emit(DismissEvent); + }) + .ok(); + } + + fn dismissed( + &mut self, + _: &mut Window, + cx: &mut Context<Picker<SettingsProfileSelectorDelegate>>, + ) { + if !self.selection_completed { + SettingsProfileSelectorDelegate::update_active_profile_name_global( + self.original_profile_name.clone(), + cx, + ); + } + self.selector.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _: &mut Window, + _: &mut Context<Picker<Self>>, + ) -> Option<Self::ListItem> { + let mat = &self.matches[ix]; + let profile_name = &self.profile_names[mat.candidate_id]; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .child(HighlightedLabel::new( + display_name(profile_name), + mat.positions.clone(), + )), + ) + } +} + +fn display_name(profile_name: &Option<String>) -> String { + profile_name.clone().unwrap_or("Disabled".into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use client; + use editor; + use gpui::{TestAppContext, UpdateGlobal, VisualTestContext}; + use language; + use menu::{Cancel, Confirm, SelectNext, SelectPrevious}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::Settings; + use theme::{self, ThemeSettings}; + use workspace::{self, AppState}; + use zed_actions::settings_profile_selector; + + async fn init_test( + profiles_json: serde_json::Value, + cx: &mut TestAppContext, + ) -> (Entity<Workspace>, &mut VisualTestContext) { + cx.update(|cx| { + let state = AppState::test(cx); + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + settings::init(cx); + theme::init(theme::LoadThemes::JustBase, cx); + ThemeSettings::register(cx); + client::init_settings(cx); + language::init(cx); + super::init(cx); + editor::init(cx); + workspace::init_settings(cx); + Project::init_settings(cx); + state + }); + + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + let settings_json = json!({ + "buffer_font_size": 10.0, + "profiles": profiles_json, + }); + + store + .set_user_settings(&settings_json.to_string(), cx) + .unwrap(); + }); + }); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, ["/test".as_ref()], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + cx.update(|_, cx| { + assert!(!cx.has_global::<ActiveSettingsProfileName>()); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + (workspace, cx) + } + + #[track_caller] + fn active_settings_profile_picker( + workspace: &Entity<Workspace>, + cx: &mut VisualTestContext, + ) -> Entity<Picker<SettingsProfileSelectorDelegate>> { + workspace.update(cx, |workspace, cx| { + workspace + .active_modal::<SettingsProfileSelector>(cx) + .expect("settings profile selector is not open") + .read(cx) + .picker + .clone() + }) + } + + #[gpui::test] + async fn test_settings_profile_selector_state(cx: &mut TestAppContext) { + let classroom_and_streaming_profile_name = "Classroom / Streaming".to_string(); + let demo_videos_profile_name = "Demo Videos".to_string(); + + let profiles_json = json!({ + classroom_and_streaming_profile_name.clone(): { + "buffer_font_size": 20.0, + }, + demo_videos_profile_name.clone(): { + "buffer_font_size": 15.0 + } + }); + let (workspace, cx) = init_test(profiles_json.clone(), cx).await; + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.matches.len(), 3); + assert_eq!(picker.delegate.matches[0].string, display_name(&None)); + assert_eq!( + picker.delegate.matches[1].string, + classroom_and_streaming_profile_name + ); + assert_eq!(picker.delegate.matches[2].string, demo_videos_profile_name); + assert_eq!(picker.delegate.matches.get(3), None); + + assert_eq!(picker.delegate.selected_index, 0); + assert_eq!(picker.delegate.selected_profile_name, None); + + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(Cancel); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(Cancel); + + cx.update(|_, cx| { + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 0); + assert_eq!(picker.delegate.selected_profile_name, None); + + assert_eq!( + cx.try_global::<ActiveSettingsProfileName>() + .map(|p| p.0.clone()), + None + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::<ActiveSettingsProfileName>(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + } +} diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index 02327045fd..a4c47081c6 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -30,7 +30,6 @@ menu.workspace = true notifications.workspace = true paths.workspace = true project.workspace = true -schemars.workspace = true search.workspace = true serde.workspace = true serde_json.workspace = true @@ -45,3 +44,10 @@ ui_input.workspace = true util.workspace = true workspace-hack.workspace = true workspace.workspace = true + +[dev-dependencies] +db = {"workspace"= true, "features" = ["test-support"]} +fs = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_ui/src/keybindings.rs b/crates/settings_ui/src/keybindings.rs index a0cbdb9680..60e527677a 100644 --- a/crates/settings_ui/src/keybindings.rs +++ b/crates/settings_ui/src/keybindings.rs @@ -11,11 +11,10 @@ use editor::{CompletionProvider, Editor, EditorEvent}; use fs::Fs; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ - Action, Animation, AnimationExt, AppContext as _, AsyncApp, Axis, ClickEvent, Context, - DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, Global, IsZero, - KeyContext, Keystroke, Modifiers, ModifiersChangedEvent, MouseButton, Point, ScrollStrategy, - ScrollWheelEvent, Stateful, StyledText, Subscription, Task, TextStyleRefinement, WeakEntity, - actions, anchored, deferred, div, + Action, AppContext as _, AsyncApp, Axis, ClickEvent, Context, DismissEvent, Entity, + EventEmitter, FocusHandle, Focusable, Global, IsZero, KeyContext, Keystroke, MouseButton, + Point, ScrollStrategy, ScrollWheelEvent, Stateful, StyledText, Subscription, Task, + TextStyleRefinement, WeakEntity, actions, anchored, deferred, div, }; use language::{Language, LanguageConfig, ToOffset as _}; use notifications::status_toast::{StatusToast, ToastIcon}; @@ -35,7 +34,10 @@ use workspace::{ use crate::{ keybindings::persistence::KEYBINDING_EDITORS, - ui_components::table::{ColumnWidths, ResizeBehavior, Table, TableInteractionState}, + ui_components::{ + keystroke_input::{ClearKeystrokes, KeystrokeInput, StartRecording, StopRecording}, + table::{ColumnWidths, ResizeBehavior, Table, TableInteractionState}, + }, }; const NO_ACTION_ARGUMENTS_TEXT: SharedString = SharedString::new_static("<no arguments>"); @@ -72,18 +74,6 @@ actions!( ] ); -actions!( - keystroke_input, - [ - /// Starts recording keystrokes - StartRecording, - /// Stops recording keystrokes - StopRecording, - /// Clears the recorded keystrokes - ClearKeystrokes, - ] -); - pub fn init(cx: &mut App) { let keymap_event_channel = KeymapEventChannel::new(); cx.set_global(keymap_event_channel); @@ -384,6 +374,14 @@ impl Focusable for KeymapEditor { } } } +/// Helper function to check if two keystroke sequences match exactly +fn keystrokes_match_exactly(keystrokes1: &[Keystroke], keystrokes2: &[Keystroke]) -> bool { + keystrokes1.len() == keystrokes2.len() + && keystrokes1 + .iter() + .zip(keystrokes2) + .all(|(k1, k2)| k1.key == k2.key && k1.modifiers == k2.modifiers) +} impl KeymapEditor { fn new(workspace: WeakEntity<Workspace>, window: &mut Window, cx: &mut Context<Self>) -> Self { @@ -393,7 +391,7 @@ impl KeymapEditor { let keystroke_editor = cx.new(|cx| { let mut keystroke_editor = KeystrokeInput::new(None, window, cx); - keystroke_editor.search = true; + keystroke_editor.set_search(true); keystroke_editor }); @@ -559,31 +557,41 @@ impl KeymapEditor { .keystrokes() .is_some_and(|keystrokes| { if exact_match { - keystroke_query.len() == keystrokes.len() - && keystroke_query.iter().zip(keystrokes).all( - |(query, keystroke)| { - query.key == keystroke.key - && query.modifiers == keystroke.modifiers - }, - ) + keystrokes_match_exactly(&keystroke_query, keystrokes) + } else if keystroke_query.len() > keystrokes.len() { + return false; } else { - let key_press_query = - KeyPressIterator::new(keystroke_query.as_slice()); - let mut last_match_idx = 0; + for keystroke_offset in 0..keystrokes.len() { + let mut found_count = 0; + let mut query_cursor = 0; + let mut keystroke_cursor = keystroke_offset; + while query_cursor < keystroke_query.len() + && keystroke_cursor < keystrokes.len() + { + let query = &keystroke_query[query_cursor]; + let keystroke = &keystrokes[keystroke_cursor]; + let matches = + query.modifiers.is_subset_of(&keystroke.modifiers) + && ((query.key.is_empty() + || query.key == keystroke.key) + && query + .key_char + .as_ref() + .map_or(true, |q_kc| { + q_kc == &keystroke.key + })); + if matches { + found_count += 1; + query_cursor += 1; + } + keystroke_cursor += 1; + } - key_press_query.into_iter().all(|key| { - let key_presses = KeyPressIterator::new(keystrokes); - key_presses.into_iter().enumerate().any( - |(index, keystroke)| { - if last_match_idx > index || keystroke != key { - return false; - } - - last_match_idx = index; - true - }, - ) - }) + if found_count == keystroke_query.len() { + return true; + } + } + return false; } }) }); @@ -1232,11 +1240,14 @@ impl KeymapEditor { match self.search_mode { SearchMode::KeyStroke { .. } => { - window.focus(&self.keystroke_editor.read(cx).recording_focus_handle(cx)); + self.keystroke_editor.update(cx, |editor, cx| { + editor.start_recording(&StartRecording, window, cx); + }); } SearchMode::Normal => { self.keystroke_editor.update(cx, |editor, cx| { - editor.clear_keystrokes(&ClearKeystrokes, window, cx) + editor.stop_recording(&StopRecording, window, cx); + editor.clear_keystrokes(&ClearKeystrokes, window, cx); }); window.focus(&self.filter_editor.focus_handle(cx)); } @@ -1671,7 +1682,7 @@ impl Render for KeymapEditor { move |window, cx| this.read(cx).render_no_matches_hint(window, cx) }) .column_widths([ - DefiniteLength::Absolute(AbsoluteLength::Pixels(px(40.))), + DefiniteLength::Absolute(AbsoluteLength::Pixels(px(36.))), DefiniteLength::Fraction(0.25), DefiniteLength::Fraction(0.20), DefiniteLength::Fraction(0.14), @@ -1746,6 +1757,7 @@ impl Render for KeymapEditor { }, ) .into_any_element(); + let keystrokes = binding.ui_key_binding().cloned().map_or( binding .keystroke_text() @@ -1754,6 +1766,7 @@ impl Render for KeymapEditor { .into_any_element(), IntoElement::into_any_element, ); + let action_arguments = match binding.action().arguments.clone() { Some(arguments) => arguments.into_any_element(), @@ -1766,6 +1779,7 @@ impl Render for KeymapEditor { } } }; + let context = binding.context().cloned().map_or( gpui::Empty.into_any_element(), |context| { @@ -1790,11 +1804,13 @@ impl Render for KeymapEditor { .into_any_element() }, ); + let source = binding .keybind_source() .map(|source| source.name()) .unwrap_or_default() .into_any_element(); + Some([ icon.into_any_element(), action, @@ -1841,7 +1857,7 @@ impl Render for KeymapEditor { .on_click(cx.listener( move |this, event: &ClickEvent, window, cx| { this.select_index(row_index, None, window, cx); - if event.up.click_count == 2 { + if event.click_count() == 2 { this.open_edit_keybinding_modal( false, window, cx, ); @@ -2326,8 +2342,50 @@ impl KeybindingEditorModal { self.save_or_display_error(cx); } - fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) { - cx.emit(DismissEvent) + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) { + cx.emit(DismissEvent); + } + + fn get_matching_bindings_count(&self, cx: &Context<Self>) -> usize { + let current_keystrokes = self.keybind_editor.read(cx).keystrokes().to_vec(); + + if current_keystrokes.is_empty() { + return 0; + } + + self.keymap_editor + .read(cx) + .keybindings + .iter() + .enumerate() + .filter(|(idx, binding)| { + // Don't count the binding we're currently editing + if !self.creating && *idx == self.editing_keybind_idx { + return false; + } + + binding + .keystrokes() + .map(|keystrokes| keystrokes_match_exactly(keystrokes, ¤t_keystrokes)) + .unwrap_or(false) + }) + .count() + } + + fn show_matching_bindings(&mut self, _window: &mut Window, cx: &mut Context<Self>) { + let keystrokes = self.keybind_editor.read(cx).keystrokes().to_vec(); + + // Dismiss the modal + cx.emit(DismissEvent); + + // Update the keymap editor to show matching keystrokes + self.keymap_editor.update(cx, |editor, cx| { + editor.filter_state = FilterState::All; + editor.search_mode = SearchMode::KeyStroke { exact_match: true }; + editor.keystroke_editor.update(cx, |keystroke_editor, cx| { + keystroke_editor.set_keystrokes(keystrokes, cx); + }); + }); } } @@ -2342,6 +2400,7 @@ fn remove_key_char(Keystroke { modifiers, key, .. }: Keystroke) -> Keystroke { impl Render for KeybindingEditorModal { fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let theme = cx.theme().colors(); + let matching_bindings_count = self.get_matching_bindings_count(cx); v_flex() .w(rems(34.)) @@ -2413,19 +2472,37 @@ impl Render for KeybindingEditorModal { ), ) .footer( - ModalFooter::new().end_slot( - h_flex() - .gap_1() - .child( - Button::new("cancel", "Cancel") - .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), - ) - .child(Button::new("save-btn", "Save").on_click(cx.listener( - |this, _event, _window, cx| { - this.save_or_display_error(cx); - }, - ))), - ), + ModalFooter::new() + .start_slot( + div().when(matching_bindings_count > 0, |this| { + this.child( + Button::new("show_matching", format!( + "There {} {} {} with the same keystrokes. Click to view", + if matching_bindings_count == 1 { "is" } else { "are" }, + matching_bindings_count, + if matching_bindings_count == 1 { "binding" } else { "bindings" } + )) + .style(ButtonStyle::Transparent) + .color(Color::Accent) + .on_click(cx.listener(|this, _, window, cx| { + this.show_matching_bindings(window, cx); + })) + ) + }) + ) + .end_slot( + h_flex() + .gap_1() + .child( + Button::new("cancel", "Cancel") + .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + ) + .child(Button::new("save-btn", "Save").on_click(cx.listener( + |this, _event, _window, cx| { + this.save_or_display_error(cx); + }, + ))), + ), ), ) } @@ -2955,516 +3032,6 @@ async fn remove_keybinding( Ok(()) } -#[derive(PartialEq, Eq, Debug, Copy, Clone)] -enum CloseKeystrokeResult { - Partial, - Close, - None, -} - -#[derive(PartialEq, Eq, Debug, Clone)] -enum KeyPress<'a> { - Alt, - Control, - Function, - Shift, - Platform, - Key(&'a String), -} - -struct KeystrokeInput { - keystrokes: Vec<Keystroke>, - placeholder_keystrokes: Option<Vec<Keystroke>>, - outer_focus_handle: FocusHandle, - inner_focus_handle: FocusHandle, - intercept_subscription: Option<Subscription>, - _focus_subscriptions: [Subscription; 2], - search: bool, - /// Handles tripe escape to stop recording - close_keystrokes: Option<Vec<Keystroke>>, - close_keystrokes_start: Option<usize>, -} - -impl KeystrokeInput { - const KEYSTROKE_COUNT_MAX: usize = 3; - - fn new( - placeholder_keystrokes: Option<Vec<Keystroke>>, - window: &mut Window, - cx: &mut Context<Self>, - ) -> Self { - let outer_focus_handle = cx.focus_handle(); - let inner_focus_handle = cx.focus_handle(); - let _focus_subscriptions = [ - cx.on_focus_in(&inner_focus_handle, window, Self::on_inner_focus_in), - cx.on_focus_out(&inner_focus_handle, window, Self::on_inner_focus_out), - ]; - Self { - keystrokes: Vec::new(), - placeholder_keystrokes, - inner_focus_handle, - outer_focus_handle, - intercept_subscription: None, - _focus_subscriptions, - search: false, - close_keystrokes: None, - close_keystrokes_start: None, - } - } - - fn set_keystrokes(&mut self, keystrokes: Vec<Keystroke>, cx: &mut Context<Self>) { - self.keystrokes = keystrokes; - self.keystrokes_changed(cx); - } - - fn dummy(modifiers: Modifiers) -> Keystroke { - return Keystroke { - modifiers, - key: "".to_string(), - key_char: None, - }; - } - - fn keystrokes_changed(&self, cx: &mut Context<Self>) { - cx.emit(()); - cx.notify(); - } - - fn key_context() -> KeyContext { - let mut key_context = KeyContext::new_with_defaults(); - key_context.add("KeystrokeInput"); - key_context - } - - fn handle_possible_close_keystroke( - &mut self, - keystroke: &Keystroke, - window: &mut Window, - cx: &mut Context<Self>, - ) -> CloseKeystrokeResult { - let Some(keybind_for_close_action) = window - .highest_precedence_binding_for_action_in_context(&StopRecording, Self::key_context()) - else { - log::trace!("No keybinding to stop recording keystrokes in keystroke input"); - self.close_keystrokes.take(); - self.close_keystrokes_start.take(); - return CloseKeystrokeResult::None; - }; - let action_keystrokes = keybind_for_close_action.keystrokes(); - - if let Some(mut close_keystrokes) = self.close_keystrokes.take() { - let mut index = 0; - - while index < action_keystrokes.len() && index < close_keystrokes.len() { - if !close_keystrokes[index].should_match(&action_keystrokes[index]) { - break; - } - index += 1; - } - if index == close_keystrokes.len() { - if index >= action_keystrokes.len() { - self.close_keystrokes_start.take(); - return CloseKeystrokeResult::None; - } - if keystroke.should_match(&action_keystrokes[index]) { - if action_keystrokes.len() >= 1 && index == action_keystrokes.len() - 1 { - self.stop_recording(&StopRecording, window, cx); - return CloseKeystrokeResult::Close; - } else { - close_keystrokes.push(keystroke.clone()); - self.close_keystrokes = Some(close_keystrokes); - return CloseKeystrokeResult::Partial; - } - } else { - self.close_keystrokes_start.take(); - return CloseKeystrokeResult::None; - } - } - } else if let Some(first_action_keystroke) = action_keystrokes.first() - && keystroke.should_match(first_action_keystroke) - { - self.close_keystrokes = Some(vec![keystroke.clone()]); - return CloseKeystrokeResult::Partial; - } - self.close_keystrokes_start.take(); - return CloseKeystrokeResult::None; - } - - fn on_modifiers_changed( - &mut self, - event: &ModifiersChangedEvent, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - let keystrokes_len = self.keystrokes.len(); - - if let Some(last) = self.keystrokes.last_mut() - && last.key.is_empty() - && keystrokes_len <= Self::KEYSTROKE_COUNT_MAX - { - if self.search { - last.modifiers = last.modifiers.xor(&event.modifiers); - } else if !event.modifiers.modified() { - self.keystrokes.pop(); - } else { - last.modifiers = event.modifiers; - } - - self.keystrokes_changed(cx); - } else if keystrokes_len < Self::KEYSTROKE_COUNT_MAX { - self.keystrokes.push(Self::dummy(event.modifiers)); - self.keystrokes_changed(cx); - } - cx.stop_propagation(); - } - - fn handle_keystroke( - &mut self, - keystroke: &Keystroke, - window: &mut Window, - cx: &mut Context<Self>, - ) { - let close_keystroke_result = self.handle_possible_close_keystroke(keystroke, window, cx); - if close_keystroke_result != CloseKeystrokeResult::Close { - let key_len = self.keystrokes.len(); - if let Some(last) = self.keystrokes.last_mut() - && last.key.is_empty() - && key_len <= Self::KEYSTROKE_COUNT_MAX - { - if self.search { - last.key = keystroke.key.clone(); - if close_keystroke_result == CloseKeystrokeResult::Partial - && self.close_keystrokes_start.is_none() - { - self.close_keystrokes_start = Some(self.keystrokes.len() - 1); - } - self.keystrokes_changed(cx); - cx.stop_propagation(); - return; - } else { - self.keystrokes.pop(); - } - } - if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX { - if close_keystroke_result == CloseKeystrokeResult::Partial - && self.close_keystrokes_start.is_none() - { - self.close_keystrokes_start = Some(self.keystrokes.len()); - } - self.keystrokes.push(keystroke.clone()); - if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX { - self.keystrokes.push(Self::dummy(keystroke.modifiers)); - } - } else if close_keystroke_result != CloseKeystrokeResult::Partial { - self.clear_keystrokes(&ClearKeystrokes, window, cx); - } - } - self.keystrokes_changed(cx); - cx.stop_propagation(); - } - - fn on_inner_focus_in(&mut self, _window: &mut Window, cx: &mut Context<Self>) { - if self.intercept_subscription.is_none() { - let listener = cx.listener(|this, event: &gpui::KeystrokeEvent, window, cx| { - this.handle_keystroke(&event.keystroke, window, cx); - }); - self.intercept_subscription = Some(cx.intercept_keystrokes(listener)) - } - } - - fn on_inner_focus_out( - &mut self, - _event: gpui::FocusOutEvent, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - self.intercept_subscription.take(); - cx.notify(); - } - - fn keystrokes(&self) -> &[Keystroke] { - if let Some(placeholders) = self.placeholder_keystrokes.as_ref() - && self.keystrokes.is_empty() - { - return placeholders; - } - if !self.search - && self - .keystrokes - .last() - .map_or(false, |last| last.key.is_empty()) - { - return &self.keystrokes[..self.keystrokes.len() - 1]; - } - return &self.keystrokes; - } - - fn render_keystrokes(&self, is_recording: bool) -> impl Iterator<Item = Div> { - let keystrokes = if let Some(placeholders) = self.placeholder_keystrokes.as_ref() - && self.keystrokes.is_empty() - { - if is_recording { - &[] - } else { - placeholders.as_slice() - } - } else { - &self.keystrokes - }; - keystrokes.iter().map(move |keystroke| { - h_flex().children(ui::render_keystroke( - keystroke, - Some(Color::Default), - Some(rems(0.875).into()), - ui::PlatformStyle::platform(), - false, - )) - }) - } - - fn recording_focus_handle(&self, _cx: &App) -> FocusHandle { - self.inner_focus_handle.clone() - } - - fn start_recording(&mut self, _: &StartRecording, window: &mut Window, cx: &mut Context<Self>) { - if !self.outer_focus_handle.is_focused(window) { - return; - } - self.clear_keystrokes(&ClearKeystrokes, window, cx); - window.focus(&self.inner_focus_handle); - cx.notify(); - } - - fn stop_recording(&mut self, _: &StopRecording, window: &mut Window, cx: &mut Context<Self>) { - if !self.inner_focus_handle.is_focused(window) { - return; - } - window.focus(&self.outer_focus_handle); - if let Some(close_keystrokes_start) = self.close_keystrokes_start.take() - && close_keystrokes_start < self.keystrokes.len() - { - self.keystrokes.drain(close_keystrokes_start..); - } - self.close_keystrokes.take(); - cx.notify(); - } - - fn clear_keystrokes( - &mut self, - _: &ClearKeystrokes, - _window: &mut Window, - cx: &mut Context<Self>, - ) { - self.keystrokes.clear(); - self.keystrokes_changed(cx); - } -} - -impl EventEmitter<()> for KeystrokeInput {} - -impl Focusable for KeystrokeInput { - fn focus_handle(&self, _cx: &App) -> FocusHandle { - self.outer_focus_handle.clone() - } -} - -impl Render for KeystrokeInput { - fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { - let colors = cx.theme().colors(); - let is_focused = self.outer_focus_handle.contains_focused(window, cx); - let is_recording = self.inner_focus_handle.is_focused(window); - - let horizontal_padding = rems_from_px(64.); - - let recording_bg_color = colors - .editor_background - .blend(colors.text_accent.opacity(0.1)); - - let recording_pulse = |color: Color| { - Icon::new(IconName::Circle) - .size(IconSize::Small) - .color(Color::Error) - .with_animation( - "recording-pulse", - Animation::new(std::time::Duration::from_secs(2)) - .repeat() - .with_easing(gpui::pulsating_between(0.4, 0.8)), - { - let color = color.color(cx); - move |this, delta| this.color(Color::Custom(color.opacity(delta))) - }, - ) - }; - - let recording_indicator = h_flex() - .h_4() - .pr_1() - .gap_0p5() - .border_1() - .border_color(colors.border) - .bg(colors - .editor_background - .blend(colors.text_accent.opacity(0.1))) - .rounded_sm() - .child(recording_pulse(Color::Error)) - .child( - Label::new("REC") - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Error), - ); - - let search_indicator = h_flex() - .h_4() - .pr_1() - .gap_0p5() - .border_1() - .border_color(colors.border) - .bg(colors - .editor_background - .blend(colors.text_accent.opacity(0.1))) - .rounded_sm() - .child(recording_pulse(Color::Accent)) - .child( - Label::new("SEARCH") - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Accent), - ); - - let record_icon = if self.search { - IconName::MagnifyingGlass - } else { - IconName::PlayFilled - }; - - h_flex() - .id("keystroke-input") - .track_focus(&self.outer_focus_handle) - .py_2() - .px_3() - .gap_2() - .min_h_10() - .w_full() - .flex_1() - .justify_between() - .rounded_lg() - .overflow_hidden() - .map(|this| { - if is_recording { - this.bg(recording_bg_color) - } else { - this.bg(colors.editor_background) - } - }) - .border_1() - .border_color(colors.border_variant) - .when(is_focused, |parent| { - parent.border_color(colors.border_focused) - }) - .key_context(Self::key_context()) - .on_action(cx.listener(Self::start_recording)) - .on_action(cx.listener(Self::stop_recording)) - .child( - h_flex() - .w(horizontal_padding) - .gap_0p5() - .justify_start() - .flex_none() - .when(is_recording, |this| { - this.map(|this| { - if self.search { - this.child(search_indicator) - } else { - this.child(recording_indicator) - } - }) - }), - ) - .child( - h_flex() - .id("keystroke-input-inner") - .track_focus(&self.inner_focus_handle) - .on_modifiers_changed(cx.listener(Self::on_modifiers_changed)) - .size_full() - .when(!self.search, |this| { - this.focus(|mut style| { - style.border_color = Some(colors.border_focused); - style - }) - }) - .w_full() - .min_w_0() - .justify_center() - .flex_wrap() - .gap(ui::DynamicSpacing::Base04.rems(cx)) - .children(self.render_keystrokes(is_recording)), - ) - .child( - h_flex() - .w(horizontal_padding) - .gap_0p5() - .justify_end() - .flex_none() - .map(|this| { - if is_recording { - this.child( - IconButton::new("stop-record-btn", IconName::StopFilled) - .shape(ui::IconButtonShape::Square) - .map(|this| { - this.tooltip(Tooltip::for_action_title( - if self.search { - "Stop Searching" - } else { - "Stop Recording" - }, - &StopRecording, - )) - }) - .icon_color(Color::Error) - .on_click(cx.listener(|this, _event, window, cx| { - this.stop_recording(&StopRecording, window, cx); - })), - ) - } else { - this.child( - IconButton::new("record-btn", record_icon) - .shape(ui::IconButtonShape::Square) - .map(|this| { - this.tooltip(Tooltip::for_action_title( - if self.search { - "Start Searching" - } else { - "Start Recording" - }, - &StartRecording, - )) - }) - .when(!is_focused, |this| this.icon_color(Color::Muted)) - .on_click(cx.listener(|this, _event, window, cx| { - this.start_recording(&StartRecording, window, cx); - })), - ) - } - }) - .child( - IconButton::new("clear-btn", IconName::Delete) - .shape(ui::IconButtonShape::Square) - .tooltip(Tooltip::for_action_title( - "Clear Keystrokes", - &ClearKeystrokes, - )) - .when(!is_recording || !is_focused, |this| { - this.icon_color(Color::Muted) - }) - .on_click(cx.listener(|this, _event, window, cx| { - this.clear_keystrokes(&ClearKeystrokes, window, cx); - })), - ), - ) - } -} - fn collect_contexts_from_assets() -> Vec<SharedString> { let mut keymap_assets = vec![ util::asset_str::<SettingsAssets>(settings::DEFAULT_KEYMAP_PATH), @@ -3633,72 +3200,3 @@ mod persistence { } } } - -/// Iterator that yields KeyPress values from a slice of Keystrokes -struct KeyPressIterator<'a> { - keystrokes: &'a [Keystroke], - current_keystroke_index: usize, - current_key_press_index: usize, -} - -impl<'a> KeyPressIterator<'a> { - fn new(keystrokes: &'a [Keystroke]) -> Self { - Self { - keystrokes, - current_keystroke_index: 0, - current_key_press_index: 0, - } - } -} - -impl<'a> Iterator for KeyPressIterator<'a> { - type Item = KeyPress<'a>; - - fn next(&mut self) -> Option<Self::Item> { - loop { - let keystroke = self.keystrokes.get(self.current_keystroke_index)?; - - match self.current_key_press_index { - 0 => { - self.current_key_press_index = 1; - if keystroke.modifiers.platform { - return Some(KeyPress::Platform); - } - } - 1 => { - self.current_key_press_index = 2; - if keystroke.modifiers.alt { - return Some(KeyPress::Alt); - } - } - 2 => { - self.current_key_press_index = 3; - if keystroke.modifiers.control { - return Some(KeyPress::Control); - } - } - 3 => { - self.current_key_press_index = 4; - if keystroke.modifiers.shift { - return Some(KeyPress::Shift); - } - } - 4 => { - self.current_key_press_index = 5; - if keystroke.modifiers.function { - return Some(KeyPress::Function); - } - } - _ => { - self.current_keystroke_index += 1; - self.current_key_press_index = 0; - - if keystroke.key.is_empty() { - continue; - } - return Some(KeyPress::Key(&keystroke.key)); - } - } - } - } -} diff --git a/crates/settings_ui/src/settings_ui.rs b/crates/settings_ui/src/settings_ui.rs index 2f0abb4789..3022cc7142 100644 --- a/crates/settings_ui/src/settings_ui.rs +++ b/crates/settings_ui/src/settings_ui.rs @@ -1,20 +1,12 @@ mod appearance_settings_controls; use std::any::TypeId; -use std::sync::Arc; use command_palette_hooks::CommandPaletteFilter; use editor::EditorSettingsControls; use feature_flags::{FeatureFlag, FeatureFlagViewExt}; -use fs::Fs; -use gpui::{ - Action, App, AsyncWindowContext, Entity, EventEmitter, FocusHandle, Focusable, Task, actions, -}; -use schemars::JsonSchema; -use serde::Deserialize; -use settings::{SettingsStore, VsCodeSettingsSource}; +use gpui::{App, Entity, EventEmitter, FocusHandle, Focusable, actions}; use ui::prelude::*; -use util::truncate_and_remove_front; use workspace::item::{Item, ItemEvent}; use workspace::{Workspace, with_active_or_new_workspace}; @@ -29,23 +21,6 @@ impl FeatureFlag for SettingsUiFeatureFlag { const NAME: &'static str = "settings-ui"; } -/// Imports settings from Visual Studio Code. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportVsCodeSettings { - #[serde(default)] - pub skip_prompt: bool, -} - -/// Imports settings from Cursor editor. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportCursorSettings { - #[serde(default)] - pub skip_prompt: bool, -} actions!( zed, [ @@ -72,45 +47,11 @@ pub fn init(cx: &mut App) { }); }); - cx.observe_new(|workspace: &mut Workspace, window, cx| { + cx.observe_new(|_workspace: &mut Workspace, window, cx| { let Some(window) = window else { return; }; - workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { - let fs = <dyn Fs>::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::VsCode, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - - workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { - let fs = <dyn Fs>::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::Cursor, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - let settings_ui_actions = [TypeId::of::<OpenSettingsEditor>()]; CommandPaletteFilter::update_global(cx, |filter, _cx| { @@ -138,57 +79,6 @@ pub fn init(cx: &mut App) { keybindings::init(cx); } -async fn handle_import_vscode_settings( - source: VsCodeSettingsSource, - skip_prompt: bool, - fs: Arc<dyn Fs>, - cx: &mut AsyncWindowContext, -) { - let vscode_settings = - match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { - Ok(vscode_settings) => vscode_settings, - Err(err) => { - log::error!("{err}"); - let _ = cx.prompt( - gpui::PromptLevel::Info, - &format!("Could not find or load a {source} settings file"), - None, - &["Ok"], - ); - return; - } - }; - - let prompt = if skip_prompt { - Task::ready(Some(0)) - } else { - let prompt = cx.prompt( - gpui::PromptLevel::Warning, - &format!( - "Importing {} settings may overwrite your existing settings. \ - Will import settings from {}", - vscode_settings.source, - truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), - ), - None, - &["Ok", "Cancel"], - ); - cx.spawn(async move |_| prompt.await.ok()) - }; - if prompt.await != Some(0) { - return; - } - - cx.update(|_, cx| { - let source = vscode_settings.source; - let path = vscode_settings.path.clone(); - cx.global::<SettingsStore>() - .import_vscode_settings(fs, vscode_settings); - log::info!("Imported {source} settings from {}", path.display()); - }) - .ok(); -} - pub struct SettingsPage { focus_handle: FocusHandle, } diff --git a/crates/settings_ui/src/ui_components/keystroke_input.rs b/crates/settings_ui/src/ui_components/keystroke_input.rs new file mode 100644 index 0000000000..03d27d0ab9 --- /dev/null +++ b/crates/settings_ui/src/ui_components/keystroke_input.rs @@ -0,0 +1,1388 @@ +use gpui::{ + Animation, AnimationExt, Context, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, + Keystroke, Modifiers, ModifiersChangedEvent, Subscription, Task, actions, +}; +use ui::{ + ActiveTheme as _, Color, IconButton, IconButtonShape, IconName, IconSize, Label, LabelSize, + ParentElement as _, Render, Styled as _, Tooltip, Window, prelude::*, +}; + +actions!( + keystroke_input, + [ + /// Starts recording keystrokes + StartRecording, + /// Stops recording keystrokes + StopRecording, + /// Clears the recorded keystrokes + ClearKeystrokes, + ] +); + +const KEY_CONTEXT_VALUE: &'static str = "KeystrokeInput"; + +const CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT: std::time::Duration = + std::time::Duration::from_millis(300); + +enum CloseKeystrokeResult { + Partial, + Close, + None, +} + +impl PartialEq for CloseKeystrokeResult { + fn eq(&self, other: &Self) -> bool { + matches!( + (self, other), + (CloseKeystrokeResult::Partial, CloseKeystrokeResult::Partial) + | (CloseKeystrokeResult::Close, CloseKeystrokeResult::Close) + | (CloseKeystrokeResult::None, CloseKeystrokeResult::None) + ) + } +} + +pub struct KeystrokeInput { + keystrokes: Vec<Keystroke>, + placeholder_keystrokes: Option<Vec<Keystroke>>, + outer_focus_handle: FocusHandle, + inner_focus_handle: FocusHandle, + intercept_subscription: Option<Subscription>, + _focus_subscriptions: [Subscription; 2], + search: bool, + /// The sequence of close keystrokes being typed + close_keystrokes: Option<Vec<Keystroke>>, + close_keystrokes_start: Option<usize>, + previous_modifiers: Modifiers, + /// In order to support inputting keystrokes that end with a prefix of the + /// close keybind keystrokes, we clear the close keystroke capture info + /// on a timeout after a close keystroke is pressed + /// + /// e.g. if close binding is `esc esc esc` and user wants to search for + /// `ctrl-g esc`, after entering the `ctrl-g esc`, hitting `esc` twice would + /// stop recording because of the sequence of three escapes making it + /// impossible to search for anything ending in `esc` + clear_close_keystrokes_timer: Option<Task<()>>, + #[cfg(test)] + recording: bool, +} + +impl KeystrokeInput { + const KEYSTROKE_COUNT_MAX: usize = 3; + + pub fn new( + placeholder_keystrokes: Option<Vec<Keystroke>>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Self { + let outer_focus_handle = cx.focus_handle(); + let inner_focus_handle = cx.focus_handle(); + let _focus_subscriptions = [ + cx.on_focus_in(&inner_focus_handle, window, Self::on_inner_focus_in), + cx.on_focus_out(&inner_focus_handle, window, Self::on_inner_focus_out), + ]; + Self { + keystrokes: Vec::new(), + placeholder_keystrokes, + inner_focus_handle, + outer_focus_handle, + intercept_subscription: None, + _focus_subscriptions, + search: false, + close_keystrokes: None, + close_keystrokes_start: None, + previous_modifiers: Modifiers::default(), + clear_close_keystrokes_timer: None, + #[cfg(test)] + recording: false, + } + } + + pub fn set_keystrokes(&mut self, keystrokes: Vec<Keystroke>, cx: &mut Context<Self>) { + self.keystrokes = keystrokes; + self.keystrokes_changed(cx); + } + + pub fn set_search(&mut self, search: bool) { + self.search = search; + } + + pub fn keystrokes(&self) -> &[Keystroke] { + if let Some(placeholders) = self.placeholder_keystrokes.as_ref() + && self.keystrokes.is_empty() + { + return placeholders; + } + if !self.search + && self + .keystrokes + .last() + .map_or(false, |last| last.key.is_empty()) + { + return &self.keystrokes[..self.keystrokes.len() - 1]; + } + return &self.keystrokes; + } + + fn dummy(modifiers: Modifiers) -> Keystroke { + return Keystroke { + modifiers, + key: "".to_string(), + key_char: None, + }; + } + + fn keystrokes_changed(&self, cx: &mut Context<Self>) { + cx.emit(()); + cx.notify(); + } + + fn key_context() -> KeyContext { + let mut key_context = KeyContext::default(); + key_context.add(KEY_CONTEXT_VALUE); + key_context + } + + fn determine_stop_recording_binding(window: &mut Window) -> Option<gpui::KeyBinding> { + if cfg!(test) { + Some(gpui::KeyBinding::new( + "escape escape escape", + StopRecording, + Some(KEY_CONTEXT_VALUE), + )) + } else { + window.highest_precedence_binding_for_action_in_context( + &StopRecording, + Self::key_context(), + ) + } + } + + fn upsert_close_keystrokes_start(&mut self, start: usize, cx: &mut Context<Self>) { + if self.close_keystrokes_start.is_some() { + return; + } + self.close_keystrokes_start = Some(start); + self.update_clear_close_keystrokes_timer(cx); + } + + fn update_clear_close_keystrokes_timer(&mut self, cx: &mut Context<Self>) { + self.clear_close_keystrokes_timer = Some(cx.spawn(async |this, cx| { + cx.background_executor() + .timer(CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT) + .await; + this.update(cx, |this, _cx| { + this.end_close_keystrokes_capture(); + }) + .ok(); + })); + } + + /// Interrupt the capture of close keystrokes, but do not clear the close keystrokes + /// from the input + fn end_close_keystrokes_capture(&mut self) -> Option<usize> { + self.close_keystrokes.take(); + self.clear_close_keystrokes_timer.take(); + return self.close_keystrokes_start.take(); + } + + fn handle_possible_close_keystroke( + &mut self, + keystroke: &Keystroke, + window: &mut Window, + cx: &mut Context<Self>, + ) -> CloseKeystrokeResult { + let Some(keybind_for_close_action) = Self::determine_stop_recording_binding(window) else { + log::trace!("No keybinding to stop recording keystrokes in keystroke input"); + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + }; + let action_keystrokes = keybind_for_close_action.keystrokes(); + + if let Some(mut close_keystrokes) = self.close_keystrokes.take() { + let mut index = 0; + + while index < action_keystrokes.len() && index < close_keystrokes.len() { + if !close_keystrokes[index].should_match(&action_keystrokes[index]) { + break; + } + index += 1; + } + if index == close_keystrokes.len() { + if index >= action_keystrokes.len() { + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + } + if keystroke.should_match(&action_keystrokes[index]) { + close_keystrokes.push(keystroke.clone()); + if close_keystrokes.len() == action_keystrokes.len() { + return CloseKeystrokeResult::Close; + } else { + self.close_keystrokes = Some(close_keystrokes); + self.update_clear_close_keystrokes_timer(cx); + return CloseKeystrokeResult::Partial; + } + } else { + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + } + } + } else if let Some(first_action_keystroke) = action_keystrokes.first() + && keystroke.should_match(first_action_keystroke) + { + self.close_keystrokes = Some(vec![keystroke.clone()]); + return CloseKeystrokeResult::Partial; + } + self.end_close_keystrokes_capture(); + return CloseKeystrokeResult::None; + } + + fn on_modifiers_changed( + &mut self, + event: &ModifiersChangedEvent, + window: &mut Window, + cx: &mut Context<Self>, + ) { + cx.stop_propagation(); + let keystrokes_len = self.keystrokes.len(); + + if self.previous_modifiers.modified() + && event.modifiers.is_subset_of(&self.previous_modifiers) + { + self.previous_modifiers &= event.modifiers; + return; + } + self.keystrokes_changed(cx); + + if let Some(last) = self.keystrokes.last_mut() + && last.key.is_empty() + && keystrokes_len <= Self::KEYSTROKE_COUNT_MAX + { + if !self.search && !event.modifiers.modified() { + self.keystrokes.pop(); + return; + } + if self.search { + if self.previous_modifiers.modified() { + last.modifiers |= event.modifiers; + } else { + self.keystrokes.push(Self::dummy(event.modifiers)); + } + self.previous_modifiers |= event.modifiers; + } else { + last.modifiers = event.modifiers; + return; + } + } else if keystrokes_len < Self::KEYSTROKE_COUNT_MAX { + self.keystrokes.push(Self::dummy(event.modifiers)); + if self.search { + self.previous_modifiers |= event.modifiers; + } + } + if keystrokes_len >= Self::KEYSTROKE_COUNT_MAX { + self.clear_keystrokes(&ClearKeystrokes, window, cx); + } + } + + fn handle_keystroke( + &mut self, + keystroke: &Keystroke, + window: &mut Window, + cx: &mut Context<Self>, + ) { + cx.stop_propagation(); + + let close_keystroke_result = self.handle_possible_close_keystroke(keystroke, window, cx); + if close_keystroke_result == CloseKeystrokeResult::Close { + self.stop_recording(&StopRecording, window, cx); + return; + } + + let mut keystroke = keystroke.clone(); + if let Some(last) = self.keystrokes.last() + && last.key.is_empty() + && (!self.search || self.previous_modifiers.modified()) + { + let key = keystroke.key.clone(); + keystroke = last.clone(); + keystroke.key = key; + self.keystrokes.pop(); + } + + if close_keystroke_result == CloseKeystrokeResult::Partial { + self.upsert_close_keystrokes_start(self.keystrokes.len(), cx); + if self.keystrokes.len() >= Self::KEYSTROKE_COUNT_MAX { + return; + } + } + + if self.keystrokes.len() >= Self::KEYSTROKE_COUNT_MAX { + self.clear_keystrokes(&ClearKeystrokes, window, cx); + return; + } + + self.keystrokes.push(keystroke.clone()); + self.keystrokes_changed(cx); + + if self.search { + self.previous_modifiers = keystroke.modifiers; + return; + } + if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX && keystroke.modifiers.modified() { + self.keystrokes.push(Self::dummy(keystroke.modifiers)); + } + } + + fn on_inner_focus_in(&mut self, _window: &mut Window, cx: &mut Context<Self>) { + if self.intercept_subscription.is_none() { + let listener = cx.listener(|this, event: &gpui::KeystrokeEvent, window, cx| { + this.handle_keystroke(&event.keystroke, window, cx); + }); + self.intercept_subscription = Some(cx.intercept_keystrokes(listener)) + } + } + + fn on_inner_focus_out( + &mut self, + _event: gpui::FocusOutEvent, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + self.intercept_subscription.take(); + cx.notify(); + } + + fn render_keystrokes(&self, is_recording: bool) -> impl Iterator<Item = Div> { + let keystrokes = if let Some(placeholders) = self.placeholder_keystrokes.as_ref() + && self.keystrokes.is_empty() + { + if is_recording { + &[] + } else { + placeholders.as_slice() + } + } else { + &self.keystrokes + }; + keystrokes.iter().map(move |keystroke| { + h_flex().children(ui::render_keystroke( + keystroke, + Some(Color::Default), + Some(rems(0.875).into()), + ui::PlatformStyle::platform(), + false, + )) + }) + } + + pub fn start_recording( + &mut self, + _: &StartRecording, + window: &mut Window, + cx: &mut Context<Self>, + ) { + window.focus(&self.inner_focus_handle); + self.clear_keystrokes(&ClearKeystrokes, window, cx); + self.previous_modifiers = window.modifiers(); + #[cfg(test)] + { + self.recording = true; + } + cx.stop_propagation(); + } + + pub fn stop_recording( + &mut self, + _: &StopRecording, + window: &mut Window, + cx: &mut Context<Self>, + ) { + if !self.is_recording(window) { + return; + } + window.focus(&self.outer_focus_handle); + if let Some(close_keystrokes_start) = self.close_keystrokes_start.take() + && close_keystrokes_start < self.keystrokes.len() + { + self.keystrokes.drain(close_keystrokes_start..); + self.keystrokes_changed(cx); + } + self.end_close_keystrokes_capture(); + #[cfg(test)] + { + self.recording = false; + } + cx.notify(); + } + + pub fn clear_keystrokes( + &mut self, + _: &ClearKeystrokes, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + self.keystrokes.clear(); + self.keystrokes_changed(cx); + self.end_close_keystrokes_capture(); + } + + fn is_recording(&self, window: &Window) -> bool { + #[cfg(test)] + { + if true { + // in tests, we just need a simple bool that is toggled on start and stop recording + return self.recording; + } + } + // however, in the real world, checking if the inner focus handle is focused + // is a much more reliable check, as the intercept keystroke handlers are installed + // on focus of the inner focus handle, thereby ensuring our recording state does + // not get de-synced + return self.inner_focus_handle.is_focused(window); + } +} + +impl EventEmitter<()> for KeystrokeInput {} + +impl Focusable for KeystrokeInput { + fn focus_handle(&self, _cx: &gpui::App) -> FocusHandle { + self.outer_focus_handle.clone() + } +} + +impl Render for KeystrokeInput { + fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { + let colors = cx.theme().colors(); + let is_focused = self.outer_focus_handle.contains_focused(window, cx); + let is_recording = self.is_recording(window); + + let horizontal_padding = rems_from_px(64.); + + let recording_bg_color = colors + .editor_background + .blend(colors.text_accent.opacity(0.1)); + + let recording_pulse = |color: Color| { + Icon::new(IconName::Circle) + .size(IconSize::Small) + .color(Color::Error) + .with_animation( + "recording-pulse", + Animation::new(std::time::Duration::from_secs(2)) + .repeat() + .with_easing(gpui::pulsating_between(0.4, 0.8)), + { + let color = color.color(cx); + move |this, delta| this.color(Color::Custom(color.opacity(delta))) + }, + ) + }; + + let recording_indicator = h_flex() + .h_4() + .pr_1() + .gap_0p5() + .border_1() + .border_color(colors.border) + .bg(colors + .editor_background + .blend(colors.text_accent.opacity(0.1))) + .rounded_sm() + .child(recording_pulse(Color::Error)) + .child( + Label::new("REC") + .size(LabelSize::XSmall) + .weight(FontWeight::SEMIBOLD) + .color(Color::Error), + ); + + let search_indicator = h_flex() + .h_4() + .pr_1() + .gap_0p5() + .border_1() + .border_color(colors.border) + .bg(colors + .editor_background + .blend(colors.text_accent.opacity(0.1))) + .rounded_sm() + .child(recording_pulse(Color::Accent)) + .child( + Label::new("SEARCH") + .size(LabelSize::XSmall) + .weight(FontWeight::SEMIBOLD) + .color(Color::Accent), + ); + + let record_icon = if self.search { + IconName::MagnifyingGlass + } else { + IconName::PlayFilled + }; + + h_flex() + .id("keystroke-input") + .track_focus(&self.outer_focus_handle) + .py_2() + .px_3() + .gap_2() + .min_h_10() + .w_full() + .flex_1() + .justify_between() + .rounded_lg() + .overflow_hidden() + .map(|this| { + if is_recording { + this.bg(recording_bg_color) + } else { + this.bg(colors.editor_background) + } + }) + .border_1() + .border_color(colors.border_variant) + .when(is_focused, |parent| { + parent.border_color(colors.border_focused) + }) + .key_context(Self::key_context()) + .on_action(cx.listener(Self::start_recording)) + .on_action(cx.listener(Self::clear_keystrokes)) + .child( + h_flex() + .w(horizontal_padding) + .gap_0p5() + .justify_start() + .flex_none() + .when(is_recording, |this| { + this.map(|this| { + if self.search { + this.child(search_indicator) + } else { + this.child(recording_indicator) + } + }) + }), + ) + .child( + h_flex() + .id("keystroke-input-inner") + .track_focus(&self.inner_focus_handle) + .on_modifiers_changed(cx.listener(Self::on_modifiers_changed)) + .size_full() + .when(!self.search, |this| { + this.focus(|mut style| { + style.border_color = Some(colors.border_focused); + style + }) + }) + .w_full() + .min_w_0() + .justify_center() + .flex_wrap() + .gap(ui::DynamicSpacing::Base04.rems(cx)) + .children(self.render_keystrokes(is_recording)), + ) + .child( + h_flex() + .w(horizontal_padding) + .gap_0p5() + .justify_end() + .flex_none() + .map(|this| { + if is_recording { + this.child( + IconButton::new("stop-record-btn", IconName::StopFilled) + .shape(IconButtonShape::Square) + .map(|this| { + this.tooltip(Tooltip::for_action_title( + if self.search { + "Stop Searching" + } else { + "Stop Recording" + }, + &StopRecording, + )) + }) + .icon_color(Color::Error) + .on_click(cx.listener(|this, _event, window, cx| { + this.stop_recording(&StopRecording, window, cx); + })), + ) + } else { + this.child( + IconButton::new("record-btn", record_icon) + .shape(IconButtonShape::Square) + .map(|this| { + this.tooltip(Tooltip::for_action_title( + if self.search { + "Start Searching" + } else { + "Start Recording" + }, + &StartRecording, + )) + }) + .when(!is_focused, |this| this.icon_color(Color::Muted)) + .on_click(cx.listener(|this, _event, window, cx| { + this.start_recording(&StartRecording, window, cx); + })), + ) + } + }) + .child( + IconButton::new("clear-btn", IconName::Delete) + .shape(IconButtonShape::Square) + .tooltip(Tooltip::for_action_title( + "Clear Keystrokes", + &ClearKeystrokes, + )) + .when(!is_recording || !is_focused, |this| { + this.icon_color(Color::Muted) + }) + .on_click(cx.listener(|this, _event, window, cx| { + this.clear_keystrokes(&ClearKeystrokes, window, cx); + })), + ), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use fs::FakeFs; + use gpui::{Entity, TestAppContext, VisualTestContext}; + use itertools::Itertools as _; + use project::Project; + use settings::SettingsStore; + use workspace::Workspace; + + pub struct KeystrokeInputTestHelper { + input: Entity<KeystrokeInput>, + current_modifiers: Modifiers, + cx: VisualTestContext, + } + + impl KeystrokeInputTestHelper { + /// Creates a new test helper with default settings + pub fn new(mut cx: VisualTestContext) -> Self { + let input = cx.new_window_entity(|window, cx| KeystrokeInput::new(None, window, cx)); + + let mut helper = Self { + input, + current_modifiers: Modifiers::default(), + cx, + }; + + helper.start_recording(); + helper + } + + /// Sets search mode on the input + pub fn with_search_mode(&mut self, search: bool) -> &mut Self { + self.input.update(&mut self.cx, |input, _| { + input.set_search(search); + }); + self + } + + /// Sends a keystroke event based on string description + /// Examples: "a", "ctrl-a", "cmd-shift-z", "escape" + #[track_caller] + pub fn send_keystroke(&mut self, keystroke_input: &str) -> &mut Self { + self.expect_is_recording(true); + let keystroke_str = if keystroke_input.ends_with('-') { + format!("{}_", keystroke_input) + } else { + keystroke_input.to_string() + }; + + let mut keystroke = Keystroke::parse(&keystroke_str) + .unwrap_or_else(|_| panic!("Invalid keystroke: {}", keystroke_input)); + + // Remove the dummy key if we added it for modifier-only keystrokes + if keystroke_input.ends_with('-') && keystroke_str.ends_with("_") { + keystroke.key = "".to_string(); + } + + // Combine current modifiers with keystroke modifiers + keystroke.modifiers |= self.current_modifiers; + + self.update_input(|input, window, cx| { + input.handle_keystroke(&keystroke, window, cx); + }); + + // Don't update current_modifiers for keystrokes with actual keys + if keystroke.key.is_empty() { + self.current_modifiers = keystroke.modifiers; + } + self + } + + /// Sends a modifier change event based on string description + /// Examples: "+ctrl", "-ctrl", "+cmd+shift", "-all" + #[track_caller] + pub fn send_modifiers(&mut self, modifiers: &str) -> &mut Self { + self.expect_is_recording(true); + let new_modifiers = if modifiers == "-all" { + Modifiers::default() + } else { + self.parse_modifier_change(modifiers) + }; + + let event = ModifiersChangedEvent { + modifiers: new_modifiers, + capslock: gpui::Capslock::default(), + }; + + self.update_input(|input, window, cx| { + input.on_modifiers_changed(&event, window, cx); + }); + + self.current_modifiers = new_modifiers; + self + } + + /// Sends multiple events in sequence + /// Each event string is either a keystroke or modifier change + #[track_caller] + pub fn send_events(&mut self, events: &[&str]) -> &mut Self { + self.expect_is_recording(true); + for event in events { + if event.starts_with('+') || event.starts_with('-') { + self.send_modifiers(event); + } else { + self.send_keystroke(event); + } + } + self + } + + #[track_caller] + fn expect_keystrokes_equal(actual: &[Keystroke], expected: &[&str]) { + let expected_keystrokes: Result<Vec<Keystroke>, _> = expected + .iter() + .map(|s| { + let keystroke_str = if s.ends_with('-') { + format!("{}_", s) + } else { + s.to_string() + }; + + let mut keystroke = Keystroke::parse(&keystroke_str)?; + + // Remove the dummy key if we added it for modifier-only keystrokes + if s.ends_with('-') && keystroke_str.ends_with("_") { + keystroke.key = "".to_string(); + } + + Ok(keystroke) + }) + .collect(); + + let expected_keystrokes = expected_keystrokes + .unwrap_or_else(|e: anyhow::Error| panic!("Invalid expected keystroke: {}", e)); + + assert_eq!( + actual.len(), + expected_keystrokes.len(), + "Keystroke count mismatch. Expected: {:?}, Actual: {:?}", + expected_keystrokes + .iter() + .map(|k| k.unparse()) + .collect::<Vec<_>>(), + actual.iter().map(|k| k.unparse()).collect::<Vec<_>>() + ); + + for (i, (actual, expected)) in actual.iter().zip(expected_keystrokes.iter()).enumerate() + { + assert_eq!( + actual.unparse(), + expected.unparse(), + "Keystroke {} mismatch. Expected: '{}', Actual: '{}'", + i, + expected.unparse(), + actual.unparse() + ); + } + } + + /// Verifies that the keystrokes match the expected strings + #[track_caller] + pub fn expect_keystrokes(&mut self, expected: &[&str]) -> &mut Self { + let actual = self + .input + .read_with(&mut self.cx, |input, _| input.keystrokes.clone()); + Self::expect_keystrokes_equal(&actual, expected); + self + } + + #[track_caller] + pub fn expect_close_keystrokes(&mut self, expected: &[&str]) -> &mut Self { + let actual = self + .input + .read_with(&mut self.cx, |input, _| input.close_keystrokes.clone()) + .unwrap_or_default(); + Self::expect_keystrokes_equal(&actual, expected); + self + } + + /// Verifies that there are no keystrokes + #[track_caller] + pub fn expect_empty(&mut self) -> &mut Self { + self.expect_keystrokes(&[]) + } + + /// Starts recording keystrokes + #[track_caller] + pub fn start_recording(&mut self) -> &mut Self { + self.expect_is_recording(false); + self.input.update_in(&mut self.cx, |input, window, cx| { + input.start_recording(&StartRecording, window, cx); + }); + self + } + + /// Stops recording keystrokes + pub fn stop_recording(&mut self) -> &mut Self { + self.expect_is_recording(true); + self.input.update_in(&mut self.cx, |input, window, cx| { + input.stop_recording(&StopRecording, window, cx); + }); + self + } + + /// Clears all keystrokes + #[track_caller] + pub fn clear_keystrokes(&mut self) -> &mut Self { + let change_tracker = KeystrokeUpdateTracker::new(self.input.clone(), &mut self.cx); + self.input.update_in(&mut self.cx, |input, window, cx| { + input.clear_keystrokes(&ClearKeystrokes, window, cx); + }); + KeystrokeUpdateTracker::finish(change_tracker, &self.cx); + self.current_modifiers = Default::default(); + self + } + + /// Verifies the recording state + #[track_caller] + pub fn expect_is_recording(&mut self, expected: bool) -> &mut Self { + let actual = self + .input + .update_in(&mut self.cx, |input, window, _| input.is_recording(window)); + assert_eq!( + actual, expected, + "Recording state mismatch. Expected: {}, Actual: {}", + expected, actual + ); + self + } + + pub async fn wait_for_close_keystroke_capture_end(&mut self) -> &mut Self { + let task = self.input.update_in(&mut self.cx, |input, _, _| { + input.clear_close_keystrokes_timer.take() + }); + let task = task.expect("No close keystroke capture end timer task"); + self.cx + .executor() + .advance_clock(CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT); + task.await; + self + } + + /// Parses modifier change strings like "+ctrl", "-shift", "+cmd+alt" + #[track_caller] + fn parse_modifier_change(&self, modifiers_str: &str) -> Modifiers { + let mut modifiers = self.current_modifiers; + + assert!(!modifiers_str.is_empty(), "Empty modifier string"); + + let value; + let split_char; + let remaining; + if let Some(to_add) = modifiers_str.strip_prefix('+') { + value = true; + split_char = '+'; + remaining = to_add; + } else { + let to_remove = modifiers_str + .strip_prefix('-') + .expect("Modifier string must start with '+' or '-'"); + value = false; + split_char = '-'; + remaining = to_remove; + } + + for modifier in remaining.split(split_char) { + match modifier { + "ctrl" | "control" => modifiers.control = value, + "alt" | "option" => modifiers.alt = value, + "shift" => modifiers.shift = value, + "cmd" | "command" | "platform" => modifiers.platform = value, + "fn" | "function" => modifiers.function = value, + _ => panic!("Unknown modifier: {}", modifier), + } + } + + modifiers + } + + #[track_caller] + fn update_input<R>( + &mut self, + cb: impl FnOnce(&mut KeystrokeInput, &mut Window, &mut Context<KeystrokeInput>) -> R, + ) -> R { + let change_tracker = KeystrokeUpdateTracker::new(self.input.clone(), &mut self.cx); + let result = self.input.update_in(&mut self.cx, cb); + KeystrokeUpdateTracker::finish(change_tracker, &self.cx); + return result; + } + } + + struct KeystrokeUpdateTracker { + initial_keystrokes: Vec<Keystroke>, + _subscription: Subscription, + input: Entity<KeystrokeInput>, + received_keystrokes_updated: bool, + } + + impl KeystrokeUpdateTracker { + fn new(input: Entity<KeystrokeInput>, cx: &mut VisualTestContext) -> Entity<Self> { + cx.new(|cx| Self { + initial_keystrokes: input.read_with(cx, |input, _| input.keystrokes.clone()), + _subscription: cx.subscribe(&input, |this: &mut Self, _, _, _| { + this.received_keystrokes_updated = true; + }), + input, + received_keystrokes_updated: false, + }) + } + #[track_caller] + fn finish(this: Entity<Self>, cx: &VisualTestContext) { + let (received_keystrokes_updated, initial_keystrokes_str, updated_keystrokes_str) = + this.read_with(cx, |this, cx| { + let updated_keystrokes = this + .input + .read_with(cx, |input, _| input.keystrokes.clone()); + let initial_keystrokes_str = keystrokes_str(&this.initial_keystrokes); + let updated_keystrokes_str = keystrokes_str(&updated_keystrokes); + ( + this.received_keystrokes_updated, + initial_keystrokes_str, + updated_keystrokes_str, + ) + }); + if received_keystrokes_updated { + assert_ne!( + initial_keystrokes_str, updated_keystrokes_str, + "Received keystrokes_updated event, expected different keystrokes" + ); + } else { + assert_eq!( + initial_keystrokes_str, updated_keystrokes_str, + "Received no keystrokes_updated event, expected same keystrokes" + ); + } + + fn keystrokes_str(ks: &[Keystroke]) -> String { + ks.iter().map(|ks| ks.unparse()).join(" ") + } + } + } + + async fn init_test(cx: &mut TestAppContext) -> KeystrokeInputTestHelper { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + theme::init(theme::LoadThemes::JustBase, cx); + language::init(cx); + project::Project::init_settings(cx); + workspace::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let workspace = + cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = VisualTestContext::from_window(*workspace, cx); + KeystrokeInputTestHelper::new(cx) + } + + #[gpui::test] + async fn test_basic_keystroke_input(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_keystroke("a") + .clear_keystrokes() + .expect_empty(); + } + + #[gpui::test] + async fn test_modifier_handling(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "a", "-ctrl"]) + .expect_keystrokes(&["ctrl-a"]); + } + + #[gpui::test] + async fn test_multiple_modifiers(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_keystroke("cmd-shift-z") + .expect_keystrokes(&["cmd-shift-z", "cmd-shift-"]); + } + + #[gpui::test] + async fn test_search_mode_behavior(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+cmd", "shift-f", "-cmd"]) + // In search mode, when completing a modifier-only keystroke with a key, + // only the original modifiers are preserved, not the keystroke's modifiers + .expect_keystrokes(&["cmd-f"]); + } + + #[gpui::test] + async fn test_keystroke_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_keystroke("a") + .send_keystroke("b") + .send_keystroke("c") + .expect_keystrokes(&["a", "b", "c"]) // At max limit + .send_keystroke("d") + .expect_empty(); // Should clear when exceeding limit + } + + #[gpui::test] + async fn test_modifier_release_all(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+shift", "a", "-all"]) + .expect_keystrokes(&["ctrl-shift-a"]); + } + + #[gpui::test] + async fn test_search_new_modifiers_not_added_until_all_released(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+shift", "a", "-ctrl"]) + .expect_keystrokes(&["ctrl-shift-a"]) + .send_events(&["+ctrl"]) + .expect_keystrokes(&["ctrl-shift-a", "ctrl-shift-"]); + } + + #[gpui::test] + async fn test_previous_modifiers_no_effect_when_not_search(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl+shift", "a", "-all"]) + .expect_keystrokes(&["ctrl-shift-a"]); + } + + #[gpui::test] + async fn test_keystroke_limit_overflow_non_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["a", "b", "c", "d"]) // 4 keystrokes, exceeds limit of 3 + .expect_empty(); // Should clear when exceeding limit + } + + #[gpui::test] + async fn test_complex_modifier_sequences(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "+alt", "a", "-ctrl", "-shift", "-alt"]) + .expect_keystrokes(&["ctrl-shift-alt-a"]); + } + + #[gpui::test] + async fn test_modifier_only_keystrokes_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "-ctrl", "-shift"]) + .expect_keystrokes(&["ctrl-shift-"]); // Modifier-only sequences create modifier-only keystrokes + } + + #[gpui::test] + async fn test_modifier_only_keystrokes_non_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "+shift", "-ctrl", "-shift"]) + .expect_empty(); // Modifier-only sequences get filtered in non-search mode + } + + #[gpui::test] + async fn test_rapid_modifier_changes(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "-ctrl", "+shift", "-shift", "+alt", "a", "-alt"]) + .expect_keystrokes(&["ctrl-", "shift-", "alt-a"]); + } + + #[gpui::test] + async fn test_clear_keystrokes_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "a", "-ctrl", "b"]) + .expect_keystrokes(&["ctrl-a", "b"]) + .clear_keystrokes() + .expect_empty(); + } + + #[gpui::test] + async fn test_non_search_mode_modifier_key_sequence(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "a"]) + .expect_keystrokes(&["ctrl-a", "ctrl-"]) + .send_events(&["-ctrl"]) + .expect_keystrokes(&["ctrl-a"]); // Non-search mode filters trailing empty keystrokes + } + + #[gpui::test] + async fn test_all_modifiers_at_once(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+shift+alt+cmd", "a", "-all"]) + .expect_keystrokes(&["ctrl-shift-alt-cmd-a"]); + } + + #[gpui::test] + async fn test_keystrokes_at_exact_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "b", "c"]) // exactly 3 keystrokes (at limit) + .expect_keystrokes(&["a", "b", "c"]) + .send_events(&["d"]) // should clear when exceeding + .expect_empty(); + } + + #[gpui::test] + async fn test_function_modifier_key(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+fn", "f1", "-fn"]) + .expect_keystrokes(&["fn-f1"]); + } + + #[gpui::test] + async fn test_start_stop_recording(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_events(&["a", "b"]) + .expect_keystrokes(&["a", "b"]) // start_recording clears existing keystrokes + .stop_recording() + .expect_is_recording(false) + .start_recording() + .send_events(&["c"]) + .expect_keystrokes(&["c"]); + } + + #[gpui::test] + async fn test_modifier_sequence_with_interruption(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "a", "-shift", "b", "-ctrl"]) + .expect_keystrokes(&["ctrl-shift-a", "ctrl-b"]); + } + + #[gpui::test] + async fn test_empty_key_sequence_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&[]) // No events at all + .expect_empty(); + } + + #[gpui::test] + async fn test_modifier_sequence_completion_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "+shift", "-shift", "a", "-ctrl"]) + .expect_keystrokes(&["ctrl-shift-a"]); + } + + #[gpui::test] + async fn test_triple_escape_stops_recording_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "escape", "escape", "escape"]) + .expect_keystrokes(&["a"]) // Triple escape removes final escape, stops recording + .expect_is_recording(false); + } + + #[gpui::test] + async fn test_triple_escape_stops_recording_non_search_mode(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["a", "escape", "escape", "escape"]) + .expect_keystrokes(&["a"]); // Triple escape stops recording but only removes final escape + } + + #[gpui::test] + async fn test_triple_escape_at_keystroke_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "b", "c", "escape", "escape", "escape"]) // 6 keystrokes total, exceeds limit + .expect_keystrokes(&["a", "b", "c"]); // Triple escape stops recording and removes escapes, leaves original keystrokes + } + + #[gpui::test] + async fn test_interrupted_escape_sequence(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape", "a", "escape"]) // Partial escape sequence interrupted by 'a' + .expect_keystrokes(&["escape", "escape", "a"]); // Escape sequence interrupted by 'a', no close triggered + } + + #[gpui::test] + async fn test_interrupted_escape_sequence_within_limit(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape", "a"]) // Partial escape sequence interrupted by 'a' (3 keystrokes, at limit) + .expect_keystrokes(&["escape", "escape", "a"]); // Should not trigger close, interruption resets escape detection + } + + #[gpui::test] + async fn test_partial_escape_sequence_no_close(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape"]) // Only 2 escapes, not enough to close + .expect_keystrokes(&["escape", "escape"]) + .expect_is_recording(true); // Should remain in keystrokes, no close triggered + } + + #[gpui::test] + async fn test_recording_state_after_triple_escape(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "escape", "escape", "escape"]) + .expect_keystrokes(&["a"]) // Triple escape stops recording, removes final escape + .expect_is_recording(false); + } + + #[gpui::test] + async fn test_triple_escape_mixed_with_other_keystrokes(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["a", "escape", "b", "escape", "escape"]) // Mixed sequence, should not trigger close + .expect_keystrokes(&["a", "escape", "b"]); // No complete triple escape sequence, stays at limit + } + + #[gpui::test] + async fn test_triple_escape_only(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["escape", "escape", "escape"]) // Pure triple escape sequence + .expect_empty(); + } + + #[gpui::test] + async fn test_end_close_keystroke_capture(cx: &mut TestAppContext) { + init_test(cx) + .await + .send_events(&["+ctrl", "g", "-ctrl", "escape"]) + .expect_keystrokes(&["ctrl-g", "escape"]) + .wait_for_close_keystroke_capture_end() + .await + .send_events(&["escape", "escape"]) + .expect_keystrokes(&["ctrl-g", "escape", "escape"]) + .expect_close_keystrokes(&["escape", "escape"]) + .send_keystroke("escape") + .expect_keystrokes(&["ctrl-g", "escape"]); + } + + #[gpui::test] + async fn test_search_previous_modifiers_are_sticky(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+alt", "-ctrl", "j"]) + .expect_keystrokes(&["ctrl-alt-j"]); + } + + #[gpui::test] + async fn test_previous_modifiers_can_be_entered_separately(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "-ctrl"]) + .expect_keystrokes(&["ctrl-"]) + .send_events(&["+alt", "-alt"]) + .expect_keystrokes(&["ctrl-", "alt-"]); + } + + #[gpui::test] + async fn test_previous_modifiers_reset_on_key(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+alt", "-ctrl", "+shift"]) + .expect_keystrokes(&["ctrl-shift-alt-"]) + .send_keystroke("j") + .expect_keystrokes(&["ctrl-shift-alt-j"]) + .send_keystroke("i") + .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i"]) + .send_events(&["-shift-alt", "+cmd"]) + .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i", "cmd-"]); + } + + #[gpui::test] + async fn test_previous_modifiers_reset_on_release_all(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl+alt", "-ctrl", "+shift"]) + .expect_keystrokes(&["ctrl-shift-alt-"]) + .send_events(&["-all", "j"]) + .expect_keystrokes(&["ctrl-shift-alt-", "j"]); + } + + #[gpui::test] + async fn test_search_repeat_modifiers(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(true) + .send_events(&["+ctrl", "-ctrl", "+alt", "-alt", "+shift", "-shift"]) + .expect_keystrokes(&["ctrl-", "alt-", "shift-"]) + .send_events(&["+cmd"]) + .expect_empty(); + } + + #[gpui::test] + async fn test_not_search_repeat_modifiers(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "-ctrl", "+alt", "-alt", "+shift", "-shift"]) + .expect_empty(); + } +} diff --git a/crates/settings_ui/src/ui_components/mod.rs b/crates/settings_ui/src/ui_components/mod.rs index 13971b0a5d..5d6463a61a 100644 --- a/crates/settings_ui/src/ui_components/mod.rs +++ b/crates/settings_ui/src/ui_components/mod.rs @@ -1 +1,2 @@ +pub mod keystroke_input; pub mod table; diff --git a/crates/settings_ui/src/ui_components/table.rs b/crates/settings_ui/src/ui_components/table.rs index 69207f559b..2b3e815f36 100644 --- a/crates/settings_ui/src/ui_components/table.rs +++ b/crates/settings_ui/src/ui_components/table.rs @@ -2,9 +2,9 @@ use std::{ops::Range, rc::Rc, time::Duration}; use editor::{EditorSettings, ShowScrollbar, scroll::ScrollbarAutoHide}; use gpui::{ - AbsoluteLength, AppContext, Axis, Context, DefiniteLength, DragMoveEvent, Entity, FocusHandle, - Length, ListHorizontalSizingBehavior, ListSizingBehavior, MouseButton, Point, Stateful, Task, - UniformListScrollHandle, WeakEntity, transparent_black, uniform_list, + AbsoluteLength, AppContext, Axis, Context, DefiniteLength, DragMoveEvent, Entity, EntityId, + FocusHandle, Length, ListHorizontalSizingBehavior, ListSizingBehavior, MouseButton, Point, + Stateful, Task, UniformListScrollHandle, WeakEntity, transparent_black, uniform_list, }; use itertools::intersperse_with; @@ -13,10 +13,12 @@ use ui::{ ActiveTheme as _, AnyElement, App, Button, ButtonCommon as _, ButtonStyle, Color, Component, ComponentScope, Div, ElementId, FixedWidth as _, FluentBuilder as _, Indicator, InteractiveElement, IntoElement, ParentElement, Pixels, RegisterComponent, RenderOnce, - Scrollbar, ScrollbarState, StatefulInteractiveElement, Styled, StyledExt as _, + Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, StyledExt as _, StyledTypography, Window, div, example_group_with_title, h_flex, px, single_example, v_flex, }; +const RESIZE_COLUMN_WIDTH: f32 = 8.0; + #[derive(Debug)] struct DraggedColumn(usize); @@ -212,6 +214,7 @@ impl TableInteractionState { let mut column_ix = 0; let resizable_columns_slice = *resizable_columns; let mut resizable_columns = resizable_columns.into_iter(); + let dividers = intersperse_with(spacers, || { window.with_id(column_ix, |window| { let mut resize_divider = div() @@ -219,15 +222,15 @@ impl TableInteractionState { .id(column_ix) .relative() .top_0() - .w_0p5() + .w_px() .h_full() - .bg(cx.theme().colors().border.opacity(0.5)); + .bg(cx.theme().colors().border.opacity(0.8)); let mut resize_handle = div() .id("column-resize-handle") .absolute() .left_neg_0p5() - .w(px(5.0)) + .w(px(RESIZE_COLUMN_WIDTH)) .h_full(); if resizable_columns @@ -235,15 +238,17 @@ impl TableInteractionState { .is_some_and(ResizeBehavior::is_resizable) { let hovered = window.use_state(cx, |_window, _cx| false); + resize_divider = resize_divider.when(*hovered.read(cx), |div| { div.bg(cx.theme().colors().border_focused) }); + resize_handle = resize_handle .on_hover(move |&was_hovered, _, cx| hovered.write(cx, was_hovered)) .cursor_col_resize() .when_some(columns.clone(), |this, columns| { this.on_click(move |event, window, cx| { - if event.down.click_count >= 2 { + if event.click_count() >= 2 { columns.update(cx, |columns, _| { columns.on_double_click( column_ix, @@ -267,12 +272,11 @@ impl TableInteractionState { }) }); - div() + h_flex() .id("resize-handles") - .h_flex() .absolute() - .w_full() .inset_0() + .w_full() .children(dividers) .into_any_element() } @@ -478,6 +482,7 @@ impl ResizeBehavior { pub struct ColumnWidths<const COLS: usize> { widths: [DefiniteLength; COLS], + visible_widths: [DefiniteLength; COLS], cached_bounds_width: Pixels, initialized: bool, } @@ -486,6 +491,7 @@ impl<const COLS: usize> ColumnWidths<COLS> { pub fn new(_: &mut App) -> Self { Self { widths: [DefiniteLength::default(); COLS], + visible_widths: [DefiniteLength::default(); COLS], cached_bounds_width: Default::default(), initialized: false, } @@ -512,46 +518,105 @@ impl<const COLS: usize> ColumnWidths<COLS> { let rem_size = window.rem_size(); let initial_sizes = initial_sizes.map(|length| Self::get_fraction(&length, bounds_width, rem_size)); - let mut widths = self + let widths = self .widths .map(|length| Self::get_fraction(&length, bounds_width, rem_size)); - let diff = initial_sizes[double_click_position] - widths[double_click_position]; + let updated_widths = Self::reset_to_initial_size( + double_click_position, + widths, + initial_sizes, + resize_behavior, + ); + self.widths = updated_widths.map(DefiniteLength::Fraction); + self.visible_widths = self.widths; + } - if diff > 0.0 { - let diff_remaining = self.propagate_resize_diff_right( - diff, - double_click_position, - &mut widths, - resize_behavior, - ); + fn reset_to_initial_size( + col_idx: usize, + mut widths: [f32; COLS], + initial_sizes: [f32; COLS], + resize_behavior: &[ResizeBehavior; COLS], + ) -> [f32; COLS] { + // RESET: + // Part 1: + // Figure out if we should shrink/grow the selected column + // Get diff which represents the change in column we want to make initial size delta curr_size = diff + // + // Part 2: We need to decide which side column we should move and where + // + // If we want to grow our column we should check the left/right columns diff to see what side + // has a greater delta than their initial size. Likewise, if we shrink our column we should check + // the left/right column diffs to see what side has the smallest delta. + // + // Part 3: resize + // + // col_idx represents the column handle to the right of an active column + // + // If growing and right has the greater delta { + // shift col_idx to the right + // } else if growing and left has the greater delta { + // shift col_idx - 1 to the left + // } else if shrinking and the right has the greater delta { + // shift + // } { + // + // } + // } + // + // if we need to shrink, then if the right + // - if diff_remaining > 0.0 && double_click_position > 0 { - self.propagate_resize_diff_left( - -diff_remaining, - double_click_position - 1, + // DRAGGING + // we get diff which represents the change in the _drag handle_ position + // -diff => dragging left -> + // grow the column to the right of the handle as much as we can shrink columns to the left of the handle + // +diff => dragging right -> growing handles column + // grow the column to the left of the handle as much as we can shrink columns to the right of the handle + // + + let diff = initial_sizes[col_idx] - widths[col_idx]; + + let left_diff = + initial_sizes[..col_idx].iter().sum::<f32>() - widths[..col_idx].iter().sum::<f32>(); + let right_diff = initial_sizes[col_idx + 1..].iter().sum::<f32>() + - widths[col_idx + 1..].iter().sum::<f32>(); + + let go_left_first = if diff < 0.0 { + left_diff > right_diff + } else { + left_diff < right_diff + }; + + if !go_left_first { + let diff_remaining = + Self::propagate_resize_diff(diff, col_idx, &mut widths, resize_behavior, 1); + + if diff_remaining != 0.0 && col_idx > 0 { + Self::propagate_resize_diff( + diff_remaining, + col_idx, &mut widths, resize_behavior, + -1, ); } - } else if double_click_position > 0 { - let diff_remaining = self.propagate_resize_diff_left( - diff, - double_click_position, - &mut widths, - resize_behavior, - ); + } else { + let diff_remaining = + Self::propagate_resize_diff(diff, col_idx, &mut widths, resize_behavior, -1); - if diff_remaining < 0.0 { - self.propagate_resize_diff_right( - -diff_remaining, - double_click_position, + if diff_remaining != 0.0 { + Self::propagate_resize_diff( + diff_remaining, + col_idx, &mut widths, resize_behavior, + 1, ); } } - self.widths = widths.map(DefiniteLength::Fraction); + + widths } fn on_drag_move( @@ -569,98 +634,102 @@ impl<const COLS: usize> ColumnWidths<COLS> { let bounds_width = bounds.right() - bounds.left(); let col_idx = drag_event.drag(cx).0; + let column_handle_width = Self::get_fraction( + &DefiniteLength::Absolute(AbsoluteLength::Pixels(px(RESIZE_COLUMN_WIDTH))), + bounds_width, + rem_size, + ); + let mut widths = self .widths .map(|length| Self::get_fraction(&length, bounds_width, rem_size)); for length in widths[0..=col_idx].iter() { - col_position += length; + col_position += length + column_handle_width; } let mut total_length_ratio = col_position; for length in widths[col_idx + 1..].iter() { total_length_ratio += length; } + total_length_ratio += (COLS - 1 - col_idx) as f32 * column_handle_width; let drag_fraction = (drag_position.x - bounds.left()) / bounds_width; let drag_fraction = drag_fraction * total_length_ratio; - let diff = drag_fraction - col_position; + let diff = drag_fraction - col_position - column_handle_width / 2.0; - let is_dragging_right = diff > 0.0; + Self::drag_column_handle(diff, col_idx, &mut widths, resize_behavior); - if is_dragging_right { - self.propagate_resize_diff_right(diff, col_idx, &mut widths, resize_behavior); - } else { - // Resize behavior should be improved in the future by also seeking to the right column when there's not enough space - self.propagate_resize_diff_left(diff, col_idx, &mut widths, resize_behavior); - } - self.widths = widths.map(DefiniteLength::Fraction); + self.visible_widths = widths.map(DefiniteLength::Fraction); } - fn propagate_resize_diff_right( - &self, + fn drag_column_handle( diff: f32, col_idx: usize, widths: &mut [f32; COLS], resize_behavior: &[ResizeBehavior; COLS], - ) -> f32 { - let mut diff_remaining = diff; - let mut curr_column = col_idx + 1; - - while diff_remaining > 0.0 && curr_column < COLS { - let Some(min_size) = resize_behavior[curr_column - 1].min_size() else { - curr_column += 1; - continue; - }; - - let mut curr_width = widths[curr_column] - diff_remaining; - - diff_remaining = 0.0; - if min_size > curr_width { - diff_remaining += min_size - curr_width; - curr_width = min_size; - } - widths[curr_column] = curr_width; - curr_column += 1; + ) { + // if diff > 0.0 then go right + if diff > 0.0 { + Self::propagate_resize_diff(diff, col_idx, widths, resize_behavior, 1); + } else { + Self::propagate_resize_diff(-diff, col_idx + 1, widths, resize_behavior, -1); } - - widths[col_idx] = widths[col_idx] + (diff - diff_remaining); - return diff_remaining; } - fn propagate_resize_diff_left( - &mut self, + fn propagate_resize_diff( diff: f32, - mut curr_column: usize, + col_idx: usize, widths: &mut [f32; COLS], resize_behavior: &[ResizeBehavior; COLS], + direction: i8, ) -> f32 { let mut diff_remaining = diff; - let col_idx = curr_column; - while diff_remaining < 0.0 { + if resize_behavior[col_idx].min_size().is_none() { + return diff; + } + + let step_right; + let step_left; + if direction < 0 { + step_right = 0; + step_left = 1; + } else { + step_right = 1; + step_left = 0; + } + if col_idx == 0 && direction < 0 { + return diff; + } + let mut curr_column = col_idx + step_right - step_left; + + while diff_remaining != 0.0 && curr_column < COLS { let Some(min_size) = resize_behavior[curr_column].min_size() else { if curr_column == 0 { break; } - curr_column -= 1; + curr_column -= step_left; + curr_column += step_right; continue; }; - let mut curr_width = widths[curr_column] + diff_remaining; - - diff_remaining = 0.0; - if curr_width < min_size { - diff_remaining = curr_width - min_size; - curr_width = min_size - } - + let curr_width = widths[curr_column] - diff_remaining; widths[curr_column] = curr_width; + + if min_size > curr_width { + diff_remaining = min_size - curr_width; + widths[curr_column] = min_size; + } else { + diff_remaining = 0.0; + break; + } if curr_column == 0 { break; } - curr_column -= 1; + curr_column -= step_left; + curr_column += step_right; } - widths[col_idx + 1] = widths[col_idx + 1] - (diff - diff_remaining); + widths[col_idx] = widths[col_idx] + (diff - diff_remaining); return diff_remaining; } @@ -686,7 +755,7 @@ impl<const COLS: usize> TableWidths<COLS> { fn lengths(&self, cx: &App) -> [Length; COLS] { self.current .as_ref() - .map(|entity| entity.read(cx).widths.map(Length::Definite)) + .map(|entity| entity.read(cx).visible_widths.map(Length::Definite)) .unwrap_or(self.initial.map(Length::Definite)) } } @@ -799,6 +868,7 @@ impl<const COLS: usize> Table<COLS> { if !widths.initialized { widths.initialized = true; widths.widths = table_widths.initial; + widths.visible_widths = widths.widths; } }) } @@ -828,7 +898,6 @@ fn base_cell_style(width: Option<Length>) -> Div { .px_1p5() .when_some(width, |this, width| this.w(width)) .when(width.is_none(), |this| this.flex_1()) - .justify_start() .whitespace_nowrap() .text_ellipsis() .overflow_hidden() @@ -873,7 +942,7 @@ pub fn render_row<const COLS: usize>( .map(IntoElement::into_any_element) .into_iter() .zip(column_widths) - .map(|(cell, width)| base_cell_style_text(width, cx).px_1p5().py_1().child(cell)), + .map(|(cell, width)| base_cell_style_text(width, cx).px_1().py_0p5().child(cell)), ); let row = if let Some(map_row) = table_context.map_row { @@ -882,17 +951,30 @@ pub fn render_row<const COLS: usize>( row.into_any_element() }; - div().h_full().w_full().child(row).into_any_element() + div().size_full().child(row).into_any_element() } pub fn render_header<const COLS: usize>( headers: [impl IntoElement; COLS], table_context: TableRenderContext<COLS>, + columns_widths: Option<( + WeakEntity<ColumnWidths<COLS>>, + [ResizeBehavior; COLS], + [DefiniteLength; COLS], + )>, + entity_id: Option<EntityId>, cx: &mut App, ) -> impl IntoElement { let column_widths = table_context .column_widths .map_or([None; COLS], |widths| widths.map(Some)); + + let element_id = entity_id + .map(|entity| entity.to_string()) + .unwrap_or_default(); + + let shared_element_id: SharedString = format!("table-{}", element_id).into(); + div() .flex() .flex_row() @@ -902,12 +984,39 @@ pub fn render_header<const COLS: usize>( .p_2() .border_b_1() .border_color(cx.theme().colors().border) - .children( - headers - .into_iter() - .zip(column_widths) - .map(|(h, width)| base_cell_style_text(width, cx).child(h)), - ) + .children(headers.into_iter().enumerate().zip(column_widths).map( + |((header_idx, h), width)| { + base_cell_style_text(width, cx) + .child(h) + .id(ElementId::NamedInteger( + shared_element_id.clone(), + header_idx as u64, + )) + .when_some( + columns_widths.as_ref().cloned(), + |this, (column_widths, resizables, initial_sizes)| { + if resizables[header_idx].is_resizable() { + this.on_click(move |event, window, cx| { + if event.click_count() > 1 { + column_widths + .update(cx, |column, _| { + column.on_double_click( + header_idx, + &initial_sizes, + &resizables, + window, + ); + }) + .ok(); + } + }) + } else { + this + } + }, + ) + }, + )) } #[derive(Clone)] @@ -939,6 +1048,12 @@ impl<const COLS: usize> RenderOnce for Table<COLS> { .and_then(|widths| Some((widths.current.as_ref()?, widths.resizable))) .map(|(curr, resize_behavior)| (curr.downgrade(), resize_behavior)); + let current_widths_with_initial_sizes = self + .col_widths + .as_ref() + .and_then(|widths| Some((widths.current.as_ref()?, widths.resizable, widths.initial))) + .map(|(curr, resize_behavior, initial)| (curr.downgrade(), resize_behavior, initial)); + let scroll_track_size = px(16.); let h_scroll_offset = if interaction_state .as_ref() @@ -958,7 +1073,13 @@ impl<const COLS: usize> RenderOnce for Table<COLS> { .h_full() .v_flex() .when_some(self.headers.take(), |this, headers| { - this.child(render_header(headers, table_context.clone(), cx)) + this.child(render_header( + headers, + table_context.clone(), + current_widths_with_initial_sizes, + interaction_state.as_ref().map(Entity::entity_id), + cx, + )) }) .when_some(current_widths, { |this, (widths, resize_behavior)| { @@ -972,19 +1093,28 @@ impl<const COLS: usize> RenderOnce for Table<COLS> { .ok(); } }) - .on_children_prepainted(move |bounds, _, cx| { + .on_children_prepainted({ + let widths = widths.clone(); + move |bounds, _, cx| { + widths + .update(cx, |widths, _| { + // This works because all children x axis bounds are the same + widths.cached_bounds_width = + bounds[0].right() - bounds[0].left(); + }) + .ok(); + } + }) + .on_drop::<DraggedColumn>(move |_, _, cx| { widths .update(cx, |widths, _| { - // This works because all children x axis bounds are the same - widths.cached_bounds_width = bounds[0].right() - bounds[0].left(); + widths.widths = widths.visible_widths; }) .ok(); + // Finish the resize operation }) } }) - .on_drop::<DraggedColumn>(|_, _, _| { - // Finish the resize operation - }) .child( div() .flex_grow() @@ -1313,3 +1443,323 @@ impl Component for Table<3> { ) } } + +#[cfg(test)] +mod test { + use super::*; + + fn is_almost_eq(a: &[f32], b: &[f32]) -> bool { + a.len() == b.len() && a.iter().zip(b).all(|(x, y)| (x - y).abs() < 1e-6) + } + + fn cols_to_str<const COLS: usize>(cols: &[f32; COLS], total_size: f32) -> String { + cols.map(|f| "*".repeat(f32::round(f * total_size) as usize)) + .join("|") + } + + fn parse_resize_behavior<const COLS: usize>( + input: &str, + total_size: f32, + ) -> [ResizeBehavior; COLS] { + let mut resize_behavior = [ResizeBehavior::None; COLS]; + let mut max_index = 0; + for (index, col) in input.split('|').enumerate() { + if col.starts_with('X') || col.is_empty() { + resize_behavior[index] = ResizeBehavior::None; + } else if col.starts_with('*') { + resize_behavior[index] = ResizeBehavior::MinSize(col.len() as f32 / total_size); + } else { + panic!("invalid test input: unrecognized resize behavior: {}", col); + } + max_index = index; + } + + if max_index + 1 != COLS { + panic!("invalid test input: too many columns"); + } + resize_behavior + } + + mod reset_column_size { + use super::*; + + fn parse<const COLS: usize>(input: &str) -> ([f32; COLS], f32, Option<usize>) { + let mut widths = [f32::NAN; COLS]; + let mut column_index = None; + for (index, col) in input.split('|').enumerate() { + widths[index] = col.len() as f32; + if col.starts_with('X') { + column_index = Some(index); + } + } + + for w in widths { + assert!(w.is_finite(), "incorrect number of columns"); + } + let total = widths.iter().sum::<f32>(); + for width in &mut widths { + *width /= total; + } + (widths, total, column_index) + } + + #[track_caller] + fn check_reset_size<const COLS: usize>( + initial_sizes: &str, + widths: &str, + expected: &str, + resize_behavior: &str, + ) { + let (initial_sizes, total_1, None) = parse::<COLS>(initial_sizes) else { + panic!("invalid test input: initial sizes should not be marked"); + }; + let (widths, total_2, Some(column_index)) = parse::<COLS>(widths) else { + panic!("invalid test input: widths should be marked"); + }; + assert_eq!( + total_1, total_2, + "invalid test input: total width not the same {total_1}, {total_2}" + ); + let (expected, total_3, None) = parse::<COLS>(expected) else { + panic!("invalid test input: expected should not be marked: {expected:?}"); + }; + assert_eq!( + total_2, total_3, + "invalid test input: total width not the same" + ); + let resize_behavior = parse_resize_behavior::<COLS>(resize_behavior, total_1); + let result = ColumnWidths::reset_to_initial_size( + column_index, + widths, + initial_sizes, + &resize_behavior, + ); + let is_eq = is_almost_eq(&result, &expected); + if !is_eq { + let result_str = cols_to_str(&result, total_1); + let expected_str = cols_to_str(&expected, total_1); + panic!( + "resize failed\ncomputed: {result_str}\nexpected: {expected_str}\n\ncomputed values: {result:?}\nexpected values: {expected:?}\n:minimum widths: {resize_behavior:?}" + ); + } + } + + macro_rules! check_reset_size { + (columns: $cols:expr, starting: $initial:expr, snapshot: $current:expr, expected: $expected:expr, resizing: $resizing:expr $(,)?) => { + check_reset_size::<$cols>($initial, $current, $expected, $resizing); + }; + ($name:ident, columns: $cols:expr, starting: $initial:expr, snapshot: $current:expr, expected: $expected:expr, minimums: $resizing:expr $(,)?) => { + #[test] + fn $name() { + check_reset_size::<$cols>($initial, $current, $expected, $resizing); + } + }; + } + + check_reset_size!( + basic_right, + columns: 5, + starting: "**|**|**|**|**", + snapshot: "**|**|X|***|**", + expected: "**|**|**|**|**", + minimums: "X|*|*|*|*", + ); + + check_reset_size!( + basic_left, + columns: 5, + starting: "**|**|**|**|**", + snapshot: "**|**|***|X|**", + expected: "**|**|**|**|**", + minimums: "X|*|*|*|**", + ); + + check_reset_size!( + squashed_left_reset_col2, + columns: 6, + starting: "*|***|**|**|****|*", + snapshot: "*|*|X|*|*|********", + expected: "*|*|**|*|*|*******", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + grow_cascading_right, + columns: 6, + starting: "*|***|****|**|***|*", + snapshot: "*|***|X|**|**|*****", + expected: "*|***|****|*|*|****", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + squashed_right_reset_col4, + columns: 6, + starting: "*|***|**|**|****|*", + snapshot: "*|********|*|*|X|*", + expected: "*|*****|*|*|****|*", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + reset_col6_right, + columns: 6, + starting: "*|***|**|***|***|**", + snapshot: "*|***|**|***|**|XXX", + expected: "*|***|**|***|***|**", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + reset_col6_left, + columns: 6, + starting: "*|***|**|***|***|**", + snapshot: "*|***|**|***|****|X", + expected: "*|***|**|***|***|**", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + last_column_grow_cascading, + columns: 6, + starting: "*|***|**|**|**|***", + snapshot: "*|*******|*|**|*|X", + expected: "*|******|*|*|*|***", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + goes_left_when_left_has_extreme_diff, + columns: 6, + starting: "*|***|****|**|**|***", + snapshot: "*|********|X|*|**|**", + expected: "*|*****|****|*|**|**", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + basic_shrink_right, + columns: 6, + starting: "**|**|**|**|**|**", + snapshot: "**|**|XXX|*|**|**", + expected: "**|**|**|**|**|**", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + shrink_should_go_left, + columns: 6, + starting: "*|***|**|*|*|*", + snapshot: "*|*|XXX|**|*|*", + expected: "*|**|**|**|*|*", + minimums: "X|*|*|*|*|*", + ); + + check_reset_size!( + shrink_should_go_right, + columns: 6, + starting: "*|***|**|**|**|*", + snapshot: "*|****|XXX|*|*|*", + expected: "*|****|**|**|*|*", + minimums: "X|*|*|*|*|*", + ); + } + + mod drag_handle { + use super::*; + + fn parse<const COLS: usize>(input: &str) -> ([f32; COLS], f32, Option<usize>) { + let mut widths = [f32::NAN; COLS]; + let column_index = input.replace("*", "").find("I"); + for (index, col) in input.replace("I", "|").split('|').enumerate() { + widths[index] = col.len() as f32; + } + + for w in widths { + assert!(w.is_finite(), "incorrect number of columns"); + } + let total = widths.iter().sum::<f32>(); + for width in &mut widths { + *width /= total; + } + (widths, total, column_index) + } + + #[track_caller] + fn check<const COLS: usize>( + distance: i32, + widths: &str, + expected: &str, + resize_behavior: &str, + ) { + let (mut widths, total_1, Some(column_index)) = parse::<COLS>(widths) else { + panic!("invalid test input: widths should be marked"); + }; + let (expected, total_2, None) = parse::<COLS>(expected) else { + panic!("invalid test input: expected should not be marked: {expected:?}"); + }; + assert_eq!( + total_1, total_2, + "invalid test input: total width not the same" + ); + let resize_behavior = parse_resize_behavior::<COLS>(resize_behavior, total_1); + + let distance = distance as f32 / total_1; + + let result = ColumnWidths::drag_column_handle( + distance, + column_index, + &mut widths, + &resize_behavior, + ); + + let is_eq = is_almost_eq(&widths, &expected); + if !is_eq { + let result_str = cols_to_str(&widths, total_1); + let expected_str = cols_to_str(&expected, total_1); + panic!( + "resize failed\ncomputed: {result_str}\nexpected: {expected_str}\n\ncomputed values: {result:?}\nexpected values: {expected:?}\n:minimum widths: {resize_behavior:?}" + ); + } + } + + macro_rules! check { + (columns: $cols:expr, distance: $dist:expr, snapshot: $current:expr, expected: $expected:expr, resizing: $resizing:expr $(,)?) => { + check!($cols, $dist, $snapshot, $expected, $resizing); + }; + ($name:ident, columns: $cols:expr, distance: $dist:expr, snapshot: $current:expr, expected: $expected:expr, minimums: $resizing:expr $(,)?) => { + #[test] + fn $name() { + check::<$cols>($dist, $current, $expected, $resizing); + } + }; + } + + check!( + basic_right_drag, + columns: 3, + distance: 1, + snapshot: "**|**I**", + expected: "**|***|*", + minimums: "X|*|*", + ); + + check!( + drag_left_against_mins, + columns: 5, + distance: -1, + snapshot: "*|*|*|*I*******", + expected: "*|*|*|*|*******", + minimums: "X|*|*|*|*", + ); + + check!( + drag_left, + columns: 5, + distance: -2, + snapshot: "*|*|*|*****I***", + expected: "*|*|*|***|*****", + minimums: "X|*|*|*|*", + ); + } +} diff --git a/crates/snippets_ui/src/snippets_ui.rs b/crates/snippets_ui/src/snippets_ui.rs index 1cc16c5576..a8710d1672 100644 --- a/crates/snippets_ui/src/snippets_ui.rs +++ b/crates/snippets_ui/src/snippets_ui.rs @@ -149,13 +149,12 @@ impl ScopeSelectorDelegate { scope_selector: WeakEntity<ScopeSelector>, language_registry: Arc<LanguageRegistry>, ) -> Self { - let candidates = Vec::from([GLOBAL_SCOPE_NAME.to_string()]).into_iter(); let languages = language_registry.language_names().into_iter(); - let candidates = candidates + let candidates = std::iter::once(LanguageName::new(GLOBAL_SCOPE_NAME)) .chain(languages) .enumerate() - .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, &name)) + .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, name.as_ref())) .collect::<Vec<_>>(); let mut existing_scopes = HashSet::new(); diff --git a/crates/sum_tree/src/sum_tree.rs b/crates/sum_tree/src/sum_tree.rs index 4f9e01ce20..3a12e3a681 100644 --- a/crates/sum_tree/src/sum_tree.rs +++ b/crates/sum_tree/src/sum_tree.rs @@ -41,16 +41,14 @@ pub trait Summary: Clone { fn add_summary(&mut self, summary: &Self, cx: &Self::Context); } -/// This type exists because we can't implement Summary for () without causing -/// type resolution errors -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub struct Unit; - -impl Summary for Unit { +/// Catch-all implementation for when you need something that implements [`Summary`] without a specific type. +/// We implement it on a &'static, as that avoids blanket impl collisions with `impl<T: Summary> Dimension for T` +/// (as we also need unit type to be a fill-in dimension) +impl Summary for &'static () { type Context = (); fn zero(_: &()) -> Self { - Unit + &() } fn add_summary(&mut self, _: &Self, _: &()) {} @@ -103,37 +101,32 @@ impl<'a, T: Summary> Dimension<'a, T> for () { fn add_summary(&mut self, _: &'a T, _: &T::Context) {} } -impl<'a, T: Summary, D1: Dimension<'a, T>, D2: Dimension<'a, T>> Dimension<'a, T> for (D1, D2) { +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct Dimensions<D1, D2, D3 = ()>(pub D1, pub D2, pub D3); + +impl<'a, T: Summary, D1: Dimension<'a, T>, D2: Dimension<'a, T>, D3: Dimension<'a, T>> + Dimension<'a, T> for Dimensions<D1, D2, D3> +{ fn zero(cx: &T::Context) -> Self { - (D1::zero(cx), D2::zero(cx)) + Dimensions(D1::zero(cx), D2::zero(cx), D3::zero(cx)) } fn add_summary(&mut self, summary: &'a T, cx: &T::Context) { self.0.add_summary(summary, cx); self.1.add_summary(summary, cx); + self.2.add_summary(summary, cx); } } -impl<'a, S, D1, D2> SeekTarget<'a, S, (D1, D2)> for D1 -where - S: Summary, - D1: SeekTarget<'a, S, D1> + Dimension<'a, S>, - D2: Dimension<'a, S>, -{ - fn cmp(&self, cursor_location: &(D1, D2), cx: &S::Context) -> Ordering { - self.cmp(&cursor_location.0, cx) - } -} - -impl<'a, S, D1, D2, D3> SeekTarget<'a, S, ((D1, D2), D3)> for D1 +impl<'a, S, D1, D2, D3> SeekTarget<'a, S, Dimensions<D1, D2, D3>> for D1 where S: Summary, D1: SeekTarget<'a, S, D1> + Dimension<'a, S>, D2: Dimension<'a, S>, D3: Dimension<'a, S>, { - fn cmp(&self, cursor_location: &((D1, D2), D3), cx: &S::Context) -> Ordering { - self.cmp(&cursor_location.0.0, cx) + fn cmp(&self, cursor_location: &Dimensions<D1, D2, D3>, cx: &S::Context) -> Ordering { + self.cmp(&cursor_location.0, cx) } } diff --git a/crates/supermaven/Cargo.toml b/crates/supermaven/Cargo.toml index d0451f34f2..4fc6a618ff 100644 --- a/crates/supermaven/Cargo.toml +++ b/crates/supermaven/Cargo.toml @@ -16,9 +16,9 @@ doctest = false anyhow.workspace = true client.workspace = true collections.workspace = true +edit_prediction.workspace = true futures.workspace = true gpui.workspace = true -inline_completion.workspace = true language.workspace = true log.workspace = true postage.workspace = true diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index c49272e66e..2660a03e6f 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -1,8 +1,8 @@ use crate::{Supermaven, SupermavenCompletionStateId}; use anyhow::Result; +use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use futures::StreamExt as _; use gpui::{App, Context, Entity, EntityId, Task}; -use inline_completion::{Direction, EditPredictionProvider, InlineCompletion}; use language::{Anchor, Buffer, BufferSnapshot}; use project::Project; use std::{ @@ -44,7 +44,7 @@ fn completion_from_diff( completion_text: &str, position: Anchor, delete_range: Range<Anchor>, -) -> InlineCompletion { +) -> EditPrediction { let buffer_text = snapshot .text_for_range(delete_range.clone()) .collect::<String>(); @@ -91,7 +91,7 @@ fn completion_from_diff( edits.push((edit_range, edit_text)); } - InlineCompletion { + EditPrediction { id: None, edits, edit_preview: None, @@ -182,7 +182,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { buffer: &Entity<Buffer>, cursor_position: Anchor, cx: &mut Context<Self>, - ) -> Option<InlineCompletion> { + ) -> Option<EditPrediction> { let completion_text = self .supermaven .read(cx) diff --git a/crates/tasks_ui/src/modal.rs b/crates/tasks_ui/src/modal.rs index 1510f613e3..c4b0931c35 100644 --- a/crates/tasks_ui/src/modal.rs +++ b/crates/tasks_ui/src/modal.rs @@ -500,7 +500,7 @@ impl PickerDelegate for TasksModalDelegate { .map(|icon| icon.color(Color::Muted).size(IconSize::Small)); let indicator = if matches!(source_kind, TaskSourceKind::Lsp { .. }) { Some(Indicator::icon( - Icon::new(IconName::Bolt).size(IconSize::Small), + Icon::new(IconName::BoltOutlined).size(IconSize::Small), )) } else { None diff --git a/crates/telemetry_events/src/telemetry_events.rs b/crates/telemetry_events/src/telemetry_events.rs index dfe167fcd4..735a1310ae 100644 --- a/crates/telemetry_events/src/telemetry_events.rs +++ b/crates/telemetry_events/src/telemetry_events.rs @@ -94,8 +94,8 @@ impl Display for AssistantPhase { pub enum Event { Flexible(FlexibleEvent), Editor(EditorEvent), - InlineCompletion(InlineCompletionEvent), - InlineCompletionRating(InlineCompletionRatingEvent), + EditPrediction(EditPredictionEvent), + EditPredictionRating(EditPredictionRatingEvent), Call(CallEvent), Assistant(AssistantEventData), Cpu(CpuEvent), @@ -132,7 +132,7 @@ pub struct EditorEvent { } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct InlineCompletionEvent { +pub struct EditPredictionEvent { /// Provider of the completion suggestion (e.g. copilot, supermaven) pub provider: String, pub suggestion_accepted: bool, @@ -140,14 +140,14 @@ pub struct InlineCompletionEvent { } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum InlineCompletionRating { +pub enum EditPredictionRating { Positive, Negative, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct InlineCompletionRatingEvent { - pub rating: InlineCompletionRating, +pub struct EditPredictionRatingEvent { + pub rating: EditPredictionRating, pub input_events: Arc<str>, pub input_excerpt: Arc<str>, pub output_excerpt: Arc<str>, diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 1cc1fbcf6f..2e6be5aaf4 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -430,6 +430,7 @@ impl TerminalView { fn settings_changed(&mut self, cx: &mut Context<Self>) { let settings = TerminalSettings::get_global(cx); + let breadcrumb_visibility_changed = self.show_breadcrumbs != settings.toolbar.breadcrumbs; self.show_breadcrumbs = settings.toolbar.breadcrumbs; let new_cursor_shape = settings.cursor_shape.unwrap_or_default(); @@ -441,6 +442,9 @@ impl TerminalView { }); } + if breadcrumb_visibility_changed { + cx.emit(ItemEvent::UpdateBreadcrumbs); + } cx.notify(); } @@ -1587,7 +1591,7 @@ impl Item for TerminalView { let (icon, icon_color, rerun_button) = match terminal.task() { Some(terminal_task) => match &terminal_task.status { TaskStatus::Running => ( - IconName::Play, + IconName::PlayOutlined, Color::Disabled, TerminalView::rerun_button(&terminal_task), ), diff --git a/crates/text/src/anchor.rs b/crates/text/src/anchor.rs index 5807d3aae0..c4778216e0 100644 --- a/crates/text/src/anchor.rs +++ b/crates/text/src/anchor.rs @@ -3,7 +3,7 @@ use crate::{ locator::Locator, }; use std::{cmp::Ordering, fmt::Debug, ops::Range}; -use sum_tree::Bias; +use sum_tree::{Bias, Dimensions}; /// A timestamped position in a buffer #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash, Default)] @@ -99,8 +99,12 @@ impl Anchor { } else if self.buffer_id != Some(buffer.remote_id) { false } else { - let fragment_id = buffer.fragment_id_for_anchor(self); - let mut fragment_cursor = buffer.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let Some(fragment_id) = buffer.try_fragment_id_for_anchor(self) else { + return false; + }; + let mut fragment_cursor = buffer + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); fragment_cursor.seek(&Some(fragment_id), Bias::Left); fragment_cursor .item() diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index c1da0649da..68c7b2a2cd 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -37,7 +37,7 @@ use std::{ }; pub use subscription::*; pub use sum_tree::Bias; -use sum_tree::{FilterCursor, SumTree, TreeMap, TreeSet}; +use sum_tree::{Dimensions, FilterCursor, SumTree, TreeMap, TreeSet}; use undo_map::UndoMap; #[cfg(any(test, feature = "test-support"))] @@ -1071,7 +1071,9 @@ impl Buffer { let mut insertion_offset = 0; let mut new_ropes = RopeBuilder::new(self.visible_text.cursor(0), self.deleted_text.cursor(0)); - let mut old_fragments = self.fragments.cursor::<(VersionedFullOffset, usize)>(&cx); + let mut old_fragments = self + .fragments + .cursor::<Dimensions<VersionedFullOffset, usize>>(&cx); let mut new_fragments = old_fragments.slice(&VersionedFullOffset::Offset(ranges[0].start), Bias::Left); new_ropes.append(new_fragments.summary().text); @@ -1298,7 +1300,9 @@ impl Buffer { self.snapshot.undo_map.insert(undo); let mut edits = Patch::default(); - let mut old_fragments = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut old_fragments = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); let mut new_fragments = SumTree::new(&None); let mut new_ropes = RopeBuilder::new(self.visible_text.cursor(0), self.deleted_text.cursor(0)); @@ -1561,7 +1565,9 @@ impl Buffer { D: TextDimension, { // get fragment ranges - let mut cursor = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut cursor = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); let offset_ranges = self .fragment_ids_for_edits(edit_ids.into_iter()) .into_iter() @@ -2232,7 +2238,9 @@ impl BufferSnapshot { { let anchors = anchors.into_iter(); let mut insertion_cursor = self.insertions.cursor::<InsertionFragmentKey>(&()); - let mut fragment_cursor = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut fragment_cursor = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); let mut text_cursor = self.visible_text.cursor(0); let mut position = D::zero(&()); @@ -2318,7 +2326,9 @@ impl BufferSnapshot { ); }; - let mut fragment_cursor = self.fragments.cursor::<(Option<&Locator>, usize)>(&None); + let mut fragment_cursor = self + .fragments + .cursor::<Dimensions<Option<&Locator>, usize>>(&None); fragment_cursor.seek(&Some(&insertion.fragment_id), Bias::Left); let fragment = fragment_cursor.item().unwrap(); let mut fragment_offset = fragment_cursor.start().1; @@ -2330,10 +2340,19 @@ impl BufferSnapshot { } fn fragment_id_for_anchor(&self, anchor: &Anchor) -> &Locator { + self.try_fragment_id_for_anchor(anchor).unwrap_or_else(|| { + panic!( + "invalid anchor {:?}. buffer id: {}, version: {:?}", + anchor, self.remote_id, self.version, + ) + }) + } + + fn try_fragment_id_for_anchor(&self, anchor: &Anchor) -> Option<&Locator> { if *anchor == Anchor::MIN { - Locator::min_ref() + Some(Locator::min_ref()) } else if *anchor == Anchor::MAX { - Locator::max_ref() + Some(Locator::max_ref()) } else { let anchor_key = InsertionFragmentKey { timestamp: anchor.timestamp, @@ -2354,20 +2373,12 @@ impl BufferSnapshot { insertion_cursor.prev(); } - let Some(insertion) = insertion_cursor.item().filter(|insertion| { - if cfg!(debug_assertions) { - insertion.timestamp == anchor.timestamp - } else { - true - } - }) else { - panic!( - "invalid anchor {:?}. buffer id: {}, version: {:?}", - anchor, self.remote_id, self.version - ); - }; - - &insertion.fragment_id + insertion_cursor + .item() + .filter(|insertion| { + !cfg!(debug_assertions) || insertion.timestamp == anchor.timestamp + }) + .map(|insertion| &insertion.fragment_id) } } @@ -2475,7 +2486,7 @@ impl BufferSnapshot { }; let mut cursor = self .fragments - .cursor::<(Option<&Locator>, FragmentTextSummary)>(&None); + .cursor::<Dimensions<Option<&Locator>, FragmentTextSummary>>(&None); let start_fragment_id = self.fragment_id_for_anchor(&range.start); cursor.seek(&Some(start_fragment_id), Bias::Left); diff --git a/crates/theme/src/icon_theme.rs b/crates/theme/src/icon_theme.rs index 09f5df06b0..10fd1e002d 100644 --- a/crates/theme/src/icon_theme.rs +++ b/crates/theme/src/icon_theme.rs @@ -152,6 +152,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ("javascript", &["cjs", "js", "mjs"]), ("json", &["json"]), ("julia", &["jl"]), + ("kdl", &["kdl"]), ("kotlin", &["kt"]), ("lock", &["lock"]), ("log", &["log"]), @@ -216,6 +217,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ "stylelintrc.yml", ], ), + ("surrealql", &["surql"]), ("svelte", &["svelte"]), ("swift", &["swift"]), ("tcl", &["tcl"]), @@ -314,6 +316,7 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("javascript", "icons/file_icons/javascript.svg"), ("json", "icons/file_icons/code.svg"), ("julia", "icons/file_icons/julia.svg"), + ("kdl", "icons/file_icons/kdl.svg"), ("kotlin", "icons/file_icons/kotlin.svg"), ("lock", "icons/file_icons/lock.svg"), ("log", "icons/file_icons/info.svg"), @@ -340,6 +343,7 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("solidity", "icons/file_icons/file.svg"), ("storage", "icons/file_icons/database.svg"), ("stylelint", "icons/file_icons/javascript.svg"), + ("surrealql", "icons/file_icons/surrealql.svg"), ("svelte", "icons/file_icons/html.svg"), ("swift", "icons/file_icons/swift.svg"), ("tcl", "icons/file_icons/tcl.svg"), diff --git a/crates/theme/src/settings.rs b/crates/theme/src/settings.rs index 1c4c90a475..20c837f287 100644 --- a/crates/theme/src/settings.rs +++ b/crates/theme/src/settings.rs @@ -438,7 +438,7 @@ fn default_font_fallbacks() -> Option<FontFallbacks> { impl ThemeSettingsContent { /// Sets the theme for the given appearance to the theme with the specified name. - pub fn set_theme(&mut self, theme_name: String, appearance: Appearance) { + pub fn set_theme(&mut self, theme_name: impl Into<Arc<str>>, appearance: Appearance) { if let Some(selection) = self.theme.as_mut() { let theme_to_update = match selection { ThemeSelection::Static(theme) => theme, @@ -867,6 +867,7 @@ impl settings::Settings for ThemeSettings { .user .into_iter() .chain(sources.release_channel) + .chain(sources.profile) .chain(sources.server) { if let Some(value) = value.ui_density { diff --git a/crates/theme_selector/src/icon_theme_selector.rs b/crates/theme_selector/src/icon_theme_selector.rs index 1adfc4b5d8..2d0b9480d5 100644 --- a/crates/theme_selector/src/icon_theme_selector.rs +++ b/crates/theme_selector/src/icon_theme_selector.rs @@ -40,7 +40,10 @@ impl IconThemeSelector { impl Render for IconThemeSelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("IconThemeSelector") + .w(rems(34.)) + .child(self.picker.clone()) } } diff --git a/crates/theme_selector/src/theme_selector.rs b/crates/theme_selector/src/theme_selector.rs index 022daced7a..ba8bde243b 100644 --- a/crates/theme_selector/src/theme_selector.rs +++ b/crates/theme_selector/src/theme_selector.rs @@ -92,7 +92,10 @@ impl Focusable for ThemeSelector { impl Render for ThemeSelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("ThemeSelector") + .w(rems(34.)) + .child(self.picker.clone()) } } diff --git a/crates/title_bar/Cargo.toml b/crates/title_bar/Cargo.toml index 8e95c6f79f..cf178e2850 100644 --- a/crates/title_bar/Cargo.toml +++ b/crates/title_bar/Cargo.toml @@ -32,6 +32,7 @@ auto_update.workspace = true call.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true db.workspace = true gpui = { workspace = true, features = ["screen-capture"] } notifications.workspace = true diff --git a/crates/title_bar/src/collab.rs b/crates/title_bar/src/collab.rs index 056c981ccf..d026b4de14 100644 --- a/crates/title_bar/src/collab.rs +++ b/crates/title_bar/src/collab.rs @@ -11,8 +11,8 @@ use gpui::{App, Task, Window, actions}; use rpc::proto::{self}; use theme::ActiveTheme; use ui::{ - Avatar, AvatarAudioStatusIndicator, ContextMenu, ContextMenuItem, Divider, Facepile, - PopoverMenu, SplitButton, SplitButtonStyle, TintColor, Tooltip, prelude::*, + Avatar, AvatarAudioStatusIndicator, ContextMenu, ContextMenuItem, Divider, DividerColor, + Facepile, PopoverMenu, SplitButton, SplitButtonStyle, TintColor, Tooltip, prelude::*, }; use util::maybe; use workspace::notifications::DetachAndPromptErr; @@ -343,6 +343,24 @@ impl TitleBar { let mut children = Vec::new(); + children.push( + h_flex() + .gap_1() + .child( + IconButton::new("leave-call", IconName::Exit) + .style(ButtonStyle::Subtle) + .tooltip(Tooltip::text("Leave Call")) + .icon_size(IconSize::Small) + .on_click(move |_, _window, cx| { + ActiveCall::global(cx) + .update(cx, |call, cx| call.hang_up(cx)) + .detach_and_log_err(cx); + }), + ) + .child(Divider::vertical().color(DividerColor::Border)) + .into_any_element(), + ); + if is_local && can_share_projects && !is_connecting_to_project { children.push( Button::new( @@ -369,32 +387,14 @@ impl TitleBar { ); } - children.push( - div() - .pr_2() - .child( - IconButton::new("leave-call", ui::IconName::Exit) - .style(ButtonStyle::Subtle) - .tooltip(Tooltip::text("Leave call")) - .icon_size(IconSize::Small) - .on_click(move |_, _window, cx| { - ActiveCall::global(cx) - .update(cx, |call, cx| call.hang_up(cx)) - .detach_and_log_err(cx); - }), - ) - .child(Divider::vertical()) - .into_any_element(), - ); - if can_use_microphone { children.push( IconButton::new( "mute-microphone", if is_muted { - ui::IconName::MicMute + IconName::MicMute } else { - ui::IconName::Mic + IconName::Mic }, ) .tooltip(move |window, cx| { @@ -429,9 +429,9 @@ impl TitleBar { IconButton::new( "mute-sound", if is_deafened { - ui::IconName::AudioOff + IconName::AudioOff } else { - ui::IconName::AudioOn + IconName::AudioOn }, ) .style(ButtonStyle::Subtle) @@ -462,7 +462,7 @@ impl TitleBar { ); if can_use_microphone && screen_sharing_supported { - let trigger = IconButton::new("screen-share", ui::IconName::Screen) + let trigger = IconButton::new("screen-share", IconName::Screen) .style(ButtonStyle::Subtle) .icon_size(IconSize::Small) .toggle_state(is_screen_sharing) @@ -498,7 +498,7 @@ impl TitleBar { trigger.render(window, cx), self.render_screen_list().into_any_element(), ) - .style(SplitButtonStyle::Outlined) + .style(SplitButtonStyle::Transparent) .into_any_element(), ); } @@ -513,11 +513,11 @@ impl TitleBar { .with_handle(self.screen_share_popover_handle.clone()) .trigger( ui::ButtonLike::new_rounded_right("screen-share-screen-list-trigger") - .layer(ui::ElevationIndex::ModalSurface) - .size(ui::ButtonSize::None) .child( - div() - .px_1() + h_flex() + .mx_neg_0p5() + .h_full() + .justify_center() .child(Icon::new(IconName::ChevronDownSmall).size(IconSize::XSmall)), ) .toggle_state(self.screen_share_popover_handle.is_deployed()), diff --git a/crates/title_bar/src/platform_title_bar.rs b/crates/title_bar/src/platform_title_bar.rs index 30b1b4c3f8..ef6ef93eed 100644 --- a/crates/title_bar/src/platform_title_bar.rs +++ b/crates/title_bar/src/platform_title_bar.rs @@ -106,14 +106,14 @@ impl Render for PlatformTitleBar { // Note: On Windows the title bar behavior is handled by the platform implementation. .when(self.platform_style == PlatformStyle::Mac, |this| { this.on_click(|event, window, _| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.titlebar_double_click(); } }) }) .when(self.platform_style == PlatformStyle::Linux, |this| { this.on_click(|event, window, _| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.zoom_window(); } }) diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 17c4c85b6d..a8b16d881f 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -21,6 +21,7 @@ use crate::application_menu::{ use auto_update::AutoUpdateStatus; use call::ActiveCall; use client::{Client, UserStore, zed_urls}; +use cloud_llm_client::Plan; use gpui::{ Action, AnyElement, App, Context, Corner, Element, Entity, Focusable, InteractiveElement, IntoElement, MouseButton, ParentElement, Render, StatefulInteractiveElement, Styled, @@ -28,7 +29,6 @@ use gpui::{ }; use onboarding_banner::OnboardingBanner; use project::Project; -use rpc::proto; use settings::Settings as _; use settings_ui::keybindings; use std::sync::Arc; @@ -179,24 +179,23 @@ impl Render for TitleBar { children.push(self.banner.clone().into_any_element()) } + let status = self.client.status(); + let status = &*status.borrow(); + let user = self.user_store.read(cx).current_user(); + children.push( h_flex() .gap_1() .pr_1() .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()) .children(self.render_call_controls(window, cx)) - .map(|el| { - let status = self.client.status(); - let status = &*status.borrow(); - if matches!(status, client::Status::Connected { .. }) { - el.child(self.render_user_menu_button(cx)) - } else { - el.children(self.render_connection_status(status, cx)) - .when(TitleBarSettings::get_global(cx).show_sign_in, |el| { - el.child(self.render_sign_in_button(cx)) - }) - .child(self.render_user_menu_button(cx)) - } + .children(self.render_connection_status(status, cx)) + .when( + user.is_none() && TitleBarSettings::get_global(cx).show_sign_in, + |el| el.child(self.render_sign_in_button(cx)), + ) + .when(user.is_some(), |parent| { + parent.child(self.render_user_menu_button(cx)) }) .into_any_element(), ); @@ -618,9 +617,8 @@ impl TitleBar { window .spawn(cx, async move |cx| { client - .authenticate_and_connect(true, &cx) + .sign_in_with_optional_connect(true, &cx) .await - .into_response() .notify_async_err(cx); }) .detach(); @@ -630,8 +628,8 @@ impl TitleBar { pub fn render_user_menu_button(&mut self, cx: &mut Context<Self>) -> impl Element { let user_store = self.user_store.read(cx); if let Some(user) = user_store.current_user() { - let has_subscription_period = self.user_store.read(cx).subscription_period().is_some(); - let plan = self.user_store.read(cx).current_plan().filter(|_| { + let has_subscription_period = user_store.subscription_period().is_some(); + let plan = user_store.plan().filter(|_| { // Since the user might be on the legacy free plan we filter based on whether we have a subscription period. has_subscription_period }); @@ -658,13 +656,9 @@ impl TitleBar { let user_login = user.github_login.clone(); let (plan_name, label_color, bg_color) = match plan { - None | Some(proto::Plan::Free) => { - ("Free", Color::Default, free_chip_bg) - } - Some(proto::Plan::ZedProTrial) => { - ("Pro Trial", Color::Accent, pro_chip_bg) - } - Some(proto::Plan::ZedPro) => ("Pro", Color::Accent, pro_chip_bg), + None | Some(Plan::ZedFree) => ("Free", Color::Default, free_chip_bg), + Some(Plan::ZedProTrial) => ("Pro Trial", Color::Accent, pro_chip_bg), + Some(Plan::ZedPro) => ("Pro", Color::Accent, pro_chip_bg), }; menu.custom_entry( @@ -688,6 +682,10 @@ impl TitleBar { ) .separator() .action("Settings", zed_actions::OpenSettings.boxed_clone()) + .action( + "Settings Profiles", + zed_actions::settings_profile_selector::Toggle.boxed_clone(), + ) .action("Key Bindings", Box::new(keybindings::OpenKeymapEditor)) .action( "Themes…", @@ -732,6 +730,10 @@ impl TitleBar { .menu(|window, cx| { ContextMenu::build(window, cx, |menu, _, _| { menu.action("Settings", zed_actions::OpenSettings.boxed_clone()) + .action( + "Settings Profiles", + zed_actions::settings_profile_selector::Toggle.boxed_clone(), + ) .action("Key Bindings", Box::new(keybindings::OpenKeymapEditor)) .action( "Themes…", diff --git a/crates/ui/src/components.rs b/crates/ui/src/components.rs index 9c2961c55f..486673e733 100644 --- a/crates/ui/src/components.rs +++ b/crates/ui/src/components.rs @@ -1,4 +1,5 @@ mod avatar; +mod badge; mod banner; mod button; mod callout; @@ -41,6 +42,7 @@ mod tooltip; mod stories; pub use avatar::*; +pub use badge::*; pub use banner::*; pub use button::*; pub use callout::*; diff --git a/crates/ui/src/components/badge.rs b/crates/ui/src/components/badge.rs new file mode 100644 index 0000000000..f36e03291c --- /dev/null +++ b/crates/ui/src/components/badge.rs @@ -0,0 +1,94 @@ +use std::rc::Rc; + +use crate::Divider; +use crate::DividerColor; +use crate::Tooltip; +use crate::component_prelude::*; +use crate::prelude::*; +use gpui::AnyView; +use gpui::{AnyElement, IntoElement, SharedString, Window}; + +#[derive(IntoElement, RegisterComponent)] +pub struct Badge { + label: SharedString, + icon: IconName, + tooltip: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView>>, +} + +impl Badge { + pub fn new(label: impl Into<SharedString>) -> Self { + Self { + label: label.into(), + icon: IconName::Check, + tooltip: None, + } + } + + pub fn icon(mut self, icon: IconName) -> Self { + self.icon = icon; + self + } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } +} + +impl RenderOnce for Badge { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let tooltip = self.tooltip; + + h_flex() + .id(self.label.clone()) + .h_full() + .gap_1() + .pl_1() + .pr_2() + .border_1() + .border_color(cx.theme().colors().border.opacity(0.6)) + .bg(cx.theme().colors().element_background) + .rounded_sm() + .overflow_hidden() + .child( + Icon::new(self.icon) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(Divider::vertical().color(DividerColor::Border)) + .child(Label::new(self.label.clone()).size(LabelSize::Small).ml_1()) + .when_some(tooltip, |this, tooltip| { + this.tooltip(move |window, cx| tooltip(window, cx)) + }) + } +} + +impl Component for Badge { + fn scope() -> ComponentScope { + ComponentScope::DataDisplay + } + + fn description() -> Option<&'static str> { + Some( + "A compact, labeled component with optional icon for displaying status, categories, or metadata.", + ) + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { + Some( + v_flex() + .gap_6() + .child(single_example( + "Basic Badge", + Badge::new("Default").into_any_element(), + )) + .child(single_example( + "With Tooltip", + Badge::new("Tooltip") + .tooltip(Tooltip::text("This is a tooltip.")) + .into_any_element(), + )) + .into_any_element(), + ) + } +} diff --git a/crates/ui/src/components/banner.rs b/crates/ui/src/components/banner.rs index b16ca795b4..d88905d466 100644 --- a/crates/ui/src/components/banner.rs +++ b/crates/ui/src/components/banner.rs @@ -131,7 +131,7 @@ impl RenderOnce for Banner { impl Component for Banner { fn scope() -> ComponentScope { - ComponentScope::Notification + ComponentScope::DataDisplay } fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { diff --git a/crates/ui/src/components/button/button.rs b/crates/ui/src/components/button/button.rs index cae5d0e2ca..19f782fb98 100644 --- a/crates/ui/src/components/button/button.rs +++ b/crates/ui/src/components/button/button.rs @@ -393,6 +393,11 @@ impl ButtonCommon for Button { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.base = self.base.tab_index(tab_index); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.base = self.base.layer(elevation); self diff --git a/crates/ui/src/components/button/button_like.rs b/crates/ui/src/components/button/button_like.rs index 135ecdfe62..35c78fbb5d 100644 --- a/crates/ui/src/components/button/button_like.rs +++ b/crates/ui/src/components/button/button_like.rs @@ -1,7 +1,8 @@ use documented::Documented; use gpui::{ AnyElement, AnyView, ClickEvent, CursorStyle, DefiniteLength, Hsla, MouseButton, - MouseDownEvent, MouseUpEvent, Rems, relative, transparent_black, + MouseClickEvent, MouseDownEvent, MouseUpEvent, Rems, StyleRefinement, relative, + transparent_black, }; use smallvec::SmallVec; @@ -37,6 +38,8 @@ pub trait ButtonCommon: Clickable + Disableable { /// exceptions might a scroll bar, or a slider. fn tooltip(self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self; + fn tab_index(self, tab_index: impl Into<isize>) -> Self; + fn layer(self, elevation: ElevationIndex) -> Self; } @@ -358,6 +361,7 @@ impl ButtonStyle { #[derive(Default, PartialEq, Clone, Copy)] pub enum ButtonSize { Large, + Medium, #[default] Default, Compact, @@ -368,6 +372,7 @@ impl ButtonSize { pub fn rems(self) -> Rems { match self { ButtonSize::Large => rems_from_px(32.), + ButtonSize::Medium => rems_from_px(28.), ButtonSize::Default => rems_from_px(22.), ButtonSize::Compact => rems_from_px(18.), ButtonSize::None => rems_from_px(16.), @@ -391,6 +396,7 @@ pub struct ButtonLike { pub(super) width: Option<DefiniteLength>, pub(super) height: Option<DefiniteLength>, pub(super) layer: Option<ElevationIndex>, + tab_index: Option<isize>, size: ButtonSize, rounding: Option<ButtonLikeRounding>, tooltip: Option<Box<dyn Fn(&mut Window, &mut App) -> AnyView>>, @@ -419,6 +425,7 @@ impl ButtonLike { on_click: None, on_right_click: None, layer: None, + tab_index: None, } } @@ -523,6 +530,11 @@ impl ButtonCommon for ButtonLike { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.tab_index = Some(tab_index.into()); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.layer = Some(elevation); self @@ -552,6 +564,7 @@ impl RenderOnce for ButtonLike { self.base .h_flex() .id(self.id.clone()) + .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)) .font_ui(cx) .group("") .flex_none() @@ -573,7 +586,7 @@ impl RenderOnce for ButtonLike { }) .gap(DynamicSpacing::Base04.rems(cx)) .map(|this| match self.size { - ButtonSize::Large => this.px(DynamicSpacing::Base06.rems(cx)), + ButtonSize::Large | ButtonSize::Medium => this.px(DynamicSpacing::Base06.rems(cx)), ButtonSize::Default | ButtonSize::Compact => { this.px(DynamicSpacing::Base04.rems(cx)) } @@ -589,8 +602,12 @@ impl RenderOnce for ButtonLike { } }) .when(!self.disabled, |this| { + let hovered_style = style.hovered(self.layer, cx); + let focus_color = + |refinement: StyleRefinement| refinement.bg(hovered_style.background); this.cursor(self.cursor_style) - .hover(|hover| hover.bg(style.hovered(self.layer, cx).background)) + .hover(focus_color) + .focus(focus_color) .active(|active| active.bg(style.active(cx).background)) }) .when_some( @@ -604,7 +621,7 @@ impl RenderOnce for ButtonLike { MouseButton::Right, move |event, window, cx| { cx.stop_propagation(); - let click_event = ClickEvent { + let click_event = ClickEvent::Mouse(MouseClickEvent { down: MouseDownEvent { button: MouseButton::Right, position: event.position, @@ -618,7 +635,7 @@ impl RenderOnce for ButtonLike { modifiers: event.modifiers, click_count: 1, }, - }; + }); (on_right_click)(&click_event, window, cx) }, ) diff --git a/crates/ui/src/components/button/icon_button.rs b/crates/ui/src/components/button/icon_button.rs index e5d13e09cd..8d8718a634 100644 --- a/crates/ui/src/components/button/icon_button.rs +++ b/crates/ui/src/components/button/icon_button.rs @@ -164,6 +164,11 @@ impl ButtonCommon for IconButton { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.base = self.base.tab_index(tab_index); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.base = self.base.layer(elevation); self diff --git a/crates/ui/src/components/button/split_button.rs b/crates/ui/src/components/button/split_button.rs index a7fa2106d1..14b9fd153c 100644 --- a/crates/ui/src/components/button/split_button.rs +++ b/crates/ui/src/components/button/split_button.rs @@ -12,6 +12,7 @@ use super::ButtonLike; pub enum SplitButtonStyle { Filled, Outlined, + Transparent, } /// /// A button with two parts: a primary action on the left and a secondary action on the right. @@ -44,10 +45,17 @@ impl SplitButton { impl RenderOnce for SplitButton { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let is_filled_or_outlined = matches!( + self.style, + SplitButtonStyle::Filled | SplitButtonStyle::Outlined + ); + h_flex() .rounded_sm() - .border_1() - .border_color(cx.theme().colors().border.opacity(0.5)) + .when(is_filled_or_outlined, |this| { + this.border_1() + .border_color(cx.theme().colors().border.opacity(0.8)) + }) .child(div().flex_grow().child(self.left)) .child( div() diff --git a/crates/ui/src/components/button/toggle_button.rs b/crates/ui/src/components/button/toggle_button.rs index eca23fe6f7..6fbf834667 100644 --- a/crates/ui/src/components/button/toggle_button.rs +++ b/crates/ui/src/components/button/toggle_button.rs @@ -1,6 +1,6 @@ use gpui::{AnyView, ClickEvent}; -use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, prelude::*}; +use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, TintColor, prelude::*}; /// The position of a [`ToggleButton`] within a group of buttons. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -121,6 +121,11 @@ impl ButtonCommon for ToggleButton { self } + fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.base = self.base.tab_index(tab_index); + self + } + fn layer(mut self, elevation: ElevationIndex) -> Self { self.base = self.base.layer(elevation); self @@ -290,3 +295,632 @@ impl Component for ToggleButton { ) } } + +pub struct ButtonConfiguration { + label: SharedString, + icon: Option<IconName>, + on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, + selected: bool, +} + +mod private { + pub trait ToggleButtonStyle {} +} + +pub trait ButtonBuilder: 'static + private::ToggleButtonStyle { + fn into_configuration(self) -> ButtonConfiguration; +} + +pub struct ToggleButtonSimple { + label: SharedString, + on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, + selected: bool, +} + +impl ToggleButtonSimple { + pub fn new( + label: impl Into<SharedString>, + on_click: impl Fn(&ClickEvent, &mut Window, &mut App) + 'static, + ) -> Self { + Self { + label: label.into(), + on_click: Box::new(on_click), + selected: false, + } + } + + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self + } +} + +impl private::ToggleButtonStyle for ToggleButtonSimple {} + +impl ButtonBuilder for ToggleButtonSimple { + fn into_configuration(self) -> ButtonConfiguration { + ButtonConfiguration { + label: self.label, + icon: None, + on_click: self.on_click, + selected: self.selected, + } + } +} + +pub struct ToggleButtonWithIcon { + label: SharedString, + icon: IconName, + on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, + selected: bool, +} + +impl ToggleButtonWithIcon { + pub fn new( + label: impl Into<SharedString>, + icon: IconName, + on_click: impl Fn(&ClickEvent, &mut Window, &mut App) + 'static, + ) -> Self { + Self { + label: label.into(), + icon, + on_click: Box::new(on_click), + selected: false, + } + } + + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self + } +} + +impl private::ToggleButtonStyle for ToggleButtonWithIcon {} + +impl ButtonBuilder for ToggleButtonWithIcon { + fn into_configuration(self) -> ButtonConfiguration { + ButtonConfiguration { + label: self.label, + icon: Some(self.icon), + on_click: self.on_click, + selected: self.selected, + } + } +} + +#[derive(Clone, Copy, PartialEq)] +pub enum ToggleButtonGroupStyle { + Transparent, + Filled, + Outlined, +} + +#[derive(Clone, Copy, PartialEq)] +pub enum ToggleButtonGroupSize { + Default, + Medium, +} + +#[derive(IntoElement)] +pub struct ToggleButtonGroup<T, const COLS: usize = 3, const ROWS: usize = 1> +where + T: ButtonBuilder, +{ + group_name: &'static str, + rows: [[T; COLS]; ROWS], + style: ToggleButtonGroupStyle, + size: ToggleButtonGroupSize, + button_width: Rems, + selected_index: usize, + tab_index: Option<isize>, +} + +impl<T: ButtonBuilder, const COLS: usize> ToggleButtonGroup<T, COLS> { + pub fn single_row(group_name: &'static str, buttons: [T; COLS]) -> Self { + Self { + group_name, + rows: [buttons], + style: ToggleButtonGroupStyle::Transparent, + size: ToggleButtonGroupSize::Default, + button_width: rems_from_px(100.), + selected_index: 0, + tab_index: None, + } + } +} + +impl<T: ButtonBuilder, const COLS: usize> ToggleButtonGroup<T, COLS, 2> { + pub fn two_rows(group_name: &'static str, first_row: [T; COLS], second_row: [T; COLS]) -> Self { + Self { + group_name, + rows: [first_row, second_row], + style: ToggleButtonGroupStyle::Transparent, + size: ToggleButtonGroupSize::Default, + button_width: rems_from_px(100.), + selected_index: 0, + tab_index: None, + } + } +} + +impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> ToggleButtonGroup<T, COLS, ROWS> { + pub fn style(mut self, style: ToggleButtonGroupStyle) -> Self { + self.style = style; + self + } + + pub fn size(mut self, size: ToggleButtonGroupSize) -> Self { + self.size = size; + self + } + + pub fn button_width(mut self, button_width: Rems) -> Self { + self.button_width = button_width; + self + } + + pub fn selected_index(mut self, index: usize) -> Self { + self.selected_index = index; + self + } + + /// Sets the tab index for the toggle button group. + /// The tab index is set to the initial value provided, then the + /// value is incremented by the number of buttons in the group. + pub fn tab_index(mut self, tab_index: &mut isize) -> Self { + self.tab_index = Some(*tab_index); + *tab_index += (COLS * ROWS) as isize; + self + } +} + +impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> RenderOnce + for ToggleButtonGroup<T, COLS, ROWS> +{ + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let entries = + self.rows.into_iter().enumerate().map(|(row_index, row)| { + row.into_iter().enumerate().map(move |(col_index, button)| { + let ButtonConfiguration { + label, + icon, + on_click, + selected, + } = button.into_configuration(); + + let entry_index = row_index * COLS + col_index; + + ButtonLike::new((self.group_name, entry_index)) + .when_some(self.tab_index, |this, tab_index| { + this.tab_index(tab_index + entry_index as isize) + }) + .when(entry_index == self.selected_index || selected, |this| { + this.toggle_state(true) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) + }) + .rounding(None) + .when(self.style == ToggleButtonGroupStyle::Filled, |button| { + button.style(ButtonStyle::Filled) + }) + .when(self.size == ToggleButtonGroupSize::Medium, |button| { + button.size(ButtonSize::Medium) + }) + .child( + h_flex() + .min_w(self.button_width) + .gap_1p5() + .px_3() + .py_1() + .justify_center() + .when_some(icon, |this, icon| { + this.py_2() + .child(Icon::new(icon).size(IconSize::XSmall).map(|this| { + if entry_index == self.selected_index || selected { + this.color(Color::Accent) + } else { + this.color(Color::Muted) + } + })) + }) + .child(Label::new(label).size(LabelSize::Small).when( + entry_index == self.selected_index || selected, + |this| this.color(Color::Accent), + )), + ) + .on_click(on_click) + .into_any_element() + }) + }); + + let border_color = cx.theme().colors().border.opacity(0.6); + let is_outlined_or_filled = self.style == ToggleButtonGroupStyle::Outlined + || self.style == ToggleButtonGroupStyle::Filled; + let is_transparent = self.style == ToggleButtonGroupStyle::Transparent; + + v_flex() + .rounded_md() + .overflow_hidden() + .map(|this| { + if is_transparent { + this.gap_px() + } else { + this.border_1().border_color(border_color) + } + }) + .children(entries.enumerate().map(|(row_index, row)| { + let last_row = row_index == ROWS - 1; + h_flex() + .when(!is_outlined_or_filled, |this| this.gap_px()) + .when(is_outlined_or_filled && !last_row, |this| { + this.border_b_1().border_color(border_color) + }) + .children(row.enumerate().map(|(item_index, item)| { + let last_item = item_index == COLS - 1; + div() + .when(is_outlined_or_filled && !last_item, |this| { + this.border_r_1().border_color(border_color) + }) + .child(item) + })) + })) + } +} + +fn register_toggle_button_group() { + component::register_component::<ToggleButtonGroup<ToggleButtonSimple>>(); +} + +component::__private::inventory::submit! { + component::ComponentFn::new(register_toggle_button_group) +} + +impl<T: ButtonBuilder, const COLS: usize, const ROWS: usize> Component + for ToggleButtonGroup<T, COLS, ROWS> +{ + fn name() -> &'static str { + "ToggleButtonGroup" + } + + fn scope() -> ComponentScope { + ComponentScope::Input + } + + fn sort_name() -> &'static str { + "ButtonG" + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { + Some( + v_flex() + .gap_6() + .children(vec![example_group_with_title( + "Transparent Variant", + vec![ + single_example( + "Single Row Group", + ToggleButtonGroup::single_row( + "single_row_test", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .into_any_element(), + ), + single_example( + "Single Row Group with icons", + ToggleButtonGroup::single_row( + "single_row_test_icon", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .into_any_element(), + ), + single_example( + "Multiple Row Group", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + [ + ToggleButtonSimple::new("Fourth", |_, _, _| {}), + ToggleButtonSimple::new("Fifth", |_, _, _| {}), + ToggleButtonSimple::new("Sixth", |_, _, _| {}), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .into_any_element(), + ), + single_example( + "Multiple Row Group with Icons", + ToggleButtonGroup::two_rows( + "multiple_row_test_icons", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + [ + ToggleButtonWithIcon::new( + "Fourth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Fifth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Sixth", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .into_any_element(), + ), + ], + )]) + .children(vec![example_group_with_title( + "Outlined Variant", + vec![ + single_example( + "Single Row Group", + ToggleButtonGroup::single_row( + "single_row_test_outline", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + ) + .selected_index(1) + .style(ToggleButtonGroupStyle::Outlined) + .into_any_element(), + ), + single_example( + "Single Row Group with icons", + ToggleButtonGroup::single_row( + "single_row_test_icon_outlined", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Outlined) + .into_any_element(), + ), + single_example( + "Multiple Row Group", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + [ + ToggleButtonSimple::new("Fourth", |_, _, _| {}), + ToggleButtonSimple::new("Fifth", |_, _, _| {}), + ToggleButtonSimple::new("Sixth", |_, _, _| {}), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Outlined) + .into_any_element(), + ), + single_example( + "Multiple Row Group with Icons", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + [ + ToggleButtonWithIcon::new( + "Fourth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Fifth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Sixth", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Outlined) + .into_any_element(), + ), + ], + )]) + .children(vec![example_group_with_title( + "Filled Variant", + vec![ + single_example( + "Single Row Group", + ToggleButtonGroup::single_row( + "single_row_test_outline", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + ) + .selected_index(2) + .style(ToggleButtonGroupStyle::Filled) + .into_any_element(), + ), + single_example( + "Single Row Group with icons", + ToggleButtonGroup::single_row( + "single_row_test_icon_outlined", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Filled) + .into_any_element(), + ), + single_example( + "Multiple Row Group", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonSimple::new("First", |_, _, _| {}), + ToggleButtonSimple::new("Second", |_, _, _| {}), + ToggleButtonSimple::new("Third", |_, _, _| {}), + ], + [ + ToggleButtonSimple::new("Fourth", |_, _, _| {}), + ToggleButtonSimple::new("Fifth", |_, _, _| {}), + ToggleButtonSimple::new("Sixth", |_, _, _| {}), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Filled) + .into_any_element(), + ), + single_example( + "Multiple Row Group with Icons", + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonWithIcon::new( + "First", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Second", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Third", + IconName::AiZed, + |_, _, _| {}, + ), + ], + [ + ToggleButtonWithIcon::new( + "Fourth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Fifth", + IconName::AiZed, + |_, _, _| {}, + ), + ToggleButtonWithIcon::new( + "Sixth", + IconName::AiZed, + |_, _, _| {}, + ), + ], + ) + .selected_index(3) + .button_width(rems_from_px(100.)) + .style(ToggleButtonGroupStyle::Filled) + .into_any_element(), + ), + ], + )]) + .into_any_element(), + ) + } +} diff --git a/crates/ui/src/components/callout.rs b/crates/ui/src/components/callout.rs index d15fa122ed..9c1c9fb1a9 100644 --- a/crates/ui/src/components/callout.rs +++ b/crates/ui/src/components/callout.rs @@ -158,7 +158,7 @@ impl RenderOnce for Callout { impl Component for Callout { fn scope() -> ComponentScope { - ComponentScope::Notification + ComponentScope::DataDisplay } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/components/dropdown_menu.rs b/crates/ui/src/components/dropdown_menu.rs index 189fac930f..7ad9400f0d 100644 --- a/crates/ui/src/components/dropdown_menu.rs +++ b/crates/ui/src/components/dropdown_menu.rs @@ -8,6 +8,7 @@ use super::PopoverMenuHandle; pub enum DropdownStyle { #[default] Solid, + Outlined, Ghost, } @@ -147,6 +148,23 @@ impl Component for DropdownMenu { ), ], ), + example_group_with_title( + "Styles", + vec![ + single_example( + "Outlined", + DropdownMenu::new("outlined", "Outlined Dropdown", menu.clone()) + .style(DropdownStyle::Outlined) + .into_any_element(), + ), + single_example( + "Ghost", + DropdownMenu::new("ghost", "Ghost Dropdown", menu.clone()) + .style(DropdownStyle::Ghost) + .into_any_element(), + ), + ], + ), example_group_with_title( "States", vec![single_example( @@ -170,10 +188,13 @@ pub struct DropdownTriggerStyle { impl DropdownTriggerStyle { pub fn for_style(style: DropdownStyle, cx: &App) -> Self { let colors = cx.theme().colors(); + let bg = match style { DropdownStyle::Solid => colors.editor_background, + DropdownStyle::Outlined => colors.surface_background, DropdownStyle::Ghost => colors.ghost_element_background, }; + Self { bg } } } @@ -244,29 +265,36 @@ impl RenderOnce for DropdownMenuTrigger { let disabled = self.disabled; let style = DropdownTriggerStyle::for_style(self.style, cx); + let is_outlined = matches!(self.style, DropdownStyle::Outlined); h_flex() .id("dropdown-menu-trigger") - .justify_between() - .rounded_sm() - .bg(style.bg) + .min_w_20() .pl_2() .pr_1p5() .py_0p5() .gap_2() - .min_w_20() - .map(|el| { + .justify_between() + .rounded_sm() + .map(|this| { if self.full_width { - el.w_full() + this.w_full() } else { - el.flex_none().w_auto() + this.flex_none().w_auto() } }) - .map(|el| { + .when(is_outlined, |this| { + this.border_1() + .border_color(cx.theme().colors().border) + .overflow_hidden() + }) + .map(|this| { if disabled { - el.cursor_not_allowed() + this.cursor_not_allowed() + .bg(cx.theme().colors().element_disabled) } else { - el.cursor_pointer() + this.bg(style.bg) + .hover(|s| s.bg(cx.theme().colors().element_hover)) } }) .child(match self.label { diff --git a/crates/ui/src/components/image.rs b/crates/ui/src/components/image.rs index 2deba68d88..18f804abe9 100644 --- a/crates/ui/src/components/image.rs +++ b/crates/ui/src/components/image.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use gpui::Transformation; use gpui::{App, IntoElement, Rems, RenderOnce, Size, Styled, Window, svg}; use serde::{Deserialize, Serialize}; use strum::{EnumIter, EnumString, IntoStaticStr}; @@ -12,11 +13,13 @@ use crate::prelude::*; )] #[strum(serialize_all = "snake_case")] pub enum VectorName { + AiGrid, + CertifiedUserStamp, + DebuggerGrid, + Grid, + ProTrialStamp, ZedLogo, ZedXCopilot, - Grid, - AiGrid, - DebuggerGrid, } impl VectorName { @@ -37,6 +40,7 @@ pub struct Vector { path: Arc<str>, color: Color, size: Size<Rems>, + transformation: Transformation, } impl Vector { @@ -46,6 +50,7 @@ impl Vector { path: vector.path(), color: Color::default(), size: Size { width, height }, + transformation: Transformation::default(), } } @@ -66,6 +71,11 @@ impl Vector { self.size = size; self } + + pub fn transform(mut self, transformation: Transformation) -> Self { + self.transformation = transformation; + self + } } impl RenderOnce for Vector { @@ -81,6 +91,7 @@ impl RenderOnce for Vector { .h(height) .path(self.path) .text_color(self.color.color(cx)) + .with_transformation(self.transformation) } } diff --git a/crates/ui/src/components/keybinding.rs b/crates/ui/src/components/keybinding.rs index 1d91492f26..5779093ccc 100644 --- a/crates/ui/src/components/keybinding.rs +++ b/crates/ui/src/components/keybinding.rs @@ -44,7 +44,7 @@ impl KeyBinding { pub fn for_action_in( action: &dyn Action, focus: &FocusHandle, - window: &mut Window, + window: &Window, cx: &App, ) -> Option<Self> { let key_binding = window.highest_precedence_binding_for_action_in(action, focus)?; diff --git a/crates/ui/src/components/list.rs b/crates/ui/src/components/list.rs index 88650b6ae8..6876f290ce 100644 --- a/crates/ui/src/components/list.rs +++ b/crates/ui/src/components/list.rs @@ -1,10 +1,12 @@ mod list; +mod list_bullet_item; mod list_header; mod list_item; mod list_separator; mod list_sub_header; pub use list::*; +pub use list_bullet_item::*; pub use list_header::*; pub use list_item::*; pub use list_separator::*; diff --git a/crates/ui/src/components/list/list_bullet_item.rs b/crates/ui/src/components/list/list_bullet_item.rs new file mode 100644 index 0000000000..6e079d9f11 --- /dev/null +++ b/crates/ui/src/components/list/list_bullet_item.rs @@ -0,0 +1,40 @@ +use crate::{ListItem, prelude::*}; +use gpui::{IntoElement, ParentElement, SharedString}; + +#[derive(IntoElement)] +pub struct ListBulletItem { + label: SharedString, +} + +impl ListBulletItem { + pub fn new(label: impl Into<SharedString>) -> Self { + Self { + label: label.into(), + } + } +} + +impl RenderOnce for ListBulletItem { + fn render(self, window: &mut Window, _cx: &mut App) -> impl IntoElement { + let line_height = 0.85 * window.line_height(); + + ListItem::new("list-item") + .selectable(false) + .child( + h_flex() + .w_full() + .min_w_0() + .gap_1() + .items_start() + .child( + h_flex().h(line_height).justify_center().child( + Icon::new(IconName::Dash) + .size(IconSize::XSmall) + .color(Color::Hidden), + ), + ) + .child(div().w_full().min_w_0().child(Label::new(self.label))), + ) + .into_any_element() + } +} diff --git a/crates/ui/src/components/modal.rs b/crates/ui/src/components/modal.rs index 2145b34ef2..a70f5e1ea5 100644 --- a/crates/ui/src/components/modal.rs +++ b/crates/ui/src/components/modal.rs @@ -1,5 +1,5 @@ use crate::{ - Clickable, Color, DynamicSpacing, Headline, HeadlineSize, IconButton, IconButtonShape, + Clickable, Color, DynamicSpacing, Headline, HeadlineSize, Icon, IconButton, IconButtonShape, IconName, Label, LabelCommon, LabelSize, h_flex, v_flex, }; use gpui::{prelude::FluentBuilder, *}; @@ -92,6 +92,7 @@ impl RenderOnce for Modal { #[derive(IntoElement)] pub struct ModalHeader { + icon: Option<Icon>, headline: Option<SharedString>, description: Option<SharedString>, children: SmallVec<[AnyElement; 2]>, @@ -108,6 +109,7 @@ impl Default for ModalHeader { impl ModalHeader { pub fn new() -> Self { Self { + icon: None, headline: None, description: None, children: SmallVec::new(), @@ -116,6 +118,11 @@ impl ModalHeader { } } + pub fn icon(mut self, icon: Icon) -> Self { + self.icon = Some(icon); + self + } + /// Set the headline of the modal. /// /// This will insert the headline as the first item @@ -179,12 +186,17 @@ impl RenderOnce for ModalHeader { ) }) .child( - v_flex().flex_1().children(children).when_some( - self.description, - |this, description| { + v_flex() + .flex_1() + .child( + h_flex() + .gap_1() + .when_some(self.icon, |this, icon| this.child(icon)) + .children(children), + ) + .when_some(self.description, |this, description| { this.child(Label::new(description).color(Color::Muted).mb_2()) - }, - ), + }), ) .when(self.show_dismiss_button, |this| { this.child( diff --git a/crates/ui/src/components/numeric_stepper.rs b/crates/ui/src/components/numeric_stepper.rs index f9e6e88f01..2ddb86d9a0 100644 --- a/crates/ui/src/components/numeric_stepper.rs +++ b/crates/ui/src/components/numeric_stepper.rs @@ -2,15 +2,24 @@ use gpui::ClickEvent; use crate::{IconButtonShape, prelude::*}; -#[derive(IntoElement)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NumericStepperStyle { + Outlined, + #[default] + Ghost, +} + +#[derive(IntoElement, RegisterComponent)] pub struct NumericStepper { id: ElementId, value: SharedString, + style: NumericStepperStyle, on_decrement: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, on_increment: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>, /// Whether to reserve space for the reset button. reserve_space_for_reset: bool, on_reset: Option<Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>>, + tab_index: Option<isize>, } impl NumericStepper { @@ -23,13 +32,20 @@ impl NumericStepper { Self { id: id.into(), value: value.into(), + style: NumericStepperStyle::default(), on_decrement: Box::new(on_decrement), on_increment: Box::new(on_increment), reserve_space_for_reset: false, on_reset: None, + tab_index: None, } } + pub fn style(mut self, style: NumericStepperStyle) -> Self { + self.style = style; + self + } + pub fn reserve_space_for_reset(mut self, reserve_space_for_reset: bool) -> Self { self.reserve_space_for_reset = reserve_space_for_reset; self @@ -42,6 +58,11 @@ impl NumericStepper { self.on_reset = Some(Box::new(on_reset)); self } + + pub fn tab_index(mut self, tab_index: isize) -> Self { + self.tab_index = Some(tab_index); + self + } } impl RenderOnce for NumericStepper { @@ -49,6 +70,9 @@ impl RenderOnce for NumericStepper { let shape = IconButtonShape::Square; let icon_size = IconSize::Small; + let is_outlined = matches!(self.style, NumericStepperStyle::Outlined); + let mut tab_index = self.tab_index; + h_flex() .id(self.id) .gap_1() @@ -58,6 +82,10 @@ impl RenderOnce for NumericStepper { IconButton::new("reset", IconName::RotateCcw) .shape(shape) .icon_size(icon_size) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1) + }) .on_click(on_reset), ) } else if self.reserve_space_for_reset { @@ -74,22 +102,136 @@ impl RenderOnce for NumericStepper { .child( h_flex() .gap_1() - .px_1() - .rounded_xs() - .bg(cx.theme().colors().editor_background) - .child( - IconButton::new("decrement", IconName::Dash) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_decrement), - ) - .child(Label::new(self.value)) - .child( - IconButton::new("increment", IconName::Plus) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_increment), - ), + .rounded_sm() + .map(|this| { + if is_outlined { + this.overflow_hidden() + .bg(cx.theme().colors().surface_background) + .border_1() + .border_color(cx.theme().colors().border_variant) + } else { + this.px_1().bg(cx.theme().colors().editor_background) + } + }) + .map(|decrement| { + if is_outlined { + decrement.child( + h_flex() + .id("decrement_button") + .p_1p5() + .size_full() + .justify_center() + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .border_r_1() + .border_color(cx.theme().colors().border_variant) + .child(Icon::new(IconName::Dash).size(IconSize::Small)) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1).focus(|style| { + style.bg(cx.theme().colors().element_hover) + }) + }) + .on_click(self.on_decrement), + ) + } else { + decrement.child( + IconButton::new("decrement", IconName::Dash) + .shape(shape) + .icon_size(icon_size) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1) + }) + .on_click(self.on_decrement), + ) + } + }) + .child(Label::new(self.value).mx_3()) + .map(|increment| { + if is_outlined { + increment.child( + h_flex() + .id("increment_button") + .p_1p5() + .size_full() + .justify_center() + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .child(Icon::new(IconName::Plus).size(IconSize::Small)) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1).focus(|style| { + style.bg(cx.theme().colors().element_hover) + }) + }) + .on_click(self.on_increment), + ) + } else { + increment.child( + IconButton::new("increment", IconName::Dash) + .shape(shape) + .icon_size(icon_size) + .when_some(tab_index.as_mut(), |this, tab_index| { + *tab_index += 1; + this.tab_index(*tab_index - 1) + }) + .on_click(self.on_increment), + ) + } + }), ) } } + +impl Component for NumericStepper { + fn scope() -> ComponentScope { + ComponentScope::Input + } + + fn name() -> &'static str { + "Numeric Stepper" + } + + fn sort_name() -> &'static str { + Self::name() + } + + fn description() -> Option<&'static str> { + Some("A button used to increment or decrement a numeric value.") + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { + Some( + v_flex() + .gap_6() + .children(vec![example_group_with_title( + "Styles", + vec![ + single_example( + "Default", + NumericStepper::new( + "numeric-stepper-component-preview", + "10", + move |_, _, _| {}, + move |_, _, _| {}, + ) + .into_any_element(), + ), + single_example( + "Outlined", + NumericStepper::new( + "numeric-stepper-with-border-component-preview", + "10", + move |_, _, _| {}, + move |_, _, _| {}, + ) + .style(NumericStepperStyle::Outlined) + .into_any_element(), + ), + ], + )]) + .into_any_element(), + ) + } +} diff --git a/crates/ui/src/components/popover.rs b/crates/ui/src/components/popover.rs index 24460f6d9c..7143514c52 100644 --- a/crates/ui/src/components/popover.rs +++ b/crates/ui/src/components/popover.rs @@ -50,7 +50,7 @@ impl RenderOnce for Popover { v_flex() .elevation_2(cx) .py(POPOVER_Y_PADDING / 2.) - .children(self.children), + .child(div().children(self.children)), ) .when_some(self.aside, |this, aside| { this.child( diff --git a/crates/ui/src/components/scrollbar.rs b/crates/ui/src/components/scrollbar.rs index 17ab2e788f..605028202f 100644 --- a/crates/ui/src/components/scrollbar.rs +++ b/crates/ui/src/components/scrollbar.rs @@ -1,11 +1,20 @@ -use std::{any::Any, cell::Cell, fmt::Debug, ops::Range, rc::Rc, sync::Arc}; +use std::{ + any::Any, + cell::{Cell, RefCell}, + fmt::Debug, + ops::Range, + rc::Rc, + sync::Arc, + time::Duration, +}; use crate::{IntoElement, prelude::*, px, relative}; use gpui::{ Along, App, Axis as ScrollbarAxis, BorderStyle, Bounds, ContentMask, Corners, CursorStyle, Edges, Element, ElementId, Entity, EntityId, GlobalElementId, Hitbox, HitboxBehavior, Hsla, - IsZero, LayoutId, ListState, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, Point, - ScrollHandle, ScrollWheelEvent, Size, Style, UniformListScrollHandle, Window, quad, + IsZero, LayoutId, ListState, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, + Point, ScrollHandle, ScrollWheelEvent, Size, Style, Task, UniformListScrollHandle, Window, + quad, }; pub struct Scrollbar { @@ -108,6 +117,25 @@ pub struct ScrollbarState { thumb_state: Rc<Cell<ThumbState>>, parent_id: Option<EntityId>, scroll_handle: Arc<dyn ScrollableHandle>, + auto_hide: Rc<RefCell<AutoHide>>, +} + +#[derive(Debug)] +enum AutoHide { + Disabled, + Hidden { + parent_id: EntityId, + }, + Visible { + parent_id: EntityId, + _task: Task<()>, + }, +} + +impl AutoHide { + fn is_hidden(&self) -> bool { + matches!(self, AutoHide::Hidden { .. }) + } } impl ScrollbarState { @@ -116,6 +144,7 @@ impl ScrollbarState { thumb_state: Default::default(), parent_id: None, scroll_handle: Arc::new(scroll), + auto_hide: Rc::new(RefCell::new(AutoHide::Disabled)), } } @@ -174,6 +203,38 @@ impl ScrollbarState { let thumb_percentage_end = (start_offset + thumb_size) / viewport_size; Some(thumb_percentage_start..thumb_percentage_end) } + + fn show_temporarily(&self, parent_id: EntityId, cx: &mut App) { + const SHOW_INTERVAL: Duration = Duration::from_secs(1); + + let auto_hide = self.auto_hide.clone(); + auto_hide.replace(AutoHide::Visible { + parent_id, + _task: cx.spawn({ + let this = auto_hide.clone(); + async move |cx| { + cx.background_executor().timer(SHOW_INTERVAL).await; + this.replace(AutoHide::Hidden { parent_id }); + cx.update(|cx| { + cx.notify(parent_id); + }) + .ok(); + } + }), + }); + } + + fn unhide(&self, position: &Point<Pixels>, cx: &mut App) { + let parent_id = match &*self.auto_hide.borrow() { + AutoHide::Disabled => return, + AutoHide::Hidden { parent_id } => *parent_id, + AutoHide::Visible { parent_id, _task } => *parent_id, + }; + + if self.scroll_handle().viewport().contains(position) { + self.show_temporarily(parent_id, cx); + } + } } impl Scrollbar { @@ -189,6 +250,14 @@ impl Scrollbar { let thumb = state.thumb_range(kind)?; Some(Self { thumb, state, kind }) } + + /// Automatically hide the scrollbar when idle + pub fn auto_hide<V: 'static>(self, cx: &mut Context<V>) -> Self { + if matches!(*self.state.auto_hide.borrow(), AutoHide::Disabled) { + self.state.show_temporarily(cx.entity_id(), cx); + } + self + } } impl Element for Scrollbar { @@ -284,16 +353,18 @@ impl Element for Scrollbar { .apply_along(axis.invert(), |width| width / 1.5), ); - let corners = Corners::all(thumb_bounds.size.along(axis.invert()) / 2.0); + if thumb_state.is_dragging() || !self.state.auto_hide.borrow().is_hidden() { + let corners = Corners::all(thumb_bounds.size.along(axis.invert()) / 2.0); - window.paint_quad(quad( - thumb_bounds, - corners, - thumb_background, - Edges::default(), - Hsla::transparent_black(), - BorderStyle::default(), - )); + window.paint_quad(quad( + thumb_bounds, + corners, + thumb_background, + Edges::default(), + Hsla::transparent_black(), + BorderStyle::default(), + )); + } if thumb_state.is_dragging() { window.set_window_cursor_style(CursorStyle::Arrow); @@ -301,8 +372,6 @@ impl Element for Scrollbar { window.set_cursor_style(CursorStyle::Arrow, hitbox); } - let scroll = self.state.scroll_handle.clone(); - enum ScrollbarMouseEvent { GutterClick, ThumbDrag(Pixels), @@ -337,10 +406,12 @@ impl Element for Scrollbar { }; window.on_mouse_event({ - let scroll = scroll.clone(); let state = self.state.clone(); move |event: &MouseDownEvent, phase, _, _| { - if !(phase.bubble() && bounds.contains(&event.position)) { + if !phase.bubble() + || event.button != MouseButton::Left + || !bounds.contains(&event.position) + { return; } @@ -348,57 +419,78 @@ impl Element for Scrollbar { let offset = event.position.along(axis) - thumb_bounds.origin.along(axis); state.set_dragging(offset); } else { + let scroll_handle = state.scroll_handle(); let click_offset = compute_click_offset( event.position, - scroll.max_offset(), + scroll_handle.max_offset(), ScrollbarMouseEvent::GutterClick, ); - scroll.set_offset(scroll.offset().apply_along(axis, |_| click_offset)); + scroll_handle + .set_offset(scroll_handle.offset().apply_along(axis, |_| click_offset)); } } }); window.on_mouse_event({ - let scroll = scroll.clone(); - move |event: &ScrollWheelEvent, phase, window, _| { - if phase.bubble() && bounds.contains(&event.position) { - let current_offset = scroll.offset(); - scroll.set_offset( - current_offset + event.delta.pixel_delta(window.line_height()), - ); + let state = self.state.clone(); + let scroll_handle = self.state.scroll_handle().clone(); + move |event: &ScrollWheelEvent, phase, window, cx| { + if phase.bubble() { + state.unhide(&event.position, cx); + + if bounds.contains(&event.position) { + let current_offset = scroll_handle.offset(); + scroll_handle.set_offset( + current_offset + event.delta.pixel_delta(window.line_height()), + ); + } } } }); - let state = self.state.clone(); - window.on_mouse_event(move |event: &MouseMoveEvent, _, window, cx| { - match state.thumb_state.get() { - ThumbState::Dragging(drag_state) if event.dragging() => { - let drag_offset = compute_click_offset( - event.position, - scroll.max_offset(), - ScrollbarMouseEvent::ThumbDrag(drag_state), - ); - scroll.set_offset(scroll.offset().apply_along(axis, |_| drag_offset)); - window.refresh(); - if let Some(id) = state.parent_id { - cx.notify(id); + window.on_mouse_event({ + let state = self.state.clone(); + move |event: &MouseMoveEvent, phase, window, cx| { + if phase.bubble() { + state.unhide(&event.position, cx); + + match state.thumb_state.get() { + ThumbState::Dragging(drag_state) if event.dragging() => { + let scroll_handle = state.scroll_handle(); + let drag_offset = compute_click_offset( + event.position, + scroll_handle.max_offset(), + ScrollbarMouseEvent::ThumbDrag(drag_state), + ); + scroll_handle.set_offset( + scroll_handle.offset().apply_along(axis, |_| drag_offset), + ); + window.refresh(); + if let Some(id) = state.parent_id { + cx.notify(id); + } + } + _ if event.pressed_button.is_none() => { + state.set_thumb_hovered(thumb_bounds.contains(&event.position)) + } + _ => {} } } - _ => state.set_thumb_hovered(thumb_bounds.contains(&event.position)), } }); - let state = self.state.clone(); - let scroll = self.state.scroll_handle.clone(); - window.on_mouse_event(move |event: &MouseUpEvent, phase, _, cx| { - if phase.bubble() { - if state.is_dragging() { + + window.on_mouse_event({ + let state = self.state.clone(); + move |event: &MouseUpEvent, phase, _, cx| { + if phase.bubble() { + if state.is_dragging() { + state.scroll_handle().drag_ended(); + if let Some(id) = state.parent_id { + cx.notify(id); + } + } state.set_thumb_hovered(thumb_bounds.contains(&event.position)); } - scroll.drag_ended(); - if let Some(id) = state.parent_id { - cx.notify(id); - } } }); }) diff --git a/crates/ui/src/components/stories/icon_button.rs b/crates/ui/src/components/stories/icon_button.rs index e787e81b55..ad6886252d 100644 --- a/crates/ui/src/components/stories/icon_button.rs +++ b/crates/ui/src/components/stories/icon_button.rs @@ -77,7 +77,7 @@ impl Render for IconButtonStory { let with_tooltip_button = StoryItem::new( "With `tooltip`", - IconButton::new("with_tooltip_button", IconName::MessageBubbles) + IconButton::new("with_tooltip_button", IconName::Chat) .tooltip(Tooltip::text("Open messages")), ) .description("Displays an icon button that has a tooltip when hovered.") diff --git a/crates/ui/src/components/tab.rs b/crates/ui/src/components/tab.rs index a205c33358..d704846a68 100644 --- a/crates/ui/src/components/tab.rs +++ b/crates/ui/src/components/tab.rs @@ -179,7 +179,7 @@ impl RenderOnce for Tab { impl Component for Tab { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::Navigation } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/components/toggle.rs b/crates/ui/src/components/toggle.rs index cf2a56b1c9..4b985fd2c2 100644 --- a/crates/ui/src/components/toggle.rs +++ b/crates/ui/src/components/toggle.rs @@ -2,10 +2,10 @@ use gpui::{ AnyElement, AnyView, ClickEvent, ElementId, Hsla, IntoElement, Styled, Window, div, hsla, prelude::*, }; -use std::sync::Arc; +use std::{rc::Rc, sync::Arc}; use crate::utils::is_light; -use crate::{Color, Icon, IconName, ToggleState}; +use crate::{Color, Icon, IconName, ToggleState, Tooltip}; use crate::{ElevationIndex, KeyBinding, prelude::*}; // TODO: Checkbox, CheckboxWithLabel, and Switch could all be @@ -424,6 +424,7 @@ pub struct Switch { label: Option<SharedString>, key_binding: Option<KeyBinding>, color: SwitchColor, + tab_index: Option<isize>, } impl Switch { @@ -437,6 +438,7 @@ impl Switch { label: None, key_binding: None, color: SwitchColor::default(), + tab_index: None, } } @@ -472,6 +474,11 @@ impl Switch { self.key_binding = key_binding.into(); self } + + pub fn tab_index(mut self, tab_index: impl Into<isize>) -> Self { + self.tab_index = Some(tab_index.into()); + self + } } impl RenderOnce for Switch { @@ -497,29 +504,46 @@ impl RenderOnce for Switch { let group_id = format!("switch_group_{:?}", self.id); - let switch = h_flex() - .w(DynamicSpacing::Base32.rems(cx)) - .h(DynamicSpacing::Base20.rems(cx)) - .group(group_id.clone()) + let switch = div() + .id((self.id.clone(), "switch")) + .p(px(1.0)) + .border_2() + .border_color(cx.theme().colors().border_transparent) + .rounded_full() + .when_some( + self.tab_index.filter(|_| !self.disabled), + |this, tab_index| { + this.tab_index(tab_index).focus(|mut style| { + style.border_color = Some(cx.theme().colors().border_focused); + style + }) + }, + ) .child( h_flex() - .when(is_on, |on| on.justify_end()) - .when(!is_on, |off| off.justify_start()) - .size_full() - .rounded_full() - .px(DynamicSpacing::Base02.px(cx)) - .bg(bg_color) - .when(!self.disabled, |this| { - this.group_hover(group_id.clone(), |el| el.bg(bg_hover_color)) - }) - .border_1() - .border_color(border_color) + .w(DynamicSpacing::Base32.rems(cx)) + .h(DynamicSpacing::Base20.rems(cx)) + .group(group_id.clone()) .child( - div() - .size(DynamicSpacing::Base12.rems(cx)) + h_flex() + .when(is_on, |on| on.justify_end()) + .when(!is_on, |off| off.justify_start()) + .size_full() .rounded_full() - .bg(thumb_color) - .opacity(thumb_opacity), + .px(DynamicSpacing::Base02.px(cx)) + .bg(bg_color) + .when(!self.disabled, |this| { + this.group_hover(group_id.clone(), |el| el.bg(bg_hover_color)) + }) + .border_1() + .border_color(border_color) + .child( + div() + .size(DynamicSpacing::Base12.rems(cx)) + .rounded_full() + .bg(thumb_color) + .opacity(thumb_opacity), + ), ), ); @@ -566,32 +590,41 @@ impl RenderOnce for Switch { pub struct SwitchField { id: ElementId, label: SharedString, - description: SharedString, + description: Option<SharedString>, toggle_state: ToggleState, on_click: Arc<dyn Fn(&ToggleState, &mut Window, &mut App) + 'static>, disabled: bool, color: SwitchColor, + tooltip: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView>>, + tab_index: Option<isize>, } impl SwitchField { pub fn new( id: impl Into<ElementId>, label: impl Into<SharedString>, - description: impl Into<SharedString>, + description: Option<SharedString>, toggle_state: impl Into<ToggleState>, on_click: impl Fn(&ToggleState, &mut Window, &mut App) + 'static, ) -> Self { Self { id: id.into(), label: label.into(), - description: description.into(), + description: description, toggle_state: toggle_state.into(), on_click: Arc::new(on_click), disabled: false, color: SwitchColor::Accent, + tooltip: None, + tab_index: None, } } + pub fn description(mut self, description: impl Into<SharedString>) -> Self { + self.description = Some(description.into()); + self + } + pub fn disabled(mut self, disabled: bool) -> Self { self.disabled = disabled; self @@ -603,36 +636,75 @@ impl SwitchField { self.color = color; self } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } + + pub fn tab_index(mut self, tab_index: isize) -> Self { + self.tab_index = Some(tab_index); + self + } } impl RenderOnce for SwitchField { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { + let tooltip = self.tooltip.map(|tooltip_fn| { + h_flex() + .gap_0p5() + .child(Label::new(self.label.clone())) + .child( + IconButton::new("tooltip_button", IconName::Info) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .shape(crate::IconButtonShape::Square) + .tooltip({ + let tooltip = tooltip_fn.clone(); + move |window, cx| tooltip(window, cx) + }), + ) + }); + h_flex() - .id(SharedString::from(format!("{}-container", self.id))) + .id((self.id.clone(), "container")) + .when(!self.disabled, |this| { + this.hover(|this| this.cursor_pointer()) + }) .w_full() .gap_4() .justify_between() .flex_wrap() - .child( - v_flex() + .child(match (&self.description, tooltip) { + (Some(description), Some(tooltip)) => v_flex() .gap_0p5() .max_w_5_6() - .child(Label::new(self.label)) - .child(Label::new(self.description).color(Color::Muted)), - ) + .child(tooltip) + .child(Label::new(description.clone()).color(Color::Muted)) + .into_any_element(), + (Some(description), None) => v_flex() + .gap_0p5() + .max_w_5_6() + .child(Label::new(self.label.clone())) + .child(Label::new(description.clone()).color(Color::Muted)) + .into_any_element(), + (None, Some(tooltip)) => tooltip.into_any_element(), + (None, None) => Label::new(self.label.clone()).into_any_element(), + }) .child( - Switch::new( - SharedString::from(format!("{}-switch", self.id)), - self.toggle_state, - ) - .color(self.color) - .disabled(self.disabled) - .on_click({ - let on_click = self.on_click.clone(); - move |state, window, cx| { - (on_click)(state, window, cx); - } - }), + Switch::new((self.id.clone(), "switch"), self.toggle_state) + .color(self.color) + .disabled(self.disabled) + .when_some( + self.tab_index.filter(|_| !self.disabled), + |this, tab_index| this.tab_index(tab_index), + ) + .on_click({ + let on_click = self.on_click.clone(); + move |state, window, cx| { + (on_click)(state, window, cx); + } + }), ) .when(!self.disabled, |this| { this.on_click({ @@ -668,7 +740,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_unselected", "Enable notifications", - "Receive notifications when new messages arrive.", + Some("Receive notifications when new messages arrive.".into()), ToggleState::Unselected, |_, _, _| {}, ) @@ -679,7 +751,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_selected", "Enable notifications", - "Receive notifications when new messages arrive.", + Some("Receive notifications when new messages arrive.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -695,7 +767,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_default", "Default color", - "This uses the default switch color.", + Some("This uses the default switch color.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -706,7 +778,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_accent", "Accent color", - "This uses the accent color scheme.", + Some("This uses the accent color scheme.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -722,7 +794,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_disabled", "Disabled field", - "This field is disabled and cannot be toggled.", + Some("This field is disabled and cannot be toggled.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -730,6 +802,49 @@ impl Component for SwitchField { .into_any_element(), )], ), + example_group_with_title( + "No Description", + vec![single_example( + "No Description", + SwitchField::new( + "switch_field_disabled", + "Disabled field", + None, + ToggleState::Selected, + |_, _, _| {}, + ) + .into_any_element(), + )], + ), + example_group_with_title( + "With Tooltip", + vec![ + single_example( + "Tooltip with Description", + SwitchField::new( + "switch_field_tooltip_with_desc", + "Nice Feature", + Some("Enable advanced configuration options.".into()), + ToggleState::Unselected, + |_, _, _| {}, + ) + .tooltip(Tooltip::text("This is content for this tooltip!")) + .into_any_element(), + ), + single_example( + "Tooltip without Description", + SwitchField::new( + "switch_field_tooltip_no_desc", + "Nice Feature", + None, + ToggleState::Selected, + |_, _, _| {}, + ) + .tooltip(Tooltip::text("This is content for this tooltip!")) + .into_any_element(), + ), + ], + ), ]) .into_any_element(), ) diff --git a/crates/ui/src/styles/animation.rs b/crates/ui/src/styles/animation.rs index 50c4e0eb0d..ee5352d454 100644 --- a/crates/ui/src/styles/animation.rs +++ b/crates/ui/src/styles/animation.rs @@ -99,7 +99,7 @@ struct Animation {} impl Component for Animation { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::Utilities } fn description() -> Option<&'static str> { @@ -109,7 +109,7 @@ impl Component for Animation { fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> { let container_size = 128.0; let element_size = 32.0; - let left_offset = element_size - container_size / 2.0; + let offset = container_size / 2.0 - element_size / 2.0; Some( v_flex() .gap_6() @@ -129,7 +129,7 @@ impl Component for Animation { .id("animate-in-from-bottom") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::red()) .animate_in(AnimationDirection::FromBottom, false), @@ -148,7 +148,7 @@ impl Component for Animation { .id("animate-in-from-top") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::blue()) .animate_in(AnimationDirection::FromTop, false), @@ -167,7 +167,7 @@ impl Component for Animation { .id("animate-in-from-left") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::green()) .animate_in(AnimationDirection::FromLeft, false), @@ -186,7 +186,7 @@ impl Component for Animation { .id("animate-in-from-right") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::yellow()) .animate_in(AnimationDirection::FromRight, false), @@ -211,7 +211,7 @@ impl Component for Animation { .id("fade-animate-in-from-bottom") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::red()) .animate_in(AnimationDirection::FromBottom, true), @@ -230,7 +230,7 @@ impl Component for Animation { .id("fade-animate-in-from-top") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .left(px(offset)) .rounded_md() .bg(gpui::blue()) .animate_in(AnimationDirection::FromTop, true), @@ -249,7 +249,7 @@ impl Component for Animation { .id("fade-animate-in-from-left") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::green()) .animate_in(AnimationDirection::FromLeft, true), @@ -268,7 +268,7 @@ impl Component for Animation { .id("fade-animate-in-from-right") .absolute() .size(px(element_size)) - .left(px(left_offset)) + .top(px(offset)) .rounded_md() .bg(gpui::yellow()) .animate_in(AnimationDirection::FromRight, true), diff --git a/crates/ui/src/styles/color.rs b/crates/ui/src/styles/color.rs index c7b995d39a..586b2ccc57 100644 --- a/crates/ui/src/styles/color.rs +++ b/crates/ui/src/styles/color.rs @@ -126,7 +126,7 @@ impl From<Hsla> for Color { impl Component for Color { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::Utilities } fn description() -> Option<&'static str> { diff --git a/crates/ui_prompt/src/ui_prompt.rs b/crates/ui_prompt/src/ui_prompt.rs index 2b6a030f26..fe6dc5b3f4 100644 --- a/crates/ui_prompt/src/ui_prompt.rs +++ b/crates/ui_prompt/src/ui_prompt.rs @@ -43,7 +43,7 @@ fn zed_prompt_renderer( let renderer = cx.new({ |cx| ZedPromptRenderer { _level: level, - message: message.to_string(), + message: cx.new(|cx| Markdown::new(SharedString::new(message), None, None, cx)), actions: actions.iter().map(|a| a.label().to_string()).collect(), focus: cx.focus_handle(), active_action_id: 0, @@ -58,7 +58,7 @@ fn zed_prompt_renderer( pub struct ZedPromptRenderer { _level: PromptLevel, - message: String, + message: Entity<Markdown>, actions: Vec<String>, focus: FocusHandle, active_action_id: usize, @@ -114,7 +114,7 @@ impl ZedPromptRenderer { impl Render for ZedPromptRenderer { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); - let font_family = settings.ui_font.family.clone(); + let font_size = settings.ui_font_size(cx).into(); let prompt = v_flex() .key_context("Prompt") .cursor_default() @@ -130,24 +130,38 @@ impl Render for ZedPromptRenderer { .overflow_hidden() .p_4() .gap_4() - .font_family(font_family) + .font_family(settings.ui_font.family.clone()) .child( div() .w_full() - .font_weight(FontWeight::BOLD) - .child(self.message.clone()) - .text_color(ui::Color::Default.color(cx)), + .child(MarkdownElement::new(self.message.clone(), { + let mut base_text_style = window.text_style(); + base_text_style.refine(&TextStyleRefinement { + font_family: Some(settings.ui_font.family.clone()), + font_size: Some(font_size), + font_weight: Some(FontWeight::BOLD), + color: Some(ui::Color::Default.color(cx)), + ..Default::default() + }); + MarkdownStyle { + base_text_style, + selection_background_color: cx + .theme() + .colors() + .element_selection_background, + ..Default::default() + } + })), ) .children(self.detail.clone().map(|detail| { div() .w_full() .text_xs() .child(MarkdownElement::new(detail, { - let settings = ThemeSettings::get_global(cx); let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { font_family: Some(settings.ui_font.family.clone()), - font_size: Some(settings.ui_font_size(cx).into()), + font_size: Some(font_size), color: Some(ui::Color::Muted.color(cx)), ..Default::default() }); @@ -176,24 +190,28 @@ impl Render for ZedPromptRenderer { }), )); - div().size_full().occlude().child( - div() - .size_full() - .absolute() - .top_0() - .left_0() - .flex() - .flex_col() - .justify_around() - .child( - div() - .w_full() - .flex() - .flex_row() - .justify_around() - .child(prompt), - ), - ) + div() + .size_full() + .occlude() + .bg(gpui::black().opacity(0.2)) + .child( + div() + .size_full() + .absolute() + .top_0() + .left_0() + .flex() + .flex_col() + .justify_around() + .child( + div() + .w_full() + .flex() + .flex_row() + .justify_around() + .child(prompt), + ), + ) } } diff --git a/crates/util/src/archive.rs b/crates/util/src/archive.rs index d10b996716..3e4d281c29 100644 --- a/crates/util/src/archive.rs +++ b/crates/util/src/archive.rs @@ -2,6 +2,8 @@ use std::path::Path; use anyhow::{Context as _, Result}; use async_zip::base::read; +#[cfg(not(windows))] +use futures::AsyncSeek; use futures::{AsyncRead, io::BufReader}; #[cfg(windows)] @@ -62,7 +64,15 @@ pub async fn extract_zip<R: AsyncRead + Unpin>(destination: &Path, reader: R) -> futures::io::copy(&mut BufReader::new(reader), &mut file) .await .context("saving archive contents into the temporary file")?; - let mut reader = read::seek::ZipFileReader::new(BufReader::new(file)) + extract_seekable_zip(destination, file).await +} + +#[cfg(not(windows))] +pub async fn extract_seekable_zip<R: AsyncRead + AsyncSeek + Unpin>( + destination: &Path, + reader: R, +) -> Result<()> { + let mut reader = read::seek::ZipFileReader::new(BufReader::new(reader)) .await .context("reading the zip archive")?; let destination = &destination diff --git a/crates/util/src/fs.rs b/crates/util/src/fs.rs index 2738b6e213..3e96594f85 100644 --- a/crates/util/src/fs.rs +++ b/crates/util/src/fs.rs @@ -95,9 +95,9 @@ pub async fn move_folder_files_to_folder<P: AsRef<Path>>( #[cfg(unix)] /// Set the permissions for the given path so that the file becomes executable. /// This is a noop for non-unix platforms. -pub async fn make_file_executable(path: &PathBuf) -> std::io::Result<()> { +pub async fn make_file_executable(path: &Path) -> std::io::Result<()> { fs::set_permissions( - &path, + path, <fs::Permissions as fs::unix::PermissionsExt>::from_mode(0o755), ) .await @@ -107,6 +107,6 @@ pub async fn make_file_executable(path: &PathBuf) -> std::io::Result<()> { #[allow(clippy::unused_async)] /// Set the permissions for the given path so that the file becomes executable. /// This is a noop for non-unix platforms. -pub async fn make_file_executable(_path: &PathBuf) -> std::io::Result<()> { +pub async fn make_file_executable(_path: &Path) -> std::io::Result<()> { Ok(()) } diff --git a/crates/util/src/shell_env.rs b/crates/util/src/shell_env.rs index 21f6096f19..2b1063316f 100644 --- a/crates/util/src/shell_env.rs +++ b/crates/util/src/shell_env.rs @@ -18,15 +18,19 @@ pub fn capture(directory: &std::path::Path) -> Result<collections::HashMap<Strin // In some shells, file descriptors greater than 2 cannot be used in interactive mode, // so file descriptor 0 (stdin) is used instead. This impacts zsh, old bash; perhaps others. // See: https://github.com/zed-industries/zed/pull/32136#issuecomment-2999645482 - const ENV_OUTPUT_FD: std::os::fd::RawFd = 0; - let redir = match shell_name { - Some("rc") => format!(">[1={}]", ENV_OUTPUT_FD), // `[1=0]` - _ => format!(">&{}", ENV_OUTPUT_FD), // `>&0` + const FD_STDIN: std::os::fd::RawFd = 0; + const FD_STDOUT: std::os::fd::RawFd = 1; + + let (fd_num, redir) = match shell_name { + Some("rc") => (FD_STDIN, format!(">[1={}]", FD_STDIN)), // `[1=0]` + Some("nu") | Some("tcsh") => (FD_STDOUT, "".to_string()), + _ => (FD_STDIN, format!(">&{}", FD_STDIN)), // `>&0` }; command.stdin(Stdio::null()); command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); + let mut command_prefix = String::new(); match shell_name { Some("tcsh" | "csh") => { // For csh/tcsh, login shell requires passing `-` as 0th argument (instead of `-l`) @@ -37,18 +41,25 @@ pub fn capture(directory: &std::path::Path) -> Result<collections::HashMap<Strin command_string.push_str("emit fish_prompt;"); command.arg("-l"); } + Some("nu") => { + // nu needs special handling for -- options. + command_prefix = String::from("^"); + } _ => { command.arg("-l"); } } // cd into the directory, triggering directory specific side-effects (asdf, direnv, etc) command_string.push_str(&format!("cd '{}';", directory.display())); - command_string.push_str(&format!("{} --printenv {}", zed_path, redir)); + command_string.push_str(&format!( + "{}{} --printenv {}", + command_prefix, zed_path, redir + )); command.args(["-i", "-c", &command_string]); super::set_pre_exec_to_start_new_session(&mut command); - let (env_output, process_output) = spawn_and_read_fd(command, ENV_OUTPUT_FD)?; + let (env_output, process_output) = spawn_and_read_fd(command, fd_num)?; let env_output = String::from_utf8_lossy(&env_output); anyhow::ensure!( diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index 23e04cae2c..7963db3571 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -6,7 +6,7 @@ use editor::{ actions::{SortLinesCaseInsensitive, SortLinesCaseSensitive}, display_map::ToDisplayPoint, }; -use gpui::{Action, App, AppContext as _, Context, Global, Window, actions}; +use gpui::{Action, App, AppContext as _, Context, Global, Keystroke, Window, actions}; use itertools::Itertools; use language::Point; use multi_buffer::MultiBufferRow; @@ -202,6 +202,7 @@ actions!( ArgumentRequired ] ); + /// Opens the specified file for editing. #[derive(Clone, PartialEq, Action)] #[action(namespace = vim, no_json, no_register)] @@ -209,6 +210,13 @@ struct VimEdit { pub filename: String, } +#[derive(Clone, PartialEq, Action)] +#[action(namespace = vim, no_json, no_register)] +struct VimNorm { + pub range: Option<CommandRange>, + pub command: String, +} + #[derive(Debug)] struct WrappedAction(Box<dyn Action>); @@ -447,6 +455,81 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) { }); }); + Vim::action(editor, cx, |vim, action: &VimNorm, window, cx| { + let keystrokes = action + .command + .chars() + .map(|c| Keystroke::parse(&c.to_string()).unwrap()) + .collect(); + vim.switch_mode(Mode::Normal, true, window, cx); + let initial_selections = vim.update_editor(window, cx, |_, editor, _, _| { + editor.selections.disjoint_anchors() + }); + if let Some(range) = &action.range { + let result = vim.update_editor(window, cx, |vim, editor, window, cx| { + let range = range.buffer_range(vim, editor, window, cx)?; + editor.change_selections( + SelectionEffects::no_scroll().nav_history(false), + window, + cx, + |s| { + s.select_ranges( + (range.start.0..=range.end.0) + .map(|line| Point::new(line, 0)..Point::new(line, 0)), + ); + }, + ); + anyhow::Ok(()) + }); + if let Some(Err(err)) = result { + log::error!("Error selecting range: {}", err); + return; + } + }; + + let Some(workspace) = vim.workspace(window) else { + return; + }; + let task = workspace.update(cx, |workspace, cx| { + workspace.send_keystrokes_impl(keystrokes, window, cx) + }); + let had_range = action.range.is_some(); + + cx.spawn_in(window, async move |vim, cx| { + task.await; + vim.update_in(cx, |vim, window, cx| { + vim.update_editor(window, cx, |_, editor, window, cx| { + if had_range { + editor.change_selections(SelectionEffects::default(), window, cx, |s| { + s.select_anchor_ranges([s.newest_anchor().range()]); + }) + } + }); + if matches!(vim.mode, Mode::Insert | Mode::Replace) { + vim.normal_before(&Default::default(), window, cx); + } else { + vim.switch_mode(Mode::Normal, true, window, cx); + } + vim.update_editor(window, cx, |_, editor, _, cx| { + if let Some(first_sel) = initial_selections { + if let Some(tx_id) = editor + .buffer() + .update(cx, |multi, cx| multi.last_transaction_id(cx)) + { + let last_sel = editor.selections.disjoint_anchors(); + editor.modify_transaction_selection_history(tx_id, |old| { + old.0 = first_sel; + old.1 = Some(last_sel); + }); + } + } + }); + }) + .ok(); + }) + .detach(); + }); + Vim::action(editor, cx, |vim, _: &CountCommand, window, cx| { let Some(workspace) = vim.workspace(window) else { return; @@ -675,14 +758,15 @@ impl VimCommand { } else { return None; }; - if !args.is_empty() { + + let action = if args.is_empty() { + action + } else { // if command does not accept args and we have args then we should do no action - if let Some(args_fn) = &self.args { - args_fn.deref()(action, args) - } else { - None - } - } else if let Some(range) = range { + self.args.as_ref()?(action, args)? + }; + + if let Some(range) = range { self.range.as_ref().and_then(|f| f(action, range)) } else { Some(action) @@ -1061,6 +1145,27 @@ fn generate_commands(_: &App) -> Vec<VimCommand> { save_intent: Some(SaveIntent::Skip), close_pinned: true, }), + VimCommand::new( + ("norm", "al"), + VimNorm { + command: "".into(), + range: None, + }, + ) + .args(|_, args| { + Some( + VimNorm { + command: args, + range: None, + } + .boxed_clone(), + ) + }) + .range(|action, range| { + let mut action: VimNorm = action.as_any().downcast_ref::<VimNorm>().unwrap().clone(); + action.range.replace(range.clone()); + Some(Box::new(action)) + }), VimCommand::new(("bn", "ext"), workspace::ActivateNextItem).count(), VimCommand::new(("bN", "ext"), workspace::ActivatePreviousItem).count(), VimCommand::new(("bp", "revious"), workspace::ActivatePreviousItem).count(), @@ -2298,4 +2403,78 @@ mod test { }); assert!(mark.is_none()) } + + #[gpui::test] + async fn test_normal_command(cx: &mut TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {" + The quick + brown« fox + jumpsˇ» over + the lazy dog + "}) + .await; + + cx.simulate_shared_keystrokes(": n o r m space w C w o r d") + .await; + cx.simulate_shared_keystrokes("enter").await; + + cx.shared_state().await.assert_eq(indoc! {" + The quick + brown word + jumps worˇd + the lazy dog + "}); + + cx.simulate_shared_keystrokes(": n o r m space _ w c i w t e s t") + .await; + cx.simulate_shared_keystrokes("enter").await; + + cx.shared_state().await.assert_eq(indoc! {" + The quick + brown word + jumps tesˇt + the lazy dog + "}); + + cx.simulate_shared_keystrokes("_ l v l : n o r m space s l a") + .await; + cx.simulate_shared_keystrokes("enter").await; + + cx.shared_state().await.assert_eq(indoc! {" + The quick + brown word + lˇaumps test + the lazy dog + "}); + + cx.set_shared_state(indoc! {" + ˇThe quick + brown fox + jumps over + the lazy dog + "}) + .await; + + cx.simulate_shared_keystrokes("c i w M y escape").await; + + cx.shared_state().await.assert_eq(indoc! {" + Mˇy quick + brown fox + jumps over + the lazy dog + "}); + + cx.simulate_shared_keystrokes(": n o r m space u").await; + cx.simulate_shared_keystrokes("enter").await; + + cx.shared_state().await.assert_eq(indoc! {" + ˇThe quick + brown fox + jumps over + the lazy dog + "}); + // Once ctrl-v to input character literals is added there should be a test for redo + } } diff --git a/crates/vim/src/helix.rs b/crates/vim/src/helix.rs index ec9b959b12..ca93c9c1de 100644 --- a/crates/vim/src/helix.rs +++ b/crates/vim/src/helix.rs @@ -1,21 +1,31 @@ -use editor::{DisplayPoint, Editor, movement}; +use editor::{DisplayPoint, Editor, SelectionEffects, ToOffset, ToPoint, movement}; use gpui::{Action, actions}; use gpui::{Context, Window}; use language::{CharClassifier, CharKind}; -use text::SelectionGoal; +use text::{Bias, SelectionGoal}; -use crate::{Vim, motion::Motion, state::Mode}; +use crate::{ + Vim, + motion::{Motion, right}, + state::Mode, +}; actions!( vim, [ /// Switches to normal mode after the cursor (Helix-style). - HelixNormalAfter + HelixNormalAfter, + /// Inserts at the beginning of the selection. + HelixInsert, + /// Appends at the end of the selection. + HelixAppend, ] ); pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) { Vim::action(editor, cx, Vim::helix_normal_after); + Vim::action(editor, cx, Vim::helix_insert); + Vim::action(editor, cx, Vim::helix_append); } impl Vim { @@ -299,6 +309,112 @@ impl Vim { _ => self.helix_move_and_collapse(motion, times, window, cx), } } + + fn helix_insert(&mut self, _: &HelixInsert, window: &mut Window, cx: &mut Context<Self>) { + self.start_recording(cx); + self.update_editor(window, cx, |_, editor, window, cx| { + editor.change_selections(Default::default(), window, cx, |s| { + s.move_with(|_map, selection| { + // In helix normal mode, move cursor to start of selection and collapse + if !selection.is_empty() { + selection.collapse_to(selection.start, SelectionGoal::None); + } + }); + }); + }); + self.switch_mode(Mode::Insert, false, window, cx); + } + + fn helix_append(&mut self, _: &HelixAppend, window: &mut Window, cx: &mut Context<Self>) { + self.start_recording(cx); + self.switch_mode(Mode::Insert, false, window, cx); + self.update_editor(window, cx, |_, editor, window, cx| { + editor.change_selections(Default::default(), window, cx, |s| { + s.move_with(|map, selection| { + let point = if selection.is_empty() { + right(map, selection.head(), 1) + } else { + selection.end + }; + selection.collapse_to(point, SelectionGoal::None); + }); + }); + }); + } + + pub fn helix_replace(&mut self, text: &str, window: &mut Window, cx: &mut Context<Self>) { + self.update_editor(window, cx, |_, editor, window, cx| { + editor.transact(window, cx, |editor, window, cx| { + let (map, selections) = editor.selections.all_display(cx); + + // Store selection info for positioning after edit + let selection_info: Vec<_> = selections + .iter() + .map(|selection| { + let range = selection.range(); + let start_offset = range.start.to_offset(&map, Bias::Left); + let end_offset = range.end.to_offset(&map, Bias::Left); + let was_empty = range.is_empty(); + let was_reversed = selection.reversed; + ( + map.buffer_snapshot.anchor_at(start_offset, Bias::Left), + end_offset - start_offset, + was_empty, + was_reversed, + ) + }) + .collect(); + + let mut edits = Vec::new(); + for selection in &selections { + let mut range = selection.range(); + + // For empty selections, extend to replace one character + if range.is_empty() { + range.end = movement::saturating_right(&map, range.start); + } + + let byte_range = range.start.to_offset(&map, Bias::Left) + ..range.end.to_offset(&map, Bias::Left); + + if !byte_range.is_empty() { + let replacement_text = text.repeat(byte_range.len()); + edits.push((byte_range, replacement_text)); + } + } + + editor.edit(edits, cx); + + // Restore selections based on original info + let snapshot = editor.buffer().read(cx).snapshot(cx); + let ranges: Vec<_> = selection_info + .into_iter() + .map(|(start_anchor, original_len, was_empty, was_reversed)| { + let start_point = start_anchor.to_point(&snapshot); + if was_empty { + // For cursor-only, collapse to start + start_point..start_point + } else { + // For selections, span the replaced text + let replacement_len = text.len() * original_len; + let end_offset = start_anchor.to_offset(&snapshot) + replacement_len; + let end_point = snapshot.offset_to_point(end_offset); + if was_reversed { + end_point..start_point + } else { + start_point..end_point + } + } + }) + .collect(); + + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges(ranges); + }); + }); + }); + self.switch_mode(Mode::HelixNormal, true, window, cx); + } } #[cfg(test)] @@ -497,4 +613,94 @@ mod test { cx.assert_state("«ˇaa»\n", Mode::HelixNormal); } + + #[gpui::test] + async fn test_insert_selected(cx: &mut gpui::TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + cx.set_state( + indoc! {" + «The ˇ»quick brown + fox jumps over + the lazy dog."}, + Mode::HelixNormal, + ); + + cx.simulate_keystrokes("i"); + + cx.assert_state( + indoc! {" + ˇThe quick brown + fox jumps over + the lazy dog."}, + Mode::Insert, + ); + } + + #[gpui::test] + async fn test_append(cx: &mut gpui::TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + // test from the end of the selection + cx.set_state( + indoc! {" + «Theˇ» quick brown + fox jumps over + the lazy dog."}, + Mode::HelixNormal, + ); + + cx.simulate_keystrokes("a"); + + cx.assert_state( + indoc! {" + Theˇ quick brown + fox jumps over + the lazy dog."}, + Mode::Insert, + ); + + // test from the beginning of the selection + cx.set_state( + indoc! {" + «ˇThe» quick brown + fox jumps over + the lazy dog."}, + Mode::HelixNormal, + ); + + cx.simulate_keystrokes("a"); + + cx.assert_state( + indoc! {" + Theˇ quick brown + fox jumps over + the lazy dog."}, + Mode::Insert, + ); + } + + #[gpui::test] + async fn test_replace(cx: &mut gpui::TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + + // No selection (single character) + cx.set_state("ˇaa", Mode::HelixNormal); + + cx.simulate_keystrokes("r x"); + + cx.assert_state("ˇxa", Mode::HelixNormal); + + // Cursor at the beginning + cx.set_state("«ˇaa»", Mode::HelixNormal); + + cx.simulate_keystrokes("r x"); + + cx.assert_state("«ˇxx»", Mode::HelixNormal); + + // Cursor at the end + cx.set_state("«aaˇ»", Mode::HelixNormal); + + cx.simulate_keystrokes("r x"); + + cx.assert_state("«xxˇ»", Mode::HelixNormal); + } } diff --git a/crates/vim/src/insert.rs b/crates/vim/src/insert.rs index 89c60adee7..0a370e16ba 100644 --- a/crates/vim/src/insert.rs +++ b/crates/vim/src/insert.rs @@ -21,7 +21,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) { } impl Vim { - fn normal_before( + pub(crate) fn normal_before( &mut self, action: &NormalBefore, window: &mut Window, diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index a50b238cc5..0e487f4410 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -987,7 +987,7 @@ impl Motion { SelectionGoal::None, ), NextWordEnd { ignore_punctuation } => ( - next_word_end(map, point, *ignore_punctuation, times, true), + next_word_end(map, point, *ignore_punctuation, times, true, true), SelectionGoal::None, ), PreviousWordStart { ignore_punctuation } => ( @@ -1723,14 +1723,19 @@ pub(crate) fn next_word_end( ignore_punctuation: bool, times: usize, allow_cross_newline: bool, + always_advance: bool, ) -> DisplayPoint { let classifier = map .buffer_snapshot .char_classifier_at(point.to_point(map)) .ignore_punctuation(ignore_punctuation); for _ in 0..times { - let new_point = next_char(map, point, allow_cross_newline); let mut need_next_char = false; + let new_point = if always_advance { + next_char(map, point, allow_cross_newline) + } else { + point + }; let new_point = movement::find_boundary_exclusive( map, new_point, @@ -3803,7 +3808,7 @@ mod test { cx.update_editor(|editor, _window, cx| { let range = editor.selections.newest_anchor().range(); let inlay_text = " field: int,\n field2: string\n field3: float"; - let inlay = Inlay::inline_completion(1, range.start, inlay_text); + let inlay = Inlay::edit_prediction(1, range.start, inlay_text); editor.splice_inlays(&[], vec![inlay], cx); }); @@ -3835,7 +3840,7 @@ mod test { let end_of_line = snapshot.anchor_after(Point::new(0, snapshot.line_len(MultiBufferRow(0)))); let inlay_text = " hint"; - let inlay = Inlay::inline_completion(1, end_of_line, inlay_text); + let inlay = Inlay::edit_prediction(1, end_of_line, inlay_text); editor.splice_inlays(&[], vec![inlay], cx); }); cx.simulate_keystrokes("$"); diff --git a/crates/vim/src/normal/change.rs b/crates/vim/src/normal/change.rs index 9485f17477..c1bc7a70ae 100644 --- a/crates/vim/src/normal/change.rs +++ b/crates/vim/src/normal/change.rs @@ -51,6 +51,7 @@ impl Vim { ignore_punctuation, &text_layout_details, motion == Motion::NextSubwordStart { ignore_punctuation }, + !matches!(motion, Motion::NextWordStart { .. }), ) } _ => { @@ -89,7 +90,7 @@ impl Vim { if let Some(kind) = motion_kind { vim.copy_selections_content(editor, kind, window, cx); editor.insert("", window, cx); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); } }); }); @@ -122,7 +123,7 @@ impl Vim { if objects_found { vim.copy_selections_content(editor, MotionKind::Exclusive, window, cx); editor.insert("", window, cx); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); } }); }); @@ -148,6 +149,7 @@ fn expand_changed_word_selection( ignore_punctuation: bool, text_layout_details: &TextLayoutDetails, use_subword: bool, + always_advance: bool, ) -> Option<MotionKind> { let is_in_word = || { let classifier = map @@ -173,8 +175,14 @@ fn expand_changed_word_selection( selection.end = motion::next_subword_end(map, selection.end, ignore_punctuation, 1, false); } else { - selection.end = - motion::next_word_end(map, selection.end, ignore_punctuation, 1, false); + selection.end = motion::next_word_end( + map, + selection.end, + ignore_punctuation, + 1, + false, + always_advance, + ); } selection.end = motion::next_char(map, selection.end, false); } @@ -271,6 +279,10 @@ mod test { cx.simulate("c shift-w", "Test teˇst-test test") .await .assert_matches(); + + // on last character of word, `cw` doesn't eat subsequent punctuation + // see https://github.com/zed-industries/zed/issues/35269 + cx.simulate("c w", "tesˇt-test").await.assert_matches(); } #[gpui::test] diff --git a/crates/vim/src/normal/delete.rs b/crates/vim/src/normal/delete.rs index ccbb3dd0fd..2cf40292cf 100644 --- a/crates/vim/src/normal/delete.rs +++ b/crates/vim/src/normal/delete.rs @@ -82,7 +82,7 @@ impl Vim { selection.collapse_to(cursor, selection.goal) }); }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); }); } @@ -169,7 +169,7 @@ impl Vim { selection.collapse_to(cursor, selection.goal) }); }); - editor.refresh_inline_completion(true, false, window, cx); + editor.refresh_edit_prediction(true, false, window, cx); }); }); } diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 95a08d7c66..72edbe77ed 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -747,7 +747,7 @@ impl Vim { Vim::action( editor, cx, - |vim, action: &editor::AcceptEditPrediction, window, cx| { + |vim, action: &editor::actions::AcceptEditPrediction, window, cx| { vim.update_editor(window, cx, |_, editor, window, cx| { editor.accept_edit_prediction(action, window, cx); }); @@ -1639,6 +1639,7 @@ impl Vim { Mode::Visual | Mode::VisualLine | Mode::VisualBlock => { self.visual_replace(text, window, cx) } + Mode::HelixNormal => self.helix_replace(&text, window, cx), _ => self.clear_operator(window, cx), }, Some(Operator::Digraph { first_char }) => { @@ -1740,11 +1741,11 @@ impl Vim { editor.set_autoindent(vim.should_autoindent()); editor.selections.line_mode = matches!(vim.mode, Mode::VisualLine); - let hide_inline_completions = match vim.mode { + let hide_edit_predictions = match vim.mode { Mode::Insert | Mode::Replace => false, _ => true, }; - editor.set_inline_completions_hidden_for_vim_mode(hide_inline_completions, window, cx); + editor.set_edit_predictions_hidden_for_vim_mode(hide_edit_predictions, window, cx); }); cx.notify() } diff --git a/crates/vim/test_data/test_change_w.json b/crates/vim/test_data/test_change_w.json index 27be543532..149dac8420 100644 --- a/crates/vim/test_data/test_change_w.json +++ b/crates/vim/test_data/test_change_w.json @@ -30,3 +30,7 @@ {"Key":"c"} {"Key":"shift-w"} {"Get":{"state":"Test teˇ test","mode":"Insert"}} +{"Put":{"state":"tesˇt-test"}} +{"Key":"c"} +{"Key":"w"} +{"Get":{"state":"tesˇ-test","mode":"Insert"}} diff --git a/crates/vim/test_data/test_normal_command.json b/crates/vim/test_data/test_normal_command.json new file mode 100644 index 0000000000..efd1d532c4 --- /dev/null +++ b/crates/vim/test_data/test_normal_command.json @@ -0,0 +1,64 @@ +{"Put":{"state":"The quick\nbrown« fox\njumpsˇ» over\nthe lazy dog\n"}} +{"Key":":"} +{"Key":"n"} +{"Key":"o"} +{"Key":"r"} +{"Key":"m"} +{"Key":"space"} +{"Key":"w"} +{"Key":"C"} +{"Key":"w"} +{"Key":"o"} +{"Key":"r"} +{"Key":"d"} +{"Key":"enter"} +{"Get":{"state":"The quick\nbrown word\njumps worˇd\nthe lazy dog\n","mode":"Normal"}} +{"Key":":"} +{"Key":"n"} +{"Key":"o"} +{"Key":"r"} +{"Key":"m"} +{"Key":"space"} +{"Key":"_"} +{"Key":"w"} +{"Key":"c"} +{"Key":"i"} +{"Key":"w"} +{"Key":"t"} +{"Key":"e"} +{"Key":"s"} +{"Key":"t"} +{"Key":"enter"} +{"Get":{"state":"The quick\nbrown word\njumps tesˇt\nthe lazy dog\n","mode":"Normal"}} +{"Key":"_"} +{"Key":"l"} +{"Key":"v"} +{"Key":"l"} +{"Key":":"} +{"Key":"n"} +{"Key":"o"} +{"Key":"r"} +{"Key":"m"} +{"Key":"space"} +{"Key":"s"} +{"Key":"l"} +{"Key":"a"} +{"Key":"enter"} +{"Get":{"state":"The quick\nbrown word\nlˇaumps test\nthe lazy dog\n","mode":"Normal"}} +{"Put":{"state":"ˇThe quick\nbrown fox\njumps over\nthe lazy dog\n"}} +{"Key":"c"} +{"Key":"i"} +{"Key":"w"} +{"Key":"M"} +{"Key":"y"} +{"Key":"escape"} +{"Get":{"state":"Mˇy quick\nbrown fox\njumps over\nthe lazy dog\n","mode":"Normal"}} +{"Key":":"} +{"Key":"n"} +{"Key":"o"} +{"Key":"r"} +{"Key":"m"} +{"Key":"space"} +{"Key":"u"} +{"Key":"enter"} +{"Get":{"state":"ˇThe quick\nbrown fox\njumps over\nthe lazy dog\n","mode":"Normal"}} diff --git a/crates/web_search/Cargo.toml b/crates/web_search/Cargo.toml index e5b8ca63b2..4ba46faec4 100644 --- a/crates/web_search/Cargo.toml +++ b/crates/web_search/Cargo.toml @@ -13,8 +13,8 @@ path = "src/web_search.rs" [dependencies] anyhow.workspace = true +cloud_llm_client.workspace = true collections.workspace = true gpui.workspace = true serde.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/web_search/src/web_search.rs b/crates/web_search/src/web_search.rs index a131b0de71..8578cfe4aa 100644 --- a/crates/web_search/src/web_search.rs +++ b/crates/web_search/src/web_search.rs @@ -1,8 +1,9 @@ +use std::sync::Arc; + use anyhow::Result; +use cloud_llm_client::WebSearchResponse; use collections::HashMap; use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task}; -use std::sync::Arc; -use zed_llm_client::WebSearchResponse; pub fn init(cx: &mut App) { let registry = cx.new(|_cx| WebSearchRegistry::default()); diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index 2e052796c4..f7a248d106 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,6 +14,7 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true +cloud_llm_client.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true @@ -22,4 +23,3 @@ serde.workspace = true serde_json.workspace = true web_search.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index adf79b0ff6..52ee0da0d4 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::Client; +use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Subscription, Task}; use http_client::{HttpClient, Method}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use web_search::{WebSearchProvider, WebSearchProviderId}; -use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; pub struct CloudWebSearchProvider { state: Entity<State>, diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml index 769dd8d6aa..acb3fe0f84 100644 --- a/crates/welcome/Cargo.toml +++ b/crates/welcome/Cargo.toml @@ -29,7 +29,6 @@ project.workspace = true serde.workspace = true settings.workspace = true telemetry.workspace = true -theme.workspace = true ui.workspace = true util.workspace = true vim_mode_setting.workspace = true diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs index 49bf2031ab..b0a1c316f4 100644 --- a/crates/welcome/src/welcome.rs +++ b/crates/welcome/src/welcome.rs @@ -1,10 +1,11 @@ -use client::{DisableAiSettings, TelemetrySettings, telemetry::Telemetry}; +use client::{TelemetrySettings, telemetry::Telemetry}; use db::kvp::KEY_VALUE_STORE; use gpui::{ Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, ParentElement, Render, Styled, Subscription, Task, WeakEntity, Window, actions, svg, }; use language::language_settings::{EditPredictionProvider, all_language_settings}; +use project::DisableAiSettings; use settings::{Settings, SettingsStore}; use std::sync::Arc; use ui::{CheckboxWithLabel, ElevationIndex, Tooltip, prelude::*}; @@ -21,7 +22,6 @@ pub use multibuffer_hint::*; mod base_keymap_picker; mod multibuffer_hint; -mod welcome_ui; actions!( welcome, diff --git a/crates/welcome/src/welcome_ui.rs b/crates/welcome/src/welcome_ui.rs deleted file mode 100644 index 622b6f448d..0000000000 --- a/crates/welcome/src/welcome_ui.rs +++ /dev/null @@ -1 +0,0 @@ -mod theme_preview; diff --git a/crates/welcome/src/welcome_ui/theme_preview.rs b/crates/welcome/src/welcome_ui/theme_preview.rs deleted file mode 100644 index b3a80c74c3..0000000000 --- a/crates/welcome/src/welcome_ui/theme_preview.rs +++ /dev/null @@ -1,280 +0,0 @@ -#![allow(unused, dead_code)] -use gpui::{Hsla, Length}; -use std::sync::Arc; -use theme::{Theme, ThemeRegistry}; -use ui::{ - IntoElement, RenderOnce, component_prelude::Documented, prelude::*, utils::inner_corner_radius, -}; - -/// Shows a preview of a theme as an abstract illustration -/// of a thumbnail-sized editor. -#[derive(IntoElement, RegisterComponent, Documented)] -pub struct ThemePreviewTile { - theme: Arc<Theme>, - selected: bool, - seed: f32, -} - -impl ThemePreviewTile { - pub fn new(theme: Arc<Theme>, selected: bool, seed: f32) -> Self { - Self { - theme, - selected, - seed, - } - } - - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self - } -} - -impl RenderOnce for ThemePreviewTile { - fn render(self, _window: &mut ui::Window, _cx: &mut ui::App) -> impl IntoElement { - let color = self.theme.colors(); - - let root_radius = px(8.0); - let root_border = px(2.0); - let root_padding = px(2.0); - let child_border = px(1.0); - let inner_radius = - inner_corner_radius(root_radius, root_border, root_padding, child_border); - - let item_skeleton = |w: Length, h: Pixels, bg: Hsla| div().w(w).h(h).rounded_full().bg(bg); - - let skeleton_height = px(4.); - - let sidebar_seeded_width = |seed: f32, index: usize| { - let value = (seed * 1000.0 + index as f32 * 10.0).sin() * 0.5 + 0.5; - 0.5 + value * 0.45 - }; - - let sidebar_skeleton_items = 8; - - let sidebar_skeleton = (0..sidebar_skeleton_items) - .map(|i| { - let width = sidebar_seeded_width(self.seed, i); - item_skeleton( - relative(width).into(), - skeleton_height, - color.text.alpha(0.45), - ) - }) - .collect::<Vec<_>>(); - - let sidebar = div() - .h_full() - .w(relative(0.25)) - .border_r(px(1.)) - .border_color(color.border_transparent) - .bg(color.panel_background) - .child( - div() - .p_2() - .flex() - .flex_col() - .size_full() - .gap(px(4.)) - .children(sidebar_skeleton), - ); - - let pseudo_code_skeleton = |theme: Arc<Theme>, seed: f32| -> AnyElement { - let colors = theme.colors(); - let syntax = theme.syntax(); - - let keyword_color = syntax.get("keyword").color; - let function_color = syntax.get("function").color; - let string_color = syntax.get("string").color; - let comment_color = syntax.get("comment").color; - let variable_color = syntax.get("variable").color; - let type_color = syntax.get("type").color; - let punctuation_color = syntax.get("punctuation").color; - - let syntax_colors = [ - keyword_color, - function_color, - string_color, - variable_color, - type_color, - punctuation_color, - comment_color, - ]; - - let line_width = |line_idx: usize, block_idx: usize| -> f32 { - let val = (seed * 100.0 + line_idx as f32 * 20.0 + block_idx as f32 * 5.0).sin() - * 0.5 - + 0.5; - 0.05 + val * 0.2 - }; - - let indentation = |line_idx: usize| -> f32 { - let step = line_idx % 6; - if step < 3 { - step as f32 * 0.1 - } else { - (5 - step) as f32 * 0.1 - } - }; - - let pick_color = |line_idx: usize, block_idx: usize| -> Hsla { - let idx = ((seed * 10.0 + line_idx as f32 * 7.0 + block_idx as f32 * 3.0).sin() - * 3.5) - .abs() as usize - % syntax_colors.len(); - syntax_colors[idx].unwrap_or(colors.text) - }; - - let line_count = 13; - - let lines = (0..line_count) - .map(|line_idx| { - let block_count = (((seed * 30.0 + line_idx as f32 * 12.0).sin() * 0.5 + 0.5) - * 3.0) - .round() as usize - + 2; - - let indent = indentation(line_idx); - - let blocks = (0..block_count) - .map(|block_idx| { - let width = line_width(line_idx, block_idx); - let color = pick_color(line_idx, block_idx); - item_skeleton(relative(width).into(), skeleton_height, color) - }) - .collect::<Vec<_>>(); - - h_flex().gap(px(2.)).ml(relative(indent)).children(blocks) - }) - .collect::<Vec<_>>(); - - v_flex() - .size_full() - .p_1() - .gap(px(6.)) - .children(lines) - .into_any_element() - }; - - let pane = div() - .h_full() - .flex_grow() - .flex() - .flex_col() - // .child( - // div() - // .w_full() - // .border_color(color.border) - // .border_b(px(1.)) - // .h(relative(0.1)) - // .bg(color.tab_bar_background), - // ) - .child( - div() - .size_full() - .overflow_hidden() - .bg(color.editor_background) - .p_2() - .child(pseudo_code_skeleton(self.theme.clone(), self.seed)), - ); - - let content = div().size_full().flex().child(sidebar).child(pane); - - div() - .size_full() - .rounded(root_radius) - .p(root_padding) - .border(root_border) - .border_color(color.border_transparent) - .when(self.selected, |this| { - this.border_color(color.border_selected) - }) - .child( - div() - .size_full() - .rounded(inner_radius) - .border(child_border) - .border_color(color.border) - .bg(color.background) - .child(content), - ) - } -} - -impl Component for ThemePreviewTile { - fn description() -> Option<&'static str> { - Some(Self::DOCS) - } - - fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> { - let theme_registry = ThemeRegistry::global(cx); - - let one_dark = theme_registry.get("One Dark"); - let one_light = theme_registry.get("One Light"); - let gruvbox_dark = theme_registry.get("Gruvbox Dark"); - let gruvbox_light = theme_registry.get("Gruvbox Light"); - - let themes_to_preview = vec![ - one_dark.clone().ok(), - one_light.clone().ok(), - gruvbox_dark.clone().ok(), - gruvbox_light.clone().ok(), - ] - .into_iter() - .flatten() - .collect::<Vec<_>>(); - - Some( - v_flex() - .gap_6() - .p_4() - .children({ - if let Some(one_dark) = one_dark.ok() { - vec![example_group(vec![ - single_example( - "Default", - div() - .w(px(240.)) - .h(px(180.)) - .child(ThemePreviewTile::new(one_dark.clone(), false, 0.42)) - .into_any_element(), - ), - single_example( - "Selected", - div() - .w(px(240.)) - .h(px(180.)) - .child(ThemePreviewTile::new(one_dark, true, 0.42)) - .into_any_element(), - ), - ])] - } else { - vec![] - } - }) - .child( - example_group(vec![single_example( - "Default Themes", - h_flex() - .gap_4() - .children( - themes_to_preview - .iter() - .enumerate() - .map(|(i, theme)| { - div().w(px(200.)).h(px(140.)).child(ThemePreviewTile::new( - theme.clone(), - false, - 0.42, - )) - }) - .collect::<Vec<_>>(), - ) - .into_any_element(), - )]) - .grow(), - ) - .into_any_element(), - ) - } -} diff --git a/crates/workspace/src/dock.rs b/crates/workspace/src/dock.rs index 7165de23ec..ca63d3e553 100644 --- a/crates/workspace/src/dock.rs +++ b/crates/workspace/src/dock.rs @@ -934,6 +934,10 @@ impl Render for PanelButtons { h_flex() .gap_1() + .when( + has_buttons && dock.position == DockPosition::Bottom, + |this| this.child(Divider::vertical().color(DividerColor::Border)), + ) .children(buttons) .when(has_buttons && dock.position == DockPosition::Left, |this| { this.child(Divider::vertical().color(DividerColor::Border)) diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index c7a2562a1b..fff15d2b52 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -1664,10 +1664,33 @@ impl Pane { } if should_save { - if !Self::save_item(project.clone(), &pane, &*item_to_close, save_intent, cx) - .await? + match Self::save_item(project.clone(), &pane, &*item_to_close, save_intent, cx) + .await { - break; + Ok(success) => { + if !success { + break; + } + } + Err(err) => { + let answer = pane.update_in(cx, |_, window, cx| { + let detail = Self::file_names_for_prompt( + &mut [&item_to_close].into_iter(), + cx, + ); + window.prompt( + PromptLevel::Warning, + &format!("Unable to save file: {}", &err), + Some(&detail), + &["Close Without Saving", "Cancel"], + cx, + ) + })?; + match answer.await { + Ok(0) => {} + Ok(1..) | Err(_) => break, + } + } } } @@ -2832,7 +2855,7 @@ impl Pane { }) .collect::<Vec<_>>(); let tab_count = tab_items.len(); - if self.pinned_tab_count > tab_count { + if self.is_tab_pinned(tab_count) { log::warn!( "Pinned tab count ({}) exceeds actual tab count ({}). \ This should not happen. If possible, add reproduction steps, \ @@ -2922,7 +2945,7 @@ impl Pane { this.handle_external_paths_drop(paths, window, cx) })) .on_click(cx.listener(move |this, event: &ClickEvent, window, cx| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.dispatch_action( this.double_click_dispatch_action.boxed_clone(), cx, @@ -3030,7 +3053,7 @@ impl Pane { || cfg!(not(target_os = "macos")) && window.modifiers().control; let from_pane = dragged_tab.pane.clone(); - let from_ix = dragged_tab.ix; + self.workspace .update(cx, |_, cx| { cx.defer_in(window, move |workspace, window, cx| { @@ -3062,9 +3085,13 @@ impl Pane { } to_pane.update(cx, |this, _| { if to_pane == from_pane { - let moved_right = ix > from_ix; - let ix = if moved_right { ix - 1 } else { ix }; - let is_pinned_in_to_pane = this.is_tab_pinned(ix); + let actual_ix = this + .items + .iter() + .position(|item| item.item_id() == item_id) + .unwrap_or(0); + + let is_pinned_in_to_pane = this.is_tab_pinned(actual_ix); if !was_pinned_in_from_pane && is_pinned_in_to_pane { this.pinned_tab_count += 1; @@ -3613,7 +3640,7 @@ impl Render for Pane { .justify_center() .on_click(cx.listener( move |this, event: &ClickEvent, window, cx| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.dispatch_action( this.double_click_dispatch_action.boxed_clone(), cx, @@ -4950,6 +4977,43 @@ mod tests { assert_item_labels(&pane_a, ["B!", "A*!"], cx); } + #[gpui::test] + async fn test_dragging_pinned_tab_onto_unpinned_tab_reduces_unpinned_tab_count( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + + let project = Project::test(fs, None, cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let pane_a = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone()); + + // Add A, B to pane A and pin A + let item_a = add_labeled_item(&pane_a, "A", false, cx); + add_labeled_item(&pane_a, "B", false, cx); + pane_a.update_in(cx, |pane, window, cx| { + let ix = pane.index_for_item_id(item_a.item_id()).unwrap(); + pane.pin_tab_at(ix, window, cx); + }); + assert_item_labels(&pane_a, ["A!", "B*"], cx); + + // Drag pinned A on top of B in the same pane, which changes tab order to B, A + pane_a.update_in(cx, |pane, window, cx| { + let dragged_tab = DraggedTab { + pane: pane_a.clone(), + item: item_a.boxed_clone(), + ix: 0, + detail: 0, + is_active: true, + }; + pane.handle_tab_drop(&dragged_tab, 1, window, cx); + }); + + // Neither are pinned + assert_item_labels(&pane_a, ["B", "A*"], cx); + } + #[gpui::test] async fn test_drag_pinned_tab_beyond_unpinned_tab_in_same_pane_becomes_unpinned( cx: &mut TestAppContext, diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 3f8b098203..6fa5c969e7 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -939,6 +939,26 @@ impl WorkspaceDb { } } + query! { + pub async fn update_ssh_project_paths_query(ssh_project_id: u64, paths: String) -> Result<Option<SerializedSshProject>> { + UPDATE ssh_projects + SET paths = ?2 + WHERE id = ?1 + RETURNING id, host, port, paths, user + } + } + + pub(crate) async fn update_ssh_project_paths( + &self, + ssh_project_id: SshProjectId, + new_paths: Vec<String>, + ) -> Result<SerializedSshProject> { + let paths = serde_json::to_string(&new_paths)?; + self.update_ssh_project_paths_query(ssh_project_id.0, paths) + .await? + .context("failed to update ssh project paths") + } + query! { pub async fn next_id() -> Result<WorkspaceId> { INSERT INTO workspaces DEFAULT VALUES RETURNING workspace_id @@ -2624,4 +2644,56 @@ mod tests { assert_eq!(workspace.center_group, new_workspace.center_group); } + + #[gpui::test] + async fn test_update_ssh_project_paths() { + zlog::init_test(); + + let db = WorkspaceDb::open_test_db("test_update_ssh_project_paths").await; + + let (host, port, initial_paths, user) = ( + "example.com".to_string(), + Some(22_u16), + vec!["/home/user".to_string(), "/etc/nginx".to_string()], + Some("user".to_string()), + ); + + let project = db + .get_or_create_ssh_project(host.clone(), port, initial_paths.clone(), user.clone()) + .await + .unwrap(); + + assert_eq!(project.host, host); + assert_eq!(project.paths, initial_paths); + assert_eq!(project.user, user); + + let new_paths = vec![ + "/home/user".to_string(), + "/etc/nginx".to_string(), + "/var/log".to_string(), + "/opt/app".to_string(), + ]; + + let updated_project = db + .update_ssh_project_paths(project.id, new_paths.clone()) + .await + .unwrap(); + + assert_eq!(updated_project.id, project.id); + assert_eq!(updated_project.paths, new_paths); + + let retrieved_project = db + .get_ssh_project( + host.clone(), + port, + serde_json::to_string(&new_paths).unwrap(), + user.clone(), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(retrieved_project.id, project.id); + assert_eq!(retrieved_project.paths, new_paths); + } } diff --git a/crates/workspace/src/tasks.rs b/crates/workspace/src/tasks.rs index 26edbd8d03..32d066c7eb 100644 --- a/crates/workspace/src/tasks.rs +++ b/crates/workspace/src/tasks.rs @@ -73,7 +73,7 @@ impl Workspace { if let Some(terminal_provider) = self.terminal_provider.as_ref() { let task_status = terminal_provider.spawn(spawn_in_terminal, window, cx); - cx.background_spawn(async move { + let task = cx.background_spawn(async move { match task_status.await { Some(Ok(status)) => { if status.success() { @@ -82,11 +82,11 @@ impl Workspace { log::debug!("Task spawn failed, code: {:?}", status.code()); } } - Some(Err(e)) => log::error!("Task spawn failed: {e}"), + Some(Err(e)) => log::error!("Task spawn failed: {e:#}"), None => log::debug!("Task spawn got cancelled"), } - }) - .detach(); + }); + self.scheduled_tasks.push(task); } } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 4c70c52d5a..6572574ce1 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -32,7 +32,7 @@ use futures::{ mpsc::{self, UnboundedReceiver, UnboundedSender}, oneshot, }, - future::try_join_all, + future::{Shared, try_join_all}, }; use gpui::{ Action, AnyEntity, AnyView, AnyWeakView, App, AsyncApp, AsyncWindowContext, Bounds, Context, @@ -48,7 +48,10 @@ pub use item::{ ProjectItem, SerializableItem, SerializableItemHandle, WeakItemHandle, }; use itertools::Itertools; -use language::{Buffer, LanguageRegistry, Rope}; +use language::{ + Buffer, LanguageRegistry, Rope, + language_settings::{AllLanguageSettings, all_language_settings}, +}; pub use modal_layer::*; use node_runtime::NodeRuntime; use notifications::{ @@ -74,7 +77,7 @@ use remote::{SshClientDelegate, SshConnectionOptions, ssh_session::ConnectionIde use schemars::JsonSchema; use serde::Deserialize; use session::AppSession; -use settings::Settings; +use settings::{Settings, update_settings_file}; use shared_screen::SharedScreen; use sqlez::{ bindable::{Bind, Column, StaticColumnCount}, @@ -87,7 +90,7 @@ use std::{ borrow::Cow, cell::RefCell, cmp, - collections::hash_map::DefaultHasher, + collections::{VecDeque, hash_map::DefaultHasher}, env, hash::{Hash, Hasher}, path::{Path, PathBuf}, @@ -233,6 +236,8 @@ actions!( ToggleBottomDock, /// Toggles centered layout mode. ToggleCenteredLayout, + /// Toggles edit prediction feature globally for all files. + ToggleEditPrediction, /// Toggles the left dock. ToggleLeftDock, /// Toggles the right dock. @@ -1043,6 +1048,13 @@ type PromptForOpenPath = Box< ) -> oneshot::Receiver<Option<Vec<PathBuf>>>, >; +#[derive(Default)] +struct DispatchingKeystrokes { + dispatched: HashSet<Vec<Keystroke>>, + queue: VecDeque<Keystroke>, + task: Option<Shared<Task<()>>>, +} + /// Collects everything project-related for a certain window opened. /// In some way, is a counterpart of a window, as the [`WindowHandle`] could be downcast into `Workspace`. /// @@ -1058,7 +1070,6 @@ pub struct Workspace { center: PaneGroup, left_dock: Entity<Dock>, bottom_dock: Entity<Dock>, - bottom_dock_layout: BottomDockLayout, right_dock: Entity<Dock>, panes: Vec<Entity<Pane>>, panes_by_item: HashMap<EntityId, WeakEntity<Pane>>, @@ -1080,11 +1091,12 @@ pub struct Workspace { leader_updates_tx: mpsc::UnboundedSender<(PeerId, proto::UpdateFollowers)>, database_id: Option<WorkspaceId>, app_state: Arc<AppState>, - dispatching_keystrokes: Rc<RefCell<(HashSet<String>, Vec<Keystroke>)>>, + dispatching_keystrokes: Rc<RefCell<DispatchingKeystrokes>>, _subscriptions: Vec<Subscription>, _apply_leader_updates: Task<Result<()>>, _observe_current_user: Task<Result<()>>, - _schedule_serialize: Option<Task<()>>, + _schedule_serialize_workspace: Option<Task<()>>, + _schedule_serialize_ssh_paths: Option<Task<()>>, pane_history_timestamp: Arc<AtomicUsize>, bounds: Bounds<Pixels>, pub centered_layout: bool, @@ -1097,6 +1109,7 @@ pub struct Workspace { serialized_ssh_project: Option<SerializedSshProject>, _items_serializer: Task<Result<()>>, session_id: Option<String>, + scheduled_tasks: Vec<Task<()>>, } impl EventEmitter<Event> for Workspace {} @@ -1142,6 +1155,8 @@ impl Workspace { project::Event::WorktreeRemoved(_) | project::Event::WorktreeAdded(_) => { this.update_window_title(window, cx); + this.update_ssh_paths(cx); + this.serialize_ssh_paths(window, cx); this.serialize_workspace(window, cx); // This event could be triggered by `AddFolderToProject` or `RemoveFromProject`. this.update_history(cx); @@ -1299,7 +1314,6 @@ impl Workspace { ) .detach(); - let bottom_dock_layout = WorkspaceSettings::get_global(cx).bottom_dock_layout; let left_dock = Dock::new(DockPosition::Left, modal_layer.clone(), window, cx); let bottom_dock = Dock::new(DockPosition::Bottom, modal_layer.clone(), window, cx); let right_dock = Dock::new(DockPosition::Right, modal_layer.clone(), window, cx); @@ -1398,7 +1412,6 @@ impl Workspace { suppressed_notifications: HashSet::default(), left_dock, bottom_dock, - bottom_dock_layout, right_dock, project: project.clone(), follower_states: Default::default(), @@ -1411,7 +1424,8 @@ impl Workspace { app_state, _observe_current_user, _apply_leader_updates, - _schedule_serialize: None, + _schedule_serialize_workspace: None, + _schedule_serialize_ssh_paths: None, leader_updates_tx, _subscriptions: subscriptions, pane_history_timestamp, @@ -1428,6 +1442,7 @@ impl Workspace { _items_serializer, session_id: Some(session_id), serialized_ssh_project: None, + scheduled_tasks: Vec::new(), } } @@ -1624,10 +1639,6 @@ impl Workspace { &self.bottom_dock } - pub fn bottom_dock_layout(&self) -> BottomDockLayout { - self.bottom_dock_layout - } - pub fn set_bottom_dock_layout( &mut self, layout: BottomDockLayout, @@ -1639,7 +1650,6 @@ impl Workspace { content.bottom_dock_layout = Some(layout); }); - self.bottom_dock_layout = layout; cx.notify(); self.serialize_workspace(window, cx); } @@ -1803,10 +1813,7 @@ impl Workspace { .max_by(|b1, b2| b1.worktree_id.cmp(&b2.worktree_id)) }); - match latest_project_path_opened { - Some(latest_project_path_opened) => latest_project_path_opened == history_path, - None => true, - } + latest_project_path_opened.map_or(true, |path| path == history_path) }) } @@ -2311,49 +2318,65 @@ impl Workspace { window: &mut Window, cx: &mut Context<Self>, ) { - let mut state = self.dispatching_keystrokes.borrow_mut(); - if !state.0.insert(action.0.clone()) { - cx.propagate(); - return; - } - let mut keystrokes: Vec<Keystroke> = action + let keystrokes: Vec<Keystroke> = action .0 .split(' ') .flat_map(|k| Keystroke::parse(k).log_err()) .collect(); - keystrokes.reverse(); + let _ = self.send_keystrokes_impl(keystrokes, window, cx); + } - state.1.append(&mut keystrokes); - drop(state); + pub fn send_keystrokes_impl( + &mut self, + keystrokes: Vec<Keystroke>, + window: &mut Window, + cx: &mut Context<Self>, + ) -> Shared<Task<()>> { + let mut state = self.dispatching_keystrokes.borrow_mut(); + if !state.dispatched.insert(keystrokes.clone()) { + cx.propagate(); + return state.task.clone().unwrap(); + } + + state.queue.extend(keystrokes); let keystrokes = self.dispatching_keystrokes.clone(); - window - .spawn(cx, async move |cx| { - // limit to 100 keystrokes to avoid infinite recursion. - for _ in 0..100 { - let Some(keystroke) = keystrokes.borrow_mut().1.pop() else { - keystrokes.borrow_mut().0.clear(); - return Ok(()); - }; - cx.update(|window, cx| { - let focused = window.focused(cx); - window.dispatch_keystroke(keystroke.clone(), cx); - if window.focused(cx) != focused { - // dispatch_keystroke may cause the focus to change. - // draw's side effect is to schedule the FocusChanged events in the current flush effect cycle - // And we need that to happen before the next keystroke to keep vim mode happy... - // (Note that the tests always do this implicitly, so you must manually test with something like: - // "bindings": { "g z": ["workspace::SendKeystrokes", ": j <enter> u"]} - // ) - window.draw(cx).clear(); + if state.task.is_none() { + state.task = Some( + window + .spawn(cx, async move |cx| { + // limit to 100 keystrokes to avoid infinite recursion. + for _ in 0..100 { + let mut state = keystrokes.borrow_mut(); + let Some(keystroke) = state.queue.pop_front() else { + state.dispatched.clear(); + state.task.take(); + return; + }; + drop(state); + cx.update(|window, cx| { + let focused = window.focused(cx); + window.dispatch_keystroke(keystroke.clone(), cx); + if window.focused(cx) != focused { + // dispatch_keystroke may cause the focus to change. + // draw's side effect is to schedule the FocusChanged events in the current flush effect cycle + // And we need that to happen before the next keystroke to keep vim mode happy... + // (Note that the tests always do this implicitly, so you must manually test with something like: + // "bindings": { "g z": ["workspace::SendKeystrokes", ": j <enter> u"]} + // ) + window.draw(cx).clear(); + } + }) + .ok(); } - })?; - } - *keystrokes.borrow_mut() = Default::default(); - anyhow::bail!("over 100 keystrokes passed to send_keystrokes"); - }) - .detach_and_log_err(cx); + *keystrokes.borrow_mut() = Default::default(); + log::error!("over 100 keystrokes passed to send_keystrokes"); + }) + .shared(), + ); + } + state.task.clone().unwrap() } fn save_all_internal( @@ -4770,7 +4793,7 @@ impl Workspace { .remote_id(&self.app_state.client, window, cx) .map(|id| id.to_proto()); - if let Some(id) = id.clone() { + if let Some(id) = id { if let Some(variant) = item.to_state_proto(window, cx) { let view = Some(proto::View { id: id.clone(), @@ -4783,7 +4806,7 @@ impl Workspace { update = proto::UpdateActiveView { view, // TODO: Remove after version 0.145.x stabilizes. - id: id.clone(), + id, leader_id: leader_peer_id, }; } @@ -5056,6 +5079,46 @@ impl Workspace { } } + fn update_ssh_paths(&mut self, cx: &App) { + let project = self.project().read(cx); + if !project.is_local() { + let paths: Vec<String> = project + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).abs_path().to_string_lossy().to_string()) + .collect(); + if let Some(ssh_project) = &mut self.serialized_ssh_project { + ssh_project.paths = paths; + } + } + } + + fn serialize_ssh_paths(&mut self, window: &mut Window, cx: &mut Context<Workspace>) { + if self._schedule_serialize_ssh_paths.is_none() { + self._schedule_serialize_ssh_paths = + Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update_in(cx, |this, window, cx| { + let task = if let Some(ssh_project) = &this.serialized_ssh_project { + let ssh_project_id = ssh_project.id; + let ssh_project_paths = ssh_project.paths.clone(); + window.spawn(cx, async move |_| { + persistence::DB + .update_ssh_project_paths(ssh_project_id, ssh_project_paths) + .await + }) + } else { + Task::ready(Err(anyhow::anyhow!("No SSH project to serialize"))) + }; + task.detach(); + this._schedule_serialize_ssh_paths.take(); + }) + .log_err(); + })); + } + } + fn remove_panes(&mut self, member: Member, window: &mut Window, cx: &mut Context<Workspace>) { match member { Member::Axis(PaneAxis { members, .. }) => { @@ -5099,17 +5162,18 @@ impl Workspace { } fn serialize_workspace(&mut self, window: &mut Window, cx: &mut Context<Self>) { - if self._schedule_serialize.is_none() { - self._schedule_serialize = Some(cx.spawn_in(window, async move |this, cx| { - cx.background_executor() - .timer(Duration::from_millis(100)) - .await; - this.update_in(cx, |this, window, cx| { - this.serialize_workspace_internal(window, cx).detach(); - this._schedule_serialize.take(); - }) - .log_err(); - })); + if self._schedule_serialize_workspace.is_none() { + self._schedule_serialize_workspace = + Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update_in(cx, |this, window, cx| { + this.serialize_workspace_internal(window, cx).detach(); + this._schedule_serialize_workspace.take(); + }) + .log_err(); + })); } } @@ -5484,6 +5548,7 @@ impl Workspace { .on_action(cx.listener(Self::activate_pane_at_index)) .on_action(cx.listener(Self::move_item_to_pane_at_index)) .on_action(cx.listener(Self::move_focused_panel_to_next_position)) + .on_action(cx.listener(Self::toggle_edit_predictions_all_files)) .on_action(cx.listener(|workspace, _: &Unfollow, window, cx| { let pane = workspace.active_pane().clone(); workspace.unfollow_in_pane(&pane, window, cx); @@ -5672,7 +5737,6 @@ impl Workspace { let client = project.read(cx).client(); let user_store = project.read(cx).user_store(); - let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let session = cx.new(|cx| AppSession::new(Session::test(), cx)); window.activate_window(); @@ -5916,6 +5980,19 @@ impl Workspace { } }); } + + fn toggle_edit_predictions_all_files( + &mut self, + _: &ToggleEditPrediction, + _window: &mut Window, + cx: &mut Context<Self>, + ) { + let fs = self.project().read(cx).fs().clone(); + let show_edit_predictions = all_language_settings(None, cx).show_edit_predictions(None, cx); + update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| { + file.defaults.show_edit_predictions = Some(!show_edit_predictions) + }); + } } fn leader_border_for_pane( @@ -6221,6 +6298,7 @@ impl Render for Workspace { .iter() .map(|(_, notification)| notification.entity_id()) .collect::<Vec<_>>(); + let bottom_dock_layout = WorkspaceSettings::get_global(cx).bottom_dock_layout; client_side_decorations( self.actions(div(), window, cx) @@ -6344,7 +6422,7 @@ impl Render for Workspace { )) }) .child({ - match self.bottom_dock_layout { + match bottom_dock_layout { BottomDockLayout::Full => div() .flex() .flex_col() @@ -6876,10 +6954,13 @@ async fn join_channel_internal( match status { Status::Connecting | Status::Authenticating + | Status::Authenticated | Status::Reconnecting | Status::Reauthenticating => continue, Status::Connected { .. } => break 'outer, - Status::SignedOut => return Err(ErrorCode::SignedOut.into()), + Status::SignedOut | Status::AuthenticationError => { + return Err(ErrorCode::SignedOut.into()); + } Status::UpgradeRequired => return Err(ErrorCode::UpgradeRequired.into()), Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { return Err(ErrorCode::Disconnected.into()); diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index 4fc6b91abb..b5a0f71e81 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -62,7 +62,7 @@ use std::{ }, time::{Duration, Instant}, }; -use sum_tree::{Bias, Edit, KeyedItem, SeekTarget, SumTree, Summary, TreeMap, TreeSet, Unit}; +use sum_tree::{Bias, Dimensions, Edit, KeyedItem, SeekTarget, SumTree, Summary, TreeMap, TreeSet}; use text::{LineEnding, Rope}; use util::{ ResultExt, @@ -407,12 +407,12 @@ struct LocalRepositoryEntry { } impl sum_tree::Item for LocalRepositoryEntry { - type Summary = PathSummary<Unit>; + type Summary = PathSummary<&'static ()>; fn summary(&self, _: &<Self::Summary as Summary>::Context) -> Self::Summary { PathSummary { max_path: self.work_directory.path_key().0, - item_summary: Unit, + item_summary: &(), } } } @@ -425,12 +425,6 @@ impl KeyedItem for LocalRepositoryEntry { } } -//impl LocalRepositoryEntry { -// pub fn repo(&self) -> &Arc<dyn GitRepository> { -// &self.repo_ptr -// } -//} - impl Deref for LocalRepositoryEntry { type Target = WorkDirectory; @@ -3572,10 +3566,15 @@ impl<'a> sum_tree::Dimension<'a, PathSummary<GitSummary>> for GitSummary { } } -impl<'a> sum_tree::SeekTarget<'a, PathSummary<GitSummary>, (TraversalProgress<'a>, GitSummary)> +impl<'a> + sum_tree::SeekTarget<'a, PathSummary<GitSummary>, Dimensions<TraversalProgress<'a>, GitSummary>> for PathTarget<'_> { - fn cmp(&self, cursor_location: &(TraversalProgress<'a>, GitSummary), _: &()) -> Ordering { + fn cmp( + &self, + cursor_location: &Dimensions<TraversalProgress<'a>, GitSummary>, + _: &(), + ) -> Ordering { self.cmp_path(&cursor_location.0.max_path) } } @@ -5417,7 +5416,7 @@ impl<'a> SeekTarget<'a, EntrySummary, TraversalProgress<'a>> for TraversalTarget } } -impl<'a> SeekTarget<'a, PathSummary<Unit>, TraversalProgress<'a>> for TraversalTarget<'_> { +impl<'a> SeekTarget<'a, PathSummary<&'static ()>, TraversalProgress<'a>> for TraversalTarget<'_> { fn cmp(&self, cursor_location: &TraversalProgress<'a>, _: &()) -> Ordering { self.cmp_progress(cursor_location) } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index a864ece683..5997e43864 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.198.0" +version = "0.200.0" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team <hi@zed.dev>"] @@ -45,6 +45,7 @@ collections.workspace = true command_palette.workspace = true component.workspace = true copilot.workspace = true +crashes.workspace = true dap_adapters.workspace = true db.workspace = true debug_adapter_extension.workspace = true @@ -76,7 +77,7 @@ gpui_tokio.workspace = true http_client.workspace = true image_viewer.workspace = true indoc.workspace = true -inline_completion_button.workspace = true +edit_prediction_button.workspace = true inspector_ui.workspace = true install_cli.workspace = true jj_ui.workspace = true @@ -106,6 +107,7 @@ outline_panel.workspace = true parking_lot.workspace = true paths.workspace = true picker.workspace = true +settings_profile_selector.workspace = true profiling.workspace = true project.workspace = true project_panel.workspace = true @@ -116,6 +118,7 @@ recent_projects.workspace = true release_channel.workspace = true remote.workspace = true repl.workspace = true +reqwest.workspace = true reqwest_client.workspace = true rope.workspace = true search.workspace = true diff --git a/crates/zed/resources/app-icon-nightly.png b/crates/zed/resources/app-icon-nightly.png index 5f1304a6af..776cd06b1b 100644 Binary files a/crates/zed/resources/app-icon-nightly.png and b/crates/zed/resources/app-icon-nightly.png differ diff --git a/crates/zed/resources/app-icon-nightly@2x.png b/crates/zed/resources/app-icon-nightly@2x.png index edb416ede4..6d781594ac 100644 Binary files a/crates/zed/resources/app-icon-nightly@2x.png and b/crates/zed/resources/app-icon-nightly@2x.png differ diff --git a/crates/zed/resources/flatpak/manifest-template.json b/crates/zed/resources/flatpak/manifest-template.json index 1560027e9f..0a14a1c2b0 100644 --- a/crates/zed/resources/flatpak/manifest-template.json +++ b/crates/zed/resources/flatpak/manifest-template.json @@ -38,7 +38,7 @@ }, "build-commands": [ "install -Dm644 $ICON_FILE.png /app/share/icons/hicolor/512x512/apps/$APP_ID.png", - "envsubst < zed.desktop.in > zed.desktop && install -Dm644 zed.desktop /app/share/applications/$APP_ID.desktop", + "envsubst < zed.desktop.in > zed.desktop && install -Dm755 zed.desktop /app/share/applications/$APP_ID.desktop", "envsubst < flatpak/zed.metainfo.xml.in > zed.metainfo.xml && install -Dm644 zed.metainfo.xml /app/share/metainfo/$APP_ID.metainfo.xml", "sed -i -e '/@release_info@/{r flatpak/release-info/$CHANNEL' -e 'd}' /app/share/metainfo/$APP_ID.metainfo.xml", "install -Dm755 bin/zed /app/bin/zed", diff --git a/crates/zed/resources/windows/zed.iss b/crates/zed/resources/windows/zed.iss index 9d104d1f15..2e76f35a0b 100644 --- a/crates/zed/resources/windows/zed.iss +++ b/crates/zed/resources/windows/zed.iss @@ -62,6 +62,7 @@ Source: "{#ResourcesDir}\Zed.exe"; DestDir: "{code:GetInstallDir}"; Flags: ignor Source: "{#ResourcesDir}\bin\*"; DestDir: "{code:GetInstallDir}\bin"; Flags: ignoreversion Source: "{#ResourcesDir}\tools\*"; DestDir: "{app}\tools"; Flags: ignoreversion Source: "{#ResourcesDir}\appx\*"; DestDir: "{app}\appx"; BeforeInstall: RemoveAppxPackage; AfterInstall: AddAppxPackage; Flags: ignoreversion; Check: IsWindows11OrLater +Source: "{#ResourcesDir}\amd_ags_x64.dll"; DestDir: "{app}"; Flags: ignoreversion [Icons] Name: "{group}\{#AppName}"; Filename: "{app}\{#AppExeName}.exe"; AppUserModelID: "{#AppUserId}" @@ -1245,16 +1246,6 @@ Root: HKCU; Subkey: "Software\Classes\zed\DefaultIcon"; ValueType: "string"; Val Root: HKCU; Subkey: "Software\Classes\zed\shell\open\command"; ValueType: "string"; ValueData: """{app}\Zed.exe"" ""%1""" [Code] -function InitializeSetup(): Boolean; -begin - Result := True; - - if not WizardSilent() and IsAdmin() then begin - MsgBox('This User Installer is not meant to be run as an Administrator.', mbError, MB_OK); - Result := False; - end; -end; - function WizardNotSilent(): Boolean; begin Result := not WizardSilent(); diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index d0b9c53397..e4a14b5d32 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -42,7 +42,7 @@ use theme::{ ActiveTheme, IconThemeNotFoundError, SystemAppearance, ThemeNotFoundError, ThemeRegistry, ThemeSettings, }; -use util::{ConnectionResult, ResultExt, TryFutureExt, maybe}; +use util::{ResultExt, TryFutureExt, maybe}; use uuid::Uuid; use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::{ @@ -51,9 +51,9 @@ use workspace::{ }; use zed::{ OpenListener, OpenRequest, RawOpenRequest, app_menus, build_window_options, - derive_paths_with_position, handle_cli_connection, handle_keymap_file_changes, - handle_settings_changed, handle_settings_file_changes, initialize_workspace, - inline_completion_registry, open_paths_with_positions, + derive_paths_with_position, edit_prediction_registry, handle_cli_connection, + handle_keymap_file_changes, handle_settings_changed, handle_settings_file_changes, + initialize_workspace, open_paths_with_positions, }; use crate::zed::OpenRequestKind; @@ -172,6 +172,12 @@ pub fn main() { let args = Args::parse(); + // `zed --crash-handler` Makes zed operate in minidump crash handler mode + if let Some(socket) = &args.crash_handler { + crashes::crash_server(socket.as_path()); + return; + } + // `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass if let Some(socket) = &args.askpass { askpass::main(socket); @@ -264,6 +270,9 @@ pub fn main() { let session_id = Uuid::new_v4().to_string(); let session = app.background_executor().block(Session::new()); + app.background_executor() + .spawn(crashes::init(session_id.clone())) + .detach(); reliability::init_panic_hook( app_version, app_commit_sha.clone(), @@ -559,11 +568,7 @@ pub fn main() { web_search::init(cx); web_search_providers::init(app_state.client.clone(), cx); snippet_provider::init(cx); - inline_completion_registry::init( - app_state.client.clone(), - app_state.user_store.clone(), - cx, - ); + edit_prediction_registry::init(app_state.client.clone(), app_state.user_store.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); agent_ui::init( app_state.fs.clone(), @@ -613,6 +618,7 @@ pub fn main() { language_selector::init(cx); toolchain_selector::init(cx); theme_selector::init(cx); + settings_profile_selector::init(cx); language_tools::init(cx); call::init(app_state.client.clone(), app_state.user_store.clone(), cx); notifications::init(app_state.client.clone(), app_state.user_store.clone(), cx); @@ -681,17 +687,9 @@ pub fn main() { cx.spawn({ let client = app_state.client.clone(); - async move |cx| match authenticate(client, &cx).await { - ConnectionResult::Timeout => log::error!("Timeout during initial auth"), - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during initial auth") - } - ConnectionResult::Result(r) => { - r.log_err(); - } - } + async move |cx| authenticate(client, &cx).await }) - .detach(); + .detach_and_log_err(cx); let urls: Vec<_> = args .paths_or_urls @@ -841,15 +839,7 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut let client = app_state.client.clone(); // we continue even if authentication fails as join_channel/ open channel notes will // show a visible error message. - match authenticate(client, &cx).await { - ConnectionResult::Timeout => { - log::error!("Timeout during open request handling") - } - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during open request handling") - } - ConnectionResult::Result(r) => r?, - }; + authenticate(client, &cx).await.log_err(); if let Some(channel_id) = request.join_channel { cx.update(|cx| { @@ -899,18 +889,18 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut } } -async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> ConnectionResult<()> { +async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> Result<()> { if stdout_is_a_pty() { if client::IMPERSONATE_LOGIN.is_some() { - return client.authenticate_and_connect(false, cx).await; + client.sign_in_with_optional_connect(false, cx).await?; } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } - ConnectionResult::Result(Ok(())) + Ok(()) } async fn system_id() -> Result<IdType> { @@ -1144,6 +1134,7 @@ fn init_paths() -> HashMap<io::ErrorKind, Vec<&'static Path>> { paths::config_dir(), paths::extensions_dir(), paths::languages_dir(), + paths::debug_adapters_dir(), paths::database_dir(), paths::logs_dir(), paths::temp_dir(), @@ -1203,6 +1194,11 @@ struct Args { #[arg(long, hide = true)] nc: Option<String>, + /// Used for recording minidumps on crashes by having Zed run a separate + /// process communicating over a socket. + #[arg(long, hide = true)] + crash_handler: Option<PathBuf>, + /// Run zed in the foreground, only used on Windows, to match the behavior on macOS. #[arg(long)] #[cfg(target_os = "windows")] diff --git a/crates/zed/src/reliability.rs b/crates/zed/src/reliability.rs index ccbe57e7b3..53539699cc 100644 --- a/crates/zed/src/reliability.rs +++ b/crates/zed/src/reliability.rs @@ -2,21 +2,32 @@ use crate::stdout_is_a_pty; use anyhow::{Context as _, Result}; use backtrace::{self, Backtrace}; use chrono::Utc; -use client::{TelemetrySettings, telemetry}; +use client::{ + TelemetrySettings, + telemetry::{self, MINIDUMP_ENDPOINT}, +}; use db::kvp::KEY_VALUE_STORE; +use futures::AsyncReadExt; use gpui::{App, AppContext as _, SemanticVersion}; use http_client::{self, HttpClient, HttpClientWithUrl, HttpRequestExt, Method}; use paths::{crashes_dir, crashes_retired_dir}; use project::Project; use release_channel::{AppCommitSha, RELEASE_CHANNEL, ReleaseChannel}; +use reqwest::multipart::{Form, Part}; use settings::Settings; use smol::stream::StreamExt; use std::{ env, ffi::{OsStr, c_void}, - sync::{Arc, atomic::Ordering}, + fs, + io::Write, + panic, + sync::{ + Arc, + atomic::{AtomicU32, Ordering}, + }, + thread, }; -use std::{io::Write, panic, sync::atomic::AtomicU32, thread}; use telemetry_events::{LocationData, Panic, PanicRequest}; use url::Url; use util::ResultExt; @@ -37,9 +48,10 @@ pub fn init_panic_hook( if prior_panic_count > 0 { // Give the panic-ing thread time to write the panic file loop { - std::thread::yield_now(); + thread::yield_now(); } } + crashes::handle_panic(); let thread = thread::current(); let thread_name = thread.name().unwrap_or("<unnamed>"); @@ -63,7 +75,7 @@ pub fn init_panic_hook( location.column(), match app_commit_sha.as_ref() { Some(commit_sha) => format!( - "https://github.com/zed-industries/zed/blob/{}/src/{}#L{} \ + "https://github.com/zed-industries/zed/blob/{}/{}#L{} \ (may not be uploaded, line may be incorrect if files modified)\n", commit_sha.full(), location.file(), @@ -136,9 +148,9 @@ pub fn init_panic_hook( if let Some(panic_data_json) = serde_json::to_string(&panic_data).log_err() { let timestamp = chrono::Utc::now().format("%Y_%m_%d %H_%M_%S").to_string(); let panic_file_path = paths::logs_dir().join(format!("zed-{timestamp}.panic")); - let panic_file = std::fs::OpenOptions::new() - .append(true) - .create(true) + let panic_file = fs::OpenOptions::new() + .write(true) + .create_new(true) .open(&panic_file_path) .log_err(); if let Some(mut panic_file) = panic_file { @@ -205,27 +217,31 @@ pub fn init( if let Some(ssh_client) = project.ssh_client() { ssh_client.update(cx, |client, cx| { if TelemetrySettings::get_global(cx).diagnostics { - let request = client.proto_client().request(proto::GetPanicFiles {}); + let request = client.proto_client().request(proto::GetCrashFiles {}); cx.background_spawn(async move { - let panic_files = request.await?; - for file in panic_files.file_contents { - let panic: Option<Panic> = serde_json::from_str(&file) - .log_err() - .or_else(|| { - file.lines() - .next() - .and_then(|line| serde_json::from_str(line).ok()) - }) - .unwrap_or_else(|| { - log::error!("failed to deserialize panic file {:?}", file); - None - }); + let crash_files = request.await?; + for crash in crash_files.crashes { + let mut panic: Option<Panic> = crash + .panic_contents + .and_then(|s| serde_json::from_str(&s).log_err()); - if let Some(mut panic) = panic { + if let Some(panic) = panic.as_mut() { panic.session_id = session_id.clone(); panic.system_id = system_id.clone(); panic.installation_id = installation_id.clone(); + } + if let Some(minidump) = crash.minidump_contents { + upload_minidump( + http_client.clone(), + minidump.clone(), + panic.as_ref(), + ) + .await + .log_err(); + } + + if let Some(panic) = panic { upload_panic(&http_client, &panic_report_url, panic, &mut None) .await?; } @@ -510,6 +526,22 @@ async fn upload_previous_panics( }); if let Some(panic) = panic { + let minidump_path = paths::logs_dir() + .join(&panic.session_id) + .with_extension("dmp"); + if minidump_path.exists() { + let minidump = smol::fs::read(&minidump_path) + .await + .context("Failed to read minidump")?; + if upload_minidump(http.clone(), minidump, Some(&panic)) + .await + .log_err() + .is_some() + { + fs::remove_file(minidump_path).ok(); + } + } + if !upload_panic(&http, &panic_report_url, panic, &mut most_recent_panic).await? { continue; } @@ -517,13 +549,80 @@ async fn upload_previous_panics( } // We've done what we can, delete the file - std::fs::remove_file(child_path) + fs::remove_file(child_path) .context("error removing panic") .log_err(); } + + if MINIDUMP_ENDPOINT.is_none() { + return Ok(most_recent_panic); + } + + // loop back over the directory again to upload any minidumps that are missing panics + let mut children = smol::fs::read_dir(paths::logs_dir()).await?; + while let Some(child) = children.next().await { + let child = child?; + let child_path = child.path(); + if child_path.extension() != Some(OsStr::new("dmp")) { + continue; + } + if upload_minidump( + http.clone(), + smol::fs::read(&child_path) + .await + .context("Failed to read minidump")?, + None, + ) + .await + .log_err() + .is_some() + { + fs::remove_file(child_path).ok(); + } + } + Ok(most_recent_panic) } +async fn upload_minidump( + http: Arc<HttpClientWithUrl>, + minidump: Vec<u8>, + panic: Option<&Panic>, +) -> Result<()> { + let minidump_endpoint = MINIDUMP_ENDPOINT + .to_owned() + .ok_or_else(|| anyhow::anyhow!("Minidump endpoint not set"))?; + + let mut form = Form::new() + .part( + "upload_file_minidump", + Part::bytes(minidump) + .file_name("minidump.dmp") + .mime_str("application/octet-stream")?, + ) + .text("platform", "rust"); + if let Some(panic) = panic { + form = form + .text( + "sentry[release]", + format!("{}-{}", panic.release_channel, panic.app_version), + ) + .text("sentry[logentry][formatted]", panic.payload.clone()); + } + + let mut response_text = String::new(); + let mut response = http.send_multipart_form(&minidump_endpoint, form).await?; + response + .body_mut() + .read_to_string(&mut response_text) + .await?; + if !response.status().is_success() { + anyhow::bail!("failed to upload minidump: {response_text}"); + } + log::info!("Uploaded minidump. event id: {response_text}"); + Ok(()) +} + async fn upload_panic( http: &Arc<HttpClientWithUrl>, panic_report_url: &Url, diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 57534c8cd5..8c89a7d85a 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -1,6 +1,6 @@ mod app_menus; pub mod component_preview; -pub mod inline_completion_registry; +pub mod edit_prediction_registry; #[cfg(target_os = "macos")] pub(crate) mod mac_only_instance; mod migrate; @@ -111,6 +111,8 @@ actions!( Zoom, /// Triggers a test panic for debugging. TestPanic, + /// Triggers a hard crash for debugging. + TestCrash, ] ); @@ -124,9 +126,28 @@ pub fn init(cx: &mut App) { cx.on_action(quit); cx.on_action(|_: &RestoreBanner, cx| title_bar::restore_banner(cx)); - if ReleaseChannel::global(cx) == ReleaseChannel::Dev || cx.has_flag::<PanicFeatureFlag>() { - cx.on_action(|_: &TestPanic, _| panic!("Ran the TestPanic action")); - } + let flag = cx.wait_for_flag::<PanicFeatureFlag>(); + cx.spawn(async |cx| { + if cx + .update(|cx| ReleaseChannel::global(cx) == ReleaseChannel::Dev) + .unwrap_or_default() + || flag.await + { + cx.update(|cx| { + cx.on_action(|_: &TestPanic, _| panic!("Ran the TestPanic action")); + cx.on_action(|_: &TestCrash, _| { + unsafe extern "C" { + fn puts(s: *const i8); + } + unsafe { + puts(0xabad1d3a as *const i8); + } + }); + }) + .ok(); + }; + }) + .detach(); cx.on_action(|_: &OpenLog, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { open_log_file(workspace, window, cx); @@ -311,18 +332,18 @@ pub fn initialize_workspace( show_software_emulation_warning_if_needed(specs, window, cx); } - let inline_completion_menu_handle = PopoverMenuHandle::default(); + let edit_prediction_menu_handle = PopoverMenuHandle::default(); let edit_prediction_button = cx.new(|cx| { - inline_completion_button::InlineCompletionButton::new( + edit_prediction_button::EditPredictionButton::new( app_state.fs.clone(), app_state.user_store.clone(), - inline_completion_menu_handle.clone(), + edit_prediction_menu_handle.clone(), cx, ) }); workspace.register_action({ - move |_, _: &inline_completion_button::ToggleMenu, window, cx| { - inline_completion_menu_handle.toggle(window, cx); + move |_, _: &edit_prediction_button::ToggleMenu, window, cx| { + edit_prediction_menu_handle.toggle(window, cx); } }); @@ -4333,6 +4354,7 @@ mod tests { "menu", "notebook", "notification_panel", + "onboarding", "outline", "outline_panel", "pane", @@ -4345,6 +4367,7 @@ mod tests { "repl", "rules_library", "search", + "settings_profile_selector", "snippets", "supermaven", "svg", @@ -4416,7 +4439,7 @@ mod tests { }); for name in languages.language_names() { languages - .language_for_name(&name) + .language_for_name(name.as_ref()) .await .with_context(|| format!("language name {name}")) .unwrap(); diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index 78532b10b4..15d5659f03 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -24,6 +24,10 @@ pub fn app_menus() -> Vec<Menu> { zed_actions::OpenDefaultKeymap, ), MenuItem::action("Open Project Settings", super::OpenProjectSettings), + MenuItem::action( + "Select Settings Profile...", + zed_actions::settings_profile_selector::Toggle, + ), MenuItem::action( "Select Theme...", zed_actions::theme_selector::Toggle::default(), diff --git a/crates/zed/src/zed/component_preview.rs b/crates/zed/src/zed/component_preview.rs index 670793cff3..db75b544f6 100644 --- a/crates/zed/src/zed/component_preview.rs +++ b/crates/zed/src/zed/component_preview.rs @@ -105,7 +105,9 @@ enum PreviewPage { struct ComponentPreview { active_page: PreviewPage, active_thread: Option<Entity<ActiveThread>>, + reset_key: usize, component_list: ListState, + entries: Vec<PreviewEntry>, component_map: HashMap<ComponentId, ComponentMetadata>, components: Vec<ComponentMetadata>, cursor_index: usize, @@ -138,8 +140,7 @@ impl ComponentPreview { let project_clone = project.clone(); cx.spawn_in(window, async move |entity, cx| { - let thread_store_future = - load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx); + let thread_store_future = load_preview_thread_store(project_clone.clone(), cx); let text_thread_store_future = load_preview_text_thread_store(workspace_clone.clone(), project_clone.clone(), cx); @@ -172,23 +173,14 @@ impl ComponentPreview { sorted_components.len(), gpui::ListAlignment::Top, px(1500.0), - { - let this = cx.entity().downgrade(); - move |ix, window: &mut Window, cx: &mut App| { - this.update(cx, |this, cx| { - let component = this.get_component(ix); - this.render_preview(&component, window, cx) - .into_any_element() - }) - .unwrap() - } - }, ); let mut component_preview = Self { active_page, active_thread: None, + reset_key: 0, component_list, + entries: Vec::new(), component_map: component_registry.component_map(), components: sorted_components, cursor_index: selected_index, @@ -265,15 +257,16 @@ impl ComponentPreview { } fn set_active_page(&mut self, page: PreviewPage, cx: &mut Context<Self>) { - self.active_page = page; - cx.emit(ItemEvent::UpdateTab); + if self.active_page == page { + // Force the current preview page to render again + self.reset_key = self.reset_key.wrapping_add(1); + } else { + self.active_page = page; + cx.emit(ItemEvent::UpdateTab); + } cx.notify(); } - fn get_component(&self, ix: usize) -> ComponentMetadata { - self.components[ix].clone() - } - fn filtered_components(&self) -> Vec<ComponentMetadata> { if self.filter_text.is_empty() { return self.components.clone(); @@ -414,7 +407,6 @@ impl ComponentPreview { fn update_component_list(&mut self, cx: &mut Context<Self>) { let entries = self.scope_ordered_entries(); let new_len = entries.len(); - let weak_entity = cx.entity().downgrade(); if new_len > 0 { self.nav_scroll_handle @@ -440,56 +432,9 @@ impl ComponentPreview { } } - self.component_list = ListState::new( - filtered_components.len(), - gpui::ListAlignment::Top, - px(1500.0), - { - let components = filtered_components.clone(); - let this = cx.entity().downgrade(); - move |ix, window: &mut Window, cx: &mut App| { - if ix >= components.len() { - return div().w_full().h_0().into_any_element(); - } + self.component_list = ListState::new(new_len, gpui::ListAlignment::Top, px(1500.0)); + self.entries = entries; - this.update(cx, |this, cx| { - let component = &components[ix]; - this.render_preview(component, window, cx) - .into_any_element() - }) - .unwrap() - } - }, - ); - - let new_list = ListState::new( - new_len, - gpui::ListAlignment::Top, - px(1500.0), - move |ix, window, cx| { - if ix >= entries.len() { - return div().w_full().h_0().into_any_element(); - } - - let entry = &entries[ix]; - - weak_entity - .update(cx, |this, cx| match entry { - PreviewEntry::Component(component, _) => this - .render_preview(component, window, cx) - .into_any_element(), - PreviewEntry::SectionHeader(shared_string) => this - .render_scope_header(ix, shared_string.clone(), window, cx) - .into_any_element(), - PreviewEntry::AllComponents => div().w_full().h_0().into_any_element(), - PreviewEntry::ActiveThread => div().w_full().h_0().into_any_element(), - PreviewEntry::Separator => div().w_full().h_0().into_any_element(), - }) - .unwrap() - }, - ); - - self.component_list = new_list; cx.emit(ItemEvent::UpdateTab); } @@ -666,10 +611,35 @@ impl ComponentPreview { .child(format!("No components matching '{}'.", self.filter_text)) .into_any_element() } else { - list(self.component_list.clone()) - .flex_grow() - .with_sizing_behavior(gpui::ListSizingBehavior::Auto) - .into_any_element() + list( + self.component_list.clone(), + cx.processor(|this, ix, window, cx| { + if ix >= this.entries.len() { + return div().w_full().h_0().into_any_element(); + } + + let entry = &this.entries[ix]; + + match entry { + PreviewEntry::Component(component, _) => this + .render_preview(component, window, cx) + .into_any_element(), + PreviewEntry::SectionHeader(shared_string) => this + .render_scope_header(ix, shared_string.clone(), window, cx) + .into_any_element(), + PreviewEntry::AllComponents => { + div().w_full().h_0().into_any_element() + } + PreviewEntry::ActiveThread => { + div().w_full().h_0().into_any_element() + } + PreviewEntry::Separator => div().w_full().h_0().into_any_element(), + } + }), + ) + .flex_grow() + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .into_any_element() }, ) } @@ -690,6 +660,7 @@ impl ComponentPreview { component.clone(), self.workspace.clone(), self.active_thread.clone(), + self.reset_key, )) .into_any_element() } else { @@ -1041,6 +1012,7 @@ pub struct ComponentPreviewPage { component: ComponentMetadata, workspace: WeakEntity<Workspace>, active_thread: Option<Entity<ActiveThread>>, + reset_key: usize, } impl ComponentPreviewPage { @@ -1048,6 +1020,7 @@ impl ComponentPreviewPage { component: ComponentMetadata, workspace: WeakEntity<Workspace>, active_thread: Option<Entity<ActiveThread>>, + reset_key: usize, // languages: Arc<LanguageRegistry> ) -> Self { Self { @@ -1055,6 +1028,7 @@ impl ComponentPreviewPage { component, workspace, active_thread, + reset_key, } } @@ -1155,6 +1129,7 @@ impl ComponentPreviewPage { }; v_flex() + .id(("component-preview", self.reset_key)) .size_full() .flex_1() .px_12() diff --git a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs index 825744572d..de98106fae 100644 --- a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs +++ b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs @@ -12,21 +12,19 @@ use ui::{App, Window}; use workspace::Workspace; pub fn load_preview_thread_store( - workspace: WeakEntity<Workspace>, project: Entity<Project>, cx: &mut AsyncApp, ) -> Task<Result<Entity<ThreadStore>>> { - workspace - .update(cx, |_, cx| { - ThreadStore::load( - project.clone(), - cx.new(|_| ToolWorkingSet::default()), - None, - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - }) - .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) + cx.update(|cx| { + ThreadStore::load( + project.clone(), + cx.new(|_| ToolWorkingSet::default()), + None, + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) } pub fn load_preview_text_thread_store( diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs similarity index 91% rename from crates/zed/src/zed/inline_completion_registry.rs rename to crates/zed/src/zed/edit_prediction_registry.rs index f2e9d21b96..b9f561c0e7 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -11,7 +11,7 @@ use supermaven::{Supermaven, SupermavenCompletionProvider}; use ui::Window; use util::ResultExt; use workspace::Workspace; -use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; +use zeta::{ProviderDataCollection, ZetaEditPredictionProvider}; pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) { let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default(); @@ -90,10 +90,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) { let new_provider = all_language_settings(None, cx).edit_predictions.provider; if new_provider != provider { - let tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); + let tos_accepted = user_store.read(cx).has_accepted_terms_of_service(); telemetry::event!( "Edit Prediction Provider Changed", @@ -174,7 +171,7 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed editor .register_action(cx.listener( |editor, _: &copilot::Suggest, window: &mut Window, cx: &mut Context<Editor>| { - editor.show_inline_completion(&Default::default(), window, cx); + editor.show_edit_prediction(&Default::default(), window, cx); }, )) .detach(); @@ -195,16 +192,6 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed }, )) .detach(); - editor - .register_action(cx.listener( - |editor, - _: &editor::actions::AcceptPartialCopilotSuggestion, - window: &mut Window, - cx: &mut Context<Editor>| { - editor.accept_partial_inline_completion(&Default::default(), window, cx); - }, - )) - .detach(); } fn assign_edit_prediction_provider( @@ -220,7 +207,7 @@ fn assign_edit_prediction_provider( match provider { EditPredictionProvider::None => { - editor.set_edit_prediction_provider::<ZetaInlineCompletionProvider>(None, window, cx); + editor.set_edit_prediction_provider::<ZetaEditPredictionProvider>(None, window, cx); } EditPredictionProvider::Copilot => { if let Some(copilot) = Copilot::global(cx) { @@ -242,7 +229,7 @@ fn assign_edit_prediction_provider( } } EditPredictionProvider::Zed => { - if client.status().borrow().is_connected() { + if user_store.read(cx).current_user().is_some() { let mut worktree = None; if let Some(buffer) = &singleton_buffer { @@ -278,7 +265,7 @@ fn assign_edit_prediction_provider( ProviderDataCollection::new(zeta.clone(), singleton_buffer, cx); let provider = - cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta, data_collection)); + cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, data_collection)); editor.set_edit_prediction_provider(Some(provider), window, cx); } diff --git a/crates/zed/src/zed/quick_action_bar.rs b/crates/zed/src/zed/quick_action_bar.rs index aff124a0bc..e76bef59a3 100644 --- a/crates/zed/src/zed/quick_action_bar.rs +++ b/crates/zed/src/zed/quick_action_bar.rs @@ -2,7 +2,6 @@ mod preview; mod repl_menu; use agent_settings::AgentSettings; -use client::DisableAiSettings; use editor::actions::{ AddSelectionAbove, AddSelectionBelow, CodeActionSource, DuplicateLineDown, GoToDiagnostic, GoToHunk, GoToPreviousDiagnostic, GoToPreviousHunk, MoveLineDown, MoveLineUp, SelectAll, @@ -16,6 +15,7 @@ use gpui::{ FocusHandle, Focusable, InteractiveElement, ParentElement, Render, Styled, Subscription, WeakEntity, Window, anchored, deferred, point, }; +use project::DisableAiSettings; use project::project_settings::DiagnosticSeverity; use search::{BufferSearchBar, buffer_search}; use settings::{Settings, SettingsStore}; @@ -192,7 +192,7 @@ impl Render for QuickActionBar { }; v_flex() .child( - IconButton::new("toggle_code_actions_icon", IconName::Bolt) + IconButton::new("toggle_code_actions_icon", IconName::BoltOutlined) .icon_size(IconSize::Small) .style(ButtonStyle::Subtle) .disabled(!has_available_code_actions) @@ -381,7 +381,7 @@ impl Render for QuickActionBar { } if has_edit_prediction_provider { - let mut inline_completion_entry = ContextMenuEntry::new("Edit Predictions") + let mut edit_prediction_entry = ContextMenuEntry::new("Edit Predictions") .toggleable(IconPosition::Start, edit_predictions_enabled_at_cursor && show_edit_predictions) .disabled(!edit_predictions_enabled_at_cursor) .action( @@ -401,12 +401,12 @@ impl Render for QuickActionBar { } }); if !edit_predictions_enabled_at_cursor { - inline_completion_entry = inline_completion_entry.documentation_aside(DocumentationSide::Left, |_| { + edit_prediction_entry = edit_prediction_entry.documentation_aside(DocumentationSide::Left, |_| { Label::new("You can't toggle edit predictions for this file as it is within the excluded files list.").into_any_element() }); } - menu = menu.item(inline_completion_entry); + menu = menu.item(edit_prediction_entry); } menu = menu.separator(); diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index 4b4bf016c4..64891b6973 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -260,14 +260,25 @@ pub mod icon_theme_selector { } } +pub mod settings_profile_selector { + use gpui::Action; + use schemars::JsonSchema; + use serde::Deserialize; + + #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] + #[action(namespace = settings_profile_selector)] + pub struct Toggle; +} + pub mod agent { use gpui::actions; actions!( agent, [ - /// Opens the agent configuration panel. - OpenConfiguration, + /// Opens the agent settings panel. + #[action(deprecated_aliases = ["agent::OpenConfiguration"])] + OpenSettings, /// Opens the agent onboarding modal. OpenOnboardingModal, /// Resets the agent onboarding state. diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index c2b1de08ae..9f1d02b790 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -21,6 +21,7 @@ ai_onboarding.workspace = true anyhow.workspace = true arrayvec.workspace = true client.workspace = true +cloud_llm_client.workspace = true collections.workspace = true command_palette_hooks.workspace = true copilot.workspace = true @@ -32,14 +33,13 @@ futures.workspace = true gpui.workspace = true http_client.workspace = true indoc.workspace = true -inline_completion.workspace = true +edit_prediction.workspace = true language.workspace = true language_model.workspace = true log.workspace = true menu.workspace = true postage.workspace = true project.workspace = true -proto.workspace = true regex.workspace = true release_channel.workspace = true serde.workspace = true @@ -52,16 +52,17 @@ thiserror.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true +workspace-hack.workspace = true workspace.workspace = true worktree.workspace = true zed_actions.workspace = true -zed_llm_client.workspace = true -workspace-hack.workspace = true [dev-dependencies] -collections = { workspace = true, features = ["test-support"] } +call = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } +cloud_api_types.workspace = true +collections = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } @@ -77,5 +78,4 @@ tree-sitter-rust.workspace = true unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } -call = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/zeta/src/completion_diff_element.rs b/crates/zeta/src/completion_diff_element.rs index 3b7355d797..73c3cb20cd 100644 --- a/crates/zeta/src/completion_diff_element.rs +++ b/crates/zeta/src/completion_diff_element.rs @@ -1,6 +1,6 @@ use std::cmp; -use crate::InlineCompletion; +use crate::EditPrediction; use gpui::{ AnyElement, App, BorderStyle, Bounds, Corners, Edges, HighlightStyle, Hsla, StyledText, TextLayout, TextStyle, point, prelude::*, quad, size, @@ -17,7 +17,7 @@ pub struct CompletionDiffElement { } impl CompletionDiffElement { - pub fn new(completion: &InlineCompletion, cx: &App) -> Self { + pub fn new(completion: &EditPrediction, cx: &App) -> Self { let mut diff = completion .snapshot .text_for_range(completion.excerpt_range.clone()) diff --git a/crates/zeta/src/init.rs b/crates/zeta/src/init.rs index 4a65771223..a01e3a89a2 100644 --- a/crates/zeta/src/init.rs +++ b/crates/zeta/src/init.rs @@ -1,10 +1,10 @@ use std::any::{Any, TypeId}; -use client::DisableAiSettings; use command_palette_hooks::CommandPaletteFilter; use feature_flags::{FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; use gpui::actions; use language::language_settings::{AllLanguageSettings, EditPredictionProvider}; +use project::DisableAiSettings; use settings::{Settings, SettingsStore, update_settings_file}; use ui::App; use workspace::Workspace; diff --git a/crates/zeta/src/rate_completion_modal.rs b/crates/zeta/src/rate_completion_modal.rs index 5a873fb8de..ac7fcade91 100644 --- a/crates/zeta/src/rate_completion_modal.rs +++ b/crates/zeta/src/rate_completion_modal.rs @@ -1,4 +1,4 @@ -use crate::{CompletionDiffElement, InlineCompletion, InlineCompletionRating, Zeta}; +use crate::{CompletionDiffElement, EditPrediction, EditPredictionRating, Zeta}; use editor::Editor; use gpui::{App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, actions, prelude::*}; use language::language_settings; @@ -34,7 +34,7 @@ pub struct RateCompletionModal { } struct ActiveCompletion { - completion: InlineCompletion, + completion: EditPrediction, feedback_editor: Entity<Editor>, } @@ -157,7 +157,7 @@ impl RateCompletionModal { if let Some(active) = &self.active_completion { zeta.rate_completion( &active.completion, - InlineCompletionRating::Positive, + EditPredictionRating::Positive, active.feedback_editor.read(cx).text(cx), cx, ); @@ -189,7 +189,7 @@ impl RateCompletionModal { self.zeta.update(cx, |zeta, cx| { zeta.rate_completion( &active.completion, - InlineCompletionRating::Negative, + EditPredictionRating::Negative, active.feedback_editor.read(cx).text(cx), cx, ); @@ -250,7 +250,7 @@ impl RateCompletionModal { pub fn select_completion( &mut self, - completion: Option<InlineCompletion>, + completion: Option<EditPrediction>, focus: bool, window: &mut Window, cx: &mut Context<Self>, diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index d6f033899d..b1bd737dbf 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -8,8 +8,8 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::{Dismissable, KEY_VALUE_STORE}; +use edit_prediction::DataCollectionState; pub use init::*; -use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; pub use license_detection::is_license_eligible_for_data_collection; pub use rate_completion_modal::*; @@ -17,6 +17,10 @@ pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; +use cloud_llm_client::{ + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, + PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME, +}; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; use gpui::{ @@ -30,7 +34,7 @@ use language::{ }; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use postage::watch; -use project::Project; +use project::{Project, ProjectPath}; use release_channel::AppVersion; use settings::WorktreeId; use std::str::FromStr; @@ -46,17 +50,13 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use telemetry_events::InlineCompletionRating; +use telemetry_events::EditPredictionRating; use thiserror::Error; use util::ResultExt; use uuid::Uuid; use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; use worktree::Worktree; -use zed_llm_client::{ - AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, -}; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>"; @@ -81,15 +81,15 @@ actions!( ); #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct InlineCompletionId(Uuid); +pub struct EditPredictionId(Uuid); -impl From<InlineCompletionId> for gpui::ElementId { - fn from(value: InlineCompletionId) -> Self { +impl From<EditPredictionId> for gpui::ElementId { + fn from(value: EditPredictionId) -> Self { gpui::ElementId::Uuid(value.0) } } -impl std::fmt::Display for InlineCompletionId { +impl std::fmt::Display for EditPredictionId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } @@ -121,9 +121,10 @@ impl Dismissable for ZedPredictUpsell { } pub fn should_show_upsell_modal(user_store: &Entity<UserStore>, cx: &App) -> bool { - match user_store.read(cx).current_user_has_accepted_terms() { - Some(true) => !ZedPredictUpsell::dismissed(), - Some(false) | None => true, + if user_store.read(cx).has_accepted_terms_of_service() { + !ZedPredictUpsell::dismissed() + } else { + true } } @@ -133,8 +134,8 @@ struct ZetaGlobal(Entity<Zeta>); impl Global for ZetaGlobal {} #[derive(Clone)] -pub struct InlineCompletion { - id: InlineCompletionId, +pub struct EditPrediction { + id: EditPredictionId, path: Arc<Path>, excerpt_range: Range<usize>, cursor_offset: usize, @@ -145,14 +146,14 @@ pub struct InlineCompletion { input_events: Arc<str>, input_excerpt: Arc<str>, output_excerpt: Arc<str>, - request_sent_at: Instant, + buffer_snapshotted_at: Instant, response_received_at: Instant, } -impl InlineCompletion { +impl EditPrediction { fn latency(&self) -> Duration { self.response_received_at - .duration_since(self.request_sent_at) + .duration_since(self.buffer_snapshotted_at) } fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> { @@ -206,9 +207,9 @@ fn interpolate( if edits.is_empty() { None } else { Some(edits) } } -impl std::fmt::Debug for InlineCompletion { +impl std::fmt::Debug for EditPrediction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InlineCompletion") + f.debug_struct("EditPrediction") .field("id", &self.id) .field("path", &self.path) .field("edits", &self.edits) @@ -221,17 +222,14 @@ pub struct Zeta { client: Arc<Client>, events: VecDeque<Event>, registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>, - shown_completions: VecDeque<InlineCompletion>, - rated_completions: HashSet<InlineCompletionId>, + shown_completions: VecDeque<EditPrediction>, + rated_completions: HashSet<EditPredictionId>, data_collection_choice: Entity<DataCollectionChoice>, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - /// Whether the terms of service have been accepted. - tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. update_required: bool, user_store: Entity<UserStore>, - _user_store_subscription: Subscription, license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>, } @@ -306,22 +304,7 @@ impl Zeta { .detach_and_log_err(cx); }, ), - tos_accepted: user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false), update_required: false, - _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { - match event { - client::user::Event::PrivateUserInfoUpdated => { - this.tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); - } - _ => {} - } - }), license_detection_watchers: HashMap::default(), user_store, } @@ -401,111 +384,64 @@ impl Zeta { can_collect_data: bool, cx: &mut Context<Self>, perform_predict_edits: F, - ) -> Task<Result<Option<InlineCompletion>>> + ) -> Task<Result<Option<EditPrediction>>> where F: FnOnce(PerformPredictEditsParams) -> R + 'static, R: Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> + Send + 'static, { + let buffer = buffer.clone(); + let buffer_snapshotted_at = Instant::now(); let snapshot = self.report_changes_for_buffer(&buffer, cx); - let diagnostic_groups = snapshot.diagnostic_groups(None); - let cursor_point = cursor.to_point(&snapshot); - let cursor_offset = cursor_point.to_offset(&snapshot); - let events = self.events.clone(); - let path: Arc<Path> = snapshot - .file() - .map(|f| Arc::from(f.full_path(cx).as_path())) - .unwrap_or_else(|| Arc::from(Path::new("untitled"))); - let zeta = cx.entity(); + let events = self.events.clone(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let buffer = buffer.clone(); - - let local_lsp_store = - project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); - let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store { - Some( - diagnostic_groups - .into_iter() - .filter_map(|(language_server_id, diagnostic_group)| { - let language_server = - local_lsp_store.running_language_server_for_id(language_server_id)?; - - Some(( - language_server.name(), - diagnostic_group.resolve::<usize>(&snapshot), - )) - }) - .collect::<Vec<_>>(), - ) + let git_info = if let (true, Some(project), Some(file)) = + (can_collect_data, project, snapshot.file()) + { + git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) } else { None }; + let full_path: Arc<Path> = snapshot + .file() + .map(|f| Arc::from(f.full_path(cx).as_path())) + .unwrap_or_else(|| Arc::from(Path::new("untitled"))); + let full_path_str = full_path.to_string_lossy().to_string(); + let cursor_point = cursor.to_point(&snapshot); + let cursor_offset = cursor_point.to_offset(&snapshot); + let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS); + let gather_task = gather_context( + project, + full_path_str, + &snapshot, + cursor_point, + make_events_prompt, + can_collect_data, + git_info, + cx, + ); + cx.spawn(async move |this, cx| { - let request_sent_at = Instant::now(); - - struct BackgroundValues { - input_events: String, - input_excerpt: String, - speculated_output: String, - editable_range: Range<usize>, - input_outline: String, - } - - let values = cx - .background_spawn({ - let snapshot = snapshot.clone(); - let path = path.clone(); - async move { - let path = path.to_string_lossy(); - let input_excerpt = excerpt_for_cursor_position( - cursor_point, - &path, - &snapshot, - MAX_REWRITE_TOKENS, - MAX_CONTEXT_TOKENS, - ); - let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS); - let input_outline = prompt_for_outline(&snapshot); - - anyhow::Ok(BackgroundValues { - input_events, - input_excerpt: input_excerpt.prompt, - speculated_output: input_excerpt.speculated_output, - editable_range: input_excerpt.editable_range.to_offset(&snapshot), - input_outline, - }) - } - }) - .await?; + let GatherContextOutput { + body, + editable_range, + } = gather_task.await?; log::debug!( "Events:\n{}\nExcerpt:\n{:?}", - values.input_events, - values.input_excerpt + body.input_events, + body.input_excerpt ); - let body = PredictEditsBody { - input_events: values.input_events.clone(), - input_excerpt: values.input_excerpt.clone(), - speculated_output: Some(values.speculated_output), - outline: Some(values.input_outline.clone()), - can_collect_data, - diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| { - diagnostic_groups - .into_iter() - .map(|(name, diagnostic_group)| { - Ok((name.to_string(), serde_json::to_value(diagnostic_group)?)) - }) - .collect::<Result<Vec<_>>>() - .log_err() - }), - }; + let input_outline = body.outline.clone().unwrap_or_default(); + let input_events = body.input_events.clone(); + let input_excerpt = body.input_excerpt.clone(); let response = perform_predict_edits(PerformPredictEditsParams { client, @@ -563,13 +499,13 @@ impl Zeta { response, buffer, &snapshot, - values.editable_range, + editable_range, cursor_offset, - path, - values.input_outline, - values.input_events, - values.input_excerpt, - request_sent_at, + full_path, + input_outline, + input_events, + input_excerpt, + buffer_snapshotted_at, &cx, ) .await @@ -737,7 +673,7 @@ and then another position: language::Anchor, response: PredictEditsResponse, cx: &mut Context<Self>, - ) -> Task<Result<Option<InlineCompletion>>> { + ) -> Task<Result<Option<EditPrediction>>> { use std::future::ready; self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { @@ -752,7 +688,7 @@ and then another position: language::Anchor, can_collect_data: bool, cx: &mut Context<Self>, - ) -> Task<Result<Option<InlineCompletion>>> { + ) -> Task<Result<Option<EditPrediction>>> { let workspace = self .workspace .as_ref() @@ -768,7 +704,7 @@ and then another ) } - fn perform_predict_edits( + pub fn perform_predict_edits( params: PerformPredictEditsParams, ) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> { async move { @@ -846,7 +782,7 @@ and then another fn accept_edit_prediction( &mut self, - request_id: InlineCompletionId, + request_id: EditPredictionId, cx: &mut Context<Self>, ) -> Task<Result<()>> { let client = self.client.clone(); @@ -923,9 +859,9 @@ and then another input_outline: String, input_events: String, input_excerpt: String, - request_sent_at: Instant, + buffer_snapshotted_at: Instant, cx: &AsyncApp, - ) -> Task<Result<Option<InlineCompletion>>> { + ) -> Task<Result<Option<EditPrediction>>> { let snapshot = snapshot.clone(); let request_id = prediction_response.request_id; let output_excerpt = prediction_response.output_excerpt; @@ -957,8 +893,8 @@ and then another let edit_preview = edit_preview.await; - Ok(Some(InlineCompletion { - id: InlineCompletionId(request_id), + Ok(Some(EditPrediction { + id: EditPredictionId(request_id), path, excerpt_range: editable_range, cursor_offset, @@ -969,7 +905,7 @@ and then another input_events: input_events.into(), input_excerpt: input_excerpt.into(), output_excerpt, - request_sent_at, + buffer_snapshotted_at, response_received_at: Instant::now(), })) }) @@ -1068,11 +1004,11 @@ and then another .collect() } - pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool { + pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool { self.rated_completions.contains(&completion_id) } - pub fn completion_shown(&mut self, completion: &InlineCompletion, cx: &mut Context<Self>) { + pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) { self.shown_completions.push_front(completion.clone()); if self.shown_completions.len() > 50 { let completion = self.shown_completions.pop_back().unwrap(); @@ -1083,8 +1019,8 @@ and then another pub fn rate_completion( &mut self, - completion: &InlineCompletion, - rating: InlineCompletionRating, + completion: &EditPrediction, + rating: EditPredictionRating, feedback: String, cx: &mut Context<Self>, ) { @@ -1102,7 +1038,7 @@ and then another cx.notify(); } - pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> { + pub fn shown_completions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> { self.shown_completions.iter() } @@ -1153,7 +1089,7 @@ and then another } } -struct PerformPredictEditsParams { +pub struct PerformPredictEditsParams { pub client: Arc<Client>, pub llm_token: LlmApiToken, pub app_version: SemanticVersion, @@ -1228,6 +1164,108 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: .sum() } +fn git_info_for_file( + project: &Entity<Project>, + project_path: &ProjectPath, + cx: &App, +) -> Option<PredictEditsGitInfo> { + let git_store = project.read(cx).git_store().read(cx); + if let Some((repository, _repo_path)) = + git_store.repository_and_path_for_project_path(project_path, cx) + { + let repository = repository.read(cx); + let head_sha = repository + .head_commit + .as_ref() + .map(|head_commit| head_commit.sha.to_string()); + let remote_origin_url = repository.remote_origin_url.clone(); + let remote_upstream_url = repository.remote_upstream_url.clone(); + if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { + return None; + } + Some(PredictEditsGitInfo { + head_sha, + remote_origin_url, + remote_upstream_url, + }) + } else { + None + } +} + +pub struct GatherContextOutput { + pub body: PredictEditsBody, + pub editable_range: Range<usize>, +} + +pub fn gather_context( + project: Option<&Entity<Project>>, + full_path_str: String, + snapshot: &BufferSnapshot, + cursor_point: language::Point, + make_events_prompt: impl FnOnce() -> String + Send + 'static, + can_collect_data: bool, + git_info: Option<PredictEditsGitInfo>, + cx: &App, +) -> Task<Result<GatherContextOutput>> { + let local_lsp_store = + project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); + let diagnostic_groups: Vec<(String, serde_json::Value)> = + if let Some(local_lsp_store) = local_lsp_store { + snapshot + .diagnostic_groups(None) + .into_iter() + .filter_map(|(language_server_id, diagnostic_group)| { + let language_server = + local_lsp_store.running_language_server_for_id(language_server_id)?; + let diagnostic_group = diagnostic_group.resolve::<usize>(&snapshot); + let language_server_name = language_server.name().to_string(); + let serialized = serde_json::to_value(diagnostic_group).unwrap(); + Some((language_server_name, serialized)) + }) + .collect::<Vec<_>>() + } else { + Vec::new() + }; + + cx.background_spawn({ + let snapshot = snapshot.clone(); + async move { + let diagnostic_groups = if diagnostic_groups.is_empty() { + None + } else { + Some(diagnostic_groups) + }; + + let input_excerpt = excerpt_for_cursor_position( + cursor_point, + &full_path_str, + &snapshot, + MAX_REWRITE_TOKENS, + MAX_CONTEXT_TOKENS, + ); + let input_events = make_events_prompt(); + let input_outline = prompt_for_outline(&snapshot); + let editable_range = input_excerpt.editable_range.to_offset(&snapshot); + + let body = PredictEditsBody { + input_events, + input_excerpt: input_excerpt.prompt, + speculated_output: Some(input_excerpt.speculated_output), + outline: Some(input_outline), + can_collect_data, + diagnostic_groups, + git_info, + }; + + Ok(GatherContextOutput { + body, + editable_range, + }) + } + }) +} + fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { let mut input_outline = String::new(); @@ -1278,7 +1316,7 @@ struct RegisteredBuffer { } #[derive(Clone)] -enum Event { +pub enum Event { BufferChange { old_snapshot: BufferSnapshot, new_snapshot: BufferSnapshot, @@ -1325,12 +1363,12 @@ impl Event { } #[derive(Debug, Clone)] -struct CurrentInlineCompletion { +struct CurrentEditPrediction { buffer_id: EntityId, - completion: InlineCompletion, + completion: EditPrediction, } -impl CurrentInlineCompletion { +impl CurrentEditPrediction { fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool { if self.buffer_id != old_completion.buffer_id { return true; @@ -1499,17 +1537,17 @@ async fn llm_token_retry( } } -pub struct ZetaInlineCompletionProvider { +pub struct ZetaEditPredictionProvider { zeta: Entity<Zeta>, pending_completions: ArrayVec<PendingCompletion, 2>, next_pending_completion_id: usize, - current_completion: Option<CurrentInlineCompletion>, + current_completion: Option<CurrentEditPrediction>, /// None if this is entirely disabled for this provider provider_data_collection: ProviderDataCollection, last_request_timestamp: Instant, } -impl ZetaInlineCompletionProvider { +impl ZetaEditPredictionProvider { pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); pub fn new(zeta: Entity<Zeta>, provider_data_collection: ProviderDataCollection) -> Self { @@ -1524,7 +1562,7 @@ impl ZetaInlineCompletionProvider { } } -impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider { +impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { fn name() -> &'static str { "zed-predict" } @@ -1573,7 +1611,12 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } fn needs_terms_acceptance(&self, cx: &App) -> bool { - !self.zeta.read(cx).tos_accepted + !self + .zeta + .read(cx) + .user_store + .read(cx) + .has_accepted_terms_of_service() } fn is_refreshing(&self) -> bool { @@ -1588,7 +1631,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider _debounce: bool, cx: &mut Context<Self>, ) { - if !self.zeta.read(cx).tos_accepted { + if self.needs_terms_acceptance(cx) { return; } @@ -1600,7 +1643,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider .zeta .read(cx) .user_store - .read_with(cx, |user_store, _| { + .read_with(cx, |user_store, _cx| { user_store.account_too_young() || user_store.has_overdue_invoices() }) { @@ -1647,7 +1690,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider Ok(completion_request) => { let completion_request = completion_request.await; completion_request.map(|c| { - c.map(|completion| CurrentInlineCompletion { + c.map(|completion| CurrentEditPrediction { buffer_id: buffer.entity_id(), completion, }) @@ -1720,7 +1763,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider &mut self, _buffer: Entity<Buffer>, _cursor_position: language::Anchor, - _direction: inline_completion::Direction, + _direction: edit_prediction::Direction, _cx: &mut Context<Self>, ) { // Right now we don't support cycling. @@ -1751,8 +1794,8 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider buffer: &Entity<Buffer>, cursor_position: language::Anchor, cx: &mut Context<Self>, - ) -> Option<inline_completion::InlineCompletion> { - let CurrentInlineCompletion { + ) -> Option<edit_prediction::EditPrediction> { + let CurrentEditPrediction { buffer_id, completion, .. @@ -1800,7 +1843,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } } - Some(inline_completion::InlineCompletion { + Some(edit_prediction::EditPrediction { id: Some(completion.id.to_string().into()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), edit_preview: Some(completion.edit_preview.clone()), @@ -1817,19 +1860,20 @@ fn tokens_for_bytes(bytes: usize) -> usize { #[cfg(test)] mod tests { + use client::UserStore; use client::test::FakeServer; use clock::FakeSystemClock; + use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; use language::Point; - use rpc::proto; use settings::SettingsStore; use super::*; #[gpui::test] - async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) { + async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| { to_completion_edits( @@ -1844,19 +1888,19 @@ mod tests { .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) .await; - let completion = InlineCompletion { + let completion = EditPrediction { edits, edit_preview, path: Path::new("").into(), snapshot: cx.read(|cx| buffer.read(cx).snapshot()), - id: InlineCompletionId(Uuid::new_v4()), + id: EditPredictionId(Uuid::new_v4()), excerpt_range: 0..0, cursor_offset: 0, input_outline: "".into(), input_events: "".into(), input_excerpt: "".into(), output_excerpt: "".into(), - request_sent_at: Instant::now(), + buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), }; @@ -2010,7 +2054,7 @@ mod tests { } #[gpui::test] - async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) { + async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); @@ -2027,28 +2071,45 @@ mod tests { <|editable_region_end|> ```"}; - let http_client = FakeHttpClient::create(move |_| async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") - .unwrap(), - output_excerpt: completion_response.to_string(), - }) - .unwrap() - .into(), - ) - .unwrap()) + let http_client = FakeHttpClient::create(move |req| async move { + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") + .unwrap(), + output_excerpt: completion_response.to_string(), + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } }); let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -2056,13 +2117,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::<proto::GetUsers>().await.unwrap(); - let token_request = server.receive::<proto::GetLlmToken>().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); buffer.update(cx, |buffer, cx| { buffer.edit(completion.edits.iter().cloned(), None, cx) @@ -2079,20 +2133,36 @@ mod tests { cx: &mut TestAppContext, ) -> Vec<(Range<Point>, String)> { let completion_response = completion_response.to_string(); - let http_client = FakeHttpClient::create(move |_| { + let http_client = FakeHttpClient::create(move |req| { let completion = completion_response.clone(); async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), - output_excerpt: completion, - }) - .unwrap() - .into(), - ) - .unwrap()) + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::new_v4(), + output_excerpt: completion, + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } } }); @@ -2100,9 +2170,10 @@ mod tests { cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); @@ -2111,13 +2182,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::<proto::GetUsers>().await.unwrap(); - let token_request = server.receive::<proto::GetLlmToken>().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); completion .edits diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml new file mode 100644 index 0000000000..e77351c219 --- /dev/null +++ b/crates/zeta_cli/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "zeta_cli" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[[bin]] +name = "zeta" +path = "src/main.rs" + +[dependencies] +anyhow.workspace = true +clap.workspace = true +client.workspace = true +debug_adapter_extension.workspace = true +extension.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +gpui_tokio.workspace = true +language.workspace = true +language_extension.workspace = true +language_model.workspace = true +language_models.workspace = true +languages = { workspace = true, features = ["load-grammars"] } +node_runtime.workspace = true +paths.workspace = true +project.workspace = true +prompt_store.workspace = true +release_channel.workspace = true +reqwest_client.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +shellexpand.workspace = true +terminal_view.workspace = true +util.workspace = true +watch.workspace = true +workspace-hack.workspace = true +zeta.workspace = true +smol.workspace = true diff --git a/crates/zeta_cli/LICENSE-GPL b/crates/zeta_cli/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/zeta_cli/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zeta_cli/build.rs b/crates/zeta_cli/build.rs new file mode 100644 index 0000000000..ccbb54c5b4 --- /dev/null +++ b/crates/zeta_cli/build.rs @@ -0,0 +1,14 @@ +fn main() { + let cargo_toml = + std::fs::read_to_string("../zed/Cargo.toml").expect("Failed to read Cargo.toml"); + let version = cargo_toml + .lines() + .find(|line| line.starts_with("version = ")) + .expect("Version not found in crates/zed/Cargo.toml") + .split('=') + .nth(1) + .expect("Invalid version format") + .trim() + .trim_matches('"'); + println!("cargo:rustc-env=ZED_PKG_VERSION={}", version); +} diff --git a/crates/zeta_cli/src/headless.rs b/crates/zeta_cli/src/headless.rs new file mode 100644 index 0000000000..959bb91a8f --- /dev/null +++ b/crates/zeta_cli/src/headless.rs @@ -0,0 +1,128 @@ +use client::{Client, ProxySettings, UserStore}; +use extension::ExtensionHostProxy; +use fs::RealFs; +use gpui::http_client::read_proxy_from_env; +use gpui::{App, AppContext, Entity}; +use gpui_tokio::Tokio; +use language::LanguageRegistry; +use language_extension::LspAccess; +use node_runtime::{NodeBinaryOptions, NodeRuntime}; +use project::Project; +use project::project_settings::ProjectSettings; +use release_channel::AppVersion; +use reqwest_client::ReqwestClient; +use settings::{Settings, SettingsStore}; +use std::path::PathBuf; +use std::sync::Arc; +use util::ResultExt as _; + +/// Headless subset of `workspace::AppState`. +pub struct ZetaCliAppState { + pub languages: Arc<LanguageRegistry>, + pub client: Arc<Client>, + pub user_store: Entity<UserStore>, + pub fs: Arc<dyn fs::Fs>, + pub node_runtime: NodeRuntime, +} + +// TODO: dedupe with crates/eval/src/eval.rs +pub fn init(cx: &mut App) -> ZetaCliAppState { + let app_version = AppVersion::load(env!("ZED_PKG_VERSION")); + release_channel::init(app_version, cx); + gpui_tokio::init(cx); + + let mut settings_store = SettingsStore::new(cx); + settings_store + .set_default_settings(settings::default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(settings_store); + client::init_settings(cx); + + // Set User-Agent so we can download language servers from GitHub + let user_agent = format!( + "Zed/{} ({}; {})", + app_version, + std::env::consts::OS, + std::env::consts::ARCH + ); + let proxy_str = ProxySettings::get_global(cx).proxy.to_owned(); + let proxy_url = proxy_str + .as_ref() + .and_then(|input| input.parse().ok()) + .or_else(read_proxy_from_env); + let http = { + let _guard = Tokio::handle(cx).enter(); + + ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent) + .expect("could not start HTTP client") + }; + cx.set_http_client(Arc::new(http)); + + Project::init_settings(cx); + + let client = Client::production(cx); + cx.set_http_client(client.http_client()); + + let git_binary_path = None; + let fs = Arc::new(RealFs::new( + git_binary_path, + cx.background_executor().clone(), + )); + + let mut languages = LanguageRegistry::new(cx.background_executor().clone()); + languages.set_language_server_download_dir(paths::languages_dir().clone()); + let languages = Arc::new(languages); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + + extension::init(cx); + + let (mut tx, rx) = watch::channel(None); + cx.observe_global::<SettingsStore>(move |cx| { + let settings = &ProjectSettings::get_global(cx).node; + let options = NodeBinaryOptions { + allow_path_lookup: !settings.ignore_system_version, + allow_binary_download: true, + use_paths: settings.path.as_ref().map(|node_path| { + let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref()); + let npm_path = settings + .npm_path + .as_ref() + .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref())); + ( + node_path.clone(), + npm_path.unwrap_or_else(|| { + let base_path = PathBuf::new(); + node_path.parent().unwrap_or(&base_path).join("npm") + }), + ) + }), + }; + tx.send(Some(options)).log_err(); + }) + .detach(); + let node_runtime = NodeRuntime::new(client.http_client(), None, rx); + + let extension_host_proxy = ExtensionHostProxy::global(cx); + + language::init(cx); + debug_adapter_extension::init(extension_host_proxy.clone(), cx); + language_extension::init( + LspAccess::Noop, + extension_host_proxy.clone(), + languages.clone(), + ); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + languages::init(languages.clone(), node_runtime.clone(), cx); + prompt_store::init(cx); + terminal_view::init(cx); + + ZetaCliAppState { + languages, + client, + user_store, + fs, + node_runtime, + } +} diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs new file mode 100644 index 0000000000..adf7683152 --- /dev/null +++ b/crates/zeta_cli/src/main.rs @@ -0,0 +1,378 @@ +mod headless; + +use anyhow::{Result, anyhow}; +use clap::{Args, Parser, Subcommand}; +use futures::channel::mpsc; +use futures::{FutureExt as _, StreamExt as _}; +use gpui::{AppContext, Application, AsyncApp}; +use gpui::{Entity, Task}; +use language::Bias; +use language::Buffer; +use language::Point; +use language_model::LlmApiToken; +use project::{Project, ProjectPath}; +use release_channel::AppVersion; +use reqwest_client::ReqwestClient; +use std::path::{Path, PathBuf}; +use std::process::exit; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context}; + +use crate::headless::ZetaCliAppState; + +#[derive(Parser, Debug)] +#[command(name = "zeta")] +struct ZetaCliArgs { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +enum Commands { + Context(ContextArgs), + Predict { + #[arg(long)] + predict_edits_body: Option<FileOrStdin>, + #[clap(flatten)] + context_args: Option<ContextArgs>, + }, +} + +#[derive(Debug, Args)] +#[group(requires = "worktree")] +struct ContextArgs { + #[arg(long)] + worktree: PathBuf, + #[arg(long)] + cursor: CursorPosition, + #[arg(long)] + use_language_server: bool, + #[arg(long)] + events: Option<FileOrStdin>, +} + +#[derive(Debug, Clone)] +enum FileOrStdin { + File(PathBuf), + Stdin, +} + +impl FileOrStdin { + async fn read_to_string(&self) -> Result<String, std::io::Error> { + match self { + FileOrStdin::File(path) => smol::fs::read_to_string(path).await, + FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await, + } + } +} + +impl FromStr for FileOrStdin { + type Err = <PathBuf as FromStr>::Err; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s { + "-" => Ok(Self::Stdin), + _ => Ok(Self::File(PathBuf::from_str(s)?)), + } + } +} + +#[derive(Debug, Clone)] +struct CursorPosition { + path: PathBuf, + point: Point, +} + +impl FromStr for CursorPosition { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result<Self> { + let parts: Vec<&str> = s.split(':').collect(); + if parts.len() != 3 { + return Err(anyhow!( + "Invalid cursor format. Expected 'file.rs:line:column', got '{}'", + s + )); + } + + let path = PathBuf::from(parts[0]); + let line: u32 = parts[1] + .parse() + .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?; + let column: u32 = parts[2] + .parse() + .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?; + + // Convert from 1-based to 0-based indexing + let point = Point::new(line.saturating_sub(1), column.saturating_sub(1)); + + Ok(CursorPosition { path, point }) + } +} + +async fn get_context( + args: ContextArgs, + app_state: &Arc<ZetaCliAppState>, + cx: &mut AsyncApp, +) -> Result<GatherContextOutput> { + let ContextArgs { + worktree: worktree_path, + cursor, + use_language_server, + events, + } = args; + + let worktree_path = worktree_path.canonicalize()?; + if cursor.path.is_absolute() { + return Err(anyhow!("Absolute paths are not supported in --cursor")); + } + + let (project, _lsp_open_handle, buffer) = if use_language_server { + let (project, lsp_open_handle, buffer) = + open_buffer_with_language_server(&worktree_path, &cursor.path, &app_state, cx).await?; + (Some(project), Some(lsp_open_handle), buffer) + } else { + let abs_path = worktree_path.join(&cursor.path); + let content = smol::fs::read_to_string(&abs_path).await?; + let buffer = cx.new(|cx| Buffer::local(content, cx))?; + (None, None, buffer) + }; + + let worktree_name = worktree_path + .file_name() + .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?; + let full_path_str = PathBuf::from(worktree_name) + .join(&cursor.path) + .to_string_lossy() + .to_string(); + + let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?; + let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left); + if clipped_cursor != cursor.point { + let max_row = snapshot.max_point().row; + if cursor.point.row < max_row { + return Err(anyhow!( + "Cursor position {:?} is out of bounds (line length is {})", + cursor.point, + snapshot.line_len(cursor.point.row) + )); + } else { + return Err(anyhow!( + "Cursor position {:?} is out of bounds (max row is {})", + cursor.point, + max_row + )); + } + } + + let events = match events { + Some(events) => events.read_to_string().await?, + None => String::new(), + }; + let can_collect_data = false; + let git_info = None; + cx.update(|cx| { + gather_context( + project.as_ref(), + full_path_str, + &snapshot, + clipped_cursor, + move || events, + can_collect_data, + git_info, + cx, + ) + })? + .await +} + +pub async fn open_buffer_with_language_server( + worktree_path: &Path, + path: &Path, + app_state: &Arc<ZetaCliAppState>, + cx: &mut AsyncApp, +) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> { + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(worktree_path, true, cx) + })? + .await?; + + let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { + worktree_id: worktree.id(), + path: path.to_path_buf().into(), + })?; + + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + + let lsp_open_handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + })?; + + let log_prefix = path.to_string_lossy().to_string(); + wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; + + Ok((project, lsp_open_handle, buffer)) +} + +// TODO: Dedupe with similar function in crates/eval/src/instance.rs +pub fn wait_for_lang_server( + project: &Entity<Project>, + buffer: &Entity<Buffer>, + log_prefix: String, + cx: &mut AsyncApp, +) -> Task<Result<()>> { + println!("{}⏵ Waiting for language server", log_prefix); + + let (mut tx, mut rx) = mpsc::channel(1); + + let lsp_store = project + .read_with(cx, |project, _| project.lsp_store()) + .unwrap(); + + let has_lang_server = buffer + .update(cx, |buffer, cx| { + lsp_store.update(cx, |lsp_store, cx| { + lsp_store + .language_servers_for_local_buffer(&buffer, cx) + .next() + .is_some() + }) + }) + .unwrap_or(false); + + if has_lang_server { + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .unwrap() + .detach(); + } + + let subscriptions = [ + cx.subscribe(&lsp_store, { + let log_prefix = log_prefix.clone(); + move |_, event, _| match event { + project::LspStoreEvent::LanguageServerUpdate { + message: + client::proto::update_language_server::Variant::WorkProgress( + client::proto::LspWorkProgress { + message: Some(message), + .. + }, + ), + .. + } => println!("{}⟲ {message}", log_prefix), + _ => {} + } + }), + cx.subscribe(&project, { + let buffer = buffer.clone(); + move |project, event, cx| match event { + project::Event::LanguageServerAdded(_, _, _) => { + let buffer = buffer.clone(); + project + .update(cx, |project, cx| project.save_buffer(buffer, cx)) + .detach(); + } + project::Event::DiskBasedDiagnosticsFinished { .. } => { + tx.try_send(()).ok(); + } + _ => {} + } + }), + ]; + + cx.spawn(async move |cx| { + let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0)); + let result = futures::select! { + _ = rx.next() => { + println!("{}⚑ Language server idle", log_prefix); + anyhow::Ok(()) + }, + _ = timeout.fuse() => { + anyhow::bail!("LSP wait timed out after 5 minutes"); + } + }; + drop(subscriptions); + result + }) +} + +fn main() { + let args = ZetaCliArgs::parse(); + let http_client = Arc::new(ReqwestClient::new()); + let app = Application::headless().with_http_client(http_client); + + app.run(move |cx| { + let app_state = Arc::new(headless::init(cx)); + cx.spawn(async move |cx| { + let result = match args.command { + Commands::Context(context_args) => get_context(context_args, &app_state, cx) + .await + .map(|output| serde_json::to_string_pretty(&output.body).unwrap()), + Commands::Predict { + predict_edits_body, + context_args, + } => { + cx.spawn(async move |cx| { + let app_version = cx.update(|cx| AppVersion::global(cx))?; + app_state.client.sign_in(true, cx).await?; + let llm_token = LlmApiToken::default(); + llm_token.refresh(&app_state.client).await?; + + let predict_edits_body = + if let Some(predict_edits_body) = predict_edits_body { + serde_json::from_str(&predict_edits_body.read_to_string().await?)? + } else if let Some(context_args) = context_args { + get_context(context_args, &app_state, cx).await?.body + } else { + return Err(anyhow!( + "Expected either --predict-edits-body-file \ + or the required args of the `context` command." + )); + }; + + let (response, _usage) = + Zeta::perform_predict_edits(PerformPredictEditsParams { + client: app_state.client.clone(), + llm_token, + app_version, + body: predict_edits_body, + }) + .await?; + + Ok(response.output_excerpt) + }) + .await + } + }; + match result { + Ok(output) => { + println!("{}", output); + let _ = cx.update(|cx| cx.quit()); + } + Err(e) => { + eprintln!("Failed: {:?}", e); + exit(1); + } + } + }) + .detach(); + }); +} diff --git a/crates/zlog/src/sink.rs b/crates/zlog/src/sink.rs index acf0469c77..17aa08026e 100644 --- a/crates/zlog/src/sink.rs +++ b/crates/zlog/src/sink.rs @@ -21,6 +21,8 @@ const ANSI_MAGENTA: &str = "\x1b[35m"; /// Whether stdout output is enabled. static mut ENABLED_SINKS_STDOUT: bool = false; +/// Whether stderr output is enabled. +static mut ENABLED_SINKS_STDERR: bool = false; /// Is Some(file) if file output is enabled. static ENABLED_SINKS_FILE: Mutex<Option<std::fs::File>> = Mutex::new(None); @@ -45,6 +47,12 @@ pub fn init_output_stdout() { } } +pub fn init_output_stderr() { + unsafe { + ENABLED_SINKS_STDERR = true; + } +} + pub fn init_output_file( path: &'static PathBuf, path_rotate: Option<&'static PathBuf>, @@ -115,6 +123,21 @@ pub fn submit(record: Record) { }, record.message ); + } else if unsafe { ENABLED_SINKS_STDERR } { + let mut stdout = std::io::stderr().lock(); + _ = writeln!( + &mut stdout, + "{} {ANSI_BOLD}{}{}{ANSI_RESET} {} {}", + chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%:z"), + LEVEL_ANSI_COLORS[record.level as usize], + LEVEL_OUTPUT_STRINGS[record.level as usize], + SourceFmt { + scope: record.scope, + module_path: record.module_path, + ansi: true, + }, + record.message + ); } let mut file = ENABLED_SINKS_FILE.lock().unwrap_or_else(|handle| { ENABLED_SINKS_FILE.clear_poison(); diff --git a/crates/zlog/src/zlog.rs b/crates/zlog/src/zlog.rs index 570c82314c..5b40278f3f 100644 --- a/crates/zlog/src/zlog.rs +++ b/crates/zlog/src/zlog.rs @@ -5,7 +5,7 @@ mod env_config; pub mod filter; pub mod sink; -pub use sink::{flush, init_output_file, init_output_stdout}; +pub use sink::{flush, init_output_file, init_output_stderr, init_output_stdout}; pub const SCOPE_DEPTH_MAX: usize = 4; diff --git a/docs/README.md b/docs/README.md index 55993c9e36..a225903674 100644 --- a/docs/README.md +++ b/docs/README.md @@ -69,3 +69,64 @@ Templates are just functions that modify the source of the docs pages (usually w - Template Trait: crates/docs_preprocessor/src/templates.rs - Example template: crates/docs_preprocessor/src/templates/keybinding.rs - Client-side plugins: docs/theme/plugins.js + +## Postprocessor + +A postprocessor is implemented as a sub-command of `docs_preprocessor` that wraps the builtin `html` renderer and applies post-processing to the `html` files, to add support for page-specific title and meta description values. + +An example of the syntax can be found in `git.md`, as well as below + +```md +--- +title: Some more detailed title for this page +description: A page-specific description +--- + +# Editor +``` + +The above will be transformed into (with non-relevant tags removed) + +```html +<head> + <title>Editor | Some more detailed title for this page + + + +

Editor

+ +``` + +If no front-matter is provided, or If one or both keys aren't provided, the title and description will be set based on the `default-title` and `default-description` keys in `book.toml` respectively. + +### Implementation details + +Unfortunately, `mdbook` does not support post-processing like it does pre-processing, and only supports defining one description to put in the meta tag per book rather than per file. So in order to apply post-processing (necessary to modify the html head tags) the global book description is set to a marker value `#description#` and the html renderer is replaced with a sub-command of `docs_preprocessor` that wraps the builtin `html` renderer and applies post-processing to the `html` files, replacing the marker value and the `(.*)` with the contents of the front-matter if there is one. + +### Known limitations + +The front-matter parsing is extremely simple, which avoids needing to take on an additional dependency, or implement full yaml parsing. + +- Double quotes and multi-line values are not supported, i.e. Keys and values must be entirely on the same line, with no double quotes around the value. + +The following will not work: + +```md +--- +title: Some + Multi-line + Title +--- +``` + +And neither will: + +```md +--- +title: "Some title" +--- +``` + +- The front-matter must be at the top of the file, with only white-space preceding it + +- The contents of the title and description will not be html-escaped. They should be simple ascii text with no unicode or emoji characters diff --git a/docs/book.toml b/docs/book.toml index f5d186f377..60ddc5ac51 100644 --- a/docs/book.toml +++ b/docs/book.toml @@ -6,13 +6,27 @@ src = "src" title = "Zed" site-url = "/docs/" -[output.html] +[build] +extra-watch-dirs = ["../crates/docs_preprocessor"] + +# zed-html is a "custom" renderer that just wraps the +# builtin mdbook html renderer, and applies post-processing +# as post-processing is not possible with mdbook in the same way +# pre-processing is +# The config is passed directly to the html renderer, so all config +# options that apply to html apply to zed-html +[output.zed-html] +command = "cargo run -p docs_preprocessor -- postprocess" +# Set here instead of above as we only use it replace the `#description#` we set in the template +# when no front-matter is provided value +default-description = "Learn how to use and customize Zed, the fast, collaborative code editor. Official docs on features, configuration, AI tools, and workflows." +default-title = "Zed Code Editor Documentation" no-section-label = true preferred-dark-theme = "dark" additional-css = ["theme/page-toc.css", "theme/plugins.css", "theme/highlight.css"] additional-js = ["theme/page-toc.js", "theme/plugins.js"] -[output.html.print] +[output.zed-html.print] enable = false # Redirects for `/docs` pages. @@ -24,7 +38,7 @@ enable = false # The destination URLs are interpreted relative to `https://zed.dev`. # - Redirects to other docs pages should end in `.html` # - You can link to pages on the Zed site by omitting the `/docs` in front of it. -[output.html.redirect] +[output.zed-html.redirect] # AI "/ai.html" = "/docs/ai/overview.html" "/assistant-panel.html" = "/docs/ai/agent-panel.html" @@ -40,6 +54,7 @@ enable = false "/assistant/prompting.html" = "/docs/ai/rules.html" "/language-model-integration.html" = "/docs/assistant/assistant.html" "/model-improvement.html" = "/docs/ai/ai-improvement.html" +"/ai/temperature.html" = "/docs/ai/agent-settings.html#model-temperature" # Community "/community/feedback.html" = "/community-links" diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 1d43872547..fc936d6bd0 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -45,13 +45,14 @@ - [Overview](./ai/overview.md) - [Agent Panel](./ai/agent-panel.md) - [Tools](./ai/tools.md) - - [Model Temperature](./ai/temperature.md) - [Inline Assistant](./ai/inline-assistant.md) - [Edit Prediction](./ai/edit-prediction.md) - [Text Threads](./ai/text-threads.md) - [Rules](./ai/rules.md) - [Model Context Protocol](./ai/mcp.md) - [Configuration](./ai/configuration.md) + - [LLM Providers](./ai/llm-providers.md) + - [Agent Settings](./ai/agent-settings.md) - [Subscription](./ai/subscription.md) - [Plans and Usage](./ai/plans-and-usage.md) - [Billing](./ai/billing.md) diff --git a/docs/src/accounts.md b/docs/src/accounts.md index c13c98ad9a..1ce23cf902 100644 --- a/docs/src/accounts.md +++ b/docs/src/accounts.md @@ -5,7 +5,7 @@ Signing in to Zed is not a requirement. You can use most features you'd expect i ## What Features Require Signing In? 1. All real-time [collaboration features](./collaboration.md). -2. [LLM-powered features](./ai/overview.md), if you are using Zed as the provider of your LLM models. Alternatively, you can [bring and configure your own API keys](./ai/configuration.md#use-your-own-keys) if you'd prefer, and avoid having to sign in. +2. [LLM-powered features](./ai/overview.md), if you are using Zed as the provider of your LLM models. Alternatively, you can [bring and configure your own API keys](./ai/llm-providers.md#use-your-own-keys) if you'd prefer, and avoid having to sign in. ## Signing In diff --git a/docs/src/ai/agent-panel.md b/docs/src/ai/agent-panel.md index 97568d6643..f944eb88b0 100644 --- a/docs/src/ai/agent-panel.md +++ b/docs/src/ai/agent-panel.md @@ -8,7 +8,7 @@ If you're using the Agent Panel for the first time, you need to have at least on You can do that by: 1. [subscribing to our Pro plan](https://zed.dev/pricing), so you have access to our hosted models -2. or by [bringing your own API keys](./configuration.md#use-your-own-keys) for your desired provider +2. or by [bringing your own API keys](./llm-providers.md#use-your-own-keys) for your desired provider ## Overview {#overview} @@ -87,7 +87,7 @@ You can also do this at any time with an ongoing thread via the "Agent Options" ## Changing Models {#changing-models} -After you've configured your LLM providers—either via [a custom API key](./configuration.md#use-your-own-keys) or through [Zed's hosted models](./models.md)—you can switch between them by clicking on the model selector on the message editor or by using the {#kb agent::ToggleModelSelector} keybinding. +After you've configured your LLM providers—either via [a custom API key](./llm-providers.md#use-your-own-keys) or through [Zed's hosted models](./models.md)—you can switch between them by clicking on the model selector on the message editor or by using the {#kb agent::ToggleModelSelector} keybinding. ## Using Tools {#using-tools} diff --git a/docs/src/ai/agent-settings.md b/docs/src/ai/agent-settings.md new file mode 100644 index 0000000000..ff97bcb8ee --- /dev/null +++ b/docs/src/ai/agent-settings.md @@ -0,0 +1,226 @@ +# Agent Settings + +Learn about all the settings you can customize in Zed's Agent Panel. + +## Model Settings {#model-settings} + +### Default Model {#default-model} + +If you're using [Zed's hosted LLM service](./plans-and-usage.md), it sets `claude-sonnet-4` as the default model. +But if you're not subscribed to it or simply just want to change it, you can do it so either via the model dropdown in the Agent Panel's bottom-right corner or by manually editing the `default_model` object in your settings: + +```json +{ + "agent": { + "default_model": { + "provider": "zed.dev", + "model": "gpt-4o" + } + } +} +``` + +### Feature-specific Models {#feature-specific-models} + +Assign distinct and specific models for the following AI-powered features in Zed: + +- Thread summary model: Used for generating thread summaries +- Inline assistant model: Used for the inline assistant feature +- Commit message model: Used for generating Git commit messages + +```json +{ + "agent": { + "default_model": { + "provider": "zed.dev", + "model": "claude-sonnet-4" + }, + "inline_assistant_model": { + "provider": "anthropic", + "model": "claude-3-5-sonnet" + }, + "commit_message_model": { + "provider": "openai", + "model": "gpt-4o-mini" + }, + "thread_summary_model": { + "provider": "google", + "model": "gemini-2.0-flash" + } + } +} +``` + +> If a custom model isn't set for one of these features, they automatically fall back to using the default model. + +### Alternative Models for Inline Assists {#alternative-assists} + +The Inline Assist feature in particular has the capacity to perform multiple generations in parallel using different models. +That is possible by assigning more than one model to it, taking the configuration shown above one step further. + +When configured, the inline assist UI will surface controls to cycle between the outputs generated by each model. + +The models you specify here are always used in _addition_ to your [default model](#default-model). + +For example, the following configuration will generate two outputs for every assist. +One with Claude Sonnet 4 (the default model), and one with GPT-4o. + +```json +{ + "agent": { + "default_model": { + "provider": "zed.dev", + "model": "claude-sonnet-4" + }, + "inline_alternatives": [ + { + "provider": "zed.dev", + "model": "gpt-4o" + } + ] + } +} +``` + +### Model Temperature + +Specify a custom temperature for a provider and/or model: + +```json +"model_parameters": [ + // To set parameters for all requests to OpenAI models: + { + "provider": "openai", + "temperature": 0.5 + }, + // To set parameters for all requests in general: + { + "temperature": 0 + }, + // To set parameters for a specific provider and model: + { + "provider": "zed.dev", + "model": "claude-sonnet-4", + "temperature": 1.0 + } +], +``` + +## Agent Panel Settings {#agent-panel-settings} + +Note that some of these settings are also surfaced in the Agent Panel's settings UI, which you can access either via the `agent: open settings` action or by the dropdown menu on the top-right corner of the panel. + +### Default View + +Use the `default_view` setting to change the default view of the Agent Panel. +You can choose between `thread` (the default) and `text_thread`: + +```json +{ + "agent": { + "default_view": "text_thread" + } +} +``` + +### Auto-run Commands + +Control whether you want to allow the agent to run commands without asking you for permission. +The default value is `false`. + +```json +{ + "agent": { + "always_allow_tool_actions": "true" + } +} +``` + +> This setting is available via the Agent Panel's settings UI. + +### Single-file Review + +Control whether you want to see review actions (accept & reject) in single buffers after the agent is done performing edits. +The default value is `false`. + +```json +{ + "agent": { + "single_file_review": "true" + } +} +``` + +When set to false, these controls are only available in the multibuffer review tab. + +> This setting is available via the Agent Panel's settings UI. + +### Sound Notification + +Control whether you want to hear a notification sound when the agent is done generating changes or needs your input. +The default value is `false`. + +```json +{ + "agent": { + "play_sound_when_agent_done": "true" + } +} +``` + +> This setting is available via the Agent Panel's settings UI. + +### Modifier to Send + +Make a modifier (`cmd` on macOS, `ctrl` on Linux) required to send messages. +This is encouraged for more thoughtful prompt crafting. +The default value is `false`. + +```json +{ + "agent": { + "use_modifier_to_send": "true" + } +} +``` + +> This setting is available via the Agent Panel's settings UI. + +### Edit Card + +Use the `expand_edit_card` setting to control whether edit cards show the full diff in the Agent Panel. +It is set to `true` by default, but if set to false, the card's height is capped to a certain number of lines, requiring a click to be expanded. + +```json +{ + "agent": { + "expand_edit_card": "false" + } +} +``` + +### Terminal Card + +Use the `expand_terminal_card` setting to control whether terminal cards show the command output in the Agent Panel. +It is set to `true` by default, but if set to false, the card will be fully collapsed even while the command is running, requiring a click to be expanded. + +```json +{ + "agent": { + "expand_terminal_card": "false" + } +} +``` + +### Feedback Controls + +Control whether you want to see the thumbs up/down buttons to give Zed feedback about the agent's performance. +The default value is `true`. + +```json +{ + "agent": { + "enable_feedback": "false" + } +} +``` diff --git a/docs/src/ai/billing.md b/docs/src/ai/billing.md index c49bacd883..d519b136ae 100644 --- a/docs/src/ai/billing.md +++ b/docs/src/ai/billing.md @@ -1,7 +1,7 @@ # Billing We use Stripe as our billing and payments provider. All Pro plans require payment via credit card. -For invoice-based billing, a Business plan is required. Contact sales@zed.dev for more information. +For invoice-based billing, a Business plan is required. Contact [sales@zed.dev](mailto:sales@zed.dev) for more information. ## Settings {#settings} @@ -12,7 +12,8 @@ Clicking the button under Account Settings will navigate you to Stripe’s secur Zed is billed on a monthly basis based on the date you initially subscribe. -We’ll also bill in-month for additional prompts used beyond your plan’s prompt limit, if usage exceeds $20 before month end. See [usage-based pricing](./plans-and-usage.md#ubp) for more. +We’ll also bill in-month for additional prompts used beyond your plan’s prompt limit, if usage exceeds $20 before month end. +See [usage-based pricing](./plans-and-usage.md#ubp) for more. ## Invoice History {#invoice-history} @@ -25,3 +26,12 @@ From Stripe’s secure portal, you can download all current and historical invoi You can update your payment method, company name, address, and tax information through the billing portal. Please note that changes to billing information will **only** affect future invoices — **we cannot modify historical invoices**. + +## Sales Tax {#sales-tax} + +Zed partners with [Sphere](https://www.getsphere.com/) to calculate indirect tax rate for invoices, based on customer location and the product being sold. Tax is listed as a separate line item on invoices, based preferentially on your billing address, followed by the card issue country known to Stripe. + +If you have a VAT/GST ID, you can add it at [zed.dev/account](https://zed.dev/account) by clicking "Manage" on your subscription. Check the box that denotes you as a business. + +Please note that changes to VAT/GST IDs and address will **only** affect future invoices — **we cannot modify historical invoices**. +Questions or issues can be directed to [billing-support@zed.dev](mailto:billing-support@zed.dev). diff --git a/docs/src/ai/configuration.md b/docs/src/ai/configuration.md index 414da2206f..d28a7e8ed0 100644 --- a/docs/src/ai/configuration.md +++ b/docs/src/ai/configuration.md @@ -1,735 +1,20 @@ # Configuration -There are various aspects about the Agent Panel that you can customize. -All of them can be seen by either visiting [the Configuring Zed page](../configuring-zed.md#agent) or by running the `zed: open default settings` action and searching for `"agent"`. +When using AI in Zed, you can customize several aspects: -Alternatively, you can also visit the panel's Settings view by running the `agent: open configuration` action or going to the top-right menu and hitting "Settings". +1. Which [LLM providers](./llm-providers.md) you can use +2. [Model parameters and usage](./agent-settings.md#model-settings) +3. [Interactions with the Agent Panel](./agent-settings.md#agent-panel-settings) -## LLM Providers +## Turning AI Off Entirely -Zed supports multiple large language model providers. -Here's an overview of the supported providers and tool call support: - -| Provider | Tool Use Supported | -| ----------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [Amazon Bedrock](#amazon-bedrock) | Depends on the model | -| [Anthropic](#anthropic) | ✅ | -| [DeepSeek](#deepseek) | ✅ | -| [GitHub Copilot Chat](#github-copilot-chat) | For some models ([link](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198)) | -| [Google AI](#google-ai) | ✅ | -| [LM Studio](#lmstudio) | ✅ | -| [Mistral](#mistral) | ✅ | -| [Ollama](#ollama) | ✅ | -| [OpenAI](#openai) | ✅ | -| [OpenAI API Compatible](#openai-api-compatible) | 🚫 | -| [OpenRouter](#openrouter) | ✅ | -| [Vercel](#vercel-v0) | ✅ | -| [xAI](#xai) | ✅ | - -## Use Your Own Keys {#use-your-own-keys} - -While Zed offers hosted versions of models through [our various plans](./plans-and-usage.md), we're always happy to support users wanting to supply their own API keys. -Below, you can learn how to do that for each provider. - -> Using your own API keys is _free_—you do not need to subscribe to a Zed plan to use our AI features with your own keys. - -### Amazon Bedrock {#amazon-bedrock} - -> ✅ Supports tool use with models that support streaming tool use. -> More details can be found in the [Amazon Bedrock's Tool Use documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html). - -To use Amazon Bedrock's models, an AWS authentication is required. -Ensure your credentials have the following permissions set up: - -- `bedrock:InvokeModelWithResponseStream` -- `bedrock:InvokeModel` -- `bedrock:ConverseStream` - -Your IAM policy should look similar to: +We want to respect users who want to use Zed without interacting with AI whatsoever. +To do that, add the following key to your `settings.json`: ```json { - "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": [ - "bedrock:InvokeModel", - "bedrock:InvokeModelWithResponseStream", - "bedrock:ConverseStream" - ], - "Resource": "*" - } - ] + "disable_ai": true } ``` -With that done, choose one of the two authentication methods: - -#### Authentication via Named Profile (Recommended) - -1. Ensure you have the AWS CLI installed and configured with a named profile -2. Open your `settings.json` (`zed: open settings`) and include the `bedrock` key under `language_models` with the following settings: - ```json - { - "language_models": { - "bedrock": { - "authentication_method": "named_profile", - "region": "your-aws-region", - "profile": "your-profile-name" - } - } - } - ``` - -#### Authentication via Static Credentials - -While it's possible to configure through the Agent Panel settings UI by entering your AWS access key and secret directly, we recommend using named profiles instead for better security practices. -To do this: - -1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). -2. Create security credentials for that User, save them and keep them secure. -3. Open the Agent Configuration with (`agent: open configuration`) and go to the Amazon Bedrock section -4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields. - -#### Cross-Region Inference - -The Zed implementation of Amazon Bedrock uses [Cross-Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html) for all the models and region combinations that support it. -With Cross-Region inference, you can distribute traffic across multiple AWS Regions, enabling higher throughput. - -For example, if you use `Claude Sonnet 3.7 Thinking` from `us-east-1`, it may be processed across the US regions, namely: `us-east-1`, `us-east-2`, or `us-west-2`. -Cross-Region inference requests are kept within the AWS Regions that are part of the geography where the data originally resides. -For example, a request made within the US is kept within the AWS Regions in the US. - -Although the data remains stored only in the source Region, your input prompts and output results might move outside of your source Region during cross-Region inference. -All data will be transmitted encrypted across Amazon's secure network. - -We will support Cross-Region inference for each of the models on a best-effort basis, please refer to the [Cross-Region Inference method Code](https://github.com/zed-industries/zed/blob/main/crates/bedrock/src/models.rs#L297). - -For the most up-to-date supported regions and models, refer to the [Supported Models and Regions for Cross Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). - -### Anthropic {#anthropic} - -> ✅ Supports tool use - -You can use Anthropic models by choosing it via the model dropdown in the Agent Panel. - -1. Sign up for Anthropic and [create an API key](https://console.anthropic.com/settings/keys) -2. Make sure that your Anthropic account has credits -3. Open the settings view (`agent: open configuration`) and go to the Anthropic section -4. Enter your Anthropic API key - -Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API. - -Zed will also use the `ANTHROPIC_API_KEY` environment variable if it's defined. - -#### Custom Models {#anthropic-custom-models} - -You can add custom models to the Anthropic provider by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "anthropic": { - "available_models": [ - { - "name": "claude-3-5-sonnet-20240620", - "display_name": "Sonnet 2024-June", - "max_tokens": 128000, - "max_output_tokens": 2560, - "cache_configuration": { - "max_cache_anchors": 10, - "min_total_token": 10000, - "should_speculate": false - }, - "tool_override": "some-model-that-supports-toolcalling" - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. - -You can configure a model to use [extended thinking](https://docs.anthropic.com/en/docs/about-claude/models/extended-thinking-models) (if it supports it) by changing the mode in your model's configuration to `thinking`, for example: - -```json -{ - "name": "claude-sonnet-4-latest", - "display_name": "claude-sonnet-4-thinking", - "max_tokens": 200000, - "mode": { - "type": "thinking", - "budget_tokens": 4_096 - } -} -``` - -### DeepSeek {#deepseek} - -> ✅ Supports tool use - -1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) -2. Open the settings view (`agent: open configuration`) and go to the DeepSeek section -3. Enter your DeepSeek API key - -The DeepSeek API key will be saved in your keychain. - -Zed will also use the `DEEPSEEK_API_KEY` environment variable if it's defined. - -#### Custom Models {#deepseek-custom-models} - -The Zed agent comes pre-configured to use the latest version for common models (DeepSeek Chat, DeepSeek Reasoner). -If you wish to use alternate models or customize the API endpoint, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "deepseek": { - "api_url": "https://api.deepseek.com", - "available_models": [ - { - "name": "deepseek-chat", - "display_name": "DeepSeek Chat", - "max_tokens": 64000 - }, - { - "name": "deepseek-reasoner", - "display_name": "DeepSeek Reasoner", - "max_tokens": 64000, - "max_output_tokens": 4096 - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. -You can also modify the `api_url` to use a custom endpoint if needed. - -### GitHub Copilot Chat {#github-copilot-chat} - -> ✅ Supports tool use in some cases. -> Visit [the Copilot Chat code](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198) for the supported subset. - -You can use GitHub Copilot Chat with the Zed agent by choosing it via the model dropdown in the Agent Panel. - -1. Open the settings view (`agent: open configuration`) and go to the GitHub Copilot Chat section -2. Click on `Sign in to use GitHub Copilot`, follow the steps shown in the modal. - -Alternatively, you can provide an OAuth token via the `GH_COPILOT_TOKEN` environment variable. - -> **Note**: If you don't see specific models in the dropdown, you may need to enable them in your [GitHub Copilot settings](https://github.com/settings/copilot/features). - -To use Copilot Enterprise with Zed (for both agent and inline completions), you must configure your enterprise endpoint as described in [Configuring GitHub Copilot Enterprise](./edit-prediction.md#github-copilot-enterprise). - -### Google AI {#google-ai} - -> ✅ Supports tool use - -You can use Gemini models with the Zed agent by choosing it via the model dropdown in the Agent Panel. - -1. Go to the Google AI Studio site and [create an API key](https://aistudio.google.com/app/apikey). -2. Open the settings view (`agent: open configuration`) and go to the Google AI section -3. Enter your Google AI API key and press enter. - -The Google AI API key will be saved in your keychain. - -Zed will also use the `GEMINI_API_KEY` environment variable if it's defined. See [Using Gemini API keys](Using Gemini API keys) in the Gemini docs for more. - -#### Custom Models {#google-ai-custom-models} - -By default, Zed will use `stable` versions of models, but you can use specific versions of models, including [experimental models](https://ai.google.dev/gemini-api/docs/models/experimental-models). You can configure a model to use [thinking mode](https://ai.google.dev/gemini-api/docs/thinking) (if it supports it) by adding a `mode` configuration to your model. This is useful for controlling reasoning token usage and response speed. If not specified, Gemini will automatically choose the thinking budget. - -Here is an example of a custom Google AI model you could add to your Zed `settings.json`: - -```json -{ - "language_models": { - "google": { - "available_models": [ - { - "name": "gemini-2.5-flash-preview-05-20", - "display_name": "Gemini 2.5 Flash (Thinking)", - "max_tokens": 1000000, - "mode": { - "type": "thinking", - "budget_tokens": 24000 - } - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. - -### LM Studio {#lmstudio} - -> ✅ Supports tool use - -1. Download and install [the latest version of LM Studio](https://lmstudio.ai/download) -2. In the app press `cmd/ctrl-shift-m` and download at least one model (e.g., qwen2.5-coder-7b). Alternatively, you can get models via the LM Studio CLI: - - ```sh - lms get qwen2.5-coder-7b - ``` - -3. Make sure the LM Studio API server is running by executing: - - ```sh - lms server start - ``` - -Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless#run-the-llm-service-on-machine-login) to automate running the LM Studio server. - -### Mistral {#mistral} - -> ✅ Supports tool use - -1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) -2. Open the configuration view (`agent: open configuration`) and navigate to the Mistral section -3. Enter your Mistral API key - -The Mistral API key will be saved in your keychain. - -Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined. - -#### Custom Models {#mistral-custom-models} - -The Zed agent comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). -All the default models support tool use. -If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "mistral": { - "api_url": "https://api.mistral.ai/v1", - "available_models": [ - { - "name": "mistral-tiny-latest", - "display_name": "Mistral Tiny", - "max_tokens": 32000, - "max_output_tokens": 4096, - "max_completion_tokens": 1024, - "supports_tools": true, - "supports_images": false - } - ] - } - } -} -``` - -Custom models will be listed in the model dropdown in the Agent Panel. - -### Ollama {#ollama} - -> ✅ Supports tool use - -Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`. - -1. Download one of the [available models](https://ollama.com/models), for example, for `mistral`: - - ```sh - ollama pull mistral - ``` - -2. Make sure that the Ollama server is running. You can start it either via running Ollama.app (macOS) or launching: - - ```sh - ollama serve - ``` - -3. In the Agent Panel, select one of the Ollama models using the model dropdown. - -#### Ollama Context Length {#ollama-context} - -Zed has pre-configured maximum context lengths (`max_tokens`) to match the capabilities of common models. -Zed API requests to Ollama include this as the `num_ctx` parameter, but the default values do not exceed `16384` so users with ~16GB of RAM are able to use most models out of the box. - -See [get_max_tokens in ollama.rs](https://github.com/zed-industries/zed/blob/main/crates/ollama/src/ollama.rs) for a complete set of defaults. - -> **Note**: Token counts displayed in the Agent Panel are only estimates and will differ from the model's native tokenizer. - -Depending on your hardware or use-case you may wish to limit or increase the context length for a specific model via settings.json: - -```json -{ - "language_models": { - "ollama": { - "api_url": "http://localhost:11434", - "available_models": [ - { - "name": "qwen2.5-coder", - "display_name": "qwen 2.5 coder 32K", - "max_tokens": 32768, - "supports_tools": true, - "supports_thinking": true, - "supports_images": true - } - ] - } - } -} -``` - -If you specify a context length that is too large for your hardware, Ollama will log an error. -You can watch these logs by running: `tail -f ~/.ollama/logs/ollama.log` (macOS) or `journalctl -u ollama -f` (Linux). -Depending on the memory available on your machine, you may need to adjust the context length to a smaller value. - -You may also optionally specify a value for `keep_alive` for each available model. -This can be an integer (seconds) or alternatively a string duration like "5m", "10m", "1h", "1d", etc. -For example, `"keep_alive": "120s"` will allow the remote server to unload the model (freeing up GPU VRAM) after 120 seconds. - -The `supports_tools` option controls whether the model will use additional tools. -If the model is tagged with `tools` in the Ollama catalog, this option should be supplied, and the built-in profiles `Ask` and `Write` can be used. -If the model is not tagged with `tools` in the Ollama catalog, this option can still be supplied with the value `true`; however, be aware that only the `Minimal` built-in profile will work. - -The `supports_thinking` option controls whether the model will perform an explicit "thinking" (reasoning) pass before producing its final answer. -If the model is tagged with `thinking` in the Ollama catalog, set this option and you can use it in Zed. - -The `supports_images` option enables the model's vision capabilities, allowing it to process images included in the conversation context. -If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in Zed. - -### OpenAI {#openai} - -> ✅ Supports tool use - -1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) -2. Make sure that your OpenAI account has credits -3. Open the settings view (`agent: open configuration`) and go to the OpenAI section -4. Enter your OpenAI API key - -The OpenAI API key will be saved in your keychain. - -Zed will also use the `OPENAI_API_KEY` environment variable if it's defined. - -#### Custom Models {#openai-custom-models} - -The Zed agent comes pre-configured to use the latest version for common models (GPT-3.5 Turbo, GPT-4, GPT-4 Turbo, GPT-4o, GPT-4o mini). -To use alternate models, perhaps a preview release or a dated model release, or if you wish to control the request parameters, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "openai": { - "available_models": [ - { - "name": "gpt-4o-2024-08-06", - "display_name": "GPT 4o Summer 2024", - "max_tokens": 128000 - }, - { - "name": "o1-mini", - "display_name": "o1-mini", - "max_tokens": 128000, - "max_completion_tokens": 20000 - } - ], - "version": "1" - } - } -} -``` - -You must provide the model's context window in the `max_tokens` parameter; this can be found in the [OpenAI model documentation](https://platform.openai.com/docs/models). - -OpenAI `o1` models should set `max_completion_tokens` as well to avoid incurring high reasoning token costs. -Custom models will be listed in the model dropdown in the Agent Panel. - -### OpenAI API Compatible {#openai-api-compatible} - -Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models. - -To configure a compatible API, you can add a custom API URL for OpenAI either via the UI (currently available only in Preview) or by editing your `settings.json`. - -For example, to connect to [Together AI](https://www.together.ai/) via the UI: - -1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys). -2. Go to the Agent Panel's settings view, click on the "Add Provider" button, and then on the "OpenAI" menu item -3. Add the requested fields, such as `api_url`, `api_key`, available models, and others - -Alternatively, you can also add it via the `settings.json`: - -```json -{ - "language_models": { - "openai": { - "api_url": "https://api.together.xyz/v1", - "api_key": "YOUR_TOGETHER_AI_API_KEY", - "available_models": [ - { - "name": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "display_name": "Together Mixtral 8x7B", - "max_tokens": 32768, - "supports_tools": true - } - ] - } - } -} -``` - -### OpenRouter {#openrouter} - -> ✅ Supports tool use - -OpenRouter provides access to multiple AI models through a single API. It supports tool use for compatible models. - -1. Visit [OpenRouter](https://openrouter.ai) and create an account -2. Generate an API key from your [OpenRouter keys page](https://openrouter.ai/keys) -3. Open the settings view (`agent: open configuration`) and go to the OpenRouter section -4. Enter your OpenRouter API key - -The OpenRouter API key will be saved in your keychain. - -Zed will also use the `OPENROUTER_API_KEY` environment variable if it's defined. - -#### Custom Models {#openrouter-custom-models} - -You can add custom models to the OpenRouter provider by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "open_router": { - "api_url": "https://openrouter.ai/api/v1", - "available_models": [ - { - "name": "google/gemini-2.0-flash-thinking-exp", - "display_name": "Gemini 2.0 Flash (Thinking)", - "max_tokens": 200000, - "max_output_tokens": 8192, - "supports_tools": true, - "supports_images": true, - "mode": { - "type": "thinking", - "budget_tokens": 8000 - } - } - ] - } - } -} -``` - -The available configuration options for each model are: - -- `name` (required): The model identifier used by OpenRouter -- `display_name` (optional): A human-readable name shown in the UI -- `max_tokens` (required): The model's context window size -- `max_output_tokens` (optional): Maximum tokens the model can generate -- `max_completion_tokens` (optional): Maximum completion tokens -- `supports_tools` (optional): Whether the model supports tool/function calling -- `supports_images` (optional): Whether the model supports image inputs -- `mode` (optional): Special mode configuration for thinking models - -You can find available models and their specifications on the [OpenRouter models page](https://openrouter.ai/models). - -Custom models will be listed in the model dropdown in the Agent Panel. - -### Vercel v0 {#vercel-v0} - -> ✅ Supports tool use - -[Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel. -It supports text and image inputs and provides fast streaming responses. - -The v0 models are [OpenAI-compatible models](/#openai-api-compatible), but Vercel is listed as first-class provider in the panel's settings view. - -To start using it with Zed, ensure you have first created a [v0 API key](https://v0.dev/chat/settings/keys). -Once you have it, paste it directly into the Vercel provider section in the panel's settings view. - -You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. - -### xAI {#xai} - -> ✅ Supports tool use - -Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. - -1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) -2. Open the settings view (`agent: open configuration`) and go to the **xAI** section -3. Enter your xAI API key - -The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined. - -> **Note:** While the xAI API is OpenAI-compatible, Zed has first-class support for it as a dedicated provider. For the best experience, we recommend using the dedicated `x_ai` provider configuration instead of the [OpenAI API Compatible](#openai-api-compatible) method. - -#### Custom Models {#xai-custom-models} - -The Zed agent comes pre-configured with common Grok models. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: - -```json -{ - "language_models": { - "x_ai": { - "api_url": "https://api.x.ai/v1", - "available_models": [ - { - "name": "grok-1.5", - "display_name": "Grok 1.5", - "max_tokens": 131072, - "max_output_tokens": 8192 - }, - { - "name": "grok-1.5v", - "display_name": "Grok 1.5V (Vision)", - "max_tokens": 131072, - "max_output_tokens": 8192, - "supports_images": true - } - ] - } - } -} -``` - -## Advanced Configuration {#advanced-configuration} - -### Custom Provider Endpoints {#custom-provider-endpoint} - -You can use a custom API endpoint for different providers, as long as it's compatible with the provider's API structure. -To do so, add the following to your `settings.json`: - -```json -{ - "language_models": { - "some-provider": { - "api_url": "http://localhost:11434" - } - } -} -``` - -Where `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`. - -### Default Model {#default-model} - -Zed's hosted LLM service sets `claude-sonnet-4` as the default model. -However, you can change it either via the model dropdown in the Agent Panel's bottom-right corner or by manually editing the `default_model` object in your settings: - -```json -{ - "agent": { - "version": "2", - "default_model": { - "provider": "zed.dev", - "model": "gpt-4o" - } - } -} -``` - -### Feature-specific Models {#feature-specific-models} - -If a feature-specific model is not set, it will fall back to using the default model, which is the one you set on the Agent Panel. - -You can configure the following feature-specific models: - -- Thread summary model: Used for generating thread summaries -- Inline assistant model: Used for the inline assistant feature -- Commit message model: Used for generating Git commit messages - -Example configuration: - -```json -{ - "agent": { - "version": "2", - "default_model": { - "provider": "zed.dev", - "model": "claude-sonnet-4" - }, - "inline_assistant_model": { - "provider": "anthropic", - "model": "claude-3-5-sonnet" - }, - "commit_message_model": { - "provider": "openai", - "model": "gpt-4o-mini" - }, - "thread_summary_model": { - "provider": "google", - "model": "gemini-2.0-flash" - } - } -} -``` - -### Alternative Models for Inline Assists {#alternative-assists} - -You can configure additional models that will be used to perform inline assists in parallel. -When you do this, the inline assist UI will surface controls to cycle between the alternatives generated by each model. - -The models you specify here are always used in _addition_ to your [default model](#default-model). -For example, the following configuration will generate two outputs for every assist. -One with Claude 3.7 Sonnet, and one with GPT-4o. - -```json -{ - "agent": { - "default_model": { - "provider": "zed.dev", - "model": "claude-sonnet-4" - }, - "inline_alternatives": [ - { - "provider": "zed.dev", - "model": "gpt-4o" - } - ], - "version": "2" - } -} -``` - -### Default View - -Use the `default_view` setting to set change the default view of the Agent Panel. -You can choose between `thread` (the default) and `text_thread`: - -```json -{ - "agent": { - "default_view": "text_thread" - } -} -``` - -### Edit Card - -Use the `expand_edit_card` setting to control whether edit cards show the full diff in the Agent Panel. -It is set to `true` by default, but if set to false, the card's height is capped to a certain number of lines, requiring a click to be expanded. - -```json -{ - "agent": { - "expand_edit_card": "false" - } -} -``` - -This setting is currently only available in Preview. -It should be up in Stable by the next release. - -### Terminal Card - -Use the `expand_terminal_card` setting to control whether terminal cards show the command output in the Agent Panel. -It is set to `true` by default, but if set to false, the card will be fully collapsed even while the command is running, requiring a click to be expanded. - -```json -{ - "agent": { - "expand_terminal_card": "false" - } -} -``` - -This setting is currently only available in Preview. -It should be up in Stable by the next release. +Read [the following blog post](https://zed.dev/blog/disable-ai-features) to learn more about our motivation to promote this, as much as we also encourage users to explore AI-assisted programming. diff --git a/docs/src/ai/inline-assistant.md b/docs/src/ai/inline-assistant.md index cd0ace3ce6..da894e2cd8 100644 --- a/docs/src/ai/inline-assistant.md +++ b/docs/src/ai/inline-assistant.md @@ -12,7 +12,7 @@ You can also perform multiple generation requests in parallel by pressing `ctrl- Give the Inline Assistant context the same way you can in [the Agent Panel](./agent-panel.md), allowing you to provide additional instructions or rules for code transformations with @-mentions. -A useful pattern here is to create a thread in the Agent Panel, and then use the mention that thread with `@thread` in the Inline Assistant to include it as context. +A useful pattern here is to create a thread in the Agent Panel, and then mention that thread with `@thread` in the Inline Assistant to include it as context. > The Inline Assistant is limited to normal mode context windows ([see Models](./models.md) for more). diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md new file mode 100644 index 0000000000..8fdb7ea325 --- /dev/null +++ b/docs/src/ai/llm-providers.md @@ -0,0 +1,583 @@ +# LLM Providers + +To use AI in Zed, you need to have at least one large language model provider set up. + +You can do that by either subscribing to [one of Zed's plans](./plans-and-usage.md), or by using API keys you already have for the supported providers. + +## Use Your Own Keys {#use-your-own-keys} + +If you already have an API key for an existing LLM provider—say Anthropic or OpenAI, for example—you can insert them in Zed and use the Agent Panel **_for free_**. + +You can add your API key to a given provider either via the Agent Panel's settings UI or directly via the `settings.json` through the `language_models` key. + +## Supported Providers + +Here's all the supported LLM providers for which you can use your own API keys: + +| Provider | +| ----------------------------------------------- | +| [Amazon Bedrock](#amazon-bedrock) | +| [Anthropic](#anthropic) | +| [DeepSeek](#deepseek) | +| [GitHub Copilot Chat](#github-copilot-chat) | +| [Google AI](#google-ai) | +| [LM Studio](#lmstudio) | +| [Mistral](#mistral) | +| [Ollama](#ollama) | +| [OpenAI](#openai) | +| [OpenAI API Compatible](#openai-api-compatible) | +| [OpenRouter](#openrouter) | +| [Vercel](#vercel-v0) | +| [xAI](#xai) | + +### Amazon Bedrock {#amazon-bedrock} + +> Supports tool use with models that support streaming tool use. +> More details can be found in the [Amazon Bedrock's Tool Use documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html). + +To use Amazon Bedrock's models, an AWS authentication is required. +Ensure your credentials have the following permissions set up: + +- `bedrock:InvokeModelWithResponseStream` +- `bedrock:InvokeModel` +- `bedrock:ConverseStream` + +Your IAM policy should look similar to: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModel", + "bedrock:InvokeModelWithResponseStream", + "bedrock:ConverseStream" + ], + "Resource": "*" + } + ] +} +``` + +With that done, choose one of the two authentication methods: + +#### Authentication via Named Profile (Recommended) + +1. Ensure you have the AWS CLI installed and configured with a named profile +2. Open your `settings.json` (`zed: open settings`) and include the `bedrock` key under `language_models` with the following settings: + ```json + { + "language_models": { + "bedrock": { + "authentication_method": "named_profile", + "region": "your-aws-region", + "profile": "your-profile-name" + } + } + } + ``` + +#### Authentication via Static Credentials + +While it's possible to configure through the Agent Panel settings UI by entering your AWS access key and secret directly, we recommend using named profiles instead for better security practices. +To do this: + +1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). +2. Create security credentials for that User, save them and keep them secure. +3. Open the Agent Configuration with (`agent: open settings`) and go to the Amazon Bedrock section +4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields. + +#### Cross-Region Inference + +The Zed implementation of Amazon Bedrock uses [Cross-Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html) for all the models and region combinations that support it. +With Cross-Region inference, you can distribute traffic across multiple AWS Regions, enabling higher throughput. + +For example, if you use `Claude Sonnet 3.7 Thinking` from `us-east-1`, it may be processed across the US regions, namely: `us-east-1`, `us-east-2`, or `us-west-2`. +Cross-Region inference requests are kept within the AWS Regions that are part of the geography where the data originally resides. +For example, a request made within the US is kept within the AWS Regions in the US. + +Although the data remains stored only in the source Region, your input prompts and output results might move outside of your source Region during cross-Region inference. +All data will be transmitted encrypted across Amazon's secure network. + +We will support Cross-Region inference for each of the models on a best-effort basis, please refer to the [Cross-Region Inference method Code](https://github.com/zed-industries/zed/blob/main/crates/bedrock/src/models.rs#L297). + +For the most up-to-date supported regions and models, refer to the [Supported Models and Regions for Cross Region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). + +### Anthropic {#anthropic} + +You can use Anthropic models by choosing them via the model dropdown in the Agent Panel. + +1. Sign up for Anthropic and [create an API key](https://console.anthropic.com/settings/keys) +2. Make sure that your Anthropic account has credits +3. Open the settings view (`agent: open settings`) and go to the Anthropic section +4. Enter your Anthropic API key + +Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API. + +Zed will also use the `ANTHROPIC_API_KEY` environment variable if it's defined. + +#### Custom Models {#anthropic-custom-models} + +You can add custom models to the Anthropic provider by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "anthropic": { + "available_models": [ + { + "name": "claude-3-5-sonnet-20240620", + "display_name": "Sonnet 2024-June", + "max_tokens": 128000, + "max_output_tokens": 2560, + "cache_configuration": { + "max_cache_anchors": 10, + "min_total_token": 10000, + "should_speculate": false + }, + "tool_override": "some-model-that-supports-toolcalling" + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. + +You can configure a model to use [extended thinking](https://docs.anthropic.com/en/docs/about-claude/models/extended-thinking-models) (if it supports it) by changing the mode in your model's configuration to `thinking`, for example: + +```json +{ + "name": "claude-sonnet-4-latest", + "display_name": "claude-sonnet-4-thinking", + "max_tokens": 200000, + "mode": { + "type": "thinking", + "budget_tokens": 4_096 + } +} +``` + +### DeepSeek {#deepseek} + +1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) +2. Open the settings view (`agent: open settings`) and go to the DeepSeek section +3. Enter your DeepSeek API key + +The DeepSeek API key will be saved in your keychain. + +Zed will also use the `DEEPSEEK_API_KEY` environment variable if it's defined. + +#### Custom Models {#deepseek-custom-models} + +The Zed agent comes pre-configured to use the latest version for common models (DeepSeek Chat, DeepSeek Reasoner). +If you wish to use alternate models or customize the API endpoint, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "deepseek": { + "api_url": "https://api.deepseek.com", + "available_models": [ + { + "name": "deepseek-chat", + "display_name": "DeepSeek Chat", + "max_tokens": 64000 + }, + { + "name": "deepseek-reasoner", + "display_name": "DeepSeek Reasoner", + "max_tokens": 64000, + "max_output_tokens": 4096 + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. +You can also modify the `api_url` to use a custom endpoint if needed. + +### GitHub Copilot Chat {#github-copilot-chat} + +You can use GitHub Copilot Chat with the Zed agent by choosing it via the model dropdown in the Agent Panel. + +1. Open the settings view (`agent: open settings`) and go to the GitHub Copilot Chat section +2. Click on `Sign in to use GitHub Copilot`, follow the steps shown in the modal. + +Alternatively, you can provide an OAuth token via the `GH_COPILOT_TOKEN` environment variable. + +> **Note**: If you don't see specific models in the dropdown, you may need to enable them in your [GitHub Copilot settings](https://github.com/settings/copilot/features). + +To use Copilot Enterprise with Zed (for both agent and completions), you must configure your enterprise endpoint as described in [Configuring GitHub Copilot Enterprise](./edit-prediction.md#github-copilot-enterprise). + +### Google AI {#google-ai} + +You can use Gemini models with the Zed agent by choosing it via the model dropdown in the Agent Panel. + +1. Go to the Google AI Studio site and [create an API key](https://aistudio.google.com/app/apikey). +2. Open the settings view (`agent: open settings`) and go to the Google AI section +3. Enter your Google AI API key and press enter. + +The Google AI API key will be saved in your keychain. + +Zed will also use the `GEMINI_API_KEY` environment variable if it's defined. See [Using Gemini API keys](https://ai.google.dev/gemini-api/docs/api-key) in the Gemini docs for more. + +#### Custom Models {#google-ai-custom-models} + +By default, Zed will use `stable` versions of models, but you can use specific versions of models, including [experimental models](https://ai.google.dev/gemini-api/docs/models/experimental-models). You can configure a model to use [thinking mode](https://ai.google.dev/gemini-api/docs/thinking) (if it supports it) by adding a `mode` configuration to your model. This is useful for controlling reasoning token usage and response speed. If not specified, Gemini will automatically choose the thinking budget. + +Here is an example of a custom Google AI model you could add to your Zed `settings.json`: + +```json +{ + "language_models": { + "google": { + "available_models": [ + { + "name": "gemini-2.5-flash-preview-05-20", + "display_name": "Gemini 2.5 Flash (Thinking)", + "max_tokens": 1000000, + "mode": { + "type": "thinking", + "budget_tokens": 24000 + } + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. + +### LM Studio {#lmstudio} + +1. Download and install [the latest version of LM Studio](https://lmstudio.ai/download) +2. In the app press `cmd/ctrl-shift-m` and download at least one model (e.g., qwen2.5-coder-7b). Alternatively, you can get models via the LM Studio CLI: + + ```sh + lms get qwen2.5-coder-7b + ``` + +3. Make sure the LM Studio API server is running by executing: + + ```sh + lms server start + ``` + +Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless#run-the-llm-service-on-machine-login) to automate running the LM Studio server. + +### Mistral {#mistral} + +1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) +2. Open the configuration view (`agent: open settings`) and navigate to the Mistral section +3. Enter your Mistral API key + +The Mistral API key will be saved in your keychain. + +Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined. + +#### Custom Models {#mistral-custom-models} + +The Zed agent comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). +All the default models support tool use. +If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "mistral": { + "api_url": "https://api.mistral.ai/v1", + "available_models": [ + { + "name": "mistral-tiny-latest", + "display_name": "Mistral Tiny", + "max_tokens": 32000, + "max_output_tokens": 4096, + "max_completion_tokens": 1024, + "supports_tools": true, + "supports_images": false + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the Agent Panel. + +### Ollama {#ollama} + +Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`. + +1. Download one of the [available models](https://ollama.com/models), for example, for `mistral`: + + ```sh + ollama pull mistral + ``` + +2. Make sure that the Ollama server is running. You can start it either via running Ollama.app (macOS) or launching: + + ```sh + ollama serve + ``` + +3. In the Agent Panel, select one of the Ollama models using the model dropdown. + +#### Ollama Context Length {#ollama-context} + +Zed has pre-configured maximum context lengths (`max_tokens`) to match the capabilities of common models. +Zed API requests to Ollama include this as the `num_ctx` parameter, but the default values do not exceed `16384` so users with ~16GB of RAM are able to use most models out of the box. + +See [get_max_tokens in ollama.rs](https://github.com/zed-industries/zed/blob/main/crates/ollama/src/ollama.rs) for a complete set of defaults. + +> **Note**: Token counts displayed in the Agent Panel are only estimates and will differ from the model's native tokenizer. + +Depending on your hardware or use-case you may wish to limit or increase the context length for a specific model via settings.json: + +```json +{ + "language_models": { + "ollama": { + "api_url": "http://localhost:11434", + "available_models": [ + { + "name": "qwen2.5-coder", + "display_name": "qwen 2.5 coder 32K", + "max_tokens": 32768, + "supports_tools": true, + "supports_thinking": true, + "supports_images": true + } + ] + } + } +} +``` + +If you specify a context length that is too large for your hardware, Ollama will log an error. +You can watch these logs by running: `tail -f ~/.ollama/logs/ollama.log` (macOS) or `journalctl -u ollama -f` (Linux). +Depending on the memory available on your machine, you may need to adjust the context length to a smaller value. + +You may also optionally specify a value for `keep_alive` for each available model. +This can be an integer (seconds) or alternatively a string duration like "5m", "10m", "1h", "1d", etc. +For example, `"keep_alive": "120s"` will allow the remote server to unload the model (freeing up GPU VRAM) after 120 seconds. + +The `supports_tools` option controls whether the model will use additional tools. +If the model is tagged with `tools` in the Ollama catalog, this option should be supplied, and the built-in profiles `Ask` and `Write` can be used. +If the model is not tagged with `tools` in the Ollama catalog, this option can still be supplied with the value `true`; however, be aware that only the `Minimal` built-in profile will work. + +The `supports_thinking` option controls whether the model will perform an explicit "thinking" (reasoning) pass before producing its final answer. +If the model is tagged with `thinking` in the Ollama catalog, set this option and you can use it in Zed. + +The `supports_images` option enables the model's vision capabilities, allowing it to process images included in the conversation context. +If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in Zed. + +### OpenAI {#openai} + +1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) +2. Make sure that your OpenAI account has credits +3. Open the settings view (`agent: open settings`) and go to the OpenAI section +4. Enter your OpenAI API key + +The OpenAI API key will be saved in your keychain. + +Zed will also use the `OPENAI_API_KEY` environment variable if it's defined. + +#### Custom Models {#openai-custom-models} + +The Zed agent comes pre-configured to use the latest version for common models (GPT-3.5 Turbo, GPT-4, GPT-4 Turbo, GPT-4o, GPT-4o mini). +To use alternate models, perhaps a preview release or a dated model release, or if you wish to control the request parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "openai": { + "available_models": [ + { + "name": "gpt-4o-2024-08-06", + "display_name": "GPT 4o Summer 2024", + "max_tokens": 128000 + }, + { + "name": "o1-mini", + "display_name": "o1-mini", + "max_tokens": 128000, + "max_completion_tokens": 20000 + } + ], + "version": "1" + } + } +} +``` + +You must provide the model's context window in the `max_tokens` parameter; this can be found in the [OpenAI model documentation](https://platform.openai.com/docs/models). + +OpenAI `o1` models should set `max_completion_tokens` as well to avoid incurring high reasoning token costs. +Custom models will be listed in the model dropdown in the Agent Panel. + +### OpenAI API Compatible {#openai-api-compatible} + +Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider. +This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models. + +You can add a custom, OpenAI-compatible model via either via the UI or by editing your `settings.json`. + +To do it via the UI, go to the Agent Panel settings (`agent: open settings`) and look for the "Add Provider" button to the right of the "LLM Providers" section title. +Then, fill up the input fields available in the modal. + +To do it via your `settings.json`, add the following snippet under `language_models`: + +```json +{ + "language_models": { + "openai": { + "api_url": "https://api.together.xyz/v1", // Using Together AI as an example + "available_models": [ + { + "name": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "display_name": "Together Mixtral 8x7B", + "max_tokens": 32768 + } + ] + } + } +} +``` + +Note that LLM API keys aren't stored in your settings file. +So, ensure you have it set in your environment variables (`OPENAI_API_KEY=`) so your settings can pick it up. + +### OpenRouter {#openrouter} + +OpenRouter provides access to multiple AI models through a single API. It supports tool use for compatible models. + +1. Visit [OpenRouter](https://openrouter.ai) and create an account +2. Generate an API key from your [OpenRouter keys page](https://openrouter.ai/keys) +3. Open the settings view (`agent: open settings`) and go to the OpenRouter section +4. Enter your OpenRouter API key + +The OpenRouter API key will be saved in your keychain. + +Zed will also use the `OPENROUTER_API_KEY` environment variable if it's defined. + +#### Custom Models {#openrouter-custom-models} + +You can add custom models to the OpenRouter provider by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "open_router": { + "api_url": "https://openrouter.ai/api/v1", + "available_models": [ + { + "name": "google/gemini-2.0-flash-thinking-exp", + "display_name": "Gemini 2.0 Flash (Thinking)", + "max_tokens": 200000, + "max_output_tokens": 8192, + "supports_tools": true, + "supports_images": true, + "mode": { + "type": "thinking", + "budget_tokens": 8000 + } + } + ] + } + } +} +``` + +The available configuration options for each model are: + +- `name` (required): The model identifier used by OpenRouter +- `display_name` (optional): A human-readable name shown in the UI +- `max_tokens` (required): The model's context window size +- `max_output_tokens` (optional): Maximum tokens the model can generate +- `max_completion_tokens` (optional): Maximum completion tokens +- `supports_tools` (optional): Whether the model supports tool/function calling +- `supports_images` (optional): Whether the model supports image inputs +- `mode` (optional): Special mode configuration for thinking models + +You can find available models and their specifications on the [OpenRouter models page](https://openrouter.ai/models). + +Custom models will be listed in the model dropdown in the Agent Panel. + +### Vercel v0 {#vercel-v0} + +[Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel. +It supports text and image inputs and provides fast streaming responses. + +The v0 models are [OpenAI-compatible models](/#openai-api-compatible), but Vercel is listed as first-class provider in the panel's settings view. + +To start using it with Zed, ensure you have first created a [v0 API key](https://v0.dev/chat/settings/keys). +Once you have it, paste it directly into the Vercel provider section in the panel's settings view. + +You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. + +### xAI {#xai} + +Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. + +1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) +2. Open the settings view (`agent: open settings`) and go to the **xAI** section +3. Enter your xAI API key + +The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined. + +> **Note:** While the xAI API is OpenAI-compatible, Zed has first-class support for it as a dedicated provider. For the best experience, we recommend using the dedicated `x_ai` provider configuration instead of the [OpenAI API Compatible](#openai-api-compatible) method. + +#### Custom Models {#xai-custom-models} + +The Zed agent comes pre-configured with common Grok models. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "x_ai": { + "api_url": "https://api.x.ai/v1", + "available_models": [ + { + "name": "grok-1.5", + "display_name": "Grok 1.5", + "max_tokens": 131072, + "max_output_tokens": 8192 + }, + { + "name": "grok-1.5v", + "display_name": "Grok 1.5V (Vision)", + "max_tokens": 131072, + "max_output_tokens": 8192, + "supports_images": true + } + ] + } + } +} +``` + +## Custom Provider Endpoints {#custom-provider-endpoint} + +You can use a custom API endpoint for different providers, as long as it's compatible with the provider's API structure. +To do so, add the following to your `settings.json`: + +```json +{ + "language_models": { + "some-provider": { + "api_url": "http://localhost:11434" + } + } +} +``` + +Currently, `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`. + +This is the same infrastructure that powers models that are, for example, [OpenAI-compatible](#openai-api-compatible). diff --git a/docs/src/ai/mcp.md b/docs/src/ai/mcp.md index 95929b2d7e..dfe3e4bdb9 100644 --- a/docs/src/ai/mcp.md +++ b/docs/src/ai/mcp.md @@ -50,7 +50,7 @@ You can connect them by adding their commands directly to your `settings.json`, } ``` -Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open configuration` action). +Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open settings` action). From there, you can add it through the modal that appears when you click the "Add Custom Server" button. ## Using MCP Servers @@ -75,7 +75,7 @@ Mentioning your MCP server by name helps the agent pick it up. If you want to ensure a given server will be used, you can create [a custom profile](./agent-panel.md#custom-profiles) by turning off the built-in tools (either all of them or the ones that would cause conflicts) and turning on only the tools coming from the MCP server. -As an example, [the Dagger team suggests](https://container-use.com/agent-integrations#add-container-use-agent-profile-optional) doing that with their [Container Use MCP server](https://zed.dev/extensions/container-use-mcp-server): +As an example, [the Dagger team suggests](https://container-use.com/agent-integrations#add-container-use-agent-profile-optional) doing that with their [Container Use MCP server](https://zed.dev/extensions/mcp-server-container-use): ```json "agent": { diff --git a/docs/src/ai/overview.md b/docs/src/ai/overview.md index f437b24ba6..6f081cb243 100644 --- a/docs/src/ai/overview.md +++ b/docs/src/ai/overview.md @@ -1,15 +1,12 @@ # AI -Zed smoothly integrates LLMs in multiple ways across the editor. -Learn how to get started with AI on Zed and all its capabilities. +Learn how to get started using AI with Zed and all its capabilities. ## Setting up AI in Zed - [Configuration](./configuration.md): Learn how to set up different language model providers like Anthropic, OpenAI, Ollama, Google AI, and more. -- [Models](./models.md): Learn about the various language models available in Zed. - -- [Subscription](./subscription.md): Learn about Zed's subscriptions and other billing-related information. +- [Subscription](./subscription.md): Learn about Zed's hosted model service and other billing-related information. - [Privacy and Security](./privacy-and-security.md): Understand how Zed handles privacy and security with AI features. diff --git a/docs/src/ai/plans-and-usage.md b/docs/src/ai/plans-and-usage.md index a1da17f50d..1e6616c79b 100644 --- a/docs/src/ai/plans-and-usage.md +++ b/docs/src/ai/plans-and-usage.md @@ -11,7 +11,7 @@ Please note that if you’re interested in just using Zed as the world’s faste ## Usage {#usage} -- A `prompt` in Zed is an input from the user, initiated on pressing enter, composed of one or many `requests`. A `prompt` can be initiated from the Agent Panel, or via Inline Assist. +- A `prompt` in Zed is an input from the user, initiated by pressing enter, composed of one or many `requests`. A `prompt` can be initiated from the Agent Panel, or via Inline Assist. - A `request` in Zed is a response to a `prompt`, plus any tool calls that are initiated as part of that response. There may be one `request` per `prompt`, or many. Most models offered by Zed are metered per-prompt. diff --git a/docs/src/ai/rules.md b/docs/src/ai/rules.md index ed916874ca..653b907a7d 100644 --- a/docs/src/ai/rules.md +++ b/docs/src/ai/rules.md @@ -5,7 +5,7 @@ Currently, Zed supports `.rules` files at the directory's root and the Rules Lib ## `.rules` files -Zed supports including `.rules` files at the top level of worktrees, and act as project-level instructions that are included in all of your interactions with the Agent Panel. +Zed supports including `.rules` files at the top level of worktrees, and they act as project-level instructions that are included in all of your interactions with the Agent Panel. Other names for this file are also supported for compatibility with other agents, but note that the first file which matches in this list will be used: - `.rules` diff --git a/docs/src/ai/temperature.md b/docs/src/ai/temperature.md deleted file mode 100644 index bb0cef6b51..0000000000 --- a/docs/src/ai/temperature.md +++ /dev/null @@ -1,23 +0,0 @@ -# Model Temperature - -Zed's settings allow you to specify a custom temperature for a provider and/or model: - -```json -"model_parameters": [ - // To set parameters for all requests to OpenAI models: - { - "provider": "openai", - "temperature": 0.5 - }, - // To set parameters for all requests in general: - { - "temperature": 0 - }, - // To set parameters for a specific provider and model: - { - "provider": "zed.dev", - "model": "claude-sonnet-4", - "temperature": 1.0 - } - ], -``` diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index cc4800fd6d..5fd27abad6 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -2588,6 +2588,7 @@ List of `integer` column numbers "font_features": null, "font_size": null, "line_height": "comfortable", + "minimum_contrast": 45, "option_as_meta": false, "button": true, "shell": "system", @@ -2883,6 +2884,30 @@ See Buffer Font Features } ``` +### Terminal: Minimum Contrast + +- Description: Controls the minimum contrast between foreground and background colors in the terminal. Uses the APCA (Accessible Perceptual Contrast Algorithm) for color adjustments. Set this to 0 to disable this feature. +- Setting: `minimum_contrast` +- Default: `45` + +**Options** + +`integer` values from 0 to 106. Common recommended values: + +- `0`: No contrast adjustment +- `45`: Minimum for large fluent text (default) +- `60`: Minimum for other content text +- `75`: Minimum for body text +- `90`: Preferred for body text + +```json +{ + "terminal": { + "minimum_contrast": 45 + } +} +``` + ### Terminal: Option As Meta - Description: Re-interprets the option keys to act like a 'meta' key, like in Emacs. @@ -3390,26 +3415,7 @@ Run the `theme selector: toggle` action in the command palette to see a current ## Agent -- Description: Customize agent behavior -- Setting: `agent` -- Default: - -```json -"agent": { - "version": "2", - "enabled": true, - "button": true, - "dock": "right", - "default_width": 640, - "default_height": 320, - "default_view": "thread", - "default_model": { - "provider": "zed.dev", - "model": "claude-sonnet-4" - }, - "single_file_review": true, -} -``` +Visit [the Configuration page](./ai/configuration.md) under the AI section to learn more about all the agent-related settings. ## Outline Panel diff --git a/docs/src/development/debugging-crashes.md b/docs/src/development/debugging-crashes.md index d08ab961cc..ed0a5807a3 100644 --- a/docs/src/development/debugging-crashes.md +++ b/docs/src/development/debugging-crashes.md @@ -6,6 +6,7 @@ When an app crashes, - macOS creates a `.ips` file in `~/Library/Logs/DiagnosticReports`. You can view these using the built in Console app (`cmd-space Console`) under "Crash Reports". - Linux creates a core dump. See the [man pages](https://man7.org/linux/man-pages/man5/core.5.html) for pointers to how your system might be configured to manage core dumps. +- Windows doesn't create crash reports by default, but can be configured to create "minidump" memory dumps upon applications crashing. If you have enabled Zed's telemetry these will be uploaded to us when you restart the app. They end up in a [Slack channel (internal only)](https://zed-industries.slack.com/archives/C04S6T1T7TQ). diff --git a/docs/src/development/linux.md b/docs/src/development/linux.md index 6fff25f6c1..d7b586be34 100644 --- a/docs/src/development/linux.md +++ b/docs/src/development/linux.md @@ -91,7 +91,7 @@ Zed has two main binaries: - You will need to build `crates/cli` and make its binary available in `$PATH` with the name `zed`. - You will need to build `crates/zed` and put it at `$PATH/to/cli/../../libexec/zed-editor`. For example, if you are going to put the cli at `~/.local/bin/zed` put zed at `~/.local/libexec/zed-editor`. As some linux distributions (notably Arch) discourage the use of `libexec`, you can also put this binary at `$PATH/to/cli/../../lib/zed/zed-editor` (e.g. `~/.local/lib/zed/zed-editor`) instead. -- If you are going to provide a `.desktop` file you can find a template in `crates/zed/resources/zed.desktop.in`, and use `envsubst` to populate it with the values required. This file should also be renamed to `$APP_ID.desktop` so that the file [follows the FreeDesktop standards](https://github.com/zed-industries/zed/issues/12707#issuecomment-2168742761). +- If you are going to provide a `.desktop` file you can find a template in `crates/zed/resources/zed.desktop.in`, and use `envsubst` to populate it with the values required. This file should also be renamed to `$APP_ID.desktop` so that the file [follows the FreeDesktop standards](https://github.com/zed-industries/zed/issues/12707#issuecomment-2168742761). You should also make this desktop file executable (`chmod 755`). - You will need to ensure that the necessary libraries are installed. You can get the current list by [inspecting the built binary](https://github.com/zed-industries/zed/blob/935cf542aebf55122ce6ed1c91d0fe8711970c82/script/bundle-linux#L65-L67) on your system. - For an example of a complete build script, see [script/bundle-linux](https://github.com/zed-industries/zed/blob/935cf542aebf55122ce6ed1c91d0fe8711970c82/script/bundle-linux). - You can disable Zed's auto updates and provide instructions for users who try to update Zed manually by building (or running) Zed with the environment variable `ZED_UPDATE_EXPLANATION`. For example: `ZED_UPDATE_EXPLANATION="Please use flatpak to update zed."`. diff --git a/docs/src/development/local-collaboration.md b/docs/src/development/local-collaboration.md index 9f0e3ef191..eb7f3dfc43 100644 --- a/docs/src/development/local-collaboration.md +++ b/docs/src/development/local-collaboration.md @@ -1,13 +1,27 @@ # Local Collaboration -First, make sure you've installed Zed's dependencies for your platform: +1. Ensure you have access to our cloud infrastructure. If you don't have access, you can't collaborate locally at this time. -- [macOS](./macos.md#backend-dependencies) -- [Linux](./linux.md#backend-dependencies) -- [Windows](./windows.md#backend-dependencies) +2. Make sure you've installed Zed's dependencies for your platform: + +- [macOS](#macos) +- [Linux](#linux) +- [Windows](#backend-windows) Note that `collab` can be compiled only with MSVC toolchain on Windows +3. Clone down our cloud repository and follow the instructions in the cloud README + +4. Setup the local database for your platform: + +- [macOS & Linux](#database-unix) +- [Windows](#database-windows) + +5. Run collab: + +- [macOS & Linux](#run-collab-unix) +- [Windows](#run-collab-windows) + ## Backend Dependencies If you are developing collaborative features of Zed, you'll need to install the dependencies of zed's `collab` server: @@ -18,7 +32,7 @@ If you are developing collaborative features of Zed, you'll need to install the You can install these dependencies natively or run them under Docker. -### MacOS +### macOS 1. Install [Postgres.app](https://postgresapp.com) or [postgresql via homebrew](https://formulae.brew.sh/formula/postgresql@15): @@ -76,7 +90,7 @@ docker compose up -d Before you can run the `collab` server locally, you'll need to set up a `zed` Postgres database. -### On macOS and Linux +### On macOS and Linux {#database-unix} ```sh script/bootstrap @@ -99,7 +113,7 @@ To use a different set of admin users, you can create your own version of that j } ``` -### On Windows +### On Windows {#database-windows} ```powershell .\script\bootstrap.ps1 @@ -107,7 +121,7 @@ To use a different set of admin users, you can create your own version of that j ## Testing collaborative features locally -### On macOS and Linux +### On macOS and Linux {#run-collab-unix} Ensure that Postgres is configured and running, then run Zed's collaboration server and the `livekit` dev server: @@ -117,12 +131,16 @@ foreman start docker compose up ``` -Alternatively, if you're not testing voice and screenshare, you can just run `collab`, and not the `livekit` dev server: +Alternatively, if you're not testing voice and screenshare, you can just run `collab` and `cloud`, and not the `livekit` dev server: ```sh cargo run -p collab -- serve all ``` +```sh +cd ../cloud; cargo make dev +``` + In a new terminal, run two or more instances of Zed. ```sh @@ -131,7 +149,7 @@ script/zed-local -3 This script starts one to four instances of Zed, depending on the `-2`, `-3` or `-4` flags. Each instance will be connected to the local `collab` server, signed in as a different user from `.admins.json` or `.admins.default.json`. -### On Windows +### On Windows {#run-collab-windows} Since `foreman` is not available on Windows, you can run the following commands in separate terminals: @@ -151,6 +169,12 @@ Otherwise, .\path\to\livekit-serve.exe --dev ``` +You'll also need to start the cloud server: + +```powershell +cd ..\cloud; cargo make dev +``` + In a new terminal, run two or more instances of Zed. ```powershell @@ -161,7 +185,10 @@ Note that this requires `node.exe` to be in your `PATH`. ## Running a local collab server -If you want to run your own version of the zed collaboration service, you can, but note that this is still under development, and there is no good support for authentication nor extensions. +> [!NOTE] +> Because of recent changes to our authentication system, Zed will not be able to authenticate itself with, and therefore use, a local collab server. + +If you want to run your own version of the zed collaboration service, you can, but note that this is still under development, and there is no support for authentication nor extensions. Configuration is done through environment variables. By default it will read the configuration from [`.env.toml`](https://github.com/zed-industries/zed/blob/main/crates/collab/.env.toml) and you should use that as a guide for setting this up. diff --git a/docs/src/extensions/installing-extensions.md b/docs/src/extensions/installing-extensions.md index aed8bef428..801fe5c55c 100644 --- a/docs/src/extensions/installing-extensions.md +++ b/docs/src/extensions/installing-extensions.md @@ -1,6 +1,6 @@ # Installing Extensions -You can search for extensions by launching the Zed Extension Gallery by pressing `cmd-shift-x` (macOS) or `ctrl-shift-x` (Linux), opening the command palette and selecting `zed: extensions` or by selecting "Zed > Extensions" from the menu bar. +You can search for extensions by launching the Zed Extension Gallery by pressing {#kb zed::Extensions} , opening the command palette and selecting {#action zed::Extensions} or by selecting "Zed > Extensions" from the menu bar. Here you can view the extensions that you currently have installed or search and install new ones. diff --git a/docs/src/extensions/languages.md b/docs/src/extensions/languages.md index 44c673e3e1..6756cb8a23 100644 --- a/docs/src/extensions/languages.md +++ b/docs/src/extensions/languages.md @@ -402,11 +402,10 @@ If your language server supports additional languages, you can use `language_ids [language-servers.my-language-server] name = "Whatever LSP" -languages = ["JavaScript", "JSX", "HTML", "CSS"] +languages = ["JavaScript", "HTML", "CSS"] [language-servers.my-language-server.language_ids] "JavaScript" = "javascript" -"JSX" = "javascriptreact" "TSX" = "typescriptreact" "HTML" = "html" "CSS" = "css" diff --git a/docs/src/getting-started.md b/docs/src/getting-started.md index 5940c74b21..22af3b36d7 100644 --- a/docs/src/getting-started.md +++ b/docs/src/getting-started.md @@ -83,6 +83,6 @@ Visit [the AI overview page](./ai/overview.md) to learn how to quickly get start ## Set up your key bindings -To open your custom keymap to add your key bindings, use the {#kb zed::OpenKeymap} keybinding. +To edit your custom keymap and add or remap bindings, you can either use {#kb zed::OpenKeymapEditor} to spawn the Zed Keymap Editor ({#action zed::OpenKeymapEditor}) or you can directly open your Zed Keymap json (`~/.config/zed/keymap.json`) with {#action zed::OpenKeymap}. To access the default key binding set, open the Command Palette with {#kb command_palette::Toggle} and search for "zed: open default keymap". See [Key Bindings](./key-bindings.md) for more info. diff --git a/docs/src/git.md b/docs/src/git.md index 76db15a767..cccbad9b2e 100644 --- a/docs/src/git.md +++ b/docs/src/git.md @@ -1,3 +1,8 @@ +--- +description: Zed is a text editor that supports lots of Git features +title: Zed Editor Git integration documentation +--- + # Git Zed currently offers a set of fundamental Git features, with support coming in the future for more advanced ones, like conflict resolution tools, line by line staging, and more. @@ -76,7 +81,7 @@ You can ask AI to generate a commit message by focusing on the message editor wi > Note that you need to have an LLM provider configured. Visit [the AI configuration page](./ai/configuration.md) to learn how to do so. -You can specify your preferred model to use by providing a `commit_message_model` agent setting. See [Feature-specific models](./ai/configuration.md#feature-specific-models) for more information. +You can specify your preferred model to use by providing a `commit_message_model` agent setting. See [Feature-specific models](./ai/agent-settings.md#feature-specific-models) for more information. ```json { diff --git a/docs/src/key-bindings.md b/docs/src/key-bindings.md index 90aa400bb4..feed912787 100644 --- a/docs/src/key-bindings.md +++ b/docs/src/key-bindings.md @@ -18,7 +18,7 @@ You can also enable `vim_mode`, which adds vim bindings too. ## User keymaps -Zed reads your keymap from `~/.config/zed/keymap.json`. You can open the file within Zed with {#kb zed::OpenKeymap}, or via `zed: Open Keymap` in the command palette. +Zed reads your keymap from `~/.config/zed/keymap.json`. You can open the file within Zed with {#action zed::OpenKeymap} from the command palette or to spawn the Zed Keymap Editor ({#action zed::OpenKeymapEditor}) use {#kb zed::OpenKeymapEditor}. The file contains a JSON array of objects with `"bindings"`. If no `"context"` is set the bindings are always active. If it is set the binding is only active when the [context matches](#contexts). @@ -93,7 +93,7 @@ For example: # in an editor, it might look like this: Workspace os=macos keyboard_layout=com.apple.keylayout.QWERTY Pane - Editor mode=full extension=md inline_completion vim_mode=insert + Editor mode=full extension=md vim_mode=insert # in the project panel Workspace os=macos diff --git a/docs/src/languages/c.md b/docs/src/languages/c.md index 14a11c0d66..8db1bb6712 100644 --- a/docs/src/languages/c.md +++ b/docs/src/languages/c.md @@ -77,7 +77,7 @@ You can use CodeLLDB or GDB to debug native binaries. (Make sure that your build "command": "make", "args": ["-j8"], "cwd": "$ZED_WORKTREE_ROOT" - } + }, "program": "$ZED_WORKTREE_ROOT/build/prog", "request": "launch", "adapter": "CodeLLDB" diff --git a/docs/src/languages/cpp.md b/docs/src/languages/cpp.md index 1273bce2ac..e84bb6ea50 100644 --- a/docs/src/languages/cpp.md +++ b/docs/src/languages/cpp.md @@ -127,7 +127,7 @@ You can use CodeLLDB or GDB to debug native binaries. (Make sure that your build "command": "make", "args": ["-j8"], "cwd": "$ZED_WORKTREE_ROOT" - } + }, "program": "$ZED_WORKTREE_ROOT/build/prog", "request": "launch", "adapter": "CodeLLDB" diff --git a/docs/src/languages/deno.md b/docs/src/languages/deno.md index c18b112326..c40b6531e6 100644 --- a/docs/src/languages/deno.md +++ b/docs/src/languages/deno.md @@ -57,6 +57,40 @@ See [Configuring supported languages](../configuring-languages.md) in the Zed do TBD: Deno Typescript REPL instructions [docs/repl#typescript-deno](../repl.md#typescript-deno) --> +## DAP support + +To debug deno programs, add this to `.zed/debug.json` + +```json +[ + { + "adapter": "JavaScript", + "label": "Deno", + "request": "launch", + "type": "pwa-node", + "cwd": "$ZED_WORKTREE_ROOT", + "program": "$ZED_FILE", + "runtimeExecutable": "deno", + "runtimeArgs": ["run", "--allow-all", "--inspect-wait"], + "attachSimplePort": 9229 + } +] +``` + +## Runnable support + +To run deno tasks like tests from the ui, add this to `.zed/tasks.json` + +```json +[ + { + "label": "deno test", + "command": "deno test -A --filter '/^$ZED_CUSTOM_DENO_TEST_NAME$/' $ZED_FILE", + "tags": ["js-test"] + } +] +``` + ## See also: - [TypeScript](./typescript.md) diff --git a/docs/src/linux.md b/docs/src/linux.md index ca65da2969..309354de6d 100644 --- a/docs/src/linux.md +++ b/docs/src/linux.md @@ -294,3 +294,78 @@ If your system uses PipeWire: ``` 3. **Restart your system** + +### Forcing X11 scale factor + +On X11 systems, Zed automatically detects the appropriate scale factor for high-DPI displays. The scale factor is determined using the following priority order: + +1. `GPUI_X11_SCALE_FACTOR` environment variable (if set) +2. `Xft.dpi` from X resources database (xrdb) +3. Automatic detection via RandR based on monitor resolution and physical size + +If you want to customize the scale factor beyond what Zed detects automatically, you have several options: + +#### Check your current scale factor + +You can verify if you have `Xft.dpi` set: + +```sh +xrdb -query | grep Xft.dpi +``` + +If this command returns no output, Zed is using RandR (X11's monitor management extension) to automatically calculate the scale factor based on your monitor's reported resolution and physical dimensions. + +#### Option 1: Set Xft.dpi (X Resources Database) + +`Xft.dpi` is a standard X11 setting that many applications use for consistent font and UI scaling. Setting this ensures Zed scales the same way as other X11 applications that respect this setting. + +Edit or create the `~/.Xresources` file: + +```sh +vim ~/.Xresources +``` + +Add this line with your desired DPI: + +```sh +Xft.dpi: 96 +``` + +Common DPI values: + +- `96` for standard 1x scaling +- `144` for 1.5x scaling +- `192` for 2x scaling +- `288` for 3x scaling + +Load the configuration: + +```sh +xrdb -merge ~/.Xresources +``` + +Restart Zed for the changes to take effect. + +#### Option 2: Use the GPUI_X11_SCALE_FACTOR environment variable + +This Zed-specific environment variable directly sets the scale factor, bypassing all automatic detection. + +```sh +GPUI_X11_SCALE_FACTOR=1.5 zed +``` + +You can use decimal values (e.g., `1.25`, `1.5`, `2.0`) or set `GPUI_X11_SCALE_FACTOR=randr` to force RandR-based detection even when `Xft.dpi` is set. + +To make this permanent, add it to your shell profile or desktop entry. + +#### Option 3: Adjust system-wide RandR DPI + +This changes the reported DPI for your entire X11 session, affecting how RandR calculates scaling for all applications that use it. + +Add this to your `.xprofile` or `.xinitrc`: + +```sh +xrandr --dpi 192 +``` + +Replace `192` with your desired DPI value. This affects the system globally and will be used by Zed's automatic RandR detection when `Xft.dpi` is not set. diff --git a/docs/src/telemetry.md b/docs/src/telemetry.md index 20018b920a..107aef5a96 100644 --- a/docs/src/telemetry.md +++ b/docs/src/telemetry.md @@ -21,19 +21,20 @@ The telemetry settings can also be configured via the welcome screen, which can Telemetry is sent from the application to our servers. Data is proxied through our servers to enable us to easily switch analytics services. We currently use: -- [Axiom](https://axiom.co): Cloud-monitoring service - stores diagnostic events -- [Snowflake](https://snowflake.com): Business Intelligence platform - stores both diagnostic and metric events -- [Metabase](https://www.metabase.com): Dashboards - dashboards built around data pulled from Snowflake +- [Sentry](https://sentry.io): Crash-monitoring service - stores diagnostic events +- [Snowflake](https://snowflake.com): Data warehouse - stores both diagnostic and metric events +- [Hex](https://www.hex.tech): Dashboards and data exploration - accesses data stored in Snowflake +- [Amplitude](https://www.amplitude.com): Dashboards and data exploration - accesses data stored in Snowflake ## Types of Telemetry ### Diagnostics -Diagnostic events include debug information (stack traces) from crash reports. Reports are sent on the first application launch after the crash occurred. We've built dashboards that allow us to visualize the frequency and severity of issues experienced by users. Having these reports sent automatically allows us to begin implementing fixes without the user needing to file a report in our issue tracker. The plots in the dashboards also give us an informal measurement of the stability of Zed. +Crash reports consist of a [minidump](https://learn.microsoft.com/en-us/windows/win32/debug/minidump-files) and some extra debug information. Reports are sent on the first application launch after the crash occurred. We've built dashboards that allow us to visualize the frequency and severity of issues experienced by users. Having these reports sent automatically allows us to begin implementing fixes without the user needing to file a report in our issue tracker. The plots in the dashboards also give us an informal measurement of the stability of Zed. -You can see what data is sent when a panic occurs by inspecting the `Panic` struct in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repo. You can find additional information in the [Debugging Crashes](./development/debugging-crashes.md) documentation. +You can see what extra data is sent alongside the minidump in the `Panic` struct in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repo. You can find additional information in the [Debugging Crashes](./development/debugging-crashes.md) documentation. -### Usage Data (Metrics) {#metrics} +### Client-Side Usage Data {#client-metrics} To improve Zed and understand how it is being used in the wild, Zed optionally collects usage data like the following: @@ -50,6 +51,12 @@ You can audit the metrics data that Zed has reported by running the command {#ac You can see the full list of the event types and exactly the data sent for each by inspecting the `Event` enum and the associated structs in [crates/telemetry_events/src/telemetry_events.rs](https://github.com/zed-industries/zed/blob/main/crates/telemetry_events/src/telemetry_events.rs) in the Zed repository. +### Server-Side Usage Data {#metrics} + +When using Zed's hosted services, we may collect, generate, and Process data to allow us to support users and improve our hosted offering. Examples include metadata around rate limiting and billing metrics/token usage. Zed does not persistently store user content or use user content to evaluate and/or improve our AI features, unless it is explicitly shared with Zed, and we have a zero-data retention agreement with Anthropic. + +You can see more about our stance on data collection (and that any prompt data shared with Zed is explicitly opt-in) at [AI Improvement](./ai/ai-improvement.md). + ## Concerns and Questions If you have concerns about telemetry, please feel free to [open an issue](https://github.com/zed-industries/zed/issues/new/choose). diff --git a/docs/src/visual-customization.md b/docs/src/visual-customization.md index 197c9b80f8..8b307d97d5 100644 --- a/docs/src/visual-customization.md +++ b/docs/src/visual-customization.md @@ -267,7 +267,7 @@ TBD: Centered layout related settings "display_in": "active_editor", // Where to show (active_editor, all_editor) "thumb": "always", // When to show thumb (always, hover) "thumb_border": "left_open", // Thumb border (left_open, right_open, full, none) - "max_width_columns": 80 // Maximum width of minimap + "max_width_columns": 80, // Maximum width of minimap "current_line_highlight": null // Highlight current line (null, line, gutter) }, diff --git a/docs/theme/index.hbs b/docs/theme/index.hbs index 8ab4f21cf1..4339a02d17 100644 --- a/docs/theme/index.hbs +++ b/docs/theme/index.hbs @@ -15,7 +15,7 @@ {{> head}} - + diff --git a/extensions/emmet/Cargo.toml b/extensions/emmet/Cargo.toml index db8aaaae41..9d72a6c5c4 100644 --- a/extensions/emmet/Cargo.toml +++ b/extensions/emmet/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zed_emmet" -version = "0.0.3" +version = "0.0.4" edition.workspace = true publish.workspace = true license = "Apache-2.0" diff --git a/extensions/ruff/Cargo.toml b/extensions/ruff/Cargo.toml index 830897279a..24616f963b 100644 --- a/extensions/ruff/Cargo.toml +++ b/extensions/ruff/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zed_ruff" -version = "0.1.0" +version = "0.1.1" edition.workspace = true publish.workspace = true license = "Apache-2.0" diff --git a/extensions/ruff/extension.toml b/extensions/ruff/extension.toml index 63929fc191..1f5a7314f4 100644 --- a/extensions/ruff/extension.toml +++ b/extensions/ruff/extension.toml @@ -1,7 +1,7 @@ id = "ruff" name = "Ruff" description = "Support for Ruff, the Python linter and formatter" -version = "0.1.0" +version = "0.1.1" schema_version = 1 authors = [] repository = "https://github.com/zed-industries/zed" diff --git a/nix/build.nix b/nix/build.nix index 873431a427..70b4f76932 100644 --- a/nix/build.nix +++ b/nix/build.nix @@ -298,6 +298,7 @@ craneLib.buildPackage ( export APP_ARGS="%U" mkdir -p "$out/share/applications" ${lib.getExe envsubst} < "crates/zed/resources/zed.desktop.in" > "$out/share/applications/dev.zed.Zed-Nightly.desktop" + chmod +x "$out/share/applications/dev.zed.Zed-Nightly.desktop" ) runHook postInstall diff --git a/script/bundle-freebsd b/script/bundle-freebsd index 7222a06256..87c9459ffb 100755 --- a/script/bundle-freebsd +++ b/script/bundle-freebsd @@ -138,6 +138,7 @@ fi # mkdir -p "${zed_dir}/share/applications" # envsubst <"crates/zed/resources/zed.desktop.in" >"${zed_dir}/share/applications/zed$suffix.desktop" +# chmod +x "${zed_dir}/share/applications/zed$suffix.desktop" # Copy generated licenses so they'll end up in archive too # cp "assets/licenses.md" "${zed_dir}/licenses.md" diff --git a/script/bundle-linux b/script/bundle-linux index c52312015b..ad67b7a0f7 100755 --- a/script/bundle-linux +++ b/script/bundle-linux @@ -83,6 +83,23 @@ if [[ "$remote_server_triple" == "$musl_triple" ]]; then fi cargo build --release --target "${remote_server_triple}" --package remote_server +# Upload debug info to sentry.io +if ! command -v sentry-cli >/dev/null 2>&1; then + echo "sentry-cli not found. skipping sentry upload." + echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" +else + if [[ -n "${SENTRY_AUTH_TOKEN:-}" ]]; then + echo "Uploading zed debug symbols to sentry..." + # note: this uploads the unstripped binary which is needed because it contains + # .eh_frame data for stack unwinindg. see https://github.com/getsentry/symbolic/issues/783 + sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev \ + "${target_dir}/${target_triple}"/release/zed \ + "${target_dir}/${remote_server_triple}"/release/remote_server + else + echo "missing SENTRY_AUTH_TOKEN. skipping sentry upload." + fi +fi + # Strip debug symbols and save them for upload to DigitalOcean objcopy --only-keep-debug "${target_dir}/${target_triple}/release/zed" "${target_dir}/${target_triple}/release/zed.dbg" objcopy --only-keep-debug "${target_dir}/${remote_server_triple}/release/remote_server" "${target_dir}/${remote_server_triple}/release/remote_server.dbg" @@ -162,6 +179,7 @@ fi mkdir -p "${zed_dir}/share/applications" envsubst < "crates/zed/resources/zed.desktop.in" > "${zed_dir}/share/applications/zed$suffix.desktop" +chmod +x "${zed_dir}/share/applications/zed$suffix.desktop" # Copy generated licenses so they'll end up in archive too cp "assets/licenses.md" "${zed_dir}/licenses.md" diff --git a/script/bundle-mac b/script/bundle-mac index 18dfe90815..b2be573235 100755 --- a/script/bundle-mac +++ b/script/bundle-mac @@ -366,3 +366,20 @@ else gzip -f --stdout --best target/x86_64-apple-darwin/release/remote_server > target/zed-remote-server-macos-x86_64.gz gzip -f --stdout --best target/aarch64-apple-darwin/release/remote_server > target/zed-remote-server-macos-aarch64.gz fi + +# Upload debug info to sentry.io +if ! command -v sentry-cli >/dev/null 2>&1; then + echo "sentry-cli not found. skipping sentry upload." + echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" +else + if [[ -n "${SENTRY_AUTH_TOKEN:-}" ]]; then + echo "Uploading zed debug symbols to sentry..." + # note: this uploads the unstripped binary which is needed because it contains + # .eh_frame data for stack unwinindg. see https://github.com/getsentry/symbolic/issues/783 + sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev \ + "target/x86_64-apple-darwin/${target_dir}/" \ + "target/aarch64-apple-darwin/${target_dir}/" + else + echo "missing SENTRY_AUTH_TOKEN. skipping sentry upload." + fi +fi diff --git a/script/bundle-windows.ps1 b/script/bundle-windows.ps1 index 01a1114c26..8ae0212491 100644 --- a/script/bundle-windows.ps1 +++ b/script/bundle-windows.ps1 @@ -26,6 +26,7 @@ if ($Help) { Push-Location -Path crates/zed $channel = Get-Content "RELEASE_CHANNEL" $env:ZED_RELEASE_CHANNEL = $channel +$env:RELEASE_CHANNEL = $channel Pop-Location function CheckEnvironmentVariables { @@ -96,6 +97,21 @@ function ZipZedAndItsFriendsDebug { Compress-Archive -Path $items -DestinationPath ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" -Force } + +function UploadToSentry { + if (-not (Get-Command "sentry-cli" -ErrorAction SilentlyContinue)) { + Write-Output "sentry-cli not found. skipping sentry upload." + Write-Output "install with: 'winget install -e --id=Sentry.sentry-cli'" + return + } + if (-not (Test-Path "env:SENTRY_AUTH_TOKEN")) { + Write-Output "missing SENTRY_AUTH_TOKEN. skipping sentry upload." + return + } + Write-Output "Uploading zed debug symbols to sentry..." + sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev .\target\release\ +} + function MakeAppx { switch ($channel) { "stable" { @@ -120,11 +136,22 @@ function SignZedAndItsFriends { & "$innoDir\sign.ps1" $files } +function DownloadAMDGpuServices { + # If you update the AGS SDK version, please also update the version in `crates/gpui/src/platform/windows/directx_renderer.rs` + $url = "https://codeload.github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/zip/refs/tags/v6.3.0" + $zipPath = ".\AGS_SDK_v6.3.0.zip" + # Download the AGS SDK zip file + Invoke-WebRequest -Uri $url -OutFile $zipPath + # Extract the AGS SDK zip file + Expand-Archive -Path $zipPath -DestinationPath "." -Force +} + function CollectFiles { Move-Item -Path "$innoDir\zed_explorer_command_injector.appx" -Destination "$innoDir\appx\zed_explorer_command_injector.appx" -Force Move-Item -Path "$innoDir\zed_explorer_command_injector.dll" -Destination "$innoDir\appx\zed_explorer_command_injector.dll" -Force Move-Item -Path "$innoDir\cli.exe" -Destination "$innoDir\bin\zed.exe" -Force Move-Item -Path "$innoDir\auto_update_helper.exe" -Destination "$innoDir\tools\auto_update_helper.exe" -Force + Move-Item -Path ".\AGS_SDK-6.3.0\ags_lib\lib\amd_ags_x64.dll" -Destination "$innoDir\amd_ags_x64.dll" -Force } function BuildInstaller { @@ -195,7 +222,6 @@ function BuildInstaller { # Windows runner 2022 default has iscc in PATH, https://github.com/actions/runner-images/blob/main/images/windows/Windows2022-Readme.md # Currently, we are using Windows 2022 runner. # Windows runner 2025 doesn't have iscc in PATH for now, https://github.com/actions/runner-images/issues/11228 - # $innoSetupPath = "iscc.exe" $innoSetupPath = "C:\Program Files (x86)\Inno Setup 6\ISCC.exe" $definitions = @{ @@ -242,6 +268,8 @@ function BuildInstaller { ParseZedWorkspace $innoDir = "$env:ZED_WORKSPACE\inno" +$debugArchive = ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" +$debugStoreKey = "$env:ZED_RELEASE_CHANNEL/zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" CheckEnvironmentVariables PrepareForBundle @@ -250,12 +278,12 @@ BuildZedAndItsFriends MakeAppx SignZedAndItsFriends ZipZedAndItsFriendsDebug +DownloadAMDGpuServices CollectFiles BuildInstaller -$debugArchive = ".\target\release\zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" -$debugStoreKey = "$env:ZED_RELEASE_CHANNEL/zed-$env:RELEASE_VERSION-$env:ZED_RELEASE_CHANNEL.dbg.zip" UploadToBlobStorePublic -BucketName "zed-debug-symbols" -FileToUpload $debugArchive -BlobStoreKey $debugStoreKey +UploadToSentry if ($buildSuccess) { Write-Output "Build successful" diff --git a/script/linux b/script/linux index 98ae026896..029278bea3 100755 --- a/script/linux +++ b/script/linux @@ -143,6 +143,7 @@ if [[ -n $zyp ]]; then gzip jq libvulkan1 + libx11-devel libxcb-devel libxkbcommon-devel libxkbcommon-x11-devel diff --git a/script/zed-local b/script/zed-local index 2568931246..99d9308232 100755 --- a/script/zed-local +++ b/script/zed-local @@ -213,7 +213,7 @@ setTimeout(() => { platform === "win32" ? "http://127.0.0.1:8080/rpc" : "http://localhost:8080/rpc", - ZED_ADMIN_API_TOKEN: "secret", + ZED_ADMIN_API_TOKEN: "internal-api-key-secret", ZED_WINDOW_SIZE: size, ZED_CLIENT_CHECKSUM_SEED: "development-checksum-seed", RUST_LOG: process.env.RUST_LOG || "info", diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index 1026454026..5678e46236 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -82,6 +82,7 @@ lyon = { version = "1", default-features = false, features = ["extra"] } lyon_path = { version = "1" } md-5 = { version = "0.10" } memchr = { version = "2" } +mime_guess = { version = "2" } miniz_oxide = { version = "0.8", features = ["simd"] } nom = { version = "7" } num-bigint = { version = "0.4" } @@ -212,6 +213,7 @@ lyon = { version = "1", default-features = false, features = ["extra"] } lyon_path = { version = "1" } md-5 = { version = "0.10" } memchr = { version = "2" } +mime_guess = { version = "2" } miniz_oxide = { version = "0.8", features = ["simd"] } nom = { version = "7" } num-bigint = { version = "0.4" } @@ -284,14 +286,13 @@ winnow = { version = "0.7", features = ["simd"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -310,18 +311,16 @@ tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } [target.x86_64-apple-darwin.build-dependencies] -clang-sys = { version = "1", default-features = false, features = ["clang_11_0", "runtime"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -344,14 +343,13 @@ tower = { version = "0.5", default-features = false, features = ["timeout", "uti codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -370,18 +368,16 @@ tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } [target.aarch64-apple-darwin.build-dependencies] -clang-sys = { version = "1", default-features = false, features = ["clang_11_0", "runtime"] } codespan-reporting = { version = "0.12" } core-foundation = { version = "0.9" } core-foundation-sys = { version = "0.8" } -coreaudio-sys = { version = "0.2", default-features = false, features = ["audio_toolbox", "audio_unit", "core_audio", "core_midi", "open_al"] } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } @@ -420,7 +416,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } @@ -460,7 +457,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } @@ -498,7 +496,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } @@ -538,7 +537,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } @@ -564,7 +564,6 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -naga = { version = "25", features = ["spv-out", "wgsl-in"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } scopeguard = { version = "1" } @@ -578,7 +577,7 @@ windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } -windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } +windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-pc-windows-msvc.build-dependencies] codespan-reporting = { version = "0.12" } @@ -588,7 +587,6 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -naga = { version = "25", features = ["spv-out", "wgsl-in"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } @@ -603,7 +601,7 @@ windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } -windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } +windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-unknown-linux-musl.dependencies] aes = { version = "0.8", default-features = false, features = ["zeroize"] } @@ -625,7 +623,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } @@ -665,7 +664,8 @@ linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", d linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } mio = { version = "1", features = ["net", "os-ext"] } naga = { version = "25", features = ["spv-out", "wgsl-in"] } -nix = { version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } +nix-1f5adca70f036a62 = { package = "nix", version = "0.28", features = ["fs", "mman", "ptrace", "signal", "term", "user"] } +nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "socket", "uio", "user"] } num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } diff --git a/typos.toml b/typos.toml index 7f1c6e04f1..336a829a44 100644 --- a/typos.toml +++ b/typos.toml @@ -71,6 +71,10 @@ extend-ignore-re = [ # Not an actual typo but an intentionally invalid color, in `color_extractor` "#fof", # Stripped version of reserved keyword `type` - "typ" + "typ", + # AMD GPU Services + "ags", + # AMD GPU Services + "AGS" ] check-filename = true