diff --git a/.github/actionlint.yml b/.github/actionlint.yml index 6bfbc27705..ad09545902 100644 --- a/.github/actionlint.yml +++ b/.github/actionlint.yml @@ -5,25 +5,25 @@ self-hosted-runner: # GitHub-hosted Runners - github-8vcpu-ubuntu-2404 - github-16vcpu-ubuntu-2404 + - github-32vcpu-ubuntu-2404 + - github-8vcpu-ubuntu-2204 + - github-16vcpu-ubuntu-2204 + - github-32vcpu-ubuntu-2204 + - github-16vcpu-ubuntu-2204-arm - windows-2025-16 - windows-2025-32 - windows-2025-64 - # Buildjet Ubuntu 20.04 - AMD x86_64 - - buildjet-2vcpu-ubuntu-2004 - - buildjet-4vcpu-ubuntu-2004 - - buildjet-8vcpu-ubuntu-2004 - - buildjet-16vcpu-ubuntu-2004 - - buildjet-32vcpu-ubuntu-2004 - # Buildjet Ubuntu 22.04 - AMD x86_64 - - buildjet-2vcpu-ubuntu-2204 - - buildjet-4vcpu-ubuntu-2204 - - buildjet-8vcpu-ubuntu-2204 - - buildjet-16vcpu-ubuntu-2204 - - buildjet-32vcpu-ubuntu-2204 - # Buildjet Ubuntu 22.04 - Graviton aarch64 - - buildjet-8vcpu-ubuntu-2204-arm - - buildjet-16vcpu-ubuntu-2204-arm - - buildjet-32vcpu-ubuntu-2204-arm + # Namespace Ubuntu 20.04 (Release builds) + - namespace-profile-16x32-ubuntu-2004 + - namespace-profile-32x64-ubuntu-2004 + - namespace-profile-16x32-ubuntu-2004-arm + - namespace-profile-32x64-ubuntu-2004-arm + # Namespace Ubuntu 22.04 (Everything else) + - namespace-profile-2x4-ubuntu-2204 + - namespace-profile-4x8-ubuntu-2204 + - namespace-profile-8x16-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 + - namespace-profile-32x64-ubuntu-2204 # Self Hosted Runners - self-mini-macos - self-32vcpu-windows-2022 diff --git a/.github/actions/build_docs/action.yml b/.github/actions/build_docs/action.yml index a7effad247..d2e62d5b22 100644 --- a/.github/actions/build_docs/action.yml +++ b/.github/actions/build_docs/action.yml @@ -13,7 +13,7 @@ runs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Linux dependencies shell: bash -euxo pipefail {0} diff --git a/.github/workflows/bump_patch_version.yml b/.github/workflows/bump_patch_version.yml index 8a48ff96f1..bfaf7a271b 100644 --- a/.github/workflows/bump_patch_version.yml +++ b/.github/workflows/bump_patch_version.yml @@ -16,7 +16,7 @@ jobs: bump_patch_version: if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Checkout code uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 43d305faae..84907351fe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -137,7 +137,7 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - buildjet-8vcpu-ubuntu-2204 + - namespace-profile-8x16-ubuntu-2204 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -168,7 +168,7 @@ jobs: needs: [job_spec] if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-8vcpu-ubuntu-2204 + - namespace-profile-4x8-ubuntu-2204 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -221,7 +221,7 @@ jobs: github.repository_owner == 'zed-industries' && (needs.job_spec.outputs.run_tests == 'true' || needs.job_spec.outputs.run_docs == 'true') runs-on: - - buildjet-8vcpu-ubuntu-2204 + - namespace-profile-8x16-ubuntu-2204 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -328,7 +328,7 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" @@ -342,7 +342,7 @@ jobs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Linux dependencies run: ./script/linux @@ -380,7 +380,7 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - buildjet-8vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" @@ -394,7 +394,7 @@ jobs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Clang & Mold run: ./script/remote-server && ./script/install-mold 2.34.0 @@ -597,7 +597,7 @@ jobs: timeout-minutes: 60 name: Linux x86_x64 release bundle runs-on: - - buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc + - namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') @@ -650,7 +650,7 @@ jobs: timeout-minutes: 60 name: Linux arm64 release bundle runs-on: - - buildjet-32vcpu-ubuntu-2204-arm + - namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') diff --git a/.github/workflows/deploy_cloudflare.yml b/.github/workflows/deploy_cloudflare.yml index fe443d493e..df35d44ca9 100644 --- a/.github/workflows/deploy_cloudflare.yml +++ b/.github/workflows/deploy_cloudflare.yml @@ -9,7 +9,7 @@ jobs: deploy-docs: name: Deploy Docs if: github.repository_owner == 'zed-industries' - runs-on: buildjet-16vcpu-ubuntu-2204 + runs-on: namespace-profile-16x32-ubuntu-2204 steps: - name: Checkout repo diff --git a/.github/workflows/deploy_collab.yml b/.github/workflows/deploy_collab.yml index f7348a1069..ff2a3589e4 100644 --- a/.github/workflows/deploy_collab.yml +++ b/.github/workflows/deploy_collab.yml @@ -61,7 +61,7 @@ jobs: - style - tests runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Install doctl uses: digitalocean/action-doctl@v2 @@ -94,7 +94,7 @@ jobs: needs: - publish runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Checkout repo @@ -137,12 +137,14 @@ jobs: export ZED_SERVICE_NAME=collab export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT + export DATABASE_MAX_CONNECTIONS=850 envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" export ZED_SERVICE_NAME=api export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_API_LOAD_BALANCER_SIZE_UNIT + export DATABASE_MAX_CONNECTIONS=60 envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" diff --git a/.github/workflows/eval.yml b/.github/workflows/eval.yml index 2ad302a602..b5da9e7b7c 100644 --- a/.github/workflows/eval.yml +++ b/.github/workflows/eval.yml @@ -32,7 +32,7 @@ jobs: github.repository_owner == 'zed-industries' && (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval')) runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" @@ -46,7 +46,7 @@ jobs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Linux dependencies run: ./script/linux diff --git a/.github/workflows/nix.yml b/.github/workflows/nix.yml index 6c3a97c163..e682ce5890 100644 --- a/.github/workflows/nix.yml +++ b/.github/workflows/nix.yml @@ -20,7 +20,7 @@ jobs: matrix: system: - os: x86 Linux - runner: buildjet-16vcpu-ubuntu-2204 + runner: namespace-profile-16x32-ubuntu-2204 install_nix: true - os: arm Mac runner: [macOS, ARM64, test] diff --git a/.github/workflows/randomized_tests.yml b/.github/workflows/randomized_tests.yml index db4d44318e..de96c3df78 100644 --- a/.github/workflows/randomized_tests.yml +++ b/.github/workflows/randomized_tests.yml @@ -20,7 +20,7 @@ jobs: name: Run randomized tests if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Install Node uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index c847149984..b3500a085b 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -128,7 +128,7 @@ jobs: name: Create a Linux *.tar.gz bundle for x86 if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-16vcpu-ubuntu-2004 + - namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc needs: tests steps: - name: Checkout repo @@ -168,7 +168,7 @@ jobs: name: Create a Linux *.tar.gz bundle for ARM if: github.repository_owner == 'zed-industries' runs-on: - - buildjet-32vcpu-ubuntu-2204-arm + - namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc needs: tests steps: - name: Checkout repo diff --git a/.github/workflows/unit_evals.yml b/.github/workflows/unit_evals.yml index cb4e39d151..2e03fb028f 100644 --- a/.github/workflows/unit_evals.yml +++ b/.github/workflows/unit_evals.yml @@ -23,7 +23,7 @@ jobs: timeout-minutes: 60 name: Run unit evals runs-on: - - buildjet-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" @@ -37,7 +37,7 @@ jobs: uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - cache-provider: "buildjet" + # cache-provider: "buildjet" - name: Install Linux dependencies run: ./script/linux diff --git a/Cargo.lock b/Cargo.lock index 4cf5a68f1d..6f434e8685 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,7 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "language_model", "markdown", "parking_lot", "project", @@ -137,9 +138,9 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.18" +version = "0.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8e4c1dccb35e69d32566f0d11948d902f9942fc3f038821816c1150cf5925f4" +checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8" dependencies = [ "anyhow", "futures 0.3.31", @@ -150,6 +151,54 @@ dependencies = [ "serde_json", ] +[[package]] +name = "agent2" +version = "0.1.0" +dependencies = [ + "acp_thread", + "agent-client-protocol", + "agent_servers", + "agent_settings", + "anyhow", + "assistant_tool", + "assistant_tools", + "client", + "clock", + "cloud_llm_client", + "collections", + "ctor", + "env_logger 0.11.8", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "handlebars 4.5.0", + "indoc", + "itertools 0.14.0", + "language", + "language_model", + "language_models", + "log", + "lsp", + "paths", + "pretty_assertions", + "project", + "prompt_store", + "reqwest_client", + "rust-embed", + "schemars", + "serde", + "serde_json", + "settings", + "smol", + "ui", + "util", + "uuid", + "watch", + "workspace-hack", + "worktree", +] + [[package]] name = "agent_servers" version = "0.1.0" @@ -214,6 +263,7 @@ dependencies = [ "acp_thread", "agent", "agent-client-protocol", + "agent2", "agent_servers", "agent_settings", "ai_onboarding", @@ -1371,7 +1421,7 @@ dependencies = [ "anyhow", "arrayvec", "log", - "nom", + "nom 7.1.3", "num-rational", "v_frame", ] @@ -2745,7 +2795,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" dependencies = [ - "nom", + "nom 7.1.3", ] [[package]] @@ -3031,17 +3081,22 @@ dependencies = [ "anyhow", "cloud_api_types", "futures 0.3.31", + "gpui", + "gpui_tokio", "http_client", "parking_lot", "serde_json", "workspace-hack", + "yawc", ] [[package]] name = "cloud_api_types" version = "0.1.0" dependencies = [ + "anyhow", "chrono", + "ciborium", "cloud_llm_client", "pretty_assertions", "serde", @@ -7457,9 +7512,9 @@ dependencies = [ [[package]] name = "grid" -version = "0.17.0" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71b01d27060ad58be4663b9e4ac9e2d4806918e8876af8912afbddd1a91d5eaa" +checksum = "12101ecc8225ea6d675bc70263074eab6169079621c2186fe0c66590b2df9681" [[package]] name = "group" @@ -9078,6 +9133,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "client", + "cloud_api_types", "cloud_llm_client", "collections", "futures 0.3.31", @@ -9208,6 +9264,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-compression", + "async-fs", "async-tar", "async-trait", "chrono", @@ -9239,9 +9296,11 @@ dependencies = [ "serde_json", "serde_json_lenient", "settings", + "sha2", "smol", "snippet_provider", "task", + "tempfile", "text", "theme", "toml 0.8.20", @@ -10537,6 +10596,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + [[package]] name = "noop_proc_macro" version = "0.3.0" @@ -12575,6 +12643,7 @@ dependencies = [ "editor", "file_icons", "git", + "git_ui", "gpui", "indexmap", "language", @@ -12588,6 +12657,7 @@ dependencies = [ "serde_json", "settings", "smallvec", + "telemetry", "theme", "ui", "util", @@ -15358,7 +15428,7 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" dependencies = [ - "nom", + "nom 7.1.3", "unicode_categories", ] @@ -16158,9 +16228,9 @@ dependencies = [ [[package]] name = "taffy" -version = "0.8.3" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7aaef0ac998e6527d6d0d5582f7e43953bb17221ac75bb8eb2fcc2db3396db1c" +checksum = "a13e5d13f79d558b5d353a98072ca8ca0e99da429467804de959aa8c83c9a004" dependencies = [ "arrayvec", "grid", @@ -16561,9 +16631,8 @@ dependencies = [ [[package]] name = "tiktoken-rs" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25563eeba904d770acf527e8b370fe9a5547bacd20ff84a0b6c3bc41288e5625" +version = "0.8.0" +source = "git+https://github.com/zed-industries/tiktoken-rs?rev=30c32a4522751699adeda0d5840c71c3b75ae73d#30c32a4522751699adeda0d5840c71c3b75ae73d" dependencies = [ "anyhow", "base64 0.22.1", @@ -19934,7 +20003,7 @@ dependencies = [ "naga", "nix 0.28.0", "nix 0.29.0", - "nom", + "nom 7.1.3", "num-bigint", "num-bigint-dig", "num-integer", @@ -20269,6 +20338,34 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yawc" +version = "0.2.4" +source = "git+https://github.com/deviant-forks/yawc?rev=1899688f3e69ace4545aceb97b2a13881cf26142#1899688f3e69ace4545aceb97b2a13881cf26142" +dependencies = [ + "base64 0.22.1", + "bytes 1.10.1", + "flate2", + "futures 0.3.31", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "js-sys", + "nom 8.0.0", + "pin-project", + "rand 0.8.5", + "sha1", + "thiserror 1.0.69", + "tokio", + "tokio-rustls 0.26.2", + "tokio-util", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", +] + [[package]] name = "yazi" version = "0.2.1" @@ -20375,7 +20472,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.199.0" +version = "0.200.0" dependencies = [ "activity_indicator", "agent", diff --git a/Cargo.toml b/Cargo.toml index 733db92ce9..998e727602 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "crates/acp_thread", "crates/activity_indicator", "crates/agent", + "crates/agent2", "crates/agent_servers", "crates/agent_settings", "crates/agent_ui", @@ -229,6 +230,7 @@ edition = "2024" acp_thread = { path = "crates/acp_thread" } agent = { path = "crates/agent" } +agent2 = { path = "crates/agent2" } activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } agent_settings = { path = "crates/agent_settings" } @@ -423,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.18" +agent-client-protocol = "0.0.23" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" @@ -459,6 +461,7 @@ bytes = "1.0" cargo_metadata = "0.19" cargo_toml = "0.21" chrono = { version = "0.4", features = ["serde"] } +ciborium = "0.2" circular-buffer = "1.0" clap = { version = "4.4", features = ["derive"] } cocoa = "0.26" @@ -598,7 +601,7 @@ sysinfo = "0.31.0" take-until = "0.2.0" tempfile = "3.20.0" thiserror = "2.0.12" -tiktoken-rs = "0.7.0" +tiktoken-rs = { git = "https://github.com/zed-industries/tiktoken-rs", rev = "30c32a4522751699adeda0d5840c71c3b75ae73d" } time = { version = "0.3", features = [ "macros", "parsing", @@ -658,6 +661,9 @@ which = "6.0.0" windows-core = "0.61" wit-component = "0.221" workspace-hack = "0.1.0" +# We can switch back to the published version once https://github.com/infinitefield/yawc/pull/16 is merged and a new +# version is released. +yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" } zstd = "0.11" [workspace.dependencies.async-stripe] diff --git a/Procfile b/Procfile index 5f1231b90a..b3f13f66a6 100644 --- a/Procfile +++ b/Procfile @@ -1,3 +1,4 @@ collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all +cloud: cd ../cloud; cargo make dev livekit: livekit-server --dev blob_store: ./script/run-local-minio diff --git a/assets/icons/file_icons/puppet.svg b/assets/icons/file_icons/puppet.svg new file mode 100644 index 0000000000..cdf903bc62 --- /dev/null +++ b/assets/icons/file_icons/puppet.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/tool_bulb.svg b/assets/icons/tool_think.svg similarity index 100% rename from assets/icons/tool_bulb.svg rename to assets/icons/tool_think.svg diff --git a/assets/images/pro_trial_stamp.svg b/assets/images/pro_trial_stamp.svg new file mode 100644 index 0000000000..a3f9095120 --- /dev/null +++ b/assets/images/pro_trial_stamp.svg @@ -0,0 +1 @@ + diff --git a/assets/images/pro_user_stamp.svg b/assets/images/pro_user_stamp.svg new file mode 100644 index 0000000000..d037a9e833 --- /dev/null +++ b/assets/images/pro_user_stamp.svg @@ -0,0 +1 @@ + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 81f5c695a2..c436b1a8fb 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -332,7 +332,9 @@ "enter": "agent::Chat", "up": "agent::PreviousHistoryMessage", "down": "agent::NextHistoryMessage", - "shift-ctrl-r": "agent::OpenAgentDiff" + "shift-ctrl-r": "agent::OpenAgentDiff", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll" } }, { @@ -846,6 +848,7 @@ "ctrl-delete": ["project_panel::Delete", { "skip_prompt": false }], "alt-ctrl-r": "project_panel::RevealInFileManager", "ctrl-shift-enter": "project_panel::OpenWithSystem", + "alt-d": "project_panel::CompareMarkedFiles", "shift-find": "project_panel::NewSearchInDirectory", "ctrl-alt-shift-f": "project_panel::NewSearchInDirectory", "shift-down": "menu::SelectNext", @@ -1100,6 +1103,13 @@ "ctrl-enter": "menu::Confirm" } }, + { + "context": "OnboardingAiConfigurationModal", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel" + } + }, { "context": "Diagnostics", "use_key_equivalents": true, @@ -1176,7 +1186,8 @@ "ctrl-1": "onboarding::ActivateBasicsPage", "ctrl-2": "onboarding::ActivateEditingPage", "ctrl-3": "onboarding::ActivateAISetupPage", - "ctrl-escape": "onboarding::Finish" + "ctrl-escape": "onboarding::Finish", + "alt-tab": "onboarding::SignIn" } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 69958fd1f8..960bac1479 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -384,7 +384,9 @@ "enter": "agent::Chat", "up": "agent::PreviousHistoryMessage", "down": "agent::NextHistoryMessage", - "shift-ctrl-r": "agent::OpenAgentDiff" + "shift-ctrl-r": "agent::OpenAgentDiff", + "cmd-shift-y": "agent::KeepAll", + "cmd-shift-n": "agent::RejectAll" } }, { @@ -905,6 +907,7 @@ "cmd-delete": ["project_panel::Delete", { "skip_prompt": false }], "alt-cmd-r": "project_panel::RevealInFileManager", "ctrl-shift-enter": "project_panel::OpenWithSystem", + "alt-d": "project_panel::CompareMarkedFiles", "cmd-alt-backspace": ["project_panel::Delete", { "skip_prompt": false }], "cmd-alt-shift-f": "project_panel::NewSearchInDirectory", "shift-down": "menu::SelectNext", @@ -1202,6 +1205,13 @@ "cmd-enter": "menu::Confirm" } }, + { + "context": "OnboardingAiConfigurationModal", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel" + } + }, { "context": "Diagnostics", "use_key_equivalents": true, @@ -1278,7 +1288,8 @@ "cmd-1": "onboarding::ActivateBasicsPage", "cmd-2": "onboarding::ActivateEditingPage", "cmd-3": "onboarding::ActivateAISetupPage", - "cmd-escape": "onboarding::Finish" + "cmd-escape": "onboarding::Finish", + "alt-tab": "onboarding::SignIn" } } ] diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index 3096ec40bb..3fca75b572 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -815,6 +815,7 @@ "ctrl-x": "project_panel::OpenSplitUp", "x": "project_panel::RevealInFileManager", "s": "project_panel::OpenWithSystem", + "z d": "project_panel::CompareMarkedFiles", "] c": "project_panel::SelectNextGitEntry", "[ c": "project_panel::SelectPrevGitEntry", "] d": "project_panel::SelectNextDiagnostic", diff --git a/assets/settings/default.json b/assets/settings/default.json index 4734b5d118..9c579b858d 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -596,6 +596,8 @@ // when a corresponding project entry becomes active. // Gitignored entries are never auto revealed. "auto_reveal_entries": true, + // Whether the project panel should open on startup. + "starts_open": true, // Whether to fold directories automatically and show compact folders // (e.g. "a/b/c" ) when a directory has only one subdirectory inside. "auto_fold_dirs": true, @@ -1171,6 +1173,9 @@ // Sets a delay after which the inline blame information is shown. // Delay is restarted with every cursor movement. "delay_ms": 0, + // The amount of padding between the end of the source line and the start + // of the inline blame in units of em widths. + "padding": 7, // Whether or not to display the git commit summary on the same line. "show_commit_summary": false, // The minimum column number to show the inline blame information at @@ -1233,6 +1238,11 @@ // 2. hour24 "hour_format": "hour12" }, + // Status bar-related settings. + "status_bar": { + // Whether to show the active language button in the status bar. + "active_language_button": true + }, // Settings specific to the terminal "terminal": { // What shell to use when opening a terminal. May take 3 values: diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 225597415c..1831c7e473 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -25,6 +25,7 @@ futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true +language_model.workspace = true markdown.workspace = true project.workspace = true serde.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 44190a4860..1df0e1def7 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,23 +1,23 @@ mod connection; +mod diff; + pub use connection::*; +pub use diff::*; use agent_client_protocol as acp; use anyhow::{Context as _, Result}; use assistant_tool::ActionLog; -use buffer_diff::BufferDiff; -use editor::{Bias, MultiBuffer, PathKey}; +use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use itertools::Itertools; -use language::{ - Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point, - text_diff, -}; +use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff}; use markdown::Markdown; use project::{AgentLocation, Project}; use std::collections::HashMap; use std::error::Error; use std::fmt::Formatter; +use std::process::ExitStatus; use std::rc::Rc; use std::{ fmt::Display, @@ -139,7 +139,7 @@ impl AgentThreadEntry { } } - pub fn diffs(&self) -> impl Iterator { + pub fn diffs(&self) -> impl Iterator> { if let AgentThreadEntry::ToolCall(call) = self { itertools::Either::Left(call.diffs()) } else { @@ -165,6 +165,7 @@ pub struct ToolCall { pub status: ToolCallStatus, pub locations: Vec, pub raw_input: Option, + pub raw_output: Option, } impl ToolCall { @@ -193,10 +194,11 @@ impl ToolCall { locations: tool_call.locations, status, raw_input: tool_call.raw_input, + raw_output: tool_call.raw_output, } } - fn update( + fn update_fields( &mut self, fields: acp::ToolCallUpdateFields, language_registry: Arc, @@ -209,6 +211,7 @@ impl ToolCall { content, locations, raw_input, + raw_output, } = fields; if let Some(kind) = kind { @@ -220,7 +223,9 @@ impl ToolCall { } if let Some(title) = title { - self.label = cx.new(|cx| Markdown::new_text(title.into(), cx)); + self.label.update(cx, |label, cx| { + label.replace(title, cx); + }); } if let Some(content) = content { @@ -237,9 +242,13 @@ impl ToolCall { if let Some(raw_input) = raw_input { self.raw_input = Some(raw_input); } + + if let Some(raw_output) = raw_output { + self.raw_output = Some(raw_output); + } } - pub fn diffs(&self) -> impl Iterator { + pub fn diffs(&self) -> impl Iterator> { self.content.iter().filter_map(|content| match content { ToolCallContent::ContentBlock { .. } => None, ToolCallContent::Diff { diff } => Some(diff), @@ -379,7 +388,7 @@ impl ContentBlock { #[derive(Debug)] pub enum ToolCallContent { ContentBlock { content: ContentBlock }, - Diff { diff: Diff }, + Diff { diff: Entity }, } impl ToolCallContent { @@ -393,7 +402,7 @@ impl ToolCallContent { content: ContentBlock::new(content, &language_registry, cx), }, acp::ToolCallContent::Diff { diff } => Self::Diff { - diff: Diff::from_acp(diff, language_registry, cx), + diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)), }, } } @@ -401,109 +410,44 @@ impl ToolCallContent { pub fn to_markdown(&self, cx: &App) -> String { match self { Self::ContentBlock { content } => content.to_markdown(cx).to_string(), - Self::Diff { diff } => diff.to_markdown(cx), + Self::Diff { diff } => diff.read(cx).to_markdown(cx), } } } -#[derive(Debug)] -pub struct Diff { - pub multibuffer: Entity, - pub path: PathBuf, - pub new_buffer: Entity, - pub old_buffer: Entity, - _task: Task>, +#[derive(Debug, PartialEq)] +pub enum ToolCallUpdate { + UpdateFields(acp::ToolCallUpdate), + UpdateDiff(ToolCallUpdateDiff), } -impl Diff { - pub fn from_acp( - diff: acp::Diff, - language_registry: Arc, - cx: &mut App, - ) -> Self { - let acp::Diff { - path, - old_text, - new_text, - } = diff; - - let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); - - let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); - let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); - let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); - let old_buffer_snapshot = old_buffer.read(cx).snapshot(); - let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); - let diff_task = buffer_diff.update(cx, |diff, cx| { - diff.set_base_text( - old_buffer_snapshot, - Some(language_registry.clone()), - new_buffer_snapshot, - cx, - ) - }); - - let task = cx.spawn({ - let multibuffer = multibuffer.clone(); - let path = path.clone(); - let new_buffer = new_buffer.clone(); - async move |cx| { - diff_task.await?; - - multibuffer - .update(cx, |multibuffer, cx| { - let hunk_ranges = { - let buffer = new_buffer.read(cx); - let diff = buffer_diff.read(cx); - diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) - .collect::>() - }; - - multibuffer.set_excerpts_for_path( - PathKey::for_buffer(&new_buffer, cx), - new_buffer.clone(), - hunk_ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, - cx, - ); - multibuffer.add_diff(buffer_diff.clone(), cx); - }) - .log_err(); - - if let Some(language) = language_registry - .language_for_file_path(&path) - .await - .log_err() - { - new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?; - } - - anyhow::Ok(()) - } - }); - - Self { - multibuffer, - path, - new_buffer, - old_buffer, - _task: task, +impl ToolCallUpdate { + fn id(&self) -> &acp::ToolCallId { + match self { + Self::UpdateFields(update) => &update.id, + Self::UpdateDiff(diff) => &diff.id, } } +} - fn to_markdown(&self, cx: &App) -> String { - let buffer_text = self - .multibuffer - .read(cx) - .all_buffers() - .iter() - .map(|buffer| buffer.read(cx).text()) - .join("\n"); - format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text) +impl From for ToolCallUpdate { + fn from(update: acp::ToolCallUpdate) -> Self { + Self::UpdateFields(update) } } +impl From for ToolCallUpdate { + fn from(diff: ToolCallUpdateDiff) -> Self { + Self::UpdateDiff(diff) + } +} + +#[derive(Debug, PartialEq)] +pub struct ToolCallUpdateDiff { + pub id: acp::ToolCallId, + pub diff: Entity, +} + #[derive(Debug, Default)] pub struct Plan { pub entries: Vec, @@ -556,7 +500,7 @@ pub struct PlanEntry { impl PlanEntry { pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self { Self { - content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)), + content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)), priority: entry.priority, status: entry.status, } @@ -581,6 +525,7 @@ pub enum AcpThreadEvent { ToolAuthorizationRequired, Stopped, Error, + ServerExited(ExitStatus), } impl EventEmitter for AcpThread {} @@ -654,6 +599,10 @@ impl AcpThread { &self.entries } + pub fn session_id(&self) -> &acp::SessionId { + &self.session_id + } + pub fn status(&self) -> ThreadStatus { if self.send_task.is_some() { if self.waiting_for_tool_confirmation() { @@ -794,15 +743,26 @@ impl AcpThread { pub fn update_tool_call( &mut self, - update: acp::ToolCallUpdate, + update: impl Into, cx: &mut Context, ) -> Result<()> { + let update = update.into(); let languages = self.project.read(cx).languages().clone(); let (ix, current_call) = self - .tool_call_mut(&update.id) + .tool_call_mut(update.id()) .context("Tool call not found")?; - current_call.update(update.fields, languages, cx); + match update { + ToolCallUpdate::UpdateFields(update) => { + current_call.update_fields(update.fields, languages, cx); + } + ToolCallUpdate::UpdateDiff(update) => { + current_call.content.clear(); + current_call + .content + .push(ToolCallContent::Diff { diff: update.diff }); + } + } cx.emit(AcpThreadEvent::EntryUpdated(ix)); @@ -890,7 +850,7 @@ impl AcpThread { }); } - pub fn request_tool_call_permission( + pub fn request_tool_call_authorization( &mut self, tool_call: acp::ToolCall, options: Vec, @@ -965,13 +925,26 @@ impl AcpThread { } pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context) { - self.plan = Plan { - entries: request - .entries - .into_iter() - .map(|entry| PlanEntry::from_acp(entry, cx)) - .collect(), - }; + let new_entries_len = request.entries.len(); + let mut new_entries = request.entries.into_iter(); + + // Reuse existing markdown to prevent flickering + for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) { + let PlanEntry { + content, + priority, + status, + } = old; + content.update(cx, |old, cx| { + old.replace(new.content, cx); + }); + *priority = new.priority; + *status = new.status; + } + for new in new_entries { + self.plan.entries.push(PlanEntry::from_acp(new, cx)) + } + self.plan.entries.truncate(new_entries_len); cx.notify(); } @@ -1032,8 +1005,9 @@ impl AcpThread { ) })? .await; + tx.send(result).log_err(); - this.update(cx, |this, _cx| this.send_task.take())?; + anyhow::Ok(()) } .await @@ -1046,7 +1020,23 @@ impl AcpThread { .log_err(); Err(e)? } - _ => { + result => { + let cancelled = matches!( + result, + Ok(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Cancelled + })) + ); + + // We only take the task if the current prompt wasn't cancelled. + // + // This prompt may have been cancelled because another one was sent + // while it was still generating. In these cases, dropping `send_task` + // would cause the next generation to be cancelled. + if !cancelled { + this.update(cx, |this, _cx| this.send_task.take()).ok(); + } + this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped)) .log_err(); Ok(()) @@ -1229,6 +1219,10 @@ impl AcpThread { pub fn to_markdown(&self, cx: &App) -> String { self.entries.iter().map(|e| e.to_markdown(cx)).collect() } + + pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context) { + cx.emit(AcpThreadEvent::ServerExited(status)); + } } #[cfg(test)] @@ -1371,6 +1365,9 @@ mod tests { cx, ) .unwrap(); + })?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, }) } .boxed_local() @@ -1443,7 +1440,9 @@ mod tests { .unwrap() .await .unwrap(); - Ok(()) + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) } .boxed_local() }, @@ -1510,13 +1509,16 @@ mod tests { content: vec![], locations: vec![], raw_input: None, + raw_output: None, }), cx, ) }) .unwrap() .unwrap(); - Ok(()) + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) } .boxed_local() } @@ -1620,13 +1622,16 @@ mod tests { }], locations: vec![], raw_input: None, + raw_output: None, }), cx, ) }) .unwrap() .unwrap(); - Ok(()) + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) } .boxed_local() } @@ -1680,7 +1685,7 @@ mod tests { acp::PromptRequest, WeakEntity, AsyncApp, - ) -> LocalBoxFuture<'static, Result<()>> + ) -> LocalBoxFuture<'static, Result> + 'static, >, >, @@ -1707,7 +1712,7 @@ mod tests { acp::PromptRequest, WeakEntity, AsyncApp, - ) -> LocalBoxFuture<'static, Result<()>> + ) -> LocalBoxFuture<'static, Result> + 'static, ) -> Self { self.on_user_message.replace(Rc::new(handler)); @@ -1749,7 +1754,11 @@ mod tests { } } - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { let sessions = self.sessions.lock(); let thread = sessions.get(¶ms.session_id).unwrap(); if let Some(handler) = &self.on_user_message { @@ -1757,7 +1766,9 @@ mod tests { let thread = thread.clone(); cx.spawn(async move |cx| handler(params, thread, cx.clone()).await) } else { - Task::ready(Ok(())) + Task::ready(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + })) } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 929500a67b..cf06563bee 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,13 +1,61 @@ -use std::{error::Error, fmt, path::Path, rc::Rc}; +use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use agent_client_protocol::{self as acp}; use anyhow::Result; use gpui::{AsyncApp, Entity, Task}; +use language_model::LanguageModel; use project::Project; use ui::App; use crate::AcpThread; +/// Trait for agents that support listing, selecting, and querying language models. +/// +/// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. +pub trait ModelSelector: 'static { + /// Lists all available language models for this agent. + /// + /// # Parameters + /// - `cx`: The GPUI app context for async operations and global access. + /// + /// # Returns + /// A task resolving to the list of models or an error (e.g., if no models are configured). + fn list_models(&self, cx: &mut AsyncApp) -> Task>>>; + + /// Selects a model for a specific session (thread). + /// + /// This sets the default model for future interactions in the session. + /// If the session doesn't exist or the model is invalid, it returns an error. + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to apply the model to. + /// - `model`: The model to select (should be one from [list_models]). + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to `Ok(())` on success or an error. + fn select_model( + &self, + session_id: acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task>; + + /// Retrieves the currently selected model for a specific session (thread). + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to query. + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to the selected model (always set) or an error (e.g., session not found). + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>>; +} + pub trait AgentConnection { fn new_thread( self: Rc, @@ -20,9 +68,18 @@ pub trait AgentConnection { fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task>; + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) + -> Task>; fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + + /// Returns this agent as an [Rc] if the model selection capability is supported. + /// + /// If the agent does not support model selection, returns [None]. + /// This allows sharing the selector in UI components. + fn model_selector(&self) -> Option> { + None // Default impl for agents that don't support it + } } #[derive(Debug)] diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs new file mode 100644 index 0000000000..9cc6271360 --- /dev/null +++ b/crates/acp_thread/src/diff.rs @@ -0,0 +1,388 @@ +use agent_client_protocol as acp; +use anyhow::Result; +use buffer_diff::{BufferDiff, BufferDiffSnapshot}; +use editor::{MultiBuffer, PathKey}; +use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task}; +use itertools::Itertools; +use language::{ + Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _, Point, Rope, TextBuffer, +}; +use std::{ + cmp::Reverse, + ops::Range, + path::{Path, PathBuf}, + sync::Arc, +}; +use util::ResultExt; + +pub enum Diff { + Pending(PendingDiff), + Finalized(FinalizedDiff), +} + +impl Diff { + pub fn from_acp( + diff: acp::Diff, + language_registry: Arc, + cx: &mut Context, + ) -> Self { + let acp::Diff { + path, + old_text, + new_text, + } = diff; + + let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); + + let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); + let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); + let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); + let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); + + let task = cx.spawn({ + let multibuffer = multibuffer.clone(); + let path = path.clone(); + async move |_, cx| { + let language = language_registry + .language_for_file_path(&path) + .await + .log_err(); + + new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?; + + let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| { + buffer.set_language(language, cx); + buffer.snapshot() + })?; + + buffer_diff + .update(cx, |diff, cx| { + diff.set_base_text( + old_buffer_snapshot, + Some(language_registry), + new_buffer_snapshot, + cx, + ) + })? + .await?; + + multibuffer + .update(cx, |multibuffer, cx| { + let hunk_ranges = { + let buffer = new_buffer.read(cx); + let diff = buffer_diff.read(cx); + diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .collect::>() + }; + + multibuffer.set_excerpts_for_path( + PathKey::for_buffer(&new_buffer, cx), + new_buffer.clone(), + hunk_ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + multibuffer.add_diff(buffer_diff, cx); + }) + .log_err(); + + anyhow::Ok(()) + } + }); + + Self::Finalized(FinalizedDiff { + multibuffer, + path, + _update_diff: task, + }) + } + + pub fn new(buffer: Entity, cx: &mut Context) -> Self { + let buffer_snapshot = buffer.read(cx).snapshot(); + let base_text = buffer_snapshot.text(); + let language_registry = buffer.read(cx).language_registry(); + let text_snapshot = buffer.read(cx).text_snapshot(); + let buffer_diff = cx.new(|cx| { + let mut diff = BufferDiff::new(&text_snapshot, cx); + let _ = diff.set_base_text( + buffer_snapshot.clone(), + language_registry, + text_snapshot, + cx, + ); + diff + }); + + let multibuffer = cx.new(|cx| { + let mut multibuffer = MultiBuffer::without_headers(Capability::ReadOnly); + multibuffer.add_diff(buffer_diff.clone(), cx); + multibuffer + }); + + Self::Pending(PendingDiff { + multibuffer, + base_text: Arc::new(base_text), + _subscription: cx.observe(&buffer, |this, _, cx| { + if let Diff::Pending(diff) = this { + diff.update(cx); + } + }), + buffer, + diff: buffer_diff, + revealed_ranges: Vec::new(), + update_diff: Task::ready(Ok(())), + }) + } + + pub fn reveal_range(&mut self, range: Range, cx: &mut Context) { + if let Self::Pending(diff) = self { + diff.reveal_range(range, cx); + } + } + + pub fn finalize(&mut self, cx: &mut Context) { + if let Self::Pending(diff) = self { + *self = Self::Finalized(diff.finalize(cx)); + } + } + + pub fn multibuffer(&self) -> &Entity { + match self { + Self::Pending(PendingDiff { multibuffer, .. }) => multibuffer, + Self::Finalized(FinalizedDiff { multibuffer, .. }) => multibuffer, + } + } + + pub fn to_markdown(&self, cx: &App) -> String { + let buffer_text = self + .multibuffer() + .read(cx) + .all_buffers() + .iter() + .map(|buffer| buffer.read(cx).text()) + .join("\n"); + let path = match self { + Diff::Pending(PendingDiff { buffer, .. }) => { + buffer.read(cx).file().map(|file| file.path().as_ref()) + } + Diff::Finalized(FinalizedDiff { path, .. }) => Some(path.as_path()), + }; + format!( + "Diff: {}\n```\n{}\n```\n", + path.unwrap_or(Path::new("untitled")).display(), + buffer_text + ) + } +} + +pub struct PendingDiff { + multibuffer: Entity, + base_text: Arc, + buffer: Entity, + diff: Entity, + revealed_ranges: Vec>, + _subscription: Subscription, + update_diff: Task>, +} + +impl PendingDiff { + pub fn update(&mut self, cx: &mut Context) { + let buffer = self.buffer.clone(); + let buffer_diff = self.diff.clone(); + let base_text = self.base_text.clone(); + self.update_diff = cx.spawn(async move |diff, cx| { + let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?; + let diff_snapshot = BufferDiff::update_diff( + buffer_diff.clone(), + text_snapshot.clone(), + Some(base_text), + false, + false, + None, + None, + cx, + ) + .await?; + buffer_diff.update(cx, |diff, cx| { + diff.set_snapshot(diff_snapshot, &text_snapshot, cx) + })?; + diff.update(cx, |diff, cx| { + if let Diff::Pending(diff) = diff { + diff.update_visible_ranges(cx); + } + }) + }); + } + + pub fn reveal_range(&mut self, range: Range, cx: &mut Context) { + self.revealed_ranges.push(range); + self.update_visible_ranges(cx); + } + + fn finalize(&self, cx: &mut Context) -> FinalizedDiff { + let ranges = self.excerpt_ranges(cx); + let base_text = self.base_text.clone(); + let language_registry = self.buffer.read(cx).language_registry().clone(); + + let path = self + .buffer + .read(cx) + .file() + .map(|file| file.path().as_ref()) + .unwrap_or(Path::new("untitled")) + .into(); + + // Replace the buffer in the multibuffer with the snapshot + let buffer = cx.new(|cx| { + let language = self.buffer.read(cx).language().cloned(); + let buffer = TextBuffer::new_normalized( + 0, + cx.entity_id().as_non_zero_u64().into(), + self.buffer.read(cx).line_ending(), + self.buffer.read(cx).as_rope().clone(), + ); + let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite); + buffer.set_language(language, cx); + buffer + }); + + let buffer_diff = cx.spawn({ + let buffer = buffer.clone(); + let language_registry = language_registry.clone(); + async move |_this, cx| { + build_buffer_diff(base_text, &buffer, language_registry, cx).await + } + }); + + let update_diff = cx.spawn(async move |this, cx| { + let buffer_diff = buffer_diff.await?; + this.update(cx, |this, cx| { + this.multibuffer().update(cx, |multibuffer, cx| { + let path_key = PathKey::for_buffer(&buffer, cx); + multibuffer.clear(cx); + multibuffer.set_excerpts_for_path( + path_key, + buffer, + ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + multibuffer.add_diff(buffer_diff.clone(), cx); + }); + + cx.notify(); + }) + }); + + FinalizedDiff { + path, + multibuffer: self.multibuffer.clone(), + _update_diff: update_diff, + } + } + + fn update_visible_ranges(&mut self, cx: &mut Context) { + let ranges = self.excerpt_ranges(cx); + self.multibuffer.update(cx, |multibuffer, cx| { + multibuffer.set_excerpts_for_path( + PathKey::for_buffer(&self.buffer, cx), + self.buffer.clone(), + ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + let end = multibuffer.len(cx); + Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1) + }); + cx.notify(); + } + + fn excerpt_ranges(&self, cx: &App) -> Vec> { + let buffer = self.buffer.read(cx); + let diff = self.diff.read(cx); + let mut ranges = diff + .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .collect::>(); + ranges.extend( + self.revealed_ranges + .iter() + .map(|range| range.to_point(&buffer)), + ); + ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end))); + + // Merge adjacent ranges + let mut ranges = ranges.into_iter().peekable(); + let mut merged_ranges = Vec::new(); + while let Some(mut range) = ranges.next() { + while let Some(next_range) = ranges.peek() { + if range.end >= next_range.start { + range.end = range.end.max(next_range.end); + ranges.next(); + } else { + break; + } + } + + merged_ranges.push(range); + } + merged_ranges + } +} + +pub struct FinalizedDiff { + path: PathBuf, + multibuffer: Entity, + _update_diff: Task>, +} + +async fn build_buffer_diff( + old_text: Arc, + buffer: &Entity, + language_registry: Option>, + cx: &mut AsyncApp, +) -> Result> { + let buffer = cx.update(|cx| buffer.read(cx).snapshot())?; + + let old_text_rope = cx + .background_spawn({ + let old_text = old_text.clone(); + async move { Rope::from(old_text.as_str()) } + }) + .await; + let base_buffer = cx + .update(|cx| { + Buffer::build_snapshot( + old_text_rope, + buffer.language().cloned(), + language_registry, + cx, + ) + })? + .await; + + let diff_snapshot = cx + .update(|cx| { + BufferDiffSnapshot::new_with_base_buffer( + buffer.text.clone(), + Some(old_text), + base_buffer, + cx, + ) + })? + .await; + + let secondary_diff = cx.new(|cx| { + let mut diff = BufferDiff::new(&buffer, cx); + diff.set_snapshot(diff_snapshot.clone(), &buffer, cx); + diff + })?; + + cx.new(|cx| { + let mut diff = BufferDiff::new(&buffer.text, cx); + diff.set_snapshot(diff_snapshot, &buffer, cx); + diff.set_secondary_diff(secondary_diff); + diff + }) +} diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index 89f75a72bd..eb39c3e454 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -212,7 +212,16 @@ impl HistoryStore { fn load_recently_opened_entries(cx: &AsyncApp) -> Task>> { cx.background_spawn(async move { let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH); - let contents = smol::fs::read_to_string(path).await?; + let contents = match smol::fs::read_to_string(path).await { + Ok(it) => it, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + return Ok(Vec::new()); + } + Err(e) => { + return Err(e) + .context("deserializing persisted agent panel navigation history"); + } + }; let entries = serde_json::from_str::>(&contents) .context("deserializing persisted agent panel navigation history")? .into_iter() diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 8558dd528d..048aa4245d 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -8,7 +8,7 @@ use crate::{ }, tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}, }; -use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; @@ -2112,12 +2112,10 @@ impl Thread { return; } - let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt"); - let request = self.to_summarize_request( &model.model, CompletionIntent::ThreadSummarization, - added_user_message.into(), + SUMMARIZE_THREAD_PROMPT.into(), cx, ); @@ -4047,8 +4045,8 @@ fn main() {{ }); cx.run_until_parked(); - fake_model.stream_last_completion_response("Brief"); - fake_model.stream_last_completion_response(" Introduction"); + fake_model.send_last_completion_stream_text_chunk("Brief"); + fake_model.send_last_completion_stream_text_chunk(" Introduction"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -4141,7 +4139,7 @@ fn main() {{ }); cx.run_until_parked(); - fake_model.stream_last_completion_response("A successful summary"); + fake_model.send_last_completion_stream_text_chunk("A successful summary"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -4774,7 +4772,7 @@ fn main() {{ !pending.is_empty(), "Should have a pending completion after retry" ); - fake_model.stream_completion_response(&pending[0], "Success!"); + fake_model.send_completion_stream_text_chunk(&pending[0], "Success!"); fake_model.end_completion_stream(&pending[0]); cx.run_until_parked(); @@ -4942,7 +4940,7 @@ fn main() {{ // Check for pending completions and complete them if let Some(pending) = inner_fake.pending_completions().first() { - inner_fake.stream_completion_response(pending, "Success!"); + inner_fake.send_completion_stream_text_chunk(pending, "Success!"); inner_fake.end_completion_stream(pending); } cx.run_until_parked(); @@ -5427,7 +5425,7 @@ fn main() {{ fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { cx.run_until_parked(); - fake_model.stream_last_completion_response("Assistant response"); + fake_model.send_last_completion_stream_text_chunk("Assistant response"); fake_model.end_last_completion_stream(); cx.run_until_parked(); } diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml new file mode 100644 index 0000000000..3e19895a31 --- /dev/null +++ b/crates/agent2/Cargo.toml @@ -0,0 +1,64 @@ +[package] +name = "agent2" +version = "0.1.0" +edition = "2021" +license = "GPL-3.0-or-later" +publish = false + +[lib] +path = "src/agent2.rs" + +[lints] +workspace = true + +[dependencies] +acp_thread.workspace = true +agent-client-protocol.workspace = true +agent_servers.workspace = true +agent_settings.workspace = true +anyhow.workspace = true +assistant_tool.workspace = true +assistant_tools.workspace = true +cloud_llm_client.workspace = true +collections.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +handlebars = { workspace = true, features = ["rust-embed"] } +indoc.workspace = true +itertools.workspace = true +language.workspace = true +language_model.workspace = true +language_models.workspace = true +log.workspace = true +paths.workspace = true +project.workspace = true +prompt_store.workspace = true +rust-embed.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +ui.workspace = true +util.workspace = true +uuid.workspace = true +watch.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] +ctor.workspace = true +client = { workspace = true, "features" = ["test-support"] } +clock = { workspace = true, "features" = ["test-support"] } +env_logger.workspace = true +fs = { workspace = true, "features" = ["test-support"] } +gpui = { workspace = true, "features" = ["test-support"] } +gpui_tokio.workspace = true +language = { workspace = true, "features" = ["test-support"] } +language_model = { workspace = true, "features" = ["test-support"] } +lsp = { workspace = true, "features" = ["test-support"] } +project = { workspace = true, "features" = ["test-support"] } +reqwest_client.workspace = true +settings = { workspace = true, "features" = ["test-support"] } +worktree = { workspace = true, "features" = ["test-support"] } +pretty_assertions.workspace = true diff --git a/crates/agent2/LICENSE-GPL b/crates/agent2/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/agent2/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs new file mode 100644 index 0000000000..df061cd5ed --- /dev/null +++ b/crates/agent2/src/agent.rs @@ -0,0 +1,696 @@ +use crate::{templates::Templates, AgentResponseEvent, Thread}; +use crate::{EditFileTool, FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization}; +use acp_thread::ModelSelector; +use agent_client_protocol as acp; +use anyhow::{anyhow, Context as _, Result}; +use futures::{future, StreamExt}; +use gpui::{ + App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, +}; +use language_model::{LanguageModel, LanguageModelRegistry}; +use project::{Project, ProjectItem, ProjectPath, Worktree}; +use prompt_store::{ + ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, +}; +use std::cell::RefCell; +use std::collections::HashMap; +use std::path::Path; +use std::rc::Rc; +use std::sync::Arc; +use util::ResultExt; + +const RULES_FILE_NAMES: [&'static str; 9] = [ + ".rules", + ".cursorrules", + ".windsurfrules", + ".clinerules", + ".github/copilot-instructions.md", + "CLAUDE.md", + "AGENT.md", + "AGENTS.md", + "GEMINI.md", +]; + +pub struct RulesLoadingError { + pub message: SharedString, +} + +/// Holds both the internal Thread and the AcpThread for a session +struct Session { + /// The internal thread that processes messages + thread: Entity, + /// The ACP thread that handles protocol communication + acp_thread: WeakEntity, + _subscription: Subscription, +} + +pub struct NativeAgent { + /// Session ID -> Session mapping + sessions: HashMap, + /// Shared project context for all threads + project_context: Rc>, + project_context_needs_refresh: watch::Sender<()>, + _maintain_project_context: Task>, + /// Shared templates for all threads + templates: Arc, + project: Entity, + prompt_store: Option>, + _subscriptions: Vec, +} + +impl NativeAgent { + pub async fn new( + project: Entity, + templates: Arc, + prompt_store: Option>, + cx: &mut AsyncApp, + ) -> Result> { + log::info!("Creating new NativeAgent"); + + let project_context = cx + .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))? + .await; + + cx.new(|cx| { + let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; + if let Some(prompt_store) = prompt_store.as_ref() { + subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event)) + } + + let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = + watch::channel(()); + Self { + sessions: HashMap::new(), + project_context: Rc::new(RefCell::new(project_context)), + project_context_needs_refresh: project_context_needs_refresh_tx, + _maintain_project_context: cx.spawn(async move |this, cx| { + Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await + }), + templates, + project, + prompt_store, + _subscriptions: subscriptions, + } + }) + } + + async fn maintain_project_context( + this: WeakEntity, + mut needs_refresh: watch::Receiver<()>, + cx: &mut AsyncApp, + ) -> Result<()> { + while needs_refresh.changed().await.is_ok() { + let project_context = this + .update(cx, |this, cx| { + Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx) + })? + .await; + this.update(cx, |this, _| this.project_context.replace(project_context))?; + } + + Ok(()) + } + + fn build_project_context( + project: &Entity, + prompt_store: Option<&Entity>, + cx: &mut App, + ) -> Task { + let worktrees = project.read(cx).visible_worktrees(cx).collect::>(); + let worktree_tasks = worktrees + .into_iter() + .map(|worktree| { + Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx) + }) + .collect::>(); + let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() { + prompt_store.read_with(cx, |prompt_store, cx| { + let prompts = prompt_store.default_prompt_metadata(); + let load_tasks = prompts.into_iter().map(|prompt_metadata| { + let contents = prompt_store.load(prompt_metadata.id, cx); + async move { (contents.await, prompt_metadata) } + }); + cx.background_spawn(future::join_all(load_tasks)) + }) + } else { + Task::ready(vec![]) + }; + + cx.spawn(async move |_cx| { + let (worktrees, default_user_rules) = + future::join(future::join_all(worktree_tasks), default_user_rules_task).await; + + let worktrees = worktrees + .into_iter() + .map(|(worktree, _rules_error)| { + // TODO: show error message + // if let Some(rules_error) = rules_error { + // this.update(cx, |_, cx| cx.emit(rules_error)).ok(); + // } + worktree + }) + .collect::>(); + + let default_user_rules = default_user_rules + .into_iter() + .flat_map(|(contents, prompt_metadata)| match contents { + Ok(contents) => Some(UserRulesContext { + uuid: match prompt_metadata.id { + PromptId::User { uuid } => uuid, + PromptId::EditWorkflow => return None, + }, + title: prompt_metadata.title.map(|title| title.to_string()), + contents, + }), + Err(_err) => { + // TODO: show error message + // this.update(cx, |_, cx| { + // cx.emit(RulesLoadingError { + // message: format!("{err:?}").into(), + // }); + // }) + // .ok(); + None + } + }) + .collect::>(); + + ProjectContext::new(worktrees, default_user_rules) + }) + } + + fn load_worktree_info_for_system_prompt( + worktree: Entity, + project: Entity, + cx: &mut App, + ) -> Task<(WorktreeContext, Option)> { + let tree = worktree.read(cx); + let root_name = tree.root_name().into(); + let abs_path = tree.abs_path(); + + let mut context = WorktreeContext { + root_name, + abs_path, + rules_file: None, + }; + + let rules_task = Self::load_worktree_rules_file(worktree, project, cx); + let Some(rules_task) = rules_task else { + return Task::ready((context, None)); + }; + + cx.spawn(async move |_| { + let (rules_file, rules_file_error) = match rules_task.await { + Ok(rules_file) => (Some(rules_file), None), + Err(err) => ( + None, + Some(RulesLoadingError { + message: format!("{err}").into(), + }), + ), + }; + context.rules_file = rules_file; + (context, rules_file_error) + }) + } + + fn load_worktree_rules_file( + worktree: Entity, + project: Entity, + cx: &mut App, + ) -> Option>> { + let worktree = worktree.read(cx); + let worktree_id = worktree.id(); + let selected_rules_file = RULES_FILE_NAMES + .into_iter() + .filter_map(|name| { + worktree + .entry_for_path(name) + .filter(|entry| entry.is_file()) + .map(|entry| entry.path.clone()) + }) + .next(); + + // Note that Cline supports `.clinerules` being a directory, but that is not currently + // supported. This doesn't seem to occur often in GitHub repositories. + selected_rules_file.map(|path_in_worktree| { + let project_path = ProjectPath { + worktree_id, + path: path_in_worktree.clone(), + }; + let buffer_task = + project.update(cx, |project, cx| project.open_buffer(project_path, cx)); + let rope_task = cx.spawn(async move |cx| { + buffer_task.await?.read_with(cx, |buffer, cx| { + let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?; + anyhow::Ok((project_entry_id, buffer.as_rope().clone())) + })? + }); + // Build a string from the rope on a background thread. + cx.background_spawn(async move { + let (project_entry_id, rope) = rope_task.await?; + anyhow::Ok(RulesFileContext { + path_in_worktree, + text: rope.to_string().trim().to_string(), + project_entry_id: project_entry_id.to_usize(), + }) + }) + }) + } + + fn handle_project_event( + &mut self, + _project: Entity, + event: &project::Event, + _cx: &mut Context, + ) { + match event { + project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { + self.project_context_needs_refresh.send(()).ok(); + } + project::Event::WorktreeUpdatedEntries(_, items) => { + if items.iter().any(|(path, _, _)| { + RULES_FILE_NAMES + .iter() + .any(|name| path.as_ref() == Path::new(name)) + }) { + self.project_context_needs_refresh.send(()).ok(); + } + } + _ => {} + } + } + + fn handle_prompts_updated_event( + &mut self, + _prompt_store: Entity, + _event: &prompt_store::PromptsUpdatedEvent, + _cx: &mut Context, + ) { + self.project_context_needs_refresh.send(()).ok(); + } +} + +/// Wrapper struct that implements the AgentConnection trait +#[derive(Clone)] +pub struct NativeAgentConnection(pub Entity); + +impl ModelSelector for NativeAgentConnection { + fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { + log::debug!("NativeAgentConnection::list_models called"); + cx.spawn(async move |cx| { + cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let models = registry.available_models(cx).collect::>(); + log::info!("Found {} available models", models.len()); + if models.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(models) + } + })? + }) + } + + fn select_model( + &self, + session_id: acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task> { + log::info!( + "Setting model for session {}: {:?}", + session_id, + model.name() + ); + let agent = self.0.clone(); + + cx.spawn(async move |cx| { + agent.update(cx, |agent, cx| { + if let Some(session) = agent.sessions.get(&session_id) { + session.thread.update(cx, |thread, _cx| { + thread.selected_model = model; + }); + Ok(()) + } else { + Err(anyhow!("Session not found")) + } + })? + }) + } + + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + let session_id = session_id.clone(); + cx.spawn(async move |cx| { + let thread = agent + .read_with(cx, |agent, _| { + agent + .sessions + .get(&session_id) + .map(|session| session.thread.clone()) + })? + .ok_or_else(|| anyhow::anyhow!("Session not found"))?; + let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + Ok(selected) + }) + } +} + +impl acp_thread::AgentConnection for NativeAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + log::info!("Creating new thread for project at: {:?}", cwd); + + cx.spawn(async move |cx| { + log::debug!("Starting thread creation in async context"); + + // Generate session ID + let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + log::info!("Created session with ID: {}", session_id); + + // Create AcpThread + let acp_thread = cx.update(|cx| { + cx.new(|cx| { + acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx) + }) + })?; + let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; + + // Create Thread + let thread = agent.update( + cx, + |agent, cx: &mut gpui::Context| -> Result<_> { + // Fetch default model from registry settings + let registry = LanguageModelRegistry::read_global(cx); + + // Log available models for debugging + let available_count = registry.available_models(cx).count(); + log::debug!("Total available models: {}", available_count); + + let default_model = registry + .default_model() + .map(|configured| { + log::info!( + "Using configured default model: {:?} from provider: {:?}", + configured.model.name(), + configured.provider.name() + ); + configured.model + }) + .ok_or_else(|| { + log::warn!("No default model configured in settings"); + anyhow!("No default model configured. Please configure a default model in settings.") + })?; + + let thread = cx.new(|cx| { + let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model); + thread.add_tool(ThinkingTool); + thread.add_tool(FindPathTool::new(project.clone())); + thread.add_tool(ReadFileTool::new(project.clone(), action_log)); + thread.add_tool(EditFileTool::new(cx.entity())); + thread + }); + + Ok(thread) + }, + )??; + + // Store the session + agent.update(cx, |agent, cx| { + agent.sessions.insert( + session_id, + Session { + thread, + acp_thread: acp_thread.downgrade(), + _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }) + }, + ); + })?; + + Ok(acp_thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] // No auth for in-process + } + + fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } + + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let session_id = params.session_id.clone(); + let agent = self.0.clone(); + log::info!("Received prompt request for session: {}", session_id); + log::debug!("Prompt blocks count: {}", params.prompt.len()); + + cx.spawn(async move |cx| { + // Get session + let (thread, acp_thread) = agent + .update(cx, |agent, _| { + agent + .sessions + .get_mut(&session_id) + .map(|s| (s.thread.clone(), s.acp_thread.clone())) + })? + .ok_or_else(|| { + log::error!("Session not found: {}", session_id); + anyhow::anyhow!("Session not found") + })?; + log::debug!("Found session for: {}", session_id); + + // Convert prompt to message + let message = convert_prompt_to_message(params.prompt); + log::info!("Converted prompt to message: {} chars", message.len()); + log::debug!("Message content: {}", message); + + // Get model using the ModelSelector capability (always available for agent2) + // Get the selected model from the thread directly + let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + + // Send to thread + log::info!("Sending message to thread with model: {:?}", model.name()); + let mut response_stream = + thread.update(cx, |thread, cx| thread.send(model, message, cx))?; + + // Handle response stream and forward to session.acp_thread + while let Some(result) = response_stream.next().await { + match result { + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + + match event { + AgentResponseEvent::Text(text) => { + acp_thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + false, + cx, + ) + })?; + } + AgentResponseEvent::Thinking(text) => { + acp_thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + true, + cx, + ) + })?; + } + AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { + tool_call, + options, + response, + }) => { + let recv = acp_thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization(tool_call, options, cx) + })?; + cx.background_spawn(async move { + if let Some(option) = recv + .await + .context("authorization sender was dropped") + .log_err() + { + response + .send(option) + .map(|_| anyhow!("authorization receiver was dropped")) + .log_err(); + } + }) + .detach(); + } + AgentResponseEvent::ToolCall(tool_call) => { + acp_thread.update(cx, |thread, cx| { + thread.upsert_tool_call(tool_call, cx) + })?; + } + AgentResponseEvent::ToolCallUpdate(update) => { + acp_thread.update(cx, |thread, cx| { + thread.update_tool_call(update, cx) + })??; + } + AgentResponseEvent::Stop(stop_reason) => { + log::debug!("Assistant message complete: {:?}", stop_reason); + return Ok(acp::PromptResponse { stop_reason }); + } + } + } + Err(e) => { + log::error!("Error in model response stream: {:?}", e); + // TODO: Consider sending an error message to the UI + break; + } + } + } + + log::info!("Response stream completed"); + anyhow::Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + log::info!("Cancelling on session: {}", session_id); + self.0.update(cx, |agent, cx| { + if let Some(agent) = agent.sessions.get(session_id) { + agent.thread.update(cx, |thread, _cx| thread.cancel()); + } + }); + } +} + +/// Convert ACP content blocks to a message string +fn convert_prompt_to_message(blocks: Vec) -> String { + log::debug!("Converting {} content blocks to message", blocks.len()); + let mut message = String::new(); + + for block in blocks { + match block { + acp::ContentBlock::Text(text) => { + log::trace!("Processing text block: {} chars", text.text.len()); + message.push_str(&text.text); + } + acp::ContentBlock::ResourceLink(link) => { + log::trace!("Processing resource link: {}", link.uri); + message.push_str(&format!(" @{} ", link.uri)); + } + acp::ContentBlock::Image(_) => { + log::trace!("Processing image block"); + message.push_str(" [image] "); + } + acp::ContentBlock::Audio(_) => { + log::trace!("Processing audio block"); + message.push_str(" [audio] "); + } + acp::ContentBlock::Resource(resource) => { + log::trace!("Processing resource block: {:?}", resource.resource); + message.push_str(&format!(" [resource: {:?}] ", resource.resource)); + } + } + } + + message +} + +#[cfg(test)] +mod tests { + use super::*; + use fs::FakeFs; + use gpui::TestAppContext; + use serde_json::json; + use settings::SettingsStore; + + #[gpui::test] + async fn test_maintaining_project_context(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": {} + }), + ) + .await; + let project = Project::test(fs.clone(), [], cx).await; + let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async()) + .await + .unwrap(); + agent.read_with(cx, |agent, _| { + assert_eq!(agent.project_context.borrow().worktrees, vec![]) + }); + + let worktree = project + .update(cx, |project, cx| project.create_worktree("/a", true, cx)) + .await + .unwrap(); + cx.run_until_parked(); + agent.read_with(cx, |agent, _| { + assert_eq!( + agent.project_context.borrow().worktrees, + vec![WorktreeContext { + root_name: "a".into(), + abs_path: Path::new("/a").into(), + rules_file: None + }] + ) + }); + + // Creating `/a/.rules` updates the project context. + fs.insert_file("/a/.rules", Vec::new()).await; + cx.run_until_parked(); + agent.read_with(cx, |agent, cx| { + let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap(); + assert_eq!( + agent.project_context.borrow().worktrees, + vec![WorktreeContext { + root_name: "a".into(), + abs_path: Path::new("/a").into(), + rules_file: Some(RulesFileContext { + path_in_worktree: Path::new(".rules").into(), + text: "".into(), + project_entry_id: rules_entry.id.to_usize() + }) + }] + ) + }); + } + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + } +} diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs new file mode 100644 index 0000000000..f13cd1bd67 --- /dev/null +++ b/crates/agent2/src/agent2.rs @@ -0,0 +1,14 @@ +mod agent; +mod native_agent_server; +mod templates; +mod thread; +mod tools; + +#[cfg(test)] +mod tests; + +pub use agent::*; +pub use native_agent_server::NativeAgentServer; +pub use templates::*; +pub use thread::*; +pub use tools::*; diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs new file mode 100644 index 0000000000..dd0188b548 --- /dev/null +++ b/crates/agent2/src/native_agent_server.rs @@ -0,0 +1,60 @@ +use std::path::Path; +use std::rc::Rc; + +use agent_servers::AgentServer; +use anyhow::Result; +use gpui::{App, Entity, Task}; +use project::Project; +use prompt_store::PromptStore; + +use crate::{templates::Templates, NativeAgent, NativeAgentConnection}; + +#[derive(Clone)] +pub struct NativeAgentServer; + +impl AgentServer for NativeAgentServer { + fn name(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_headline(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_message(&self) -> &'static str { + "How can I help you today?" + } + + fn logo(&self) -> ui::IconName { + // Using the ZedAssistant icon as it's the native built-in agent + ui::IconName::ZedAssistant + } + + fn connect( + &self, + _root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + log::info!( + "NativeAgentServer::connect called for path: {:?}", + _root_dir + ); + let project = project.clone(); + let prompt_store = PromptStore::global(cx); + cx.spawn(async move |cx| { + log::debug!("Creating templates for native agent"); + let templates = Templates::new(); + let prompt_store = prompt_store.await?; + + log::debug!("Creating native agent entity"); + let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?; + + // Create the connection wrapper + let connection = NativeAgentConnection(agent); + log::info!("NativeAgentServer connection established successfully"); + + Ok(Rc::new(connection) as Rc) + }) + } +} diff --git a/crates/agent2/src/templates.rs b/crates/agent2/src/templates.rs new file mode 100644 index 0000000000..a63f0ad206 --- /dev/null +++ b/crates/agent2/src/templates.rs @@ -0,0 +1,87 @@ +use anyhow::Result; +use gpui::SharedString; +use handlebars::Handlebars; +use rust_embed::RustEmbed; +use serde::Serialize; +use std::sync::Arc; + +#[derive(RustEmbed)] +#[folder = "src/templates"] +#[include = "*.hbs"] +struct Assets; + +pub struct Templates(Handlebars<'static>); + +impl Templates { + pub fn new() -> Arc { + let mut handlebars = Handlebars::new(); + handlebars.set_strict_mode(true); + handlebars.register_helper("contains", Box::new(contains)); + handlebars.register_embed_templates::().unwrap(); + Arc::new(Self(handlebars)) + } +} + +pub trait Template: Sized { + const TEMPLATE_NAME: &'static str; + + fn render(&self, templates: &Templates) -> Result + where + Self: Serialize + Sized, + { + Ok(templates.0.render(Self::TEMPLATE_NAME, self)?) + } +} + +#[derive(Serialize)] +pub struct SystemPromptTemplate<'a> { + #[serde(flatten)] + pub project: &'a prompt_store::ProjectContext, + pub available_tools: Vec, +} + +impl Template for SystemPromptTemplate<'_> { + const TEMPLATE_NAME: &'static str = "system_prompt.hbs"; +} + +/// Handlebars helper for checking if an item is in a list +fn contains( + h: &handlebars::Helper, + _: &handlebars::Handlebars, + _: &handlebars::Context, + _: &mut handlebars::RenderContext, + out: &mut dyn handlebars::Output, +) -> handlebars::HelperResult { + let list = h + .param(0) + .and_then(|v| v.value().as_array()) + .ok_or_else(|| { + handlebars::RenderError::new("contains: missing or invalid list parameter") + })?; + let query = h.param(1).map(|v| v.value()).ok_or_else(|| { + handlebars::RenderError::new("contains: missing or invalid query parameter") + })?; + + if list.contains(&query) { + out.write("true")?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_system_prompt_template() { + let project = prompt_store::ProjectContext::default(); + let template = SystemPromptTemplate { + project: &project, + available_tools: vec!["echo".into()], + }; + let templates = Templates::new(); + let rendered = template.render(&templates).unwrap(); + assert!(rendered.contains("## Fixing Diagnostics")); + } +} diff --git a/crates/agent2/src/templates/system_prompt.hbs b/crates/agent2/src/templates/system_prompt.hbs new file mode 100644 index 0000000000..a9f67460d8 --- /dev/null +++ b/crates/agent2/src/templates/system_prompt.hbs @@ -0,0 +1,178 @@ +You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. + +## Communication + +1. Be conversational but professional. +2. Refer to the user in the second person and yourself in the first person. +3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. +4. NEVER lie or make things up. +5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing. + +{{#if (gt (len available_tools) 0)}} +## Tool Use + +1. Make sure to adhere to the tools schema. +2. Provide every required argument. +3. DO NOT use tools to access items that are already available in the context section. +4. Use only the tools that are currently available. +5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off. +6. NEVER run commands that don't terminate on their own such as web servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers. +7. Avoid HTML entity escaping - use plain characters instead. + +## Searching and Reading + +If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions. + +If appropriate, use tool calls to explore the current project, which contains the following root directories: + +{{#each worktrees}} +- `{{abs_path}}` +{{/each}} + +- Bias towards not asking the user for help if you can find the answer yourself. +- When providing paths to tools, the path should always start with the name of a project root directory listed above. +- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path! +{{# if (contains available_tools 'grep') }} +- When looking for symbols in the project, prefer the `grep` tool. +- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project. +- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file. +{{/if}} +{{else}} +You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you). + +As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally. + +The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response. +{{/if}} + +## Code Block Formatting + +Whenever you mention a code block, you MUST use ONLY use the following format: +```path/to/Something.blah#L123-456 +(code goes here) +``` +The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah +is a path in the project. (If there is no valid path in the project, then you can use +/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser +does not understand the more common ```language syntax, or bare ``` blocks. It only +understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again. +Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP! +You have made a mistake. You can only ever put paths after triple backticks! + +Based on all the information I've gathered, here's a summary of how this system works: +1. The README file is loaded into the system. +2. The system finds the first two headers, including everything in between. In this case, that would be: +```path/to/README.md#L8-12 +# First Header +This is the info under the first header. +## Sub-header +``` +3. Then the system finds the last header in the README: +```path/to/README.md#L27-29 +## Last Header +This is the last header in the README. +``` +4. Finally, it passes this information on to the next process. + + +In Markdown, hash marks signify headings. For example: +```/dev/null/example.md#L1-3 +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +Here are examples of ways you must never render code blocks: + +In Markdown, hash marks signify headings. For example: +``` +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +This example is unacceptable because it does not include the path. + +In Markdown, hash marks signify headings. For example: +```markdown +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +This example is unacceptable because it has the language instead of the path. + +In Markdown, hash marks signify headings. For example: + # Level 1 heading + ## Level 2 heading + ### Level 3 heading + +This example is unacceptable because it uses indentation to mark the code block +instead of backticks with a path. + +In Markdown, hash marks signify headings. For example: +```markdown +/dev/null/example.md#L1-3 +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks. + +{{#if (gt (len available_tools) 0)}} +## Fixing Diagnostics + +1. Make 1-2 attempts at fixing diagnostics, then defer to the user. +2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem. + +## Debugging + +When debugging, only make code changes if you are certain that you can solve the problem. +Otherwise, follow debugging best practices: +1. Address the root cause instead of the symptoms. +2. Add descriptive logging statements and error messages to track variable and code state. +3. Add test functions and statements to isolate the problem. + +{{/if}} +## Calling External APIs + +1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission. +2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data. +3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed) + +## System Information + +Operating System: {{os}} +Default Shell: {{shell}} + +{{#if (or has_rules has_user_rules)}} +## User's Custom Instructions + +The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if (gt (len available_tools) 0)}} without interfering with the tool use guidelines{{/if}}. + +{{#if has_rules}} +There are project rules that apply to these root directories: +{{#each worktrees}} +{{#if rules_file}} +`{{root_name}}/{{rules_file.path_in_worktree}}`: +`````` +{{{rules_file.text}}} +`````` +{{/if}} +{{/each}} +{{/if}} + +{{#if has_user_rules}} +The user has specified the following rules that should be applied: +{{#each user_rules}} + +{{#if title}} +Rules title: {{title}} +{{/if}} +`````` +{{contents}}} +`````` +{{/each}} +{{/if}} +{{/if}} diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs new file mode 100644 index 0000000000..273da1dae5 --- /dev/null +++ b/crates/agent2/src/tests/mod.rs @@ -0,0 +1,846 @@ +use super::*; +use acp_thread::AgentConnection; +use agent_client_protocol::{self as acp}; +use anyhow::Result; +use assistant_tool::ActionLog; +use client::{Client, UserStore}; +use fs::FakeFs; +use futures::channel::mpsc::UnboundedReceiver; +use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext}; +use indoc::indoc; +use language_model::{ + fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult, + LanguageModelToolUse, MessageContent, Role, StopReason, +}; +use project::Project; +use prompt_store::ProjectContext; +use reqwest_client::ReqwestClient; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use smol::stream::StreamExt; +use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration}; +use util::path; + +mod test_tools; +use test_tools::*; + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_echo(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + let events = thread + .update(cx, |thread, cx| { + thread.send(model.clone(), "Testing: Reply with 'Hello'", cx) + }) + .collect() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().content, + vec![MessageContent::Text("Hello".to_string())] + ); + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_thinking(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await; + + let events = thread + .update(cx, |thread, cx| { + thread.send( + model.clone(), + indoc! {" + Testing: + + Generate a thinking step where you just think the word 'Think', + and have your final answer be 'Hello' + "}, + cx, + ) + }) + .collect() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().to_markdown(), + indoc! {" + ## assistant + Think + Hello + "} + ) + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +async fn test_system_prompt(cx: &mut TestAppContext) { + let ThreadTest { + model, + thread, + project_context, + .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + project_context.borrow_mut().shell = "test-shell".into(); + thread.update(cx, |thread, _| thread.add_tool(EchoTool)); + thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx)); + cx.run_until_parked(); + let mut pending_completions = fake_model.pending_completions(); + assert_eq!( + pending_completions.len(), + 1, + "unexpected pending completions: {:?}", + pending_completions + ); + + let pending_completion = pending_completions.pop().unwrap(); + assert_eq!(pending_completion.messages[0].role, Role::System); + + let system_message = &pending_completion.messages[0]; + let system_prompt = system_message.content[0].to_str().unwrap(); + assert!( + system_prompt.contains("test-shell"), + "unexpected system message: {:?}", + system_message + ); + assert!( + system_prompt.contains("## Fixing Diagnostics"), + "unexpected system message: {:?}", + system_message + ); +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_basic_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send( + model.clone(), + "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", + cx, + ) + }) + .collect() + .await; + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); + + // Test a tool calls that's likely to complete *after* streaming stops. + let events = thread + .update(cx, |thread, cx| { + thread.remove_tool(&AgentTool::name(&EchoTool)); + thread.add_tool(DelayTool); + thread.send( + model.clone(), + "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", + cx, + ) + }) + .collect() + .await; + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); + thread.update(cx, |thread, _cx| { + assert!(thread + .messages() + .last() + .unwrap() + .content + .iter() + .any(|content| { + if let MessageContent::Text(text) = content { + text.contains("Ding") + } else { + false + } + })); + }); +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_streaming_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(WordListTool); + thread.send(model.clone(), "Test the word_list tool.", cx) + }); + + let mut saw_partial_tool_use = false; + while let Some(event) = events.next().await { + if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { + thread.update(cx, |thread, _cx| { + // Look for a tool use in the thread's last message + let last_content = thread.messages().last().unwrap().content.last().unwrap(); + if let MessageContent::ToolUse(last_tool_use) = last_content { + assert_eq!(last_tool_use.name.as_ref(), "word_list"); + if tool_call.status == acp::ToolCallStatus::Pending { + if !last_tool_use.is_input_complete + && last_tool_use.input.get("g").is_none() + { + saw_partial_tool_use = true; + } + } else { + last_tool_use + .input + .get("a") + .expect("'a' has streamed because input is now complete"); + last_tool_use + .input + .get("g") + .expect("'g' has streamed because input is now complete"); + } + } else { + panic!("last content should be a tool use"); + } + }); + } + } + + assert!( + saw_partial_tool_use, + "should see at least one partially streamed tool use in the history" + ); +} + +#[gpui::test] +async fn test_tool_authorization(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(ToolRequiringPermission); + thread.send(model.clone(), "abc", cx) + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_1".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_2".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + let tool_call_auth_1 = next_tool_call_authorization(&mut events).await; + let tool_call_auth_2 = next_tool_call_authorization(&mut events).await; + + // Approve the first + tool_call_auth_1 + .response + .send(tool_call_auth_1.options[1].id.clone()) + .unwrap(); + cx.run_until_parked(); + + // Reject the second + tool_call_auth_2 + .response + .send(tool_call_auth_1.options[2].id.clone()) + .unwrap(); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + let message = completion.messages.last().unwrap(); + assert_eq!( + message.content, + vec![ + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: false, + content: "Allowed".into(), + output: Some("Allowed".into()) + }), + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: true, + content: "Permission to run tool denied by user".into(), + output: None + }) + ] + ); +} + +#[gpui::test] +async fn test_tool_hallucination(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx)); + cx.run_until_parked(); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_1".into(), + name: "nonexistent_tool".into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + + let tool_call = expect_tool_call(&mut events).await; + assert_eq!(tool_call.title, "nonexistent_tool"); + assert_eq!(tool_call.status, acp::ToolCallStatus::Pending); + let update = expect_tool_call_update_fields(&mut events).await; + assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed)); +} + +async fn expect_tool_call( + events: &mut UnboundedReceiver>, +) -> acp::ToolCall { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + match event { + AgentResponseEvent::ToolCall(tool_call) => return tool_call, + event => { + panic!("Unexpected event {event:?}"); + } + } +} + +async fn expect_tool_call_update_fields( + events: &mut UnboundedReceiver>, +) -> acp::ToolCallUpdate { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + match event { + AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { + return update + } + event => { + panic!("Unexpected event {event:?}"); + } + } +} + +async fn next_tool_call_authorization( + events: &mut UnboundedReceiver>, +) -> ToolCallAuthorization { + loop { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { + let permission_kinds = tool_call_authorization + .options + .iter() + .map(|o| o.kind) + .collect::>(); + assert_eq!( + permission_kinds, + vec![ + acp::PermissionOptionKind::AllowAlways, + acp::PermissionOptionKind::AllowOnce, + acp::PermissionOptionKind::RejectOnce, + ] + ); + return tool_call_authorization; + } + } +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + // Test concurrent tool calls with different delay times + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(DelayTool); + thread.send( + model.clone(), + "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", + cx, + ) + }) + .collect() + .await; + + let stop_reasons = stop_events(events); + assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]); + + thread.update(cx, |thread, _cx| { + let last_message = thread.messages().last().unwrap(); + let text = last_message + .content + .iter() + .filter_map(|content| { + if let MessageContent::Text(text) = content { + Some(text.as_str()) + } else { + None + } + }) + .collect::(); + + assert!(text.contains("Ding")); + }); +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] +async fn test_cancellation(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(InfiniteTool); + thread.add_tool(EchoTool); + thread.send( + model.clone(), + "Call the echo tool and then call the infinite tool, then explain their output", + cx, + ) + }); + + // Wait until both tools are called. + let mut expected_tools = vec!["Echo", "Infinite Tool"]; + let mut echo_id = None; + let mut echo_completed = false; + while let Some(event) = events.next().await { + match event.unwrap() { + AgentResponseEvent::ToolCall(tool_call) => { + assert_eq!(tool_call.title, expected_tools.remove(0)); + if tool_call.title == "Echo" { + echo_id = Some(tool_call.id); + } + } + AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + acp::ToolCallUpdate { + id, + fields: + acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + .. + }, + }, + )) if Some(&id) == echo_id.as_ref() => { + echo_completed = true; + } + _ => {} + } + + if expected_tools.is_empty() && echo_completed { + break; + } + } + + // Cancel the current send and ensure that the event stream is closed, even + // if one of the tools is still running. + thread.update(cx, |thread, _cx| thread.cancel()); + events.collect::>().await; + + // Ensure we can still send a new message after cancellation. + let events = thread + .update(cx, |thread, cx| { + thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx) + }) + .collect::>() + .await; + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.messages().last().unwrap().content, + vec![MessageContent::Text("Hello".to_string())] + ); + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +async fn test_refusal(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx)); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## user + Hello + "} + ); + }); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## user + Hello + ## assistant + Hey! + "} + ); + }); + + // If the model refuses to continue, the thread should remove all the messages after the last user message. + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal)); + let events = events.collect::>().await; + assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]); + thread.read_with(cx, |thread, _| { + assert_eq!(thread.to_markdown(), ""); + }); +} + +#[gpui::test] +async fn test_agent_connection(cx: &mut TestAppContext) { + cx.update(settings::init); + let templates = Templates::new(); + + // Initialize language model system with test provider + cx.update(|cx| { + gpui_tokio::init(cx); + client::init_settings(cx); + + let http_client = FakeHttpClient::with_404_response(); + let clock = Arc::new(clock::FakeSystemClock::new()); + let client = Client::new(clock, http_client, cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + Project::init_settings(cx); + LanguageModelRegistry::test(cx); + }); + cx.executor().forbid_parking(); + + // Create a project for new_thread + let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); + fake_fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + let cwd = Path::new("/test"); + + // Create agent and connection + let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async()) + .await + .unwrap(); + let connection = NativeAgentConnection(agent.clone()); + + // Test model_selector returns Some + let selector_opt = connection.model_selector(); + assert!( + selector_opt.is_some(), + "agent2 should always support ModelSelector" + ); + let selector = selector_opt.unwrap(); + + // Test list_models + let listed_models = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.list_models(&mut async_cx) + }) + .await + .expect("list_models should succeed"); + assert!(!listed_models.is_empty(), "should have at least one model"); + assert_eq!(listed_models[0].id().0, "fake"); + + // Create a thread using new_thread + let connection_rc = Rc::new(connection.clone()); + let acp_thread = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + connection_rc.new_thread(project, cwd, &mut async_cx) + }) + .await + .expect("new_thread should succeed"); + + // Get the session_id from the AcpThread + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + + // Test selected_model returns the default + let model = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&session_id, &mut async_cx) + }) + .await + .expect("selected_model should succeed"); + let model = model.as_fake(); + assert_eq!(model.id().0, "fake", "should return default model"); + + let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("def"); + cx.run_until_parked(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + abc + + ## Assistant + + def + + "} + ) + }); + + // Test cancel + cx.update(|cx| connection.cancel(&session_id, cx)); + request.await.expect("prompt should fail gracefully"); + + // Ensure that dropping the ACP thread causes the native thread to be + // dropped as well. + cx.update(|_| drop(acp_thread)); + let result = cx + .update(|cx| { + connection.prompt( + acp::PromptRequest { + session_id: session_id.clone(), + prompt: vec!["ghi".into()], + }, + cx, + ) + }) + .await; + assert_eq!( + result.as_ref().unwrap_err().to_string(), + "Session not found", + "unexpected result: {:?}", + result + ); +} + +#[gpui::test] +async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool)); + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx)); + cx.run_until_parked(); + + // Simulate streaming partial input. + let input = json!({}); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "1".into(), + name: ThinkingTool.name().into(), + raw_input: input.to_string(), + input, + is_input_complete: false, + }, + )); + + // Input streaming completed + let input = json!({ "content": "Thinking hard!" }); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "1".into(), + name: "thinking".into(), + raw_input: input.to_string(), + input, + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let tool_call = expect_tool_call(&mut events).await; + assert_eq!( + tool_call, + acp::ToolCall { + id: acp::ToolCallId("1".into()), + title: "Thinking".into(), + kind: acp::ToolKind::Think, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(json!({})), + raw_output: None, + } + ); + let update = expect_tool_call_update_fields(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + title: Some("Thinking".into()), + kind: Some(acp::ToolKind::Think), + raw_input: Some(json!({ "content": "Thinking hard!" })), + ..Default::default() + }, + } + ); + let update = expect_tool_call_update_fields(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress), + ..Default::default() + }, + } + ); + let update = expect_tool_call_update_fields(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + content: Some(vec!["Thinking hard!".into()]), + ..Default::default() + }, + } + ); + let update = expect_tool_call_update_fields(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + ..Default::default() + }, + } + ); +} + +/// Filters out the stop events for asserting against in tests +fn stop_events( + result_events: Vec>, +) -> Vec { + result_events + .into_iter() + .filter_map(|event| match event.unwrap() { + AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), + _ => None, + }) + .collect() +} + +struct ThreadTest { + model: Arc, + thread: Entity, + project_context: Rc>, +} + +enum TestModel { + Sonnet4, + Sonnet4Thinking, + Fake, +} + +impl TestModel { + fn id(&self) -> LanguageModelId { + match self { + TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()), + TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()), + TestModel::Fake => unreachable!(), + } + } +} + +async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { + cx.executor().allow_parking(); + cx.update(|cx| { + settings::init(cx); + Project::init_settings(cx); + }); + let templates = Templates::new(); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + + let model = cx + .update(|cx| { + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + if let TestModel::Fake = model { + Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>) + } else { + let model_id = model.id(); + let models = LanguageModelRegistry::read_global(cx); + let model = models + .available_models(cx) + .find(|model| model.id() == model_id) + .unwrap(); + + let provider = models.provider(&model.provider_id()).unwrap(); + let authenticated = provider.authenticate(cx); + + cx.spawn(async move |_cx| { + authenticated.await.unwrap(); + model + }) + } + }) + .await; + + let project_context = Rc::new(RefCell::new(ProjectContext::default())); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let thread = cx.new(|_| { + Thread::new( + project, + project_context.clone(), + action_log, + templates, + model.clone(), + ) + }); + ThreadTest { + model, + thread, + project_context, + } +} + +#[cfg(test)] +#[ctor::ctor] +fn init_logger() { + if std::env::var("RUST_LOG").is_ok() { + env_logger::init(); + } +} diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs new file mode 100644 index 0000000000..d06614f3fe --- /dev/null +++ b/crates/agent2/src/tests/test_tools.rs @@ -0,0 +1,201 @@ +use super::*; +use anyhow::Result; +use gpui::{App, SharedString, Task}; +use std::future; + +/// A tool that echoes its input +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct EchoToolInput { + /// The text to echo. + text: String, +} + +pub struct EchoTool; + +impl AgentTool for EchoTool { + type Input = EchoToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "echo".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title(&self, _input: Result) -> SharedString { + "Echo".into() + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(input.text)) + } +} + +/// A tool that waits for a specified delay +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct DelayToolInput { + /// The delay in milliseconds. + ms: u64, +} + +pub struct DelayTool; + +impl AgentTool for DelayTool { + type Input = DelayToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "delay".into() + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Ok(input) = input { + format!("Delay {}ms", input.ms).into() + } else { + "Delay".into() + } + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> + where + Self: Sized, + { + cx.foreground_executor().spawn(async move { + smol::Timer::after(Duration::from_millis(input.ms)).await; + Ok("Ding".to_string()) + }) + } +} + +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct ToolRequiringPermissionInput {} + +pub struct ToolRequiringPermission; + +impl AgentTool for ToolRequiringPermission { + type Input = ToolRequiringPermissionInput; + type Output = String; + + fn name(&self) -> SharedString { + "tool_requiring_permission".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title(&self, _input: Result) -> SharedString { + "This tool requires permission".into() + } + + fn run( + self: Arc, + _input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let auth_check = event_stream.authorize("Authorize?".into()); + cx.foreground_executor().spawn(async move { + auth_check.await?; + Ok("Allowed".to_string()) + }) + } +} + +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct InfiniteToolInput {} + +pub struct InfiniteTool; + +impl AgentTool for InfiniteTool { + type Input = InfiniteToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "infinite".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title(&self, _input: Result) -> SharedString { + "Infinite Tool".into() + } + + fn run( + self: Arc, + _input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.foreground_executor().spawn(async move { + future::pending::<()>().await; + unreachable!() + }) + } +} + +/// A tool that takes an object with map from letters to random words starting with that letter. +/// All fiealds are required! Pass a word for every letter! +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct WordListInput { + /// Provide a random word that starts with A. + a: Option, + /// Provide a random word that starts with B. + b: Option, + /// Provide a random word that starts with C. + c: Option, + /// Provide a random word that starts with D. + d: Option, + /// Provide a random word that starts with E. + e: Option, + /// Provide a random word that starts with F. + f: Option, + /// Provide a random word that starts with G. + g: Option, +} + +pub struct WordListTool; + +impl AgentTool for WordListTool { + type Input = WordListInput; + type Output = String; + + fn name(&self) -> SharedString { + "word_list".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title(&self, _input: Result) -> SharedString { + "List of random words".into() + } + + fn run( + self: Arc, + _input: Self::Input, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok("ok".to_string())) + } +} diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs new file mode 100644 index 0000000000..f664e0f5d2 --- /dev/null +++ b/crates/agent2/src/thread.rs @@ -0,0 +1,1026 @@ +use crate::{SystemPromptTemplate, Template, Templates}; +use acp_thread::Diff; +use agent_client_protocol as acp; +use anyhow::{anyhow, Context as _, Result}; +use assistant_tool::{adapt_schema_to_format, ActionLog}; +use cloud_llm_client::{CompletionIntent, CompletionMode}; +use collections::HashMap; +use futures::{ + channel::{mpsc, oneshot}, + stream::FuturesUnordered, +}; +use gpui::{App, Context, Entity, SharedString, Task}; +use language_model::{ + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, +}; +use log; +use project::Project; +use prompt_store::ProjectContext; +use schemars::{JsonSchema, Schema}; +use serde::{Deserialize, Serialize}; +use smol::stream::StreamExt; +use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc}; +use util::{markdown::MarkdownCodeBlock, ResultExt}; + +#[derive(Debug, Clone)] +pub struct AgentMessage { + pub role: Role, + pub content: Vec, +} + +impl AgentMessage { + pub fn to_markdown(&self) -> String { + let mut markdown = format!("## {}\n", self.role); + + for content in &self.content { + match content { + MessageContent::Text(text) => { + markdown.push_str(text); + markdown.push('\n'); + } + MessageContent::Thinking { text, .. } => { + markdown.push_str(""); + markdown.push_str(text); + markdown.push_str("\n"); + } + MessageContent::RedactedThinking(_) => markdown.push_str("\n"), + MessageContent::Image(_) => { + markdown.push_str("\n"); + } + MessageContent::ToolUse(tool_use) => { + markdown.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + markdown.push_str(&format!( + "{}\n", + MarkdownCodeBlock { + tag: "json", + text: &format!("{:#}", tool_use.input) + } + )); + } + MessageContent::ToolResult(tool_result) => { + markdown.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + markdown.push_str("**ERROR:**\n"); + } + + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + writeln!(markdown, "{text}\n").ok(); + } + LanguageModelToolResultContent::Image(_) => { + writeln!(markdown, "\n").ok(); + } + } + + if let Some(output) = tool_result.output.as_ref() { + writeln!( + markdown, + "**Debug Output**:\n\n```json\n{}\n```\n", + serde_json::to_string_pretty(output).unwrap() + ) + .unwrap(); + } + } + } + } + + markdown + } +} + +#[derive(Debug)] +pub enum AgentResponseEvent { + Text(String), + Thinking(String), + ToolCall(acp::ToolCall), + ToolCallUpdate(acp_thread::ToolCallUpdate), + ToolCallAuthorization(ToolCallAuthorization), + Stop(acp::StopReason), +} + +#[derive(Debug)] +pub struct ToolCallAuthorization { + pub tool_call: acp::ToolCall, + pub options: Vec, + pub response: oneshot::Sender, +} + +pub struct Thread { + messages: Vec, + completion_mode: CompletionMode, + /// Holds the task that handles agent interaction until the end of the turn. + /// Survives across multiple requests as the model performs tool calls and + /// we run tools, report their results. + running_turn: Option>, + pending_tool_uses: HashMap, + tools: BTreeMap>, + project_context: Rc>, + templates: Arc, + pub selected_model: Arc, + project: Entity, + action_log: Entity, +} + +impl Thread { + pub fn new( + project: Entity, + project_context: Rc>, + action_log: Entity, + templates: Arc, + default_model: Arc, + ) -> Self { + Self { + messages: Vec::new(), + completion_mode: CompletionMode::Normal, + running_turn: None, + pending_tool_uses: HashMap::default(), + tools: BTreeMap::default(), + project_context, + templates, + selected_model: default_model, + project, + action_log, + } + } + + pub fn project(&self) -> &Entity { + &self.project + } + + pub fn action_log(&self) -> &Entity { + &self.action_log + } + + pub fn set_mode(&mut self, mode: CompletionMode) { + self.completion_mode = mode; + } + + pub fn messages(&self) -> &[AgentMessage] { + &self.messages + } + + pub fn add_tool(&mut self, tool: impl AgentTool) { + self.tools.insert(tool.name(), tool.erase()); + } + + pub fn remove_tool(&mut self, name: &str) -> bool { + self.tools.remove(name).is_some() + } + + pub fn cancel(&mut self) { + self.running_turn.take(); + + let tool_results = self + .pending_tool_uses + .drain() + .map(|(tool_use_id, tool_use)| { + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id, + tool_name: tool_use.name.clone(), + is_error: true, + content: LanguageModelToolResultContent::Text("Tool canceled by user".into()), + output: None, + }) + }) + .collect::>(); + self.last_user_message().content.extend(tool_results); + } + + /// Sending a message results in the model streaming a response, which could include tool calls. + /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. + /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. + pub fn send( + &mut self, + model: Arc, + content: impl Into, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let content = content.into(); + log::info!("Thread::send called with model: {:?}", model.name()); + log::debug!("Thread::send content: {:?}", content); + + cx.notify(); + let (events_tx, events_rx) = + mpsc::unbounded::>(); + let event_stream = AgentResponseEventStream(events_tx); + + let user_message_ix = self.messages.len(); + self.messages.push(AgentMessage { + role: Role::User, + content: vec![content], + }); + log::info!("Total messages in thread: {}", self.messages.len()); + self.running_turn = Some(cx.spawn(async move |thread, cx| { + log::info!("Starting agent turn execution"); + let turn_result = async { + // Perform one request, then keep looping if the model makes tool calls. + let mut completion_intent = CompletionIntent::UserPrompt; + 'outer: loop { + log::debug!( + "Building completion request with intent: {:?}", + completion_intent + ); + let request = thread.update(cx, |thread, cx| { + thread.build_completion_request(completion_intent, cx) + })?; + + // println!( + // "request: {}", + // serde_json::to_string_pretty(&request).unwrap() + // ); + + // Stream events, appending to messages and collecting up tool uses. + log::info!("Calling model.stream_completion"); + let mut events = model.stream_completion(request, cx).await?; + log::debug!("Stream completion started successfully"); + let mut tool_uses = FuturesUnordered::new(); + while let Some(event) = events.next().await { + match event { + Ok(LanguageModelCompletionEvent::Stop(reason)) => { + event_stream.send_stop(reason); + if reason == StopReason::Refusal { + thread.update(cx, |thread, _cx| { + thread.messages.truncate(user_message_ix); + })?; + break 'outer; + } + } + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + thread + .update(cx, |thread, cx| { + tool_uses.extend(thread.handle_streamed_completion_event( + event, + &event_stream, + cx, + )); + }) + .ok(); + } + Err(error) => { + log::error!("Error in completion stream: {:?}", error); + event_stream.send_error(error); + break; + } + } + } + + // If there are no tool uses, the turn is done. + if tool_uses.is_empty() { + log::info!("No tool uses found, completing turn"); + break; + } + log::info!("Found {} tool uses to execute", tool_uses.len()); + + // As tool results trickle in, insert them in the last user + // message so that they can be sent on the next tick of the + // agentic loop. + while let Some(tool_result) = tool_uses.next().await { + log::info!("Tool finished {:?}", tool_result); + + event_stream.update_tool_call_fields( + &tool_result.tool_use_id, + acp::ToolCallUpdateFields { + status: Some(if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }), + ..Default::default() + }, + ); + thread + .update(cx, |thread, _cx| { + thread.pending_tool_uses.remove(&tool_result.tool_use_id); + thread + .last_user_message() + .content + .push(MessageContent::ToolResult(tool_result)); + }) + .ok(); + } + + completion_intent = CompletionIntent::ToolResults; + } + + Ok(()) + } + .await; + + if let Err(error) = turn_result { + log::error!("Turn execution failed: {:?}", error); + event_stream.send_error(error); + } else { + log::info!("Turn execution completed successfully"); + } + })); + events_rx + } + + pub fn build_system_message(&self) -> AgentMessage { + log::debug!("Building system message"); + let prompt = SystemPromptTemplate { + project: &self.project_context.borrow(), + available_tools: self.tools.keys().cloned().collect(), + } + .render(&self.templates) + .context("failed to build system prompt") + .expect("Invalid template"); + log::debug!("System message built"); + AgentMessage { + role: Role::System, + content: vec![prompt.into()], + } + } + + /// A helper method that's called on every streamed completion event. + /// Returns an optional tool result task, which the main agentic loop in + /// send will send back to the model when it resolves. + fn handle_streamed_completion_event( + &mut self, + event: LanguageModelCompletionEvent, + event_stream: &AgentResponseEventStream, + cx: &mut Context, + ) -> Option> { + log::trace!("Handling streamed completion event: {:?}", event); + use LanguageModelCompletionEvent::*; + + match event { + StartMessage { .. } => { + self.messages.push(AgentMessage { + role: Role::Assistant, + content: Vec::new(), + }); + } + Text(new_text) => self.handle_text_event(new_text, event_stream, cx), + Thinking { text, signature } => { + self.handle_thinking_event(text, signature, event_stream, cx) + } + RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx), + ToolUse(tool_use) => { + return self.handle_tool_use_event(tool_use, event_stream, cx); + } + ToolUseJsonParseError { + id, + tool_name, + raw_input, + json_parse_error, + } => { + return Some(Task::ready(self.handle_tool_use_json_parse_error_event( + id, + tool_name, + raw_input, + json_parse_error, + ))); + } + UsageUpdate(_) | StatusUpdate(_) => {} + Stop(_) => unreachable!(), + } + + None + } + + fn handle_text_event( + &mut self, + new_text: String, + events_stream: &AgentResponseEventStream, + cx: &mut Context, + ) { + events_stream.send_text(&new_text); + + let last_message = self.last_assistant_message(); + if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { + text.push_str(&new_text); + } else { + last_message.content.push(MessageContent::Text(new_text)); + } + + cx.notify(); + } + + fn handle_thinking_event( + &mut self, + new_text: String, + new_signature: Option, + event_stream: &AgentResponseEventStream, + cx: &mut Context, + ) { + event_stream.send_thinking(&new_text); + + let last_message = self.last_assistant_message(); + if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut() + { + text.push_str(&new_text); + *signature = new_signature.or(signature.take()); + } else { + last_message.content.push(MessageContent::Thinking { + text: new_text, + signature: new_signature, + }); + } + + cx.notify(); + } + + fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context) { + let last_message = self.last_assistant_message(); + last_message + .content + .push(MessageContent::RedactedThinking(data)); + cx.notify(); + } + + fn handle_tool_use_event( + &mut self, + tool_use: LanguageModelToolUse, + event_stream: &AgentResponseEventStream, + cx: &mut Context, + ) -> Option> { + cx.notify(); + + let tool = self.tools.get(tool_use.name.as_ref()).cloned(); + + self.pending_tool_uses + .insert(tool_use.id.clone(), tool_use.clone()); + let last_message = self.last_assistant_message(); + + // Ensure the last message ends in the current tool use + let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { + if let MessageContent::ToolUse(last_tool_use) = content { + if last_tool_use.id == tool_use.id { + *last_tool_use = tool_use.clone(); + false + } else { + true + } + } else { + true + } + }); + + let mut title = SharedString::from(&tool_use.name); + let mut kind = acp::ToolKind::Other; + if let Some(tool) = tool.as_ref() { + title = tool.initial_title(tool_use.input.clone()); + kind = tool.kind(); + } + + if push_new_tool_use { + event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + last_message + .content + .push(MessageContent::ToolUse(tool_use.clone())); + } else { + event_stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + title: Some(title.into()), + kind: Some(kind), + raw_input: Some(tool_use.input.clone()), + ..Default::default() + }, + ); + } + + if !tool_use.is_input_complete { + return None; + } + + let Some(tool) = tool else { + let content = format!("No tool named {} exists", tool_use.name); + return Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(content)), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })); + }; + + let tool_event_stream = + ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone()); + tool_event_stream.update_fields(acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress), + ..Default::default() + }); + let supports_images = self.selected_model.supports_images(); + let tool_result = tool.run(tool_use.input, tool_event_stream, cx); + Some(cx.foreground_executor().spawn(async move { + let tool_result = tool_result.await.and_then(|output| { + if let LanguageModelToolResultContent::Image(_) = &output.llm_output { + if !supports_images { + return Err(anyhow!( + "Attempted to read an image, but this model doesn't support it.", + )); + } + } + Ok(output) + }); + + match tool_result { + Ok(output) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: false, + content: output.llm_output, + output: Some(output.raw_output), + }, + Err(error) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), + output: None, + }, + } + })) + } + + fn handle_tool_use_json_parse_error_event( + &mut self, + tool_use_id: LanguageModelToolUseId, + tool_name: Arc, + raw_input: Arc, + json_parse_error: String, + ) -> LanguageModelToolResult { + let tool_output = format!("Error parsing input JSON: {json_parse_error}"); + LanguageModelToolResult { + tool_use_id, + tool_name, + is_error: true, + content: LanguageModelToolResultContent::Text(tool_output.into()), + output: Some(serde_json::Value::String(raw_input.to_string())), + } + } + + /// Guarantees the last message is from the assistant and returns a mutable reference. + fn last_assistant_message(&mut self) -> &mut AgentMessage { + if self + .messages + .last() + .map_or(true, |m| m.role != Role::Assistant) + { + self.messages.push(AgentMessage { + role: Role::Assistant, + content: Vec::new(), + }); + } + self.messages.last_mut().unwrap() + } + + /// Guarantees the last message is from the user and returns a mutable reference. + fn last_user_message(&mut self) -> &mut AgentMessage { + if self.messages.last().map_or(true, |m| m.role != Role::User) { + self.messages.push(AgentMessage { + role: Role::User, + content: Vec::new(), + }); + } + self.messages.last_mut().unwrap() + } + + pub(crate) fn build_completion_request( + &self, + completion_intent: CompletionIntent, + cx: &mut App, + ) -> LanguageModelRequest { + log::debug!("Building completion request"); + log::debug!("Completion intent: {:?}", completion_intent); + log::debug!("Completion mode: {:?}", self.completion_mode); + + let messages = self.build_request_messages(); + log::info!("Request will include {} messages", messages.len()); + + let tools: Vec = self + .tools + .values() + .filter_map(|tool| { + let tool_name = tool.name().to_string(); + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name, + description: tool.description(cx).to_string(), + input_schema: tool + .input_schema(self.selected_model.tool_input_format()) + .log_err()?, + }) + }) + .collect(); + + log::info!("Request includes {} tools", tools.len()); + + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: Some(completion_intent), + mode: Some(self.completion_mode), + messages, + tools, + tool_choice: None, + stop: Vec::new(), + temperature: None, + thinking_allowed: true, + }; + + log::debug!("Completion request built successfully"); + request + } + + fn build_request_messages(&self) -> Vec { + log::trace!( + "Building request messages from {} thread messages", + self.messages.len() + ); + + let messages = Some(self.build_system_message()) + .iter() + .chain(self.messages.iter()) + .map(|message| { + log::trace!( + " - {} message with {} content items", + match message.role { + Role::System => "System", + Role::User => "User", + Role::Assistant => "Assistant", + }, + message.content.len() + ); + LanguageModelRequestMessage { + role: message.role, + content: message.content.clone(), + cache: false, + } + }) + .collect(); + messages + } + + pub fn to_markdown(&self) -> String { + let mut markdown = String::new(); + for message in &self.messages { + markdown.push_str(&message.to_markdown()); + } + markdown + } +} + +pub trait AgentTool +where + Self: 'static + Sized, +{ + type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema; + type Output: for<'de> Deserialize<'de> + Serialize + Into; + + fn name(&self) -> SharedString; + + fn description(&self, _cx: &mut App) -> SharedString { + let schema = schemars::schema_for!(Self::Input); + SharedString::new( + schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap_or_default(), + ) + } + + fn kind(&self) -> acp::ToolKind; + + /// The initial tool title to display. Can be updated during the tool run. + fn initial_title(&self, input: Result) -> SharedString; + + /// Returns the JSON schema that describes the tool's input. + fn input_schema(&self) -> Schema { + schemars::schema_for!(Self::Input) + } + + /// Runs the tool with the provided input. + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task>; + + fn erase(self) -> Arc { + Arc::new(Erased(Arc::new(self))) + } +} + +pub struct Erased(T); + +pub struct AgentToolOutput { + llm_output: LanguageModelToolResultContent, + raw_output: serde_json::Value, +} + +pub trait AnyAgentTool { + fn name(&self) -> SharedString; + fn description(&self, cx: &mut App) -> SharedString; + fn kind(&self) -> acp::ToolKind; + fn initial_title(&self, input: serde_json::Value) -> SharedString; + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn run( + self: Arc, + input: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task>; +} + +impl AnyAgentTool for Erased> +where + T: AgentTool, +{ + fn name(&self) -> SharedString { + self.0.name() + } + + fn description(&self, cx: &mut App) -> SharedString { + self.0.description(cx) + } + + fn kind(&self) -> agent_client_protocol::ToolKind { + self.0.kind() + } + + fn initial_title(&self, input: serde_json::Value) -> SharedString { + let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input); + self.0.initial_title(parsed_input) + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + let mut json = serde_json::to_value(self.0.input_schema())?; + adapt_schema_to_format(&mut json, format)?; + Ok(json) + } + + fn run( + self: Arc, + input: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.spawn(async move |cx| { + let input = serde_json::from_value(input)?; + let output = cx + .update(|cx| self.0.clone().run(input, event_stream, cx))? + .await?; + let raw_output = serde_json::to_value(&output)?; + Ok(AgentToolOutput { + llm_output: output.into(), + raw_output, + }) + }) + } +} + +#[derive(Clone)] +struct AgentResponseEventStream( + mpsc::UnboundedSender>, +); + +impl AgentResponseEventStream { + fn send_text(&self, text: &str) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string()))) + .ok(); + } + + fn send_thinking(&self, text: &str) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string()))) + .ok(); + } + + fn authorize_tool_call( + &self, + id: &LanguageModelToolUseId, + title: String, + kind: acp::ToolKind, + input: serde_json::Value, + ) -> impl use<> + Future> { + let (response_tx, response_rx) = oneshot::channel(); + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( + ToolCallAuthorization { + tool_call: Self::initial_tool_call(id, title, kind, input), + options: vec![ + acp::PermissionOption { + id: acp::PermissionOptionId("always_allow".into()), + name: "Always Allow".into(), + kind: acp::PermissionOptionKind::AllowAlways, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("allow".into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("deny".into()), + name: "Deny".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + response: response_tx, + }, + ))) + .ok(); + async move { + match response_rx.await?.0.as_ref() { + "allow" | "always_allow" => Ok(()), + _ => Err(anyhow!("Permission to run tool denied by user")), + } + } + } + + fn send_tool_call( + &self, + id: &LanguageModelToolUseId, + title: SharedString, + kind: acp::ToolKind, + input: serde_json::Value, + ) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call( + id, + title.to_string(), + kind, + input, + )))) + .ok(); + } + + fn initial_tool_call( + id: &LanguageModelToolUseId, + title: String, + kind: acp::ToolKind, + input: serde_json::Value, + ) -> acp::ToolCall { + acp::ToolCall { + id: acp::ToolCallId(id.to_string().into()), + title, + kind, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(input), + raw_output: None, + } + } + + fn update_tool_call_fields( + &self, + tool_use_id: &LanguageModelToolUseId, + fields: acp::ToolCallUpdateFields, + ) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use_id.to_string().into()), + fields, + } + .into(), + ))) + .ok(); + } + + fn update_tool_call_diff(&self, tool_use_id: &LanguageModelToolUseId, diff: Entity) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + acp_thread::ToolCallUpdateDiff { + id: acp::ToolCallId(tool_use_id.to_string().into()), + diff, + } + .into(), + ))) + .ok(); + } + + fn send_stop(&self, reason: StopReason) { + match reason { + StopReason::EndTurn => { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn))) + .ok(); + } + StopReason::MaxTokens => { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens))) + .ok(); + } + StopReason::Refusal => { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal))) + .ok(); + } + StopReason::ToolUse => {} + } + } + + fn send_error(&self, error: LanguageModelCompletionError) { + self.0.unbounded_send(Err(error)).ok(); + } +} + +#[derive(Clone)] +pub struct ToolCallEventStream { + tool_use_id: LanguageModelToolUseId, + kind: acp::ToolKind, + input: serde_json::Value, + stream: AgentResponseEventStream, +} + +impl ToolCallEventStream { + #[cfg(test)] + pub fn test() -> (Self, ToolCallEventStreamReceiver) { + let (events_tx, events_rx) = + mpsc::unbounded::>(); + + let stream = ToolCallEventStream::new( + &LanguageModelToolUse { + id: "test_id".into(), + name: "test_tool".into(), + raw_input: String::new(), + input: serde_json::Value::Null, + is_input_complete: true, + }, + acp::ToolKind::Other, + AgentResponseEventStream(events_tx), + ); + + (stream, ToolCallEventStreamReceiver(events_rx)) + } + + fn new( + tool_use: &LanguageModelToolUse, + kind: acp::ToolKind, + stream: AgentResponseEventStream, + ) -> Self { + Self { + tool_use_id: tool_use.id.clone(), + kind, + input: tool_use.input.clone(), + stream, + } + } + + pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) { + self.stream + .update_tool_call_fields(&self.tool_use_id, fields); + } + + pub fn update_diff(&self, diff: Entity) { + self.stream.update_tool_call_diff(&self.tool_use_id, diff); + } + + pub fn authorize(&self, title: String) -> impl use<> + Future> { + self.stream.authorize_tool_call( + &self.tool_use_id, + title, + self.kind.clone(), + self.input.clone(), + ) + } +} + +#[cfg(test)] +pub struct ToolCallEventStreamReceiver( + mpsc::UnboundedReceiver>, +); + +#[cfg(test)] +impl ToolCallEventStreamReceiver { + pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization { + let event = self.0.next().await; + if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event { + auth + } else { + panic!("Expected ToolCallAuthorization but got: {:?}", event); + } + } +} + +#[cfg(test)] +impl std::ops::Deref for ToolCallEventStreamReceiver { + type Target = mpsc::UnboundedReceiver>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(test)] +impl std::ops::DerefMut for ToolCallEventStreamReceiver { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs new file mode 100644 index 0000000000..5fe13db854 --- /dev/null +++ b/crates/agent2/src/tools.rs @@ -0,0 +1,9 @@ +mod edit_file_tool; +mod find_path_tool; +mod read_file_tool; +mod thinking_tool; + +pub use edit_file_tool::*; +pub use find_path_tool::*; +pub use read_file_tool::*; +pub use thinking_tool::*; diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs new file mode 100644 index 0000000000..0858bb501c --- /dev/null +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -0,0 +1,1449 @@ +use crate::{AgentTool, Thread, ToolCallEventStream}; +use acp_thread::Diff; +use agent_client_protocol as acp; +use anyhow::{anyhow, Context as _, Result}; +use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; +use cloud_llm_client::CompletionIntent; +use collections::HashSet; +use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use indoc::formatdoc; +use language::language_settings::{self, FormatOnSave}; +use language_model::LanguageModelToolResultContent; +use paths; +use project::lsp_store::{FormatTrigger, LspFormatTarget}; +use project::{Project, ProjectPath}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::Settings; +use smol::stream::StreamExt as _; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use ui::SharedString; +use util::ResultExt; + +const DEFAULT_UI_TEXT: &str = "Editing file"; + +/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. +/// +/// Before using this tool: +/// +/// 1. Use the `read_file` tool to understand the file's contents and context +/// +/// 2. Verify the directory path is correct (only applicable when creating new files): +/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct EditFileToolInput { + /// A one-line, user-friendly markdown description of the edit. This will be + /// shown in the UI and also passed to another model to perform the edit. + /// + /// Be terse, but also descriptive in what you want to achieve with this + /// edit. Avoid generic instructions. + /// + /// NEVER mention the file path in this description. + /// + /// Fix API endpoint URLs + /// Update copyright year in `page_footer` + /// + /// Make sure to include this field before all the others in the input object + /// so that we can display it immediately. + pub display_description: String, + + /// The full path of the file to create or modify in the project. + /// + /// WARNING: When specifying which file path need changing, you MUST + /// start each path with one of the project's root directories. + /// + /// The following examples assume we have two root directories in the project: + /// - /a/b/backend + /// - /c/d/frontend + /// + /// + /// `backend/src/main.rs` + /// + /// Notice how the file path starts with `backend`. Without that, the path + /// would be ambiguous and the call would fail! + /// + /// + /// + /// `frontend/db.js` + /// + pub path: PathBuf, + + /// The mode of operation on the file. Possible values: + /// - 'edit': Make granular edits to an existing file. + /// - 'create': Create a new file if it doesn't exist. + /// - 'overwrite': Replace the entire contents of an existing file. + /// + /// When a file already exists or you just created it, prefer editing + /// it as opposed to recreating it from scratch. + pub mode: EditFileMode, +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +struct EditFileToolPartialInput { + #[serde(default)] + path: String, + #[serde(default)] + display_description: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "lowercase")] +pub enum EditFileMode { + Edit, + Create, + Overwrite, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct EditFileToolOutput { + input_path: PathBuf, + project_path: PathBuf, + new_text: String, + old_text: Arc, + diff: String, + edit_agent_output: EditAgentOutput, +} + +impl From for LanguageModelToolResultContent { + fn from(output: EditFileToolOutput) -> Self { + if output.diff.is_empty() { + "No edits were made.".into() + } else { + format!( + "Edited {}:\n\n```diff\n{}\n```", + output.input_path.display(), + output.diff + ) + .into() + } + } +} + +pub struct EditFileTool { + thread: Entity, +} + +impl EditFileTool { + pub fn new(thread: Entity) -> Self { + Self { thread } + } + + fn authorize( + &self, + input: &EditFileToolInput, + event_stream: &ToolCallEventStream, + cx: &App, + ) -> Task> { + if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { + return Task::ready(Ok(())); + } + + // If any path component matches the local settings folder, then this could affect + // the editor in ways beyond the project source, so prompt. + let local_settings_folder = paths::local_settings_folder_relative_path(); + let path = Path::new(&input.path); + if path + .components() + .any(|component| component.as_os_str() == local_settings_folder.as_os_str()) + { + return cx.foreground_executor().spawn( + event_stream.authorize(format!("{} (local settings)", input.display_description)), + ); + } + + // It's also possible that the global config dir is configured to be inside the project, + // so check for that edge case too. + if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { + if canonical_path.starts_with(paths::config_dir()) { + return cx.foreground_executor().spawn( + event_stream + .authorize(format!("{} (global settings)", input.display_description)), + ); + } + } + + // Check if path is inside the global config directory + // First check if it's already inside project - if not, try to canonicalize + let thread = self.thread.read(cx); + let project_path = thread.project().read(cx).find_project_path(&input.path, cx); + + // If the path is inside the project, and it's not one of the above edge cases, + // then no confirmation is necessary. Otherwise, confirmation is necessary. + if project_path.is_some() { + Task::ready(Ok(())) + } else { + cx.foreground_executor() + .spawn(event_stream.authorize(input.display_description.clone())) + } + } +} + +impl AgentTool for EditFileTool { + type Input = EditFileToolInput; + type Output = EditFileToolOutput; + + fn name(&self) -> SharedString { + "edit_file".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Edit + } + + fn initial_title(&self, input: Result) -> SharedString { + match input { + Ok(input) => input.display_description.into(), + Err(raw_input) => { + if let Some(input) = + serde_json::from_value::(raw_input).ok() + { + let description = input.display_description.trim(); + if !description.is_empty() { + return description.to_string().into(); + } + + let path = input.path.trim().to_string(); + if !path.is_empty() { + return path.into(); + } + } + + DEFAULT_UI_TEXT.into() + } + } + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let project = self.thread.read(cx).project().clone(); + let project_path = match resolve_path(&input, project.clone(), cx) { + Ok(path) => path, + Err(err) => return Task::ready(Err(anyhow!(err))), + }; + + let request = self.thread.update(cx, |thread, cx| { + thread.build_completion_request(CompletionIntent::ToolResults, cx) + }); + let thread = self.thread.read(cx); + let model = thread.selected_model.clone(); + let action_log = thread.action_log().clone(); + + let authorize = self.authorize(&input, &event_stream, cx); + cx.spawn(async move |cx: &mut AsyncApp| { + authorize.await?; + + let edit_format = EditFormat::from_model(model.clone())?; + let edit_agent = EditAgent::new( + model, + project.clone(), + action_log.clone(), + // TODO: move edit agent to this crate so we can use our templates + assistant_tools::templates::Templates::new(), + edit_format, + ); + + let buffer = project + .update(cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) + })? + .await?; + + let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?; + event_stream.update_diff(diff.clone()); + + let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let old_text = cx + .background_spawn({ + let old_snapshot = old_snapshot.clone(); + async move { Arc::new(old_snapshot.text()) } + }) + .await; + + + let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) { + edit_agent.edit( + buffer.clone(), + input.display_description.clone(), + &request, + cx, + ) + } else { + edit_agent.overwrite( + buffer.clone(), + input.display_description.clone(), + &request, + cx, + ) + }; + + let mut hallucinated_old_text = false; + let mut ambiguous_ranges = Vec::new(); + while let Some(event) = events.next().await { + match event { + EditAgentOutputEvent::Edited => {}, + EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true, + EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges, + EditAgentOutputEvent::ResolvingEditRange(range) => { + diff.update(cx, |card, cx| card.reveal_range(range, cx))?; + } + } + } + + // If format_on_save is enabled, format the buffer + let format_on_save_enabled = buffer + .read_with(cx, |buffer, cx| { + let settings = language_settings::language_settings( + buffer.language().map(|l| l.name()), + buffer.file(), + cx, + ); + settings.format_on_save != FormatOnSave::Off + }) + .unwrap_or(false); + + let edit_agent_output = output.await?; + + if format_on_save_enabled { + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + })?; + + let format_task = project.update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, // Don't push to history since the tool did it. + FormatTrigger::Save, + cx, + ) + })?; + format_task.await.log_err(); + } + + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? + .await?; + + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + })?; + + let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let (new_text, unified_diff) = cx + .background_spawn({ + let new_snapshot = new_snapshot.clone(); + let old_text = old_text.clone(); + async move { + let new_text = new_snapshot.text(); + let diff = language::unified_diff(&old_text, &new_text); + (new_text, diff) + } + }) + .await; + + diff.update(cx, |diff, cx| diff.finalize(cx)).ok(); + + let input_path = input.path.display(); + if unified_diff.is_empty() { + anyhow::ensure!( + !hallucinated_old_text, + formatdoc! {" + Some edits were produced but none of them could be applied. + Read the relevant sections of {input_path} again so that + I can perform the requested edits. + "} + ); + anyhow::ensure!( + ambiguous_ranges.is_empty(), + { + let line_numbers = ambiguous_ranges + .iter() + .map(|range| range.start.to_string()) + .collect::>() + .join(", "); + formatdoc! {" + matches more than one position in the file (lines: {line_numbers}). Read the + relevant sections of {input_path} again and extend so + that I can perform the requested edits. + "} + } + ); + } + + Ok(EditFileToolOutput { + input_path: input.path, + project_path: project_path.path.to_path_buf(), + new_text: new_text.clone(), + old_text, + diff: unified_diff, + edit_agent_output, + }) + }) + } +} + +/// Validate that the file path is valid, meaning: +/// +/// - For `edit` and `overwrite`, the path must point to an existing file. +/// - For `create`, the file must not already exist, but it's parent dir must exist. +fn resolve_path( + input: &EditFileToolInput, + project: Entity, + cx: &mut App, +) -> Result { + let project = project.read(cx); + + match input.mode { + EditFileMode::Edit | EditFileMode::Overwrite => { + let path = project + .find_project_path(&input.path, cx) + .context("Can't edit file: path not found")?; + + let entry = project + .entry_for_path(&path, cx) + .context("Can't edit file: path not found")?; + + anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory"); + Ok(path) + } + + EditFileMode::Create => { + if let Some(path) = project.find_project_path(&input.path, cx) { + anyhow::ensure!( + project.entry_for_path(&path, cx).is_none(), + "Can't create file: file already exists" + ); + } + + let parent_path = input + .path + .parent() + .context("Can't create file: incorrect path")?; + + let parent_project_path = project.find_project_path(&parent_path, cx); + + let parent_entry = parent_project_path + .as_ref() + .and_then(|path| project.entry_for_path(&path, cx)) + .context("Can't create file: parent directory doesn't exist")?; + + anyhow::ensure!( + parent_entry.is_dir(), + "Can't create file: parent is not a directory" + ); + + let file_name = input + .path + .file_name() + .context("Can't create file: invalid filename")?; + + let new_file_path = parent_project_path.map(|parent| ProjectPath { + path: Arc::from(parent.path.join(file_name)), + ..parent + }); + + new_file_path.context("Can't create file") + } + } +} + +#[cfg(test)] +mod tests { + use crate::Templates; + + use super::*; + use assistant_tool::ActionLog; + use client::TelemetrySettings; + use fs::Fs; + use gpui::{TestAppContext, UpdateGlobal}; + use language_model::fake_provider::FakeLanguageModel; + use serde_json::json; + use settings::SettingsStore; + use std::rc::Rc; + use util::path; + + #[gpui::test] + async fn test_edit_nonexistent_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({})).await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = + cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model)); + let result = cx + .update(|cx| { + let input = EditFileToolInput { + display_description: "Some edit".into(), + path: "root/nonexistent_file.txt".into(), + mode: EditFileMode::Edit, + }; + Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert_eq!( + result.unwrap_err().to_string(), + "Can't edit file: path not found" + ); + } + + #[gpui::test] + async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) { + let mode = &EditFileMode::Create; + + let result = test_resolve_path(mode, "root/new.txt", cx); + assert_resolved_path_eq(result.await, "new.txt"); + + let result = test_resolve_path(mode, "new.txt", cx); + assert_resolved_path_eq(result.await, "new.txt"); + + let result = test_resolve_path(mode, "dir/new.txt", cx); + assert_resolved_path_eq(result.await, "dir/new.txt"); + + let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't create file: file already exists" + ); + + let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't create file: parent directory doesn't exist" + ); + } + + #[gpui::test] + async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) { + let mode = &EditFileMode::Edit; + + let path_with_root = "root/dir/subdir/existing.txt"; + let path_without_root = "dir/subdir/existing.txt"; + let result = test_resolve_path(mode, path_with_root, cx); + assert_resolved_path_eq(result.await, path_without_root); + + let result = test_resolve_path(mode, path_without_root, cx); + assert_resolved_path_eq(result.await, path_without_root); + + let result = test_resolve_path(mode, "root/nonexistent.txt", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't edit file: path not found" + ); + + let result = test_resolve_path(mode, "root/dir", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't edit file: path is a directory" + ); + } + + async fn test_resolve_path( + mode: &EditFileMode, + path: &str, + cx: &mut TestAppContext, + ) -> anyhow::Result { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "dir": { + "subdir": { + "existing.txt": "hello" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let input = EditFileToolInput { + display_description: "Some edit".into(), + path: path.into(), + mode: mode.clone(), + }; + + let result = cx.update(|cx| resolve_path(&input, project, cx)); + result + } + + fn assert_resolved_path_eq(path: anyhow::Result, expected: &str) { + let actual = path + .expect("Should return valid path") + .path + .to_str() + .unwrap() + .replace("\\", "/"); // Naive Windows paths normalization + assert_eq!(actual, expected); + } + + #[gpui::test] + async fn test_format_on_save(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + // Set up a Rust language with LSP formatting support + let rust_language = Arc::new(language::Language::new( + language::LanguageConfig { + name: "Rust".into(), + matcher: language::LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + None, + )); + + // Register the language and fake LSP + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(rust_language); + + let mut fake_language_servers = language_registry.register_fake_lsp( + "Rust", + language::FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + document_formatting_provider: Some(lsp::OneOf::Left(true)), + ..Default::default() + }, + ..Default::default() + }, + ); + + // Create the file + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + // Open the buffer to trigger LSP initialization + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + // Register the buffer with language servers + let _handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + + const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n"; + const FORMATTED_CONTENT: &str = + "This file was formatted by the fake formatter in the test.\n"; + + // Get the fake language server and set up formatting handler + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.set_request_handler::({ + |_, _| async move { + Ok(Some(vec![lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)), + new_text: FORMATTED_CONTENT.to_string(), + }])) + } + }); + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project, + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + + // First, test with format_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.format_on_save = Some(FormatOnSave::On); + settings.defaults.formatter = + Some(language::language_settings::SelectedFormatter::Auto); + }, + ); + }); + }); + + // Have the model stream unformatted content + let edit_result = { + let edit_task = cx.update(|cx| { + let input = EditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }; + Arc::new(EditFileTool { + thread: thread.clone(), + }) + .run(input, ToolCallEventStream::test().0, cx) + }); + + // Stream the unformatted content + cx.executor().run_until_parked(); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Read the file to verify it was formatted automatically + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + new_content.replace("\r\n", "\n"), + FORMATTED_CONTENT, + "Code should be formatted when format_on_save is enabled" + ); + + let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count()); + + assert_eq!( + stale_buffer_count, 0, + "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \ + This causes the agent to think the file was modified externally when it was just formatted.", + stale_buffer_count + ); + + // Next, test with format_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.format_on_save = Some(FormatOnSave::Off); + }, + ); + }); + }); + + // Stream unformatted edits again + let edit_result = { + let edit_task = cx.update(|cx| { + let input = EditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }; + Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + }); + + // Stream the unformatted content + cx.executor().run_until_parked(); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Verify the file was not formatted + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + new_content.replace("\r\n", "\n"), + UNFORMATTED_CONTENT, + "Code should not be formatted when format_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + // Create a simple file with trailing whitespace + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project, + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + + // First, test with remove_trailing_whitespace_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.remove_trailing_whitespace_on_save = Some(true); + }, + ); + }); + }); + + const CONTENT_WITH_TRAILING_WHITESPACE: &str = + "fn main() { \n println!(\"Hello!\"); \n}\n"; + + // Have the model stream content that contains trailing whitespace + let edit_result = { + let edit_task = cx.update(|cx| { + let input = EditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }; + Arc::new(EditFileTool { + thread: thread.clone(), + }) + .run(input, ToolCallEventStream::test().0, cx) + }); + + // Stream the content with trailing whitespace + cx.executor().run_until_parked(); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Read the file to verify trailing whitespace was removed automatically + assert_eq!( + // Ignore carriage returns on Windows + fs.load(path!("/root/src/main.rs").as_ref()) + .await + .unwrap() + .replace("\r\n", "\n"), + "fn main() {\n println!(\"Hello!\");\n}\n", + "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled" + ); + + // Next, test with remove_trailing_whitespace_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.remove_trailing_whitespace_on_save = Some(false); + }, + ); + }); + }); + + // Stream edits again with trailing whitespace + let edit_result = { + let edit_task = cx.update(|cx| { + let input = EditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }; + Arc::new(EditFileTool { + thread: thread.clone(), + }) + .run(input, ToolCallEventStream::test().0, cx) + }); + + // Stream the content with trailing whitespace + cx.executor().run_until_parked(); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Verify the file still has trailing whitespace + // Read the file again - it should still have trailing whitespace + let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + final_content.replace("\r\n", "\n"), + CONTENT_WITH_TRAILING_WHITESPACE, + "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_authorize(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project, + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + fs.insert_tree("/root", json!({})).await; + + // Test 1: Path with .zed component should require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 1".into(), + path: ".zed/settings.json".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + let event = stream_rx.expect_tool_authorization().await; + assert_eq!(event.tool_call.title, "test 1 (local settings)"); + + // Test 2: Path outside project should require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 2".into(), + path: "/etc/hosts".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + let event = stream_rx.expect_tool_authorization().await; + assert_eq!(event.tool_call.title, "test 2"); + + // Test 3: Relative path without .zed should not require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 3".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + + // Test 4: Path with .zed in the middle should require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 4".into(), + path: "root/.zed/tasks.json".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + let event = stream_rx.expect_tool_authorization().await; + assert_eq!(event.tool_call.title, "test 4 (local settings)"); + + // Test 5: When always_allow_tool_actions is enabled, no confirmation needed + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = true; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 5.1".into(), + path: ".zed/settings.json".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 5.2".into(), + path: "/etc/hosts".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + } + + #[gpui::test] + async fn test_authorize_global_config(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project, + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + // Test global config paths - these should require confirmation if they exist and are outside the project + let test_cases = vec![ + ( + "/etc/hosts", + true, + "System file should require confirmation", + ), + ( + "/usr/local/bin/script", + true, + "System bin file should require confirmation", + ), + ( + "project/normal_file.rs", + false, + "Normal project file should not require confirmation", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: path.into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + if should_confirm { + stream_rx.expect_tool_authorization().await; + } else { + auth.await.unwrap(); + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + } + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + + // Create multiple worktree directories + fs.insert_tree( + "/workspace/frontend", + json!({ + "src": { + "main.js": "console.log('frontend');" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/backend", + json!({ + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/shared", + json!({ + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + + // Create project with multiple worktrees + let project = Project::test( + fs.clone(), + [ + path!("/workspace/frontend").as_ref(), + path!("/workspace/backend").as_ref(), + path!("/workspace/shared").as_ref(), + ], + cx, + ) + .await; + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project.clone(), + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + // Test files in different worktrees + let test_cases = vec![ + ("frontend/src/main.js", false, "File in first worktree"), + ("backend/src/main.rs", false, "File in second worktree"), + ( + "shared/.zed/settings.json", + true, + ".zed file in third worktree", + ), + ("/etc/hosts", true, "Absolute path outside all worktrees"), + ( + "../outside/file.txt", + true, + "Relative path outside worktrees", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: path.into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + if should_confirm { + stream_rx.expect_tool_authorization().await; + } else { + auth.await.unwrap(); + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + } + } + } + + #[gpui::test] + async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + ".zed": { + "settings.json": "{}" + }, + "src": { + ".zed": { + "local.json": "{}" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project.clone(), + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + // Test edge cases + let test_cases = vec![ + // Empty path - find_project_path returns Some for empty paths + ("", false, "Empty path is treated as project root"), + // Root directory + ("/", true, "Root directory should be outside project"), + // Parent directory references - find_project_path resolves these + ( + "project/../other", + false, + "Path with .. is resolved by find_project_path", + ), + ( + "project/./src/file.rs", + false, + "Path with . should work normally", + ), + // Windows-style paths (if on Windows) + #[cfg(target_os = "windows")] + ("C:\\Windows\\System32\\hosts", true, "Windows system path"), + #[cfg(target_os = "windows")] + ("project\\src\\main.rs", false, "Windows-style project path"), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: path.into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + if should_confirm { + stream_rx.expect_tool_authorization().await; + } else { + auth.await.unwrap(); + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + } + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "existing.txt": "content", + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project.clone(), + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + // Test different EditFileMode values + let modes = vec![ + EditFileMode::Edit, + EditFileMode::Create, + EditFileMode::Overwrite, + ]; + + for mode in modes { + // Test .zed path with different modes + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit settings".into(), + path: "project/.zed/settings.json".into(), + mode: mode.clone(), + }, + &stream_tx, + cx, + ) + }); + + stream_rx.expect_tool_authorization().await; + + // Test outside path with different modes + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: "/outside/file.txt".into(), + mode: mode.clone(), + }, + &stream_tx, + cx, + ) + }); + + stream_rx.expect_tool_authorization().await; + + // Test normal path with different modes + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: "project/normal.txt".into(), + mode: mode.clone(), + }, + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + } + } + + #[gpui::test] + async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project.clone(), + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + assert_eq!( + tool.initial_title(Err(json!({ + "path": "src/main.rs", + "display_description": "", + "old_string": "old code", + "new_string": "new code" + }))), + "src/main.rs" + ); + assert_eq!( + tool.initial_title(Err(json!({ + "path": "", + "display_description": "Fix error handling", + "old_string": "old code", + "new_string": "new code" + }))), + "Fix error handling" + ); + assert_eq!( + tool.initial_title(Err(json!({ + "path": "src/main.rs", + "display_description": "Fix error handling", + "old_string": "old code", + "new_string": "new code" + }))), + "Fix error handling" + ); + assert_eq!( + tool.initial_title(Err(json!({ + "path": "", + "display_description": "", + "old_string": "old code", + "new_string": "new code" + }))), + DEFAULT_UI_TEXT + ); + assert_eq!( + tool.initial_title(Err(serde_json::Value::Null)), + DEFAULT_UI_TEXT + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + TelemetrySettings::register(cx); + agent_settings::AgentSettings::register(cx); + Project::init_settings(cx); + }); + } +} diff --git a/crates/agent2/src/tools/find_path_tool.rs b/crates/agent2/src/tools/find_path_tool.rs new file mode 100644 index 0000000000..f4589e5600 --- /dev/null +++ b/crates/agent2/src/tools/find_path_tool.rs @@ -0,0 +1,248 @@ +use crate::{AgentTool, ToolCallEventStream}; +use agent_client_protocol as acp; +use anyhow::{anyhow, Result}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use language_model::LanguageModelToolResultContent; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::fmt::Write; +use std::{cmp, path::PathBuf, sync::Arc}; +use util::paths::PathMatcher; + +/// Fast file path pattern matching tool that works with any codebase size +/// +/// - Supports glob patterns like "**/*.js" or "src/**/*.ts" +/// - Returns matching file paths sorted alphabetically +/// - Prefer the `grep` tool to this tool when searching for symbols unless you have specific information about paths. +/// - Use this tool when you need to find files by name patterns +/// - Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FindPathToolInput { + /// The glob to match against every path in the project. + /// + /// + /// If the project has the following root directories: + /// + /// - directory1/a/something.txt + /// - directory2/a/things.txt + /// - directory3/a/other.txt + /// + /// You can get back the first two paths by providing a glob of "*thing*.txt" + /// + pub glob: String, + + /// Optional starting position for paginated results (0-based). + /// When not provided, starts from the beginning. + #[serde(default)] + pub offset: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FindPathToolOutput { + offset: usize, + current_matches_page: Vec, + all_matches_len: usize, +} + +impl From for LanguageModelToolResultContent { + fn from(output: FindPathToolOutput) -> Self { + if output.current_matches_page.is_empty() { + "No matches found".into() + } else { + let mut llm_output = format!("Found {} total matches.", output.all_matches_len); + if output.all_matches_len > RESULTS_PER_PAGE { + write!( + &mut llm_output, + "\nShowing results {}-{} (provide 'offset' parameter for more results):", + output.offset + 1, + output.offset + output.current_matches_page.len() + ) + .unwrap(); + } + + for mat in output.current_matches_page { + write!(&mut llm_output, "\n{}", mat.display()).unwrap(); + } + + llm_output.into() + } + } +} + +const RESULTS_PER_PAGE: usize = 50; + +pub struct FindPathTool { + project: Entity, +} + +impl FindPathTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for FindPathTool { + type Input = FindPathToolInput; + type Output = FindPathToolOutput; + + fn name(&self) -> SharedString { + "find_path".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Search + } + + fn initial_title(&self, input: Result) -> SharedString { + let mut title = "Find paths".to_string(); + if let Ok(input) = input { + title.push_str(&format!(" matching “`{}`”", input.glob)); + } + title.into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); + + cx.background_spawn(async move { + let matches = search_paths_task.await?; + let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len()) + ..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())]; + + event_stream.update_fields(acp::ToolCallUpdateFields { + title: Some(if paginated_matches.len() == 0 { + "No matches".into() + } else if paginated_matches.len() == 1 { + "1 match".into() + } else { + format!("{} matches", paginated_matches.len()) + }), + content: Some( + paginated_matches + .iter() + .map(|path| acp::ToolCallContent::Content { + content: acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: format!("file://{}", path.display()), + name: path.to_string_lossy().into(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }), + }) + .collect(), + ), + raw_output: Some(serde_json::json!({ + "paths": &matches, + })), + ..Default::default() + }); + + Ok(FindPathToolOutput { + offset: input.offset, + current_matches_page: paginated_matches.to_vec(), + all_matches_len: matches.len(), + }) + }) + } +} + +fn search_paths(glob: &str, project: Entity, cx: &mut App) -> Task>> { + let path_matcher = match PathMatcher::new([ + // Sometimes models try to search for "". In this case, return all paths in the project. + if glob.is_empty() { "*" } else { glob }, + ]) { + Ok(matcher) => matcher, + Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))), + }; + let snapshots: Vec<_> = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect(); + + cx.background_spawn(async move { + Ok(snapshots + .iter() + .flat_map(|snapshot| { + let root_name = PathBuf::from(snapshot.root_name()); + snapshot + .entries(false, 0) + .map(move |entry| root_name.join(&entry.path)) + .filter(|path| path_matcher.is_match(&path)) + }) + .collect()) + }) +} + +#[cfg(test)] +mod test { + use super::*; + use gpui::TestAppContext; + use project::{FakeFs, Project}; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_find_path_tool(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + serde_json::json!({ + "apple": { + "banana": { + "carrot": "1", + }, + "bandana": { + "carbonara": "2", + }, + "endive": "3" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let matches = cx + .update(|cx| search_paths("root/**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + + let matches = cx + .update(|cx| search_paths("**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } +} diff --git a/crates/agent2/src/tools/read_file_tool.rs b/crates/agent2/src/tools/read_file_tool.rs new file mode 100644 index 0000000000..7bbe3ac4c1 --- /dev/null +++ b/crates/agent2/src/tools/read_file_tool.rs @@ -0,0 +1,951 @@ +use agent_client_protocol::{self as acp}; +use anyhow::{anyhow, Context, Result}; +use assistant_tool::{outline, ActionLog}; +use gpui::{Entity, Task}; +use indoc::formatdoc; +use language::{Anchor, Point}; +use language_model::{LanguageModelImage, LanguageModelToolResultContent}; +use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::Settings; +use std::sync::Arc; +use ui::{App, SharedString}; + +use crate::{AgentTool, ToolCallEventStream}; + +/// Reads the content of the given file in the project. +/// +/// - Never attempt to read a path that hasn't been previously mentioned. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct ReadFileToolInput { + /// The relative path of the file to read. + /// + /// This path should never be absolute, and the first component + /// of the path should always be a root directory in a project. + /// + /// + /// If the project has the following root directories: + /// + /// - /a/b/directory1 + /// - /c/d/directory2 + /// + /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`. + /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`. + /// + pub path: String, + + /// Optional line number to start reading on (1-based index) + #[serde(default)] + pub start_line: Option, + + /// Optional line number to end reading on (1-based index, inclusive) + #[serde(default)] + pub end_line: Option, +} + +pub struct ReadFileTool { + project: Entity, + action_log: Entity, +} + +impl ReadFileTool { + pub fn new(project: Entity, action_log: Entity) -> Self { + Self { + project, + action_log, + } + } +} + +impl AgentTool for ReadFileTool { + type Input = ReadFileToolInput; + type Output = LanguageModelToolResultContent; + + fn name(&self) -> SharedString { + "read_file".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Read + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Ok(input) = input { + let path = &input.path; + match (input.start_line, input.end_line) { + (Some(start), Some(end)) => { + format!( + "[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))", + path, start, end, path, start, end + ) + } + (Some(start), None) => { + format!( + "[Read file `{}` (from line {})](@selection:{}:({}-{}))", + path, start, path, start, start + ) + } + _ => format!("[Read file `{}`](@file:{})", path, path), + } + .into() + } else { + "Read file".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { + return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))); + }; + + // Error out if this path is either excluded or private in global settings + let global_settings = WorktreeSettings::get_global(cx); + if global_settings.is_path_excluded(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot read file because its path matches the global `file_scan_exclusions` setting: {}", + &input.path + ))); + } + + if global_settings.is_path_private(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot read file because its path matches the global `private_files` setting: {}", + &input.path + ))); + } + + // Error out if this path is either excluded or private in worktree settings + let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); + if worktree_settings.is_path_excluded(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot read file because its path matches the worktree `file_scan_exclusions` setting: {}", + &input.path + ))); + } + + if worktree_settings.is_path_private(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot read file because its path matches the worktree `private_files` setting: {}", + &input.path + ))); + } + + let file_path = input.path.clone(); + + if image_store::is_image_file(&self.project, &project_path, cx) { + return cx.spawn(async move |cx| { + let image_entity: Entity = cx + .update(|cx| { + self.project.update(cx, |project, cx| { + project.open_image(project_path.clone(), cx) + }) + })? + .await?; + + let image = + image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?; + + let language_model_image = cx + .update(|cx| LanguageModelImage::from_image(image, cx))? + .await + .context("processing image")?; + + Ok(language_model_image.into()) + }); + } + + let project = self.project.clone(); + let action_log = self.action_log.clone(); + + cx.spawn(async move |cx| { + let buffer = cx + .update(|cx| { + project.update(cx, |project, cx| project.open_buffer(project_path, cx)) + })? + .await?; + if buffer.read_with(cx, |buffer, _| { + buffer + .file() + .as_ref() + .map_or(true, |file| !file.disk_state().exists()) + })? { + anyhow::bail!("{file_path} not found"); + } + + project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: Anchor::MIN, + }), + cx, + ); + })?; + + // Check if specific line ranges are provided + if input.start_line.is_some() || input.end_line.is_some() { + let mut anchor = None; + let result = buffer.read_with(cx, |buffer, _cx| { + let text = buffer.text(); + // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0. + let start = input.start_line.unwrap_or(1).max(1); + let start_row = start - 1; + if start_row <= buffer.max_point().row { + let column = buffer.line_indent_for_row(start_row).raw_len(); + anchor = Some(buffer.anchor_before(Point::new(start_row, column))); + } + + let lines = text.split('\n').skip(start_row as usize); + if let Some(end) = input.end_line { + let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line + itertools::intersperse(lines.take(count as usize), "\n").collect::() + } else { + itertools::intersperse(lines, "\n").collect::() + } + })?; + + action_log.update(cx, |log, cx| { + log.buffer_read(buffer.clone(), cx); + })?; + + if let Some(anchor) = anchor { + project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: anchor, + }), + cx, + ); + })?; + } + + Ok(result.into()) + } else { + // No line ranges specified, so check file size to see if it's too big. + let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?; + + if file_size <= outline::AUTO_OUTLINE_SIZE { + // File is small enough, so return its contents. + let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?; + + action_log.update(cx, |log, cx| { + log.buffer_read(buffer, cx); + })?; + + Ok(result.into()) + } else { + // File is too big, so return the outline + // and a suggestion to read again with line numbers. + let outline = + outline::file_outline(project, file_path, action_log, None, cx).await?; + Ok(formatdoc! {" + This file was too big to read all at once. + + Here is an outline of its symbols: + + {outline} + + Using the line numbers in this outline, you can call this tool again + while specifying the start_line and end_line fields to see the + implementations of symbols in the outline. + + Alternatively, you can fall back to the `grep` tool (if available) + to search the file for specific content." + } + .into()) + } + } + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; + use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_read_nonexistent_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/root"), json!({})).await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let (event_stream, _) = ToolCallEventStream::test(); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/nonexistent_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.run(input, event_stream, cx) + }) + .await; + assert_eq!( + result.unwrap_err().to_string(), + "root/nonexistent_file.txt not found" + ); + } + + #[gpui::test] + async fn test_read_small_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "small_file.txt": "This is a small file content" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/small_file.txt".into(), + start_line: None, + end_line: None, + }; + tool.run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert_eq!(result.unwrap(), "This is a small file content".into()); + } + + #[gpui::test] + async fn test_read_large_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::>().join("\n") + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(Arc::new(rust_lang())); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/large_file.rs".into(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await + .unwrap(); + let content = result.to_str().unwrap(); + + assert_eq!( + content.lines().skip(4).take(6).collect::>(), + vec![ + "struct Test0 [L1-4]", + " a [L2]", + " b [L3]", + "struct Test1 [L5-8]", + " a [L6]", + " b [L7]", + ] + ); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/large_file.rs".into(), + start_line: None, + end_line: None, + }; + tool.run(input, ToolCallEventStream::test().0, cx) + }) + .await + .unwrap(); + let content = result.to_str().unwrap(); + let expected_content = (0..1000) + .flat_map(|i| { + vec![ + format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4), + format!(" a [L{}]", i * 4 + 2), + format!(" b [L{}]", i * 4 + 3), + ] + }) + .collect::>(); + pretty_assertions::assert_eq!( + content + .lines() + .skip(4) + .take(expected_content.len()) + .collect::>(), + expected_content + ); + } + + #[gpui::test] + async fn test_read_file_with_line_range(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(2), + end_line: Some(4), + }; + tool.run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into()); + } + + #[gpui::test] + async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + + // start_line of 0 should be treated as 1 + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(0), + end_line: Some(2), + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert_eq!(result.unwrap(), "Line 1\nLine 2".into()); + + // end_line of 0 should result in at least 1 line + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(1), + end_line: Some(0), + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert_eq!(result.unwrap(), "Line 1".into()); + + // when start_line > end_line, should still return at least 1 line + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(3), + end_line: Some(2), + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert_eq!(result.unwrap(), "Line 3".into()); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query( + r#" + (line_comment) @annotation + + (struct_item + "struct" @context + name: (_) @name) @item + (enum_item + "enum" @context + name: (_) @name) @item + (enum_variant + name: (_) @name) @item + (field_declaration + name: (_) @name) @item + (impl_item + "impl" @context + trait: (_)? @name + "for"? @context + type: (_) @name + body: (_ "{" (_)* "}")) @item + (function_item + "fn" @context + name: (_) @name) @item + (mod_item + "mod" @context + name: (_) @name) @item + "#, + ) + .unwrap() + } + + #[gpui::test] + async fn test_read_file_security(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + + fs.insert_tree( + path!("/"), + json!({ + "project_root": { + "allowed_file.txt": "This file is in the project", + ".mysecrets": "SECRET_KEY=abc123", + ".secretdir": { + "config": "special configuration" + }, + ".mymetadata": "custom metadata", + "subdir": { + "normal_file.txt": "Normal file content", + "special.privatekey": "private key content", + "data.mysensitive": "sensitive data" + } + }, + "outside_project": { + "sensitive_file.txt": "This file is outside the project" + } + }), + ) + .await; + + cx.update(|cx| { + use gpui::UpdateGlobal; + use project::WorktreeSettings; + use settings::SettingsStore; + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.file_scan_exclusions = Some(vec![ + "**/.secretdir".to_string(), + "**/.mymetadata".to_string(), + ]); + settings.private_files = Some(vec![ + "**/.mysecrets".to_string(), + "**/*.privatekey".to_string(), + "**/*.mysensitive".to_string(), + ]); + }); + }); + }); + + let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + + // Reading a file outside the project worktree should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "/outside_project/sensitive_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read an absolute path outside a worktree" + ); + + // Reading a file within the project should succeed + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/allowed_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert!( + result.is_ok(), + "read_file_tool should be able to read files inside worktrees" + ); + + // Reading files that match file_scan_exclusions should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/.secretdir/config".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read files in .secretdir (file_scan_exclusions)" + ); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/.mymetadata".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read .mymetadata files (file_scan_exclusions)" + ); + + // Reading private files should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/.mysecrets".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read .mysecrets (private_files)" + ); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/subdir/special.privatekey".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read .privatekey files (private_files)" + ); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/subdir/data.mysensitive".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read .mysensitive files (private_files)" + ); + + // Reading a normal file should still work, even with private_files configured + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/subdir/normal_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert!(result.is_ok(), "Should be able to read normal files"); + assert_eq!(result.unwrap(), "Normal file content".into()); + + // Path traversal attempts with .. should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/../outside_project/sensitive_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read a relative path that resolves to outside a worktree" + ); + } + + #[gpui::test] + async fn test_read_file_with_multiple_worktree_settings(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + + // Create first worktree with its own private_files setting + fs.insert_tree( + path!("/worktree1"), + json!({ + "src": { + "main.rs": "fn main() { println!(\"Hello from worktree1\"); }", + "secret.rs": "const API_KEY: &str = \"secret_key_1\";", + "config.toml": "[database]\nurl = \"postgres://localhost/db1\"" + }, + "tests": { + "test.rs": "mod tests { fn test_it() {} }", + "fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));" + }, + ".zed": { + "settings.json": r#"{ + "file_scan_exclusions": ["**/fixture.*"], + "private_files": ["**/secret.rs", "**/config.toml"] + }"# + } + }), + ) + .await; + + // Create second worktree with different private_files setting + fs.insert_tree( + path!("/worktree2"), + json!({ + "lib": { + "public.js": "export function greet() { return 'Hello from worktree2'; }", + "private.js": "const SECRET_TOKEN = \"private_token_2\";", + "data.json": "{\"api_key\": \"json_secret_key\"}" + }, + "docs": { + "README.md": "# Public Documentation", + "internal.md": "# Internal Secrets and Configuration" + }, + ".zed": { + "settings.json": r#"{ + "file_scan_exclusions": ["**/internal.*"], + "private_files": ["**/private.js", "**/data.json"] + }"# + } + }), + ) + .await; + + // Set global settings + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.file_scan_exclusions = + Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]); + settings.private_files = Some(vec!["**/.env".to_string()]); + }); + }); + }); + + let project = Project::test( + fs.clone(), + [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], + cx, + ) + .await; + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone())); + + // Test reading allowed files in worktree1 + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree1/src/main.rs".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await + .unwrap(); + + assert_eq!( + result, + "fn main() { println!(\"Hello from worktree1\"); }".into() + ); + + // Test reading private file in worktree1 should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree1/src/secret.rs".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `private_files` setting"), + "Error should mention worktree private_files setting" + ); + + // Test reading excluded file in worktree1 should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree1/tests/fixture.sql".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `file_scan_exclusions` setting"), + "Error should mention worktree file_scan_exclusions setting" + ); + + // Test reading allowed files in worktree2 + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree2/lib/public.js".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await + .unwrap(); + + assert_eq!( + result, + "export function greet() { return 'Hello from worktree2'; }".into() + ); + + // Test reading private file in worktree2 should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree2/lib/private.js".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `private_files` setting"), + "Error should mention worktree private_files setting" + ); + + // Test reading excluded file in worktree2 should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree2/docs/internal.md".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `file_scan_exclusions` setting"), + "Error should mention worktree file_scan_exclusions setting" + ); + + // Test that files allowed in one worktree but not in another are handled correctly + // (e.g., config.toml is private in worktree1 but doesn't exist in worktree2) + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree1/src/config.toml".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, ToolCallEventStream::test().0, cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `private_files` setting"), + "Config.toml should be blocked by worktree1's private_files setting" + ); + } +} diff --git a/crates/agent2/src/tools/thinking_tool.rs b/crates/agent2/src/tools/thinking_tool.rs new file mode 100644 index 0000000000..43647bb468 --- /dev/null +++ b/crates/agent2/src/tools/thinking_tool.rs @@ -0,0 +1,49 @@ +use agent_client_protocol as acp; +use anyhow::Result; +use gpui::{App, SharedString, Task}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::{AgentTool, ToolCallEventStream}; + +/// A tool for thinking through problems, brainstorming ideas, or planning without executing any actions. +/// Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct ThinkingToolInput { + /// Content to think about. This should be a description of what to think about or + /// a problem to solve. + content: String, +} + +pub struct ThinkingTool; + +impl AgentTool for ThinkingTool { + type Input = ThinkingToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "thinking".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Think + } + + fn initial_title(&self, _input: Result) -> SharedString { + "Thinking".into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { + event_stream.update_fields(acp::ToolCallUpdateFields { + content: Some(vec![input.content.into()]), + ..Default::default() + }); + Task::ready(Ok("Finished thinking.".to_string())) + } +} diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 3dcda4ce8d..8d85435f92 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -135,7 +135,7 @@ impl acp_old::Client for OldAcpClientDelegate { let response = cx .update(|cx| { self.thread.borrow().update(cx, |thread, cx| { - thread.request_tool_call_permission(tool_call, acp_options, cx) + thread.request_tool_call_authorization(tool_call, acp_options, cx) }) })? .context("Failed to update thread")? @@ -280,6 +280,7 @@ fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) .map(into_new_tool_call_location) .collect(), raw_input: None, + raw_output: None, } } @@ -380,6 +381,7 @@ impl AcpConnection { let stdin = child.stdin.take().unwrap(); let stdout = child.stdout.take().unwrap(); + log::trace!("Spawned (pid: {})", child.id()); let foreground_executor = cx.foreground_executor().clone(); @@ -463,7 +465,11 @@ impl AgentConnection for AcpConnection { }) } - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { let chunks = params .prompt .into_iter() @@ -483,7 +489,9 @@ impl AgentConnection for AcpConnection { .request_any(acp_old::SendUserMessageParams { chunks }.into_any()); cx.foreground_executor().spawn(async move { task.await?; - anyhow::Ok(()) + anyhow::Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) }) } diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index a4f0e996b5..ff71783b48 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -19,7 +19,6 @@ pub struct AcpConnection { sessions: Rc>>, auth_methods: Vec, _io_task: Task>, - _child: smol::process::Child, } pub struct AcpSession { @@ -47,6 +46,7 @@ impl AcpConnection { let stdout = child.stdout.take().expect("Failed to take stdout"); let stdin = child.stdin.take().expect("Failed to take stdin"); + log::trace!("Spawned (pid: {})", child.id()); let sessions = Rc::new(RefCell::new(HashMap::default())); @@ -63,6 +63,23 @@ impl AcpConnection { let io_task = cx.background_spawn(io_task); + cx.spawn({ + let sessions = sessions.clone(); + async move |cx| { + let status = child.status().await?; + + for session in sessions.borrow().values() { + session + .thread + .update(cx, |thread, cx| thread.emit_server_exited(status, cx)) + .ok(); + } + + anyhow::Ok(()) + } + }) + .detach(); + let response = connection .initialize(acp::InitializeRequest { protocol_version: acp::VERSION, @@ -84,7 +101,6 @@ impl AcpConnection { connection: connection.into(), server_name, sessions, - _child: child, _io_task: io_task, }) } @@ -153,10 +169,16 @@ impl AgentConnection for AcpConnection { }) } - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { let conn = self.connection.clone(); - cx.foreground_executor() - .spawn(async move { Ok(conn.prompt(params).await?) }) + cx.foreground_executor().spawn(async move { + let response = conn.prompt(params).await?; + Ok(response) + }) } fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { @@ -188,7 +210,7 @@ impl acp::Client for ClientDelegate { .context("Failed to get session")? .thread .update(cx, |thread, cx| { - thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx) + thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx) })?; let result = rx.await; diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index ec69290206..b3b8a33170 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -89,6 +89,7 @@ impl AgentServerCommand { pub(crate) async fn resolve( path_bin_name: &'static str, extra_args: &[&'static str], + fallback_path: Option<&Path>, settings: Option, project: &Entity, cx: &mut AsyncApp, @@ -105,13 +106,24 @@ impl AgentServerCommand { env: agent_settings.command.env, }); } else { - find_bin_in_path(path_bin_name, project, cx) - .await - .map(|path| Self { + match find_bin_in_path(path_bin_name, project, cx).await { + Some(path) => Some(Self { path, args: extra_args.iter().map(|arg| arg.to_string()).collect(), env: None, - }) + }), + None => fallback_path.and_then(|path| { + if path.exists() { + Some(Self { + path: path.to_path_buf(), + args: extra_args.iter().map(|arg| arg.to_string()).collect(), + env: None, + }) + } else { + None + } + }), + } } } } diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 9040b83085..c65508f152 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -24,7 +24,7 @@ use futures::{ }; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use serde::{Deserialize, Serialize}; -use util::ResultExt; +use util::{ResultExt, debug_panic}; use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::tools::ClaudeTool; @@ -101,8 +101,15 @@ impl AgentConnection for ClaudeAgentConnection { settings.get::(None).claude.clone() })?; - let Some(command) = - AgentServerCommand::resolve("claude", &[], settings, &project, cx).await + let Some(command) = AgentServerCommand::resolve( + "claude", + &[], + Some(&util::paths::home_dir().join(".claude/local/claude")), + settings, + &project, + cx, + ) + .await else { anyhow::bail!("Failed to find claude binary"); }; @@ -114,53 +121,63 @@ impl AgentConnection for ClaudeAgentConnection { log::trace!("Starting session with id: {}", session_id); - cx.background_spawn({ - let session_id = session_id.clone(); - async move { - let mut outgoing_rx = Some(outgoing_rx); + let mut child = spawn_claude( + &command, + ClaudeSessionMode::Start, + session_id.clone(), + &mcp_config_path, + &cwd, + )?; - let mut child = spawn_claude( - &command, - ClaudeSessionMode::Start, - session_id.clone(), - &mcp_config_path, - &cwd, - ) - .await?; + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); - let pid = child.id(); - log::trace!("Spawned (pid: {})", pid); + let pid = child.id(); + log::trace!("Spawned (pid: {})", pid); - ClaudeAgentSession::handle_io( - outgoing_rx.take().unwrap(), - incoming_message_tx.clone(), - child.stdin.take().unwrap(), - child.stdout.take().unwrap(), - ) - .await?; + cx.background_spawn(async move { + let mut outgoing_rx = Some(outgoing_rx); - log::trace!("Stopped (pid: {})", pid); + ClaudeAgentSession::handle_io( + outgoing_rx.take().unwrap(), + incoming_message_tx.clone(), + stdin, + stdout, + ) + .await?; - drop(mcp_config_path); - anyhow::Ok(()) - } + log::trace!("Stopped (pid: {})", pid); + + drop(mcp_config_path); + anyhow::Ok(()) }) .detach(); - let end_turn_tx = Rc::new(RefCell::new(None)); + let turn_state = Rc::new(RefCell::new(TurnState::None)); + let handler_task = cx.spawn({ - let end_turn_tx = end_turn_tx.clone(); - let thread_rx = thread_rx.clone(); + let turn_state = turn_state.clone(); + let mut thread_rx = thread_rx.clone(); async move |cx| { while let Some(message) = incoming_message_rx.next().await { ClaudeAgentSession::handle_message( thread_rx.clone(), message, - end_turn_tx.clone(), + turn_state.clone(), cx, ) .await } + + if let Some(status) = child.status().await.log_err() { + if let Some(thread) = thread_rx.recv().await.ok() { + thread + .update(cx, |thread, cx| { + thread.emit_server_exited(status, cx); + }) + .ok(); + } + } } }); @@ -172,7 +189,7 @@ impl AgentConnection for ClaudeAgentConnection { let session = ClaudeAgentSession { outgoing_tx, - end_turn_tx, + turn_state, _handler_task: handler_task, _mcp_server: Some(permission_mcp_server), }; @@ -191,7 +208,11 @@ impl AgentConnection for ClaudeAgentConnection { Task::ready(Err(anyhow!("Authentication not supported"))) } - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { let sessions = self.sessions.borrow(); let Some(session) = sessions.get(¶ms.session_id) else { return Task::ready(Err(anyhow!( @@ -200,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection { ))); }; - let (tx, rx) = oneshot::channel(); - session.end_turn_tx.borrow_mut().replace(tx); + let (end_tx, end_rx) = oneshot::channel(); + session.turn_state.replace(TurnState::InProgress { end_tx }); let mut content = String::new(); for chunk in params.prompt { @@ -235,10 +256,7 @@ impl AgentConnection for ClaudeAgentConnection { return Task::ready(Err(anyhow!(err))); } - cx.foreground_executor().spawn(async move { - rx.await??; - Ok(()) - }) + cx.foreground_executor().spawn(async move { end_rx.await? }) } fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { @@ -248,9 +266,26 @@ impl AgentConnection for ClaudeAgentConnection { return; }; + let request_id = new_request_id(); + + let turn_state = session.turn_state.take(); + let TurnState::InProgress { end_tx } = turn_state else { + // Already cancelled or idle, put it back + session.turn_state.replace(turn_state); + return; + }; + + session.turn_state.replace(TurnState::CancelRequested { + end_tx, + request_id: request_id.clone(), + }); + session .outgoing_tx - .unbounded_send(SdkMessage::new_interrupt_message()) + .unbounded_send(SdkMessage::ControlRequest { + request_id, + request: ControlRequest::Interrupt, + }) .log_err(); } } @@ -262,7 +297,7 @@ enum ClaudeSessionMode { Resume, } -async fn spawn_claude( +fn spawn_claude( command: &AgentServerCommand, mode: ClaudeSessionMode, session_id: acp::SessionId, @@ -313,26 +348,139 @@ async fn spawn_claude( struct ClaudeAgentSession { outgoing_tx: UnboundedSender, - end_turn_tx: Rc>>>>, + turn_state: Rc>, _mcp_server: Option, _handler_task: Task<()>, } +#[derive(Debug, Default)] +enum TurnState { + #[default] + None, + InProgress { + end_tx: oneshot::Sender>, + }, + CancelRequested { + end_tx: oneshot::Sender>, + request_id: String, + }, + CancelConfirmed { + end_tx: oneshot::Sender>, + }, +} + +impl TurnState { + fn is_cancelled(&self) -> bool { + matches!(self, TurnState::CancelConfirmed { .. }) + } + + fn end_tx(self) -> Option>> { + match self { + TurnState::None => None, + TurnState::InProgress { end_tx, .. } => Some(end_tx), + TurnState::CancelRequested { end_tx, .. } => Some(end_tx), + TurnState::CancelConfirmed { end_tx } => Some(end_tx), + } + } + + fn confirm_cancellation(self, id: &str) -> Self { + match self { + TurnState::CancelRequested { request_id, end_tx } if request_id == id => { + TurnState::CancelConfirmed { end_tx } + } + _ => self, + } + } +} + impl ClaudeAgentSession { async fn handle_message( mut thread_rx: watch::Receiver>, message: SdkMessage, - end_turn_tx: Rc>>>>, + turn_state: Rc>, cx: &mut AsyncApp, ) { match message { // we should only be sending these out, they don't need to be in the thread SdkMessage::ControlRequest { .. } => {} - SdkMessage::Assistant { + SdkMessage::User { message, session_id: _, + } => { + let Some(thread) = thread_rx + .recv() + .await + .log_err() + .and_then(|entity| entity.upgrade()) + else { + log::error!("Received an SDK message but thread is gone"); + return; + }; + + for chunk in message.content.chunks() { + match chunk { + ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { + if !turn_state.borrow().is_cancelled() { + thread + .update(cx, |thread, cx| { + thread.push_user_content_block(text.into(), cx) + }) + .log_err(); + } + } + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + let content = content.to_string(); + thread + .update(cx, |thread, cx| { + thread.update_tool_call( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use_id.into()), + fields: acp::ToolCallUpdateFields { + status: if turn_state.borrow().is_cancelled() { + // Do not set to completed if turn was cancelled + None + } else { + Some(acp::ToolCallStatus::Completed) + }, + content: (!content.is_empty()) + .then(|| vec![content.into()]), + ..Default::default() + }, + }, + cx, + ) + }) + .log_err(); + } + ContentChunk::Thinking { .. } + | ContentChunk::RedactedThinking + | ContentChunk::ToolUse { .. } => { + debug_panic!( + "Should not get {:?} with role: assistant. should we handle this?", + chunk + ); + } + + ContentChunk::Image + | ContentChunk::Document + | ContentChunk::WebSearchToolResult => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block( + format!("Unsupported content: {:?}", chunk).into(), + false, + cx, + ) + }) + .log_err(); + } + } + } } - | SdkMessage::User { + SdkMessage::Assistant { message, session_id: _, } => { @@ -355,6 +503,24 @@ impl ClaudeAgentSession { }) .log_err(); } + ContentChunk::Thinking { thinking } => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block(thinking.into(), true, cx) + }) + .log_err(); + } + ContentChunk::RedactedThinking => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block( + "[REDACTED]".into(), + true, + cx, + ) + }) + .log_err(); + } ContentChunk::ToolUse { id, name, input } => { let claude_tool = ClaudeTool::infer(&name, input); @@ -380,33 +546,12 @@ impl ClaudeAgentSession { }) .log_err(); } - ContentChunk::ToolResult { - content, - tool_use_id, - } => { - let content = content.to_string(); - thread - .update(cx, |thread, cx| { - thread.update_tool_call( - acp::ToolCallUpdate { - id: acp::ToolCallId(tool_use_id.into()), - fields: acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::Completed), - content: (!content.is_empty()) - .then(|| vec![content.into()]), - ..Default::default() - }, - }, - cx, - ) - }) - .log_err(); + ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => { + debug_panic!( + "Should not get tool results with role: assistant. should we handle this?" + ); } - ContentChunk::Image - | ContentChunk::Document - | ContentChunk::Thinking - | ContentChunk::RedactedThinking - | ContentChunk::WebSearchToolResult => { + ContentChunk::Image | ContentChunk::Document => { thread .update(cx, |thread, cx| { thread.push_assistant_content_block( @@ -426,20 +571,41 @@ impl ClaudeAgentSession { result, .. } => { - if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { - if is_error { - end_turn_tx - .send(Err(anyhow!( - "Error: {}", - result.unwrap_or_else(|| subtype.to_string()) - ))) - .ok(); - } else { - end_turn_tx.send(Ok(())).ok(); - } + let turn_state = turn_state.take(); + let was_cancelled = turn_state.is_cancelled(); + let Some(end_turn_tx) = turn_state.end_tx() else { + debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn"); + return; + }; + + if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution) + { + end_turn_tx + .send(Err(anyhow!( + "Error: {}", + result.unwrap_or_else(|| subtype.to_string()) + ))) + .ok(); + } else { + let stop_reason = match subtype { + ResultErrorType::Success => acp::StopReason::EndTurn, + ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, + ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, + }; + end_turn_tx + .send(Ok(acp::PromptResponse { stop_reason })) + .ok(); } } - SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {} + SdkMessage::ControlResponse { response } => { + if matches!(response.subtype, ResultErrorType::Success) { + let new_state = turn_state.take().confirm_cancellation(&response.request_id); + turn_state.replace(new_state); + } else { + log::error!("Control response error: {:?}", response); + } + } + SdkMessage::System { .. } => {} } } @@ -548,11 +714,13 @@ enum ContentChunk { content: Content, tool_use_id: String, }, + Thinking { + thinking: String, + }, + RedactedThinking, // TODO Image, Document, - Thinking, - RedactedThinking, WebSearchToolResult, #[serde(untagged)] UntaggedText(String), @@ -562,12 +730,12 @@ impl Display for ContentChunk { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ContentChunk::Text { text } => write!(f, "{}", text), + ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking), + ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"), ContentChunk::UntaggedText(text) => write!(f, "{}", text), ContentChunk::ToolResult { content, .. } => write!(f, "{}", content), ContentChunk::Image | ContentChunk::Document - | ContentChunk::Thinking - | ContentChunk::RedactedThinking | ContentChunk::ToolUse { .. } | ContentChunk::WebSearchToolResult => { write!(f, "\n{:?}\n", &self) @@ -660,7 +828,7 @@ struct ControlResponse { subtype: ResultErrorType, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[serde(rename_all = "snake_case")] enum ResultErrorType { Success, @@ -678,22 +846,15 @@ impl Display for ResultErrorType { } } -impl SdkMessage { - fn new_interrupt_message() -> Self { - use rand::Rng; - // In the Claude Code TS SDK they just generate a random 12 character string, - // `Math.random().toString(36).substring(2, 15)` - let request_id = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(12) - .map(char::from) - .collect(); - - Self::ControlRequest { - request_id, - request: ControlRequest::Interrupt, - } - } +fn new_request_id() -> String { + use rand::Rng; + // In the Claude Code TS SDK they just generate a random 12 character string, + // `Math.random().toString(36).substring(2, 15)` + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(12) + .map(char::from) + .collect() } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -714,6 +875,8 @@ enum PermissionMode { #[cfg(test)] pub(crate) mod tests { use super::*; + use crate::e2e_tests; + use gpui::TestAppContext; use serde_json::json; crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow"); @@ -726,6 +889,68 @@ pub(crate) mod tests { } } + #[gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn test_todo_plan(cx: &mut TestAppContext) { + let fs = e2e_tests::init_test(cx).await; + let project = Project::test(fs, [], cx).await; + let thread = + e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.", + cx, + ) + }) + .await + .unwrap(); + + let mut entries_len = 0; + + thread.read_with(cx, |thread, _| { + entries_len = thread.plan().entries.len(); + assert!(thread.plan().entries.len() > 0, "Empty plan"); + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Mark the first entry status as in progress without acting on it.", + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.plan().entries[0].status, + acp::PlanEntryStatus::InProgress + )); + assert_eq!(thread.plan().entries.len(), entries_len); + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Now mark the first entry as completed without acting on it.", + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.plan().entries[0].status, + acp::PlanEntryStatus::Completed + )); + assert_eq!(thread.plan().entries.len(), entries_len); + }); + } + #[test] fn test_deserialize_content_untagged_text() { let json = json!("Hello, world!"); diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index c6f8bb5b69..53a8556e74 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -153,7 +153,7 @@ impl McpServerTool for PermissionTool { let chosen_option = thread .update(cx, |thread, cx| { - thread.request_tool_call_permission( + thread.request_tool_call_authorization( claude_tool.as_acp(tool_call_id), vec![ acp::PermissionOption { diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index e7d33e5298..7ca150c0bd 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -143,25 +143,6 @@ impl ClaudeTool { Self::Grep(Some(params)) => vec![format!("`{params}`").into()], Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()], Self::WebSearch(Some(params)) => vec![params.to_string().into()], - Self::TodoWrite(Some(params)) => vec![ - params - .todos - .iter() - .map(|todo| { - format!( - "- {} {}: {}", - match todo.status { - TodoStatus::Completed => "✅", - TodoStatus::InProgress => "🚧", - TodoStatus::Pending => "⬜", - }, - todo.priority, - todo.content - ) - }) - .join("\n") - .into(), - ], Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()], Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff { diff: acp::Diff { @@ -193,6 +174,10 @@ impl ClaudeTool { }) .unwrap_or_default() } + Self::TodoWrite(Some(_)) => { + // These are mapped to plan updates later + vec![] + } Self::Task(None) | Self::NotebookRead(None) | Self::NotebookEdit(None) @@ -312,6 +297,7 @@ impl ClaudeTool { content: self.content(), locations: self.locations(), raw_input: None, + raw_output: None, } } } @@ -488,10 +474,11 @@ impl std::fmt::Display for GrepToolParams { } } -#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)] +#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)] #[serde(rename_all = "snake_case")] pub enum TodoPriority { High, + #[default] Medium, Low, } @@ -526,14 +513,13 @@ impl Into for TodoStatus { #[derive(Deserialize, Serialize, JsonSchema, Debug)] pub struct Todo { - /// Unique identifier - pub id: String, /// Task description pub content: String, - /// Priority level of the todo - pub priority: TodoPriority, /// Current status of the todo pub status: TodoStatus, + /// Priority level of the todo + #[serde(default)] + pub priority: TodoPriority, } impl Into for Todo { diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index a60aefb7b9..ec6ca29b9d 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; - let full_turn = thread.update(cx, |thread, cx| { + let _ = thread.update(cx, |thread, cx| { thread.send_raw( r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, cx, @@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon id.clone() }); - let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); - full_turn.await.unwrap(); - thread.read_with(cx, |thread, _| { + thread.update(cx, |thread, cx| thread.cancel(cx)).await; + thread.read_with(cx, |thread, _cx| { let AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Canceled, .. @@ -311,6 +310,27 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon }); } +pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + let project = Project::test(fs, [], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx)) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert!(thread.entries().len() >= 2, "Expected at least 2 entries"); + }); + + let weak_thread = thread.downgrade(); + drop(thread); + + cx.executor().run_until_parked(); + assert!(!weak_thread.is_upgradable()); +} + #[macro_export] macro_rules! common_e2e_tests { ($server:expr, allow_option_id = $allow_option_id:expr) => { @@ -351,6 +371,12 @@ macro_rules! common_e2e_tests { async fn cancel(cx: &mut ::gpui::TestAppContext) { $crate::e2e_tests::test_cancel($server, cx).await; } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn thread_drop(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_thread_drop($server, cx).await; + } } }; } diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 2366783d22..ad883f6da8 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -2,7 +2,7 @@ use std::path::Path; use std::rc::Rc; use crate::{AgentServer, AgentServerCommand}; -use acp_thread::AgentConnection; +use acp_thread::{AgentConnection, LoadError}; use anyhow::Result; use gpui::{Entity, Task}; use project::Project; @@ -48,12 +48,42 @@ impl AgentServer for Gemini { })?; let Some(command) = - AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await + AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await else { anyhow::bail!("Failed to find gemini binary"); }; - crate::acp::connect(server_name, command, &root_dir, cx).await + let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await; + if result.is_err() { + let version_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--version") + .kill_on_drop(true) + .output(); + + let help_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--help") + .kill_on_drop(true) + .output(); + + let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; + + let current_version = String::from_utf8(version_output?.stdout)?; + let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); + + if !supported { + return Err(LoadError::Unsupported { + error_message: format!( + "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", + current_version + ).into(), + upgrade_message: "Upgrade Gemini to Latest".into(), + upgrade_command: "npm install -g @google/gemini-cli@latest".into(), + }.into()) + } + } + result }) } } diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 4e872c78d7..e6a79963d6 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -13,6 +13,9 @@ use std::borrow::Cow; pub use crate::agent_profile::*; +pub const SUMMARIZE_THREAD_PROMPT: &str = + include_str!("../../agent/src/prompts/summarize_thread_prompt.txt"); + pub fn init(cx: &mut App) { AgentSettings::register(cx); } diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 95fd2b1757..c145df0eae 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -19,6 +19,7 @@ test-support = ["gpui/test-support", "language/test-support"] acp_thread.workspace = true agent-client-protocol.workspace = true agent.workspace = true +agent2.workspace = true agent_servers.workspace = true agent_settings.workspace = true ai_onboarding.workspace = true diff --git a/crates/agent_ui/src/acp/message_history.rs b/crates/agent_ui/src/acp/message_history.rs index d0fb1f0990..c6106c7578 100644 --- a/crates/agent_ui/src/acp/message_history.rs +++ b/crates/agent_ui/src/acp/message_history.rs @@ -45,6 +45,11 @@ impl MessageHistory { None }) } + + #[cfg(test)] + pub fn items(&self) -> &[T] { + &self.items + } } #[cfg(test)] mod tests { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index a8e2d59b62..7f4e7e7208 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -5,6 +5,7 @@ use audio::{Audio, Sound}; use std::cell::RefCell; use std::collections::BTreeMap; use std::path::Path; +use std::process::ExitStatus; use std::rc::Rc; use std::sync::Arc; use std::time::Duration; @@ -20,26 +21,28 @@ use editor::{ use file_icons::FileIcons; use gpui::{ Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, - FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString, - StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation, - UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient, - list, percentage, point, prelude::*, pulsating_between, + FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay, + SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, + Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, + linear_gradient, list, percentage, point, prelude::*, pulsating_between, }; use language::language_settings::SoftWrap; use language::{Buffer, Language}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::Project; -use settings::Settings as _; +use settings::{Settings as _, SettingsStore}; use text::{Anchor, BufferSnapshot}; use theme::ThemeSettings; -use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*}; +use ui::{ + Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*, +}; use util::ResultExt; use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; use ::acp_thread::{ - AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, + AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, }; @@ -68,6 +71,7 @@ pub struct AcpThreadView { notification_subscriptions: HashMap, Vec>, last_error: Option>, list_state: ListState, + scrollbar_state: ScrollbarState, auth_task: Option>, expanded_tool_calls: HashSet, expanded_thinking_blocks: HashSet<(usize, usize)>, @@ -76,6 +80,7 @@ pub struct AcpThreadView { editor_expanded: bool, message_history: Rc>>>, _cancel_task: Option>, + _subscriptions: [Subscription; 1], } enum ThreadState { @@ -90,6 +95,9 @@ enum ThreadState { Unauthenticated { connection: Rc, }, + ServerExited { + status: ExitStatus, + }, } impl AcpThreadView { @@ -169,22 +177,9 @@ impl AcpThreadView { let mention_set = mention_set.clone(); - let list_state = ListState::new( - 0, - gpui::ListAlignment::Bottom, - px(2048.0), - cx.processor({ - move |this: &mut Self, index: usize, window, cx| { - let Some((entry, len)) = this.thread().and_then(|thread| { - let entries = &thread.read(cx).entries(); - Some((entries.get(index)?, entries.len())) - }) else { - return Empty.into_any(); - }; - this.render_entry(index, len, entry, window, cx) - } - }), - ); + let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0)); + + let subscription = cx.observe_global_in::(window, Self::settings_changed); Self { agent: agent.clone(), @@ -198,7 +193,8 @@ impl AcpThreadView { notifications: Vec::new(), notification_subscriptions: HashMap::default(), diff_editors: Default::default(), - list_state: list_state, + list_state: list_state.clone(), + scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), last_error: None, auth_task: None, expanded_tool_calls: HashSet::default(), @@ -207,6 +203,7 @@ impl AcpThreadView { plan_expanded: false, editor_expanded: false, message_history, + _subscriptions: [subscription], _cancel_task: None, } } @@ -228,7 +225,7 @@ impl AcpThreadView { let connect_task = agent.connect(&root_dir, &project, cx); let load_task = cx.spawn_in(window, async move |this, cx| { let connection = match connect_task.await { - Ok(thread) => thread, + Ok(connection) => connection, Err(err) => { this.update(cx, |this, cx| { this.handle_load_error(err, cx); @@ -239,6 +236,20 @@ impl AcpThreadView { } }; + // this.update_in(cx, |_this, _window, cx| { + // let status = connection.exit_status(cx); + // cx.spawn(async move |this, cx| { + // let status = status.await.ok(); + // this.update(cx, |this, cx| { + // this.thread_state = ThreadState::ServerExited { status }; + // cx.notify(); + // }) + // .ok(); + // }) + // .detach(); + // }) + // .ok(); + let result = match connection .clone() .new_thread(project.clone(), &root_dir, cx) @@ -307,7 +318,8 @@ impl AcpThreadView { ThreadState::Ready { thread, .. } => Some(thread), ThreadState::Unauthenticated { .. } | ThreadState::Loading { .. } - | ThreadState::LoadError(..) => None, + | ThreadState::LoadError(..) + | ThreadState::ServerExited { .. } => None, } } @@ -317,6 +329,7 @@ impl AcpThreadView { ThreadState::Loading { .. } => "Loading…".into(), ThreadState::LoadError(_) => "Failed to load".into(), ThreadState::Unauthenticated { .. } => "Not authenticated".into(), + ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(), } } @@ -368,6 +381,11 @@ impl AcpThreadView { editor.display_map.update(cx, |map, cx| { let snapshot = map.snapshot(cx); for (crease_id, crease) in snapshot.crease_snapshot.creases() { + // Skip creases that have been edited out of the message buffer. + if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { + continue; + } + if let Some(project_path) = self.mention_set.lock().path_for_crease_id(crease_id) { @@ -646,6 +664,9 @@ impl AcpThreadView { cx, ); } + AcpThreadEvent::ServerExited(status) => { + self.thread_state = ThreadState::ServerExited { status: *status }; + } } cx.notify(); } @@ -692,15 +713,7 @@ impl AcpThreadView { editor.set_show_code_actions(false, cx); editor.set_show_git_diff_gutter(false, cx); editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(TextStyleRefinement { - font_size: Some( - TextSize::Small - .rems(cx) - .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) - .into(), - ), - ..Default::default() - }); + editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); editor }); let entity_id = multibuffer.entity_id(); @@ -719,7 +732,11 @@ impl AcpThreadView { cx: &App, ) -> Option>> { let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - Some(entry.diffs().map(|diff| diff.multibuffer.clone())) + Some( + entry + .diffs() + .map(|diff| diff.read(cx).multibuffer().clone()), + ) } fn authenticate( @@ -785,7 +802,7 @@ impl AcpThreadView { window: &mut Window, cx: &Context, ) -> AnyElement { - match &entry { + let primary = match &entry { AgentThreadEntry::UserMessage(message) => div() .py_4() .px_2() @@ -846,10 +863,25 @@ impl AcpThreadView { .into_any() } AgentThreadEntry::ToolCall(tool_call) => div() + .w_full() .py_1p5() .px_5() .child(self.render_tool_call(index, tool_call, window, cx)) .into_any(), + }; + + let Some(thread) = self.thread() else { + return primary; + }; + let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating); + if index == total_entries - 1 && !is_generating { + v_flex() + .w_full() + .child(primary) + .child(self.render_thread_controls(cx)) + .into_any_element() + } else { + primary } } @@ -877,6 +909,7 @@ impl AcpThreadView { cx: &Context, ) -> AnyElement { let header_id = SharedString::from(format!("thinking-block-header-{}", entry_ix)); + let card_header_id = SharedString::from("inner-card-header"); let key = (entry_ix, chunk_ix); let is_open = self.expanded_thinking_blocks.contains(&key); @@ -884,41 +917,53 @@ impl AcpThreadView { .child( h_flex() .id(header_id) - .group("disclosure-header") + .group(&card_header_id) + .relative() .w_full() - .justify_between() + .gap_1p5() .opacity(0.8) .hover(|style| style.opacity(1.)) .child( h_flex() - .gap_1p5() - .child( - Icon::new(IconName::ToolBulb) - .size(IconSize::Small) - .color(Color::Muted), - ) + .size_4() + .justify_center() .child( div() - .text_size(self.tool_name_font_size()) - .child("Thinking"), + .group_hover(&card_header_id, |s| s.invisible().w_0()) + .child( + Icon::new(IconName::ToolThink) + .size(IconSize::Small) + .color(Color::Muted), + ), + ) + .child( + h_flex() + .absolute() + .inset_0() + .invisible() + .justify_center() + .group_hover(&card_header_id, |s| s.visible()) + .child( + Disclosure::new(("expand", entry_ix), is_open) + .opened_icon(IconName::ChevronUp) + .closed_icon(IconName::ChevronRight) + .on_click(cx.listener({ + move |this, _event, _window, cx| { + if is_open { + this.expanded_thinking_blocks.remove(&key); + } else { + this.expanded_thinking_blocks.insert(key); + } + cx.notify(); + } + })), + ), ), ) .child( - div().visible_on_hover("disclosure-header").child( - Disclosure::new("thinking-disclosure", is_open) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronDown) - .on_click(cx.listener({ - move |this, _event, _window, cx| { - if is_open { - this.expanded_thinking_blocks.remove(&key); - } else { - this.expanded_thinking_blocks.insert(key); - } - cx.notify(); - } - })), - ), + div() + .text_size(self.tool_name_font_size()) + .child("Thinking"), ) .on_click(cx.listener({ move |this, _event, _window, cx| { @@ -949,6 +994,67 @@ impl AcpThreadView { .into_any_element() } + fn render_tool_call_icon( + &self, + group_name: SharedString, + entry_ix: usize, + is_collapsible: bool, + is_open: bool, + tool_call: &ToolCall, + cx: &Context, + ) -> Div { + let tool_icon = Icon::new(match tool_call.kind { + acp::ToolKind::Read => IconName::ToolRead, + acp::ToolKind::Edit => IconName::ToolPencil, + acp::ToolKind::Delete => IconName::ToolDeleteFile, + acp::ToolKind::Move => IconName::ArrowRightLeft, + acp::ToolKind::Search => IconName::ToolSearch, + acp::ToolKind::Execute => IconName::ToolTerminal, + acp::ToolKind::Think => IconName::ToolThink, + acp::ToolKind::Fetch => IconName::ToolWeb, + acp::ToolKind::Other => IconName::ToolHammer, + }) + .size(IconSize::Small) + .color(Color::Muted); + + if is_collapsible { + h_flex() + .size_4() + .justify_center() + .child( + div() + .group_hover(&group_name, |s| s.invisible().w_0()) + .child(tool_icon), + ) + .child( + h_flex() + .absolute() + .inset_0() + .invisible() + .justify_center() + .group_hover(&group_name, |s| s.visible()) + .child( + Disclosure::new(("expand", entry_ix), is_open) + .opened_icon(IconName::ChevronUp) + .closed_icon(IconName::ChevronRight) + .on_click(cx.listener({ + let id = tool_call.id.clone(); + move |this: &mut Self, _, _, cx: &mut Context| { + if is_open { + this.expanded_tool_calls.remove(&id); + } else { + this.expanded_tool_calls.insert(id.clone()); + } + cx.notify(); + } + })), + ), + ) + } else { + div().child(tool_icon) + } + } + fn render_tool_call( &self, entry_ix: usize, @@ -956,7 +1062,8 @@ impl AcpThreadView { window: &Window, cx: &Context, ) -> Div { - let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix)); + let header_id = SharedString::from(format!("outer-tool-call-header-{}", entry_ix)); + let card_header_id = SharedString::from("inner-tool-call-header"); let status_icon = match &tool_call.status { ToolCallStatus::Allowed { @@ -1005,6 +1112,21 @@ impl AcpThreadView { let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id); + let gradient_color = cx.theme().colors().panel_background; + let gradient_overlay = { + div() + .absolute() + .top_0() + .right_0() + .w_12() + .h_full() + .bg(linear_gradient( + 90., + linear_color_stop(gradient_color, 1.), + linear_color_stop(gradient_color.opacity(0.2), 0.), + )) + }; + v_flex() .when(needs_confirmation, |this| { this.rounded_lg() @@ -1021,43 +1143,38 @@ impl AcpThreadView { .justify_between() .map(|this| { if needs_confirmation { - this.px_2() + this.pl_2() + .pr_1() .py_1() .rounded_t_md() - .bg(self.tool_card_header_bg(cx)) .border_b_1() .border_color(self.tool_card_border_color(cx)) + .bg(self.tool_card_header_bg(cx)) } else { this.opacity(0.8).hover(|style| style.opacity(1.)) } }) .child( h_flex() - .id("tool-call-header") - .overflow_x_scroll() + .group(&card_header_id) + .relative() + .w_full() .map(|this| { - if needs_confirmation { - this.text_xs() + if tool_call.locations.len() == 1 { + this.gap_0() } else { - this.text_size(self.tool_name_font_size()) + this.gap_1p5() } }) - .gap_1p5() - .child( - Icon::new(match tool_call.kind { - acp::ToolKind::Read => IconName::ToolRead, - acp::ToolKind::Edit => IconName::ToolPencil, - acp::ToolKind::Delete => IconName::ToolDeleteFile, - acp::ToolKind::Move => IconName::ArrowRightLeft, - acp::ToolKind::Search => IconName::ToolSearch, - acp::ToolKind::Execute => IconName::ToolTerminal, - acp::ToolKind::Think => IconName::ToolBulb, - acp::ToolKind::Fetch => IconName::ToolWeb, - acp::ToolKind::Other => IconName::ToolHammer, - }) - .size(IconSize::Small) - .color(Color::Muted), - ) + .text_size(self.tool_name_font_size()) + .child(self.render_tool_call_icon( + card_header_id, + entry_ix, + is_collapsible, + is_open, + tool_call, + cx, + )) .child(if tool_call.locations.len() == 1 { let name = tool_call.locations[0] .path @@ -1068,13 +1185,11 @@ impl AcpThreadView { h_flex() .id(("open-tool-call-location", entry_ix)) - .child(name) .w_full() .max_w_full() - .pr_1() - .gap_0p5() - .cursor_pointer() + .px_1p5() .rounded_sm() + .overflow_x_scroll() .opacity(0.8) .hover(|label| { label.opacity(1.).bg(cx @@ -1083,53 +1198,49 @@ impl AcpThreadView { .element_hover .opacity(0.5)) }) + .child(name) .tooltip(Tooltip::text("Jump to File")) .on_click(cx.listener(move |this, _, window, cx| { this.open_tool_call_location(entry_ix, 0, window, cx); })) .into_any_element() } else { - self.render_markdown( - tool_call.label.clone(), - default_markdown_style(needs_confirmation, window, cx), - ) - .into_any() + h_flex() + .id("non-card-label-container") + .w_full() + .relative() + .overflow_hidden() + .child( + h_flex() + .id("non-card-label") + .pr_8() + .w_full() + .overflow_x_scroll() + .child(self.render_markdown( + tool_call.label.clone(), + default_markdown_style( + needs_confirmation, + window, + cx, + ), + )), + ) + .child(gradient_overlay) + .on_click(cx.listener({ + let id = tool_call.id.clone(); + move |this: &mut Self, _, _, cx: &mut Context| { + if is_open { + this.expanded_tool_calls.remove(&id); + } else { + this.expanded_tool_calls.insert(id.clone()); + } + cx.notify(); + } + })) + .into_any() }), ) - .child( - h_flex() - .gap_0p5() - .when(is_collapsible, |this| { - this.child( - Disclosure::new(("expand", entry_ix), is_open) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronDown) - .on_click(cx.listener({ - let id = tool_call.id.clone(); - move |this: &mut Self, _, _, cx: &mut Context| { - if is_open { - this.expanded_tool_calls.remove(&id); - } else { - this.expanded_tool_calls.insert(id.clone()); - } - cx.notify(); - } - })), - ) - }) - .children(status_icon), - ) - .on_click(cx.listener({ - let id = tool_call.id.clone(); - move |this: &mut Self, _, _, cx: &mut Context| { - if is_open { - this.expanded_tool_calls.remove(&id); - } else { - this.expanded_tool_calls.insert(id.clone()); - } - cx.notify(); - } - })), + .children(status_icon), ) .when(is_open, |this| { this.child( @@ -1207,10 +1318,9 @@ impl AcpThreadView { Empty.into_any_element() } } - ToolCallContent::Diff { - diff: Diff { multibuffer, .. }, - .. - } => self.render_diff_editor(multibuffer), + ToolCallContent::Diff { diff, .. } => { + self.render_diff_editor(&diff.read(cx).multibuffer()) + } } } @@ -1223,8 +1333,7 @@ impl AcpThreadView { cx: &Context, ) -> Div { h_flex() - .py_1p5() - .px_1p5() + .p_1p5() .gap_1() .justify_end() .when(!empty_content, |this| { @@ -1250,6 +1359,7 @@ impl AcpThreadView { }) .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) + .label_size(LabelSize::Small) .on_click(cx.listener({ let tool_call_id = tool_call_id.clone(); let option_id = option.id.clone(); @@ -1369,7 +1479,29 @@ impl AcpThreadView { .into_any() } - fn render_error_state(&self, e: &LoadError, cx: &Context) -> AnyElement { + fn render_server_exited(&self, status: ExitStatus, _cx: &Context) -> AnyElement { + v_flex() + .items_center() + .justify_center() + .child(self.render_error_agent_logo()) + .child( + v_flex() + .mt_4() + .mb_2() + .gap_0p5() + .text_center() + .items_center() + .child(Headline::new("Server exited unexpectedly").size(HeadlineSize::Medium)) + .child( + Label::new(format!("Exit status: {}", status.code().unwrap_or(-127))) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .into_any_element() + } + + fn render_load_error(&self, e: &LoadError, cx: &Context) -> AnyElement { let mut container = v_flex() .items_center() .justify_center() @@ -1477,7 +1609,7 @@ impl AcpThreadView { }) }) .when(!changed_buffers.is_empty(), |this| { - this.child(Divider::horizontal()) + this.child(Divider::horizontal().color(DividerColor::Border)) .child(self.render_edits_summary( action_log, &changed_buffers, @@ -1507,6 +1639,7 @@ impl AcpThreadView { { h_flex() .w_full() + .cursor_default() .gap_1() .text_xs() .text_color(cx.theme().colors().text_muted) @@ -1536,7 +1669,7 @@ impl AcpThreadView { let status_label = if stats.pending == 0 { "All Done".to_string() } else if stats.completed == 0 { - format!("{}", plan.entries.len()) + format!("{} Tasks", plan.entries.len()) } else { format!("{}/{}", stats.completed, plan.entries.len()) }; @@ -1650,7 +1783,6 @@ impl AcpThreadView { .child( h_flex() .id("edits-container") - .cursor_pointer() .w_full() .gap_1() .child(Disclosure::new("edits-disclosure", expanded)) @@ -2404,7 +2536,7 @@ impl AcpThreadView { } } - fn render_thread_controls(&mut self, cx: &mut Context) -> impl IntoElement { + fn render_thread_controls(&self, cx: &Context) -> impl IntoElement { let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileText) .icon_size(IconSize::XSmall) .icon_color(Color::Ignored) @@ -2425,9 +2557,9 @@ impl AcpThreadView { })); h_flex() - .mt_1() + .w_full() .mr_1() - .py_2() + .pb_2() .px(RESPONSE_PADDING_X) .opacity(0.4) .hover(|style| style.opacity(1.)) @@ -2436,6 +2568,48 @@ impl AcpThreadView { .child(open_as_markdown) .child(scroll_to_top) } + + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .id("acp-thread-scrollbar") + .occlude() + .on_mouse_move(cx.listener(|_, _, _, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { + cx.stop_propagation(); + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) + } + + fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context) { + for diff_editor in self.diff_editors.values() { + diff_editor.update(cx, |diff_editor, cx| { + diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); + cx.notify(); + }) + } + } } impl Focusable for AcpThreadView { @@ -2481,24 +2655,36 @@ impl Render for AcpThreadView { .flex_1() .items_center() .justify_center() - .child(self.render_error_state(e, cx)), + .child(self.render_load_error(e, cx)), + ThreadState::ServerExited { status } => v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_server_exited(*status, cx)), ThreadState::Ready { thread, .. } => { let thread_clone = thread.clone(); v_flex().flex_1().map(|this| { if self.list_state.item_count() > 0 { - let is_generating = - matches!(thread_clone.read(cx).status(), ThreadStatus::Generating); - this.child( - list(self.list_state.clone()) - .with_sizing_behavior(gpui::ListSizingBehavior::Auto) - .flex_grow() - .into_any(), + list( + self.list_state.clone(), + cx.processor(|this, index: usize, window, cx| { + let Some((entry, len)) = this.thread().and_then(|thread| { + let entries = &thread.read(cx).entries(); + Some((entries.get(index)?, entries.len())) + }) else { + return Empty.into_any(); + }; + this.render_entry(index, len, entry, window, cx) + }), + ) + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .flex_grow() + .into_any(), ) - .when(!is_generating, |this| { - this.child(self.render_thread_controls(cx)) - }) + .child(self.render_vertical_scrollbar(cx)) .children(match thread_clone.read(cx).status() { ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => { None @@ -2701,6 +2887,18 @@ fn plan_label_markdown_style( } } +fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { + TextStyleRefinement { + font_size: Some( + TextSize::Small + .rems(cx) + .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) + .into(), + ), + ..Default::default() + } +} + #[cfg(test)] mod tests { use agent_client_protocol::SessionId; @@ -2708,11 +2906,25 @@ mod tests { use fs::FakeFs; use futures::future::try_join_all; use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; + use lsp::{CompletionContext, CompletionTriggerKind}; + use project::CompletionIntent; use rand::Rng; + use serde_json::json; use settings::SettingsStore; + use util::path; use super::*; + #[gpui::test] + async fn test_drop(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, _cx) = setup_thread_view(StubAgentServer::default(), cx).await; + let weak_view = thread_view.downgrade(); + drop(thread_view); + assert!(!weak_view.is_upgradable()); + } + #[gpui::test] async fn test_notification_for_stop_event(cx: &mut TestAppContext) { init_test(cx); @@ -2779,6 +2991,7 @@ mod tests { content: vec!["hi".into()], locations: vec![], raw_input: None, + raw_output: None, }; let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) .with_permission_requests(HashMap::from_iter([( @@ -2811,6 +3024,109 @@ mod tests { ); } + #[gpui::test] + async fn test_crease_removal(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({"file": ""})).await; + let project = Project::test(fs, [Path::new(path!("/project"))], cx).await; + let agent = StubAgentServer::default(); + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(agent), + workspace.downgrade(), + project, + Rc::new(RefCell::new(MessageHistory::default())), + 1, + None, + window, + cx, + ) + }) + }); + + cx.run_until_parked(); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + let excerpt_id = message_editor.update(cx, |editor, cx| { + editor + .buffer() + .read(cx) + .excerpt_ids() + .into_iter() + .next() + .unwrap() + }); + let completions = message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello @", window, cx); + let buffer = editor.buffer().read(cx).as_singleton().unwrap(); + let completion_provider = editor.completion_provider().unwrap(); + completion_provider.completions( + excerpt_id, + &buffer, + Anchor::MAX, + CompletionContext { + trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER, + trigger_character: Some("@".into()), + }, + window, + cx, + ) + }); + let [_, completion]: [_; 2] = completions + .await + .unwrap() + .into_iter() + .flat_map(|response| response.completions) + .collect::>() + .try_into() + .unwrap(); + + message_editor.update_in(cx, |editor, window, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let start = snapshot + .anchor_in_excerpt(excerpt_id, completion.replace_range.start) + .unwrap(); + let end = snapshot + .anchor_in_excerpt(excerpt_id, completion.replace_range.end) + .unwrap(); + editor.edit([(start..end, completion.new_text)], cx); + (completion.confirm.unwrap())(CompletionIntent::Complete, window, cx); + }); + + cx.run_until_parked(); + + // Backspace over the inserted crease (and the following space). + message_editor.update_in(cx, |editor, window, cx| { + editor.backspace(&Default::default(), window, cx); + editor.backspace(&Default::default(), window, cx); + }); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + let content = thread_view.update_in(cx, |thread_view, _window, _cx| { + thread_view + .message_history + .borrow() + .items() + .iter() + .flatten() + .cloned() + .collect::>() + }); + + // We don't send a resource link for the deleted crease. + pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]); + } + async fn setup_thread_view( agent: impl AgentServer + 'static, cx: &mut TestAppContext, @@ -2943,7 +3259,11 @@ mod tests { unimplemented!() } - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + fn prompt( + &self, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { let sessions = self.sessions.lock(); let thread = sessions.get(¶ms.session_id).unwrap(); let mut tasks = vec![]; @@ -2960,7 +3280,7 @@ mod tests { let task = cx.spawn(async move |cx| { if let Some((tool_call, options)) = permission_request { let permission = thread.update(cx, |thread, cx| { - thread.request_tool_call_permission( + thread.request_tool_call_authorization( tool_call.clone(), options.clone(), cx, @@ -2977,7 +3297,9 @@ mod tests { } cx.spawn(async move |_| { try_join_all(tasks).await?; - Ok(()) + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) }) } @@ -3021,7 +3343,11 @@ mod tests { unimplemented!() } - fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task> { + fn prompt( + &self, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { Task::ready(Err(anyhow::anyhow!("Error prompting"))) } diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 04a093c7d0..71526c8fe1 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -69,8 +69,6 @@ pub struct ActiveThread { messages: Vec, list_state: ListState, scrollbar_state: ScrollbarState, - show_scrollbar: bool, - hide_scrollbar_task: Option>, rendered_messages_by_id: HashMap, rendered_tool_uses: HashMap, editing_message: Option<(MessageId, EditingMessageState)>, @@ -780,13 +778,7 @@ impl ActiveThread { cx.observe_global::(|_, cx| cx.notify()), ]; - let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), { - let this = cx.entity().downgrade(); - move |ix, window: &mut Window, cx: &mut App| { - this.update(cx, |this, cx| this.render_message(ix, window, cx)) - .unwrap() - } - }); + let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.)); let workspace_subscription = if let Some(workspace) = workspace.upgrade() { Some(cx.observe_release(&workspace, |this, _, cx| { @@ -811,9 +803,7 @@ impl ActiveThread { expanded_thinking_segments: HashMap::default(), expanded_code_blocks: HashMap::default(), list_state: list_state.clone(), - scrollbar_state: ScrollbarState::new(list_state), - show_scrollbar: false, - hide_scrollbar_task: None, + scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), editing_message: None, last_error: None, copied_code_block_ids: HashSet::default(), @@ -1846,7 +1836,12 @@ impl ActiveThread { ))) } - fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context) -> AnyElement { + fn render_message( + &mut self, + ix: usize, + window: &mut Window, + cx: &mut Context, + ) -> AnyElement { let message_id = self.messages[ix]; let workspace = self.workspace.clone(); let thread = self.thread.read(cx); @@ -2629,7 +2624,7 @@ impl ActiveThread { h_flex() .gap_1p5() .child( - Icon::new(IconName::ToolBulb) + Icon::new(IconName::ToolThink) .size(IconSize::Small) .color(Color::Muted), ) @@ -3503,60 +3498,37 @@ impl ActiveThread { } } - fn render_vertical_scrollbar(&self, cx: &mut Context) -> Option> { - if !self.show_scrollbar && !self.scrollbar_state.is_dragging() { - return None; - } - - Some( - div() - .occlude() - .id("active-thread-scrollbar") - .on_mouse_move(cx.listener(|_, _, _, cx| { - cx.notify(); - cx.stop_propagation() - })) - .on_hover(|_, _, cx| { + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .occlude() + .id("active-thread-scrollbar") + .on_mouse_move(cx.listener(|_, _, _, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { cx.stop_propagation(); - }) - .on_any_mouse_down(|_, _, cx| { - cx.stop_propagation(); - }) - .on_mouse_up( - MouseButton::Left, - cx.listener(|_, _, _, cx| { - cx.stop_propagation(); - }), - ) - .on_scroll_wheel(cx.listener(|_, _, _, cx| { - cx.notify(); - })) - .h_full() - .absolute() - .right_1() - .top_1() - .bottom_0() - .w(px(12.)) - .cursor_default() - .children(Scrollbar::vertical(self.scrollbar_state.clone())), - ) - } - - fn hide_scrollbar_later(&mut self, cx: &mut Context) { - const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); - self.hide_scrollbar_task = Some(cx.spawn(async move |thread, cx| { - cx.background_executor() - .timer(SCROLLBAR_SHOW_INTERVAL) - .await; - thread - .update(cx, |thread, cx| { - if !thread.scrollbar_state.is_dragging() { - thread.show_scrollbar = false; - cx.notify(); - } - }) - .log_err(); - })) + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) } pub fn is_codeblock_expanded(&self, message_id: MessageId, ix: usize) -> bool { @@ -3597,26 +3569,8 @@ impl Render for ActiveThread { .size_full() .relative() .bg(cx.theme().colors().panel_background) - .on_mouse_move(cx.listener(|this, _, _, cx| { - this.show_scrollbar = true; - this.hide_scrollbar_later(cx); - cx.notify(); - })) - .on_scroll_wheel(cx.listener(|this, _, _, cx| { - this.show_scrollbar = true; - this.hide_scrollbar_later(cx); - cx.notify(); - })) - .on_mouse_up( - MouseButton::Left, - cx.listener(|this, _, _, cx| { - this.hide_scrollbar_later(cx); - }), - ) - .child(list(self.list_state.clone()).flex_grow()) - .when_some(self.render_vertical_scrollbar(cx), |this, scrollbar| { - this.child(scrollbar) - }) + .child(list(self.list_state.clone(), cx.processor(Self::render_message)).flex_grow()) + .child(self.render_vertical_scrollbar(cx)) } } diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index c4dc359093..e1ceaf761d 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1523,7 +1523,8 @@ impl AgentDiff { } AcpThreadEvent::Stopped | AcpThreadEvent::ToolAuthorizationRequired - | AcpThreadEvent::Error => {} + | AcpThreadEvent::Error + | AcpThreadEvent::ServerExited(_) => {} } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 5f3315f69a..6b8e36066b 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -970,13 +970,7 @@ impl AgentPanel { ) }); - this.set_active_view( - ActiveView::ExternalAgentThread { - thread_view: thread_view.clone(), - }, - window, - cx, - ); + this.set_active_view(ActiveView::ExternalAgentThread { thread_view }, window, cx); }) }) .detach_and_log_err(cx); @@ -1987,6 +1981,22 @@ impl AgentPanel { ); }), ) + .item( + ContextMenuEntry::new("New Native Agent Thread") + .icon(IconName::ZedAssistant) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::NativeAgent, + ), + } + .boxed_clone(), + cx, + ); + }), + ) }); menu })) @@ -2333,6 +2343,16 @@ impl AgentPanel { ) } + fn render_backdrop(&self, cx: &mut Context) -> impl IntoElement { + div() + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().panel_background) + .opacity(0.8) + .block_mouse_except_scroll() + } + fn render_trial_end_upsell( &self, _window: &mut Window, @@ -2342,15 +2362,24 @@ impl AgentPanel { return None; } - Some(EndTrialUpsell::new(Arc::new({ - let this = cx.entity(); - move |_, cx| { - this.update(cx, |_this, cx| { - TrialEndUpsell::set_dismissed(true, cx); - cx.notify(); - }); - } - }))) + Some( + v_flex() + .absolute() + .inset_0() + .size_full() + .bg(cx.theme().colors().panel_background) + .opacity(0.85) + .block_mouse_except_scroll() + .child(EndTrialUpsell::new(Arc::new({ + let this = cx.entity(); + move |_, cx| { + this.update(cx, |_this, cx| { + TrialEndUpsell::set_dismissed(true, cx); + cx.notify(); + }); + } + }))), + ) } fn render_empty_state_section_header( @@ -2649,6 +2678,31 @@ impl AgentPanel { ) }, ), + ) + .child( + NewThreadButton::new( + "new-native-agent-thread-btn", + "New Native Agent Thread", + IconName::ZedAssistant, + ) + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::NativeAgent, + ), + }), + cx, + ) + }, + ), ), ) }), @@ -3175,9 +3229,10 @@ impl Render for AgentPanel { // - Scrolling in all views works as expected // - Files can be dropped into the panel let content = v_flex() - .key_context(self.key_context()) - .justify_between() + .relative() .size_full() + .justify_between() + .key_context(self.key_context()) .on_action(cx.listener(Self::cancel)) .on_action(cx.listener(|this, action: &NewThread, window, cx| { this.new_thread(action, window, cx); @@ -3220,14 +3275,12 @@ impl Render for AgentPanel { .on_action(cx.listener(Self::toggle_burn_mode)) .child(self.render_toolbar(window, cx)) .children(self.render_onboarding(window, cx)) - .children(self.render_trial_end_upsell(window, cx)) .map(|parent| match &self.active_view { ActiveView::Thread { thread, message_editor, .. } => parent - .relative() .child( if thread.read(cx).is_empty() && !self.should_render_onboarding(cx) { self.render_thread_empty_state(window, cx) @@ -3264,21 +3317,10 @@ impl Render for AgentPanel { }) .child(h_flex().relative().child(message_editor.clone()).when( !LanguageModelRegistry::read_global(cx).has_authenticated_provider(cx), - |this| { - this.child( - div() - .size_full() - .absolute() - .inset_0() - .bg(cx.theme().colors().panel_background) - .opacity(0.8) - .block_mouse_except_scroll(), - ) - }, + |this| this.child(self.render_backdrop(cx)), )) .child(self.render_drag_target(cx)), ActiveView::ExternalAgentThread { thread_view, .. } => parent - .relative() .child(thread_view.clone()) .child(self.render_drag_target(cx)), ActiveView::History => parent.child(self.history.clone()), @@ -3317,7 +3359,8 @@ impl Render for AgentPanel { )) } ActiveView::Configuration => parent.children(self.configuration.clone()), - }); + }) + .children(self.render_trial_end_upsell(window, cx)); match self.active_view.which_font_size_used() { WhichFontSize::AgentFont => { diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 30faf5ef2e..fceb8f4c45 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -151,6 +151,7 @@ enum ExternalAgent { #[default] Gemini, ClaudeCode, + NativeAgent, } impl ExternalAgent { @@ -158,6 +159,7 @@ impl ExternalAgent { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), + ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer), } } } diff --git a/crates/agent_ui/src/context_strip.rs b/crates/agent_ui/src/context_strip.rs index 080ffd2ea0..369964f165 100644 --- a/crates/agent_ui/src/context_strip.rs +++ b/crates/agent_ui/src/context_strip.rs @@ -504,7 +504,7 @@ impl Render for ContextStrip { ) .on_click({ Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| { - if event.down.click_count > 1 { + if event.click_count() > 1 { this.open_context(&context, window, cx); } else { this.focused_index = Some(i); diff --git a/crates/agent_ui/src/ui/end_trial_upsell.rs b/crates/agent_ui/src/ui/end_trial_upsell.rs index 36770c2197..3a8a119800 100644 --- a/crates/agent_ui/src/ui/end_trial_upsell.rs +++ b/crates/agent_ui/src/ui/end_trial_upsell.rs @@ -1,9 +1,9 @@ use std::sync::Arc; -use ai_onboarding::{AgentPanelOnboardingCard, BulletItem}; +use ai_onboarding::{AgentPanelOnboardingCard, PlanDefinitions}; use client::zed_urls; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; -use ui::{Divider, List, Tooltip, prelude::*}; +use ui::{Divider, Tooltip, prelude::*}; #[derive(IntoElement, RegisterComponent)] pub struct EndTrialUpsell { @@ -18,6 +18,8 @@ impl EndTrialUpsell { impl RenderOnce for EndTrialUpsell { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let plan_definitions = PlanDefinitions; + let pro_section = v_flex() .gap_1() .child( @@ -31,13 +33,7 @@ impl RenderOnce for EndTrialUpsell { ) .child(Divider::horizontal()), ) - .child( - List::new() - .child(BulletItem::new("500 prompts with Claude models")) - .child(BulletItem::new( - "Unlimited edit predictions with Zeta, our open-source model", - )), - ) + .child(plan_definitions.pro_plan(false)) .child( Button::new("cta-button", "Upgrade to Zed Pro") .full_width() @@ -68,11 +64,7 @@ impl RenderOnce for EndTrialUpsell { ) .child(Divider::horizontal()), ) - .child( - List::new() - .child(BulletItem::new("50 prompts with the Claude models")) - .child(BulletItem::new("2,000 accepted edit predictions")), - ); + .child(plan_definitions.free_plan()); AgentPanelOnboardingCard::new() .child(Headline::new("Your Zed Pro Trial has expired")) @@ -102,18 +94,20 @@ impl RenderOnce for EndTrialUpsell { impl Component for EndTrialUpsell { fn scope() -> ComponentScope { - ComponentScope::Agent + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "End of Trial Upsell Banner" } fn sort_name() -> &'static str { - "AgentEndTrialUpsell" + "End of Trial Upsell Banner" } fn preview(_window: &mut Window, _cx: &mut App) -> Option { Some( v_flex() - .p_4() - .gap_4() .child(EndTrialUpsell { dismiss_upsell: Arc::new(|_, _| {}), }) diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index e86568fe7a..b55ad4c895 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -1,8 +1,6 @@ use gpui::{Action, IntoElement, ParentElement, RenderOnce, point}; use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; -use ui::{Divider, List, prelude::*}; - -use crate::BulletItem; +use ui::{Divider, List, ListBulletItem, prelude::*}; pub struct ApiKeysWithProviders { configured_providers: Vec<(IconName, SharedString)>, @@ -128,7 +126,7 @@ impl RenderOnce for ApiKeysWithoutProviders { ) .child(Divider::horizontal()), ) - .child(List::new().child(BulletItem::new( + .child(List::new().child(ListBulletItem::new( "Add your own keys to use AI without signing in.", ))) .child( diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs index c252b65f20..b9a1e49a4a 100644 --- a/crates/ai_onboarding/src/ai_onboarding.rs +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -3,6 +3,7 @@ mod agent_panel_onboarding_card; mod agent_panel_onboarding_content; mod ai_upsell_card; mod edit_prediction_onboarding_content; +mod plan_definitions; mod young_account_banner; pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders}; @@ -11,51 +12,14 @@ pub use agent_panel_onboarding_content::AgentPanelOnboarding; pub use ai_upsell_card::AiUpsellCard; use cloud_llm_client::Plan; pub use edit_prediction_onboarding_content::EditPredictionOnboarding; +pub use plan_definitions::PlanDefinitions; pub use young_account_banner::YoungAccountBanner; use std::sync::Arc; use client::{Client, UserStore, zed_urls}; -use gpui::{AnyElement, Entity, IntoElement, ParentElement, SharedString}; -use ui::{Divider, List, ListItem, RegisterComponent, TintColor, Tooltip, prelude::*}; - -#[derive(IntoElement)] -pub struct BulletItem { - label: SharedString, -} - -impl BulletItem { - pub fn new(label: impl Into) -> Self { - Self { - label: label.into(), - } - } -} - -impl RenderOnce for BulletItem { - fn render(self, window: &mut Window, _cx: &mut App) -> impl IntoElement { - let line_height = 0.85 * window.line_height(); - - ListItem::new("list-item") - .selectable(false) - .child( - h_flex() - .w_full() - .min_w_0() - .gap_1() - .items_start() - .child( - h_flex().h(line_height).justify_center().child( - Icon::new(IconName::Dash) - .size(IconSize::XSmall) - .color(Color::Hidden), - ), - ) - .child(div().w_full().min_w_0().child(Label::new(self.label))), - ) - .into_any_element() - } -} +use gpui::{AnyElement, Entity, IntoElement, ParentElement}; +use ui::{Divider, RegisterComponent, TintColor, Tooltip, prelude::*}; #[derive(PartialEq)] pub enum SignInStatus { @@ -130,107 +94,6 @@ impl ZedAiOnboarding { self } - fn free_plan_definition(&self, cx: &mut App) -> impl IntoElement { - v_flex() - .mt_2() - .gap_1() - .child( - h_flex() - .gap_2() - .child( - Label::new("Free") - .size(LabelSize::Small) - .color(Color::Muted) - .buffer_font(cx), - ) - .child( - Label::new("(Current Plan)") - .size(LabelSize::Small) - .color(Color::Custom(cx.theme().colors().text_muted.opacity(0.6))) - .buffer_font(cx), - ) - .child(Divider::horizontal()), - ) - .child( - List::new() - .child(BulletItem::new("50 prompts per month with Claude models")) - .child(BulletItem::new( - "2,000 accepted edit predictions with Zeta, our open-source model", - )), - ) - } - - fn pro_trial_definition(&self) -> impl IntoElement { - List::new() - .child(BulletItem::new("150 prompts with Claude models")) - .child(BulletItem::new( - "Unlimited accepted edit predictions with Zeta, our open-source model", - )) - } - - fn pro_plan_definition(&self, cx: &mut App) -> impl IntoElement { - v_flex().mt_2().gap_1().map(|this| { - if self.account_too_young { - this.child( - h_flex() - .gap_2() - .child( - Label::new("Pro") - .size(LabelSize::Small) - .color(Color::Accent) - .buffer_font(cx), - ) - .child(Divider::horizontal()), - ) - .child( - List::new() - .child(BulletItem::new("500 prompts per month with Claude models")) - .child(BulletItem::new( - "Unlimited accepted edit predictions with Zeta, our open-source model", - )) - .child(BulletItem::new("$20 USD per month")), - ) - .child( - Button::new("pro", "Get Started") - .full_width() - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .on_click(move |_, _window, cx| { - telemetry::event!("Upgrade To Pro Clicked", state = "young-account"); - cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) - }), - ) - } else { - this.child( - h_flex() - .gap_2() - .child( - Label::new("Pro Trial") - .size(LabelSize::Small) - .color(Color::Accent) - .buffer_font(cx), - ) - .child(Divider::horizontal()), - ) - .child( - List::new() - .child(self.pro_trial_definition()) - .child(BulletItem::new( - "Try it out for 14 days for free, no credit card required", - )), - ) - .child( - Button::new("pro", "Start Free Trial") - .full_width() - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .on_click(move |_, _window, cx| { - telemetry::event!("Start Trial Clicked", state = "post-sign-in"); - cx.open_url(&zed_urls::start_trial_url(cx)) - }), - ) - } - }) - } - fn render_accept_terms_of_service(&self) -> AnyElement { v_flex() .gap_1() @@ -269,6 +132,7 @@ impl ZedAiOnboarding { fn render_sign_in_disclaimer(&self, _cx: &mut App) -> AnyElement { let signing_in = matches!(self.sign_in_status, SignInStatus::SigningIn); + let plan_definitions = PlanDefinitions; v_flex() .gap_1() @@ -278,7 +142,7 @@ impl ZedAiOnboarding { .color(Color::Muted) .mb_2(), ) - .child(self.pro_trial_definition()) + .child(plan_definitions.pro_plan(false)) .child( Button::new("sign_in", "Try Zed Pro for Free") .disabled(signing_in) @@ -297,43 +161,132 @@ impl ZedAiOnboarding { fn render_free_plan_state(&self, cx: &mut App) -> AnyElement { let young_account_banner = YoungAccountBanner; + let plan_definitions = PlanDefinitions; - v_flex() - .relative() - .gap_1() - .child(Headline::new("Welcome to Zed AI")) - .map(|this| { - if self.account_too_young { - this.child(young_account_banner) - } else { - this.child(self.free_plan_definition(cx)).when_some( - self.dismiss_onboarding.as_ref(), - |this, dismiss_callback| { - let callback = dismiss_callback.clone(); + if self.account_too_young { + v_flex() + .relative() + .max_w_full() + .gap_1() + .child(Headline::new("Welcome to Zed AI")) + .child(young_account_banner) + .child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_plan(true)) + .child( + Button::new("pro", "Get Started") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Upgrade To Pro Clicked", + state = "young-account" + ); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), + ), + ) + .into_any_element() + } else { + v_flex() + .relative() + .gap_1() + .child(Headline::new("Welcome to Zed AI")) + .child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Free") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child( + Label::new("(Current Plan)") + .size(LabelSize::Small) + .color(Color::Custom( + cx.theme().colors().text_muted.opacity(0.6), + )) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.free_plan()), + ) + .when_some( + self.dismiss_onboarding.as_ref(), + |this, dismiss_callback| { + let callback = dismiss_callback.clone(); - this.child( - h_flex().absolute().top_0().right_0().child( - IconButton::new("dismiss_onboarding", IconName::Close) - .icon_size(IconSize::Small) - .tooltip(Tooltip::text("Dismiss")) - .on_click(move |_, window, cx| { - telemetry::event!( - "Banner Dismissed", - source = "AI Onboarding", - ); - callback(window, cx) - }), - ), - ) - }, - ) - } - }) - .child(self.pro_plan_definition(cx)) - .into_any_element() + this.child( + h_flex().absolute().top_0().right_0().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click(move |_, window, cx| { + telemetry::event!( + "Banner Dismissed", + source = "AI Onboarding", + ); + callback(window, cx) + }), + ), + ) + }, + ) + .child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro Trial") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_trial(true)) + .child( + Button::new("pro", "Start Free Trial") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Start Trial Clicked", + state = "post-sign-in" + ); + cx.open_url(&zed_urls::start_trial_url(cx)) + }), + ), + ) + .into_any_element() + } } fn render_trial_state(&self, _cx: &mut App) -> AnyElement { + let plan_definitions = PlanDefinitions; + v_flex() .relative() .gap_1() @@ -343,13 +296,7 @@ impl ZedAiOnboarding { .color(Color::Muted) .mb_2(), ) - .child( - List::new() - .child(BulletItem::new("150 prompts with Claude models")) - .child(BulletItem::new( - "Unlimited edit predictions with Zeta, our open-source model", - )), - ) + .child(plan_definitions.pro_trial(false)) .when_some( self.dismiss_onboarding.as_ref(), |this, dismiss_callback| { @@ -374,6 +321,8 @@ impl ZedAiOnboarding { } fn render_pro_plan_state(&self, _cx: &mut App) -> AnyElement { + let plan_definitions = PlanDefinitions; + v_flex() .gap_1() .child(Headline::new("Welcome to Zed Pro")) @@ -382,13 +331,7 @@ impl ZedAiOnboarding { .color(Color::Muted) .mb_2(), ) - .child( - List::new() - .child(BulletItem::new("500 prompts with Claude models")) - .child(BulletItem::new( - "Unlimited edit predictions with Zeta, our open-source model", - )), - ) + .child(plan_definitions.pro_plan(false)) .child( Button::new("pro", "Continue with Zed Pro") .full_width() @@ -425,7 +368,15 @@ impl RenderOnce for ZedAiOnboarding { impl Component for ZedAiOnboarding { fn scope() -> ComponentScope { - ComponentScope::Agent + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "Agent Panel Banners" + } + + fn sort_name() -> &'static str { + "Agent Panel Banners" } fn preview(_window: &mut Window, _cx: &mut App) -> Option { @@ -450,8 +401,9 @@ impl Component for ZedAiOnboarding { Some( v_flex() - .p_4() .gap_4() + .items_center() + .max_w_4_5() .children(vec![ single_example( "Not Signed-in", @@ -462,8 +414,8 @@ impl Component for ZedAiOnboarding { onboarding(SignInStatus::SignedIn, false, None, false), ), single_example( - "Account too young", - onboarding(SignInStatus::SignedIn, false, None, true), + "Young Account", + onboarding(SignInStatus::SignedIn, true, None, true), ), single_example( "Free Plan", diff --git a/crates/ai_onboarding/src/ai_upsell_card.rs b/crates/ai_onboarding/src/ai_upsell_card.rs index 89a782a7c2..e9639ca075 100644 --- a/crates/ai_onboarding/src/ai_upsell_card.rs +++ b/crates/ai_onboarding/src/ai_upsell_card.rs @@ -1,23 +1,33 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; -use client::{Client, zed_urls}; +use client::{Client, UserStore, zed_urls}; use cloud_llm_client::Plan; -use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; -use ui::{Divider, List, Vector, VectorName, prelude::*}; +use gpui::{ + Animation, AnimationExt, AnyElement, App, Entity, IntoElement, RenderOnce, Transformation, + Window, percentage, +}; +use ui::{Divider, Vector, VectorName, prelude::*}; -use crate::{BulletItem, SignInStatus}; +use crate::{SignInStatus, YoungAccountBanner, plan_definitions::PlanDefinitions}; #[derive(IntoElement, RegisterComponent)] pub struct AiUpsellCard { pub sign_in_status: SignInStatus, pub sign_in: Arc, + pub account_too_young: bool, pub user_plan: Option, pub tab_index: Option, } impl AiUpsellCard { - pub fn new(client: Arc, user_plan: Option) -> Self { + pub fn new( + client: Arc, + user_store: &Entity, + user_plan: Option, + cx: &mut App, + ) -> Self { let status = *client.status().borrow(); + let store = user_store.read(cx); Self { user_plan, @@ -29,6 +39,7 @@ impl AiUpsellCard { }) .detach_and_log_err(cx); }), + account_too_young: store.account_too_young(), tab_index: None, } } @@ -36,6 +47,9 @@ impl AiUpsellCard { impl RenderOnce for AiUpsellCard { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let plan_definitions = PlanDefinitions; + let young_account_banner = YoungAccountBanner; + let pro_section = v_flex() .flex_grow() .w_full() @@ -51,13 +65,7 @@ impl RenderOnce for AiUpsellCard { ) .child(Divider::horizontal()), ) - .child( - List::new() - .child(BulletItem::new("500 prompts with Claude models")) - .child(BulletItem::new( - "Unlimited edit predictions with Zeta, our open-source model", - )), - ); + .child(plan_definitions.pro_plan(false)); let free_section = v_flex() .flex_grow() @@ -74,11 +82,7 @@ impl RenderOnce for AiUpsellCard { ) .child(Divider::horizontal()), ) - .child( - List::new() - .child(BulletItem::new("50 prompts with Claude models")) - .child(BulletItem::new("2,000 accepted edit predictions")), - ); + .child(plan_definitions.free_plan()); let grid_bg = h_flex().absolute().inset_0().w_full().h(px(240.)).child( Vector::new(VectorName::Grid, rems_from_px(500.), rems_from_px(240.)) @@ -101,44 +105,11 @@ impl RenderOnce for AiUpsellCard { ), )); - const DESCRIPTION: &str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI."; + let description = PlanDefinitions::AI_DESCRIPTION; - let footer_buttons = match self.sign_in_status { - SignInStatus::SignedIn => v_flex() - .items_center() - .gap_1() - .child( - Button::new("sign_in", "Start 14-day Free Pro Trial") - .full_width() - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .on_click(move |_, _window, cx| { - telemetry::event!("Start Trial Clicked", state = "post-sign-in"); - cx.open_url(&zed_urls::start_trial_url(cx)) - }) - .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)), - ) - .child( - Label::new("No credit card required") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .into_any_element(), - _ => Button::new("sign_in", "Sign In") - .full_width() - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)) - .on_click({ - let callback = self.sign_in.clone(); - move |_, window, cx| { - telemetry::event!("Start Trial Clicked", state = "pre-sign-in"); - callback(window, cx) - } - }) - .into_any_element(), - }; - - v_flex() + let card = v_flex() .relative() + .flex_grow() .p_4() .pt_3() .border_1() @@ -146,31 +117,169 @@ impl RenderOnce for AiUpsellCard { .rounded_lg() .overflow_hidden() .child(grid_bg) - .child(gradient_bg) - .child(Label::new("Try Zed AI").size(LabelSize::Large)) + .child(gradient_bg); + + let plans_section = h_flex() + .w_full() + .mt_1p5() + .mb_2p5() + .items_start() + .gap_6() + .child(free_section) + .child(pro_section); + + let footer_container = v_flex().items_center().gap_1(); + + let certified_user_stamp = div() + .absolute() + .top_2() + .right_2() + .size(rems_from_px(72.)) .child( - div() - .max_w_3_4() - .mb_2() - .child(Label::new(DESCRIPTION).color(Color::Muted)), - ) + Vector::new( + VectorName::ProUserStamp, + rems_from_px(72.), + rems_from_px(72.), + ) + .color(Color::Custom(cx.theme().colors().text_accent.alpha(0.3))) + .with_animation( + "loading_stamp", + Animation::new(Duration::from_secs(10)).repeat(), + |this, delta| this.transform(Transformation::rotate(percentage(delta))), + ), + ); + + let pro_trial_stamp = div() + .absolute() + .top_2() + .right_2() + .size(rems_from_px(72.)) .child( - h_flex() - .w_full() - .mt_1p5() - .mb_2p5() - .items_start() - .gap_6() - .child(free_section) - .child(pro_section), - ) - .child(footer_buttons) + Vector::new( + VectorName::ProTrialStamp, + rems_from_px(72.), + rems_from_px(72.), + ) + .color(Color::Custom(cx.theme().colors().text.alpha(0.2))), + ); + + match self.sign_in_status { + SignInStatus::SignedIn => match self.user_plan { + None | Some(Plan::ZedFree) => card + .child(Label::new("Try Zed AI").size(LabelSize::Large)) + .map(|this| { + if self.account_too_young { + this.child(young_account_banner).child( + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(plan_definitions.pro_plan(true)) + .child( + Button::new("pro", "Get Started") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Upgrade To Pro Clicked", + state = "young-account" + ); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), + ), + ) + } else { + this.child( + div() + .max_w_3_4() + .mb_2() + .child(Label::new(description).color(Color::Muted)), + ) + .child(plans_section) + .child( + footer_container + .child( + Button::new("start_trial", "Start 14-day Free Pro Trial") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .when_some(self.tab_index, |this, tab_index| { + this.tab_index(tab_index) + }) + .on_click(move |_, _window, cx| { + telemetry::event!( + "Start Trial Clicked", + state = "post-sign-in" + ); + cx.open_url(&zed_urls::start_trial_url(cx)) + }), + ) + .child( + Label::new("No credit card required") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + } + }), + Some(Plan::ZedProTrial) => card + .child(pro_trial_stamp) + .child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large)) + .child( + Label::new("Here's what you get for the next 14 days:") + .color(Color::Muted) + .mb_2(), + ) + .child(plan_definitions.pro_trial(false)), + Some(Plan::ZedPro) => card + .child(certified_user_stamp) + .child(Label::new("You're in the Zed Pro plan").size(LabelSize::Large)) + .child( + Label::new("Here's what you get:") + .color(Color::Muted) + .mb_2(), + ) + .child(plan_definitions.pro_plan(false)), + }, + // Signed Out State + _ => card + .child(Label::new("Try Zed AI").size(LabelSize::Large)) + .child( + div() + .max_w_3_4() + .mb_2() + .child(Label::new(description).color(Color::Muted)), + ) + .child(plans_section) + .child( + Button::new("sign_in", "Sign In") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)) + .on_click({ + let callback = self.sign_in.clone(); + move |_, window, cx| { + telemetry::event!("Start Trial Clicked", state = "pre-sign-in"); + callback(window, cx) + } + }), + ), + } } } impl Component for AiUpsellCard { fn scope() -> ComponentScope { - ComponentScope::Agent + ComponentScope::Onboarding } fn name() -> &'static str { @@ -188,30 +297,69 @@ impl Component for AiUpsellCard { fn preview(_window: &mut Window, _cx: &mut App) -> Option { Some( v_flex() - .p_4() .gap_4() - .children(vec![example_group(vec![ - single_example( - "Signed Out State", - AiUpsellCard { - sign_in_status: SignInStatus::SignedOut, - sign_in: Arc::new(|_, _| {}), - user_plan: None, - tab_index: Some(0), - } - .into_any_element(), - ), - single_example( - "Signed In State", - AiUpsellCard { - sign_in_status: SignInStatus::SignedIn, - sign_in: Arc::new(|_, _| {}), - user_plan: None, - tab_index: Some(1), - } - .into_any_element(), - ), - ])]) + .items_center() + .max_w_4_5() + .child(single_example( + "Signed Out State", + AiUpsellCard { + sign_in_status: SignInStatus::SignedOut, + sign_in: Arc::new(|_, _| {}), + account_too_young: false, + user_plan: None, + tab_index: Some(0), + } + .into_any_element(), + )) + .child(example_group_with_title( + "Signed In States", + vec![ + single_example( + "Free Plan", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + account_too_young: false, + user_plan: Some(Plan::ZedFree), + tab_index: Some(1), + } + .into_any_element(), + ), + single_example( + "Free Plan but Young Account", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + account_too_young: true, + user_plan: Some(Plan::ZedFree), + tab_index: Some(1), + } + .into_any_element(), + ), + single_example( + "Pro Trial", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + account_too_young: false, + user_plan: Some(Plan::ZedProTrial), + tab_index: Some(1), + } + .into_any_element(), + ), + single_example( + "Pro Plan", + AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + account_too_young: false, + user_plan: Some(Plan::ZedPro), + tab_index: Some(1), + } + .into_any_element(), + ), + ], + )) .into_any_element(), ) } diff --git a/crates/ai_onboarding/src/plan_definitions.rs b/crates/ai_onboarding/src/plan_definitions.rs new file mode 100644 index 0000000000..8d66f6c356 --- /dev/null +++ b/crates/ai_onboarding/src/plan_definitions.rs @@ -0,0 +1,39 @@ +use gpui::{IntoElement, ParentElement}; +use ui::{List, ListBulletItem, prelude::*}; + +/// Centralized definitions for Zed AI plans +pub struct PlanDefinitions; + +impl PlanDefinitions { + pub const AI_DESCRIPTION: &'static str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI."; + + pub fn free_plan(&self) -> impl IntoElement { + List::new() + .child(ListBulletItem::new("50 prompts with Claude models")) + .child(ListBulletItem::new("2,000 accepted edit predictions")) + } + + pub fn pro_trial(&self, period: bool) -> impl IntoElement { + List::new() + .child(ListBulletItem::new("150 prompts with Claude models")) + .child(ListBulletItem::new( + "Unlimited edit predictions with Zeta, our open-source model", + )) + .when(period, |this| { + this.child(ListBulletItem::new( + "Try it out for 14 days for free, no credit card required", + )) + }) + } + + pub fn pro_plan(&self, price: bool) -> impl IntoElement { + List::new() + .child(ListBulletItem::new("500 prompts with Claude models")) + .child(ListBulletItem::new( + "Unlimited edit predictions with Zeta, our open-source model", + )) + .when(price, |this| { + this.child(ListBulletItem::new("$20 USD per month")) + }) + } +} diff --git a/crates/ai_onboarding/src/young_account_banner.rs b/crates/ai_onboarding/src/young_account_banner.rs index a43625a60e..54f563e4aa 100644 --- a/crates/ai_onboarding/src/young_account_banner.rs +++ b/crates/ai_onboarding/src/young_account_banner.rs @@ -15,6 +15,7 @@ impl RenderOnce for YoungAccountBanner { .child(YOUNG_ACCOUNT_DISCLAIMER); div() + .max_w_full() .my_1() .child(Banner::new().severity(ui::Severity::Warning).child(label)) } diff --git a/crates/assistant_context/src/assistant_context.rs b/crates/assistant_context/src/assistant_context.rs index 4518bbff79..557f9592e4 100644 --- a/crates/assistant_context/src/assistant_context.rs +++ b/crates/assistant_context/src/assistant_context.rs @@ -2,16 +2,16 @@ mod assistant_context_tests; mod context_store; -use agent_settings::AgentSettings; +use agent_settings::{AgentSettings, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Context as _, Result, bail}; use assistant_slash_command::{ SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection, SlashCommandResult, SlashCommandWorkingSet, }; use assistant_slash_commands::FileCommandMetadata; -use client::{self, Client, proto, telemetry::Telemetry}; +use client::{self, Client, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry}; use clock::ReplicaId; -use cloud_llm_client::CompletionIntent; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use collections::{HashMap, HashSet}; use fs::{Fs, RenameOptions}; use futures::{FutureExt, StreamExt, future::Shared}; @@ -2080,7 +2080,18 @@ impl AssistantContext { }); match event { - LanguageModelCompletionEvent::StatusUpdate { .. } => {} + LanguageModelCompletionEvent::StatusUpdate(status_update) => { + match status_update { + CompletionRequestStatus::UsageUpdated { amount, limit } => { + this.update_model_request_usage( + amount as u32, + limit, + cx, + ); + } + _ => {} + } + } LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; @@ -2677,10 +2688,7 @@ impl AssistantContext { let mut request = self.to_completion_request(Some(&model.model), cx); request.messages.push(LanguageModelRequestMessage { role: Role::User, - content: vec![ - "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`" - .into(), - ], + content: vec![SUMMARIZE_THREAD_PROMPT.into()], cache: false, }); @@ -2956,6 +2964,21 @@ impl AssistantContext { summary.text = custom_summary; cx.emit(ContextEvent::SummaryChanged); } + + fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) { + let Some(project) = &self.project else { + return; + }; + project.read(cx).user_store().update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }); + } } #[derive(Debug, Default)] diff --git a/crates/assistant_context/src/assistant_context_tests.rs b/crates/assistant_context/src/assistant_context_tests.rs index f139d525d3..efcad8ed96 100644 --- a/crates/assistant_context/src/assistant_context_tests.rs +++ b/crates/assistant_context/src/assistant_context_tests.rs @@ -1210,8 +1210,8 @@ async fn test_summarization(cx: &mut TestAppContext) { }); cx.run_until_parked(); - fake_model.stream_last_completion_response("Brief"); - fake_model.stream_last_completion_response(" Introduction"); + fake_model.send_last_completion_stream_text_chunk("Brief"); + fake_model.send_last_completion_stream_text_chunk(" Introduction"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -1274,7 +1274,7 @@ async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { }); cx.run_until_parked(); - fake_model.stream_last_completion_response("A successful summary"); + fake_model.send_last_completion_stream_text_chunk("A successful summary"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -1356,7 +1356,7 @@ fn setup_context_editor_with_fake_model( fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { cx.run_until_parked(); - fake_model.stream_last_completion_response("Assistant response"); + fake_model.send_last_completion_stream_text_chunk("Assistant response"); fake_model.end_last_completion_stream(); cx.run_until_parked(); } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 57fdc51336..bf668e6918 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -2,7 +2,7 @@ mod copy_path_tool; mod create_directory_tool; mod delete_path_tool; mod diagnostics_tool; -mod edit_agent; +pub mod edit_agent; mod edit_file_tool; mod fetch_tool; mod find_path_tool; @@ -14,7 +14,7 @@ mod open_tool; mod project_notifications_tool; mod read_file_tool; mod schema; -mod templates; +pub mod templates; mod terminal_tool; mod thinking_tool; mod ui; @@ -36,13 +36,12 @@ use crate::delete_path_tool::DeletePathTool; use crate::diagnostics_tool::DiagnosticsTool; use crate::edit_file_tool::EditFileTool; use crate::fetch_tool::FetchTool; -use crate::find_path_tool::FindPathTool; use crate::list_directory_tool::ListDirectoryTool; use crate::now_tool::NowTool; use crate::thinking_tool::ThinkingTool; pub use edit_file_tool::{EditFileMode, EditFileToolInput}; -pub use find_path_tool::FindPathToolInput; +pub use find_path_tool::*; pub use grep_tool::{GrepTool, GrepToolInput}; pub use open_tool::OpenTool; pub use project_notifications_tool::ProjectNotificationsTool; diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index fed79434bb..dcb14a48f3 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -29,7 +29,6 @@ use serde::{Deserialize, Serialize}; use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll}; use streaming_diff::{CharOperation, StreamingDiff}; use streaming_fuzzy_matcher::StreamingFuzzyMatcher; -use util::debug_panic; #[derive(Serialize)] struct CreateFilePromptTemplate { @@ -682,11 +681,6 @@ impl EditAgent { if last_message.content.is_empty() { conversation.messages.pop(); } - } else { - debug_panic!( - "Last message must be an Assistant tool calling! Got {:?}", - last_message.content - ); } } @@ -962,7 +956,7 @@ mod tests { ); cx.run_until_parked(); - model.stream_last_completion_response("a"); + model.send_last_completion_stream_text_chunk("a"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( @@ -974,7 +968,7 @@ mod tests { None ); - model.stream_last_completion_response("bc"); + model.send_last_completion_stream_text_chunk("bc"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -996,7 +990,7 @@ mod tests { }) ); - model.stream_last_completion_response("abX"); + model.send_last_completion_stream_text_chunk("abX"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); assert_eq!( @@ -1011,7 +1005,7 @@ mod tests { }) ); - model.stream_last_completion_response("cY"); + model.send_last_completion_stream_text_chunk("cY"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); assert_eq!( @@ -1026,8 +1020,8 @@ mod tests { }) ); - model.stream_last_completion_response(""); - model.stream_last_completion_response("hall"); + model.send_last_completion_stream_text_chunk(""); + model.send_last_completion_stream_text_chunk("hall"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( @@ -1042,8 +1036,8 @@ mod tests { }) ); - model.stream_last_completion_response("ucinated old"); - model.stream_last_completion_response(""); + model.send_last_completion_stream_text_chunk("ucinated old"); + model.send_last_completion_stream_text_chunk(""); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1061,8 +1055,8 @@ mod tests { }) ); - model.stream_last_completion_response("hallucinated new"); + model.send_last_completion_stream_text_chunk("hallucinated new"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( @@ -1077,7 +1071,7 @@ mod tests { }) ); - model.stream_last_completion_response("\nghi\nj"); + model.send_last_completion_stream_text_chunk("\nghi\nj"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1099,8 +1093,8 @@ mod tests { }) ); - model.stream_last_completion_response("kl"); - model.stream_last_completion_response(""); + model.send_last_completion_stream_text_chunk("kl"); + model.send_last_completion_stream_text_chunk(""); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1122,7 +1116,7 @@ mod tests { }) ); - model.stream_last_completion_response("GHI"); + model.send_last_completion_stream_text_chunk("GHI"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1367,7 +1361,9 @@ mod tests { cx.background_spawn(async move { for chunk in chunks { executor.simulate_random_delay().await; - model.as_fake().stream_last_completion_response(chunk); + model + .as_fake() + .send_last_completion_stream_text_chunk(chunk); } model.as_fake().end_last_completion_stream(); }) diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 1c41b26092..dce9f49abd 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -1577,7 +1577,7 @@ mod tests { // Stream the unformatted content cx.executor().run_until_parked(); - model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); model.end_last_completion_stream(); edit_task.await @@ -1641,7 +1641,7 @@ mod tests { // Stream the unformatted content cx.executor().run_until_parked(); - model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string()); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); model.end_last_completion_stream(); edit_task.await @@ -1720,7 +1720,9 @@ mod tests { // Stream the content with trailing whitespace cx.executor().run_until_parked(); - model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); model.end_last_completion_stream(); edit_task.await @@ -1777,7 +1779,9 @@ mod tests { // Stream the content with trailing whitespace cx.executor().run_until_parked(); - model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string()); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); model.end_last_completion_stream(); edit_task.await diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 58833c5208..8add60f09a 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -225,7 +225,6 @@ impl Tool for TerminalTool { env, ..Default::default() }), - window, cx, ) })? diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index 443c2930be..76c6e6c0ba 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -37,7 +37,7 @@ impl Tool for ThinkingTool { } fn icon(&self) -> IconName { - IconName::ToolBulb + IconName::ToolThink } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/auto_update/src/auto_update.rs b/crates/auto_update/src/auto_update.rs index d62a9cdbe3..074aaa6fea 100644 --- a/crates/auto_update/src/auto_update.rs +++ b/crates/auto_update/src/auto_update.rs @@ -134,10 +134,15 @@ impl Settings for AutoUpdateSetting { type FileContent = Option; fn load(sources: SettingsSources, _: &mut App) -> Result { - let auto_update = [sources.server, sources.release_channel, sources.user] - .into_iter() - .find_map(|value| value.copied().flatten()) - .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); + let auto_update = [ + sources.server, + sources.release_channel, + sources.operating_system, + sources.user, + ] + .into_iter() + .find_map(|value| value.copied().flatten()) + .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); Ok(Self(auto_update.0)) } diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 287c62b753..8d6cd2544a 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -400,7 +400,6 @@ mod linux { os::unix::net::{SocketAddr, UnixDatagram}, path::{Path, PathBuf}, process::{self, ExitStatus}, - sync::LazyLock, thread, time::Duration, }; @@ -411,9 +410,6 @@ mod linux { use crate::{Detect, InstalledApp}; - static RELEASE_CHANNEL: LazyLock = - LazyLock::new(|| include_str!("../../zed/RELEASE_CHANNEL").trim().to_string()); - struct App(PathBuf); impl Detect { @@ -444,10 +440,10 @@ mod linux { fn zed_version_string(&self) -> String { format!( "Zed {}{}{} – {}", - if *RELEASE_CHANNEL == "stable" { + if *release_channel::RELEASE_CHANNEL_NAME == "stable" { "".to_string() } else { - format!("{} ", *RELEASE_CHANNEL) + format!("{} ", *release_channel::RELEASE_CHANNEL_NAME) }, option_env!("RELEASE_VERSION").unwrap_or_default(), match option_env!("ZED_COMMIT_SHA") { @@ -459,7 +455,10 @@ mod linux { } fn launch(&self, ipc_url: String) -> anyhow::Result<()> { - let sock_path = paths::data_dir().join(format!("zed-{}.sock", *RELEASE_CHANNEL)); + let sock_path = paths::data_dir().join(format!( + "zed-{}.sock", + *release_channel::RELEASE_CHANNEL_NAME + )); let sock = UnixDatagram::unbound()?; if sock.connect(&sock_path).is_err() { self.boot_background(ipc_url)?; diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index b4894cddcf..f09c012a85 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -14,7 +14,9 @@ use async_tungstenite::tungstenite::{ }; use clock::SystemClock; use cloud_api_client::CloudApiClient; +use cloud_api_client::websocket_protocol::MessageToClient; use credentials_provider::CredentialsProvider; +use feature_flags::FeatureFlagAppExt as _; use futures::{ AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, channel::oneshot, future::BoxFuture, @@ -191,6 +193,8 @@ pub fn init(client: &Arc, cx: &mut App) { }); } +pub type MessageToClientHandler = Box; + struct GlobalClient(Arc); impl Global for GlobalClient {} @@ -204,6 +208,7 @@ pub struct Client { credentials_provider: ClientCredentialsProvider, state: RwLock, handler_set: parking_lot::Mutex, + message_to_client_handlers: parking_lot::Mutex>, #[allow(clippy::type_complexity)] #[cfg(any(test, feature = "test-support"))] @@ -553,6 +558,7 @@ impl Client { credentials_provider: ClientCredentialsProvider::new(cx), state: Default::default(), handler_set: Default::default(), + message_to_client_handlers: parking_lot::Mutex::new(Vec::new()), #[cfg(any(test, feature = "test-support"))] authenticate: Default::default(), @@ -933,23 +939,77 @@ impl Client { } } - /// Performs a sign-in and also connects to Collab. + /// Establishes a WebSocket connection with Cloud for receiving updates from the server. + async fn connect_to_cloud(self: &Arc, cx: &AsyncApp) -> Result<()> { + let connect_task = cx.update({ + let cloud_client = self.cloud_client.clone(); + move |cx| cloud_client.connect(cx) + })??; + let connection = connect_task.await?; + + let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?; + task.detach(); + + cx.spawn({ + let this = self.clone(); + async move |cx| { + while let Some(message) = messages.next().await { + if let Some(message) = message.log_err() { + this.handle_message_to_client(message, cx); + } + } + } + }) + .detach(); + + Ok(()) + } + + /// Performs a sign-in and also (optionally) connects to Collab. /// - /// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls - /// to `sign_in` when we're ready to remove auto-connection to Collab. + /// Only Zed staff automatically connect to Collab. pub async fn sign_in_with_optional_connect( self: &Arc, try_provider: bool, cx: &AsyncApp, ) -> Result<()> { + let (is_staff_tx, is_staff_rx) = oneshot::channel::(); + let mut is_staff_tx = Some(is_staff_tx); + cx.update(|cx| { + cx.on_flags_ready(move |state, _cx| { + if let Some(is_staff_tx) = is_staff_tx.take() { + is_staff_tx.send(state.is_staff).log_err(); + } + }) + .detach(); + }) + .log_err(); + let credentials = self.sign_in(try_provider, cx).await?; - let connect_result = match self.connect_with_credentials(credentials, cx).await { - ConnectionResult::Timeout => Err(anyhow!("connection timed out")), - ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")), - ConnectionResult::Result(result) => result.context("client auth and connect"), - }; - connect_result.log_err(); + self.connect_to_cloud(cx).await.log_err(); + + cx.update(move |cx| { + cx.spawn({ + let client = self.clone(); + async move |cx| { + let is_staff = is_staff_rx.await?; + if is_staff { + match client.connect_with_credentials(credentials, cx).await { + ConnectionResult::Timeout => Err(anyhow!("connection timed out")), + ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")), + ConnectionResult::Result(result) => { + result.context("client auth and connect") + } + } + } else { + Ok(()) + } + } + }) + .detach_and_log_err(cx); + }) + .log_err(); Ok(()) } @@ -1622,6 +1682,24 @@ impl Client { } } + pub fn add_message_to_client_handler( + self: &Arc, + handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static, + ) { + self.message_to_client_handlers + .lock() + .push(Box::new(handler)); + } + + fn handle_message_to_client(self: &Arc, message: MessageToClient, cx: &AsyncApp) { + cx.update(|cx| { + for handler in self.message_to_client_handlers.lock().iter() { + handler(&message, cx); + } + }) + .ok(); + } + pub fn telemetry(&self) -> &Arc { &self.telemetry } diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 3c125a0882..9f76dd7ad0 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,6 +1,7 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; +use cloud_api_client::websocket_protocol::MessageToClient; use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo}; use cloud_llm_client::{ EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, @@ -181,6 +182,12 @@ impl UserStore { client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info), client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts), ]; + + client.add_message_to_client_handler({ + let this = cx.weak_entity(); + move |message, cx| Self::handle_message_to_client(this.clone(), message, cx) + }); + Self { users: Default::default(), by_github_login: Default::default(), @@ -813,6 +820,32 @@ impl UserStore { cx.emit(Event::PrivateUserInfoUpdated); } + fn handle_message_to_client(this: WeakEntity, message: &MessageToClient, cx: &App) { + cx.spawn(async move |cx| { + match message { + MessageToClient::UserUpdated => { + let cloud_client = cx + .update(|cx| { + this.read_with(cx, |this, _cx| { + this.client.upgrade().map(|client| client.cloud_client()) + }) + })?? + .ok_or(anyhow::anyhow!("Failed to get Cloud client"))?; + + let response = cloud_client.get_authenticated_user().await?; + cx.update(|cx| { + this.update(cx, |this, cx| { + this.update_authenticated_user(response, cx); + }) + })??; + } + } + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + pub fn watch_current_user(&self) -> watch::Receiver>> { self.current_user.clone() } diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml index d56aa94c6e..8e50ccb191 100644 --- a/crates/cloud_api_client/Cargo.toml +++ b/crates/cloud_api_client/Cargo.toml @@ -15,7 +15,10 @@ path = "src/cloud_api_client.rs" anyhow.workspace = true cloud_api_types.workspace = true futures.workspace = true +gpui.workspace = true +gpui_tokio.workspace = true http_client.workspace = true parking_lot.workspace = true serde_json.workspace = true workspace-hack.workspace = true +yawc.workspace = true diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index edac051a0e..ef9a1a9a55 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -1,11 +1,19 @@ +mod websocket; + use std::sync::Arc; use anyhow::{Context, Result, anyhow}; +use cloud_api_types::websocket_protocol::{PROTOCOL_VERSION, PROTOCOL_VERSION_HEADER_NAME}; pub use cloud_api_types::*; use futures::AsyncReadExt as _; +use gpui::{App, Task}; +use gpui_tokio::Tokio; use http_client::http::request; use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode}; use parking_lot::RwLock; +use yawc::WebSocket; + +use crate::websocket::Connection; struct Credentials { user_id: u32, @@ -78,6 +86,41 @@ impl CloudApiClient { Ok(serde_json::from_str(&body)?) } + pub fn connect(&self, cx: &App) -> Result>> { + let mut connect_url = self + .http_client + .build_zed_cloud_url("/client/users/connect", &[])?; + connect_url + .set_scheme(match connect_url.scheme() { + "https" => "wss", + "http" => "ws", + scheme => Err(anyhow!("invalid URL scheme: {scheme}"))?, + }) + .map_err(|_| anyhow!("failed to set URL scheme"))?; + + let credentials = self.credentials.read(); + let credentials = credentials.as_ref().context("no credentials provided")?; + let authorization_header = format!("{} {}", credentials.user_id, credentials.access_token); + + Ok(cx.spawn(async move |cx| { + let handle = cx + .update(|cx| Tokio::handle(cx)) + .ok() + .context("failed to get Tokio handle")?; + let _guard = handle.enter(); + + let ws = WebSocket::connect(connect_url) + .with_request( + request::Builder::new() + .header("Authorization", authorization_header) + .header(PROTOCOL_VERSION_HEADER_NAME, PROTOCOL_VERSION.to_string()), + ) + .await?; + + Ok(Connection::new(ws)) + })) + } + pub async fn accept_terms_of_service(&self) -> Result { let request = self.build_request( Request::builder().method(Method::POST).uri( diff --git a/crates/cloud_api_client/src/websocket.rs b/crates/cloud_api_client/src/websocket.rs new file mode 100644 index 0000000000..48a628db78 --- /dev/null +++ b/crates/cloud_api_client/src/websocket.rs @@ -0,0 +1,73 @@ +use std::pin::Pin; +use std::time::Duration; + +use anyhow::Result; +use cloud_api_types::websocket_protocol::MessageToClient; +use futures::channel::mpsc::unbounded; +use futures::stream::{SplitSink, SplitStream}; +use futures::{FutureExt as _, SinkExt as _, Stream, StreamExt as _, TryStreamExt as _, pin_mut}; +use gpui::{App, BackgroundExecutor, Task}; +use yawc::WebSocket; +use yawc::frame::{FrameView, OpCode}; + +const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1); + +pub type MessageStream = Pin>>>; + +pub struct Connection { + tx: SplitSink, + rx: SplitStream, +} + +impl Connection { + pub fn new(ws: WebSocket) -> Self { + let (tx, rx) = ws.split(); + + Self { tx, rx } + } + + pub fn spawn(self, cx: &App) -> (MessageStream, Task<()>) { + let (mut tx, rx) = (self.tx, self.rx); + + let (message_tx, message_rx) = unbounded(); + + let handle_io = |executor: BackgroundExecutor| async move { + // Send messages on this frequency so the connection isn't closed. + let keepalive_timer = executor.timer(KEEPALIVE_INTERVAL).fuse(); + futures::pin_mut!(keepalive_timer); + + let rx = rx.fuse(); + pin_mut!(rx); + + loop { + futures::select_biased! { + _ = keepalive_timer => { + let _ = tx.send(FrameView::ping(Vec::new())).await; + + keepalive_timer.set(executor.timer(KEEPALIVE_INTERVAL).fuse()); + } + frame = rx.next() => { + let Some(frame) = frame else { + break; + }; + + match frame.opcode { + OpCode::Binary => { + let message_result = MessageToClient::deserialize(&frame.payload); + message_tx.unbounded_send(message_result).ok(); + } + OpCode::Close => { + break; + } + _ => {} + } + } + } + } + }; + + let task = cx.spawn(async move |cx| handle_io(cx.background_executor().clone()).await); + + (message_rx.into_stream().boxed(), task) + } +} diff --git a/crates/cloud_api_types/Cargo.toml b/crates/cloud_api_types/Cargo.toml index 868797df3b..28e0a36a44 100644 --- a/crates/cloud_api_types/Cargo.toml +++ b/crates/cloud_api_types/Cargo.toml @@ -12,7 +12,9 @@ workspace = true path = "src/cloud_api_types.rs" [dependencies] +anyhow.workspace = true chrono.workspace = true +ciborium.workspace = true cloud_llm_client.workspace = true serde.workspace = true workspace-hack.workspace = true diff --git a/crates/cloud_api_types/src/cloud_api_types.rs b/crates/cloud_api_types/src/cloud_api_types.rs index b38b38cde1..fa189cd3b5 100644 --- a/crates/cloud_api_types/src/cloud_api_types.rs +++ b/crates/cloud_api_types/src/cloud_api_types.rs @@ -1,4 +1,5 @@ mod timestamp; +pub mod websocket_protocol; use serde::{Deserialize, Serialize}; diff --git a/crates/cloud_api_types/src/websocket_protocol.rs b/crates/cloud_api_types/src/websocket_protocol.rs new file mode 100644 index 0000000000..75f6a73b43 --- /dev/null +++ b/crates/cloud_api_types/src/websocket_protocol.rs @@ -0,0 +1,28 @@ +use anyhow::{Context as _, Result}; +use serde::{Deserialize, Serialize}; + +/// The version of the Cloud WebSocket protocol. +pub const PROTOCOL_VERSION: u32 = 0; + +/// The name of the header used to indicate the protocol version in use. +pub const PROTOCOL_VERSION_HEADER_NAME: &str = "x-zed-protocol-version"; + +/// A message from Cloud to the Zed client. +#[derive(Debug, Serialize, Deserialize)] +pub enum MessageToClient { + /// The user was updated and should be refreshed. + UserUpdated, +} + +impl MessageToClient { + pub fn serialize(&self) -> Result> { + let mut buffer = Vec::new(); + ciborium::into_writer(self, &mut buffer).context("failed to serialize message")?; + + Ok(buffer) + } + + pub fn deserialize(data: &[u8]) -> Result { + ciborium::from_reader(data).context("failed to deserialize message") + } +} diff --git a/crates/collab/k8s/environments/production.sh b/crates/collab/k8s/environments/production.sh index e9e68849b8..2861f37896 100644 --- a/crates/collab/k8s/environments/production.sh +++ b/crates/collab/k8s/environments/production.sh @@ -2,5 +2,6 @@ ZED_ENVIRONMENT=production RUST_LOG=info INVITE_LINK_PREFIX=https://zed.dev/invites/ AUTO_JOIN_CHANNEL_ID=283 -DATABASE_MAX_CONNECTIONS=250 +# Set DATABASE_MAX_CONNECTIONS max connections in the `deploy_collab.yml`: +# https://github.com/zed-industries/zed/blob/main/.github/workflows/deploy_collab.yml LLM_DATABASE_MAX_CONNECTIONS=25 diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 0e15308ffe..a0325d14c4 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,485 +1,10 @@ -use anyhow::{Context as _, bail}; -use chrono::{DateTime, Utc}; -use cloud_llm_client::LanguageModelProvider; -use collections::{HashMap, HashSet}; -use sea_orm::ActiveValue; -use std::{sync::Arc, time::Duration}; -use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus}; -use util::{ResultExt, maybe}; +use std::sync::Arc; +use stripe::SubscriptionStatus; use crate::AppState; -use crate::db::billing_subscription::{ - StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, -}; -use crate::llm::db::subscription_usage_meter::{self, CompletionMode}; -use crate::rpc::{ResultExt as _, Server}; -use crate::stripe_client::{ - StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription, - StripeSubscriptionId, -}; -use crate::{db::UserId, llm::db::LlmDatabase}; -use crate::{ - db::{ - CreateBillingCustomerParams, CreateBillingSubscriptionParams, - CreateProcessedStripeEventParams, UpdateBillingCustomerParams, - UpdateBillingSubscriptionParams, billing_customer, - }, - stripe_billing::StripeBilling, -}; - -/// The amount of time we wait in between each poll of Stripe events. -/// -/// This value should strike a balance between: -/// 1. Being short enough that we update quickly when something in Stripe changes -/// 2. Being long enough that we don't eat into our rate limits. -/// -/// As a point of reference, the Sequin folks say they have this at **500ms**: -/// -/// > We poll the Stripe /events endpoint every 500ms per account -/// > -/// > — https://blog.sequinstream.com/events-not-webhooks/ -const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5); - -/// The maximum number of events to return per page. -/// -/// We set this to 100 (the max) so we have to make fewer requests to Stripe. -/// -/// > Limit can range between 1 and 100, and the default is 10. -const EVENTS_LIMIT_PER_PAGE: u64 = 100; - -/// The number of pages consisting entirely of already-processed events that we -/// will see before we stop retrieving events. -/// -/// This is used to prevent over-fetching the Stripe events API for events we've -/// already seen and processed. -const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4; - -/// Polls the Stripe events API periodically to reconcile the records in our -/// database with the data in Stripe. -pub fn poll_stripe_events_periodically(app: Arc, rpc_server: Arc) { - let Some(real_stripe_client) = app.real_stripe_client.clone() else { - log::warn!("failed to retrieve Stripe client"); - return; - }; - let Some(stripe_client) = app.stripe_client.clone() else { - log::warn!("failed to retrieve Stripe client"); - return; - }; - - let executor = app.executor.clone(); - executor.spawn_detached({ - let executor = executor.clone(); - async move { - loop { - poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client) - .await - .log_err(); - - executor.sleep(POLL_EVENTS_INTERVAL).await; - } - } - }); -} - -async fn poll_stripe_events( - app: &Arc, - rpc_server: &Arc, - stripe_client: &Arc, - real_stripe_client: &stripe::Client, -) -> anyhow::Result<()> { - let feature_flags = app.db.list_feature_flags().await?; - let sync_events_using_cloud = feature_flags - .iter() - .any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all); - if sync_events_using_cloud { - return Ok(()); - } - - fn event_type_to_string(event_type: EventType) -> String { - // Calling `to_string` on `stripe::EventType` members gives us a quoted string, - // so we need to unquote it. - event_type.to_string().trim_matches('"').to_string() - } - - let event_types = [ - EventType::CustomerCreated, - EventType::CustomerUpdated, - EventType::CustomerSubscriptionCreated, - EventType::CustomerSubscriptionUpdated, - EventType::CustomerSubscriptionPaused, - EventType::CustomerSubscriptionResumed, - EventType::CustomerSubscriptionDeleted, - ] - .into_iter() - .map(event_type_to_string) - .collect::>(); - - let mut pages_of_already_processed_events = 0; - let mut unprocessed_events = Vec::new(); - - log::info!( - "Stripe events: starting retrieval for {}", - event_types.join(", ") - ); - let mut params = ListEvents::new(); - params.types = Some(event_types.clone()); - params.limit = Some(EVENTS_LIMIT_PER_PAGE); - - let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms) - .await? - .paginate(params); - - loop { - let processed_event_ids = { - let event_ids = event_pages - .page - .data - .iter() - .map(|event| event.id.as_str()) - .collect::>(); - app.db - .get_processed_stripe_events_by_event_ids(&event_ids) - .await? - .into_iter() - .map(|event| event.stripe_event_id) - .collect::>() - }; - - let mut processed_events_in_page = 0; - let events_in_page = event_pages.page.data.len(); - for event in &event_pages.page.data { - if processed_event_ids.contains(&event.id.to_string()) { - processed_events_in_page += 1; - log::debug!("Stripe events: already processed '{}', skipping", event.id); - } else { - unprocessed_events.push(event.clone()); - } - } - - if processed_events_in_page == events_in_page { - pages_of_already_processed_events += 1; - } - - if event_pages.page.has_more { - if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP - { - log::info!( - "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events" - ); - break; - } else { - log::info!("Stripe events: retrieving next page"); - event_pages = event_pages.next(&real_stripe_client).await?; - } - } else { - break; - } - } - - log::info!("Stripe events: unprocessed {}", unprocessed_events.len()); - - // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred. - unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id))); - - for event in unprocessed_events { - let event_id = event.id.clone(); - let processed_event_params = CreateProcessedStripeEventParams { - stripe_event_id: event.id.to_string(), - stripe_event_type: event_type_to_string(event.type_), - stripe_event_created_timestamp: event.created, - }; - - // If the event has happened too far in the past, we don't want to - // process it and risk overwriting other more-recent updates. - // - // 1 day was chosen arbitrarily. This could be made longer or shorter. - let one_day = Duration::from_secs(24 * 60 * 60); - let a_day_ago = Utc::now() - one_day; - if a_day_ago.timestamp() > event.created { - log::info!( - "Stripe events: event '{}' is more than {one_day:?} old, marking as processed", - event_id - ); - app.db - .create_processed_stripe_event(&processed_event_params) - .await?; - - continue; - } - - let process_result = match event.type_ { - EventType::CustomerCreated | EventType::CustomerUpdated => { - handle_customer_event(app, real_stripe_client, event).await - } - EventType::CustomerSubscriptionCreated - | EventType::CustomerSubscriptionUpdated - | EventType::CustomerSubscriptionPaused - | EventType::CustomerSubscriptionResumed - | EventType::CustomerSubscriptionDeleted => { - handle_customer_subscription_event(app, rpc_server, stripe_client, event).await - } - _ => Ok(()), - }; - - if let Some(()) = process_result - .with_context(|| format!("failed to process event {event_id} successfully")) - .log_err() - { - app.db - .create_processed_stripe_event(&processed_event_params) - .await?; - } - } - - Ok(()) -} - -async fn handle_customer_event( - app: &Arc, - _stripe_client: &stripe::Client, - event: stripe::Event, -) -> anyhow::Result<()> { - let EventObject::Customer(customer) = event.data.object else { - bail!("unexpected event payload for {}", event.id); - }; - - log::info!("handling Stripe {} event: {}", event.type_, event.id); - - let Some(email) = customer.email else { - log::info!("Stripe customer has no email: skipping"); - return Ok(()); - }; - - let Some(user) = app.db.get_user_by_email(&email).await? else { - log::info!("no user found for email: skipping"); - return Ok(()); - }; - - if let Some(existing_customer) = app - .db - .get_billing_customer_by_stripe_customer_id(&customer.id) - .await? - { - app.db - .update_billing_customer( - existing_customer.id, - &UpdateBillingCustomerParams { - // For now we just leave the information as-is, as it is not - // likely to change. - ..Default::default() - }, - ) - .await?; - } else { - app.db - .create_billing_customer(&CreateBillingCustomerParams { - user_id: user.id, - stripe_customer_id: customer.id.to_string(), - }) - .await?; - } - - Ok(()) -} - -async fn sync_subscription( - app: &Arc, - stripe_client: &Arc, - subscription: StripeSubscription, -) -> anyhow::Result { - let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing { - stripe_billing - .determine_subscription_kind(&subscription) - .await - } else { - None - }; - - let billing_customer = - find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer) - .await? - .context("billing customer not found")?; - - if let Some(SubscriptionKind::ZedProTrial) = subscription_kind { - if subscription.status == SubscriptionStatus::Trialing { - let current_period_start = - DateTime::from_timestamp(subscription.current_period_start, 0) - .context("No trial subscription period start")?; - - app.db - .update_billing_customer( - billing_customer.id, - &UpdateBillingCustomerParams { - trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())), - ..Default::default() - }, - ) - .await?; - } - } - - let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled - && subscription - .cancellation_details - .as_ref() - .and_then(|details| details.reason) - .map_or(false, |reason| { - reason == StripeCancellationDetailsReason::PaymentFailed - }); - - if was_canceled_due_to_payment_failure { - app.db - .update_billing_customer( - billing_customer.id, - &UpdateBillingCustomerParams { - has_overdue_invoices: ActiveValue::set(true), - ..Default::default() - }, - ) - .await?; - } - - if let Some(existing_subscription) = app - .db - .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref()) - .await? - { - app.db - .update_billing_subscription( - existing_subscription.id, - &UpdateBillingSubscriptionParams { - billing_customer_id: ActiveValue::set(billing_customer.id), - kind: ActiveValue::set(subscription_kind), - stripe_subscription_id: ActiveValue::set(subscription.id.to_string()), - stripe_subscription_status: ActiveValue::set(subscription.status.into()), - stripe_cancel_at: ActiveValue::set( - subscription - .cancel_at - .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0)) - .map(|time| time.naive_utc()), - ), - stripe_cancellation_reason: ActiveValue::set( - subscription - .cancellation_details - .and_then(|details| details.reason) - .map(|reason| reason.into()), - ), - stripe_current_period_start: ActiveValue::set(Some( - subscription.current_period_start, - )), - stripe_current_period_end: ActiveValue::set(Some( - subscription.current_period_end, - )), - }, - ) - .await?; - } else { - if let Some(existing_subscription) = app - .db - .get_active_billing_subscription(billing_customer.user_id) - .await? - { - if existing_subscription.kind == Some(SubscriptionKind::ZedFree) - && subscription_kind == Some(SubscriptionKind::ZedProTrial) - { - let stripe_subscription_id = StripeSubscriptionId( - existing_subscription.stripe_subscription_id.clone().into(), - ); - - stripe_client - .cancel_subscription(&stripe_subscription_id) - .await?; - } else { - // If the user already has an active billing subscription, ignore the - // event and return an `Ok` to signal that it was processed - // successfully. - // - // There is the possibility that this could cause us to not create a - // subscription in the following scenario: - // - // 1. User has an active subscription A - // 2. User cancels subscription A - // 3. User creates a new subscription B - // 4. We process the new subscription B before the cancellation of subscription A - // 5. User ends up with no subscriptions - // - // In theory this situation shouldn't arise as we try to process the events in the order they occur. - - log::info!( - "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}", - user_id = billing_customer.user_id, - subscription_id = subscription.id - ); - return Ok(billing_customer); - } - } - - app.db - .create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: billing_customer.id, - kind: subscription_kind, - stripe_subscription_id: subscription.id.to_string(), - stripe_subscription_status: subscription.status.into(), - stripe_cancellation_reason: subscription - .cancellation_details - .and_then(|details| details.reason) - .map(|reason| reason.into()), - stripe_current_period_start: Some(subscription.current_period_start), - stripe_current_period_end: Some(subscription.current_period_end), - }) - .await?; - } - - if let Some(stripe_billing) = app.stripe_billing.as_ref() { - if subscription.status == SubscriptionStatus::Canceled - || subscription.status == SubscriptionStatus::Paused - { - let already_has_active_billing_subscription = app - .db - .has_active_billing_subscription(billing_customer.user_id) - .await?; - if !already_has_active_billing_subscription { - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - stripe_billing - .subscribe_to_zed_free(stripe_customer_id) - .await?; - } - } - } - - Ok(billing_customer) -} - -async fn handle_customer_subscription_event( - app: &Arc, - rpc_server: &Arc, - stripe_client: &Arc, - event: stripe::Event, -) -> anyhow::Result<()> { - let EventObject::Subscription(subscription) = event.data.object else { - bail!("unexpected event payload for {}", event.id); - }; - - log::info!("handling Stripe {} event: {}", event.type_, event.id); - - let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?; - - // When the user's subscription changes, push down any changes to their plan. - rpc_server - .update_plan_for_user_legacy(billing_customer.user_id) - .await - .trace_err(); - - // When the user's subscription changes, we want to refresh their LLM tokens - // to either grant/revoke access. - rpc_server - .refresh_llm_tokens_for_user(billing_customer.user_id) - .await; - - Ok(()) -} +use crate::db::billing_subscription::StripeSubscriptionStatus; +use crate::db::{CreateBillingCustomerParams, billing_customer}; +use crate::stripe_client::{StripeClient, StripeCustomerId}; impl From for StripeSubscriptionStatus { fn from(value: SubscriptionStatus) -> Self { @@ -496,16 +21,6 @@ impl From for StripeSubscriptionStatus { } } -impl From for StripeCancellationReason { - fn from(value: CancellationDetailsReason) -> Self { - match value { - CancellationDetailsReason::CancellationRequested => Self::CancellationRequested, - CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed, - CancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} - /// Finds or creates a billing customer using the provided customer. pub async fn find_or_create_billing_customer( app: &Arc, @@ -542,194 +57,3 @@ pub async fn find_or_create_billing_customer( Ok(Some(billing_customer)) } - -const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60); - -pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc) { - let Some(stripe_billing) = app.stripe_billing.clone() else { - log::warn!("failed to retrieve Stripe billing object"); - return; - }; - let Some(llm_db) = app.llm_db.clone() else { - log::warn!("failed to retrieve LLM database"); - return; - }; - - let executor = app.executor.clone(); - executor.spawn_detached({ - let executor = executor.clone(); - async move { - loop { - sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing) - .await - .context("failed to sync LLM request usage to Stripe") - .trace_err(); - executor - .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL) - .await; - } - } - }); -} - -async fn sync_model_request_usage_with_stripe( - app: &Arc, - llm_db: &Arc, - stripe_billing: &Arc, -) -> anyhow::Result<()> { - let feature_flags = app.db.list_feature_flags().await?; - let sync_model_request_usage_using_cloud = feature_flags - .iter() - .any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all); - if sync_model_request_usage_using_cloud { - return Ok(()); - } - - log::info!("Stripe usage sync: Starting"); - let started_at = Utc::now(); - - let staff_users = app.db.get_staff_users().await?; - let staff_user_ids = staff_users - .iter() - .map(|user| user.id) - .collect::>(); - - let usage_meters = llm_db - .get_current_subscription_usage_meters(Utc::now()) - .await?; - let mut usage_meters_by_user_id = - HashMap::>::default(); - for (usage_meter, usage) in usage_meters { - let meters = usage_meters_by_user_id.entry(usage.user_id).or_default(); - meters.push(usage_meter); - } - - log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions"); - let get_zed_pro_subscriptions_started_at = Utc::now(); - let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?; - log::info!( - "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}", - billing_subscriptions.len(), - Utc::now() - get_zed_pro_subscriptions_started_at - ); - - let claude_sonnet_4 = stripe_billing - .find_price_by_lookup_key("claude-sonnet-4-requests") - .await?; - let claude_sonnet_4_max = stripe_billing - .find_price_by_lookup_key("claude-sonnet-4-requests-max") - .await?; - let claude_opus_4 = stripe_billing - .find_price_by_lookup_key("claude-opus-4-requests") - .await?; - let claude_opus_4_max = stripe_billing - .find_price_by_lookup_key("claude-opus-4-requests-max") - .await?; - let claude_3_5_sonnet = stripe_billing - .find_price_by_lookup_key("claude-3-5-sonnet-requests") - .await?; - let claude_3_7_sonnet = stripe_billing - .find_price_by_lookup_key("claude-3-7-sonnet-requests") - .await?; - let claude_3_7_sonnet_max = stripe_billing - .find_price_by_lookup_key("claude-3-7-sonnet-requests-max") - .await?; - - let model_mode_combinations = [ - ("claude-opus-4", CompletionMode::Max), - ("claude-opus-4", CompletionMode::Normal), - ("claude-sonnet-4", CompletionMode::Max), - ("claude-sonnet-4", CompletionMode::Normal), - ("claude-3-7-sonnet", CompletionMode::Max), - ("claude-3-7-sonnet", CompletionMode::Normal), - ("claude-3-5-sonnet", CompletionMode::Normal), - ]; - - let billing_subscription_count = billing_subscriptions.len(); - - log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions"); - - for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions { - maybe!(async { - if staff_user_ids.contains(&user_id) { - return anyhow::Ok(()); - } - - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - let stripe_subscription_id = - StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into()); - - let usage_meters = usage_meters_by_user_id.get(&user_id); - - for (model, mode) in &model_mode_combinations { - let Ok(model) = - llm_db.model(LanguageModelProvider::Anthropic, model) - else { - log::warn!("Failed to load model for user {user_id}: {model}"); - continue; - }; - - let (price, meter_event_name) = match model.name.as_str() { - "claude-opus-4" => match mode { - CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"), - CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"), - }, - "claude-sonnet-4" => match mode { - CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"), - CompletionMode::Max => { - (&claude_sonnet_4_max, "claude_sonnet_4/requests/max") - } - }, - "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), - "claude-3-7-sonnet" => match mode { - CompletionMode::Normal => { - (&claude_3_7_sonnet, "claude_3_7_sonnet/requests") - } - CompletionMode::Max => { - (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max") - } - }, - model_name => { - bail!("Attempted to sync usage meter for unsupported model: {model_name:?}") - } - }; - - let model_requests = usage_meters - .and_then(|usage_meters| { - usage_meters - .iter() - .find(|meter| meter.model_id == model.id && meter.mode == *mode) - }) - .map(|usage_meter| usage_meter.requests) - .unwrap_or(0); - - if model_requests > 0 { - stripe_billing - .subscribe_to_price(&stripe_subscription_id, price) - .await?; - } - - stripe_billing - .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests) - .await - .with_context(|| { - format!( - "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}", - ) - })?; - } - - Ok(()) - }) - .await - .log_err(); - } - - log::info!( - "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}", - Utc::now() - started_at - ); - - Ok(()) -} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 9f82e3dbc4..8361d6b4d0 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -85,19 +85,6 @@ impl Database { .await } - /// Returns the billing subscription with the specified ID. - pub async fn get_billing_subscription_by_id( - &self, - id: BillingSubscriptionId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_subscription::Entity::find_by_id(id) - .one(&*tx) - .await?) - }) - .await - } - /// Returns the billing subscription with the specified Stripe subscription ID. pub async fn get_billing_subscription_by_stripe_subscription_id( &self, @@ -143,119 +130,6 @@ impl Database { .await } - /// Returns all of the billing subscriptions for the user with the specified ID. - /// - /// Note that this returns the subscriptions regardless of their status. - /// If you're wanting to check if a use has an active billing subscription, - /// use `get_active_billing_subscriptions` instead. - pub async fn get_billing_subscriptions( - &self, - user_id: UserId, - ) -> Result> { - self.transaction(|tx| async move { - let subscriptions = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .filter(billing_customer::Column::UserId.eq(user_id)) - .order_by_asc(billing_subscription::Column::Id) - .all(&*tx) - .await?; - - Ok(subscriptions) - }) - .await - } - - pub async fn get_active_billing_subscriptions( - &self, - user_ids: HashSet, - ) -> Result> { - self.transaction(|tx| { - let user_ids = user_ids.clone(); - async move { - let mut rows = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .select_also(billing_customer::Entity) - .filter(billing_customer::Column::UserId.is_in(user_ids)) - .filter( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .filter(billing_subscription::Column::Kind.is_null()) - .order_by_asc(billing_subscription::Column::Id) - .stream(&*tx) - .await?; - - let mut subscriptions = HashMap::default(); - while let Some(row) = rows.next().await { - if let (subscription, Some(customer)) = row? { - subscriptions.insert(customer.user_id, (customer, subscription)); - } - } - Ok(subscriptions) - } - }) - .await - } - - pub async fn get_active_zed_pro_billing_subscriptions( - &self, - ) -> Result> { - self.transaction(|tx| async move { - let mut rows = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .select_also(billing_customer::Entity) - .filter( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro)) - .order_by_asc(billing_subscription::Column::Id) - .stream(&*tx) - .await?; - - let mut subscriptions = HashMap::default(); - while let Some(row) = rows.next().await { - if let (subscription, Some(customer)) = row? { - subscriptions.insert(customer.user_id, (customer, subscription)); - } - } - Ok(subscriptions) - }) - .await - } - - pub async fn get_active_zed_pro_billing_subscriptions_for_users( - &self, - user_ids: HashSet, - ) -> Result> { - self.transaction(|tx| { - let user_ids = user_ids.clone(); - async move { - let mut rows = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .select_also(billing_customer::Entity) - .filter(billing_customer::Column::UserId.is_in(user_ids)) - .filter( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro)) - .order_by_asc(billing_subscription::Column::Id) - .stream(&*tx) - .await?; - - let mut subscriptions = HashMap::default(); - while let Some(row) = rows.next().await { - if let (subscription, Some(customer)) = row? { - subscriptions.insert(customer.user_id, (customer, subscription)); - } - } - Ok(subscriptions) - } - }) - .await - } - /// Returns whether the user has an active billing subscription. pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result { Ok(self.count_active_billing_subscriptions(user_id).await? > 0) diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index 31635575a8..82f74d910b 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -699,7 +699,10 @@ impl Database { language_server::Column::ProjectId, language_server::Column::Id, ]) - .update_column(language_server::Column::Name) + .update_columns([ + language_server::Column::Name, + language_server::Column::Capabilities, + ]) .to_owned(), ) .exec(&*tx) diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 9404e2670c..6c2f9dc82a 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -1,4 +1,3 @@ -mod billing_subscription_tests; mod buffer_tests; mod channel_tests; mod contributor_tests; diff --git a/crates/collab/src/db/tests/billing_subscription_tests.rs b/crates/collab/src/db/tests/billing_subscription_tests.rs deleted file mode 100644 index fb5f8552a3..0000000000 --- a/crates/collab/src/db/tests/billing_subscription_tests.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::sync::Arc; - -use crate::db::billing_subscription::StripeSubscriptionStatus; -use crate::db::tests::new_test_user; -use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams}; -use crate::test_both_dbs; - -use super::Database; - -test_both_dbs!( - test_get_active_billing_subscriptions, - test_get_active_billing_subscriptions_postgres, - test_get_active_billing_subscriptions_sqlite -); - -async fn test_get_active_billing_subscriptions(db: &Arc) { - // A user with no subscription has no active billing subscriptions. - { - let user_id = new_test_user(db, "no-subscription-user@example.com").await; - let subscription_count = db - .count_active_billing_subscriptions(user_id) - .await - .unwrap(); - - assert_eq!(subscription_count, 0); - } - - // A user with an active subscription has one active billing subscription. - { - let user_id = new_test_user(db, "active-user@example.com").await; - let customer = db - .create_billing_customer(&CreateBillingCustomerParams { - user_id, - stripe_customer_id: "cus_active_user".into(), - }) - .await - .unwrap(); - assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string()); - - db.create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: customer.id, - kind: None, - stripe_subscription_id: "sub_active_user".into(), - stripe_subscription_status: StripeSubscriptionStatus::Active, - stripe_cancellation_reason: None, - stripe_current_period_start: None, - stripe_current_period_end: None, - }) - .await - .unwrap(); - - let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap(); - assert_eq!(subscriptions.len(), 1); - - let subscription = &subscriptions[0]; - assert_eq!( - subscription.stripe_subscription_id, - "sub_active_user".to_string() - ); - assert_eq!( - subscription.stripe_subscription_status, - StripeSubscriptionStatus::Active - ); - } - - // A user with a past-due subscription has no active billing subscriptions. - { - let user_id = new_test_user(db, "past-due-user@example.com").await; - let customer = db - .create_billing_customer(&CreateBillingCustomerParams { - user_id, - stripe_customer_id: "cus_past_due_user".into(), - }) - .await - .unwrap(); - assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string()); - - db.create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: customer.id, - kind: None, - stripe_subscription_id: "sub_past_due_user".into(), - stripe_subscription_status: StripeSubscriptionStatus::PastDue, - stripe_cancellation_reason: None, - stripe_current_period_start: None, - stripe_current_period_end: None, - }) - .await - .unwrap(); - - let subscription_count = db - .count_active_billing_subscriptions(user_id) - .await - .unwrap(); - assert_eq!(subscription_count, 0); - } -} diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index 3565366fdd..0087218b3f 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -1,6 +1,5 @@ use super::*; pub mod providers; -pub mod subscription_usage_meters; pub mod subscription_usages; pub mod usages; diff --git a/crates/collab/src/llm/db/queries/subscription_usage_meters.rs b/crates/collab/src/llm/db/queries/subscription_usage_meters.rs deleted file mode 100644 index c0ce5d679b..0000000000 --- a/crates/collab/src/llm/db/queries/subscription_usage_meters.rs +++ /dev/null @@ -1,72 +0,0 @@ -use crate::db::UserId; -use crate::llm::db::queries::subscription_usages::convert_chrono_to_time; - -use super::*; - -impl LlmDatabase { - /// Returns all current subscription usage meters as of the given timestamp. - pub async fn get_current_subscription_usage_meters( - &self, - now: DateTimeUtc, - ) -> Result> { - let now = convert_chrono_to_time(now)?; - - self.transaction(|tx| async move { - let result = subscription_usage_meter::Entity::find() - .inner_join(subscription_usage::Entity) - .filter( - subscription_usage::Column::PeriodStartAt - .lte(now) - .and(subscription_usage::Column::PeriodEndAt.gte(now)), - ) - .select_also(subscription_usage::Entity) - .all(&*tx) - .await?; - - let result = result - .into_iter() - .filter_map(|(meter, usage)| { - let usage = usage?; - Some((meter, usage)) - }) - .collect(); - - Ok(result) - }) - .await - } - - /// Returns all current subscription usage meters for the given user as of the given timestamp. - pub async fn get_current_subscription_usage_meters_for_user( - &self, - user_id: UserId, - now: DateTimeUtc, - ) -> Result> { - let now = convert_chrono_to_time(now)?; - - self.transaction(|tx| async move { - let result = subscription_usage_meter::Entity::find() - .inner_join(subscription_usage::Entity) - .filter(subscription_usage::Column::UserId.eq(user_id)) - .filter( - subscription_usage::Column::PeriodStartAt - .lte(now) - .and(subscription_usage::Column::PeriodEndAt.gte(now)), - ) - .select_also(subscription_usage::Entity) - .all(&*tx) - .await?; - - let result = result - .into_iter() - .filter_map(|(meter, usage)| { - let usage = usage?; - Some((meter, usage)) - }) - .collect(); - - Ok(result) - }) - .await - } -} diff --git a/crates/collab/src/llm/db/queries/subscription_usages.rs b/crates/collab/src/llm/db/queries/subscription_usages.rs index ee1ebf59b8..8a51979075 100644 --- a/crates/collab/src/llm/db/queries/subscription_usages.rs +++ b/crates/collab/src/llm/db/queries/subscription_usages.rs @@ -1,28 +1,7 @@ -use time::PrimitiveDateTime; - use crate::db::UserId; use super::*; -pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result { - use chrono::{Datelike as _, Timelike as _}; - - let date = time::Date::from_calendar_date( - datetime.year(), - time::Month::try_from(datetime.month() as u8).unwrap(), - datetime.day() as u8, - )?; - - let time = time::Time::from_hms_nano( - datetime.hour() as u8, - datetime.minute() as u8, - datetime.second() as u8, - datetime.nanosecond(), - )?; - - Ok(PrimitiveDateTime::new(date, time)) -} - impl LlmDatabase { pub async fn get_subscription_usage_for_period( &self, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 6a78049b3f..20641cb232 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -7,8 +7,8 @@ use axum::{ routing::get, }; +use collab::ServiceMode; use collab::api::CloudflareIpCountryHeader; -use collab::api::billing::sync_llm_request_usage_with_stripe_periodically; use collab::llm::db::LlmDatabase; use collab::migrations::run_database_migrations; use collab::user_backfiller::spawn_user_backfiller; @@ -16,7 +16,6 @@ use collab::{ AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, rpc::ResultExt, }; -use collab::{ServiceMode, api::billing::poll_stripe_events_periodically}; use db::Database; use std::{ env::args, @@ -31,7 +30,7 @@ use tower_http::trace::TraceLayer; use tracing_subscriber::{ Layer, filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, }; -use util::{ResultExt as _, maybe}; +use util::ResultExt as _; const VERSION: &str = env!("CARGO_PKG_VERSION"); const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); @@ -120,8 +119,6 @@ async fn main() -> Result<()> { let rpc_server = collab::rpc::Server::new(epoch, state.clone()); rpc_server.start().await?; - poll_stripe_events_periodically(state.clone(), rpc_server.clone()); - app = app .merge(collab::api::routes(rpc_server.clone())) .merge(collab::rpc::routes(rpc_server.clone())); @@ -133,29 +130,6 @@ async fn main() -> Result<()> { fetch_extensions_from_blob_store_periodically(state.clone()); spawn_user_backfiller(state.clone()); - let llm_db = maybe!(async { - let database_url = state - .config - .llm_database_url - .as_ref() - .context("missing LLM_DATABASE_URL")?; - let max_connections = state - .config - .llm_database_max_connections - .context("missing LLM_DATABASE_MAX_CONNECTIONS")?; - - let mut db_options = db::ConnectOptions::new(database_url); - db_options.max_connections(max_connections); - LlmDatabase::new(db_options, state.executor.clone()).await - }) - .await - .trace_err(); - - if let Some(mut llm_db) = llm_db { - llm_db.initialize().await?; - sync_llm_request_usage_with_stripe_periodically(state.clone()); - } - app = app .merge(collab::api::events::router()) .merge(collab::api::extensions::router()) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 22b21f2c7a..18eb1457dc 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -41,9 +41,11 @@ use chrono::Utc; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; +use futures::TryFutureExt as _; use reqwest_client::ReqwestClient; use rpc::proto::{MultiLspQuery, split_repository_update}; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; +use tracing::Span; use futures::{ FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture, @@ -94,8 +96,13 @@ const MAX_CONCURRENT_CONNECTIONS: usize = 512; static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0); +const TOTAL_DURATION_MS: &str = "total_duration_ms"; +const PROCESSING_DURATION_MS: &str = "processing_duration_ms"; +const QUEUE_DURATION_MS: &str = "queue_duration_ms"; +const HOST_WAITING_MS: &str = "host_waiting_ms"; + type MessageHandler = - Box, Session) -> BoxFuture<'static, ()>>; + Box, Session, Span) -> BoxFuture<'static, ()>>; pub struct ConnectionGuard; @@ -163,6 +170,42 @@ impl Principal { } } +#[derive(Clone)] +struct MessageContext { + session: Session, + span: tracing::Span, +} + +impl Deref for MessageContext { + type Target = Session; + + fn deref(&self) -> &Self::Target { + &self.session + } +} + +impl MessageContext { + pub fn forward_request( + &self, + receiver_id: ConnectionId, + request: T, + ) -> impl Future> { + let request_start_time = Instant::now(); + let span = self.span.clone(); + tracing::info!("start forwarding request"); + self.peer + .forward_request(self.connection_id, receiver_id, request) + .inspect(move |_| { + span.record( + HOST_WAITING_MS, + request_start_time.elapsed().as_micros() as f64 / 1000.0, + ); + }) + .inspect_err(|_| tracing::error!("error forwarding request")) + .inspect_ok(|_| tracing::info!("finished forwarding request")) + } +} + #[derive(Clone)] struct Session { principal: Principal, @@ -340,9 +383,6 @@ impl Server { .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) - .add_request_handler( - forward_read_only_project_request::, - ) .add_request_handler(forward_read_only_project_request::) .add_request_handler( forward_mutating_project_request::, @@ -649,40 +689,37 @@ impl Server { fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(TypedEnvelope, Session) -> Fut, + F: 'static + Send + Sync + Fn(TypedEnvelope, MessageContext) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |envelope, session| { + Box::new(move |envelope, session, span| { let envelope = envelope.into_any().downcast::>().unwrap(); let received_at = envelope.received_at; tracing::info!("message received"); let start_time = Instant::now(); - let future = (handler)(*envelope, session); + let future = (handler)( + *envelope, + MessageContext { + session, + span: span.clone(), + }, + ); async move { let result = future.await; let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0; let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0; let queue_duration_ms = total_duration_ms - processing_duration_ms; - + span.record(TOTAL_DURATION_MS, total_duration_ms); + span.record(PROCESSING_DURATION_MS, processing_duration_ms); + span.record(QUEUE_DURATION_MS, queue_duration_ms); match result { Err(error) => { - tracing::error!( - ?error, - total_duration_ms, - processing_duration_ms, - queue_duration_ms, - "error handling message" - ) + tracing::error!(?error, "error handling message") } - Ok(()) => tracing::info!( - total_duration_ms, - processing_duration_ms, - queue_duration_ms, - "finished handling message" - ), + Ok(()) => tracing::info!("finished handling message"), } } .boxed() @@ -696,7 +733,7 @@ impl Server { fn add_message_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(M, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { @@ -706,7 +743,7 @@ impl Server { fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(M, Response, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, Response, MessageContext) -> Fut, Fut: Send + Future>, M: RequestMessage, { @@ -749,6 +786,7 @@ impl Server { address: String, principal: Principal, zed_version: ZedVersion, + release_channel: Option, user_agent: Option, geoip_country_code: Option, system_id: Option, @@ -763,12 +801,16 @@ impl Server { login=field::Empty, impersonator=field::Empty, user_agent=field::Empty, - geoip_country_code=field::Empty + geoip_country_code=field::Empty, + release_channel=field::Empty, ); principal.update_span(&span); if let Some(user_agent) = user_agent { span.record("user_agent", user_agent); } + if let Some(release_channel) = release_channel { + span.record("release_channel", release_channel); + } if let Some(country_code) = geoip_country_code.as_ref() { span.record("geoip_country_code", country_code); @@ -887,12 +929,17 @@ impl Server { login=field::Empty, impersonator=field::Empty, multi_lsp_query_request=field::Empty, + release_channel=field::Empty, + { TOTAL_DURATION_MS }=field::Empty, + { PROCESSING_DURATION_MS }=field::Empty, + { QUEUE_DURATION_MS }=field::Empty, + { HOST_WAITING_MS }=field::Empty ); principal.update_span(&span); let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let handle_message = (handler)(message, session.clone()); + let handle_message = (handler)(message, session.clone(), span.clone()); drop(span_enter); let handle_message = async move { @@ -1184,6 +1231,35 @@ impl Header for AppVersionHeader { } } +#[derive(Debug)] +pub struct ReleaseChannelHeader(String); + +impl Header for ReleaseChannelHeader { + fn name() -> &'static HeaderName { + static ZED_RELEASE_CHANNEL: OnceLock = OnceLock::new(); + ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel")) + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + Ok(Self( + values + .next() + .ok_or_else(axum::headers::Error::invalid)? + .to_str() + .map_err(|_| axum::headers::Error::invalid())? + .to_owned(), + )) + } + + fn encode>(&self, values: &mut E) { + values.extend([self.0.parse().unwrap()]); + } +} + pub fn routes(server: Arc) -> Router<(), Body> { Router::new() .route("/rpc", get(handle_websocket_request)) @@ -1199,6 +1275,7 @@ pub fn routes(server: Arc) -> Router<(), Body> { pub async fn handle_websocket_request( TypedHeader(ProtocolVersion(protocol_version)): TypedHeader, app_version_header: Option>, + release_channel_header: Option>, ConnectInfo(socket_address): ConnectInfo, Extension(server): Extension>, Extension(principal): Extension, @@ -1223,6 +1300,8 @@ pub async fn handle_websocket_request( .into_response(); }; + let release_channel = release_channel_header.map(|header| header.0.0); + if !version.can_collaborate() { return ( StatusCode::UPGRADE_REQUIRED, @@ -1258,6 +1337,7 @@ pub async fn handle_websocket_request( socket_address, principal, version, + release_channel, user_agent.map(|header| header.to_string()), country_code_header.map(|header| header.to_string()), system_id_header.map(|header| header.to_string()), @@ -1351,7 +1431,11 @@ async fn connection_lost( } /// Acknowledges a ping from a client, used to keep the connection alive. -async fn ping(_: proto::Ping, response: Response, _session: Session) -> Result<()> { +async fn ping( + _: proto::Ping, + response: Response, + _session: MessageContext, +) -> Result<()> { response.send(proto::Ack {})?; Ok(()) } @@ -1360,7 +1444,7 @@ async fn ping(_: proto::Ping, response: Response, _session: Session async fn create_room( _request: proto::CreateRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let livekit_room = nanoid::nanoid!(30); @@ -1400,7 +1484,7 @@ async fn create_room( async fn join_room( request: proto::JoinRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.id); @@ -1467,7 +1551,7 @@ async fn join_room( async fn rejoin_room( request: proto::RejoinRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room; let channel; @@ -1595,15 +1679,15 @@ fn notify_rejoined_projects( } // Stream this worktree's diagnostics. - for summary in worktree.diagnostic_summaries { - session.peer.send( - session.connection_id, - proto::UpdateDiagnosticSummary { - project_id: project.id.to_proto(), - worktree_id: worktree.id, - summary: Some(summary), - }, - )?; + let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter(); + if let Some(summary) = worktree_diagnostics.next() { + let message = proto::UpdateDiagnosticSummary { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + summary: Some(summary), + more_summaries: worktree_diagnostics.collect(), + }; + session.peer.send(session.connection_id, message)?; } for settings_file in worktree.settings_files { @@ -1644,7 +1728,7 @@ fn notify_rejoined_projects( async fn leave_room( _: proto::LeaveRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { leave_room_for_session(&session, session.connection_id).await?; response.send(proto::Ack {})?; @@ -1655,7 +1739,7 @@ async fn leave_room( async fn set_room_participant_role( request: proto::SetRoomParticipantRole, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let user_id = UserId::from_proto(request.user_id); let role = ChannelRole::from(request.role()); @@ -1703,7 +1787,7 @@ async fn set_room_participant_role( async fn call( request: proto::Call, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let calling_user_id = session.user_id(); @@ -1772,7 +1856,7 @@ async fn call( async fn cancel_call( request: proto::CancelCall, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let called_user_id = UserId::from_proto(request.called_user_id); let room_id = RoomId::from_proto(request.room_id); @@ -1807,7 +1891,7 @@ async fn cancel_call( } /// Decline an incoming call. -async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { +async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> { let room_id = RoomId::from_proto(message.room_id); { let room = session @@ -1842,7 +1926,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( async fn update_participant_location( request: proto::UpdateParticipantLocation, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let location = request.location.context("invalid location")?; @@ -1861,7 +1945,7 @@ async fn update_participant_location( async fn share_project( request: proto::ShareProject, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let (project_id, room) = &*session .db() @@ -1882,7 +1966,7 @@ async fn share_project( } /// Unshare a project from the room. -async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { +async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> { let project_id = ProjectId::from_proto(message.project_id); unshare_project_internal(project_id, session.connection_id, &session).await } @@ -1929,7 +2013,7 @@ async fn unshare_project_internal( async fn join_project( request: proto::JoinProject, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); @@ -2025,15 +2109,15 @@ async fn join_project( } // Stream this worktree's diagnostics. - for summary in worktree.diagnostic_summaries { - session.peer.send( - session.connection_id, - proto::UpdateDiagnosticSummary { - project_id: project_id.to_proto(), - worktree_id: worktree.id, - summary: Some(summary), - }, - )?; + let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter(); + if let Some(summary) = worktree_diagnostics.next() { + let message = proto::UpdateDiagnosticSummary { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + summary: Some(summary), + more_summaries: worktree_diagnostics.collect(), + }; + session.peer.send(session.connection_id, message)?; } for settings_file in worktree.settings_files { @@ -2076,7 +2160,7 @@ async fn join_project( } /// Leave someone elses shared project. -async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { +async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> { let sender_id = session.connection_id; let project_id = ProjectId::from_proto(request.project_id); let db = session.db().await; @@ -2099,7 +2183,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result async fn update_project( request: proto::UpdateProject, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let (room, guest_connection_ids) = &*session @@ -2128,7 +2212,7 @@ async fn update_project( async fn update_worktree( request: proto::UpdateWorktree, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2152,7 +2236,7 @@ async fn update_worktree( async fn update_repository( request: proto::UpdateRepository, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2176,7 +2260,7 @@ async fn update_repository( async fn remove_repository( request: proto::RemoveRepository, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2200,7 +2284,7 @@ async fn remove_repository( /// Updates other participants with changes to the diagnostics async fn update_diagnostic_summary( message: proto::UpdateDiagnosticSummary, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2224,7 +2308,7 @@ async fn update_diagnostic_summary( /// Updates other participants with changes to the worktree settings async fn update_worktree_settings( message: proto::UpdateWorktreeSettings, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2248,7 +2332,7 @@ async fn update_worktree_settings( /// Notify other participants that a language server has started. async fn start_language_server( request: proto::StartLanguageServer, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2271,7 +2355,7 @@ async fn start_language_server( /// Notify other participants that a language server has changed. async fn update_language_server( request: proto::UpdateLanguageServer, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let db = session.db().await; @@ -2304,7 +2388,7 @@ async fn update_language_server( async fn forward_read_only_project_request( request: T, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> where T: EntityMessage + RequestMessage, @@ -2315,10 +2399,7 @@ where .await .host_for_read_only_project_request(project_id, session.connection_id) .await?; - let payload = session - .peer - .forward_request(session.connection_id, host_connection_id, request) - .await?; + let payload = session.forward_request(host_connection_id, request).await?; response.send(payload)?; Ok(()) } @@ -2328,7 +2409,7 @@ where async fn forward_mutating_project_request( request: T, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> where T: EntityMessage + RequestMessage, @@ -2340,10 +2421,7 @@ where .await .host_for_mutating_project_request(project_id, session.connection_id) .await?; - let payload = session - .peer - .forward_request(session.connection_id, host_connection_id, request) - .await?; + let payload = session.forward_request(host_connection_id, request).await?; response.send(payload)?; Ok(()) } @@ -2351,7 +2429,7 @@ where async fn multi_lsp_query( request: MultiLspQuery, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { tracing::Span::current().record("multi_lsp_query_request", request.request_str()); tracing::info!("multi_lsp_query message received"); @@ -2361,7 +2439,7 @@ async fn multi_lsp_query( /// Notify other participants that a new buffer has been created async fn create_buffer_for_peer( request: proto::CreateBufferForPeer, - session: Session, + session: MessageContext, ) -> Result<()> { session .db() @@ -2383,7 +2461,7 @@ async fn create_buffer_for_peer( async fn update_buffer( request: proto::UpdateBuffer, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let mut capability = Capability::ReadOnly; @@ -2418,17 +2496,14 @@ async fn update_buffer( }; if host != session.connection_id { - session - .peer - .forward_request(session.connection_id, host, request.clone()) - .await?; + session.forward_request(host, request.clone()).await?; } response.send(proto::Ack {})?; Ok(()) } -async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> { +async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> { let project_id = ProjectId::from_proto(message.project_id); let operation = message.operation.as_ref().context("invalid operation")?; @@ -2473,7 +2548,7 @@ async fn update_context(message: proto::UpdateContext, session: Session) -> Resu /// Notify other participants that a project has been updated. async fn broadcast_project_message_from_host>( request: T, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.remote_entity_id()); let project_connection_ids = session @@ -2498,7 +2573,7 @@ async fn broadcast_project_message_from_host, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); @@ -2511,10 +2586,7 @@ async fn follow( .check_room_participants(room_id, leader_id, session.connection_id) .await?; - let response_payload = session - .peer - .forward_request(session.connection_id, leader_id, request) - .await?; + let response_payload = session.forward_request(leader_id, request).await?; response.send(response_payload)?; if let Some(project_id) = project_id { @@ -2530,7 +2602,7 @@ async fn follow( } /// Stop following another user in a call. -async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { +async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); let leader_id = request.leader_id.context("invalid leader id")?.into(); @@ -2559,7 +2631,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { } /// Notify everyone following you of your current location. -async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { +async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let database = session.db.lock().await; @@ -2594,7 +2666,7 @@ async fn update_followers(request: proto::UpdateFollowers, session: Session) -> async fn get_users( request: proto::GetUsers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let user_ids = request .user_ids @@ -2622,7 +2694,7 @@ async fn get_users( async fn fuzzy_search_users( request: proto::FuzzySearchUsers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let query = request.query; let users = match query.len() { @@ -2654,7 +2726,7 @@ async fn fuzzy_search_users( async fn request_contact( request: proto::RequestContact, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.responder_id); @@ -2701,7 +2773,7 @@ async fn request_contact( async fn respond_to_contact_request( request: proto::RespondToContactRequest, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let responder_id = session.user_id(); let requester_id = UserId::from_proto(request.requester_id); @@ -2759,7 +2831,7 @@ async fn respond_to_contact_request( async fn remove_contact( request: proto::RemoveContact, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.user_id); @@ -3018,7 +3090,10 @@ async fn update_user_plan(session: &Session) -> Result<()> { Ok(()) } -async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> { +async fn subscribe_to_channels( + _: proto::SubscribeToChannels, + session: MessageContext, +) -> Result<()> { subscribe_user_to_channels(session.user_id(), &session).await?; Ok(()) } @@ -3044,7 +3119,7 @@ async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Resul async fn create_channel( request: proto::CreateChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -3099,7 +3174,7 @@ async fn create_channel( async fn delete_channel( request: proto::DeleteChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -3127,7 +3202,7 @@ async fn delete_channel( async fn invite_channel_member( request: proto::InviteChannelMember, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3164,7 +3239,7 @@ async fn invite_channel_member( async fn remove_channel_member( request: proto::RemoveChannelMember, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3208,7 +3283,7 @@ async fn remove_channel_member( async fn set_channel_visibility( request: proto::SetChannelVisibility, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3253,7 +3328,7 @@ async fn set_channel_visibility( async fn set_channel_member_role( request: proto::SetChannelMemberRole, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3301,7 +3376,7 @@ async fn set_channel_member_role( async fn rename_channel( request: proto::RenameChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3333,7 +3408,7 @@ async fn rename_channel( async fn move_channel( request: proto::MoveChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let to = ChannelId::from_proto(request.to); @@ -3375,7 +3450,7 @@ async fn move_channel( async fn reorder_channel( request: proto::ReorderChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let direction = request.direction(); @@ -3421,7 +3496,7 @@ async fn reorder_channel( async fn get_channel_members( request: proto::GetChannelMembers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3441,7 +3516,7 @@ async fn get_channel_members( async fn respond_to_channel_invite( request: proto::RespondToChannelInvite, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3482,7 +3557,7 @@ async fn respond_to_channel_invite( async fn join_channel( request: proto::JoinChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); join_channel_internal(channel_id, Box::new(response), session).await @@ -3505,7 +3580,7 @@ impl JoinChannelInternalResponse for Response { async fn join_channel_internal( channel_id: ChannelId, response: Box, - session: Session, + session: MessageContext, ) -> Result<()> { let joined_room = { let mut db = session.db().await; @@ -3600,7 +3675,7 @@ async fn join_channel_internal( async fn join_channel_buffer( request: proto::JoinChannelBuffer, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3631,7 +3706,7 @@ async fn join_channel_buffer( /// Edit the channel notes async fn update_channel_buffer( request: proto::UpdateChannelBuffer, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3683,7 +3758,7 @@ async fn update_channel_buffer( async fn rejoin_channel_buffers( request: proto::RejoinChannelBuffers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let buffers = db @@ -3718,7 +3793,7 @@ async fn rejoin_channel_buffers( async fn leave_channel_buffer( request: proto::LeaveChannelBuffer, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3780,7 +3855,7 @@ fn send_notifications( async fn send_channel_message( request: proto::SendChannelMessage, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { // Validate the message body. let body = request.body.trim().to_string(); @@ -3873,7 +3948,7 @@ async fn send_channel_message( async fn remove_channel_message( request: proto::RemoveChannelMessage, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -3908,7 +3983,7 @@ async fn remove_channel_message( async fn update_channel_message( request: proto::UpdateChannelMessage, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -3992,7 +4067,7 @@ async fn update_channel_message( /// Mark a channel message as read async fn acknowledge_channel_message( request: proto::AckChannelMessage, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -4012,7 +4087,7 @@ async fn acknowledge_channel_message( /// Mark a buffer version as synced async fn acknowledge_buffer_version( request: proto::AckBufferOperation, - session: Session, + session: MessageContext, ) -> Result<()> { let buffer_id = BufferId::from_proto(request.buffer_id); session @@ -4032,7 +4107,7 @@ async fn acknowledge_buffer_version( async fn get_supermaven_api_key( _request: proto::GetSupermavenApiKey, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let user_id: String = session.user_id().to_string(); if !session.is_staff() { @@ -4061,7 +4136,7 @@ async fn get_supermaven_api_key( async fn join_channel_chat( request: proto::JoinChannelChat, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); @@ -4079,7 +4154,10 @@ async fn join_channel_chat( } /// Stop receiving chat updates for a channel -async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> { +async fn leave_channel_chat( + request: proto::LeaveChannelChat, + session: MessageContext, +) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); session .db() @@ -4093,7 +4171,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) async fn get_channel_messages( request: proto::GetChannelMessages, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let messages = session @@ -4117,7 +4195,7 @@ async fn get_channel_messages( async fn get_channel_messages_by_id( request: proto::GetChannelMessagesById, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let message_ids = request .message_ids @@ -4140,7 +4218,7 @@ async fn get_channel_messages_by_id( async fn get_notifications( request: proto::GetNotifications, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let notifications = session .db() @@ -4162,7 +4240,7 @@ async fn get_notifications( async fn mark_notification_as_read( request: proto::MarkNotificationRead, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let database = &session.db().await; let notifications = database @@ -4184,7 +4262,7 @@ async fn mark_notification_as_read( async fn get_private_user_info( _request: proto::GetPrivateUserInfo, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -4208,7 +4286,7 @@ async fn get_private_user_info( async fn accept_terms_of_service( _request: proto::AcceptTermsOfService, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -4232,7 +4310,7 @@ async fn accept_terms_of_service( async fn get_llm_api_token( _request: proto::GetLlmToken, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 850b716a9f..ef5bef3e7e 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -1,21 +1,15 @@ use std::sync::Arc; use anyhow::anyhow; -use chrono::Utc; use collections::HashMap; use stripe::SubscriptionStatus; use tokio::sync::RwLock; -use uuid::Uuid; use crate::Result; -use crate::db::billing_subscription::SubscriptionKind; use crate::stripe_client::{ - RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams, - StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams, - StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, - StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, - UpdateSubscriptionParams, + RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems, + StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId, + StripeSubscription, }; pub struct StripeBilling { @@ -94,30 +88,6 @@ impl StripeBilling { .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}"))) } - pub async fn determine_subscription_kind( - &self, - subscription: &StripeSubscription, - ) -> Option { - let zed_pro_price_id = self.zed_pro_price_id().await.ok()?; - let zed_free_price_id = self.zed_free_price_id().await.ok()?; - - subscription.items.iter().find_map(|item| { - let price = item.price.as_ref()?; - - if price.id == zed_pro_price_id { - Some(if subscription.status == SubscriptionStatus::Trialing { - SubscriptionKind::ZedProTrial - } else { - SubscriptionKind::ZedPro - }) - } else if price.id == zed_free_price_id { - Some(SubscriptionKind::ZedFree) - } else { - None - } - }) - } - /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does /// not already exist. /// @@ -150,65 +120,6 @@ impl StripeBilling { Ok(customer_id) } - pub async fn subscribe_to_price( - &self, - subscription_id: &StripeSubscriptionId, - price: &StripePrice, - ) -> Result<()> { - let subscription = self.client.get_subscription(subscription_id).await?; - - if subscription_contains_price(&subscription, &price.id) { - return Ok(()); - } - - const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100; - - let price_per_unit = price.unit_amount.unwrap_or_default(); - let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit; - - self.client - .update_subscription( - subscription_id, - UpdateSubscriptionParams { - items: Some(vec![UpdateSubscriptionItems { - price: Some(price.id.clone()), - }]), - trial_settings: Some(StripeSubscriptionTrialSettings { - end_behavior: StripeSubscriptionTrialSettingsEndBehavior { - missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel - }, - }), - }, - ) - .await?; - - Ok(()) - } - - pub async fn bill_model_request_usage( - &self, - customer_id: &StripeCustomerId, - event_name: &str, - requests: i32, - ) -> Result<()> { - let timestamp = Utc::now().timestamp(); - let idempotency_key = Uuid::new_v4(); - - self.client - .create_meter_event(StripeCreateMeterEventParams { - identifier: &format!("model_requests/{}", idempotency_key), - event_name, - payload: StripeCreateMeterEventPayload { - value: requests as u64, - stripe_customer_id: customer_id, - }, - timestamp: Some(timestamp), - }) - .await?; - - Ok(()) - } - pub async fn subscribe_to_zed_free( &self, customer_id: StripeCustomerId, @@ -243,14 +154,3 @@ impl StripeBilling { Ok(subscription) } } - -fn subscription_contains_price( - subscription: &StripeSubscription, - price_id: &StripePriceId, -) -> bool { - subscription.items.iter().any(|item| { - item.price - .as_ref() - .map_or(false, |price| price.id == *price_id) - }) -} diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 1d28c7f6ef..7b95fdd458 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -24,10 +24,7 @@ use language::{ }; use project::{ ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT, - lsp_store::{ - lsp_ext_command::{ExpandedMacro, LspExtExpandMacro}, - rust_analyzer_ext::RUST_ANALYZER_NAME, - }, + lsp_store::lsp_ext_command::{ExpandedMacro, LspExtExpandMacro}, project_settings::{InlineBlameSettings, ProjectSettings}, }; use recent_projects::disconnected_overlay::DisconnectedOverlay; @@ -3104,9 +3101,7 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA // Turn inline-blame-off by default so no state is transferred without us explicitly doing so let inline_blame_off_settings = Some(InlineBlameSettings { enabled: false, - delay_ms: None, - min_column: None, - show_commit_summary: false, + ..Default::default() }); cx_a.update(|cx| { SettingsStore::update_global(cx, |store, cx| { @@ -3786,11 +3781,18 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes cx_b.update(editor::init); client_a.language_registry().add(rust_lang()); - client_b.language_registry().add(rust_lang()); let mut fake_language_servers = client_a.language_registry().register_fake_lsp( "Rust", FakeLspAdapter { - name: RUST_ANALYZER_NAME, + name: "rust-analyzer", + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + name: "rust-analyzer", ..FakeLspAdapter::default() }, ); diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index 5c5bcd5832..bb84bedfcf 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -1,14 +1,9 @@ use std::sync::Arc; -use chrono::{Duration, Utc}; use pretty_assertions::assert_eq; use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{ - FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, - StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, - StripeSubscriptionItemId, UpdateSubscriptionItems, -}; +use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring}; fn make_stripe_billing() -> (StripeBilling, Arc) { let stripe_client = Arc::new(FakeStripeClient::new()); @@ -21,24 +16,6 @@ fn make_stripe_billing() -> (StripeBilling, Arc) { async fn test_initialize() { let (stripe_billing, stripe_client) = make_stripe_billing(); - // Add test meters - let meter1 = StripeMeter { - id: StripeMeterId("meter_1".into()), - event_name: "event_1".to_string(), - }; - let meter2 = StripeMeter { - id: StripeMeterId("meter_2".into()), - event_name: "event_2".to_string(), - }; - stripe_client - .meters - .lock() - .insert(meter1.id.clone(), meter1); - stripe_client - .meters - .lock() - .insert(meter2.id.clone(), meter2); - // Add test prices let price1 = StripePrice { id: StripePriceId("price_1".into()), @@ -144,217 +121,3 @@ async fn test_find_or_create_customer_by_email() { assert_eq!(customer.email.as_deref(), Some(email)); } } - -#[gpui::test] -async fn test_subscribe_to_price() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let price = StripePrice { - id: StripePriceId("price_test".into()), - unit_amount: Some(2000), - lookup_key: Some("test-price".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(price.id.clone(), price.clone()); - - let now = Utc::now(); - let subscription = StripeSubscription { - id: StripeSubscriptionId("sub_test".into()), - customer: StripeCustomerId("cus_test".into()), - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: vec![], - cancel_at: None, - cancellation_details: None, - }; - stripe_client - .subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - stripe_billing - .subscribe_to_price(&subscription.id, &price) - .await - .unwrap(); - - let update_subscription_calls = stripe_client - .update_subscription_calls - .lock() - .iter() - .map(|(id, params)| (id.clone(), params.clone())) - .collect::>(); - assert_eq!(update_subscription_calls.len(), 1); - assert_eq!(update_subscription_calls[0].0, subscription.id); - assert_eq!( - update_subscription_calls[0].1.items, - Some(vec![UpdateSubscriptionItems { - price: Some(price.id.clone()) - }]) - ); - - // Subscribing to a price that is already on the subscription is a no-op. - { - let now = Utc::now(); - let subscription = StripeSubscription { - id: StripeSubscriptionId("sub_test".into()), - customer: StripeCustomerId("cus_test".into()), - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: vec![StripeSubscriptionItem { - id: StripeSubscriptionItemId("si_test".into()), - price: Some(price.clone()), - }], - cancel_at: None, - cancellation_details: None, - }; - stripe_client - .subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - stripe_billing - .subscribe_to_price(&subscription.id, &price) - .await - .unwrap(); - - assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1); - } -} - -#[gpui::test] -async fn test_subscribe_to_zed_free() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let zed_pro_price = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(0), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(zed_pro_price.id.clone(), zed_pro_price.clone()); - let zed_free_price = StripePrice { - id: StripePriceId("price_2".into()), - unit_amount: Some(0), - lookup_key: Some("zed-free".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(zed_free_price.id.clone(), zed_free_price.clone()); - - stripe_billing.initialize().await.unwrap(); - - // Customer is subscribed to Zed Free when not already subscribed to a plan. - { - let customer_id = StripeCustomerId("cus_no_plan".into()); - - let subscription = stripe_billing - .subscribe_to_zed_free(customer_id) - .await - .unwrap(); - - assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price)); - } - - // Customer is not subscribed to Zed Free when they already have an active subscription. - { - let customer_id = StripeCustomerId("cus_active_subscription".into()); - - let now = Utc::now(); - let existing_subscription = StripeSubscription { - id: StripeSubscriptionId("sub_existing_active".into()), - customer: customer_id.clone(), - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: vec![StripeSubscriptionItem { - id: StripeSubscriptionItemId("si_test".into()), - price: Some(zed_pro_price.clone()), - }], - cancel_at: None, - cancellation_details: None, - }; - stripe_client.subscriptions.lock().insert( - existing_subscription.id.clone(), - existing_subscription.clone(), - ); - - let subscription = stripe_billing - .subscribe_to_zed_free(customer_id) - .await - .unwrap(); - - assert_eq!(subscription, existing_subscription); - } - - // Customer is not subscribed to Zed Free when they already have a trial subscription. - { - let customer_id = StripeCustomerId("cus_trial_subscription".into()); - - let now = Utc::now(); - let existing_subscription = StripeSubscription { - id: StripeSubscriptionId("sub_existing_trial".into()), - customer: customer_id.clone(), - status: stripe::SubscriptionStatus::Trialing, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(14)).timestamp(), - items: vec![StripeSubscriptionItem { - id: StripeSubscriptionItemId("si_test".into()), - price: Some(zed_pro_price.clone()), - }], - cancel_at: None, - cancellation_details: None, - }; - stripe_client.subscriptions.lock().insert( - existing_subscription.id.clone(), - existing_subscription.clone(), - ); - - let subscription = stripe_billing - .subscribe_to_zed_free(customer_id) - .await - .unwrap(); - - assert_eq!(subscription, existing_subscription); - } -} - -#[gpui::test] -async fn test_bill_model_request_usage() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let customer_id = StripeCustomerId("cus_test".into()); - - stripe_billing - .bill_model_request_usage(&customer_id, "some_model/requests", 73) - .await - .unwrap(); - - let create_meter_event_calls = stripe_client - .create_meter_event_calls - .lock() - .iter() - .cloned() - .collect::>(); - assert_eq!(create_meter_event_calls.len(), 1); - assert!( - create_meter_event_calls[0] - .identifier - .starts_with("model_requests/") - ); - assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id); - assert_eq!( - create_meter_event_calls[0].event_name.as_ref(), - "some_model/requests" - ); - assert_eq!(create_meter_event_calls[0].value, 73); -} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 5fcc622fc1..f5a0e8ea81 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -297,6 +297,7 @@ impl TestServer { client_name, Principal::User(user), ZedVersion(SemanticVersion::new(1, 0, 0)), + Some("test".to_string()), None, None, None, diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 3a9b568264..51d9f003f8 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -103,28 +103,16 @@ impl ChatPanel { }); cx.new(|cx| { - let entity = cx.entity().downgrade(); - let message_list = ListState::new( - 0, - gpui::ListAlignment::Bottom, - px(1000.), - move |ix, window, cx| { - if let Some(entity) = entity.upgrade() { - entity.update(cx, |this: &mut Self, cx| { - this.render_message(ix, window, cx).into_any_element() - }) - } else { - div().into_any() - } - }, - ); + let message_list = ListState::new(0, gpui::ListAlignment::Bottom, px(1000.)); - message_list.set_scroll_handler(cx.listener(|this, event: &ListScrollEvent, _, cx| { - if event.visible_range.start < MESSAGE_LOADING_THRESHOLD { - this.load_more_messages(cx); - } - this.is_scrolled_to_bottom = !event.is_scrolled; - })); + message_list.set_scroll_handler(cx.listener( + |this: &mut Self, event: &ListScrollEvent, _, cx| { + if event.visible_range.start < MESSAGE_LOADING_THRESHOLD { + this.load_more_messages(cx); + } + this.is_scrolled_to_bottom = !event.is_scrolled; + }, + )); let local_offset = chrono::Local::now().offset().local_minus_utc(); let mut this = Self { @@ -399,7 +387,7 @@ impl ChatPanel { ix: usize, window: &mut Window, cx: &mut Context, - ) -> impl IntoElement { + ) -> AnyElement { let active_chat = &self.active_chat.as_ref().unwrap().0; let (message, is_continuation_from_previous, is_admin) = active_chat.update(cx, |active_chat, cx| { @@ -582,6 +570,7 @@ impl ChatPanel { self.render_popover_buttons(message_id, can_delete_message, can_edit_message, cx) .mt_neg_2p5(), ) + .into_any_element() } fn has_open_menu(&self, message_id: Option) -> bool { @@ -979,7 +968,13 @@ impl Render for ChatPanel { ) .child(div().flex_grow().px_2().map(|this| { if self.active_chat.is_some() { - this.child(list(self.message_list.clone()).size_full()) + this.child( + list( + self.message_list.clone(), + cx.processor(Self::render_message), + ) + .size_full(), + ) } else { this.child( div() diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 689591df12..51e4ff8965 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -324,20 +324,6 @@ impl CollabPanel { ) .detach(); - let entity = cx.entity().downgrade(); - let list_state = ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - if let Some(entity) = entity.upgrade() { - entity.update(cx, |this, cx| this.render_list_entry(ix, window, cx)) - } else { - div().into_any() - } - }, - ); - let mut this = Self { width: None, focus_handle: cx.focus_handle(), @@ -345,7 +331,7 @@ impl CollabPanel { fs: workspace.app_state().fs.clone(), pending_serialization: Task::ready(None), context_menu: None, - list_state, + list_state: ListState::new(0, gpui::ListAlignment::Top, px(1000.)), channel_name_editor, filter_editor, entries: Vec::default(), @@ -2431,7 +2417,13 @@ impl CollabPanel { }); v_flex() .size_full() - .child(list(self.list_state.clone()).size_full()) + .child( + list( + self.list_state.clone(), + cx.processor(Self::render_list_entry), + ) + .size_full(), + ) .child( v_flex() .child(div().mx_2().border_primary(cx).border_t_1()) @@ -2605,7 +2597,7 @@ impl CollabPanel { let contact = contact.clone(); move |this, event: &ClickEvent, window, cx| { this.deploy_contact_context_menu( - event.down.position, + event.position(), contact.clone(), window, cx, @@ -3061,7 +3053,7 @@ impl Render for CollabPanel { .on_action(cx.listener(CollabPanel::move_channel_down)) .track_focus(&self.focus_handle) .size_full() - .child(if self.user_store.read(cx).current_user().is_none() { + .child(if !self.client.status().borrow().is_connected() { self.render_signed_out(cx) } else { self.render_signed_in(window, cx) diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index c3e834b645..3a280ff667 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -118,16 +118,7 @@ impl NotificationPanel { }) .detach(); - let entity = cx.entity().downgrade(); - let notification_list = - ListState::new(0, ListAlignment::Top, px(1000.), move |ix, window, cx| { - entity - .upgrade() - .and_then(|entity| { - entity.update(cx, |this, cx| this.render_notification(ix, window, cx)) - }) - .unwrap_or_else(|| div().into_any()) - }); + let notification_list = ListState::new(0, ListAlignment::Top, px(1000.)); notification_list.set_scroll_handler(cx.listener( |this, event: &ListScrollEvent, _, cx| { if event.count.saturating_sub(event.visible_range.end) < LOADING_THRESHOLD { @@ -687,7 +678,16 @@ impl Render for NotificationPanel { ), ) } else { - this.child(list(self.notification_list.clone()).size_full()) + this.child( + list( + self.notification_list.clone(), + cx.processor(|this, ix, window, cx| { + this.render_notification(ix, window, cx) + .unwrap_or_else(|| div().into_any()) + }), + ) + .size_full(), + ) } }) } diff --git a/crates/command_palette/src/command_palette.rs b/crates/command_palette/src/command_palette.rs index dfaede0dc4..b8800ff912 100644 --- a/crates/command_palette/src/command_palette.rs +++ b/crates/command_palette/src/command_palette.rs @@ -136,7 +136,10 @@ impl Focusable for CommandPalette { impl Render for CommandPalette { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("CommandPalette") + .w(rems(34.)) + .child(self.picker.clone()) } } diff --git a/crates/component/src/component.rs b/crates/component/src/component.rs index 02840cc3cb..0c05ba4a97 100644 --- a/crates/component/src/component.rs +++ b/crates/component/src/component.rs @@ -318,8 +318,10 @@ pub enum ComponentScope { Notification, #[strum(serialize = "Overlays & Layering")] Overlays, + Onboarding, Status, Typography, + Utilities, #[strum(serialize = "Version Control")] VersionControl, } diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index 2a7225c4e3..2fd6df27b9 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -58,11 +58,19 @@ impl EditPredictionProvider for CopilotCompletionProvider { } fn show_completions_in_menu() -> bool { + true + } + + fn show_tab_accept_marker() -> bool { + true + } + + fn supports_jump_to_edit() -> bool { false } fn is_refreshing(&self) -> bool { - self.pending_refresh.is_some() + self.pending_refresh.is_some() && self.completions.is_empty() } fn is_enabled( @@ -343,8 +351,8 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, window, cx| { assert!(editor.context_menu_visible()); - assert!(!editor.has_active_edit_prediction()); - // Since we have both, the copilot suggestion is not shown inline + assert!(editor.has_active_edit_prediction()); + // Since we have both, the copilot suggestion is existing but does not show up as ghost text assert_eq!(editor.text(cx), "one.\ntwo\nthree\n"); assert_eq!(editor.display_text(cx), "one.\ntwo\nthree\n"); @@ -934,8 +942,9 @@ mod tests { executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT); cx.update_editor(|editor, _, cx| { assert!(editor.context_menu_visible()); - assert!(!editor.has_active_edit_prediction(),); + assert!(editor.has_active_edit_prediction()); assert_eq!(editor.text(cx), "one\ntwo.\nthree\n"); + assert_eq!(editor.display_text(cx), "one\ntwo.\nthree\n"); }); } @@ -1077,8 +1086,6 @@ mod tests { vec![complete_from_marker.clone(), replace_range_marker.clone()], ); - let complete_from_position = - cx.to_lsp(marked_ranges.remove(&complete_from_marker).unwrap()[0].start); let replace_range = cx.to_lsp_range(marked_ranges.remove(&replace_range_marker).unwrap()[0].clone()); @@ -1087,10 +1094,6 @@ mod tests { let completions = completions.clone(); async move { assert_eq!(params.text_document_position.text_document.uri, url.clone()); - assert_eq!( - params.text_document_position.position, - complete_from_position - ); Ok(Some(lsp::CompletionResponse::Array( completions .iter() diff --git a/crates/dap/src/adapters.rs b/crates/dap/src/adapters.rs index 0c88f37ff8..687305ae94 100644 --- a/crates/dap/src/adapters.rs +++ b/crates/dap/src/adapters.rs @@ -74,6 +74,12 @@ impl Borrow for DebugAdapterName { } } +impl Borrow for DebugAdapterName { + fn borrow(&self) -> &SharedString { + &self.0 + } +} + impl std::fmt::Display for DebugAdapterName { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(&self.0, f) diff --git a/crates/dap/src/registry.rs b/crates/dap/src/registry.rs index d56e2f8f34..212fa2bc23 100644 --- a/crates/dap/src/registry.rs +++ b/crates/dap/src/registry.rs @@ -87,7 +87,7 @@ impl DapRegistry { self.0.read().adapters.get(name).cloned() } - pub fn enumerate_adapters(&self) -> Vec { + pub fn enumerate_adapters>(&self) -> B { self.0.read().adapters.keys().cloned().collect() } } diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index d81c593484..0ac419580b 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -300,7 +300,7 @@ impl DebugPanel { }); session.update(cx, |session, _| match &mut session.mode { - SessionState::Building(state_task) => { + SessionState::Booting(state_task) => { *state_task = Some(boot_task); } SessionState::Running(_) => { diff --git a/crates/debugger_ui/src/debugger_ui.rs b/crates/debugger_ui/src/debugger_ui.rs index 9eac59af83..5f5dfd1a1e 100644 --- a/crates/debugger_ui/src/debugger_ui.rs +++ b/crates/debugger_ui/src/debugger_ui.rs @@ -299,59 +299,76 @@ pub fn init(cx: &mut App) { else { return; }; + + let session = active_session + .read(cx) + .running_state + .read(cx) + .session() + .read(cx); + + if session.is_terminated() { + return; + } + let editor = cx.entity().downgrade(); - window.on_action(TypeId::of::(), { - let editor = editor.clone(); - let active_session = active_session.clone(); - move |_, phase, _, cx| { - if phase != DispatchPhase::Bubble { - return; - } - maybe!({ - let (buffer, position, _) = editor - .update(cx, |editor, cx| { - let cursor_point: language::Point = - editor.selections.newest(cx).head(); - editor - .buffer() - .read(cx) - .point_to_buffer_point(cursor_point, cx) - }) - .ok()??; + window.on_action_when( + session.any_stopped_thread(), + TypeId::of::(), + { + let editor = editor.clone(); + let active_session = active_session.clone(); + move |_, phase, _, cx| { + if phase != DispatchPhase::Bubble { + return; + } + maybe!({ + let (buffer, position, _) = editor + .update(cx, |editor, cx| { + let cursor_point: language::Point = + editor.selections.newest(cx).head(); - let path = + editor + .buffer() + .read(cx) + .point_to_buffer_point(cursor_point, cx) + }) + .ok()??; + + let path = debugger::breakpoint_store::BreakpointStore::abs_path_from_buffer( &buffer, cx, )?; - let source_breakpoint = SourceBreakpoint { - row: position.row, - path, - message: None, - condition: None, - hit_condition: None, - state: debugger::breakpoint_store::BreakpointState::Enabled, - }; + let source_breakpoint = SourceBreakpoint { + row: position.row, + path, + message: None, + condition: None, + hit_condition: None, + state: debugger::breakpoint_store::BreakpointState::Enabled, + }; - active_session.update(cx, |session, cx| { - session.running_state().update(cx, |state, cx| { - if let Some(thread_id) = state.selected_thread_id() { - state.session().update(cx, |session, cx| { - session.run_to_position( - source_breakpoint, - thread_id, - cx, - ); - }) - } + active_session.update(cx, |session, cx| { + session.running_state().update(cx, |state, cx| { + if let Some(thread_id) = state.selected_thread_id() { + state.session().update(cx, |session, cx| { + session.run_to_position( + source_breakpoint, + thread_id, + cx, + ); + }) + } + }); }); - }); - Some(()) - }); - } - }); + Some(()) + }); + } + }, + ); window.on_action( TypeId::of::(), diff --git a/crates/debugger_ui/src/new_process_modal.rs b/crates/debugger_ui/src/new_process_modal.rs index 42f77ab056..4ac8e371a1 100644 --- a/crates/debugger_ui/src/new_process_modal.rs +++ b/crates/debugger_ui/src/new_process_modal.rs @@ -1,5 +1,5 @@ use anyhow::{Context as _, bail}; -use collections::{FxHashMap, HashMap}; +use collections::{FxHashMap, HashMap, HashSet}; use language::LanguageRegistry; use std::{ borrow::Cow, @@ -450,7 +450,7 @@ impl NewProcessModal { .and_then(|buffer| buffer.read(cx).language()) .cloned(); - let mut available_adapters = workspace + let mut available_adapters: Vec<_> = workspace .update(cx, |_, cx| DapRegistry::global(cx).enumerate_adapters()) .unwrap_or_default(); if let Some(language) = active_buffer_language { @@ -1015,15 +1015,13 @@ impl DebugDelegate { let language_names = languages.language_names(); let language = dap_registry .adapter_language(&scenario.adapter) - .map(|language| TaskSourceKind::Language { - name: language.into(), - }); + .map(|language| TaskSourceKind::Language { name: language.0 }); let language = language.or_else(|| { scenario.label.split_whitespace().find_map(|word| { language_names .iter() - .find(|name| name.eq_ignore_ascii_case(word)) + .find(|name| name.as_ref().eq_ignore_ascii_case(word)) .map(|name| TaskSourceKind::Language { name: name.to_owned().into(), }) @@ -1056,6 +1054,9 @@ impl DebugDelegate { }) }) }); + + let valid_adapters: HashSet<_> = cx.global::().enumerate_adapters(); + cx.spawn(async move |this, cx| { let (recent, scenarios) = if let Some(task) = task { task.await @@ -1096,6 +1097,7 @@ impl DebugDelegate { } => !(hide_vscode && dir.ends_with(".vscode")), _ => true, }) + .filter(|(_, scenario)| valid_adapters.contains(&scenario.adapter)) .map(|(kind, scenario)| { let (language, scenario) = Self::get_scenario_kind(&languages, &dap_registry, scenario); diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index 2651a94520..a3e2805e2b 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -1014,10 +1014,9 @@ impl RunningState { ..task.resolved.clone() }; let terminal = project - .update_in(cx, |project, window, cx| { + .update(cx, |project, cx| { project.create_terminal( TerminalKind::Task(task_with_shell.clone()), - window.window_handle(), cx, ) })? @@ -1189,9 +1188,7 @@ impl RunningState { let workspace = self.workspace.clone(); let weak_project = project.downgrade(); - let terminal_task = project.update(cx, |project, cx| { - project.create_terminal(kind, window.window_handle(), cx) - }); + let terminal_task = project.update(cx, |project, cx| project.create_terminal(kind, cx)); let terminal_task = cx.spawn_in(window, async move |_, cx| { let terminal = terminal_task.await?; @@ -1651,7 +1648,7 @@ impl RunningState { let is_building = self.session.update(cx, |session, cx| { session.shutdown(cx).detach(); - matches!(session.mode, session::SessionState::Building(_)) + matches!(session.mode, session::SessionState::Booting(_)) }); if is_building { diff --git a/crates/debugger_ui/src/session/running/breakpoint_list.rs b/crates/debugger_ui/src/session/running/breakpoint_list.rs index 6ac4b1c878..a6defbbf35 100644 --- a/crates/debugger_ui/src/session/running/breakpoint_list.rs +++ b/crates/debugger_ui/src/session/running/breakpoint_list.rs @@ -29,7 +29,6 @@ use ui::{ Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, Toggleable, Tooltip, Window, div, h_flex, px, v_flex, }; -use util::ResultExt; use workspace::Workspace; use zed_actions::{ToggleEnableBreakpoint, UnsetBreakpoint}; @@ -56,8 +55,6 @@ pub(crate) struct BreakpointList { scrollbar_state: ScrollbarState, breakpoints: Vec, session: Option>, - hide_scrollbar_task: Option>, - show_scrollbar: bool, focus_handle: FocusHandle, scroll_handle: UniformListScrollHandle, selected_ix: Option, @@ -103,8 +100,6 @@ impl BreakpointList { worktree_store, scrollbar_state, breakpoints: Default::default(), - hide_scrollbar_task: None, - show_scrollbar: false, workspace, session, focus_handle, @@ -565,21 +560,6 @@ impl BreakpointList { Ok(()) } - fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context) { - const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); - self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| { - cx.background_executor() - .timer(SCROLLBAR_SHOW_INTERVAL) - .await; - panel - .update(cx, |panel, cx| { - panel.show_scrollbar = false; - cx.notify(); - }) - .log_err(); - })) - } - fn render_list(&mut self, cx: &mut Context) -> impl IntoElement { let selected_ix = self.selected_ix; let focus_handle = self.focus_handle.clone(); @@ -614,43 +594,39 @@ impl BreakpointList { .flex_grow() } - fn render_vertical_scrollbar(&self, cx: &mut Context) -> Option> { - if !(self.show_scrollbar || self.scrollbar_state.is_dragging()) { - return None; - } - Some( - div() - .occlude() - .id("breakpoint-list-vertical-scrollbar") - .on_mouse_move(cx.listener(|_, _, _, cx| { - cx.notify(); - cx.stop_propagation() - })) - .on_hover(|_, _, cx| { + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .occlude() + .id("breakpoint-list-vertical-scrollbar") + .on_mouse_move(cx.listener(|_, _, _, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { cx.stop_propagation(); - }) - .on_any_mouse_down(|_, _, cx| { - cx.stop_propagation(); - }) - .on_mouse_up( - MouseButton::Left, - cx.listener(|_, _, _, cx| { - cx.stop_propagation(); - }), - ) - .on_scroll_wheel(cx.listener(|_, _, _, cx| { - cx.notify(); - })) - .h_full() - .absolute() - .right_1() - .top_1() - .bottom_0() - .w(px(12.)) - .cursor_default() - .children(Scrollbar::vertical(self.scrollbar_state.clone())), - ) + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) } + pub(crate) fn render_control_strip(&self) -> AnyElement { let selection_kind = self.selection_kind(); let focus_handle = self.focus_handle.clone(); @@ -819,15 +795,6 @@ impl Render for BreakpointList { .id("breakpoint-list") .key_context("BreakpointList") .track_focus(&self.focus_handle) - .on_hover(cx.listener(|this, hovered, window, cx| { - if *hovered { - this.show_scrollbar = true; - this.hide_scrollbar_task.take(); - cx.notify(); - } else if !this.focus_handle.contains_focused(window, cx) { - this.hide_scrollbar(window, cx); - } - })) .on_action(cx.listener(Self::select_next)) .on_action(cx.listener(Self::select_previous)) .on_action(cx.listener(Self::select_first)) @@ -844,7 +811,7 @@ impl Render for BreakpointList { v_flex() .size_full() .child(self.render_list(cx)) - .children(self.render_vertical_scrollbar(cx)), + .child(self.render_vertical_scrollbar(cx)), ) .when_some(self.strip_mode, |this, _| { this.child(Divider::horizontal()).child( diff --git a/crates/debugger_ui/src/session/running/loaded_source_list.rs b/crates/debugger_ui/src/session/running/loaded_source_list.rs index dd5487e042..6b376bb892 100644 --- a/crates/debugger_ui/src/session/running/loaded_source_list.rs +++ b/crates/debugger_ui/src/session/running/loaded_source_list.rs @@ -13,22 +13,8 @@ pub(crate) struct LoadedSourceList { impl LoadedSourceList { pub fn new(session: Entity, cx: &mut Context) -> Self { - let weak_entity = cx.weak_entity(); let focus_handle = cx.focus_handle(); - - let list = ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, _window, cx| { - weak_entity - .upgrade() - .map(|loaded_sources| { - loaded_sources.update(cx, |this, cx| this.render_entry(ix, cx)) - }) - .unwrap_or(div().into_any()) - }, - ); + let list = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let _subscription = cx.subscribe(&session, |this, _, event, cx| match event { SessionEvent::Stopped(_) | SessionEvent::LoadedSources => { @@ -98,6 +84,12 @@ impl Render for LoadedSourceList { .track_focus(&self.focus_handle) .size_full() .p_1() - .child(list(self.list.clone()).size_full()) + .child( + list( + self.list.clone(), + cx.processor(|this, ix, _window, cx| this.render_entry(ix, cx)), + ) + .size_full(), + ) } } diff --git a/crates/debugger_ui/src/session/running/memory_view.rs b/crates/debugger_ui/src/session/running/memory_view.rs index 7b62a1d55d..75b8938371 100644 --- a/crates/debugger_ui/src/session/running/memory_view.rs +++ b/crates/debugger_ui/src/session/running/memory_view.rs @@ -23,7 +23,6 @@ use ui::{ ParentElement, Pixels, PopoverMenuHandle, Render, Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, TextSize, Tooltip, Window, div, h_flex, px, v_flex, }; -use util::ResultExt; use workspace::Workspace; use crate::{ToggleDataBreakpoint, session::running::stack_frame_list::StackFrameList}; @@ -34,9 +33,7 @@ pub(crate) struct MemoryView { workspace: WeakEntity, scroll_handle: UniformListScrollHandle, scroll_state: ScrollbarState, - show_scrollbar: bool, stack_frame_list: WeakEntity, - hide_scrollbar_task: Option>, focus_handle: FocusHandle, view_state: ViewState, query_editor: Entity, @@ -150,8 +147,6 @@ impl MemoryView { scroll_state, scroll_handle, stack_frame_list, - show_scrollbar: false, - hide_scrollbar_task: None, focus_handle: cx.focus_handle(), view_state, query_editor, @@ -168,61 +163,42 @@ impl MemoryView { .detach(); this } - fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context) { - const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); - self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| { - cx.background_executor() - .timer(SCROLLBAR_SHOW_INTERVAL) - .await; - panel - .update(cx, |panel, cx| { - panel.show_scrollbar = false; - cx.notify(); - }) - .log_err(); - })) - } - fn render_vertical_scrollbar(&self, cx: &mut Context) -> Option> { - if !(self.show_scrollbar || self.scroll_state.is_dragging()) { - return None; - } - Some( - div() - .occlude() - .id("memory-view-vertical-scrollbar") - .on_drag_move(cx.listener(|this, evt, _, cx| { - let did_handle = this.handle_scroll_drag(evt); - cx.notify(); - if did_handle { - cx.stop_propagation() - } - })) - .on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty)) - .on_hover(|_, _, cx| { + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ + div() + .occlude() + .id("memory-view-vertical-scrollbar") + .on_drag_move(cx.listener(|this, evt, _, cx| { + let did_handle = this.handle_scroll_drag(evt); + cx.notify(); + if did_handle { + cx.stop_propagation() + } + })) + .on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty)) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { cx.stop_propagation(); - }) - .on_any_mouse_down(|_, _, cx| { - cx.stop_propagation(); - }) - .on_mouse_up( - MouseButton::Left, - cx.listener(|_, _, _, cx| { - cx.stop_propagation(); - }), - ) - .on_scroll_wheel(cx.listener(|_, _, _, cx| { - cx.notify(); - })) - .h_full() - .absolute() - .right_1() - .top_1() - .bottom_0() - .w(px(12.)) - .cursor_default() - .children(Scrollbar::vertical(self.scroll_state.clone())), - ) + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scroll_state.clone()).map(|s| s.auto_hide(cx))) } fn render_memory(&self, cx: &mut Context) -> UniformList { @@ -920,15 +896,6 @@ impl Render for MemoryView { .on_action(cx.listener(Self::page_up)) .size_full() .track_focus(&self.focus_handle) - .on_hover(cx.listener(|this, hovered, window, cx| { - if *hovered { - this.show_scrollbar = true; - this.hide_scrollbar_task.take(); - cx.notify(); - } else if !this.focus_handle.contains_focused(window, cx) { - this.hide_scrollbar(window, cx); - } - })) .child( h_flex() .w_full() @@ -978,7 +945,7 @@ impl Render for MemoryView { ) .with_priority(1) })) - .children(self.render_vertical_scrollbar(cx)), + .child(self.render_vertical_scrollbar(cx)), ) } } diff --git a/crates/debugger_ui/src/session/running/stack_frame_list.rs b/crates/debugger_ui/src/session/running/stack_frame_list.rs index da3674c8e2..2149502f4a 100644 --- a/crates/debugger_ui/src/session/running/stack_frame_list.rs +++ b/crates/debugger_ui/src/session/running/stack_frame_list.rs @@ -70,13 +70,7 @@ impl StackFrameList { _ => {} }); - let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.), { - let this = cx.weak_entity(); - move |ix, _window, cx| { - this.update(cx, |this, cx| this.render_entry(ix, cx)) - .unwrap_or(div().into_any()) - } - }); + let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let scrollbar_state = ScrollbarState::new(list_state.clone()); let mut this = Self { @@ -708,11 +702,14 @@ impl StackFrameList { self.activate_selected_entry(window, cx); } - fn render_list(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - div() - .p_1() - .size_full() - .child(list(self.list_state.clone()).size_full()) + fn render_list(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + div().p_1().size_full().child( + list( + self.list_state.clone(), + cx.processor(|this, ix, _window, cx| this.render_entry(ix, cx)), + ) + .size_full(), + ) } } diff --git a/crates/debugger_ui/src/session/running/variable_list.rs b/crates/debugger_ui/src/session/running/variable_list.rs index 906e482687..efbc72e8cf 100644 --- a/crates/debugger_ui/src/session/running/variable_list.rs +++ b/crates/debugger_ui/src/session/running/variable_list.rs @@ -1107,7 +1107,7 @@ impl VariableList { let variable_value = value.clone(); this.on_click(cx.listener( move |this, click: &ClickEvent, window, cx| { - if click.down.click_count < 2 { + if click.click_count() < 2 { return; } let editor = Self::create_variable_editor( diff --git a/crates/debugger_ui/src/tests/new_process_modal.rs b/crates/debugger_ui/src/tests/new_process_modal.rs index 0805060bf4..d6b0dfa004 100644 --- a/crates/debugger_ui/src/tests/new_process_modal.rs +++ b/crates/debugger_ui/src/tests/new_process_modal.rs @@ -298,7 +298,7 @@ async fn test_dap_adapter_config_conversion_and_validation(cx: &mut TestAppConte let adapter_names = cx.update(|cx| { let registry = DapRegistry::global(cx); - registry.enumerate_adapters() + registry.enumerate_adapters::>() }); let zed_config = ZedDebugConfig { diff --git a/crates/diagnostics/src/diagnostics.rs b/crates/diagnostics/src/diagnostics.rs index ba64ba0eed..e7660920da 100644 --- a/crates/diagnostics/src/diagnostics.rs +++ b/crates/diagnostics/src/diagnostics.rs @@ -177,9 +177,9 @@ impl ProjectDiagnosticsEditor { } project::Event::DiagnosticsUpdated { language_server_id, - path, + paths, } => { - this.paths_to_update.insert(path.clone()); + this.paths_to_update.extend(paths.clone()); let project = project.clone(); this.diagnostic_summary_update = cx.spawn(async move |this, cx| { cx.background_executor() @@ -193,9 +193,9 @@ impl ProjectDiagnosticsEditor { cx.emit(EditorEvent::TitleChanged); if this.editor.focus_handle(cx).contains_focused(window, cx) || this.focus_handle.contains_focused(window, cx) { - log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. recording change"); + log::debug!("diagnostics updated for server {language_server_id}, paths {paths:?}. recording change"); } else { - log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. updating excerpts"); + log::debug!("diagnostics updated for server {language_server_id}, paths {paths:?}. updating excerpts"); this.update_stale_excerpts(window, cx); } } diff --git a/crates/diagnostics/src/diagnostics_tests.rs b/crates/diagnostics/src/diagnostics_tests.rs index 1bb84488e8..8fb223b2cb 100644 --- a/crates/diagnostics/src/diagnostics_tests.rs +++ b/crates/diagnostics/src/diagnostics_tests.rs @@ -876,7 +876,7 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S vec![Inlay::edit_prediction( post_inc(&mut next_inlay_id), snapshot.buffer_snapshot.anchor_before(position), - format!("Test inlay {next_inlay_id}"), + Rope::from_iter(["Test inlay ", "next_inlay_id"]), )], cx, ); diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index fd4e9bb21d..c8502f75de 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -61,6 +61,10 @@ pub trait EditPredictionProvider: 'static + Sized { fn show_tab_accept_marker() -> bool { false } + fn supports_jump_to_edit() -> bool { + true + } + fn data_collection_state(&self, _cx: &App) -> DataCollectionState { DataCollectionState::Unsupported } @@ -116,6 +120,7 @@ pub trait EditPredictionProviderHandle { ) -> bool; fn show_completions_in_menu(&self) -> bool; fn show_tab_accept_marker(&self) -> bool; + fn supports_jump_to_edit(&self) -> bool; fn data_collection_state(&self, cx: &App) -> DataCollectionState; fn usage(&self, cx: &App) -> Option; fn toggle_data_collection(&self, cx: &mut App); @@ -166,6 +171,10 @@ where T::show_tab_accept_marker() } + fn supports_jump_to_edit(&self) -> bool { + T::supports_jump_to_edit() + } + fn data_collection_state(&self, cx: &App) -> DataCollectionState { self.read(cx).data_collection_state(cx) } diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 9ab94a4095..3d3b43d71b 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -491,7 +491,12 @@ impl EditPredictionButton { let subtle_mode = matches!(current_mode, EditPredictionsMode::Subtle); let eager_mode = matches!(current_mode, EditPredictionsMode::Eager); - if matches!(provider, EditPredictionProvider::Zed) { + if matches!( + provider, + EditPredictionProvider::Zed + | EditPredictionProvider::Copilot + | EditPredictionProvider::Supermaven + ) { menu = menu .separator() .header("Display Modes") diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index 3a3a57ca64..39433b3c27 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -745,5 +745,6 @@ actions!( UniqueLinesCaseInsensitive, /// Removes duplicate lines (case-sensitive). UniqueLinesCaseSensitive, + UnwrapSyntaxNode ] ); diff --git a/crates/editor/src/clangd_ext.rs b/crates/editor/src/clangd_ext.rs index b745bf8c37..3239fdc653 100644 --- a/crates/editor/src/clangd_ext.rs +++ b/crates/editor/src/clangd_ext.rs @@ -29,16 +29,14 @@ pub fn switch_source_header( return; }; - let server_lookup = - find_specific_language_server_in_selection(editor, cx, is_c_language, CLANGD_SERVER_NAME); + let Some((_, _, server_to_query, buffer)) = + find_specific_language_server_in_selection(editor, cx, is_c_language, CLANGD_SERVER_NAME) + else { + return; + }; let project = project.clone(); let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); cx.spawn_in(window, async move |_editor, cx| { - let Some((_, _, server_to_query, buffer)) = - server_lookup.await - else { - return Ok(()); - }; let source_file = buffer.read_with(cx, |buffer, _| { buffer.file().map(|file| file.path()).map(|path| path.to_string_lossy().to_string()).unwrap_or_else(|| "Unknown".to_string()) })?; diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index fd49c262c6..b296b3e62a 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -48,16 +48,16 @@ pub struct Inlay { impl Inlay { pub fn hint(id: usize, position: Anchor, hint: &project::InlayHint) -> Self { let mut text = hint.text(); - if hint.padding_right && !text.ends_with(' ') { - text.push(' '); + if hint.padding_right && text.chars_at(text.len().saturating_sub(1)).next() != Some(' ') { + text.push(" "); } - if hint.padding_left && !text.starts_with(' ') { - text.insert(0, ' '); + if hint.padding_left && text.chars_at(0).next() != Some(' ') { + text.push_front(" "); } Self { id: InlayId::Hint(id), position, - text: text.into(), + text, color: None, } } @@ -737,13 +737,13 @@ impl InlayMap { Inlay::mock_hint( post_inc(next_inlay_id), snapshot.buffer.anchor_at(position, bias), - text.clone(), + &text, ) } else { Inlay::edit_prediction( post_inc(next_inlay_id), snapshot.buffer.anchor_at(position, bias), - text.clone(), + &text, ) }; let inlay_id = next_inlay.id; @@ -1694,7 +1694,7 @@ mod tests { (offset, inlay.clone()) }) .collect::>(); - let mut expected_text = Rope::from(buffer_snapshot.text()); + let mut expected_text = Rope::from(&buffer_snapshot.text()); for (offset, inlay) in inlays.iter().rev() { expected_text.replace(*offset..*offset, &inlay.text.to_string()); } diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index 527dfb8832..7bf51e45d7 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -228,6 +228,49 @@ async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext) }); } +#[gpui::test] +async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let provider = cx.new(|_| FakeNonZedEditPredictionProvider::default()); + assign_editor_completion_provider_non_zed(provider.clone(), &mut cx); + + // Cursor is 2+ lines above the proposed edit + cx.set_state(indoc! {" + line 0 + line ˇ1 + line 2 + line 3 + line + "}); + + propose_edits_non_zed( + &provider, + vec![(Point::new(4, 3)..Point::new(4, 3), " 4")], + &mut cx, + ); + + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); + + // For non-Zed providers, there should be no move completion (jump functionality disabled) + cx.editor(|editor, _, _| { + if let Some(completion_state) = &editor.active_edit_prediction { + // Should be an Edit prediction, not a Move prediction + match &completion_state.completion { + EditPrediction::Edit { .. } => { + // This is expected for non-Zed providers + } + EditPrediction::Move { .. } => { + panic!( + "Non-Zed providers should not show Move predictions (jump functionality)" + ); + } + } + } + }); +} + fn assert_editor_active_edit_completion( cx: &mut EditorTestContext, assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range, String)>), @@ -301,6 +344,37 @@ fn assign_editor_completion_provider( }) } +fn propose_edits_non_zed( + provider: &Entity, + edits: Vec<(Range, &str)>, + cx: &mut EditorTestContext, +) { + let snapshot = cx.buffer_snapshot(); + let edits = edits.into_iter().map(|(range, text)| { + let range = snapshot.anchor_after(range.start)..snapshot.anchor_before(range.end); + (range, text.into()) + }); + + cx.update(|_, cx| { + provider.update(cx, |provider, _| { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction { + id: None, + edits: edits.collect(), + edit_preview: None, + })) + }) + }); +} + +fn assign_editor_completion_provider_non_zed( + provider: Entity, + cx: &mut EditorTestContext, +) { + cx.update_editor(|editor, window, cx| { + editor.set_edit_prediction_provider(Some(provider), window, cx); + }) +} + #[derive(Default, Clone)] pub struct FakeEditPredictionProvider { pub completion: Option, @@ -325,6 +399,84 @@ impl EditPredictionProvider for FakeEditPredictionProvider { false } + fn supports_jump_to_edit() -> bool { + true + } + + fn is_enabled( + &self, + _buffer: &gpui::Entity, + _cursor_position: language::Anchor, + _cx: &gpui::App, + ) -> bool { + true + } + + fn is_refreshing(&self) -> bool { + false + } + + fn refresh( + &mut self, + _project: Option>, + _buffer: gpui::Entity, + _cursor_position: language::Anchor, + _debounce: bool, + _cx: &mut gpui::Context, + ) { + } + + fn cycle( + &mut self, + _buffer: gpui::Entity, + _cursor_position: language::Anchor, + _direction: edit_prediction::Direction, + _cx: &mut gpui::Context, + ) { + } + + fn accept(&mut self, _cx: &mut gpui::Context) {} + + fn discard(&mut self, _cx: &mut gpui::Context) {} + + fn suggest<'a>( + &mut self, + _buffer: &gpui::Entity, + _cursor_position: language::Anchor, + _cx: &mut gpui::Context, + ) -> Option { + self.completion.clone() + } +} + +#[derive(Default, Clone)] +pub struct FakeNonZedEditPredictionProvider { + pub completion: Option, +} + +impl FakeNonZedEditPredictionProvider { + pub fn set_edit_prediction(&mut self, completion: Option) { + self.completion = completion; + } +} + +impl EditPredictionProvider for FakeNonZedEditPredictionProvider { + fn name() -> &'static str { + "fake-non-zed-provider" + } + + fn display_name() -> &'static str { + "Fake Non-Zed Provider" + } + + fn show_completions_in_menu() -> bool { + false + } + + fn supports_jump_to_edit() -> bool { + false + } + fn is_enabled( &self, _buffer: &gpui::Entity, diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index ff9b703d66..d1bf95c794 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -118,7 +118,7 @@ use hover_links::{HoverLink, HoveredLinkState, InlayHighlight, find_file}; use hover_popover::{HoverState, hide_hover}; use indent_guides::ActiveIndentGuidesState; use inlay_hint_cache::{InlayHintCache, InlaySplice, InvalidationStrategy}; -use itertools::Itertools; +use itertools::{Either, Itertools}; use language::{ AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, BufferRow, BufferSnapshot, Capability, CharClassifier, CharKind, CodeLabel, CursorShape, DiagnosticEntry, @@ -2705,6 +2705,11 @@ impl Editor { self.completion_provider = provider; } + #[cfg(any(test, feature = "test-support"))] + pub fn completion_provider(&self) -> Option> { + self.completion_provider.clone() + } + pub fn semantics_provider(&self) -> Option> { self.semantics_provider.clone() } @@ -7760,8 +7765,14 @@ impl Editor { } else { None }; - let is_move = - move_invalidation_row_range.is_some() || self.edit_predictions_hidden_for_vim_mode; + let supports_jump = self + .edit_prediction_provider + .as_ref() + .map(|provider| provider.provider.supports_jump_to_edit()) + .unwrap_or(true); + + let is_move = supports_jump + && (move_invalidation_row_range.is_some() || self.edit_predictions_hidden_for_vim_mode); let completion = if is_move { invalidation_row_range = move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row); @@ -8183,7 +8194,7 @@ impl Editor { editor.set_breakpoint_context_menu( row, Some(position), - event.down.position, + event.position(), window, cx, ); @@ -8350,7 +8361,11 @@ impl Editor { .icon_color(color) .toggle_state(is_active) .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| { - let quick_launch = e.down.button == MouseButton::Left; + let quick_launch = match e { + ClickEvent::Keyboard(_) => true, + ClickEvent::Mouse(e) => e.down.button == MouseButton::Left, + }; + window.focus(&editor.focus_handle(cx)); editor.toggle_code_actions( &ToggleCodeActions { @@ -8362,7 +8377,7 @@ impl Editor { ); })) .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| { - editor.set_breakpoint_context_menu(row, position, event.down.position, window, cx); + editor.set_breakpoint_context_menu(row, position, event.position(), window, cx); })) } @@ -8795,8 +8810,12 @@ impl Editor { return None; } - let highlighted_edits = - crate::edit_prediction_edit_text(&snapshot, edits, edit_preview.as_ref()?, false, cx); + let highlighted_edits = if let Some(edit_preview) = edit_preview.as_ref() { + crate::edit_prediction_edit_text(&snapshot, edits, edit_preview, false, cx) + } else { + // Fallback for providers without edit_preview + crate::edit_prediction_fallback_text(edits, cx) + }; let styled_text = highlighted_edits.to_styled_text(&style.text); let line_count = highlighted_edits.text.lines().count(); @@ -9064,6 +9083,18 @@ impl Editor { let editor_bg_color = cx.theme().colors().editor_background; editor_bg_color.blend(accent_color.opacity(0.6)) } + fn get_prediction_provider_icon_name( + provider: &Option, + ) -> IconName { + match provider { + Some(provider) => match provider.provider.name() { + "copilot" => IconName::Copilot, + "supermaven" => IconName::Supermaven, + _ => IconName::ZedPredict, + }, + None => IconName::ZedPredict, + } + } fn render_edit_prediction_cursor_popover( &self, @@ -9076,6 +9107,7 @@ impl Editor { cx: &mut Context, ) -> Option { let provider = self.edit_prediction_provider.as_ref()?; + let provider_icon = Self::get_prediction_provider_icon_name(&self.edit_prediction_provider); if provider.provider.needs_terms_acceptance(cx) { return Some( @@ -9102,7 +9134,7 @@ impl Editor { h_flex() .flex_1() .gap_2() - .child(Icon::new(IconName::ZedPredict)) + .child(Icon::new(provider_icon)) .child(Label::new("Accept Terms of Service")) .child(div().w_full()) .child( @@ -9118,12 +9150,8 @@ impl Editor { let is_refreshing = provider.provider.is_refreshing(cx); - fn pending_completion_container() -> Div { - h_flex() - .h_full() - .flex_1() - .gap_2() - .child(Icon::new(IconName::ZedPredict)) + fn pending_completion_container(icon: IconName) -> Div { + h_flex().h_full().flex_1().gap_2().child(Icon::new(icon)) } let completion = match &self.active_edit_prediction { @@ -9153,7 +9181,7 @@ impl Editor { Icon::new(IconName::ZedPredictUp) } } - EditPrediction::Edit { .. } => Icon::new(IconName::ZedPredict), + EditPrediction::Edit { .. } => Icon::new(provider_icon), })) .child( h_flex() @@ -9220,15 +9248,15 @@ impl Editor { cx, )?, - None => { - pending_completion_container().child(Label::new("...").size(LabelSize::Small)) - } + None => pending_completion_container(provider_icon) + .child(Label::new("...").size(LabelSize::Small)), }, - None => pending_completion_container().child(Label::new("No Prediction")), + None => pending_completion_container(provider_icon) + .child(Label::new("...").size(LabelSize::Small)), }; - let completion = if is_refreshing { + let completion = if is_refreshing || self.active_edit_prediction.is_none() { completion .with_animation( "loading-completion", @@ -9328,23 +9356,35 @@ impl Editor { .child(Icon::new(arrow).color(Color::Muted).size(IconSize::Small)) } + let supports_jump = self + .edit_prediction_provider + .as_ref() + .map(|provider| provider.provider.supports_jump_to_edit()) + .unwrap_or(true); + match &completion.completion { EditPrediction::Move { target, snapshot, .. - } => Some( - h_flex() - .px_2() - .gap_2() - .flex_1() - .child( - if target.text_anchor.to_point(&snapshot).row > cursor_point.row { - Icon::new(IconName::ZedPredictDown) - } else { - Icon::new(IconName::ZedPredictUp) - }, - ) - .child(Label::new("Jump to Edit")), - ), + } => { + if !supports_jump { + return None; + } + + Some( + h_flex() + .px_2() + .gap_2() + .flex_1() + .child( + if target.text_anchor.to_point(&snapshot).row > cursor_point.row { + Icon::new(IconName::ZedPredictDown) + } else { + Icon::new(IconName::ZedPredictUp) + }, + ) + .child(Label::new("Jump to Edit")), + ) + } EditPrediction::Edit { edits, @@ -9354,14 +9394,13 @@ impl Editor { } => { let first_edit_row = edits.first()?.0.start.text_anchor.to_point(&snapshot).row; - let (highlighted_edits, has_more_lines) = crate::edit_prediction_edit_text( - &snapshot, - &edits, - edit_preview.as_ref()?, - true, - cx, - ) - .first_line_preview(); + let (highlighted_edits, has_more_lines) = + if let Some(edit_preview) = edit_preview.as_ref() { + crate::edit_prediction_edit_text(&snapshot, &edits, edit_preview, true, cx) + .first_line_preview() + } else { + crate::edit_prediction_fallback_text(&edits, cx).first_line_preview() + }; let styled_text = gpui::StyledText::new(highlighted_edits.text) .with_default_highlights(&style.text, highlighted_edits.highlights); @@ -9372,11 +9411,13 @@ impl Editor { .child(styled_text) .when(has_more_lines, |parent| parent.child("…")); - let left = if first_edit_row != cursor_point.row { + let left = if supports_jump && first_edit_row != cursor_point.row { render_relative_row_jump("", cursor_point.row, first_edit_row) .into_any_element() } else { - Icon::new(IconName::ZedPredict).into_any_element() + let icon_name = + Editor::get_prediction_provider_icon_name(&self.edit_prediction_provider); + Icon::new(icon_name).into_any_element() }; Some( @@ -14707,6 +14748,81 @@ impl Editor { } } + pub fn unwrap_syntax_node( + &mut self, + _: &UnwrapSyntaxNode, + window: &mut Window, + cx: &mut Context, + ) { + self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); + + let buffer = self.buffer.read(cx).snapshot(cx); + let old_selections: Box<[_]> = self.selections.all::(cx).into(); + + let edits = old_selections + .iter() + // only consider the first selection for now + .take(1) + .map(|selection| { + // Only requires two branches once if-let-chains stabilize (#53667) + let selection_range = if !selection.is_empty() { + selection.range() + } else if let Some((_, ancestor_range)) = + buffer.syntax_ancestor(selection.start..selection.end) + { + match ancestor_range { + MultiOrSingleBufferOffsetRange::Single(range) => range, + MultiOrSingleBufferOffsetRange::Multi(range) => range, + } + } else { + selection.range() + }; + + let mut new_range = selection_range.clone(); + while let Some((_, ancestor_range)) = buffer.syntax_ancestor(new_range.clone()) { + new_range = match ancestor_range { + MultiOrSingleBufferOffsetRange::Single(range) => range, + MultiOrSingleBufferOffsetRange::Multi(range) => range, + }; + if new_range.start < selection_range.start + || new_range.end > selection_range.end + { + break; + } + } + + (selection, selection_range, new_range) + }) + .collect::>(); + + self.transact(window, cx, |editor, window, cx| { + for (_, child, parent) in &edits { + let text = buffer.text_for_range(child.clone()).collect::(); + editor.replace_text_in_range(Some(parent.clone()), &text, window, cx); + } + + editor.change_selections( + SelectionEffects::scroll(Autoscroll::fit()), + window, + cx, + |s| { + s.select( + edits + .iter() + .map(|(s, old, new)| Selection { + id: s.id, + start: new.start, + end: new.start + old.len(), + goal: SelectionGoal::None, + reversed: s.reversed, + }) + .collect(), + ); + }, + ); + }); + } + fn refresh_runnables(&mut self, window: &mut Window, cx: &mut Context) -> Task<()> { if !EditorSettings::get_global(cx).gutter.runnables { self.clear_tasks(); @@ -15550,12 +15666,9 @@ impl Editor { }; let head = self.selections.newest::(cx).head(); let buffer = self.buffer.read(cx); - let (buffer, head) = if let Some(text_anchor) = buffer.text_anchor_for_position(head, cx) { - text_anchor - } else { + let Some((buffer, head)) = buffer.text_anchor_for_position(head, cx) else { return Task::ready(Ok(Navigated::No)); }; - let Some(definitions) = provider.definitions(&buffer, head, kind, cx) else { return Task::ready(Ok(Navigated::No)); }; @@ -15660,62 +15773,109 @@ impl Editor { pub(crate) fn navigate_to_hover_links( &mut self, kind: Option, - mut definitions: Vec, + definitions: Vec, split: bool, window: &mut Window, cx: &mut Context, ) -> Task> { - // If there is one definition, just open it directly - if definitions.len() == 1 { - let definition = definitions.pop().unwrap(); - - enum TargetTaskResult { - Location(Option), - AlreadyNavigated, - } - - let target_task = match definition { - HoverLink::Text(link) => { - Task::ready(anyhow::Ok(TargetTaskResult::Location(Some(link.target)))) - } + // Separate out url and file links, we can only handle one of them at most or an arbitrary number of locations + let mut first_url_or_file = None; + let definitions: Vec<_> = definitions + .into_iter() + .filter_map(|def| match def { + HoverLink::Text(link) => Some(Task::ready(anyhow::Ok(Some(link.target)))), HoverLink::InlayHint(lsp_location, server_id) => { let computation = self.compute_target_location(lsp_location, server_id, window, cx); - cx.background_spawn(async move { - let location = computation.await?; - Ok(TargetTaskResult::Location(location)) - }) + Some(cx.background_spawn(computation)) } HoverLink::Url(url) => { - cx.open_url(&url); - Task::ready(Ok(TargetTaskResult::AlreadyNavigated)) + first_url_or_file = Some(Either::Left(url)); + None } HoverLink::File(path) => { - if let Some(workspace) = self.workspace() { - cx.spawn_in(window, async move |_, cx| { - workspace - .update_in(cx, |workspace, window, cx| { - workspace.open_resolved_path(path, window, cx) - })? - .await - .map(|_| TargetTaskResult::AlreadyNavigated) - }) - } else { - Task::ready(Ok(TargetTaskResult::Location(None))) - } + first_url_or_file = Some(Either::Right(path)); + None } - }; - cx.spawn_in(window, async move |editor, cx| { - let target = match target_task.await.context("target resolution task")? { - TargetTaskResult::AlreadyNavigated => return Ok(Navigated::Yes), - TargetTaskResult::Location(None) => return Ok(Navigated::No), - TargetTaskResult::Location(Some(target)) => target, + }) + .collect(); + + let workspace = self.workspace(); + + cx.spawn_in(window, async move |editor, acx| { + let mut locations: Vec = future::join_all(definitions) + .await + .into_iter() + .filter_map(|location| location.transpose()) + .collect::>() + .context("location tasks")?; + + if locations.len() > 1 { + let Some(workspace) = workspace else { + return Ok(Navigated::No); }; - editor.update_in(cx, |editor, window, cx| { - let Some(workspace) = editor.workspace() else { - return Navigated::No; - }; + let tab_kind = match kind { + Some(GotoDefinitionKind::Implementation) => "Implementations", + _ => "Definitions", + }; + let title = editor + .update_in(acx, |_, _, cx| { + let origin = locations.first().unwrap(); + let buffer = origin.buffer.read(cx); + format!( + "{} for {}", + tab_kind, + buffer + .text_for_range(origin.range.clone()) + .collect::() + ) + }) + .context("buffer title")?; + + let opened = workspace + .update_in(acx, |workspace, window, cx| { + Self::open_locations_in_multibuffer( + workspace, + locations, + title, + split, + MultibufferSelectionMode::First, + window, + cx, + ) + }) + .is_ok(); + + anyhow::Ok(Navigated::from_bool(opened)) + } else if locations.is_empty() { + // If there is one definition, just open it directly + match first_url_or_file { + Some(Either::Left(url)) => { + acx.update(|_, cx| cx.open_url(&url))?; + Ok(Navigated::Yes) + } + Some(Either::Right(path)) => { + let Some(workspace) = workspace else { + return Ok(Navigated::No); + }; + + workspace + .update_in(acx, |workspace, window, cx| { + workspace.open_resolved_path(path, window, cx) + })? + .await?; + Ok(Navigated::Yes) + } + None => Ok(Navigated::No), + } + } else { + let Some(workspace) = workspace else { + return Ok(Navigated::No); + }; + + let target = locations.pop().unwrap(); + editor.update_in(acx, |editor, window, cx| { let pane = workspace.read(cx).active_pane().clone(); let range = target.range.to_point(target.buffer.read(cx)); @@ -15756,81 +15916,8 @@ impl Editor { } Navigated::Yes }) - }) - } else if !definitions.is_empty() { - cx.spawn_in(window, async move |editor, cx| { - let (title, location_tasks, workspace) = editor - .update_in(cx, |editor, window, cx| { - let tab_kind = match kind { - Some(GotoDefinitionKind::Implementation) => "Implementations", - _ => "Definitions", - }; - let title = definitions - .iter() - .find_map(|definition| match definition { - HoverLink::Text(link) => link.origin.as_ref().map(|origin| { - let buffer = origin.buffer.read(cx); - format!( - "{} for {}", - tab_kind, - buffer - .text_for_range(origin.range.clone()) - .collect::() - ) - }), - HoverLink::InlayHint(_, _) => None, - HoverLink::Url(_) => None, - HoverLink::File(_) => None, - }) - .unwrap_or(tab_kind.to_string()); - let location_tasks = definitions - .into_iter() - .map(|definition| match definition { - HoverLink::Text(link) => Task::ready(Ok(Some(link.target))), - HoverLink::InlayHint(lsp_location, server_id) => editor - .compute_target_location(lsp_location, server_id, window, cx), - HoverLink::Url(_) => Task::ready(Ok(None)), - HoverLink::File(_) => Task::ready(Ok(None)), - }) - .collect::>(); - (title, location_tasks, editor.workspace().clone()) - }) - .context("location tasks preparation")?; - - let locations: Vec = future::join_all(location_tasks) - .await - .into_iter() - .filter_map(|location| location.transpose()) - .collect::>() - .context("location tasks")?; - - if locations.is_empty() { - return Ok(Navigated::No); - } - - let Some(workspace) = workspace else { - return Ok(Navigated::No); - }; - - let opened = workspace - .update_in(cx, |workspace, window, cx| { - Self::open_locations_in_multibuffer( - workspace, - locations, - title, - split, - MultibufferSelectionMode::First, - window, - cx, - ) - }) - .ok(); - - anyhow::Ok(Navigated::from_bool(opened.is_some())) - }) - } else { - Task::ready(Ok(Navigated::No)) - } + } + }) } fn compute_target_location( @@ -22188,7 +22275,6 @@ impl SemanticsProvider for Entity { } fn supports_inlay_hints(&self, buffer: &Entity, cx: &mut App) -> bool { - // TODO: make this work for remote projects self.update(cx, |project, cx| { if project .active_debug_session(cx) @@ -23192,6 +23278,33 @@ fn edit_prediction_edit_text( edit_preview.highlight_edits(current_snapshot, &edits, include_deletions, cx) } +fn edit_prediction_fallback_text(edits: &[(Range, String)], cx: &App) -> HighlightedText { + // Fallback for providers that don't provide edit_preview (like Copilot/Supermaven) + // Just show the raw edit text with basic styling + let mut text = String::new(); + let mut highlights = Vec::new(); + + let insertion_highlight_style = HighlightStyle { + color: Some(cx.theme().colors().text), + ..Default::default() + }; + + for (_, edit_text) in edits { + let start_offset = text.len(); + text.push_str(edit_text); + let end_offset = text.len(); + + if start_offset < end_offset { + highlights.push((start_offset..end_offset, insertion_highlight_style)); + } + } + + HighlightedText { + text: text.into(), + highlights, + } +} + pub fn diagnostic_style(severity: lsp::DiagnosticSeverity, colors: &StatusColors) -> Hsla { match severity { lsp::DiagnosticSeverity::ERROR => colors.error, diff --git a/crates/editor/src/editor_settings.rs b/crates/editor/src/editor_settings.rs index 14f46c0e60..3d132651b8 100644 --- a/crates/editor/src/editor_settings.rs +++ b/crates/editor/src/editor_settings.rs @@ -20,6 +20,7 @@ pub struct EditorSettings { pub lsp_highlight_debounce: u64, pub hover_popover_enabled: bool, pub hover_popover_delay: u64, + pub status_bar: StatusBar, pub toolbar: Toolbar, pub scrollbar: Scrollbar, pub minimap: Minimap, @@ -125,6 +126,14 @@ pub struct JupyterContent { pub enabled: Option, } +#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +pub struct StatusBar { + /// Whether to display the active language button in the status bar. + /// + /// Default: true + pub active_language_button: bool, +} + #[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] pub struct Toolbar { pub breadcrumbs: bool, @@ -440,6 +449,8 @@ pub struct EditorSettingsContent { /// /// Default: 300 pub hover_popover_delay: Option, + /// Status bar related settings + pub status_bar: Option, /// Toolbar related settings pub toolbar: Option, /// Scrollbar related settings @@ -567,6 +578,15 @@ pub struct EditorSettingsContent { pub lsp_document_colors: Option, } +// Status bar related settings +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +pub struct StatusBarContent { + /// Whether to display the active language button in the status bar. + /// + /// Default: true + pub active_language_button: Option, +} + // Toolbar related settings #[derive(Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] pub struct ToolbarContent { diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 1cb3565733..b31963c9c8 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -7969,6 +7969,38 @@ async fn test_select_larger_smaller_syntax_node_for_string(cx: &mut TestAppConte }); } +#[gpui::test] +async fn test_unwrap_syntax_node(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + + let language = Arc::new(Language::new( + LanguageConfig::default(), + Some(tree_sitter_rust::LANGUAGE.into()), + )); + + cx.update_buffer(|buffer, cx| { + buffer.set_language(Some(language), cx); + }); + + cx.set_state( + &r#" + use mod1::mod2::{«mod3ˇ», mod4}; + "# + .unindent(), + ); + cx.update_editor(|editor, window, cx| { + editor.unwrap_syntax_node(&UnwrapSyntaxNode, window, cx); + }); + cx.assert_editor_state( + &r#" + use mod1::mod2::«mod3ˇ»; + "# + .unindent(), + ); +} + #[gpui::test] async fn test_fold_function_bodies(cx: &mut TestAppContext) { init_test(cx, |_| {}); diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 268855ab61..a7fd0abf88 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -43,11 +43,11 @@ use gpui::{ Bounds, ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges, Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero, Keystroke, Length, - ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, - ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, - Size, StatefulInteractiveElement, Style, Styled, TextRun, TextStyleRefinement, WeakEntity, - Window, anchored, deferred, div, fill, linear_color_stop, linear_gradient, outline, point, px, - quad, relative, size, solid_background, transparent_black, + ModifiersChangedEvent, MouseButton, MouseClickEvent, MouseDownEvent, MouseMoveEvent, + MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, + ShapedLine, SharedString, Size, StatefulInteractiveElement, Style, Styled, TextRun, + TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill, linear_color_stop, + linear_gradient, outline, point, px, quad, relative, size, solid_background, transparent_black, }; use itertools::Itertools; use language::language_settings::{ @@ -86,8 +86,6 @@ use util::post_inc; use util::{RangeExt, ResultExt, debug_panic}; use workspace::{CollaboratorId, Workspace, item::Item, notifications::NotifyTaskExt}; -const INLINE_BLAME_PADDING_EM_WIDTHS: f32 = 7.; - /// Determines what kinds of highlights should be applied to a lines background. #[derive(Clone, Copy, Default)] struct LineHighlightSpec { @@ -357,6 +355,7 @@ impl EditorElement { register_action(editor, window, Editor::toggle_comments); register_action(editor, window, Editor::select_larger_syntax_node); register_action(editor, window, Editor::select_smaller_syntax_node); + register_action(editor, window, Editor::unwrap_syntax_node); register_action(editor, window, Editor::select_enclosing_symbol); register_action(editor, window, Editor::move_to_enclosing_bracket); register_action(editor, window, Editor::undo_selection); @@ -949,8 +948,12 @@ impl EditorElement { let hovered_link_modifier = Editor::multi_cursor_modifier(false, &event.modifiers(), cx); - if !pending_nonempty_selections && hovered_link_modifier && text_hitbox.is_hovered(window) { - let point = position_map.point_for_position(event.up.position); + if let Some(mouse_position) = event.mouse_position() + && !pending_nonempty_selections + && hovered_link_modifier + && text_hitbox.is_hovered(window) + { + let point = position_map.point_for_position(mouse_position); editor.handle_click_hovered_link(point, event.modifiers(), window, cx); editor.selection_drag_state = SelectionDragState::None; @@ -2423,10 +2426,13 @@ impl EditorElement { let editor = self.editor.read(cx); let blame = editor.blame.clone()?; let padding = { - const INLINE_BLAME_PADDING_EM_WIDTHS: f32 = 6.; const INLINE_ACCEPT_SUGGESTION_EM_WIDTHS: f32 = 14.; - let mut padding = INLINE_BLAME_PADDING_EM_WIDTHS; + let mut padding = ProjectSettings::get_global(cx) + .git + .inline_blame + .unwrap_or_default() + .padding as f32; if let Some(edit_prediction) = editor.active_edit_prediction.as_ref() { match &edit_prediction.completion { @@ -2464,7 +2470,7 @@ impl EditorElement { let min_column_in_pixels = ProjectSettings::get_global(cx) .git .inline_blame - .and_then(|settings| settings.min_column) + .map(|settings| settings.min_column) .map(|col| self.column_pixels(col as usize, window)) .unwrap_or(px(0.)); let min_start = content_origin.x - scroll_pixel_position.x + min_column_in_pixels; @@ -3677,6 +3683,7 @@ impl EditorElement { .id("path header block") .size_full() .justify_between() + .overflow_hidden() .child( h_flex() .gap_2() @@ -3735,7 +3742,7 @@ impl EditorElement { move |editor, e: &ClickEvent, window, cx| { editor.open_excerpts_common( Some(jump_data.clone()), - e.down.modifiers.secondary(), + e.modifiers().secondary(), window, cx, ); @@ -6882,10 +6889,10 @@ impl EditorElement { // Fire click handlers during the bubble phase. DispatchPhase::Bubble => editor.update(cx, |editor, cx| { if let Some(mouse_down) = captured_mouse_down.take() { - let event = ClickEvent { + let event = ClickEvent::Mouse(MouseClickEvent { down: mouse_down, up: event.clone(), - }; + }); Self::click(editor, &event, &position_map, window, cx); } }), @@ -8024,12 +8031,20 @@ impl Element for EditorElement { autoscroll_containing_element, needs_horizontal_autoscroll, ) = self.editor.update(cx, |editor, cx| { - let autoscroll_request = editor.autoscroll_request(); + let autoscroll_request = editor.scroll_manager.take_autoscroll_request(); + let autoscroll_containing_element = autoscroll_request.is_some() || editor.has_pending_selection(); let (needs_horizontal_autoscroll, was_scrolled) = editor - .autoscroll_vertically(bounds, line_height, max_scroll_top, window, cx); + .autoscroll_vertically( + bounds, + line_height, + max_scroll_top, + autoscroll_request, + window, + cx, + ); if was_scrolled.0 { snapshot = editor.snapshot(window, cx); } @@ -8351,7 +8366,13 @@ impl Element for EditorElement { }) .flatten()?; let mut element = render_inline_blame_entry(blame_entry, &style, cx)?; - let inline_blame_padding = INLINE_BLAME_PADDING_EM_WIDTHS * em_advance; + let inline_blame_padding = ProjectSettings::get_global(cx) + .git + .inline_blame + .unwrap_or_default() + .padding + as f32 + * em_advance; Some( element .layout_as_root(AvailableSpace::min_size(), window, cx) @@ -8419,7 +8440,11 @@ impl Element for EditorElement { Ok(blocks) => blocks, Err(resized_blocks) => { self.editor.update(cx, |editor, cx| { - editor.resize_blocks(resized_blocks, autoscroll_request, cx) + editor.resize_blocks( + resized_blocks, + autoscroll_request.map(|(autoscroll, _)| autoscroll), + cx, + ) }); return self.prepaint(None, _inspector_id, bounds, &mut (), window, cx); } @@ -8464,6 +8489,7 @@ impl Element for EditorElement { scroll_width, em_advance, &line_layouts, + autoscroll_request, window, cx, ) diff --git a/crates/editor/src/inlay_hint_cache.rs b/crates/editor/src/inlay_hint_cache.rs index db01cc7ad1..60ad0e5bf6 100644 --- a/crates/editor/src/inlay_hint_cache.rs +++ b/crates/editor/src/inlay_hint_cache.rs @@ -3546,7 +3546,7 @@ pub mod tests { let excerpt_hints = excerpt_hints.read(); for id in &excerpt_hints.ordered_hints { let hint = &excerpt_hints.hints_by_id[id]; - let mut label = hint.text(); + let mut label = hint.text().to_string(); if hint.padding_left { label.insert(0, ' '); } diff --git a/crates/editor/src/lsp_ext.rs b/crates/editor/src/lsp_ext.rs index 8d078f304c..6161afbbc0 100644 --- a/crates/editor/src/lsp_ext.rs +++ b/crates/editor/src/lsp_ext.rs @@ -3,9 +3,8 @@ use std::time::Duration; use crate::Editor; use collections::HashMap; -use futures::stream::FuturesUnordered; use gpui::AsyncApp; -use gpui::{App, AppContext as _, Entity, Task}; +use gpui::{App, Entity, Task}; use itertools::Itertools; use language::Buffer; use language::Language; @@ -18,7 +17,6 @@ use project::Project; use project::TaskSourceKind; use project::lsp_store::lsp_ext_command::GetLspRunnables; use smol::future::FutureExt as _; -use smol::stream::StreamExt; use task::ResolvedTask; use task::TaskContext; use text::BufferId; @@ -29,52 +27,32 @@ pub(crate) fn find_specific_language_server_in_selection( editor: &Editor, cx: &mut App, filter_language: F, - language_server_name: &str, -) -> Task, LanguageServerId, Entity)>> + language_server_name: LanguageServerName, +) -> Option<(Anchor, Arc, LanguageServerId, Entity)> where F: Fn(&Language) -> bool, { - let Some(project) = &editor.project else { - return Task::ready(None); - }; - - let applicable_buffers = editor + let project = editor.project.clone()?; + editor .selections .disjoint_anchors() .iter() .filter_map(|selection| Some((selection.head(), selection.head().buffer_id?))) .unique_by(|(_, buffer_id)| *buffer_id) - .filter_map(|(trigger_anchor, buffer_id)| { + .find_map(|(trigger_anchor, buffer_id)| { let buffer = editor.buffer().read(cx).buffer(buffer_id)?; let language = buffer.read(cx).language_at(trigger_anchor.text_anchor)?; if filter_language(&language) { - Some((trigger_anchor, buffer, language)) + let server_id = buffer.update(cx, |buffer, cx| { + project + .read(cx) + .language_server_id_for_name(buffer, &language_server_name, cx) + })?; + Some((trigger_anchor, language, server_id, buffer)) } else { None } }) - .collect::>(); - - let applicable_buffer_tasks = applicable_buffers - .into_iter() - .map(|(trigger_anchor, buffer, language)| { - let task = buffer.update(cx, |buffer, cx| { - project.update(cx, |project, cx| { - project.language_server_id_for_name(buffer, language_server_name, cx) - }) - }); - (trigger_anchor, buffer, language, task) - }) - .collect::>(); - cx.background_spawn(async move { - for (trigger_anchor, buffer, language, task) in applicable_buffer_tasks { - if let Some(server_id) = task.await { - return Some((trigger_anchor, language, server_id, buffer)); - } - } - - None - }) } async fn lsp_task_context( @@ -116,9 +94,9 @@ pub fn lsp_tasks( for_position: Option, cx: &mut App, ) -> Task, ResolvedTask)>)>> { - let mut lsp_task_sources = task_sources + let lsp_task_sources = task_sources .iter() - .map(|(name, buffer_ids)| { + .filter_map(|(name, buffer_ids)| { let buffers = buffer_ids .iter() .filter(|&&buffer_id| match for_position { @@ -127,61 +105,63 @@ pub fn lsp_tasks( }) .filter_map(|&buffer_id| project.read(cx).buffer_for_id(buffer_id, cx)) .collect::>(); - language_server_for_buffers(project.clone(), name.clone(), buffers, cx) + + let server_id = buffers.iter().find_map(|buffer| { + project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), name, cx) + }) + }); + server_id.zip(Some(buffers)) }) - .collect::>(); + .collect::>(); cx.spawn(async move |cx| { cx.spawn(async move |cx| { let mut lsp_tasks = HashMap::default(); - while let Some(server_to_query) = lsp_task_sources.next().await { - if let Some((server_id, buffers)) = server_to_query { - let mut new_lsp_tasks = Vec::new(); - for buffer in buffers { - let source_kind = match buffer.update(cx, |buffer, _| { - buffer.language().map(|language| language.name()) - }) { - Ok(Some(language_name)) => TaskSourceKind::Lsp { - server: server_id, - language_name: SharedString::from(language_name), - }, - Ok(None) => continue, - Err(_) => return Vec::new(), - }; - let id_base = source_kind.to_id_base(); - let lsp_buffer_context = lsp_task_context(&project, &buffer, cx) - .await - .unwrap_or_default(); + for (server_id, buffers) in lsp_task_sources { + let mut new_lsp_tasks = Vec::new(); + for buffer in buffers { + let source_kind = match buffer.update(cx, |buffer, _| { + buffer.language().map(|language| language.name()) + }) { + Ok(Some(language_name)) => TaskSourceKind::Lsp { + server: server_id, + language_name: SharedString::from(language_name), + }, + Ok(None) => continue, + Err(_) => return Vec::new(), + }; + let id_base = source_kind.to_id_base(); + let lsp_buffer_context = lsp_task_context(&project, &buffer, cx) + .await + .unwrap_or_default(); - if let Ok(runnables_task) = project.update(cx, |project, cx| { - let buffer_id = buffer.read(cx).remote_id(); - project.request_lsp( - buffer, - LanguageServerToQuery::Other(server_id), - GetLspRunnables { - buffer_id, - position: for_position, + if let Ok(runnables_task) = project.update(cx, |project, cx| { + let buffer_id = buffer.read(cx).remote_id(); + project.request_lsp( + buffer, + LanguageServerToQuery::Other(server_id), + GetLspRunnables { + buffer_id, + position: for_position, + }, + cx, + ) + }) { + if let Some(new_runnables) = runnables_task.await.log_err() { + new_lsp_tasks.extend(new_runnables.runnables.into_iter().filter_map( + |(location, runnable)| { + let resolved_task = + runnable.resolve_task(&id_base, &lsp_buffer_context)?; + Some((location, resolved_task)) }, - cx, - ) - }) { - if let Some(new_runnables) = runnables_task.await.log_err() { - new_lsp_tasks.extend( - new_runnables.runnables.into_iter().filter_map( - |(location, runnable)| { - let resolved_task = runnable - .resolve_task(&id_base, &lsp_buffer_context)?; - Some((location, resolved_task)) - }, - ), - ); - } + )); } - lsp_tasks - .entry(source_kind) - .or_insert_with(Vec::new) - .append(&mut new_lsp_tasks); } + lsp_tasks + .entry(source_kind) + .or_insert_with(Vec::new) + .append(&mut new_lsp_tasks); } } lsp_tasks.into_iter().collect() @@ -198,27 +178,3 @@ pub fn lsp_tasks( .await }) } - -fn language_server_for_buffers( - project: Entity, - name: LanguageServerName, - candidates: Vec>, - cx: &mut App, -) -> Task>)>> { - cx.spawn(async move |cx| { - for buffer in &candidates { - let server_id = buffer - .update(cx, |buffer, cx| { - project.update(cx, |project, cx| { - project.language_server_id_for_name(buffer, &name.0, cx) - }) - }) - .ok()? - .await; - if let Some(server_id) = server_id { - return Some((server_id, candidates)); - } - } - None - }) -} diff --git a/crates/editor/src/mouse_context_menu.rs b/crates/editor/src/mouse_context_menu.rs index cbb6791a2f..9d5145dec1 100644 --- a/crates/editor/src/mouse_context_menu.rs +++ b/crates/editor/src/mouse_context_menu.rs @@ -1,8 +1,8 @@ use crate::{ Copy, CopyAndTrim, CopyPermalinkToLine, Cut, DisplayPoint, DisplaySnapshot, Editor, EvaluateSelectedText, FindAllReferences, GoToDeclaration, GoToDefinition, GoToImplementation, - GoToTypeDefinition, Paste, Rename, RevealInFileManager, SelectMode, SelectionEffects, - SelectionExt, ToDisplayPoint, ToggleCodeActions, + GoToTypeDefinition, Paste, Rename, RevealInFileManager, RunToCursor, SelectMode, + SelectionEffects, SelectionExt, ToDisplayPoint, ToggleCodeActions, actions::{Format, FormatSelections}, selections_collection::SelectionsCollection, }; @@ -200,15 +200,21 @@ pub fn deploy_context_menu( }); let evaluate_selection = window.is_action_available(&EvaluateSelectedText, cx); + let run_to_cursor = window.is_action_available(&RunToCursor, cx); ui::ContextMenu::build(window, cx, |menu, _window, _cx| { let builder = menu .on_blur_subscription(Subscription::new(|| {})) - .when(evaluate_selection && has_selections, |builder| { - builder - .action("Evaluate Selection", Box::new(EvaluateSelectedText)) - .separator() + .when(run_to_cursor, |builder| { + builder.action("Run to Cursor", Box::new(RunToCursor)) }) + .when(evaluate_selection && has_selections, |builder| { + builder.action("Evaluate Selection", Box::new(EvaluateSelectedText)) + }) + .when( + run_to_cursor || (evaluate_selection && has_selections), + |builder| builder.separator(), + ) .action("Go to Definition", Box::new(GoToDefinition)) .action("Go to Declaration", Box::new(GoToDeclaration)) .action("Go to Type Definition", Box::new(GoToTypeDefinition)) diff --git a/crates/editor/src/rust_analyzer_ext.rs b/crates/editor/src/rust_analyzer_ext.rs index da0f11036f..2b8150de67 100644 --- a/crates/editor/src/rust_analyzer_ext.rs +++ b/crates/editor/src/rust_analyzer_ext.rs @@ -57,21 +57,21 @@ pub fn go_to_parent_module( return; }; - let server_lookup = find_specific_language_server_in_selection( - editor, - cx, - is_rust_language, - RUST_ANALYZER_NAME, - ); + let Some((trigger_anchor, _, server_to_query, buffer)) = + find_specific_language_server_in_selection( + editor, + cx, + is_rust_language, + RUST_ANALYZER_NAME, + ) + else { + return; + }; let project = project.clone(); let lsp_store = project.read(cx).lsp_store(); let upstream_client = lsp_store.read(cx).upstream_client(); cx.spawn_in(window, async move |editor, cx| { - let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else { - return anyhow::Ok(()); - }; - let location_links = if let Some((client, project_id)) = upstream_client { let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id())?; @@ -121,7 +121,7 @@ pub fn go_to_parent_module( ) })? .await?; - Ok(()) + anyhow::Ok(()) }) .detach_and_log_err(cx); } @@ -139,21 +139,19 @@ pub fn expand_macro_recursively( return; }; - let server_lookup = find_specific_language_server_in_selection( - editor, - cx, - is_rust_language, - RUST_ANALYZER_NAME, - ); - + let Some((trigger_anchor, rust_language, server_to_query, buffer)) = + find_specific_language_server_in_selection( + editor, + cx, + is_rust_language, + RUST_ANALYZER_NAME, + ) + else { + return; + }; let project = project.clone(); let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); cx.spawn_in(window, async move |_editor, cx| { - let Some((trigger_anchor, rust_language, server_to_query, buffer)) = server_lookup.await - else { - return Ok(()); - }; - let macro_expansion = if let Some((client, project_id)) = upstream_client { let buffer_id = buffer.update(cx, |buffer, _| buffer.remote_id())?; let request = proto::LspExtExpandMacro { @@ -231,20 +229,20 @@ pub fn open_docs(editor: &mut Editor, _: &OpenDocs, window: &mut Window, cx: &mu return; }; - let server_lookup = find_specific_language_server_in_selection( - editor, - cx, - is_rust_language, - RUST_ANALYZER_NAME, - ); + let Some((trigger_anchor, _, server_to_query, buffer)) = + find_specific_language_server_in_selection( + editor, + cx, + is_rust_language, + RUST_ANALYZER_NAME, + ) + else { + return; + }; let project = project.clone(); let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); cx.spawn_in(window, async move |_editor, cx| { - let Some((trigger_anchor, _, server_to_query, buffer)) = server_lookup.await else { - return Ok(()); - }; - let docs_urls = if let Some((client, project_id)) = upstream_client { let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id())?; let request = proto::LspExtOpenDocs { diff --git a/crates/editor/src/scroll.rs b/crates/editor/src/scroll.rs index ecaf7c11e4..08ff23f8f7 100644 --- a/crates/editor/src/scroll.rs +++ b/crates/editor/src/scroll.rs @@ -348,8 +348,8 @@ impl ScrollManager { self.show_scrollbars } - pub fn autoscroll_request(&self) -> Option { - self.autoscroll_request.map(|(autoscroll, _)| autoscroll) + pub fn take_autoscroll_request(&mut self) -> Option<(Autoscroll, bool)> { + self.autoscroll_request.take() } pub fn active_scrollbar_state(&self) -> Option<&ActiveScrollbarState> { diff --git a/crates/editor/src/scroll/autoscroll.rs b/crates/editor/src/scroll/autoscroll.rs index e8a1f8da73..88d3b52d76 100644 --- a/crates/editor/src/scroll/autoscroll.rs +++ b/crates/editor/src/scroll/autoscroll.rs @@ -102,15 +102,12 @@ impl AutoscrollStrategy { pub(crate) struct NeedsHorizontalAutoscroll(pub(crate) bool); impl Editor { - pub fn autoscroll_request(&self) -> Option { - self.scroll_manager.autoscroll_request() - } - pub(crate) fn autoscroll_vertically( &mut self, bounds: Bounds, line_height: Pixels, max_scroll_top: f32, + autoscroll_request: Option<(Autoscroll, bool)>, window: &mut Window, cx: &mut Context, ) -> (NeedsHorizontalAutoscroll, WasScrolled) { @@ -137,7 +134,7 @@ impl Editor { WasScrolled(false) }; - let Some((autoscroll, local)) = self.scroll_manager.autoscroll_request.take() else { + let Some((autoscroll, local)) = autoscroll_request else { return (NeedsHorizontalAutoscroll(false), editor_was_scrolled); }; @@ -284,9 +281,12 @@ impl Editor { scroll_width: Pixels, em_advance: Pixels, layouts: &[LineWithInvisibles], + autoscroll_request: Option<(Autoscroll, bool)>, window: &mut Window, cx: &mut Context, ) -> Option> { + let (_, local) = autoscroll_request?; + let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); let selections = self.selections.all::(cx); let mut scroll_position = self.scroll_manager.scroll_position(&display_map); @@ -335,10 +335,10 @@ impl Editor { let was_scrolled = if target_left < scroll_left { scroll_position.x = target_left / em_advance; - self.set_scroll_position_internal(scroll_position, true, true, window, cx) + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } else if target_right > scroll_right { scroll_position.x = (target_right - viewport_width) / em_advance; - self.set_scroll_position_internal(scroll_position, true, true, window, cx) + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } else { WasScrolled(false) }; diff --git a/crates/editor/src/signature_help.rs b/crates/editor/src/signature_help.rs index 3447e66ccd..e9f8d2dbd3 100644 --- a/crates/editor/src/signature_help.rs +++ b/crates/editor/src/signature_help.rs @@ -191,7 +191,7 @@ impl Editor { if let Some(language) = language { for signature in &mut signature_help.signatures { - let text = Rope::from(signature.label.to_string()); + let text = Rope::from(signature.label.as_ref()); let highlights = language .highlight_text(&text, 0..signature.label.len()) .into_iter() diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index 891ab91852..c31774c20d 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -10,7 +10,7 @@ use fs::{FakeFs, Fs, RealFs}; use futures::{AsyncReadExt, StreamExt, io::BufReader}; use gpui::{AppContext as _, SemanticVersion, TestAppContext}; use http_client::{FakeHttpClient, Response}; -use language::{BinaryStatus, LanguageMatcher, LanguageRegistry}; +use language::{BinaryStatus, LanguageMatcher, LanguageName, LanguageRegistry}; use language_extension::LspAccess; use lsp::LanguageServerName; use node_runtime::NodeRuntime; @@ -306,7 +306,11 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!( language_registry.language_names(), - ["ERB", "Plain Text", "Ruby"] + [ + LanguageName::new("ERB"), + LanguageName::new("Plain Text"), + LanguageName::new("Ruby"), + ] ); assert_eq!( theme_registry.list_names(), @@ -458,7 +462,11 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!( language_registry.language_names(), - ["ERB", "Plain Text", "Ruby"] + [ + LanguageName::new("ERB"), + LanguageName::new("Plain Text"), + LanguageName::new("Ruby"), + ] ); assert_eq!( language_registry.grammar_names(), @@ -513,7 +521,10 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!(actual_language.hidden, expected_language.hidden); } - assert_eq!(language_registry.language_names(), ["Plain Text"]); + assert_eq!( + language_registry.language_names(), + [LanguageName::new("Plain Text")] + ); assert_eq!(language_registry.grammar_names(), []); }); } diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 631bafc841..ef357adf35 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -158,6 +158,11 @@ where } } +#[derive(Debug)] +pub struct OnFlagsReady { + pub is_staff: bool, +} + pub trait FeatureFlagAppExt { fn wait_for_flag(&mut self) -> WaitForFlag; @@ -169,6 +174,10 @@ pub trait FeatureFlagAppExt { fn has_flag(&self) -> bool; fn is_staff(&self) -> bool; + fn on_flags_ready(&mut self, callback: F) -> Subscription + where + F: FnMut(OnFlagsReady, &mut App) + 'static; + fn observe_flag(&mut self, callback: F) -> Subscription where F: FnMut(bool, &mut App) + 'static; @@ -198,6 +207,21 @@ impl FeatureFlagAppExt for App { .unwrap_or(false) } + fn on_flags_ready(&mut self, mut callback: F) -> Subscription + where + F: FnMut(OnFlagsReady, &mut App) + 'static, + { + self.observe_global::(move |cx| { + let feature_flags = cx.global::(); + callback( + OnFlagsReady { + is_staff: feature_flags.staff, + }, + cx, + ); + }) + } + fn observe_flag(&mut self, mut callback: F) -> Subscription where F: FnMut(bool, &mut App) + 'static, diff --git a/crates/fs/src/fs_watcher.rs b/crates/fs/src/fs_watcher.rs index 9fdf2ad0b1..a5ce21294f 100644 --- a/crates/fs/src/fs_watcher.rs +++ b/crates/fs/src/fs_watcher.rs @@ -1,6 +1,9 @@ use notify::EventKind; use parking_lot::Mutex; -use std::sync::{Arc, OnceLock}; +use std::{ + collections::HashMap, + sync::{Arc, OnceLock}, +}; use util::{ResultExt, paths::SanitizedPath}; use crate::{PathEvent, PathEventKind, Watcher}; @@ -8,6 +11,7 @@ use crate::{PathEvent, PathEventKind, Watcher}; pub struct FsWatcher { tx: smol::channel::Sender<()>, pending_path_events: Arc>>, + registrations: Mutex, WatcherRegistrationId>>, } impl FsWatcher { @@ -18,10 +22,24 @@ impl FsWatcher { Self { tx, pending_path_events, + registrations: Default::default(), } } } +impl Drop for FsWatcher { + fn drop(&mut self) { + let mut registrations = self.registrations.lock(); + let registrations = registrations.drain(); + + let _ = global(|g| { + for (_, registration) in registrations { + g.remove(registration); + } + }); + } +} + impl Watcher for FsWatcher { fn add(&self, path: &std::path::Path) -> anyhow::Result<()> { let root_path = SanitizedPath::from(path); @@ -29,75 +47,143 @@ impl Watcher for FsWatcher { let tx = self.tx.clone(); let pending_paths = self.pending_path_events.clone(); - use notify::Watcher; + let path: Arc = path.into(); - global({ + if self.registrations.lock().contains_key(&path) { + return Ok(()); + } + + let registration_id = global({ + let path = path.clone(); |g| { - g.add(move |event: ¬ify::Event| { - let kind = match event.kind { - EventKind::Create(_) => Some(PathEventKind::Created), - EventKind::Modify(_) => Some(PathEventKind::Changed), - EventKind::Remove(_) => Some(PathEventKind::Removed), - _ => None, - }; - let mut path_events = event - .paths - .iter() - .filter_map(|event_path| { - let event_path = SanitizedPath::from(event_path); - event_path.starts_with(&root_path).then(|| PathEvent { - path: event_path.as_path().to_path_buf(), - kind, + g.add( + path, + notify::RecursiveMode::NonRecursive, + move |event: ¬ify::Event| { + let kind = match event.kind { + EventKind::Create(_) => Some(PathEventKind::Created), + EventKind::Modify(_) => Some(PathEventKind::Changed), + EventKind::Remove(_) => Some(PathEventKind::Removed), + _ => None, + }; + let mut path_events = event + .paths + .iter() + .filter_map(|event_path| { + let event_path = SanitizedPath::from(event_path); + event_path.starts_with(&root_path).then(|| PathEvent { + path: event_path.as_path().to_path_buf(), + kind, + }) }) - }) - .collect::>(); + .collect::>(); - if !path_events.is_empty() { - path_events.sort(); - let mut pending_paths = pending_paths.lock(); - if pending_paths.is_empty() { - tx.try_send(()).ok(); + if !path_events.is_empty() { + path_events.sort(); + let mut pending_paths = pending_paths.lock(); + if pending_paths.is_empty() { + tx.try_send(()).ok(); + } + util::extend_sorted( + &mut *pending_paths, + path_events, + usize::MAX, + |a, b| a.path.cmp(&b.path), + ); } - util::extend_sorted( - &mut *pending_paths, - path_events, - usize::MAX, - |a, b| a.path.cmp(&b.path), - ); - } - }) + }, + ) } - })?; - - global(|g| { - g.watcher - .lock() - .watch(path, notify::RecursiveMode::NonRecursive) })??; + self.registrations.lock().insert(path, registration_id); + Ok(()) } fn remove(&self, path: &std::path::Path) -> anyhow::Result<()> { - use notify::Watcher; - Ok(global(|w| w.watcher.lock().unwatch(path))??) + let Some(registration) = self.registrations.lock().remove(path) else { + return Ok(()); + }; + + global(|w| w.remove(registration)) } } +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct WatcherRegistrationId(u32); + +struct WatcherRegistrationState { + callback: Arc, + path: Arc, +} + +struct WatcherState { + watchers: HashMap, + path_registrations: HashMap, u32>, + last_registration: WatcherRegistrationId, +} + pub struct GlobalWatcher { + state: Mutex, + + // DANGER: never keep the state lock while holding the watcher lock // two mutexes because calling watcher.add triggers an watcher.event, which needs watchers. #[cfg(target_os = "linux")] - pub(super) watcher: Mutex, + watcher: Mutex, #[cfg(target_os = "freebsd")] - pub(super) watcher: Mutex, + watcher: Mutex, #[cfg(target_os = "windows")] - pub(super) watcher: Mutex, - pub(super) watchers: Mutex>>, + watcher: Mutex, } impl GlobalWatcher { - pub(super) fn add(&self, cb: impl Fn(¬ify::Event) + Send + Sync + 'static) { - self.watchers.lock().push(Box::new(cb)) + #[must_use] + fn add( + &self, + path: Arc, + mode: notify::RecursiveMode, + cb: impl Fn(¬ify::Event) + Send + Sync + 'static, + ) -> anyhow::Result { + use notify::Watcher; + + self.watcher.lock().watch(&path, mode)?; + + let mut state = self.state.lock(); + + let id = state.last_registration; + state.last_registration = WatcherRegistrationId(id.0 + 1); + + let registration_state = WatcherRegistrationState { + callback: Arc::new(cb), + path: path.clone(), + }; + state.watchers.insert(id, registration_state); + *state.path_registrations.entry(path.clone()).or_insert(0) += 1; + + Ok(id) + } + + pub fn remove(&self, id: WatcherRegistrationId) { + use notify::Watcher; + let mut state = self.state.lock(); + let Some(registration_state) = state.watchers.remove(&id) else { + return; + }; + + let Some(count) = state.path_registrations.get_mut(®istration_state.path) else { + return; + }; + *count -= 1; + if *count == 0 { + state.path_registrations.remove(®istration_state.path); + + drop(state); + self.watcher + .lock() + .unwatch(®istration_state.path) + .log_err(); + } } } @@ -114,8 +200,16 @@ fn handle_event(event: Result) { return; }; global::<()>(move |watcher| { - for f in watcher.watchers.lock().iter() { - f(&event) + let callbacks = { + let state = watcher.state.lock(); + state + .watchers + .values() + .map(|r| r.callback.clone()) + .collect::>() + }; + for callback in callbacks { + callback(&event); } }) .log_err(); @@ -124,8 +218,12 @@ fn handle_event(event: Result) { pub fn global(f: impl FnOnce(&GlobalWatcher) -> T) -> anyhow::Result { let result = FS_WATCHER_INSTANCE.get_or_init(|| { notify::recommended_watcher(handle_event).map(|file_watcher| GlobalWatcher { + state: Mutex::new(WatcherState { + watchers: Default::default(), + path_registrations: Default::default(), + last_registration: Default::default(), + }), watcher: Mutex::new(file_watcher), - watchers: Default::default(), }) }); match result { diff --git a/crates/fuzzy/src/matcher.rs b/crates/fuzzy/src/matcher.rs index aff6390534..e649d47dd6 100644 --- a/crates/fuzzy/src/matcher.rs +++ b/crates/fuzzy/src/matcher.rs @@ -208,8 +208,15 @@ impl<'a> Matcher<'a> { return 1.0; } - let path_len = prefix.len() + path.len(); + let limit = self.last_positions[query_idx]; + let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1); + let safe_limit = limit.min(max_valid_index); + if path_idx > safe_limit { + return 0.0; + } + + let path_len = prefix.len() + path.len(); if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] { return memoized; } @@ -218,16 +225,13 @@ impl<'a> Matcher<'a> { let mut best_position = 0; let query_char = self.lowercase_query[query_idx]; - let limit = self.last_positions[query_idx]; - - let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1); - let safe_limit = limit.min(max_valid_index); let mut last_slash = 0; + for j in path_idx..=safe_limit { let extra_lowercase_chars_count = extra_lowercase_chars .iter() - .take_while(|(i, _)| i < &&j) + .take_while(|&(&i, _)| i < j) .map(|(_, increment)| increment) .sum::(); let j_regular = j - extra_lowercase_chars_count; @@ -236,10 +240,9 @@ impl<'a> Matcher<'a> { lowercase_prefix[j] } else { let path_index = j - prefix.len(); - if path_index < path_lowercased.len() { - path_lowercased[path_index] - } else { - continue; + match path_lowercased.get(path_index) { + Some(&char) => char, + None => continue, } }; let is_path_sep = path_char == MAIN_SEPARATOR; @@ -255,18 +258,16 @@ impl<'a> Matcher<'a> { #[cfg(target_os = "windows")] let need_to_score = query_char == path_char || (is_path_sep && query_char == '_'); if need_to_score { - let curr = if j_regular < prefix.len() { - prefix[j_regular] - } else { - path[j_regular - prefix.len()] + let curr = match prefix.get(j_regular) { + Some(&curr) => curr, + None => path[j_regular - prefix.len()], }; let mut char_score = 1.0; if j > path_idx { - let last = if j_regular - 1 < prefix.len() { - prefix[j_regular - 1] - } else { - path[j_regular - 1 - prefix.len()] + let last = match prefix.get(j_regular - 1) { + Some(&last) => last, + None => path[j_regular - 1 - prefix.len()], }; if last == MAIN_SEPARATOR { diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index b536bed710..dc7ab0af65 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -846,14 +846,12 @@ impl GitRepository for RealGitRepository { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn()?; - child - .stdin - .take() - .unwrap() - .write_all(content.as_bytes()) - .await?; + let mut stdin = child.stdin.take().unwrap(); + stdin.write_all(content.as_bytes()).await?; + stdin.flush().await?; + drop(stdin); let output = child.output().await?.stdout; - let sha = String::from_utf8(output)?; + let sha = str::from_utf8(&output)?.trim(); log::debug!("indexing SHA: {sha}, path {path:?}"); @@ -871,6 +869,7 @@ impl GitRepository for RealGitRepository { String::from_utf8_lossy(&output.stderr) ); } else { + log::debug!("removing path {path:?} from the index"); let output = new_smol_command(&git_binary_path) .current_dir(&working_directory) .envs(env.iter()) @@ -921,6 +920,7 @@ impl GitRepository for RealGitRepository { for rev in &revs { write!(&mut stdin, "{rev}\n")?; } + stdin.flush()?; drop(stdin); let output = process.wait_with_output()?; diff --git a/crates/git_hosting_providers/src/providers/bitbucket.rs b/crates/git_hosting_providers/src/providers/bitbucket.rs index 074a169135..26df7b567a 100644 --- a/crates/git_hosting_providers/src/providers/bitbucket.rs +++ b/crates/git_hosting_providers/src/providers/bitbucket.rs @@ -1,12 +1,22 @@ use std::str::FromStr; +use std::sync::LazyLock; +use regex::Regex; use url::Url; use git::{ BuildCommitPermalinkParams, BuildPermalinkParams, GitHostingProvider, ParsedGitRemote, - RemoteUrl, + PullRequest, RemoteUrl, }; +fn pull_request_regex() -> &'static Regex { + static PULL_REQUEST_REGEX: LazyLock = LazyLock::new(|| { + // This matches Bitbucket PR reference pattern: (pull request #xxx) + Regex::new(r"\(pull request #(\d+)\)").unwrap() + }); + &PULL_REQUEST_REGEX +} + pub struct Bitbucket { name: String, base_url: Url, @@ -96,6 +106,22 @@ impl GitHostingProvider for Bitbucket { ); permalink } + + fn extract_pull_request(&self, remote: &ParsedGitRemote, message: &str) -> Option { + // Check first line of commit message for PR references + let first_line = message.lines().next()?; + + // Try to match against our PR patterns + let capture = pull_request_regex().captures(first_line)?; + let number = capture.get(1)?.as_str().parse::().ok()?; + + // Construct the PR URL in Bitbucket format + let mut url = self.base_url(); + let path = format!("/{}/{}/pull-requests/{}", remote.owner, remote.repo, number); + url.set_path(&path); + + Some(PullRequest { number, url }) + } } #[cfg(test)] @@ -203,4 +229,34 @@ mod tests { "https://bitbucket.org/zed-industries/zed/src/f00b4r/main.rs#lines-24:48"; assert_eq!(permalink.to_string(), expected_url.to_string()) } + + #[test] + fn test_bitbucket_pull_requests() { + use indoc::indoc; + + let remote = ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed".into(), + }; + + let bitbucket = Bitbucket::public_instance(); + + // Test message without PR reference + let message = "This does not contain a pull request"; + assert!(bitbucket.extract_pull_request(&remote, message).is_none()); + + // Pull request number at end of first line + let message = indoc! {r#" + Merged in feature-branch (pull request #123) + + Some detailed description of the changes. + "#}; + + let pr = bitbucket.extract_pull_request(&remote, message).unwrap(); + assert_eq!(pr.number, 123); + assert_eq!( + pr.url.as_str(), + "https://bitbucket.org/zed-industries/zed/pull-requests/123" + ); + } } diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 1092ba33d1..b74fa649b0 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -180,6 +180,7 @@ impl Focusable for BranchList { impl Render for BranchList { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() + .key_context("GitBranchSelector") .w(self.width) .on_modifiers_changed(cx.listener(Self::handle_modifiers_changed)) .child(self.picker.clone()) diff --git a/crates/git_ui/src/repository_selector.rs b/crates/git_ui/src/repository_selector.rs index b5865e9a85..db080ab0b4 100644 --- a/crates/git_ui/src/repository_selector.rs +++ b/crates/git_ui/src/repository_selector.rs @@ -109,7 +109,10 @@ impl Focusable for RepositorySelector { impl Render for RepositorySelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - div().w(self.width).child(self.picker.clone()) + div() + .key_context("GitRepositorySelector") + .w(self.width) + .child(self.picker.clone()) } } diff --git a/crates/go_to_line/src/cursor_position.rs b/crates/go_to_line/src/cursor_position.rs index 322a791b13..29064eb29c 100644 --- a/crates/go_to_line/src/cursor_position.rs +++ b/crates/go_to_line/src/cursor_position.rs @@ -308,10 +308,14 @@ impl Settings for LineIndicatorFormat { type FileContent = Option; fn load(sources: SettingsSources, _: &mut App) -> anyhow::Result { - let format = [sources.release_channel, sources.user] - .into_iter() - .find_map(|value| value.copied().flatten()) - .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); + let format = [ + sources.release_channel, + sources.operating_system, + sources.user, + ] + .into_iter() + .find_map(|value| value.copied().flatten()) + .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); Ok(format.0) } diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 2bf49fa7d8..6e5a76d441 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -121,7 +121,7 @@ smallvec.workspace = true smol.workspace = true strum.workspace = true sum_tree.workspace = true -taffy = "=0.8.3" +taffy = "=0.9.0" thiserror.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/gpui/examples/tab_stop.rs b/crates/gpui/examples/tab_stop.rs index 1f6500f3e6..8dbcbeccb7 100644 --- a/crates/gpui/examples/tab_stop.rs +++ b/crates/gpui/examples/tab_stop.rs @@ -111,8 +111,24 @@ impl Render for Example { .flex_row() .gap_3() .items_center() - .child(button("el1").tab_index(4).child("Button 1")) - .child(button("el2").tab_index(5).child("Button 2")), + .child( + button("el1") + .tab_index(4) + .child("Button 1") + .on_click(cx.listener(|this, _, _, cx| { + this.message = "You have clicked Button 1.".into(); + cx.notify(); + })), + ) + .child( + button("el2") + .tab_index(5) + .child("Button 2") + .on_click(cx.listener(|this, _, _, cx| { + this.message = "You have clicked Button 2.".into(); + cx.notify(); + })), + ), ) } } diff --git a/crates/gpui/examples/window_shadow.rs b/crates/gpui/examples/window_shadow.rs index 06dde91133..469017da79 100644 --- a/crates/gpui/examples/window_shadow.rs +++ b/crates/gpui/examples/window_shadow.rs @@ -165,8 +165,8 @@ impl Render for WindowShadow { }, ) .on_click(|e, window, _| { - if e.down.button == MouseButton::Right { - window.show_window_menu(e.up.position); + if e.is_right_click() { + window.show_window_menu(e.position()); } }) .text_color(black()) diff --git a/crates/gpui/src/elements/div.rs b/crates/gpui/src/elements/div.rs index fa47758581..09afbff929 100644 --- a/crates/gpui/src/elements/div.rs +++ b/crates/gpui/src/elements/div.rs @@ -19,10 +19,10 @@ use crate::{ Action, AnyDrag, AnyElement, AnyTooltip, AnyView, App, Bounds, ClickEvent, DispatchPhase, Element, ElementId, Entity, FocusHandle, Global, GlobalElementId, Hitbox, HitboxBehavior, HitboxId, InspectorElementId, IntoElement, IsZero, KeyContext, KeyDownEvent, KeyUpEvent, - LayoutId, ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, - Overflow, ParentElement, Pixels, Point, Render, ScrollWheelEvent, SharedString, Size, Style, - StyleRefinement, Styled, Task, TooltipId, Visibility, Window, WindowControlArea, point, px, - size, + KeyboardButton, KeyboardClickEvent, LayoutId, ModifiersChangedEvent, MouseButton, + MouseClickEvent, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Overflow, ParentElement, Pixels, + Point, Render, ScrollWheelEvent, SharedString, Size, Style, StyleRefinement, Styled, Task, + TooltipId, Visibility, Window, WindowControlArea, point, px, size, }; use collections::HashMap; use refineable::Refineable; @@ -484,10 +484,9 @@ impl Interactivity { where Self: Sized, { - self.click_listeners - .push(Box::new(move |event, window, cx| { - listener(event, window, cx) - })); + self.click_listeners.push(Rc::new(move |event, window, cx| { + listener(event, window, cx) + })); } /// On drag initiation, this callback will be used to create a new view to render the dragged value for a @@ -1156,7 +1155,7 @@ pub(crate) type MouseMoveListener = pub(crate) type ScrollWheelListener = Box; -pub(crate) type ClickListener = Box; +pub(crate) type ClickListener = Rc; pub(crate) type DragListener = Box, &mut Window, &mut App) -> AnyView + 'static>; @@ -1950,6 +1949,12 @@ impl Interactivity { window: &mut Window, cx: &mut App, ) { + let is_focused = self + .tracked_focus_handle + .as_ref() + .map(|handle| handle.is_focused(window)) + .unwrap_or(false); + // If this element can be focused, register a mouse down listener // that will automatically transfer focus when hitting the element. // This behavior can be suppressed by using `cx.prevent_default()`. @@ -2113,6 +2118,39 @@ impl Interactivity { } }); + if is_focused { + // Press enter, space to trigger click, when the element is focused. + window.on_key_event({ + let click_listeners = click_listeners.clone(); + let hitbox = hitbox.clone(); + move |event: &KeyUpEvent, phase, window, cx| { + if phase.bubble() && !window.default_prevented() { + let stroke = &event.keystroke; + let keyboard_button = if stroke.key.eq("enter") { + Some(KeyboardButton::Enter) + } else if stroke.key.eq("space") { + Some(KeyboardButton::Space) + } else { + None + }; + + if let Some(button) = keyboard_button + && !stroke.modifiers.modified() + { + let click_event = ClickEvent::Keyboard(KeyboardClickEvent { + button, + bounds: hitbox.bounds, + }); + + for listener in &click_listeners { + listener(&click_event, window, cx); + } + } + } + } + }); + } + window.on_mouse_event({ let mut captured_mouse_down = None; let hitbox = hitbox.clone(); @@ -2138,10 +2176,10 @@ impl Interactivity { // Fire click handlers during the bubble phase. DispatchPhase::Bubble => { if let Some(mouse_down) = captured_mouse_down.take() { - let mouse_click = ClickEvent { + let mouse_click = ClickEvent::Mouse(MouseClickEvent { down: mouse_down, up: event.clone(), - }; + }); for listener in &click_listeners { listener(&mouse_click, window, cx); } diff --git a/crates/gpui/src/elements/list.rs b/crates/gpui/src/elements/list.rs index 709323ef58..39f38bdc69 100644 --- a/crates/gpui/src/elements/list.rs +++ b/crates/gpui/src/elements/list.rs @@ -18,10 +18,16 @@ use refineable::Refineable as _; use std::{cell::RefCell, ops::Range, rc::Rc}; use sum_tree::{Bias, Dimensions, SumTree}; +type RenderItemFn = dyn FnMut(usize, &mut Window, &mut App) -> AnyElement + 'static; + /// Construct a new list element -pub fn list(state: ListState) -> List { +pub fn list( + state: ListState, + render_item: impl FnMut(usize, &mut Window, &mut App) -> AnyElement + 'static, +) -> List { List { state, + render_item: Box::new(render_item), style: StyleRefinement::default(), sizing_behavior: ListSizingBehavior::default(), } @@ -30,6 +36,7 @@ pub fn list(state: ListState) -> List { /// A list element pub struct List { state: ListState, + render_item: Box, style: StyleRefinement, sizing_behavior: ListSizingBehavior, } @@ -55,7 +62,6 @@ impl std::fmt::Debug for ListState { struct StateInner { last_layout_bounds: Option>, last_padding: Option>, - render_item: Box AnyElement>, items: SumTree, logical_scroll_top: Option, alignment: ListAlignment, @@ -186,19 +192,10 @@ impl ListState { /// above and below the visible area. Elements within this area will /// be measured even though they are not visible. This can help ensure /// that the list doesn't flicker or pop in when scrolling. - pub fn new( - item_count: usize, - alignment: ListAlignment, - overdraw: Pixels, - render_item: R, - ) -> Self - where - R: 'static + FnMut(usize, &mut Window, &mut App) -> AnyElement, - { + pub fn new(item_count: usize, alignment: ListAlignment, overdraw: Pixels) -> Self { let this = Self(Rc::new(RefCell::new(StateInner { last_layout_bounds: None, last_padding: None, - render_item: Box::new(render_item), items: SumTree::default(), logical_scroll_top: None, alignment, @@ -532,6 +529,7 @@ impl StateInner { available_width: Option, available_height: Pixels, padding: &Edges, + render_item: &mut RenderItemFn, window: &mut Window, cx: &mut App, ) -> LayoutItemsResponse { @@ -566,7 +564,7 @@ impl StateInner { // If we're within the visible area or the height wasn't cached, render and measure the item's element if visible_height < available_height || size.is_none() { let item_index = scroll_top.item_ix + ix; - let mut element = (self.render_item)(item_index, window, cx); + let mut element = render_item(item_index, window, cx); let element_size = element.layout_as_root(available_item_space, window, cx); size = Some(element_size); if visible_height < available_height { @@ -601,7 +599,7 @@ impl StateInner { cursor.prev(); if let Some(item) = cursor.item() { let item_index = cursor.start().0; - let mut element = (self.render_item)(item_index, window, cx); + let mut element = render_item(item_index, window, cx); let element_size = element.layout_as_root(available_item_space, window, cx); let focus_handle = item.focus_handle(); rendered_height += element_size.height; @@ -650,7 +648,7 @@ impl StateInner { let size = if let ListItem::Measured { size, .. } = item { *size } else { - let mut element = (self.render_item)(cursor.start().0, window, cx); + let mut element = render_item(cursor.start().0, window, cx); element.layout_as_root(available_item_space, window, cx) }; @@ -683,7 +681,7 @@ impl StateInner { while let Some(item) = cursor.item() { if item.contains_focused(window, cx) { let item_index = cursor.start().0; - let mut element = (self.render_item)(cursor.start().0, window, cx); + let mut element = render_item(cursor.start().0, window, cx); let size = element.layout_as_root(available_item_space, window, cx); item_layouts.push_back(ItemLayout { index: item_index, @@ -708,6 +706,7 @@ impl StateInner { bounds: Bounds, padding: Edges, autoscroll: bool, + render_item: &mut RenderItemFn, window: &mut Window, cx: &mut App, ) -> Result { @@ -716,6 +715,7 @@ impl StateInner { Some(bounds.size.width), bounds.size.height, &padding, + render_item, window, cx, ); @@ -753,8 +753,7 @@ impl StateInner { let Some(item) = cursor.item() else { break }; let size = item.size().unwrap_or_else(|| { - let mut item = - (self.render_item)(cursor.start().0, window, cx); + let mut item = render_item(cursor.start().0, window, cx); let item_available_size = size( bounds.size.width.into(), AvailableSpace::MinContent, @@ -876,8 +875,14 @@ impl Element for List { window.rem_size(), ); - let layout_response = - state.layout_items(None, available_height, &padding, window, cx); + let layout_response = state.layout_items( + None, + available_height, + &padding, + &mut self.render_item, + window, + cx, + ); let max_element_width = layout_response.max_item_width; let summary = state.items.summary(); @@ -951,15 +956,16 @@ impl Element for List { let padding = style .padding .to_pixels(bounds.size.into(), window.rem_size()); - let layout = match state.prepaint_items(bounds, padding, true, window, cx) { - Ok(layout) => layout, - Err(autoscroll_request) => { - state.logical_scroll_top = Some(autoscroll_request); - state - .prepaint_items(bounds, padding, false, window, cx) - .unwrap() - } - }; + let layout = + match state.prepaint_items(bounds, padding, true, &mut self.render_item, window, cx) { + Ok(layout) => layout, + Err(autoscroll_request) => { + state.logical_scroll_top = Some(autoscroll_request); + state + .prepaint_items(bounds, padding, false, &mut self.render_item, window, cx) + .unwrap() + } + }; state.last_layout_bounds = Some(bounds); state.last_padding = Some(padding); @@ -1108,9 +1114,7 @@ mod test { let cx = cx.add_empty_window(); - let state = ListState::new(5, crate::ListAlignment::Top, px(10.), |_, _, _| { - div().h(px(10.)).w_full().into_any() - }); + let state = ListState::new(5, crate::ListAlignment::Top, px(10.)); // Ensure that the list is scrolled to the top state.scroll_to(gpui::ListOffset { @@ -1121,7 +1125,11 @@ mod test { struct TestView(ListState); impl Render for TestView { fn render(&mut self, _: &mut Window, _: &mut Context) -> impl IntoElement { - list(self.0.clone()).w_full().h_full() + list(self.0.clone(), |_, _, _| { + div().h(px(10.)).w_full().into_any() + }) + .w_full() + .h_full() } } @@ -1154,14 +1162,16 @@ mod test { let cx = cx.add_empty_window(); - let state = ListState::new(5, crate::ListAlignment::Top, px(10.), |_, _, _| { - div().h(px(20.)).w_full().into_any() - }); + let state = ListState::new(5, crate::ListAlignment::Top, px(10.)); struct TestView(ListState); impl Render for TestView { fn render(&mut self, _: &mut Window, _: &mut Context) -> impl IntoElement { - list(self.0.clone()).w_full().h_full() + list(self.0.clone(), |_, _, _| { + div().h(px(20.)).w_full().into_any() + }) + .w_full() + .h_full() } } diff --git a/crates/gpui/src/interactive.rs b/crates/gpui/src/interactive.rs index edd807da11..218ae5fcdf 100644 --- a/crates/gpui/src/interactive.rs +++ b/crates/gpui/src/interactive.rs @@ -1,6 +1,6 @@ use crate::{ - Capslock, Context, Empty, IntoElement, Keystroke, Modifiers, Pixels, Point, Render, Window, - point, seal::Sealed, + Bounds, Capslock, Context, Empty, IntoElement, Keystroke, Modifiers, Pixels, Point, Render, + Window, point, seal::Sealed, }; use smallvec::SmallVec; use std::{any::Any, fmt::Debug, ops::Deref, path::PathBuf}; @@ -141,7 +141,7 @@ impl MouseEvent for MouseUpEvent {} /// A click event, generated when a mouse button is pressed and released. #[derive(Clone, Debug, Default)] -pub struct ClickEvent { +pub struct MouseClickEvent { /// The mouse event when the button was pressed. pub down: MouseDownEvent, @@ -149,18 +149,126 @@ pub struct ClickEvent { pub up: MouseUpEvent, } +/// A click event that was generated by a keyboard button being pressed and released. +#[derive(Clone, Debug, Default)] +pub struct KeyboardClickEvent { + /// The keyboard button that was pressed to trigger the click. + pub button: KeyboardButton, + + /// The bounds of the element that was clicked. + pub bounds: Bounds, +} + +/// A click event, generated when a mouse button or keyboard button is pressed and released. +#[derive(Clone, Debug)] +pub enum ClickEvent { + /// A click event trigger by a mouse button being pressed and released. + Mouse(MouseClickEvent), + /// A click event trigger by a keyboard button being pressed and released. + Keyboard(KeyboardClickEvent), +} + +impl Default for ClickEvent { + fn default() -> Self { + ClickEvent::Keyboard(KeyboardClickEvent::default()) + } +} + impl ClickEvent { - /// Returns the modifiers that were held down during both the - /// mouse down and mouse up events + /// Returns the modifiers that were held during the click event + /// + /// `Keyboard`: The keyboard click events never have modifiers. + /// `Mouse`: Modifiers that were held during the mouse key up event. pub fn modifiers(&self) -> Modifiers { - Modifiers { - control: self.up.modifiers.control && self.down.modifiers.control, - alt: self.up.modifiers.alt && self.down.modifiers.alt, - shift: self.up.modifiers.shift && self.down.modifiers.shift, - platform: self.up.modifiers.platform && self.down.modifiers.platform, - function: self.up.modifiers.function && self.down.modifiers.function, + match self { + // Click events are only generated from keyboard events _without any modifiers_, so we know the modifiers are always Default + ClickEvent::Keyboard(_) => Modifiers::default(), + // Click events on the web only reflect the modifiers for the keyup event, + // tested via observing the behavior of the `ClickEvent.shiftKey` field in Chrome 138 + // under various combinations of modifiers and keyUp / keyDown events. + ClickEvent::Mouse(event) => event.up.modifiers, } } + + /// Returns the position of the click event + /// + /// `Keyboard`: The bottom left corner of the clicked hitbox + /// `Mouse`: The position of the mouse when the button was released. + pub fn position(&self) -> Point { + match self { + ClickEvent::Keyboard(event) => event.bounds.bottom_left(), + ClickEvent::Mouse(event) => event.up.position, + } + } + + /// Returns the mouse position of the click event + /// + /// `Keyboard`: None + /// `Mouse`: The position of the mouse when the button was released. + pub fn mouse_position(&self) -> Option> { + match self { + ClickEvent::Keyboard(_) => None, + ClickEvent::Mouse(event) => Some(event.up.position), + } + } + + /// Returns if this was a right click + /// + /// `Keyboard`: false + /// `Mouse`: Whether the right button was pressed and released + pub fn is_right_click(&self) -> bool { + match self { + ClickEvent::Keyboard(_) => false, + ClickEvent::Mouse(event) => { + event.down.button == MouseButton::Right && event.up.button == MouseButton::Right + } + } + } + + /// Returns whether the click was a standard click + /// + /// `Keyboard`: Always true + /// `Mouse`: Left button pressed and released + pub fn standard_click(&self) -> bool { + match self { + ClickEvent::Keyboard(_) => true, + ClickEvent::Mouse(event) => { + event.down.button == MouseButton::Left && event.up.button == MouseButton::Left + } + } + } + + /// Returns whether the click focused the element + /// + /// `Keyboard`: false, keyboard clicks only work if an element is already focused + /// `Mouse`: Whether this was the first focusing click + pub fn first_focus(&self) -> bool { + match self { + ClickEvent::Keyboard(_) => false, + ClickEvent::Mouse(event) => event.down.first_mouse, + } + } + + /// Returns the click count of the click event + /// + /// `Keyboard`: Always 1 + /// `Mouse`: Count of clicks from MouseUpEvent + pub fn click_count(&self) -> usize { + match self { + ClickEvent::Keyboard(_) => 1, + ClickEvent::Mouse(event) => event.up.click_count, + } + } +} + +/// An enum representing the keyboard button that was pressed for a click event. +#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug, Default)] +pub enum KeyboardButton { + /// Enter key was clicked + #[default] + Enter, + /// Space key was clicked + Space, } /// An enum representing the mouse button that was pressed. diff --git a/crates/gpui/src/platform/blade/blade_renderer.rs b/crates/gpui/src/platform/blade/blade_renderer.rs index 2e18d2be22..46d3c16c72 100644 --- a/crates/gpui/src/platform/blade/blade_renderer.rs +++ b/crates/gpui/src/platform/blade/blade_renderer.rs @@ -606,7 +606,7 @@ impl BladeRenderer { xy_position: v.xy_position, st_position: v.st_position, color: path.color, - bounds: path.bounds.intersect(&path.content_mask.bounds), + bounds: path.clipped_bounds(), })); } let vertex_buf = unsafe { self.instance_belt.alloc_typed(&vertices, &self.gpu) }; @@ -735,13 +735,13 @@ impl BladeRenderer { paths .iter() .map(|path| PathSprite { - bounds: path.bounds, + bounds: path.clipped_bounds(), }) .collect() } else { - let mut bounds = first_path.bounds; + let mut bounds = first_path.clipped_bounds(); for path in paths.iter().skip(1) { - bounds = bounds.union(&path.bounds); + bounds = bounds.union(&path.clipped_bounds()); } vec![PathSprite { bounds }] }; diff --git a/crates/gpui/src/platform/mac/metal_renderer.rs b/crates/gpui/src/platform/mac/metal_renderer.rs index fb5cb852d6..629654014d 100644 --- a/crates/gpui/src/platform/mac/metal_renderer.rs +++ b/crates/gpui/src/platform/mac/metal_renderer.rs @@ -791,13 +791,13 @@ impl MetalRenderer { sprites = paths .iter() .map(|path| PathSprite { - bounds: path.bounds, + bounds: path.clipped_bounds(), }) .collect(); } else { - let mut bounds = first_path.bounds; + let mut bounds = first_path.clipped_bounds(); for path in paths.iter().skip(1) { - bounds = bounds.union(&path.bounds); + bounds = bounds.union(&path.clipped_bounds()); } sprites = vec![PathSprite { bounds }]; } diff --git a/crates/gpui/src/platform/windows/directx_renderer.rs b/crates/gpui/src/platform/windows/directx_renderer.rs index 72cc12a5b4..585b1dab1c 100644 --- a/crates/gpui/src/platform/windows/directx_renderer.rs +++ b/crates/gpui/src/platform/windows/directx_renderer.rs @@ -4,15 +4,16 @@ use ::util::ResultExt; use anyhow::{Context, Result}; use windows::{ Win32::{ - Foundation::{HMODULE, HWND}, + Foundation::{FreeLibrary, HMODULE, HWND}, Graphics::{ Direct3D::*, Direct3D11::*, DirectComposition::*, Dxgi::{Common::*, *}, }, + System::LibraryLoader::LoadLibraryA, }, - core::Interface, + core::{Interface, PCSTR}, }; use crate::{ @@ -435,7 +436,7 @@ impl DirectXRenderer { xy_position: v.xy_position, st_position: v.st_position, color: path.color, - bounds: path.bounds.intersect(&path.content_mask.bounds), + bounds: path.clipped_bounds(), })); } @@ -487,13 +488,13 @@ impl DirectXRenderer { paths .iter() .map(|path| PathSprite { - bounds: path.bounds, + bounds: path.clipped_bounds(), }) .collect::>() } else { - let mut bounds = first_path.bounds; + let mut bounds = first_path.clipped_bounds(); for path in paths.iter().skip(1) { - bounds = bounds.union(&path.bounds); + bounds = bounds.union(&path.clipped_bounds()); } vec![PathSprite { bounds }] }; @@ -1618,17 +1619,32 @@ pub(crate) mod shader_resources { } } +fn with_dll_library(dll_name: PCSTR, f: F) -> Result +where + F: FnOnce(HMODULE) -> Result, +{ + let library = unsafe { + LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))? + }; + let result = f(library); + unsafe { + FreeLibrary(library) + .with_context(|| format!("Freeing dll: {}", dll_name.display())) + .log_err(); + } + result +} + mod nvidia { use std::{ ffi::CStr, os::raw::{c_char, c_int, c_uint}, }; - use anyhow::{Context, Result}; - use windows::{ - Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, - core::s, - }; + use anyhow::Result; + use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s}; + + use crate::platform::windows::directx_renderer::with_dll_library; // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180 const NVAPI_SHORT_STRING_MAX: usize = 64; @@ -1645,13 +1661,12 @@ mod nvidia { ) -> c_int; pub(super) fn get_driver_version() -> Result { - unsafe { - // Try to load the NVIDIA driver DLL - #[cfg(target_pointer_width = "64")] - let nvidia_dll = LoadLibraryA(s!("nvapi64.dll")).context("Can't load nvapi64.dll")?; - #[cfg(target_pointer_width = "32")] - let nvidia_dll = LoadLibraryA(s!("nvapi.dll")).context("Can't load nvapi.dll")?; + #[cfg(target_pointer_width = "64")] + let nvidia_dll_name = s!("nvapi64.dll"); + #[cfg(target_pointer_width = "32")] + let nvidia_dll_name = s!("nvapi.dll"); + with_dll_library(nvidia_dll_name, |nvidia_dll| unsafe { let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface")) .ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?; let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr); @@ -1686,18 +1701,17 @@ mod nvidia { minor, branch_string.to_string_lossy() )) - } + }) } } mod amd { use std::os::raw::{c_char, c_int, c_void}; - use anyhow::{Context, Result}; - use windows::{ - Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, - core::s, - }; + use anyhow::Result; + use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s}; + + use crate::platform::windows::directx_renderer::with_dll_library; // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145 const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12); @@ -1731,14 +1745,12 @@ mod amd { type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int; pub(super) fn get_driver_version() -> Result { - unsafe { - #[cfg(target_pointer_width = "64")] - let amd_dll = - LoadLibraryA(s!("amd_ags_x64.dll")).context("Failed to load AMD AGS library")?; - #[cfg(target_pointer_width = "32")] - let amd_dll = - LoadLibraryA(s!("amd_ags_x86.dll")).context("Failed to load AMD AGS library")?; + #[cfg(target_pointer_width = "64")] + let amd_dll_name = s!("amd_ags_x64.dll"); + #[cfg(target_pointer_width = "32")] + let amd_dll_name = s!("amd_ags_x86.dll"); + with_dll_library(amd_dll_name, |amd_dll| unsafe { let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize")) .ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?; let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize")) @@ -1784,7 +1796,7 @@ mod amd { ags_deinitialize(context); Ok(format!("{} ({})", software_version, driver_version)) - } + }) } } diff --git a/crates/gpui/src/platform/windows/events.rs b/crates/gpui/src/platform/windows/events.rs index 00b22fa807..4ab257d27a 100644 --- a/crates/gpui/src/platform/windows/events.rs +++ b/crates/gpui/src/platform/windows/events.rs @@ -174,20 +174,37 @@ impl WindowsWindowInner { let width = lparam.loword().max(1) as i32; let height = lparam.hiword().max(1) as i32; let new_size = size(DevicePixels(width), DevicePixels(height)); + let scale_factor = lock.scale_factor; + let mut should_resize_renderer = false; if lock.restore_from_minimized.is_some() { lock.callbacks.request_frame = lock.restore_from_minimized.take(); } else { - lock.renderer.resize(new_size).log_err(); + should_resize_renderer = true; + } + drop(lock); + + self.handle_size_change(new_size, scale_factor, should_resize_renderer); + Some(0) + } + + fn handle_size_change( + &self, + device_size: Size, + scale_factor: f32, + should_resize_renderer: bool, + ) { + let new_logical_size = device_size.to_pixels(scale_factor); + let mut lock = self.state.borrow_mut(); + lock.logical_size = new_logical_size; + if should_resize_renderer { + lock.renderer.resize(device_size).log_err(); } - let new_size = new_size.to_pixels(scale_factor); - lock.logical_size = new_size; if let Some(mut callback) = lock.callbacks.resize.take() { drop(lock); - callback(new_size, scale_factor); + callback(new_logical_size, scale_factor); self.state.borrow_mut().callbacks.resize = Some(callback); } - Some(0) } fn handle_size_move_loop(&self, handle: HWND) -> Option { @@ -747,7 +764,9 @@ impl WindowsWindowInner { ) -> Option { let new_dpi = wparam.loword() as f32; let mut lock = self.state.borrow_mut(); - lock.scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32; + let is_maximized = lock.is_maximized(); + let new_scale_factor = new_dpi / USER_DEFAULT_SCREEN_DPI as f32; + lock.scale_factor = new_scale_factor; lock.border_offset.update(handle).log_err(); drop(lock); @@ -771,6 +790,13 @@ impl WindowsWindowInner { .log_err(); } + // When maximized, SetWindowPos doesn't send WM_SIZE, so we need to manually + // update the size and call the resize callback + if is_maximized { + let device_size = size(DevicePixels(width), DevicePixels(height)); + self.handle_size_change(device_size, new_scale_factor, true); + } + Some(0) } diff --git a/crates/gpui/src/scene.rs b/crates/gpui/src/scene.rs index ec8d720cdf..c527dfe750 100644 --- a/crates/gpui/src/scene.rs +++ b/crates/gpui/src/scene.rs @@ -8,7 +8,12 @@ use crate::{ AtlasTextureId, AtlasTile, Background, Bounds, ContentMask, Corners, Edges, Hsla, Pixels, Point, Radians, ScaledPixels, Size, bounds_tree::BoundsTree, point, }; -use std::{fmt::Debug, iter::Peekable, ops::Range, slice}; +use std::{ + fmt::Debug, + iter::Peekable, + ops::{Add, Range, Sub}, + slice, +}; #[allow(non_camel_case_types, unused)] pub(crate) type PathVertex_ScaledPixels = PathVertex; @@ -793,6 +798,16 @@ impl Path { } } +impl Path +where + T: Clone + Debug + Default + PartialEq + PartialOrd + Add + Sub, +{ + #[allow(unused)] + pub(crate) fn clipped_bounds(&self) -> Bounds { + self.bounds.intersect(&self.content_mask.bounds) + } +} + impl From> for Primitive { fn from(path: Path) -> Self { Primitive::Path(path) diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index 9e4c1c26c5..40d3845ff9 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -79,11 +79,13 @@ pub enum DispatchPhase { impl DispatchPhase { /// Returns true if this represents the "bubble" phase. + #[inline] pub fn bubble(self) -> bool { self == DispatchPhase::Bubble } /// Returns true if this represents the "capture" phase. + #[inline] pub fn capture(self) -> bool { self == DispatchPhase::Capture } @@ -4246,6 +4248,25 @@ impl Window { .on_action(action_type, Rc::new(listener)); } + /// Register an action listener on the window for the next frame if the condition is true. + /// The type of action is determined by the first parameter of the given listener. + /// When the next frame is rendered the listener will be cleared. + /// + /// This is a fairly low-level method, so prefer using action handlers on elements unless you have + /// a specific need to register a global listener. + pub fn on_action_when( + &mut self, + condition: bool, + action_type: TypeId, + listener: impl Fn(&dyn Any, DispatchPhase, &mut Window, &mut App) + 'static, + ) { + if condition { + self.next_frame + .dispatch_tree + .on_action(action_type, Rc::new(listener)); + } + } + /// Read information about the GPU backing this window. /// Currently returns None on Mac and Windows. pub fn gpu_specs(&self) -> Option { diff --git a/crates/http_client/src/github.rs b/crates/http_client/src/github.rs index a038915e2f..a19c13b0ff 100644 --- a/crates/http_client/src/github.rs +++ b/crates/http_client/src/github.rs @@ -8,6 +8,7 @@ use url::Url; pub struct GitHubLspBinaryVersion { pub name: String, pub url: String, + pub digest: Option, } #[derive(Deserialize, Debug)] @@ -24,6 +25,7 @@ pub struct GithubRelease { pub struct GithubReleaseAsset { pub name: String, pub browser_download_url: String, + pub digest: Option, } pub async fn latest_github_release( diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index a94d89bdc8..12805e62e0 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -261,7 +261,6 @@ pub enum IconName { TodoComplete, TodoPending, TodoProgress, - ToolBulb, ToolCopy, ToolDeleteFile, ToolDiagnostics, @@ -273,6 +272,7 @@ pub enum IconName { ToolRegex, ToolSearch, ToolTerminal, + ToolThink, ToolWeb, Trash, Triangle, diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 894625b982..b9933dfcec 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -2353,9 +2353,9 @@ mod tests { assert_eq!( languages.language_names(), &[ - "JSON".to_string(), - "Plain Text".to_string(), - "Rust".to_string(), + LanguageName::new("JSON"), + LanguageName::new("Plain Text"), + LanguageName::new("Rust"), ] ); @@ -2366,9 +2366,9 @@ mod tests { assert_eq!( languages.language_names(), &[ - "JSON".to_string(), - "Plain Text".to_string(), - "Rust".to_string(), + LanguageName::new("JSON"), + LanguageName::new("Plain Text"), + LanguageName::new("Rust"), ] ); @@ -2379,9 +2379,9 @@ mod tests { assert_eq!( languages.language_names(), &[ - "JSON".to_string(), - "Plain Text".to_string(), - "Rust".to_string(), + LanguageName::new("JSON"), + LanguageName::new("Plain Text"), + LanguageName::new("Rust"), ] ); diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index 85123d2373..ea988e8098 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -547,15 +547,15 @@ impl LanguageRegistry { self.state.read().language_settings.clone() } - pub fn language_names(&self) -> Vec { + pub fn language_names(&self) -> Vec { let state = self.state.read(); let mut result = state .available_languages .iter() - .filter_map(|l| l.loaded.not().then_some(l.name.to_string())) - .chain(state.languages.iter().map(|l| l.config.name.to_string())) + .filter_map(|l| l.loaded.not().then_some(l.name.clone())) + .chain(state.languages.iter().map(|l| l.config.name.clone())) .collect::>(); - result.sort_unstable_by_key(|language_name| language_name.to_lowercase()); + result.sort_unstable_by_key(|language_name| language_name.as_ref().to_lowercase()); result } diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 841be60b0e..f9920623b5 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,6 +20,7 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true base64.workspace = true client.workspace = true +cloud_api_types.workspace = true cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index d54db7554a..a9c7d5c034 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -92,7 +92,12 @@ pub struct ToolUseRequest { pub struct FakeLanguageModel { provider_id: LanguageModelProviderId, provider_name: LanguageModelProviderName, - current_completion_txs: Mutex)>>, + current_completion_txs: Mutex< + Vec<( + LanguageModelRequest, + mpsc::UnboundedSender, + )>, + >, } impl Default for FakeLanguageModel { @@ -118,10 +123,21 @@ impl FakeLanguageModel { self.current_completion_txs.lock().len() } - pub fn stream_completion_response( + pub fn send_completion_stream_text_chunk( &self, request: &LanguageModelRequest, chunk: impl Into, + ) { + self.send_completion_stream_event( + request, + LanguageModelCompletionEvent::Text(chunk.into()), + ); + } + + pub fn send_completion_stream_event( + &self, + request: &LanguageModelRequest, + event: impl Into, ) { let current_completion_txs = self.current_completion_txs.lock(); let tx = current_completion_txs @@ -129,7 +145,7 @@ impl FakeLanguageModel { .find(|(req, _)| req == request) .map(|(_, tx)| tx) .unwrap(); - tx.unbounded_send(chunk.into()).unwrap(); + tx.unbounded_send(event.into()).unwrap(); } pub fn end_completion_stream(&self, request: &LanguageModelRequest) { @@ -138,8 +154,15 @@ impl FakeLanguageModel { .retain(|(req, _)| req != request); } - pub fn stream_last_completion_response(&self, chunk: impl Into) { - self.stream_completion_response(self.pending_completions().last().unwrap(), chunk); + pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into) { + self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk); + } + + pub fn send_last_completion_stream_event( + &self, + event: impl Into, + ) { + self.send_completion_stream_event(self.pending_completions().last().unwrap(), event); } pub fn end_last_completion_stream(&self) { @@ -201,12 +224,7 @@ impl LanguageModel for FakeLanguageModel { > { let (tx, rx) = mpsc::unbounded(); self.current_completion_txs.lock().push((request, tx)); - async move { - Ok(rx - .map(|text| Ok(LanguageModelCompletionEvent::Text(text))) - .boxed()) - } - .boxed() + async move { Ok(rx.map(Ok).boxed()) }.boxed() } fn as_fake(&self) -> &Self { diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 8ae5893410..3b4c1fa269 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,11 +3,9 @@ use std::sync::Arc; use anyhow::Result; use client::Client; +use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::Plan; -use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, -}; -use proto::TypedEnvelope; +use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -82,9 +80,7 @@ impl Global for GlobalRefreshLlmTokenListener {} pub struct RefreshLlmTokenEvent; -pub struct RefreshLlmTokenListener { - _llm_token_subscription: client::Subscription, -} +pub struct RefreshLlmTokenListener; impl EventEmitter for RefreshLlmTokenListener {} @@ -99,17 +95,21 @@ impl RefreshLlmTokenListener { } fn new(client: Arc, cx: &mut Context) -> Self { - Self { - _llm_token_subscription: client - .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token), - } + client.add_message_to_client_handler({ + let this = cx.entity(); + move |message, cx| { + Self::handle_refresh_llm_token(this.clone(), message, cx); + } + }); + + Self } - async fn handle_refresh_llm_token( - this: Entity, - _: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)) + fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { + match message { + MessageToClient::UserUpdated => { + this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)); + } + } } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index dc485e9937..edce3d03b7 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -297,6 +297,12 @@ impl From for LanguageModelToolResultContent { } } +impl From for LanguageModelToolResultContent { + fn from(image: LanguageModelImage) -> Self { + Self::Image(image) + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub enum MessageContent { Text(String), diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 2108547c4f..40dd120761 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -136,6 +136,7 @@ impl State { cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); + let mut current_user = user_store.read(cx).watch_current_user(); Self { client: client.clone(), llm_api_token: LlmApiToken::default(), @@ -151,22 +152,14 @@ impl State { let (client, llm_api_token) = this .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; - loop { - let is_authenticated = user_store - .read_with(cx, |user_store, _cx| user_store.current_user().is_some())?; - if is_authenticated { - break; - } - - cx.background_executor() - .timer(Duration::from_millis(100)) - .await; + while current_user.borrow().is_none() { + current_user.next().await; } - let response = Self::fetch_models(client, llm_api_token).await?; - this.update(cx, |this, cx| { - this.update_models(response, cx); - }) + let response = + Self::fetch_models(client.clone(), llm_api_token.clone()).await?; + this.update(cx, |this, cx| this.update_models(response, cx))?; + anyhow::Ok(()) }) .await .context("failed to fetch Zed models") @@ -1267,8 +1260,16 @@ impl Render for ConfigurationView { } impl Component for ZedAiConfiguration { + fn name() -> &'static str { + "AI Configuration Content" + } + + fn sort_name() -> &'static str { + "AI Configuration Content" + } + fn scope() -> ComponentScope { - ComponentScope::Agent + ComponentScope::Onboarding } fn preview(_window: &mut Window, _cx: &mut App) -> Option { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 5185e979b7..ee74562687 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -674,6 +674,10 @@ pub fn count_open_ai_tokens( | Model::O3 | Model::O3Mini | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), + // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer + Model::Five | Model::FiveMini | Model::FiveNano => { + tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages) + } } .map(|tokens| tokens as u64) }) diff --git a/crates/language_selector/src/active_buffer_language.rs b/crates/language_selector/src/active_buffer_language.rs index 250d0c23d8..c5c5eceab5 100644 --- a/crates/language_selector/src/active_buffer_language.rs +++ b/crates/language_selector/src/active_buffer_language.rs @@ -1,8 +1,9 @@ -use editor::Editor; +use editor::{Editor, EditorSettings}; use gpui::{ Context, Entity, IntoElement, ParentElement, Render, Subscription, WeakEntity, Window, div, }; use language::LanguageName; +use settings::Settings as _; use ui::{Button, ButtonCommon, Clickable, FluentBuilder, LabelSize, Tooltip}; use workspace::{StatusItemView, Workspace, item::ItemHandle}; @@ -39,6 +40,13 @@ impl ActiveBufferLanguage { impl Render for ActiveBufferLanguage { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + if !EditorSettings::get_global(cx) + .status_bar + .active_language_button + { + return div(); + } + div().when_some(self.active_language.as_ref(), |el, active_language| { let active_language_text = if let Some(active_language_text) = active_language { active_language_text.to_string() diff --git a/crates/language_selector/src/language_selector.rs b/crates/language_selector/src/language_selector.rs index 4c03430553..f6e2d75015 100644 --- a/crates/language_selector/src/language_selector.rs +++ b/crates/language_selector/src/language_selector.rs @@ -86,7 +86,10 @@ impl LanguageSelector { impl Render for LanguageSelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("LanguageSelector") + .w(rems(34.)) + .child(self.picker.clone()) } } @@ -121,13 +124,13 @@ impl LanguageSelectorDelegate { .into_iter() .filter_map(|name| { language_registry - .available_language_for_name(&name)? + .available_language_for_name(name.as_ref())? .hidden() .not() .then_some(name) }) .enumerate() - .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, &name)) + .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, name.as_ref())) .collect::>(); Self { diff --git a/crates/languages/Cargo.toml b/crates/languages/Cargo.toml index 260126da63..8e25818070 100644 --- a/crates/languages/Cargo.toml +++ b/crates/languages/Cargo.toml @@ -36,6 +36,7 @@ load-grammars = [ [dependencies] anyhow.workspace = true async-compression.workspace = true +async-fs.workspace = true async-tar.workspace = true async-trait.workspace = true chrono.workspace = true @@ -62,6 +63,7 @@ regex.workspace = true rope.workspace = true rust-embed.workspace = true schemars.workspace = true +sha2.workspace = true serde.workspace = true serde_json.workspace = true serde_json_lenient.workspace = true @@ -69,6 +71,7 @@ settings.workspace = true smol.workspace = true snippet_provider.workspace = true task.workspace = true +tempfile.workspace = true toml.workspace = true tree-sitter = { workspace = true, optional = true } tree-sitter-bash = { workspace = true, optional = true } diff --git a/crates/languages/src/c.rs b/crates/languages/src/c.rs index c06c35ee69..df93e51760 100644 --- a/crates/languages/src/c.rs +++ b/crates/languages/src/c.rs @@ -2,14 +2,16 @@ use anyhow::{Context as _, Result, bail}; use async_trait::async_trait; use futures::StreamExt; use gpui::{App, AsyncApp}; -use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; +use http_client::github::{AssetKind, GitHubLspBinaryVersion, latest_github_release}; pub use language::*; use lsp::{InitializeParams, LanguageServerBinary, LanguageServerName}; use project::lsp_store::clangd_ext; use serde_json::json; use smol::fs; use std::{any::Any, env::consts, path::PathBuf, sync::Arc}; -use util::{ResultExt, archive::extract_zip, fs::remove_matching, maybe, merge_json_value_into}; +use util::{ResultExt, fs::remove_matching, maybe, merge_json_value_into}; + +use crate::github_download::{GithubBinaryMetadata, download_server_binary}; pub struct CLspAdapter; @@ -58,6 +60,7 @@ impl super::LspAdapter for CLspAdapter { let version = GitHubLspBinaryVersion { name: release.tag_name, url: asset.browser_download_url.clone(), + digest: asset.digest.clone(), }; Ok(Box::new(version) as Box<_>) } @@ -68,32 +71,72 @@ impl super::LspAdapter for CLspAdapter { container_dir: PathBuf, delegate: &dyn LspAdapterDelegate, ) -> Result { - let version = version.downcast::().unwrap(); - let version_dir = container_dir.join(format!("clangd_{}", version.name)); + let GitHubLspBinaryVersion { name, url, digest } = + &*version.downcast::().unwrap(); + let version_dir = container_dir.join(format!("clangd_{name}")); let binary_path = version_dir.join("bin/clangd"); + let expected_digest = digest + .as_ref() + .and_then(|digest| digest.strip_prefix("sha256:")); - if fs::metadata(&binary_path).await.is_err() { - let mut response = delegate - .http_client() - .get(&version.url, Default::default(), true) - .await - .context("error downloading release")?; - anyhow::ensure!( - response.status().is_success(), - "download failed with status {}", - response.status().to_string() - ); - extract_zip(&container_dir, response.body_mut()) - .await - .with_context(|| format!("unzipping clangd archive to {container_dir:?}"))?; - remove_matching(&container_dir, |entry| entry != version_dir).await; - } - - Ok(LanguageServerBinary { - path: binary_path, + let binary = LanguageServerBinary { + path: binary_path.clone(), env: None, - arguments: Vec::new(), - }) + arguments: Default::default(), + }; + + let metadata_path = version_dir.join("metadata"); + let metadata = GithubBinaryMetadata::read_from_file(&metadata_path) + .await + .ok(); + if let Some(metadata) = metadata { + let validity_check = async || { + delegate + .try_exec(LanguageServerBinary { + path: binary_path.clone(), + arguments: vec!["--version".into()], + env: None, + }) + .await + .inspect_err(|err| { + log::warn!("Unable to run {binary_path:?} asset, redownloading: {err}",) + }) + }; + if let (Some(actual_digest), Some(expected_digest)) = + (&metadata.digest, expected_digest) + { + if actual_digest == expected_digest { + if validity_check().await.is_ok() { + return Ok(binary); + } + } else { + log::info!( + "SHA-256 mismatch for {binary_path:?} asset, downloading new asset. Expected: {expected_digest}, Got: {actual_digest}" + ); + } + } else if validity_check().await.is_ok() { + return Ok(binary); + } + } + download_server_binary( + delegate, + url, + digest.as_deref(), + &container_dir, + AssetKind::Zip, + ) + .await?; + remove_matching(&container_dir, |entry| entry != version_dir).await; + GithubBinaryMetadata::write_to_file( + &GithubBinaryMetadata { + metadata_version: 1, + digest: digest.clone(), + }, + &metadata_path, + ) + .await?; + + Ok(binary) } async fn cached_server_binary( diff --git a/crates/languages/src/cpp/config.toml b/crates/languages/src/cpp/config.toml index fab88266d7..7e24415f9d 100644 --- a/crates/languages/src/cpp/config.toml +++ b/crates/languages/src/cpp/config.toml @@ -1,6 +1,6 @@ name = "C++" grammar = "cpp" -path_suffixes = ["cc", "hh", "cpp", "h", "hpp", "cxx", "hxx", "c++", "ipp", "inl", "ixx", "cu", "cuh", "C", "H"] +path_suffixes = ["cc", "hh", "cpp", "h", "hpp", "cxx", "hxx", "c++", "ipp", "inl", "ino", "ixx", "cu", "cuh", "C", "H"] line_comments = ["// ", "/// ", "//! "] decrease_indent_patterns = [ { pattern = "^\\s*\\{.*\\}?\\s*$", valid_after = ["if", "for", "while", "do", "switch", "else"] }, diff --git a/crates/languages/src/css.rs b/crates/languages/src/css.rs index f2a94809a0..7725e079be 100644 --- a/crates/languages/src/css.rs +++ b/crates/languages/src/css.rs @@ -5,7 +5,7 @@ use gpui::AsyncApp; use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; -use project::Fs; +use project::{Fs, lsp_store::language_server_settings}; use serde_json::json; use smol::fs; use std::{ @@ -14,7 +14,7 @@ use std::{ path::{Path, PathBuf}, sync::Arc, }; -use util::{ResultExt, maybe}; +use util::{ResultExt, maybe, merge_json_value_into}; const SERVER_PATH: &str = "node_modules/vscode-langservers-extracted/bin/vscode-css-language-server"; @@ -134,6 +134,37 @@ impl LspAdapter for CssLspAdapter { "provideFormatter": true }))) } + + async fn workspace_configuration( + self: Arc, + _: &dyn Fs, + delegate: &Arc, + _: Arc, + cx: &mut AsyncApp, + ) -> Result { + let mut default_config = json!({ + "css": { + "lint": {} + }, + "less": { + "lint": {} + }, + "scss": { + "lint": {} + } + }); + + let project_options = cx.update(|cx| { + language_server_settings(delegate.as_ref(), &self.name(), cx) + .and_then(|s| s.settings.clone()) + })?; + + if let Some(override_options) = project_options { + merge_json_value_into(override_options, &mut default_config); + } + + Ok(default_config) + } } async fn get_cached_server_binary( diff --git a/crates/languages/src/github_download.rs b/crates/languages/src/github_download.rs new file mode 100644 index 0000000000..a3cd0a964b --- /dev/null +++ b/crates/languages/src/github_download.rs @@ -0,0 +1,190 @@ +use std::{path::Path, pin::Pin, task::Poll}; + +use anyhow::{Context, Result}; +use async_compression::futures::bufread::GzipDecoder; +use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader}; +use http_client::github::AssetKind; +use language::LspAdapterDelegate; +use sha2::{Digest, Sha256}; + +#[derive(serde::Deserialize, serde::Serialize, Debug)] +pub(crate) struct GithubBinaryMetadata { + pub(crate) metadata_version: u64, + pub(crate) digest: Option, +} + +impl GithubBinaryMetadata { + pub(crate) async fn read_from_file(metadata_path: &Path) -> Result { + let metadata_content = async_fs::read_to_string(metadata_path) + .await + .with_context(|| format!("reading metadata file at {metadata_path:?}"))?; + let metadata: GithubBinaryMetadata = serde_json::from_str(&metadata_content) + .with_context(|| format!("parsing metadata file at {metadata_path:?}"))?; + Ok(metadata) + } + + pub(crate) async fn write_to_file(&self, metadata_path: &Path) -> Result<()> { + let metadata_content = serde_json::to_string(self) + .with_context(|| format!("serializing metadata for {metadata_path:?}"))?; + async_fs::write(metadata_path, metadata_content.as_bytes()) + .await + .with_context(|| format!("writing metadata file at {metadata_path:?}"))?; + Ok(()) + } +} + +pub(crate) async fn download_server_binary( + delegate: &dyn LspAdapterDelegate, + url: &str, + digest: Option<&str>, + destination_path: &Path, + asset_kind: AssetKind, +) -> Result<(), anyhow::Error> { + log::info!("downloading github artifact from {url}"); + let mut response = delegate + .http_client() + .get(url, Default::default(), true) + .await + .with_context(|| format!("downloading release from {url}"))?; + let body = response.body_mut(); + match digest { + Some(expected_sha_256) => { + let temp_asset_file = tempfile::NamedTempFile::new() + .with_context(|| format!("creating a temporary file for {url}"))?; + let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts(); + let mut writer = HashingWriter { + writer: async_fs::File::from(temp_asset_file), + hasher: Sha256::new(), + }; + futures::io::copy(&mut BufReader::new(body), &mut writer) + .await + .with_context(|| { + format!("saving archive contents into the temporary file for {url}",) + })?; + let asset_sha_256 = format!("{:x}", writer.hasher.finalize()); + anyhow::ensure!( + asset_sha_256 == expected_sha_256, + "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}", + ); + writer + .writer + .seek(std::io::SeekFrom::Start(0)) + .await + .with_context(|| format!("seeking temporary file {destination_path:?}",))?; + stream_file_archive(&mut writer.writer, url, destination_path, asset_kind) + .await + .with_context(|| { + format!("extracting downloaded asset for {url} into {destination_path:?}",) + })?; + } + None => stream_response_archive(body, url, destination_path, asset_kind) + .await + .with_context(|| { + format!("extracting response for asset {url} into {destination_path:?}",) + })?, + } + Ok(()) +} + +async fn stream_response_archive( + response: impl AsyncRead + Unpin, + url: &str, + destination_path: &Path, + asset_kind: AssetKind, +) -> Result<()> { + match asset_kind { + AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?, + AssetKind::Gz => extract_gz(destination_path, url, response).await?, + AssetKind::Zip => { + util::archive::extract_zip(&destination_path, response).await?; + } + }; + Ok(()) +} + +async fn stream_file_archive( + file_archive: impl AsyncRead + AsyncSeek + Unpin, + url: &str, + destination_path: &Path, + asset_kind: AssetKind, +) -> Result<()> { + match asset_kind { + AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?, + AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?, + #[cfg(not(windows))] + AssetKind::Zip => { + util::archive::extract_seekable_zip(&destination_path, file_archive).await?; + } + #[cfg(windows)] + AssetKind::Zip => { + util::archive::extract_zip(&destination_path, file_archive).await?; + } + }; + Ok(()) +} + +async fn extract_tar_gz( + destination_path: &Path, + url: &str, + from: impl AsyncRead + Unpin, +) -> Result<(), anyhow::Error> { + let decompressed_bytes = GzipDecoder::new(BufReader::new(from)); + let archive = async_tar::Archive::new(decompressed_bytes); + archive + .unpack(&destination_path) + .await + .with_context(|| format!("extracting {url} to {destination_path:?}"))?; + Ok(()) +} + +async fn extract_gz( + destination_path: &Path, + url: &str, + from: impl AsyncRead + Unpin, +) -> Result<(), anyhow::Error> { + let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from)); + let mut file = smol::fs::File::create(&destination_path) + .await + .with_context(|| { + format!("creating a file {destination_path:?} for a download from {url}") + })?; + futures::io::copy(&mut decompressed_bytes, &mut file) + .await + .with_context(|| format!("extracting {url} to {destination_path:?}"))?; + Ok(()) +} + +struct HashingWriter { + writer: W, + hasher: Sha256, +} + +impl AsyncWrite for HashingWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match Pin::new(&mut self.writer).poll_write(cx, buf) { + Poll::Ready(Ok(n)) => { + self.hasher.update(&buf[..n]); + Poll::Ready(Ok(n)) + } + other => other, + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.writer).poll_flush(cx) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.writer).poll_close(cx) + } +} diff --git a/crates/languages/src/javascript/highlights.scm b/crates/languages/src/javascript/highlights.scm index 73cb1a5e45..9d5ebbaf71 100644 --- a/crates/languages/src/javascript/highlights.scm +++ b/crates/languages/src/javascript/highlights.scm @@ -146,6 +146,7 @@ "&&=" "||=" "??=" + "..." ] @operator (regex "/" @string.regex) diff --git a/crates/languages/src/json.rs b/crates/languages/src/json.rs index 601b4620c5..ca82bb2431 100644 --- a/crates/languages/src/json.rs +++ b/crates/languages/src/json.rs @@ -269,7 +269,15 @@ impl JsonLspAdapter { .await; let config = cx.update(|cx| { - Self::get_workspace_config(self.languages.language_names().clone(), adapter_schemas, cx) + Self::get_workspace_config( + self.languages + .language_names() + .into_iter() + .map(|name| name.to_string()) + .collect(), + adapter_schemas, + cx, + ) })?; writer.replace(config.clone()); return Ok(config); @@ -509,6 +517,7 @@ impl LspAdapter for NodeVersionAdapter { Ok(Box::new(GitHubLspBinaryVersion { name: release.tag_name, url: asset.browser_download_url.clone(), + digest: asset.digest.clone(), })) } diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index 001fd15200..195ba79e1d 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -17,6 +17,7 @@ use crate::{json::JsonTaskProvider, python::BasedPyrightLspAdapter}; mod bash; mod c; mod css; +mod github_download; mod go; mod json; mod package_json; diff --git a/crates/languages/src/rust.rs b/crates/languages/src/rust.rs index 3f83c9c000..b52b1e7d55 100644 --- a/crates/languages/src/rust.rs +++ b/crates/languages/src/rust.rs @@ -1,8 +1,7 @@ use anyhow::{Context as _, Result}; -use async_compression::futures::bufread::GzipDecoder; use async_trait::async_trait; use collections::HashMap; -use futures::{StreamExt, io::BufReader}; +use futures::StreamExt; use gpui::{App, AppContext, AsyncApp, SharedString, Task}; use http_client::github::AssetKind; use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; @@ -16,6 +15,7 @@ use serde_json::json; use settings::Settings as _; use smol::fs::{self}; use std::fmt::Display; +use std::ops::Range; use std::{ any::Any, borrow::Cow, @@ -23,14 +23,11 @@ use std::{ sync::{Arc, LazyLock}, }; use task::{TaskTemplate, TaskTemplates, TaskVariables, VariableName}; -use util::archive::extract_zip; +use util::fs::make_file_executable; use util::merge_json_value_into; -use util::{ - ResultExt, - fs::{make_file_executable, remove_matching}, - maybe, -}; +use util::{ResultExt, maybe}; +use crate::github_download::{GithubBinaryMetadata, download_server_binary}; use crate::language_settings::language_settings; pub struct RustLspAdapter; @@ -163,7 +160,6 @@ impl LspAdapter for RustLspAdapter { ) .await?; let asset_name = Self::build_asset_name(); - let asset = release .assets .iter() @@ -172,6 +168,7 @@ impl LspAdapter for RustLspAdapter { Ok(Box::new(GitHubLspBinaryVersion { name: release.tag_name, url: asset.browser_download_url.clone(), + digest: asset.digest.clone(), })) } @@ -181,58 +178,76 @@ impl LspAdapter for RustLspAdapter { container_dir: PathBuf, delegate: &dyn LspAdapterDelegate, ) -> Result { - let version = version.downcast::().unwrap(); - let destination_path = container_dir.join(format!("rust-analyzer-{}", version.name)); + let GitHubLspBinaryVersion { name, url, digest } = + &*version.downcast::().unwrap(); + let expected_digest = digest + .as_ref() + .and_then(|digest| digest.strip_prefix("sha256:")); + let destination_path = container_dir.join(format!("rust-analyzer-{name}")); let server_path = match Self::GITHUB_ASSET_KIND { AssetKind::TarGz | AssetKind::Gz => destination_path.clone(), // Tar and gzip extract in place. AssetKind::Zip => destination_path.clone().join("rust-analyzer.exe"), // zip contains a .exe }; - if fs::metadata(&server_path).await.is_err() { - remove_matching(&container_dir, |entry| entry != destination_path).await; + let binary = LanguageServerBinary { + path: server_path.clone(), + env: None, + arguments: Default::default(), + }; - let mut response = delegate - .http_client() - .get(&version.url, Default::default(), true) - .await - .with_context(|| format!("downloading release from {}", version.url))?; - match Self::GITHUB_ASSET_KIND { - AssetKind::TarGz => { - let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); - let archive = async_tar::Archive::new(decompressed_bytes); - archive.unpack(&destination_path).await.with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Gz => { - let mut decompressed_bytes = - GzipDecoder::new(BufReader::new(response.body_mut())); - let mut file = - fs::File::create(&destination_path).await.with_context(|| { - format!( - "creating a file {:?} for a download from {}", - destination_path, version.url, - ) - })?; - futures::io::copy(&mut decompressed_bytes, &mut file) - .await - .with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Zip => { - extract_zip(&destination_path, response.body_mut()) - .await - .with_context(|| { - format!("unzipping {} to {:?}", version.url, destination_path) - })?; - } + let metadata_path = destination_path.with_extension("metadata"); + let metadata = GithubBinaryMetadata::read_from_file(&metadata_path) + .await + .ok(); + if let Some(metadata) = metadata { + let validity_check = async || { + delegate + .try_exec(LanguageServerBinary { + path: server_path.clone(), + arguments: vec!["--version".into()], + env: None, + }) + .await + .inspect_err(|err| { + log::warn!("Unable to run {server_path:?} asset, redownloading: {err}",) + }) }; - - // todo("windows") - make_file_executable(&server_path).await?; + if let (Some(actual_digest), Some(expected_digest)) = + (&metadata.digest, expected_digest) + { + if actual_digest == expected_digest { + if validity_check().await.is_ok() { + return Ok(binary); + } + } else { + log::info!( + "SHA-256 mismatch for {destination_path:?} asset, downloading new asset. Expected: {expected_digest}, Got: {actual_digest}" + ); + } + } else if validity_check().await.is_ok() { + return Ok(binary); + } } + _ = fs::remove_dir_all(&destination_path).await; + download_server_binary( + delegate, + url, + expected_digest, + &destination_path, + Self::GITHUB_ASSET_KIND, + ) + .await?; + make_file_executable(&server_path).await?; + GithubBinaryMetadata::write_to_file( + &GithubBinaryMetadata { + metadata_version: 1, + digest: expected_digest.map(ToString::to_string), + }, + &metadata_path, + ) + .await?; + Ok(LanguageServerBinary { path: server_path, env: None, @@ -291,66 +306,62 @@ impl LspAdapter for RustLspAdapter { completion: &lsp::CompletionItem, language: &Arc, ) -> Option { - let detail = completion + // rust-analyzer calls these detail left and detail right in terms of where it expects things to be rendered + // this usually contains signatures of the thing to be completed + let detail_right = completion .label_details .as_ref() - .and_then(|detail| detail.detail.as_ref()) + .and_then(|detail| detail.description.as_ref()) .or(completion.detail.as_ref()) .map(|detail| detail.trim()); - let function_signature = completion + // this tends to contain alias and import information + let detail_left = completion .label_details .as_ref() - .and_then(|detail| detail.description.as_deref()) - .or(completion.detail.as_deref()); - match (detail, completion.kind) { - (Some(detail), Some(lsp::CompletionItemKind::FIELD)) => { + .and_then(|detail| detail.detail.as_deref()); + let mk_label = |text: String, filter_range: &dyn Fn() -> Range, runs| { + let filter_range = completion + .filter_text + .as_deref() + .and_then(|filter| text.find(filter).map(|ix| ix..ix + filter.len())) + .or_else(|| { + text.find(&completion.label) + .map(|ix| ix..ix + completion.label.len()) + }) + .unwrap_or_else(filter_range); + + CodeLabel { + text, + runs, + filter_range, + } + }; + let mut label = match (detail_right, completion.kind) { + (Some(signature), Some(lsp::CompletionItemKind::FIELD)) => { let name = &completion.label; - let text = format!("{name}: {detail}"); + let text = format!("{name}: {signature}"); let prefix = "struct S { "; - let source = Rope::from(format!("{prefix}{text} }}")); + let source = Rope::from_iter([prefix, &text, " }"]); let runs = language.highlight_text(&source, prefix.len()..prefix.len() + text.len()); - let filter_range = completion - .filter_text - .as_deref() - .and_then(|filter| text.find(filter).map(|ix| ix..ix + filter.len())) - .unwrap_or(0..name.len()); - return Some(CodeLabel { - text, - runs, - filter_range, - }); + mk_label(text, &|| 0..completion.label.len(), runs) } ( - Some(detail), + Some(signature), Some(lsp::CompletionItemKind::CONSTANT | lsp::CompletionItemKind::VARIABLE), ) if completion.insert_text_format != Some(lsp::InsertTextFormat::SNIPPET) => { let name = &completion.label; - let text = format!( - "{}: {}", - name, - completion.detail.as_deref().unwrap_or(detail) - ); + let text = format!("{name}: {signature}",); let prefix = "let "; - let source = Rope::from(format!("{prefix}{text} = ();")); + let source = Rope::from_iter([prefix, &text, " = ();"]); let runs = language.highlight_text(&source, prefix.len()..prefix.len() + text.len()); - let filter_range = completion - .filter_text - .as_deref() - .and_then(|filter| text.find(filter).map(|ix| ix..ix + filter.len())) - .unwrap_or(0..name.len()); - return Some(CodeLabel { - text, - runs, - filter_range, - }); + mk_label(text, &|| 0..completion.label.len(), runs) } ( - Some(detail), + function_signature, Some(lsp::CompletionItemKind::FUNCTION | lsp::CompletionItemKind::METHOD), ) => { - static REGEX: LazyLock = LazyLock::new(|| Regex::new("\\(…?\\)").unwrap()); const FUNCTION_PREFIXES: [&str; 6] = [ "async fn", "async unsafe fn", @@ -359,34 +370,40 @@ impl LspAdapter for RustLspAdapter { "unsafe fn", "fn", ]; - // Is it function `async`? - let fn_keyword = FUNCTION_PREFIXES.iter().find_map(|prefix| { - function_signature.as_ref().and_then(|signature| { - signature - .strip_prefix(*prefix) - .map(|suffix| (*prefix, suffix)) - }) + let fn_prefixed = FUNCTION_PREFIXES.iter().find_map(|&prefix| { + function_signature? + .strip_prefix(prefix) + .map(|suffix| (prefix, suffix)) }); - // fn keyword should be followed by opening parenthesis. - if let Some((prefix, suffix)) = fn_keyword { - let mut text = REGEX.replace(&completion.label, suffix).to_string(); - let source = Rope::from(format!("{prefix} {text} {{}}")); + let label = if let Some(label) = completion + .label + .strip_suffix("(…)") + .or_else(|| completion.label.strip_suffix("()")) + { + label + } else { + &completion.label + }; + + static FULL_SIGNATURE_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"fn (.?+)\(").expect("Failed to create REGEX")); + if let Some((function_signature, match_)) = function_signature + .filter(|it| it.contains(&label)) + .and_then(|it| Some((it, FULL_SIGNATURE_REGEX.find(it)?))) + { + let source = Rope::from(function_signature); + let runs = language.highlight_text(&source, 0..function_signature.len()); + mk_label( + function_signature.to_owned(), + &|| match_.range().start - 3..match_.range().end - 1, + runs, + ) + } else if let Some((prefix, suffix)) = fn_prefixed { + let text = format!("{label}{suffix}"); + let source = Rope::from_iter([prefix, " ", &text, " {}"]); let run_start = prefix.len() + 1; let runs = language.highlight_text(&source, run_start..run_start + text.len()); - if detail.starts_with("(") { - text.push(' '); - text.push_str(&detail); - } - let filter_range = completion - .filter_text - .as_deref() - .and_then(|filter| text.find(filter).map(|ix| ix..ix + filter.len())) - .unwrap_or(0..completion.label.find('(').unwrap_or(text.len())); - return Some(CodeLabel { - filter_range, - text, - runs, - }); + mk_label(text, &|| 0..label.len(), runs) } else if completion .detail .as_ref() @@ -396,20 +413,19 @@ impl LspAdapter for RustLspAdapter { let len = text.len(); let source = Rope::from(text.as_str()); let runs = language.highlight_text(&source, 0..len); - let filter_range = completion - .filter_text - .as_deref() - .and_then(|filter| text.find(filter).map(|ix| ix..ix + filter.len())) - .unwrap_or(0..len); - return Some(CodeLabel { - filter_range, - text, - runs, - }); + mk_label(text, &|| 0..completion.label.len(), runs) + } else if detail_left.is_none() { + return None; + } else { + mk_label( + completion.label.clone(), + &|| 0..completion.label.len(), + vec![], + ) } } - (_, Some(kind)) => { - let highlight_name = match kind { + (_, kind) => { + let highlight_name = kind.and_then(|kind| match kind { lsp::CompletionItemKind::STRUCT | lsp::CompletionItemKind::INTERFACE | lsp::CompletionItemKind::ENUM => Some("type"), @@ -419,27 +435,35 @@ impl LspAdapter for RustLspAdapter { Some("constant") } _ => None, - }; + }); - let mut label = completion.label.clone(); - if let Some(detail) = detail.filter(|detail| detail.starts_with("(")) { - label.push(' '); - label.push_str(detail); - } - let mut label = CodeLabel::plain(label, completion.filter_text.as_deref()); + let label = completion.label.clone(); + let mut runs = vec![]; if let Some(highlight_name) = highlight_name { let highlight_id = language.grammar()?.highlight_id_for_name(highlight_name)?; - label.runs.push(( - 0..label.text.rfind('(').unwrap_or(completion.label.len()), + runs.push(( + 0..label.rfind('(').unwrap_or(completion.label.len()), highlight_id, )); + } else if detail_left.is_none() { + return None; } - return Some(label); + mk_label(label, &|| 0..completion.label.len(), runs) + } + }; + + if let Some(detail_left) = detail_left { + label.text.push(' '); + if !detail_left.starts_with('(') { + label.text.push('('); + } + label.text.push_str(detail_left); + if !detail_left.ends_with(')') { + label.text.push(')'); } - _ => {} } - None + Some(label) } async fn label_for_symbol( @@ -448,55 +472,22 @@ impl LspAdapter for RustLspAdapter { kind: lsp::SymbolKind, language: &Arc, ) -> Option { - let (text, filter_range, display_range) = match kind { - lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => { - let text = format!("fn {} () {{}}", name); - let filter_range = 3..3 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::STRUCT => { - let text = format!("struct {} {{}}", name); - let filter_range = 7..7 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::ENUM => { - let text = format!("enum {} {{}}", name); - let filter_range = 5..5 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::INTERFACE => { - let text = format!("trait {} {{}}", name); - let filter_range = 6..6 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::CONSTANT => { - let text = format!("const {}: () = ();", name); - let filter_range = 6..6 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::MODULE => { - let text = format!("mod {} {{}}", name); - let filter_range = 4..4 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } - lsp::SymbolKind::TYPE_PARAMETER => { - let text = format!("type {} {{}}", name); - let filter_range = 5..5 + name.len(); - let display_range = 0..filter_range.end; - (text, filter_range, display_range) - } + let (prefix, suffix) = match kind { + lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => ("fn ", " () {}"), + lsp::SymbolKind::STRUCT => ("struct ", " {}"), + lsp::SymbolKind::ENUM => ("enum ", " {}"), + lsp::SymbolKind::INTERFACE => ("trait ", " {}"), + lsp::SymbolKind::CONSTANT => ("const ", ": () = ();"), + lsp::SymbolKind::MODULE => ("mod ", " {}"), + lsp::SymbolKind::TYPE_PARAMETER => ("type ", " {}"), _ => return None, }; + let filter_range = prefix.len()..prefix.len() + name.len(); + let display_range = 0..filter_range.end; Some(CodeLabel { - runs: language.highlight_text(&text.as_str().into(), display_range.clone()), - text: text[display_range].to_string(), + runs: language.highlight_text(&Rope::from_iter([prefix, name, suffix]), display_range), + text: format!("{prefix}{name}"), filter_range, }) } @@ -1025,7 +1016,11 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option) -> Vec".to_string()), }), ..Default::default() @@ -1216,7 +1211,7 @@ mod tests { kind: Some(lsp::CompletionItemKind::FUNCTION), label: "hello(…)".to_string(), label_details: Some(CompletionItemLabelDetails { - detail: Some(" (use crate::foo)".to_string()), + detail: Some("(use crate::foo)".to_string()), description: Some("fn(&mut Option) -> Vec".to_string()), }), @@ -1239,6 +1234,35 @@ mod tests { }) ); + assert_eq!( + adapter + .label_for_completion( + &lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::FUNCTION), + label: "hello".to_string(), + label_details: Some(CompletionItemLabelDetails { + detail: Some("(use crate::foo)".to_string()), + description: Some("fn(&mut Option) -> Vec".to_string()), + }), + ..Default::default() + }, + &language + ) + .await, + Some(CodeLabel { + text: "hello(&mut Option) -> Vec (use crate::foo)".to_string(), + filter_range: 0..5, + runs: vec![ + (0..5, highlight_function), + (7..10, highlight_keyword), + (11..17, highlight_type), + (18..19, highlight_type), + (25..28, highlight_type), + (29..30, highlight_type), + ], + }) + ); + assert_eq!( adapter .label_for_completion( @@ -1256,9 +1280,46 @@ mod tests { ) .await, Some(CodeLabel { - text: "await.as_deref_mut()".to_string(), + text: "await.as_deref_mut(&mut self) -> IterMut<'_, T>".to_string(), filter_range: 6..18, - runs: vec![], + runs: vec![ + (6..18, HighlightId(2)), + (20..23, HighlightId(1)), + (33..40, HighlightId(0)), + (45..46, HighlightId(0)) + ], + }) + ); + + assert_eq!( + adapter + .label_for_completion( + &lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::METHOD), + label: "as_deref_mut()".to_string(), + filter_text: Some("as_deref_mut".to_string()), + label_details: Some(CompletionItemLabelDetails { + detail: None, + description: Some( + "pub fn as_deref_mut(&mut self) -> IterMut<'_, T>".to_string() + ), + }), + ..Default::default() + }, + &language + ) + .await, + Some(CodeLabel { + text: "pub fn as_deref_mut(&mut self) -> IterMut<'_, T>".to_string(), + filter_range: 7..19, + runs: vec![ + (0..3, HighlightId(1)), + (4..6, HighlightId(1)), + (7..19, HighlightId(2)), + (21..24, HighlightId(1)), + (34..41, HighlightId(0)), + (46..47, HighlightId(0)) + ], }) ); diff --git a/crates/languages/src/tsx/highlights.scm b/crates/languages/src/tsx/highlights.scm index e2837c61fd..5e2fbbf63a 100644 --- a/crates/languages/src/tsx/highlights.scm +++ b/crates/languages/src/tsx/highlights.scm @@ -146,6 +146,7 @@ "&&=" "||=" "??=" + "..." ] @operator (regex "/" @string.regex) diff --git a/crates/languages/src/typescript.rs b/crates/languages/src/typescript.rs index 9dc3ee303d..f976b62614 100644 --- a/crates/languages/src/typescript.rs +++ b/crates/languages/src/typescript.rs @@ -1,6 +1,4 @@ use anyhow::{Context as _, Result}; -use async_compression::futures::bufread::GzipDecoder; -use async_tar::Archive; use async_trait::async_trait; use chrono::{DateTime, Local}; use collections::HashMap; @@ -15,7 +13,7 @@ use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; use node_runtime::NodeRuntime; use project::{Fs, lsp_store::language_server_settings}; use serde_json::{Value, json}; -use smol::{fs, io::BufReader, lock::RwLock, stream::StreamExt}; +use smol::{fs, lock::RwLock, stream::StreamExt}; use std::{ any::Any, borrow::Cow, @@ -24,11 +22,10 @@ use std::{ sync::Arc, }; use task::{TaskTemplate, TaskTemplates, VariableName}; -use util::archive::extract_zip; use util::merge_json_value_into; use util::{ResultExt, fs::remove_matching, maybe}; -use crate::{PackageJson, PackageJsonData}; +use crate::{PackageJson, PackageJsonData, github_download::download_server_binary}; #[derive(Debug)] pub(crate) struct TypeScriptContextProvider { @@ -897,6 +894,7 @@ impl LspAdapter for EsLintLspAdapter { Ok(Box::new(GitHubLspBinaryVersion { name: Self::CURRENT_VERSION.into(), + digest: None, url, })) } @@ -914,43 +912,14 @@ impl LspAdapter for EsLintLspAdapter { if fs::metadata(&server_path).await.is_err() { remove_matching(&container_dir, |entry| entry != destination_path).await; - let mut response = delegate - .http_client() - .get(&version.url, Default::default(), true) - .await - .context("downloading release")?; - match Self::GITHUB_ASSET_KIND { - AssetKind::TarGz => { - let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); - let archive = Archive::new(decompressed_bytes); - archive.unpack(&destination_path).await.with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Gz => { - let mut decompressed_bytes = - GzipDecoder::new(BufReader::new(response.body_mut())); - let mut file = - fs::File::create(&destination_path).await.with_context(|| { - format!( - "creating a file {:?} for a download from {}", - destination_path, version.url, - ) - })?; - futures::io::copy(&mut decompressed_bytes, &mut file) - .await - .with_context(|| { - format!("extracting {} to {:?}", version.url, destination_path) - })?; - } - AssetKind::Zip => { - extract_zip(&destination_path, response.body_mut()) - .await - .with_context(|| { - format!("unzipping {} to {:?}", version.url, destination_path) - })?; - } - } + download_server_binary( + delegate, + &version.url, + None, + &destination_path, + Self::GITHUB_ASSET_KIND, + ) + .await?; let mut dir = fs::read_dir(&destination_path).await?; let first = dir.next().await.context("missing first file")??; diff --git a/crates/languages/src/typescript/highlights.scm b/crates/languages/src/typescript/highlights.scm index 486e5a7684..af37ef6415 100644 --- a/crates/languages/src/typescript/highlights.scm +++ b/crates/languages/src/typescript/highlights.scm @@ -167,6 +167,7 @@ "&&=" "||=" "??=" + "..." ] @operator (regex "/" @string.regex) diff --git a/crates/languages/src/yaml/config.toml b/crates/languages/src/yaml/config.toml index 4dfb890c54..e54bceda1a 100644 --- a/crates/languages/src/yaml/config.toml +++ b/crates/languages/src/yaml/config.toml @@ -1,6 +1,6 @@ name = "YAML" grammar = "yaml" -path_suffixes = ["yml", "yaml"] +path_suffixes = ["yml", "yaml", "pixi.lock"] line_comments = ["# "] autoclose_before = ",]}" brackets = [ diff --git a/crates/lsp/src/input_handler.rs b/crates/lsp/src/input_handler.rs index db3f1190fc..001ebf1fc9 100644 --- a/crates/lsp/src/input_handler.rs +++ b/crates/lsp/src/input_handler.rs @@ -13,14 +13,15 @@ use parking_lot::Mutex; use smol::io::BufReader; use crate::{ - AnyNotification, AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, RequestId, ResponseHandler, + AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, NotificationOrRequest, RequestId, + ResponseHandler, }; const HEADER_DELIMITER: &[u8; 4] = b"\r\n\r\n"; /// Handler for stdout of language server. pub struct LspStdoutHandler { pub(super) loop_handle: Task>, - pub(super) notifications_channel: UnboundedReceiver, + pub(super) incoming_messages: UnboundedReceiver, } async fn read_headers(reader: &mut BufReader, buffer: &mut Vec) -> Result<()> @@ -54,13 +55,13 @@ impl LspStdoutHandler { let loop_handle = cx.spawn(Self::handler(stdout, tx, response_handlers, io_handlers)); Self { loop_handle, - notifications_channel, + incoming_messages: notifications_channel, } } async fn handler( stdout: Input, - notifications_sender: UnboundedSender, + notifications_sender: UnboundedSender, response_handlers: Arc>>>, io_handlers: Arc>>, ) -> anyhow::Result<()> @@ -96,7 +97,7 @@ impl LspStdoutHandler { } } - if let Ok(msg) = serde_json::from_slice::(&buffer) { + if let Ok(msg) = serde_json::from_slice::(&buffer) { notifications_sender.unbounded_send(msg)?; } else if let Ok(AnyResponse { id, error, result, .. diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index b9701a83d2..a92787cd3e 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -242,7 +242,7 @@ struct Notification<'a, T> { /// Language server RPC notification message before it is deserialized into a concrete type. #[derive(Debug, Clone, Deserialize)] -struct AnyNotification { +struct NotificationOrRequest { #[serde(default)] id: Option, method: String, @@ -252,7 +252,10 @@ struct AnyNotification { #[derive(Debug, Serialize, Deserialize)] struct Error { + code: i64, message: String, + #[serde(default)] + data: Option, } pub trait LspRequestFuture: Future> { @@ -364,6 +367,7 @@ impl LanguageServer { notification.method, serde_json::to_string_pretty(¬ification.params).unwrap(), ); + false }, ); @@ -389,7 +393,7 @@ impl LanguageServer { Stdin: AsyncWrite + Unpin + Send + 'static, Stdout: AsyncRead + Unpin + Send + 'static, Stderr: AsyncRead + Unpin + Send + 'static, - F: FnMut(AnyNotification) + 'static + Send + Sync + Clone, + F: Fn(&NotificationOrRequest) -> bool + 'static + Send + Sync + Clone, { let (outbound_tx, outbound_rx) = channel::unbounded::(); let (output_done_tx, output_done_rx) = barrier::channel(); @@ -400,14 +404,34 @@ impl LanguageServer { let io_handlers = Arc::new(Mutex::new(HashMap::default())); let stdout_input_task = cx.spawn({ - let on_unhandled_notification = on_unhandled_notification.clone(); + let unhandled_notification_wrapper = { + let response_channel = outbound_tx.clone(); + async move |msg: NotificationOrRequest| { + let did_handle = on_unhandled_notification(&msg); + if !did_handle && let Some(message_id) = msg.id { + let response = AnyResponse { + jsonrpc: JSON_RPC_VERSION, + id: message_id, + error: Some(Error { + code: -32601, + message: format!("Unrecognized method `{}`", msg.method), + data: None, + }), + result: None, + }; + if let Ok(response) = serde_json::to_string(&response) { + response_channel.send(response).await.ok(); + } + } + } + }; let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); let io_handlers = io_handlers.clone(); async move |cx| { - Self::handle_input( + Self::handle_incoming_messages( stdout, - on_unhandled_notification, + unhandled_notification_wrapper, notification_handlers, response_handlers, io_handlers, @@ -433,7 +457,7 @@ impl LanguageServer { stdout.or(stderr) }); let output_task = cx.background_spawn({ - Self::handle_output( + Self::handle_outgoing_messages( stdin, outbound_rx, output_done_tx, @@ -479,9 +503,9 @@ impl LanguageServer { self.code_action_kinds.clone() } - async fn handle_input( + async fn handle_incoming_messages( stdout: Stdout, - mut on_unhandled_notification: F, + on_unhandled_notification: impl AsyncFn(NotificationOrRequest) + 'static + Send, notification_handlers: Arc>>, response_handlers: Arc>>>, io_handlers: Arc>>, @@ -489,7 +513,6 @@ impl LanguageServer { ) -> anyhow::Result<()> where Stdout: AsyncRead + Unpin + Send + 'static, - F: FnMut(AnyNotification) + 'static + Send, { use smol::stream::StreamExt; let stdout = BufReader::new(stdout); @@ -506,15 +529,19 @@ impl LanguageServer { cx.background_executor().clone(), ); - while let Some(msg) = input_handler.notifications_channel.next().await { - { + while let Some(msg) = input_handler.incoming_messages.next().await { + let unhandled_message = { let mut notification_handlers = notification_handlers.lock(); if let Some(handler) = notification_handlers.get_mut(msg.method.as_str()) { handler(msg.id, msg.params.unwrap_or(Value::Null), cx); + None } else { - drop(notification_handlers); - on_unhandled_notification(msg); + Some(msg) } + }; + + if let Some(msg) = unhandled_message { + on_unhandled_notification(msg).await; } // Don't starve the main thread when receiving lots of notifications at once. @@ -558,7 +585,7 @@ impl LanguageServer { } } - async fn handle_output( + async fn handle_outgoing_messages( stdin: Stdin, outbound_rx: channel::Receiver, output_done_tx: barrier::Sender, @@ -720,6 +747,10 @@ impl LanguageServer { InsertTextMode::ADJUST_INDENTATION, ], }), + documentation_format: Some(vec![ + MarkupKind::Markdown, + MarkupKind::PlainText, + ]), ..Default::default() }), insert_text_mode: Some(InsertTextMode::ADJUST_INDENTATION), @@ -1036,7 +1067,9 @@ impl LanguageServer { jsonrpc: JSON_RPC_VERSION, id, value: LspResult::Error(Some(Error { + code: lsp_types::error_codes::REQUEST_FAILED, message: error.to_string(), + data: None, })), }, }; @@ -1057,7 +1090,9 @@ impl LanguageServer { id, result: None, error: Some(Error { + code: -32700, // Parse error message: error.to_string(), + data: None, }), }; if let Some(response) = serde_json::to_string(&response).log_err() { @@ -1559,7 +1594,7 @@ impl FakeLanguageServer { root, Some(workspace_folders.clone()), cx, - |_| {}, + |_| false, ); server.process_name = process_name; let fake = FakeLanguageServer { @@ -1582,9 +1617,10 @@ impl FakeLanguageServer { notifications_tx .try_send(( msg.method.to_string(), - msg.params.unwrap_or(Value::Null).to_string(), + msg.params.as_ref().unwrap_or(&Value::Null).to_string(), )) .ok(); + true }, ); server.process_name = name.as_str().into(); @@ -1862,7 +1898,7 @@ mod tests { #[gpui::test] fn test_deserialize_string_digit_id() { let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#; - let notification = serde_json::from_str::(json) + let notification = serde_json::from_str::(json) .expect("message with string id should be parsed"); let expected_id = RequestId::Str("2".to_string()); assert_eq!(notification.id, Some(expected_id)); @@ -1871,7 +1907,7 @@ mod tests { #[gpui::test] fn test_deserialize_string_id() { let json = r#"{"jsonrpc":"2.0","id":"anythingAtAll","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#; - let notification = serde_json::from_str::(json) + let notification = serde_json::from_str::(json) .expect("message with string id should be parsed"); let expected_id = RequestId::Str("anythingAtAll".to_string()); assert_eq!(notification.id, Some(expected_id)); @@ -1880,7 +1916,7 @@ mod tests { #[gpui::test] fn test_deserialize_int_id() { let json = r#"{"jsonrpc":"2.0","id":2,"method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#; - let notification = serde_json::from_str::(json) + let notification = serde_json::from_str::(json) .expect("message with string id should be parsed"); let expected_id = RequestId::Int(2); assert_eq!(notification.id, Some(expected_id)); diff --git a/crates/markdown_preview/src/markdown_preview_view.rs b/crates/markdown_preview/src/markdown_preview_view.rs index 03cfd7ee82..a0c8819991 100644 --- a/crates/markdown_preview/src/markdown_preview_view.rs +++ b/crates/markdown_preview/src/markdown_preview_view.rs @@ -18,6 +18,7 @@ use workspace::item::{Item, ItemHandle}; use workspace::{Pane, Workspace}; use crate::markdown_elements::ParsedMarkdownElement; +use crate::markdown_renderer::CheckboxClickedEvent; use crate::{ MovePageDown, MovePageUp, OpenFollowingPreview, OpenPreview, OpenPreviewToTheSide, markdown_elements::ParsedMarkdown, @@ -203,114 +204,7 @@ impl MarkdownPreviewView { cx: &mut Context, ) -> Entity { cx.new(|cx| { - let view = cx.entity().downgrade(); - - let list_state = ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - if let Some(view) = view.upgrade() { - view.update(cx, |this: &mut Self, cx| { - let Some(contents) = &this.contents else { - return div().into_any(); - }; - - let mut render_cx = - RenderContext::new(Some(this.workspace.clone()), window, cx) - .with_checkbox_clicked_callback({ - let view = view.clone(); - move |checked, source_range, window, cx| { - view.update(cx, |view, cx| { - if let Some(editor) = view - .active_editor - .as_ref() - .map(|s| s.editor.clone()) - { - editor.update(cx, |editor, cx| { - let task_marker = - if checked { "[x]" } else { "[ ]" }; - - editor.edit( - vec![(source_range, task_marker)], - cx, - ); - }); - view.parse_markdown_from_active_editor( - false, window, cx, - ); - cx.notify(); - } - }) - } - }); - - let block = contents.children.get(ix).unwrap(); - let rendered_block = render_markdown_block(block, &mut render_cx); - - let should_apply_padding = Self::should_apply_padding_between( - block, - contents.children.get(ix + 1), - ); - - div() - .id(ix) - .when(should_apply_padding, |this| { - this.pb(render_cx.scaled_rems(0.75)) - }) - .group("markdown-block") - .on_click(cx.listener( - move |this, event: &ClickEvent, window, cx| { - if event.down.click_count == 2 { - if let Some(source_range) = this - .contents - .as_ref() - .and_then(|c| c.children.get(ix)) - .and_then(|block| block.source_range()) - { - this.move_cursor_to_block( - window, - cx, - source_range.start..source_range.start, - ); - } - } - }, - )) - .map(move |container| { - let indicator = div() - .h_full() - .w(px(4.0)) - .when(ix == this.selected_block, |this| { - this.bg(cx.theme().colors().border) - }) - .group_hover("markdown-block", |s| { - if ix == this.selected_block { - s - } else { - s.bg(cx.theme().colors().border_variant) - } - }) - .rounded_xs(); - - container.child( - div() - .relative() - .child( - div() - .pl(render_cx.scaled_rems(1.0)) - .child(rendered_block), - ) - .child(indicator.absolute().left_0().top_0()), - ) - }) - .into_any() - }) - } else { - div().into_any() - } - }, - ); + let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let mut this = Self { selected_block: 0, @@ -607,10 +501,107 @@ impl Render for MarkdownPreviewView { .p_4() .text_size(buffer_size) .line_height(relative(buffer_line_height.value())) - .child( - div() - .flex_grow() - .map(|this| this.child(list(self.list_state.clone()).size_full())), - ) + .child(div().flex_grow().map(|this| { + this.child( + list( + self.list_state.clone(), + cx.processor(|this, ix, window, cx| { + let Some(contents) = &this.contents else { + return div().into_any(); + }; + + let mut render_cx = + RenderContext::new(Some(this.workspace.clone()), window, cx) + .with_checkbox_clicked_callback(cx.listener( + move |this, e: &CheckboxClickedEvent, window, cx| { + if let Some(editor) = this + .active_editor + .as_ref() + .map(|s| s.editor.clone()) + { + editor.update(cx, |editor, cx| { + let task_marker = + if e.checked() { "[x]" } else { "[ ]" }; + + editor.edit( + vec![(e.source_range(), task_marker)], + cx, + ); + }); + this.parse_markdown_from_active_editor( + false, window, cx, + ); + cx.notify(); + } + }, + )); + + let block = contents.children.get(ix).unwrap(); + let rendered_block = render_markdown_block(block, &mut render_cx); + + let should_apply_padding = Self::should_apply_padding_between( + block, + contents.children.get(ix + 1), + ); + + div() + .id(ix) + .when(should_apply_padding, |this| { + this.pb(render_cx.scaled_rems(0.75)) + }) + .group("markdown-block") + .on_click(cx.listener( + move |this, event: &ClickEvent, window, cx| { + if event.click_count() == 2 { + if let Some(source_range) = this + .contents + .as_ref() + .and_then(|c| c.children.get(ix)) + .and_then(|block: &ParsedMarkdownElement| { + block.source_range() + }) + { + this.move_cursor_to_block( + window, + cx, + source_range.start..source_range.start, + ); + } + } + }, + )) + .map(move |container| { + let indicator = div() + .h_full() + .w(px(4.0)) + .when(ix == this.selected_block, |this| { + this.bg(cx.theme().colors().border) + }) + .group_hover("markdown-block", |s| { + if ix == this.selected_block { + s + } else { + s.bg(cx.theme().colors().border_variant) + } + }) + .rounded_xs(); + + container.child( + div() + .relative() + .child( + div() + .pl(render_cx.scaled_rems(1.0)) + .child(rendered_block), + ) + .child(indicator.absolute().left_0().top_0()), + ) + }) + .into_any() + }), + ) + .size_full(), + ) + })) } } diff --git a/crates/markdown_preview/src/markdown_renderer.rs b/crates/markdown_preview/src/markdown_renderer.rs index 80bed8a6e8..37d2ca2110 100644 --- a/crates/markdown_preview/src/markdown_renderer.rs +++ b/crates/markdown_preview/src/markdown_renderer.rs @@ -26,7 +26,22 @@ use ui::{ }; use workspace::{OpenOptions, OpenVisible, Workspace}; -type CheckboxClickedCallback = Arc, &mut Window, &mut App)>>; +pub struct CheckboxClickedEvent { + pub checked: bool, + pub source_range: Range, +} + +impl CheckboxClickedEvent { + pub fn source_range(&self) -> Range { + self.source_range.clone() + } + + pub fn checked(&self) -> bool { + self.checked + } +} + +type CheckboxClickedCallback = Arc>; #[derive(Clone)] pub struct RenderContext { @@ -80,7 +95,7 @@ impl RenderContext { pub fn with_checkbox_clicked_callback( mut self, - callback: impl Fn(bool, Range, &mut Window, &mut App) + 'static, + callback: impl Fn(&CheckboxClickedEvent, &mut Window, &mut App) + 'static, ) -> Self { self.checkbox_clicked_callback = Some(Arc::new(Box::new(callback))); self @@ -229,7 +244,14 @@ fn render_markdown_list_item( }; if window.modifiers().secondary() { - callback(checked, range.clone(), window, cx); + callback( + &CheckboxClickedEvent { + checked, + source_range: range.clone(), + }, + window, + cx, + ); } } }) diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 62c32b4161..64cd1cc0cb 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -58,7 +58,7 @@ fn get_max_tokens(name: &str) -> u64 { "magistral" => 40000, "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r" | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder" - | "devstral" => 128000, + | "devstral" | "gpt-oss" => 128000, _ => DEFAULT_TOKENS, } .clamp(1, MAXIMUM_TOKENS) diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs index 098907870b..00f2d5fc8b 100644 --- a/crates/onboarding/src/ai_setup_page.rs +++ b/crates/onboarding/src/ai_setup_page.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use ai_onboarding::{AiUpsellCard, SignInStatus}; -use client::UserStore; +use ai_onboarding::AiUpsellCard; +use client::{Client, UserStore}; use fs::Fs; use gpui::{ Action, AnyView, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, WeakEntity, @@ -12,8 +12,8 @@ use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageMod use project::DisableAiSettings; use settings::{Settings, update_settings_file}; use ui::{ - Badge, ButtonLike, Divider, Modal, ModalFooter, ModalHeader, Section, SwitchField, ToggleState, - prelude::*, tooltip_container, + Badge, ButtonLike, Divider, KeyBinding, Modal, ModalFooter, ModalHeader, Section, SwitchField, + ToggleState, prelude::*, tooltip_container, }; use util::ResultExt; use workspace::{ModalView, Workspace}; @@ -88,7 +88,7 @@ fn render_privacy_card(tab_index: &mut isize, disabled: bool, cx: &mut App) -> i h_flex() .gap_2() .justify_between() - .child(Label::new("We don't train models using your data")) + .child(Label::new("Privacy is the default for Zed")) .child( h_flex().gap_1().child(privacy_badge()).child( Button::new("learn_more", "Learn More") @@ -109,7 +109,7 @@ fn render_privacy_card(tab_index: &mut isize, disabled: bool, cx: &mut App) -> i ) .child( Label::new( - "Feel confident in the security and privacy of your projects using Zed.", + "Any use or storage of your data is with your explicit, single-use, opt-in consent.", ) .size(LabelSize::Small) .color(Color::Muted), @@ -240,6 +240,7 @@ fn render_llm_provider_card( pub(crate) fn render_ai_setup_page( workspace: WeakEntity, user_store: Entity, + client: Arc, window: &mut Window, cx: &mut App, ) -> impl IntoElement { @@ -283,14 +284,16 @@ pub(crate) fn render_ai_setup_page( v_flex() .mt_2() .gap_6() - .child(AiUpsellCard { - sign_in_status: SignInStatus::SignedIn, - sign_in: Arc::new(|_, _| {}), - user_plan: user_store.read(cx).plan(), - tab_index: Some({ + .child({ + let mut ai_upsell_card = + AiUpsellCard::new(client, &user_store, user_store.read(cx).plan(), cx); + + ai_upsell_card.tab_index = Some({ tab_index += 1; tab_index - 1 - }), + }); + + ai_upsell_card }) .child(render_llm_provider_section( &mut tab_index, @@ -335,6 +338,10 @@ impl AiConfigurationModal { selected_provider, } } + + fn cancel(&mut self, _: &menu::Cancel, cx: &mut Context) { + cx.emit(DismissEvent); + } } impl ModalView for AiConfigurationModal {} @@ -348,11 +355,15 @@ impl Focusable for AiConfigurationModal { } impl Render for AiConfigurationModal { - fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() + .key_context("OnboardingAiConfigurationModal") .w(rems(34.)) .elevation_3(cx) .track_focus(&self.focus_handle) + .on_action( + cx.listener(|this, _: &menu::Cancel, _window, cx| this.cancel(&menu::Cancel, cx)), + ) .child( Modal::new("onboarding-ai-setup-modal", None) .header( @@ -367,18 +378,19 @@ impl Render for AiConfigurationModal { .section(Section::new().child(self.configuration_view.clone())) .footer( ModalFooter::new().end_slot( - h_flex() - .gap_1() - .child( - Button::new("onboarding-closing-cancel", "Cancel") - .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + Button::new("ai-onb-modal-Done", "Done") + .key_binding( + KeyBinding::for_action_in( + &menu::Cancel, + &self.focus_handle.clone(), + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), ) - .child(Button::new("save-btn", "Done").on_click(cx.listener( - |_, _, window, cx| { - window.dispatch_action(menu::Confirm.boxed_clone(), cx); - cx.emit(DismissEvent); - }, - ))), + .on_click(cx.listener(|this, _event, _window, cx| { + this.cancel(&menu::Cancel, cx) + })), ), ), ) @@ -395,7 +407,7 @@ impl AiPrivacyTooltip { impl Render for AiPrivacyTooltip { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - const DESCRIPTION: &'static str = "One of Zed's most important principles is transparency. This is why we are and value open-source so much. And it wouldn't be any different with AI."; + const DESCRIPTION: &'static str = "We believe in opt-in data sharing as the default for building AI products, rather than opt-out. We'll only use or store your data if you affirmatively send it to us. "; tooltip_container(window, cx, move |this, _, _| { this.child( @@ -406,7 +418,7 @@ impl Render for AiPrivacyTooltip { .size(IconSize::Small) .color(Color::Muted), ) - .child(Label::new("Privacy Principle")), + .child(Label::new("Privacy First")), ) .child( div().max_w_64().child( diff --git a/crates/onboarding/src/basics_page.rs b/crates/onboarding/src/basics_page.rs index a4e4028051..a19a21fddf 100644 --- a/crates/onboarding/src/basics_page.rs +++ b/crates/onboarding/src/basics_page.rs @@ -201,12 +201,15 @@ fn render_telemetry_section(tab_index: &mut isize, cx: &App) -> impl IntoElement let fs = ::global(cx); v_flex() + .pt_6() .gap_4() + .border_t_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) .child(Label::new("Telemetry").size(LabelSize::Large)) .child(SwitchField::new( "onboarding-telemetry-metrics", "Help Improve Zed", - Some("Sending anonymous usage data helps us build the right features and create the best experience.".into()), + Some("Anonymous usage data helps us build the right features and improve your experience.".into()), if TelemetrySettings::get_global(cx).metrics { ui::ToggleState::Selected } else { @@ -294,7 +297,7 @@ fn render_base_keymap_section(tab_index: &mut isize, cx: &mut App) -> impl IntoE ToggleButtonWithIcon::new("Emacs", IconName::EditorEmacs, |_, _, cx| { write_keymap_base(BaseKeymap::Emacs, cx); }), - ToggleButtonWithIcon::new("Cursor (Beta)", IconName::EditorCursor, |_, _, cx| { + ToggleButtonWithIcon::new("Cursor", IconName::EditorCursor, |_, _, cx| { write_keymap_base(BaseKeymap::Cursor, cx); }), ], @@ -326,10 +329,7 @@ fn render_vim_mode_switch(tab_index: &mut isize, cx: &mut App) -> impl IntoEleme SwitchField::new( "onboarding-vim-mode", "Vim Mode", - Some( - "Coming from Neovim? Zed's first-class implementation of Vim Mode has got your back." - .into(), - ), + Some("Coming from Neovim? Use our first-class implementation of Vim Mode.".into()), toggle_state, { let fs = ::global(cx); diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs index a8f0265b6b..8b4293db0d 100644 --- a/crates/onboarding/src/editing_page.rs +++ b/crates/onboarding/src/editing_page.rs @@ -584,11 +584,15 @@ fn render_popular_settings_section( window: &mut Window, cx: &mut App, ) -> impl IntoElement { - const LIGATURE_TOOLTIP: &'static str = "Ligatures are when a font creates a special character out of combining two characters into one. For example, with ligatures turned on, =/= would become ≠."; + const LIGATURE_TOOLTIP: &'static str = + "Font ligatures combine two characters into one. For example, turning =/= into ≠."; v_flex() - .gap_5() - .child(Label::new("Popular Settings").size(LabelSize::Large).mt_8()) + .pt_6() + .gap_4() + .border_t_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .child(Label::new("Popular Settings").size(LabelSize::Large)) .child(render_font_customization_section(tab_index, window, cx)) .child( SwitchField::new( @@ -683,7 +687,10 @@ fn render_popular_settings_section( [ ToggleButtonSimple::new("Auto", |_, _, cx| { write_show_mini_map(ShowMinimap::Auto, cx); - }), + }) + .tooltip(Tooltip::text( + "Show the minimap if the editor's scrollbar is visible.", + )), ToggleButtonSimple::new("Always", |_, _, cx| { write_show_mini_map(ShowMinimap::Always, cx); }), @@ -707,7 +714,7 @@ fn render_popular_settings_section( pub(crate) fn render_editing_page(window: &mut Window, cx: &mut App) -> impl IntoElement { let mut tab_index = 0; v_flex() - .gap_4() + .gap_6() .child(render_import_settings_section(&mut tab_index, cx)) .child(render_popular_settings_section(&mut tab_index, window, cx)) } diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index c4d2b6847c..98f61df97b 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -77,6 +77,8 @@ actions!( ActivateAISetupPage, /// Finish the onboarding process. Finish, + /// Sign in while in the onboarding flow. + SignIn ] ); @@ -376,6 +378,7 @@ impl Onboarding { cx, ) .map(|kb| kb.size(rems_from_px(12.))); + if ai_setup_page { this.child( ButtonLike::new("start_building") @@ -387,14 +390,7 @@ impl Onboarding { .w_full() .justify_between() .child(Label::new("Start Building")) - .child(keybinding.map_or_else( - || { - Icon::new(IconName::Check) - .size(IconSize::Small) - .into_any_element() - }, - IntoElement::into_any_element, - )), + .children(keybinding), ) .on_click(|_, window, cx| { window.dispatch_action(Finish.boxed_clone(), cx); @@ -409,11 +405,10 @@ impl Onboarding { .ml_1() .w_full() .justify_between() - .child(Label::new("Skip All")) - .child(keybinding.map_or_else( - || gpui::Empty.into_any_element(), - IntoElement::into_any_element, - )), + .child( + Label::new("Skip All").color(Color::Muted), + ) + .children(keybinding), ) .on_click(|_, window, cx| { window.dispatch_action(Finish.boxed_clone(), cx); @@ -435,23 +430,39 @@ impl Onboarding { Button::new("sign_in", "Sign In") .full_width() .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .key_binding( + KeyBinding::for_action_in(&SignIn, &self.focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))), + ) .on_click(|_, window, cx| { - let client = Client::global(cx); - window - .spawn(cx, async move |cx| { - client - .sign_in_with_optional_connect(true, &cx) - .await - .notify_async_err(cx); - }) - .detach(); + window.dispatch_action(SignIn.boxed_clone(), cx); }) .into_any_element() }, ) } + fn on_finish(_: &Finish, _: &mut Window, cx: &mut App) { + go_to_welcome_page(cx); + } + + fn handle_sign_in(_: &SignIn, window: &mut Window, cx: &mut App) { + let client = Client::global(cx); + + window + .spawn(cx, async move |cx| { + client + .sign_in_with_optional_connect(true, &cx) + .await + .notify_async_err(cx); + }) + .detach(); + } + fn render_page(&mut self, window: &mut Window, cx: &mut Context) -> AnyElement { + let client = Client::global(cx); + match self.selected_page { SelectedPage::Basics => crate::basics_page::render_basics_page(cx).into_any_element(), SelectedPage::Editing => { @@ -460,16 +471,13 @@ impl Onboarding { SelectedPage::AiSetup => crate::ai_setup_page::render_ai_setup_page( self.workspace.clone(), self.user_store.clone(), + client, window, cx, ) .into_any_element(), } } - - fn on_finish(_: &Finish, _: &mut Window, cx: &mut App) { - go_to_welcome_page(cx); - } } impl Render for Onboarding { @@ -486,6 +494,7 @@ impl Render for Onboarding { .size_full() .bg(cx.theme().colors().editor_background) .on_action(Self::on_finish) + .on_action(Self::handle_sign_in) .on_action(cx.listener(|this, _: &ActivateBasicsPage, _, cx| { this.set_page(SelectedPage::Basics, cx); })) diff --git a/crates/onboarding/src/theme_preview.rs b/crates/onboarding/src/theme_preview.rs index 53631be1c9..81eb14ec4b 100644 --- a/crates/onboarding/src/theme_preview.rs +++ b/crates/onboarding/src/theme_preview.rs @@ -299,6 +299,18 @@ impl RenderOnce for ThemePreviewTile { } impl Component for ThemePreviewTile { + fn scope() -> ComponentScope { + ComponentScope::Onboarding + } + + fn name() -> &'static str { + "Theme Preview Tile" + } + + fn sort_name() -> &'static str { + "Theme Preview Tile" + } + fn description() -> Option<&'static str> { Some(Self::DOCS) } diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 12a5cf52d2..4697d71ed3 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -74,6 +74,12 @@ pub enum Model { O3, #[serde(rename = "o4-mini")] O4Mini, + #[serde(rename = "gpt-5")] + Five, + #[serde(rename = "gpt-5-mini")] + FiveMini, + #[serde(rename = "gpt-5-nano")] + FiveNano, #[serde(rename = "custom")] Custom { @@ -105,6 +111,9 @@ impl Model { "o3-mini" => Ok(Self::O3Mini), "o3" => Ok(Self::O3), "o4-mini" => Ok(Self::O4Mini), + "gpt-5" => Ok(Self::Five), + "gpt-5-mini" => Ok(Self::FiveMini), + "gpt-5-nano" => Ok(Self::FiveNano), invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"), } } @@ -123,6 +132,9 @@ impl Model { Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", + Self::Five => "gpt-5", + Self::FiveMini => "gpt-5-mini", + Self::FiveNano => "gpt-5-nano", Self::Custom { name, .. } => name, } } @@ -141,6 +153,9 @@ impl Model { Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", + Self::Five => "gpt-5", + Self::FiveMini => "gpt-5-mini", + Self::FiveNano => "gpt-5-nano", Self::Custom { name, display_name, .. } => display_name.as_ref().unwrap_or(name), @@ -161,6 +176,9 @@ impl Model { Self::O3Mini => 200_000, Self::O3 => 200_000, Self::O4Mini => 200_000, + Self::Five => 272_000, + Self::FiveMini => 272_000, + Self::FiveNano => 272_000, Self::Custom { max_tokens, .. } => *max_tokens, } } @@ -182,6 +200,9 @@ impl Model { Self::O3Mini => Some(100_000), Self::O3 => Some(100_000), Self::O4Mini => Some(100_000), + Self::Five => Some(128_000), + Self::FiveMini => Some(128_000), + Self::FiveNano => Some(128_000), } } @@ -197,7 +218,10 @@ impl Model { | Self::FourOmniMini | Self::FourPointOne | Self::FourPointOneMini - | Self::FourPointOneNano => true, + | Self::FourPointOneNano + | Self::Five + | Self::FiveMini + | Self::FiveNano => true, Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false, } } diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index ad96670db9..1cda3897ec 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -2570,11 +2570,11 @@ impl OutlinePanel { .on_click({ let clicked_entry = rendered_entry.clone(); cx.listener(move |outline_panel, event: &gpui::ClickEvent, window, cx| { - if event.down.button == MouseButton::Right || event.down.first_mouse { + if event.is_right_click() || event.first_focus() { return; } - let change_focus = event.down.click_count > 1; + let change_focus = event.click_count() > 1; outline_panel.toggle_expanded(&clicked_entry, window, cx); outline_panel.scroll_editor_to_entry( diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs index 692bdd5bd7..34af5fed02 100644 --- a/crates/picker/src/picker.rs +++ b/crates/picker/src/picker.rs @@ -292,7 +292,7 @@ impl Picker { window: &mut Window, cx: &mut Context, ) -> Self { - let element_container = Self::create_element_container(container, cx); + let element_container = Self::create_element_container(container); let scrollbar_state = match &element_container { ElementContainer::UniformList(scroll_handle) => { ScrollbarState::new(scroll_handle.clone()) @@ -323,31 +323,13 @@ impl Picker { this } - fn create_element_container( - container: ContainerKind, - cx: &mut Context, - ) -> ElementContainer { + fn create_element_container(container: ContainerKind) -> ElementContainer { match container { ContainerKind::UniformList => { ElementContainer::UniformList(UniformListScrollHandle::new()) } ContainerKind::List => { - let entity = cx.entity().downgrade(); - ElementContainer::List(ListState::new( - 0, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - entity - .upgrade() - .map(|entity| { - entity.update(cx, |this, cx| { - this.render_element(window, cx, ix).into_any_element() - }) - }) - .unwrap_or_else(|| div().into_any_element()) - }, - )) + ElementContainer::List(ListState::new(0, gpui::ListAlignment::Top, px(1000.))) } } } @@ -786,11 +768,16 @@ impl Picker { .py_1() .track_scroll(scroll_handle.clone()) .into_any_element(), - ElementContainer::List(state) => list(state.clone()) - .with_sizing_behavior(sizing_behavior) - .flex_grow() - .py_2() - .into_any_element(), + ElementContainer::List(state) => list( + state.clone(), + cx.processor(|this, ix, window, cx| { + this.render_element(window, cx, ix).into_any_element() + }), + ) + .with_sizing_behavior(sizing_behavior) + .flex_grow() + .py_2() + .into_any_element(), } } diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index f60a7becf7..d9c28df497 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -56,7 +56,7 @@ use std::{ }; use task::TaskContext; use text::{PointUtf16, ToPointUtf16}; -use util::{ResultExt, maybe}; +use util::{ResultExt, debug_panic, maybe}; use worktree::Worktree; #[derive(Debug, Copy, Clone, Hash, PartialEq, PartialOrd, Ord, Eq)] @@ -141,7 +141,10 @@ pub struct DataBreakpointState { } pub enum SessionState { - Building(Option>>), + /// Represents a session that is building/initializing + /// even if a session doesn't have a pre build task this state + /// is used to run all the async tasks that are required to start the session + Booting(Option>>), Running(RunningMode), } @@ -574,7 +577,7 @@ impl SessionState { { match self { SessionState::Running(debug_adapter_client) => debug_adapter_client.request(request), - SessionState::Building(_) => Task::ready(Err(anyhow!( + SessionState::Booting(_) => Task::ready(Err(anyhow!( "no adapter running to send request: {request:?}" ))), } @@ -583,7 +586,7 @@ impl SessionState { /// Did this debug session stop at least once? pub(crate) fn has_ever_stopped(&self) -> bool { match self { - SessionState::Building(_) => false, + SessionState::Booting(_) => false, SessionState::Running(running_mode) => running_mode.has_ever_stopped, } } @@ -839,7 +842,7 @@ impl Session { .detach(); let this = Self { - mode: SessionState::Building(None), + mode: SessionState::Booting(None), id: session_id, child_session_ids: HashSet::default(), parent_session, @@ -879,7 +882,7 @@ impl Session { pub fn worktree(&self) -> Option> { match &self.mode { - SessionState::Building(_) => None, + SessionState::Booting(_) => None, SessionState::Running(local_mode) => local_mode.worktree.upgrade(), } } @@ -940,14 +943,12 @@ impl Session { .await?; this.update(cx, |this, cx| { match &mut this.mode { - SessionState::Building(task) if task.is_some() => { + SessionState::Booting(task) if task.is_some() => { task.take().unwrap().detach_and_log_err(cx); } - _ => { - debug_assert!( - this.parent_session.is_some(), - "Booting a root debug session without a boot task" - ); + SessionState::Booting(_) => {} + SessionState::Running(_) => { + debug_panic!("Attempting to boot a session that is already running"); } }; this.mode = SessionState::Running(mode); @@ -1043,7 +1044,7 @@ impl Session { pub fn binary(&self) -> Option<&DebugAdapterBinary> { match &self.mode { - SessionState::Building(_) => None, + SessionState::Booting(_) => None, SessionState::Running(running_mode) => Some(&running_mode.binary), } } @@ -1089,26 +1090,26 @@ impl Session { pub fn is_started(&self) -> bool { match &self.mode { - SessionState::Building(_) => false, + SessionState::Booting(_) => false, SessionState::Running(running) => running.is_started, } } pub fn is_building(&self) -> bool { - matches!(self.mode, SessionState::Building(_)) + matches!(self.mode, SessionState::Booting(_)) } pub fn as_running_mut(&mut self) -> Option<&mut RunningMode> { match &mut self.mode { SessionState::Running(local_mode) => Some(local_mode), - SessionState::Building(_) => None, + SessionState::Booting(_) => None, } } pub fn as_running(&self) -> Option<&RunningMode> { match &self.mode { SessionState::Running(local_mode) => Some(local_mode), - SessionState::Building(_) => None, + SessionState::Booting(_) => None, } } @@ -1302,7 +1303,7 @@ impl Session { SessionState::Running(local_mode) => { local_mode.initialize_sequence(&self.capabilities, initialize_rx, dap_store, cx) } - SessionState::Building(_) => { + SessionState::Booting(_) => { Task::ready(Err(anyhow!("cannot initialize, still building"))) } } @@ -1339,7 +1340,7 @@ impl Session { }) .detach(); } - SessionState::Building(_) => {} + SessionState::Booting(_) => {} } } @@ -2145,7 +2146,7 @@ impl Session { ) } } - SessionState::Building(build_task) => { + SessionState::Booting(build_task) => { build_task.take(); Task::ready(Some(())) } @@ -2199,7 +2200,7 @@ impl Session { pub fn adapter_client(&self) -> Option> { match self.mode { SessionState::Running(ref local) => Some(local.client.clone()), - SessionState::Building(_) => None, + SessionState::Booting(_) => None, } } diff --git a/crates/project/src/git_store/git_traversal.rs b/crates/project/src/git_store/git_traversal.rs index 777042cb02..bbcffe046d 100644 --- a/crates/project/src/git_store/git_traversal.rs +++ b/crates/project/src/git_store/git_traversal.rs @@ -110,11 +110,7 @@ impl<'a> GitTraversal<'a> { } pub fn advance(&mut self) -> bool { - self.advance_by(1) - } - - pub fn advance_by(&mut self, count: usize) -> bool { - let found = self.traversal.advance_by(count); + let found = self.traversal.advance_by(1); self.synchronize_statuses(false); found } diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index f8e69e2185..c458b6b300 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -3284,6 +3284,16 @@ impl InlayHints { }) .unwrap_or(false) } + + pub fn check_capabilities(capabilities: &ServerCapabilities) -> bool { + capabilities + .inlay_hint_provider + .as_ref() + .is_some_and(|inlay_hint_provider| match inlay_hint_provider { + lsp::OneOf::Left(enabled) => *enabled, + lsp::OneOf::Right(_) => true, + }) + } } #[async_trait(?Send)] @@ -3297,17 +3307,7 @@ impl LspCommand for InlayHints { } fn check_capabilities(&self, capabilities: AdapterServerCapabilities) -> bool { - let Some(inlay_hint_provider) = &capabilities.server_capabilities.inlay_hint_provider - else { - return false; - }; - match inlay_hint_provider { - lsp::OneOf::Left(enabled) => *enabled, - lsp::OneOf::Right(inlay_hint_capabilities) => match inlay_hint_capabilities { - lsp::InlayHintServerCapabilities::Options(_) => true, - lsp::InlayHintServerCapabilities::RegistrationOptions(_) => false, - }, - } + Self::check_capabilities(&capabilities.server_capabilities) } fn to_lsp( diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 6d448a6fea..b88cf42ff5 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -140,6 +140,20 @@ impl FormatTrigger { } } +#[derive(Debug)] +pub struct DocumentDiagnosticsUpdate<'a, D> { + pub diagnostics: D, + pub result_id: Option, + pub server_id: LanguageServerId, + pub disk_based_sources: Cow<'a, [String]>, +} + +pub struct DocumentDiagnostics { + diagnostics: Vec>>, + document_abs_path: PathBuf, + version: Option, +} + pub struct LocalLspStore { weak: WeakEntity, worktree_store: Entity, @@ -503,12 +517,16 @@ impl LocalLspStore { adapter.process_diagnostics(&mut params, server_id, buffer); } - this.merge_diagnostics( - server_id, - params, - None, + this.merge_lsp_diagnostics( DiagnosticSourceKind::Pushed, - &adapter.disk_based_diagnostic_sources, + vec![DocumentDiagnosticsUpdate { + server_id, + diagnostics: params, + result_id: None, + disk_based_sources: Cow::Borrowed( + &adapter.disk_based_diagnostic_sources, + ), + }], |_, diagnostic, cx| match diagnostic.source_kind { DiagnosticSourceKind::Other | DiagnosticSourceKind::Pushed => { adapter.retain_old_diagnostic(diagnostic, cx) @@ -3610,8 +3628,8 @@ pub enum LspStoreEvent { RefreshInlayHints, RefreshCodeLens, DiagnosticsUpdated { - language_server_id: LanguageServerId, - path: ProjectPath, + server_id: LanguageServerId, + paths: Vec, }, DiskBasedDiagnosticsStarted { language_server_id: LanguageServerId, @@ -3671,7 +3689,6 @@ impl LspStore { client.add_entity_request_handler(Self::handle_apply_additional_edits_for_completion); client.add_entity_request_handler(Self::handle_register_buffer_with_language_servers); client.add_entity_request_handler(Self::handle_rename_project_entry); - client.add_entity_request_handler(Self::handle_language_server_id_for_name); client.add_entity_request_handler(Self::handle_pull_workspace_diagnostics); client.add_entity_request_handler(Self::handle_lsp_command::); client.add_entity_request_handler(Self::handle_lsp_command::); @@ -4441,17 +4458,24 @@ impl LspStore { pub(crate) fn send_diagnostic_summaries(&self, worktree: &mut Worktree) { if let Some((client, downstream_project_id)) = self.downstream_client.clone() { - if let Some(summaries) = self.diagnostic_summaries.get(&worktree.id()) { - for (path, summaries) in summaries { - for (&server_id, summary) in summaries { - client - .send(proto::UpdateDiagnosticSummary { - project_id: downstream_project_id, - worktree_id: worktree.id().to_proto(), - summary: Some(summary.to_proto(server_id, path)), - }) - .log_err(); - } + if let Some(diangostic_summaries) = self.diagnostic_summaries.get(&worktree.id()) { + let mut summaries = + diangostic_summaries + .into_iter() + .flat_map(|(path, summaries)| { + summaries + .into_iter() + .map(|(server_id, summary)| summary.to_proto(*server_id, path)) + }); + if let Some(summary) = summaries.next() { + client + .send(proto::UpdateDiagnosticSummary { + project_id: downstream_project_id, + worktree_id: worktree.id().to_proto(), + summary: Some(summary), + more_summaries: summaries.collect(), + }) + .log_err(); } } } @@ -6565,7 +6589,7 @@ impl LspStore { &mut self, buffer: Entity, cx: &mut Context, - ) -> Task>> { + ) -> Task>>> { let buffer_id = buffer.read(cx).remote_id(); if let Some((client, upstream_project_id)) = self.upstream_client() { @@ -6576,7 +6600,7 @@ impl LspStore { }, cx, ) { - return Task::ready(Ok(Vec::new())); + return Task::ready(Ok(None)); } let request_task = client.request(proto::MultiLspQuery { buffer_id: buffer_id.to_proto(), @@ -6594,7 +6618,7 @@ impl LspStore { )), }); cx.background_spawn(async move { - Ok(request_task + let _proto_responses = request_task .await? .responses .into_iter() @@ -6607,8 +6631,11 @@ impl LspStore { None } }) - .flat_map(GetDocumentDiagnostics::diagnostics_from_proto) - .collect()) + .collect::>(); + // Proto requests cause the diagnostics to be pulled from language server(s) on the local side + // and then, buffer state updated with the diagnostics received, which will be later propagated to the client. + // Do not attempt to further process the dummy responses here. + Ok(None) }) } else { let server_ids = buffer.update(cx, |buffer, cx| { @@ -6636,7 +6663,7 @@ impl LspStore { for diagnostics in join_all(pull_diagnostics).await { responses.extend(diagnostics?); } - Ok(responses) + Ok(Some(responses)) }) } } @@ -6702,75 +6729,93 @@ impl LspStore { buffer: Entity, cx: &mut Context, ) -> Task> { - let buffer_id = buffer.read(cx).remote_id(); let diagnostics = self.pull_diagnostics(buffer, cx); cx.spawn(async move |lsp_store, cx| { - let diagnostics = diagnostics.await.context("pulling diagnostics")?; + let Some(diagnostics) = diagnostics.await.context("pulling diagnostics")? else { + return Ok(()); + }; lsp_store.update(cx, |lsp_store, cx| { if lsp_store.as_local().is_none() { return; } - for diagnostics_set in diagnostics { - let LspPullDiagnostics::Response { - server_id, - uri, - diagnostics, - } = diagnostics_set - else { - continue; - }; - - let adapter = lsp_store.language_server_adapter_for_id(server_id); - let disk_based_sources = adapter - .as_ref() - .map(|adapter| adapter.disk_based_diagnostic_sources.as_slice()) - .unwrap_or(&[]); - match diagnostics { - PulledDiagnostics::Unchanged { result_id } => { - lsp_store - .merge_diagnostics( - server_id, - lsp::PublishDiagnosticsParams { - uri: uri.clone(), - diagnostics: Vec::new(), - version: None, - }, - Some(result_id), - DiagnosticSourceKind::Pulled, - disk_based_sources, - |_, _, _| true, - cx, - ) - .log_err(); - } - PulledDiagnostics::Changed { + let mut unchanged_buffers = HashSet::default(); + let mut changed_buffers = HashSet::default(); + let server_diagnostics_updates = diagnostics + .into_iter() + .filter_map(|diagnostics_set| match diagnostics_set { + LspPullDiagnostics::Response { + server_id, + uri, diagnostics, - result_id, - } => { - lsp_store - .merge_diagnostics( + } => Some((server_id, uri, diagnostics)), + LspPullDiagnostics::Default => None, + }) + .fold( + HashMap::default(), + |mut acc, (server_id, uri, diagnostics)| { + let (result_id, diagnostics) = match diagnostics { + PulledDiagnostics::Unchanged { result_id } => { + unchanged_buffers.insert(uri.clone()); + (Some(result_id), Vec::new()) + } + PulledDiagnostics::Changed { + result_id, + diagnostics, + } => { + changed_buffers.insert(uri.clone()); + (result_id, diagnostics) + } + }; + let disk_based_sources = Cow::Owned( + lsp_store + .language_server_adapter_for_id(server_id) + .as_ref() + .map(|adapter| adapter.disk_based_diagnostic_sources.as_slice()) + .unwrap_or(&[]) + .to_vec(), + ); + acc.entry(server_id).or_insert_with(Vec::new).push( + DocumentDiagnosticsUpdate { server_id, - lsp::PublishDiagnosticsParams { - uri: uri.clone(), + diagnostics: lsp::PublishDiagnosticsParams { + uri, diagnostics, version: None, }, result_id, - DiagnosticSourceKind::Pulled, disk_based_sources, - |buffer, old_diagnostic, _| match old_diagnostic.source_kind { - DiagnosticSourceKind::Pulled => { - buffer.remote_id() != buffer_id - } - DiagnosticSourceKind::Other - | DiagnosticSourceKind::Pushed => true, - }, - cx, - ) - .log_err(); - } - } + }, + ); + acc + }, + ); + + for diagnostic_updates in server_diagnostics_updates.into_values() { + lsp_store + .merge_lsp_diagnostics( + DiagnosticSourceKind::Pulled, + diagnostic_updates, + |buffer, old_diagnostic, cx| { + File::from_dyn(buffer.file()) + .and_then(|file| { + let abs_path = file.as_local()?.abs_path(cx); + lsp::Url::from_file_path(abs_path).ok() + }) + .is_none_or(|buffer_uri| { + unchanged_buffers.contains(&buffer_uri) + || match old_diagnostic.source_kind { + DiagnosticSourceKind::Pulled => { + !changed_buffers.contains(&buffer_uri) + } + DiagnosticSourceKind::Other + | DiagnosticSourceKind::Pushed => true, + } + }) + }, + cx, + ) + .log_err(); } }) }) @@ -7792,88 +7837,135 @@ impl LspStore { cx: &mut Context, ) -> anyhow::Result<()> { self.merge_diagnostic_entries( - server_id, - abs_path, - result_id, - version, - diagnostics, + vec![DocumentDiagnosticsUpdate { + diagnostics: DocumentDiagnostics { + diagnostics, + document_abs_path: abs_path, + version, + }, + result_id, + server_id, + disk_based_sources: Cow::Borrowed(&[]), + }], |_, _, _| false, cx, )?; Ok(()) } - pub fn merge_diagnostic_entries( + pub fn merge_diagnostic_entries<'a>( &mut self, - server_id: LanguageServerId, - abs_path: PathBuf, - result_id: Option, - version: Option, - mut diagnostics: Vec>>, - filter: impl Fn(&Buffer, &Diagnostic, &App) -> bool + Clone, + diagnostic_updates: Vec>, + merge: impl Fn(&Buffer, &Diagnostic, &App) -> bool + Clone, cx: &mut Context, ) -> anyhow::Result<()> { - let Some((worktree, relative_path)) = - self.worktree_store.read(cx).find_worktree(&abs_path, cx) - else { - log::warn!("skipping diagnostics update, no worktree found for path {abs_path:?}"); - return Ok(()); - }; + let mut diagnostics_summary = None::; + let mut updated_diagnostics_paths = HashMap::default(); + for mut update in diagnostic_updates { + let abs_path = &update.diagnostics.document_abs_path; + let server_id = update.server_id; + let Some((worktree, relative_path)) = + self.worktree_store.read(cx).find_worktree(abs_path, cx) + else { + log::warn!("skipping diagnostics update, no worktree found for path {abs_path:?}"); + return Ok(()); + }; - let project_path = ProjectPath { - worktree_id: worktree.read(cx).id(), - path: relative_path.into(), - }; + let worktree_id = worktree.read(cx).id(); + let project_path = ProjectPath { + worktree_id, + path: relative_path.into(), + }; - if let Some(buffer_handle) = self.buffer_store.read(cx).get_by_path(&project_path) { - let snapshot = buffer_handle.read(cx).snapshot(); - let buffer = buffer_handle.read(cx); - let reused_diagnostics = buffer - .get_diagnostics(server_id) - .into_iter() - .flat_map(|diag| { - diag.iter() - .filter(|v| filter(buffer, &v.diagnostic, cx)) - .map(|v| { - let start = Unclipped(v.range.start.to_point_utf16(&snapshot)); - let end = Unclipped(v.range.end.to_point_utf16(&snapshot)); - DiagnosticEntry { - range: start..end, - diagnostic: v.diagnostic.clone(), - } - }) - }) - .collect::>(); + if let Some(buffer_handle) = self.buffer_store.read(cx).get_by_path(&project_path) { + let snapshot = buffer_handle.read(cx).snapshot(); + let buffer = buffer_handle.read(cx); + let reused_diagnostics = buffer + .get_diagnostics(server_id) + .into_iter() + .flat_map(|diag| { + diag.iter() + .filter(|v| merge(buffer, &v.diagnostic, cx)) + .map(|v| { + let start = Unclipped(v.range.start.to_point_utf16(&snapshot)); + let end = Unclipped(v.range.end.to_point_utf16(&snapshot)); + DiagnosticEntry { + range: start..end, + diagnostic: v.diagnostic.clone(), + } + }) + }) + .collect::>(); - self.as_local_mut() - .context("cannot merge diagnostics on a remote LspStore")? - .update_buffer_diagnostics( - &buffer_handle, + self.as_local_mut() + .context("cannot merge diagnostics on a remote LspStore")? + .update_buffer_diagnostics( + &buffer_handle, + server_id, + update.result_id, + update.diagnostics.version, + update.diagnostics.diagnostics.clone(), + reused_diagnostics.clone(), + cx, + )?; + + update.diagnostics.diagnostics.extend(reused_diagnostics); + } + + let updated = worktree.update(cx, |worktree, cx| { + self.update_worktree_diagnostics( + worktree.id(), server_id, - result_id, - version, - diagnostics.clone(), - reused_diagnostics.clone(), + project_path.path.clone(), + update.diagnostics.diagnostics, cx, - )?; - - diagnostics.extend(reused_diagnostics); + ) + })?; + match updated { + ControlFlow::Continue(new_summary) => { + if let Some((project_id, new_summary)) = new_summary { + match &mut diagnostics_summary { + Some(diagnostics_summary) => { + diagnostics_summary + .more_summaries + .push(proto::DiagnosticSummary { + path: project_path.path.as_ref().to_proto(), + language_server_id: server_id.0 as u64, + error_count: new_summary.error_count, + warning_count: new_summary.warning_count, + }) + } + None => { + diagnostics_summary = Some(proto::UpdateDiagnosticSummary { + project_id: project_id, + worktree_id: worktree_id.to_proto(), + summary: Some(proto::DiagnosticSummary { + path: project_path.path.as_ref().to_proto(), + language_server_id: server_id.0 as u64, + error_count: new_summary.error_count, + warning_count: new_summary.warning_count, + }), + more_summaries: Vec::new(), + }) + } + } + } + updated_diagnostics_paths + .entry(server_id) + .or_insert_with(Vec::new) + .push(project_path); + } + ControlFlow::Break(()) => {} + } } - let updated = worktree.update(cx, |worktree, cx| { - self.update_worktree_diagnostics( - worktree.id(), - server_id, - project_path.path.clone(), - diagnostics, - cx, - ) - })?; - if updated { - cx.emit(LspStoreEvent::DiagnosticsUpdated { - language_server_id: server_id, - path: project_path, - }) + if let Some((diagnostics_summary, (downstream_client, _))) = + diagnostics_summary.zip(self.downstream_client.as_ref()) + { + downstream_client.send(diagnostics_summary).log_err(); + } + for (server_id, paths) in updated_diagnostics_paths { + cx.emit(LspStoreEvent::DiagnosticsUpdated { server_id, paths }); } Ok(()) } @@ -7882,10 +7974,10 @@ impl LspStore { &mut self, worktree_id: WorktreeId, server_id: LanguageServerId, - worktree_path: Arc, + path_in_worktree: Arc, diagnostics: Vec>>, _: &mut Context, - ) -> Result { + ) -> Result>> { let local = match &mut self.mode { LspStoreMode::Local(local_lsp_store) => local_lsp_store, _ => anyhow::bail!("update_worktree_diagnostics called on remote"), @@ -7893,7 +7985,9 @@ impl LspStore { let summaries_for_tree = self.diagnostic_summaries.entry(worktree_id).or_default(); let diagnostics_for_tree = local.diagnostics.entry(worktree_id).or_default(); - let summaries_by_server_id = summaries_for_tree.entry(worktree_path.clone()).or_default(); + let summaries_by_server_id = summaries_for_tree + .entry(path_in_worktree.clone()) + .or_default(); let old_summary = summaries_by_server_id .remove(&server_id) @@ -7901,18 +7995,19 @@ impl LspStore { let new_summary = DiagnosticSummary::new(&diagnostics); if new_summary.is_empty() { - if let Some(diagnostics_by_server_id) = diagnostics_for_tree.get_mut(&worktree_path) { + if let Some(diagnostics_by_server_id) = diagnostics_for_tree.get_mut(&path_in_worktree) + { if let Ok(ix) = diagnostics_by_server_id.binary_search_by_key(&server_id, |e| e.0) { diagnostics_by_server_id.remove(ix); } if diagnostics_by_server_id.is_empty() { - diagnostics_for_tree.remove(&worktree_path); + diagnostics_for_tree.remove(&path_in_worktree); } } } else { summaries_by_server_id.insert(server_id, new_summary); let diagnostics_by_server_id = diagnostics_for_tree - .entry(worktree_path.clone()) + .entry(path_in_worktree.clone()) .or_default(); match diagnostics_by_server_id.binary_search_by_key(&server_id, |e| e.0) { Ok(ix) => { @@ -7925,23 +8020,22 @@ impl LspStore { } if !old_summary.is_empty() || !new_summary.is_empty() { - if let Some((downstream_client, project_id)) = &self.downstream_client { - downstream_client - .send(proto::UpdateDiagnosticSummary { - project_id: *project_id, - worktree_id: worktree_id.to_proto(), - summary: Some(proto::DiagnosticSummary { - path: worktree_path.to_proto(), - language_server_id: server_id.0 as u64, - error_count: new_summary.error_count as u32, - warning_count: new_summary.warning_count as u32, - }), - }) - .log_err(); + if let Some((_, project_id)) = &self.downstream_client { + Ok(ControlFlow::Continue(Some(( + *project_id, + proto::DiagnosticSummary { + path: path_in_worktree.to_proto(), + language_server_id: server_id.0 as u64, + error_count: new_summary.error_count as u32, + warning_count: new_summary.warning_count as u32, + }, + )))) + } else { + Ok(ControlFlow::Continue(None)) } + } else { + Ok(ControlFlow::Break(())) } - - Ok(!old_summary.is_empty() || !new_summary.is_empty()) } pub fn open_buffer_for_symbol( @@ -8745,34 +8839,6 @@ impl LspStore { Ok(proto::Ack {}) } - async fn handle_language_server_id_for_name( - lsp_store: Entity, - envelope: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result { - let name = &envelope.payload.name; - let buffer_id = BufferId::new(envelope.payload.buffer_id)?; - lsp_store - .update(&mut cx, |lsp_store, cx| { - let buffer = lsp_store.buffer_store.read(cx).get_existing(buffer_id)?; - let server_id = buffer.update(cx, |buffer, cx| { - lsp_store - .language_servers_for_local_buffer(buffer, cx) - .find_map(|(adapter, server)| { - if adapter.name.0.as_ref() == name { - Some(server.server_id()) - } else { - None - } - }) - }); - Ok(server_id) - })? - .map(|server_id| proto::LanguageServerIdForNameResponse { - server_id: server_id.map(|id| id.to_proto()), - }) - } - async fn handle_rename_project_entry( this: Entity, envelope: TypedEnvelope, @@ -8822,23 +8888,30 @@ impl LspStore { envelope: TypedEnvelope, mut cx: AsyncApp, ) -> Result<()> { - this.update(&mut cx, |this, cx| { + this.update(&mut cx, |lsp_store, cx| { let worktree_id = WorktreeId::from_proto(envelope.payload.worktree_id); - if let Some(message) = envelope.payload.summary { + let mut updated_diagnostics_paths = HashMap::default(); + let mut diagnostics_summary = None::; + for message_summary in envelope + .payload + .summary + .into_iter() + .chain(envelope.payload.more_summaries) + { let project_path = ProjectPath { worktree_id, - path: Arc::::from_proto(message.path), + path: Arc::::from_proto(message_summary.path), }; let path = project_path.path.clone(); - let server_id = LanguageServerId(message.language_server_id as usize); + let server_id = LanguageServerId(message_summary.language_server_id as usize); let summary = DiagnosticSummary { - error_count: message.error_count as usize, - warning_count: message.warning_count as usize, + error_count: message_summary.error_count as usize, + warning_count: message_summary.warning_count as usize, }; if summary.is_empty() { if let Some(worktree_summaries) = - this.diagnostic_summaries.get_mut(&worktree_id) + lsp_store.diagnostic_summaries.get_mut(&worktree_id) { if let Some(summaries) = worktree_summaries.get_mut(&path) { summaries.remove(&server_id); @@ -8848,31 +8921,55 @@ impl LspStore { } } } else { - this.diagnostic_summaries + lsp_store + .diagnostic_summaries .entry(worktree_id) .or_default() .entry(path) .or_default() .insert(server_id, summary); } - if let Some((downstream_client, project_id)) = &this.downstream_client { - downstream_client - .send(proto::UpdateDiagnosticSummary { - project_id: *project_id, - worktree_id: worktree_id.to_proto(), - summary: Some(proto::DiagnosticSummary { - path: project_path.path.as_ref().to_proto(), - language_server_id: server_id.0 as u64, - error_count: summary.error_count as u32, - warning_count: summary.warning_count as u32, - }), - }) - .log_err(); + + if let Some((_, project_id)) = &lsp_store.downstream_client { + match &mut diagnostics_summary { + Some(diagnostics_summary) => { + diagnostics_summary + .more_summaries + .push(proto::DiagnosticSummary { + path: project_path.path.as_ref().to_proto(), + language_server_id: server_id.0 as u64, + error_count: summary.error_count as u32, + warning_count: summary.warning_count as u32, + }) + } + None => { + diagnostics_summary = Some(proto::UpdateDiagnosticSummary { + project_id: *project_id, + worktree_id: worktree_id.to_proto(), + summary: Some(proto::DiagnosticSummary { + path: project_path.path.as_ref().to_proto(), + language_server_id: server_id.0 as u64, + error_count: summary.error_count as u32, + warning_count: summary.warning_count as u32, + }), + more_summaries: Vec::new(), + }) + } + } } - cx.emit(LspStoreEvent::DiagnosticsUpdated { - language_server_id: LanguageServerId(message.language_server_id as usize), - path: project_path, - }); + updated_diagnostics_paths + .entry(server_id) + .or_insert_with(Vec::new) + .push(project_path); + } + + if let Some((diagnostics_summary, (downstream_client, _))) = + diagnostics_summary.zip(lsp_store.downstream_client.as_ref()) + { + downstream_client.send(diagnostics_summary).log_err(); + } + for (server_id, paths) in updated_diagnostics_paths { + cx.emit(LspStoreEvent::DiagnosticsUpdated { server_id, paths }); } Ok(()) })? @@ -10390,6 +10487,7 @@ impl LspStore { error_count: 0, warning_count: 0, }), + more_summaries: Vec::new(), }) .log_err(); } @@ -10678,52 +10776,80 @@ impl LspStore { ) } + #[cfg(any(test, feature = "test-support"))] pub fn update_diagnostics( &mut self, - language_server_id: LanguageServerId, - params: lsp::PublishDiagnosticsParams, + server_id: LanguageServerId, + diagnostics: lsp::PublishDiagnosticsParams, result_id: Option, source_kind: DiagnosticSourceKind, disk_based_sources: &[String], cx: &mut Context, ) -> Result<()> { - self.merge_diagnostics( - language_server_id, - params, - result_id, + self.merge_lsp_diagnostics( source_kind, - disk_based_sources, + vec![DocumentDiagnosticsUpdate { + diagnostics, + result_id, + server_id, + disk_based_sources: Cow::Borrowed(disk_based_sources), + }], |_, _, _| false, cx, ) } - pub fn merge_diagnostics( + pub fn merge_lsp_diagnostics( &mut self, - language_server_id: LanguageServerId, - mut params: lsp::PublishDiagnosticsParams, - result_id: Option, source_kind: DiagnosticSourceKind, - disk_based_sources: &[String], - filter: impl Fn(&Buffer, &Diagnostic, &App) -> bool + Clone, + lsp_diagnostics: Vec>, + merge: impl Fn(&Buffer, &Diagnostic, &App) -> bool + Clone, cx: &mut Context, ) -> Result<()> { anyhow::ensure!(self.mode.is_local(), "called update_diagnostics on remote"); - let abs_path = params - .uri - .to_file_path() - .map_err(|()| anyhow!("URI is not a file"))?; + let updates = lsp_diagnostics + .into_iter() + .filter_map(|update| { + let abs_path = update.diagnostics.uri.to_file_path().ok()?; + Some(DocumentDiagnosticsUpdate { + diagnostics: self.lsp_to_document_diagnostics( + abs_path, + source_kind, + update.server_id, + update.diagnostics, + &update.disk_based_sources, + ), + result_id: update.result_id, + server_id: update.server_id, + disk_based_sources: update.disk_based_sources, + }) + }) + .collect(); + self.merge_diagnostic_entries(updates, merge, cx)?; + Ok(()) + } + + fn lsp_to_document_diagnostics( + &mut self, + document_abs_path: PathBuf, + source_kind: DiagnosticSourceKind, + server_id: LanguageServerId, + mut lsp_diagnostics: lsp::PublishDiagnosticsParams, + disk_based_sources: &[String], + ) -> DocumentDiagnostics { let mut diagnostics = Vec::default(); let mut primary_diagnostic_group_ids = HashMap::default(); let mut sources_by_group_id = HashMap::default(); let mut supporting_diagnostics = HashMap::default(); - let adapter = self.language_server_adapter_for_id(language_server_id); + let adapter = self.language_server_adapter_for_id(server_id); // Ensure that primary diagnostics are always the most severe - params.diagnostics.sort_by_key(|item| item.severity); + lsp_diagnostics + .diagnostics + .sort_by_key(|item| item.severity); - for diagnostic in ¶ms.diagnostics { + for diagnostic in &lsp_diagnostics.diagnostics { let source = diagnostic.source.as_ref(); let range = range_from_lsp(diagnostic.range); let is_supporting = diagnostic @@ -10745,7 +10871,7 @@ impl LspStore { .map_or(false, |tags| tags.contains(&DiagnosticTag::UNNECESSARY)); let underline = self - .language_server_adapter_for_id(language_server_id) + .language_server_adapter_for_id(server_id) .map_or(true, |adapter| adapter.underline_diagnostic(diagnostic)); if is_supporting { @@ -10787,7 +10913,7 @@ impl LspStore { }); if let Some(infos) = &diagnostic.related_information { for info in infos { - if info.location.uri == params.uri && !info.message.is_empty() { + if info.location.uri == lsp_diagnostics.uri && !info.message.is_empty() { let range = range_from_lsp(info.location.range); diagnostics.push(DiagnosticEntry { range, @@ -10835,16 +10961,11 @@ impl LspStore { } } - self.merge_diagnostic_entries( - language_server_id, - abs_path, - result_id, - params.version, + DocumentDiagnostics { diagnostics, - filter, - cx, - )?; - Ok(()) + document_abs_path, + version: lsp_diagnostics.version, + } } fn insert_newly_running_language_server( @@ -11600,67 +11721,84 @@ impl LspStore { ) { let workspace_diagnostics = GetDocumentDiagnostics::deserialize_workspace_diagnostics_report(report, server_id); - for workspace_diagnostics in workspace_diagnostics { - let LspPullDiagnostics::Response { - server_id, - uri, - diagnostics, - } = workspace_diagnostics.diagnostics - else { - continue; - }; - - let adapter = self.language_server_adapter_for_id(server_id); - let disk_based_sources = adapter - .as_ref() - .map(|adapter| adapter.disk_based_diagnostic_sources.as_slice()) - .unwrap_or(&[]); - - match diagnostics { - PulledDiagnostics::Unchanged { result_id } => { - self.merge_diagnostics( + let mut unchanged_buffers = HashSet::default(); + let mut changed_buffers = HashSet::default(); + let workspace_diagnostics_updates = workspace_diagnostics + .into_iter() + .filter_map( + |workspace_diagnostics| match workspace_diagnostics.diagnostics { + LspPullDiagnostics::Response { server_id, - lsp::PublishDiagnosticsParams { - uri: uri.clone(), - diagnostics: Vec::new(), - version: None, - }, - Some(result_id), - DiagnosticSourceKind::Pulled, - disk_based_sources, - |_, _, _| true, - cx, - ) - .log_err(); - } - PulledDiagnostics::Changed { - diagnostics, - result_id, - } => { - self.merge_diagnostics( - server_id, - lsp::PublishDiagnosticsParams { - uri: uri.clone(), + uri, + diagnostics, + } => Some((server_id, uri, diagnostics, workspace_diagnostics.version)), + LspPullDiagnostics::Default => None, + }, + ) + .fold( + HashMap::default(), + |mut acc, (server_id, uri, diagnostics, version)| { + let (result_id, diagnostics) = match diagnostics { + PulledDiagnostics::Unchanged { result_id } => { + unchanged_buffers.insert(uri.clone()); + (Some(result_id), Vec::new()) + } + PulledDiagnostics::Changed { + result_id, diagnostics, - version: workspace_diagnostics.version, - }, - result_id, - DiagnosticSourceKind::Pulled, - disk_based_sources, - |buffer, old_diagnostic, cx| match old_diagnostic.source_kind { - DiagnosticSourceKind::Pulled => { - let buffer_url = File::from_dyn(buffer.file()) - .map(|f| f.abs_path(cx)) - .and_then(|abs_path| file_path_to_lsp_url(&abs_path).ok()); - buffer_url.is_none_or(|buffer_url| buffer_url != uri) - } - DiagnosticSourceKind::Other | DiagnosticSourceKind::Pushed => true, - }, - cx, - ) - .log_err(); - } - } + } => { + changed_buffers.insert(uri.clone()); + (result_id, diagnostics) + } + }; + let disk_based_sources = Cow::Owned( + self.language_server_adapter_for_id(server_id) + .as_ref() + .map(|adapter| adapter.disk_based_diagnostic_sources.as_slice()) + .unwrap_or(&[]) + .to_vec(), + ); + acc.entry(server_id) + .or_insert_with(Vec::new) + .push(DocumentDiagnosticsUpdate { + server_id, + diagnostics: lsp::PublishDiagnosticsParams { + uri, + diagnostics, + version, + }, + result_id, + disk_based_sources, + }); + acc + }, + ); + + for diagnostic_updates in workspace_diagnostics_updates.into_values() { + self.merge_lsp_diagnostics( + DiagnosticSourceKind::Pulled, + diagnostic_updates, + |buffer, old_diagnostic, cx| { + File::from_dyn(buffer.file()) + .and_then(|file| { + let abs_path = file.as_local()?.abs_path(cx); + lsp::Url::from_file_path(abs_path).ok() + }) + .is_none_or(|buffer_uri| { + unchanged_buffers.contains(&buffer_uri) + || match old_diagnostic.source_kind { + DiagnosticSourceKind::Pulled => { + !changed_buffers.contains(&buffer_uri) + } + DiagnosticSourceKind::Other | DiagnosticSourceKind::Pushed => { + true + } + } + }) + }, + cx, + ) + .log_err(); } } } diff --git a/crates/project/src/lsp_store/clangd_ext.rs b/crates/project/src/lsp_store/clangd_ext.rs index 6a09bb99b4..274b1b8980 100644 --- a/crates/project/src/lsp_store/clangd_ext.rs +++ b/crates/project/src/lsp_store/clangd_ext.rs @@ -1,14 +1,14 @@ -use std::sync::Arc; +use std::{borrow::Cow, sync::Arc}; use ::serde::{Deserialize, Serialize}; use gpui::WeakEntity; use language::{CachedLspAdapter, Diagnostic, DiagnosticSourceKind}; -use lsp::LanguageServer; +use lsp::{LanguageServer, LanguageServerName}; use util::ResultExt as _; -use crate::LspStore; +use crate::{LspStore, lsp_store::DocumentDiagnosticsUpdate}; -pub const CLANGD_SERVER_NAME: &str = "clangd"; +pub const CLANGD_SERVER_NAME: LanguageServerName = LanguageServerName::new_static("clangd"); const INACTIVE_REGION_MESSAGE: &str = "inactive region"; const INACTIVE_DIAGNOSTIC_SEVERITY: lsp::DiagnosticSeverity = lsp::DiagnosticSeverity::INFORMATION; @@ -34,7 +34,7 @@ pub fn is_inactive_region(diag: &Diagnostic) -> bool { && diag .source .as_ref() - .is_some_and(|v| v == CLANGD_SERVER_NAME) + .is_some_and(|v| v == &CLANGD_SERVER_NAME.0) } pub fn is_lsp_inactive_region(diag: &lsp::Diagnostic) -> bool { @@ -43,7 +43,7 @@ pub fn is_lsp_inactive_region(diag: &lsp::Diagnostic) -> bool { && diag .source .as_ref() - .is_some_and(|v| v == CLANGD_SERVER_NAME) + .is_some_and(|v| v == &CLANGD_SERVER_NAME.0) } pub fn register_notifications( @@ -51,7 +51,7 @@ pub fn register_notifications( language_server: &LanguageServer, adapter: Arc, ) { - if language_server.name().0 != CLANGD_SERVER_NAME { + if language_server.name() != CLANGD_SERVER_NAME { return; } let server_id = language_server.server_id(); @@ -81,12 +81,16 @@ pub fn register_notifications( version: params.text_document.version, diagnostics, }; - this.merge_diagnostics( - server_id, - mapped_diagnostics, - None, + this.merge_lsp_diagnostics( DiagnosticSourceKind::Pushed, - &adapter.disk_based_diagnostic_sources, + vec![DocumentDiagnosticsUpdate { + server_id, + diagnostics: mapped_diagnostics, + result_id: None, + disk_based_sources: Cow::Borrowed( + &adapter.disk_based_diagnostic_sources, + ), + }], |_, diag, _| !is_inactive_region(diag), cx, ) diff --git a/crates/project/src/lsp_store/rust_analyzer_ext.rs b/crates/project/src/lsp_store/rust_analyzer_ext.rs index d78715d385..6c425717a8 100644 --- a/crates/project/src/lsp_store/rust_analyzer_ext.rs +++ b/crates/project/src/lsp_store/rust_analyzer_ext.rs @@ -2,12 +2,12 @@ use ::serde::{Deserialize, Serialize}; use anyhow::Context as _; use gpui::{App, Entity, Task, WeakEntity}; use language::ServerHealth; -use lsp::LanguageServer; +use lsp::{LanguageServer, LanguageServerName}; use rpc::proto; use crate::{LspStore, LspStoreEvent, Project, ProjectPath, lsp_store}; -pub const RUST_ANALYZER_NAME: &str = "rust-analyzer"; +pub const RUST_ANALYZER_NAME: LanguageServerName = LanguageServerName::new_static("rust-analyzer"); pub const CARGO_DIAGNOSTICS_SOURCE_NAME: &str = "rustc"; /// Experimental: Informs the end user about the state of the server @@ -97,13 +97,9 @@ pub fn cancel_flycheck( cx.spawn(async move |cx| { let buffer = buffer.await?; - let Some(rust_analyzer_server) = project - .update(cx, |project, cx| { - buffer.update(cx, |buffer, cx| { - project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx) - }) - })? - .await + let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + })? else { return Ok(()); }; @@ -148,13 +144,9 @@ pub fn run_flycheck( cx.spawn(async move |cx| { let buffer = buffer.await?; - let Some(rust_analyzer_server) = project - .update(cx, |project, cx| { - buffer.update(cx, |buffer, cx| { - project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx) - }) - })? - .await + let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + })? else { return Ok(()); }; @@ -204,13 +196,9 @@ pub fn clear_flycheck( cx.spawn(async move |cx| { let buffer = buffer.await?; - let Some(rust_analyzer_server) = project - .update(cx, |project, cx| { - buffer.update(cx, |buffer, cx| { - project.language_server_id_for_name(buffer, RUST_ANALYZER_NAME, cx) - }) - })? - .await + let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + })? else { return Ok(()); }; diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 398e8bde87..b3a9e6fdf5 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -73,11 +73,10 @@ use gpui::{ App, AppContext, AsyncApp, BorrowAppContext, Context, Entity, EventEmitter, Hsla, SharedString, Task, WeakEntity, Window, }; -use itertools::Itertools; use language::{ - Buffer, BufferEvent, Capability, CodeLabel, CursorShape, DiagnosticSourceKind, Language, - LanguageName, LanguageRegistry, PointUtf16, ToOffset, ToPointUtf16, Toolchain, ToolchainList, - Transaction, Unclipped, language_settings::InlayHintKind, proto::split_operations, + Buffer, BufferEvent, Capability, CodeLabel, CursorShape, Language, LanguageName, + LanguageRegistry, PointUtf16, ToOffset, ToPointUtf16, Toolchain, ToolchainList, Transaction, + Unclipped, language_settings::InlayHintKind, proto::split_operations, }; use lsp::{ CodeActionKind, CompletionContext, CompletionItemKind, DocumentHighlightKind, InsertTextMode, @@ -113,7 +112,7 @@ use std::{ use task_store::TaskStore; use terminals::Terminals; -use text::{Anchor, BufferId, OffsetRangeExt, Point}; +use text::{Anchor, BufferId, OffsetRangeExt, Point, Rope}; use toolchain_store::EmptyToolchainStore; use util::{ ResultExt as _, @@ -306,7 +305,7 @@ pub enum Event { language_server_id: LanguageServerId, }, DiagnosticsUpdated { - path: ProjectPath, + paths: Vec, language_server_id: LanguageServerId, }, RemoteIdChanged(Option), @@ -668,10 +667,10 @@ pub enum ResolveState { } impl InlayHint { - pub fn text(&self) -> String { + pub fn text(&self) -> Rope { match &self.label { - InlayHintLabel::String(s) => s.to_owned(), - InlayHintLabel::LabelParts(parts) => parts.iter().map(|part| &part.value).join(""), + InlayHintLabel::String(s) => Rope::from(s), + InlayHintLabel::LabelParts(parts) => parts.iter().map(|part| &*part.value).collect(), } } } @@ -2896,18 +2895,17 @@ impl Project { cx: &mut Context, ) { match event { - LspStoreEvent::DiagnosticsUpdated { - language_server_id, - path, - } => cx.emit(Event::DiagnosticsUpdated { - path: path.clone(), - language_server_id: *language_server_id, - }), - LspStoreEvent::LanguageServerAdded(language_server_id, name, worktree_id) => cx.emit( - Event::LanguageServerAdded(*language_server_id, name.clone(), *worktree_id), + LspStoreEvent::DiagnosticsUpdated { server_id, paths } => { + cx.emit(Event::DiagnosticsUpdated { + paths: paths.clone(), + language_server_id: *server_id, + }) + } + LspStoreEvent::LanguageServerAdded(server_id, name, worktree_id) => cx.emit( + Event::LanguageServerAdded(*server_id, name.clone(), *worktree_id), ), - LspStoreEvent::LanguageServerRemoved(language_server_id) => { - cx.emit(Event::LanguageServerRemoved(*language_server_id)) + LspStoreEvent::LanguageServerRemoved(server_id) => { + cx.emit(Event::LanguageServerRemoved(*server_id)) } LspStoreEvent::LanguageServerLog(server_id, log_type, string) => cx.emit( Event::LanguageServerLog(*server_id, log_type.clone(), string.clone()), @@ -3830,27 +3828,6 @@ impl Project { }) } - pub fn update_diagnostics( - &mut self, - language_server_id: LanguageServerId, - source_kind: DiagnosticSourceKind, - result_id: Option, - params: lsp::PublishDiagnosticsParams, - disk_based_sources: &[String], - cx: &mut Context, - ) -> Result<(), anyhow::Error> { - self.lsp_store.update(cx, |lsp_store, cx| { - lsp_store.update_diagnostics( - language_server_id, - params, - result_id, - source_kind, - disk_based_sources, - cx, - ) - }) - } - pub fn search(&mut self, query: SearchQuery, cx: &mut Context) -> Receiver { let (result_tx, result_rx) = smol::channel::unbounded(); @@ -5002,63 +4979,53 @@ impl Project { } pub fn any_language_server_supports_inlay_hints(&self, buffer: &Buffer, cx: &mut App) -> bool { - self.lsp_store.update(cx, |this, cx| { - this.language_servers_for_local_buffer(buffer, cx) - .any( - |(_, server)| match server.capabilities().inlay_hint_provider { - Some(lsp::OneOf::Left(enabled)) => enabled, - Some(lsp::OneOf::Right(_)) => true, - None => false, - }, - ) + let Some(language) = buffer.language().cloned() else { + return false; + }; + self.lsp_store.update(cx, |lsp_store, _| { + let relevant_language_servers = lsp_store + .languages + .lsp_adapters(&language.name()) + .into_iter() + .map(|lsp_adapter| lsp_adapter.name()) + .collect::>(); + lsp_store + .language_server_statuses() + .filter_map(|(server_id, server_status)| { + relevant_language_servers + .contains(&server_status.name) + .then_some(server_id) + }) + .filter_map(|server_id| lsp_store.lsp_server_capabilities.get(&server_id)) + .any(InlayHints::check_capabilities) }) } pub fn language_server_id_for_name( &self, buffer: &Buffer, - name: &str, - cx: &mut App, - ) -> Task> { - if self.is_local() { - Task::ready(self.lsp_store.update(cx, |lsp_store, cx| { - lsp_store - .language_servers_for_local_buffer(buffer, cx) - .find_map(|(adapter, server)| { - if adapter.name.0 == name { - Some(server.server_id()) - } else { - None - } - }) - })) - } else if let Some(project_id) = self.remote_id() { - let request = self.client.request(proto::LanguageServerIdForName { - project_id, - buffer_id: buffer.remote_id().to_proto(), - name: name.to_string(), - }); - cx.background_spawn(async move { - let response = request.await.log_err()?; - response.server_id.map(LanguageServerId::from_proto) - }) - } else if let Some(ssh_client) = self.ssh_client.as_ref() { - let request = - ssh_client - .read(cx) - .proto_client() - .request(proto::LanguageServerIdForName { - project_id: SSH_PROJECT_ID, - buffer_id: buffer.remote_id().to_proto(), - name: name.to_string(), - }); - cx.background_spawn(async move { - let response = request.await.log_err()?; - response.server_id.map(LanguageServerId::from_proto) - }) - } else { - Task::ready(None) + name: &LanguageServerName, + cx: &App, + ) -> Option { + let language = buffer.language()?; + let relevant_language_servers = self + .languages + .lsp_adapters(&language.name()) + .into_iter() + .map(|lsp_adapter| lsp_adapter.name()) + .collect::>(); + if !relevant_language_servers.contains(name) { + return None; } + self.language_server_statuses(cx) + .filter(|(_, server_status)| relevant_language_servers.contains(&server_status.name)) + .find_map(|(server_id, server_status)| { + if &server_status.name == name { + Some(server_id) + } else { + None + } + }) } pub fn has_language_servers_for(&self, buffer: &Buffer, cx: &mut App) -> bool { diff --git a/crates/project/src/project_settings.rs b/crates/project/src/project_settings.rs index 20be7fef85..12e3aa88ad 100644 --- a/crates/project/src/project_settings.rs +++ b/crates/project/src/project_settings.rs @@ -431,10 +431,9 @@ impl GitSettings { pub fn inline_blame_delay(&self) -> Option { match self.inline_blame { - Some(InlineBlameSettings { - delay_ms: Some(delay_ms), - .. - }) if delay_ms > 0 => Some(Duration::from_millis(delay_ms)), + Some(InlineBlameSettings { delay_ms, .. }) if delay_ms > 0 => { + Some(Duration::from_millis(delay_ms)) + } _ => None, } } @@ -470,7 +469,7 @@ pub enum GitGutterSetting { Hide, } -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub struct InlineBlameSettings { /// Whether or not to show git blame data inline in @@ -483,11 +482,19 @@ pub struct InlineBlameSettings { /// after a delay once the cursor stops moving. /// /// Default: 0 - pub delay_ms: Option, + #[serde(default)] + pub delay_ms: u64, + /// The amount of padding between the end of the source line and the start + /// of the inline blame in units of columns. + /// + /// Default: 7 + #[serde(default = "default_inline_blame_padding")] + pub padding: u32, /// The minimum column number to show the inline blame information at /// /// Default: 0 - pub min_column: Option, + #[serde(default)] + pub min_column: u32, /// Whether to show commit summary as part of the inline blame. /// /// Default: false @@ -495,6 +502,22 @@ pub struct InlineBlameSettings { pub show_commit_summary: bool, } +fn default_inline_blame_padding() -> u32 { + 7 +} + +impl Default for InlineBlameSettings { + fn default() -> Self { + Self { + enabled: true, + delay_ms: 0, + padding: default_inline_blame_padding(), + min_column: 0, + show_commit_summary: false, + } + } +} + #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] pub struct BinarySettings { pub path: Option, diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index 75ebc8339a..cb3c9efe60 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -18,9 +18,10 @@ use git::{ use git2::RepositoryInitOptions; use gpui::{App, BackgroundExecutor, SemanticVersion, UpdateGlobal}; use http_client::Url; +use itertools::Itertools; use language::{ - Diagnostic, DiagnosticEntry, DiagnosticSet, DiskState, FakeLspAdapter, LanguageConfig, - LanguageMatcher, LanguageName, LineEnding, OffsetRangeExt, Point, ToPoint, + Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, DiskState, FakeLspAdapter, + LanguageConfig, LanguageMatcher, LanguageName, LineEnding, OffsetRangeExt, Point, ToPoint, language_settings::{AllLanguageSettings, LanguageSettingsContent, language_settings}, tree_sitter_rust, tree_sitter_typescript, }; @@ -1618,7 +1619,7 @@ async fn test_disk_based_diagnostics_progress(cx: &mut gpui::TestAppContext) { events.next().await.unwrap(), Event::DiagnosticsUpdated { language_server_id: LanguageServerId(0), - path: (worktree_id, Path::new("a.rs")).into() + paths: vec![(worktree_id, Path::new("a.rs")).into()], } ); @@ -1666,7 +1667,7 @@ async fn test_disk_based_diagnostics_progress(cx: &mut gpui::TestAppContext) { events.next().await.unwrap(), Event::DiagnosticsUpdated { language_server_id: LanguageServerId(0), - path: (worktree_id, Path::new("a.rs")).into() + paths: vec![(worktree_id, Path::new("a.rs")).into()], } ); diff --git a/crates/project/src/terminals.rs b/crates/project/src/terminals.rs index 973d4e8811..41d8c4b2fd 100644 --- a/crates/project/src/terminals.rs +++ b/crates/project/src/terminals.rs @@ -1,7 +1,7 @@ use crate::{Project, ProjectPath}; use anyhow::{Context as _, Result}; use collections::HashMap; -use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, Task, WeakEntity}; +use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity}; use itertools::Itertools; use language::LanguageName; use remote::ssh_session::SshArgs; @@ -98,7 +98,6 @@ impl Project { pub fn create_terminal( &mut self, kind: TerminalKind, - window: AnyWindowHandle, cx: &mut Context, ) -> Task>> { let path: Option> = match &kind { @@ -134,7 +133,7 @@ impl Project { None }; project.update(cx, |project, cx| { - project.create_terminal_with_venv(kind, python_venv_directory, window, cx) + project.create_terminal_with_venv(kind, python_venv_directory, cx) })? }) } @@ -209,7 +208,6 @@ impl Project { &mut self, kind: TerminalKind, python_venv_directory: Option, - window: AnyWindowHandle, cx: &mut Context, ) -> Result> { let this = &mut *self; @@ -396,7 +394,7 @@ impl Project { settings.alternate_scroll, settings.max_scroll_history_lines, is_ssh_terminal, - window, + cx.entity_id().as_u64(), completion_tx, cx, ) diff --git a/crates/project_panel/Cargo.toml b/crates/project_panel/Cargo.toml index ce5fec0b13..6ad3c4c2cd 100644 --- a/crates/project_panel/Cargo.toml +++ b/crates/project_panel/Cargo.toml @@ -19,6 +19,7 @@ command_palette_hooks.workspace = true db.workspace = true editor.workspace = true file_icons.workspace = true +git_ui.workspace = true indexmap.workspace = true git.workspace = true gpui.workspace = true @@ -40,6 +41,7 @@ worktree.workspace = true workspace.workspace = true language.workspace = true zed_actions.workspace = true +telemetry.workspace = true workspace-hack.workspace = true [dev-dependencies] diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 28d97f85fa..8b32c1157d 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -16,6 +16,7 @@ use editor::{ }; use file_icons::FileIcons; use git::status::GitSummary; +use git_ui::file_diff_view::FileDiffView; use gpui::{ Action, AnyElement, App, ArcCow, AsyncWindowContext, Bounds, ClipboardItem, Context, CursorStyle, DismissEvent, Div, DragMoveEvent, Entity, EventEmitter, ExternalPaths, @@ -43,7 +44,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore, update_settings_file}; use smallvec::SmallVec; -use std::any::TypeId; +use std::{any::TypeId, time::Instant}; use std::{ cell::OnceCell, cmp, @@ -73,6 +74,12 @@ use zed_actions::OpenRecent; const PROJECT_PANEL_KEY: &str = "ProjectPanel"; const NEW_ENTRY_ID: ProjectEntryId = ProjectEntryId::MAX; +struct VisibleEntriesForWorktree { + worktree_id: WorktreeId, + entries: Vec, + index: OnceCell>>, +} + pub struct ProjectPanel { project: Entity, fs: Arc, @@ -81,7 +88,7 @@ pub struct ProjectPanel { // An update loop that keeps incrementing/decrementing scroll offset while there is a dragged entry that's // hovered over the start/end of a list. hover_scroll_task: Option>, - visible_entries: Vec<(WorktreeId, Vec, OnceCell>>)>, + visible_entries: Vec, /// Maps from leaf project entry ID to the currently selected ancestor. /// Relevant only for auto-fold dirs, where a single project panel entry may actually consist of several /// project entries (and all non-leaf nodes are guaranteed to be directories). @@ -93,7 +100,7 @@ pub struct ProjectPanel { unfolded_dir_ids: HashSet, // Currently selected leaf entry (see auto-folding for a definition of that) in a file tree selection: Option, - marked_entries: BTreeSet, + marked_entries: Vec, context_menu: Option<(Entity, Point, Subscription)>, edit_state: Option, filename_editor: Entity, @@ -115,6 +122,7 @@ pub struct ProjectPanel { hover_expand_task: Option>, previous_drag_position: Option>, sticky_items_count: usize, + last_reported_update: Instant, } struct DragTargetEntry { @@ -284,6 +292,8 @@ actions!( SelectNextDirectory, /// Selects the previous directory. SelectPrevDirectory, + /// Opens a diff view to compare two marked files. + CompareMarkedFiles, ] ); @@ -381,7 +391,7 @@ struct DraggedProjectEntryView { selection: SelectedEntry, details: EntryDetails, click_offset: Point, - selections: Arc>, + selections: Arc<[SelectedEntry]>, } struct ItemColors { @@ -447,7 +457,15 @@ impl ProjectPanel { } } project::Event::ActiveEntryChanged(None) => { - this.marked_entries.clear(); + let is_active_item_file_diff_view = this + .workspace + .upgrade() + .and_then(|ws| ws.read(cx).active_item(cx)) + .map(|item| item.act_as_type(TypeId::of::(), cx).is_some()) + .unwrap_or(false); + if !is_active_item_file_diff_view { + this.marked_entries.clear(); + } } project::Event::RevealInProjectPanel(entry_id) => { if let Some(()) = this @@ -625,6 +643,7 @@ impl ProjectPanel { hover_expand_task: None, previous_drag_position: None, sticky_items_count: 0, + last_reported_update: Instant::now(), }; this.update_visible_entries(None, cx); @@ -681,7 +700,7 @@ impl ProjectPanel { project_panel.update(cx, |project_panel, _| { let entry = SelectedEntry { worktree_id, entry_id }; project_panel.marked_entries.clear(); - project_panel.marked_entries.insert(entry); + project_panel.marked_entries.push(entry); project_panel.selection = Some(entry); }); if !focus_opened_item { @@ -894,6 +913,7 @@ impl ProjectPanel { let should_hide_rename = is_root && (cfg!(target_os = "windows") || (settings.hide_root && visible_worktrees_count == 1)); + let should_show_compare = !is_dir && self.file_abs_paths_to_diff(cx).is_some(); let context_menu = ContextMenu::build(window, cx, |menu, _, _| { menu.context(self.focus_handle.clone()).map(|menu| { @@ -925,6 +945,10 @@ impl ProjectPanel { .when(is_foldable, |menu| { menu.action("Fold Directory", Box::new(FoldDirectory)) }) + .when(should_show_compare, |menu| { + menu.separator() + .action("Compare marked files", Box::new(CompareMarkedFiles)) + }) .separator() .action("Cut", Box::new(Cut)) .action("Copy", Box::new(Copy)) @@ -1257,19 +1281,23 @@ impl ProjectPanel { entry_ix -= 1; } else if worktree_ix > 0 { worktree_ix -= 1; - entry_ix = self.visible_entries[worktree_ix].1.len() - 1; + entry_ix = self.visible_entries[worktree_ix].entries.len() - 1; } else { return; } - let (worktree_id, worktree_entries, _) = &self.visible_entries[worktree_ix]; + let VisibleEntriesForWorktree { + worktree_id, + entries, + .. + } = &self.visible_entries[worktree_ix]; let selection = SelectedEntry { worktree_id: *worktree_id, - entry_id: worktree_entries[entry_ix].id, + entry_id: entries[entry_ix].id, }; self.selection = Some(selection); if window.modifiers().shift { - self.marked_entries.insert(selection); + self.marked_entries.push(selection); } self.autoscroll(cx); cx.notify(); @@ -1295,31 +1323,34 @@ impl ProjectPanel { window: &mut Window, cx: &mut Context, ) { - if let Some((_, entry)) = self.selected_entry(cx) { - if entry.is_file() { - self.split_entry(entry.id, Some(SplitDirection::Right), cx); - cx.notify(); - } else { - self.toggle_expanded(entry.id, window, cx); - } - } + self.open_split_internal(Some(SplitDirection::Right), window, cx); } fn open_split_up(&mut self, _: &OpenSplitUp, window: &mut Window, cx: &mut Context) { - if let Some((_, entry)) = self.selected_entry(cx) { - if entry.is_file() { - self.split_entry(entry.id, Some(SplitDirection::Up), cx); - cx.notify(); - } else { - self.toggle_expanded(entry.id, window, cx); - } - } + self.open_split_internal(Some(SplitDirection::Up), window, cx); } fn open_permanent(&mut self, _: &OpenPermanent, window: &mut Window, cx: &mut Context) { self.open_internal(false, true, window, cx); } + fn open_split_internal( + &mut self, + split_direction: Option, + window: &mut Window, + cx: &mut Context, + ) { + let split_direction = split_direction.or(Some(SplitDirection::Right)); + if let Some((_, entry)) = self.selected_entry(cx) { + if entry.is_file() { + self.split_entry(entry.id, split_direction, cx); + cx.notify(); + } else { + self.toggle_expanded(entry.id, window, cx); + } + } + } + fn open_internal( &mut self, allow_preview: bool, @@ -2031,7 +2062,9 @@ impl ProjectPanel { if let Some(selection) = self.selection { let (mut worktree_ix, mut entry_ix, _) = self.index_for_selection(selection).unwrap_or_default(); - if let Some((_, worktree_entries, _)) = self.visible_entries.get(worktree_ix) { + if let Some(worktree_entries) = + self.visible_entries.get(worktree_ix).map(|v| &v.entries) + { if entry_ix + 1 < worktree_entries.len() { entry_ix += 1; } else { @@ -2040,16 +2073,20 @@ impl ProjectPanel { } } - if let Some((worktree_id, worktree_entries, _)) = self.visible_entries.get(worktree_ix) + if let Some(VisibleEntriesForWorktree { + worktree_id, + entries, + .. + }) = self.visible_entries.get(worktree_ix) { - if let Some(entry) = worktree_entries.get(entry_ix) { + if let Some(entry) = entries.get(entry_ix) { let selection = SelectedEntry { worktree_id: *worktree_id, entry_id: entry.id, }; self.selection = Some(selection); if window.modifiers().shift { - self.marked_entries.insert(selection); + self.marked_entries.push(selection); } self.autoscroll(cx); @@ -2278,15 +2315,20 @@ impl ProjectPanel { } fn select_first(&mut self, _: &SelectFirst, window: &mut Window, cx: &mut Context) { - if let Some((worktree_id, visible_worktree_entries, _)) = self.visible_entries.first() { - if let Some(entry) = visible_worktree_entries.first() { + if let Some(VisibleEntriesForWorktree { + worktree_id, + entries, + .. + }) = self.visible_entries.first() + { + if let Some(entry) = entries.first() { let selection = SelectedEntry { worktree_id: *worktree_id, entry_id: entry.id, }; self.selection = Some(selection); if window.modifiers().shift { - self.marked_entries.insert(selection); + self.marked_entries.push(selection); } self.autoscroll(cx); cx.notify(); @@ -2295,9 +2337,14 @@ impl ProjectPanel { } fn select_last(&mut self, _: &SelectLast, _: &mut Window, cx: &mut Context) { - if let Some((worktree_id, visible_worktree_entries, _)) = self.visible_entries.last() { + if let Some(VisibleEntriesForWorktree { + worktree_id, + entries, + .. + }) = self.visible_entries.last() + { let worktree = self.project.read(cx).worktree_for_id(*worktree_id, cx); - if let (Some(worktree), Some(entry)) = (worktree, visible_worktree_entries.last()) { + if let (Some(worktree), Some(entry)) = (worktree, entries.last()) { let worktree = worktree.read(cx); if let Some(entry) = worktree.entry_for_id(entry.id) { let selection = SelectedEntry { @@ -2614,6 +2661,43 @@ impl ProjectPanel { } } + fn file_abs_paths_to_diff(&self, cx: &Context) -> Option<(PathBuf, PathBuf)> { + let mut selections_abs_path = self + .marked_entries + .iter() + .filter_map(|entry| { + let project = self.project.read(cx); + let worktree = project.worktree_for_id(entry.worktree_id, cx)?; + let entry = worktree.read(cx).entry_for_id(entry.entry_id)?; + if !entry.is_file() { + return None; + } + worktree.read(cx).absolutize(&entry.path).ok() + }) + .rev(); + + let last_path = selections_abs_path.next()?; + let previous_to_last = selections_abs_path.next()?; + Some((previous_to_last, last_path)) + } + + fn compare_marked_files( + &mut self, + _: &CompareMarkedFiles, + window: &mut Window, + cx: &mut Context, + ) { + let selected_files = self.file_abs_paths_to_diff(cx); + if let Some((file_path1, file_path2)) = selected_files { + self.workspace + .update(cx, |workspace, cx| { + FileDiffView::open(file_path1, file_path2, workspace, window, cx) + .detach_and_log_err(cx); + }) + .ok(); + } + } + fn open_system(&mut self, _: &OpenWithSystem, _: &mut Window, cx: &mut Context) { if let Some((worktree, entry)) = self.selected_entry(cx) { let abs_path = worktree.abs_path().join(&entry.path); @@ -2949,6 +3033,7 @@ impl ProjectPanel { new_selected_entry: Option<(WorktreeId, ProjectEntryId)>, cx: &mut Context, ) { + let now = Instant::now(); let settings = ProjectPanelSettings::get_global(cx); let auto_collapse_dirs = settings.auto_fold_dirs; let hide_gitignore = settings.hide_gitignore; @@ -3146,19 +3231,23 @@ impl ProjectPanel { project::sort_worktree_entries(&mut visible_worktree_entries); - self.visible_entries - .push((worktree_id, visible_worktree_entries, OnceCell::new())); + self.visible_entries.push(VisibleEntriesForWorktree { + worktree_id, + entries: visible_worktree_entries, + index: OnceCell::new(), + }) } if let Some((project_entry_id, worktree_id, _)) = max_width_item { let mut visited_worktrees_length = 0; - let index = self.visible_entries.iter().find_map(|(id, entries, _)| { - if worktree_id == *id { - entries + let index = self.visible_entries.iter().find_map(|visible_entries| { + if worktree_id == visible_entries.worktree_id { + visible_entries + .entries .iter() .position(|entry| entry.id == project_entry_id) } else { - visited_worktrees_length += entries.len(); + visited_worktrees_length += visible_entries.entries.len(); None } }); @@ -3172,6 +3261,18 @@ impl ProjectPanel { entry_id, }); } + let elapsed = now.elapsed(); + if self.last_reported_update.elapsed() > Duration::from_secs(3600) { + telemetry::event!( + "Project Panel Updated", + elapsed_ms = elapsed.as_millis() as u64, + worktree_entries = self + .visible_entries + .iter() + .map(|worktree| worktree.entries.len()) + .sum::(), + ) + } } fn expand_entry( @@ -3385,15 +3486,14 @@ impl ProjectPanel { worktree_id: WorktreeId, ) -> Option<(usize, usize, usize)> { let mut total_ix = 0; - for (worktree_ix, (current_worktree_id, visible_worktree_entries, _)) in - self.visible_entries.iter().enumerate() - { - if worktree_id != *current_worktree_id { - total_ix += visible_worktree_entries.len(); + for (worktree_ix, visible) in self.visible_entries.iter().enumerate() { + if worktree_id != visible.worktree_id { + total_ix += visible.entries.len(); continue; } - return visible_worktree_entries + return visible + .entries .iter() .enumerate() .find(|(_, entry)| entry.id == entry_id) @@ -3404,12 +3504,13 @@ impl ProjectPanel { fn entry_at_index(&self, index: usize) -> Option<(WorktreeId, GitEntryRef<'_>)> { let mut offset = 0; - for (worktree_id, visible_worktree_entries, _) in &self.visible_entries { - let current_len = visible_worktree_entries.len(); + for worktree in &self.visible_entries { + let current_len = worktree.entries.len(); if index < offset + current_len { - return visible_worktree_entries + return worktree + .entries .get(index - offset) - .map(|entry| (*worktree_id, entry.to_ref())); + .map(|entry| (worktree.worktree_id, entry.to_ref())); } offset += current_len; } @@ -3430,26 +3531,23 @@ impl ProjectPanel { ), ) { let mut ix = 0; - for (_, visible_worktree_entries, entries_paths) in &self.visible_entries { + for visible in &self.visible_entries { if ix >= range.end { return; } - if ix + visible_worktree_entries.len() <= range.start { - ix += visible_worktree_entries.len(); + if ix + visible.entries.len() <= range.start { + ix += visible.entries.len(); continue; } - let end_ix = range.end.min(ix + visible_worktree_entries.len()); + let end_ix = range.end.min(ix + visible.entries.len()); let entry_range = range.start.saturating_sub(ix)..end_ix - ix; - let entries = entries_paths.get_or_init(|| { - visible_worktree_entries - .iter() - .map(|e| (e.path.clone())) - .collect() - }); + let entries = visible + .index + .get_or_init(|| visible.entries.iter().map(|e| (e.path.clone())).collect()); let base_index = ix + entry_range.start; - for (i, entry) in visible_worktree_entries[entry_range].iter().enumerate() { + for (i, entry) in visible.entries[entry_range].iter().enumerate() { let global_index = base_index + i; callback(&entry, global_index, entries, window, cx); } @@ -3465,40 +3563,41 @@ impl ProjectPanel { mut callback: impl FnMut(ProjectEntryId, EntryDetails, &mut Window, &mut Context), ) { let mut ix = 0; - for (worktree_id, visible_worktree_entries, entries_paths) in &self.visible_entries { + for visible in &self.visible_entries { if ix >= range.end { return; } - if ix + visible_worktree_entries.len() <= range.start { - ix += visible_worktree_entries.len(); + if ix + visible.entries.len() <= range.start { + ix += visible.entries.len(); continue; } - let end_ix = range.end.min(ix + visible_worktree_entries.len()); + let end_ix = range.end.min(ix + visible.entries.len()); let git_status_setting = { let settings = ProjectPanelSettings::get_global(cx); settings.git_status }; - if let Some(worktree) = self.project.read(cx).worktree_for_id(*worktree_id, cx) { + if let Some(worktree) = self + .project + .read(cx) + .worktree_for_id(visible.worktree_id, cx) + { let snapshot = worktree.read(cx).snapshot(); let root_name = OsStr::new(snapshot.root_name()); let entry_range = range.start.saturating_sub(ix)..end_ix - ix; - let entries = entries_paths.get_or_init(|| { - visible_worktree_entries - .iter() - .map(|e| (e.path.clone())) - .collect() - }); - for entry in visible_worktree_entries[entry_range].iter() { + let entries = visible + .index + .get_or_init(|| visible.entries.iter().map(|e| (e.path.clone())).collect()); + for entry in visible.entries[entry_range].iter() { let status = git_status_setting .then_some(entry.git_summary) .unwrap_or_default(); let mut details = self.details_for_entry( entry, - *worktree_id, + visible.worktree_id, root_name, entries, status, @@ -3584,9 +3683,9 @@ impl ProjectPanel { let entries = self .visible_entries .iter() - .find_map(|(tree_id, entries, _)| { - if worktree_id == *tree_id { - Some(entries) + .find_map(|visible| { + if worktree_id == visible.worktree_id { + Some(&visible.entries) } else { None } @@ -3625,7 +3724,7 @@ impl ProjectPanel { let mut worktree_ids: Vec<_> = self .visible_entries .iter() - .map(|(worktree_id, _, _)| *worktree_id) + .map(|worktree| worktree.worktree_id) .collect(); let repo_snapshots = self .project @@ -3741,7 +3840,7 @@ impl ProjectPanel { let mut worktree_ids: Vec<_> = self .visible_entries .iter() - .map(|(worktree_id, _, _)| *worktree_id) + .map(|worktree| worktree.worktree_id) .collect(); let mut last_found: Option = None; @@ -3750,8 +3849,8 @@ impl ProjectPanel { let entries = self .visible_entries .iter() - .find(|(worktree_id, _, _)| *worktree_id == start.worktree_id) - .map(|(_, entries, _)| entries)?; + .find(|worktree| worktree.worktree_id == start.worktree_id) + .map(|worktree| &worktree.entries)?; let mut start_idx = entries .iter() @@ -3956,11 +4055,9 @@ impl ProjectPanel { let depth = details.depth; let worktree_id = details.worktree_id; - let selections = Arc::new(self.marked_entries.clone()); - let dragged_selection = DraggedSelection { active_selection: selection, - marked_selections: selections, + marked_selections: Arc::from(self.marked_entries.clone()), }; let bg_color = if is_marked { @@ -4131,7 +4228,7 @@ impl ProjectPanel { }); if drag_state.items().count() == 1 { this.marked_entries.clear(); - this.marked_entries.insert(drag_state.active_selection); + this.marked_entries.push(drag_state.active_selection); } this.hover_expand_task.take(); @@ -4198,66 +4295,69 @@ impl ProjectPanel { }), ) .on_click( - cx.listener(move |this, event: &gpui::ClickEvent, window, cx| { - if event.down.button == MouseButton::Right - || event.down.first_mouse + cx.listener(move |project_panel, event: &gpui::ClickEvent, window, cx| { + if event.is_right_click() || event.first_focus() || show_editor { return; } - if event.down.button == MouseButton::Left { - this.mouse_down = false; + if event.standard_click() { + project_panel.mouse_down = false; } cx.stop_propagation(); - if let Some(selection) = this.selection.filter(|_| event.modifiers().shift) { - let current_selection = this.index_for_selection(selection); + if let Some(selection) = project_panel.selection.filter(|_| event.modifiers().shift) { + let current_selection = project_panel.index_for_selection(selection); let clicked_entry = SelectedEntry { entry_id, worktree_id, }; - let target_selection = this.index_for_selection(clicked_entry); + let target_selection = project_panel.index_for_selection(clicked_entry); if let Some(((_, _, source_index), (_, _, target_index))) = current_selection.zip(target_selection) { let range_start = source_index.min(target_index); let range_end = source_index.max(target_index) + 1; - let mut new_selections = BTreeSet::new(); - this.for_each_visible_entry( + let mut new_selections = Vec::new(); + project_panel.for_each_visible_entry( range_start..range_end, window, cx, |entry_id, details, _, _| { - new_selections.insert(SelectedEntry { + new_selections.push(SelectedEntry { entry_id, worktree_id: details.worktree_id, }); }, ); - this.marked_entries = this - .marked_entries - .union(&new_selections) - .cloned() - .collect(); + for selection in &new_selections { + if !project_panel.marked_entries.contains(selection) { + project_panel.marked_entries.push(*selection); + } + } - this.selection = Some(clicked_entry); - this.marked_entries.insert(clicked_entry); + project_panel.selection = Some(clicked_entry); + if !project_panel.marked_entries.contains(&clicked_entry) { + project_panel.marked_entries.push(clicked_entry); + } } } else if event.modifiers().secondary() { - if event.down.click_count > 1 { - this.split_entry(entry_id,None, cx); + if event.click_count() > 1 { + project_panel.split_entry(entry_id, None, cx); } else { - this.selection = Some(selection); - if !this.marked_entries.insert(selection) { - this.marked_entries.remove(&selection); + project_panel.selection = Some(selection); + if let Some(position) = project_panel.marked_entries.iter().position(|e| *e == selection) { + project_panel.marked_entries.remove(position); + } else { + project_panel.marked_entries.push(selection); } } } else if kind.is_dir() { - this.marked_entries.clear(); + project_panel.marked_entries.clear(); if is_sticky { - if let Some((_, _, index)) = this.index_for_entry(entry_id, worktree_id) { - this.scroll_handle.scroll_to_item_with_offset(index, ScrollStrategy::Top, sticky_index.unwrap_or(0)); + if let Some((_, _, index)) = project_panel.index_for_entry(entry_id, worktree_id) { + project_panel.scroll_handle.scroll_to_item_with_offset(index, ScrollStrategy::Top, sticky_index.unwrap_or(0)); cx.notify(); // move down by 1px so that clicked item // don't count as sticky anymore @@ -4273,16 +4373,16 @@ impl ProjectPanel { } } if event.modifiers().alt { - this.toggle_expand_all(entry_id, window, cx); + project_panel.toggle_expand_all(entry_id, window, cx); } else { - this.toggle_expanded(entry_id, window, cx); + project_panel.toggle_expanded(entry_id, window, cx); } } else { let preview_tabs_enabled = PreviewTabsSettings::get_global(cx).enabled; - let click_count = event.up.click_count; + let click_count = event.click_count(); let focus_opened_item = !preview_tabs_enabled || click_count > 1; let allow_preview = preview_tabs_enabled && click_count == 1; - this.open_entry(entry_id, focus_opened_item, allow_preview, cx); + project_panel.open_entry(entry_id, focus_opened_item, allow_preview, cx); } }), ) @@ -4853,12 +4953,21 @@ impl ProjectPanel { { anyhow::bail!("can't reveal an ignored entry in the project panel"); } + let is_active_item_file_diff_view = self + .workspace + .upgrade() + .and_then(|ws| ws.read(cx).active_item(cx)) + .map(|item| item.act_as_type(TypeId::of::(), cx).is_some()) + .unwrap_or(false); + if is_active_item_file_diff_view { + return Ok(()); + } let worktree_id = worktree.id(); self.expand_entry(worktree_id, entry_id, cx); self.update_visible_entries(Some((worktree_id, entry_id)), cx); self.marked_entries.clear(); - self.marked_entries.insert(SelectedEntry { + self.marked_entries.push(SelectedEntry { worktree_id, entry_id, }); @@ -4893,7 +5002,7 @@ impl ProjectPanel { let (active_indent_range, depth) = { let (worktree_ix, child_offset, ix) = self.index_for_entry(entry.id, worktree.id())?; - let child_paths = &self.visible_entries[worktree_ix].1; + let child_paths = &self.visible_entries[worktree_ix].entries; let mut child_count = 0; let depth = entry.path.ancestors().count(); while let Some(entry) = child_paths.get(child_offset + child_count + 1) { @@ -4906,9 +5015,14 @@ impl ProjectPanel { let start = ix + 1; let end = start + child_count; - let (_, entries, paths) = &self.visible_entries[worktree_ix]; - let visible_worktree_entries = - paths.get_or_init(|| entries.iter().map(|e| (e.path.clone())).collect()); + let visible_worktree = &self.visible_entries[worktree_ix]; + let visible_worktree_entries = visible_worktree.index.get_or_init(|| { + visible_worktree + .entries + .iter() + .map(|e| (e.path.clone())) + .collect() + }); // Calculate the actual depth of the entry, taking into account that directories can be auto-folded. let (depth, _) = Self::calculate_depth_and_difference(entry, visible_worktree_entries); @@ -4943,10 +5057,10 @@ impl ProjectPanel { return SmallVec::new(); }; - let Some((_, visible_worktree_entries, entries_paths)) = self + let Some(visible) = self .visible_entries .iter() - .find(|(id, _, _)| *id == worktree_id) + .find(|worktree| worktree.worktree_id == worktree_id) else { return SmallVec::new(); }; @@ -4956,12 +5070,9 @@ impl ProjectPanel { }; let worktree = worktree.read(cx).snapshot(); - let paths = entries_paths.get_or_init(|| { - visible_worktree_entries - .iter() - .map(|e| e.path.clone()) - .collect() - }); + let paths = visible + .index + .get_or_init(|| visible.entries.iter().map(|e| e.path.clone()).collect()); let mut sticky_parents = Vec::new(); let mut current_path = entry_ref.path.clone(); @@ -4991,7 +5102,8 @@ impl ProjectPanel { let root_name = OsStr::new(worktree.root_name()); let git_summaries_by_id = if git_status_enabled { - visible_worktree_entries + visible + .entries .iter() .map(|e| (e.id, e.git_summary)) .collect::>() @@ -5089,7 +5201,7 @@ impl Render for ProjectPanel { let item_count = self .visible_entries .iter() - .map(|(_, worktree_entries, _)| worktree_entries.len()) + .map(|worktree| worktree.entries.len()) .sum(); fn handle_drag_move( @@ -5180,7 +5292,10 @@ impl Render for ProjectPanel { this.hide_scrollbar(window, cx); } })) - .on_click(cx.listener(|this, _event, _, cx| { + .on_click(cx.listener(|this, event, _, cx| { + if matches!(event, gpui::ClickEvent::Keyboard(_)) { + return; + } cx.stop_propagation(); this.selection = None; this.marked_entries.clear(); @@ -5212,6 +5327,7 @@ impl Render for ProjectPanel { .on_action(cx.listener(Self::unfold_directory)) .on_action(cx.listener(Self::fold_directory)) .on_action(cx.listener(Self::remove_from_project)) + .on_action(cx.listener(Self::compare_marked_files)) .when(!project.is_read_only(cx), |el| { el.on_action(cx.listener(Self::new_file)) .on_action(cx.listener(Self::new_directory)) @@ -5223,7 +5339,7 @@ impl Render for ProjectPanel { .on_action(cx.listener(Self::paste)) .on_action(cx.listener(Self::duplicate)) .on_click(cx.listener(|this, event: &gpui::ClickEvent, window, cx| { - if event.up.click_count > 1 { + if event.click_count() > 1 { if let Some(entry_id) = this.last_worktree_root_id { let project = this.project.read(cx); @@ -5602,6 +5718,10 @@ impl Panel for ProjectPanel { } fn starts_open(&self, _: &Window, cx: &App) -> bool { + if !ProjectPanelSettings::get_global(cx).starts_open { + return false; + } + let project = &self.project.read(cx); project.visible_worktrees(cx).any(|tree| { tree.read(cx) diff --git a/crates/project_panel/src/project_panel_settings.rs b/crates/project_panel/src/project_panel_settings.rs index 9057480972..8a243589ed 100644 --- a/crates/project_panel/src/project_panel_settings.rs +++ b/crates/project_panel/src/project_panel_settings.rs @@ -43,6 +43,7 @@ pub struct ProjectPanelSettings { pub sticky_scroll: bool, pub auto_reveal_entries: bool, pub auto_fold_dirs: bool, + pub starts_open: bool, pub scrollbar: ScrollbarSettings, pub show_diagnostics: ShowDiagnostics, pub hide_root: bool, @@ -139,6 +140,10 @@ pub struct ProjectPanelSettingsContent { /// /// Default: true pub auto_fold_dirs: Option, + /// Whether the project panel should open on startup. + /// + /// Default: true + pub starts_open: Option, /// Scrollbar-related settings pub scrollbar: Option, /// Which files containing diagnostic errors/warnings to mark in the project panel. diff --git a/crates/project_panel/src/project_panel_tests.rs b/crates/project_panel/src/project_panel_tests.rs index 7699256bc9..6c62c8db93 100644 --- a/crates/project_panel/src/project_panel_tests.rs +++ b/crates/project_panel/src/project_panel_tests.rs @@ -8,7 +8,7 @@ use settings::SettingsStore; use std::path::{Path, PathBuf}; use util::path; use workspace::{ - AppState, Pane, + AppState, ItemHandle, Pane, item::{Item, ProjectItem}, register_project_item, }; @@ -3068,7 +3068,7 @@ async fn test_multiple_marked_entries(cx: &mut gpui::TestAppContext) { panel.update(cx, |this, cx| { let drag = DraggedSelection { active_selection: this.selection.unwrap(), - marked_selections: Arc::new(this.marked_entries.clone()), + marked_selections: this.marked_entries.clone().into(), }; let target_entry = this .project @@ -5562,10 +5562,10 @@ async fn test_highlight_entry_for_selection_drag(cx: &mut gpui::TestAppContext) worktree_id, entry_id: child_file.id, }, - marked_selections: Arc::new(BTreeSet::from([SelectedEntry { + marked_selections: Arc::new([SelectedEntry { worktree_id, entry_id: child_file.id, - }])), + }]), }; let result = panel.highlight_entry_for_selection_drag(parent_dir, worktree, &dragged_selection, cx); @@ -5604,7 +5604,7 @@ async fn test_highlight_entry_for_selection_drag(cx: &mut gpui::TestAppContext) worktree_id, entry_id: child_file.id, }, - marked_selections: Arc::new(BTreeSet::from([ + marked_selections: Arc::new([ SelectedEntry { worktree_id, entry_id: child_file.id, @@ -5613,7 +5613,7 @@ async fn test_highlight_entry_for_selection_drag(cx: &mut gpui::TestAppContext) worktree_id, entry_id: sibling_file.id, }, - ])), + ]), }; let result = panel.highlight_entry_for_selection_drag(parent_dir, worktree, &dragged_selection, cx); @@ -5821,6 +5821,186 @@ async fn test_hide_root(cx: &mut gpui::TestAppContext) { } } +#[gpui::test] +async fn test_compare_selected_files(cx: &mut gpui::TestAppContext) { + init_test_with_editor(cx); + + let fs = FakeFs::new(cx.executor().clone()); + fs.insert_tree( + "/root", + json!({ + "file1.txt": "content of file1", + "file2.txt": "content of file2", + "dir1": { + "file3.txt": "content of file3" + } + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await; + let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*workspace, cx); + let panel = workspace.update(cx, ProjectPanel::new).unwrap(); + + let file1_path = path!("root/file1.txt"); + let file2_path = path!("root/file2.txt"); + select_path_with_mark(&panel, file1_path, cx); + select_path_with_mark(&panel, file2_path, cx); + + panel.update_in(cx, |panel, window, cx| { + panel.compare_marked_files(&CompareMarkedFiles, window, cx); + }); + cx.executor().run_until_parked(); + + workspace + .update(cx, |workspace, _, cx| { + let active_items = workspace + .panes() + .iter() + .filter_map(|pane| pane.read(cx).active_item()) + .collect::>(); + assert_eq!(active_items.len(), 1); + let diff_view = active_items + .into_iter() + .next() + .unwrap() + .downcast::() + .expect("Open item should be an FileDiffView"); + assert_eq!(diff_view.tab_content_text(0, cx), "file1.txt ↔ file2.txt"); + assert_eq!( + diff_view.tab_tooltip_text(cx).unwrap(), + format!("{} ↔ {}", file1_path, file2_path) + ); + }) + .unwrap(); + + let file1_entry_id = find_project_entry(&panel, file1_path, cx).unwrap(); + let file2_entry_id = find_project_entry(&panel, file2_path, cx).unwrap(); + let worktree_id = panel.update(cx, |panel, cx| { + panel + .project + .read(cx) + .worktrees(cx) + .next() + .unwrap() + .read(cx) + .id() + }); + + let expected_entries = [ + SelectedEntry { + worktree_id, + entry_id: file1_entry_id, + }, + SelectedEntry { + worktree_id, + entry_id: file2_entry_id, + }, + ]; + panel.update(cx, |panel, _cx| { + assert_eq!( + &panel.marked_entries, &expected_entries, + "Should keep marked entries after comparison" + ); + }); + + panel.update(cx, |panel, cx| { + panel.project.update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(file2_entry_id)) + }) + }); + + panel.update(cx, |panel, _cx| { + assert_eq!( + &panel.marked_entries, &expected_entries, + "Marked entries should persist after focusing back on the project panel" + ); + }); +} + +#[gpui::test] +async fn test_compare_files_context_menu(cx: &mut gpui::TestAppContext) { + init_test_with_editor(cx); + + let fs = FakeFs::new(cx.executor().clone()); + fs.insert_tree( + "/root", + json!({ + "file1.txt": "content of file1", + "file2.txt": "content of file2", + "dir1": {}, + "dir2": { + "file3.txt": "content of file3" + } + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await; + let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*workspace, cx); + let panel = workspace.update(cx, ProjectPanel::new).unwrap(); + + // Test 1: When only one file is selected, there should be no compare option + select_path(&panel, "root/file1.txt", cx); + + let selected_files = panel.update(cx, |panel, cx| panel.file_abs_paths_to_diff(cx)); + assert_eq!( + selected_files, None, + "Should not have compare option when only one file is selected" + ); + + // Test 2: When multiple files are selected, there should be a compare option + select_path_with_mark(&panel, "root/file1.txt", cx); + select_path_with_mark(&panel, "root/file2.txt", cx); + + let selected_files = panel.update(cx, |panel, cx| panel.file_abs_paths_to_diff(cx)); + assert!( + selected_files.is_some(), + "Should have files selected for comparison" + ); + if let Some((file1, file2)) = selected_files { + assert!( + file1.to_string_lossy().ends_with("file1.txt") + && file2.to_string_lossy().ends_with("file2.txt"), + "Should have file1.txt and file2.txt as the selected files when multi-selecting" + ); + } + + // Test 3: Selecting a directory shouldn't count as a comparable file + select_path_with_mark(&panel, "root/dir1", cx); + + let selected_files = panel.update(cx, |panel, cx| panel.file_abs_paths_to_diff(cx)); + assert!( + selected_files.is_some(), + "Directory selection should not affect comparable files" + ); + if let Some((file1, file2)) = selected_files { + assert!( + file1.to_string_lossy().ends_with("file1.txt") + && file2.to_string_lossy().ends_with("file2.txt"), + "Selecting a directory should not affect the number of comparable files" + ); + } + + // Test 4: Selecting one more file + select_path_with_mark(&panel, "root/dir2/file3.txt", cx); + + let selected_files = panel.update(cx, |panel, cx| panel.file_abs_paths_to_diff(cx)); + assert!( + selected_files.is_some(), + "Directory selection should not affect comparable files" + ); + if let Some((file1, file2)) = selected_files { + assert!( + file1.to_string_lossy().ends_with("file2.txt") + && file2.to_string_lossy().ends_with("file3.txt"), + "Selecting a directory should not affect the number of comparable files" + ); + } +} + fn select_path(panel: &Entity, path: impl AsRef, cx: &mut VisualTestContext) { let path = path.as_ref(); panel.update(cx, |panel, cx| { @@ -5855,7 +6035,7 @@ fn select_path_with_mark( entry_id, }; if !panel.marked_entries.contains(&entry) { - panel.marked_entries.insert(entry); + panel.marked_entries.push(entry); } panel.selection = Some(entry); return; diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index d737ef9246..7eb63eec5e 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -18,7 +18,7 @@ use util::{ResultExt, get_system_shell}; use crate::UserPromptId; -#[derive(Debug, Clone, Serialize)] +#[derive(Default, Debug, Clone, Serialize)] pub struct ProjectContext { pub worktrees: Vec, /// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this. @@ -71,14 +71,14 @@ pub struct UserRulesContext { pub contents: String, } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Serialize)] pub struct WorktreeContext { pub root_name: String, pub abs_path: Arc, pub rules_file: Option, } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Serialize)] pub struct RulesFileContext { pub path_in_worktree: Arc, pub text: String, diff --git a/crates/proto/proto/lsp.proto b/crates/proto/proto/lsp.proto index 1e693dfdf3..ea9647feff 100644 --- a/crates/proto/proto/lsp.proto +++ b/crates/proto/proto/lsp.proto @@ -525,6 +525,7 @@ message UpdateDiagnosticSummary { uint64 project_id = 1; uint64 worktree_id = 2; DiagnosticSummary summary = 3; + repeated DiagnosticSummary more_summaries = 4; } message DiagnosticSummary { @@ -818,16 +819,6 @@ message LspResponse { uint64 server_id = 7; } -message LanguageServerIdForName { - uint64 project_id = 1; - uint64 buffer_id = 2; - string name = 3; -} - -message LanguageServerIdForNameResponse { - optional uint64 server_id = 1; -} - message LspExtRunnables { uint64 project_id = 1; uint64 buffer_id = 2; diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 9de5c2c0c7..bb97bd500a 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -362,9 +362,6 @@ message Envelope { GetDocumentSymbols get_document_symbols = 330; GetDocumentSymbolsResponse get_document_symbols_response = 331; - LanguageServerIdForName language_server_id_for_name = 332; - LanguageServerIdForNameResponse language_server_id_for_name_response = 333; - LoadCommitDiff load_commit_diff = 334; LoadCommitDiffResponse load_commit_diff_response = 335; @@ -424,6 +421,7 @@ message Envelope { reserved 247 to 254; reserved 255 to 256; reserved 280 to 281; + reserved 332 to 333; } message Hello { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 4c447e2eca..9edb041b4b 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -121,8 +121,6 @@ messages!( (GetImplementationResponse, Background), (GetLlmToken, Background), (GetLlmTokenResponse, Background), - (LanguageServerIdForName, Background), - (LanguageServerIdForNameResponse, Background), (OpenUnstagedDiff, Foreground), (OpenUnstagedDiffResponse, Foreground), (OpenUncommittedDiff, Foreground), @@ -431,7 +429,6 @@ request_messages!( (UpdateWorktree, Ack), (UpdateRepository, Ack), (RemoveRepository, Ack), - (LanguageServerIdForName, LanguageServerIdForNameResponse), (LspExtExpandMacro, LspExtExpandMacroResponse), (LspExtOpenDocs, LspExtOpenDocsResponse), (LspExtRunnables, LspExtRunnablesResponse), @@ -588,7 +585,6 @@ entity_messages!( OpenServerSettings, GetPermalinkToLine, LanguageServerPromptRequest, - LanguageServerIdForName, GitGetBranches, UpdateGitBranch, ListToolchains, diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index 5dbde6496d..2093e96cae 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -141,6 +141,7 @@ impl Focusable for RecentProjects { impl Render for RecentProjects { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() + .key_context("RecentProjects") .w(rems(self.rem_width)) .child(self.picker.clone()) .on_mouse_down_out(cx.listener(|this, _, window, cx| { diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index 655e24860a..354434a7fc 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -953,7 +953,7 @@ impl RemoteServerProjects { ) .child(Label::new(project.paths.join(", "))) .on_click(cx.listener(move |this, e: &ClickEvent, window, cx| { - let secondary_confirm = e.down.modifiers.platform; + let secondary_confirm = e.modifiers().platform; callback(this, secondary_confirm, window, cx) })) .when(is_from_zed, |server_list_item| { diff --git a/crates/repl/src/notebook/notebook_ui.rs b/crates/repl/src/notebook/notebook_ui.rs index 3e96cc4d11..2efa51e0cc 100644 --- a/crates/repl/src/notebook/notebook_ui.rs +++ b/crates/repl/src/notebook/notebook_ui.rs @@ -126,29 +126,7 @@ impl NotebookEditor { let cell_count = cell_order.len(); let this = cx.entity(); - let cell_list = ListState::new( - cell_count, - gpui::ListAlignment::Top, - px(1000.), - move |ix, window, cx| { - notebook_handle - .upgrade() - .and_then(|notebook_handle| { - notebook_handle.update(cx, |notebook, cx| { - notebook - .cell_order - .get(ix) - .and_then(|cell_id| notebook.cell_map.get(cell_id)) - .map(|cell| { - notebook - .render_cell(ix, cell, window, cx) - .into_any_element() - }) - }) - }) - .unwrap_or_else(|| div().into_any()) - }, - ); + let cell_list = ListState::new(cell_count, gpui::ListAlignment::Top, px(1000.)); Self { project, @@ -544,7 +522,19 @@ impl Render for NotebookEditor { .flex_1() .size_full() .overflow_y_scroll() - .child(list(self.cell_list.clone()).size_full()), + .child(list( + self.cell_list.clone(), + cx.processor(|this, ix, window, cx| { + this.cell_order + .get(ix) + .and_then(|cell_id| this.cell_map.get(cell_id)) + .map(|cell| { + this.render_cell(ix, cell, window, cx).into_any_element() + }) + .unwrap_or_else(|| div().into_any()) + }), + )) + .size_full(), ) .child(self.render_notebook_controls(window, cx)) } diff --git a/crates/rope/src/rope.rs b/crates/rope/src/rope.rs index aa3ed5db57..d8ed3bfac8 100644 --- a/crates/rope/src/rope.rs +++ b/crates/rope/src/rope.rs @@ -471,11 +471,19 @@ impl<'a> FromIterator<&'a str> for Rope { } impl From for Rope { + #[inline(always)] fn from(text: String) -> Self { Rope::from(text.as_str()) } } +impl From<&String> for Rope { + #[inline(always)] + fn from(text: &String) -> Self { + Rope::from(text.as_str()) + } +} + impl fmt::Display for Rope { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for chunk in self.chunks() { diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index c1fd1df5ff..80a104641f 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -422,26 +422,8 @@ impl Peer { receiver_id: ConnectionId, request: T, ) -> impl Future> { - let request_start_time = Instant::now(); - let payload_type = T::NAME; - let elapsed_time = move || request_start_time.elapsed().as_millis(); - tracing::info!(payload_type, "start forwarding request"); self.request_internal(Some(sender_id), receiver_id, request) .map_ok(|envelope| envelope.payload) - .inspect_err(move |_| { - tracing::error!( - waiting_for_host_ms = elapsed_time(), - payload_type, - "error forwarding request" - ) - }) - .inspect_ok(move |_| { - tracing::info!( - waiting_for_host_ms = elapsed_time(), - payload_type, - "finished forwarding request" - ) - }) } fn request_internal( diff --git a/crates/semantic_index/src/project_index_debug_view.rs b/crates/semantic_index/src/project_index_debug_view.rs index 1b0d87fca0..8d6a49c45c 100644 --- a/crates/semantic_index/src/project_index_debug_view.rs +++ b/crates/semantic_index/src/project_index_debug_view.rs @@ -115,21 +115,9 @@ impl ProjectIndexDebugView { .collect::>(); this.update(cx, |this, cx| { - let view = cx.entity().downgrade(); this.selected_path = Some(PathState { path: file_path, - list_state: ListState::new( - chunks.len(), - gpui::ListAlignment::Top, - px(100.), - move |ix, _, cx| { - if let Some(view) = view.upgrade() { - view.update(cx, |view, cx| view.render_chunk(ix, cx)) - } else { - div().into_any() - } - }, - ), + list_state: ListState::new(chunks.len(), gpui::ListAlignment::Top, px(100.)), chunks, }); cx.notify(); @@ -219,7 +207,13 @@ impl Render for ProjectIndexDebugView { cx.notify(); })), ) - .child(list(selected_path.list_state.clone()).size_full()) + .child( + list( + selected_path.list_state.clone(), + cx.processor(|this, ix, _, cx| this.render_chunk(ix, cx)), + ) + .size_full(), + ) .size_full() .into_any_element() } else { diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index bc42d2c886..bfdafbffe8 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -16,6 +16,7 @@ use serde_json::{Value, json}; use smallvec::SmallVec; use std::{ any::{Any, TypeId, type_name}, + env, fmt::Debug, ops::Range, path::{Path, PathBuf}, @@ -126,6 +127,8 @@ pub struct SettingsSources<'a, T> { pub user: Option<&'a T>, /// The user settings for the current release channel. pub release_channel: Option<&'a T>, + /// The user settings for the current operating system. + pub operating_system: Option<&'a T>, /// The settings associated with an enabled settings profile pub profile: Option<&'a T>, /// The server's settings. @@ -147,6 +150,7 @@ impl<'a, T: Serialize> SettingsSources<'a, T> { .chain(self.extensions) .chain(self.user) .chain(self.release_channel) + .chain(self.operating_system) .chain(self.profile) .chain(self.server) .chain(self.project.iter().copied()) @@ -336,6 +340,11 @@ impl SettingsStore { .log_err(); } + let mut os_settings_value = None; + if let Some(os_settings) = &self.raw_user_settings.get(env::consts::OS) { + os_settings_value = setting_value.deserialize_setting(os_settings).log_err(); + } + let mut profile_value = None; if let Some(active_profile) = cx.try_global::() { if let Some(profiles) = self.raw_user_settings.get("profiles") { @@ -366,6 +375,7 @@ impl SettingsStore { extensions: extension_value.as_ref(), user: user_value.as_ref(), release_channel: release_channel_value.as_ref(), + operating_system: os_settings_value.as_ref(), profile: profile_value.as_ref(), server: server_value.as_ref(), project: &[], @@ -1092,7 +1102,7 @@ impl SettingsStore { "$schema": meta_schema, "title": "Zed Settings", "unevaluatedProperties": false, - // ZedSettings + settings overrides for each release stage / profiles + // ZedSettings + settings overrides for each release stage / OS / profiles "allOf": [ zed_settings_ref, { @@ -1101,6 +1111,9 @@ impl SettingsStore { "nightly": zed_settings_override_ref, "stable": zed_settings_override_ref, "preview": zed_settings_override_ref, + "linux": zed_settings_override_ref, + "macos": zed_settings_override_ref, + "windows": zed_settings_override_ref, "profiles": { "type": "object", "description": "Configures any number of settings profiles.", @@ -1164,6 +1177,13 @@ impl SettingsStore { } } + let mut os_settings = None; + if let Some(settings) = &self.raw_user_settings.get(env::consts::OS) { + if let Some(settings) = setting_value.deserialize_setting(settings).log_err() { + os_settings = Some(settings); + } + } + let mut profile_settings = None; if let Some(active_profile) = cx.try_global::() { if let Some(profiles) = self.raw_user_settings.get("profiles") { @@ -1184,6 +1204,7 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + operating_system: os_settings.as_ref(), profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &[], @@ -1237,6 +1258,7 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + operating_system: os_settings.as_ref(), profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &project_settings_stack.iter().collect::>(), @@ -1363,6 +1385,9 @@ impl AnySettingValue for SettingValue { release_channel: values .release_channel .map(|value| value.0.downcast_ref::().unwrap()), + operating_system: values + .operating_system + .map(|value| value.0.downcast_ref::().unwrap()), profile: values .profile .map(|value| value.0.downcast_ref::().unwrap()), diff --git a/crates/settings_ui/src/keybindings.rs b/crates/settings_ui/src/keybindings.rs index 70afe1729c..599bb0b18f 100644 --- a/crates/settings_ui/src/keybindings.rs +++ b/crates/settings_ui/src/keybindings.rs @@ -374,6 +374,14 @@ impl Focusable for KeymapEditor { } } } +/// Helper function to check if two keystroke sequences match exactly +fn keystrokes_match_exactly(keystrokes1: &[Keystroke], keystrokes2: &[Keystroke]) -> bool { + keystrokes1.len() == keystrokes2.len() + && keystrokes1 + .iter() + .zip(keystrokes2) + .all(|(k1, k2)| k1.key == k2.key && k1.modifiers == k2.modifiers) +} impl KeymapEditor { fn new(workspace: WeakEntity, window: &mut Window, cx: &mut Context) -> Self { @@ -549,13 +557,7 @@ impl KeymapEditor { .keystrokes() .is_some_and(|keystrokes| { if exact_match { - keystroke_query.len() == keystrokes.len() - && keystroke_query.iter().zip(keystrokes).all( - |(query, keystroke)| { - query.key == keystroke.key - && query.modifiers == keystroke.modifiers - }, - ) + keystrokes_match_exactly(&keystroke_query, keystrokes) } else if keystroke_query.len() > keystrokes.len() { return false; } else { @@ -1855,7 +1857,7 @@ impl Render for KeymapEditor { .on_click(cx.listener( move |this, event: &ClickEvent, window, cx| { this.select_index(row_index, None, window, cx); - if event.up.click_count == 2 { + if event.click_count() == 2 { this.open_edit_keybinding_modal( false, window, cx, ); @@ -2340,8 +2342,50 @@ impl KeybindingEditorModal { self.save_or_display_error(cx); } - fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context) { - cx.emit(DismissEvent) + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + cx.emit(DismissEvent); + } + + fn get_matching_bindings_count(&self, cx: &Context) -> usize { + let current_keystrokes = self.keybind_editor.read(cx).keystrokes().to_vec(); + + if current_keystrokes.is_empty() { + return 0; + } + + self.keymap_editor + .read(cx) + .keybindings + .iter() + .enumerate() + .filter(|(idx, binding)| { + // Don't count the binding we're currently editing + if !self.creating && *idx == self.editing_keybind_idx { + return false; + } + + binding + .keystrokes() + .map(|keystrokes| keystrokes_match_exactly(keystrokes, ¤t_keystrokes)) + .unwrap_or(false) + }) + .count() + } + + fn show_matching_bindings(&mut self, _window: &mut Window, cx: &mut Context) { + let keystrokes = self.keybind_editor.read(cx).keystrokes().to_vec(); + + // Dismiss the modal + cx.emit(DismissEvent); + + // Update the keymap editor to show matching keystrokes + self.keymap_editor.update(cx, |editor, cx| { + editor.filter_state = FilterState::All; + editor.search_mode = SearchMode::KeyStroke { exact_match: true }; + editor.keystroke_editor.update(cx, |keystroke_editor, cx| { + keystroke_editor.set_keystrokes(keystrokes, cx); + }); + }); } } @@ -2356,6 +2400,7 @@ fn remove_key_char(Keystroke { modifiers, key, .. }: Keystroke) -> Keystroke { impl Render for KeybindingEditorModal { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let theme = cx.theme().colors(); + let matching_bindings_count = self.get_matching_bindings_count(cx); v_flex() .w(rems(34.)) @@ -2370,6 +2415,7 @@ impl Render for KeybindingEditorModal { .header( ModalHeader::new().child( v_flex() + .w_full() .pb_1p5() .mb_1() .gap_0p5() @@ -2393,17 +2439,55 @@ impl Render for KeybindingEditorModal { .section( Section::new().child( v_flex() - .gap_2() + .gap_2p5() .child( v_flex() - .child(Label::new("Edit Keystroke")) .gap_1() - .child(self.keybind_editor.clone()), + .child(Label::new("Edit Keystroke")) + .child(self.keybind_editor.clone()) + .child(h_flex().gap_px().when( + matching_bindings_count > 0, + |this| { + let label = format!( + "There {} {} {} with the same keystrokes.", + if matching_bindings_count == 1 { + "is" + } else { + "are" + }, + matching_bindings_count, + if matching_bindings_count == 1 { + "binding" + } else { + "bindings" + } + ); + + this.child( + Label::new(label) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + Button::new("show_matching", "View") + .label_size(LabelSize::Small) + .icon(IconName::ArrowUpRight) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click(cx.listener( + |this, _, window, cx| { + this.show_matching_bindings( + window, cx, + ); + }, + )), + ) + }, + )), ) .when_some(self.action_arguments_editor.clone(), |this, editor| { this.child( v_flex() - .mt_1p5() .gap_1() .child(Label::new("Edit Arguments")) .child(editor), @@ -2414,14 +2498,7 @@ impl Render for KeybindingEditorModal { this.child( Banner::new() .severity(error.severity) - // For some reason, the div overflows its container to the - //right. The padding accounts for that. - .child( - div() - .size_full() - .pr_2() - .child(Label::new(error.content.clone())), - ), + .child(Label::new(error.content.clone())), ) }), ), diff --git a/crates/settings_ui/src/ui_components/keystroke_input.rs b/crates/settings_ui/src/ui_components/keystroke_input.rs index 03d27d0ab9..ee5c4036ea 100644 --- a/crates/settings_ui/src/ui_components/keystroke_input.rs +++ b/crates/settings_ui/src/ui_components/keystroke_input.rs @@ -529,7 +529,7 @@ impl Render for KeystrokeInput { .w_full() .flex_1() .justify_between() - .rounded_lg() + .rounded_sm() .overflow_hidden() .map(|this| { if is_recording { diff --git a/crates/settings_ui/src/ui_components/table.rs b/crates/settings_ui/src/ui_components/table.rs index 3c9992bd68..2b3e815f36 100644 --- a/crates/settings_ui/src/ui_components/table.rs +++ b/crates/settings_ui/src/ui_components/table.rs @@ -248,7 +248,7 @@ impl TableInteractionState { .cursor_col_resize() .when_some(columns.clone(), |this, columns| { this.on_click(move |event, window, cx| { - if event.down.click_count >= 2 { + if event.click_count() >= 2 { columns.update(cx, |columns, _| { columns.on_double_click( column_ix, @@ -997,7 +997,7 @@ pub fn render_header( |this, (column_widths, resizables, initial_sizes)| { if resizables[header_idx].is_resizable() { this.on_click(move |event, window, cx| { - if event.down.click_count > 1 { + if event.click_count() > 1 { column_widths .update(cx, |column, _| { column.on_double_click( diff --git a/crates/snippets_ui/src/snippets_ui.rs b/crates/snippets_ui/src/snippets_ui.rs index 1cc16c5576..a8710d1672 100644 --- a/crates/snippets_ui/src/snippets_ui.rs +++ b/crates/snippets_ui/src/snippets_ui.rs @@ -149,13 +149,12 @@ impl ScopeSelectorDelegate { scope_selector: WeakEntity, language_registry: Arc, ) -> Self { - let candidates = Vec::from([GLOBAL_SCOPE_NAME.to_string()]).into_iter(); let languages = language_registry.language_names().into_iter(); - let candidates = candidates + let candidates = std::iter::once(LanguageName::new(GLOBAL_SCOPE_NAME)) .chain(languages) .enumerate() - .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, &name)) + .map(|(candidate_id, name)| StringMatchCandidate::new(candidate_id, name.as_ref())) .collect::>(); let mut existing_scopes = HashSet::new(); diff --git a/crates/supermaven/src/supermaven.rs b/crates/supermaven/src/supermaven.rs index ab500fb79d..a31b96d882 100644 --- a/crates/supermaven/src/supermaven.rs +++ b/crates/supermaven/src/supermaven.rs @@ -234,16 +234,14 @@ fn find_relevant_completion<'a>( } let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left); - let text_inserted_since_completion_request = - buffer.text_for_range(original_cursor_offset..current_cursor_offset); - let mut trimmed_completion = state_completion; - for chunk in text_inserted_since_completion_request { - if let Some(suffix) = trimmed_completion.strip_prefix(chunk) { - trimmed_completion = suffix; - } else { - continue 'completions; - } - } + let text_inserted_since_completion_request: String = buffer + .text_for_range(original_cursor_offset..current_cursor_offset) + .collect(); + let trimmed_completion = + match state_completion.strip_prefix(&text_inserted_since_completion_request) { + Some(suffix) => suffix, + None => continue 'completions, + }; if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) { continue; @@ -439,3 +437,77 @@ pub struct SupermavenCompletion { pub id: SupermavenCompletionStateId, pub updates: watch::Receiver<()>, } + +#[cfg(test)] +mod tests { + use super::*; + use collections::BTreeMap; + use gpui::TestAppContext; + use language::Buffer; + + #[gpui::test] + async fn test_find_relevant_completion_no_first_letter_skip(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("hello world", cx)); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + + let mut states = BTreeMap::new(); + let state_id = SupermavenCompletionStateId(1); + let (updates_tx, _) = watch::channel(); + + states.insert( + state_id, + SupermavenCompletionState { + buffer_id: buffer.entity_id(), + prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer + prefix_offset: 0, + text: "hello".to_string(), + dedent: String::new(), + updates_tx, + }, + ); + + let cursor_position = buffer_snapshot.anchor_after(1); + + let result = find_relevant_completion( + &states, + buffer.entity_id(), + &buffer_snapshot, + cursor_position, + ); + + assert_eq!(result, Some("ello")); + } + + #[gpui::test] + async fn test_find_relevant_completion_with_multiple_chars(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("hello world", cx)); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + + let mut states = BTreeMap::new(); + let state_id = SupermavenCompletionStateId(1); + let (updates_tx, _) = watch::channel(); + + states.insert( + state_id, + SupermavenCompletionState { + buffer_id: buffer.entity_id(), + prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer + prefix_offset: 0, + text: "hello".to_string(), + dedent: String::new(), + updates_tx, + }, + ); + + let cursor_position = buffer_snapshot.anchor_after(3); + + let result = find_relevant_completion( + &states, + buffer.entity_id(), + &buffer_snapshot, + cursor_position, + ); + + assert_eq!(result, Some("lo")); + } +} diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index 2660a03e6f..1b1fc54a7a 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -108,6 +108,14 @@ impl EditPredictionProvider for SupermavenCompletionProvider { } fn show_completions_in_menu() -> bool { + true + } + + fn show_tab_accept_marker() -> bool { + true + } + + fn supports_jump_to_edit() -> bool { false } @@ -116,7 +124,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { } fn is_refreshing(&self) -> bool { - self.pending_refresh.is_some() + self.pending_refresh.is_some() && self.completion_id.is_none() } fn refresh( @@ -197,6 +205,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { let mut point = cursor_position.to_point(&snapshot); point.column = snapshot.line_len(point.row); let range = cursor_position..snapshot.anchor_after(point); + Some(completion_from_diff( snapshot, completion_text, diff --git a/crates/terminal/src/terminal.rs b/crates/terminal/src/terminal.rs index 6e359414d7..d6a09a590f 100644 --- a/crates/terminal/src/terminal.rs +++ b/crates/terminal/src/terminal.rs @@ -63,9 +63,9 @@ use std::{ use thiserror::Error; use gpui::{ - AnyWindowHandle, App, AppContext as _, Bounds, ClipboardItem, Context, EventEmitter, Hsla, - Keystroke, Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, Point, - Rgba, ScrollWheelEvent, SharedString, Size, Task, TouchPhase, Window, actions, black, px, + App, AppContext as _, Bounds, ClipboardItem, Context, EventEmitter, Hsla, Keystroke, Modifiers, + MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, Point, Rgba, + ScrollWheelEvent, SharedString, Size, Task, TouchPhase, Window, actions, black, px, }; use crate::mappings::{colors::to_alac_rgb, keys::to_esc_str}; @@ -351,7 +351,7 @@ impl TerminalBuilder { alternate_scroll: AlternateScroll, max_scroll_history_lines: Option, is_ssh_terminal: bool, - window: AnyWindowHandle, + window_id: u64, completion_tx: Sender>, cx: &App, ) -> Result { @@ -463,11 +463,7 @@ impl TerminalBuilder { let term = Arc::new(FairMutex::new(term)); //Setup the pty... - let pty = match tty::new( - &pty_options, - TerminalBounds::default().into(), - window.window_id().as_u64(), - ) { + let pty = match tty::new(&pty_options, TerminalBounds::default().into(), window_id) { Ok(pty) => pty, Err(error) => { bail!(TerminalError { diff --git a/crates/terminal_view/src/persistence.rs b/crates/terminal_view/src/persistence.rs index 056365ab8c..b93b267f58 100644 --- a/crates/terminal_view/src/persistence.rs +++ b/crates/terminal_view/src/persistence.rs @@ -245,9 +245,8 @@ async fn deserialize_pane_group( let kind = TerminalKind::Shell( working_directory.as_deref().map(Path::to_path_buf), ); - let window = window.window_handle(); - let terminal = project - .update(cx, |project, cx| project.create_terminal(kind, window, cx)); + let terminal = + project.update(cx, |project, cx| project.create_terminal(kind, cx)); Some(Some(terminal)) } else { Some(None) diff --git a/crates/terminal_view/src/terminal_panel.rs b/crates/terminal_view/src/terminal_panel.rs index cb1e362884..c9528c39b9 100644 --- a/crates/terminal_view/src/terminal_panel.rs +++ b/crates/terminal_view/src/terminal_panel.rs @@ -432,10 +432,9 @@ impl TerminalPanel { }) .unwrap_or((None, None)); let kind = TerminalKind::Shell(working_directory); - let window_handle = window.window_handle(); let terminal = project .update(cx, |project, cx| { - project.create_terminal_with_venv(kind, python_venv_directory, window_handle, cx) + project.create_terminal_with_venv(kind, python_venv_directory, cx) }) .ok()?; @@ -666,13 +665,10 @@ impl TerminalPanel { "terminal not yet supported for remote projects" ))); } - let window_handle = window.window_handle(); let project = workspace.project().downgrade(); cx.spawn_in(window, async move |workspace, cx| { let terminal = project - .update(cx, |project, cx| { - project.create_terminal(kind, window_handle, cx) - })? + .update(cx, |project, cx| project.create_terminal(kind, cx))? .await?; workspace.update_in(cx, |workspace, window, cx| { @@ -709,11 +705,8 @@ impl TerminalPanel { terminal_panel.active_pane.clone() })?; let project = workspace.read_with(cx, |workspace, _| workspace.project().clone())?; - let window_handle = cx.window_handle(); let terminal = project - .update(cx, |project, cx| { - project.create_terminal(kind, window_handle, cx) - })? + .update(cx, |project, cx| project.create_terminal(kind, cx))? .await?; let result = workspace.update_in(cx, |workspace, window, cx| { let terminal_view = Box::new(cx.new(|cx| { @@ -814,7 +807,6 @@ impl TerminalPanel { ) -> Task>> { let reveal = spawn_task.reveal; let reveal_target = spawn_task.reveal_target; - let window_handle = window.window_handle(); let task_workspace = self.workspace.clone(); cx.spawn_in(window, async move |terminal_panel, cx| { let project = terminal_panel.update(cx, |this, cx| { @@ -823,7 +815,7 @@ impl TerminalPanel { })??; let new_terminal = project .update(cx, |project, cx| { - project.create_terminal(TerminalKind::Task(spawn_task), window_handle, cx) + project.create_terminal(TerminalKind::Task(spawn_task), cx) })? .await?; terminal_to_replace.update_in(cx, |terminal_to_replace, window, cx| { diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 2e6be5aaf4..361cdd0b1c 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -1654,7 +1654,6 @@ impl Item for TerminalView { window: &mut Window, cx: &mut Context, ) -> Option> { - let window_handle = window.window_handle(); let terminal = self .project .update(cx, |project, cx| { @@ -1666,7 +1665,6 @@ impl Item for TerminalView { project.create_terminal_with_venv( TerminalKind::Shell(working_directory), python_venv_directory, - window_handle, cx, ) }) @@ -1802,7 +1800,6 @@ impl SerializableItem for TerminalView { window: &mut Window, cx: &mut App, ) -> Task>> { - let window_handle = window.window_handle(); window.spawn(cx, async move |cx| { let cwd = cx .update(|_window, cx| { @@ -1826,7 +1823,7 @@ impl SerializableItem for TerminalView { let terminal = project .update(cx, |project, cx| { - project.create_terminal(TerminalKind::Shell(cwd), window_handle, cx) + project.create_terminal(TerminalKind::Shell(cwd), cx) })? .await?; cx.update(|window, cx| { diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index 68c7b2a2cd..9f7e49d24d 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -713,7 +713,7 @@ impl Buffer { let mut base_text = base_text.into(); let line_ending = LineEnding::detect(&base_text); LineEnding::normalize(&mut base_text); - Self::new_normalized(replica_id, remote_id, line_ending, Rope::from(base_text)) + Self::new_normalized(replica_id, remote_id, line_ending, Rope::from(&*base_text)) } pub fn new_normalized( diff --git a/crates/theme/src/icon_theme.rs b/crates/theme/src/icon_theme.rs index 10fd1e002d..5bd69c1733 100644 --- a/crates/theme/src/icon_theme.rs +++ b/crates/theme/src/icon_theme.rs @@ -183,6 +183,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[ ], ), ("prisma", &["prisma"]), + ("puppet", &["pp"]), ("python", &["py"]), ("r", &["r", "R"]), ("react", &["cjsx", "ctsx", "jsx", "mjsx", "mtsx", "tsx"]), @@ -331,6 +332,7 @@ const FILE_ICONS: &[(&str, &str)] = &[ ("php", "icons/file_icons/php.svg"), ("prettier", "icons/file_icons/prettier.svg"), ("prisma", "icons/file_icons/prisma.svg"), + ("puppet", "icons/file_icons/puppet.svg"), ("python", "icons/file_icons/python.svg"), ("r", "icons/file_icons/r.svg"), ("react", "icons/file_icons/react.svg"), diff --git a/crates/theme/src/settings.rs b/crates/theme/src/settings.rs index 20c837f287..6d19494f40 100644 --- a/crates/theme/src/settings.rs +++ b/crates/theme/src/settings.rs @@ -867,6 +867,7 @@ impl settings::Settings for ThemeSettings { .user .into_iter() .chain(sources.release_channel) + .chain(sources.operating_system) .chain(sources.profile) .chain(sources.server) { diff --git a/crates/theme_selector/src/icon_theme_selector.rs b/crates/theme_selector/src/icon_theme_selector.rs index 1adfc4b5d8..2d0b9480d5 100644 --- a/crates/theme_selector/src/icon_theme_selector.rs +++ b/crates/theme_selector/src/icon_theme_selector.rs @@ -40,7 +40,10 @@ impl IconThemeSelector { impl Render for IconThemeSelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("IconThemeSelector") + .w(rems(34.)) + .child(self.picker.clone()) } } diff --git a/crates/theme_selector/src/theme_selector.rs b/crates/theme_selector/src/theme_selector.rs index 022daced7a..ba8bde243b 100644 --- a/crates/theme_selector/src/theme_selector.rs +++ b/crates/theme_selector/src/theme_selector.rs @@ -92,7 +92,10 @@ impl Focusable for ThemeSelector { impl Render for ThemeSelector { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + v_flex() + .key_context("ThemeSelector") + .w(rems(34.)) + .child(self.picker.clone()) } } diff --git a/crates/title_bar/src/platform_title_bar.rs b/crates/title_bar/src/platform_title_bar.rs index 30b1b4c3f8..ef6ef93eed 100644 --- a/crates/title_bar/src/platform_title_bar.rs +++ b/crates/title_bar/src/platform_title_bar.rs @@ -106,14 +106,14 @@ impl Render for PlatformTitleBar { // Note: On Windows the title bar behavior is handled by the platform implementation. .when(self.platform_style == PlatformStyle::Mac, |this| { this.on_click(|event, window, _| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.titlebar_double_click(); } }) }) .when(self.platform_style == PlatformStyle::Linux, |this| { this.on_click(|event, window, _| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.zoom_window(); } }) diff --git a/crates/ui/src/components/banner.rs b/crates/ui/src/components/banner.rs index b16ca795b4..d88905d466 100644 --- a/crates/ui/src/components/banner.rs +++ b/crates/ui/src/components/banner.rs @@ -131,7 +131,7 @@ impl RenderOnce for Banner { impl Component for Banner { fn scope() -> ComponentScope { - ComponentScope::Notification + ComponentScope::DataDisplay } fn preview(_window: &mut Window, _cx: &mut App) -> Option { diff --git a/crates/ui/src/components/button/button_like.rs b/crates/ui/src/components/button/button_like.rs index 15ab00e7e5..35c78fbb5d 100644 --- a/crates/ui/src/components/button/button_like.rs +++ b/crates/ui/src/components/button/button_like.rs @@ -1,7 +1,8 @@ use documented::Documented; use gpui::{ AnyElement, AnyView, ClickEvent, CursorStyle, DefiniteLength, Hsla, MouseButton, - MouseDownEvent, MouseUpEvent, Rems, StyleRefinement, relative, transparent_black, + MouseClickEvent, MouseDownEvent, MouseUpEvent, Rems, StyleRefinement, relative, + transparent_black, }; use smallvec::SmallVec; @@ -620,7 +621,7 @@ impl RenderOnce for ButtonLike { MouseButton::Right, move |event, window, cx| { cx.stop_propagation(); - let click_event = ClickEvent { + let click_event = ClickEvent::Mouse(MouseClickEvent { down: MouseDownEvent { button: MouseButton::Right, position: event.position, @@ -634,7 +635,7 @@ impl RenderOnce for ButtonLike { modifiers: event.modifiers, click_count: 1, }, - }; + }); (on_right_click)(&click_event, window, cx) }, ) diff --git a/crates/ui/src/components/button/toggle_button.rs b/crates/ui/src/components/button/toggle_button.rs index 6fbf834667..91defa730b 100644 --- a/crates/ui/src/components/button/toggle_button.rs +++ b/crates/ui/src/components/button/toggle_button.rs @@ -1,6 +1,8 @@ +use std::rc::Rc; + use gpui::{AnyView, ClickEvent}; -use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, TintColor, prelude::*}; +use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, TintColor, Tooltip, prelude::*}; /// The position of a [`ToggleButton`] within a group of buttons. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -301,6 +303,7 @@ pub struct ButtonConfiguration { icon: Option, on_click: Box, selected: bool, + tooltip: Option AnyView>>, } mod private { @@ -315,6 +318,7 @@ pub struct ToggleButtonSimple { label: SharedString, on_click: Box, selected: bool, + tooltip: Option AnyView>>, } impl ToggleButtonSimple { @@ -326,6 +330,7 @@ impl ToggleButtonSimple { label: label.into(), on_click: Box::new(on_click), selected: false, + tooltip: None, } } @@ -333,6 +338,11 @@ impl ToggleButtonSimple { self.selected = selected; self } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } } impl private::ToggleButtonStyle for ToggleButtonSimple {} @@ -344,6 +354,7 @@ impl ButtonBuilder for ToggleButtonSimple { icon: None, on_click: self.on_click, selected: self.selected, + tooltip: self.tooltip, } } } @@ -353,6 +364,7 @@ pub struct ToggleButtonWithIcon { icon: IconName, on_click: Box, selected: bool, + tooltip: Option AnyView>>, } impl ToggleButtonWithIcon { @@ -366,6 +378,7 @@ impl ToggleButtonWithIcon { icon, on_click: Box::new(on_click), selected: false, + tooltip: None, } } @@ -373,6 +386,11 @@ impl ToggleButtonWithIcon { self.selected = selected; self } + + pub fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self { + self.tooltip = Some(Rc::new(tooltip)); + self + } } impl private::ToggleButtonStyle for ToggleButtonWithIcon {} @@ -384,6 +402,7 @@ impl ButtonBuilder for ToggleButtonWithIcon { icon: Some(self.icon), on_click: self.on_click, selected: self.selected, + tooltip: self.tooltip, } } } @@ -486,11 +505,13 @@ impl RenderOnce icon, on_click, selected, + tooltip, } = button.into_configuration(); let entry_index = row_index * COLS + col_index; ButtonLike::new((self.group_name, entry_index)) + .rounding(None) .when_some(self.tab_index, |this, tab_index| { this.tab_index(tab_index + entry_index as isize) }) @@ -498,7 +519,6 @@ impl RenderOnce this.toggle_state(true) .selected_style(ButtonStyle::Tinted(TintColor::Accent)) }) - .rounding(None) .when(self.style == ToggleButtonGroupStyle::Filled, |button| { button.style(ButtonStyle::Filled) }) @@ -527,6 +547,9 @@ impl RenderOnce |this| this.color(Color::Accent), )), ) + .when_some(tooltip, |this, tooltip| { + this.tooltip(move |window, cx| tooltip(window, cx)) + }) .on_click(on_click) .into_any_element() }) @@ -920,6 +943,23 @@ impl Component ), ], )]) + .children(vec![single_example( + "With Tooltips", + ToggleButtonGroup::single_row( + "with_tooltips", + [ + ToggleButtonSimple::new("First", |_, _, _| {}) + .tooltip(Tooltip::text("This is a tooltip. Hello!")), + ToggleButtonSimple::new("Second", |_, _, _| {}) + .tooltip(Tooltip::text("This is a tooltip. Hey?")), + ToggleButtonSimple::new("Third", |_, _, _| {}) + .tooltip(Tooltip::text("This is a tooltip. Get out of here now!")), + ], + ) + .selected_index(1) + .button_width(rems_from_px(100.)) + .into_any_element(), + )]) .into_any_element(), ) } diff --git a/crates/ui/src/components/callout.rs b/crates/ui/src/components/callout.rs index d15fa122ed..9c1c9fb1a9 100644 --- a/crates/ui/src/components/callout.rs +++ b/crates/ui/src/components/callout.rs @@ -158,7 +158,7 @@ impl RenderOnce for Callout { impl Component for Callout { fn scope() -> ComponentScope { - ComponentScope::Notification + ComponentScope::DataDisplay } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/components/context_menu.rs b/crates/ui/src/components/context_menu.rs index 77468fd295..21ab283d88 100644 --- a/crates/ui/src/components/context_menu.rs +++ b/crates/ui/src/components/context_menu.rs @@ -679,18 +679,18 @@ impl ContextMenu { let next_index = ix + 1; if self.items.len() <= next_index { self.select_first(&SelectFirst, window, cx); + return; } else { for (ix, item) in self.items.iter().enumerate().skip(next_index) { if item.is_selectable() { self.select_index(ix, window, cx); cx.notify(); - break; + return; } } } - } else { - self.select_first(&SelectFirst, window, cx); } + self.select_first(&SelectFirst, window, cx); } pub fn select_previous( @@ -1203,6 +1203,7 @@ mod tests { .separator() .separator() .entry("Last entry", None, |_, _| {}) + .header("Last header") }) }); @@ -1255,5 +1256,27 @@ mod tests { "Should go back to previous selectable entry (first)" ); }); + + context_menu.update_in(cx, |context_menu, window, cx| { + context_menu.select_first(&SelectFirst, window, cx); + assert_eq!( + Some(2), + context_menu.selected_index, + "Should start from the first selectable entry" + ); + + context_menu.select_previous(&SelectPrevious, window, cx); + assert_eq!( + Some(5), + context_menu.selected_index, + "Should wrap around to last selectable entry" + ); + context_menu.select_next(&SelectNext, window, cx); + assert_eq!( + Some(2), + context_menu.selected_index, + "Should wrap around to first selectable entry" + ); + }); } } diff --git a/crates/ui/src/components/disclosure.rs b/crates/ui/src/components/disclosure.rs index a1fab02e54..98406cd1e2 100644 --- a/crates/ui/src/components/disclosure.rs +++ b/crates/ui/src/components/disclosure.rs @@ -95,7 +95,7 @@ impl RenderOnce for Disclosure { impl Component for Disclosure { fn scope() -> ComponentScope { - ComponentScope::Navigation + ComponentScope::Input } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/components/image.rs b/crates/ui/src/components/image.rs index 2deba68d88..09c3bbeb94 100644 --- a/crates/ui/src/components/image.rs +++ b/crates/ui/src/components/image.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use gpui::Transformation; use gpui::{App, IntoElement, Rems, RenderOnce, Size, Styled, Window, svg}; use serde::{Deserialize, Serialize}; use strum::{EnumIter, EnumString, IntoStaticStr}; @@ -12,11 +13,13 @@ use crate::prelude::*; )] #[strum(serialize_all = "snake_case")] pub enum VectorName { - ZedLogo, - ZedXCopilot, - Grid, AiGrid, DebuggerGrid, + Grid, + ProTrialStamp, + ProUserStamp, + ZedLogo, + ZedXCopilot, } impl VectorName { @@ -37,6 +40,7 @@ pub struct Vector { path: Arc, color: Color, size: Size, + transformation: Transformation, } impl Vector { @@ -46,6 +50,7 @@ impl Vector { path: vector.path(), color: Color::default(), size: Size { width, height }, + transformation: Transformation::default(), } } @@ -66,6 +71,11 @@ impl Vector { self.size = size; self } + + pub fn transform(mut self, transformation: Transformation) -> Self { + self.transformation = transformation; + self + } } impl RenderOnce for Vector { @@ -81,6 +91,7 @@ impl RenderOnce for Vector { .h(height) .path(self.path) .text_color(self.color.color(cx)) + .with_transformation(self.transformation) } } diff --git a/crates/ui/src/components/list.rs b/crates/ui/src/components/list.rs index 88650b6ae8..6876f290ce 100644 --- a/crates/ui/src/components/list.rs +++ b/crates/ui/src/components/list.rs @@ -1,10 +1,12 @@ mod list; +mod list_bullet_item; mod list_header; mod list_item; mod list_separator; mod list_sub_header; pub use list::*; +pub use list_bullet_item::*; pub use list_header::*; pub use list_item::*; pub use list_separator::*; diff --git a/crates/ui/src/components/list/list_bullet_item.rs b/crates/ui/src/components/list/list_bullet_item.rs new file mode 100644 index 0000000000..6e079d9f11 --- /dev/null +++ b/crates/ui/src/components/list/list_bullet_item.rs @@ -0,0 +1,40 @@ +use crate::{ListItem, prelude::*}; +use gpui::{IntoElement, ParentElement, SharedString}; + +#[derive(IntoElement)] +pub struct ListBulletItem { + label: SharedString, +} + +impl ListBulletItem { + pub fn new(label: impl Into) -> Self { + Self { + label: label.into(), + } + } +} + +impl RenderOnce for ListBulletItem { + fn render(self, window: &mut Window, _cx: &mut App) -> impl IntoElement { + let line_height = 0.85 * window.line_height(); + + ListItem::new("list-item") + .selectable(false) + .child( + h_flex() + .w_full() + .min_w_0() + .gap_1() + .items_start() + .child( + h_flex().h(line_height).justify_center().child( + Icon::new(IconName::Dash) + .size(IconSize::XSmall) + .color(Color::Hidden), + ), + ) + .child(div().w_full().min_w_0().child(Label::new(self.label))), + ) + .into_any_element() + } +} diff --git a/crates/ui/src/components/scrollbar.rs b/crates/ui/src/components/scrollbar.rs index 7af55b76b7..605028202f 100644 --- a/crates/ui/src/components/scrollbar.rs +++ b/crates/ui/src/components/scrollbar.rs @@ -1,11 +1,20 @@ -use std::{any::Any, cell::Cell, fmt::Debug, ops::Range, rc::Rc, sync::Arc}; +use std::{ + any::Any, + cell::{Cell, RefCell}, + fmt::Debug, + ops::Range, + rc::Rc, + sync::Arc, + time::Duration, +}; use crate::{IntoElement, prelude::*, px, relative}; use gpui::{ Along, App, Axis as ScrollbarAxis, BorderStyle, Bounds, ContentMask, Corners, CursorStyle, Edges, Element, ElementId, Entity, EntityId, GlobalElementId, Hitbox, HitboxBehavior, Hsla, IsZero, LayoutId, ListState, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, - Point, ScrollHandle, ScrollWheelEvent, Size, Style, UniformListScrollHandle, Window, quad, + Point, ScrollHandle, ScrollWheelEvent, Size, Style, Task, UniformListScrollHandle, Window, + quad, }; pub struct Scrollbar { @@ -108,6 +117,25 @@ pub struct ScrollbarState { thumb_state: Rc>, parent_id: Option, scroll_handle: Arc, + auto_hide: Rc>, +} + +#[derive(Debug)] +enum AutoHide { + Disabled, + Hidden { + parent_id: EntityId, + }, + Visible { + parent_id: EntityId, + _task: Task<()>, + }, +} + +impl AutoHide { + fn is_hidden(&self) -> bool { + matches!(self, AutoHide::Hidden { .. }) + } } impl ScrollbarState { @@ -116,6 +144,7 @@ impl ScrollbarState { thumb_state: Default::default(), parent_id: None, scroll_handle: Arc::new(scroll), + auto_hide: Rc::new(RefCell::new(AutoHide::Disabled)), } } @@ -174,6 +203,38 @@ impl ScrollbarState { let thumb_percentage_end = (start_offset + thumb_size) / viewport_size; Some(thumb_percentage_start..thumb_percentage_end) } + + fn show_temporarily(&self, parent_id: EntityId, cx: &mut App) { + const SHOW_INTERVAL: Duration = Duration::from_secs(1); + + let auto_hide = self.auto_hide.clone(); + auto_hide.replace(AutoHide::Visible { + parent_id, + _task: cx.spawn({ + let this = auto_hide.clone(); + async move |cx| { + cx.background_executor().timer(SHOW_INTERVAL).await; + this.replace(AutoHide::Hidden { parent_id }); + cx.update(|cx| { + cx.notify(parent_id); + }) + .ok(); + } + }), + }); + } + + fn unhide(&self, position: &Point, cx: &mut App) { + let parent_id = match &*self.auto_hide.borrow() { + AutoHide::Disabled => return, + AutoHide::Hidden { parent_id } => *parent_id, + AutoHide::Visible { parent_id, _task } => *parent_id, + }; + + if self.scroll_handle().viewport().contains(position) { + self.show_temporarily(parent_id, cx); + } + } } impl Scrollbar { @@ -189,6 +250,14 @@ impl Scrollbar { let thumb = state.thumb_range(kind)?; Some(Self { thumb, state, kind }) } + + /// Automatically hide the scrollbar when idle + pub fn auto_hide(self, cx: &mut Context) -> Self { + if matches!(*self.state.auto_hide.borrow(), AutoHide::Disabled) { + self.state.show_temporarily(cx.entity_id(), cx); + } + self + } } impl Element for Scrollbar { @@ -284,16 +353,18 @@ impl Element for Scrollbar { .apply_along(axis.invert(), |width| width / 1.5), ); - let corners = Corners::all(thumb_bounds.size.along(axis.invert()) / 2.0); + if thumb_state.is_dragging() || !self.state.auto_hide.borrow().is_hidden() { + let corners = Corners::all(thumb_bounds.size.along(axis.invert()) / 2.0); - window.paint_quad(quad( - thumb_bounds, - corners, - thumb_background, - Edges::default(), - Hsla::transparent_black(), - BorderStyle::default(), - )); + window.paint_quad(quad( + thumb_bounds, + corners, + thumb_background, + Edges::default(), + Hsla::transparent_black(), + BorderStyle::default(), + )); + } if thumb_state.is_dragging() { window.set_window_cursor_style(CursorStyle::Arrow); @@ -361,13 +432,18 @@ impl Element for Scrollbar { }); window.on_mouse_event({ + let state = self.state.clone(); let scroll_handle = self.state.scroll_handle().clone(); - move |event: &ScrollWheelEvent, phase, window, _| { - if phase.bubble() && bounds.contains(&event.position) { - let current_offset = scroll_handle.offset(); - scroll_handle.set_offset( - current_offset + event.delta.pixel_delta(window.line_height()), - ); + move |event: &ScrollWheelEvent, phase, window, cx| { + if phase.bubble() { + state.unhide(&event.position, cx); + + if bounds.contains(&event.position) { + let current_offset = scroll_handle.offset(); + scroll_handle.set_offset( + current_offset + event.delta.pixel_delta(window.line_height()), + ); + } } } }); @@ -376,6 +452,8 @@ impl Element for Scrollbar { let state = self.state.clone(); move |event: &MouseMoveEvent, phase, window, cx| { if phase.bubble() { + state.unhide(&event.position, cx); + match state.thumb_state.get() { ThumbState::Dragging(drag_state) if event.dragging() => { let scroll_handle = state.scroll_handle(); diff --git a/crates/ui/src/components/tab.rs b/crates/ui/src/components/tab.rs index a205c33358..d704846a68 100644 --- a/crates/ui/src/components/tab.rs +++ b/crates/ui/src/components/tab.rs @@ -179,7 +179,7 @@ impl RenderOnce for Tab { impl Component for Tab { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::Navigation } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/components/toggle.rs b/crates/ui/src/components/toggle.rs index 53df4767b0..4b985fd2c2 100644 --- a/crates/ui/src/components/toggle.rs +++ b/crates/ui/src/components/toggle.rs @@ -504,15 +504,12 @@ impl RenderOnce for Switch { let group_id = format!("switch_group_{:?}", self.id); - let switch = h_flex() - .w(DynamicSpacing::Base32.rems(cx)) - .h(DynamicSpacing::Base20.rems(cx)) - .group(group_id.clone()) - .border_1() + let switch = div() + .id((self.id.clone(), "switch")) .p(px(1.0)) + .border_2() .border_color(cx.theme().colors().border_transparent) .rounded_full() - .id((self.id.clone(), "switch")) .when_some( self.tab_index.filter(|_| !self.disabled), |this, tab_index| { @@ -524,23 +521,29 @@ impl RenderOnce for Switch { ) .child( h_flex() - .when(is_on, |on| on.justify_end()) - .when(!is_on, |off| off.justify_start()) - .size_full() - .rounded_full() - .px(DynamicSpacing::Base02.px(cx)) - .bg(bg_color) - .when(!self.disabled, |this| { - this.group_hover(group_id.clone(), |el| el.bg(bg_hover_color)) - }) - .border_1() - .border_color(border_color) + .w(DynamicSpacing::Base32.rems(cx)) + .h(DynamicSpacing::Base20.rems(cx)) + .group(group_id.clone()) .child( - div() - .size(DynamicSpacing::Base12.rems(cx)) + h_flex() + .when(is_on, |on| on.justify_end()) + .when(!is_on, |off| off.justify_start()) + .size_full() .rounded_full() - .bg(thumb_color) - .opacity(thumb_opacity), + .px(DynamicSpacing::Base02.px(cx)) + .bg(bg_color) + .when(!self.disabled, |this| { + this.group_hover(group_id.clone(), |el| el.bg(bg_hover_color)) + }) + .border_1() + .border_color(border_color) + .child( + div() + .size(DynamicSpacing::Base12.rems(cx)) + .rounded_full() + .bg(thumb_color) + .opacity(thumb_opacity), + ), ), ); diff --git a/crates/ui/src/styles/animation.rs b/crates/ui/src/styles/animation.rs index 0649bee1f8..ee5352d454 100644 --- a/crates/ui/src/styles/animation.rs +++ b/crates/ui/src/styles/animation.rs @@ -99,7 +99,7 @@ struct Animation {} impl Component for Animation { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::Utilities } fn description() -> Option<&'static str> { diff --git a/crates/ui/src/styles/color.rs b/crates/ui/src/styles/color.rs index c7b995d39a..586b2ccc57 100644 --- a/crates/ui/src/styles/color.rs +++ b/crates/ui/src/styles/color.rs @@ -126,7 +126,7 @@ impl From for Color { impl Component for Color { fn scope() -> ComponentScope { - ComponentScope::None + ComponentScope::Utilities } fn description() -> Option<&'static str> { diff --git a/crates/ui_input/src/ui_input.rs b/crates/ui_input/src/ui_input.rs index 309b3f62f6..1a5bebaf1e 100644 --- a/crates/ui_input/src/ui_input.rs +++ b/crates/ui_input/src/ui_input.rs @@ -168,7 +168,7 @@ impl Render for SingleLineInput { .py_1p5() .flex_grow() .text_color(style.text_color) - .rounded_lg() + .rounded_sm() .bg(style.background_color) .border_1() .border_color(style.border_color) diff --git a/crates/util/src/archive.rs b/crates/util/src/archive.rs index d10b996716..3e4d281c29 100644 --- a/crates/util/src/archive.rs +++ b/crates/util/src/archive.rs @@ -2,6 +2,8 @@ use std::path::Path; use anyhow::{Context as _, Result}; use async_zip::base::read; +#[cfg(not(windows))] +use futures::AsyncSeek; use futures::{AsyncRead, io::BufReader}; #[cfg(windows)] @@ -62,7 +64,15 @@ pub async fn extract_zip(destination: &Path, reader: R) -> futures::io::copy(&mut BufReader::new(reader), &mut file) .await .context("saving archive contents into the temporary file")?; - let mut reader = read::seek::ZipFileReader::new(BufReader::new(file)) + extract_seekable_zip(destination, file).await +} + +#[cfg(not(windows))] +pub async fn extract_seekable_zip( + destination: &Path, + reader: R, +) -> Result<()> { + let mut reader = read::seek::ZipFileReader::new(BufReader::new(reader)) .await .context("reading the zip archive")?; let destination = &destination diff --git a/crates/util/src/fs.rs b/crates/util/src/fs.rs index 2738b6e213..3e96594f85 100644 --- a/crates/util/src/fs.rs +++ b/crates/util/src/fs.rs @@ -95,9 +95,9 @@ pub async fn move_folder_files_to_folder>( #[cfg(unix)] /// Set the permissions for the given path so that the file becomes executable. /// This is a noop for non-unix platforms. -pub async fn make_file_executable(path: &PathBuf) -> std::io::Result<()> { +pub async fn make_file_executable(path: &Path) -> std::io::Result<()> { fs::set_permissions( - &path, + path, ::from_mode(0o755), ) .await @@ -107,6 +107,6 @@ pub async fn make_file_executable(path: &PathBuf) -> std::io::Result<()> { #[allow(clippy::unused_async)] /// Set the permissions for the given path so that the file becomes executable. /// This is a noop for non-unix platforms. -pub async fn make_file_executable(_path: &PathBuf) -> std::io::Result<()> { +pub async fn make_file_executable(_path: &Path) -> std::io::Result<()> { Ok(()) } diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 2062255f4b..a9e7304e47 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -62,7 +62,7 @@ pub struct SelectedEntry { #[derive(Debug)] pub struct DraggedSelection { pub active_selection: SelectedEntry, - pub marked_selections: Arc>, + pub marked_selections: Arc<[SelectedEntry]>, } impl DraggedSelection { @@ -2945,7 +2945,7 @@ impl Pane { this.handle_external_paths_drop(paths, window, cx) })) .on_click(cx.listener(move |this, event: &ClickEvent, window, cx| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.dispatch_action( this.double_click_dispatch_action.boxed_clone(), cx, @@ -3640,7 +3640,7 @@ impl Render for Pane { .justify_center() .on_click(cx.listener( move |this, event: &ClickEvent, window, cx| { - if event.up.click_count == 2 { + if event.click_count() == 2 { window.dispatch_action( this.double_click_dispatch_action.boxed_clone(), cx, diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 63953ff802..aab8a36f45 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -1086,6 +1086,7 @@ pub struct Workspace { follower_states: HashMap, last_leaders_by_pane: HashMap, CollaboratorId>, window_edited: bool, + last_window_title: Option, dirty_items: HashMap, active_call: Option<(Entity, Vec)>, leader_updates_tx: mpsc::UnboundedSender<(PeerId, proto::UpdateFollowers)>, @@ -1418,6 +1419,7 @@ impl Workspace { last_leaders_by_pane: Default::default(), dispatching_keystrokes: Default::default(), window_edited: false, + last_window_title: None, dirty_items: Default::default(), active_call, database_id: workspace_id, @@ -1813,10 +1815,7 @@ impl Workspace { .max_by(|b1, b2| b1.worktree_id.cmp(&b2.worktree_id)) }); - match latest_project_path_opened { - Some(latest_project_path_opened) => latest_project_path_opened == history_path, - None => true, - } + latest_project_path_opened.map_or(true, |path| path == history_path) }) } @@ -4406,7 +4405,13 @@ impl Workspace { title.push_str(" ↗"); } + if let Some(last_title) = self.last_window_title.as_ref() { + if &title == last_title { + return; + } + } window.set_window_title(&title); + self.last_window_title = Some(title); } fn update_window_edited(&mut self, window: &mut Window, cx: &mut App) { @@ -4796,7 +4801,7 @@ impl Workspace { .remote_id(&self.app_state.client, window, cx) .map(|id| id.to_proto()); - if let Some(id) = id.clone() { + if let Some(id) = id { if let Some(variant) = item.to_state_proto(window, cx) { let view = Some(proto::View { id: id.clone(), @@ -4809,7 +4814,7 @@ impl Workspace { update = proto::UpdateActiveView { view, // TODO: Remove after version 0.145.x stabilizes. - id: id.clone(), + id, leader_id: leader_peer_id, }; } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 5bd6d981fa..5997e43864 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.199.0" +version = "0.200.0" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team "] diff --git a/crates/zed/resources/windows/app-icon-nightly.ico b/crates/zed/resources/windows/app-icon-nightly.ico index 15e06a6e17..165e4ce1f7 100644 Binary files a/crates/zed/resources/windows/app-icon-nightly.ico and b/crates/zed/resources/windows/app-icon-nightly.ico differ diff --git a/crates/zed/src/reliability.rs b/crates/zed/src/reliability.rs index ed149a470a..53539699cc 100644 --- a/crates/zed/src/reliability.rs +++ b/crates/zed/src/reliability.rs @@ -149,6 +149,7 @@ pub fn init_panic_hook( let timestamp = chrono::Utc::now().format("%Y_%m_%d %H_%M_%S").to_string(); let panic_file_path = paths::logs_dir().join(format!("zed-{timestamp}.panic")); let panic_file = fs::OpenOptions::new() + .write(true) .create_new(true) .open(&panic_file_path) .log_err(); @@ -553,6 +554,10 @@ async fn upload_previous_panics( .log_err(); } + if MINIDUMP_ENDPOINT.is_none() { + return Ok(most_recent_panic); + } + // loop back over the directory again to upload any minidumps that are missing panics let mut children = smol::fs::read_dir(paths::logs_dir()).await?; while let Some(child) = children.next().await { @@ -597,11 +602,12 @@ async fn upload_minidump( ) .text("platform", "rust"); if let Some(panic) = panic { - form = form.text( - "release", - format!("{}-{}", panic.release_channel, panic.app_version), - ); - // TODO: tack on more fields + form = form + .text( + "sentry[release]", + format!("{}-{}", panic.release_channel, panic.app_version), + ) + .text("sentry[logentry][formatted]", panic.payload.clone()); } let mut response_text = String::new(); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index ec62ed33fd..8c89a7d85a 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -4439,7 +4439,7 @@ mod tests { }); for name in languages.language_names() { languages - .language_for_name(&name) + .language_for_name(name.as_ref()) .await .with_context(|| format!("language name {name}")) .unwrap(); diff --git a/crates/zed/src/zed/component_preview.rs b/crates/zed/src/zed/component_preview.rs index 480505338b..db75b544f6 100644 --- a/crates/zed/src/zed/component_preview.rs +++ b/crates/zed/src/zed/component_preview.rs @@ -107,6 +107,7 @@ struct ComponentPreview { active_thread: Option>, reset_key: usize, component_list: ListState, + entries: Vec, component_map: HashMap, components: Vec, cursor_index: usize, @@ -172,17 +173,6 @@ impl ComponentPreview { sorted_components.len(), gpui::ListAlignment::Top, px(1500.0), - { - let this = cx.entity().downgrade(); - move |ix, window: &mut Window, cx: &mut App| { - this.update(cx, |this, cx| { - let component = this.get_component(ix); - this.render_preview(&component, window, cx) - .into_any_element() - }) - .unwrap() - } - }, ); let mut component_preview = Self { @@ -190,6 +180,7 @@ impl ComponentPreview { active_thread: None, reset_key: 0, component_list, + entries: Vec::new(), component_map: component_registry.component_map(), components: sorted_components, cursor_index: selected_index, @@ -276,10 +267,6 @@ impl ComponentPreview { cx.notify(); } - fn get_component(&self, ix: usize) -> ComponentMetadata { - self.components[ix].clone() - } - fn filtered_components(&self) -> Vec { if self.filter_text.is_empty() { return self.components.clone(); @@ -420,7 +407,6 @@ impl ComponentPreview { fn update_component_list(&mut self, cx: &mut Context) { let entries = self.scope_ordered_entries(); let new_len = entries.len(); - let weak_entity = cx.entity().downgrade(); if new_len > 0 { self.nav_scroll_handle @@ -446,56 +432,9 @@ impl ComponentPreview { } } - self.component_list = ListState::new( - filtered_components.len(), - gpui::ListAlignment::Top, - px(1500.0), - { - let components = filtered_components.clone(); - let this = cx.entity().downgrade(); - move |ix, window: &mut Window, cx: &mut App| { - if ix >= components.len() { - return div().w_full().h_0().into_any_element(); - } + self.component_list = ListState::new(new_len, gpui::ListAlignment::Top, px(1500.0)); + self.entries = entries; - this.update(cx, |this, cx| { - let component = &components[ix]; - this.render_preview(component, window, cx) - .into_any_element() - }) - .unwrap() - } - }, - ); - - let new_list = ListState::new( - new_len, - gpui::ListAlignment::Top, - px(1500.0), - move |ix, window, cx| { - if ix >= entries.len() { - return div().w_full().h_0().into_any_element(); - } - - let entry = &entries[ix]; - - weak_entity - .update(cx, |this, cx| match entry { - PreviewEntry::Component(component, _) => this - .render_preview(component, window, cx) - .into_any_element(), - PreviewEntry::SectionHeader(shared_string) => this - .render_scope_header(ix, shared_string.clone(), window, cx) - .into_any_element(), - PreviewEntry::AllComponents => div().w_full().h_0().into_any_element(), - PreviewEntry::ActiveThread => div().w_full().h_0().into_any_element(), - PreviewEntry::Separator => div().w_full().h_0().into_any_element(), - }) - .unwrap() - }, - ); - - self.component_list = new_list; cx.emit(ItemEvent::UpdateTab); } @@ -672,10 +611,35 @@ impl ComponentPreview { .child(format!("No components matching '{}'.", self.filter_text)) .into_any_element() } else { - list(self.component_list.clone()) - .flex_grow() - .with_sizing_behavior(gpui::ListSizingBehavior::Auto) - .into_any_element() + list( + self.component_list.clone(), + cx.processor(|this, ix, window, cx| { + if ix >= this.entries.len() { + return div().w_full().h_0().into_any_element(); + } + + let entry = &this.entries[ix]; + + match entry { + PreviewEntry::Component(component, _) => this + .render_preview(component, window, cx) + .into_any_element(), + PreviewEntry::SectionHeader(shared_string) => this + .render_scope_header(ix, shared_string.clone(), window, cx) + .into_any_element(), + PreviewEntry::AllComponents => { + div().w_full().h_0().into_any_element() + } + PreviewEntry::ActiveThread => { + div().w_full().h_0().into_any_element() + } + PreviewEntry::Separator => div().w_full().h_0().into_any_element(), + } + }), + ) + .flex_grow() + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .into_any_element() }, ) } diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md index 04646213e6..8fdb7ea325 100644 --- a/docs/src/ai/llm-providers.md +++ b/docs/src/ai/llm-providers.md @@ -14,25 +14,25 @@ You can add your API key to a given provider either via the Agent Panel's settin Here's all the supported LLM providers for which you can use your own API keys: -| Provider | Tool Use Supported | -| ----------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [Amazon Bedrock](#amazon-bedrock) | Depends on the model | -| [Anthropic](#anthropic) | ✅ | -| [DeepSeek](#deepseek) | ✅ | -| [GitHub Copilot Chat](#github-copilot-chat) | For some models ([link](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198)) | -| [Google AI](#google-ai) | ✅ | -| [LM Studio](#lmstudio) | ✅ | -| [Mistral](#mistral) | ✅ | -| [Ollama](#ollama) | ✅ | -| [OpenAI](#openai) | ✅ | -| [OpenAI API Compatible](#openai-api-compatible) | ✅ | -| [OpenRouter](#openrouter) | ✅ | -| [Vercel](#vercel-v0) | ✅ | -| [xAI](#xai) | ✅ | +| Provider | +| ----------------------------------------------- | +| [Amazon Bedrock](#amazon-bedrock) | +| [Anthropic](#anthropic) | +| [DeepSeek](#deepseek) | +| [GitHub Copilot Chat](#github-copilot-chat) | +| [Google AI](#google-ai) | +| [LM Studio](#lmstudio) | +| [Mistral](#mistral) | +| [Ollama](#ollama) | +| [OpenAI](#openai) | +| [OpenAI API Compatible](#openai-api-compatible) | +| [OpenRouter](#openrouter) | +| [Vercel](#vercel-v0) | +| [xAI](#xai) | ### Amazon Bedrock {#amazon-bedrock} -> ✅ Supports tool use with models that support streaming tool use. +> Supports tool use with models that support streaming tool use. > More details can be found in the [Amazon Bedrock's Tool Use documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html). To use Amazon Bedrock's models, an AWS authentication is required. @@ -107,8 +107,6 @@ For the most up-to-date supported regions and models, refer to the [Supported Mo ### Anthropic {#anthropic} -> ✅ Supports tool use - You can use Anthropic models by choosing them via the model dropdown in the Agent Panel. 1. Sign up for Anthropic and [create an API key](https://console.anthropic.com/settings/keys) @@ -165,8 +163,6 @@ You can configure a model to use [extended thinking](https://docs.anthropic.com/ ### DeepSeek {#deepseek} -> ✅ Supports tool use - 1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) 2. Open the settings view (`agent: open settings`) and go to the DeepSeek section 3. Enter your DeepSeek API key @@ -208,9 +204,6 @@ You can also modify the `api_url` to use a custom endpoint if needed. ### GitHub Copilot Chat {#github-copilot-chat} -> ✅ Supports tool use in some cases. -> Visit [the Copilot Chat code](https://github.com/zed-industries/zed/blob/9e0330ba7d848755c9734bf456c716bddf0973f3/crates/language_models/src/provider/copilot_chat.rs#L189-L198) for the supported subset. - You can use GitHub Copilot Chat with the Zed agent by choosing it via the model dropdown in the Agent Panel. 1. Open the settings view (`agent: open settings`) and go to the GitHub Copilot Chat section @@ -224,8 +217,6 @@ To use Copilot Enterprise with Zed (for both agent and completions), you must co ### Google AI {#google-ai} -> ✅ Supports tool use - You can use Gemini models with the Zed agent by choosing it via the model dropdown in the Agent Panel. 1. Go to the Google AI Studio site and [create an API key](https://aistudio.google.com/app/apikey). @@ -266,8 +257,6 @@ Custom models will be listed in the model dropdown in the Agent Panel. ### LM Studio {#lmstudio} -> ✅ Supports tool use - 1. Download and install [the latest version of LM Studio](https://lmstudio.ai/download) 2. In the app press `cmd/ctrl-shift-m` and download at least one model (e.g., qwen2.5-coder-7b). Alternatively, you can get models via the LM Studio CLI: @@ -285,8 +274,6 @@ Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless# ### Mistral {#mistral} -> ✅ Supports tool use - 1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) 2. Open the configuration view (`agent: open settings`) and navigate to the Mistral section 3. Enter your Mistral API key @@ -326,8 +313,6 @@ Custom models will be listed in the model dropdown in the Agent Panel. ### Ollama {#ollama} -> ✅ Supports tool use - Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`. 1. Download one of the [available models](https://ollama.com/models), for example, for `mistral`: @@ -395,8 +380,6 @@ If the model is tagged with `vision` in the Ollama catalog, set this option and ### OpenAI {#openai} -> ✅ Supports tool use - 1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) 2. Make sure that your OpenAI account has credits 3. Open the settings view (`agent: open settings`) and go to the OpenAI section @@ -473,8 +456,6 @@ So, ensure you have it set in your environment variables (`OPENAI_API_KEY= ✅ Supports tool use - OpenRouter provides access to multiple AI models through a single API. It supports tool use for compatible models. 1. Visit [OpenRouter](https://openrouter.ai) and create an account @@ -531,8 +512,6 @@ Custom models will be listed in the model dropdown in the Agent Panel. ### Vercel v0 {#vercel-v0} -> ✅ Supports tool use - [Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel. It supports text and image inputs and provides fast streaming responses. @@ -545,8 +524,6 @@ You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. ### xAI {#xai} -> ✅ Supports tool use - Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. 1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index 5fd27abad6..1996e1c4ee 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -1275,6 +1275,18 @@ Each option controls displaying of a particular toolbar element. If all elements `boolean` values +## Status Bar + +- Description: Control various elements in the status bar. Note that some items in the status bar have their own settings set elsewhere. +- Setting: `status_bar` +- Default: + +```json +"status_bar": { + "active_language_button": true, +}, +``` + ## LSP - Description: Configuration for language servers. @@ -1795,7 +1807,6 @@ Example: { "git": { "inline_blame": { - "enabled": true, "delay_ms": 500 } } @@ -1808,7 +1819,6 @@ Example: { "git": { "inline_blame": { - "enabled": true, "show_commit_summary": true } } @@ -1821,13 +1831,24 @@ Example: { "git": { "inline_blame": { - "enabled": true, "min_column": 80 } } } ``` +5. Set the padding between the end of the line and the inline blame hint, in ems: + +```json +{ + "git": { + "inline_blame": { + "padding": 10 + } + } +} +``` + ### Hunk Style - Description: What styling we should use for the diff hunks. @@ -3204,7 +3225,8 @@ Run the `theme selector: toggle` action in the command palette to see a current "indent_guides": { "show": "always" }, - "hide_root": false + "hide_root": false, + "starts_open": true } } ``` diff --git a/docs/src/development/local-collaboration.md b/docs/src/development/local-collaboration.md index 9f0e3ef191..eb7f3dfc43 100644 --- a/docs/src/development/local-collaboration.md +++ b/docs/src/development/local-collaboration.md @@ -1,13 +1,27 @@ # Local Collaboration -First, make sure you've installed Zed's dependencies for your platform: +1. Ensure you have access to our cloud infrastructure. If you don't have access, you can't collaborate locally at this time. -- [macOS](./macos.md#backend-dependencies) -- [Linux](./linux.md#backend-dependencies) -- [Windows](./windows.md#backend-dependencies) +2. Make sure you've installed Zed's dependencies for your platform: + +- [macOS](#macos) +- [Linux](#linux) +- [Windows](#backend-windows) Note that `collab` can be compiled only with MSVC toolchain on Windows +3. Clone down our cloud repository and follow the instructions in the cloud README + +4. Setup the local database for your platform: + +- [macOS & Linux](#database-unix) +- [Windows](#database-windows) + +5. Run collab: + +- [macOS & Linux](#run-collab-unix) +- [Windows](#run-collab-windows) + ## Backend Dependencies If you are developing collaborative features of Zed, you'll need to install the dependencies of zed's `collab` server: @@ -18,7 +32,7 @@ If you are developing collaborative features of Zed, you'll need to install the You can install these dependencies natively or run them under Docker. -### MacOS +### macOS 1. Install [Postgres.app](https://postgresapp.com) or [postgresql via homebrew](https://formulae.brew.sh/formula/postgresql@15): @@ -76,7 +90,7 @@ docker compose up -d Before you can run the `collab` server locally, you'll need to set up a `zed` Postgres database. -### On macOS and Linux +### On macOS and Linux {#database-unix} ```sh script/bootstrap @@ -99,7 +113,7 @@ To use a different set of admin users, you can create your own version of that j } ``` -### On Windows +### On Windows {#database-windows} ```powershell .\script\bootstrap.ps1 @@ -107,7 +121,7 @@ To use a different set of admin users, you can create your own version of that j ## Testing collaborative features locally -### On macOS and Linux +### On macOS and Linux {#run-collab-unix} Ensure that Postgres is configured and running, then run Zed's collaboration server and the `livekit` dev server: @@ -117,12 +131,16 @@ foreman start docker compose up ``` -Alternatively, if you're not testing voice and screenshare, you can just run `collab`, and not the `livekit` dev server: +Alternatively, if you're not testing voice and screenshare, you can just run `collab` and `cloud`, and not the `livekit` dev server: ```sh cargo run -p collab -- serve all ``` +```sh +cd ../cloud; cargo make dev +``` + In a new terminal, run two or more instances of Zed. ```sh @@ -131,7 +149,7 @@ script/zed-local -3 This script starts one to four instances of Zed, depending on the `-2`, `-3` or `-4` flags. Each instance will be connected to the local `collab` server, signed in as a different user from `.admins.json` or `.admins.default.json`. -### On Windows +### On Windows {#run-collab-windows} Since `foreman` is not available on Windows, you can run the following commands in separate terminals: @@ -151,6 +169,12 @@ Otherwise, .\path\to\livekit-serve.exe --dev ``` +You'll also need to start the cloud server: + +```powershell +cd ..\cloud; cargo make dev +``` + In a new terminal, run two or more instances of Zed. ```powershell @@ -161,7 +185,10 @@ Note that this requires `node.exe` to be in your `PATH`. ## Running a local collab server -If you want to run your own version of the zed collaboration service, you can, but note that this is still under development, and there is no good support for authentication nor extensions. +> [!NOTE] +> Because of recent changes to our authentication system, Zed will not be able to authenticate itself with, and therefore use, a local collab server. + +If you want to run your own version of the zed collaboration service, you can, but note that this is still under development, and there is no support for authentication nor extensions. Configuration is done through environment variables. By default it will read the configuration from [`.env.toml`](https://github.com/zed-industries/zed/blob/main/crates/collab/.env.toml) and you should use that as a guide for setting this up. diff --git a/docs/src/visual-customization.md b/docs/src/visual-customization.md index 8b307d97d5..46de078d89 100644 --- a/docs/src/visual-customization.md +++ b/docs/src/visual-customization.md @@ -223,6 +223,7 @@ TBD: Centered layout related settings "enabled": true, // Show/hide inline blame "delay": 0, // Show after delay (ms) "min_column": 0, // Minimum column to inline display blame + "padding": 7, // Padding between code and inline blame (em) "show_commit_summary": false // Show/hide commit summary }, "hunk_style": "staged_hollow" // staged_hollow, unstaged_hollow @@ -305,6 +306,17 @@ TBD: Centered layout related settings } ``` +### Status Bar + +```json + "status_bar": { + // Show/hide a button that displays the active buffer's language. + // Clicking the button brings up the language selector. + // Defaults to true. + "active_language_button": true, + }, +``` + ### Multibuffer ```json diff --git a/extensions/emmet/extension.toml b/extensions/emmet/extension.toml index 99aa80a2d4..9fa14d091f 100644 --- a/extensions/emmet/extension.toml +++ b/extensions/emmet/extension.toml @@ -9,7 +9,7 @@ repository = "https://github.com/zed-industries/zed" [language_servers.emmet-language-server] name = "Emmet Language Server" language = "HTML" -languages = ["HTML", "PHP", "ERB", "HTML/ERB", "JavaScript", "TSX", "CSS", "HEEX", "Elixir"] +languages = ["HTML", "PHP", "ERB", "HTML/ERB", "JavaScript", "TSX", "CSS", "HEEX", "Elixir", "Vue.js"] [language_servers.emmet-language-server.language_ids] "HTML" = "html" @@ -21,3 +21,4 @@ languages = ["HTML", "PHP", "ERB", "HTML/ERB", "JavaScript", "TSX", "CSS", "HEEX "CSS" = "css" "HEEX" = "heex" "Elixir" = "heex" +"Vue.js" = "vue" diff --git a/script/install.sh b/script/install.sh index 9cd21119b7..feb140c984 100755 --- a/script/install.sh +++ b/script/install.sh @@ -9,7 +9,12 @@ main() { platform="$(uname -s)" arch="$(uname -m)" channel="${ZED_CHANNEL:-stable}" - temp="$(mktemp -d "/tmp/zed-XXXXXX")" + # Use TMPDIR if available (for environments with non-standard temp directories) + if [ -n "${TMPDIR:-}" ] && [ -d "${TMPDIR}" ]; then + temp="$(mktemp -d "$TMPDIR/zed-XXXXXX")" + else + temp="$(mktemp -d "/tmp/zed-XXXXXX")" + fi if [ "$platform" = "Darwin" ]; then platform="macos" diff --git a/script/lib/deploy-helpers.sh b/script/lib/deploy-helpers.sh index c0feb2f861..bd7b3c4d6f 100644 --- a/script/lib/deploy-helpers.sh +++ b/script/lib/deploy-helpers.sh @@ -5,7 +5,7 @@ function export_vars_for_environment { echo "Invalid environment name '${environment}'" >&2 exit 1 fi - export $(cat $env_file) + export $(grep -v '^#' $env_file | grep -v '^[[:space:]]*$') } function target_zed_kube_cluster { diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index 5678e46236..338985ed95 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -305,7 +305,7 @@ scopeguard = { version = "1" } security-framework = { version = "3", features = ["OSX_10_14"] } security-framework-sys = { version = "2", features = ["OSX_10_14"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } @@ -334,7 +334,7 @@ scopeguard = { version = "1" } security-framework = { version = "3", features = ["OSX_10_14"] } security-framework-sys = { version = "2", features = ["OSX_10_14"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } @@ -362,7 +362,7 @@ scopeguard = { version = "1" } security-framework = { version = "3", features = ["OSX_10_14"] } security-framework-sys = { version = "2", features = ["OSX_10_14"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } @@ -391,7 +391,7 @@ scopeguard = { version = "1" } security-framework = { version = "3", features = ["OSX_10_14"] } security-framework-sys = { version = "2", features = ["OSX_10_14"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } @@ -429,7 +429,7 @@ rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", scopeguard = { version = "1" } syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -468,7 +468,7 @@ rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["ev rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", "net", "process", "termios", "time"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -509,7 +509,7 @@ rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", scopeguard = { version = "1" } syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -548,7 +548,7 @@ rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["ev rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", "net", "process", "termios", "time"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -568,7 +568,7 @@ ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } @@ -592,7 +592,7 @@ ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } @@ -636,7 +636,7 @@ rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", scopeguard = { version = "1" } syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -675,7 +675,7 @@ rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["ev rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", features = ["fs", "net", "process", "termios", "time"] } scopeguard = { version = "1" } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] }